diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 198469408..1df02b419 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -178,7 +178,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra ) newWR.SetRTCPCh(t.params.RTCPChan) newWR.OnCloseHandler(func() { - t.RemoveAllSubscribers() + t.RemoveAllSubscribers(false) t.MediaTrackReceiver.ClearReceiver(mime) if t.MediaTrackReceiver.TryClose() { t.params.Telemetry.TrackUnpublished( diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 4b9bff75a..c425695cd 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -240,8 +240,8 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * go sub.UpdateRTT(rtt) }) - downTrack.OnCloseHandler(func() { - go t.downTrackClosed(sub, subTrack, sender) + downTrack.OnCloseHandler(func(willBeResumed bool) { + go t.downTrackClosed(sub, subTrack, willBeResumed, sender) }) t.subscribedTracksMu.Lock() @@ -261,7 +261,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // RemoveSubscriber removes participant from subscription // stop all forwarders to the client -func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.ParticipantID, resume bool) { +func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.ParticipantID, willBeResumed bool) { subTrack := t.getSubscribedTrack(participantID) t.subscribedTracksMu.Lock() @@ -272,11 +272,11 @@ func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.Partici t.subscribedTracksMu.Unlock() if subTrack != nil { - subTrack.DownTrack().CloseWithFlush(!resume) + subTrack.DownTrack().CloseWithFlush(!willBeResumed) } } -func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { +func (t *MediaTrackSubscriptions) RemoveAllSubscribers(willBeResumed bool) { t.params.Logger.Debugw("removing all subscribers") t.subscribedTracksMu.Lock() @@ -289,7 +289,7 @@ func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { t.subscribedTracksMu.Unlock() for _, subTrack := range subscribedTracks { - subTrack.DownTrack().Close() + subTrack.DownTrack().CloseWithFlush(!willBeResumed) } } @@ -678,6 +678,7 @@ func (t *MediaTrackSubscriptions) maybeNotifyNoSubscribers() { func (t *MediaTrackSubscriptions) downTrackClosed( sub types.LocalParticipant, subTrack types.SubscribedTrack, + willBeResumed bool, sender *webrtc.RTPSender, ) { subscriberID := sub.ID() @@ -688,39 +689,43 @@ func (t *MediaTrackSubscriptions) downTrackClosed( t.maybeNotifyNoSubscribers() - t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) + if !willBeResumed { + t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) - // ignore if the subscribing sub is not connected - if sub.SubscriberPC().ConnectionState() == webrtc.PeerConnectionStateClosed { - return - } - - // if the source has been terminated, we'll need to terminate all the subscribed tracks - // however, if the dest sub has disconnected, then we can skip - if sender == nil { - return - } - t.params.Logger.Debugw("removing peerconnection track", - "subscriber", sub.Identity(), - "subscriberID", subscriberID, - "kind", t.params.MediaTrack.Kind(), - ) - if err := sub.SubscriberPC().RemoveTrack(sender); err != nil { - if err == webrtc.ErrConnectionClosed { - // sub closing, can skip removing subscribedtracks + // ignore if the subscribing sub is not connected + if sub.SubscriberPC().ConnectionState() == webrtc.PeerConnectionStateClosed { return } - if _, ok := err.(*rtcerr.InvalidStateError); !ok { - // most of these are safe to ignore, since the track state might have already - // been set to Inactive - t.params.Logger.Debugw("could not remove remoteTrack from forwarder", - "error", err, - "subscriber", sub.Identity(), - "subscriberID", subscriberID, - ) + + // if the source has been terminated, we'll need to terminate all the subscribed tracks + // however, if the dest sub has disconnected, then we can skip + if sender == nil { + return + } + t.params.Logger.Debugw("removing peerconnection track", + "subscriber", sub.Identity(), + "subscriberID", subscriberID, + "kind", t.params.MediaTrack.Kind(), + ) + if err := sub.SubscriberPC().RemoveTrack(sender); err != nil { + if err == webrtc.ErrConnectionClosed { + // sub closing, can skip removing subscribedtracks + return + } + if _, ok := err.(*rtcerr.InvalidStateError); !ok { + // most of these are safe to ignore, since the track state might have already + // been set to Inactive + t.params.Logger.Debugw("could not remove remoteTrack from forwarder", + "error", err, + "subscriber", sub.Identity(), + "subscriberID", subscriberID, + ) + } } } sub.RemoveSubscribedTrack(subTrack) - sub.Negotiate(false) + if !willBeResumed { + sub.Negotiate(false) + } } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index cd3edcfb6..75416faf6 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -664,7 +664,7 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea }) } - p.UpTrackManager.Close() + p.UpTrackManager.Close(!sendLeave) p.pendingTracksLock.Lock() p.pendingTracks = make(map[string]*pendingTrackInfo) diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 76f91ed62..0dc1895b6 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -288,9 +288,9 @@ type MediaTrack interface { // subscribers AddSubscriber(participant LocalParticipant) error - RemoveSubscriber(participantID livekit.ParticipantID, resume bool) + RemoveSubscriber(participantID livekit.ParticipantID, willBeResumed bool) IsSubscriber(subID livekit.ParticipantID) bool - RemoveAllSubscribers() + RemoveAllSubscribers(willBeResumed bool) RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity GetAllSubscribers() []livekit.ParticipantID diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index b040eeb1a..ad5a74685 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -184,9 +184,10 @@ type FakeLocalMediaTrack struct { receiversReturnsOnCall map[int]struct { result1 []sfu.TrackReceiver } - RemoveAllSubscribersStub func() + RemoveAllSubscribersStub func(bool) removeAllSubscribersMutex sync.RWMutex removeAllSubscribersArgsForCall []struct { + arg1 bool } RemoveSubscriberStub func(livekit.ParticipantID, bool) removeSubscriberMutex sync.RWMutex @@ -1192,15 +1193,16 @@ func (fake *FakeLocalMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.Tra }{result1} } -func (fake *FakeLocalMediaTrack) RemoveAllSubscribers() { +func (fake *FakeLocalMediaTrack) RemoveAllSubscribers(arg1 bool) { fake.removeAllSubscribersMutex.Lock() fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct { - }{}) + arg1 bool + }{arg1}) stub := fake.RemoveAllSubscribersStub - fake.recordInvocation("RemoveAllSubscribers", []interface{}{}) + fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1}) fake.removeAllSubscribersMutex.Unlock() if stub != nil { - fake.RemoveAllSubscribersStub() + fake.RemoveAllSubscribersStub(arg1) } } @@ -1210,12 +1212,19 @@ func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCallCount() int { return len(fake.removeAllSubscribersArgsForCall) } -func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCalls(stub func()) { +func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) { fake.removeAllSubscribersMutex.Lock() defer fake.removeAllSubscribersMutex.Unlock() fake.RemoveAllSubscribersStub = stub } +func (fake *FakeLocalMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool { + fake.removeAllSubscribersMutex.RLock() + defer fake.removeAllSubscribersMutex.RUnlock() + argsForCall := fake.removeAllSubscribersArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) { fake.removeSubscriberMutex.Lock() fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 12b5ef70e..770674215 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -151,9 +151,10 @@ type FakeMediaTrack struct { receiversReturnsOnCall map[int]struct { result1 []sfu.TrackReceiver } - RemoveAllSubscribersStub func() + RemoveAllSubscribersStub func(bool) removeAllSubscribersMutex sync.RWMutex removeAllSubscribersArgsForCall []struct { + arg1 bool } RemoveSubscriberStub func(livekit.ParticipantID, bool) removeSubscriberMutex sync.RWMutex @@ -974,15 +975,16 @@ func (fake *FakeMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.TrackRec }{result1} } -func (fake *FakeMediaTrack) RemoveAllSubscribers() { +func (fake *FakeMediaTrack) RemoveAllSubscribers(arg1 bool) { fake.removeAllSubscribersMutex.Lock() fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct { - }{}) + arg1 bool + }{arg1}) stub := fake.RemoveAllSubscribersStub - fake.recordInvocation("RemoveAllSubscribers", []interface{}{}) + fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1}) fake.removeAllSubscribersMutex.Unlock() if stub != nil { - fake.RemoveAllSubscribersStub() + fake.RemoveAllSubscribersStub(arg1) } } @@ -992,12 +994,19 @@ func (fake *FakeMediaTrack) RemoveAllSubscribersCallCount() int { return len(fake.removeAllSubscribersArgsForCall) } -func (fake *FakeMediaTrack) RemoveAllSubscribersCalls(stub func()) { +func (fake *FakeMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) { fake.removeAllSubscribersMutex.Lock() defer fake.removeAllSubscribersMutex.Unlock() fake.RemoveAllSubscribersStub = stub } +func (fake *FakeMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool { + fake.removeAllSubscribersMutex.RLock() + defer fake.removeAllSubscribersMutex.RUnlock() + argsForCall := fake.removeAllSubscribersArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) { fake.removeSubscriberMutex.Lock() fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index a6f7b9f00..ad6830027 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -56,7 +56,7 @@ func (u *UpTrackManager) Restart() { } } -func (u *UpTrackManager) Close() { +func (u *UpTrackManager) Close(willBeResumed bool) { u.lock.Lock() u.closed = true notify := len(u.publishedTracks) == 0 @@ -64,7 +64,7 @@ func (u *UpTrackManager) Close() { // remove all subscribers for _, t := range u.GetPublishedTracks() { - t.RemoveAllSubscribers() + t.RemoveAllSubscribers(willBeResumed) } if notify && u.onClose != nil { @@ -297,7 +297,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { } func (u *UpTrackManager) RemovePublishedTrack(track types.MediaTrack) { - track.RemoveAllSubscribers() + track.RemoveAllSubscribers(false) u.lock.Lock() delete(u.publishedTracks, track.ID()) u.lock.Unlock() diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 4b876e23d..0a488319a 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -131,7 +131,7 @@ type DownTrack struct { receiver TrackReceiver transceiver *webrtc.RTPTransceiver writeStream webrtc.TrackLocalWriter - onCloseHandler func() + onCloseHandler func(willBeResumed bool) onBind func() receiverReportListeners []ReceiverReportListener listenerLock sync.RWMutex @@ -674,7 +674,7 @@ func (d *DownTrack) CloseWithFlush(flush bool) { } if d.onCloseHandler != nil { - d.onCloseHandler() + d.onCloseHandler(!flush) } d.stopKeyFrameRequester() @@ -737,7 +737,7 @@ func (d *DownTrack) UpTrackBitrateAvailabilityChange() { } // OnCloseHandler method to be called on remote tracked removed -func (d *DownTrack) OnCloseHandler(fn func()) { +func (d *DownTrack) OnCloseHandler(fn func(willBeResumed bool)) { d.onCloseHandler = fn }