mirror of
https://github.com/livekit/livekit.git
synced 2026-04-25 19:52:11 +00:00
Prevent rtx buffer and forwarding path colliding (#1174)
* Prevent rtx buffer and forwarding path colliding Received packets are put into RTX buffer which is a circular buffer and the packet (sequence number) is queued for forwarding. If the RTX buffer fills up and cycles before forwarding happens, forwarding would pick the wrong packet (as it is holding a reference to a byte slice in the RTX buffer) to forward. Prevent it by moving reading from RTX buffer just before forwarding. Adds an extra copy from RTX buffer -> temp buffer for forwarding, but ensures that forwarding buffer is not used by another go routine. * Revert some changes from previous commit Details: - Do all forward processing as before. - One difference is not load raw packet into ExtPacket. - Load raw packet into provided buffer when module that reads using ReadExtended calls that function. If the packet is not there in the retransmission buffer, that packet will be dropped. This is the case we are trying to fix, i. e. the RTX buffer has cycled before ReadExtended could pull the packet. This makes a copy into the provided buffer so that the data does not change underneath. * Remove debug comment * Oops missed a function call
This commit is contained in:
@@ -99,11 +99,6 @@ type Buffer struct {
|
||||
frameRateCalculated bool
|
||||
}
|
||||
|
||||
// BufferOptions provides configuration options for the buffer
|
||||
type Options struct {
|
||||
MaxBitRate uint64
|
||||
}
|
||||
|
||||
// NewBuffer constructs a new Buffer
|
||||
func NewBuffer(ssrc uint32, vp, ap *sync.Pool) *Buffer {
|
||||
l := logger.GetDefaultLogger() // will be reset with correct context via SetLogger
|
||||
@@ -272,16 +267,22 @@ func (b *Buffer) Read(buff []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadExtended() (*ExtPacket, error) {
|
||||
func (b *Buffer) ReadExtended(buf []byte) (*ExtPacket, error) {
|
||||
for {
|
||||
if b.closed.Load() {
|
||||
return nil, io.EOF
|
||||
}
|
||||
b.Lock()
|
||||
if b.extPackets.Len() > 0 {
|
||||
extPkt := b.extPackets.PopFront().(*ExtPacket)
|
||||
ep := b.extPackets.PopFront().(*ExtPacket)
|
||||
ep = b.patchExtPacket(ep, buf)
|
||||
if ep == nil {
|
||||
b.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
b.Unlock()
|
||||
return extPkt, nil
|
||||
return ep, nil
|
||||
}
|
||||
b.Unlock()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
@@ -363,7 +364,7 @@ func (b *Buffer) SetRTT(rtt uint32) {
|
||||
}
|
||||
|
||||
func (b *Buffer) calc(pkt []byte, arrivalTime int64) {
|
||||
pb, err := b.bucket.AddPacket(pkt)
|
||||
pktBuf, err := b.bucket.AddPacket(pkt)
|
||||
if err != nil {
|
||||
//
|
||||
// Even when erroring, do
|
||||
@@ -385,7 +386,7 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) {
|
||||
}
|
||||
|
||||
var p rtp.Packet
|
||||
err = p.Unmarshal(pb)
|
||||
err = p.Unmarshal(pktBuf)
|
||||
if err != nil {
|
||||
b.logger.Warnw("error unmarshaling RTP packet", err)
|
||||
return
|
||||
@@ -394,19 +395,41 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) {
|
||||
b.updateStreamState(&p, arrivalTime)
|
||||
b.processHeaderExtensions(&p, arrivalTime)
|
||||
|
||||
ep := b.getExtPacket(pb, &p, arrivalTime)
|
||||
b.doNACKs()
|
||||
|
||||
b.doReports(arrivalTime)
|
||||
|
||||
ep := b.getExtPacket(&p, arrivalTime)
|
||||
if ep == nil {
|
||||
return
|
||||
}
|
||||
b.extPackets.PushBack(ep)
|
||||
|
||||
b.doNACKs()
|
||||
|
||||
b.doReports(arrivalTime)
|
||||
|
||||
b.doFpsCalc(ep)
|
||||
}
|
||||
|
||||
func (b *Buffer) patchExtPacket(ep *ExtPacket, buf []byte) *ExtPacket {
|
||||
n, err := b.getPacket(buf, ep.Packet.SequenceNumber)
|
||||
if err != nil {
|
||||
b.logger.Warnw("could not get packet", err, "sn", ep.Packet.SequenceNumber)
|
||||
return nil
|
||||
}
|
||||
ep.RawPacket = buf[:n]
|
||||
|
||||
// patch RTP packet to point payload to new buffer
|
||||
rtp := *ep.Packet
|
||||
payloadStart := ep.Packet.Header.MarshalSize()
|
||||
payloadEnd := payloadStart + len(ep.Packet.Payload)
|
||||
if payloadEnd > n {
|
||||
b.logger.Warnw("unexpected marshal size", nil, "max", n, "need", payloadEnd)
|
||||
return nil
|
||||
}
|
||||
rtp.Payload = buf[payloadStart:payloadEnd]
|
||||
ep.Packet = &rtp
|
||||
|
||||
return ep
|
||||
}
|
||||
|
||||
func (b *Buffer) doFpsCalc(ep *ExtPacket) {
|
||||
if b.frameRateCalculated || len(ep.Packet.Payload) == 0 {
|
||||
return
|
||||
@@ -478,11 +501,10 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime int64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) getExtPacket(rawPacket []byte, rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket {
|
||||
func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPacket {
|
||||
ep := &ExtPacket{
|
||||
Packet: rtpPacket,
|
||||
Arrival: arrivalTime,
|
||||
RawPacket: rawPacket,
|
||||
Packet: rtpPacket,
|
||||
Arrival: arrivalTime,
|
||||
VideoLayer: VideoLayer{
|
||||
Spatial: InvalidLayerSpatial,
|
||||
Temporal: InvalidLayerTemporal,
|
||||
@@ -619,6 +641,11 @@ func (b *Buffer) getRTCP() []rtcp.Packet {
|
||||
func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
return b.getPacket(buff, sn)
|
||||
}
|
||||
|
||||
func (b *Buffer) getPacket(buff []byte, sn uint16) (int, error) {
|
||||
if b.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
@@ -146,20 +146,11 @@ func TestNack(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewBuffer(t *testing.T) {
|
||||
type args struct {
|
||||
options Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
name: "Must not be nil and add packets in sequence",
|
||||
args: args{
|
||||
options: Options{
|
||||
MaxBitRate: 1e6,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/pion/webrtc/v3"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/mediatransportutil/pkg/bucket"
|
||||
"github.com/livekit/mediatransportutil/pkg/twcc"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
@@ -516,6 +517,7 @@ func (w *WebRTCReceiver) getDeltaStats() map[uint32]*buffer.StreamStatsWithLayer
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) forwardRTP(layer int32) {
|
||||
pktBuf := make([]byte, bucket.MaxPktSize)
|
||||
tracker := w.streamTrackerManager.GetTracker(layer)
|
||||
|
||||
defer func() {
|
||||
@@ -541,7 +543,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
|
||||
buf := w.buffers[layer]
|
||||
redPktWriter := w.redPktWriter
|
||||
w.bufferMu.RUnlock()
|
||||
pkt, err := buf.ReadExtended()
|
||||
pkt, err := buf.ReadExtended(pktBuf)
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user