diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 3bfc96639..b1f6bd9ee 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -225,7 +225,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra ) newWR.SetRTCPCh(t.params.RTCPChan) newWR.OnCloseHandler(func() { - t.RemoveAllSubscribers(false) t.MediaTrackReceiver.ClearReceiver(mime) if t.MediaTrackReceiver.TryClose() { if t.dynacastManager != nil { diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 56bfa2cff..ecb926bab 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -26,7 +26,8 @@ const ( ) var ( - ErrNoReceiver = errors.New("cannot subscribe without a receiver in place") + ErrClosingOrClosed = errors.New("track is closing or closed") + ErrNoReceiver = errors.New("cannot subscribe without a receiver in place") ) type simulcastReceiver struct { @@ -64,6 +65,9 @@ type MediaTrackReceiver struct { layerDimensions map[livekit.VideoQuality]*livekit.VideoLayer potentialCodecs []webrtc.RTPCodecParameters pendingSubscribeOp map[livekit.ParticipantID]int + isMimeClosed map[string]bool + isClosing bool + isClosed bool onSetupReceiver func(mime string) onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) @@ -79,6 +83,7 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver trackInfo: proto.Clone(params.TrackInfo).(*livekit.TrackInfo), layerDimensions: make(map[livekit.VideoQuality]*livekit.VideoLayer), pendingSubscribeOp: make(map[livekit.ParticipantID]int), + isMimeClosed: make(map[string]bool), } t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ @@ -126,10 +131,23 @@ func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime string)) { func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority int, mid string) { t.lock.Lock() + if t.isClosing || t.isClosed { + t.params.Logger.Warnw("cannot set up receiver on closing or closed track", nil) + t.lock.Unlock() + return + } + + mimeType := receiver.Codec().MimeType + if t.isMimeClosed[mimeType] { + t.params.Logger.Warnw("cannot set up receiver on closing mime", nil, "mime", mimeType) + t.lock.Unlock() + return + } + // codec postion maybe taked by DumbReceiver, check and upgrade to WebRTCReceiver var upgradeReceiver bool for _, r := range t.receivers { - if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) { + if strings.EqualFold(r.Codec().MimeType, mimeType) { if d, ok := r.TrackReceiver.(*DummyReceiver); ok { d.Upgrade(receiver) upgradeReceiver = true @@ -229,19 +247,34 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string) { if strings.EqualFold(receiver.Codec().MimeType, mime) { t.receivers[idx] = t.receivers[len(t.receivers)-1] t.receivers = t.receivers[:len(t.receivers)-1] + + t.isMimeClosed[mime] = true break } } t.shadowReceiversLocked() t.lock.Unlock() + + t.removeAllSubscribersForMime(mime, false) } func (t *MediaTrackReceiver) ClearAllReceivers() { t.lock.Lock() + var mimes []string + for _, receiver := range t.receivers { + mime := receiver.Codec().MimeType + t.isMimeClosed[mime] = true + mimes = append(mimes, mime) + } + t.receivers = t.receivers[:0] t.receiversShadow = nil t.lock.Unlock() + + for _, mime := range mimes { + t.ClearReceiver(mime) + } } func (t *MediaTrackReceiver) OnMediaLossFeedback(f func(dt *sfu.DownTrack, rr *rtcp.ReceiverReport)) { @@ -254,6 +287,11 @@ func (t *MediaTrackReceiver) OnVideoLayerUpdate(f func(layers []*livekit.VideoLa func (t *MediaTrackReceiver) TryClose() bool { t.lock.RLock() + if t.isClosed { + t.lock.RUnlock() + return true + } + if len(t.receiversShadow) > 0 { t.lock.RUnlock() return false @@ -266,6 +304,7 @@ func (t *MediaTrackReceiver) TryClose() bool { func (t *MediaTrackReceiver) Close() { t.lock.RLock() + t.isClosed = true onclose := t.onClose t.lock.RUnlock() @@ -387,6 +426,12 @@ func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) (err erro }() t.lock.RLock() + if t.isClosing || t.isClosed { + t.lock.RUnlock() + err = ErrClosingOrClosed + return + } + receivers := t.receiversShadow potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs)) copy(potentialCodecs, t.potentialCodecs) @@ -453,8 +498,19 @@ func (t *MediaTrackReceiver) removeSubscriber(subscriberID livekit.ParticipantID return } -func (t *MediaTrackReceiver) RemoveAllSubscribers(willBeResumed bool) { - t.params.Logger.Infow("removing all subscribers") +func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime string, willBeResumed bool) { + t.params.Logger.Infow("removing all subscribers", "mime", mime) + for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribersForMime(mime) { + t.RemoveSubscriber(subscriberID, willBeResumed) + } +} + +func (t *MediaTrackReceiver) InitiateClose(willBeResumed bool) { + t.params.Logger.Infow("initiating close") + t.lock.Lock() + t.isClosing = true + t.lock.Unlock() + for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribers() { t.RemoveSubscriber(subscriberID, willBeResumed) } diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 84b29ea15..1e18b7fe4 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -335,6 +335,21 @@ func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID { return subs } +func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime string) []livekit.ParticipantID { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks)) + for id, subTrack := range t.subscribedTracks { + if subTrack.DownTrack().Codec().MimeType != mime { + continue + } + + subs = append(subs, id) + } + return subs +} + func (t *MediaTrackSubscriptions) GetNumSubscribers() int { t.subscribedTracksMu.RLock() defer t.subscribedTracksMu.RUnlock() diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 57060f451..28fb88c68 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -340,7 +340,7 @@ type MediaTrack interface { AddSubscriber(participant LocalParticipant) error RemoveSubscriber(participantID livekit.ParticipantID, willBeResumed bool) IsSubscriber(subID livekit.ParticipantID) bool - RemoveAllSubscribers(willBeResumed bool) + InitiateClose(willBeResumed bool) RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity GetAllSubscribers() []livekit.ParticipantID GetNumSubscribers() int diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 04a1999de..a2c416dd1 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -101,6 +101,11 @@ type FakeLocalMediaTrack struct { iDReturnsOnCall map[int]struct { result1 livekit.TrackID } + InitiateCloseStub func(bool) + initiateCloseMutex sync.RWMutex + initiateCloseArgsForCall []struct { + arg1 bool + } IsMutedStub func() bool isMutedMutex sync.RWMutex isMutedArgsForCall []struct { @@ -204,11 +209,6 @@ type FakeLocalMediaTrack struct { receiversReturnsOnCall map[int]struct { result1 []sfu.TrackReceiver } - RemoveAllSubscribersStub func(bool) - removeAllSubscribersMutex sync.RWMutex - removeAllSubscribersArgsForCall []struct { - arg1 bool - } RemoveSubscriberStub func(livekit.ParticipantID, bool) removeSubscriberMutex sync.RWMutex removeSubscriberArgsForCall []struct { @@ -763,6 +763,38 @@ func (fake *FakeLocalMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) }{result1} } +func (fake *FakeLocalMediaTrack) InitiateClose(arg1 bool) { + fake.initiateCloseMutex.Lock() + fake.initiateCloseArgsForCall = append(fake.initiateCloseArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.InitiateCloseStub + fake.recordInvocation("InitiateClose", []interface{}{arg1}) + fake.initiateCloseMutex.Unlock() + if stub != nil { + fake.InitiateCloseStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) InitiateCloseCallCount() int { + fake.initiateCloseMutex.RLock() + defer fake.initiateCloseMutex.RUnlock() + return len(fake.initiateCloseArgsForCall) +} + +func (fake *FakeLocalMediaTrack) InitiateCloseCalls(stub func(bool)) { + fake.initiateCloseMutex.Lock() + defer fake.initiateCloseMutex.Unlock() + fake.InitiateCloseStub = stub +} + +func (fake *FakeLocalMediaTrack) InitiateCloseArgsForCall(i int) bool { + fake.initiateCloseMutex.RLock() + defer fake.initiateCloseMutex.RUnlock() + argsForCall := fake.initiateCloseArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalMediaTrack) IsMuted() bool { fake.isMutedMutex.Lock() ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)] @@ -1319,38 +1351,6 @@ func (fake *FakeLocalMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.Tra }{result1} } -func (fake *FakeLocalMediaTrack) RemoveAllSubscribers(arg1 bool) { - fake.removeAllSubscribersMutex.Lock() - fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct { - arg1 bool - }{arg1}) - stub := fake.RemoveAllSubscribersStub - fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1}) - fake.removeAllSubscribersMutex.Unlock() - if stub != nil { - fake.RemoveAllSubscribersStub(arg1) - } -} - -func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCallCount() int { - fake.removeAllSubscribersMutex.RLock() - defer fake.removeAllSubscribersMutex.RUnlock() - return len(fake.removeAllSubscribersArgsForCall) -} - -func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) { - fake.removeAllSubscribersMutex.Lock() - defer fake.removeAllSubscribersMutex.Unlock() - fake.RemoveAllSubscribersStub = stub -} - -func (fake *FakeLocalMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool { - fake.removeAllSubscribersMutex.RLock() - defer fake.removeAllSubscribersMutex.RUnlock() - argsForCall := fake.removeAllSubscribersArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakeLocalMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) { fake.removeSubscriberMutex.Lock() fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { @@ -1755,6 +1755,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.hasSdpCidMutex.RUnlock() fake.iDMutex.RLock() defer fake.iDMutex.RUnlock() + fake.initiateCloseMutex.RLock() + defer fake.initiateCloseMutex.RUnlock() fake.isMutedMutex.RLock() defer fake.isMutedMutex.RUnlock() fake.isSimulcastMutex.RLock() @@ -1777,8 +1779,6 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.publisherVersionMutex.RUnlock() fake.receiversMutex.RLock() defer fake.receiversMutex.RUnlock() - fake.removeAllSubscribersMutex.RLock() - defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() fake.restartMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 04ce00dc8..9eb7b525d 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -68,6 +68,11 @@ type FakeMediaTrack struct { iDReturnsOnCall map[int]struct { result1 livekit.TrackID } + InitiateCloseStub func(bool) + initiateCloseMutex sync.RWMutex + initiateCloseArgsForCall []struct { + arg1 bool + } IsMutedStub func() bool isMutedMutex sync.RWMutex isMutedArgsForCall []struct { @@ -159,11 +164,6 @@ type FakeMediaTrack struct { receiversReturnsOnCall map[int]struct { result1 []sfu.TrackReceiver } - RemoveAllSubscribersStub func(bool) - removeAllSubscribersMutex sync.RWMutex - removeAllSubscribersArgsForCall []struct { - arg1 bool - } RemoveSubscriberStub func(livekit.ParticipantID, bool) removeSubscriberMutex sync.RWMutex removeSubscriberArgsForCall []struct { @@ -529,6 +529,38 @@ func (fake *FakeMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) { }{result1} } +func (fake *FakeMediaTrack) InitiateClose(arg1 bool) { + fake.initiateCloseMutex.Lock() + fake.initiateCloseArgsForCall = append(fake.initiateCloseArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.InitiateCloseStub + fake.recordInvocation("InitiateClose", []interface{}{arg1}) + fake.initiateCloseMutex.Unlock() + if stub != nil { + fake.InitiateCloseStub(arg1) + } +} + +func (fake *FakeMediaTrack) InitiateCloseCallCount() int { + fake.initiateCloseMutex.RLock() + defer fake.initiateCloseMutex.RUnlock() + return len(fake.initiateCloseArgsForCall) +} + +func (fake *FakeMediaTrack) InitiateCloseCalls(stub func(bool)) { + fake.initiateCloseMutex.Lock() + defer fake.initiateCloseMutex.Unlock() + fake.InitiateCloseStub = stub +} + +func (fake *FakeMediaTrack) InitiateCloseArgsForCall(i int) bool { + fake.initiateCloseMutex.RLock() + defer fake.initiateCloseMutex.RUnlock() + argsForCall := fake.initiateCloseArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeMediaTrack) IsMuted() bool { fake.isMutedMutex.Lock() ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)] @@ -1014,38 +1046,6 @@ func (fake *FakeMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.TrackRec }{result1} } -func (fake *FakeMediaTrack) RemoveAllSubscribers(arg1 bool) { - fake.removeAllSubscribersMutex.Lock() - fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct { - arg1 bool - }{arg1}) - stub := fake.RemoveAllSubscribersStub - fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1}) - fake.removeAllSubscribersMutex.Unlock() - if stub != nil { - fake.RemoveAllSubscribersStub(arg1) - } -} - -func (fake *FakeMediaTrack) RemoveAllSubscribersCallCount() int { - fake.removeAllSubscribersMutex.RLock() - defer fake.removeAllSubscribersMutex.RUnlock() - return len(fake.removeAllSubscribersArgsForCall) -} - -func (fake *FakeMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) { - fake.removeAllSubscribersMutex.Lock() - defer fake.removeAllSubscribersMutex.Unlock() - fake.RemoveAllSubscribersStub = stub -} - -func (fake *FakeMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool { - fake.removeAllSubscribersMutex.RLock() - defer fake.removeAllSubscribersMutex.RUnlock() - argsForCall := fake.removeAllSubscribersArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakeMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) { fake.removeSubscriberMutex.Lock() fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { @@ -1335,6 +1335,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.getQualityForDimensionMutex.RUnlock() fake.iDMutex.RLock() defer fake.iDMutex.RUnlock() + fake.initiateCloseMutex.RLock() + defer fake.initiateCloseMutex.RUnlock() fake.isMutedMutex.RLock() defer fake.isMutedMutex.RUnlock() fake.isSimulcastMutex.RLock() @@ -1353,8 +1355,6 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.publisherVersionMutex.RUnlock() fake.receiversMutex.RLock() defer fake.receiversMutex.RUnlock() - fake.removeAllSubscribersMutex.RLock() - defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() fake.revokeDisallowedSubscribersMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index eebac7819..e8fc1b02a 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -68,7 +68,7 @@ func (u *UpTrackManager) Close(willBeResumed bool) { // remove all subscribers for _, t := range u.GetPublishedTracks() { - t.RemoveAllSubscribers(willBeResumed) + t.InitiateClose(willBeResumed) } if notify && u.onClose != nil { @@ -317,7 +317,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { } func (u *UpTrackManager) RemovePublishedTrack(track types.MediaTrack, willBeResumed bool) { - track.RemoveAllSubscribers(willBeResumed) + track.InitiateClose(willBeResumed) u.lock.Lock() delete(u.publishedTracks, track.ID()) u.lock.Unlock()