mirror of
https://github.com/livekit/livekit.git
synced 2026-05-24 16:55:35 +00:00
Refactor Bucket a little bit (#459)
This commit is contained in:
+92
-47
@@ -6,7 +6,13 @@ import (
|
||||
"math"
|
||||
)
|
||||
|
||||
const maxPktSize = 1500
|
||||
const (
|
||||
maxPktSize = 1500
|
||||
pktSizeHeader = 2
|
||||
seqNumOffset = 2
|
||||
seqNumSize = 2
|
||||
invalidPktSize = uint16(65535)
|
||||
)
|
||||
|
||||
type Bucket struct {
|
||||
buf []byte
|
||||
@@ -19,32 +25,30 @@ type Bucket struct {
|
||||
}
|
||||
|
||||
func NewBucket(buf *[]byte) *Bucket {
|
||||
return &Bucket{
|
||||
b := &Bucket{
|
||||
src: buf,
|
||||
buf: *buf,
|
||||
maxSteps: int(math.Floor(float64(len(*buf))/float64(maxPktSize))) - 1,
|
||||
maxSteps: int(math.Floor(float64(len(*buf)) / float64(maxPktSize))),
|
||||
}
|
||||
|
||||
b.invalidate(0, b.maxSteps)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Bucket) AddPacket(pkt []byte) ([]byte, error) {
|
||||
sn := binary.BigEndian.Uint16(pkt[2:4])
|
||||
sn := binary.BigEndian.Uint16(pkt[seqNumOffset : seqNumOffset+seqNumSize])
|
||||
if !b.init {
|
||||
b.headSN = sn - 1
|
||||
b.init = true
|
||||
}
|
||||
|
||||
diff := sn - b.headSN
|
||||
if diff == 0 || diff > (1<<15) {
|
||||
// duplicate of last packet or out-of-order
|
||||
return b.set(sn, pkt)
|
||||
}
|
||||
b.headSN = sn
|
||||
for i := uint16(1); i < diff; i++ {
|
||||
b.step++
|
||||
if b.step >= b.maxSteps {
|
||||
b.step = 0
|
||||
}
|
||||
}
|
||||
return b.push(pkt), nil
|
||||
|
||||
return b.push(sn, pkt)
|
||||
}
|
||||
|
||||
func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) {
|
||||
@@ -65,53 +69,94 @@ func (b *Bucket) GetPacket(buf []byte, sn uint16) (i int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (b *Bucket) push(pkt []byte) []byte {
|
||||
binary.BigEndian.PutUint16(b.buf[b.step*maxPktSize:], uint16(len(pkt)))
|
||||
off := b.step*maxPktSize + 2
|
||||
copy(b.buf[off:], pkt)
|
||||
b.step++
|
||||
if b.step > b.maxSteps {
|
||||
b.step = 0
|
||||
}
|
||||
return b.buf[off : off+len(pkt)]
|
||||
func (b *Bucket) push(sn uint16, pkt []byte) ([]byte, error) {
|
||||
diff := int(sn-b.headSN) - 1
|
||||
b.headSN = sn
|
||||
|
||||
// invalidate slots if there is a gap in the sequence number
|
||||
b.invalidate(b.step, diff)
|
||||
|
||||
// store headSN packet
|
||||
off := b.offset(b.step + diff)
|
||||
storedPkt := b.store(off, pkt)
|
||||
|
||||
// for next packet
|
||||
b.step = b.wrap(b.step + diff + 1)
|
||||
|
||||
return storedPkt, nil
|
||||
}
|
||||
|
||||
func (b *Bucket) get(sn uint16) []byte {
|
||||
pos := b.step - int(b.headSN-sn+1)
|
||||
if pos < 0 {
|
||||
if pos*-1 > b.maxSteps+1 {
|
||||
return nil
|
||||
}
|
||||
pos = b.maxSteps + pos + 1
|
||||
}
|
||||
off := pos * maxPktSize
|
||||
if off > len(b.buf) {
|
||||
diff := b.headSN - sn
|
||||
if int(diff) >= b.maxSteps {
|
||||
// too old or asking for something ahead of headSN (which is effectively too old with wrap around)
|
||||
return nil
|
||||
}
|
||||
if binary.BigEndian.Uint16(b.buf[off+4:off+6]) != sn {
|
||||
|
||||
off := b.offset(b.step - int(diff) - 1)
|
||||
if binary.BigEndian.Uint16(b.buf[off+pktSizeHeader+seqNumOffset:off+pktSizeHeader+seqNumOffset+seqNumSize]) != sn {
|
||||
return nil
|
||||
}
|
||||
sz := int(binary.BigEndian.Uint16(b.buf[off : off+2]))
|
||||
return b.buf[off+2 : off+2+sz]
|
||||
|
||||
sz := binary.BigEndian.Uint16(b.buf[off : off+pktSizeHeader])
|
||||
if sz == invalidPktSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
off += pktSizeHeader
|
||||
return b.buf[off : off+int(sz)]
|
||||
}
|
||||
|
||||
func (b *Bucket) set(sn uint16, pkt []byte) ([]byte, error) {
|
||||
if b.headSN-sn >= uint16(b.maxSteps+1) {
|
||||
diff := b.headSN - sn
|
||||
if int(diff) >= b.maxSteps {
|
||||
return nil, fmt.Errorf("%w, headSN %d, sn %d", ErrPacketTooOld, b.headSN, sn)
|
||||
}
|
||||
pos := b.step - int(b.headSN-sn+1)
|
||||
if pos < 0 {
|
||||
pos = b.maxSteps + pos + 1
|
||||
}
|
||||
off := pos * maxPktSize
|
||||
if off > len(b.buf) || off < 0 {
|
||||
return nil, ErrPacketTooOld
|
||||
}
|
||||
// Do not overwrite if packet exist
|
||||
if binary.BigEndian.Uint16(b.buf[off+4:off+6]) == sn {
|
||||
|
||||
off := b.offset(b.step - int(diff) - 1)
|
||||
|
||||
// Do not overwrite if duplicate
|
||||
if binary.BigEndian.Uint16(b.buf[off+pktSizeHeader+seqNumOffset:off+pktSizeHeader+seqNumOffset+seqNumSize]) == sn {
|
||||
return nil, ErrRTXPacket
|
||||
}
|
||||
binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt)))
|
||||
copy(b.buf[off+2:], pkt)
|
||||
return b.buf[off+2 : off+2+len(pkt)], nil
|
||||
|
||||
return b.store(off, pkt), nil
|
||||
}
|
||||
|
||||
func (b *Bucket) store(off int, pkt []byte) []byte {
|
||||
// store packet size
|
||||
binary.BigEndian.PutUint16(b.buf[off:], uint16(len(pkt)))
|
||||
|
||||
// store packet
|
||||
off += pktSizeHeader
|
||||
copy(b.buf[off:], pkt)
|
||||
|
||||
return b.buf[off : off+len(pkt)]
|
||||
}
|
||||
|
||||
func (b *Bucket) wrap(slot int) int {
|
||||
for slot < 0 {
|
||||
slot += b.maxSteps
|
||||
}
|
||||
|
||||
for slot >= b.maxSteps {
|
||||
slot -= b.maxSteps
|
||||
}
|
||||
|
||||
return slot
|
||||
}
|
||||
|
||||
func (b *Bucket) offset(slot int) int {
|
||||
return b.wrap(slot) * maxPktSize
|
||||
}
|
||||
|
||||
func (b *Bucket) invalidate(startSlot int, numSlots int) {
|
||||
if numSlots > b.maxSteps {
|
||||
numSlots = b.maxSteps
|
||||
}
|
||||
|
||||
for i := 0; i < numSlots; i++ {
|
||||
off := b.offset(startSlot + i)
|
||||
binary.BigEndian.PutUint16(b.buf[off:], invalidPktSize)
|
||||
}
|
||||
}
|
||||
|
||||
+197
-66
@@ -4,85 +4,94 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var TestPackets = []*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 7,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func Test_queue(t *testing.T) {
|
||||
b := make([]byte, 25000)
|
||||
TestPackets := []*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 7,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b := make([]byte, 16000)
|
||||
q := NewBucket(&b)
|
||||
|
||||
for _, p := range TestPackets {
|
||||
p := p
|
||||
buf, err := p.Marshal()
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
require.NoError(t, err)
|
||||
require.NotPanics(t, func() {
|
||||
_, _ = q.AddPacket(buf)
|
||||
})
|
||||
}
|
||||
var expectedSN uint16
|
||||
expectedSN = 6
|
||||
|
||||
expectedSN := uint16(6)
|
||||
np := rtp.Packet{}
|
||||
buff := make([]byte, maxPktSize)
|
||||
i, err := q.GetPacket(buff, 6)
|
||||
assert.NoError(t, err)
|
||||
i, err := q.GetPacket(buff, expectedSN)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSN, np.SequenceNumber)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN, np.SequenceNumber)
|
||||
|
||||
// add an out-of-order packet and ensure it can be retrieved
|
||||
np2 := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 8,
|
||||
},
|
||||
}
|
||||
buf, err := np2.Marshal()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
_, err = q.AddPacket(buf)
|
||||
require.NoError(t, err)
|
||||
expectedSN = 8
|
||||
_, _ = q.AddPacket(buf)
|
||||
i, err = q.GetPacket(buff, expectedSN)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSN, np.SequenceNumber)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN, np.SequenceNumber)
|
||||
|
||||
_, err = q.AddPacket(buf)
|
||||
assert.ErrorIs(t, err, ErrRTXPacket)
|
||||
require.ErrorIs(t, err, ErrRTXPacket)
|
||||
|
||||
// try to get old packets
|
||||
_, err = q.GetPacket(buff, 0)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
|
||||
// ask for soemething ahead of headSN
|
||||
_, err = q.GetPacket(buff, 11)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
}
|
||||
|
||||
func Test_queue_edges(t *testing.T) {
|
||||
var TestPackets = []*rtp.Packet{
|
||||
TestPackets := []*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 65533,
|
||||
@@ -99,41 +108,163 @@ func Test_queue_edges(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b := make([]byte, 25000)
|
||||
q := NewBucket(&b)
|
||||
|
||||
for _, p := range TestPackets {
|
||||
p := p
|
||||
assert.NotNil(t, p)
|
||||
assert.NotPanics(t, func() {
|
||||
p := p
|
||||
require.NotPanics(t, func() {
|
||||
buf, err := p.Marshal()
|
||||
assert.NoError(t, err)
|
||||
assert.NotPanics(t, func() {
|
||||
require.NoError(t, err)
|
||||
require.NotPanics(t, func() {
|
||||
_, _ = q.AddPacket(buf)
|
||||
})
|
||||
})
|
||||
}
|
||||
var expectedSN uint16
|
||||
expectedSN = 65534
|
||||
|
||||
expectedSN := uint16(65534)
|
||||
np := rtp.Packet{}
|
||||
buff := make([]byte, maxPktSize)
|
||||
i, err := q.GetPacket(buff, expectedSN)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSN, np.SequenceNumber)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN, np.SequenceNumber)
|
||||
|
||||
// add an out-of-order packet where the head sequence has wrapped and ensure it can be retrieved
|
||||
np2 := rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 65535,
|
||||
},
|
||||
}
|
||||
buf, err := np2.Marshal()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
_, _ = q.AddPacket(buf)
|
||||
i, err = q.GetPacket(buff, expectedSN+1)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSN+1, np.SequenceNumber)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN+1, np.SequenceNumber)
|
||||
}
|
||||
|
||||
func Test_queue_wrap(t *testing.T) {
|
||||
TestPackets := []*rtp.Packet{
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 7,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 13,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 15,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b := make([]byte, 16000)
|
||||
q := NewBucket(&b)
|
||||
|
||||
for _, p := range TestPackets {
|
||||
buf, err := p.Marshal()
|
||||
require.NoError(t, err)
|
||||
require.NotPanics(t, func() {
|
||||
_, _ = q.AddPacket(buf)
|
||||
})
|
||||
}
|
||||
|
||||
buff := make([]byte, maxPktSize)
|
||||
|
||||
// try to get old packets, but were valid before the bucket wrapped
|
||||
_, err := q.GetPacket(buff, 1)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 3)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 4)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
|
||||
expectedSN := uint16(6)
|
||||
np := rtp.Packet{}
|
||||
i, err := q.GetPacket(buff, expectedSN)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN, np.SequenceNumber)
|
||||
|
||||
// add an out-of-order packet and ensure it can be retrieved
|
||||
np2 := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 8,
|
||||
},
|
||||
}
|
||||
buf, err := np2.Marshal()
|
||||
require.NoError(t, err)
|
||||
_, err = q.AddPacket(buf)
|
||||
require.NoError(t, err)
|
||||
expectedSN = 8
|
||||
i, err = q.GetPacket(buff, expectedSN)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN, np.SequenceNumber)
|
||||
|
||||
// add a packet with a large gap in sequence number which will invalidate all the slots
|
||||
np3 := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
SequenceNumber: 56,
|
||||
},
|
||||
}
|
||||
buf, err = np3.Marshal()
|
||||
require.NoError(t, err)
|
||||
_, err = q.AddPacket(buf)
|
||||
require.NoError(t, err)
|
||||
expectedSN = 56
|
||||
i, err = q.GetPacket(buff, expectedSN)
|
||||
require.NoError(t, err)
|
||||
err = np.Unmarshal(buff[:i])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSN, np.SequenceNumber)
|
||||
|
||||
// after the large jump invalidating all slots, retrieving previously added packets should fail
|
||||
_, err = q.GetPacket(buff, 6)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 7)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 8)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 10)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 13)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
_, err = q.GetPacket(buff, 15)
|
||||
require.ErrorIs(t, err, ErrPacketNotFound)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package buffer
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVP8Helper_Unmarshal(t *testing.T) {
|
||||
@@ -75,19 +75,19 @@ func TestVP8Helper_Unmarshal(t *testing.T) {
|
||||
t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.checkTemporal {
|
||||
assert.Equal(t, tt.temporalSupport, p.TIDPresent == 1)
|
||||
require.Equal(t, tt.temporalSupport, p.TIDPresent == 1)
|
||||
}
|
||||
if tt.checkKeyFrame {
|
||||
assert.Equal(t, tt.keyFrame, p.IsKeyFrame)
|
||||
require.Equal(t, tt.keyFrame, p.IsKeyFrame)
|
||||
}
|
||||
if tt.checkPictureID {
|
||||
assert.Equal(t, tt.pictureID, p.PictureID)
|
||||
require.Equal(t, tt.pictureID, p.PictureID)
|
||||
}
|
||||
if tt.checkTlzIdx {
|
||||
assert.Equal(t, tt.tlzIdx, p.TL0PICIDX)
|
||||
require.Equal(t, tt.tlzIdx, p.TL0PICIDX)
|
||||
}
|
||||
if tt.checkTempID {
|
||||
assert.Equal(t, tt.temporalID, p.TID)
|
||||
require.Equal(t, tt.temporalID, p.TID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user