From a4888fcf8fada3aecfac42cfe2668514ea875c85 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 21 Dec 2023 16:02:10 +0530 Subject: [PATCH] Prevent unsafe access (hopefully). (#2332) * Prevent unsafe access (hopefully). Thank you @paulwe for catching it. * prevent recursive locks --- pkg/rtc/mediatrackreceiver.go | 54 +++++++++++++++-------------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index b5fefa8a3..d04db1a02 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -133,13 +133,11 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams, ti *livekit.TrackInf } func (t *MediaTrackReceiver) Restart() { - t.lock.Lock() - receivers := t.receivers - ti := t.trackInfo - t.lock.Unlock() + t.lock.RLock() + hq := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.trackInfo) + t.lock.RUnlock() - hq := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, ti) - for _, receiver := range receivers { + for _, receiver := range t.Receivers() { receiver.SetMaxExpectedSpatialLayer(hq) } } @@ -263,7 +261,6 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string, willBeResumed bool) { break } } - t.lock.Unlock() t.removeAllSubscribersForMime(mime, willBeResumed) @@ -271,14 +268,12 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string, willBeResumed bool) { func (t *MediaTrackReceiver) ClearAllReceivers(willBeResumed bool) { t.params.Logger.Debugw("clearing all receivers") - t.lock.Lock() + t.lock.RLock() var mimes []string for _, receiver := range t.receivers { mimes = append(mimes, receiver.Codec().MimeType) } - - t.receivers = nil - t.lock.Unlock() + t.lock.RUnlock() for _, mime := range mimes { t.ClearReceiver(mime, willBeResumed) @@ -418,9 +413,9 @@ func (t *MediaTrackReceiver) IsMuted() bool { func (t *MediaTrackReceiver) SetMuted(muted bool) { t.lock.Lock() t.trackInfo.Muted = muted - receivers := t.receivers t.lock.Unlock() - for _, receiver := range receivers { + + for _, receiver := range t.Receivers() { receiver.SetUpTrackPaused(muted) } @@ -445,7 +440,7 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su return nil, ErrNotOpen } - receivers := t.receivers + receivers := t.simulcastReceiversLocked() potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs)) copy(potentialCodecs, t.potentialCodecs) t.lock.RUnlock() @@ -529,11 +524,10 @@ func (t *MediaTrackReceiver) RevokeDisallowedSubscribers(allowedSubscriberIdenti func (t *MediaTrackReceiver) updateTrackInfoOfReceivers() { t.lock.RLock() - receivers := t.receivers - ti := t.trackInfo + ti := proto.Clone(t.trackInfo).(*livekit.TrackInfo) t.lock.RUnlock() - for _, r := range receivers { + for _, r := range t.Receivers() { r.UpdateTrackInfo(ti) } } @@ -821,10 +815,7 @@ func (t *MediaTrackReceiver) DebugInfo() map[string]interface{} { info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() - t.lock.RLock() - receivers := t.receivers - t.lock.RUnlock() - for _, receiver := range receivers { + for _, receiver := range t.Receivers() { info[receiver.Codec().MimeType] = receiver.DebugInfo() } @@ -870,13 +861,17 @@ func (t *MediaTrackReceiver) Receivers() []sfu.TrackReceiver { return receivers } -func (t *MediaTrackReceiver) SetRTT(rtt uint32) { - t.lock.RLock() - receivers := t.receivers - t.lock.RUnlock() +func (t *MediaTrackReceiver) simulcastReceiversLocked() []*simulcastReceiver { + receivers := make([]*simulcastReceiver, 0, len(t.receivers)) + for _, r := range t.receivers { + receivers = append(receivers, r) + } + return receivers +} - for _, r := range receivers { - if wr, ok := r.TrackReceiver.(*sfu.WebRTCReceiver); ok { +func (t *MediaTrackReceiver) SetRTT(rtt uint32) { + for _, r := range t.Receivers() { + if wr, ok := r.(*sfu.WebRTCReceiver); ok { wr.SetRTT(rtt) } } @@ -906,10 +901,7 @@ func (t *MediaTrackReceiver) IsEncrypted() bool { } func (t *MediaTrackReceiver) GetTrackStats() *livekit.RTPStats { - t.lock.RLock() - receivers := t.receivers - t.lock.RUnlock() - + receivers := t.Receivers() stats := make([]*livekit.RTPStats, 0, len(receivers)) for _, receiver := range receivers { receiverStats := receiver.GetTrackStats()