diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index c334258f8..bf253bb04 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -49,6 +49,8 @@ const ( keyFrameIntervalMin = 200 keyFrameIntervalMax = 1000 flushTimeout = 1 * time.Second + + maxPadding = 2000 ) var ( @@ -307,7 +309,13 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.handleRTCP(pkt) }) } - d.sequencer = newSequencer(d.maxTrack, d.logger) + + if d.kind == webrtc.RTPCodecTypeAudio { + d.sequencer = newSequencer(d.maxTrack, 0, d.logger) + } else { + d.sequencer = newSequencer(d.maxTrack, maxPadding, d.logger) + } + d.codec = codec.RTPCodecCapability d.forwarder.DetermineCodec(d.codec) if d.onBind != nil { @@ -615,7 +623,7 @@ func (d *DownTrack) WritePaddingRTP(bytesToSend int) int { // So, retransmitting padding packets is only going to make matters worse. // if d.sequencer != nil { - d.sequencer.push(0, hdr.SequenceNumber, hdr.Timestamp, int8(InvalidLayerSpatial)) + d.sequencer.pushPadding(hdr.SequenceNumber) } bytesSent += size @@ -1235,11 +1243,6 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { nackMisses := uint32(0) numRepeatedNACKs := uint32(0) for _, meta := range d.sequencer.getPacketsMeta(filtered) { - if meta.layer == int8(InvalidLayerSpatial) { - // padding packet, no RTX for those - continue - } - if disallowedLayers[meta.layer] { continue } diff --git a/pkg/sfu/sequencer.go b/pkg/sfu/sequencer.go index f20f85384..52ee2f54f 100644 --- a/pkg/sfu/sequencer.go +++ b/pkg/sfu/sequencer.go @@ -11,7 +11,6 @@ import ( ) const ( - maxPadding = 2000 defaultRtt = 70 ignoreRetransmission = 100 // Ignore packet retransmission after ignoreRetransmission milliseconds ) @@ -96,7 +95,7 @@ type sequencer struct { sync.Mutex init bool max int - seq []packetMeta + seq []*packetMeta step int headSN uint16 startTime int64 @@ -104,11 +103,11 @@ type sequencer struct { logger logger.Logger } -func newSequencer(maxTrack int, logger logger.Logger) *sequencer { +func newSequencer(maxTrack int, maxPadding int, logger logger.Logger) *sequencer { return &sequencer{ startTime: time.Now().UnixNano() / 1e6, max: maxTrack + maxPadding, - seq: make([]packetMeta, maxTrack+maxPadding), + seq: make([]*packetMeta, maxTrack+maxPadding), rtt: defaultRtt, logger: logger, } @@ -129,6 +128,33 @@ func (s *sequencer) push(sn, offSn uint16, timeStamp uint32, layer int8) *packet s.Lock() defer s.Unlock() + slot, isValid := s.getSlot(offSn) + if !isValid { + return nil + } + + s.seq[slot] = &packetMeta{ + sourceSeqNo: sn, + targetSeqNo: offSn, + timestamp: timeStamp, + layer: layer, + } + return s.seq[slot] +} + +func (s *sequencer) pushPadding(offSn uint16) { + s.Lock() + defer s.Unlock() + + slot, isValid := s.getSlot(offSn) + if !isValid { + return + } + + s.seq[slot] = nil +} + +func (s *sequencer) getSlot(offSn uint16) (int, bool) { if !s.init { s.headSN = offSn - 1 s.init = true @@ -137,7 +163,7 @@ func (s *sequencer) push(sn, offSn uint16, timeStamp uint32, layer int8) *packet diff := offSn - s.headSN if diff == 0 { // duplicate - return nil + return 0, false } slot := 0 @@ -145,33 +171,32 @@ func (s *sequencer) push(sn, offSn uint16, timeStamp uint32, layer int8) *packet // out-of-order back := int(s.headSN - offSn) if back >= s.max { - s.logger.Debugw("old packet, can not be sequenced", "head", sn, "received", offSn) - return nil + s.logger.Debugw("old packet, can not be sequenced", "head", s.headSN, "received", offSn) + return 0, false } slot = s.step - back - 1 } else { + s.headSN = offSn + + // invalidate intervening slots + for idx := 0; idx < int(diff)-1; idx++ { + s.seq[s.wrap(s.step+idx)] = nil + } + slot = s.step + int(diff) - 1 - s.headSN = offSn // for next packet s.step = s.wrap(s.step + int(diff)) } - slot = s.wrap(slot) - s.seq[slot] = packetMeta{ - sourceSeqNo: sn, - targetSeqNo: offSn, - timestamp: timeStamp, - layer: layer, - } - return &s.seq[slot] + return s.wrap(slot), true } -func (s *sequencer) getPacketsMeta(seqNo []uint16) []packetMeta { +func (s *sequencer) getPacketsMeta(seqNo []uint16) []*packetMeta { s.Lock() defer s.Unlock() - meta := make([]packetMeta, 0, len(seqNo)) + meta := make([]*packetMeta, 0, len(seqNo)) refTime := uint32(time.Now().UnixNano()/1e6 - s.startTime) for _, sn := range seqNo { diff := s.headSN - sn @@ -181,15 +206,15 @@ func (s *sequencer) getPacketsMeta(seqNo []uint16) []packetMeta { } slot := s.wrap(s.step - int(diff) - 1) - seq := &s.seq[slot] - if seq.targetSeqNo != sn { + seq := s.seq[slot] + if seq == nil || seq.targetSeqNo != sn { continue } if seq.lastNack == 0 || refTime-seq.lastNack > uint32(math.Min(float64(ignoreRetransmission), float64(2*s.rtt))) { seq.nacked++ seq.lastNack = refTime - meta = append(meta, *seq) + meta = append(meta, seq) } } diff --git a/pkg/sfu/sequencer_test.go b/pkg/sfu/sequencer_test.go index ffd624fc2..5156b0bee 100644 --- a/pkg/sfu/sequencer_test.go +++ b/pkg/sfu/sequencer_test.go @@ -13,7 +13,7 @@ import ( ) func Test_sequencer(t *testing.T) { - seq := newSequencer(500, logger.GetDefaultLogger()) + seq := newSequencer(500, 0, logger.GetDefaultLogger()) off := uint16(15) for i := uint16(1); i < 518; i++ { @@ -57,8 +57,9 @@ func Test_sequencer_getNACKSeqNo(t *testing.T) { seqNo []uint16 } type fields struct { - input []uint16 - offset uint16 + input []uint16 + padding []uint16 + offset uint16 } tests := []struct { @@ -70,11 +71,12 @@ func Test_sequencer_getNACKSeqNo(t *testing.T) { { name: "Should get correct seq numbers", fields: fields{ - input: []uint16{2, 3, 4, 7, 8}, - offset: 5, + input: []uint16{2, 3, 4, 7, 8}, + padding: []uint16{9, 10}, + offset: 5, }, args: args{ - seqNo: []uint16{4 + 5, 5 + 5, 8 + 5}, + seqNo: []uint16{4 + 5, 5 + 5, 8 + 5, 9 + 5, 10 + 5}, }, want: []uint16{4, 8}, }, @@ -82,11 +84,14 @@ func Test_sequencer_getNACKSeqNo(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - n := newSequencer(500, logger.GetDefaultLogger()) + n := newSequencer(500, 0, logger.GetDefaultLogger()) for _, i := range tt.fields.input { n.push(i, i+tt.fields.offset, 123, 3) } + for _, i := range tt.fields.padding { + n.pushPadding(i + tt.fields.offset) + } g := n.getPacketsMeta(tt.args.seqNo) var got []uint16