From 5dc2e7b180c11c2d3a8ccd7df096384141c81d70 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 14 Mar 2026 13:29:40 +0530 Subject: [PATCH] Switch data track extension to 1-byte ID/length. (#4362) And match design to RTP header extension, i. e. the padding for extensions is not at per extension level (which was the case before), but has been changed to padding the aggregate of all extensions in this PR. --- .../datatrack/extension_participant_sid.go | 6 +- .../extension_participant_sid_test.go | 4 +- pkg/rtc/datatrack/packet.go | 88 +++++++++++++------ pkg/rtc/datatrack/packet_test.go | 68 +++++++------- 4 files changed, 96 insertions(+), 70 deletions(-) diff --git a/pkg/rtc/datatrack/extension_participant_sid.go b/pkg/rtc/datatrack/extension_participant_sid.go index bfad08a78..35ada18f1 100644 --- a/pkg/rtc/datatrack/extension_participant_sid.go +++ b/pkg/rtc/datatrack/extension_participant_sid.go @@ -25,7 +25,7 @@ type ExtensionParticipantSid struct { } func NewExtensionParticipantSid(participantID livekit.ParticipantID) (*ExtensionParticipantSid, error) { - if len(participantID) >= 65536 { + if len(participantID) >= 256 { return nil, errors.New("participantID too long") } @@ -40,13 +40,13 @@ func (e *ExtensionParticipantSid) Marshal() (Extension, error) { data := make([]byte, len(e.participantID)) copy(data, e.participantID) return Extension{ - id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), + id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), data: data, }, nil } func (e *ExtensionParticipantSid) Unmarshal(ext Extension) error { - if ext.id != uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) { + if ext.id != uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) { return errors.New("invalid extension ID") } diff --git a/pkg/rtc/datatrack/extension_participant_sid_test.go b/pkg/rtc/datatrack/extension_participant_sid_test.go index 243bfc6d3..125db5c1a 100644 --- a/pkg/rtc/datatrack/extension_participant_sid_test.go +++ b/pkg/rtc/datatrack/extension_participant_sid_test.go @@ -23,7 +23,7 @@ import ( ) func TestExtensionParticipantSid(t *testing.T) { - longTestParticipantID := livekit.ParticipantID(make([]byte, 65536)) + longTestParticipantID := livekit.ParticipantID(make([]byte, 256)) extParticipantSid, err := NewExtensionParticipantSid(longTestParticipantID) require.Error(t, err) @@ -32,7 +32,7 @@ func TestExtensionParticipantSid(t *testing.T) { require.NoError(t, err) expectedExt := Extension{ - id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), + id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), data: []byte{'t', 'e', 's', 't'}, } ext, err := extParticipantSid.Marshal() diff --git a/pkg/rtc/datatrack/packet.go b/pkg/rtc/datatrack/packet.go index fe4da8fae..be36318a9 100644 --- a/pkg/rtc/datatrack/packet.go +++ b/pkg/rtc/datatrack/packet.go @@ -25,6 +25,7 @@ var ( errBufferSizeInsufficient = errors.New("data track packet buffer size insufficient") errExtensionSizeInsufficient = errors.New("data track packet extension size insufficient") errExtensionNotFound = errors.New("data track packet extension not found") + errExtensionSizeTooBig = errors.New("extension size is too big") ) const ( @@ -57,12 +58,12 @@ const ( extensionsSizeOffset = headerLength extensionsSizeLength = 2 - extensionIDLength = 2 - extensionSizeLength = 2 + extensionIDLength = 1 + extensionSizeLength = 1 ) type Extension struct { - id uint16 + id uint8 data []byte } @@ -98,9 +99,13 @@ type Header struct { ┆* 0 1 2 3 ┆* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - ┆* | Extension ID | Extension size | + ┆* | Extension ID | Extension size| Extension payload | ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |* Extension payload (padded to 4 byte boundary) | + + End of all extensions + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |* padded to 4 byte boundary if aggregate of `Extensions Size` | + |* field and all extensions do not end on a 4 byte boundary | ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ @@ -121,46 +126,59 @@ func (h *Header) Unmarshal(buf []byte) (int, error) { h.Timestamp = binary.BigEndian.Uint32(buf[timestampOffset : timestampOffset+timestampLength]) if h.HasExtensions { - h.ExtensionsSize = (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength]) + 1) * 4 + extensionsSize := (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength])+1)*4 - extensionsSizeLength hdrSize += extensionsSizeLength - remainingSize := int(h.ExtensionsSize) + extensionHeaderSize := extensionIDLength + extensionSizeLength + remainingSize := int(extensionsSize) idx := extensionsSizeOffset + extensionsSizeLength for remainingSize != 0 { - if len(buf[idx:]) < 4 || remainingSize < 4 { - return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), 4) + // read extension header + if len(buf[idx:]) < extensionIDLength || remainingSize < extensionIDLength { + return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionIDLength) + } + id := buf[idx] + if id == 0 { + // end of extensions, padding has started + hdrSize += remainingSize + break } - id := binary.BigEndian.Uint16(buf[idx : idx+2]) - size := int(binary.BigEndian.Uint16(buf[idx+2 : idx+4])) - remainingSize -= 4 - idx += 4 - hdrSize += 4 + if len(buf[idx+1:]) < extensionSizeLength || remainingSize < extensionSizeLength { + return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionSizeLength) + } + size := int(buf[idx+1]) + remainingSize -= extensionHeaderSize + idx += extensionHeaderSize + hdrSize += extensionHeaderSize + + // read extension data if len(buf[idx:]) < size || remainingSize < size { return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), size) } h.Extensions = append(h.Extensions, Extension{id: id, data: buf[idx : idx+size]}) - size = ((size + 3) / 4) * 4 remainingSize -= size idx += size hdrSize += size } + h.ExtensionsSize = extensionsSize - uint16(remainingSize) } return hdrSize, nil } func (h *Header) MarshalSize() int { - size := headerLength + extensionsSize := 0 if h.HasExtensions { - size += 2 // extensions size field + extensionsSize += extensionsSizeLength for _, ext := range h.Extensions { - size += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + extensionsSize += len(ext.data) + extensionIDLength + extensionSizeLength } } - return size + + return headerLength + (extensionsSize+3)/4*4 } func (h *Header) MarshalTo(buf []byte) (int, error) { @@ -186,18 +204,32 @@ func (h *Header) MarshalTo(buf []byte) (int, error) { binary.BigEndian.PutUint32(buf[timestampOffset:timestampOffset+timestampLength], h.Timestamp) if h.HasExtensions { - binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (h.ExtensionsSize/4)-1) + extensionsSize := (extensionsSizeLength + h.ExtensionsSize + 3) / 4 * 4 + binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (extensionsSize/4)-1) hdrSize += extensionsSizeLength + addedSize := 0 idx := extensionsSizeOffset + extensionsSizeLength for _, ext := range h.Extensions { - binary.BigEndian.PutUint16(buf[idx:idx+extensionIDLength], ext.id) - binary.BigEndian.PutUint16(buf[idx+extensionIDLength:idx+extensionIDLength+extensionSizeLength], uint16(len(ext.data))) + buf[idx] = ext.id + if len(ext.data) > 255 { + return 0, fmt.Errorf("%w: %d > 255", errExtensionSizeTooBig, len(ext.data)) + } + buf[idx+extensionIDLength] = byte(len(ext.data)) copy(buf[idx+extensionIDLength+extensionSizeLength:], ext.data) - idx += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ - hdrSize += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + extSize := len(ext.data) + extensionIDLength + extensionSizeLength + idx += extSize + hdrSize += extSize + addedSize += extSize } + + paddingSize := extensionsSize - extensionsSizeLength - uint16(addedSize) + for i := range paddingSize { + buf[idx+int(i)] = 0 + } + idx += int(paddingSize) + hdrSize += int(paddingSize) } return hdrSize, nil @@ -206,19 +238,19 @@ func (h *Header) MarshalTo(buf []byte) (int, error) { func (h *Header) AddExtension(ext Extension) { for i, existingExt := range h.Extensions { if existingExt.id == ext.id { - h.ExtensionsSize -= uint16((len(existingExt.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.ExtensionsSize -= uint16(len(existingExt.data) + extensionIDLength + extensionSizeLength) h.Extensions[i].data = ext.data - h.ExtensionsSize += uint16((len(h.Extensions[i].data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.ExtensionsSize += uint16(len(h.Extensions[i].data) + extensionIDLength + extensionSizeLength) return } } h.Extensions = append(h.Extensions, ext) - h.ExtensionsSize += uint16((len(ext.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.ExtensionsSize += uint16(len(ext.data) + extensionIDLength + extensionSizeLength) h.HasExtensions = true } -func (h *Header) GetExtension(id uint16) (Extension, error) { +func (h *Header) GetExtension(id uint8) (Extension, error) { for _, ext := range h.Extensions { if ext.id == id { return ext, nil diff --git a/pkg/rtc/datatrack/packet_test.go b/pkg/rtc/datatrack/packet_test.go index a144ebfe2..4d8b2c6cf 100644 --- a/pkg/rtc/datatrack/packet_test.go +++ b/pkg/rtc/datatrack/packet_test.go @@ -82,10 +82,10 @@ func TestPacket(t *testing.T) { expectedRawPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01, - 0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, - 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, - 0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, + 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, + 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -94,7 +94,7 @@ func TestPacket(t *testing.T) { require.NoError(t, err) require.Equal(t, packet, &unmarshaled) - ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.NoError(t, err) var extParticipantSid ExtensionParticipantSid @@ -129,10 +129,9 @@ func TestPacket(t *testing.T) { expectedRawPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -141,7 +140,7 @@ func TestPacket(t *testing.T) { require.NoError(t, err) require.Equal(t, packet, &unmarshaled) - ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.NoError(t, err) var extParticipantSid ExtensionParticipantSid @@ -176,10 +175,9 @@ func TestPacket(t *testing.T) { expectedRawPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -194,10 +192,10 @@ func TestPacket(t *testing.T) { expectedRawPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01, - 0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, - 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, - 0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, + 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, + 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -206,7 +204,7 @@ func TestPacket(t *testing.T) { require.NoError(t, err) require.Equal(t, packet, &unmarshaled) - ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.NoError(t, err) var extParticipantSid ExtensionParticipantSid @@ -214,15 +212,14 @@ func TestPacket(t *testing.T) { require.Equal(t, livekit.ParticipantID("test_participant"), extParticipantSid.ParticipantID()) }) - t.Run("bad pcaket", func(t *testing.T) { + t.Run("bad packet", func(t *testing.T) { var unmarshaled Packet // extensions size too small badPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x00, 0x01, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x01, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err := unmarshaled.Unmarshal(badPacket) require.Error(t, err) @@ -230,23 +227,21 @@ func TestPacket(t *testing.T) { // get an invalid extension id badPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x02, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x02, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err = unmarshaled.Unmarshal(badPacket) require.NoError(t, err) - _, err = unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + _, err = unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.Error(t, err) // extension payload size bigger than payload badPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x0d, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0d, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err = unmarshaled.Unmarshal(badPacket) require.Error(t, err) @@ -254,10 +249,9 @@ func TestPacket(t *testing.T) { // extension payload size smaller than payload badPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x07, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x07, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err = unmarshaled.Unmarshal(badPacket) require.Error(t, err)