mirror of
https://github.com/livekit/livekit.git
synced 2026-05-14 07:35:17 +00:00
Merge remote-tracking branch 'origin/master' into raja_fr
This commit is contained in:
@@ -417,16 +417,35 @@ func (b *Buffer) calc(pkt []byte, arrivalTime time.Time) {
|
||||
|
||||
flowState := b.updateStreamState(&rtpPacket, arrivalTime)
|
||||
b.processHeaderExtensions(&rtpPacket, arrivalTime)
|
||||
if !flowState.IsOutOfOrder && len(rtpPacket.Payload) == 0 {
|
||||
// drop padding only in-order packet
|
||||
b.snRangeMap.IncValue(1)
|
||||
if len(rtpPacket.Payload) == 0 && (!flowState.IsOutOfOrder || flowState.IsDuplicate) {
|
||||
// drop padding only in-order or duplicate packet
|
||||
if !flowState.IsOutOfOrder {
|
||||
// in-order packet - increment sequence number offset for subsequent packets
|
||||
// Example:
|
||||
// 40 - regular packet - pass through as sequence number 40
|
||||
// 41 - missing packet - don't know what it is, could be padding or not
|
||||
// 42 - padding only packet - in-order - drop - increment sequence number offset to 1 -
|
||||
// range[0, 42] = 0 offset
|
||||
// 41 - arrives out of order - get offset 0 from cache - passed through as sequence number 41
|
||||
// 43 - regular packet - offset = 1 (running offset) - passes through as sequence number 42
|
||||
// 44 - padding only - in order - drop - increment sequence number offset to 2
|
||||
// range[0, 42] = 0 offset, range[43, 44] = 1 offset
|
||||
// 43 - regular packet - out of order + duplicate - offset = 1 from cache -
|
||||
// adjusted sequence number is 42, will be dropped by RTX buffer AddPacket method as duplicate
|
||||
// 45 - regular packet - offset = 2 (running offset) - passed through with adjusted sequence number as 43
|
||||
// 44 - padding only - out-of-order + duplicate - dropped as duplicate
|
||||
//
|
||||
if err := b.snRangeMap.ExcludeRange(flowState.ExtSequenceNumber, flowState.ExtSequenceNumber+1); err != nil {
|
||||
b.logger.Errorw("could not exclude range", err, "sn", flowState.ExtSequenceNumber)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// add to RTX buffer using sequence number after accounting for dropped padding only packets
|
||||
snAdjustment, err := b.snRangeMap.GetValue(flowState.ExtSequenceNumber)
|
||||
if err != nil {
|
||||
b.logger.Errorw("could not get sequence number adjustment", err)
|
||||
b.logger.Errorw("could not get sequence number adjustment", err, "sn", flowState.ExtSequenceNumber, "payloadSize", len(rtpPacket.Payload))
|
||||
return
|
||||
}
|
||||
rtpPacket.Header.SequenceNumber = uint16(flowState.ExtSequenceNumber - snAdjustment)
|
||||
@@ -516,7 +535,6 @@ func (b *Buffer) updateStreamState(p *rtp.Packet, arrivalTime time.Time) RTPFlow
|
||||
b.nacker.Remove(p.SequenceNumber)
|
||||
|
||||
if flowState.HasLoss {
|
||||
b.snRangeMap.AddRange(flowState.LossStartInclusive, flowState.LossEndExclusive)
|
||||
for lost := flowState.LossStartInclusive; lost != flowState.LossEndExclusive; lost++ {
|
||||
b.nacker.Push(uint16(lost))
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ type RTPFlowState struct {
|
||||
LossStartInclusive uint64
|
||||
LossEndExclusive uint64
|
||||
|
||||
IsDuplicate bool
|
||||
IsOutOfOrder bool
|
||||
|
||||
ExtSequenceNumber uint64
|
||||
@@ -409,7 +410,6 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa
|
||||
|
||||
hdrSize := uint64(rtph.MarshalSize())
|
||||
pktSize := hdrSize + uint64(payloadSize+paddingSize)
|
||||
isDuplicate := false
|
||||
gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest)
|
||||
if gapSN <= 0 { // duplicate OR out-of-order
|
||||
if payloadSize == 0 {
|
||||
@@ -458,7 +458,7 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa
|
||||
r.bytesDuplicate += pktSize
|
||||
r.headerBytesDuplicate += hdrSize
|
||||
r.packetsDuplicate++
|
||||
isDuplicate = true
|
||||
flowState.IsDuplicate = true
|
||||
} else {
|
||||
r.packetsLost--
|
||||
r.setSnInfo(resSN.ExtendedVal, resSN.PreExtendedHighest, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, true)
|
||||
@@ -492,7 +492,7 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa
|
||||
flowState.ExtTimestamp = resTS.ExtendedVal
|
||||
}
|
||||
|
||||
if !isDuplicate {
|
||||
if !flowState.IsDuplicate {
|
||||
if payloadSize == 0 {
|
||||
r.packetsPadding++
|
||||
r.bytesPadding += pktSize
|
||||
@@ -565,7 +565,7 @@ func (r *RTPStats) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt uint32
|
||||
extHighestSNOverridden += (1 << 32)
|
||||
}
|
||||
}
|
||||
if extHighestSNOverridden < r.sequenceNumber.GetExtendedHighest() {
|
||||
if extHighestSNOverridden < r.sequenceNumber.GetExtendedStart() {
|
||||
// it is possible that the `LastSequenceNumber` in the receiver report is before the starting
|
||||
// sequence number when dummy packets are used to trigger Pion's OnTrack path.
|
||||
r.lastRRTime = time.Now()
|
||||
@@ -1591,8 +1591,8 @@ func (r *RTPStats) getDrift() (packetDrift *livekit.RTPDrift, reportDrift *livek
|
||||
StartTime: timestamppb.New(r.srFirst.NTPTimestamp.Time()),
|
||||
EndTime: timestamppb.New(r.srNewest.NTPTimestamp.Time()),
|
||||
Duration: elapsed.Seconds(),
|
||||
StartTimestamp: r.timestamp.GetExtendedStart(),
|
||||
EndTimestamp: r.timestamp.GetExtendedHighest(),
|
||||
StartTimestamp: r.srFirst.RTPTimestampExt,
|
||||
EndTimestamp: r.srNewest.RTPTimestampExt,
|
||||
RtpClockTicks: rtpClockTicks,
|
||||
DriftSamples: driftSamples,
|
||||
DriftMs: (float64(driftSamples) * 1000) / float64(r.params.ClockRate),
|
||||
|
||||
@@ -47,12 +47,7 @@ func NewDownTrackSpreader(params DownTrackSpreaderParams) *DownTrackSpreader {
|
||||
func (d *DownTrackSpreader) GetDownTracks() []TrackSender {
|
||||
d.downTrackMu.RLock()
|
||||
defer d.downTrackMu.RUnlock()
|
||||
|
||||
downTracks := make([]TrackSender, 0, len(d.downTracksShadow))
|
||||
for _, dt := range d.downTracksShadow {
|
||||
downTracks = append(downTracks, dt)
|
||||
}
|
||||
return downTracks
|
||||
return d.downTracksShadow
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) ResetAndGetDownTracks() []TrackSender {
|
||||
|
||||
@@ -1214,10 +1214,9 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) {
|
||||
require.Equal(t, expectedTP, *actualTP)
|
||||
|
||||
// add a missing sequence number to the cache
|
||||
f.rtpMunger.snRangeMap.IncValue(10)
|
||||
f.rtpMunger.snRangeMap.AddRange(23332, 23333)
|
||||
f.rtpMunger.snRangeMap.ExcludeRange(23332, 23333)
|
||||
|
||||
// out-of-order packet not in cache should be dropped
|
||||
// out-of-order packet should get offset from cache
|
||||
params = &testutils.TestExtPacketParams{
|
||||
SequenceNumber: 23331,
|
||||
Timestamp: 0xabcdef,
|
||||
@@ -1227,7 +1226,11 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) {
|
||||
extPkt, _ = testutils.GetTestExtPacket(params)
|
||||
|
||||
expectedTP = TranslationParams{
|
||||
shouldDrop: true,
|
||||
rtp: &TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingOutOfOrder,
|
||||
sequenceNumber: 23331,
|
||||
timestamp: 0xabcdef,
|
||||
},
|
||||
}
|
||||
actualTP, err = f.GetTranslationParams(extPkt, 0)
|
||||
require.NoError(t, err)
|
||||
@@ -1260,7 +1263,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) {
|
||||
expectedTP = TranslationParams{
|
||||
rtp: &TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingContiguous,
|
||||
sequenceNumber: 23324,
|
||||
sequenceNumber: 23333,
|
||||
timestamp: 0xabcdef,
|
||||
},
|
||||
}
|
||||
@@ -1279,7 +1282,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) {
|
||||
expectedTP = TranslationParams{
|
||||
rtp: &TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingGap,
|
||||
sequenceNumber: 23326,
|
||||
sequenceNumber: 23335,
|
||||
timestamp: 0xabcdef,
|
||||
},
|
||||
}
|
||||
@@ -1299,7 +1302,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) {
|
||||
expectedTP = TranslationParams{
|
||||
rtp: &TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingOutOfOrder,
|
||||
sequenceNumber: 23325,
|
||||
sequenceNumber: 23334,
|
||||
timestamp: 0xabcdef,
|
||||
},
|
||||
}
|
||||
@@ -1319,7 +1322,7 @@ func TestForwarderGetTranslationParamsAudio(t *testing.T) {
|
||||
expectedTP = TranslationParams{
|
||||
rtp: &TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingContiguous,
|
||||
sequenceNumber: 23327,
|
||||
sequenceNumber: 23336,
|
||||
timestamp: 0xabcdf0,
|
||||
},
|
||||
}
|
||||
|
||||
+39
-18
@@ -53,12 +53,13 @@ type SnTs struct {
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
type RTPMungerState struct {
|
||||
ExtLastSN uint64
|
||||
ExtLastTS uint64
|
||||
ExtLastSN uint64
|
||||
ExtSecondLastSN uint64
|
||||
ExtLastTS uint64
|
||||
}
|
||||
|
||||
func (r RTPMungerState) String() string {
|
||||
return fmt.Sprintf("RTPMungerState{extLastSN: %d, extLastTS: %d)", r.ExtLastSN, r.ExtLastTS)
|
||||
return fmt.Sprintf("RTPMungerState{extLastSN: %d, extSecondLastSN: %d, extLastTS: %d)", r.ExtLastSN, r.ExtSecondLastSN, r.ExtLastTS)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
@@ -69,10 +70,11 @@ type RTPMunger struct {
|
||||
extHighestIncomingSN uint64
|
||||
snRangeMap *utils.RangeMap[uint64, uint64]
|
||||
|
||||
extLastSN uint64
|
||||
extLastTS uint64
|
||||
tsOffset uint64
|
||||
lastMarker bool
|
||||
extLastSN uint64
|
||||
extSecondLastSN uint64
|
||||
extLastTS uint64
|
||||
tsOffset uint64
|
||||
lastMarker bool
|
||||
|
||||
extRtxGateSn uint64
|
||||
isInRtxGateRegion bool
|
||||
@@ -86,10 +88,11 @@ func NewRTPMunger(logger logger.Logger) *RTPMunger {
|
||||
}
|
||||
|
||||
func (r *RTPMunger) DebugInfo() map[string]interface{} {
|
||||
snOffset, _ := r.snRangeMap.GetValue(r.extHighestIncomingSN)
|
||||
snOffset, _ := r.snRangeMap.GetValue(r.extHighestIncomingSN + 1)
|
||||
return map[string]interface{}{
|
||||
"ExtHighestIncomingSN": r.extHighestIncomingSN,
|
||||
"ExtLastSN": r.extLastSN,
|
||||
"ExtSecondLastSN": r.extSecondLastSN,
|
||||
"SNOffset": snOffset,
|
||||
"ExtLastTS": r.extLastTS,
|
||||
"TSOffset": r.tsOffset,
|
||||
@@ -99,19 +102,22 @@ func (r *RTPMunger) DebugInfo() map[string]interface{} {
|
||||
|
||||
func (r *RTPMunger) GetLast() RTPMungerState {
|
||||
return RTPMungerState{
|
||||
ExtLastSN: r.extLastSN,
|
||||
ExtLastTS: r.extLastTS,
|
||||
ExtLastSN: r.extLastSN,
|
||||
ExtSecondLastSN: r.extSecondLastSN,
|
||||
ExtLastTS: r.extLastTS,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RTPMunger) SeedLast(state RTPMungerState) {
|
||||
r.extLastSN = state.ExtLastSN
|
||||
r.extSecondLastSN = state.ExtSecondLastSN
|
||||
r.extLastTS = state.ExtLastTS
|
||||
}
|
||||
|
||||
func (r *RTPMunger) SetLastSnTs(extPkt *buffer.ExtPacket) {
|
||||
r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1
|
||||
r.extLastSN = extPkt.ExtSequenceNumber
|
||||
r.extSecondLastSN = r.extLastSN - 1
|
||||
r.extLastTS = extPkt.ExtTimestamp
|
||||
}
|
||||
|
||||
@@ -126,15 +132,22 @@ func (r *RTPMunger) PacketDropped(extPkt *buffer.ExtPacket) {
|
||||
return
|
||||
}
|
||||
|
||||
r.snRangeMap.IncValue(1)
|
||||
|
||||
snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber)
|
||||
if err != nil {
|
||||
r.logger.Errorw("could not get sequence number offset", err)
|
||||
return
|
||||
if err == nil {
|
||||
outSN := extPkt.ExtSequenceNumber - snOffset
|
||||
if outSN != r.extLastSN {
|
||||
r.logger.Warnw("last outgoing sequence number mismatch", nil, "expected", r.extLastSN, "got", outSN)
|
||||
}
|
||||
}
|
||||
if r.extLastSN == r.extSecondLastSN {
|
||||
r.logger.Warnw("cannot roll back on drop", nil, "extLastSN", r.extLastSN, "secondLastSN", r.extSecondLastSN)
|
||||
}
|
||||
|
||||
r.extLastSN = extPkt.ExtSequenceNumber - snOffset
|
||||
if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil {
|
||||
r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN)
|
||||
}
|
||||
|
||||
r.extLastSN = r.extSecondLastSN
|
||||
}
|
||||
|
||||
func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationParamsRTP, error) {
|
||||
@@ -166,14 +179,15 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara
|
||||
ordering := SequenceNumberOrderingContiguous
|
||||
if diff > 1 {
|
||||
ordering = SequenceNumberOrderingGap
|
||||
r.snRangeMap.AddRange(r.extHighestIncomingSN+1, extPkt.ExtSequenceNumber)
|
||||
}
|
||||
|
||||
r.extHighestIncomingSN = extPkt.ExtSequenceNumber
|
||||
|
||||
// if padding only packet, can be dropped and sequence number adjusted, if contiguous
|
||||
if diff == 1 && len(extPkt.Packet.Payload) == 0 {
|
||||
r.snRangeMap.IncValue(1)
|
||||
if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil {
|
||||
r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN)
|
||||
}
|
||||
return &TranslationParamsRTP{
|
||||
snOrdering: ordering,
|
||||
}, ErrPaddingOnlyPacket
|
||||
@@ -181,6 +195,7 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara
|
||||
|
||||
snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber)
|
||||
if err != nil {
|
||||
r.logger.Errorw("could not get sequence number adjustment", err, "sn", extPkt.ExtSequenceNumber, "payloadSize", len(extPkt.Packet.Payload))
|
||||
return &TranslationParamsRTP{
|
||||
snOrdering: ordering,
|
||||
}, ErrSequenceNumberOffsetNotFound
|
||||
@@ -189,6 +204,7 @@ func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket) (*TranslationPara
|
||||
extMungedSN := extPkt.ExtSequenceNumber - snOffset
|
||||
extMungedTS := extPkt.ExtTimestamp - r.tsOffset
|
||||
|
||||
r.extSecondLastSN = r.extLastSN
|
||||
r.extLastSN = extMungedSN
|
||||
r.extLastTS = extMungedTS
|
||||
r.lastMarker = extPkt.Packet.Marker
|
||||
@@ -225,6 +241,10 @@ func (r *RTPMunger) FilterRTX(nacks []uint16) []uint16 {
|
||||
}
|
||||
|
||||
func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate uint32, forceMarker bool, extRtpTimestamp uint64) ([]SnTs, error) {
|
||||
if num == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
useLastTSForFirst := false
|
||||
tsOffset := 0
|
||||
if !r.lastMarker {
|
||||
@@ -260,6 +280,7 @@ func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate
|
||||
}
|
||||
}
|
||||
|
||||
r.extSecondLastSN = extLastSN - 1
|
||||
r.extLastSN = extLastSN
|
||||
r.snRangeMap.DecValue(uint64(num))
|
||||
|
||||
|
||||
+32
-22
@@ -122,10 +122,13 @@ func TestPacketDropped(t *testing.T) {
|
||||
extPkt, _ = testutils.GetTestExtPacket(params)
|
||||
|
||||
r.UpdateAndGetSnTs(extPkt) // update sequence number offset
|
||||
require.Equal(t, uint64(44444), r.extLastSN)
|
||||
|
||||
r.PacketDropped(extPkt)
|
||||
require.Equal(t, uint64(44443), r.extLastSN)
|
||||
require.Equal(t, uint64(23333), r.extLastSN)
|
||||
snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN)
|
||||
require.Error(t, err)
|
||||
snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN + 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), snOffset)
|
||||
|
||||
@@ -158,10 +161,9 @@ func TestOutOfOrderSequenceNumber(t *testing.T) {
|
||||
r.UpdateAndGetSnTs(extPkt)
|
||||
|
||||
// add a missing sequence number to the cache
|
||||
r.snRangeMap.IncValue(10)
|
||||
r.snRangeMap.AddRange(23332, 23333)
|
||||
r.snRangeMap.ExcludeRange(23332, 23333)
|
||||
|
||||
// out-of-order sequence number not in the missing sequence number cache
|
||||
// out-of-order sequence number should be munged using cache
|
||||
params = &testutils.TestExtPacketParams{
|
||||
SequenceNumber: 23331,
|
||||
Timestamp: 0xabcdef,
|
||||
@@ -171,12 +173,13 @@ func TestOutOfOrderSequenceNumber(t *testing.T) {
|
||||
extPkt, _ = testutils.GetTestExtPacket(params)
|
||||
|
||||
tpExpected := TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingOutOfOrder,
|
||||
snOrdering: SequenceNumberOrderingOutOfOrder,
|
||||
sequenceNumber: 23331,
|
||||
timestamp: 0xabcdef,
|
||||
}
|
||||
|
||||
tp, err := r.UpdateAndGetSnTs(extPkt)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrOutOfOrderSequenceNumberCacheMiss)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tpExpected, *tp)
|
||||
|
||||
params = &testutils.TestExtPacketParams{
|
||||
@@ -188,13 +191,11 @@ func TestOutOfOrderSequenceNumber(t *testing.T) {
|
||||
extPkt, _ = testutils.GetTestExtPacket(params)
|
||||
|
||||
tpExpected = TranslationParamsRTP{
|
||||
snOrdering: SequenceNumberOrderingOutOfOrder,
|
||||
sequenceNumber: 23322,
|
||||
timestamp: 0xabcdef,
|
||||
snOrdering: SequenceNumberOrderingOutOfOrder,
|
||||
}
|
||||
|
||||
tp, err = r.UpdateAndGetSnTs(extPkt)
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err, ErrOutOfOrderSequenceNumberCacheMiss)
|
||||
require.Equal(t, tpExpected, *tp)
|
||||
}
|
||||
|
||||
@@ -246,8 +247,7 @@ func TestPaddingOnlyPacket(t *testing.T) {
|
||||
require.Equal(t, uint64(23333), r.extHighestIncomingSN)
|
||||
require.Equal(t, uint64(23333), r.extLastSN)
|
||||
snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), snOffset)
|
||||
require.Error(t, err)
|
||||
|
||||
// padding only packet with a gap should not report an error
|
||||
params = &testutils.TestExtPacketParams{
|
||||
@@ -313,9 +313,7 @@ func TestGapInSequenceNumber(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(0), snOffset)
|
||||
|
||||
// ensure missing sequence numbers got recorded in cache
|
||||
|
||||
// last received, three missing in between and current received should all be in cache
|
||||
// ensure missing sequence numbers have correct cached offset
|
||||
for i := uint64(65534); i != 65536+1; i++ {
|
||||
offset, err := r.snRangeMap.GetValue(i)
|
||||
require.NoError(t, err)
|
||||
@@ -341,10 +339,9 @@ func TestGapInSequenceNumber(t *testing.T) {
|
||||
require.Equal(t, uint64(65536+2), r.extHighestIncomingSN)
|
||||
require.Equal(t, uint64(65536+1), r.extLastSN)
|
||||
snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), snOffset)
|
||||
require.Error(t, err)
|
||||
|
||||
// a packet with a gap should be adding to missing cache
|
||||
// a packet with a gap should be adjusting for dropped padding packet
|
||||
params = &testutils.TestExtPacketParams{
|
||||
SequenceNumber: 4,
|
||||
SNCycles: 1,
|
||||
@@ -369,6 +366,11 @@ func TestGapInSequenceNumber(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), snOffset)
|
||||
|
||||
// ensure missing sequence number has correct cached offset
|
||||
offset, err := r.snRangeMap.GetValue(65536 + 3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), offset)
|
||||
|
||||
// another contiguous padding only packet should be dropped
|
||||
params = &testutils.TestExtPacketParams{
|
||||
SequenceNumber: 5,
|
||||
@@ -388,10 +390,9 @@ func TestGapInSequenceNumber(t *testing.T) {
|
||||
require.Equal(t, uint64(65536+5), r.extHighestIncomingSN)
|
||||
require.Equal(t, uint64(65536+3), r.extLastSN)
|
||||
snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(2), snOffset)
|
||||
require.Error(t, err)
|
||||
|
||||
// a packet with a gap should be adding to missing cache
|
||||
// a packet with a gap should be adjusting for dropped packets
|
||||
params = &testutils.TestExtPacketParams{
|
||||
SequenceNumber: 7,
|
||||
SNCycles: 1,
|
||||
@@ -416,6 +417,15 @@ func TestGapInSequenceNumber(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(2), snOffset)
|
||||
|
||||
// ensure missing sequence number has correct cached offset
|
||||
offset, err = r.snRangeMap.GetValue(65536 + 3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(1), offset)
|
||||
|
||||
offset, err = r.snRangeMap.GetValue(65536 + 6)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(2), offset)
|
||||
|
||||
// check the missing packets
|
||||
params = &testutils.TestExtPacketParams{
|
||||
SequenceNumber: 6,
|
||||
|
||||
+68
-42
@@ -25,8 +25,10 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
errReversedOrder = errors.New("end is before start")
|
||||
errReversedOrder = errors.New("end <= start")
|
||||
errKeyNotFound = errors.New("key not found")
|
||||
errKeyTooOld = errors.New("key too old")
|
||||
errKeyExcluded = errors.New("key excluded")
|
||||
)
|
||||
|
||||
type rangeType interface {
|
||||
@@ -46,60 +48,68 @@ type rangeVal[RT rangeType, VT valueType] struct {
|
||||
type RangeMap[RT rangeType, VT valueType] struct {
|
||||
halfRange RT
|
||||
|
||||
size int
|
||||
ranges []rangeVal[RT, VT]
|
||||
runningValue VT
|
||||
size int
|
||||
ranges []rangeVal[RT, VT]
|
||||
}
|
||||
|
||||
func NewRangeMap[RT rangeType, VT valueType](size int) *RangeMap[RT, VT] {
|
||||
var t RT
|
||||
return &RangeMap[RT, VT]{
|
||||
r := &RangeMap[RT, VT]{
|
||||
halfRange: 1 << ((unsafe.Sizeof(t) * 8) - 1),
|
||||
size: int(math.Max(float64(size), float64(minRanges))),
|
||||
}
|
||||
r.initRanges(0)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RangeMap[RT, VT]) ClearAndResetValue(val VT) {
|
||||
r.ranges = r.ranges[:0]
|
||||
r.runningValue = val
|
||||
}
|
||||
|
||||
func (r *RangeMap[RT, VT]) IncValue(inc VT) {
|
||||
r.runningValue += inc
|
||||
r.initRanges(val)
|
||||
}
|
||||
|
||||
func (r *RangeMap[RT, VT]) DecValue(dec VT) {
|
||||
r.runningValue -= dec
|
||||
r.ranges[len(r.ranges)-1].value -= dec
|
||||
}
|
||||
|
||||
func (r *RangeMap[RT, VT]) AddRange(startInclusive RT, endExclusive RT) error {
|
||||
func (r *RangeMap[RT, VT]) initRanges(val VT) {
|
||||
r.ranges = []rangeVal[RT, VT]{
|
||||
{
|
||||
start: 0,
|
||||
end: 0,
|
||||
value: val,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RangeMap[RT, VT]) ExcludeRange(startInclusive RT, endExclusive RT) error {
|
||||
if endExclusive == startInclusive || endExclusive-startInclusive > r.halfRange {
|
||||
return errReversedOrder
|
||||
}
|
||||
|
||||
isNewRange := true
|
||||
// check if last range can be extended
|
||||
if len(r.ranges) != 0 {
|
||||
lr := &r.ranges[len(r.ranges)-1]
|
||||
if startInclusive <= lr.end {
|
||||
return errReversedOrder
|
||||
}
|
||||
if lr.value == r.runningValue {
|
||||
lr.end = endExclusive - 1
|
||||
isNewRange = false
|
||||
} else {
|
||||
// end last range before start and start a new range
|
||||
lr.end = startInclusive - 1
|
||||
}
|
||||
lr := &r.ranges[len(r.ranges)-1]
|
||||
if lr.start > startInclusive {
|
||||
// start of open range is after start of exclusion range, cannot close the open range
|
||||
return errReversedOrder
|
||||
}
|
||||
|
||||
if isNewRange {
|
||||
r.ranges = append(r.ranges, rangeVal[RT, VT]{
|
||||
start: startInclusive,
|
||||
end: endExclusive - 1,
|
||||
value: r.runningValue,
|
||||
})
|
||||
newValue := lr.value + VT(endExclusive-startInclusive)
|
||||
|
||||
// if start of exclusion range matches start of open range, move the open range
|
||||
if lr.start == startInclusive {
|
||||
lr.start = endExclusive
|
||||
lr.value = newValue
|
||||
return nil
|
||||
}
|
||||
|
||||
// close previous range
|
||||
lr.end = startInclusive - 1
|
||||
|
||||
// start new open one after given exclusion range
|
||||
r.ranges = append(r.ranges, rangeVal[RT, VT]{
|
||||
start: endExclusive,
|
||||
end: 0,
|
||||
value: newValue,
|
||||
})
|
||||
|
||||
r.prune()
|
||||
return nil
|
||||
}
|
||||
@@ -107,26 +117,42 @@ func (r *RangeMap[RT, VT]) AddRange(startInclusive RT, endExclusive RT) error {
|
||||
func (r *RangeMap[RT, VT]) GetValue(key RT) (VT, error) {
|
||||
numRanges := len(r.ranges)
|
||||
if numRanges != 0 {
|
||||
if key > r.ranges[numRanges-1].end {
|
||||
return r.runningValue, nil
|
||||
if key >= r.ranges[numRanges-1].start {
|
||||
// in the open range
|
||||
return r.ranges[numRanges-1].value, nil
|
||||
}
|
||||
|
||||
if key < r.ranges[0].start {
|
||||
return 0, errKeyNotFound
|
||||
// too old
|
||||
return 0, errKeyTooOld
|
||||
}
|
||||
}
|
||||
|
||||
for _, rv := range r.ranges {
|
||||
if key-rv.start < r.halfRange && rv.end-key < r.halfRange {
|
||||
return rv.value, nil
|
||||
for idx := numRanges - 1; idx >= 0; idx-- {
|
||||
rv := &r.ranges[idx]
|
||||
if idx != numRanges-1 {
|
||||
// open range checked above
|
||||
if key-rv.start < r.halfRange && rv.end-key < r.halfRange {
|
||||
return rv.value, nil
|
||||
}
|
||||
}
|
||||
|
||||
if idx > 0 {
|
||||
rvPrev := &r.ranges[idx-1]
|
||||
beforeDiff := key - rvPrev.end
|
||||
afterDiff := rv.start - key
|
||||
if beforeDiff > 0 && beforeDiff < r.halfRange && afterDiff > 0 && afterDiff < r.halfRange {
|
||||
// in excluded range
|
||||
return 0, errKeyExcluded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r.runningValue, nil
|
||||
return 0, errKeyNotFound
|
||||
}
|
||||
|
||||
func (r *RangeMap[RT, VT]) prune() {
|
||||
if len(r.ranges) > r.size {
|
||||
r.ranges = r.ranges[len(r.ranges)-r.size:]
|
||||
if len(r.ranges) > r.size+1 { // +1 to accommodate the open range
|
||||
r.ranges = r.ranges[len(r.ranges)-r.size-1:]
|
||||
}
|
||||
}
|
||||
|
||||
+253
-79
@@ -27,103 +27,277 @@ func TestRangeMapUint32(t *testing.T) {
|
||||
value, err := r.GetValue(33333)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(0), value)
|
||||
value, err = r.GetValue(0xffffffff)
|
||||
|
||||
expectedRangeVal := rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 0,
|
||||
value: 0,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
// add an exclusion, should create a new range
|
||||
err = r.ExcludeRange(10, 11)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 9,
|
||||
value: 0,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 11,
|
||||
end: 0,
|
||||
value: 1,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
// getting value in old range should return 0
|
||||
value, err = r.GetValue(6)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(0), value)
|
||||
|
||||
// getting value for any key should be incremented value
|
||||
r.IncValue(2)
|
||||
value, err = r.GetValue(66666666)
|
||||
// newer should return 1
|
||||
value, err = r.GetValue(11)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), value)
|
||||
value, err = r.GetValue(0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), value)
|
||||
require.Equal(t, uint32(1), value)
|
||||
|
||||
// add a couple of ranges, as the value is same should just extend
|
||||
err = r.AddRange(10, 20)
|
||||
require.NoError(t, err)
|
||||
err = r.AddRange(30, 40)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(r.ranges))
|
||||
require.Equal(t, uint32(10), r.ranges[0].start)
|
||||
require.Equal(t, uint32(39), r.ranges[0].end)
|
||||
require.Equal(t, uint32(2), r.ranges[0].value)
|
||||
|
||||
// bump value
|
||||
r.IncValue(1)
|
||||
// getting value in previously added range should return 2
|
||||
value, err = r.GetValue(22)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), value)
|
||||
|
||||
// outside range should return 3
|
||||
value, err = r.GetValue(662)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), value)
|
||||
|
||||
// adding out-of-order range should return error
|
||||
err = r.AddRange(60, 50)
|
||||
require.Error(t, err, errReversedOrder)
|
||||
|
||||
// adding overlapping should return error
|
||||
err = r.AddRange(30, 50)
|
||||
require.Error(t, err, errReversedOrder)
|
||||
|
||||
// adding a non-overlapping range should extend previous range and add new one
|
||||
err = r.AddRange(50, 60)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(r.ranges))
|
||||
|
||||
require.Equal(t, uint32(10), r.ranges[0].start)
|
||||
require.Equal(t, uint32(49), r.ranges[0].end)
|
||||
require.Equal(t, uint32(2), r.ranges[0].value)
|
||||
|
||||
require.Equal(t, uint32(50), r.ranges[1].start)
|
||||
require.Equal(t, uint32(59), r.ranges[1].end)
|
||||
require.Equal(t, uint32(3), r.ranges[1].value)
|
||||
|
||||
// getting an old value should not succeed, but start of first range should return no error
|
||||
value, err = r.GetValue(9)
|
||||
require.Error(t, err, errKeyNotFound)
|
||||
// excluded range should return error
|
||||
value, err = r.GetValue(10)
|
||||
require.ErrorIs(t, err, errKeyExcluded)
|
||||
|
||||
// out-of-order exclusion should return error
|
||||
err = r.ExcludeRange(9, 10)
|
||||
require.ErrorIs(t, err, errReversedOrder)
|
||||
|
||||
// flipped exclusion should return error
|
||||
err = r.ExcludeRange(12, 11)
|
||||
require.ErrorIs(t, err, errReversedOrder)
|
||||
err = r.ExcludeRange(11, 11)
|
||||
require.ErrorIs(t, err, errReversedOrder)
|
||||
|
||||
// add adjacent exclusion range of length = 1
|
||||
err = r.ExcludeRange(11, 12)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 9,
|
||||
value: 0,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 12,
|
||||
end: 0,
|
||||
value: 2,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
// excluded range should return error, now is excluded because exclusion range could be extended
|
||||
value, err = r.GetValue(11)
|
||||
require.ErrorIs(t, err, errKeyExcluded)
|
||||
|
||||
// getting value in old range should return 0
|
||||
value, err = r.GetValue(6)
|
||||
require.NoError(t, err)
|
||||
|
||||
// newer should return 2
|
||||
value, err = r.GetValue(12)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), value)
|
||||
|
||||
// adding another range should prune the first one as size if set to 2
|
||||
r.IncValue(10)
|
||||
err = r.AddRange(1000, 1233)
|
||||
// add adjacent exclusion range of length = 10
|
||||
err = r.ExcludeRange(12, 22)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, len(r.ranges))
|
||||
|
||||
require.Equal(t, uint32(50), r.ranges[0].start)
|
||||
require.Equal(t, uint32(999), r.ranges[0].end)
|
||||
require.Equal(t, uint32(3), r.ranges[0].value)
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 9,
|
||||
value: 0,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
require.Equal(t, uint32(1000), r.ranges[1].start)
|
||||
require.Equal(t, uint32(1232), r.ranges[1].end)
|
||||
require.Equal(t, uint32(13), r.ranges[1].value)
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 22,
|
||||
end: 0,
|
||||
value: 12,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
// previously valid range should return key not found after pruning
|
||||
value, err = r.GetValue(10)
|
||||
require.Error(t, err, errKeyNotFound)
|
||||
// excluded range should return error, now is excluded because exclusion range could be extended
|
||||
value, err = r.GetValue(15)
|
||||
require.ErrorIs(t, err, errKeyExcluded)
|
||||
|
||||
value, err = r.GetValue(999)
|
||||
// newer should return 12
|
||||
value, err = r.GetValue(25)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), value)
|
||||
require.Equal(t, uint32(12), value)
|
||||
|
||||
value, err = r.GetValue(1200)
|
||||
// add a disjoint exclusion of length = 4
|
||||
err = r.ExcludeRange(26, 30)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(13), value)
|
||||
|
||||
// something newer than what is in ranges should return running value
|
||||
value, err = r.GetValue(3000)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(13), value)
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 9,
|
||||
value: 0,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
// decrement running value
|
||||
r.DecValue(23)
|
||||
value, err = r.GetValue(3000)
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 22,
|
||||
end: 25,
|
||||
value: 12,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 30,
|
||||
end: 0,
|
||||
value: 16,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[2])
|
||||
|
||||
// get a value from newly closed range [22, 25]
|
||||
value, err = r.GetValue(23)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32((1<<32)-10), value)
|
||||
require.Equal(t, uint32(12), value)
|
||||
|
||||
// add a disjoint exclusion of length = 1
|
||||
err = r.ExcludeRange(50, 51)
|
||||
require.NoError(t, err)
|
||||
|
||||
// previously first range would have been pruned due to size limitations
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 22,
|
||||
end: 25,
|
||||
value: 12,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 30,
|
||||
end: 49,
|
||||
value: 16,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 51,
|
||||
end: 0,
|
||||
value: 17,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[2])
|
||||
|
||||
// excluded range should return error
|
||||
value, err = r.GetValue(50)
|
||||
require.ErrorIs(t, err, errKeyExcluded)
|
||||
value, err = r.GetValue(28)
|
||||
require.ErrorIs(t, err, errKeyExcluded)
|
||||
value, err = r.GetValue(17)
|
||||
require.ErrorIs(t, err, errKeyTooOld)
|
||||
|
||||
// previously valid, but aged out key should return error
|
||||
value, err = r.GetValue(5)
|
||||
require.ErrorIs(t, err, errKeyTooOld)
|
||||
|
||||
// valid range access should return values
|
||||
value, err = r.GetValue(24)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(12), value)
|
||||
|
||||
value, err = r.GetValue(34)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(16), value)
|
||||
|
||||
value, err = r.GetValue(49)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(16), value)
|
||||
|
||||
value, err = r.GetValue(55555555)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(17), value)
|
||||
|
||||
// reset
|
||||
r.ClearAndResetValue(23)
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 0,
|
||||
value: 23,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
value, err = r.GetValue(55555555)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(23), value)
|
||||
|
||||
// decrement value and ensure that any key returns that value
|
||||
r.DecValue(12)
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 0,
|
||||
value: 11,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
value, err = r.GetValue(55555555)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(11), value)
|
||||
|
||||
// add an exclusion and then decrement value
|
||||
err = r.ExcludeRange(10, 15)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 9,
|
||||
value: 11,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 15,
|
||||
end: 0,
|
||||
value: 16,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
// first range access
|
||||
value, err = r.GetValue(5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(11), value)
|
||||
|
||||
// open range access
|
||||
value, err = r.GetValue(55555555)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(16), value)
|
||||
|
||||
r.DecValue(6)
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 0,
|
||||
end: 9,
|
||||
value: 11,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[0])
|
||||
|
||||
expectedRangeVal = rangeVal[uint32, uint32]{
|
||||
start: 15,
|
||||
end: 0,
|
||||
value: 10,
|
||||
}
|
||||
require.Equal(t, expectedRangeVal, r.ranges[1])
|
||||
|
||||
// first range access
|
||||
value, err = r.GetValue(5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(11), value)
|
||||
|
||||
// open range access
|
||||
value, err = r.GetValue(55555555)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(10), value)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user