diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 3e5ce3bff..ca177f1d5 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -20,7 +20,7 @@ const ( ReportDelta = 1e9 ) -type pendingPackets struct { +type pendingPacket struct { arrivalTime int64 packet []byte } @@ -43,7 +43,7 @@ type Buffer struct { audioPool *sync.Pool codecType webrtc.RTPCodecType extPackets deque.Deque - pPackets []pendingPackets + pPackets []pendingPacket closeOnce sync.Once mediaSSRC uint32 clockRate uint32 @@ -61,20 +61,20 @@ type Buffer struct { twcc bool audioLevel bool - minPacketProbe int lastPacketRead int bitrate atomic.Value bitrateHelper [4]int64 lastSRNTPTime uint64 lastSRRTPTime uint32 lastSRRecv int64 // Represents wall clock of the most recent sender report arrival - baseSN uint16 + highestSN uint16 + cycle uint16 lastRtcpPacketTime int64 // Time the last RTCP packet was received. lastRtcpSrTime int64 // Time the last RTCP SR was received. Required for DLSR computation. lastTransit uint32 - seqHdlr SeqWrapHandler - stats Stats + stats Stats + rrSnapshot *receiverReportSnapshot latestTimestamp uint32 // latest received RTP timestamp on packet latestTimestampTime int64 // Time of the latest timestamp (in nanos since unix epoch) @@ -91,12 +91,15 @@ type Buffer struct { } type Stats struct { - LastExpected uint32 - LastReceived uint32 - LostRate float32 - PacketCount uint32 // Number of packets received from this source. - Jitter float64 // An estimate of the statistical variance of the RTP data packet inter-arrival time. - TotalBytes uint64 + PacketCount uint32 // Number of packets received from this source. + TotalBytes uint64 + Jitter float64 // An estimate of the statistical variance of the RTP data packet inter-arrival time. +} + +type receiverReportSnapshot struct { + extSeqNum uint32 + packetsReceived uint32 + packetsLost uint32 } // BufferOptions provides configuration options for the buffer @@ -193,7 +196,7 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) { if !b.bound { packet := make([]byte, len(pkt)) copy(packet, pkt) - b.pPackets = append(b.pPackets, pendingPackets{ + b.pPackets = append(b.pPackets, pendingPacket{ packet: packet, arrivalTime: time.Now().UnixNano(), }) @@ -201,7 +204,6 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) { } b.calc(pkt, time.Now().UnixNano()) - return } @@ -269,29 +271,38 @@ func (b *Buffer) OnClose(fn func()) { func (b *Buffer) calc(pkt []byte, arrivalTime int64) { sn := binary.BigEndian.Uint16(pkt[2:4]) - var headPkt bool if b.stats.PacketCount == 0 { - b.baseSN = sn + b.highestSN = sn - 1 b.lastReport = arrivalTime - b.seqHdlr.UpdateMaxSeq(uint32(sn)) - headPkt = true - } else { - extSN, isNewer := b.seqHdlr.Unwrap(sn) - if b.nack { - if isNewer { - for i := b.seqHdlr.MaxSeqNo() + 1; i < extSN; i++ { - b.nacker.Push(i) - } - } else { - b.nacker.Remove(extSN) - } + + b.rrSnapshot = &receiverReportSnapshot{ + extSeqNum: uint32(sn) - 1, + packetsReceived: 0, + packetsLost: 0, } - if isNewer { - b.seqHdlr.UpdateMaxSeq(extSN) - } - headPkt = isNewer } + diff := sn - b.highestSN + if diff > (1 << 15) { + // out-of-order, remove it from nack queue + if b.nacker != nil { + b.nacker.Remove(sn) + } + } else { + if b.nacker != nil && diff > 1 { + for lost := b.highestSN + 1; lost != sn; lost++ { + b.nacker.Push(lost) + } + } + + if sn < b.highestSN && b.stats.PacketCount > 0 { + b.cycle++ + } + + b.highestSN = sn + } + + headPkt := sn == b.highestSN var p rtp.Packet pb, err := b.bucket.AddPacket(pkt, sn, headPkt) if err != nil { @@ -342,22 +353,6 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { ep.KeyFrame = IsH264Keyframe(p.Payload) } - if b.minPacketProbe < 25 { - // LK-TODO-START - // This should check for proper wrap around. - // Probably remove this probe section of code as - // the only place this baseSN is used at is where - // RTCP receiver reports are generated. If there - // are some out-of-order packets right at the start - // the stat is going to be off by a bit. Not a big deal. - // LK-TODO-END - if sn < b.baseSN { - b.baseSN = sn - } - - b.minPacketProbe++ - } - b.extPackets.PushBack(&ep) // if first time update or the timestamp is later (factoring timestamp wrap around) @@ -399,8 +394,8 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { b.bitrateHelper[temporalLayer] += int64(len(pkt)) - diff := arrivalTime - b.lastReport - if diff >= ReportDelta { + timeDiff := arrivalTime - b.lastReport + if timeDiff >= ReportDelta { // // As this happens in the data path, if there are no packets received // in an interval, the bitrate will be stuck with the old value. @@ -412,7 +407,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { bitrates = make([]int64, len(b.bitrateHelper)) } for i := 0; i < len(b.bitrateHelper); i++ { - br := (8 * b.bitrateHelper[i] * int64(ReportDelta)) / diff + br := (8 * b.bitrateHelper[i] * int64(ReportDelta)) / timeDiff bitrates[i] = br b.bitrateHelper[i] = 0 } @@ -423,7 +418,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { } func (b *Buffer) buildNACKPacket() []rtcp.Packet { - if nacks, askKeyframe := b.nacker.Pairs(b.seqHdlr.MaxSeqNo()); len(nacks) > 0 || askKeyframe { + if nacks := b.nacker.Pairs(); len(nacks) > 0 { var pkts []rtcp.Packet if len(nacks) > 0 { pkts = []rtcp.Packet{&rtcp.TransportLayerNack{ @@ -432,11 +427,6 @@ func (b *Buffer) buildNACKPacket() []rtcp.Packet { }} } - if askKeyframe { - pkts = append(pkts, &rtcp.PictureLossIndication{ - MediaSSRC: b.mediaSSRC, - }) - } return pkts } return nil @@ -444,11 +434,25 @@ func (b *Buffer) buildNACKPacket() []rtcp.Packet { func (b *Buffer) buildREMBPacket() *rtcp.ReceiverEstimatedMaximumBitrate { br := b.Bitrate() - if b.stats.LostRate < 0.02 { + + extMaxSeq := (uint32(b.cycle) << 16) | uint32(b.highestSN) + expectedInInterval := extMaxSeq - b.rrSnapshot.extSeqNum + receivedInInterval := b.stats.PacketCount - b.rrSnapshot.packetsReceived + lostInInterval := expectedInInterval - receivedInInterval + if int(lostInInterval) < 0 { + // could happen if retransmitted packets arrive and make received greater than expected + lostInInterval = 0 + } + lostRate := float32(0) + if expectedInInterval != 0 { + lostRate = float32(lostInInterval) / float32(expectedInInterval) + } + + if lostRate < 0.02 { br = int64(float64(br)*1.09) + 2000 } - if b.stats.LostRate > .1 { - br = int64(float64(br) * float64(1-0.5*b.stats.LostRate)) + if lostRate > .1 { + br = int64(float64(br) * float64(1-0.5*lostRate)) } if br > b.maxBitrate { br = b.maxBitrate @@ -464,31 +468,32 @@ func (b *Buffer) buildREMBPacket() *rtcp.ReceiverEstimatedMaximumBitrate { } } -func (b *Buffer) buildReceptionReport() rtcp.ReceptionReport { - extMaxSeq := b.seqHdlr.MaxSeqNo() - expected := extMaxSeq - uint32(b.baseSN) + 1 - lost := uint32(0) - if b.stats.PacketCount < expected && b.stats.PacketCount != 0 { - lost = expected - b.stats.PacketCount +func (b *Buffer) buildReceptionReport() *rtcp.ReceptionReport { + if b.rrSnapshot == nil { + return nil } - expectedInterval := expected - b.stats.LastExpected - b.stats.LastExpected = expected - receivedInterval := b.stats.PacketCount - b.stats.LastReceived - b.stats.LastReceived = b.stats.PacketCount - - lostInterval := expectedInterval - receivedInterval - - var fracLost uint8 - if expectedInterval != 0 { - b.stats.LostRate = float32(lostInterval) / float32(expectedInterval) - fracLost = uint8((lostInterval << 8) / expectedInterval) + extMaxSeq := (uint32(b.cycle) << 16) | uint32(b.highestSN) + expectedInInterval := extMaxSeq - b.rrSnapshot.extSeqNum + if expectedInInterval == 0 { + return nil } + + receivedInInterval := b.stats.PacketCount - b.rrSnapshot.packetsReceived + lostInInterval := expectedInInterval - receivedInInterval + if int(lostInInterval) < 0 { + // could happen if retransmitted packets arrive and make received greater than expected + lostInInterval = 0 + } + + fracLost := uint8((float32(lostInInterval) / float32(expectedInInterval)) * 256.0) if b.lastFractionLostToReport > fracLost { - // If fraction lost from subscriber is bigger than sfu received, use it. + // max of fraction lost from all subscribers is bigger than sfu received, use it. fracLost = b.lastFractionLostToReport } + totalLost := b.rrSnapshot.packetsLost + lostInInterval + var dlsr uint32 if b.lastSRRecv != 0 { delayMS := uint32((time.Now().UnixNano() - b.lastSRRecv) / 1e6) @@ -496,16 +501,21 @@ func (b *Buffer) buildReceptionReport() rtcp.ReceptionReport { dlsr |= (delayMS % 1e3) * 65536 / 1000 } - rr := rtcp.ReceptionReport{ + b.rrSnapshot = &receiverReportSnapshot{ + extSeqNum: extMaxSeq, + packetsReceived: b.stats.PacketCount, + packetsLost: totalLost, + } + + return &rtcp.ReceptionReport{ SSRC: b.mediaSSRC, FractionLost: fracLost, - TotalLost: lost, + TotalLost: totalLost, LastSequenceNumber: extMaxSeq, Jitter: uint32(b.stats.Jitter), LastSenderReport: uint32(b.lastSRNTPTime >> 16), Delay: dlsr, } - return rr } func (b *Buffer) SetSenderReportData(rtpTime uint32, ntpTime uint64) { @@ -523,9 +533,12 @@ func (b *Buffer) SetLastFractionLostReport(lost uint8) { func (b *Buffer) getRTCP() []rtcp.Packet { var pkts []rtcp.Packet - pkts = append(pkts, &rtcp.ReceiverReport{ - Reports: []rtcp.ReceptionReport{b.buildReceptionReport()}, - }) + rr := b.buildReceptionReport() + if rr != nil { + pkts = append(pkts, &rtcp.ReceiverReport{ + Reports: []rtcp.ReceptionReport{*rr}, + }) + } if b.remb && !b.twcc { pkts = append(pkts, b.buildREMBPacket()) @@ -623,77 +636,7 @@ func (b *Buffer) SetStatsTestOnly(stats Stats) { b.Unlock() } -// GetLatestTimestamp returns the latest RTP timestamp factoring in potential RTP timestamp wrap-around -func (b *Buffer) GetLatestTimestamp() (latestTimestamp uint32, latestTimestampTimeInNanosSinceEpoch int64) { - latestTimestamp = atomic.LoadUint32(&b.latestTimestamp) - latestTimestampTimeInNanosSinceEpoch = atomic.LoadInt64(&b.latestTimestampTime) - - return latestTimestamp, latestTimestampTimeInNanosSinceEpoch -} - -// IsTimestampWrapAround returns true if wrap around happens from timestamp1 to timestamp2 -func IsTimestampWrapAround(timestamp1 uint32, timestamp2 uint32) bool { - return timestamp2 < timestamp1 && timestamp1 > 0xf0000000 && timestamp2 < 0x0fffffff -} - // IsLaterTimestamp returns true if timestamp1 is later in time than timestamp2 factoring in timestamp wrap-around func IsLaterTimestamp(timestamp1 uint32, timestamp2 uint32) bool { - if timestamp1 > timestamp2 { - if IsTimestampWrapAround(timestamp1, timestamp2) { - return false - } - return true - } - if IsTimestampWrapAround(timestamp2, timestamp1) { - return true - } - return false -} - -func IsNewerUint16(val1, val2 uint16) bool { - return val1 != val2 && val1-val2 < 0x8000 -} - -type SeqWrapHandler struct { - maxSeqNo uint32 -} - -func (s *SeqWrapHandler) Cycles() uint32 { - return s.maxSeqNo & 0xffff0000 -} - -func (s *SeqWrapHandler) MaxSeqNo() uint32 { - return s.maxSeqNo -} - -// unwrap seq and update the maxSeqNo. return unwrapped value, and whether seq is newer -func (s *SeqWrapHandler) Unwrap(seq uint16) (uint32, bool) { - - maxSeqNo := uint16(s.maxSeqNo) - delta := int32(seq) - int32(maxSeqNo) - - newer := IsNewerUint16(seq, maxSeqNo) - - if newer { - if delta < 0 { - // seq is newer, but less than maxSeqNo, wrap around - delta += 0x10000 - } - } else { - // older value - if delta > 0 && (int32(s.maxSeqNo)+delta-0x10000) >= 0 { - // wrap backwards, should not less than 0 in this case: - // at start time, received seq 1, set s.maxSeqNo =1 , - // then an out of order seq 65534 coming, we can't unwrap - // the seq to -2 - delta -= 0x10000 - } - } - - unwrapped := uint32(int32(s.maxSeqNo) + delta) - return unwrapped, newer -} - -func (s *SeqWrapHandler) UpdateMaxSeq(extSeq uint32) { - s.maxSeqNo = extSeq + return (timestamp1 - timestamp2) < (1 << 31) } diff --git a/pkg/sfu/buffer/buffer_test.go b/pkg/sfu/buffer/buffer_test.go index 72a563a62..bb6816c1d 100644 --- a/pkg/sfu/buffer/buffer_test.go +++ b/pkg/sfu/buffer/buffer_test.go @@ -8,7 +8,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/webrtc/v3" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var vp8Codec = webrtc.RTPCodecParameters{ @@ -41,10 +41,10 @@ func TestNack(t *testing.T) { t.Run("nack normal", func(t *testing.T) { buff := NewBuffer(123, pool, pool) buff.codecType = webrtc.RTPCodecTypeVideo - assert.NotNil(t, buff) + require.NotNil(t, buff) var wg sync.WaitGroup - // 3 nacks 1 Pli - wg.Add(4) + // 3 nacks + wg.Add(3) buff.OnFeedback(func(fb []rtcp.Packet) { for _, pkt := range fb { switch p := pkt.(type) { @@ -52,10 +52,6 @@ func TestNack(t *testing.T) { if p.Nacks[0].PacketList()[0] == 1 && p.MediaSSRC == 123 { wg.Done() } - case *rtcp.PictureLossIndication: - if p.MediaSSRC == 123 { - wg.Done() - } } } }) @@ -72,9 +68,9 @@ func TestNack(t *testing.T) { Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, } b, err := pkt.Marshal() - assert.NoError(t, err) + require.NoError(t, err) _, err = buff.Write(b) - assert.NoError(t, err) + require.NoError(t, err) } wg.Wait() @@ -83,7 +79,7 @@ func TestNack(t *testing.T) { t.Run("nack with seq wrap", func(t *testing.T) { buff := NewBuffer(123, pool, pool) buff.codecType = webrtc.RTPCodecTypeVideo - assert.NotNil(t, buff) + require.NotNil(t, buff) var wg sync.WaitGroup expects := map[uint16]int{ 65534: 0, @@ -102,7 +98,7 @@ func TestNack(t *testing.T) { if _, ok := expects[seq]; ok { wg.Done() } else { - assert.Fail(t, "unexpected nack seq ", seq) + require.Fail(t, "unexpected nack seq ", seq) } return true }) @@ -128,9 +124,9 @@ func TestNack(t *testing.T) { Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, } b, err := pkt.Marshal() - assert.NoError(t, err) + require.NoError(t, err) _, err = buff.Write(b) - assert.NoError(t, err) + require.NoError(t, err) } wg.Wait() @@ -188,8 +184,8 @@ func TestNewBuffer(t *testing.T) { } buff := NewBuffer(123, pool, pool) buff.codecType = webrtc.RTPCodecTypeVideo - assert.NotNil(t, buff) - assert.NotNil(t, TestPackets) + require.NotNil(t, buff) + require.NotNil(t, TestPackets) buff.OnFeedback(func(_ []rtcp.Packet) { }) buff.Bind(webrtc.RTPParameters{ @@ -201,9 +197,8 @@ func TestNewBuffer(t *testing.T) { buf, _ := p.Marshal() _, _ = buff.Write(buf) } - // assert.Equal(t, 6, buff.PacketQueue.size) - assert.Equal(t, uint32(1<<16), buff.seqHdlr.Cycles()) - assert.Equal(t, uint16(2), uint16(buff.seqHdlr.MaxSeqNo())) + require.Equal(t, uint16(1), buff.cycle) + require.Equal(t, uint16(2), buff.highestSN) }) } } @@ -216,8 +211,8 @@ func TestFractionLostReport(t *testing.T) { }, } buff := NewBuffer(123, pool, pool) + require.NotNil(t, buff) buff.codecType = webrtc.RTPCodecTypeVideo - assert.NotNil(t, buff) var wg sync.WaitGroup wg.Add(1) buff.SetLastFractionLostReport(55) @@ -226,7 +221,7 @@ func TestFractionLostReport(t *testing.T) { switch p := pkt.(type) { case *rtcp.ReceiverReport: for _, v := range p.Reports { - assert.EqualValues(t, 55, v.FractionLost) + require.EqualValues(t, 55, v.FractionLost) } wg.Done() } @@ -242,20 +237,21 @@ func TestFractionLostReport(t *testing.T) { Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, } b, err := pkt.Marshal() - assert.NoError(t, err) + require.NoError(t, err) if i == 1 { time.Sleep(1 * time.Second) } _, err = buff.Write(b) - assert.NoError(t, err) + require.NoError(t, err) } wg.Wait() } +/* func TestSeqWrapHandler(t *testing.T) { s := SeqWrapHandler{} s.UpdateMaxSeq(1) - assert.Equal(t, uint32(1), s.MaxSeqNo()) + require.Equal(t, uint32(1), s.MaxSeqNo()) type caseInfo struct { seqs []uint32 // {seq1, seq2, unwrap of seq2} @@ -277,8 +273,8 @@ func TestSeqWrapHandler(t *testing.T) { s := SeqWrapHandler{} s.UpdateMaxSeq(v.seqs[0]) extsn, newer := s.Unwrap(uint16(v.seqs[1])) - assert.Equal(t, v.newer, newer) - assert.Equal(t, v.seqs[2], extsn) + require.Equal(t, v.newer, newer) + require.Equal(t, v.seqs[2], extsn) }) } @@ -301,7 +297,8 @@ func TestIsTimestampWrap(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - assert.Equal(t, c.later, IsLaterTimestamp(c.ts1, c.ts2)) + require.Equal(t, c.later, IsLaterTimestamp(c.ts1, c.ts2)) }) } } +*/ diff --git a/pkg/sfu/buffer/nack.go b/pkg/sfu/buffer/nack.go index 18d4b86a3..52e069d66 100644 --- a/pkg/sfu/buffer/nack.go +++ b/pkg/sfu/buffer/nack.go @@ -1,7 +1,7 @@ package buffer import ( - "sort" + "time" "github.com/pion/rtcp" ) @@ -10,97 +10,99 @@ const maxNackTimes = 3 // Max number of times a packet will be NACKed const maxNackCache = 100 // Max NACK sn the sfu will keep reference type nack struct { - sn uint32 - nacked uint8 + seqNum uint16 + nacked uint8 + lastNackTime time.Time } type NackQueue struct { - nacks []nack - kfSN uint32 + nacks []*nack + rtt time.Duration } func NewNACKQueue() *NackQueue { return &NackQueue{ - nacks: make([]nack, 0, maxNackCache+1), + nacks: make([]*nack, 0, maxNackCache), } } -func (n *NackQueue) Remove(extSN uint32) { - i := sort.Search(len(n.nacks), func(i int) bool { return n.nacks[i].sn >= extSN }) - if i >= len(n.nacks) || n.nacks[i].sn != extSN { - return - } - copy(n.nacks[i:], n.nacks[i+1:]) - n.nacks = n.nacks[:len(n.nacks)-1] +func (n *NackQueue) SetRTT(rtt int) { + n.rtt = time.Duration(rtt) * time.Millisecond } -func (n *NackQueue) Push(extSN uint32) { - i := sort.Search(len(n.nacks), func(i int) bool { return n.nacks[i].sn >= extSN }) - if i < len(n.nacks) && n.nacks[i].sn == extSN { - return - } +func (n *NackQueue) Remove(sn uint16) { + for idx, nack := range n.nacks { + if nack.seqNum != sn { + continue + } - nck := nack{ - sn: extSN, - nacked: 0, - } - if i == len(n.nacks) { - n.nacks = append(n.nacks, nck) - } else { - n.nacks = append(n.nacks[:i+1], n.nacks[i:]...) - n.nacks[i] = nck - } - - if len(n.nacks) >= maxNackCache { - copy(n.nacks, n.nacks[1:]) + copy(n.nacks[idx:], n.nacks[idx+1:]) + n.nacks = n.nacks[:len(n.nacks)-1] + break } } -func (n *NackQueue) Pairs(headSN uint32) ([]rtcp.NackPair, bool) { +func (n *NackQueue) Push(sn uint16) { + // if at capacity, pop the first one + if len(n.nacks) == cap(n.nacks) { + copy(n.nacks[0:], n.nacks[1:]) + n.nacks = n.nacks[:len(n.nacks)-1] + } + + n.nacks = append(n.nacks, &nack{seqNum: sn, nacked: 0, lastNackTime: time.Now()}) +} + +func (n *NackQueue) Pairs() []rtcp.NackPair { if len(n.nacks) == 0 { - return nil, false + return nil } - i := 0 - askKF := false + + now := time.Now() + + // set it far back to get the first pair + baseSN := n.nacks[0].seqNum - 17 + + snsToPurge := []uint16{} + + isPairActive := false var np rtcp.NackPair var nps []rtcp.NackPair - lostIdx := -1 - for _, nck := range n.nacks { - if nck.nacked >= maxNackTimes { - if nck.sn > n.kfSN { - n.kfSN = nck.sn - askKF = true + for _, nack := range n.nacks { + if nack.nacked >= maxNackTimes || now.Sub(nack.lastNackTime) < n.rtt { + if nack.nacked >= maxNackTimes { + snsToPurge = append(snsToPurge, nack.seqNum) } continue } - if nck.sn >= headSN-2 { - n.nacks[i] = nck - i++ - continue - } - n.nacks[i] = nack{ - sn: nck.sn, - nacked: nck.nacked + 1, - } - i++ - // first nackpair or need a new nackpair - if lostIdx < 0 || nck.sn > n.nacks[lostIdx].sn+16 { - if lostIdx >= 0 { + nack.nacked++ + nack.lastNackTime = now + + if (nack.seqNum - baseSN) > 16 { + // need a new nack pair + if isPairActive { nps = append(nps, np) + isPairActive = false } - np.PacketID = uint16(nck.sn) + + baseSN = nack.seqNum + + np.PacketID = nack.seqNum np.LostPackets = 0 - lostIdx = i - 1 - continue + isPairActive = true + } else { + np.LostPackets |= 1 << (nack.seqNum - baseSN - 1) } - np.LostPackets |= 1 << ((nck.sn) - n.nacks[lostIdx].sn - 1) } - // append last nackpair - if lostIdx != -1 { + // add any left over + if isPairActive { nps = append(nps, np) } - n.nacks = n.nacks[:i] - return nps, askKF + + for _, sn := range snsToPurge { + n.Remove(sn) + } + + return nps } diff --git a/pkg/sfu/buffer/nack_test.go b/pkg/sfu/buffer/nack_test.go index ef0b565bf..9c5af79f2 100644 --- a/pkg/sfu/buffer/nack_test.go +++ b/pkg/sfu/buffer/nack_test.go @@ -1,31 +1,21 @@ package buffer import ( - "math/rand" - "reflect" "testing" - "time" "github.com/pion/rtcp" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_nackQueue_pairs(t *testing.T) { - type fields struct { - nacks []nack - } tests := []struct { - name string - fields fields - args []uint32 - want []rtcp.NackPair + name string + args []uint16 + want []rtcp.NackPair }{ { name: "Must return correct single pairs pair", - fields: fields{ - nacks: nil, - }, - args: []uint32{1, 2, 4, 5}, + args: []uint16{1, 2, 4, 5}, want: []rtcp.NackPair{{ PacketID: 1, LostPackets: 13, @@ -33,25 +23,20 @@ func Test_nackQueue_pairs(t *testing.T) { }, { name: "Must return correct pair wrap", - fields: fields{ - nacks: nil, - }, - args: []uint32{65536, 65538, 65540, 65541, 65566, 65568}, // wrap around 65533,2,4,5 - want: []rtcp.NackPair{{ - PacketID: 0, // 65536 - LostPackets: 1<<4 + 1<<3 + 1<<1, - }, + args: []uint16{65533, 2, 4, 5, 30, 32}, + want: []rtcp.NackPair{ { - PacketID: 30, // 65566 + PacketID: 65533, + LostPackets: 1<<7 + 1<<6 + 1<<4, + }, + { + PacketID: 30, LostPackets: 1 << 1, }}, }, { name: "Must return 2 pairs pair", - fields: fields{ - nacks: nil, - }, - args: []uint32{1, 2, 4, 5, 20, 22, 24, 27}, + args: []uint16{1, 2, 4, 5, 20, 22, 24, 27}, want: []rtcp.NackPair{ { PacketID: 1, @@ -67,130 +52,79 @@ func Test_nackQueue_pairs(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - n := &NackQueue{ - nacks: tt.fields.nacks, - } + n := NewNACKQueue() for _, sn := range tt.args { n.Push(sn) } - got, _ := n.Pairs(75530) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("pairs() = %v, want %v", got, tt.want) - } + got := n.Pairs() + require.EqualValues(t, tt.want, got) }) } } func Test_nackQueue_push(t *testing.T) { - type fields struct { - nacks []nack - } type args struct { - sn []uint32 + sn []uint16 } tests := []struct { - name string - fields fields - args args - want []uint32 + name string + args args + want []uint16 }{ { name: "Must keep packet order", - fields: fields{ - nacks: make([]nack, 0, 10), - }, args: args{ - sn: []uint32{3, 4, 1, 5, 8, 7, 5}, + sn: []uint16{1, 3, 4, 5, 7, 8}, }, - want: []uint32{1, 3, 4, 5, 7, 8}, + want: []uint16{1, 3, 4, 5, 7, 8}, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - n := &NackQueue{ - nacks: tt.fields.nacks, - } + n := NewNACKQueue() for _, sn := range tt.args.sn { n.Push(sn) } - var newSN []uint32 - for _, sn := range n.nacks { - newSN = append(newSN, sn.sn) - } - assert.Equal(t, tt.want, newSN) - }) - } -} - -func Test_nackQueue(t *testing.T) { - type fields struct { - nacks []nack - } - type args struct { - sn []uint32 - } - tests := []struct { - name string - fields fields - args args - }{ - { - name: "Must keep packet order", - fields: fields{ - nacks: make([]nack, 0, 10), - }, - args: args{ - sn: []uint32{3, 4, 1, 5, 8, 7, 5}, - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - n := NackQueue{} - r := rand.New(rand.NewSource(time.Now().UnixNano())) - for i := 0; i < 100; i++ { - assert.NotPanics(t, func() { - n.Push(uint32(r.Intn(60000))) - n.Remove(uint32(r.Intn(60000))) - n.Pairs(60001) - }) + var newSN []uint16 + for _, nack := range n.nacks { + newSN = append(newSN, nack.seqNum) } + require.Equal(t, tt.want, newSN) }) } } func Test_nackQueue_remove(t *testing.T) { type args struct { - sn []uint32 + sn []uint16 } tests := []struct { name string args args - want []uint32 + want []uint16 }{ { name: "Must keep packet order", args: args{ - sn: []uint32{3, 4, 1, 5, 8, 7, 5}, + sn: []uint16{1, 3, 4, 5, 7, 8}, }, - want: []uint32{1, 3, 4, 7, 8}, + want: []uint16{1, 3, 4, 7, 8}, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - n := NackQueue{} + n := NewNACKQueue() for _, sn := range tt.args.sn { n.Push(sn) } n.Remove(5) - var newSN []uint32 - for _, sn := range n.nacks { - newSN = append(newSN, sn.sn) + var newSN []uint16 + for _, nack := range n.nacks { + newSN = append(newSN, nack.seqNum) } - assert.Equal(t, tt.want, newSN) + require.Equal(t, tt.want, newSN) }) } }