Add red encoding for opus only publisher (#1137)

* Add red encodings for opus only publisher

* Add test case for red receiver
This commit is contained in:
cnderrauber
2022-11-02 10:36:29 +08:00
committed by GitHub
parent 46f45e8892
commit bdd69c7a1c
7 changed files with 471 additions and 43 deletions
+9 -10
View File
@@ -1,7 +1,6 @@
package rtc
import (
"fmt"
"strings"
"github.com/pion/webrtc/v3"
@@ -11,6 +10,7 @@ import (
)
var opusCodecCapability = webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2, SDPFmtpLine: "minptime=10;useinbandfec=1"}
var redCodecCapability = webrtc.RTPCodecCapability{MimeType: sfu.MimeTypeAudioRed, ClockRate: 48000, Channels: 2, SDPFmtpLine: "111/111"}
func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedback RTCPFeedbackConfig) error {
opusCodec := opusCodecCapability
@@ -24,15 +24,14 @@ func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedbac
}, webrtc.RTPCodecTypeAudio); err != nil {
return err
}
}
redCodec := webrtc.RTPCodecCapability{MimeType: sfu.MimeTypeAudioRed, ClockRate: 48000, Channels: 2}
if opusPayload != 0 && isCodecEnabled(codecs, redCodec) {
redCodec.SDPFmtpLine = fmt.Sprintf("%d/%d", opusPayload, opusPayload)
if err := me.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: redCodec,
PayloadType: 63,
}, webrtc.RTPCodecTypeAudio); err != nil {
return err
if isCodecEnabled(codecs, redCodecCapability) {
if err := me.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: redCodecCapability,
PayloadType: 63,
}, webrtc.RTPCodecTypeAudio); err != nil {
return err
}
}
}
+8 -1
View File
@@ -485,7 +485,14 @@ func (t *MediaTrackReceiver) addSubscriber(sub types.LocalParticipant) (err erro
}
tLogger := LoggerWithTrack(sub.GetLogger(), t.ID(), t.params.IsRelayed)
err = t.MediaTrackSubscriptions.AddSubscriber(sub, NewWrappedReceiver(receivers, t.ID(), streamId, potentialCodecs, tLogger))
err = t.MediaTrackSubscriptions.AddSubscriber(sub, NewWrappedReceiver(WrappedReceiverParams{
Receivers: receivers,
TrackID: t.ID(),
StreamId: streamId,
UpstreamCodecs: potentialCodecs,
Logger: tLogger,
DisableRed: t.trackInfo.GetDisableRed(),
}))
if err != nil {
return
}
+46 -25
View File
@@ -16,47 +16,61 @@ import (
// wrapper around WebRTC receiver, overriding its ID
type WrappedReceiver struct {
sfu.TrackReceiver
receivers []sfu.TrackReceiver
trackID livekit.TrackID
streamId string
codecs []webrtc.RTPCodecParameters
determinedCodec webrtc.RTPCodecCapability
logger logger.Logger
type WrappedReceiverParams struct {
Receivers []*simulcastReceiver
TrackID livekit.TrackID
StreamId string
UpstreamCodecs []webrtc.RTPCodecParameters
Logger logger.Logger
DisableRed bool
}
func NewWrappedReceiver(receivers []*simulcastReceiver, trackID livekit.TrackID, streamId string, upstreamCodecs []webrtc.RTPCodecParameters, logger logger.Logger) *WrappedReceiver {
sfuReceivers := make([]sfu.TrackReceiver, 0, len(receivers))
for _, r := range receivers {
type WrappedReceiver struct {
sfu.TrackReceiver
params WrappedReceiverParams
receivers []sfu.TrackReceiver
codecs []webrtc.RTPCodecParameters
determinedCodec webrtc.RTPCodecCapability
}
func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver {
sfuReceivers := make([]sfu.TrackReceiver, 0, len(params.Receivers))
for _, r := range params.Receivers {
sfuReceivers = append(sfuReceivers, r.TrackReceiver)
}
codecs := upstreamCodecs
// if upstream is opus/red, then add opus to match clients that don't support red
if len(codecs) == 1 && strings.EqualFold(codecs[0].MimeType, sfu.MimeTypeAudioRed) {
codecs = append(codecs, webrtc.RTPCodecParameters{
RTPCodecCapability: opusCodecCapability,
PayloadType: 111,
})
codecs := params.UpstreamCodecs
if len(codecs) == 1 {
if strings.EqualFold(codecs[0].MimeType, sfu.MimeTypeAudioRed) {
// if upstream is opus/red, then add opus to match clients that don't support red
codecs = append(codecs, webrtc.RTPCodecParameters{
RTPCodecCapability: opusCodecCapability,
PayloadType: 111,
})
} else if !params.DisableRed && strings.EqualFold(codecs[0].MimeType, webrtc.MimeTypeOpus) {
// if upstream is opus only and red eanbled, add red to match clients that supoort red
codecs = append(codecs, webrtc.RTPCodecParameters{
RTPCodecCapability: redCodecCapability,
PayloadType: 63,
})
// prefer red codec
codecs[0], codecs[1] = codecs[1], codecs[0]
}
}
return &WrappedReceiver{
params: params,
receivers: sfuReceivers,
trackID: trackID,
streamId: streamId,
codecs: codecs,
logger: logger,
}
}
func (r *WrappedReceiver) TrackID() livekit.TrackID {
return r.trackID
return r.params.TrackID
}
func (r *WrappedReceiver) StreamID() string {
return r.streamId
return r.params.StreamId
}
func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) {
@@ -69,10 +83,13 @@ func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) {
// audio opus/red can match opus only
r.TrackReceiver = receiver.GetPrimaryReceiverForRed()
break
} else if strings.EqualFold(c.MimeType, webrtc.MimeTypeOpus) && strings.EqualFold(codec.MimeType, sfu.MimeTypeAudioRed) {
r.TrackReceiver = receiver.GetRedReceiver()
break
}
}
if r.TrackReceiver == nil {
r.logger.Errorw("can't determine receiver for codec", nil, "codec", codec.MimeType)
r.params.Logger.Errorw("can't determine receiver for codec", nil, "codec", codec.MimeType)
if len(r.receivers) > 0 {
r.TrackReceiver = r.receivers[0]
}
@@ -266,3 +283,7 @@ func (d *DummyReceiver) GetPrimaryReceiverForRed() sfu.TrackReceiver {
// DummyReceiver used for video, it should not have RED codec
return d
}
func (d *DummyReceiver) GetRedReceiver() sfu.TrackReceiver {
return d
}
+36 -3
View File
@@ -58,6 +58,9 @@ type TrackReceiver interface {
// Get primary receiver if this receiver represents a RED codec; otherwise it will return itself
GetPrimaryReceiverForRed() TrackReceiver
// Get red receiver for primary codec, used by forward red encodings for opus only codec
GetRedReceiver() TrackReceiver
GetTemporalLayerFpsForSpatial(layer int32) []float32
}
@@ -105,6 +108,8 @@ type WebRTCReceiver struct {
onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat)
primaryReceiver atomic.Value // *RedPrimaryReceiver
redReceiver atomic.Value // *RedReceiver
redPktWriter func(pkt *buffer.ExtPacket, spatialLayer int32)
}
func IsSvcCodec(mime string) bool {
@@ -520,6 +525,9 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
if pr := w.primaryReceiver.Load(); pr != nil {
pr.(*RedPrimaryReceiver).Close()
}
if pr := w.redReceiver.Load(); pr != nil {
pr.(*RedReceiver).Close()
}
})
w.streamTrackerManager.RemoveTracker(layer)
@@ -531,6 +539,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
for {
w.bufferMu.RLock()
buf := w.buffers[layer]
redPktWriter := w.redPktWriter
w.bufferMu.RUnlock()
pkt, err := buf.ReadExtended()
if err == io.EOF {
@@ -555,8 +564,9 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
w.downTrackSpreader.Broadcast(func(dt TrackSender) {
_ = dt.WriteRTP(pkt, spatialLayer)
})
if pr := w.primaryReceiver.Load(); pr != nil {
pr.(*RedPrimaryReceiver).ForwardRTP(pkt, spatialLayer)
if redPktWriter != nil {
redPktWriter(pkt, spatialLayer)
}
}
}
@@ -607,11 +617,34 @@ func (w *WebRTCReceiver) GetPrimaryReceiverForRed() TrackReceiver {
Threshold: w.lbThreshold,
Logger: w.logger,
})
w.primaryReceiver.CompareAndSwap(nil, pr)
if w.primaryReceiver.CompareAndSwap(nil, pr) {
w.bufferMu.Lock()
w.redPktWriter = pr.ForwardRTP
w.bufferMu.Unlock()
}
}
return w.primaryReceiver.Load().(*RedPrimaryReceiver)
}
func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver {
if w.isRED || w.closed.Load() {
return w
}
if w.redReceiver.Load() == nil {
pr := NewRedReceiver(w, DownTrackSpreaderParams{
Threshold: w.lbThreshold,
Logger: w.logger,
})
if w.redReceiver.CompareAndSwap(nil, pr) {
w.bufferMu.Lock()
w.redPktWriter = pr.ForwardRTP
w.bufferMu.Unlock()
}
}
return w.redReceiver.Load().(*RedReceiver)
}
func (w *WebRTCReceiver) GetTemporalLayerFpsForSpatial(layer int32) []float32 {
if !w.isSVC {
return w.getBuffer(layer).GetTemporalLayerFpsForSpatial(0)
-4
View File
@@ -84,10 +84,6 @@ func (r *RedPrimaryReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID)
r.downTrackSpreader.Free(subscriberID)
}
func (r *RedPrimaryReceiver) CanClose() bool {
return r.closed.Load() || r.downTrackSpreader.DownTrackCount() == 0
}
func (r *RedPrimaryReceiver) Close() {
r.closed.Store(true)
for _, dt := range r.downTrackSpreader.ResetAndGetDownTracks() {
+150
View File
@@ -0,0 +1,150 @@
package sfu
import (
"encoding/binary"
"go.uber.org/atomic"
"github.com/pion/rtp"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
"github.com/livekit/mediatransportutil/pkg/bucket"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
)
const (
maxRedCount = 2
mtuSize = 1500
// the RedReceiver is only for chrome / native webrtc now, we always negotiate opus payload to 111 with those clients,
// so it is safe to use a fixed payload 111 here for performance(avoid encoding red blocks for each downtrack that
// have a different opus payload type).
opusPT = 111
)
type RedReceiver struct {
TrackReceiver
downTrackSpreader *DownTrackSpreader
logger logger.Logger
closed atomic.Bool
pktBuff [maxRedCount]*rtp.Packet
redPayloadBuf [mtuSize]byte
}
func NewRedReceiver(receiver TrackReceiver, dsp DownTrackSpreaderParams) *RedReceiver {
return &RedReceiver{
TrackReceiver: receiver,
downTrackSpreader: NewDownTrackSpreader(dsp),
logger: dsp.Logger,
}
}
func (r *RedReceiver) ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int32) {
// extract primary payload from RED and forward to downtracks
if r.downTrackSpreader.DownTrackCount() == 0 {
return
}
redLen, err := r.encodeRedForPrimary(pkt.Packet, r.redPayloadBuf[:])
if err != nil {
r.logger.Errorw("red encoding failed", err)
return
}
pPkt := *pkt
redRtpPacket := *pkt.Packet
redRtpPacket.Payload = r.redPayloadBuf[:redLen]
pPkt.Packet = &redRtpPacket
// not modify the ExtPacket.RawPacket here for performance since it is not used by the DownTrack,
// otherwise it should be set to the correct value (marshal the primary rtp packet)
r.downTrackSpreader.Broadcast(func(dt TrackSender) {
_ = dt.WriteRTP(&pPkt, spatialLayer)
})
}
func (r *RedReceiver) AddDownTrack(track TrackSender) error {
if r.closed.Load() {
return ErrReceiverClosed
}
if r.downTrackSpreader.HasDownTrack(track.SubscriberID()) {
r.logger.Infow("subscriberID already exists, replacing downtrack", "subscriberID", track.SubscriberID())
}
r.downTrackSpreader.Store(track)
return nil
}
func (r *RedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) {
if r.closed.Load() {
return
}
r.downTrackSpreader.Free(subscriberID)
}
func (r *RedReceiver) Close() {
r.closed.Store(true)
for _, dt := range r.downTrackSpreader.ResetAndGetDownTracks() {
dt.Close()
}
}
func (r *RedReceiver) ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) {
// red encoding don't support nack
return 0, bucket.ErrPacketNotFound
}
func (r *RedReceiver) encodeRedForPrimary(pkt *rtp.Packet, redPayload []byte) (int, error) {
redPkts := make([]*rtp.Packet, 0, maxRedCount+1)
for _, prev := range r.pktBuff {
if prev == nil || pkt.SequenceNumber == prev.SequenceNumber ||
(pkt.SequenceNumber-prev.SequenceNumber) > uint16(maxRedCount) {
continue
}
redPkts = append(redPkts, prev)
}
if r.pktBuff[1] == nil || pkt.SequenceNumber-r.pktBuff[1].SequenceNumber < 0x8000 {
/* update packet, not copy the rtp packet here since we only hold two packets for red encoding,
the upstream buffer size is much larger than two, so it is safe to use packet directly
*/
r.pktBuff[0], r.pktBuff[1] = r.pktBuff[1], pkt
}
var index int
for _, p := range redPkts {
/* RED payload https://datatracker.ietf.org/doc/html/rfc2198#section-3
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|F| block PT | timestamp offset | block length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
F: 1 bit First bit in header indicates whether another header block
follows. If 1 further header blocks follow, if 0 this is the
last header block.
*/
header := uint32(0x80 | uint8(opusPT))
header <<= 14
header |= (pkt.Timestamp - p.Timestamp) & 0x3FFF
header <<= 10
header |= uint32(len(p.Payload)) & 0x3FF
binary.BigEndian.PutUint32(redPayload[index:], header)
index += 4
}
// last block header
redPayload[index] = uint8(opusPT)
index++
// append data blocks
redPkts = append(redPkts, pkt)
for _, p := range redPkts {
if copy(redPayload[index:], p.Payload) < len(p.Payload) {
r.logger.Errorw("red payload don't have enough space", nil, "needsize", p.Payload)
return 0, bucket.ErrBufferTooSmall
}
index += len(p.Payload)
}
return index, nil
}
+222
View File
@@ -0,0 +1,222 @@
package sfu
import (
"encoding/binary"
"testing"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
)
type dummyDowntrack struct {
TrackSender
pkt *rtp.Packet
}
func (dt *dummyDowntrack) WriteRTP(p *buffer.ExtPacket, _ int32) error {
dt.pkt = p.Packet
return nil
}
func TestRedReceiver(t *testing.T) {
dt := &dummyDowntrack{TrackSender: &DownTrack{}}
tsStep := uint32(48000 / 1000 * 10)
t.Run("normal", func(t *testing.T) {
w := &WebRTCReceiver{isRED: true, kind: webrtc.RTPCodecTypeAudio}
require.Equal(t, w.GetRedReceiver(), w)
w.isRED = false
red := w.GetRedReceiver().(*RedReceiver)
require.NotNil(t, red)
require.NoError(t, red.AddDownTrack(dt))
header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111}
expectPkt := make([]*rtp.Packet, 0, maxRedCount+1)
for i := 0; i < 10; i++ {
hbuf, _ := header.Marshal()
pkt1 := &rtp.Packet{
Header: header,
Payload: hbuf,
}
expectPkt = append(expectPkt, pkt1)
if len(expectPkt) > maxRedCount+1 {
expectPkt = expectPkt[1:]
}
red.ForwardRTP(&buffer.ExtPacket{
Packet: pkt1,
}, 0)
verifyRedEncodings(t, dt.pkt, expectPkt)
header.SequenceNumber++
header.Timestamp += tsStep
}
})
t.Run("packet lost and jump", func(t *testing.T) {
w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio}
red := w.GetRedReceiver().(*RedReceiver)
require.NoError(t, red.AddDownTrack(dt))
header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111}
expectPkt := make([]*rtp.Packet, 0, maxRedCount+1)
for i := 0; i < 10; i++ {
if i%2 == 0 {
header.SequenceNumber++
header.Timestamp += tsStep
expectPkt = append(expectPkt, nil)
continue
}
hbuf, _ := header.Marshal()
pkt1 := &rtp.Packet{
Header: header,
Payload: hbuf,
}
expectPkt = append(expectPkt, pkt1)
if len(expectPkt) > maxRedCount+1 {
expectPkt = expectPkt[len(expectPkt)-maxRedCount-1:]
}
red.ForwardRTP(&buffer.ExtPacket{
Packet: pkt1,
}, 0)
verifyRedEncodings(t, dt.pkt, expectPkt)
header.SequenceNumber++
header.Timestamp += tsStep
}
// jump
header.SequenceNumber += 10
header.Timestamp += 10 * tsStep
expectPkt = expectPkt[:0]
for i := 0; i < 3; i++ {
hbuf, _ := header.Marshal()
pkt1 := &rtp.Packet{
Header: header,
Payload: hbuf,
}
expectPkt = append(expectPkt, pkt1)
if len(expectPkt) > maxRedCount+1 {
expectPkt = expectPkt[len(expectPkt)-maxRedCount-1:]
}
red.ForwardRTP(&buffer.ExtPacket{
Packet: pkt1,
}, 0)
verifyRedEncodings(t, dt.pkt, expectPkt)
header.SequenceNumber++
header.Timestamp += tsStep
}
})
t.Run("unorder and repeat", func(t *testing.T) {
w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio}
red := w.GetRedReceiver().(*RedReceiver)
require.NoError(t, red.AddDownTrack(dt))
header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111}
var prevPkts []*rtp.Packet
for i := 0; i < 10; i++ {
hbuf, _ := header.Marshal()
pkt1 := &rtp.Packet{
Header: header,
Payload: hbuf,
}
red.ForwardRTP(&buffer.ExtPacket{
Packet: pkt1,
}, 0)
header.SequenceNumber++
header.Timestamp += tsStep
prevPkts = append(prevPkts, pkt1)
}
// old unorder data don't have red records
expectPkt := prevPkts[len(prevPkts)-3 : len(prevPkts)-2]
red.ForwardRTP(&buffer.ExtPacket{
Packet: expectPkt[0],
}, 0)
verifyRedEncodings(t, dt.pkt, expectPkt)
// repeat packet only have 1 red records
expectPkt = prevPkts[len(prevPkts)-2:]
red.ForwardRTP(&buffer.ExtPacket{
Packet: expectPkt[1],
}, 0)
verifyRedEncodings(t, dt.pkt, expectPkt)
})
}
func verifyRedEncodings(t *testing.T, red *rtp.Packet, redPkts []*rtp.Packet) {
solidPkts := make([]*rtp.Packet, 0, len(redPkts))
for _, pkt := range redPkts {
if pkt != nil {
solidPkts = append(solidPkts, pkt)
}
}
pktsFromRed, err := extractPktsFromRed(red)
require.NoError(t, err)
require.Len(t, pktsFromRed, len(solidPkts))
for i, pkt := range pktsFromRed {
verifyEncodingEqual(t, pkt, solidPkts[i])
}
}
func verifyEncodingEqual(t *testing.T, p1, p2 *rtp.Packet) {
require.Equal(t, p1.Header.Timestamp, p2.Header.Timestamp)
require.Equal(t, p1.PayloadType, p2.PayloadType)
require.EqualValues(t, p1.Payload, p2.Payload)
}
type block struct {
tsOffset uint32
length int
pt uint8
}
func extractPktsFromRed(redPkt *rtp.Packet) ([]*rtp.Packet, error) {
payload := redPkt.Payload
var blocks []block
var blockLength int
for {
if payload[0]&0x80 == 0 {
// last block is primary encoding data
payload = payload[1:]
blocks = append(blocks, block{})
break
} else {
if len(payload) < 4 {
// illegal data
return nil, ErrIncompleteRedHeader
}
blockHead := binary.BigEndian.Uint32(payload[0:])
length := int(blockHead & 0x03FF)
blockHead >>= 10
tsOffset := blockHead & 0x3FFF
blockHead >>= 14
pt := uint8(blockHead & 0x7F)
payload = payload[4:]
blockLength += length
blocks = append(blocks, block{pt: pt, length: length, tsOffset: tsOffset})
}
}
if len(payload) < blockLength {
return nil, ErrIncompleteRedBlock
}
pkts := make([]*rtp.Packet, 0, len(blocks))
for _, b := range blocks {
if b.tsOffset == 0 {
pkts = append(pkts, &rtp.Packet{Header: redPkt.Header, Payload: payload})
break
}
header := redPkt.Header
header.Timestamp -= b.tsOffset
header.PayloadType = b.pt
pkts = append(pkts, &rtp.Packet{Header: header, Payload: payload[:b.length]})
payload = payload[b.length:]
}
return pkts, nil
}