Prevent rtx buffer and forwarding path colliding (#1174)

* Prevent rtx buffer and forwarding path colliding

Received packets are put into RTX buffer which is
a circular buffer and the packet (sequence number) is
queued for forwarding. If the RTX buffer fills up
and cycles before forwarding happens, forwarding
would pick the wrong packet (as it is holding a
reference to a byte slice in the RTX buffer) to forward.

Prevent it by moving reading from RTX buffer just
before forwarding. Adds an extra copy from RTX buffer
-> temp buffer for forwarding, but ensures that forwarding
buffer is not used by another go routine.

* Revert some changes from previous commit

Details:
- Do all forward processing as before.
- One difference is not load raw packet into ExtPacket.
- Load raw packet into provided buffer when module that reads
using ReadExtended calls that function. If the packet is
not there in the retransmission buffer, that packet will be
dropped. This is the case we are trying to fix, i. e. the RTX
buffer has cycled before ReadExtended could pull the packet.
This makes a copy into the provided buffer so that the data
does not change underneath.

* Remove debug comment

* Oops missed a function call
This commit is contained in:
Raja Subramanian
2022-11-19 13:19:49 +05:30
committed by GitHub
parent f9bdcdf201
commit aba18accd9
3 changed files with 49 additions and 29 deletions

View File

@@ -99,11 +99,6 @@ type Buffer struct {
frameRateCalculated bool
}
// BufferOptions provides configuration options for the buffer
type Options struct {
MaxBitRate uint64
}
// NewBuffer constructs a new Buffer
func NewBuffer(ssrc uint32, vp, ap *sync.Pool) *Buffer {
l := logger.GetDefaultLogger() // will be reset with correct context via SetLogger
@@ -272,16 +267,22 @@ func (b *Buffer) Read(buff []byte) (n int, err error) {
}
}
func (b *Buffer) ReadExtended() (*ExtPacket, error) {
func (b *Buffer) ReadExtended(buf []byte) (*ExtPacket, error) {
for {
if b.closed.Load() {
return nil, io.EOF
}
b.Lock()
if b.extPackets.Len() > 0 {
extPkt := b.extPackets.PopFront().(*ExtPacket)
ep := b.extPackets.PopFront().(*ExtPacket)
ep = b.patchExtPacket(ep, buf)
if ep == nil {
b.Unlock()
continue
}
b.Unlock()
return extPkt, nil
return ep, nil
}
b.Unlock()
time.Sleep(10 * time.Millisecond)
@@ -363,7 +364,7 @@ func (b *Buffer) SetRTT(rtt uint32) {
}
func (b *Buffer) calc(pkt []byte, arrivalTime int64) {
pb, err := b.bucket.AddPacket(pkt)
pktBuf, err := b.bucket.AddPacket(pkt)
if err != nil {
//
// Even when erroring, do
@@ -385,7 +386,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) {
}
var p rtp.Packet
err = p.Unmarshal(pb)
err = p.Unmarshal(pktBuf)
if err != nil {
b.logger.Warnw("error unmarshaling RTP packet", err)
return
@@ -394,19 +395,41 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) {
b.updateStreamState(&p, arrivalTime)
b.processHeaderExtensions(&p, arrivalTime)
ep := b.getExtPacket(pb, &p, arrivalTime)
b.doNACKs()
b.doReports(arrivalTime)
ep := b.getExtPacket(&p, arrivalTime)
if ep == nil {
return
}
b.extPackets.PushBack(ep)
b.doNACKs()
b.doReports(arrivalTime)
b.doFpsCalc(ep)
}
func (b *Buffer) patchExtPacket(ep *ExtPacket, buf []byte) *ExtPacket {
n, err := b.getPacket(buf, ep.Packet.SequenceNumber)
if err != nil {
b.logger.Warnw("could not get packet", err, "sn", ep.Packet.SequenceNumber)
return nil
}
ep.RawPacket = buf[:n]
// patch RTP packet to point payload to new buffer
rtp := *ep.Packet
payloadStart := ep.Packet.Header.MarshalSize()
payloadEnd := payloadStart + len(ep.Packet.Payload)
if payloadEnd > n {
b.logger.Warnw("unexpected marshal size", nil, "max", n, "need", payloadEnd)
return nil
}
rtp.Payload = buf[payloadStart:payloadEnd]
ep.Packet = &rtp
return ep
}
func (b *Buffer) doFpsCalc(ep *ExtPacket) {
if b.frameRateCalculated || len(ep.Packet.Payload) == 0 {
return
@@ -478,11 +501,10 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime int64) {
}
}
func (b *Buffer) getExtPacket(rawPacket []byte, rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket {
func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket {
ep := &ExtPacket{
Packet: rtpPacket,
Arrival: arrivalTime,
RawPacket: rawPacket,
Packet: rtpPacket,
Arrival: arrivalTime,
VideoLayer: VideoLayer{
Spatial: InvalidLayerSpatial,
Temporal: InvalidLayerTemporal,
@@ -619,6 +641,11 @@ func (b *Buffer) getRTCP() []rtcp.Packet {
func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) {
b.Lock()
defer b.Unlock()
return b.getPacket(buff, sn)
}
func (b *Buffer) getPacket(buff []byte, sn uint16) (int, error) {
if b.closed.Load() {
return 0, io.EOF
}

View File

@@ -146,20 +146,11 @@ func TestNack(t *testing.T) {
}
func TestNewBuffer(t *testing.T) {
type args struct {
options Options
}
tests := []struct {
name string
args args
}{
{
name: "Must not be nil and add packets in sequence",
args: args{
options: Options{
MaxBitRate: 1e6,
},
},
},
}
for _, tt := range tests {

View File

@@ -12,6 +12,7 @@ import (
"github.com/pion/webrtc/v3"
"go.uber.org/atomic"
"github.com/livekit/mediatransportutil/pkg/bucket"
"github.com/livekit/mediatransportutil/pkg/twcc"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
@@ -516,6 +517,7 @@ func (w *WebRTCReceiver) getDeltaStats() map[uint32]*buffer.StreamStatsWithLayer
}
func (w *WebRTCReceiver) forwardRTP(layer int32) {
pktBuf := make([]byte, bucket.MaxPktSize)
tracker := w.streamTrackerManager.GetTracker(layer)
defer func() {
@@ -541,7 +543,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
buf := w.buffers[layer]
redPktWriter := w.redPktWriter
w.bufferMu.RUnlock()
pkt, err := buf.ReadExtended()
pkt, err := buf.ReadExtended(pktBuf)
if err == io.EOF {
return
}