Prevent track subscriptions/adding receivers after close (#924)

* Prevent track subscriptions/adding receivers after close

With subscribe/unsubscribe queuing, a subscribe may be
attempted after a call to `RemoveAllSubscribers`.
So, renaming `RemoveAllSubscribers` to `InitiateClose`
and maintaining state that track is in the process of closing.

* Mime specific remove

* Remove unused error

* do not add receiver when closing
This commit is contained in:
Raja Subramanian
2022-08-17 13:07:59 +05:30
committed by GitHub
parent d9059f4f3b
commit f5627c3859
7 changed files with 156 additions and 86 deletions
-1
View File
@@ -225,7 +225,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra
)
newWR.SetRTCPCh(t.params.RTCPChan)
newWR.OnCloseHandler(func() {
t.RemoveAllSubscribers(false)
t.MediaTrackReceiver.ClearReceiver(mime)
if t.MediaTrackReceiver.TryClose() {
if t.dynacastManager != nil {
+60 -4
View File
@@ -26,7 +26,8 @@ const (
)
var (
ErrNoReceiver = errors.New("cannot subscribe without a receiver in place")
ErrClosingOrClosed = errors.New("track is closing or closed")
ErrNoReceiver = errors.New("cannot subscribe without a receiver in place")
)
type simulcastReceiver struct {
@@ -64,6 +65,9 @@ type MediaTrackReceiver struct {
layerDimensions map[livekit.VideoQuality]*livekit.VideoLayer
potentialCodecs []webrtc.RTPCodecParameters
pendingSubscribeOp map[livekit.ParticipantID]int
isMimeClosed map[string]bool
isClosing bool
isClosed bool
onSetupReceiver func(mime string)
onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport)
@@ -79,6 +83,7 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver
trackInfo: proto.Clone(params.TrackInfo).(*livekit.TrackInfo),
layerDimensions: make(map[livekit.VideoQuality]*livekit.VideoLayer),
pendingSubscribeOp: make(map[livekit.ParticipantID]int),
isMimeClosed: make(map[string]bool),
}
t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{
@@ -126,10 +131,23 @@ func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime string)) {
func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority int, mid string) {
t.lock.Lock()
if t.isClosing || t.isClosed {
t.params.Logger.Warnw("cannot set up receiver on closing or closed track", nil)
t.lock.Unlock()
return
}
mimeType := receiver.Codec().MimeType
if t.isMimeClosed[mimeType] {
t.params.Logger.Warnw("cannot set up receiver on closing mime", nil, "mime", mimeType)
t.lock.Unlock()
return
}
// codec postion maybe taked by DumbReceiver, check and upgrade to WebRTCReceiver
var upgradeReceiver bool
for _, r := range t.receivers {
if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) {
if strings.EqualFold(r.Codec().MimeType, mimeType) {
if d, ok := r.TrackReceiver.(*DummyReceiver); ok {
d.Upgrade(receiver)
upgradeReceiver = true
@@ -229,19 +247,34 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string) {
if strings.EqualFold(receiver.Codec().MimeType, mime) {
t.receivers[idx] = t.receivers[len(t.receivers)-1]
t.receivers = t.receivers[:len(t.receivers)-1]
t.isMimeClosed[mime] = true
break
}
}
t.shadowReceiversLocked()
t.lock.Unlock()
t.removeAllSubscribersForMime(mime, false)
}
func (t *MediaTrackReceiver) ClearAllReceivers() {
t.lock.Lock()
var mimes []string
for _, receiver := range t.receivers {
mime := receiver.Codec().MimeType
t.isMimeClosed[mime] = true
mimes = append(mimes, mime)
}
t.receivers = t.receivers[:0]
t.receiversShadow = nil
t.lock.Unlock()
for _, mime := range mimes {
t.ClearReceiver(mime)
}
}
func (t *MediaTrackReceiver) OnMediaLossFeedback(f func(dt *sfu.DownTrack, rr *rtcp.ReceiverReport)) {
@@ -254,6 +287,11 @@ func (t *MediaTrackReceiver) OnVideoLayerUpdate(f func(layers []*livekit.VideoLa
func (t *MediaTrackReceiver) TryClose() bool {
t.lock.RLock()
if t.isClosed {
t.lock.RUnlock()
return true
}
if len(t.receiversShadow) > 0 {
t.lock.RUnlock()
return false
@@ -266,6 +304,7 @@ func (t *MediaTrackReceiver) TryClose() bool {
func (t *MediaTrackReceiver) Close() {
t.lock.RLock()
t.isClosed = true
onclose := t.onClose
t.lock.RUnlock()
@@ -387,6 +426,12 @@ func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) (err erro
}()
t.lock.RLock()
if t.isClosing || t.isClosed {
t.lock.RUnlock()
err = ErrClosingOrClosed
return
}
receivers := t.receiversShadow
potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs))
copy(potentialCodecs, t.potentialCodecs)
@@ -453,8 +498,19 @@ func (t *MediaTrackReceiver) removeSubscriber(subscriberID livekit.ParticipantID
return
}
func (t *MediaTrackReceiver) RemoveAllSubscribers(willBeResumed bool) {
t.params.Logger.Infow("removing all subscribers")
func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime string, willBeResumed bool) {
t.params.Logger.Infow("removing all subscribers", "mime", mime)
for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribersForMime(mime) {
t.RemoveSubscriber(subscriberID, willBeResumed)
}
}
func (t *MediaTrackReceiver) InitiateClose(willBeResumed bool) {
t.params.Logger.Infow("initiating close")
t.lock.Lock()
t.isClosing = true
t.lock.Unlock()
for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribers() {
t.RemoveSubscriber(subscriberID, willBeResumed)
}
+15
View File
@@ -335,6 +335,21 @@ func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID {
return subs
}
func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime string) []livekit.ParticipantID {
t.subscribedTracksMu.RLock()
defer t.subscribedTracksMu.RUnlock()
subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks))
for id, subTrack := range t.subscribedTracks {
if subTrack.DownTrack().Codec().MimeType != mime {
continue
}
subs = append(subs, id)
}
return subs
}
func (t *MediaTrackSubscriptions) GetNumSubscribers() int {
t.subscribedTracksMu.RLock()
defer t.subscribedTracksMu.RUnlock()
+1 -1
View File
@@ -340,7 +340,7 @@ type MediaTrack interface {
AddSubscriber(participant LocalParticipant) error
RemoveSubscriber(participantID livekit.ParticipantID, willBeResumed bool)
IsSubscriber(subID livekit.ParticipantID) bool
RemoveAllSubscribers(willBeResumed bool)
InitiateClose(willBeResumed bool)
RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity
GetAllSubscribers() []livekit.ParticipantID
GetNumSubscribers() int
@@ -101,6 +101,11 @@ type FakeLocalMediaTrack struct {
iDReturnsOnCall map[int]struct {
result1 livekit.TrackID
}
InitiateCloseStub func(bool)
initiateCloseMutex sync.RWMutex
initiateCloseArgsForCall []struct {
arg1 bool
}
IsMutedStub func() bool
isMutedMutex sync.RWMutex
isMutedArgsForCall []struct {
@@ -204,11 +209,6 @@ type FakeLocalMediaTrack struct {
receiversReturnsOnCall map[int]struct {
result1 []sfu.TrackReceiver
}
RemoveAllSubscribersStub func(bool)
removeAllSubscribersMutex sync.RWMutex
removeAllSubscribersArgsForCall []struct {
arg1 bool
}
RemoveSubscriberStub func(livekit.ParticipantID, bool)
removeSubscriberMutex sync.RWMutex
removeSubscriberArgsForCall []struct {
@@ -763,6 +763,38 @@ func (fake *FakeLocalMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID)
}{result1}
}
func (fake *FakeLocalMediaTrack) InitiateClose(arg1 bool) {
fake.initiateCloseMutex.Lock()
fake.initiateCloseArgsForCall = append(fake.initiateCloseArgsForCall, struct {
arg1 bool
}{arg1})
stub := fake.InitiateCloseStub
fake.recordInvocation("InitiateClose", []interface{}{arg1})
fake.initiateCloseMutex.Unlock()
if stub != nil {
fake.InitiateCloseStub(arg1)
}
}
func (fake *FakeLocalMediaTrack) InitiateCloseCallCount() int {
fake.initiateCloseMutex.RLock()
defer fake.initiateCloseMutex.RUnlock()
return len(fake.initiateCloseArgsForCall)
}
func (fake *FakeLocalMediaTrack) InitiateCloseCalls(stub func(bool)) {
fake.initiateCloseMutex.Lock()
defer fake.initiateCloseMutex.Unlock()
fake.InitiateCloseStub = stub
}
func (fake *FakeLocalMediaTrack) InitiateCloseArgsForCall(i int) bool {
fake.initiateCloseMutex.RLock()
defer fake.initiateCloseMutex.RUnlock()
argsForCall := fake.initiateCloseArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeLocalMediaTrack) IsMuted() bool {
fake.isMutedMutex.Lock()
ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)]
@@ -1319,38 +1351,6 @@ func (fake *FakeLocalMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.Tra
}{result1}
}
func (fake *FakeLocalMediaTrack) RemoveAllSubscribers(arg1 bool) {
fake.removeAllSubscribersMutex.Lock()
fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct {
arg1 bool
}{arg1})
stub := fake.RemoveAllSubscribersStub
fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1})
fake.removeAllSubscribersMutex.Unlock()
if stub != nil {
fake.RemoveAllSubscribersStub(arg1)
}
}
func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCallCount() int {
fake.removeAllSubscribersMutex.RLock()
defer fake.removeAllSubscribersMutex.RUnlock()
return len(fake.removeAllSubscribersArgsForCall)
}
func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) {
fake.removeAllSubscribersMutex.Lock()
defer fake.removeAllSubscribersMutex.Unlock()
fake.RemoveAllSubscribersStub = stub
}
func (fake *FakeLocalMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool {
fake.removeAllSubscribersMutex.RLock()
defer fake.removeAllSubscribersMutex.RUnlock()
argsForCall := fake.removeAllSubscribersArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeLocalMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) {
fake.removeSubscriberMutex.Lock()
fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct {
@@ -1755,6 +1755,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} {
defer fake.hasSdpCidMutex.RUnlock()
fake.iDMutex.RLock()
defer fake.iDMutex.RUnlock()
fake.initiateCloseMutex.RLock()
defer fake.initiateCloseMutex.RUnlock()
fake.isMutedMutex.RLock()
defer fake.isMutedMutex.RUnlock()
fake.isSimulcastMutex.RLock()
@@ -1777,8 +1779,6 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} {
defer fake.publisherVersionMutex.RUnlock()
fake.receiversMutex.RLock()
defer fake.receiversMutex.RUnlock()
fake.removeAllSubscribersMutex.RLock()
defer fake.removeAllSubscribersMutex.RUnlock()
fake.removeSubscriberMutex.RLock()
defer fake.removeSubscriberMutex.RUnlock()
fake.restartMutex.RLock()
+39 -39
View File
@@ -68,6 +68,11 @@ type FakeMediaTrack struct {
iDReturnsOnCall map[int]struct {
result1 livekit.TrackID
}
InitiateCloseStub func(bool)
initiateCloseMutex sync.RWMutex
initiateCloseArgsForCall []struct {
arg1 bool
}
IsMutedStub func() bool
isMutedMutex sync.RWMutex
isMutedArgsForCall []struct {
@@ -159,11 +164,6 @@ type FakeMediaTrack struct {
receiversReturnsOnCall map[int]struct {
result1 []sfu.TrackReceiver
}
RemoveAllSubscribersStub func(bool)
removeAllSubscribersMutex sync.RWMutex
removeAllSubscribersArgsForCall []struct {
arg1 bool
}
RemoveSubscriberStub func(livekit.ParticipantID, bool)
removeSubscriberMutex sync.RWMutex
removeSubscriberArgsForCall []struct {
@@ -529,6 +529,38 @@ func (fake *FakeMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) {
}{result1}
}
func (fake *FakeMediaTrack) InitiateClose(arg1 bool) {
fake.initiateCloseMutex.Lock()
fake.initiateCloseArgsForCall = append(fake.initiateCloseArgsForCall, struct {
arg1 bool
}{arg1})
stub := fake.InitiateCloseStub
fake.recordInvocation("InitiateClose", []interface{}{arg1})
fake.initiateCloseMutex.Unlock()
if stub != nil {
fake.InitiateCloseStub(arg1)
}
}
func (fake *FakeMediaTrack) InitiateCloseCallCount() int {
fake.initiateCloseMutex.RLock()
defer fake.initiateCloseMutex.RUnlock()
return len(fake.initiateCloseArgsForCall)
}
func (fake *FakeMediaTrack) InitiateCloseCalls(stub func(bool)) {
fake.initiateCloseMutex.Lock()
defer fake.initiateCloseMutex.Unlock()
fake.InitiateCloseStub = stub
}
func (fake *FakeMediaTrack) InitiateCloseArgsForCall(i int) bool {
fake.initiateCloseMutex.RLock()
defer fake.initiateCloseMutex.RUnlock()
argsForCall := fake.initiateCloseArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeMediaTrack) IsMuted() bool {
fake.isMutedMutex.Lock()
ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)]
@@ -1014,38 +1046,6 @@ func (fake *FakeMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.TrackRec
}{result1}
}
func (fake *FakeMediaTrack) RemoveAllSubscribers(arg1 bool) {
fake.removeAllSubscribersMutex.Lock()
fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct {
arg1 bool
}{arg1})
stub := fake.RemoveAllSubscribersStub
fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1})
fake.removeAllSubscribersMutex.Unlock()
if stub != nil {
fake.RemoveAllSubscribersStub(arg1)
}
}
func (fake *FakeMediaTrack) RemoveAllSubscribersCallCount() int {
fake.removeAllSubscribersMutex.RLock()
defer fake.removeAllSubscribersMutex.RUnlock()
return len(fake.removeAllSubscribersArgsForCall)
}
func (fake *FakeMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) {
fake.removeAllSubscribersMutex.Lock()
defer fake.removeAllSubscribersMutex.Unlock()
fake.RemoveAllSubscribersStub = stub
}
func (fake *FakeMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool {
fake.removeAllSubscribersMutex.RLock()
defer fake.removeAllSubscribersMutex.RUnlock()
argsForCall := fake.removeAllSubscribersArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) {
fake.removeSubscriberMutex.Lock()
fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct {
@@ -1335,6 +1335,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} {
defer fake.getQualityForDimensionMutex.RUnlock()
fake.iDMutex.RLock()
defer fake.iDMutex.RUnlock()
fake.initiateCloseMutex.RLock()
defer fake.initiateCloseMutex.RUnlock()
fake.isMutedMutex.RLock()
defer fake.isMutedMutex.RUnlock()
fake.isSimulcastMutex.RLock()
@@ -1353,8 +1355,6 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} {
defer fake.publisherVersionMutex.RUnlock()
fake.receiversMutex.RLock()
defer fake.receiversMutex.RUnlock()
fake.removeAllSubscribersMutex.RLock()
defer fake.removeAllSubscribersMutex.RUnlock()
fake.removeSubscriberMutex.RLock()
defer fake.removeSubscriberMutex.RUnlock()
fake.revokeDisallowedSubscribersMutex.RLock()
+2 -2
View File
@@ -68,7 +68,7 @@ func (u *UpTrackManager) Close(willBeResumed bool) {
// remove all subscribers
for _, t := range u.GetPublishedTracks() {
t.RemoveAllSubscribers(willBeResumed)
t.InitiateClose(willBeResumed)
}
if notify && u.onClose != nil {
@@ -317,7 +317,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) {
}
func (u *UpTrackManager) RemovePublishedTrack(track types.MediaTrack, willBeResumed bool) {
track.RemoveAllSubscribers(willBeResumed)
track.InitiateClose(willBeResumed)
u.lock.Lock()
delete(u.publishedTracks, track.ID())
u.lock.Unlock()