diff --git a/pkg/sfu/buffer/buffer_test.go b/pkg/sfu/buffer/buffer_test.go index eb6582dbb..7f3edf125 100644 --- a/pkg/sfu/buffer/buffer_test.go +++ b/pkg/sfu/buffer/buffer_test.go @@ -212,8 +212,8 @@ func TestNewBuffer(t *testing.T) { buf, _ := p.Marshal() _, _ = buff.Write(buf) } - require.Equal(t, uint16(1), buff.rtpStats.cycles) - require.Equal(t, uint16(2), buff.rtpStats.highestSN) + require.Equal(t, uint16(2), buff.rtpStats.sequenceNumber.GetHighest()) + require.Equal(t, uint32(65536+2), buff.rtpStats.sequenceNumber.GetExtendedHighest()) }) } } diff --git a/pkg/sfu/buffer/rtpstats.go b/pkg/sfu/buffer/rtpstats.go index 91069d377..a55c12a42 100644 --- a/pkg/sfu/buffer/rtpstats.go +++ b/pkg/sfu/buffer/rtpstats.go @@ -25,6 +25,7 @@ import ( "github.com/pion/rtp" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/livekit/livekit-server/pkg/sfu/utils" "github.com/livekit/mediatransportutil" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -157,17 +158,13 @@ type RTPStats struct { startTime time.Time endTime time.Time - extStartSN uint32 - highestSN uint16 - cycles uint16 + sequenceNumber *utils.WrapAround[uint16, uint32] extHighestSNOverridden uint32 lastRRTime time.Time lastRR rtcp.ReceptionReport - extStartTS uint64 - highestTS uint32 - tsCycles uint32 + timestamp *utils.WrapAround[uint32, uint64] firstTime time.Time highestTime time.Time @@ -232,6 +229,8 @@ func NewRTPStats(params RTPStatsParams) *RTPStats { return &RTPStats{ params: params, logger: params.Logger, + sequenceNumber: utils.NewWrapAround[uint16, uint32](), + timestamp: utils.NewWrapAround[uint32, uint64](), nextSnapshotId: FirstSnapshotId, snapshots: make(map[uint32]*Snapshot), } @@ -251,17 +250,13 @@ func (r *RTPStats) Seed(from *RTPStats) { r.startTime = from.startTime // do not clone endTime as a non-zero endTime indicates an ended object - r.extStartSN = from.extStartSN - r.highestSN = from.highestSN - r.cycles = from.cycles + r.sequenceNumber.Seed(from.sequenceNumber) r.extHighestSNOverridden = from.extHighestSNOverridden r.lastRRTime = from.lastRRTime r.lastRR = from.lastRR - r.extStartTS = from.extStartTS - r.highestTS = from.highestTS - r.tsCycles = from.tsCycles + r.timestamp.Seed(from.timestamp) r.firstTime = from.firstTime r.highestTime = from.highestTime @@ -352,10 +347,11 @@ func (r *RTPStats) NewSnapshotId() uint32 { id := r.nextSnapshotId if r.initialized { + extStartSN := r.sequenceNumber.GetExtendedStart() r.snapshots[id] = &Snapshot{ startTime: time.Now(), - extStartSN: r.extStartSN, - extStartSNOverridden: r.extStartSN, + extStartSN: extStartSN, + extStartSNOverridden: extStartSN, } } @@ -379,7 +375,18 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa return } - first := false + if r.resyncOnNextPacket { + r.resyncOnNextPacket = false + + if r.initialized { + r.sequenceNumber.ResetHighest(rtph.SequenceNumber - 1) + r.timestamp.ResetHighest(rtph.Timestamp) + r.highestTime = packetTime + } + } + + var resSN utils.WrapAroundUpdateResult[uint32] + var resTS utils.WrapAroundUpdateResult[uint64] if !r.initialized { if payloadSize == 0 { // do not start on a padding only packet @@ -390,25 +397,19 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa r.startTime = time.Now() - r.extStartSN = uint32(rtph.SequenceNumber) - r.highestSN = rtph.SequenceNumber - 1 - r.cycles = 0 - - r.extStartTS = uint64(rtph.Timestamp) - r.highestTS = rtph.Timestamp - r.tsCycles = 0 - r.firstTime = packetTime r.highestTime = packetTime - first = true + resSN = r.sequenceNumber.Update(rtph.SequenceNumber) + resTS = r.timestamp.Update(rtph.Timestamp) // initialize snapshots if any for i := uint32(FirstSnapshotId); i < r.nextSnapshotId; i++ { + extStartSN := r.sequenceNumber.GetExtendedStart() r.snapshots[i] = &Snapshot{ startTime: r.startTime, - extStartSN: r.extStartSN, - extStartSNOverridden: r.extStartSN, + extStartSN: extStartSN, + extStartSNOverridden: extStartSN, } } @@ -416,92 +417,95 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa "rtp stream start", "startTime", r.startTime.String(), "firstTime", r.firstTime.String(), - "startSN", r.extStartSN, - "startTS", r.extStartTS, + "startSN", r.sequenceNumber.GetExtendedHighest(), + "startTS", r.timestamp.GetExtendedHighest(), ) - } - - if r.resyncOnNextPacket { - r.resyncOnNextPacket = false - - r.highestSN = rtph.SequenceNumber - 1 - r.highestTS = rtph.Timestamp - r.highestTime = packetTime + } else { + resSN = r.sequenceNumber.Update(rtph.SequenceNumber) + resTS = r.timestamp.Update(rtph.Timestamp) } hdrSize := uint64(rtph.MarshalSize()) pktSize := hdrSize + uint64(payloadSize+paddingSize) isDuplicate := false - diff := rtph.SequenceNumber - r.highestSN - switch { - // duplicate or out-of-order - case diff == 0 || diff > (1<<15): - if diff != 0 { + gapSN := resSN.ExtendedVal - resSN.PreExtendedHighest + if gapSN == 0 || gapSN > (1<<31) { // duplicate OR out-of-order + if payloadSize == 0 { + // do not start on a padding only packet + if resTS.IsRestart { + r.logger.Infow("rolling back timestamp restart", "tsBefore", r.timestamp.GetExtendedStart(), "tsAfter", resTS.PreExtendedStart) + r.timestamp.RollbackRestart(resTS.PreExtendedStart) + } + if resSN.IsRestart { + r.logger.Infow("rolling back sequence number restart", "snBefore", r.sequenceNumber.GetExtendedStart(), "snAfter", resSN.PreExtendedStart) + r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart) + return + } + } + + if gapSN != 0 { r.packetsOutOfOrder++ } - // adjust start to account for out-of-order packets before a cycle completes - if !r.maybeAdjustStart(rtph, pktSize, hdrSize, payloadSize) { - if !r.isSnInfoLost(rtph.SequenceNumber) { - r.bytesDuplicate += pktSize - r.headerBytesDuplicate += hdrSize - r.packetsDuplicate++ - isDuplicate = true - } else { - r.packetsLost-- - r.setSnInfo(rtph.SequenceNumber, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, true) + if resSN.IsRestart { + r.packetsLost += resSN.PreExtendedStart - resSN.ExtendedVal + + extStartSN := r.sequenceNumber.GetExtendedStart() + for _, s := range r.snapshots { + if s.extStartSN == resSN.PreExtendedStart { + s.extStartSN = extStartSN + } } + + r.logger.Infow( + "adjusting start sequence number", + "snBefore", resSN.PreExtendedStart, + "snAfter", resSN.ExtendedVal, + ) + } + + if resTS.IsRestart { + r.logger.Infow( + "adjusting start timestamp", + "tsBefore", resTS.PreExtendedStart, + "tsAfter", resTS.ExtendedVal, + ) + } + + if !r.isSnInfoLost(resSN.ExtendedVal, resSN.PreExtendedHighest) { + r.bytesDuplicate += pktSize + r.headerBytesDuplicate += hdrSize + r.packetsDuplicate++ + isDuplicate = true + } else { + r.packetsLost-- + r.setSnInfo(resSN.ExtendedVal, resSN.PreExtendedHighest, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, true) } flowState.IsOutOfOrder = true - - cycles := r.cycles - if rtph.SequenceNumber > r.highestSN { - cycles-- - } - flowState.ExtSeqNumber = getExtSN(rtph.SequenceNumber, cycles) - - // in-order - default: + flowState.ExtSeqNumber = resSN.ExtendedVal + } else { // in-order // update gap histogram - r.updateGapHistogram(int(diff)) + r.updateGapHistogram(int(gapSN)) // update missing sequence numbers - r.clearSnInfos(r.highestSN+1, rtph.SequenceNumber) - r.packetsLost += uint32(diff - 1) + r.clearSnInfos(resSN.PreExtendedHighest+1, resSN.ExtendedVal) + r.packetsLost += gapSN - 1 - r.setSnInfo(rtph.SequenceNumber, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, false) - - if diff > 1 { - flowState.HasLoss = true - - cycles := r.cycles - if r.highestSN+1 < r.highestSN { - cycles++ - } - flowState.LossStartInclusive = getExtSN(r.highestSN+1, cycles) - } - - if rtph.SequenceNumber < r.highestSN && !first { - r.cycles++ - } - r.highestSN = rtph.SequenceNumber - - if rtph.Timestamp != r.highestTS { - if rtph.Timestamp < r.highestTS && !first { - r.tsCycles++ - } - r.highestTS = rtph.Timestamp + r.setSnInfo(resSN.ExtendedVal, resSN.PreExtendedHighest, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, false) + if rtph.Timestamp != uint32(resTS.PreExtendedHighest) { // update only on first packet as same timestamp could be in multiple packets. // NOTE: this may not be the first packet with this time stamp if there is packet loss. r.highestTime = packetTime } - if flowState.HasLoss { - flowState.LossEndExclusive = getExtSN(rtph.SequenceNumber, r.cycles) + if gapSN > 1 { + flowState.HasLoss = true + flowState.LossStartInclusive = resSN.PreExtendedHighest + 1 + flowState.LossEndExclusive = resSN.ExtendedVal } - flowState.ExtSeqNumber = getExtSN(rtph.SequenceNumber, r.cycles) + flowState.ExtSeqNumber = resSN.ExtendedVal } if !isDuplicate { @@ -520,7 +524,6 @@ func (r *RTPStats) Update(rtph *rtp.Header, payloadSize int, paddingSize int, pa r.updateJitter(rtph, packetTime) } } - return } @@ -531,54 +534,8 @@ func (r *RTPStats) ResyncOnNextPacket() { r.resyncOnNextPacket = true } -func (r *RTPStats) maybeAdjustStart(rtph *rtp.Header, pktSize uint64, hdrSize uint64, payloadSize int) bool { - if (r.getExtHighestSN() - r.extStartSN + 1) >= (NumSequenceNumbers / 2) { - return false - } - - if (rtph.SequenceNumber - uint16(r.extStartSN)) < (1 << 15) { - return false - } - - if payloadSize == 0 { - // do not start on a padding only packet - r.logger.Infow("adjusting start, skipping on padding only packet") - return true - } - - r.packetsLost += uint32(uint16(r.extStartSN)-rtph.SequenceNumber) - 1 - snBeforeAdjust := r.extStartSN - r.extStartSN = uint32(rtph.SequenceNumber) - if r.extStartSN > snBeforeAdjust { - // wrapping back - r.cycles++ - } - - r.setSnInfo(rtph.SequenceNumber, uint16(pktSize), uint16(hdrSize), uint16(payloadSize), rtph.Marker, true) - - for _, s := range r.snapshots { - if s.extStartSN == snBeforeAdjust { - s.extStartSN = r.extStartSN - } - } - - tsBeforeAdjust := r.extStartTS - r.extStartTS = uint64(rtph.Timestamp) - if r.extStartTS > tsBeforeAdjust { - // wrapping back - r.tsCycles++ - } - r.logger.Infow( - "adjusting start", - "snBefore", snBeforeAdjust, - "snAfter", r.extStartSN, - "snCyles", r.cycles, - "tsBefore", tsBeforeAdjust, - "tsAfter", r.extStartTS, - "tsCyles", r.tsCycles, - ) - - return true +func (r *RTPStats) getPacketsExpected() uint32 { + return r.sequenceNumber.GetExtendedHighest() - r.sequenceNumber.GetExtendedStart() + 1 } func (r *RTPStats) GetTotalPacketsPrimary() uint32 { @@ -589,7 +546,7 @@ func (r *RTPStats) GetTotalPacketsPrimary() uint32 { } func (r *RTPStats) getTotalPacketsPrimary() uint32 { - packetsExpected := r.getExtHighestSN() - r.extStartSN + 1 + packetsExpected := r.getPacketsExpected() if r.packetsLost > packetsExpected { // should not happen return 0 @@ -607,7 +564,7 @@ func (r *RTPStats) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt uint32 r.lock.Lock() defer r.lock.Unlock() - if !r.initialized || !r.endTime.IsZero() || !r.params.IsReceiverReportDriven || rr.LastSequenceNumber < r.extStartSN { + if !r.initialized || !r.endTime.IsZero() || !r.params.IsReceiverReportDriven || rr.LastSequenceNumber < r.sequenceNumber.GetExtendedHighest() { // it is possible that the `LastSequenceNumber` in the receiver report is before the starting // sequence number when dummy packets are used to trigger Pion's OnTrack path. return @@ -846,7 +803,7 @@ func (r *RTPStats) maybeAdjustFirstPacketTime(ts uint32) { // abnormal delay (maybe due to pacing or maybe due to queuing // in some network element along the way), push back first time // to an earlier instance. - samplesDiff := int32(ts - uint32(r.extStartTS)) + samplesDiff := int32(ts - uint32(r.timestamp.GetExtendedStart())) if samplesDiff < 0 { // out-of-order, skip return @@ -863,7 +820,7 @@ func (r *RTPStats) maybeAdjustFirstPacketTime(ts uint32) { "after", firstTime.String(), "adjustment", r.firstTime.Sub(firstTime), "nowTS", ts, - "extStartTS", r.extStartTS, + "extStartTS", r.timestamp.GetExtendedStart(), ) if r.firstTime.Sub(firstTime) > firstPacketTimeAdjustThreshold { r.logger.Infow("first packet time adjustment too big, ignoring", @@ -873,7 +830,7 @@ func (r *RTPStats) maybeAdjustFirstPacketTime(ts uint32) { "after", firstTime.String(), "adjustment", r.firstTime.Sub(firstTime), "nowTS", ts, - "extStartTS", r.extStartTS, + "extStartTS", r.timestamp.GetExtendedStart(), ) } else { r.firstTime = firstTime @@ -967,7 +924,7 @@ func (r *RTPStats) SetRtcpSenderReportData(srData *RTCPSenderReportData) { "expectedTimeDiffSinceLast", expectedTimeDiffSinceLast, "packetDrift", packetDriftResult.String(), "reportDrift", reportDriftResult.String(), - "highestTS", r.highestTS, + "highestTS", r.timestamp.GetExtendedHighest(), "highestTime", r.highestTime.String(), ) } @@ -1000,7 +957,7 @@ func (r *RTPStats) GetExpectedRTPTimestamp(at time.Time) (expectedTSExt uint64, timeDiff := at.Sub(r.firstTime) expectedRTPDiff := timeDiff.Nanoseconds() * int64(r.params.ClockRate) / 1e9 - expectedTSExt = r.extStartTS + uint64(expectedRTPDiff) + expectedTSExt = r.timestamp.GetExtendedStart() + uint64(expectedRTPDiff) return } @@ -1018,19 +975,15 @@ func (r *RTPStats) GetRtcpSenderReport(ssrc uint32, calculatedClockRate uint32) nowNTP := mediatransportutil.ToNtpTime(now) timeSinceHighest := now.Sub(r.highestTime) - nowRTP := r.highestTS + uint32(timeSinceHighest.Nanoseconds()*int64(r.params.ClockRate)/1e9) + nowRTPExt := r.timestamp.GetExtendedHighest() + uint64(timeSinceHighest.Nanoseconds()*int64(r.params.ClockRate)/1e9) + nowRTP := uint32(nowRTPExt) // It is possible that publisher is pacing at a slower rate. // That would make `highestTS` to be lagging the RTP time stamp in the RTCP Sender Report from publisher. // Check for that using calculated clock rate and use the later time stamp if applicable. - tsCycles := r.tsCycles - if nowRTP < r.highestTS { - tsCycles++ - } - nowRTPExt := getExtTS(nowRTP, tsCycles) var nowRTPExtUsingRate uint64 if calculatedClockRate != 0 { - nowRTPExtUsingRate = r.extStartTS + uint64(float64(calculatedClockRate)*timeSinceFirst.Seconds()) + nowRTPExtUsingRate = r.timestamp.GetExtendedStart() + uint64(float64(calculatedClockRate)*timeSinceFirst.Seconds()) if nowRTPExtUsingRate > nowRTPExt { nowRTPExt = nowRTPExtUsingRate nowRTP = uint32(nowRTPExt) @@ -1104,7 +1057,7 @@ func (r *RTPStats) GetRtcpSenderReport(ssrc uint32, calculatedClockRate uint32) "expectedTimeDiffSinceLast", expectedTimeDiffSinceLast, "packetDrift", packetDriftResult.String(), "reportDrift", reportDriftResult.String(), - "highestTS", r.highestTS, + "highestTS", r.timestamp.GetExtendedHighest(), "highestTime", r.highestTime.String(), "calculatedClockRate", calculatedClockRate, "nowRTPExt", nowRTPExt, @@ -1145,7 +1098,7 @@ func (r *RTPStats) SnapshotRtcpReceptionReport(ssrc uint32, proxyFracLost uint8, return nil } - intervalStats := r.getIntervalStats(uint16(then.extStartSN), uint16(now.extStartSN)) + intervalStats := r.getIntervalStats(then.extStartSN, now.extStartSN) packetsLost := intervalStats.packetsLost lossRate := float32(packetsLost) / float32(packetsExpected) fracLost := uint8(lossRate * 256.0) @@ -1205,7 +1158,7 @@ func (r *RTPStats) DeltaInfo(snapshotId uint32) *RTPDeltaInfo { } } - intervalStats := r.getIntervalStats(uint16(then.extStartSN), uint16(now.extStartSN)) + intervalStats := r.getIntervalStats(then.extStartSN, now.extStartSN) return &RTPDeltaInfo{ StartTime: startTime, Duration: endTime.Sub(startTime), @@ -1260,7 +1213,7 @@ func (r *RTPStats) DeltaInfoOverridden(snapshotId uint32) *RTPDeltaInfo { return nil } - intervalStats := r.getIntervalStats(uint16(then.extStartSNOverridden), uint16(now.extStartSNOverridden)) + intervalStats := r.getIntervalStats(then.extStartSNOverridden, now.extStartSNOverridden) packetsLost := now.packetsLostOverridden - then.packetsLostOverridden if int32(packetsLost) < 0 { packetsLost = 0 @@ -1321,12 +1274,12 @@ func (r *RTPStats) ToString() string { r.lock.RLock() defer r.lock.RUnlock() - expectedPackets := r.getExtHighestSN() - r.extStartSN + 1 + expectedPackets := r.getPacketsExpected() expectedPacketRate := float64(expectedPackets) / p.Duration str := fmt.Sprintf("t: %+v|%+v|%.2fs", p.StartTime.AsTime().Format(time.UnixDate), p.EndTime.AsTime().Format(time.UnixDate), p.Duration) - str += fmt.Sprintf(", sn: %d|%d", r.extStartSN, r.getExtHighestSN()) + str += fmt.Sprintf(", sn: %d|%d", r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest()) str += fmt.Sprintf(", ep: %d|%.2f/s", expectedPackets, expectedPacketRate) str += fmt.Sprintf(", p: %d|%.2f/s", p.Packets, p.PacketRate) @@ -1411,7 +1364,7 @@ func (r *RTPStats) ToProto() *livekit.RTPStats { frameRate := float64(r.frames) / elapsed - packetsExpected := r.getExtHighestSN() - r.extStartSN + 1 + packetsExpected := r.getPacketsExpected() packetsLost := r.getPacketsLost() packetLostRate := float64(packetsLost) / elapsed packetLostPercentage := float32(packetsLost) / float32(packetsExpected) * 100.0 @@ -1503,16 +1456,12 @@ func (r *RTPStats) ToProto() *livekit.RTPStats { return p } -func (r *RTPStats) getExtHighestSN() uint32 { - return (uint32(r.cycles) << 16) | uint32(r.highestSN) -} - func (r *RTPStats) getExtHighestSNAdjusted() uint32 { if r.params.IsReceiverReportDriven && !r.lastRRTime.IsZero() { return r.extHighestSNOverridden } - return r.getExtHighestSN() + return r.sequenceNumber.GetExtendedHighest() } func (r *RTPStats) getPacketsLost() uint32 { @@ -1523,13 +1472,13 @@ func (r *RTPStats) getPacketsLost() uint32 { return r.packetsLost } -func (r *RTPStats) getSnInfoOutOfOrderPtr(sn uint16) int { - offset := sn - r.highestSN - if offset > 0 && offset < (1<<15) { +func (r *RTPStats) getSnInfoOutOfOrderPtr(esn uint32, ehsn uint32) int { + offset := esn - ehsn + if offset > 0 && offset < (1<<31) { return -1 // in-order, not expected, maybe too new } - offset = r.highestSN - sn + offset = ehsn - esn if int(offset) >= SnInfoSize { // too old, ignore return -1 @@ -1538,14 +1487,14 @@ func (r *RTPStats) getSnInfoOutOfOrderPtr(sn uint16) int { return (r.snInfoWritePtr - int(offset) - 1) & SnInfoMask } -func (r *RTPStats) setSnInfo(sn uint16, pktSize uint16, hdrSize uint16, payloadSize uint16, marker bool, isOutOfOrder bool) { +func (r *RTPStats) setSnInfo(esn uint32, ehsn uint32, pktSize uint16, hdrSize uint16, payloadSize uint16, marker bool, isOutOfOrder bool) { writePtr := 0 - ooo := (sn - r.highestSN) > (1 << 15) + ooo := (esn - ehsn) > (1 << 31) if !ooo { writePtr = r.snInfoWritePtr r.snInfoWritePtr = (writePtr + 1) & SnInfoMask } else { - writePtr = r.getSnInfoOutOfOrderPtr(sn) + writePtr = r.getSnInfoOutOfOrderPtr(esn, ehsn) if writePtr < 0 { return } @@ -1559,8 +1508,8 @@ func (r *RTPStats) setSnInfo(sn uint16, pktSize uint16, hdrSize uint16, payloadS snInfo.isOutOfOrder = isOutOfOrder } -func (r *RTPStats) clearSnInfos(startInclusive uint16, endExclusive uint16) { - for sn := startInclusive; sn != endExclusive; sn++ { +func (r *RTPStats) clearSnInfos(extStartInclusive uint32, extEndExclusive uint32) { + for esn := extStartInclusive; esn != extEndExclusive; esn++ { snInfo := &r.snInfos[r.snInfoWritePtr] snInfo.pktSize = 0 snInfo.hdrSize = 0 @@ -1571,8 +1520,8 @@ func (r *RTPStats) clearSnInfos(startInclusive uint16, endExclusive uint16) { } } -func (r *RTPStats) isSnInfoLost(sn uint16) bool { - readPtr := r.getSnInfoOutOfOrderPtr(sn) +func (r *RTPStats) isSnInfoLost(esn uint32, ehsn uint32) bool { + readPtr := r.getSnInfoOutOfOrderPtr(esn, ehsn) if readPtr < 0 { return false } @@ -1581,10 +1530,10 @@ func (r *RTPStats) isSnInfoLost(sn uint16) bool { return snInfo.pktSize == 0 } -func (r *RTPStats) getIntervalStats(startInclusive uint16, endExclusive uint16) (intervalStats IntervalStats) { +func (r *RTPStats) getIntervalStats(extStartInclusive uint32, extEndExclusive uint32) (intervalStats IntervalStats) { packetsNotFound := uint32(0) - processSN := func(sn uint16) { - readPtr := r.getSnInfoOutOfOrderPtr(sn) + processESN := func(esn uint32, ehsn uint32) { + readPtr := r.getSnInfoOutOfOrderPtr(esn, ehsn) if readPtr < 0 { packetsNotFound++ return @@ -1614,24 +1563,18 @@ func (r *RTPStats) getIntervalStats(startInclusive uint16, endExclusive uint16) } } - if startInclusive == endExclusive { - // do a full cycle - for sn := uint32(0); sn < NumSequenceNumbers; sn++ { - processSN(uint16(sn)) - } - } else { - for sn := startInclusive; sn != endExclusive; sn++ { - processSN(sn) - } + ehsn := r.sequenceNumber.GetExtendedHighest() + for esn := extStartInclusive; esn != extEndExclusive; esn++ { + processESN(esn, ehsn) } if packetsNotFound != 0 { r.logger.Errorw( "could not find some packets", nil, - "start", startInclusive, - "end", endExclusive, + "start", extStartInclusive, + "end", extEndExclusive, "count", packetsNotFound, - "highestSN", r.highestSN, + "highestSN", r.sequenceNumber.GetExtendedHighest(), ) } return @@ -1676,7 +1619,7 @@ func (r *RTPStats) updateJitter(rtph *rtp.Header, packetTime time.Time) { func (r *RTPStats) getDrift() (packetDrift driftResult, reportDrift driftResult) { packetDrift.timeSinceFirst = r.highestTime.Sub(r.firstTime) - packetDrift.rtpDiffSinceFirst = getExtTS(r.highestTS, r.tsCycles) - r.extStartTS + packetDrift.rtpDiffSinceFirst = r.timestamp.GetExtendedHighest() - r.timestamp.GetExtendedStart() packetDrift.driftSamples = int64(packetDrift.rtpDiffSinceFirst - uint64(packetDrift.timeSinceFirst.Nanoseconds()*int64(r.params.ClockRate)/1e9)) packetDrift.driftMs = (float64(packetDrift.driftSamples) * 1000) / float64(r.params.ClockRate) if packetDrift.timeSinceFirst.Seconds() != 0 { @@ -1715,10 +1658,11 @@ func (r *RTPStats) getAndResetSnapshot(snapshotId uint32, override bool) (*Snaps then := r.snapshots[snapshotId] if then == nil { + extStartSN := r.sequenceNumber.GetExtendedHighest() then = &Snapshot{ startTime: r.startTime, - extStartSN: r.extStartSN, - extStartSNOverridden: r.extStartSN, + extStartSN: extStartSN, + extStartSNOverridden: extStartSN, } r.snapshots[snapshotId] = then } @@ -1733,7 +1677,7 @@ func (r *RTPStats) getAndResetSnapshot(snapshotId uint32, override bool) (*Snaps // snapshot now r.snapshots[snapshotId] = &Snapshot{ startTime: startTime, - extStartSN: r.getExtHighestSN() + 1, + extStartSN: r.sequenceNumber.GetExtendedHighest() + 1, extStartSNOverridden: r.getExtHighestSNAdjusted() + 1, packetsDuplicate: r.packetsDuplicate, bytesDuplicate: r.bytesDuplicate, @@ -1754,14 +1698,6 @@ func (r *RTPStats) getAndResetSnapshot(snapshotId uint32, override bool) (*Snaps // ---------------------------------- -func getExtSN(sn uint16, cycles uint16) uint32 { - return (uint32(cycles) << 16) | uint32(sn) -} - -func getExtTS(ts uint32, cycles uint32) uint64 { - return (uint64(cycles) << 32) | uint64(ts) -} - func AggregateRTPStats(statsList []*livekit.RTPStats) *livekit.RTPStats { if len(statsList) == 0 { return nil diff --git a/pkg/sfu/buffer/rtpstats_test.go b/pkg/sfu/buffer/rtpstats_test.go index 74e4774f4..1a803578d 100644 --- a/pkg/sfu/buffer/rtpstats_test.go +++ b/pkg/sfu/buffer/rtpstats_test.go @@ -90,8 +90,10 @@ func TestRTPStats_Update(t *testing.T) { flowState := r.Update(&packet.Header, len(packet.Payload), 0, time.Now()) require.False(t, flowState.HasLoss) require.True(t, r.initialized) - require.Equal(t, sequenceNumber, r.highestSN) - require.Equal(t, timestamp, r.highestTS) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) // in-order, no loss sequenceNumber++ @@ -99,15 +101,19 @@ func TestRTPStats_Update(t *testing.T) { packet = getPacket(sequenceNumber, timestamp, 1000) flowState = r.Update(&packet.Header, len(packet.Payload), 0, time.Now()) require.False(t, flowState.HasLoss) - require.Equal(t, sequenceNumber, r.highestSN) - require.Equal(t, timestamp, r.highestTS) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) // out-of-order packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) flowState = r.Update(&packet.Header, len(packet.Payload), 0, time.Now()) require.False(t, flowState.HasLoss) - require.Equal(t, sequenceNumber, r.highestSN) - require.Equal(t, timestamp, r.highestTS) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) require.Equal(t, uint32(1), r.packetsOutOfOrder) require.Equal(t, uint32(0), r.packetsDuplicate) @@ -115,8 +121,10 @@ func TestRTPStats_Update(t *testing.T) { packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) flowState = r.Update(&packet.Header, len(packet.Payload), 0, time.Now()) require.False(t, flowState.HasLoss) - require.Equal(t, sequenceNumber, r.highestSN) - require.Equal(t, timestamp, r.highestTS) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) require.Equal(t, uint32(2), r.packetsOutOfOrder) require.Equal(t, uint32(1), r.packetsDuplicate) @@ -134,12 +142,14 @@ func TestRTPStats_Update(t *testing.T) { packet = getPacket(sequenceNumber-15, timestamp-45000, 1000) flowState = r.Update(&packet.Header, len(packet.Payload), 0, time.Now()) require.False(t, flowState.HasLoss) - require.Equal(t, sequenceNumber, r.highestSN) - require.Equal(t, timestamp, r.highestTS) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) require.Equal(t, uint32(3), r.packetsOutOfOrder) require.Equal(t, uint32(1), r.packetsDuplicate) require.Equal(t, uint32(16), r.packetsLost) - intervalStats := r.getIntervalStats(uint16(r.extStartSN), uint16(r.getExtHighestSN()+1)) + intervalStats := r.getIntervalStats(r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest()+1) require.Equal(t, uint32(16), intervalStats.packetsLost) r.Stop() diff --git a/pkg/sfu/utils/wraparound.go b/pkg/sfu/utils/wraparound.go index f9f102e1c..8845f2c19 100644 --- a/pkg/sfu/utils/wraparound.go +++ b/pkg/sfu/utils/wraparound.go @@ -49,14 +49,14 @@ func (w *WrapAround[T, ET]) Seed(from *WrapAround[T, ET]) { w.cycles = from.cycles } -type wrapAroundUpdateResult[ET extendedNumber] struct { +type WrapAroundUpdateResult[ET extendedNumber] struct { IsRestart bool PreExtendedStart ET // valid only if IsRestart = true PreExtendedHighest ET ExtendedVal ET } -func (w *WrapAround[T, ET]) Update(val T) (result wrapAroundUpdateResult[ET]) { +func (w *WrapAround[T, ET]) Update(val T) (result WrapAroundUpdateResult[ET]) { if !w.initialized { result.PreExtendedHighest = ET(val) - 1 result.ExtendedVal = ET(val) @@ -82,10 +82,17 @@ func (w *WrapAround[T, ET]) Update(val T) (result wrapAroundUpdateResult[ET]) { } w.highest = val - result.ExtendedVal = ET(w.cycles)*w.fullRange + ET(val) + result.ExtendedVal = w.getExtendedHighest(w.cycles, val) return } +func (w *WrapAround[T, ET]) RollbackRestart(ev ET) { + if w.isWrapBack(w.start, T(ev)) { + w.cycles-- + } + w.start = T(ev) +} + func (w *WrapAround[T, ET]) ResetHighest(val T) { w.highest = val } @@ -103,14 +110,10 @@ func (w *WrapAround[T, ET]) GetHighest() T { } func (w *WrapAround[T, ET]) GetExtendedHighest() ET { - return ET(w.cycles)*w.fullRange + ET(w.highest) + return w.getExtendedHighest(w.cycles, w.highest) } func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtendedStart ET, extendedVal ET) { - isWrapBack := func() bool { - return ET(w.highest) < (w.fullRange>>1) && ET(val) >= (w.fullRange>>1) - } - // re-adjust start if necessary. The conditions are // 1. Not seen more than half the range yet // 1. wrap back compared to start and not completed a half cycle, sequences like (10, 65530) in uint16 space @@ -118,10 +121,10 @@ func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtended cycles := w.cycles totalNum := w.GetExtendedHighest() - w.GetExtendedStart() + 1 if totalNum > (w.fullRange >> 1) { - if isWrapBack() { + if w.isWrapBack(val, w.highest) { cycles-- } - extendedVal = ET(cycles)*w.fullRange + ET(val) + extendedVal = w.getExtendedHighest(cycles, val) return } @@ -130,17 +133,24 @@ func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtended isRestart = true preExtendedStart = w.GetExtendedStart() - if val > w.highest { - // wrap around + if w.isWrapBack(val, w.highest) { w.cycles = 1 cycles = 0 } w.start = val } else { - if isWrapBack() { + if w.isWrapBack(val, w.highest) { cycles-- } } - extendedVal = ET(cycles)*w.fullRange + ET(val) + extendedVal = w.getExtendedHighest(cycles, val) return } + +func (w *WrapAround[T, ET]) isWrapBack(earlier T, later T) bool { + return ET(later) < (w.fullRange>>1) && ET(earlier) >= (w.fullRange>>1) +} + +func (w *WrapAround[T, ET]) getExtendedHighest(cycles int, val T) ET { + return ET(cycles)*w.fullRange + ET(val) +} diff --git a/pkg/sfu/utils/wraparound_test.go b/pkg/sfu/utils/wraparound_test.go index 9e3b8e555..c01729108 100644 --- a/pkg/sfu/utils/wraparound_test.go +++ b/pkg/sfu/utils/wraparound_test.go @@ -25,7 +25,7 @@ func TestWrapAroundUint16(t *testing.T) { testCases := []struct { name string input uint16 - updated wrapAroundUpdateResult[uint32] + updated WrapAroundUpdateResult[uint32] start uint16 extendedStart uint32 highest uint16 @@ -35,7 +35,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "initialize", input: 10, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: 9, @@ -50,7 +50,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "reset start no wrap around", input: 8, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: true, PreExtendedStart: 10, PreExtendedHighest: 10, @@ -65,7 +65,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "reset start wrap around", input: (1 << 16) - 6, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: true, PreExtendedStart: 8, PreExtendedHighest: 10, @@ -80,7 +80,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "reset start again", input: (1 << 16) - 12, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: true, PreExtendedStart: (1 << 16) - 6, PreExtendedHighest: (1 << 16) + 10, @@ -95,7 +95,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "out of order - no restart", input: (1 << 16) - 3, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 16) + 10, @@ -110,7 +110,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "duplicate", input: 10, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 16) + 10, @@ -125,7 +125,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "big in-order jump", input: (1 << 15) - 10, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 16) + 10, @@ -140,7 +140,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "out-of-order after half range", input: (1 << 15) - 11, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 16) + (1 << 15) - 10, @@ -155,7 +155,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "wrap back out-of-order after half range", input: (1 << 16) - 1, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 16) + (1 << 15) - 10, @@ -170,7 +170,7 @@ func TestWrapAroundUint16(t *testing.T) { { name: "in-order", input: (1 << 15) + 3, - updated: wrapAroundUpdateResult[uint32]{ + updated: WrapAroundUpdateResult[uint32]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 16) + (1 << 15) - 10, @@ -194,12 +194,72 @@ func TestWrapAroundUint16(t *testing.T) { } } +func TestWrapAroundUint16RollbackRestart(t *testing.T) { + w := NewWrapAround[uint16, uint32]() + + // initialize + w.Update(23) + require.Equal(t, uint16(23), w.GetStart()) + require.Equal(t, uint32(23), w.GetExtendedStart()) + require.Equal(t, uint16(23), w.GetHighest()) + require.Equal(t, uint32(23), w.GetExtendedHighest()) + + // an in-order update + w.Update(25) + require.Equal(t, uint16(23), w.GetStart()) + require.Equal(t, uint32(23), w.GetExtendedStart()) + require.Equal(t, uint16(25), w.GetHighest()) + require.Equal(t, uint32(25), w.GetExtendedHighest()) + + // force restart without wrap + res := w.Update(12) + expectedResult := WrapAroundUpdateResult[uint32]{ + IsRestart: true, + PreExtendedStart: 23, + PreExtendedHighest: 25, + ExtendedVal: 12, + } + require.Equal(t, expectedResult, res) + require.Equal(t, uint16(12), w.GetStart()) + require.Equal(t, uint32(12), w.GetExtendedStart()) + require.Equal(t, uint16(25), w.GetHighest()) + require.Equal(t, uint32(25), w.GetExtendedHighest()) + + // roll back restart + w.RollbackRestart(res.PreExtendedStart) + require.Equal(t, uint16(23), w.GetStart()) + require.Equal(t, uint32(23), w.GetExtendedStart()) + require.Equal(t, uint16(25), w.GetHighest()) + require.Equal(t, uint32(25), w.GetExtendedHighest()) + + // force restart with wrap + res = w.Update(65533) + expectedResult = WrapAroundUpdateResult[uint32]{ + IsRestart: true, + PreExtendedStart: 23, + PreExtendedHighest: 25, + ExtendedVal: 65533, + } + require.Equal(t, expectedResult, res) + require.Equal(t, uint16(65533), w.GetStart()) + require.Equal(t, uint32(65533), w.GetExtendedStart()) + require.Equal(t, uint16(25), w.GetHighest()) + require.Equal(t, uint32(65536+25), w.GetExtendedHighest()) + + // roll back restart + w.RollbackRestart(res.PreExtendedStart) + require.Equal(t, uint16(23), w.GetStart()) + require.Equal(t, uint32(23), w.GetExtendedStart()) + require.Equal(t, uint16(25), w.GetHighest()) + require.Equal(t, uint32(25), w.GetExtendedHighest()) +} + func TestWrapAroundUint32(t *testing.T) { w := NewWrapAround[uint32, uint64]() testCases := []struct { name string input uint32 - updated wrapAroundUpdateResult[uint64] + updated WrapAroundUpdateResult[uint64] start uint32 extendedStart uint64 highest uint32 @@ -209,7 +269,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "initialize", input: 10, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: 9, @@ -224,7 +284,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "reset start no wrap around", input: 8, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: true, PreExtendedStart: 10, PreExtendedHighest: 10, @@ -239,7 +299,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "reset start wrap around", input: (1 << 32) - 6, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: true, PreExtendedStart: 8, PreExtendedHighest: 10, @@ -254,7 +314,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "reset start again", input: (1 << 32) - 12, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: true, PreExtendedStart: (1 << 32) - 6, PreExtendedHighest: (1 << 32) + 10, @@ -269,7 +329,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "duplicate", input: 10, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 32) + 10, @@ -284,7 +344,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "big in-order jump", input: 1 << 31, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 32) + 10, @@ -299,7 +359,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "out-of-order after half range", input: (1 << 31) - 1, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 32) + (1 << 31), @@ -314,7 +374,7 @@ func TestWrapAroundUint32(t *testing.T) { { name: "in-order", input: (1 << 31) + 3, - updated: wrapAroundUpdateResult[uint64]{ + updated: WrapAroundUpdateResult[uint64]{ IsRestart: false, PreExtendedStart: 0, PreExtendedHighest: (1 << 32) + (1 << 31),