Refactor Bucket a little bit (#459)

This commit is contained in:
Raja Subramanian
2022-02-24 13:43:12 +05:30
committed by GitHub
parent 7eb2fecadd
commit eed241cf7f
3 changed files with 295 additions and 119 deletions
+92 -47
View File
@@ -6,7 +6,13 @@ import (
"math"
)
const maxPktSize = 1500
const (
maxPktSize = 1500
pktSizeHeader = 2
seqNumOffset = 2
seqNumSize = 2
invalidPktSize = uint16(65535)
)
type Bucket struct {
buf []byte
@@ -19,32 +25,30 @@ type Bucket struct {
}
func NewBucket(buf *[]byte) *Bucket {
return &Bucket{
b := &Bucket{
src: buf,
buf: *buf,
maxSteps: int(math.Floor(float64(len(*buf))/float64(maxPktSize))) - 1,
maxSteps: int(math.Floor(float64(len(*buf)) / float64(maxPktSize))),
}
b.invalidate(0, b.maxSteps)
return b
}
func (b *Bucket) AddPacket(pkt []byte) ([]byte, error) {
sn := binary.BigEndian.Uint16(pkt[2:4])
sn := binary.BigEndian.Uint16(pkt[seqNumOffset : seqNumOffset+seqNumSize])
if !b.init {
b.headSN = sn - 1
b.init = true
}
diff := sn - b.headSN
if diff == 0 || diff > (1<<15) {
// duplicate of last packet or out-of-order
return b.set(sn, pkt)
}
b.headSN = sn
for i := uint16(1); i < diff; i++ {
b.step++
if b.step >= b.maxSteps {
b.step = 0
}
}
return b.push(pkt), nil
return b.push(sn, pkt)
}
func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) {
@@ -65,53 +69,94 @@ func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) {
return
}
func (b *Bucket) push(pkt []byte) []byte {
binary.BigEndian.PutUint16(b.buf[b.step*maxPktSize:], uint16(len(pkt)))
off := b.step*maxPktSize + 2
copy(b.buf[off:], pkt)
b.step++
if b.step > b.maxSteps {
b.step = 0
}
return b.buf[off : off+len(pkt)]
func (b *Bucket) push(sn uint16, pkt []byte) ([]byte, error) {
diff := int(sn-b.headSN) - 1
b.headSN = sn
// invalidate slots if there is a gap in the sequence number
b.invalidate(b.step, diff)
// store headSN packet
off := b.offset(b.step + diff)
storedPkt := b.store(off, pkt)
// for next packet
b.step = b.wrap(b.step + diff + 1)
return storedPkt, nil
}
func (b *Bucket) get(sn uint16) []byte {
pos := b.step - int(b.headSN-sn+1)
if pos < 0 {
if pos*-1 > b.maxSteps+1 {
return nil
}
pos = b.maxSteps + pos + 1
}
off := pos * maxPktSize
if off > len(b.buf) {
diff := b.headSN - sn
if int(diff) >= b.maxSteps {
// too old or asking for something ahead of headSN (which is effectively too old with wrap around)
return nil
}
if binary.BigEndian.Uint16(b.buf[off+4:off+6]) != sn {
off := b.offset(b.step - int(diff) - 1)
if binary.BigEndian.Uint16(b.buf[off+pktSizeHeader+seqNumOffset:off+pktSizeHeader+seqNumOffset+seqNumSize]) != sn {
return nil
}
sz := int(binary.BigEndian.Uint16(b.buf[off : off+2]))
return b.buf[off+2 : off+2+sz]
sz := binary.BigEndian.Uint16(b.buf[off : off+pktSizeHeader])
if sz == invalidPktSize {
return nil
}
off += pktSizeHeader
return b.buf[off : off+int(sz)]
}
func (b *Bucket) set(sn uint16, pkt []byte) ([]byte, error) {
if b.headSN-sn >= uint16(b.maxSteps+1) {
diff := b.headSN - sn
if int(diff) >= b.maxSteps {
return nil, fmt.Errorf("%w, headSN %d, sn %d", ErrPacketTooOld, b.headSN, sn)
}
pos := b.step - int(b.headSN-sn+1)
if pos < 0 {
pos = b.maxSteps + pos + 1
}
off := pos * maxPktSize
if off > len(b.buf) || off < 0 {
return nil, ErrPacketTooOld
}
// Do not overwrite if packet exist
if binary.BigEndian.Uint16(b.buf[off+4:off+6]) == sn {
off := b.offset(b.step - int(diff) - 1)
// Do not overwrite if duplicate
if binary.BigEndian.Uint16(b.buf[off+pktSizeHeader+seqNumOffset:off+pktSizeHeader+seqNumOffset+seqNumSize]) == sn {
return nil, ErrRTXPacket
}
binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt)))
copy(b.buf[off+2:], pkt)
return b.buf[off+2 : off+2+len(pkt)], nil
return b.store(off, pkt), nil
}
func (b *Bucket) store(off int, pkt []byte) []byte {
// store packet size
binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt)))
// store packet
off += pktSizeHeader
copy(b.buf[off:], pkt)
return b.buf[off : off+len(pkt)]
}
func (b *Bucket) wrap(slot int) int {
for slot < 0 {
slot += b.maxSteps
}
for slot >= b.maxSteps {
slot -= b.maxSteps
}
return slot
}
func (b *Bucket) offset(slot int) int {
return b.wrap(slot) * maxPktSize
}
func (b *Bucket) invalidate(startSlot int, numSlots int) {
if numSlots > b.maxSteps {
numSlots = b.maxSteps
}
for i := 0; i < numSlots; i++ {
off := b.offset(startSlot + i)
binary.BigEndian.PutUint16(b.buf[off:], invalidPktSize)
}
}
+197 -66
View File
@@ -4,85 +4,94 @@ import (
"testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var TestPackets = []*rtp.Packet{
{
Header: rtp.Header{
SequenceNumber: 1,
},
},
{
Header: rtp.Header{
SequenceNumber: 3,
},
},
{
Header: rtp.Header{
SequenceNumber: 4,
},
},
{
Header: rtp.Header{
SequenceNumber: 6,
},
},
{
Header: rtp.Header{
SequenceNumber: 7,
},
},
{
Header: rtp.Header{
SequenceNumber: 10,
},
},
}
func Test_queue(t *testing.T) {
b := make([]byte, 25000)
TestPackets := []*rtp.Packet{
{
Header: rtp.Header{
SequenceNumber: 1,
},
},
{
Header: rtp.Header{
SequenceNumber: 3,
},
},
{
Header: rtp.Header{
SequenceNumber: 4,
},
},
{
Header: rtp.Header{
SequenceNumber: 6,
},
},
{
Header: rtp.Header{
SequenceNumber: 7,
},
},
{
Header: rtp.Header{
SequenceNumber: 10,
},
},
}
b := make([]byte, 16000)
q := NewBucket(&b)
for _, p := range TestPackets {
p := p
buf, err := p.Marshal()
assert.NoError(t, err)
assert.NotPanics(t, func() {
require.NoError(t, err)
require.NotPanics(t, func() {
_, _ = q.AddPacket(buf)
})
}
var expectedSN uint16
expectedSN = 6
expectedSN := uint16(6)
np := rtp.Packet{}
buff := make([]byte, maxPktSize)
i, err := q.GetPacket(buff, 6)
assert.NoError(t, err)
i, err := q.GetPacket(buff, expectedSN)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
assert.NoError(t, err)
assert.Equal(t, expectedSN, np.SequenceNumber)
require.NoError(t, err)
require.Equal(t, expectedSN, np.SequenceNumber)
// add an out-of-order packet and ensure it can be retrieved
np2 := &rtp.Packet{
Header: rtp.Header{
SequenceNumber: 8,
},
}
buf, err := np2.Marshal()
assert.NoError(t, err)
require.NoError(t, err)
_, err = q.AddPacket(buf)
require.NoError(t, err)
expectedSN = 8
_, _ = q.AddPacket(buf)
i, err = q.GetPacket(buff, expectedSN)
assert.NoError(t, err)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
assert.NoError(t, err)
assert.Equal(t, expectedSN, np.SequenceNumber)
require.NoError(t, err)
require.Equal(t, expectedSN, np.SequenceNumber)
_, err = q.AddPacket(buf)
assert.ErrorIs(t, err, ErrRTXPacket)
require.ErrorIs(t, err, ErrRTXPacket)
// try to get old packets
_, err = q.GetPacket(buff, 0)
require.ErrorIs(t, err, ErrPacketNotFound)
// ask for soemething ahead of headSN
_, err = q.GetPacket(buff, 11)
require.ErrorIs(t, err, ErrPacketNotFound)
}
func Test_queue_edges(t *testing.T) {
var TestPackets = []*rtp.Packet{
TestPackets := []*rtp.Packet{
{
Header: rtp.Header{
SequenceNumber: 65533,
@@ -99,41 +108,163 @@ func Test_queue_edges(t *testing.T) {
},
},
}
b := make([]byte, 25000)
q := NewBucket(&b)
for _, p := range TestPackets {
p := p
assert.NotNil(t, p)
assert.NotPanics(t, func() {
p := p
require.NotPanics(t, func() {
buf, err := p.Marshal()
assert.NoError(t, err)
assert.NotPanics(t, func() {
require.NoError(t, err)
require.NotPanics(t, func() {
_, _ = q.AddPacket(buf)
})
})
}
var expectedSN uint16
expectedSN = 65534
expectedSN := uint16(65534)
np := rtp.Packet{}
buff := make([]byte, maxPktSize)
i, err := q.GetPacket(buff, expectedSN)
assert.NoError(t, err)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
assert.NoError(t, err)
assert.Equal(t, expectedSN, np.SequenceNumber)
require.NoError(t, err)
require.Equal(t, expectedSN, np.SequenceNumber)
// add an out-of-order packet where the head sequence has wrapped and ensure it can be retrieved
np2 := rtp.Packet{
Header: rtp.Header{
SequenceNumber: 65535,
},
}
buf, err := np2.Marshal()
assert.NoError(t, err)
require.NoError(t, err)
_, _ = q.AddPacket(buf)
i, err = q.GetPacket(buff, expectedSN+1)
assert.NoError(t, err)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
assert.NoError(t, err)
assert.Equal(t, expectedSN+1, np.SequenceNumber)
require.NoError(t, err)
require.Equal(t, expectedSN+1, np.SequenceNumber)
}
func Test_queue_wrap(t *testing.T) {
TestPackets := []*rtp.Packet{
{
Header: rtp.Header{
SequenceNumber: 1,
},
},
{
Header: rtp.Header{
SequenceNumber: 3,
},
},
{
Header: rtp.Header{
SequenceNumber: 4,
},
},
{
Header: rtp.Header{
SequenceNumber: 6,
},
},
{
Header: rtp.Header{
SequenceNumber: 7,
},
},
{
Header: rtp.Header{
SequenceNumber: 10,
},
},
{
Header: rtp.Header{
SequenceNumber: 13,
},
},
{
Header: rtp.Header{
SequenceNumber: 15,
},
},
}
b := make([]byte, 16000)
q := NewBucket(&b)
for _, p := range TestPackets {
buf, err := p.Marshal()
require.NoError(t, err)
require.NotPanics(t, func() {
_, _ = q.AddPacket(buf)
})
}
buff := make([]byte, maxPktSize)
// try to get old packets, but were valid before the bucket wrapped
_, err := q.GetPacket(buff, 1)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 3)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 4)
require.ErrorIs(t, err, ErrPacketNotFound)
expectedSN := uint16(6)
np := rtp.Packet{}
i, err := q.GetPacket(buff, expectedSN)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
require.NoError(t, err)
require.Equal(t, expectedSN, np.SequenceNumber)
// add an out-of-order packet and ensure it can be retrieved
np2 := &rtp.Packet{
Header: rtp.Header{
SequenceNumber: 8,
},
}
buf, err := np2.Marshal()
require.NoError(t, err)
_, err = q.AddPacket(buf)
require.NoError(t, err)
expectedSN = 8
i, err = q.GetPacket(buff, expectedSN)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
require.NoError(t, err)
require.Equal(t, expectedSN, np.SequenceNumber)
// add a packet with a large gap in sequence number which will invalidate all the slots
np3 := &rtp.Packet{
Header: rtp.Header{
SequenceNumber: 56,
},
}
buf, err = np3.Marshal()
require.NoError(t, err)
_, err = q.AddPacket(buf)
require.NoError(t, err)
expectedSN = 56
i, err = q.GetPacket(buff, expectedSN)
require.NoError(t, err)
err = np.Unmarshal(buff[:i])
require.NoError(t, err)
require.Equal(t, expectedSN, np.SequenceNumber)
// after the large jump invalidating all slots, retrieving previously added packets should fail
_, err = q.GetPacket(buff, 6)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 7)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 8)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 10)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 13)
require.ErrorIs(t, err, ErrPacketNotFound)
_, err = q.GetPacket(buff, 15)
require.ErrorIs(t, err, ErrPacketNotFound)
}
+6 -6
View File
@@ -3,7 +3,7 @@ package buffer
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVP8Helper_Unmarshal(t *testing.T) {
@@ -75,19 +75,19 @@ func TestVP8Helper_Unmarshal(t *testing.T) {
t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.checkTemporal {
assert.Equal(t, tt.temporalSupport, p.TIDPresent == 1)
require.Equal(t, tt.temporalSupport, p.TIDPresent == 1)
}
if tt.checkKeyFrame {
assert.Equal(t, tt.keyFrame, p.IsKeyFrame)
require.Equal(t, tt.keyFrame, p.IsKeyFrame)
}
if tt.checkPictureID {
assert.Equal(t, tt.pictureID, p.PictureID)
require.Equal(t, tt.pictureID, p.PictureID)
}
if tt.checkTlzIdx {
assert.Equal(t, tt.tlzIdx, p.TL0PICIDX)
require.Equal(t, tt.tlzIdx, p.TL0PICIDX)
}
if tt.checkTempID {
assert.Equal(t, tt.temporalID, p.TID)
require.Equal(t, tt.temporalID, p.TID)
}
})
}