diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index ec5030ca3..83094d7f6 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -241,10 +241,10 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra layer := buffer.RidToSpatialLayer(track.RID(), ti) t.params.Logger.Debugw( "AddReceiver", - "mime", track.Codec().MimeType, "rid", track.RID(), "layer", layer, "ssrc", track.SSRC(), + "codec", track.Codec(), ) wr := t.MediaTrackReceiver.Receiver(mime) if wr == nil { diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index c52958ec6..11bca8b52 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -235,8 +235,14 @@ type DownTrack struct { forwarder *Forwarder - upstreamCodecs []webrtc.RTPCodecParameters - codec webrtc.RTPCodecCapability + upstreamCodecs []webrtc.RTPCodecParameters + codec webrtc.RTPCodecCapability + + // payload types for red codec only + isRED bool + upstreamPrimaryPT uint8 + primaryPT uint8 + absSendTimeExtID int transportWideExtID int dependencyDescriptorExtID int @@ -365,7 +371,7 @@ func NewDownTrack(params DowntrackParams) (*DownTrack, error) { go d.maxLayerNotifierWorker() go d.keyFrameRequester() } - d.params.Logger.Debugw("downtrack created") + d.params.Logger.Debugw("downtrack created", "upstreamCodecs", d.upstreamCodecs) return d, nil } @@ -379,11 +385,12 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, d.bindLock.Unlock() return webrtc.RTPCodecParameters{}, ErrDownTrackAlreadyBound } - var codec webrtc.RTPCodecParameters + var codec, matchedUpstreamCodec webrtc.RTPCodecParameters for _, c := range d.upstreamCodecs { matchCodec, err := utils.CodecParametersFuzzySearch(c, t.CodecParameters()) if err == nil { codec = matchCodec + matchedUpstreamCodec = c break } } @@ -397,6 +404,18 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, onBinding(err) } return webrtc.RTPCodecParameters{}, err + } else if strings.EqualFold(matchedUpstreamCodec.MimeType, "audio/red") { + d.isRED = true + var primaryPT, secondaryPT int + if n, err := fmt.Sscanf(matchedUpstreamCodec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 { + d.params.Logger.Errorw("failed to parse upstream primary and secondary payload type for RED", err, "matchedCodec", codec) + } + d.upstreamPrimaryPT = uint8(primaryPT) + + if n, err := fmt.Sscanf(codec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 { + d.params.Logger.Errorw("failed to parse primary and secondary payload type for RED", err, "matchedCodec", codec) + } + d.primaryPT = uint8(primaryPT) } // if a downtrack is closed before bind, it already unsubscribed from client, don't do subsequent operation and return here. @@ -406,7 +425,22 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, return codec, nil } - d.params.Logger.Debugw("DownTrack.Bind", "codecs", d.upstreamCodecs, "matchCodec", codec, "ssrc", t.SSRC()) + logFields := []interface{}{ + "codecs", d.upstreamCodecs, + "matchCodec", codec, + "ssrc", t.SSRC(), + } + if d.isRED { + logFields = append(logFields, + "isRED", d.isRED, + "upstreamPrimaryPT", d.upstreamPrimaryPT, + "primaryPT", d.primaryPT, + ) + } + d.params.Logger.Debugw("DownTrack.Bind", + logFields..., + ) + d.ssrc = uint32(t.SSRC()) d.payloadType = uint8(codec.PayloadType) d.writeStream = t.WriteStream() @@ -1756,7 +1790,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { pkt.Header.SequenceNumber = epm.targetSeqNo pkt.Header.Timestamp = epm.timestamp pkt.Header.SSRC = d.ssrc - pkt.Header.PayloadType = d.payloadType + pkt.Header.PayloadType = d.getTranslatedPayloadType(pkt.Header.PayloadType) poolEntity := PacketFactory.Get().(*[]byte) payload := *poolEntity @@ -1843,7 +1877,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { func (d *DownTrack) getTranslatedRTPHeader(extPkt *buffer.ExtPacket, tp *TranslationParams) (*rtp.Header, error) { hdr := extPkt.Packet.Header - hdr.PayloadType = d.payloadType + hdr.PayloadType = d.getTranslatedPayloadType(hdr.PayloadType) hdr.Timestamp = uint32(tp.rtp.extTimestamp) hdr.SequenceNumber = uint16(tp.rtp.extSequenceNumber) hdr.SSRC = d.ssrc @@ -1854,6 +1888,15 @@ func (d *DownTrack) getTranslatedRTPHeader(extPkt *buffer.ExtPacket, tp *Transla return &hdr, nil } +func (d *DownTrack) getTranslatedPayloadType(src uint8) uint8 { + // send primary codec to subscriber if the publisher send primary codec to us when red is negotiated, + // this will happen when the payload is too large to encode into red payload (exceeds mtu). + if d.isRED && src == d.upstreamPrimaryPT && d.primaryPT != 0 { + return d.primaryPT + } + return d.payloadType +} + func (d *DownTrack) DebugInfo() map[string]interface{} { stats := map[string]interface{}{ "LastPli": d.rtpStats.LastPli(), diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index 89b437124..956ee7251 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -41,6 +41,7 @@ type RedPrimaryReceiver struct { downTrackSpreader *DownTrackSpreader logger logger.Logger closed atomic.Bool + redPT uint8 firstPktReceived bool lastSeq uint16 @@ -54,6 +55,7 @@ func NewRedPrimaryReceiver(receiver TrackReceiver, dsp DownTrackSpreaderParams) TrackReceiver: receiver, downTrackSpreader: NewDownTrackSpreader(dsp), logger: dsp.Logger, + redPT: uint8(receiver.Codec().PayloadType), } } @@ -63,6 +65,13 @@ func (r *RedPrimaryReceiver) ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int3 return 0 } + if pkt.Packet.PayloadType != r.redPT { + // forward non-red packet directly + return r.downTrackSpreader.Broadcast(func(dt TrackSender) { + _ = dt.WriteRTP(pkt, spatialLayer) + }) + } + pkts, err := r.getSendPktsFromRed(pkt.Packet) if err != nil { r.logger.Errorw("get encoding for red failed", err, "payloadtype", pkt.Packet.PayloadType) diff --git a/pkg/sfu/redreceiver_test.go b/pkg/sfu/redreceiver_test.go index d19181446..d3fdc3ebc 100644 --- a/pkg/sfu/redreceiver_test.go +++ b/pkg/sfu/redreceiver_test.go @@ -25,7 +25,10 @@ import ( "github.com/livekit/protocol/logger" ) -const tsStep = uint32(48000 / 1000 * 10) +const ( + tsStep = uint32(48000 / 1000 * 10) + opusREDPT = 63 +) type dummyDowntrack struct { TrackSender @@ -268,6 +271,7 @@ func generateRedPkts(t *testing.T, pkts []*rtp.Packet, redCount int) []*rtp.Pack } buf := make([]byte, mtuSize) redPkt := *pkt + redPkt.PayloadType = opusREDPT encoded, err := encodeRedForPrimary(encodingPkts, pkt, buf) require.NoError(t, err) redPkt.Payload = buf[:encoded] @@ -281,6 +285,7 @@ func testRedRedPrimaryReceiver(t *testing.T, maxPktCount, redCount int, sendPktI w := &WebRTCReceiver{ kind: webrtc.RTPCodecTypeAudio, logger: logger.GetLogger(), + codec: webrtc.RTPCodecParameters{PayloadType: opusREDPT, RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: "audio/red"}}, } require.Equal(t, w.GetPrimaryReceiverForRed(), w) w.isRED = true @@ -374,6 +379,30 @@ func TestRedPrimaryReceiver(t *testing.T) { recvPktIndex := []int{20, 10, 23, 24, 25, 21, 22, 23, 32, 33, 34} testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) }) + + t.Run("mixed primary codec", func(t *testing.T) { + dt := &dummyDowntrack{TrackSender: &DownTrack{}} + w := &WebRTCReceiver{ + kind: webrtc.RTPCodecTypeAudio, + logger: logger.GetLogger(), + codec: webrtc.RTPCodecParameters{PayloadType: opusREDPT, RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: "audio/red"}}, + } + require.Equal(t, w.GetPrimaryReceiverForRed(), w) + w.isRED = true + red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + primaryPkt := &rtp.Packet{ + Header: rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111}, + Payload: []byte{1, 3, 5, 7, 9}, + } + red.ForwardRTP(&buffer.ExtPacket{ + Packet: primaryPkt, + }, 0) + + verifyPktsEqual(t, []*rtp.Packet{primaryPkt}, dt.receivedPkts) + }) } func TestExtractPrimaryEncodingForRED(t *testing.T) { @@ -386,8 +415,10 @@ func TestExtractPrimaryEncodingForRED(t *testing.T) { for _, redPkt := range redPkts { payload, err := extractPrimaryEncodingForRED(redPkt.Payload) require.NoError(t, err) + primaryHeader := redPkt.Header + primaryHeader.PayloadType = 111 primaryPkts = append(primaryPkts, &rtp.Packet{ - Header: redPkt.Header, + Header: primaryHeader, Payload: payload, }) }