From 01f90d185f5a2de7fcfece86a261bf5b1f1ca295 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Thu, 21 Dec 2023 08:23:22 -0800 Subject: [PATCH] copy receivers on write (#2336) * copy receivers on write * cleanup * cleanup * test --- pkg/rtc/mediatrackreceiver.go | 85 ++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index d04db1a02..3942264ef 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -24,6 +24,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" "google.golang.org/protobuf/proto" "github.com/livekit/protocol/livekit" @@ -137,7 +138,7 @@ func (t *MediaTrackReceiver) Restart() { hq := buffer.VideoQualityToSpatialLayer(livekit.VideoQuality_HIGH, t.trackInfo) t.lock.RUnlock() - for _, receiver := range t.Receivers() { + for _, receiver := range t.loadReceivers() { receiver.SetMaxExpectedSpatialLayer(hq) } } @@ -156,9 +157,11 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority return } + receivers := slices.Clone(t.receivers) + // codec position maybe taken by DummyReceiver, check and upgrade to WebRTCReceiver var upgradeReceiver bool - for _, r := range t.receivers { + for _, r := range receivers { if strings.EqualFold(r.Codec().MimeType, receiver.Codec().MimeType) { if d, ok := r.TrackReceiver.(*DummyReceiver); ok { d.Upgrade(receiver) @@ -168,11 +171,11 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority } } if !upgradeReceiver { - t.receivers = append(t.receivers, &simulcastReceiver{TrackReceiver: receiver, priority: priority}) + receivers = append(receivers, &simulcastReceiver{TrackReceiver: receiver, priority: priority}) } - sort.Slice(t.receivers, func(i, j int) bool { - return t.receivers[i].Priority() < t.receivers[j].Priority() + sort.Slice(receivers, func(i, j int) bool { + return receivers[i].Priority() < receivers[j].Priority() }) if mid != "" { @@ -195,8 +198,12 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority } } + t.receivers = receivers + onSetupReceiver := t.onSetupReceiver + t.lock.Unlock() + var receiverCodecs []string - for _, r := range t.receivers { + for _, r := range receivers { receiverCodecs = append(receiverCodecs, r.Codec().MimeType) } t.params.Logger.Debugw( @@ -206,8 +213,6 @@ func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority "receivers", receiverCodecs, "mid", mid, ) - onSetupReceiver := t.onSetupReceiver - t.lock.Unlock() if onSetupReceiver != nil { onSetupReceiver(receiver.Codec().MimeType) @@ -225,10 +230,11 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete } } t.lock.Lock() + receivers := slices.Clone(t.receivers) t.potentialCodecs = codecs for i, c := range codecs { var exist bool - for _, r := range t.receivers { + for _, r := range receivers { if strings.EqualFold(c.MimeType, r.Codec().MimeType) { exist = true break @@ -239,28 +245,30 @@ func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParamete if !sfu.IsSvcCodec(c.MimeType) { extHeaders = headersWithoutDD } - t.receivers = append(t.receivers, &simulcastReceiver{ + receivers = append(receivers, &simulcastReceiver{ TrackReceiver: NewDummyReceiver(livekit.TrackID(t.trackInfo.Sid), string(t.PublisherID()), c, extHeaders), priority: i, }) } } - sort.Slice(t.receivers, func(i, j int) bool { - return t.receivers[i].Priority() < t.receivers[j].Priority() + sort.Slice(receivers, func(i, j int) bool { + return receivers[i].Priority() < receivers[j].Priority() }) + t.receivers = receivers t.lock.Unlock() } func (t *MediaTrackReceiver) ClearReceiver(mime string, willBeResumed bool) { - t.params.Logger.Debugw("clearing receiver", "mime", mime) t.lock.Lock() - for idx, receiver := range t.receivers { + receivers := slices.Clone(t.receivers) + for idx, receiver := range receivers { if strings.EqualFold(receiver.Codec().MimeType, mime) { - t.receivers[idx] = t.receivers[len(t.receivers)-1] - t.receivers = t.receivers[:len(t.receivers)-1] + receivers[idx] = receivers[len(receivers)-1] + receivers = receivers[:len(receivers)-1] break } } + t.receivers = receivers t.lock.Unlock() t.removeAllSubscribersForMime(mime, willBeResumed) @@ -268,15 +276,13 @@ func (t *MediaTrackReceiver) ClearReceiver(mime string, willBeResumed bool) { func (t *MediaTrackReceiver) ClearAllReceivers(willBeResumed bool) { t.params.Logger.Debugw("clearing all receivers") - t.lock.RLock() - var mimes []string - for _, receiver := range t.receivers { - mimes = append(mimes, receiver.Codec().MimeType) - } - t.lock.RUnlock() + t.lock.Lock() + receivers := t.receivers + t.receivers = nil + t.lock.Unlock() - for _, mime := range mimes { - t.ClearReceiver(mime, willBeResumed) + for _, r := range receivers { + t.removeAllSubscribersForMime(r.Codec().MimeType, willBeResumed) } } @@ -415,7 +421,7 @@ func (t *MediaTrackReceiver) SetMuted(muted bool) { t.trackInfo.Muted = muted t.lock.Unlock() - for _, receiver := range t.Receivers() { + for _, receiver := range t.loadReceivers() { receiver.SetUpTrackPaused(muted) } @@ -440,7 +446,7 @@ func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.Su return nil, ErrNotOpen } - receivers := t.simulcastReceiversLocked() + receivers := t.receivers potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs)) copy(potentialCodecs, t.potentialCodecs) t.lock.RUnlock() @@ -527,7 +533,7 @@ func (t *MediaTrackReceiver) updateTrackInfoOfReceivers() { ti := proto.Clone(t.trackInfo).(*livekit.TrackInfo) t.lock.RUnlock() - for _, r := range t.Receivers() { + for _, r := range t.loadReceivers() { r.UpdateTrackInfo(ti) } } @@ -815,7 +821,7 @@ func (t *MediaTrackReceiver) DebugInfo() map[string]interface{} { info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() - for _, receiver := range t.Receivers() { + for _, receiver := range t.loadReceivers() { info[receiver.Codec().MimeType] = receiver.DebugInfo() } @@ -853,25 +859,22 @@ func (t *MediaTrackReceiver) Receiver(mime string) sfu.TrackReceiver { func (t *MediaTrackReceiver) Receivers() []sfu.TrackReceiver { t.lock.RLock() defer t.lock.RUnlock() - - receivers := make([]sfu.TrackReceiver, 0, len(t.receivers)) - for _, r := range t.receivers { - receivers = append(receivers, r.TrackReceiver) + receivers := make([]sfu.TrackReceiver, len(t.receivers)) + for i, r := range t.receivers { + receivers[i] = r.TrackReceiver } return receivers } -func (t *MediaTrackReceiver) simulcastReceiversLocked() []*simulcastReceiver { - receivers := make([]*simulcastReceiver, 0, len(t.receivers)) - for _, r := range t.receivers { - receivers = append(receivers, r) - } - return receivers +func (t *MediaTrackReceiver) loadReceivers() []*simulcastReceiver { + t.lock.RLock() + defer t.lock.RUnlock() + return t.receivers } func (t *MediaTrackReceiver) SetRTT(rtt uint32) { - for _, r := range t.Receivers() { - if wr, ok := r.(*sfu.WebRTCReceiver); ok { + for _, r := range t.loadReceivers() { + if wr, ok := r.TrackReceiver.(*sfu.WebRTCReceiver); ok { wr.SetRTT(rtt) } } @@ -901,7 +904,7 @@ func (t *MediaTrackReceiver) IsEncrypted() bool { } func (t *MediaTrackReceiver) GetTrackStats() *livekit.RTPStats { - receivers := t.Receivers() + receivers := t.loadReceivers() stats := make([]*livekit.RTPStats, 0, len(receivers)) for _, receiver := range receivers { receiverStats := receiver.GetTrackStats()