diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 4793820c8..4ab352b41 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -142,13 +142,18 @@ type TrackReceiver interface { CodecState() ReceiverCodecState } -type redPktWriteFunc func(pkt *buffer.ExtPacket, spatialLayer int32) int -type redSenderReportWriteFunc func( - payloadType webrtc.PayloadType, - isSVC bool, - layer int32, - publisherSRData *livekit.RTCPSenderReportState, -) +type REDTransformer interface { + ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int32) int + ForwardRTCPSenderReport( + payloadType webrtc.PayloadType, + isSVC bool, + layer int32, + publisherSRData *livekit.RTCPSenderReportState, + ) + ResyncDownTracks() + CanClose() bool + Close() +} // WebRTCReceiver receives a media track type WebRTCReceiver struct { @@ -191,10 +196,7 @@ type WebRTCReceiver struct { onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) onMaxLayerChange func(maxLayer int32) - primaryReceiver atomic.Pointer[RedPrimaryReceiver] - redReceiver atomic.Pointer[RedReceiver] - redPktWriter atomic.Value // redPktWriteFunc - redSenderReportWriter atomic.Value // redPktWriteFunc + redTransformer atomic.Value // redTransformer interface forwardStats *ForwardStats } @@ -413,8 +415,8 @@ func (w *WebRTCReceiver) AddUpTrack(track TrackRemote, buff *buffer.Buffer) erro _ = dt.HandleRTCPSenderReportData(w.codec.PayloadType, w.isSVC, layer, srData) }) - if f := w.redSenderReportWriter.Load(); f != nil { - f.(redSenderReportWriteFunc)(w.codec.PayloadType, w.isSVC, layer, srData) + if rt := w.redTransformer.Load(); rt != nil { + rt.(REDTransformer).ForwardRTCPSenderReport(w.codec.PayloadType, w.isSVC, layer, srData) } }) @@ -743,11 +745,8 @@ func (w *WebRTCReceiver) forwardRTP(layer int32, buff *buffer.Buffer) { w.closeOnce.Do(func() { w.closed.Store(true) w.closeTracks() - if pr := w.primaryReceiver.Load(); pr != nil { - pr.Close() - } - if pr := w.redReceiver.Load(); pr != nil { - pr.Close() + if rt := w.redTransformer.Load(); rt != nil { + rt.(REDTransformer).Close() } }) @@ -794,8 +793,8 @@ func (w *WebRTCReceiver) forwardRTP(layer int32, buff *buffer.Buffer) { _ = dt.WriteRTP(pkt, spatialLayer) }) - if f := w.redPktWriter.Load(); f != nil { - writeCount += f.(redPktWriteFunc)(pkt, spatialLayer) + if rt := w.redTransformer.Load(); rt != nil { + writeCount += rt.(REDTransformer).ForwardRTP(pkt, spatialLayer) } // track delay/jitter @@ -866,39 +865,51 @@ func (w *WebRTCReceiver) DebugInfo() map[string]interface{} { } func (w *WebRTCReceiver) GetPrimaryReceiverForRed() TrackReceiver { + w.bufferMu.Lock() + defer w.bufferMu.Unlock() + if !w.isRED || w.closed.Load() { return w } - if w.primaryReceiver.Load() == nil { + rt := w.redTransformer.Load() + if rt == nil { pr := NewRedPrimaryReceiver(w, DownTrackSpreaderParams{ Threshold: w.lbThreshold, Logger: w.logger, }) - if w.primaryReceiver.CompareAndSwap(nil, pr) { - w.redPktWriter.Store(redPktWriteFunc(pr.ForwardRTP)) - w.redSenderReportWriter.Store(redSenderReportWriteFunc(pr.ForwardRTCPSenderReport)) + w.redTransformer.Store(pr) + return pr + } else { + if pr, ok := rt.(*RedPrimaryReceiver); ok { + return pr } } - return w.primaryReceiver.Load() + return nil } func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver { + w.bufferMu.Lock() + defer w.bufferMu.Unlock() + if w.isRED || w.closed.Load() { return w } - if w.redReceiver.Load() == nil { + rt := w.redTransformer.Load() + if rt == nil { pr := NewRedReceiver(w, DownTrackSpreaderParams{ Threshold: w.lbThreshold, Logger: w.logger, }) - if w.redReceiver.CompareAndSwap(nil, pr) { - w.redPktWriter.Store(redPktWriteFunc(pr.ForwardRTP)) - w.redSenderReportWriter.Store(redSenderReportWriteFunc(pr.ForwardRTCPSenderReport)) + w.redTransformer.Store(pr) + return pr + } else { + if pr, ok := rt.(*RedReceiver); ok { + return pr } } - return w.redReceiver.Load() + return nil } func (w *WebRTCReceiver) GetTemporalLayerFpsForSpatial(layer int32) []float32 {