diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 3c45151aa..7b76a3768 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -211,7 +211,7 @@ type ParticipantImpl struct { onStateChange func(p types.LocalParticipant, state livekit.ParticipantInfo_State) onMigrateStateChange func(p types.LocalParticipant, migrateState types.MigrateState) onParticipantUpdate func(types.LocalParticipant) - onDataPacket func(types.LocalParticipant, *livekit.DataPacket) + onDataPacket func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) migrateState atomic.Value // types.MigrateState @@ -625,7 +625,7 @@ func (p *ParticipantImpl) OnParticipantUpdate(callback func(types.LocalParticipa p.lock.Unlock() } -func (p *ParticipantImpl) OnDataPacket(callback func(types.LocalParticipant, *livekit.DataPacket)) { +func (p *ParticipantImpl) OnDataPacket(callback func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) { p.lock.Lock() p.onDataPacket = callback p.lock.Unlock() @@ -1472,8 +1472,8 @@ func (p *ParticipantImpl) onDataMessage(kind livekit.DataPacket_Kind, data []byt p.dataChannelStats.AddBytes(uint64(len(data)), false) - dp := livekit.DataPacket{} - if err := proto.Unmarshal(data, &dp); err != nil { + dp := &livekit.DataPacket{} + if err := proto.Unmarshal(data, dp); err != nil { p.pubLogger.Warnw("could not parse data packet", err) return } @@ -1481,6 +1481,12 @@ func (p *ParticipantImpl) onDataMessage(kind livekit.DataPacket_Kind, data []byt // trust the channel that it came in as the source of truth dp.Kind = kind + if p.Hidden() { + dp.ParticipantIdentity = "" + } else { + dp.ParticipantIdentity = string(p.params.Identity) + } + // only forward on user payloads switch payload := dp.Value.(type) { case *livekit.DataPacket_User: @@ -1488,14 +1494,34 @@ func (p *ParticipantImpl) onDataMessage(kind livekit.DataPacket_Kind, data []byt onDataPacket := p.onDataPacket p.lock.RUnlock() if onDataPacket != nil { + u := payload.User if p.Hidden() { - payload.User.ParticipantSid = "" - payload.User.ParticipantIdentity = "" + u.ParticipantSid = "" + u.ParticipantIdentity = "" } else { - payload.User.ParticipantSid = string(p.params.SID) - payload.User.ParticipantIdentity = string(p.params.Identity) + u.ParticipantSid = string(p.params.SID) + u.ParticipantIdentity = string(p.params.Identity) + } + if dp.ParticipantIdentity != "" { + u.ParticipantIdentity = dp.ParticipantIdentity + } else { + dp.ParticipantIdentity = u.ParticipantIdentity + } + if len(dp.DestinationIdentities) != 0 { + u.DestinationIdentities = dp.DestinationIdentities + } else { + dp.DestinationIdentities = u.DestinationIdentities + } + onDataPacket(p, kind, dp) + } + case *livekit.DataPacket_SipDtmf: + if p.grants.GetParticipantKind() == livekit.ParticipantInfo_SIP { + p.lock.RLock() + onDataPacket := p.onDataPacket + p.lock.RUnlock() + if onDataPacket != nil { + onDataPacket(p, kind, dp) } - onDataPacket(p, &dp) } default: p.pubLogger.Warnw("received unsupported data packet", nil, "payload", payload) @@ -2474,19 +2500,19 @@ func codecsFromMediaDescription(m *sdp.MediaDescription) (out []sdp.Codec, err e return out, nil } -func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket, data []byte) error { +func (p *ParticipantImpl) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { if p.State() != livekit.ParticipantInfo_ACTIVE { return ErrDataChannelUnavailable } - err := p.TransportManager.SendDataPacket(dp, data) + err := p.TransportManager.SendDataPacket(kind, encoded) if err != nil { if (errors.Is(err, sctp.ErrStreamClosed) || errors.Is(err, io.ErrClosedPipe)) && p.params.ReconnectOnDataChannelError { p.params.Logger.Infow("issuing full reconnect on data channel error", "error", err) p.IssueFullReconnect(types.ParticipantCloseReasonDataChannelError) } } else { - p.dataChannelStats.AddBytes(uint64(len(data)), true) + p.dataChannelStats.AddBytes(uint64(len(encoded)), true) } return err } diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 02343080f..5cfe8f5da 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "math" + "slices" "sort" "strings" "sync" @@ -820,14 +821,8 @@ func (r *Room) OnParticipantChanged(f func(participant types.LocalParticipant)) r.onParticipantChanged = f } -func (r *Room) SendDataPacket(up *livekit.UserPacket, kind livekit.DataPacket_Kind) { - dp := &livekit.DataPacket{ - Kind: kind, - Value: &livekit.DataPacket_User{ - User: up, - }, - } - r.onDataPacket(nil, dp) +func (r *Room) SendDataPacket(dp *livekit.DataPacket, kind livekit.DataPacket_Kind) { + r.onDataPacket(nil, kind, dp) } func (r *Room) SetMetadata(metadata string) <-chan struct{} { @@ -1085,8 +1080,8 @@ func (r *Room) onParticipantUpdate(p types.LocalParticipant) { } } -func (r *Room) onDataPacket(source types.LocalParticipant, dp *livekit.DataPacket) { - BroadcastDataPacketForRoom(r, source, dp, r.Logger) +func (r *Room) onDataPacket(source types.LocalParticipant, kind livekit.DataPacket_Kind, dp *livekit.DataPacket) { + BroadcastDataPacketForRoom(r, source, kind, dp, r.Logger) } func (r *Room) subscribeToExistingTracks(p types.LocalParticipant) { @@ -1171,33 +1166,6 @@ func (r *Room) sendParticipantUpdates(updates []*participantUpdate) { } } -// for protocol 2, send all active speakers -func (r *Room) sendActiveSpeakers(speakers []*livekit.SpeakerInfo) { - dp := &livekit.DataPacket{ - Kind: livekit.DataPacket_LOSSY, - Value: &livekit.DataPacket_Speaker{ - Speaker: &livekit.ActiveSpeakerUpdate{ - Speakers: speakers, - }, - }, - } - - var dpData []byte - for _, p := range r.GetParticipants() { - if p.ProtocolVersion().HandlesDataPackets() && !p.ProtocolVersion().SupportsSpeakerChanged() { - if dpData == nil { - var err error - dpData, err = proto.Marshal(dp) - if err != nil { - r.Logger.Errorw("failed to marshal ActiveSpeaker data packet", err) - return - } - } - _ = p.SendDataPacket(dp, dpData) - } - } -} - // for protocol 3, send only changed updates func (r *Room) sendSpeakerChanges(speakers []*livekit.SpeakerInfo) { for _, p := range r.GetParticipants() { @@ -1346,7 +1314,6 @@ func (r *Room) audioUpdateWorker() { // see if an update is needed if len(changedSpeakers) > 0 { - r.sendActiveSpeakers(activeSpeakers) r.sendSpeakerChanges(changedSpeakers) } @@ -1495,18 +1462,34 @@ func (r *Room) DebugInfo() map[string]interface{} { // ------------------------------------------------------------ -func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, dp *livekit.DataPacket, logger logger.Logger) { +func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, kind livekit.DataPacket_Kind, dp *livekit.DataPacket, logger logger.Logger) { + dp.Kind = kind // backward compatibility dest := dp.GetUser().GetDestinationSids() - var dpData []byte - destIdentities := dp.GetUser().GetDestinationIdentities() + if u := dp.GetUser(); u != nil { + if len(dp.DestinationIdentities) == 0 { + dp.DestinationIdentities = u.DestinationIdentities + } else { + u.DestinationIdentities = dp.DestinationIdentities + } + if dp.ParticipantIdentity != "" { + u.ParticipantIdentity = dp.ParticipantIdentity + } else { + dp.ParticipantIdentity = u.ParticipantIdentity + } + } + destIdentities := dp.DestinationIdentities participants := r.GetLocalParticipants() - capacity := len(dest) + capacity := len(destIdentities) + if capacity == 0 { + capacity = len(dest) + } if capacity == 0 { capacity = len(participants) } destParticipants := make([]types.LocalParticipant, 0, capacity) + var dpData []byte for _, op := range participants { if op.State() != livekit.ParticipantInfo_ACTIVE { continue @@ -1515,20 +1498,7 @@ func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, dp continue } if len(dest) > 0 || len(destIdentities) > 0 { - found := false - for _, dID := range dest { - if op.ID() == livekit.ParticipantID(dID) { - found = true - break - } - } - for _, dIdentity := range destIdentities { - if op.Identity() == livekit.ParticipantIdentity(dIdentity) { - found = true - break - } - } - if !found { + if !slices.Contains(dest, string(op.ID())) && !slices.Contains(destIdentities, string(op.Identity())) { continue } } @@ -1544,7 +1514,7 @@ func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, dp } utils.ParallelExec(destParticipants, dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { - err := op.SendDataPacket(dp, dpData) + err := op.SendDataPacket(kind, dpData) if err != nil && !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, sctp.ErrStreamClosed) && !errors.Is(err, ErrTransportFailure) && !errors.Is(err, ErrDataChannelBufferFull) { op.GetLogger().Infow("send data packet error", "error", err) diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index ea09a91d0..c73b5fa5c 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -22,10 +22,11 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "github.com/livekit/livekit-server/version" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/webhook" + "github.com/livekit/livekit-server/version" + "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" @@ -569,65 +570,126 @@ func TestActiveSpeakers(t *testing.T) { func TestDataChannel(t *testing.T) { t.Parallel() - t.Run("participants should receive data", func(t *testing.T) { - rm := newRoomWithParticipants(t, testRoomOpts{num: 3}) - defer rm.Close(types.ParticipantCloseReasonNone) - participants := rm.GetParticipants() - p := participants[0].(*typesfakes.FakeLocalParticipant) + const ( + curAPI = iota + legacySID + legacyIdentity + ) + modes := []int{ + curAPI, legacySID, legacyIdentity, + } + modeNames := []string{ + "cur", "legacy sid", "legacy identity", + } - packet := livekit.DataPacket{ - Kind: livekit.DataPacket_RELIABLE, - Value: &livekit.DataPacket_User{ - User: &livekit.UserPacket{ - ParticipantSid: string(p.ID()), - Payload: []byte("message.."), - }, - }, + setSource := func(mode int, dp *livekit.DataPacket, p types.LocalParticipant) { + switch mode { + case curAPI: + dp.ParticipantIdentity = string(p.Identity()) + case legacySID: + dp.GetUser().ParticipantSid = string(p.ID()) + case legacyIdentity: + dp.GetUser().ParticipantIdentity = string(p.Identity()) } - p.OnDataPacketArgsForCall(0)(p, &packet) + } + setDest := func(mode int, dp *livekit.DataPacket, p types.LocalParticipant) { + switch mode { + case curAPI: + dp.DestinationIdentities = []string{string(p.Identity())} + case legacySID: + dp.GetUser().DestinationSids = []string{string(p.ID())} + case legacyIdentity: + dp.GetUser().DestinationIdentities = []string{string(p.Identity())} + } + } - // ensure everyone has received the packet - for _, op := range participants { - fp := op.(*typesfakes.FakeLocalParticipant) - if fp == p { - require.Zero(t, fp.SendDataPacketCallCount()) - continue - } - require.Equal(t, 1, fp.SendDataPacketCallCount()) - dp, _ := fp.SendDataPacketArgsForCall(0) - require.Equal(t, packet.Value, dp.Value) + t.Run("participants should receive data", func(t *testing.T) { + for _, mode := range modes { + mode := mode + t.Run(modeNames[mode], func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 3}) + defer rm.Close(types.ParticipantCloseReasonNone) + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + + packet := &livekit.DataPacket{ + Kind: livekit.DataPacket_RELIABLE, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: []byte("message.."), + }, + }, + } + setSource(mode, packet, p) + + packetExp := proto.Clone(packet).(*livekit.DataPacket) + if mode != legacySID { + packetExp.ParticipantIdentity = string(p.Identity()) + packetExp.GetUser().ParticipantIdentity = string(p.Identity()) + } + + encoded, _ := proto.Marshal(packetExp) + p.OnDataPacketArgsForCall(0)(p, packet.Kind, packet) + + // ensure everyone has received the packet + for _, op := range participants { + fp := op.(*typesfakes.FakeLocalParticipant) + if fp == p { + require.Zero(t, fp.SendDataPacketCallCount()) + continue + } + require.Equal(t, 1, fp.SendDataPacketCallCount()) + _, got := fp.SendDataPacketArgsForCall(0) + require.Equal(t, encoded, got) + } + }) } }) t.Run("only one participant should receive the data", func(t *testing.T) { - rm := newRoomWithParticipants(t, testRoomOpts{num: 4}) - defer rm.Close(types.ParticipantCloseReasonNone) - participants := rm.GetParticipants() - p := participants[0].(*typesfakes.FakeLocalParticipant) - p1 := participants[1].(*typesfakes.FakeLocalParticipant) + for _, mode := range modes { + mode := mode + t.Run(modeNames[mode], func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 4}) + defer rm.Close(types.ParticipantCloseReasonNone) + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + p1 := participants[1].(*typesfakes.FakeLocalParticipant) - packet := livekit.DataPacket{ - Kind: livekit.DataPacket_RELIABLE, - Value: &livekit.DataPacket_User{ - User: &livekit.UserPacket{ - ParticipantSid: string(p.ID()), - Payload: []byte("message to p1.."), - DestinationSids: []string{string(p1.ID())}, - }, - }, - } - p.OnDataPacketArgsForCall(0)(p, &packet) + packet := &livekit.DataPacket{ + Kind: livekit.DataPacket_RELIABLE, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: []byte("message to p1.."), + }, + }, + } + setSource(mode, packet, p) + setDest(mode, packet, p1) - // only p1 should receive the data - for _, op := range participants { - fp := op.(*typesfakes.FakeLocalParticipant) - if fp != p1 { - require.Zero(t, fp.SendDataPacketCallCount()) - } + packetExp := proto.Clone(packet).(*livekit.DataPacket) + if mode != legacySID { + packetExp.ParticipantIdentity = string(p.Identity()) + packetExp.GetUser().ParticipantIdentity = string(p.Identity()) + packetExp.DestinationIdentities = []string{string(p1.Identity())} + packetExp.GetUser().DestinationIdentities = []string{string(p1.Identity())} + } + + encoded, _ := proto.Marshal(packetExp) + p.OnDataPacketArgsForCall(0)(p, packet.Kind, packet) + + // only p1 should receive the data + for _, op := range participants { + fp := op.(*typesfakes.FakeLocalParticipant) + if fp != p1 { + require.Zero(t, fp.SendDataPacketCallCount()) + } + } + require.Equal(t, 1, p1.SendDataPacketCallCount()) + _, got := p1.SendDataPacketArgsForCall(0) + require.Equal(t, encoded, got) + }) } - require.Equal(t, 1, p1.SendDataPacketCallCount()) - dp, _ := p1.SendDataPacketArgsForCall(0) - require.Equal(t, packet.Value, dp.Value) }) t.Run("publishing disallowed", func(t *testing.T) { @@ -646,7 +708,7 @@ func TestDataChannel(t *testing.T) { }, } if p.CanPublishData() { - p.OnDataPacketArgsForCall(0)(p, &packet) + p.OnDataPacketArgsForCall(0)(p, packet.Kind, &packet) } // no one should've been sent packet diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 29558ae47..1e0cd1a93 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -35,6 +35,13 @@ import ( "github.com/pkg/errors" "go.uber.org/atomic" + lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor" + lktwcc "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/logger/pionlogger" + lksdp "github.com/livekit/protocol/sdp" + "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" @@ -46,12 +53,6 @@ import ( "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/livekit-server/pkg/utils" sutils "github.com/livekit/livekit-server/pkg/utils" - lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor" - lktwcc "github.com/livekit/mediatransportutil/pkg/twcc" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/logger/pionlogger" - lksdp "github.com/livekit/protocol/sdp" ) const ( @@ -799,16 +800,27 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni if err != nil { return err } + var ( + dcPtr **webrtc.DataChannel + dcReady *bool + ) + switch dc.Label() { + default: + // TODO: Appears that it's never called, so not sure what needs to be done here. We just keep the DC open? + // Maybe just add "reliable" parameter instead of checking the label. + t.params.Logger.Warnw("unknown data channel label", nil, "label", dc.Label()) + return nil + case ReliableDataChannel: + dcPtr = &t.reliableDC + dcReady = &t.reliableDCOpened + case LossyDataChannel: + dcPtr = &t.lossyDC + dcReady = &t.lossyDCOpened + } dcReadyHandler := func() { t.lock.Lock() - switch dc.Label() { - case ReliableDataChannel: - t.reliableDCOpened = true - - case LossyDataChannel: - t.lossyDCOpened = true - } + *dcReady = true t.lock.Unlock() t.params.Logger.Debugw(dc.Label() + " data channel open") @@ -826,30 +838,15 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni } t.lock.Lock() - switch dc.Label() { - case ReliableDataChannel: - t.reliableDC = dc - if t.params.DirectionConfig.StrictACKs { - t.reliableDC.OnOpen(dcReadyHandler) - } else { - t.reliableDC.OnDial(dcReadyHandler) - } - t.reliableDC.OnClose(dcCloseHandler) - t.reliableDC.OnError(dcErrorHandler) - case LossyDataChannel: - t.lossyDC = dc - if t.params.DirectionConfig.StrictACKs { - t.lossyDC.OnOpen(dcReadyHandler) - } else { - t.lossyDC.OnDial(dcReadyHandler) - } - t.lossyDC.OnClose(dcCloseHandler) - t.lossyDC.OnError(dcErrorHandler) - default: - t.params.Logger.Warnw("unknown data channel label", nil, "label", dc.Label()) + defer t.lock.Unlock() + *dcPtr = dc + if t.params.DirectionConfig.StrictACKs { + dc.OnOpen(dcReadyHandler) + } else { + dc.OnDial(dcReadyHandler) } - t.lock.Unlock() - + dc.OnClose(dcCloseHandler) + dc.OnError(dcErrorHandler) return nil } @@ -903,10 +900,10 @@ func (t *PCTransport) WriteRTCP(pkts []rtcp.Packet) error { return t.pc.WriteRTCP(pkts) } -func (t *PCTransport) SendDataPacket(dp *livekit.DataPacket, data []byte) error { +func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { var dc *webrtc.DataChannel t.lock.RLock() - if dp.Kind == livekit.DataPacket_RELIABLE { + if kind == livekit.DataPacket_RELIABLE { dc = t.reliableDC } else { dc = t.lossyDC @@ -925,7 +922,7 @@ func (t *PCTransport) SendDataPacket(dp *livekit.DataPacket, data []byte) error return ErrDataChannelBufferFull } - return dc.Send(data) + return dc.Send(encoded) } func (t *PCTransport) Close() { diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index cd1b23289..049b76a5e 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -28,14 +28,15 @@ import ( "go.uber.org/atomic" "google.golang.org/protobuf/proto" + "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/pacer" - "github.com/livekit/mediatransportutil/pkg/twcc" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" ) const ( @@ -240,9 +241,9 @@ func (t *TransportManager) RemoveSubscribedTrack(subTrack types.SubscribedTrack) t.subscriber.RemoveTrackFromStreamAllocator(subTrack) } -func (t *TransportManager) SendDataPacket(dp *livekit.DataPacket, data []byte) error { +func (t *TransportManager) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { // downstream data is sent via primary peer connection - return t.getTransport(true).SendDataPacket(dp, data) + return t.getTransport(true).SendDataPacket(kind, encoded) } func (t *TransportManager) createDataChannelsForSubscriber(pendingDataChannels []*livekit.DataChannelInfo) error { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index e4e6b61d3..b0943d4f4 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -365,7 +365,7 @@ type LocalParticipant interface { SendJoinResponse(joinResponse *livekit.JoinResponse) error SendParticipantUpdate(participants []*livekit.ParticipantInfo) error SendSpeakerUpdate(speakers []*livekit.SpeakerInfo, force bool) error - SendDataPacket(packet *livekit.DataPacket, data []byte) error + SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error SendRoomUpdate(room *livekit.Room) error SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error SubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) @@ -384,7 +384,7 @@ type LocalParticipant interface { OnTrackUnpublished(callback func(LocalParticipant, MediaTrack)) // OnParticipantUpdate - metadata or permission is updated OnParticipantUpdate(callback func(LocalParticipant)) - OnDataPacket(callback func(LocalParticipant, *livekit.DataPacket)) + OnDataPacket(callback func(LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) OnSubscribeStatusChanged(fn func(publisherID livekit.ParticipantID, subscribed bool)) OnClose(callback func(LocalParticipant)) OnClaimsChanged(callback func(LocalParticipant)) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 464f4009c..37c03ef10 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -571,10 +571,10 @@ type FakeLocalParticipant struct { onCloseArgsForCall []struct { arg1 func(types.LocalParticipant) } - OnDataPacketStub func(func(types.LocalParticipant, *livekit.DataPacket)) + OnDataPacketStub func(func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) onDataPacketMutex sync.RWMutex onDataPacketArgsForCall []struct { - arg1 func(types.LocalParticipant, *livekit.DataPacket) + arg1 func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) } OnICEConfigChangedStub func(func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig)) onICEConfigChangedMutex sync.RWMutex @@ -660,10 +660,10 @@ type FakeLocalParticipant struct { sendConnectionQualityUpdateReturnsOnCall map[int]struct { result1 error } - SendDataPacketStub func(*livekit.DataPacket, []byte) error + SendDataPacketStub func(livekit.DataPacket_Kind, []byte) error sendDataPacketMutex sync.RWMutex sendDataPacketArgsForCall []struct { - arg1 *livekit.DataPacket + arg1 livekit.DataPacket_Kind arg2 []byte } sendDataPacketReturns struct { @@ -3932,10 +3932,10 @@ func (fake *FakeLocalParticipant) OnCloseArgsForCall(i int) func(types.LocalPart return argsForCall.arg1 } -func (fake *FakeLocalParticipant) OnDataPacket(arg1 func(types.LocalParticipant, *livekit.DataPacket)) { +func (fake *FakeLocalParticipant) OnDataPacket(arg1 func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) { fake.onDataPacketMutex.Lock() fake.onDataPacketArgsForCall = append(fake.onDataPacketArgsForCall, struct { - arg1 func(types.LocalParticipant, *livekit.DataPacket) + arg1 func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) }{arg1}) stub := fake.OnDataPacketStub fake.recordInvocation("OnDataPacket", []interface{}{arg1}) @@ -3951,13 +3951,13 @@ func (fake *FakeLocalParticipant) OnDataPacketCallCount() int { return len(fake.onDataPacketArgsForCall) } -func (fake *FakeLocalParticipant) OnDataPacketCalls(stub func(func(types.LocalParticipant, *livekit.DataPacket))) { +func (fake *FakeLocalParticipant) OnDataPacketCalls(stub func(func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket))) { fake.onDataPacketMutex.Lock() defer fake.onDataPacketMutex.Unlock() fake.OnDataPacketStub = stub } -func (fake *FakeLocalParticipant) OnDataPacketArgsForCall(i int) func(types.LocalParticipant, *livekit.DataPacket) { +func (fake *FakeLocalParticipant) OnDataPacketArgsForCall(i int) func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) { fake.onDataPacketMutex.RLock() defer fake.onDataPacketMutex.RUnlock() argsForCall := fake.onDataPacketArgsForCall[i] @@ -4461,7 +4461,7 @@ func (fake *FakeLocalParticipant) SendConnectionQualityUpdateReturnsOnCall(i int }{result1} } -func (fake *FakeLocalParticipant) SendDataPacket(arg1 *livekit.DataPacket, arg2 []byte) error { +func (fake *FakeLocalParticipant) SendDataPacket(arg1 livekit.DataPacket_Kind, arg2 []byte) error { var arg2Copy []byte if arg2 != nil { arg2Copy = make([]byte, len(arg2)) @@ -4470,7 +4470,7 @@ func (fake *FakeLocalParticipant) SendDataPacket(arg1 *livekit.DataPacket, arg2 fake.sendDataPacketMutex.Lock() ret, specificReturn := fake.sendDataPacketReturnsOnCall[len(fake.sendDataPacketArgsForCall)] fake.sendDataPacketArgsForCall = append(fake.sendDataPacketArgsForCall, struct { - arg1 *livekit.DataPacket + arg1 livekit.DataPacket_Kind arg2 []byte }{arg1, arg2Copy}) stub := fake.SendDataPacketStub @@ -4492,13 +4492,13 @@ func (fake *FakeLocalParticipant) SendDataPacketCallCount() int { return len(fake.sendDataPacketArgsForCall) } -func (fake *FakeLocalParticipant) SendDataPacketCalls(stub func(*livekit.DataPacket, []byte) error) { +func (fake *FakeLocalParticipant) SendDataPacketCalls(stub func(livekit.DataPacket_Kind, []byte) error) { fake.sendDataPacketMutex.Lock() defer fake.sendDataPacketMutex.Unlock() fake.SendDataPacketStub = stub } -func (fake *FakeLocalParticipant) SendDataPacketArgsForCall(i int) (*livekit.DataPacket, []byte) { +func (fake *FakeLocalParticipant) SendDataPacketArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { fake.sendDataPacketMutex.RLock() defer fake.sendDataPacketMutex.RUnlock() argsForCall := fake.sendDataPacketArgsForCall[i] diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 8c5aa60d5..96dd7036c 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -24,8 +24,6 @@ import ( "github.com/pkg/errors" "golang.org/x/exp/maps" - "github.com/livekit/livekit-server/pkg/telemetry/prometheus" - "github.com/livekit/livekit-server/version" "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" @@ -41,6 +39,8 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/version" ) const ( @@ -753,13 +753,18 @@ func (r *RoomManager) SendData(ctx context.Context, req *livekit.SendDataRequest } room.Logger.Debugw("api send data", "size", len(req.Data)) - up := &livekit.UserPacket{ - Payload: req.Data, - DestinationSids: req.DestinationSids, + room.SendDataPacket(&livekit.DataPacket{ + Kind: req.Kind, DestinationIdentities: req.DestinationIdentities, - Topic: req.Topic, - } - room.SendDataPacket(up, req.Kind) + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: req.Data, + DestinationSids: req.DestinationSids, + DestinationIdentities: req.DestinationIdentities, + Topic: req.Topic, + }, + }, + }, req.Kind) return &livekit.SendDataResponse{}, nil } diff --git a/test/client/client.go b/test/client/client.go index e0be22e47..5d461a74a 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -679,19 +679,17 @@ func (c *RTCClient) PublishData(data []byte, kind livekit.DataPacket_Kind) error return err } - dp := &livekit.DataPacket{ + dpData, err := proto.Marshal(&livekit.DataPacket{ Kind: kind, Value: &livekit.DataPacket_User{ User: &livekit.UserPacket{Payload: data}, }, - } - - dpData, err := proto.Marshal(dp) + }) if err != nil { return err } - return c.publisher.SendDataPacket(dp, dpData) + return c.publisher.SendDataPacket(kind, dpData) } func (c *RTCClient) GetPublishedTrackIDs() []string { @@ -732,12 +730,13 @@ func (c *RTCClient) ensurePublisherConnected() error { } } -func (c *RTCClient) handleDataMessage(_ livekit.DataPacket_Kind, data []byte) { +func (c *RTCClient) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) { dp := &livekit.DataPacket{} err := proto.Unmarshal(data, dp) if err != nil { return } + dp.Kind = kind if val, ok := dp.Value.(*livekit.DataPacket_User); ok { if c.OnDataReceived != nil { c.OnDataReceived(val.User.Payload, val.User.ParticipantSid)