diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 53e6f327e..d48b98460 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -237,6 +237,7 @@ type DownTrack struct { isClosed atomic.Bool connected atomic.Bool bindAndConnectedOnce atomic.Bool + writable atomic.Bool rtpStats *buffer.RTPStats @@ -420,7 +421,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.forwarder.DetermineCodec(d.codec, d.params.Receiver.HeaderExtensions()) d.params.Logger.Debugw("downtrack bound") - d.onBindAndConnected() + d.onBindAndConnectedChange() return codec, nil } @@ -429,6 +430,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, // because a track has been stopped. func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { d.bound.Store(false) + d.onBindAndConnectedChange() return nil } @@ -600,7 +602,7 @@ func (d *DownTrack) keyFrameRequester(generation uint32, layer int32) { return } - if d.connected.Load() { + if d.writable.Load() { d.params.Logger.Debugw("sending PLI for layer lock", "generation", generation, "layer", layer) d.params.Receiver.SendPLI(layer, false) d.rtpStats.UpdateLayerLockPliAndTime(1) @@ -608,7 +610,7 @@ func (d *DownTrack) keyFrameRequester(generation uint32, layer int32) { <-ticker.C - if generation != d.keyFrameRequestGeneration.Load() || !d.bound.Load() { + if generation != d.keyFrameRequestGeneration.Load() || !d.writable.Load() { return } } @@ -643,7 +645,7 @@ func (d *DownTrack) maxLayerNotifierWorker() { // WriteRTP writes an RTP Packet to the DownTrack func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { - if !d.bound.Load() || !d.connected.Load() { + if !d.writable.Load() { return nil } @@ -720,6 +722,10 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { // WritePaddingRTP tries to write as many padding only RTP packets as necessary // to satisfy given size to the DownTrack func (d *DownTrack) WritePaddingRTP(bytesToSend int, paddingOnMute bool, forceMarker bool) int { + if !d.writable.Load() { + return 0 + } + if !d.rtpStats.IsActive() && !paddingOnMute { return 0 } @@ -1224,8 +1230,8 @@ func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport { func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan struct{} { done := make(chan struct{}) go func() { - // don't send if nothing has been sent - if !d.rtpStats.IsActive() { + // don't send if not writable OR nothing has been sent + if !d.writable.Load() || !d.rtpStats.IsActive() { close(done) return } @@ -1492,7 +1498,7 @@ func (d *DownTrack) handleRTCP(bytes []byte) { func (d *DownTrack) SetConnected() { if !d.connected.Swap(true) { - d.onBindAndConnected() + d.onBindAndConnectedChange() } } @@ -1710,7 +1716,7 @@ func (d *DownTrack) GetAndResetBytesSent() (uint32, uint32) { return d.bytesSent.Swap(0), d.bytesRetransmitted.Swap(0) } -func (d *DownTrack) onBindAndConnected() { +func (d *DownTrack) onBindAndConnectedChange() { if d.connected.Load() && d.bound.Load() && !d.bindAndConnectedOnce.Swap(true) { if d.kind == webrtc.RTPCodecTypeVideo { _, layer := d.forwarder.CheckSync() @@ -1723,6 +1729,7 @@ func (d *DownTrack) onBindAndConnected() { go d.sendPaddingOnMute() } } + d.writable.Store(d.connected.Load() && d.bound.Load()) } func (d *DownTrack) sendPaddingOnMute() { diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index 0b7320991..49c56bc2f 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -1622,9 +1622,6 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e // should be called with lock held func (f *Forwarder) getTranslationParamsCommon(extPkt *buffer.ExtPacket, layer int32, tp *TranslationParams) (*TranslationParams, error) { - if tp == nil { - tp = &TranslationParams{} - } if f.lastSSRC != extPkt.Packet.SSRC { if err := f.processSourceSwitch(extPkt, layer); err != nil { tp.shouldDrop = true @@ -1649,7 +1646,7 @@ func (f *Forwarder) getTranslationParamsCommon(extPkt *buffer.ExtPacket, layer i // should be called with lock held func (f *Forwarder) getTranslationParamsAudio(extPkt *buffer.ExtPacket, layer int32) (*TranslationParams, error) { - return f.getTranslationParamsCommon(extPkt, layer, nil) + return f.getTranslationParamsCommon(extPkt, layer, &TranslationParams{}) } // should be called with lock held @@ -1661,7 +1658,6 @@ func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer in } tp := &TranslationParams{} - if !f.vls.GetTarget().IsValid() { // stream is paused by streamallocator tp.shouldDrop = true diff --git a/pkg/sfu/rtpmunger.go b/pkg/sfu/rtpmunger.go index c0ce6ae6a..8176cfe81 100644 --- a/pkg/sfu/rtpmunger.go +++ b/pkg/sfu/rtpmunger.go @@ -72,9 +72,12 @@ type RTPMunger struct { extLastSN uint64 extSecondLastSN uint64 - extLastTS uint64 - tsOffset uint64 - lastMarker bool + snOffset uint64 + + extLastTS uint64 + tsOffset uint64 + + lastMarker bool extRtxGateSn uint64 isInRtxGateRegion bool @@ -88,12 +91,11 @@ func NewRTPMunger(logger logger.Logger) *RTPMunger { } func (r *RTPMunger) DebugInfo() map[string]interface{} { - snOffset, _ := r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) return map[string]interface{}{ "ExtHighestIncomingSN": r.extHighestIncomingSN, "ExtLastSN": r.extLastSN, "ExtSecondLastSN": r.extSecondLastSN, - "SNOffset": snOffset, + "SNOffset": r.snOffset, "ExtLastTS": r.extLastTS, "TSOffset": r.tsOffset, "LastMarker": r.lastMarker, @@ -116,14 +118,20 @@ func (r *RTPMunger) SeedLast(state RTPMungerState) { func (r *RTPMunger) SetLastSnTs(extPkt *buffer.ExtPacket) { r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1 + r.extLastSN = extPkt.ExtSequenceNumber r.extSecondLastSN = r.extLastSN - 1 + r.updateSnOffset() + r.extLastTS = extPkt.ExtTimestamp } func (r *RTPMunger) UpdateSnTsOffsets(extPkt *buffer.ExtPacket, snAdjust uint64, tsAdjust uint64) { r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1 + r.snRangeMap.ClearAndResetValue(extPkt.ExtSequenceNumber - r.extLastSN - snAdjust) + r.updateSnOffset() + r.tsOffset = extPkt.ExtTimestamp - r.extLastTS - tsAdjust } @@ -148,16 +156,42 @@ func (r *RTPMunger) PacketDropped(extPkt *buffer.ExtPacket) { } r.extLastSN = r.extSecondLastSN + r.updateSnOffset() } func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationParamsRTP, error) { diff := int64(extPkt.ExtSequenceNumber - r.extHighestIncomingSN) + if (diff == 1 && len(extPkt.Packet.Payload) != 0) || diff > 1 { + // in-order - either contiguous packet with payload OR packet following a gap, may or may not have payload + r.extHighestIncomingSN = extPkt.ExtSequenceNumber + + ordering := SequenceNumberOrderingContiguous + if diff > 1 { + ordering = SequenceNumberOrderingGap + } + + extMungedSN := extPkt.ExtSequenceNumber - r.snOffset + extMungedTS := extPkt.ExtTimestamp - r.tsOffset + + r.extSecondLastSN = r.extLastSN + r.extLastSN = extMungedSN + r.extLastTS = extMungedTS + r.lastMarker = extPkt.Packet.Marker + + if extPkt.KeyFrame { + r.extRtxGateSn = extMungedSN + r.isInRtxGateRegion = true + } + + if r.isInRtxGateRegion && (extMungedSN-r.extRtxGateSn) > RtxGateWindow { + r.isInRtxGateRegion = false + } - // can get duplicate packet due to FEC - if diff == 0 { return &TranslationParamsRTP{ - snOrdering: SequenceNumberOrderingDuplicate, - }, ErrDuplicatePacket + snOrdering: ordering, + sequenceNumber: uint16(extMungedSN), + timestamp: uint32(extMungedTS), + }, nil } if diff < 0 { @@ -176,53 +210,25 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara }, nil } - ordering := SequenceNumberOrderingContiguous - if diff > 1 { - ordering = SequenceNumberOrderingGap - } - - r.extHighestIncomingSN = extPkt.ExtSequenceNumber - // if padding only packet, can be dropped and sequence number adjusted, if contiguous - if diff == 1 && len(extPkt.Packet.Payload) == 0 { + if diff == 1 { + r.extHighestIncomingSN = extPkt.ExtSequenceNumber + if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil { r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN) } + + r.updateSnOffset() + return &TranslationParamsRTP{ - snOrdering: ordering, + snOrdering: SequenceNumberOrderingContiguous, }, ErrPaddingOnlyPacket } - snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber) - if err != nil { - r.logger.Errorw("could not get sequence number adjustment", err, "sn", extPkt.ExtSequenceNumber, "payloadSize", len(extPkt.Packet.Payload)) - return &TranslationParamsRTP{ - snOrdering: ordering, - }, ErrSequenceNumberOffsetNotFound - } - - extMungedSN := extPkt.ExtSequenceNumber - snOffset - extMungedTS := extPkt.ExtTimestamp - r.tsOffset - - r.extSecondLastSN = r.extLastSN - r.extLastSN = extMungedSN - r.extLastTS = extMungedTS - r.lastMarker = extPkt.Packet.Marker - - if extPkt.KeyFrame { - r.extRtxGateSn = extMungedSN - r.isInRtxGateRegion = true - } - - if r.isInRtxGateRegion && (extMungedSN-r.extRtxGateSn) > RtxGateWindow { - r.isInRtxGateRegion = false - } - + // can get duplicate packet due to FEC return &TranslationParamsRTP{ - snOrdering: ordering, - sequenceNumber: uint16(extMungedSN), - timestamp: uint32(extMungedTS), - }, nil + snOrdering: SequenceNumberOrderingDuplicate, + }, ErrDuplicatePacket } func (r *RTPMunger) FilterRTX(nacks []uint16) []uint16 { @@ -297,3 +303,11 @@ func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate func (r *RTPMunger) IsOnFrameBoundary() bool { return r.lastMarker } + +func (r *RTPMunger) updateSnOffset() { + snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) + if err != nil { + r.logger.Errorw("could not get SN offset", err) + } + r.snOffset = snOffset +}