From aba18accd9cc2ec638c01b4296d21ede8d21ee3e Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 19 Nov 2022 13:19:49 +0530 Subject: [PATCH] Prevent rtx buffer and forwarding path colliding (#1174) * Prevent rtx buffer and forwarding path colliding Received packets are put into RTX buffer which is a circular buffer and the packet (sequence number) is queued for forwarding. If the RTX buffer fills up and cycles before forwarding happens, forwarding would pick the wrong packet (as it is holding a reference to a byte slice in the RTX buffer) to forward. Prevent it by moving reading from RTX buffer just before forwarding. Adds an extra copy from RTX buffer -> temp buffer for forwarding, but ensures that forwarding buffer is not used by another go routine. * Revert some changes from previous commit Details: - Do all forward processing as before. - One difference is not load raw packet into ExtPacket. - Load raw packet into provided buffer when module that reads using ReadExtended calls that function. If the packet is not there in the retransmission buffer, that packet will be dropped. This is the case we are trying to fix, i. e. the RTX buffer has cycled before ReadExtended could pull the packet. This makes a copy into the provided buffer so that the data does not change underneath. * Remove debug comment * Oops missed a function call --- pkg/sfu/buffer/buffer.go | 65 +++++++++++++++++++++++++---------- pkg/sfu/buffer/buffer_test.go | 9 ----- pkg/sfu/receiver.go | 4 ++- 3 files changed, 49 insertions(+), 29 deletions(-) diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index da586f028..bbd3d44a7 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -99,11 +99,6 @@ type Buffer struct { frameRateCalculated bool } -// BufferOptions provides configuration options for the buffer -type Options struct { - MaxBitRate uint64 -} - // NewBuffer constructs a new Buffer func NewBuffer(ssrc uint32, vp, ap *sync.Pool) *Buffer { l := logger.GetDefaultLogger() // will be reset with correct context via SetLogger @@ -272,16 +267,22 @@ func (b *Buffer) Read(buff []byte) (n int, err error) { } } -func (b *Buffer) ReadExtended() (*ExtPacket, error) { +func (b *Buffer) ReadExtended(buf []byte) (*ExtPacket, error) { for { if b.closed.Load() { return nil, io.EOF } b.Lock() if b.extPackets.Len() > 0 { - extPkt := b.extPackets.PopFront().(*ExtPacket) + ep := b.extPackets.PopFront().(*ExtPacket) + ep = b.patchExtPacket(ep, buf) + if ep == nil { + b.Unlock() + continue + } + b.Unlock() - return extPkt, nil + return ep, nil } b.Unlock() time.Sleep(10 * time.Millisecond) @@ -363,7 +364,7 @@ func (b *Buffer) SetRTT(rtt uint32) { } func (b *Buffer) calc(pkt []byte, arrivalTime int64) { - pb, err := b.bucket.AddPacket(pkt) + pktBuf, err := b.bucket.AddPacket(pkt) if err != nil { // // Even when erroring, do @@ -385,7 +386,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { } var p rtp.Packet - err = p.Unmarshal(pb) + err = p.Unmarshal(pktBuf) if err != nil { b.logger.Warnw("error unmarshaling RTP packet", err) return @@ -394,19 +395,41 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { b.updateStreamState(&p, arrivalTime) b.processHeaderExtensions(&p, arrivalTime) - ep := b.getExtPacket(pb, &p, arrivalTime) + b.doNACKs() + + b.doReports(arrivalTime) + + ep := b.getExtPacket(&p, arrivalTime) if ep == nil { return } b.extPackets.PushBack(ep) - b.doNACKs() - - b.doReports(arrivalTime) - b.doFpsCalc(ep) } +func (b *Buffer) patchExtPacket(ep *ExtPacket, buf []byte) *ExtPacket { + n, err := b.getPacket(buf, ep.Packet.SequenceNumber) + if err != nil { + b.logger.Warnw("could not get packet", err, "sn", ep.Packet.SequenceNumber) + return nil + } + ep.RawPacket = buf[:n] + + // patch RTP packet to point payload to new buffer + rtp := *ep.Packet + payloadStart := ep.Packet.Header.MarshalSize() + payloadEnd := payloadStart + len(ep.Packet.Payload) + if payloadEnd > n { + b.logger.Warnw("unexpected marshal size", nil, "max", n, "need", payloadEnd) + return nil + } + rtp.Payload = buf[payloadStart:payloadEnd] + ep.Packet = &rtp + + return ep +} + func (b *Buffer) doFpsCalc(ep *ExtPacket) { if b.frameRateCalculated || len(ep.Packet.Payload) == 0 { return @@ -478,11 +501,10 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime int64) { } } -func (b *Buffer) getExtPacket(rawPacket []byte, rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket { +func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket { ep := &ExtPacket{ - Packet: rtpPacket, - Arrival: arrivalTime, - RawPacket: rawPacket, + Packet: rtpPacket, + Arrival: arrivalTime, VideoLayer: VideoLayer{ Spatial: InvalidLayerSpatial, Temporal: InvalidLayerTemporal, @@ -619,6 +641,11 @@ func (b *Buffer) getRTCP() []rtcp.Packet { func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) { b.Lock() defer b.Unlock() + + return b.getPacket(buff, sn) +} + +func (b *Buffer) getPacket(buff []byte, sn uint16) (int, error) { if b.closed.Load() { return 0, io.EOF } diff --git a/pkg/sfu/buffer/buffer_test.go b/pkg/sfu/buffer/buffer_test.go index 91e5d5196..a3207336e 100644 --- a/pkg/sfu/buffer/buffer_test.go +++ b/pkg/sfu/buffer/buffer_test.go @@ -146,20 +146,11 @@ func TestNack(t *testing.T) { } func TestNewBuffer(t *testing.T) { - type args struct { - options Options - } tests := []struct { name string - args args }{ { name: "Must not be nil and add packets in sequence", - args: args{ - options: Options{ - MaxBitRate: 1e6, - }, - }, }, } for _, tt := range tests { diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 1cbfaf2d7..3b911b729 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -12,6 +12,7 @@ import ( "github.com/pion/webrtc/v3" "go.uber.org/atomic" + "github.com/livekit/mediatransportutil/pkg/bucket" "github.com/livekit/mediatransportutil/pkg/twcc" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -516,6 +517,7 @@ func (w *WebRTCReceiver) getDeltaStats() map[uint32]*buffer.StreamStatsWithLayer } func (w *WebRTCReceiver) forwardRTP(layer int32) { + pktBuf := make([]byte, bucket.MaxPktSize) tracker := w.streamTrackerManager.GetTracker(layer) defer func() { @@ -541,7 +543,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { buf := w.buffers[layer] redPktWriter := w.redPktWriter w.bufferMu.RUnlock() - pkt, err := buf.ReadExtended() + pkt, err := buf.ReadExtended(pktBuf) if err == io.EOF { return }