mirror of
https://github.com/livekit/livekit.git
synced 2026-03-29 09:19:53 +00:00
Switch data track extension to 1-byte ID/length. (#4362)
And match design to RTP header extension, i. e. the padding for extensions is not at per extension level (which was the case before), but has been changed to padding the aggregate of all extensions in this PR.
This commit is contained in:
@@ -25,7 +25,7 @@ type ExtensionParticipantSid struct {
|
||||
}
|
||||
|
||||
func NewExtensionParticipantSid(participantID livekit.ParticipantID) (*ExtensionParticipantSid, error) {
|
||||
if len(participantID) >= 65536 {
|
||||
if len(participantID) >= 256 {
|
||||
return nil, errors.New("participantID too long")
|
||||
}
|
||||
|
||||
@@ -40,13 +40,13 @@ func (e *ExtensionParticipantSid) Marshal() (Extension, error) {
|
||||
data := make([]byte, len(e.participantID))
|
||||
copy(data, e.participantID)
|
||||
return Extension{
|
||||
id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
|
||||
id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
|
||||
data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *ExtensionParticipantSid) Unmarshal(ext Extension) error {
|
||||
if ext.id != uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) {
|
||||
if ext.id != uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) {
|
||||
return errors.New("invalid extension ID")
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
func TestExtensionParticipantSid(t *testing.T) {
|
||||
longTestParticipantID := livekit.ParticipantID(make([]byte, 65536))
|
||||
longTestParticipantID := livekit.ParticipantID(make([]byte, 256))
|
||||
extParticipantSid, err := NewExtensionParticipantSid(longTestParticipantID)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestExtensionParticipantSid(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedExt := Extension{
|
||||
id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
|
||||
id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
|
||||
data: []byte{'t', 'e', 's', 't'},
|
||||
}
|
||||
ext, err := extParticipantSid.Marshal()
|
||||
|
||||
@@ -25,6 +25,7 @@ var (
|
||||
errBufferSizeInsufficient = errors.New("data track packet buffer size insufficient")
|
||||
errExtensionSizeInsufficient = errors.New("data track packet extension size insufficient")
|
||||
errExtensionNotFound = errors.New("data track packet extension not found")
|
||||
errExtensionSizeTooBig = errors.New("extension size is too big")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -57,12 +58,12 @@ const (
|
||||
extensionsSizeOffset = headerLength
|
||||
extensionsSizeLength = 2
|
||||
|
||||
extensionIDLength = 2
|
||||
extensionSizeLength = 2
|
||||
extensionIDLength = 1
|
||||
extensionSizeLength = 1
|
||||
)
|
||||
|
||||
type Extension struct {
|
||||
id uint16
|
||||
id uint8
|
||||
data []byte
|
||||
}
|
||||
|
||||
@@ -98,9 +99,13 @@ type Header struct {
|
||||
┆* 0 1 2 3
|
||||
┆* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
┆* | Extension ID | Extension size |
|
||||
┆* | Extension ID | Extension size| Extension payload |
|
||||
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|* Extension payload (padded to 4 byte boundary) |
|
||||
|
||||
End of all extensions
|
||||
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|* padded to 4 byte boundary if aggregate of `Extensions Size` |
|
||||
|* field and all extensions do not end on a 4 byte boundary |
|
||||
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
*/
|
||||
|
||||
@@ -121,46 +126,59 @@ func (h *Header) Unmarshal(buf []byte) (int, error) {
|
||||
h.Timestamp = binary.BigEndian.Uint32(buf[timestampOffset : timestampOffset+timestampLength])
|
||||
|
||||
if h.HasExtensions {
|
||||
h.ExtensionsSize = (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength]) + 1) * 4
|
||||
extensionsSize := (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength])+1)*4 - extensionsSizeLength
|
||||
hdrSize += extensionsSizeLength
|
||||
|
||||
remainingSize := int(h.ExtensionsSize)
|
||||
extensionHeaderSize := extensionIDLength + extensionSizeLength
|
||||
remainingSize := int(extensionsSize)
|
||||
idx := extensionsSizeOffset + extensionsSizeLength
|
||||
for remainingSize != 0 {
|
||||
if len(buf[idx:]) < 4 || remainingSize < 4 {
|
||||
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), 4)
|
||||
// read extension header
|
||||
if len(buf[idx:]) < extensionIDLength || remainingSize < extensionIDLength {
|
||||
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionIDLength)
|
||||
}
|
||||
id := buf[idx]
|
||||
if id == 0 {
|
||||
// end of extensions, padding has started
|
||||
hdrSize += remainingSize
|
||||
break
|
||||
}
|
||||
|
||||
id := binary.BigEndian.Uint16(buf[idx : idx+2])
|
||||
size := int(binary.BigEndian.Uint16(buf[idx+2 : idx+4]))
|
||||
remainingSize -= 4
|
||||
idx += 4
|
||||
hdrSize += 4
|
||||
if len(buf[idx+1:]) < extensionSizeLength || remainingSize < extensionSizeLength {
|
||||
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionSizeLength)
|
||||
}
|
||||
size := int(buf[idx+1])
|
||||
|
||||
remainingSize -= extensionHeaderSize
|
||||
idx += extensionHeaderSize
|
||||
hdrSize += extensionHeaderSize
|
||||
|
||||
// read extension data
|
||||
if len(buf[idx:]) < size || remainingSize < size {
|
||||
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), size)
|
||||
}
|
||||
h.Extensions = append(h.Extensions, Extension{id: id, data: buf[idx : idx+size]})
|
||||
|
||||
size = ((size + 3) / 4) * 4
|
||||
remainingSize -= size
|
||||
idx += size
|
||||
hdrSize += size
|
||||
}
|
||||
h.ExtensionsSize = extensionsSize - uint16(remainingSize)
|
||||
}
|
||||
|
||||
return hdrSize, nil
|
||||
}
|
||||
|
||||
func (h *Header) MarshalSize() int {
|
||||
size := headerLength
|
||||
extensionsSize := 0
|
||||
if h.HasExtensions {
|
||||
size += 2 // extensions size field
|
||||
extensionsSize += extensionsSizeLength
|
||||
for _, ext := range h.Extensions {
|
||||
size += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */
|
||||
extensionsSize += len(ext.data) + extensionIDLength + extensionSizeLength
|
||||
}
|
||||
}
|
||||
return size
|
||||
|
||||
return headerLength + (extensionsSize+3)/4*4
|
||||
}
|
||||
|
||||
func (h *Header) MarshalTo(buf []byte) (int, error) {
|
||||
@@ -186,18 +204,32 @@ func (h *Header) MarshalTo(buf []byte) (int, error) {
|
||||
binary.BigEndian.PutUint32(buf[timestampOffset:timestampOffset+timestampLength], h.Timestamp)
|
||||
|
||||
if h.HasExtensions {
|
||||
binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (h.ExtensionsSize/4)-1)
|
||||
extensionsSize := (extensionsSizeLength + h.ExtensionsSize + 3) / 4 * 4
|
||||
binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (extensionsSize/4)-1)
|
||||
hdrSize += extensionsSizeLength
|
||||
|
||||
addedSize := 0
|
||||
idx := extensionsSizeOffset + extensionsSizeLength
|
||||
for _, ext := range h.Extensions {
|
||||
binary.BigEndian.PutUint16(buf[idx:idx+extensionIDLength], ext.id)
|
||||
binary.BigEndian.PutUint16(buf[idx+extensionIDLength:idx+extensionIDLength+extensionSizeLength], uint16(len(ext.data)))
|
||||
buf[idx] = ext.id
|
||||
if len(ext.data) > 255 {
|
||||
return 0, fmt.Errorf("%w: %d > 255", errExtensionSizeTooBig, len(ext.data))
|
||||
}
|
||||
buf[idx+extensionIDLength] = byte(len(ext.data))
|
||||
copy(buf[idx+extensionIDLength+extensionSizeLength:], ext.data)
|
||||
|
||||
idx += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */
|
||||
hdrSize += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */
|
||||
extSize := len(ext.data) + extensionIDLength + extensionSizeLength
|
||||
idx += extSize
|
||||
hdrSize += extSize
|
||||
addedSize += extSize
|
||||
}
|
||||
|
||||
paddingSize := extensionsSize - extensionsSizeLength - uint16(addedSize)
|
||||
for i := range paddingSize {
|
||||
buf[idx+int(i)] = 0
|
||||
}
|
||||
idx += int(paddingSize)
|
||||
hdrSize += int(paddingSize)
|
||||
}
|
||||
|
||||
return hdrSize, nil
|
||||
@@ -206,19 +238,19 @@ func (h *Header) MarshalTo(buf []byte) (int, error) {
|
||||
func (h *Header) AddExtension(ext Extension) {
|
||||
for i, existingExt := range h.Extensions {
|
||||
if existingExt.id == ext.id {
|
||||
h.ExtensionsSize -= uint16((len(existingExt.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */)
|
||||
h.ExtensionsSize -= uint16(len(existingExt.data) + extensionIDLength + extensionSizeLength)
|
||||
h.Extensions[i].data = ext.data
|
||||
h.ExtensionsSize += uint16((len(h.Extensions[i].data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */)
|
||||
h.ExtensionsSize += uint16(len(h.Extensions[i].data) + extensionIDLength + extensionSizeLength)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.Extensions = append(h.Extensions, ext)
|
||||
h.ExtensionsSize += uint16((len(ext.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */)
|
||||
h.ExtensionsSize += uint16(len(ext.data) + extensionIDLength + extensionSizeLength)
|
||||
h.HasExtensions = true
|
||||
}
|
||||
|
||||
func (h *Header) GetExtension(id uint16) (Extension, error) {
|
||||
func (h *Header) GetExtension(id uint8) (Extension, error) {
|
||||
for _, ext := range h.Extensions {
|
||||
if ext.id == id {
|
||||
return ext, nil
|
||||
|
||||
@@ -82,10 +82,10 @@ func TestPacket(t *testing.T) {
|
||||
|
||||
expectedRawPacket := []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01,
|
||||
0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70,
|
||||
0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61,
|
||||
0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10,
|
||||
0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72,
|
||||
0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74,
|
||||
0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
require.Equal(t, expectedRawPacket, rawPacket)
|
||||
|
||||
@@ -94,7 +94,7 @@ func TestPacket(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packet, &unmarshaled)
|
||||
|
||||
ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
require.NoError(t, err)
|
||||
|
||||
var extParticipantSid ExtensionParticipantSid
|
||||
@@ -129,10 +129,9 @@ func TestPacket(t *testing.T) {
|
||||
|
||||
expectedRawPacket := []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
|
||||
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
|
||||
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
|
||||
0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b,
|
||||
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
|
||||
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
require.Equal(t, expectedRawPacket, rawPacket)
|
||||
|
||||
@@ -141,7 +140,7 @@ func TestPacket(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packet, &unmarshaled)
|
||||
|
||||
ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
require.NoError(t, err)
|
||||
|
||||
var extParticipantSid ExtensionParticipantSid
|
||||
@@ -176,10 +175,9 @@ func TestPacket(t *testing.T) {
|
||||
|
||||
expectedRawPacket := []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
|
||||
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
|
||||
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
|
||||
0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b,
|
||||
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
|
||||
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
require.Equal(t, expectedRawPacket, rawPacket)
|
||||
|
||||
@@ -194,10 +192,10 @@ func TestPacket(t *testing.T) {
|
||||
|
||||
expectedRawPacket = []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01,
|
||||
0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70,
|
||||
0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61,
|
||||
0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10,
|
||||
0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72,
|
||||
0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74,
|
||||
0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
require.Equal(t, expectedRawPacket, rawPacket)
|
||||
|
||||
@@ -206,7 +204,7 @@ func TestPacket(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packet, &unmarshaled)
|
||||
|
||||
ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
require.NoError(t, err)
|
||||
|
||||
var extParticipantSid ExtensionParticipantSid
|
||||
@@ -214,15 +212,14 @@ func TestPacket(t *testing.T) {
|
||||
require.Equal(t, livekit.ParticipantID("test_participant"), extParticipantSid.ParticipantID())
|
||||
})
|
||||
|
||||
t.Run("bad pcaket", func(t *testing.T) {
|
||||
t.Run("bad packet", func(t *testing.T) {
|
||||
var unmarshaled Packet
|
||||
// extensions size too small
|
||||
badPacket := []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x00, 0x01,
|
||||
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
|
||||
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
|
||||
0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x01, 0x0b,
|
||||
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
|
||||
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
err := unmarshaled.Unmarshal(badPacket)
|
||||
require.Error(t, err)
|
||||
@@ -230,23 +227,21 @@ func TestPacket(t *testing.T) {
|
||||
// get an invalid extension id
|
||||
badPacket = []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x02,
|
||||
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
|
||||
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
|
||||
0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x02, 0x0b,
|
||||
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
|
||||
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
err = unmarshaled.Unmarshal(badPacket)
|
||||
require.NoError(t, err)
|
||||
_, err = unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
_, err = unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
|
||||
require.Error(t, err)
|
||||
|
||||
// extension payload size bigger than payload
|
||||
badPacket = []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
|
||||
0x00, 0x0d, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
|
||||
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
|
||||
0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0d,
|
||||
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
|
||||
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
err = unmarshaled.Unmarshal(badPacket)
|
||||
require.Error(t, err)
|
||||
@@ -254,10 +249,9 @@ func TestPacket(t *testing.T) {
|
||||
// extension payload size smaller than payload
|
||||
badPacket = []byte{
|
||||
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
|
||||
0x00, 0x07, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
|
||||
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
|
||||
0xfd, 0xfc,
|
||||
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x07,
|
||||
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
|
||||
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
|
||||
}
|
||||
err = unmarshaled.Unmarshal(badPacket)
|
||||
require.Error(t, err)
|
||||
|
||||
Reference in New Issue
Block a user