From d08487bf83fd5df64a14abf7683254ccc276e0fc Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Tue, 1 Apr 2025 21:59:31 +0530 Subject: [PATCH] Unlabeled (pass through) data channels. (#3567) * Unlabeled (pass through) data channels. Support data channels than can pass through raw data without any LK protocol marshaling/unmarshaling. * statischeck * test * error -> warn * reset data message callback --- pkg/rtc/participant.go | 53 ++++- pkg/rtc/room.go | 26 ++- pkg/rtc/room_test.go | 14 +- pkg/rtc/transport.go | 120 +++++++++-- pkg/rtc/transport/handler.go | 6 +- .../transport/transportfakes/fake_handler.go | 92 +++++--- pkg/rtc/transportmanager.go | 38 +++- pkg/rtc/types/interfaces.go | 4 +- .../typesfakes/fake_local_participant.go | 196 ++++++++++++++---- pkg/service/servicefakes/fake_sipstore.go | 162 --------------- test/client/client.go | 41 +++- test/multinode_test.go | 1 + test/scenarios.go | 27 +++ 13 files changed, 505 insertions(+), 275 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 1e0d09d11..a8e900609 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -243,6 +243,7 @@ type ParticipantImpl struct { onMigrateStateChange func(p types.LocalParticipant, migrateState types.MigrateState) onParticipantUpdate func(types.LocalParticipant) onDataPacket func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) + onDataMessage func(types.LocalParticipant, []byte) onMetrics func(types.Participant, *livekit.DataPacket) migrateState atomic.Value // types.MigrateState @@ -789,6 +790,18 @@ func (p *ParticipantImpl) getOnDataPacket() func(types.LocalParticipant, livekit return p.onDataPacket } +func (p *ParticipantImpl) OnDataMessage(callback func(types.LocalParticipant, []byte)) { + p.lock.Lock() + p.onDataMessage = callback + p.lock.Unlock() +} + +func (p *ParticipantImpl) getOnDataMessage() func(types.LocalParticipant, []byte) { + p.lock.RLock() + defer p.lock.RUnlock() + return p.onDataMessage +} + func (p *ParticipantImpl) OnMetrics(callback func(types.Participant, *livekit.DataPacket)) { p.lock.Lock() p.onMetrics = callback @@ -1465,8 +1478,16 @@ func (h PublisherTransportHandler) OnInitialConnected() { h.p.onPublisherInitialConnected() } -func (h PublisherTransportHandler) OnDataPacket(kind livekit.DataPacket_Kind, data []byte) { - h.p.onDataMessage(kind, data) +func (h PublisherTransportHandler) OnDataMessage(kind livekit.DataPacket_Kind, data []byte) { + h.p.onReceivedDataMessage(kind, data) +} + +func (h PublisherTransportHandler) OnDataMessageUnlabeled(data []byte) { + h.p.onReceivedDataMessageUnlabeled(data) +} + +func (h PublisherTransportHandler) OnDataSendError(err error) { + h.p.onDataSendError(err) } // ---------------------------------------------------------- @@ -1655,7 +1676,7 @@ func (p *ParticipantImpl) MetricsReporterBatchReady(mb *livekit.MetricsBatch) { return } - p.TransportManager.SendDataPacket(livekit.DataPacket_RELIABLE, dpData) + p.TransportManager.SendDataMessage(livekit.DataPacket_RELIABLE, dpData) } func (p *ParticipantImpl) setupMetrics() { @@ -1835,7 +1856,7 @@ func (p *ParticipantImpl) handlePendingRemoteTracks() { } } -func (p *ParticipantImpl) onDataMessage(kind livekit.DataPacket_Kind, data []byte) { +func (p *ParticipantImpl) onReceivedDataMessage(kind livekit.DataPacket_Kind, data []byte) { if p.IsDisconnected() || !p.CanPublishData() { return } @@ -1958,6 +1979,18 @@ func (p *ParticipantImpl) onDataMessage(kind livekit.DataPacket_Kind, data []byt } } +func (p *ParticipantImpl) onReceivedDataMessageUnlabeled(data []byte) { + if p.IsDisconnected() || !p.CanPublishData() { + return + } + + p.dataChannelStats.AddBytes(uint64(len(data)), false) + + if onDataMessage := p.getOnDataMessage(); onDataMessage != nil { + onDataMessage(p, data) + } +} + func (p *ParticipantImpl) onICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { if p.IsDisconnected() || p.IsClosed() { return nil @@ -3004,12 +3037,20 @@ func (p *ParticipantImpl) SupportsTransceiverReuse() bool { return p.ProtocolVersion().SupportsTransceiverReuse() && !p.SupportsSyncStreamID() } -func (p *ParticipantImpl) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { +func (p *ParticipantImpl) SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error { if p.State() != livekit.ParticipantInfo_ACTIVE { return ErrDataChannelUnavailable } - return p.TransportManager.SendDataPacket(kind, encoded) + return p.TransportManager.SendDataMessage(kind, data) +} + +func (p *ParticipantImpl) SendDataMessageUnlabeled(data []byte) error { + if p.State() != livekit.ParticipantInfo_ACTIVE { + return ErrDataChannelUnavailable + } + + return p.TransportManager.SendDataMessageUnlabeled(data) } func (p *ParticipantImpl) onDataSendError(err error) { diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 76d9515da..eda9519df 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -484,6 +484,7 @@ func (r *Room) Join(participant types.LocalParticipant, requestSource routing.Me participant.OnTrackUnpublished(r.onTrackUnpublished) participant.OnParticipantUpdate(r.onParticipantUpdate) participant.OnDataPacket(r.onDataPacket) + participant.OnDataMessage(r.onDataMessage) participant.OnMetrics(r.onMetrics) participant.OnSubscribeStatusChanged(func(publisherID livekit.ParticipantID, subscribed bool) { if subscribed { @@ -731,6 +732,7 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek p.OnStateChange(nil) p.OnParticipantUpdate(nil) p.OnDataPacket(nil) + p.OnDataMessage(nil) p.OnMetrics(nil) p.OnSubscribeStatusChanged(nil) @@ -1285,6 +1287,10 @@ func (r *Room) onDataPacket(source types.LocalParticipant, kind livekit.DataPack BroadcastDataPacketForRoom(r, source, kind, dp, r.Logger) } +func (r *Room) onDataMessage(source types.LocalParticipant, data []byte) { + BroadcastDataMessageForRoom(r, source, data, r.Logger) +} + func (r *Room) onMetrics(source types.Participant, dp *livekit.DataPacket) { BroadcastMetricsForRoom(r, source, dp, r.Logger) } @@ -1760,7 +1766,13 @@ func (r *Room) IsDataMessageUserPacketDuplicate(up *livekit.UserPacket) bool { // ------------------------------------------------------------ -func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, kind livekit.DataPacket_Kind, dp *livekit.DataPacket, logger logger.Logger) { +func BroadcastDataPacketForRoom( + r types.Room, + source types.LocalParticipant, + kind livekit.DataPacket_Kind, + dp *livekit.DataPacket, + logger logger.Logger, +) { dp.Kind = kind // backward compatibility dest := dp.GetUser().GetDestinationSids() if u := dp.GetUser(); u != nil { @@ -1813,7 +1825,17 @@ func BroadcastDataPacketForRoom(r types.Room, source types.LocalParticipant, kin } utils.ParallelExec(destParticipants, dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { - op.SendDataPacket(kind, dpData) + op.SendDataMessage(kind, dpData) + }) +} + +func BroadcastDataMessageForRoom(r types.Room, source types.LocalParticipant, data []byte, logger logger.Logger) { + utils.ParallelExec(r.GetLocalParticipants(), dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { + if source != nil && op.ID() == source.ID() { + return + } + + op.SendDataMessageUnlabeled(data) }) } diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index b67f0b78b..f502e9ca5 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -637,11 +637,11 @@ func TestDataChannel(t *testing.T) { for _, op := range participants { fp := op.(*typesfakes.FakeLocalParticipant) if fp == p { - require.Zero(t, fp.SendDataPacketCallCount()) + require.Zero(t, fp.SendDataMessageCallCount()) continue } - require.Equal(t, 1, fp.SendDataPacketCallCount()) - _, got := fp.SendDataPacketArgsForCall(0) + require.Equal(t, 1, fp.SendDataMessageCallCount()) + _, got := fp.SendDataMessageArgsForCall(0) require.Equal(t, encoded, got) } }) @@ -684,11 +684,11 @@ func TestDataChannel(t *testing.T) { for _, op := range participants { fp := op.(*typesfakes.FakeLocalParticipant) if fp != p1 { - require.Zero(t, fp.SendDataPacketCallCount()) + require.Zero(t, fp.SendDataMessageCallCount()) } } - require.Equal(t, 1, p1.SendDataPacketCallCount()) - _, got := p1.SendDataPacketArgsForCall(0) + require.Equal(t, 1, p1.SendDataMessageCallCount()) + _, got := p1.SendDataMessageArgsForCall(0) require.Equal(t, encoded, got) }) } @@ -716,7 +716,7 @@ func TestDataChannel(t *testing.T) { // no one should've been sent packet for _, op := range participants { fp := op.(*typesfakes.FakeLocalParticipant) - require.Zero(t, fp.SendDataPacketCallCount()) + require.Zero(t, fp.SendDataMessageCallCount()) } }) } diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index e7d302ae8..ca56e2f56 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -196,6 +196,7 @@ type PCTransport struct { reliableDCOpened bool lossyDC *datachannel.DataChannelWriter[*webrtc.DataChannel] lossyDCOpened bool + unlabeledDataChannels []*datachannel.DataChannelWriter[*webrtc.DataChannel] iceStartedAt time.Time iceConnectedAt time.Time @@ -789,6 +790,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { dc.OnOpen(func() { t.params.Logger.Debugw(dc.Label() + " data channel open") var kind livekit.DataPacket_Kind + var isUnlabeled bool switch dc.Label() { case ReliableDataChannel: kind = livekit.DataPacket_RELIABLE @@ -797,8 +799,8 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { kind = livekit.DataPacket_LOSSY default: - t.params.Logger.Warnw("unsupported datachannel added", nil, "label", dc.Label()) - return + t.params.Logger.Infow("unlabeled datachannel added", "label", dc.Label()) + isUnlabeled = true } rawDC, err := dc.DetachWithDeadline() @@ -807,8 +809,16 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { return } - switch kind { - case livekit.DataPacket_RELIABLE: + switch { + case isUnlabeled: + t.lock.Lock() + t.unlabeledDataChannels = append( + t.unlabeledDataChannels, + datachannel.NewDataChannelWriter(dc, rawDC, t.params.DatachannelSlowThreshold), + ) + t.lock.Unlock() + + case kind == livekit.DataPacket_RELIABLE: t.lock.Lock() if t.reliableDC != nil { t.reliableDC.Close() @@ -817,7 +827,7 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { t.reliableDCOpened = true t.lock.Unlock() - case livekit.DataPacket_LOSSY: + case kind == livekit.DataPacket_LOSSY: t.lock.Lock() if t.lossyDC != nil { t.lossyDC.Close() @@ -839,7 +849,11 @@ func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { return } - t.params.Handler.OnDataPacket(kind, buffer[:n]) + if isUnlabeled { + t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) + } else { + t.params.Handler.OnDataMessage(kind, buffer[:n]) + } } }() @@ -969,15 +983,14 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni return err } var ( - dcPtr **datachannel.DataChannelWriter[*webrtc.DataChannel] - dcReady *bool + dcPtr **datachannel.DataChannelWriter[*webrtc.DataChannel] + dcReady *bool + isUnlabeled bool ) switch dc.Label() { default: - // TODO: Appears that it's never called, so not sure what needs to be done here. We just keep the DC open? - // Maybe just add "reliable" parameter instead of checking the label. - t.params.Logger.Warnw("unknown data channel label", nil, "label", dc.Label()) - return nil + isUnlabeled = true + t.params.Logger.Infow("unlabeled datachannel added", "label", dc.Label()) case ReliableDataChannel: dcPtr = &t.reliableDC dcReady = &t.reliableDCOpened @@ -994,16 +1007,23 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni } var slowThreshold int - if dc.Label() == ReliableDataChannel { + if dc.Label() == ReliableDataChannel || isUnlabeled { slowThreshold = t.params.DatachannelSlowThreshold } t.lock.Lock() - if *dcPtr != nil { - (*dcPtr).Close() + if isUnlabeled { + t.unlabeledDataChannels = append( + t.unlabeledDataChannels, + datachannel.NewDataChannelWriter(dc, rawDC, slowThreshold), + ) + } else { + if *dcPtr != nil { + (*dcPtr).Close() + } + *dcPtr = datachannel.NewDataChannelWriter(dc, rawDC, slowThreshold) + *dcReady = true } - *dcPtr = datachannel.NewDataChannelWriter(dc, rawDC, slowThreshold) - *dcReady = true t.lock.Unlock() t.params.Logger.Debugw(dc.Label() + " data channel open") @@ -1013,6 +1033,47 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni return nil } +// for testing only +func (t *PCTransport) CreateReadableDataChannel(label string, dci *webrtc.DataChannelInit) error { + dc, err := t.pc.CreateDataChannel(label, dci) + if err != nil { + return err + } + + dc.OnOpen(func() { + t.params.Logger.Debugw(dc.Label() + " data channel open") + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Errorw("failed to detach data channel", err, "label", dc.Label()) + return + } + + t.lock.Lock() + t.unlabeledDataChannels = append( + t.unlabeledDataChannels, + datachannel.NewDataChannelWriter(dc, rawDC, t.params.DatachannelSlowThreshold), + ) + t.lock.Unlock() + + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "state=Closed") { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) + } + }() + }) + return nil +} + func (t *PCTransport) CreateDataChannelIfEmpty(dcLabel string, dci *webrtc.DataChannelInit) (label string, id uint16, existing bool, err error) { t.lock.RLock() var dcw *datachannel.DataChannelWriter[*webrtc.DataChannel] @@ -1073,7 +1134,7 @@ func (t *PCTransport) WriteRTCP(pkts []rtcp.Packet) error { return t.pc.WriteRTCP(pkts) } -func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { +func (t *PCTransport) SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error { var dc *datachannel.DataChannelWriter[*webrtc.DataChannel] t.lock.RLock() if kind == livekit.DataPacket_RELIABLE { @@ -1083,6 +1144,22 @@ func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byt } t.lock.RUnlock() + return t.sendDataMessage(dc, data) +} + +func (t *PCTransport) SendDataMessageUnlabeled(data []byte) error { + var dc *datachannel.DataChannelWriter[*webrtc.DataChannel] + t.lock.RLock() + if len(t.unlabeledDataChannels) > 0 { + // use the first unlabeled to send + dc = t.unlabeledDataChannels[0] + } + t.lock.RUnlock() + + return t.sendDataMessage(dc, data) +} + +func (t *PCTransport) sendDataMessage(dc *datachannel.DataChannelWriter[*webrtc.DataChannel], data []byte) error { if dc == nil { return ErrDataChannelUnavailable } @@ -1094,7 +1171,7 @@ func (t *PCTransport) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byt if t.params.DatachannelSlowThreshold == 0 && t.params.DataChannelMaxBufferedAmount > 0 && dc.BufferedAmountGetter().BufferedAmount() > t.params.DataChannelMaxBufferedAmount { return ErrDataChannelBufferFull } - _, err := dc.Write(encoded) + _, err := dc.Write(data) return err } @@ -1129,6 +1206,11 @@ func (t *PCTransport) Close() { t.lossyDC = nil } + for _, dc := range t.unlabeledDataChannels { + dc.Close() + } + t.unlabeledDataChannels = nil + if t.mayFailedICEStatsTimer != nil { t.mayFailedICEStatsTimer.Stop() t.mayFailedICEStatsTimer = nil diff --git a/pkg/rtc/transport/handler.go b/pkg/rtc/transport/handler.go index 067fee80a..5fa6129e0 100644 --- a/pkg/rtc/transport/handler.go +++ b/pkg/rtc/transport/handler.go @@ -39,7 +39,8 @@ type Handler interface { OnFullyEstablished() OnFailed(isShortLived bool, iceConnectionInfo *types.ICEConnectionInfo) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) - OnDataPacket(kind livekit.DataPacket_Kind, data []byte) + OnDataMessage(kind livekit.DataPacket_Kind, data []byte) + OnDataMessageUnlabeled(data []byte) OnDataSendError(err error) OnOffer(sd webrtc.SessionDescription) error OnAnswer(sd webrtc.SessionDescription) error @@ -57,7 +58,8 @@ func (h UnimplementedHandler) OnInitialConnected() func (h UnimplementedHandler) OnFullyEstablished() {} func (h UnimplementedHandler) OnFailed(isShortLived bool) {} func (h UnimplementedHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) {} -func (h UnimplementedHandler) OnDataPacket(kind livekit.DataPacket_Kind, data []byte) {} +func (h UnimplementedHandler) OnDataMessage(kind livekit.DataPacket_Kind, data []byte) {} +func (h UnimplementedHandler) OnDataMessageUnlabeled(data []byte) {} func (h UnimplementedHandler) OnDataSendError(err error) {} func (h UnimplementedHandler) OnOffer(sd webrtc.SessionDescription) error { return ErrNoOfferHandler diff --git a/pkg/rtc/transport/transportfakes/fake_handler.go b/pkg/rtc/transport/transportfakes/fake_handler.go index a3d243d9d..5af37276b 100644 --- a/pkg/rtc/transport/transportfakes/fake_handler.go +++ b/pkg/rtc/transport/transportfakes/fake_handler.go @@ -23,12 +23,17 @@ type FakeHandler struct { onAnswerReturnsOnCall map[int]struct { result1 error } - OnDataPacketStub func(livekit.DataPacket_Kind, []byte) - onDataPacketMutex sync.RWMutex - onDataPacketArgsForCall []struct { + OnDataMessageStub func(livekit.DataPacket_Kind, []byte) + onDataMessageMutex sync.RWMutex + onDataMessageArgsForCall []struct { arg1 livekit.DataPacket_Kind arg2 []byte } + OnDataMessageUnlabeledStub func([]byte) + onDataMessageUnlabeledMutex sync.RWMutex + onDataMessageUnlabeledArgsForCall []struct { + arg1 []byte + } OnDataSendErrorStub func(error) onDataSendErrorMutex sync.RWMutex onDataSendErrorArgsForCall []struct { @@ -162,44 +167,81 @@ func (fake *FakeHandler) OnAnswerReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeHandler) OnDataPacket(arg1 livekit.DataPacket_Kind, arg2 []byte) { +func (fake *FakeHandler) OnDataMessage(arg1 livekit.DataPacket_Kind, arg2 []byte) { var arg2Copy []byte if arg2 != nil { arg2Copy = make([]byte, len(arg2)) copy(arg2Copy, arg2) } - fake.onDataPacketMutex.Lock() - fake.onDataPacketArgsForCall = append(fake.onDataPacketArgsForCall, struct { + fake.onDataMessageMutex.Lock() + fake.onDataMessageArgsForCall = append(fake.onDataMessageArgsForCall, struct { arg1 livekit.DataPacket_Kind arg2 []byte }{arg1, arg2Copy}) - stub := fake.OnDataPacketStub - fake.recordInvocation("OnDataPacket", []interface{}{arg1, arg2Copy}) - fake.onDataPacketMutex.Unlock() + stub := fake.OnDataMessageStub + fake.recordInvocation("OnDataMessage", []interface{}{arg1, arg2Copy}) + fake.onDataMessageMutex.Unlock() if stub != nil { - fake.OnDataPacketStub(arg1, arg2) + fake.OnDataMessageStub(arg1, arg2) } } -func (fake *FakeHandler) OnDataPacketCallCount() int { - fake.onDataPacketMutex.RLock() - defer fake.onDataPacketMutex.RUnlock() - return len(fake.onDataPacketArgsForCall) +func (fake *FakeHandler) OnDataMessageCallCount() int { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + return len(fake.onDataMessageArgsForCall) } -func (fake *FakeHandler) OnDataPacketCalls(stub func(livekit.DataPacket_Kind, []byte)) { - fake.onDataPacketMutex.Lock() - defer fake.onDataPacketMutex.Unlock() - fake.OnDataPacketStub = stub +func (fake *FakeHandler) OnDataMessageCalls(stub func(livekit.DataPacket_Kind, []byte)) { + fake.onDataMessageMutex.Lock() + defer fake.onDataMessageMutex.Unlock() + fake.OnDataMessageStub = stub } -func (fake *FakeHandler) OnDataPacketArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { - fake.onDataPacketMutex.RLock() - defer fake.onDataPacketMutex.RUnlock() - argsForCall := fake.onDataPacketArgsForCall[i] +func (fake *FakeHandler) OnDataMessageArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + argsForCall := fake.onDataMessageArgsForCall[i] return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeHandler) OnDataMessageUnlabeled(arg1 []byte) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.onDataMessageUnlabeledMutex.Lock() + fake.onDataMessageUnlabeledArgsForCall = append(fake.onDataMessageUnlabeledArgsForCall, struct { + arg1 []byte + }{arg1Copy}) + stub := fake.OnDataMessageUnlabeledStub + fake.recordInvocation("OnDataMessageUnlabeled", []interface{}{arg1Copy}) + fake.onDataMessageUnlabeledMutex.Unlock() + if stub != nil { + fake.OnDataMessageUnlabeledStub(arg1) + } +} + +func (fake *FakeHandler) OnDataMessageUnlabeledCallCount() int { + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() + return len(fake.onDataMessageUnlabeledArgsForCall) +} + +func (fake *FakeHandler) OnDataMessageUnlabeledCalls(stub func([]byte)) { + fake.onDataMessageUnlabeledMutex.Lock() + defer fake.onDataMessageUnlabeledMutex.Unlock() + fake.OnDataMessageUnlabeledStub = stub +} + +func (fake *FakeHandler) OnDataMessageUnlabeledArgsForCall(i int) []byte { + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() + argsForCall := fake.onDataMessageUnlabeledArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeHandler) OnDataSendError(arg1 error) { fake.onDataSendErrorMutex.Lock() fake.onDataSendErrorArgsForCall = append(fake.onDataSendErrorArgsForCall, struct { @@ -591,8 +633,10 @@ func (fake *FakeHandler) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.onAnswerMutex.RLock() defer fake.onAnswerMutex.RUnlock() - fake.onDataPacketMutex.RLock() - defer fake.onDataPacketMutex.RUnlock() + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() fake.onDataSendErrorMutex.RLock() defer fake.onDataSendErrorMutex.RUnlock() fake.onFailedMutex.RLock() diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index bdf010431..671ed6fbf 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -297,25 +297,49 @@ func (t *TransportManager) RemoveSubscribedTrack(subTrack types.SubscribedTrack) t.subscriber.RemoveTrackFromStreamAllocator(subTrack) } -func (t *TransportManager) SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error { +func (t *TransportManager) SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error { // downstream data is sent via primary peer connection - err := t.getTransport(true).SendDataPacket(kind, encoded) + return t.handleSendDataResult(t.getTransport(true).SendDataMessage(kind, data), kind.String(), len(data)) +} + +func (t *TransportManager) SendDataMessageUnlabeled(data []byte) error { + // downstream data is sent via primary peer connection + return t.handleSendDataResult(t.getTransport(true).SendDataMessageUnlabeled(data), "unlabeled", len(data)) +} + +func (t *TransportManager) handleSendDataResult(err error, kind string, size int) error { if err != nil { - if !utils.ErrorIsOneOf(err, io.ErrClosedPipe, sctp.ErrStreamClosed, ErrTransportFailure, ErrDataChannelBufferFull, context.DeadlineExceeded) { + if !utils.ErrorIsOneOf( + err, + io.ErrClosedPipe, + sctp.ErrStreamClosed, + ErrTransportFailure, + ErrDataChannelBufferFull, + context.DeadlineExceeded, + ) { if errors.Is(err, datachannel.ErrDataDroppedBySlowReader) { droppedBySlowReaderCount := t.droppedBySlowReaderCount.Inc() if (droppedBySlowReaderCount-1)%100 == 0 { - t.params.Logger.Infow("drop data packet by slow reader", "error", err, "kind", kind, "count", droppedBySlowReaderCount) + t.params.Logger.Infow( + "drop data message by slow reader", + "error", err, + "kind", kind, + "count", droppedBySlowReaderCount, + ) } } else { - t.params.Logger.Warnw("send data packet error", err) + t.params.Logger.Warnw("send data message error", err) } } if utils.ErrorIsOneOf(err, sctp.ErrStreamClosed, io.ErrClosedPipe) { - t.params.SubscriberHandler.OnDataSendError(err) + if t.params.SubscriberAsPrimary { + t.params.SubscriberHandler.OnDataSendError(err) + } else { + t.params.PublisherHandler.OnDataSendError(err) + } } } else { - t.params.DataChannelStats.AddBytes(uint64(len(encoded)), true) + t.params.DataChannelStats.AddBytes(uint64(size), true) } return err diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index d44bb329c..c840681db 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -403,7 +403,8 @@ type LocalParticipant interface { SendJoinResponse(joinResponse *livekit.JoinResponse) error SendParticipantUpdate(participants []*livekit.ParticipantInfo) error SendSpeakerUpdate(speakers []*livekit.SpeakerInfo, force bool) error - SendDataPacket(kind livekit.DataPacket_Kind, encoded []byte) error + SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error + SendDataMessageUnlabeled(data []byte) error SendRoomUpdate(room *livekit.Room) error SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error SubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) @@ -424,6 +425,7 @@ type LocalParticipant interface { // OnParticipantUpdate - metadata or permission is updated OnParticipantUpdate(callback func(LocalParticipant)) OnDataPacket(callback func(LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) + OnDataMessage(callback func(LocalParticipant, []byte)) OnSubscribeStatusChanged(fn func(publisherID livekit.ParticipantID, subscribed bool)) OnClose(callback func(LocalParticipant)) OnClaimsChanged(callback func(LocalParticipant)) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index f71d519c3..e0fbfa826 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -696,6 +696,11 @@ type FakeLocalParticipant struct { onCloseArgsForCall []struct { arg1 func(types.LocalParticipant) } + OnDataMessageStub func(func(types.LocalParticipant, []byte)) + onDataMessageMutex sync.RWMutex + onDataMessageArgsForCall []struct { + arg1 func(types.LocalParticipant, []byte) + } OnDataPacketStub func(func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) onDataPacketMutex sync.RWMutex onDataPacketArgsForCall []struct { @@ -785,16 +790,27 @@ type FakeLocalParticipant struct { sendConnectionQualityUpdateReturnsOnCall map[int]struct { result1 error } - SendDataPacketStub func(livekit.DataPacket_Kind, []byte) error - sendDataPacketMutex sync.RWMutex - sendDataPacketArgsForCall []struct { + SendDataMessageStub func(livekit.DataPacket_Kind, []byte) error + sendDataMessageMutex sync.RWMutex + sendDataMessageArgsForCall []struct { arg1 livekit.DataPacket_Kind arg2 []byte } - sendDataPacketReturns struct { + sendDataMessageReturns struct { result1 error } - sendDataPacketReturnsOnCall map[int]struct { + sendDataMessageReturnsOnCall map[int]struct { + result1 error + } + SendDataMessageUnlabeledStub func([]byte) error + sendDataMessageUnlabeledMutex sync.RWMutex + sendDataMessageUnlabeledArgsForCall []struct { + arg1 []byte + } + sendDataMessageUnlabeledReturns struct { + result1 error + } + sendDataMessageUnlabeledReturnsOnCall map[int]struct { result1 error } SendJoinResponseStub func(*livekit.JoinResponse) error @@ -4767,6 +4783,38 @@ func (fake *FakeLocalParticipant) OnCloseArgsForCall(i int) func(types.LocalPart return argsForCall.arg1 } +func (fake *FakeLocalParticipant) OnDataMessage(arg1 func(types.LocalParticipant, []byte)) { + fake.onDataMessageMutex.Lock() + fake.onDataMessageArgsForCall = append(fake.onDataMessageArgsForCall, struct { + arg1 func(types.LocalParticipant, []byte) + }{arg1}) + stub := fake.OnDataMessageStub + fake.recordInvocation("OnDataMessage", []interface{}{arg1}) + fake.onDataMessageMutex.Unlock() + if stub != nil { + fake.OnDataMessageStub(arg1) + } +} + +func (fake *FakeLocalParticipant) OnDataMessageCallCount() int { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + return len(fake.onDataMessageArgsForCall) +} + +func (fake *FakeLocalParticipant) OnDataMessageCalls(stub func(func(types.LocalParticipant, []byte))) { + fake.onDataMessageMutex.Lock() + defer fake.onDataMessageMutex.Unlock() + fake.OnDataMessageStub = stub +} + +func (fake *FakeLocalParticipant) OnDataMessageArgsForCall(i int) func(types.LocalParticipant, []byte) { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + argsForCall := fake.onDataMessageArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalParticipant) OnDataPacket(arg1 func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) { fake.onDataPacketMutex.Lock() fake.onDataPacketArgsForCall = append(fake.onDataPacketArgsForCall, struct { @@ -5296,22 +5344,22 @@ func (fake *FakeLocalParticipant) SendConnectionQualityUpdateReturnsOnCall(i int }{result1} } -func (fake *FakeLocalParticipant) SendDataPacket(arg1 livekit.DataPacket_Kind, arg2 []byte) error { +func (fake *FakeLocalParticipant) SendDataMessage(arg1 livekit.DataPacket_Kind, arg2 []byte) error { var arg2Copy []byte if arg2 != nil { arg2Copy = make([]byte, len(arg2)) copy(arg2Copy, arg2) } - fake.sendDataPacketMutex.Lock() - ret, specificReturn := fake.sendDataPacketReturnsOnCall[len(fake.sendDataPacketArgsForCall)] - fake.sendDataPacketArgsForCall = append(fake.sendDataPacketArgsForCall, struct { + fake.sendDataMessageMutex.Lock() + ret, specificReturn := fake.sendDataMessageReturnsOnCall[len(fake.sendDataMessageArgsForCall)] + fake.sendDataMessageArgsForCall = append(fake.sendDataMessageArgsForCall, struct { arg1 livekit.DataPacket_Kind arg2 []byte }{arg1, arg2Copy}) - stub := fake.SendDataPacketStub - fakeReturns := fake.sendDataPacketReturns - fake.recordInvocation("SendDataPacket", []interface{}{arg1, arg2Copy}) - fake.sendDataPacketMutex.Unlock() + stub := fake.SendDataMessageStub + fakeReturns := fake.sendDataMessageReturns + fake.recordInvocation("SendDataMessage", []interface{}{arg1, arg2Copy}) + fake.sendDataMessageMutex.Unlock() if stub != nil { return stub(arg1, arg2) } @@ -5321,44 +5369,110 @@ func (fake *FakeLocalParticipant) SendDataPacket(arg1 livekit.DataPacket_Kind, a return fakeReturns.result1 } -func (fake *FakeLocalParticipant) SendDataPacketCallCount() int { - fake.sendDataPacketMutex.RLock() - defer fake.sendDataPacketMutex.RUnlock() - return len(fake.sendDataPacketArgsForCall) +func (fake *FakeLocalParticipant) SendDataMessageCallCount() int { + fake.sendDataMessageMutex.RLock() + defer fake.sendDataMessageMutex.RUnlock() + return len(fake.sendDataMessageArgsForCall) } -func (fake *FakeLocalParticipant) SendDataPacketCalls(stub func(livekit.DataPacket_Kind, []byte) error) { - fake.sendDataPacketMutex.Lock() - defer fake.sendDataPacketMutex.Unlock() - fake.SendDataPacketStub = stub +func (fake *FakeLocalParticipant) SendDataMessageCalls(stub func(livekit.DataPacket_Kind, []byte) error) { + fake.sendDataMessageMutex.Lock() + defer fake.sendDataMessageMutex.Unlock() + fake.SendDataMessageStub = stub } -func (fake *FakeLocalParticipant) SendDataPacketArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { - fake.sendDataPacketMutex.RLock() - defer fake.sendDataPacketMutex.RUnlock() - argsForCall := fake.sendDataPacketArgsForCall[i] +func (fake *FakeLocalParticipant) SendDataMessageArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { + fake.sendDataMessageMutex.RLock() + defer fake.sendDataMessageMutex.RUnlock() + argsForCall := fake.sendDataMessageArgsForCall[i] return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeLocalParticipant) SendDataPacketReturns(result1 error) { - fake.sendDataPacketMutex.Lock() - defer fake.sendDataPacketMutex.Unlock() - fake.SendDataPacketStub = nil - fake.sendDataPacketReturns = struct { +func (fake *FakeLocalParticipant) SendDataMessageReturns(result1 error) { + fake.sendDataMessageMutex.Lock() + defer fake.sendDataMessageMutex.Unlock() + fake.SendDataMessageStub = nil + fake.sendDataMessageReturns = struct { result1 error }{result1} } -func (fake *FakeLocalParticipant) SendDataPacketReturnsOnCall(i int, result1 error) { - fake.sendDataPacketMutex.Lock() - defer fake.sendDataPacketMutex.Unlock() - fake.SendDataPacketStub = nil - if fake.sendDataPacketReturnsOnCall == nil { - fake.sendDataPacketReturnsOnCall = make(map[int]struct { +func (fake *FakeLocalParticipant) SendDataMessageReturnsOnCall(i int, result1 error) { + fake.sendDataMessageMutex.Lock() + defer fake.sendDataMessageMutex.Unlock() + fake.SendDataMessageStub = nil + if fake.sendDataMessageReturnsOnCall == nil { + fake.sendDataMessageReturnsOnCall = make(map[int]struct { result1 error }) } - fake.sendDataPacketReturnsOnCall[i] = struct { + fake.sendDataMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeled(arg1 []byte) error { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.sendDataMessageUnlabeledMutex.Lock() + ret, specificReturn := fake.sendDataMessageUnlabeledReturnsOnCall[len(fake.sendDataMessageUnlabeledArgsForCall)] + fake.sendDataMessageUnlabeledArgsForCall = append(fake.sendDataMessageUnlabeledArgsForCall, struct { + arg1 []byte + }{arg1Copy}) + stub := fake.SendDataMessageUnlabeledStub + fakeReturns := fake.sendDataMessageUnlabeledReturns + fake.recordInvocation("SendDataMessageUnlabeled", []interface{}{arg1Copy}) + fake.sendDataMessageUnlabeledMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledCallCount() int { + fake.sendDataMessageUnlabeledMutex.RLock() + defer fake.sendDataMessageUnlabeledMutex.RUnlock() + return len(fake.sendDataMessageUnlabeledArgsForCall) +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledCalls(stub func([]byte) error) { + fake.sendDataMessageUnlabeledMutex.Lock() + defer fake.sendDataMessageUnlabeledMutex.Unlock() + fake.SendDataMessageUnlabeledStub = stub +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledArgsForCall(i int) []byte { + fake.sendDataMessageUnlabeledMutex.RLock() + defer fake.sendDataMessageUnlabeledMutex.RUnlock() + argsForCall := fake.sendDataMessageUnlabeledArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledReturns(result1 error) { + fake.sendDataMessageUnlabeledMutex.Lock() + defer fake.sendDataMessageUnlabeledMutex.Unlock() + fake.SendDataMessageUnlabeledStub = nil + fake.sendDataMessageUnlabeledReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledReturnsOnCall(i int, result1 error) { + fake.sendDataMessageUnlabeledMutex.Lock() + defer fake.sendDataMessageUnlabeledMutex.Unlock() + fake.SendDataMessageUnlabeledStub = nil + if fake.sendDataMessageUnlabeledReturnsOnCall == nil { + fake.sendDataMessageUnlabeledReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendDataMessageUnlabeledReturnsOnCall[i] = struct { result1 error }{result1} } @@ -7656,6 +7770,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.onClaimsChangedMutex.RUnlock() fake.onCloseMutex.RLock() defer fake.onCloseMutex.RUnlock() + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() fake.onDataPacketMutex.RLock() defer fake.onDataPacketMutex.RUnlock() fake.onICEConfigChangedMutex.RLock() @@ -7684,8 +7800,10 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.removeTrackLocalMutex.RUnlock() fake.sendConnectionQualityUpdateMutex.RLock() defer fake.sendConnectionQualityUpdateMutex.RUnlock() - fake.sendDataPacketMutex.RLock() - defer fake.sendDataPacketMutex.RUnlock() + fake.sendDataMessageMutex.RLock() + defer fake.sendDataMessageMutex.RUnlock() + fake.sendDataMessageUnlabeledMutex.RLock() + defer fake.sendDataMessageUnlabeledMutex.RUnlock() fake.sendJoinResponseMutex.RLock() defer fake.sendJoinResponseMutex.RUnlock() fake.sendParticipantUpdateMutex.RLock() diff --git a/pkg/service/servicefakes/fake_sipstore.go b/pkg/service/servicefakes/fake_sipstore.go index 055018fe0..bc914444c 100644 --- a/pkg/service/servicefakes/fake_sipstore.go +++ b/pkg/service/servicefakes/fake_sipstore.go @@ -146,34 +146,6 @@ type FakeSIPStore struct { result1 *livekit.SIPTrunkInfo result2 error } - SelectSIPDispatchRuleStub func(context.Context, string) ([]*livekit.SIPDispatchRuleInfo, error) - selectSIPDispatchRuleMutex sync.RWMutex - selectSIPDispatchRuleArgsForCall []struct { - arg1 context.Context - arg2 string - } - selectSIPDispatchRuleReturns struct { - result1 []*livekit.SIPDispatchRuleInfo - result2 error - } - selectSIPDispatchRuleReturnsOnCall map[int]struct { - result1 []*livekit.SIPDispatchRuleInfo - result2 error - } - SelectSIPInboundTrunkStub func(context.Context, string) ([]*livekit.SIPInboundTrunkInfo, error) - selectSIPInboundTrunkMutex sync.RWMutex - selectSIPInboundTrunkArgsForCall []struct { - arg1 context.Context - arg2 string - } - selectSIPInboundTrunkReturns struct { - result1 []*livekit.SIPInboundTrunkInfo - result2 error - } - selectSIPInboundTrunkReturnsOnCall map[int]struct { - result1 []*livekit.SIPInboundTrunkInfo - result2 error - } StoreSIPDispatchRuleStub func(context.Context, *livekit.SIPDispatchRuleInfo) error storeSIPDispatchRuleMutex sync.RWMutex storeSIPDispatchRuleArgsForCall []struct { @@ -870,136 +842,6 @@ func (fake *FakeSIPStore) LoadSIPTrunkReturnsOnCall(i int, result1 *livekit.SIPT }{result1, result2} } -func (fake *FakeSIPStore) SelectSIPDispatchRule(arg1 context.Context, arg2 string) ([]*livekit.SIPDispatchRuleInfo, error) { - fake.selectSIPDispatchRuleMutex.Lock() - ret, specificReturn := fake.selectSIPDispatchRuleReturnsOnCall[len(fake.selectSIPDispatchRuleArgsForCall)] - fake.selectSIPDispatchRuleArgsForCall = append(fake.selectSIPDispatchRuleArgsForCall, struct { - arg1 context.Context - arg2 string - }{arg1, arg2}) - stub := fake.SelectSIPDispatchRuleStub - fakeReturns := fake.selectSIPDispatchRuleReturns - fake.recordInvocation("SelectSIPDispatchRule", []interface{}{arg1, arg2}) - fake.selectSIPDispatchRuleMutex.Unlock() - if stub != nil { - return stub(arg1, arg2) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeSIPStore) SelectSIPDispatchRuleCallCount() int { - fake.selectSIPDispatchRuleMutex.RLock() - defer fake.selectSIPDispatchRuleMutex.RUnlock() - return len(fake.selectSIPDispatchRuleArgsForCall) -} - -func (fake *FakeSIPStore) SelectSIPDispatchRuleCalls(stub func(context.Context, string) ([]*livekit.SIPDispatchRuleInfo, error)) { - fake.selectSIPDispatchRuleMutex.Lock() - defer fake.selectSIPDispatchRuleMutex.Unlock() - fake.SelectSIPDispatchRuleStub = stub -} - -func (fake *FakeSIPStore) SelectSIPDispatchRuleArgsForCall(i int) (context.Context, string) { - fake.selectSIPDispatchRuleMutex.RLock() - defer fake.selectSIPDispatchRuleMutex.RUnlock() - argsForCall := fake.selectSIPDispatchRuleArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 -} - -func (fake *FakeSIPStore) SelectSIPDispatchRuleReturns(result1 []*livekit.SIPDispatchRuleInfo, result2 error) { - fake.selectSIPDispatchRuleMutex.Lock() - defer fake.selectSIPDispatchRuleMutex.Unlock() - fake.SelectSIPDispatchRuleStub = nil - fake.selectSIPDispatchRuleReturns = struct { - result1 []*livekit.SIPDispatchRuleInfo - result2 error - }{result1, result2} -} - -func (fake *FakeSIPStore) SelectSIPDispatchRuleReturnsOnCall(i int, result1 []*livekit.SIPDispatchRuleInfo, result2 error) { - fake.selectSIPDispatchRuleMutex.Lock() - defer fake.selectSIPDispatchRuleMutex.Unlock() - fake.SelectSIPDispatchRuleStub = nil - if fake.selectSIPDispatchRuleReturnsOnCall == nil { - fake.selectSIPDispatchRuleReturnsOnCall = make(map[int]struct { - result1 []*livekit.SIPDispatchRuleInfo - result2 error - }) - } - fake.selectSIPDispatchRuleReturnsOnCall[i] = struct { - result1 []*livekit.SIPDispatchRuleInfo - result2 error - }{result1, result2} -} - -func (fake *FakeSIPStore) SelectSIPInboundTrunk(arg1 context.Context, arg2 string) ([]*livekit.SIPInboundTrunkInfo, error) { - fake.selectSIPInboundTrunkMutex.Lock() - ret, specificReturn := fake.selectSIPInboundTrunkReturnsOnCall[len(fake.selectSIPInboundTrunkArgsForCall)] - fake.selectSIPInboundTrunkArgsForCall = append(fake.selectSIPInboundTrunkArgsForCall, struct { - arg1 context.Context - arg2 string - }{arg1, arg2}) - stub := fake.SelectSIPInboundTrunkStub - fakeReturns := fake.selectSIPInboundTrunkReturns - fake.recordInvocation("SelectSIPInboundTrunk", []interface{}{arg1, arg2}) - fake.selectSIPInboundTrunkMutex.Unlock() - if stub != nil { - return stub(arg1, arg2) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeSIPStore) SelectSIPInboundTrunkCallCount() int { - fake.selectSIPInboundTrunkMutex.RLock() - defer fake.selectSIPInboundTrunkMutex.RUnlock() - return len(fake.selectSIPInboundTrunkArgsForCall) -} - -func (fake *FakeSIPStore) SelectSIPInboundTrunkCalls(stub func(context.Context, string) ([]*livekit.SIPInboundTrunkInfo, error)) { - fake.selectSIPInboundTrunkMutex.Lock() - defer fake.selectSIPInboundTrunkMutex.Unlock() - fake.SelectSIPInboundTrunkStub = stub -} - -func (fake *FakeSIPStore) SelectSIPInboundTrunkArgsForCall(i int) (context.Context, string) { - fake.selectSIPInboundTrunkMutex.RLock() - defer fake.selectSIPInboundTrunkMutex.RUnlock() - argsForCall := fake.selectSIPInboundTrunkArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 -} - -func (fake *FakeSIPStore) SelectSIPInboundTrunkReturns(result1 []*livekit.SIPInboundTrunkInfo, result2 error) { - fake.selectSIPInboundTrunkMutex.Lock() - defer fake.selectSIPInboundTrunkMutex.Unlock() - fake.SelectSIPInboundTrunkStub = nil - fake.selectSIPInboundTrunkReturns = struct { - result1 []*livekit.SIPInboundTrunkInfo - result2 error - }{result1, result2} -} - -func (fake *FakeSIPStore) SelectSIPInboundTrunkReturnsOnCall(i int, result1 []*livekit.SIPInboundTrunkInfo, result2 error) { - fake.selectSIPInboundTrunkMutex.Lock() - defer fake.selectSIPInboundTrunkMutex.Unlock() - fake.SelectSIPInboundTrunkStub = nil - if fake.selectSIPInboundTrunkReturnsOnCall == nil { - fake.selectSIPInboundTrunkReturnsOnCall = make(map[int]struct { - result1 []*livekit.SIPInboundTrunkInfo - result2 error - }) - } - fake.selectSIPInboundTrunkReturnsOnCall[i] = struct { - result1 []*livekit.SIPInboundTrunkInfo - result2 error - }{result1, result2} -} - func (fake *FakeSIPStore) StoreSIPDispatchRule(arg1 context.Context, arg2 *livekit.SIPDispatchRuleInfo) error { fake.storeSIPDispatchRuleMutex.Lock() ret, specificReturn := fake.storeSIPDispatchRuleReturnsOnCall[len(fake.storeSIPDispatchRuleArgsForCall)] @@ -1271,10 +1113,6 @@ func (fake *FakeSIPStore) Invocations() map[string][][]interface{} { defer fake.loadSIPOutboundTrunkMutex.RUnlock() fake.loadSIPTrunkMutex.RLock() defer fake.loadSIPTrunkMutex.RUnlock() - fake.selectSIPDispatchRuleMutex.RLock() - defer fake.selectSIPDispatchRuleMutex.RUnlock() - fake.selectSIPInboundTrunkMutex.RLock() - defer fake.selectSIPInboundTrunkMutex.RUnlock() fake.storeSIPDispatchRuleMutex.RLock() defer fake.storeSIPDispatchRuleMutex.RUnlock() fake.storeSIPInboundTrunkMutex.RLock() diff --git a/test/client/client.go b/test/client/client.go index 777c5a1e0..3b2d17464 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -81,10 +81,11 @@ type RTCClient struct { // tracks waiting to be acked, cid => trackInfo pendingPublishedTracks map[string]*livekit.TrackInfo - pendingTrackWriters []*TrackWriter - OnConnected func() - OnDataReceived func(data []byte, sid string) - refreshToken string + pendingTrackWriters []*TrackWriter + OnConnected func() + OnDataReceived func(data []byte, sid string) + OnDataUnlabeledReceived func(data []byte) + refreshToken string // map of livekit.ParticipantID and last packet lastPackets map[livekit.ParticipantID]*rtp.Packet @@ -262,6 +263,18 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { return nil, err } + if err := c.publisher.CreateDataChannel("pubraw", &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { + return nil, err + } + + if err := c.subscriber.CreateReadableDataChannel("subraw", &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { + return nil, err + } + subscriberHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { if ic == nil { return nil @@ -271,7 +284,8 @@ func NewRTCClient(conn *websocket.Conn, opts *Options) (*RTCClient, error) { subscriberHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { go c.processTrack(track) }) - subscriberHandler.OnDataPacketCalls(c.handleDataMessage) + subscriberHandler.OnDataMessageCalls(c.handleDataMessage) + subscriberHandler.OnDataMessageUnlabeledCalls(c.handleDataMessageUnlabeled) subscriberHandler.OnInitialConnectedCalls(func() { logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) @@ -732,7 +746,16 @@ func (c *RTCClient) PublishData(data []byte, kind livekit.DataPacket_Kind) error return err } - return c.publisher.SendDataPacket(kind, dpData) + return c.publisher.SendDataMessage(kind, dpData) +} + +func (c *RTCClient) PublishDataUnlabeled(data []byte) error { + if err := c.ensurePublisherConnected(); err != nil { + return err + } + + fmt.Printf("RAJA sending unlabeled data: %s\n", string(data)) // REMOVE + return c.publisher.SendDataMessageUnlabeled(data) } func (c *RTCClient) GetPublishedTrackIDs() []string { @@ -787,6 +810,12 @@ func (c *RTCClient) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) } } +func (c *RTCClient) handleDataMessageUnlabeled(data []byte) { + if c.OnDataUnlabeledReceived != nil { + c.OnDataUnlabeledReceived(data) + } +} + // handles a server initiated offer, handle on subscriber PC func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) { c.subscriber.HandleRemoteDescription(desc) diff --git a/test/multinode_test.go b/test/multinode_test.go index e964f3096..e79c8316c 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -154,6 +154,7 @@ func TestMultinodeDataPublishing(t *testing.T) { defer finish() scenarioDataPublish(t) + scenarioDataUnlabeledPublish(t) } func TestMultiNodeJoinAfterClose(t *testing.T) { diff --git a/test/scenarios.go b/test/scenarios.go index d72578614..d1aa8f0b3 100644 --- a/test/scenarios.go +++ b/test/scenarios.go @@ -164,6 +164,33 @@ func scenarioDataPublish(t *testing.T) { }) } +func scenarioDataUnlabeledPublish(t *testing.T) { + c1 := createRTCClient("dp1", defaultServerPort, nil) + c2 := createRTCClient("dp2", secondServerPort, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) + + payload := "test unlabeled bytes" + + received := atomic.NewBool(false) + c2.OnDataUnlabeledReceived = func(data []byte) { + fmt.Printf("RAJA received data: message: %s, sid: %s\n", string(data), c1.ID()) // REMOVE + if string(data) == payload { + received.Store(true) + } + } + + require.NoError(t, c1.PublishDataUnlabeled([]byte(payload))) + + testutils.WithTimeout(t, func() string { + if received.Load() { + return "" + } else { + return "c2 did not receive published data unlabeled" + } + }) +} + func scenarioJoinClosedRoom(t *testing.T) { c1 := createRTCClient("jcr1", defaultServerPort, nil) waitUntilConnected(t, c1)