From 03b0a01aaddc4fe9f2722565d0d8e66eae37a3d5 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 2 Jul 2022 10:52:55 +0530 Subject: [PATCH] Use a queue for add/remove subscribe operations. (#797) * Use a queue for add/remove subscribe operations. If subscribe/unsubscribe happens very quickly, the subscription state gets mixed up as things are keyed off of subscriberID. Use a queue of subscribe operations and process it serially. * set up callback for down track added * move the queue on unexpected type * move the queue if removeSubscirber does not have a subscribed track --- pkg/rtc/mediatrackreceiver.go | 15 +-- pkg/rtc/mediatracksubscriptions.go | 194 +++++++++++++++++++++++------ pkg/rtc/uptrackmanager.go | 2 +- 3 files changed, 162 insertions(+), 49 deletions(-) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 9a443d84a..aff5ed43a 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -91,6 +91,7 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver Telemetry: params.Telemetry, Logger: params.Logger, }) + t.MediaTrackSubscriptions.OnDownTrackCreated(t.onDownTrackCreated) if t.trackInfo.Muted { t.SetMuted(true) @@ -381,17 +382,11 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) error { streamId = PackStreamID(t.PublisherID(), t.ID()) } - downTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, NewWrappedReceiver(receivers, t.ID(), streamId, potentialCodecs)) + err := t.MediaTrackSubscriptions.AddSubscriber(sub, NewWrappedReceiver(receivers, t.ID(), streamId, potentialCodecs)) if err != nil { return err } - if downTrack != nil { - if t.Kind() == livekit.TrackType_AUDIO { - downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) - } - - } return nil } @@ -550,6 +545,12 @@ func (t *MediaTrackReceiver) GetAudioLevel() (float64, bool) { return receiver.GetAudioLevel() } +func (t *MediaTrackReceiver) onDownTrackCreated(downTrack *sfu.DownTrack) { + if t.Kind() == livekit.TrackType_AUDIO { + downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) + } +} + // handles max loss for audio streams func (t *MediaTrackReceiver) handleMaxLossFeedback(_ *sfu.DownTrack, report *rtcp.ReceiverReport) { t.downFracLostLock.Lock() diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 865d0362a..bcfc9dc3a 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -27,14 +27,35 @@ const ( initialQualityUpdateWait = 10 * time.Second ) +var ( + errAlreadySubscribed = errors.New("already subscribed") + errNoTransceiver = errors.New("cannot subscribe without a transceiver in place") + errNoSender = errors.New("cannot subscribe without a sender in place") + errNotFound = errors.New("not found") +) + +type SubscribeRequestType int + +const ( + SubscribeRequestTypeRemove SubscribeRequestType = iota + SubscribeRequestTypeAdd +) + +type SubscribeRequest struct { + requestType SubscribeRequestType + sub types.LocalParticipant + wr *WrappedReceiver + willBeResumed bool +} + // MediaTrackSubscriptions manages subscriptions of a media track type MediaTrackSubscriptions struct { params MediaTrackSubscriptionsParams - subscribedTracksMu sync.RWMutex - subscribedTracks map[livekit.ParticipantID]types.SubscribedTrack - pendingSubscribeTracks sync.Map // livekit.ParticipantID -> bool - pendingClose map[livekit.ParticipantID]types.SubscribedTrack + subscribedTracksMu sync.RWMutex + subscribedTracks map[livekit.ParticipantID]types.SubscribedTrack + inProgress map[livekit.ParticipantID]bool + requestsQueue map[livekit.ParticipantID][]SubscribeRequest onNoSubscribers func() @@ -48,6 +69,8 @@ type MediaTrackSubscriptions struct { maxQualityTimer *time.Timer qualityNotifyOpQueue *utils.OpsQueue + + onDownTrackCreated func(downTrack *sfu.DownTrack) } type MediaTrackSubscriptionsParams struct { @@ -67,7 +90,8 @@ func NewMediaTrackSubscriptions(params MediaTrackSubscriptionsParams) *MediaTrac t := &MediaTrackSubscriptions{ params: params, subscribedTracks: make(map[livekit.ParticipantID]types.SubscribedTrack), - pendingClose: make(map[livekit.ParticipantID]types.SubscribedTrack), + inProgress: make(map[livekit.ParticipantID]bool), + requestsQueue: make(map[livekit.ParticipantID][]SubscribeRequest), maxSubscriberQuality: make(map[livekit.ParticipantID]*types.SubscribedCodecQuality), maxSubscriberNodeQuality: make(map[livekit.NodeID][]types.SubscribedCodecQuality), maxSubscribedQuality: make(map[string]livekit.VideoQuality), @@ -99,6 +123,10 @@ func (t *MediaTrackSubscriptions) OnNoSubscribers(f func()) { t.onNoSubscribers = f } +func (t *MediaTrackSubscriptions) OnDownTrackCreated(f func(downTrack *sfu.DownTrack)) { + t.onDownTrackCreated = f +} + func (t *MediaTrackSubscriptions) SetMuted(muted bool) { // update mute of all subscribed tracks for _, st := range t.getAllSubscribedTracks() { @@ -120,22 +148,72 @@ func (t *MediaTrackSubscriptions) AddCodec(mime string) { t.subscribedTracksMu.Unlock() } +func (t *MediaTrackSubscriptions) processRequestsQueue(subscriberID livekit.ParticipantID) { + t.subscribedTracksMu.Lock() + if t.inProgress[subscriberID] || len(t.requestsQueue[subscriberID]) == 0 { + t.subscribedTracksMu.Unlock() + return + } + + request := t.requestsQueue[subscriberID][0] + t.requestsQueue[subscriberID] = t.requestsQueue[subscriberID][1:] + if len(t.requestsQueue[subscriberID]) == 0 { + delete(t.requestsQueue, subscriberID) + } + + t.inProgress[subscriberID] = true + t.subscribedTracksMu.Unlock() + + switch request.requestType { + case SubscribeRequestTypeAdd: + err := t.addSubscriber(request.sub, request.wr) + if err != nil { + if err != errAlreadySubscribed { + t.params.Logger.Errorw("error adding subscriber", err, "subscriberID", subscriberID) + } + + // process pending request even if adding errors out + go t.clearInProgressAndProcessRequestsQueue(subscriberID) + } + + case SubscribeRequestTypeRemove: + err := t.removeSubscriber(subscriberID, request.willBeResumed) + if err != nil { + go t.clearInProgressAndProcessRequestsQueue(subscriberID) + } + + default: + t.params.Logger.Warnw("unknown request type", nil) + + // let the queue move forward + go t.clearInProgressAndProcessRequestsQueue(subscriberID) + } +} + // AddSubscriber subscribes sub to current mediaTrack -func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *WrappedReceiver) (*sfu.DownTrack, error) { +func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *WrappedReceiver) error { + subscriberID := sub.ID() + t.subscribedTracksMu.Lock() + t.requestsQueue[subscriberID] = append(t.requestsQueue[subscriberID], SubscribeRequest{ + requestType: SubscribeRequestTypeAdd, + sub: sub, + wr: wr, + }) + t.subscribedTracksMu.Unlock() + + t.processRequestsQueue(subscriberID) + return nil +} + +func (t *MediaTrackSubscriptions) addSubscriber(sub types.LocalParticipant, wr *WrappedReceiver) error { trackID := t.params.MediaTrack.ID() subscriberID := sub.ID() - if _, pending := t.pendingSubscribeTracks.LoadOrStore(subscriberID, true); pending { - return nil, nil - } else { - defer t.pendingSubscribeTracks.Delete(subscriberID) - } - // don't subscribe to the same track multiple times t.subscribedTracksMu.Lock() if _, ok := t.subscribedTracks[subscriberID]; ok { t.subscribedTracksMu.Unlock() - return nil, nil + return errAlreadySubscribed } t.subscribedTracksMu.Unlock() @@ -159,7 +237,11 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * LoggerWithTrack(sub.GetLogger(), trackID), ) if err != nil { - return nil, err + return err + } + + if t.onDownTrackCreated != nil { + t.onDownTrackCreated(downTrack) } subTrack := NewSubscribedTrack(SubscribedTrackParams{ @@ -236,7 +318,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // sender, err = sub.SubscriberPC().AddTrack(downTrack) if err != nil { - return nil, err + return err } // as there is no way to get transceiver from sender, search @@ -251,7 +333,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * Direction: webrtc.RTPTransceiverDirectionSendonly, }) if err != nil { - return nil, err + return err } sender = transceiver.Sender() @@ -259,11 +341,11 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * } if transceiver == nil { // cannot add, no transceiver - return nil, errors.New("cannot subscribe without a transceiver in place") + return errNoTransceiver } if sender == nil { // cannot add, no sender - return nil, errors.New("cannot subscribe without a sender in place") + return errNoSender } // wthether re-using or stopping remove transceiver from cache @@ -301,45 +383,67 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * if !replacedTrack { sub.Negotiate(false) } + + t.clearInProgressAndProcessRequestsQueue(subscriberID) }() - t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto(), - &livekit.ParticipantInfo{Sid: string(t.params.MediaTrack.PublisherID()), Identity: string(t.params.MediaTrack.PublisherIdentity())}) - return downTrack, nil + t.params.Telemetry.TrackSubscribed( + context.Background(), + subscriberID, + t.params.MediaTrack.ToProto(), + &livekit.ParticipantInfo{ + Sid: string(t.params.MediaTrack.PublisherID()), + Identity: string(t.params.MediaTrack.PublisherIdentity()), + }, + ) + return nil } // RemoveSubscriber removes participant from subscription // stop all forwarders to the client -func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.ParticipantID, willBeResumed bool) { - subTrack := t.getSubscribedTrack(participantID) - +func (t *MediaTrackSubscriptions) RemoveSubscriber(subscriberID livekit.ParticipantID, willBeResumed bool) { t.subscribedTracksMu.Lock() - delete(t.subscribedTracks, participantID) - if subTrack != nil { - t.pendingClose[participantID] = subTrack - } + t.requestsQueue[subscriberID] = append(t.requestsQueue[subscriberID], SubscribeRequest{ + requestType: SubscribeRequestTypeRemove, + willBeResumed: willBeResumed, + }) t.subscribedTracksMu.Unlock() - if subTrack != nil { - t.closeSubscribedTrack(subTrack, willBeResumed) + t.processRequestsQueue(subscriberID) +} + +func (t *MediaTrackSubscriptions) removeSubscriber(subscriberID livekit.ParticipantID, willBeResumed bool) error { + t.params.Logger.Debugw("removing subscriber", "subscriberID", subscriberID, "willBeResumed", willBeResumed) + subTrack := t.getSubscribedTrack(subscriberID) + if subTrack == nil { + return errNotFound } + + t.closeSubscribedTrack(subTrack, willBeResumed) + return nil } func (t *MediaTrackSubscriptions) RemoveAllSubscribers(willBeResumed bool) { t.params.Logger.Debugw("removing all subscribers") + var subIDs []livekit.ParticipantID t.subscribedTracksMu.Lock() - subscribedTracks := t.getAllSubscribedTracksLocked() - t.subscribedTracks = make(map[livekit.ParticipantID]types.SubscribedTrack) + for _, subTrack := range t.getAllSubscribedTracksLocked() { + subscriberID := subTrack.SubscriberID() + t.requestsQueue[subscriberID] = append(t.requestsQueue[subscriberID], SubscribeRequest{ + requestType: SubscribeRequestTypeRemove, + willBeResumed: willBeResumed, + }) - for _, subTrack := range subscribedTracks { - t.pendingClose[subTrack.SubscriberID()] = subTrack + subIDs = append(subIDs, subscriberID) } t.subscribedTracksMu.Unlock() - for _, subTrack := range subscribedTracks { - t.closeSubscribedTrack(subTrack, willBeResumed) + for _, subID := range subIDs { + t.processRequestsQueue(subID) } + + t.maybeNotifyNoSubscribers() } func (t *MediaTrackSubscriptions) closeSubscribedTrack(subTrack types.SubscribedTrack, willBeResumed bool) { @@ -489,10 +593,10 @@ func (t *MediaTrackSubscriptions) OnSubscribedMaxQualityChange(f func(subscribed } func (t *MediaTrackSubscriptions) notifySubscriberMaxQuality(subscriberID livekit.ParticipantID, codec webrtc.RTPCodecCapability, quality livekit.VideoQuality) { - t.params.Logger.Debugw("notifying subscriber max quality", "subscriberID", subscriberID, "codec", codec, "quality", quality) if t.params.MediaTrack.Kind() != livekit.TrackType_VIDEO { return } + t.params.Logger.Debugw("notifying subscriber max quality", "subscriberID", subscriberID, "codec", codec, "quality", quality) if codec.MimeType == "" { t.params.Logger.Errorw("codec mime type is empty", nil) @@ -733,7 +837,7 @@ func (t *MediaTrackSubscriptions) maybeNotifyNoSubscribers() { } t.subscribedTracksMu.RLock() - empty := len(t.subscribedTracks) == 0 && len(t.pendingClose) == 0 + empty := len(t.subscribedTracks) == 0 && len(t.inProgress) == 0 && len(t.requestsQueue) == 0 t.subscribedTracksMu.RUnlock() if empty { @@ -750,11 +854,8 @@ func (t *MediaTrackSubscriptions) downTrackClosed( subscriberID := sub.ID() t.subscribedTracksMu.Lock() delete(t.subscribedTracks, subscriberID) - delete(t.pendingClose, subscriberID) t.subscribedTracksMu.Unlock() - t.maybeNotifyNoSubscribers() - if !willBeResumed { t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) @@ -794,4 +895,15 @@ func (t *MediaTrackSubscriptions) downTrackClosed( if !willBeResumed { sub.Negotiate(false) } + + t.clearInProgressAndProcessRequestsQueue(subscriberID) + t.maybeNotifyNoSubscribers() +} + +func (t *MediaTrackSubscriptions) clearInProgressAndProcessRequestsQueue(subscriberID livekit.ParticipantID) { + t.subscribedTracksMu.Lock() + delete(t.inProgress, subscriberID) + t.subscribedTracksMu.Unlock() + + t.processRequestsQueue(subscriberID) } diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 1b81c3bff..3b2d79408 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -124,7 +124,7 @@ func (u *UpTrackManager) AddSubscriber(sub types.LocalParticipant, params types. for _, track := range tracks { trackIDs = append(trackIDs, track.ID()) } - u.params.Logger.Debugw("subscribing new participant to tracks", + u.params.Logger.Debugw("subscribing participant to tracks", "subscriber", sub.Identity(), "subscriberID", sub.ID(), "trackIDs", trackIDs)