From fdba70dab74a1dec448c88aad65ac679a55ded9a Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Tue, 29 Aug 2023 21:13:07 +0530 Subject: [PATCH 1/4] Use correct variables (#2010) --- pkg/sfu/buffer/rtpstats.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/sfu/buffer/rtpstats.go b/pkg/sfu/buffer/rtpstats.go index 2fdcbb4c5..3dd477481 100644 --- a/pkg/sfu/buffer/rtpstats.go +++ b/pkg/sfu/buffer/rtpstats.go @@ -1591,8 +1591,8 @@ func (r *RTPStats) getDrift() (packetDrift *livekit.RTPDrift, reportDrift *livek StartTime: timestamppb.New(r.srFirst.NTPTimestamp.Time()), EndTime: timestamppb.New(r.srNewest.NTPTimestamp.Time()), Duration: elapsed.Seconds(), - StartTimestamp: r.timestamp.GetExtendedStart(), - EndTimestamp: r.timestamp.GetExtendedHighest(), + StartTimestamp: r.srFirst.RTPTimestampExt, + EndTimestamp: r.srNewest.RTPTimestampExt, RtpClockTicks: rtpClockTicks, DriftSamples: driftSamples, DriftMs: (float64(driftSamples) * 1000) / float64(r.params.ClockRate), From 33b48d986fd8e72fbc09cfa177e461234d8d7a0c Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Tue, 29 Aug 2023 21:45:58 +0530 Subject: [PATCH 2/4] Fix typo, have to check against start (#2011) --- pkg/sfu/buffer/rtpstats.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sfu/buffer/rtpstats.go b/pkg/sfu/buffer/rtpstats.go index 3dd477481..1a45b29b7 100644 --- a/pkg/sfu/buffer/rtpstats.go +++ b/pkg/sfu/buffer/rtpstats.go @@ -565,7 +565,7 @@ func (r *RTPStats) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt uint32 extHighestSNOverridden += (1 << 32) } } - if extHighestSNOverridden < r.sequenceNumber.GetExtendedHighest() { + if extHighestSNOverridden < r.sequenceNumber.GetExtendedStart() { // it is possible that the `LastSequenceNumber` in the receiver report is before the starting // sequence number when dummy packets are used to trigger Pion's OnTrack path. r.lastRRTime = time.Now() From 5e481fe6bfecfd09dc43fdd56dadbf0e871c04c0 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Wed, 30 Aug 2023 15:31:01 +0800 Subject: [PATCH 3/4] Don't create new slice when return broadcast downtracks (#2013) --- pkg/sfu/downtrackspreader.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pkg/sfu/downtrackspreader.go b/pkg/sfu/downtrackspreader.go index dd7ac59c6..3768592f3 100644 --- a/pkg/sfu/downtrackspreader.go +++ b/pkg/sfu/downtrackspreader.go @@ -47,12 +47,7 @@ func NewDownTrackSpreader(params DownTrackSpreaderParams) *DownTrackSpreader { func (d *DownTrackSpreader) GetDownTracks() []TrackSender { d.downTrackMu.RLock() defer d.downTrackMu.RUnlock() - - downTracks := make([]TrackSender, 0, len(d.downTracksShadow)) - for _, dt := range d.downTracksShadow { - downTracks = append(downTracks, dt) - } - return downTracks + return d.downTracksShadow } func (d *DownTrackSpreader) ResetAndGetDownTracks() []TrackSender { From 126872047de72b0434d24e3cde997637414bf963 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Wed, 30 Aug 2023 16:46:39 +0530 Subject: [PATCH 4/4] Handle duplicate padding packet in the up stream. (#2012) * Handle duplicate padding packet in the up stream. The following sequence would have produce incorrect results - Sequence number 39 - regular packet - offset = 0 - Sequence number 40 - padding only - drop - offset = 1 - Sequence number 40 - padding only duplicate - was not dropped (this is the bug) - apply offet - sequence number becomes 39 and clashes with previous packet - Sequence number 41 - regular packet - apply offset - goes through as 40. - Sequence number 40 again - does not get dropped - will pass through as 39. * fix duplicate dropping * fix tests * accept repeat last value as padding injection could cause that * use exclusion ranges * more UT and more specific errors --- pkg/sfu/buffer/buffer.go | 28 ++- pkg/sfu/buffer/rtpstats.go | 6 +- pkg/sfu/forwarder_test.go | 19 +- pkg/sfu/rtpmunger.go | 57 ++++-- pkg/sfu/rtpmunger_test.go | 54 +++--- pkg/sfu/utils/rangemap.go | 110 ++++++----- pkg/sfu/utils/rangemap_test.go | 332 +++++++++++++++++++++++++-------- 7 files changed, 429 insertions(+), 177 deletions(-) diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 939f5c6d8..93471804e 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -417,16 +417,35 @@ func (b *Buffer) calc(pkt []byte, arrivalTime time.Time) { flowState := b.updateStreamState(&rtpPacket, arrivalTime) b.processHeaderExtensions(&rtpPacket, arrivalTime) - if !flowState.IsOutOfOrder && len(rtpPacket.Payload) == 0 { - // drop padding only in-order packet - b.snRangeMap.IncValue(1) + if len(rtpPacket.Payload) == 0 && (!flowState.IsOutOfOrder || flowState.IsDuplicate) { + // drop padding only in-order or duplicate packet + if !flowState.IsOutOfOrder { + // in-order packet - increment sequence number offset for subsequent packets + // Example: + // 40 - regular packet - pass through as sequence number 40 + // 41 - missing packet - don't know what it is, could be padding or not + // 42 - padding only packet - in-order - drop - increment sequence number offset to 1 - + // range[0, 42] = 0 offset + // 41 - arrives out of order - get offset 0 from cache - passed through as sequence number 41 + // 43 - regular packet - offset = 1 (running offset) - passes through as sequence number 42 + // 44 - padding only - in order - drop - increment sequence number offset to 2 + // range[0, 42] = 0 offset, range[43, 44] = 1 offset + // 43 - regular packet - out of order + duplicate - offset = 1 from cache - + // adjusted sequence number is 42, will be dropped by RTX buffer AddPacket method as duplicate + // 45 - regular packet - offset = 2 (running offset) - passed through with adjusted sequence number as 43 + // 44 - padding only - out-of-order + duplicate - dropped as duplicate + // + if err := b.snRangeMap.ExcludeRange(flowState.ExtSequenceNumber, flowState.ExtSequenceNumber+1); err != nil { + b.logger.Errorw("could not exclude range", err, "sn", flowState.ExtSequenceNumber) + } + } return } // add to RTX buffer using sequence number after accounting for dropped padding only packets snAdjustment, err := b.snRangeMap.GetValue(flowState.ExtSequenceNumber) if err != nil { - b.logger.Errorw("could not get sequence number adjustment", err) + b.logger.Errorw("could not get sequence number adjustment", err, "sn", flowState.ExtSequenceNumber, "payloadSize", len(rtpPacket.Payload)) return } rtpPacket.Header.SequenceNumber = uint16(flowState.ExtSequenceNumber - snAdjustment) @@ -507,7 +526,6 @@ func (b *Buffer) updateStreamState(p *rtp.Packet, arrivalTime time.Time) RTPFlow b.nacker.Remove(p.SequenceNumber) if flowState.HasLoss { - b.snRangeMap.AddRange(flowState.LossStartInclusive, flowState.LossEndExclusive) for lost := flowState.LossStartInclusive; lost != flowState.LossEndExclusive; lost++ { b.nacker.Push(uint16(lost)) } diff --git a/pkg/sfu/buffer/rtpstats.go b/pkg/sfu/buffer/rtpstats.go index 1a45b29b7..d9f58ba67 100644 --- a/pkg/sfu/buffer/rtpstats.go +++ b/pkg/sfu/buffer/rtpstats.go @@ -62,6 +62,7 @@ type RTPFlowState struct { LossStartInclusive uint64 LossEndExclusive uint64 + IsDuplicate bool IsOutOfOrder bool ExtSequenceNumber uint64 @@ -409,7 +410,6 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa hdrSize := uint64(rtph.MarshalSize()) pktSize := hdrSize + uint64(payloadSize+paddingSize) - isDuplicate := false gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest) if gapSN <= 0 { // duplicate OR out-of-order if payloadSize == 0 { @@ -458,7 +458,7 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa r.bytesDuplicate += pktSize r.headerBytesDuplicate += hdrSize r.packetsDuplicate++ - isDuplicate = true + flowState.IsDuplicate = true } else { r.packetsLost-- r.setSnInfo(resSN.ExtendedVal, resSN.PreExtendedHighest, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, true) @@ -492,7 +492,7 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa flowState.ExtTimestamp = resTS.ExtendedVal } - if !isDuplicate { + if !flowState.IsDuplicate { if payloadSize == 0 { r.packetsPadding++ r.bytesPadding += pktSize diff --git a/pkg/sfu/forwarder_test.go b/pkg/sfu/forwarder_test.go index b0e9bc2ea..7f1e936e5 100644 --- a/pkg/sfu/forwarder_test.go +++ b/pkg/sfu/forwarder_test.go @@ -1214,10 +1214,9 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) { require.Equal(t, expectedTP, *actualTP) // add a missing sequence number to the cache - f.rtpMunger.snRangeMap.IncValue(10) - f.rtpMunger.snRangeMap.AddRange(23332, 23333) + f.rtpMunger.snRangeMap.ExcludeRange(23332, 23333) - // out-of-order packet not in cache should be dropped + // out-of-order packet should get offset from cache params = &testutils.TestExtPacketParams{ SequenceNumber: 23331, Timestamp: 0xabcdef, @@ -1227,7 +1226,11 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) { extPkt, _ = testutils.GetTestExtPacket(params) expectedTP = TranslationParams{ - shouldDrop: true, + rtp: &TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + sequenceNumber: 23331, + timestamp: 0xabcdef, + }, } actualTP, err = f.GetTranslationParams(extPkt, 0) require.NoError(t, err) @@ -1260,7 +1263,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) { expectedTP = TranslationParams{ rtp: &TranslationParamsRTP{ snOrdering: SequenceNumberOrderingContiguous, - sequenceNumber: 23324, + sequenceNumber: 23333, timestamp: 0xabcdef, }, } @@ -1279,7 +1282,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) { expectedTP = TranslationParams{ rtp: &TranslationParamsRTP{ snOrdering: SequenceNumberOrderingGap, - sequenceNumber: 23326, + sequenceNumber: 23335, timestamp: 0xabcdef, }, } @@ -1299,7 +1302,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) { expectedTP = TranslationParams{ rtp: &TranslationParamsRTP{ snOrdering: SequenceNumberOrderingOutOfOrder, - sequenceNumber: 23325, + sequenceNumber: 23334, timestamp: 0xabcdef, }, } @@ -1319,7 +1322,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) { expectedTP = TranslationParams{ rtp: &TranslationParamsRTP{ snOrdering: SequenceNumberOrderingContiguous, - sequenceNumber: 23327, + sequenceNumber: 23336, timestamp: 0xabcdf0, }, } diff --git a/pkg/sfu/rtpmunger.go b/pkg/sfu/rtpmunger.go index 64ac05008..c0ce6ae6a 100644 --- a/pkg/sfu/rtpmunger.go +++ b/pkg/sfu/rtpmunger.go @@ -53,12 +53,13 @@ type SnTs struct { // ---------------------------------------------------------------------- type RTPMungerState struct { - ExtLastSN uint64 - ExtLastTS uint64 + ExtLastSN uint64 + ExtSecondLastSN uint64 + ExtLastTS uint64 } func (r RTPMungerState) String() string { - return fmt.Sprintf("RTPMungerState{extLastSN: %d, extLastTS: %d)", r.ExtLastSN, r.ExtLastTS) + return fmt.Sprintf("RTPMungerState{extLastSN: %d, extSecondLastSN: %d, extLastTS: %d)", r.ExtLastSN, r.ExtSecondLastSN, r.ExtLastTS) } // ---------------------------------------------------------------------- @@ -69,10 +70,11 @@ type RTPMunger struct { extHighestIncomingSN uint64 snRangeMap *utils.RangeMap[uint64, uint64] - extLastSN uint64 - extLastTS uint64 - tsOffset uint64 - lastMarker bool + extLastSN uint64 + extSecondLastSN uint64 + extLastTS uint64 + tsOffset uint64 + lastMarker bool extRtxGateSn uint64 isInRtxGateRegion bool @@ -86,10 +88,11 @@ func NewRTPMunger(logger logger.Logger) *RTPMunger { } func (r *RTPMunger) DebugInfo() map[string]interface{} { - snOffset, _ := r.snRangeMap.GetValue(r.extHighestIncomingSN) + snOffset, _ := r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) return map[string]interface{}{ "ExtHighestIncomingSN": r.extHighestIncomingSN, "ExtLastSN": r.extLastSN, + "ExtSecondLastSN": r.extSecondLastSN, "SNOffset": snOffset, "ExtLastTS": r.extLastTS, "TSOffset": r.tsOffset, @@ -99,19 +102,22 @@ func (r *RTPMunger) DebugInfo() map[string]interface{} { func (r *RTPMunger) GetLast() RTPMungerState { return RTPMungerState{ - ExtLastSN: r.extLastSN, - ExtLastTS: r.extLastTS, + ExtLastSN: r.extLastSN, + ExtSecondLastSN: r.extSecondLastSN, + ExtLastTS: r.extLastTS, } } func (r *RTPMunger) SeedLast(state RTPMungerState) { r.extLastSN = state.ExtLastSN + r.extSecondLastSN = state.ExtSecondLastSN r.extLastTS = state.ExtLastTS } func (r *RTPMunger) SetLastSnTs(extPkt *buffer.ExtPacket) { r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1 r.extLastSN = extPkt.ExtSequenceNumber + r.extSecondLastSN = r.extLastSN - 1 r.extLastTS = extPkt.ExtTimestamp } @@ -126,15 +132,22 @@ func (r *RTPMunger) PacketDropped(extPkt *buffer.ExtPacket) { return } - r.snRangeMap.IncValue(1) - snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber) - if err != nil { - r.logger.Errorw("could not get sequence number offset", err) - return + if err == nil { + outSN := extPkt.ExtSequenceNumber - snOffset + if outSN != r.extLastSN { + r.logger.Warnw("last outgoing sequence number mismatch", nil, "expected", r.extLastSN, "got", outSN) + } + } + if r.extLastSN == r.extSecondLastSN { + r.logger.Warnw("cannot roll back on drop", nil, "extLastSN", r.extLastSN, "secondLastSN", r.extSecondLastSN) } - r.extLastSN = extPkt.ExtSequenceNumber - snOffset + if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil { + r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN) + } + + r.extLastSN = r.extSecondLastSN } func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationParamsRTP, error) { @@ -166,14 +179,15 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara ordering := SequenceNumberOrderingContiguous if diff > 1 { ordering = SequenceNumberOrderingGap - r.snRangeMap.AddRange(r.extHighestIncomingSN+1, extPkt.ExtSequenceNumber) } r.extHighestIncomingSN = extPkt.ExtSequenceNumber // if padding only packet, can be dropped and sequence number adjusted, if contiguous if diff == 1 && len(extPkt.Packet.Payload) == 0 { - r.snRangeMap.IncValue(1) + if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil { + r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN) + } return &TranslationParamsRTP{ snOrdering: ordering, }, ErrPaddingOnlyPacket @@ -181,6 +195,7 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber) if err != nil { + r.logger.Errorw("could not get sequence number adjustment", err, "sn", extPkt.ExtSequenceNumber, "payloadSize", len(extPkt.Packet.Payload)) return &TranslationParamsRTP{ snOrdering: ordering, }, ErrSequenceNumberOffsetNotFound @@ -189,6 +204,7 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara extMungedSN := extPkt.ExtSequenceNumber - snOffset extMungedTS := extPkt.ExtTimestamp - r.tsOffset + r.extSecondLastSN = r.extLastSN r.extLastSN = extMungedSN r.extLastTS = extMungedTS r.lastMarker = extPkt.Packet.Marker @@ -225,6 +241,10 @@ func (r *RTPMunger) FilterRTX(nacks []uint16) []uint16 { } func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate uint32, forceMarker bool, extRtpTimestamp uint64) ([]SnTs, error) { + if num == 0 { + return nil, nil + } + useLastTSForFirst := false tsOffset := 0 if !r.lastMarker { @@ -260,6 +280,7 @@ func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate } } + r.extSecondLastSN = extLastSN - 1 r.extLastSN = extLastSN r.snRangeMap.DecValue(uint64(num)) diff --git a/pkg/sfu/rtpmunger_test.go b/pkg/sfu/rtpmunger_test.go index 2ade34e7b..8e5573100 100644 --- a/pkg/sfu/rtpmunger_test.go +++ b/pkg/sfu/rtpmunger_test.go @@ -122,10 +122,13 @@ func TestPacketDropped(t *testing.T) { extPkt, _ = testutils.GetTestExtPacket(params) r.UpdateAndGetSnTs(extPkt) // update sequence number offset + require.Equal(t, uint64(44444), r.extLastSN) r.PacketDropped(extPkt) - require.Equal(t, uint64(44443), r.extLastSN) + require.Equal(t, uint64(23333), r.extLastSN) snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) require.NoError(t, err) require.Equal(t, uint64(1), snOffset) @@ -158,10 +161,9 @@ func TestOutOfOrderSequenceNumber(t *testing.T) { r.UpdateAndGetSnTs(extPkt) // add a missing sequence number to the cache - r.snRangeMap.IncValue(10) - r.snRangeMap.AddRange(23332, 23333) + r.snRangeMap.ExcludeRange(23332, 23333) - // out-of-order sequence number not in the missing sequence number cache + // out-of-order sequence number should be munged using cache params = &testutils.TestExtPacketParams{ SequenceNumber: 23331, Timestamp: 0xabcdef, @@ -171,12 +173,13 @@ func TestOutOfOrderSequenceNumber(t *testing.T) { extPkt, _ = testutils.GetTestExtPacket(params) tpExpected := TranslationParamsRTP{ - snOrdering: SequenceNumberOrderingOutOfOrder, + snOrdering: SequenceNumberOrderingOutOfOrder, + sequenceNumber: 23331, + timestamp: 0xabcdef, } tp, err := r.UpdateAndGetSnTs(extPkt) - require.Error(t, err) - require.ErrorIs(t, err, ErrOutOfOrderSequenceNumberCacheMiss) + require.NoError(t, err) require.Equal(t, tpExpected, *tp) params = &testutils.TestExtPacketParams{ @@ -188,13 +191,11 @@ func TestOutOfOrderSequenceNumber(t *testing.T) { extPkt, _ = testutils.GetTestExtPacket(params) tpExpected = TranslationParamsRTP{ - snOrdering: SequenceNumberOrderingOutOfOrder, - sequenceNumber: 23322, - timestamp: 0xabcdef, + snOrdering: SequenceNumberOrderingOutOfOrder, } tp, err = r.UpdateAndGetSnTs(extPkt) - require.NoError(t, err) + require.Error(t, err, ErrOutOfOrderSequenceNumberCacheMiss) require.Equal(t, tpExpected, *tp) } @@ -246,8 +247,7 @@ func TestPaddingOnlyPacket(t *testing.T) { require.Equal(t, uint64(23333), r.extHighestIncomingSN) require.Equal(t, uint64(23333), r.extLastSN) snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN) - require.NoError(t, err) - require.Equal(t, uint64(1), snOffset) + require.Error(t, err) // padding only packet with a gap should not report an error params = &testutils.TestExtPacketParams{ @@ -313,9 +313,7 @@ func TestGapInSequenceNumber(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(0), snOffset) - // ensure missing sequence numbers got recorded in cache - - // last received, three missing in between and current received should all be in cache + // ensure missing sequence numbers have correct cached offset for i := uint64(65534); i != 65536+1; i++ { offset, err := r.snRangeMap.GetValue(i) require.NoError(t, err) @@ -341,10 +339,9 @@ func TestGapInSequenceNumber(t *testing.T) { require.Equal(t, uint64(65536+2), r.extHighestIncomingSN) require.Equal(t, uint64(65536+1), r.extLastSN) snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) - require.NoError(t, err) - require.Equal(t, uint64(1), snOffset) + require.Error(t, err) - // a packet with a gap should be adding to missing cache + // a packet with a gap should be adjusting for dropped padding packet params = &testutils.TestExtPacketParams{ SequenceNumber: 4, SNCycles: 1, @@ -369,6 +366,11 @@ func TestGapInSequenceNumber(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(1), snOffset) + // ensure missing sequence number has correct cached offset + offset, err := r.snRangeMap.GetValue(65536 + 3) + require.NoError(t, err) + require.Equal(t, uint64(1), offset) + // another contiguous padding only packet should be dropped params = &testutils.TestExtPacketParams{ SequenceNumber: 5, @@ -388,10 +390,9 @@ func TestGapInSequenceNumber(t *testing.T) { require.Equal(t, uint64(65536+5), r.extHighestIncomingSN) require.Equal(t, uint64(65536+3), r.extLastSN) snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) - require.NoError(t, err) - require.Equal(t, uint64(2), snOffset) + require.Error(t, err) - // a packet with a gap should be adding to missing cache + // a packet with a gap should be adjusting for dropped packets params = &testutils.TestExtPacketParams{ SequenceNumber: 7, SNCycles: 1, @@ -416,6 +417,15 @@ func TestGapInSequenceNumber(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(2), snOffset) + // ensure missing sequence number has correct cached offset + offset, err = r.snRangeMap.GetValue(65536 + 3) + require.NoError(t, err) + require.Equal(t, uint64(1), offset) + + offset, err = r.snRangeMap.GetValue(65536 + 6) + require.NoError(t, err) + require.Equal(t, uint64(2), offset) + // check the missing packets params = &testutils.TestExtPacketParams{ SequenceNumber: 6, diff --git a/pkg/sfu/utils/rangemap.go b/pkg/sfu/utils/rangemap.go index acaa4a320..ecf36ec93 100644 --- a/pkg/sfu/utils/rangemap.go +++ b/pkg/sfu/utils/rangemap.go @@ -25,8 +25,10 @@ const ( ) var ( - errReversedOrder = errors.New("end is before start") + errReversedOrder = errors.New("end <= start") errKeyNotFound = errors.New("key not found") + errKeyTooOld = errors.New("key too old") + errKeyExcluded = errors.New("key excluded") ) type rangeType interface { @@ -46,60 +48,68 @@ type rangeVal[RT rangeType, VT valueType] struct { type RangeMap[RT rangeType, VT valueType] struct { halfRange RT - size int - ranges []rangeVal[RT, VT] - runningValue VT + size int + ranges []rangeVal[RT, VT] } func NewRangeMap[RT rangeType, VT valueType](size int) *RangeMap[RT, VT] { var t RT - return &RangeMap[RT, VT]{ + r := &RangeMap[RT, VT]{ halfRange: 1 << ((unsafe.Sizeof(t) * 8) - 1), size: int(math.Max(float64(size), float64(minRanges))), } + r.initRanges(0) + return r } func (r *RangeMap[RT, VT]) ClearAndResetValue(val VT) { - r.ranges = r.ranges[:0] - r.runningValue = val -} - -func (r *RangeMap[RT, VT]) IncValue(inc VT) { - r.runningValue += inc + r.initRanges(val) } func (r *RangeMap[RT, VT]) DecValue(dec VT) { - r.runningValue -= dec + r.ranges[len(r.ranges)-1].value -= dec } -func (r *RangeMap[RT, VT]) AddRange(startInclusive RT, endExclusive RT) error { +func (r *RangeMap[RT, VT]) initRanges(val VT) { + r.ranges = []rangeVal[RT, VT]{ + { + start: 0, + end: 0, + value: val, + }, + } +} + +func (r *RangeMap[RT, VT]) ExcludeRange(startInclusive RT, endExclusive RT) error { if endExclusive == startInclusive || endExclusive-startInclusive > r.halfRange { return errReversedOrder } - isNewRange := true - // check if last range can be extended - if len(r.ranges) != 0 { - lr := &r.ranges[len(r.ranges)-1] - if startInclusive <= lr.end { - return errReversedOrder - } - if lr.value == r.runningValue { - lr.end = endExclusive - 1 - isNewRange = false - } else { - // end last range before start and start a new range - lr.end = startInclusive - 1 - } + lr := &r.ranges[len(r.ranges)-1] + if lr.start > startInclusive { + // start of open range is after start of exclusion range, cannot close the open range + return errReversedOrder } - if isNewRange { - r.ranges = append(r.ranges, rangeVal[RT, VT]{ - start: startInclusive, - end: endExclusive - 1, - value: r.runningValue, - }) + newValue := lr.value + VT(endExclusive-startInclusive) + + // if start of exclusion range matches start of open range, move the open range + if lr.start == startInclusive { + lr.start = endExclusive + lr.value = newValue + return nil } + + // close previous range + lr.end = startInclusive - 1 + + // start new open one after given exclusion range + r.ranges = append(r.ranges, rangeVal[RT, VT]{ + start: endExclusive, + end: 0, + value: newValue, + }) + r.prune() return nil } @@ -107,26 +117,42 @@ func (r *RangeMap[RT, VT]) AddRange(startInclusive RT, endExclusive RT) error { func (r *RangeMap[RT, VT]) GetValue(key RT) (VT, error) { numRanges := len(r.ranges) if numRanges != 0 { - if key > r.ranges[numRanges-1].end { - return r.runningValue, nil + if key >= r.ranges[numRanges-1].start { + // in the open range + return r.ranges[numRanges-1].value, nil } if key < r.ranges[0].start { - return 0, errKeyNotFound + // too old + return 0, errKeyTooOld } } - for _, rv := range r.ranges { - if key-rv.start < r.halfRange && rv.end-key < r.halfRange { - return rv.value, nil + for idx := numRanges - 1; idx >= 0; idx-- { + rv := &r.ranges[idx] + if idx != numRanges-1 { + // open range checked above + if key-rv.start < r.halfRange && rv.end-key < r.halfRange { + return rv.value, nil + } + } + + if idx > 0 { + rvPrev := &r.ranges[idx-1] + beforeDiff := key - rvPrev.end + afterDiff := rv.start - key + if beforeDiff > 0 && beforeDiff < r.halfRange && afterDiff > 0 && afterDiff < r.halfRange { + // in excluded range + return 0, errKeyExcluded + } } } - return r.runningValue, nil + return 0, errKeyNotFound } func (r *RangeMap[RT, VT]) prune() { - if len(r.ranges) > r.size { - r.ranges = r.ranges[len(r.ranges)-r.size:] + if len(r.ranges) > r.size+1 { // +1 to accommodate the open range + r.ranges = r.ranges[len(r.ranges)-r.size-1:] } } diff --git a/pkg/sfu/utils/rangemap_test.go b/pkg/sfu/utils/rangemap_test.go index ef154c1cd..2c1d38f1d 100644 --- a/pkg/sfu/utils/rangemap_test.go +++ b/pkg/sfu/utils/rangemap_test.go @@ -27,103 +27,277 @@ func TestRangeMapUint32(t *testing.T) { value, err := r.GetValue(33333) require.NoError(t, err) require.Equal(t, uint32(0), value) - value, err = r.GetValue(0xffffffff) + + expectedRangeVal := rangeVal[uint32, uint32]{ + start: 0, + end: 0, + value: 0, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + // add an exclusion, should create a new range + err = r.ExcludeRange(10, 11) + require.NoError(t, err) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 9, + value: 0, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 11, + end: 0, + value: 1, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) + + // getting value in old range should return 0 + value, err = r.GetValue(6) require.NoError(t, err) require.Equal(t, uint32(0), value) - // getting value for any key should be incremented value - r.IncValue(2) - value, err = r.GetValue(66666666) + // newer should return 1 + value, err = r.GetValue(11) require.NoError(t, err) - require.Equal(t, uint32(2), value) - value, err = r.GetValue(0) - require.NoError(t, err) - require.Equal(t, uint32(2), value) + require.Equal(t, uint32(1), value) - // add a couple of ranges, as the value is same should just extend - err = r.AddRange(10, 20) - require.NoError(t, err) - err = r.AddRange(30, 40) - require.NoError(t, err) - require.Equal(t, 1, len(r.ranges)) - require.Equal(t, uint32(10), r.ranges[0].start) - require.Equal(t, uint32(39), r.ranges[0].end) - require.Equal(t, uint32(2), r.ranges[0].value) - - // bump value - r.IncValue(1) - // getting value in previously added range should return 2 - value, err = r.GetValue(22) - require.NoError(t, err) - require.Equal(t, uint32(2), value) - - // outside range should return 3 - value, err = r.GetValue(662) - require.NoError(t, err) - require.Equal(t, uint32(3), value) - - // adding out-of-order range should return error - err = r.AddRange(60, 50) - require.Error(t, err, errReversedOrder) - - // adding overlapping should return error - err = r.AddRange(30, 50) - require.Error(t, err, errReversedOrder) - - // adding a non-overlapping range should extend previous range and add new one - err = r.AddRange(50, 60) - require.NoError(t, err) - require.Equal(t, 2, len(r.ranges)) - - require.Equal(t, uint32(10), r.ranges[0].start) - require.Equal(t, uint32(49), r.ranges[0].end) - require.Equal(t, uint32(2), r.ranges[0].value) - - require.Equal(t, uint32(50), r.ranges[1].start) - require.Equal(t, uint32(59), r.ranges[1].end) - require.Equal(t, uint32(3), r.ranges[1].value) - - // getting an old value should not succeed, but start of first range should return no error - value, err = r.GetValue(9) - require.Error(t, err, errKeyNotFound) + // excluded range should return error value, err = r.GetValue(10) + require.ErrorIs(t, err, errKeyExcluded) + + // out-of-order exclusion should return error + err = r.ExcludeRange(9, 10) + require.ErrorIs(t, err, errReversedOrder) + + // flipped exclusion should return error + err = r.ExcludeRange(12, 11) + require.ErrorIs(t, err, errReversedOrder) + err = r.ExcludeRange(11, 11) + require.ErrorIs(t, err, errReversedOrder) + + // add adjacent exclusion range of length = 1 + err = r.ExcludeRange(11, 12) + require.NoError(t, err) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 9, + value: 0, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 12, + end: 0, + value: 2, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) + + // excluded range should return error, now is excluded because exclusion range could be extended + value, err = r.GetValue(11) + require.ErrorIs(t, err, errKeyExcluded) + + // getting value in old range should return 0 + value, err = r.GetValue(6) + require.NoError(t, err) + + // newer should return 2 + value, err = r.GetValue(12) require.NoError(t, err) require.Equal(t, uint32(2), value) - // adding another range should prune the first one as size if set to 2 - r.IncValue(10) - err = r.AddRange(1000, 1233) + // add adjacent exclusion range of length = 10 + err = r.ExcludeRange(12, 22) require.NoError(t, err) - require.Equal(t, 2, len(r.ranges)) - require.Equal(t, uint32(50), r.ranges[0].start) - require.Equal(t, uint32(999), r.ranges[0].end) - require.Equal(t, uint32(3), r.ranges[0].value) + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 9, + value: 0, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) - require.Equal(t, uint32(1000), r.ranges[1].start) - require.Equal(t, uint32(1232), r.ranges[1].end) - require.Equal(t, uint32(13), r.ranges[1].value) + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 22, + end: 0, + value: 12, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) - // previously valid range should return key not found after pruning - value, err = r.GetValue(10) - require.Error(t, err, errKeyNotFound) + // excluded range should return error, now is excluded because exclusion range could be extended + value, err = r.GetValue(15) + require.ErrorIs(t, err, errKeyExcluded) - value, err = r.GetValue(999) + // newer should return 12 + value, err = r.GetValue(25) require.NoError(t, err) - require.Equal(t, uint32(3), value) + require.Equal(t, uint32(12), value) - value, err = r.GetValue(1200) + // add a disjoint exclusion of length = 4 + err = r.ExcludeRange(26, 30) require.NoError(t, err) - require.Equal(t, uint32(13), value) - // something newer than what is in ranges should return running value - value, err = r.GetValue(3000) - require.NoError(t, err) - require.Equal(t, uint32(13), value) + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 9, + value: 0, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) - // decrement running value - r.DecValue(23) - value, err = r.GetValue(3000) + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 22, + end: 25, + value: 12, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 30, + end: 0, + value: 16, + } + require.Equal(t, expectedRangeVal, r.ranges[2]) + + // get a value from newly closed range [22, 25] + value, err = r.GetValue(23) require.NoError(t, err) - require.Equal(t, uint32((1<<32)-10), value) + require.Equal(t, uint32(12), value) + + // add a disjoint exclusion of length = 1 + err = r.ExcludeRange(50, 51) + require.NoError(t, err) + + // previously first range would have been pruned due to size limitations + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 22, + end: 25, + value: 12, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 30, + end: 49, + value: 16, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 51, + end: 0, + value: 17, + } + require.Equal(t, expectedRangeVal, r.ranges[2]) + + // excluded range should return error + value, err = r.GetValue(50) + require.ErrorIs(t, err, errKeyExcluded) + value, err = r.GetValue(28) + require.ErrorIs(t, err, errKeyExcluded) + value, err = r.GetValue(17) + require.ErrorIs(t, err, errKeyTooOld) + + // previously valid, but aged out key should return error + value, err = r.GetValue(5) + require.ErrorIs(t, err, errKeyTooOld) + + // valid range access should return values + value, err = r.GetValue(24) + require.NoError(t, err) + require.Equal(t, uint32(12), value) + + value, err = r.GetValue(34) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + value, err = r.GetValue(49) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(17), value) + + // reset + r.ClearAndResetValue(23) + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 0, + value: 23, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(23), value) + + // decrement value and ensure that any key returns that value + r.DecValue(12) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 0, + value: 11, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(11), value) + + // add an exclusion and then decrement value + err = r.ExcludeRange(10, 15) + require.NoError(t, err) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 9, + value: 11, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 15, + end: 0, + value: 16, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) + + // first range access + value, err = r.GetValue(5) + require.NoError(t, err) + require.Equal(t, uint32(11), value) + + // open range access + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + r.DecValue(6) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 0, + end: 9, + value: 11, + } + require.Equal(t, expectedRangeVal, r.ranges[0]) + + expectedRangeVal = rangeVal[uint32, uint32]{ + start: 15, + end: 0, + value: 10, + } + require.Equal(t, expectedRangeVal, r.ranges[1]) + + // first range access + value, err = r.GetValue(5) + require.NoError(t, err) + require.Equal(t, uint32(11), value) + + // open range access + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(10), value) }