diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index be2301bc1..207165fbb 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -26,6 +26,7 @@ import ( var ( ErrReceiverClosed = errors.New("receiver closed") ErrDownTrackAlreadyExist = errors.New("DownTrack already exist") + ErrBufferNotFound = errors.New("buffer not found") ) type AudioLevelHandle func(level uint8, duration uint32) @@ -451,22 +452,33 @@ func (w *WebRTCReceiver) SetRTCPCh(ch chan []rtcp.Packet) { } func (w *WebRTCReceiver) getBuffer(layer int32) *buffer.Buffer { + w.bufferMu.RLock() + defer w.bufferMu.RUnlock() + + return w.getBufferLocked(layer) +} + +func (w *WebRTCReceiver) getBufferLocked(layer int32) *buffer.Buffer { // for svc codecs, use layer full quality instead. // we only have buffer for full quality if w.isSVC { layer = int32(len(w.buffers)) - 1 } - w.bufferMu.RLock() - buff := w.buffers[layer] - w.bufferMu.RUnlock() - if buff == nil { - w.logger.Warnw("getBuffer failed, buffer not found", nil, "layer", layer) + + if int(layer) >= len(w.buffers) { + return nil } - return buff + + return w.buffers[layer] } func (w *WebRTCReceiver) ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) { - return w.getBuffer(int32(layer)).GetPacket(buf, sn) + b := w.getBuffer(int32(layer)) + if b == nil { + return 0, ErrBufferNotFound + } + + return b.GetPacket(buf, sn) } func (w *WebRTCReceiver) GetTrackStats() *livekit.RTPStats { @@ -674,10 +686,16 @@ func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver { } func (w *WebRTCReceiver) GetTemporalLayerFpsForSpatial(layer int32) []float32 { - if !w.isSVC { - return w.getBuffer(layer).GetTemporalLayerFpsForSpatial(0) + b := w.getBuffer(layer) + if b == nil { + return nil } - return w.getBuffer(layer).GetTemporalLayerFpsForSpatial(layer) + + if !w.isSVC { + return b.GetTemporalLayerFpsForSpatial(0) + } + + return b.GetTemporalLayerFpsForSpatial(layer) } func (w *WebRTCReceiver) GetRTCPSenderReportData(layer int32) *buffer.RTCPSenderReportData { @@ -699,18 +717,20 @@ func (w *WebRTCReceiver) GetReferenceLayerRTPTimestamp(ts uint32, layer int32, r return ts, nil } - if layer == InvalidLayerSpatial || int(layer) >= len(w.buffers) { + bLayer := w.getBufferLocked(layer) + if bLayer == nil { return 0, fmt.Errorf("invalid layer: %d", layer) } - srLayer := w.buffers[layer].GetSenderReportData() + srLayer := bLayer.GetSenderReportData() if srLayer == nil || srLayer.NTPTimestamp == 0 { return 0, fmt.Errorf("layer rtcp sender report not available: %d", layer) } - if referenceLayer == InvalidLayerSpatial || int(referenceLayer) >= len(w.buffers) { + bReferenceLayer := w.getBufferLocked(referenceLayer) + if bReferenceLayer == nil { return 0, fmt.Errorf("invalid reference layer: %d", referenceLayer) } - srRef := w.buffers[referenceLayer].GetSenderReportData() + srRef := bReferenceLayer.GetSenderReportData() if srRef == nil || srRef.NTPTimestamp == 0 { return 0, fmt.Errorf("reference layer rtcp sender report not available: %d", referenceLayer) }