From eed241cf7ff7a03164c14c1dc12c80e1bfa0ef26 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 24 Feb 2022 13:43:12 +0530 Subject: [PATCH] Refactor Bucket a little bit (#459) --- pkg/sfu/buffer/bucket.go | 139 +++++++++++------ pkg/sfu/buffer/bucket_test.go | 263 ++++++++++++++++++++++++--------- pkg/sfu/buffer/helpers_test.go | 12 +- 3 files changed, 295 insertions(+), 119 deletions(-) diff --git a/pkg/sfu/buffer/bucket.go b/pkg/sfu/buffer/bucket.go index 094df8b42..f09574dce 100644 --- a/pkg/sfu/buffer/bucket.go +++ b/pkg/sfu/buffer/bucket.go @@ -6,7 +6,13 @@ import ( "math" ) -const maxPktSize = 1500 +const ( + maxPktSize = 1500 + pktSizeHeader = 2 + seqNumOffset = 2 + seqNumSize = 2 + invalidPktSize = uint16(65535) +) type Bucket struct { buf []byte @@ -19,32 +25,30 @@ type Bucket struct { } func NewBucket(buf *[]byte) *Bucket { - return &Bucket{ + b := &Bucket{ src: buf, buf: *buf, - maxSteps: int(math.Floor(float64(len(*buf))/float64(maxPktSize))) - 1, + maxSteps: int(math.Floor(float64(len(*buf)) / float64(maxPktSize))), } + + b.invalidate(0, b.maxSteps) + return b } func (b *Bucket) AddPacket(pkt []byte) ([]byte, error) { - sn := binary.BigEndian.Uint16(pkt[2:4]) + sn := binary.BigEndian.Uint16(pkt[seqNumOffset : seqNumOffset+seqNumSize]) if !b.init { b.headSN = sn - 1 b.init = true } + diff := sn - b.headSN if diff == 0 || diff > (1<<15) { // duplicate of last packet or out-of-order return b.set(sn, pkt) } - b.headSN = sn - for i := uint16(1); i < diff; i++ { - b.step++ - if b.step >= b.maxSteps { - b.step = 0 - } - } - return b.push(pkt), nil + + return b.push(sn, pkt) } func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) { @@ -65,53 +69,94 @@ func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) { return } -func (b *Bucket) push(pkt []byte) []byte { - binary.BigEndian.PutUint16(b.buf[b.step*maxPktSize:], uint16(len(pkt))) - off := b.step*maxPktSize + 2 - copy(b.buf[off:], pkt) - b.step++ - if b.step > b.maxSteps { - b.step = 0 - } - return b.buf[off : off+len(pkt)] +func (b *Bucket) push(sn uint16, pkt []byte) ([]byte, error) { + diff := int(sn-b.headSN) - 1 + b.headSN = sn + + // invalidate slots if there is a gap in the sequence number + b.invalidate(b.step, diff) + + // store headSN packet + off := b.offset(b.step + diff) + storedPkt := b.store(off, pkt) + + // for next packet + b.step = b.wrap(b.step + diff + 1) + + return storedPkt, nil } func (b *Bucket) get(sn uint16) []byte { - pos := b.step - int(b.headSN-sn+1) - if pos < 0 { - if pos*-1 > b.maxSteps+1 { - return nil - } - pos = b.maxSteps + pos + 1 - } - off := pos * maxPktSize - if off > len(b.buf) { + diff := b.headSN - sn + if int(diff) >= b.maxSteps { + // too old or asking for something ahead of headSN (which is effectively too old with wrap around) return nil } - if binary.BigEndian.Uint16(b.buf[off+4:off+6]) != sn { + + off := b.offset(b.step - int(diff) - 1) + if binary.BigEndian.Uint16(b.buf[off+pktSizeHeader+seqNumOffset:off+pktSizeHeader+seqNumOffset+seqNumSize]) != sn { return nil } - sz := int(binary.BigEndian.Uint16(b.buf[off : off+2])) - return b.buf[off+2 : off+2+sz] + + sz := binary.BigEndian.Uint16(b.buf[off : off+pktSizeHeader]) + if sz == invalidPktSize { + return nil + } + + off += pktSizeHeader + return b.buf[off : off+int(sz)] } func (b *Bucket) set(sn uint16, pkt []byte) ([]byte, error) { - if b.headSN-sn >= uint16(b.maxSteps+1) { + diff := b.headSN - sn + if int(diff) >= b.maxSteps { return nil, fmt.Errorf("%w, headSN %d, sn %d", ErrPacketTooOld, b.headSN, sn) } - pos := b.step - int(b.headSN-sn+1) - if pos < 0 { - pos = b.maxSteps + pos + 1 - } - off := pos * maxPktSize - if off > len(b.buf) || off < 0 { - return nil, ErrPacketTooOld - } - // Do not overwrite if packet exist - if binary.BigEndian.Uint16(b.buf[off+4:off+6]) == sn { + + off := b.offset(b.step - int(diff) - 1) + + // Do not overwrite if duplicate + if binary.BigEndian.Uint16(b.buf[off+pktSizeHeader+seqNumOffset:off+pktSizeHeader+seqNumOffset+seqNumSize]) == sn { return nil, ErrRTXPacket } - binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt))) - copy(b.buf[off+2:], pkt) - return b.buf[off+2 : off+2+len(pkt)], nil + + return b.store(off, pkt), nil +} + +func (b *Bucket) store(off int, pkt []byte) []byte { + // store packet size + binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt))) + + // store packet + off += pktSizeHeader + copy(b.buf[off:], pkt) + + return b.buf[off : off+len(pkt)] +} + +func (b *Bucket) wrap(slot int) int { + for slot < 0 { + slot += b.maxSteps + } + + for slot >= b.maxSteps { + slot -= b.maxSteps + } + + return slot +} + +func (b *Bucket) offset(slot int) int { + return b.wrap(slot) * maxPktSize +} + +func (b *Bucket) invalidate(startSlot int, numSlots int) { + if numSlots > b.maxSteps { + numSlots = b.maxSteps + } + + for i := 0; i < numSlots; i++ { + off := b.offset(startSlot + i) + binary.BigEndian.PutUint16(b.buf[off:], invalidPktSize) + } } diff --git a/pkg/sfu/buffer/bucket_test.go b/pkg/sfu/buffer/bucket_test.go index 046aac9b4..b9d522b7b 100644 --- a/pkg/sfu/buffer/bucket_test.go +++ b/pkg/sfu/buffer/bucket_test.go @@ -4,85 +4,94 @@ import ( "testing" "github.com/pion/rtp" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var TestPackets = []*rtp.Packet{ - { - Header: rtp.Header{ - SequenceNumber: 1, - }, - }, - { - Header: rtp.Header{ - SequenceNumber: 3, - }, - }, - { - Header: rtp.Header{ - SequenceNumber: 4, - }, - }, - { - Header: rtp.Header{ - SequenceNumber: 6, - }, - }, - { - Header: rtp.Header{ - SequenceNumber: 7, - }, - }, - { - Header: rtp.Header{ - SequenceNumber: 10, - }, - }, -} - func Test_queue(t *testing.T) { - b := make([]byte, 25000) + TestPackets := []*rtp.Packet{ + { + Header: rtp.Header{ + SequenceNumber: 1, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 3, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 4, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 6, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 7, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 10, + }, + }, + } + + b := make([]byte, 16000) q := NewBucket(&b) for _, p := range TestPackets { - p := p buf, err := p.Marshal() - assert.NoError(t, err) - assert.NotPanics(t, func() { + require.NoError(t, err) + require.NotPanics(t, func() { _, _ = q.AddPacket(buf) }) } - var expectedSN uint16 - expectedSN = 6 + + expectedSN := uint16(6) np := rtp.Packet{} buff := make([]byte, maxPktSize) - i, err := q.GetPacket(buff, 6) - assert.NoError(t, err) + i, err := q.GetPacket(buff, expectedSN) + require.NoError(t, err) err = np.Unmarshal(buff[:i]) - assert.NoError(t, err) - assert.Equal(t, expectedSN, np.SequenceNumber) + require.NoError(t, err) + require.Equal(t, expectedSN, np.SequenceNumber) + // add an out-of-order packet and ensure it can be retrieved np2 := &rtp.Packet{ Header: rtp.Header{ SequenceNumber: 8, }, } buf, err := np2.Marshal() - assert.NoError(t, err) + require.NoError(t, err) + _, err = q.AddPacket(buf) + require.NoError(t, err) expectedSN = 8 - _, _ = q.AddPacket(buf) i, err = q.GetPacket(buff, expectedSN) - assert.NoError(t, err) + require.NoError(t, err) err = np.Unmarshal(buff[:i]) - assert.NoError(t, err) - assert.Equal(t, expectedSN, np.SequenceNumber) + require.NoError(t, err) + require.Equal(t, expectedSN, np.SequenceNumber) _, err = q.AddPacket(buf) - assert.ErrorIs(t, err, ErrRTXPacket) + require.ErrorIs(t, err, ErrRTXPacket) + + // try to get old packets + _, err = q.GetPacket(buff, 0) + require.ErrorIs(t, err, ErrPacketNotFound) + + // ask for soemething ahead of headSN + _, err = q.GetPacket(buff, 11) + require.ErrorIs(t, err, ErrPacketNotFound) } func Test_queue_edges(t *testing.T) { - var TestPackets = []*rtp.Packet{ + TestPackets := []*rtp.Packet{ { Header: rtp.Header{ SequenceNumber: 65533, @@ -99,41 +108,163 @@ func Test_queue_edges(t *testing.T) { }, }, } + b := make([]byte, 25000) q := NewBucket(&b) + for _, p := range TestPackets { - p := p - assert.NotNil(t, p) - assert.NotPanics(t, func() { - p := p + require.NotPanics(t, func() { buf, err := p.Marshal() - assert.NoError(t, err) - assert.NotPanics(t, func() { + require.NoError(t, err) + require.NotPanics(t, func() { _, _ = q.AddPacket(buf) }) }) } - var expectedSN uint16 - expectedSN = 65534 + + expectedSN := uint16(65534) np := rtp.Packet{} buff := make([]byte, maxPktSize) i, err := q.GetPacket(buff, expectedSN) - assert.NoError(t, err) + require.NoError(t, err) err = np.Unmarshal(buff[:i]) - assert.NoError(t, err) - assert.Equal(t, expectedSN, np.SequenceNumber) + require.NoError(t, err) + require.Equal(t, expectedSN, np.SequenceNumber) + // add an out-of-order packet where the head sequence has wrapped and ensure it can be retrieved np2 := rtp.Packet{ Header: rtp.Header{ SequenceNumber: 65535, }, } buf, err := np2.Marshal() - assert.NoError(t, err) + require.NoError(t, err) _, _ = q.AddPacket(buf) i, err = q.GetPacket(buff, expectedSN+1) - assert.NoError(t, err) + require.NoError(t, err) err = np.Unmarshal(buff[:i]) - assert.NoError(t, err) - assert.Equal(t, expectedSN+1, np.SequenceNumber) + require.NoError(t, err) + require.Equal(t, expectedSN+1, np.SequenceNumber) +} + +func Test_queue_wrap(t *testing.T) { + TestPackets := []*rtp.Packet{ + { + Header: rtp.Header{ + SequenceNumber: 1, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 3, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 4, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 6, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 7, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 10, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 13, + }, + }, + { + Header: rtp.Header{ + SequenceNumber: 15, + }, + }, + } + + b := make([]byte, 16000) + q := NewBucket(&b) + + for _, p := range TestPackets { + buf, err := p.Marshal() + require.NoError(t, err) + require.NotPanics(t, func() { + _, _ = q.AddPacket(buf) + }) + } + + buff := make([]byte, maxPktSize) + + // try to get old packets, but were valid before the bucket wrapped + _, err := q.GetPacket(buff, 1) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 3) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 4) + require.ErrorIs(t, err, ErrPacketNotFound) + + expectedSN := uint16(6) + np := rtp.Packet{} + i, err := q.GetPacket(buff, expectedSN) + require.NoError(t, err) + err = np.Unmarshal(buff[:i]) + require.NoError(t, err) + require.Equal(t, expectedSN, np.SequenceNumber) + + // add an out-of-order packet and ensure it can be retrieved + np2 := &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: 8, + }, + } + buf, err := np2.Marshal() + require.NoError(t, err) + _, err = q.AddPacket(buf) + require.NoError(t, err) + expectedSN = 8 + i, err = q.GetPacket(buff, expectedSN) + require.NoError(t, err) + err = np.Unmarshal(buff[:i]) + require.NoError(t, err) + require.Equal(t, expectedSN, np.SequenceNumber) + + // add a packet with a large gap in sequence number which will invalidate all the slots + np3 := &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: 56, + }, + } + buf, err = np3.Marshal() + require.NoError(t, err) + _, err = q.AddPacket(buf) + require.NoError(t, err) + expectedSN = 56 + i, err = q.GetPacket(buff, expectedSN) + require.NoError(t, err) + err = np.Unmarshal(buff[:i]) + require.NoError(t, err) + require.Equal(t, expectedSN, np.SequenceNumber) + + // after the large jump invalidating all slots, retrieving previously added packets should fail + _, err = q.GetPacket(buff, 6) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 7) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 8) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 10) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 13) + require.ErrorIs(t, err, ErrPacketNotFound) + _, err = q.GetPacket(buff, 15) + require.ErrorIs(t, err, ErrPacketNotFound) } diff --git a/pkg/sfu/buffer/helpers_test.go b/pkg/sfu/buffer/helpers_test.go index 0386d5530..219e65175 100644 --- a/pkg/sfu/buffer/helpers_test.go +++ b/pkg/sfu/buffer/helpers_test.go @@ -3,7 +3,7 @@ package buffer import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestVP8Helper_Unmarshal(t *testing.T) { @@ -75,19 +75,19 @@ func TestVP8Helper_Unmarshal(t *testing.T) { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } if tt.checkTemporal { - assert.Equal(t, tt.temporalSupport, p.TIDPresent == 1) + require.Equal(t, tt.temporalSupport, p.TIDPresent == 1) } if tt.checkKeyFrame { - assert.Equal(t, tt.keyFrame, p.IsKeyFrame) + require.Equal(t, tt.keyFrame, p.IsKeyFrame) } if tt.checkPictureID { - assert.Equal(t, tt.pictureID, p.PictureID) + require.Equal(t, tt.pictureID, p.PictureID) } if tt.checkTlzIdx { - assert.Equal(t, tt.tlzIdx, p.TL0PICIDX) + require.Equal(t, tt.tlzIdx, p.TL0PICIDX) } if tt.checkTempID { - assert.Equal(t, tt.temporalID, p.TID) + require.Equal(t, tt.temporalID, p.TID) } }) }