move TrackSubscribed trigger to MediaSubscription (#2916)

This commit is contained in:
cnderrauber
2024-08-07 22:30:52 +08:00
committed by GitHub
parent 2346c8a6b7
commit a8730b04b8
6 changed files with 70 additions and 22 deletions

View File

@@ -285,7 +285,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra
sfu.WithLoadBalanceThreshold(20),
sfu.WithStreamTrackers(),
sfu.WithForwardStats(t.params.ForwardStats),
sfu.WithEverHasDownTrackAdded(t.handleReceiverEverAddDowntrack),
)
newWR.OnCloseHandler(func() {
t.MediaTrackReceiver.SetClosing()
@@ -434,7 +433,7 @@ func (t *MediaTrack) SetMuted(muted bool) {
t.MediaTrackReceiver.SetMuted(muted)
}
func (t *MediaTrack) handleReceiverEverAddDowntrack() {
func (t *MediaTrack) OnTrackSubscribed() {
if !t.everSubscribed.Swap(true) && t.params.OnTrackEverSubscribed != nil {
go t.params.OnTrackEverSubscribed(t.ID())
}

View File

@@ -161,6 +161,12 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *
AdaptiveStream: sub.GetAdaptiveStream(),
})
subTrack.AddOnBind(func(err error) {
if err == nil {
t.params.MediaTrack.OnTrackSubscribed()
}
})
// Bind callback can happen from replaceTrack, so set it up early
var reusingTransceiver atomic.Bool
var dtState sfu.DownTrackState

View File

@@ -495,6 +495,7 @@ type MediaTrack interface {
RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity
GetAllSubscribers() []livekit.ParticipantID
GetNumSubscribers() int
OnTrackSubscribed()
// returns quality information that's appropriate for width & height
GetQualityForDimension(width, height uint32) livekit.VideoQuality

View File

@@ -221,6 +221,10 @@ type FakeLocalMediaTrack struct {
arg1 livekit.NodeID
arg2 uint8
}
OnTrackSubscribedStub func()
onTrackSubscribedMutex sync.RWMutex
onTrackSubscribedArgsForCall []struct {
}
PublisherIDStub func() livekit.ParticipantID
publisherIDMutex sync.RWMutex
publisherIDArgsForCall []struct {
@@ -1471,6 +1475,30 @@ func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLossArgsForCall(i int)
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeLocalMediaTrack) OnTrackSubscribed() {
fake.onTrackSubscribedMutex.Lock()
fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct {
}{})
stub := fake.OnTrackSubscribedStub
fake.recordInvocation("OnTrackSubscribed", []interface{}{})
fake.onTrackSubscribedMutex.Unlock()
if stub != nil {
fake.OnTrackSubscribedStub()
}
}
func (fake *FakeLocalMediaTrack) OnTrackSubscribedCallCount() int {
fake.onTrackSubscribedMutex.RLock()
defer fake.onTrackSubscribedMutex.RUnlock()
return len(fake.onTrackSubscribedArgsForCall)
}
func (fake *FakeLocalMediaTrack) OnTrackSubscribedCalls(stub func()) {
fake.onTrackSubscribedMutex.Lock()
defer fake.onTrackSubscribedMutex.Unlock()
fake.OnTrackSubscribedStub = stub
}
func (fake *FakeLocalMediaTrack) PublisherID() livekit.ParticipantID {
fake.publisherIDMutex.Lock()
ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)]
@@ -2225,6 +2253,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} {
defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock()
fake.notifySubscriberNodeMediaLossMutex.RLock()
defer fake.notifySubscriberNodeMediaLossMutex.RUnlock()
fake.onTrackSubscribedMutex.RLock()
defer fake.onTrackSubscribedMutex.RUnlock()
fake.publisherIDMutex.RLock()
defer fake.publisherIDMutex.RUnlock()
fake.publisherIdentityMutex.RLock()

View File

@@ -176,6 +176,10 @@ type FakeMediaTrack struct {
nameReturnsOnCall map[int]struct {
result1 string
}
OnTrackSubscribedStub func()
onTrackSubscribedMutex sync.RWMutex
onTrackSubscribedArgsForCall []struct {
}
PublisherIDStub func() livekit.ParticipantID
publisherIDMutex sync.RWMutex
publisherIDArgsForCall []struct {
@@ -1166,6 +1170,30 @@ func (fake *FakeMediaTrack) NameReturnsOnCall(i int, result1 string) {
}{result1}
}
func (fake *FakeMediaTrack) OnTrackSubscribed() {
fake.onTrackSubscribedMutex.Lock()
fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct {
}{})
stub := fake.OnTrackSubscribedStub
fake.recordInvocation("OnTrackSubscribed", []interface{}{})
fake.onTrackSubscribedMutex.Unlock()
if stub != nil {
fake.OnTrackSubscribedStub()
}
}
func (fake *FakeMediaTrack) OnTrackSubscribedCallCount() int {
fake.onTrackSubscribedMutex.RLock()
defer fake.onTrackSubscribedMutex.RUnlock()
return len(fake.onTrackSubscribedArgsForCall)
}
func (fake *FakeMediaTrack) OnTrackSubscribedCalls(stub func()) {
fake.onTrackSubscribedMutex.Lock()
defer fake.onTrackSubscribedMutex.Unlock()
fake.OnTrackSubscribedStub = stub
}
func (fake *FakeMediaTrack) PublisherID() livekit.ParticipantID {
fake.publisherIDMutex.Lock()
ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)]
@@ -1801,6 +1829,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} {
defer fake.kindMutex.RUnlock()
fake.nameMutex.RLock()
defer fake.nameMutex.RUnlock()
fake.onTrackSubscribedMutex.RLock()
defer fake.onTrackSubscribedMutex.RUnlock()
fake.publisherIDMutex.RLock()
defer fake.publisherIDMutex.RUnlock()
fake.publisherIdentityMutex.RLock()

View File

@@ -120,10 +120,8 @@ type WebRTCReceiver struct {
connectionStats *connectionquality.ConnectionStats
onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat)
onMaxLayerChange func(maxLayer int32)
downTrackEverAdded atomic.Bool
onDownTrackEverAdded func()
onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat)
onMaxLayerChange func(maxLayer int32)
primaryReceiver atomic.Pointer[RedPrimaryReceiver]
redReceiver atomic.Pointer[RedReceiver]
@@ -177,13 +175,6 @@ func WithForwardStats(forwardStats *ForwardStats) ReceiverOpts {
}
}
func WithEverHasDownTrackAdded(f func()) ReceiverOpts {
return func(w *WebRTCReceiver) *WebRTCReceiver {
w.onDownTrackEverAdded = f
return w
}
}
// NewWebRTCReceiver creates a new webrtc track receiver
func NewWebRTCReceiver(
receiver *webrtc.RTPReceiver,
@@ -420,16 +411,9 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error {
w.downTrackSpreader.Store(track)
w.logger.Debugw("downtrack added", "subscriberID", track.SubscriberID())
w.handleDowntrackAdded()
return nil
}
func (w *WebRTCReceiver) handleDowntrackAdded() {
if !w.downTrackEverAdded.Swap(true) && w.onDownTrackEverAdded != nil {
w.onDownTrackEverAdded()
}
}
func (w *WebRTCReceiver) notifyMaxExpectedLayer(layer int32) {
ti := w.TrackInfo()
if ti == nil {
@@ -792,7 +776,6 @@ func (w *WebRTCReceiver) GetPrimaryReceiverForRed() TrackReceiver {
w.bufferMu.Lock()
w.redPktWriter = pr.ForwardRTP
w.bufferMu.Unlock()
w.handleDowntrackAdded()
}
}
return w.primaryReceiver.Load()
@@ -812,7 +795,6 @@ func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver {
w.bufferMu.Lock()
w.redPktWriter = pr.ForwardRTP
w.bufferMu.Unlock()
w.handleDowntrackAdded()
}
}
return w.redReceiver.Load()