diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index da586f028..bbd3d44a7 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -99,11 +99,6 @@ type Buffer struct { frameRateCalculated bool } -// BufferOptions provides configuration options for the buffer -type Options struct { - MaxBitRate uint64 -} - // NewBuffer constructs a new Buffer func NewBuffer(ssrc uint32, vp, ap *sync.Pool) *Buffer { l := logger.GetDefaultLogger() // will be reset with correct context via SetLogger @@ -272,16 +267,22 @@ func (b *Buffer) Read(buff []byte) (n int, err error) { } } -func (b *Buffer) ReadExtended() (*ExtPacket, error) { +func (b *Buffer) ReadExtended(buf []byte) (*ExtPacket, error) { for { if b.closed.Load() { return nil, io.EOF } b.Lock() if b.extPackets.Len() > 0 { - extPkt := b.extPackets.PopFront().(*ExtPacket) + ep := b.extPackets.PopFront().(*ExtPacket) + ep = b.patchExtPacket(ep, buf) + if ep == nil { + b.Unlock() + continue + } + b.Unlock() - return extPkt, nil + return ep, nil } b.Unlock() time.Sleep(10 * time.Millisecond) @@ -363,7 +364,7 @@ func (b *Buffer) SetRTT(rtt uint32) { } func (b *Buffer) calc(pkt []byte, arrivalTime int64) { - pb, err := b.bucket.AddPacket(pkt) + pktBuf, err := b.bucket.AddPacket(pkt) if err != nil { // // Even when erroring, do @@ -385,7 +386,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { } var p rtp.Packet - err = p.Unmarshal(pb) + err = p.Unmarshal(pktBuf) if err != nil { b.logger.Warnw("error unmarshaling RTP packet", err) return @@ -394,19 +395,41 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { b.updateStreamState(&p, arrivalTime) b.processHeaderExtensions(&p, arrivalTime) - ep := b.getExtPacket(pb, &p, arrivalTime) + b.doNACKs() + + b.doReports(arrivalTime) + + ep := b.getExtPacket(&p, arrivalTime) if ep == nil { return } b.extPackets.PushBack(ep) - b.doNACKs() - - b.doReports(arrivalTime) - b.doFpsCalc(ep) } +func (b *Buffer) patchExtPacket(ep *ExtPacket, buf []byte) *ExtPacket { + n, err := b.getPacket(buf, ep.Packet.SequenceNumber) + if err != nil { + b.logger.Warnw("could not get packet", err, "sn", ep.Packet.SequenceNumber) + return nil + } + ep.RawPacket = buf[:n] + + // patch RTP packet to point payload to new buffer + rtp := *ep.Packet + payloadStart := ep.Packet.Header.MarshalSize() + payloadEnd := payloadStart + len(ep.Packet.Payload) + if payloadEnd > n { + b.logger.Warnw("unexpected marshal size", nil, "max", n, "need", payloadEnd) + return nil + } + rtp.Payload = buf[payloadStart:payloadEnd] + ep.Packet = &rtp + + return ep +} + func (b *Buffer) doFpsCalc(ep *ExtPacket) { if b.frameRateCalculated || len(ep.Packet.Payload) == 0 { return @@ -478,11 +501,10 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime int64) { } } -func (b *Buffer) getExtPacket(rawPacket []byte, rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket { +func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket { ep := &ExtPacket{ - Packet: rtpPacket, - Arrival: arrivalTime, - RawPacket: rawPacket, + Packet: rtpPacket, + Arrival: arrivalTime, VideoLayer: VideoLayer{ Spatial: InvalidLayerSpatial, Temporal: InvalidLayerTemporal, @@ -619,6 +641,11 @@ func (b *Buffer) getRTCP() []rtcp.Packet { func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) { b.Lock() defer b.Unlock() + + return b.getPacket(buff, sn) +} + +func (b *Buffer) getPacket(buff []byte, sn uint16) (int, error) { if b.closed.Load() { return 0, io.EOF } diff --git a/pkg/sfu/buffer/buffer_test.go b/pkg/sfu/buffer/buffer_test.go index 91e5d5196..a3207336e 100644 --- a/pkg/sfu/buffer/buffer_test.go +++ b/pkg/sfu/buffer/buffer_test.go @@ -146,20 +146,11 @@ func TestNack(t *testing.T) { } func TestNewBuffer(t *testing.T) { - type args struct { - options Options - } tests := []struct { name string - args args }{ { name: "Must not be nil and add packets in sequence", - args: args{ - options: Options{ - MaxBitRate: 1e6, - }, - }, }, } for _, tt := range tests { diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 1cbfaf2d7..3b911b729 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -12,6 +12,7 @@ import ( "github.com/pion/webrtc/v3" "go.uber.org/atomic" + "github.com/livekit/mediatransportutil/pkg/bucket" "github.com/livekit/mediatransportutil/pkg/twcc" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -516,6 +517,7 @@ func (w *WebRTCReceiver) getDeltaStats() map[uint32]*buffer.StreamStatsWithLayer } func (w *WebRTCReceiver) forwardRTP(layer int32) { + pktBuf := make([]byte, bucket.MaxPktSize) tracker := w.streamTrackerManager.GetTracker(layer) defer func() { @@ -541,7 +543,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { buf := w.buffers[layer] redPktWriter := w.redPktWriter w.bufferMu.RUnlock() - pkt, err := buf.ReadExtended() + pkt, err := buf.ReadExtended(pktBuf) if err == io.EOF { return }