From ef66404a1a73e5044fb153cbf725fd7ef181a60b Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Wed, 6 Jul 2022 23:48:28 +0530 Subject: [PATCH] Keep track of pending subscriber operations. (#814) * Keep track of pending subscriber operations. This is required to determine if a receiver does not have any subscription. * correct spelling of queuing * lock around hasPermission --- pkg/rtc/mediatrackreceiver.go | 93 ++++++++++++++++++++++++------ pkg/rtc/mediatracksubscriptions.go | 10 ++-- pkg/rtc/participant.go | 12 ++-- pkg/rtc/uptrackmanager.go | 6 +- 4 files changed, 90 insertions(+), 31 deletions(-) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index d77e8bdb8..6d9d8dc65 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -28,6 +28,10 @@ const ( layerSelectionTolerance = 0.9 ) +var ( + ErrNoReceiver = errors.New("cannot subscribe without a receiver in place") +) + type simulcastReceiver struct { sfu.TrackReceiver priority int @@ -57,12 +61,13 @@ type MediaTrackReceiver struct { muted atomic.Bool simulcasted atomic.Bool - lock sync.RWMutex - receivers []*simulcastReceiver - receiversShadow []*simulcastReceiver - trackInfo *livekit.TrackInfo - layerDimensions map[livekit.VideoQuality]*livekit.VideoLayer - potentialCodecs []webrtc.RTPCodecParameters + lock sync.RWMutex + receivers []*simulcastReceiver + receiversShadow []*simulcastReceiver + trackInfo *livekit.TrackInfo + layerDimensions map[livekit.VideoQuality]*livekit.VideoLayer + potentialCodecs []webrtc.RTPCodecParameters + pendingSubscribeOp map[livekit.ParticipantID]int // track audio fraction lost downFracLostLock sync.Mutex @@ -78,9 +83,10 @@ type MediaTrackReceiver struct { func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver { t := &MediaTrackReceiver{ - params: params, - trackInfo: proto.Clone(params.TrackInfo).(*livekit.TrackInfo), - layerDimensions: make(map[livekit.VideoQuality]*livekit.VideoLayer), + params: params, + trackInfo: proto.Clone(params.TrackInfo).(*livekit.TrackInfo), + layerDimensions: make(map[livekit.VideoQuality]*livekit.VideoLayer), + pendingSubscribeOp: make(map[livekit.ParticipantID]int), } t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ @@ -94,7 +100,8 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver }) t.MediaTrackSubscriptions.OnDownTrackCreated(t.onDownTrackCreated) t.MediaTrackSubscriptions.OnSubscriptionOperationComplete(func(sub types.LocalParticipant) { - go sub.ClearInProgressAndProcessSubscriptionRequestsQueue(t.ID()) + t.removePendingSubscribeOp(sub.ID()) + sub.ClearInProgressAndProcessSubscriptionRequestsQueue(t.ID()) }) if t.trackInfo.Muted { @@ -230,8 +237,7 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string) { } } - t.receiversShadow = make([]*simulcastReceiver, len(t.receivers)) - copy(t.receiversShadow, t.receivers) + t.shadowReceiversLocked() stopSubscription := len(t.receiversShadow) == 0 t.lock.Unlock() @@ -274,7 +280,7 @@ func (t *MediaTrackReceiver) Close() { onclose := t.onClose t.lock.RUnlock() - t.MediaTrackSubscriptions.Close() + t.MediaTrackSubscriptions.Stop() for _, f := range onclose { f() } @@ -355,14 +361,43 @@ func (t *MediaTrackReceiver) AddOnClose(f func()) { t.lock.Unlock() } +func (t *MediaTrackReceiver) addPendingSubscribeOp(subscriberID livekit.ParticipantID) { + t.lock.Lock() + if c, ok := t.pendingSubscribeOp[subscriberID]; !ok { + t.pendingSubscribeOp[subscriberID] = 1 + } else { + t.pendingSubscribeOp[subscriberID] = c + 1 + } + t.lock.Unlock() +} + +func (t *MediaTrackReceiver) removePendingSubscribeOp(subscriberID livekit.ParticipantID) { + t.lock.Lock() + if c, ok := t.pendingSubscribeOp[subscriberID]; ok { + t.pendingSubscribeOp[subscriberID] = c - 1 + if t.pendingSubscribeOp[subscriberID] == 0 { + delete(t.pendingSubscribeOp, subscriberID) + } + } + t.lock.Unlock() +} + // AddSubscriber subscribes sub to current mediaTrack func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) error { + t.addPendingSubscribeOp(sub.ID()) + trackID := t.ID() sub.EnqueueSubscribeTrack(trackID, t.addSubscriber) return nil } -func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) error { +func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) (err error) { + defer func() { + if err != nil { + t.removePendingSubscribeOp(sub.ID()) + } + }() + t.lock.RLock() receivers := t.receiversShadow potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs)) @@ -371,7 +406,8 @@ func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) error { if len(receivers) == 0 { // cannot add, no receiver - return errors.New("cannot subscribe without a receiver in place") + err = ErrNoReceiver + return } for _, receiver := range receivers { @@ -396,9 +432,9 @@ func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) error { streamId = PackStreamID(t.PublisherID(), t.ID()) } - err := t.MediaTrackSubscriptions.AddSubscriber(sub, NewWrappedReceiver(receivers, t.ID(), streamId, potentialCodecs)) + err = t.MediaTrackSubscriptions.AddSubscriber(sub, NewWrappedReceiver(receivers, t.ID(), streamId, potentialCodecs)) if err != nil { - return err + return } return nil @@ -413,8 +449,20 @@ func (t *MediaTrackReceiver) RemoveSubscriber(subscriberID livekit.ParticipantID } sub := subTrack.Subscriber() - trackID := subTrack.ID() - sub.EnqueueUnsubscribeTrack(trackID, willBeResumed, t.MediaTrackSubscriptions.RemoveSubscriber) + t.addPendingSubscribeOp(sub.ID()) + + sub.EnqueueUnsubscribeTrack(subTrack.ID(), willBeResumed, t.removeSubscriber) +} + +func (t *MediaTrackReceiver) removeSubscriber(subscriberID livekit.ParticipantID, willBeResumed bool) (err error) { + defer func() { + if err != nil { + t.removePendingSubscribeOp(subscriberID) + } + }() + + err = t.MediaTrackSubscriptions.RemoveSubscriber(subscriberID, willBeResumed) + return } func (t *MediaTrackReceiver) RemoveAllSubscribers(willBeResumed bool) { @@ -424,6 +472,13 @@ func (t *MediaTrackReceiver) RemoveAllSubscribers(willBeResumed bool) { } } +func (t *MediaTrackReceiver) IsSubscribed() bool { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.MediaTrackSubscriptions.GetNumSubscribers() != 0 || len(t.pendingSubscribeOp) != 0 +} + func (t *MediaTrackReceiver) RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { var revokedSubscriberIdentities []livekit.ParticipantIdentity diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 76547dd34..5ee5432cc 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -93,11 +93,8 @@ func (t *MediaTrackSubscriptions) Restart() { } func (t *MediaTrackSubscriptions) Stop() { - t.stopMaxQualityTimer() -} - -func (t *MediaTrackSubscriptions) Close() { t.qualityNotifyOpQueue.Stop() + t.stopMaxQualityTimer() } func (t *MediaTrackSubscriptions) OnDownTrackCreated(f func(downTrack *sfu.DownTrack)) { @@ -702,6 +699,11 @@ func (t *MediaTrackSubscriptions) startMaxQualityTimer(force bool) { return } + if t.maxQualityTimer != nil { + t.maxQualityTimer.Stop() + t.maxQualityTimer = nil + } + t.maxQualityTimer = time.AfterFunc(initialQualityUpdateWait, func() { t.stopMaxQualityTimer() t.UpdateQualityChange(force) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 03a9c1267..e117b3419 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -2112,7 +2112,7 @@ func (p *ParticipantImpl) handleNegotiationFailed() { } func (p *ParticipantImpl) EnqueueSubscribeTrack(trackID livekit.TrackID, f func(sub types.LocalParticipant) error) { - p.params.Logger.Infow("queueing subscribe", "trackID", trackID) + p.params.Logger.Infow("queuing subscribe", "trackID", trackID) p.lock.Lock() p.subscriptionRequestsQueue[trackID] = append(p.subscriptionRequestsQueue[trackID], SubscribeRequest{ @@ -2125,7 +2125,7 @@ func (p *ParticipantImpl) EnqueueSubscribeTrack(trackID livekit.TrackID, f func( } func (p *ParticipantImpl) EnqueueUnsubscribeTrack(trackID livekit.TrackID, willBeResumed bool, f func(subscriberID livekit.ParticipantID, willBeResumed bool) error) { - p.params.Logger.Infow("queueing unsubscribe", "trackID", trackID) + p.params.Logger.Infow("queuing unsubscribe", "trackID", trackID) p.lock.Lock() p.subscriptionRequestsQueue[trackID] = append(p.subscriptionRequestsQueue[trackID], SubscribeRequest{ @@ -2163,20 +2163,20 @@ func (p *ParticipantImpl) ProcessSubscriptionRequestsQueue(trackID livekit.Track } // process pending request even if adding errors out - go p.ClearInProgressAndProcessSubscriptionRequestsQueue(trackID) + p.ClearInProgressAndProcessSubscriptionRequestsQueue(trackID) } case SubscribeRequestTypeRemove: err := request.removeCb(p.ID(), request.willBeResumed) if err != nil { - go p.ClearInProgressAndProcessSubscriptionRequestsQueue(trackID) + p.ClearInProgressAndProcessSubscriptionRequestsQueue(trackID) } default: p.params.Logger.Warnw("unknown request type", nil) // let the queue move forward - go p.ClearInProgressAndProcessSubscriptionRequestsQueue(trackID) + p.ClearInProgressAndProcessSubscriptionRequestsQueue(trackID) } } @@ -2185,5 +2185,5 @@ func (p *ParticipantImpl) ClearInProgressAndProcessSubscriptionRequestsQueue(tra delete(p.subscriptionInProgress, trackID) p.lock.Unlock() - p.ProcessSubscriptionRequestsQueue(trackID) + go p.ProcessSubscriptionRequestsQueue(trackID) } diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index a1719488a..e62a12a1a 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -133,12 +133,13 @@ func (u *UpTrackManager) AddSubscriber(sub types.LocalParticipant, params types. for _, track := range tracks { trackID := track.ID() subscriberIdentity := sub.Identity() + u.lock.Lock() if !u.hasPermission(trackID, subscriberIdentity) { - u.lock.Lock() u.maybeAddPendingSubscription(trackID, sub) u.lock.Unlock() continue } + u.lock.Unlock() if err := track.AddSubscriber(sub); err != nil { return n, err @@ -209,14 +210,15 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( u.lock.Lock() defer u.lock.Unlock() - u.params.Logger.Debugw("updating subscription permission", "permissions", subscriptionPermission) if subscriptionPermission == nil { + u.params.Logger.Debugw("updating subscription permission, setting to nil") // store as is for use when migrating u.subscriptionPermission = subscriptionPermission // possible to get a nil when migrating return nil } + u.params.Logger.Debugw("updating subscription permission", "permissions", subscriptionPermission.String()) if err := u.parseSubscriptionPermissions(subscriptionPermission, resolverBySid); err != nil { // when failed, do not override previous permissions return err