Switch data track extension to 1-byte ID/length. (#4362)

And match design to RTP header extension, i. e. the padding for
extensions is not at per extension level (which was the case before),
but has been changed to padding the aggregate of all extensions in this
PR.
This commit is contained in:
Raja Subramanian
2026-03-14 13:29:40 +05:30
committed by GitHub
parent 7323ad02b7
commit 5dc2e7b180
4 changed files with 96 additions and 70 deletions

View File

@@ -25,7 +25,7 @@ type ExtensionParticipantSid struct {
}
func NewExtensionParticipantSid(participantID livekit.ParticipantID) (*ExtensionParticipantSid, error) {
if len(participantID) >= 65536 {
if len(participantID) >= 256 {
return nil, errors.New("participantID too long")
}
@@ -40,13 +40,13 @@ func (e *ExtensionParticipantSid) Marshal() (Extension, error) {
data := make([]byte, len(e.participantID))
copy(data, e.participantID)
return Extension{
id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
data: data,
}, nil
}
func (e *ExtensionParticipantSid) Unmarshal(ext Extension) error {
if ext.id != uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) {
if ext.id != uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) {
return errors.New("invalid extension ID")
}

View File

@@ -23,7 +23,7 @@ import (
)
func TestExtensionParticipantSid(t *testing.T) {
longTestParticipantID := livekit.ParticipantID(make([]byte, 65536))
longTestParticipantID := livekit.ParticipantID(make([]byte, 256))
extParticipantSid, err := NewExtensionParticipantSid(longTestParticipantID)
require.Error(t, err)
@@ -32,7 +32,7 @@ func TestExtensionParticipantSid(t *testing.T) {
require.NoError(t, err)
expectedExt := Extension{
id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID),
data: []byte{'t', 'e', 's', 't'},
}
ext, err := extParticipantSid.Marshal()

View File

@@ -25,6 +25,7 @@ var (
errBufferSizeInsufficient = errors.New("data track packet buffer size insufficient")
errExtensionSizeInsufficient = errors.New("data track packet extension size insufficient")
errExtensionNotFound = errors.New("data track packet extension not found")
errExtensionSizeTooBig = errors.New("extension size is too big")
)
const (
@@ -57,12 +58,12 @@ const (
extensionsSizeOffset = headerLength
extensionsSizeLength = 2
extensionIDLength = 2
extensionSizeLength = 2
extensionIDLength = 1
extensionSizeLength = 1
)
type Extension struct {
id uint16
id uint8
data []byte
}
@@ -98,9 +99,13 @@ type Header struct {
┆* 0 1 2 3
┆* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
┆* | Extension ID | Extension size |
┆* | Extension ID | Extension size| Extension payload |
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|* Extension payload (padded to 4 byte boundary) |
End of all extensions
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|* padded to 4 byte boundary if aggregate of `Extensions Size` |
|* field and all extensions do not end on a 4 byte boundary |
┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
@@ -121,46 +126,59 @@ func (h *Header) Unmarshal(buf []byte) (int, error) {
h.Timestamp = binary.BigEndian.Uint32(buf[timestampOffset : timestampOffset+timestampLength])
if h.HasExtensions {
h.ExtensionsSize = (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength]) + 1) * 4
extensionsSize := (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength])+1)*4 - extensionsSizeLength
hdrSize += extensionsSizeLength
remainingSize := int(h.ExtensionsSize)
extensionHeaderSize := extensionIDLength + extensionSizeLength
remainingSize := int(extensionsSize)
idx := extensionsSizeOffset + extensionsSizeLength
for remainingSize != 0 {
if len(buf[idx:]) < 4 || remainingSize < 4 {
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), 4)
// read extension header
if len(buf[idx:]) < extensionIDLength || remainingSize < extensionIDLength {
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionIDLength)
}
id := buf[idx]
if id == 0 {
// end of extensions, padding has started
hdrSize += remainingSize
break
}
id := binary.BigEndian.Uint16(buf[idx : idx+2])
size := int(binary.BigEndian.Uint16(buf[idx+2 : idx+4]))
remainingSize -= 4
idx += 4
hdrSize += 4
if len(buf[idx+1:]) < extensionSizeLength || remainingSize < extensionSizeLength {
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionSizeLength)
}
size := int(buf[idx+1])
remainingSize -= extensionHeaderSize
idx += extensionHeaderSize
hdrSize += extensionHeaderSize
// read extension data
if len(buf[idx:]) < size || remainingSize < size {
return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), size)
}
h.Extensions = append(h.Extensions, Extension{id: id, data: buf[idx : idx+size]})
size = ((size + 3) / 4) * 4
remainingSize -= size
idx += size
hdrSize += size
}
h.ExtensionsSize = extensionsSize - uint16(remainingSize)
}
return hdrSize, nil
}
func (h *Header) MarshalSize() int {
size := headerLength
extensionsSize := 0
if h.HasExtensions {
size += 2 // extensions size field
extensionsSize += extensionsSizeLength
for _, ext := range h.Extensions {
size += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */
extensionsSize += len(ext.data) + extensionIDLength + extensionSizeLength
}
}
return size
return headerLength + (extensionsSize+3)/4*4
}
func (h *Header) MarshalTo(buf []byte) (int, error) {
@@ -186,18 +204,32 @@ func (h *Header) MarshalTo(buf []byte) (int, error) {
binary.BigEndian.PutUint32(buf[timestampOffset:timestampOffset+timestampLength], h.Timestamp)
if h.HasExtensions {
binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (h.ExtensionsSize/4)-1)
extensionsSize := (extensionsSizeLength + h.ExtensionsSize + 3) / 4 * 4
binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (extensionsSize/4)-1)
hdrSize += extensionsSizeLength
addedSize := 0
idx := extensionsSizeOffset + extensionsSizeLength
for _, ext := range h.Extensions {
binary.BigEndian.PutUint16(buf[idx:idx+extensionIDLength], ext.id)
binary.BigEndian.PutUint16(buf[idx+extensionIDLength:idx+extensionIDLength+extensionSizeLength], uint16(len(ext.data)))
buf[idx] = ext.id
if len(ext.data) > 255 {
return 0, fmt.Errorf("%w: %d > 255", errExtensionSizeTooBig, len(ext.data))
}
buf[idx+extensionIDLength] = byte(len(ext.data))
copy(buf[idx+extensionIDLength+extensionSizeLength:], ext.data)
idx += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */
hdrSize += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */
extSize := len(ext.data) + extensionIDLength + extensionSizeLength
idx += extSize
hdrSize += extSize
addedSize += extSize
}
paddingSize := extensionsSize - extensionsSizeLength - uint16(addedSize)
for i := range paddingSize {
buf[idx+int(i)] = 0
}
idx += int(paddingSize)
hdrSize += int(paddingSize)
}
return hdrSize, nil
@@ -206,19 +238,19 @@ func (h *Header) MarshalTo(buf []byte) (int, error) {
func (h *Header) AddExtension(ext Extension) {
for i, existingExt := range h.Extensions {
if existingExt.id == ext.id {
h.ExtensionsSize -= uint16((len(existingExt.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */)
h.ExtensionsSize -= uint16(len(existingExt.data) + extensionIDLength + extensionSizeLength)
h.Extensions[i].data = ext.data
h.ExtensionsSize += uint16((len(h.Extensions[i].data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */)
h.ExtensionsSize += uint16(len(h.Extensions[i].data) + extensionIDLength + extensionSizeLength)
return
}
}
h.Extensions = append(h.Extensions, ext)
h.ExtensionsSize += uint16((len(ext.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */)
h.ExtensionsSize += uint16(len(ext.data) + extensionIDLength + extensionSizeLength)
h.HasExtensions = true
}
func (h *Header) GetExtension(id uint16) (Extension, error) {
func (h *Header) GetExtension(id uint8) (Extension, error) {
for _, ext := range h.Extensions {
if ext.id == id {
return ext, nil

View File

@@ -82,10 +82,10 @@ func TestPacket(t *testing.T) {
expectedRawPacket := []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01,
0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70,
0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61,
0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10,
0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72,
0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74,
0xff, 0xfe, 0xfd, 0xfc,
}
require.Equal(t, expectedRawPacket, rawPacket)
@@ -94,7 +94,7 @@ func TestPacket(t *testing.T) {
require.NoError(t, err)
require.Equal(t, packet, &unmarshaled)
ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
require.NoError(t, err)
var extParticipantSid ExtensionParticipantSid
@@ -129,10 +129,9 @@ func TestPacket(t *testing.T) {
expectedRawPacket := []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b,
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
}
require.Equal(t, expectedRawPacket, rawPacket)
@@ -141,7 +140,7 @@ func TestPacket(t *testing.T) {
require.NoError(t, err)
require.Equal(t, packet, &unmarshaled)
ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
require.NoError(t, err)
var extParticipantSid ExtensionParticipantSid
@@ -176,10 +175,9 @@ func TestPacket(t *testing.T) {
expectedRawPacket := []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b,
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
}
require.Equal(t, expectedRawPacket, rawPacket)
@@ -194,10 +192,10 @@ func TestPacket(t *testing.T) {
expectedRawPacket = []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01,
0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70,
0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61,
0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10,
0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72,
0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74,
0xff, 0xfe, 0xfd, 0xfc,
}
require.Equal(t, expectedRawPacket, rawPacket)
@@ -206,7 +204,7 @@ func TestPacket(t *testing.T) {
require.NoError(t, err)
require.Equal(t, packet, &unmarshaled)
ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
require.NoError(t, err)
var extParticipantSid ExtensionParticipantSid
@@ -214,15 +212,14 @@ func TestPacket(t *testing.T) {
require.Equal(t, livekit.ParticipantID("test_participant"), extParticipantSid.ParticipantID())
})
t.Run("bad pcaket", func(t *testing.T) {
t.Run("bad packet", func(t *testing.T) {
var unmarshaled Packet
// extensions size too small
badPacket := []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x00, 0x01,
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x01, 0x0b,
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
}
err := unmarshaled.Unmarshal(badPacket)
require.Error(t, err)
@@ -230,23 +227,21 @@ func TestPacket(t *testing.T) {
// get an invalid extension id
badPacket = []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x02,
0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x02, 0x0b,
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
}
err = unmarshaled.Unmarshal(badPacket)
require.NoError(t, err)
_, err = unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
_, err = unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID))
require.Error(t, err)
// extension payload size bigger than payload
badPacket = []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
0x00, 0x0d, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0d,
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
}
err = unmarshaled.Unmarshal(badPacket)
require.Error(t, err)
@@ -254,10 +249,9 @@ func TestPacket(t *testing.T) {
// extension payload size smaller than payload
badPacket = []byte{
0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01,
0x00, 0x07, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63,
0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe,
0xfd, 0xfc,
0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x07,
0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70,
0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc,
}
err = unmarshaled.Unmarshal(badPacket)
require.Error(t, err)