diff --git a/pkg/rtc/datatrack/extension_participant_sid.go b/pkg/rtc/datatrack/extension_participant_sid.go index bfad08a78..35ada18f1 100644 --- a/pkg/rtc/datatrack/extension_participant_sid.go +++ b/pkg/rtc/datatrack/extension_participant_sid.go @@ -25,7 +25,7 @@ type ExtensionParticipantSid struct { } func NewExtensionParticipantSid(participantID livekit.ParticipantID) (*ExtensionParticipantSid, error) { - if len(participantID) >= 65536 { + if len(participantID) >= 256 { return nil, errors.New("participantID too long") } @@ -40,13 +40,13 @@ func (e *ExtensionParticipantSid) Marshal() (Extension, error) { data := make([]byte, len(e.participantID)) copy(data, e.participantID) return Extension{ - id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), + id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), data: data, }, nil } func (e *ExtensionParticipantSid) Unmarshal(ext Extension) error { - if ext.id != uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) { + if ext.id != uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) { return errors.New("invalid extension ID") } diff --git a/pkg/rtc/datatrack/extension_participant_sid_test.go b/pkg/rtc/datatrack/extension_participant_sid_test.go index 243bfc6d3..125db5c1a 100644 --- a/pkg/rtc/datatrack/extension_participant_sid_test.go +++ b/pkg/rtc/datatrack/extension_participant_sid_test.go @@ -23,7 +23,7 @@ import ( ) func TestExtensionParticipantSid(t *testing.T) { - longTestParticipantID := livekit.ParticipantID(make([]byte, 65536)) + longTestParticipantID := livekit.ParticipantID(make([]byte, 256)) extParticipantSid, err := NewExtensionParticipantSid(longTestParticipantID) require.Error(t, err) @@ -32,7 +32,7 @@ func TestExtensionParticipantSid(t *testing.T) { require.NoError(t, err) expectedExt := Extension{ - id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), + id: uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), data: []byte{'t', 'e', 's', 't'}, } ext, err := extParticipantSid.Marshal() diff --git a/pkg/rtc/datatrack/packet.go b/pkg/rtc/datatrack/packet.go index fe4da8fae..be36318a9 100644 --- a/pkg/rtc/datatrack/packet.go +++ b/pkg/rtc/datatrack/packet.go @@ -25,6 +25,7 @@ var ( errBufferSizeInsufficient = errors.New("data track packet buffer size insufficient") errExtensionSizeInsufficient = errors.New("data track packet extension size insufficient") errExtensionNotFound = errors.New("data track packet extension not found") + errExtensionSizeTooBig = errors.New("extension size is too big") ) const ( @@ -57,12 +58,12 @@ const ( extensionsSizeOffset = headerLength extensionsSizeLength = 2 - extensionIDLength = 2 - extensionSizeLength = 2 + extensionIDLength = 1 + extensionSizeLength = 1 ) type Extension struct { - id uint16 + id uint8 data []byte } @@ -98,9 +99,13 @@ type Header struct { ┆* 0 1 2 3 ┆* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - ┆* | Extension ID | Extension size | + ┆* | Extension ID | Extension size| Extension payload | ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |* Extension payload (padded to 4 byte boundary) | + + End of all extensions + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |* padded to 4 byte boundary if aggregate of `Extensions Size` | + |* field and all extensions do not end on a 4 byte boundary | ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ @@ -121,46 +126,59 @@ func (h *Header) Unmarshal(buf []byte) (int, error) { h.Timestamp = binary.BigEndian.Uint32(buf[timestampOffset : timestampOffset+timestampLength]) if h.HasExtensions { - h.ExtensionsSize = (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength]) + 1) * 4 + extensionsSize := (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength])+1)*4 - extensionsSizeLength hdrSize += extensionsSizeLength - remainingSize := int(h.ExtensionsSize) + extensionHeaderSize := extensionIDLength + extensionSizeLength + remainingSize := int(extensionsSize) idx := extensionsSizeOffset + extensionsSizeLength for remainingSize != 0 { - if len(buf[idx:]) < 4 || remainingSize < 4 { - return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), 4) + // read extension header + if len(buf[idx:]) < extensionIDLength || remainingSize < extensionIDLength { + return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionIDLength) + } + id := buf[idx] + if id == 0 { + // end of extensions, padding has started + hdrSize += remainingSize + break } - id := binary.BigEndian.Uint16(buf[idx : idx+2]) - size := int(binary.BigEndian.Uint16(buf[idx+2 : idx+4])) - remainingSize -= 4 - idx += 4 - hdrSize += 4 + if len(buf[idx+1:]) < extensionSizeLength || remainingSize < extensionSizeLength { + return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), extensionSizeLength) + } + size := int(buf[idx+1]) + remainingSize -= extensionHeaderSize + idx += extensionHeaderSize + hdrSize += extensionHeaderSize + + // read extension data if len(buf[idx:]) < size || remainingSize < size { return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), size) } h.Extensions = append(h.Extensions, Extension{id: id, data: buf[idx : idx+size]}) - size = ((size + 3) / 4) * 4 remainingSize -= size idx += size hdrSize += size } + h.ExtensionsSize = extensionsSize - uint16(remainingSize) } return hdrSize, nil } func (h *Header) MarshalSize() int { - size := headerLength + extensionsSize := 0 if h.HasExtensions { - size += 2 // extensions size field + extensionsSize += extensionsSizeLength for _, ext := range h.Extensions { - size += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + extensionsSize += len(ext.data) + extensionIDLength + extensionSizeLength } } - return size + + return headerLength + (extensionsSize+3)/4*4 } func (h *Header) MarshalTo(buf []byte) (int, error) { @@ -186,18 +204,32 @@ func (h *Header) MarshalTo(buf []byte) (int, error) { binary.BigEndian.PutUint32(buf[timestampOffset:timestampOffset+timestampLength], h.Timestamp) if h.HasExtensions { - binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (h.ExtensionsSize/4)-1) + extensionsSize := (extensionsSizeLength + h.ExtensionsSize + 3) / 4 * 4 + binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (extensionsSize/4)-1) hdrSize += extensionsSizeLength + addedSize := 0 idx := extensionsSizeOffset + extensionsSizeLength for _, ext := range h.Extensions { - binary.BigEndian.PutUint16(buf[idx:idx+extensionIDLength], ext.id) - binary.BigEndian.PutUint16(buf[idx+extensionIDLength:idx+extensionIDLength+extensionSizeLength], uint16(len(ext.data))) + buf[idx] = ext.id + if len(ext.data) > 255 { + return 0, fmt.Errorf("%w: %d > 255", errExtensionSizeTooBig, len(ext.data)) + } + buf[idx+extensionIDLength] = byte(len(ext.data)) copy(buf[idx+extensionIDLength+extensionSizeLength:], ext.data) - idx += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ - hdrSize += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + extSize := len(ext.data) + extensionIDLength + extensionSizeLength + idx += extSize + hdrSize += extSize + addedSize += extSize } + + paddingSize := extensionsSize - extensionsSizeLength - uint16(addedSize) + for i := range paddingSize { + buf[idx+int(i)] = 0 + } + idx += int(paddingSize) + hdrSize += int(paddingSize) } return hdrSize, nil @@ -206,19 +238,19 @@ func (h *Header) MarshalTo(buf []byte) (int, error) { func (h *Header) AddExtension(ext Extension) { for i, existingExt := range h.Extensions { if existingExt.id == ext.id { - h.ExtensionsSize -= uint16((len(existingExt.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.ExtensionsSize -= uint16(len(existingExt.data) + extensionIDLength + extensionSizeLength) h.Extensions[i].data = ext.data - h.ExtensionsSize += uint16((len(h.Extensions[i].data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.ExtensionsSize += uint16(len(h.Extensions[i].data) + extensionIDLength + extensionSizeLength) return } } h.Extensions = append(h.Extensions, ext) - h.ExtensionsSize += uint16((len(ext.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.ExtensionsSize += uint16(len(ext.data) + extensionIDLength + extensionSizeLength) h.HasExtensions = true } -func (h *Header) GetExtension(id uint16) (Extension, error) { +func (h *Header) GetExtension(id uint8) (Extension, error) { for _, ext := range h.Extensions { if ext.id == id { return ext, nil diff --git a/pkg/rtc/datatrack/packet_test.go b/pkg/rtc/datatrack/packet_test.go index a144ebfe2..4d8b2c6cf 100644 --- a/pkg/rtc/datatrack/packet_test.go +++ b/pkg/rtc/datatrack/packet_test.go @@ -82,10 +82,10 @@ func TestPacket(t *testing.T) { expectedRawPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01, - 0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, - 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, - 0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, + 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, + 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -94,7 +94,7 @@ func TestPacket(t *testing.T) { require.NoError(t, err) require.Equal(t, packet, &unmarshaled) - ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.NoError(t, err) var extParticipantSid ExtensionParticipantSid @@ -129,10 +129,9 @@ func TestPacket(t *testing.T) { expectedRawPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -141,7 +140,7 @@ func TestPacket(t *testing.T) { require.NoError(t, err) require.Equal(t, packet, &unmarshaled) - ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.NoError(t, err) var extParticipantSid ExtensionParticipantSid @@ -176,10 +175,9 @@ func TestPacket(t *testing.T) { expectedRawPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -194,10 +192,10 @@ func TestPacket(t *testing.T) { expectedRawPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01, - 0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, - 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, - 0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x01, 0x10, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, + 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, + 0xff, 0xfe, 0xfd, 0xfc, } require.Equal(t, expectedRawPacket, rawPacket) @@ -206,7 +204,7 @@ func TestPacket(t *testing.T) { require.NoError(t, err) require.Equal(t, packet, &unmarshaled) - ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + ext, err := unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.NoError(t, err) var extParticipantSid ExtensionParticipantSid @@ -214,15 +212,14 @@ func TestPacket(t *testing.T) { require.Equal(t, livekit.ParticipantID("test_participant"), extParticipantSid.ParticipantID()) }) - t.Run("bad pcaket", func(t *testing.T) { + t.Run("bad packet", func(t *testing.T) { var unmarshaled Packet // extensions size too small badPacket := []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x00, 0x01, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x01, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err := unmarshaled.Unmarshal(badPacket) require.Error(t, err) @@ -230,23 +227,21 @@ func TestPacket(t *testing.T) { // get an invalid extension id badPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x02, - 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x02, 0x0b, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err = unmarshaled.Unmarshal(badPacket) require.NoError(t, err) - _, err = unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + _, err = unmarshaled.GetExtension(uint8(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) require.Error(t, err) // extension payload size bigger than payload badPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x0d, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x0d, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err = unmarshaled.Unmarshal(badPacket) require.Error(t, err) @@ -254,10 +249,9 @@ func TestPacket(t *testing.T) { // extension payload size smaller than payload badPacket = []byte{ 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, - 0x00, 0x07, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, - 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, - 0xfd, 0xfc, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x01, 0x07, + 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, + 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, 0xfd, 0xfc, } err = unmarshaled.Unmarshal(badPacket) require.Error(t, err)