From 3bfdb2523eaf22de702cce1b7fc6e8ed694efe96 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 29 Jan 2022 20:14:36 +0530 Subject: [PATCH] Catch some instances of traversing map outside lock (#388) * one more place, do not range over map outside lock * Catch one more location * Catching a couple of more places --- pkg/rtc/mediatracksubscriptions.go | 57 +++++++--------- pkg/rtc/subscribedtrack.go | 4 ++ pkg/rtc/types/interfaces.go | 2 +- .../typesfakes/fake_local_media_track.go | 65 ------------------- pkg/rtc/types/typesfakes/fake_media_track.go | 65 ------------------- .../types/typesfakes/fake_subscribed_track.go | 65 +++++++++++++++++++ 6 files changed, 93 insertions(+), 165 deletions(-) diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 8b5b342fe..82d6374a0 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -76,12 +76,8 @@ func (t *MediaTrackSubscriptions) OnNoSubscribers(f func()) { } func (t *MediaTrackSubscriptions) SetMuted(muted bool) { - t.subscribedTracksMu.RLock() - subscribedTracks := t.subscribedTracks - t.subscribedTracksMu.RUnlock() - - // mute all subscribed tracks - for _, st := range subscribedTracks { + // update mute of all subscribed tracks + for _, st := range t.getAllSubscribedTracks() { st.SetPublisherMuted(muted) } @@ -267,7 +263,7 @@ func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { t.params.Logger.Debugw("removing all subscribers") t.subscribedTracksMu.Lock() - subscribedTracks := t.subscribedTracks + subscribedTracks := t.getAllSubscribedTracksLocked() t.subscribedTracks = make(map[livekit.ParticipantID]types.SubscribedTrack) t.subscribedTracksMu.Unlock() @@ -279,15 +275,11 @@ func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID { var revokedSubscriberIDs []livekit.ParticipantID - t.subscribedTracksMu.RLock() - subscribedTracks := t.subscribedTracks - t.subscribedTracksMu.RUnlock() - // LK-TODO: large number of subscribers needs to be solved for this loop - for subID, subTrack := range subscribedTracks { + for _, subTrack := range t.getAllSubscribedTracks() { found := false for _, allowedID := range allowedSubscriberIDs { - if subID == allowedID { + if subTrack.SubscriberID() == allowedID { found = true break } @@ -295,7 +287,7 @@ func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberI if !found { go subTrack.DownTrack().Close() - revokedSubscriberIDs = append(revokedSubscriberIDs, subID) + revokedSubscriberIDs = append(revokedSubscriberIDs, subTrack.SubscriberID()) } } @@ -303,11 +295,7 @@ func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberI } func (t *MediaTrackSubscriptions) UpdateVideoLayers() { - t.subscribedTracksMu.RLock() - subscribedTracks := t.subscribedTracks - t.subscribedTracksMu.RUnlock() - - for _, st := range subscribedTracks { + for _, st := range t.getAllSubscribedTracks() { st.UpdateVideoLayer() } } @@ -319,6 +307,21 @@ func (t *MediaTrackSubscriptions) getSubscribedTrack(subscriberID livekit.Partic return t.subscribedTracks[subscriberID] } +func (t *MediaTrackSubscriptions) getAllSubscribedTracks() []types.SubscribedTrack { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + return t.getAllSubscribedTracksLocked() +} + +func (t *MediaTrackSubscriptions) getAllSubscribedTracksLocked() []types.SubscribedTrack { + subTracks := make([]types.SubscribedTrack, 0, len(t.subscribedTracks)) + for _, subTrack := range t.subscribedTracks { + subTracks = append(subTracks, subTrack) + } + return subTracks +} + // TODO: send for all down tracks from the source participant // https://tools.ietf.org/html/rfc7941 func (t *MediaTrackSubscriptions) sendDownTrackBindingReports(sub types.LocalParticipant) { @@ -358,12 +361,8 @@ func (t *MediaTrackSubscriptions) sendDownTrackBindingReports(sub types.LocalPar } func (t *MediaTrackSubscriptions) DebugInfo() []map[string]interface{} { - t.subscribedTracksMu.RLock() - subscribedTracks := t.subscribedTracks - t.subscribedTracksMu.RUnlock() - subscribedTrackInfo := make([]map[string]interface{}, 0) - for _, val := range subscribedTracks { + for _, val := range t.getAllSubscribedTracks() { if st, ok := val.(*SubscribedTrack); ok { dt := st.DownTrack().DebugInfo() dt["PubMuted"] = st.pubMuted.Get() @@ -531,13 +530,3 @@ func (t *MediaTrackSubscriptions) maybeNotifyNoSubscribers() { t.onNoSubscribers() } } - -func (t *MediaTrackSubscriptions) GetAllSubscriberIDs() []livekit.ParticipantID { - t.subscribedTracksMu.RLock() - defer t.subscribedTracksMu.RUnlock() - ids := make([]livekit.ParticipantID, 0, len(t.subscribedTracks)) - for id := range t.subscribedTracks { - ids = append(ids, id) - } - return ids -} diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index 7a6ef25ea..fafe20280 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -64,6 +64,10 @@ func (t *SubscribedTrack) PublisherIdentity() livekit.ParticipantIdentity { return t.params.PublisherIdentity } +func (t *SubscribedTrack) SubscriberID() livekit.ParticipantID { + return t.params.SubscriberID +} + func (t *SubscribedTrack) DownTrack() *sfu.DownTrack { return t.params.DownTrack } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 3aa48070a..b19906187 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -193,7 +193,6 @@ type MediaTrack interface { AddSubscriber(participant LocalParticipant) error RemoveSubscriber(participantID livekit.ParticipantID, resume bool) IsSubscriber(subID livekit.ParticipantID) bool - GetAllSubscriberIDs() []livekit.ParticipantID RemoveAllSubscribers() RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID @@ -224,6 +223,7 @@ type SubscribedTrack interface { ID() livekit.TrackID PublisherID() livekit.ParticipantID PublisherIdentity() livekit.ParticipantIdentity + SubscriberID() livekit.ParticipantID DownTrack() *sfu.DownTrack MediaTrack() MediaTrack IsMuted() bool diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 595a4c7ea..f5189610f 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -26,16 +26,6 @@ type FakeLocalMediaTrack struct { addSubscriberReturnsOnCall map[int]struct { result1 error } - GetAllSubscriberIDsStub func() []livekit.ParticipantID - getAllSubscriberIDsMutex sync.RWMutex - getAllSubscriberIDsArgsForCall []struct { - } - getAllSubscriberIDsReturns struct { - result1 []livekit.ParticipantID - } - getAllSubscriberIDsReturnsOnCall map[int]struct { - result1 []livekit.ParticipantID - } GetAudioLevelStub func() (uint8, bool) getAudioLevelMutex sync.RWMutex getAudioLevelArgsForCall []struct { @@ -347,59 +337,6 @@ func (fake *FakeLocalMediaTrack) AddSubscriberReturnsOnCall(i int, result1 error }{result1} } -func (fake *FakeLocalMediaTrack) GetAllSubscriberIDs() []livekit.ParticipantID { - fake.getAllSubscriberIDsMutex.Lock() - ret, specificReturn := fake.getAllSubscriberIDsReturnsOnCall[len(fake.getAllSubscriberIDsArgsForCall)] - fake.getAllSubscriberIDsArgsForCall = append(fake.getAllSubscriberIDsArgsForCall, struct { - }{}) - stub := fake.GetAllSubscriberIDsStub - fakeReturns := fake.getAllSubscriberIDsReturns - fake.recordInvocation("GetAllSubscriberIDs", []interface{}{}) - fake.getAllSubscriberIDsMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeLocalMediaTrack) GetAllSubscriberIDsCallCount() int { - fake.getAllSubscriberIDsMutex.RLock() - defer fake.getAllSubscriberIDsMutex.RUnlock() - return len(fake.getAllSubscriberIDsArgsForCall) -} - -func (fake *FakeLocalMediaTrack) GetAllSubscriberIDsCalls(stub func() []livekit.ParticipantID) { - fake.getAllSubscriberIDsMutex.Lock() - defer fake.getAllSubscriberIDsMutex.Unlock() - fake.GetAllSubscriberIDsStub = stub -} - -func (fake *FakeLocalMediaTrack) GetAllSubscriberIDsReturns(result1 []livekit.ParticipantID) { - fake.getAllSubscriberIDsMutex.Lock() - defer fake.getAllSubscriberIDsMutex.Unlock() - fake.GetAllSubscriberIDsStub = nil - fake.getAllSubscriberIDsReturns = struct { - result1 []livekit.ParticipantID - }{result1} -} - -func (fake *FakeLocalMediaTrack) GetAllSubscriberIDsReturnsOnCall(i int, result1 []livekit.ParticipantID) { - fake.getAllSubscriberIDsMutex.Lock() - defer fake.getAllSubscriberIDsMutex.Unlock() - fake.GetAllSubscriberIDsStub = nil - if fake.getAllSubscriberIDsReturnsOnCall == nil { - fake.getAllSubscriberIDsReturnsOnCall = make(map[int]struct { - result1 []livekit.ParticipantID - }) - } - fake.getAllSubscriberIDsReturnsOnCall[i] = struct { - result1 []livekit.ParticipantID - }{result1} -} - func (fake *FakeLocalMediaTrack) GetAudioLevel() (uint8, bool) { fake.getAudioLevelMutex.Lock() ret, specificReturn := fake.getAudioLevelReturnsOnCall[len(fake.getAudioLevelArgsForCall)] @@ -1566,8 +1503,6 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.addOnCloseMutex.RUnlock() fake.addSubscriberMutex.RLock() defer fake.addSubscriberMutex.RUnlock() - fake.getAllSubscriberIDsMutex.RLock() - defer fake.getAllSubscriberIDsMutex.RUnlock() fake.getAudioLevelMutex.RLock() defer fake.getAudioLevelMutex.RUnlock() fake.getConnectionScoreMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 34988407f..fb4f1560e 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -26,16 +26,6 @@ type FakeMediaTrack struct { addSubscriberReturnsOnCall map[int]struct { result1 error } - GetAllSubscriberIDsStub func() []livekit.ParticipantID - getAllSubscriberIDsMutex sync.RWMutex - getAllSubscriberIDsArgsForCall []struct { - } - getAllSubscriberIDsReturns struct { - result1 []livekit.ParticipantID - } - getAllSubscriberIDsReturnsOnCall map[int]struct { - result1 []livekit.ParticipantID - } GetQualityForDimensionStub func(uint32, uint32) livekit.VideoQuality getQualityForDimensionMutex sync.RWMutex getQualityForDimensionArgsForCall []struct { @@ -305,59 +295,6 @@ func (fake *FakeMediaTrack) AddSubscriberReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeMediaTrack) GetAllSubscriberIDs() []livekit.ParticipantID { - fake.getAllSubscriberIDsMutex.Lock() - ret, specificReturn := fake.getAllSubscriberIDsReturnsOnCall[len(fake.getAllSubscriberIDsArgsForCall)] - fake.getAllSubscriberIDsArgsForCall = append(fake.getAllSubscriberIDsArgsForCall, struct { - }{}) - stub := fake.GetAllSubscriberIDsStub - fakeReturns := fake.getAllSubscriberIDsReturns - fake.recordInvocation("GetAllSubscriberIDs", []interface{}{}) - fake.getAllSubscriberIDsMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeMediaTrack) GetAllSubscriberIDsCallCount() int { - fake.getAllSubscriberIDsMutex.RLock() - defer fake.getAllSubscriberIDsMutex.RUnlock() - return len(fake.getAllSubscriberIDsArgsForCall) -} - -func (fake *FakeMediaTrack) GetAllSubscriberIDsCalls(stub func() []livekit.ParticipantID) { - fake.getAllSubscriberIDsMutex.Lock() - defer fake.getAllSubscriberIDsMutex.Unlock() - fake.GetAllSubscriberIDsStub = stub -} - -func (fake *FakeMediaTrack) GetAllSubscriberIDsReturns(result1 []livekit.ParticipantID) { - fake.getAllSubscriberIDsMutex.Lock() - defer fake.getAllSubscriberIDsMutex.Unlock() - fake.GetAllSubscriberIDsStub = nil - fake.getAllSubscriberIDsReturns = struct { - result1 []livekit.ParticipantID - }{result1} -} - -func (fake *FakeMediaTrack) GetAllSubscriberIDsReturnsOnCall(i int, result1 []livekit.ParticipantID) { - fake.getAllSubscriberIDsMutex.Lock() - defer fake.getAllSubscriberIDsMutex.Unlock() - fake.GetAllSubscriberIDsStub = nil - if fake.getAllSubscriberIDsReturnsOnCall == nil { - fake.getAllSubscriberIDsReturnsOnCall = make(map[int]struct { - result1 []livekit.ParticipantID - }) - } - fake.getAllSubscriberIDsReturnsOnCall[i] = struct { - result1 []livekit.ParticipantID - }{result1} -} - func (fake *FakeMediaTrack) GetQualityForDimension(arg1 uint32, arg2 uint32) livekit.VideoQuality { fake.getQualityForDimensionMutex.Lock() ret, specificReturn := fake.getQualityForDimensionReturnsOnCall[len(fake.getQualityForDimensionArgsForCall)] @@ -1309,8 +1246,6 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.addOnCloseMutex.RUnlock() fake.addSubscriberMutex.RLock() defer fake.addSubscriberMutex.RUnlock() - fake.getAllSubscriberIDsMutex.RLock() - defer fake.getAllSubscriberIDsMutex.RUnlock() fake.getQualityForDimensionMutex.RLock() defer fake.getQualityForDimensionMutex.RUnlock() fake.iDMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_subscribed_track.go b/pkg/rtc/types/typesfakes/fake_subscribed_track.go index edf80a1dd..5cf5de198 100644 --- a/pkg/rtc/types/typesfakes/fake_subscribed_track.go +++ b/pkg/rtc/types/typesfakes/fake_subscribed_track.go @@ -80,6 +80,16 @@ type FakeSubscribedTrack struct { setPublisherMutedArgsForCall []struct { arg1 bool } + SubscriberIDStub func() livekit.ParticipantID + subscriberIDMutex sync.RWMutex + subscriberIDArgsForCall []struct { + } + subscriberIDReturns struct { + result1 livekit.ParticipantID + } + subscriberIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } UpdateSubscriberSettingsStub func(*livekit.UpdateTrackSettings) updateSubscriberSettingsMutex sync.RWMutex updateSubscriberSettingsArgsForCall []struct { @@ -475,6 +485,59 @@ func (fake *FakeSubscribedTrack) SetPublisherMutedArgsForCall(i int) bool { return argsForCall.arg1 } +func (fake *FakeSubscribedTrack) SubscriberID() livekit.ParticipantID { + fake.subscriberIDMutex.Lock() + ret, specificReturn := fake.subscriberIDReturnsOnCall[len(fake.subscriberIDArgsForCall)] + fake.subscriberIDArgsForCall = append(fake.subscriberIDArgsForCall, struct { + }{}) + stub := fake.SubscriberIDStub + fakeReturns := fake.subscriberIDReturns + fake.recordInvocation("SubscriberID", []interface{}{}) + fake.subscriberIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) SubscriberIDCallCount() int { + fake.subscriberIDMutex.RLock() + defer fake.subscriberIDMutex.RUnlock() + return len(fake.subscriberIDArgsForCall) +} + +func (fake *FakeSubscribedTrack) SubscriberIDCalls(stub func() livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = stub +} + +func (fake *FakeSubscribedTrack) SubscriberIDReturns(result1 livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = nil + fake.subscriberIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = nil + if fake.subscriberIDReturnsOnCall == nil { + fake.subscriberIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.subscriberIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + func (fake *FakeSubscribedTrack) UpdateSubscriberSettings(arg1 *livekit.UpdateTrackSettings) { fake.updateSubscriberSettingsMutex.Lock() fake.updateSubscriberSettingsArgsForCall = append(fake.updateSubscriberSettingsArgsForCall, struct { @@ -550,6 +613,8 @@ func (fake *FakeSubscribedTrack) Invocations() map[string][][]interface{} { defer fake.publisherIdentityMutex.RUnlock() fake.setPublisherMutedMutex.RLock() defer fake.setPublisherMutedMutex.RUnlock() + fake.subscriberIDMutex.RLock() + defer fake.subscriberIDMutex.RUnlock() fake.updateSubscriberSettingsMutex.RLock() defer fake.updateSubscriberSettingsMutex.RUnlock() fake.updateVideoLayerMutex.RLock()