Forward correct payload type for mixed up red/primary payload (#2847)

* Forward correct payload type for mixed up red/primary payload

* empty line

* log field & test case
This commit is contained in:
cnderrauber
2024-07-09 23:04:47 +08:00
committed by GitHub
parent 27f6794e77
commit deee816d0a
4 changed files with 93 additions and 10 deletions
+1 -1
View File
@@ -241,10 +241,10 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra
layer := buffer.RidToSpatialLayer(track.RID(), ti)
t.params.Logger.Debugw(
"AddReceiver",
"mime", track.Codec().MimeType,
"rid", track.RID(),
"layer", layer,
"ssrc", track.SSRC(),
"codec", track.Codec(),
)
wr := t.MediaTrackReceiver.Receiver(mime)
if wr == nil {
+50 -7
View File
@@ -235,8 +235,14 @@ type DownTrack struct {
forwarder *Forwarder
upstreamCodecs []webrtc.RTPCodecParameters
codec webrtc.RTPCodecCapability
upstreamCodecs []webrtc.RTPCodecParameters
codec webrtc.RTPCodecCapability
// payload types for red codec only
isRED bool
upstreamPrimaryPT uint8
primaryPT uint8
absSendTimeExtID int
transportWideExtID int
dependencyDescriptorExtID int
@@ -365,7 +371,7 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) {
go d.maxLayerNotifierWorker()
go d.keyFrameRequester()
}
d.params.Logger.Debugw("downtrack created")
d.params.Logger.Debugw("downtrack created", "upstreamCodecs", d.upstreamCodecs)
return d, nil
}
@@ -379,11 +385,12 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters,
d.bindLock.Unlock()
return webrtc.RTPCodecParameters{}, ErrDownTrackAlreadyBound
}
var codec webrtc.RTPCodecParameters
var codec, matchedUpstreamCodec webrtc.RTPCodecParameters
for _, c := range d.upstreamCodecs {
matchCodec, err := utils.CodecParametersFuzzySearch(c, t.CodecParameters())
if err == nil {
codec = matchCodec
matchedUpstreamCodec = c
break
}
}
@@ -397,6 +404,18 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters,
onBinding(err)
}
return webrtc.RTPCodecParameters{}, err
} else if strings.EqualFold(matchedUpstreamCodec.MimeType, "audio/red") {
d.isRED = true
var primaryPT, secondaryPT int
if n, err := fmt.Sscanf(matchedUpstreamCodec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 {
d.params.Logger.Errorw("failed to parse upstream primary and secondary payload type for RED", err, "matchedCodec", codec)
}
d.upstreamPrimaryPT = uint8(primaryPT)
if n, err := fmt.Sscanf(codec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 {
d.params.Logger.Errorw("failed to parse primary and secondary payload type for RED", err, "matchedCodec", codec)
}
d.primaryPT = uint8(primaryPT)
}
// if a downtrack is closed before bind, it already unsubscribed from client, don't do subsequent operation and return here.
@@ -406,7 +425,22 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters,
return codec, nil
}
d.params.Logger.Debugw("DownTrack.Bind", "codecs", d.upstreamCodecs, "matchCodec", codec, "ssrc", t.SSRC())
logFields := []interface{}{
"codecs", d.upstreamCodecs,
"matchCodec", codec,
"ssrc", t.SSRC(),
}
if d.isRED {
logFields = append(logFields,
"isRED", d.isRED,
"upstreamPrimaryPT", d.upstreamPrimaryPT,
"primaryPT", d.primaryPT,
)
}
d.params.Logger.Debugw("DownTrack.Bind",
logFields...,
)
d.ssrc = uint32(t.SSRC())
d.payloadType = uint8(codec.PayloadType)
d.writeStream = t.WriteStream()
@@ -1756,7 +1790,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
pkt.Header.SequenceNumber = epm.targetSeqNo
pkt.Header.Timestamp = epm.timestamp
pkt.Header.SSRC = d.ssrc
pkt.Header.PayloadType = d.payloadType
pkt.Header.PayloadType = d.getTranslatedPayloadType(pkt.Header.PayloadType)
poolEntity := PacketFactory.Get().(*[]byte)
payload := *poolEntity
@@ -1843,7 +1877,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
func (d *DownTrack) getTranslatedRTPHeader(extPkt *buffer.ExtPacket, tp *TranslationParams) (*rtp.Header, error) {
hdr := extPkt.Packet.Header
hdr.PayloadType = d.payloadType
hdr.PayloadType = d.getTranslatedPayloadType(hdr.PayloadType)
hdr.Timestamp = uint32(tp.rtp.extTimestamp)
hdr.SequenceNumber = uint16(tp.rtp.extSequenceNumber)
hdr.SSRC = d.ssrc
@@ -1854,6 +1888,15 @@ func (d *DownTrack) getTranslatedRTPHeader(extPkt *buffer.ExtPacket, tp *Transla
return &hdr, nil
}
func (d *DownTrack) getTranslatedPayloadType(src uint8) uint8 {
// send primary codec to subscriber if the publisher send primary codec to us when red is negotiated,
// this will happen when the payload is too large to encode into red payload (exceeds mtu).
if d.isRED && src == d.upstreamPrimaryPT && d.primaryPT != 0 {
return d.primaryPT
}
return d.payloadType
}
func (d *DownTrack) DebugInfo() map[string]interface{} {
stats := map[string]interface{}{
"LastPli": d.rtpStats.LastPli(),
+9
View File
@@ -41,6 +41,7 @@ type RedPrimaryReceiver struct {
downTrackSpreader *DownTrackSpreader
logger logger.Logger
closed atomic.Bool
redPT uint8
firstPktReceived bool
lastSeq uint16
@@ -54,6 +55,7 @@ func NewRedPrimaryReceiver(receiver TrackReceiver, dsp DownTrackSpreaderParams)
TrackReceiver: receiver,
downTrackSpreader: NewDownTrackSpreader(dsp),
logger: dsp.Logger,
redPT: uint8(receiver.Codec().PayloadType),
}
}
@@ -63,6 +65,13 @@ func (r *RedPrimaryReceiver) ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int3
return 0
}
if pkt.Packet.PayloadType != r.redPT {
// forward non-red packet directly
return r.downTrackSpreader.Broadcast(func(dt TrackSender) {
_ = dt.WriteRTP(pkt, spatialLayer)
})
}
pkts, err := r.getSendPktsFromRed(pkt.Packet)
if err != nil {
r.logger.Errorw("get encoding for red failed", err, "payloadtype", pkt.Packet.PayloadType)
+33 -2
View File
@@ -25,7 +25,10 @@ import (
"github.com/livekit/protocol/logger"
)
const tsStep = uint32(48000 / 1000 * 10)
const (
tsStep = uint32(48000 / 1000 * 10)
opusREDPT = 63
)
type dummyDowntrack struct {
TrackSender
@@ -268,6 +271,7 @@ func generateRedPkts(t *testing.T, pkts []*rtp.Packet, redCount int) []*rtp.Pack
}
buf := make([]byte, mtuSize)
redPkt := *pkt
redPkt.PayloadType = opusREDPT
encoded, err := encodeRedForPrimary(encodingPkts, pkt, buf)
require.NoError(t, err)
redPkt.Payload = buf[:encoded]
@@ -281,6 +285,7 @@ func testRedRedPrimaryReceiver(t *testing.T, maxPktCount, redCount int, sendPktI
w := &WebRTCReceiver{
kind: webrtc.RTPCodecTypeAudio,
logger: logger.GetLogger(),
codec: webrtc.RTPCodecParameters{PayloadType: opusREDPT, RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: "audio/red"}},
}
require.Equal(t, w.GetPrimaryReceiverForRed(), w)
w.isRED = true
@@ -374,6 +379,30 @@ func TestRedPrimaryReceiver(t *testing.T) {
recvPktIndex := []int{20, 10, 23, 24, 25, 21, 22, 23, 32, 33, 34}
testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex)
})
t.Run("mixed primary codec", func(t *testing.T) {
dt := &dummyDowntrack{TrackSender: &DownTrack{}}
w := &WebRTCReceiver{
kind: webrtc.RTPCodecTypeAudio,
logger: logger.GetLogger(),
codec: webrtc.RTPCodecParameters{PayloadType: opusREDPT, RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: "audio/red"}},
}
require.Equal(t, w.GetPrimaryReceiverForRed(), w)
w.isRED = true
red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver)
require.NotNil(t, red)
require.NoError(t, red.AddDownTrack(dt))
primaryPkt := &rtp.Packet{
Header: rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111},
Payload: []byte{1, 3, 5, 7, 9},
}
red.ForwardRTP(&buffer.ExtPacket{
Packet: primaryPkt,
}, 0)
verifyPktsEqual(t, []*rtp.Packet{primaryPkt}, dt.receivedPkts)
})
}
func TestExtractPrimaryEncodingForRED(t *testing.T) {
@@ -386,8 +415,10 @@ func TestExtractPrimaryEncodingForRED(t *testing.T) {
for _, redPkt := range redPkts {
payload, err := extractPrimaryEncodingForRED(redPkt.Payload)
require.NoError(t, err)
primaryHeader := redPkt.Header
primaryHeader.PayloadType = 111
primaryPkts = append(primaryPkts, &rtp.Packet{
Header: redPkt.Header,
Header: primaryHeader,
Payload: payload,
})
}