mirror of
https://github.com/livekit/livekit.git
synced 2026-04-27 02:05:25 +00:00
Prevent track subscriptions/adding receivers after close (#924)
* Prevent track subscriptions/adding receivers after close With subscribe/unsubscribe queuing, a subscribe may be attempted after a call to `RemoveAllSubscribers`. So, renaming `RemoveAllSubscribers` to `InitiateClose` and maintaining state that track is in the process of closing. * Mime specific remove * Remove unused error * do not add receiver when closing
This commit is contained in:
@@ -225,7 +225,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra
|
||||
)
|
||||
newWR.SetRTCPCh(t.params.RTCPChan)
|
||||
newWR.OnCloseHandler(func() {
|
||||
t.RemoveAllSubscribers(false)
|
||||
t.MediaTrackReceiver.ClearReceiver(mime)
|
||||
if t.MediaTrackReceiver.TryClose() {
|
||||
if t.dynacastManager != nil {
|
||||
|
||||
@@ -26,7 +26,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoReceiver = errors.New("cannot subscribe without a receiver in place")
|
||||
ErrClosingOrClosed = errors.New("track is closing or closed")
|
||||
ErrNoReceiver = errors.New("cannot subscribe without a receiver in place")
|
||||
)
|
||||
|
||||
type simulcastReceiver struct {
|
||||
@@ -64,6 +65,9 @@ type MediaTrackReceiver struct {
|
||||
layerDimensions map[livekit.VideoQuality]*livekit.VideoLayer
|
||||
potentialCodecs []webrtc.RTPCodecParameters
|
||||
pendingSubscribeOp map[livekit.ParticipantID]int
|
||||
isMimeClosed map[string]bool
|
||||
isClosing bool
|
||||
isClosed bool
|
||||
|
||||
onSetupReceiver func(mime string)
|
||||
onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport)
|
||||
@@ -79,6 +83,7 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver
|
||||
trackInfo: proto.Clone(params.TrackInfo).(*livekit.TrackInfo),
|
||||
layerDimensions: make(map[livekit.VideoQuality]*livekit.VideoLayer),
|
||||
pendingSubscribeOp: make(map[livekit.ParticipantID]int),
|
||||
isMimeClosed: make(map[string]bool),
|
||||
}
|
||||
|
||||
t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{
|
||||
@@ -126,10 +131,23 @@ func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime string)) {
|
||||
func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority int, mid string) {
|
||||
t.lock.Lock()
|
||||
|
||||
if t.isClosing || t.isClosed {
|
||||
t.params.Logger.Warnw("cannot set up receiver on closing or closed track", nil)
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
mimeType := receiver.Codec().MimeType
|
||||
if t.isMimeClosed[mimeType] {
|
||||
t.params.Logger.Warnw("cannot set up receiver on closing mime", nil, "mime", mimeType)
|
||||
t.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// codec postion maybe taked by DumbReceiver, check and upgrade to WebRTCReceiver
|
||||
var upgradeReceiver bool
|
||||
for _, r := range t.receivers {
|
||||
if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) {
|
||||
if strings.EqualFold(r.Codec().MimeType, mimeType) {
|
||||
if d, ok := r.TrackReceiver.(*DummyReceiver); ok {
|
||||
d.Upgrade(receiver)
|
||||
upgradeReceiver = true
|
||||
@@ -229,19 +247,34 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string) {
|
||||
if strings.EqualFold(receiver.Codec().MimeType, mime) {
|
||||
t.receivers[idx] = t.receivers[len(t.receivers)-1]
|
||||
t.receivers = t.receivers[:len(t.receivers)-1]
|
||||
|
||||
t.isMimeClosed[mime] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
t.shadowReceiversLocked()
|
||||
t.lock.Unlock()
|
||||
|
||||
t.removeAllSubscribersForMime(mime, false)
|
||||
}
|
||||
|
||||
func (t *MediaTrackReceiver) ClearAllReceivers() {
|
||||
t.lock.Lock()
|
||||
var mimes []string
|
||||
for _, receiver := range t.receivers {
|
||||
mime := receiver.Codec().MimeType
|
||||
t.isMimeClosed[mime] = true
|
||||
mimes = append(mimes, mime)
|
||||
}
|
||||
|
||||
t.receivers = t.receivers[:0]
|
||||
t.receiversShadow = nil
|
||||
t.lock.Unlock()
|
||||
|
||||
for _, mime := range mimes {
|
||||
t.ClearReceiver(mime)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MediaTrackReceiver) OnMediaLossFeedback(f func(dt *sfu.DownTrack, rr *rtcp.ReceiverReport)) {
|
||||
@@ -254,6 +287,11 @@ func (t *MediaTrackReceiver) OnVideoLayerUpdate(f func(layers []*livekit.VideoLa
|
||||
|
||||
func (t *MediaTrackReceiver) TryClose() bool {
|
||||
t.lock.RLock()
|
||||
if t.isClosed {
|
||||
t.lock.RUnlock()
|
||||
return true
|
||||
}
|
||||
|
||||
if len(t.receiversShadow) > 0 {
|
||||
t.lock.RUnlock()
|
||||
return false
|
||||
@@ -266,6 +304,7 @@ func (t *MediaTrackReceiver) TryClose() bool {
|
||||
|
||||
func (t *MediaTrackReceiver) Close() {
|
||||
t.lock.RLock()
|
||||
t.isClosed = true
|
||||
onclose := t.onClose
|
||||
t.lock.RUnlock()
|
||||
|
||||
@@ -387,6 +426,12 @@ func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) (err erro
|
||||
}()
|
||||
|
||||
t.lock.RLock()
|
||||
if t.isClosing || t.isClosed {
|
||||
t.lock.RUnlock()
|
||||
err = ErrClosingOrClosed
|
||||
return
|
||||
}
|
||||
|
||||
receivers := t.receiversShadow
|
||||
potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs))
|
||||
copy(potentialCodecs, t.potentialCodecs)
|
||||
@@ -453,8 +498,19 @@ func (t *MediaTrackReceiver) removeSubscriber(subscriberID livekit.ParticipantID
|
||||
return
|
||||
}
|
||||
|
||||
func (t *MediaTrackReceiver) RemoveAllSubscribers(willBeResumed bool) {
|
||||
t.params.Logger.Infow("removing all subscribers")
|
||||
func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime string, willBeResumed bool) {
|
||||
t.params.Logger.Infow("removing all subscribers", "mime", mime)
|
||||
for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribersForMime(mime) {
|
||||
t.RemoveSubscriber(subscriberID, willBeResumed)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MediaTrackReceiver) InitiateClose(willBeResumed bool) {
|
||||
t.params.Logger.Infow("initiating close")
|
||||
t.lock.Lock()
|
||||
t.isClosing = true
|
||||
t.lock.Unlock()
|
||||
|
||||
for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribers() {
|
||||
t.RemoveSubscriber(subscriberID, willBeResumed)
|
||||
}
|
||||
|
||||
@@ -335,6 +335,21 @@ func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID {
|
||||
return subs
|
||||
}
|
||||
|
||||
func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime string) []livekit.ParticipantID {
|
||||
t.subscribedTracksMu.RLock()
|
||||
defer t.subscribedTracksMu.RUnlock()
|
||||
|
||||
subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks))
|
||||
for id, subTrack := range t.subscribedTracks {
|
||||
if subTrack.DownTrack().Codec().MimeType != mime {
|
||||
continue
|
||||
}
|
||||
|
||||
subs = append(subs, id)
|
||||
}
|
||||
return subs
|
||||
}
|
||||
|
||||
func (t *MediaTrackSubscriptions) GetNumSubscribers() int {
|
||||
t.subscribedTracksMu.RLock()
|
||||
defer t.subscribedTracksMu.RUnlock()
|
||||
|
||||
@@ -340,7 +340,7 @@ type MediaTrack interface {
|
||||
AddSubscriber(participant LocalParticipant) error
|
||||
RemoveSubscriber(participantID livekit.ParticipantID, willBeResumed bool)
|
||||
IsSubscriber(subID livekit.ParticipantID) bool
|
||||
RemoveAllSubscribers(willBeResumed bool)
|
||||
InitiateClose(willBeResumed bool)
|
||||
RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity
|
||||
GetAllSubscribers() []livekit.ParticipantID
|
||||
GetNumSubscribers() int
|
||||
|
||||
@@ -101,6 +101,11 @@ type FakeLocalMediaTrack struct {
|
||||
iDReturnsOnCall map[int]struct {
|
||||
result1 livekit.TrackID
|
||||
}
|
||||
InitiateCloseStub func(bool)
|
||||
initiateCloseMutex sync.RWMutex
|
||||
initiateCloseArgsForCall []struct {
|
||||
arg1 bool
|
||||
}
|
||||
IsMutedStub func() bool
|
||||
isMutedMutex sync.RWMutex
|
||||
isMutedArgsForCall []struct {
|
||||
@@ -204,11 +209,6 @@ type FakeLocalMediaTrack struct {
|
||||
receiversReturnsOnCall map[int]struct {
|
||||
result1 []sfu.TrackReceiver
|
||||
}
|
||||
RemoveAllSubscribersStub func(bool)
|
||||
removeAllSubscribersMutex sync.RWMutex
|
||||
removeAllSubscribersArgsForCall []struct {
|
||||
arg1 bool
|
||||
}
|
||||
RemoveSubscriberStub func(livekit.ParticipantID, bool)
|
||||
removeSubscriberMutex sync.RWMutex
|
||||
removeSubscriberArgsForCall []struct {
|
||||
@@ -763,6 +763,38 @@ func (fake *FakeLocalMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID)
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) InitiateClose(arg1 bool) {
|
||||
fake.initiateCloseMutex.Lock()
|
||||
fake.initiateCloseArgsForCall = append(fake.initiateCloseArgsForCall, struct {
|
||||
arg1 bool
|
||||
}{arg1})
|
||||
stub := fake.InitiateCloseStub
|
||||
fake.recordInvocation("InitiateClose", []interface{}{arg1})
|
||||
fake.initiateCloseMutex.Unlock()
|
||||
if stub != nil {
|
||||
fake.InitiateCloseStub(arg1)
|
||||
}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) InitiateCloseCallCount() int {
|
||||
fake.initiateCloseMutex.RLock()
|
||||
defer fake.initiateCloseMutex.RUnlock()
|
||||
return len(fake.initiateCloseArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) InitiateCloseCalls(stub func(bool)) {
|
||||
fake.initiateCloseMutex.Lock()
|
||||
defer fake.initiateCloseMutex.Unlock()
|
||||
fake.InitiateCloseStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) InitiateCloseArgsForCall(i int) bool {
|
||||
fake.initiateCloseMutex.RLock()
|
||||
defer fake.initiateCloseMutex.RUnlock()
|
||||
argsForCall := fake.initiateCloseArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) IsMuted() bool {
|
||||
fake.isMutedMutex.Lock()
|
||||
ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)]
|
||||
@@ -1319,38 +1351,6 @@ func (fake *FakeLocalMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.Tra
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) RemoveAllSubscribers(arg1 bool) {
|
||||
fake.removeAllSubscribersMutex.Lock()
|
||||
fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct {
|
||||
arg1 bool
|
||||
}{arg1})
|
||||
stub := fake.RemoveAllSubscribersStub
|
||||
fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1})
|
||||
fake.removeAllSubscribersMutex.Unlock()
|
||||
if stub != nil {
|
||||
fake.RemoveAllSubscribersStub(arg1)
|
||||
}
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCallCount() int {
|
||||
fake.removeAllSubscribersMutex.RLock()
|
||||
defer fake.removeAllSubscribersMutex.RUnlock()
|
||||
return len(fake.removeAllSubscribersArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) {
|
||||
fake.removeAllSubscribersMutex.Lock()
|
||||
defer fake.removeAllSubscribersMutex.Unlock()
|
||||
fake.RemoveAllSubscribersStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool {
|
||||
fake.removeAllSubscribersMutex.RLock()
|
||||
defer fake.removeAllSubscribersMutex.RUnlock()
|
||||
argsForCall := fake.removeAllSubscribersArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeLocalMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) {
|
||||
fake.removeSubscriberMutex.Lock()
|
||||
fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct {
|
||||
@@ -1755,6 +1755,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} {
|
||||
defer fake.hasSdpCidMutex.RUnlock()
|
||||
fake.iDMutex.RLock()
|
||||
defer fake.iDMutex.RUnlock()
|
||||
fake.initiateCloseMutex.RLock()
|
||||
defer fake.initiateCloseMutex.RUnlock()
|
||||
fake.isMutedMutex.RLock()
|
||||
defer fake.isMutedMutex.RUnlock()
|
||||
fake.isSimulcastMutex.RLock()
|
||||
@@ -1777,8 +1779,6 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} {
|
||||
defer fake.publisherVersionMutex.RUnlock()
|
||||
fake.receiversMutex.RLock()
|
||||
defer fake.receiversMutex.RUnlock()
|
||||
fake.removeAllSubscribersMutex.RLock()
|
||||
defer fake.removeAllSubscribersMutex.RUnlock()
|
||||
fake.removeSubscriberMutex.RLock()
|
||||
defer fake.removeSubscriberMutex.RUnlock()
|
||||
fake.restartMutex.RLock()
|
||||
|
||||
@@ -68,6 +68,11 @@ type FakeMediaTrack struct {
|
||||
iDReturnsOnCall map[int]struct {
|
||||
result1 livekit.TrackID
|
||||
}
|
||||
InitiateCloseStub func(bool)
|
||||
initiateCloseMutex sync.RWMutex
|
||||
initiateCloseArgsForCall []struct {
|
||||
arg1 bool
|
||||
}
|
||||
IsMutedStub func() bool
|
||||
isMutedMutex sync.RWMutex
|
||||
isMutedArgsForCall []struct {
|
||||
@@ -159,11 +164,6 @@ type FakeMediaTrack struct {
|
||||
receiversReturnsOnCall map[int]struct {
|
||||
result1 []sfu.TrackReceiver
|
||||
}
|
||||
RemoveAllSubscribersStub func(bool)
|
||||
removeAllSubscribersMutex sync.RWMutex
|
||||
removeAllSubscribersArgsForCall []struct {
|
||||
arg1 bool
|
||||
}
|
||||
RemoveSubscriberStub func(livekit.ParticipantID, bool)
|
||||
removeSubscriberMutex sync.RWMutex
|
||||
removeSubscriberArgsForCall []struct {
|
||||
@@ -529,6 +529,38 @@ func (fake *FakeMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) {
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) InitiateClose(arg1 bool) {
|
||||
fake.initiateCloseMutex.Lock()
|
||||
fake.initiateCloseArgsForCall = append(fake.initiateCloseArgsForCall, struct {
|
||||
arg1 bool
|
||||
}{arg1})
|
||||
stub := fake.InitiateCloseStub
|
||||
fake.recordInvocation("InitiateClose", []interface{}{arg1})
|
||||
fake.initiateCloseMutex.Unlock()
|
||||
if stub != nil {
|
||||
fake.InitiateCloseStub(arg1)
|
||||
}
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) InitiateCloseCallCount() int {
|
||||
fake.initiateCloseMutex.RLock()
|
||||
defer fake.initiateCloseMutex.RUnlock()
|
||||
return len(fake.initiateCloseArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) InitiateCloseCalls(stub func(bool)) {
|
||||
fake.initiateCloseMutex.Lock()
|
||||
defer fake.initiateCloseMutex.Unlock()
|
||||
fake.InitiateCloseStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) InitiateCloseArgsForCall(i int) bool {
|
||||
fake.initiateCloseMutex.RLock()
|
||||
defer fake.initiateCloseMutex.RUnlock()
|
||||
argsForCall := fake.initiateCloseArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) IsMuted() bool {
|
||||
fake.isMutedMutex.Lock()
|
||||
ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)]
|
||||
@@ -1014,38 +1046,6 @@ func (fake *FakeMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.TrackRec
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) RemoveAllSubscribers(arg1 bool) {
|
||||
fake.removeAllSubscribersMutex.Lock()
|
||||
fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct {
|
||||
arg1 bool
|
||||
}{arg1})
|
||||
stub := fake.RemoveAllSubscribersStub
|
||||
fake.recordInvocation("RemoveAllSubscribers", []interface{}{arg1})
|
||||
fake.removeAllSubscribersMutex.Unlock()
|
||||
if stub != nil {
|
||||
fake.RemoveAllSubscribersStub(arg1)
|
||||
}
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) RemoveAllSubscribersCallCount() int {
|
||||
fake.removeAllSubscribersMutex.RLock()
|
||||
defer fake.removeAllSubscribersMutex.RUnlock()
|
||||
return len(fake.removeAllSubscribersArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) RemoveAllSubscribersCalls(stub func(bool)) {
|
||||
fake.removeAllSubscribersMutex.Lock()
|
||||
defer fake.removeAllSubscribersMutex.Unlock()
|
||||
fake.RemoveAllSubscribersStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) RemoveAllSubscribersArgsForCall(i int) bool {
|
||||
fake.removeAllSubscribersMutex.RLock()
|
||||
defer fake.removeAllSubscribersMutex.RUnlock()
|
||||
argsForCall := fake.removeAllSubscribersArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) {
|
||||
fake.removeSubscriberMutex.Lock()
|
||||
fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct {
|
||||
@@ -1335,6 +1335,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} {
|
||||
defer fake.getQualityForDimensionMutex.RUnlock()
|
||||
fake.iDMutex.RLock()
|
||||
defer fake.iDMutex.RUnlock()
|
||||
fake.initiateCloseMutex.RLock()
|
||||
defer fake.initiateCloseMutex.RUnlock()
|
||||
fake.isMutedMutex.RLock()
|
||||
defer fake.isMutedMutex.RUnlock()
|
||||
fake.isSimulcastMutex.RLock()
|
||||
@@ -1353,8 +1355,6 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} {
|
||||
defer fake.publisherVersionMutex.RUnlock()
|
||||
fake.receiversMutex.RLock()
|
||||
defer fake.receiversMutex.RUnlock()
|
||||
fake.removeAllSubscribersMutex.RLock()
|
||||
defer fake.removeAllSubscribersMutex.RUnlock()
|
||||
fake.removeSubscriberMutex.RLock()
|
||||
defer fake.removeSubscriberMutex.RUnlock()
|
||||
fake.revokeDisallowedSubscribersMutex.RLock()
|
||||
|
||||
@@ -68,7 +68,7 @@ func (u *UpTrackManager) Close(willBeResumed bool) {
|
||||
|
||||
// remove all subscribers
|
||||
for _, t := range u.GetPublishedTracks() {
|
||||
t.RemoveAllSubscribers(willBeResumed)
|
||||
t.InitiateClose(willBeResumed)
|
||||
}
|
||||
|
||||
if notify && u.onClose != nil {
|
||||
@@ -317,7 +317,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) {
|
||||
}
|
||||
|
||||
func (u *UpTrackManager) RemovePublishedTrack(track types.MediaTrack, willBeResumed bool) {
|
||||
track.RemoveAllSubscribers(willBeResumed)
|
||||
track.InitiateClose(willBeResumed)
|
||||
u.lock.Lock()
|
||||
delete(u.publishedTracks, track.ID())
|
||||
u.lock.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user