diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index c55faf536..11f306a54 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -285,7 +285,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra sfu.WithLoadBalanceThreshold(20), sfu.WithStreamTrackers(), sfu.WithForwardStats(t.params.ForwardStats), - sfu.WithEverHasDownTrackAdded(t.handleReceiverEverAddDowntrack), ) newWR.OnCloseHandler(func() { t.MediaTrackReceiver.SetClosing() @@ -434,7 +433,7 @@ func (t *MediaTrack) SetMuted(muted bool) { t.MediaTrackReceiver.SetMuted(muted) } -func (t *MediaTrack) handleReceiverEverAddDowntrack() { +func (t *MediaTrack) OnTrackSubscribed() { if !t.everSubscribed.Swap(true) && t.params.OnTrackEverSubscribed != nil { go t.params.OnTrackEverSubscribed(t.ID()) } diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index acd85037d..20b73db55 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -161,6 +161,12 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * AdaptiveStream: sub.GetAdaptiveStream(), }) + subTrack.AddOnBind(func(err error) { + if err == nil { + t.params.MediaTrack.OnTrackSubscribed() + } + }) + // Bind callback can happen from replaceTrack, so set it up early var reusingTransceiver atomic.Bool var dtState sfu.DownTrackState diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index e3d25a59d..1470e58b8 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -495,6 +495,7 @@ type MediaTrack interface { RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity GetAllSubscribers() []livekit.ParticipantID GetNumSubscribers() int + OnTrackSubscribed() // returns quality information that's appropriate for width & height GetQualityForDimension(width, height uint32) livekit.VideoQuality diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 1bf6297b4..382c609ae 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -221,6 +221,10 @@ type FakeLocalMediaTrack struct { arg1 livekit.NodeID arg2 uint8 } + OnTrackSubscribedStub func() + onTrackSubscribedMutex sync.RWMutex + onTrackSubscribedArgsForCall []struct { + } PublisherIDStub func() livekit.ParticipantID publisherIDMutex sync.RWMutex publisherIDArgsForCall []struct { @@ -1471,6 +1475,30 @@ func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLossArgsForCall(i int) return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeLocalMediaTrack) OnTrackSubscribed() { + fake.onTrackSubscribedMutex.Lock() + fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct { + }{}) + stub := fake.OnTrackSubscribedStub + fake.recordInvocation("OnTrackSubscribed", []interface{}{}) + fake.onTrackSubscribedMutex.Unlock() + if stub != nil { + fake.OnTrackSubscribedStub() + } +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribedCallCount() int { + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() + return len(fake.onTrackSubscribedArgsForCall) +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribedCalls(stub func()) { + fake.onTrackSubscribedMutex.Lock() + defer fake.onTrackSubscribedMutex.Unlock() + fake.OnTrackSubscribedStub = stub +} + func (fake *FakeLocalMediaTrack) PublisherID() livekit.ParticipantID { fake.publisherIDMutex.Lock() ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] @@ -2225,6 +2253,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() fake.notifySubscriberNodeMediaLossMutex.RLock() defer fake.notifySubscriberNodeMediaLossMutex.RUnlock() + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() fake.publisherIDMutex.RLock() defer fake.publisherIDMutex.RUnlock() fake.publisherIdentityMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 887646a18..ce6e9355e 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -176,6 +176,10 @@ type FakeMediaTrack struct { nameReturnsOnCall map[int]struct { result1 string } + OnTrackSubscribedStub func() + onTrackSubscribedMutex sync.RWMutex + onTrackSubscribedArgsForCall []struct { + } PublisherIDStub func() livekit.ParticipantID publisherIDMutex sync.RWMutex publisherIDArgsForCall []struct { @@ -1166,6 +1170,30 @@ func (fake *FakeMediaTrack) NameReturnsOnCall(i int, result1 string) { }{result1} } +func (fake *FakeMediaTrack) OnTrackSubscribed() { + fake.onTrackSubscribedMutex.Lock() + fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct { + }{}) + stub := fake.OnTrackSubscribedStub + fake.recordInvocation("OnTrackSubscribed", []interface{}{}) + fake.onTrackSubscribedMutex.Unlock() + if stub != nil { + fake.OnTrackSubscribedStub() + } +} + +func (fake *FakeMediaTrack) OnTrackSubscribedCallCount() int { + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() + return len(fake.onTrackSubscribedArgsForCall) +} + +func (fake *FakeMediaTrack) OnTrackSubscribedCalls(stub func()) { + fake.onTrackSubscribedMutex.Lock() + defer fake.onTrackSubscribedMutex.Unlock() + fake.OnTrackSubscribedStub = stub +} + func (fake *FakeMediaTrack) PublisherID() livekit.ParticipantID { fake.publisherIDMutex.Lock() ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] @@ -1801,6 +1829,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.kindMutex.RUnlock() fake.nameMutex.RLock() defer fake.nameMutex.RUnlock() + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() fake.publisherIDMutex.RLock() defer fake.publisherIDMutex.RUnlock() fake.publisherIdentityMutex.RLock() diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index beb2fb06b..cc0241070 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -120,10 +120,8 @@ type WebRTCReceiver struct { connectionStats *connectionquality.ConnectionStats - onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) - onMaxLayerChange func(maxLayer int32) - downTrackEverAdded atomic.Bool - onDownTrackEverAdded func() + onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) + onMaxLayerChange func(maxLayer int32) primaryReceiver atomic.Pointer[RedPrimaryReceiver] redReceiver atomic.Pointer[RedReceiver] @@ -177,13 +175,6 @@ func WithForwardStats(forwardStats *ForwardStats) ReceiverOpts { } } -func WithEverHasDownTrackAdded(f func()) ReceiverOpts { - return func(w *WebRTCReceiver) *WebRTCReceiver { - w.onDownTrackEverAdded = f - return w - } -} - // NewWebRTCReceiver creates a new webrtc track receiver func NewWebRTCReceiver( receiver *webrtc.RTPReceiver, @@ -420,16 +411,9 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error { w.downTrackSpreader.Store(track) w.logger.Debugw("downtrack added", "subscriberID", track.SubscriberID()) - w.handleDowntrackAdded() return nil } -func (w *WebRTCReceiver) handleDowntrackAdded() { - if !w.downTrackEverAdded.Swap(true) && w.onDownTrackEverAdded != nil { - w.onDownTrackEverAdded() - } -} - func (w *WebRTCReceiver) notifyMaxExpectedLayer(layer int32) { ti := w.TrackInfo() if ti == nil { @@ -792,7 +776,6 @@ func (w *WebRTCReceiver) GetPrimaryReceiverForRed() TrackReceiver { w.bufferMu.Lock() w.redPktWriter = pr.ForwardRTP w.bufferMu.Unlock() - w.handleDowntrackAdded() } } return w.primaryReceiver.Load() @@ -812,7 +795,6 @@ func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver { w.bufferMu.Lock() w.redPktWriter = pr.ForwardRTP w.bufferMu.Unlock() - w.handleDowntrackAdded() } } return w.redReceiver.Load()