diff --git a/pkg/sfu/redprimaryreceiver.go b/pkg/sfu/redprimaryreceiver.go index 9469b5cd4..9df8d26ea 100644 --- a/pkg/sfu/redprimaryreceiver.go +++ b/pkg/sfu/redprimaryreceiver.go @@ -27,6 +27,12 @@ type RedPrimaryReceiver struct { downTrackSpreader *DownTrackSpreader logger logger.Logger closed atomic.Bool + + firstPktReceived bool + lastSeq uint16 + + // bitset for upstream packet receive history [lastSeq-8, lastSeq-1], bit 1 represents packet received + pktHistory byte } func NewRedPrimaryReceiver(receiver TrackReceiver, dsp DownTrackSpreaderParams) *RedPrimaryReceiver { @@ -42,25 +48,23 @@ func (r *RedPrimaryReceiver) ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int3 if r.downTrackSpreader.DownTrackCount() == 0 { return } - payload, err := ExtractPrimaryEncodingForRED(pkt.Packet.Payload) + + pkts, err := r.getSendPktsFromRed(pkt.Packet) if err != nil { - r.logger.Errorw("get primary encoding for red failed", err, "payloadtype", pkt.Packet.PayloadType) + r.logger.Errorw("get encoding for red failed", err, "payloadtype", pkt.Packet.PayloadType) return } - pPkt := *pkt - primaryRtpPacket := *pkt.Packet - primaryRtpPacket.Payload = payload - pPkt.Packet = &primaryRtpPacket + for _, sendPkt := range pkts { + pPkt := *pkt + pPkt.Packet = sendPkt - // not modify the ExtPacket.RawPacket here for performance since it is not used by the DownTrack, - // otherwise it should be set to the correct value (marshal the primary rtp packet) - - r.downTrackSpreader.Broadcast(func(dt TrackSender) { - _ = dt.WriteRTP(&pPkt, spatialLayer) - }) - - // TODO : detect rtp packet lost, recover it from the redundant payload then send to downstreams. + // not modify the ExtPacket.RawPacket here for performance since it is not used by the DownTrack, + // otherwise it should be set to the correct value (marshal the primary rtp packet) + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + _ = dt.WriteRTP(&pPkt, spatialLayer) + }) + } } func (r *RedPrimaryReceiver) AddDownTrack(track TrackSender) error { @@ -99,7 +103,7 @@ func (r *RedPrimaryReceiver) ReadRTP(buf []byte, layer uint8, sn uint16) (int, e var pkt rtp.Packet pkt.Unmarshal(buf[:n]) - payload, err := ExtractPrimaryEncodingForRED(pkt.Payload) + payload, err := extractPrimaryEncodingForRED(pkt.Payload) if err != nil { return 0, err } @@ -108,7 +112,117 @@ func (r *RedPrimaryReceiver) ReadRTP(buf []byte, layer uint8, sn uint16) (int, e return pkt.MarshalTo(buf) } -func ExtractPrimaryEncodingForRED(payload []byte) ([]byte, error) { +func (r *RedPrimaryReceiver) getSendPktsFromRed(rtp *rtp.Packet) ([]*rtp.Packet, error) { + var needRecover bool + if !r.firstPktReceived { + r.lastSeq = rtp.SequenceNumber + r.pktHistory = 0 + r.firstPktReceived = true + } else { + diff := rtp.SequenceNumber - r.lastSeq + switch { + case diff == 0: // duplicate + break + case diff > 0x8000: // unorder + // in history + if 65535-diff < 8 { + r.pktHistory |= 1 << (65535 - diff) + needRecover = true + } + + case diff > 8: // long jump + r.lastSeq = rtp.SequenceNumber + r.pktHistory = 0 + needRecover = true + + default: + r.lastSeq = rtp.SequenceNumber + r.pktHistory = (r.pktHistory << byte(diff)) | 1<<(diff-1) + needRecover = true + } + } + + var recoverBits byte + if needRecover { + bitIndex := r.lastSeq - rtp.SequenceNumber + for i := 0; i < maxRedCount; i++ { + if bitIndex > 7 { + break + } + if r.pktHistory&byte(1<>= 10 + tsOffset := blockHead & 0x3FFF + blockHead >>= 14 + pt := uint8(blockHead & 0x7F) + payload = payload[4:] + blockLength += length + blocks = append(blocks, block{pt: pt, length: length, tsOffset: tsOffset}) + } + } + + if len(payload) < blockLength { + return nil, ErrIncompleteRedBlock + } + + pkts := make([]*rtp.Packet, 0, len(blocks)) + for i, b := range blocks { + if b.primary { + pkts = append(pkts, &rtp.Packet{Header: redPkt.Header, Payload: payload}) + break + } + + // last block is primary encoding + recoverIndex := len(blocks) - i - 1 + if recoverIndex < 1 || recoverBits&(1<<(recoverIndex-1)) == 0 { + payload = payload[b.length:] + continue + } + + header := redPkt.Header + header.SequenceNumber -= uint16(recoverIndex) + header.Timestamp -= b.tsOffset + header.PayloadType = b.pt + pkts = append(pkts, &rtp.Packet{Header: header, Payload: payload[:b.length]}) + payload = payload[b.length:] + } + + return pkts, nil +} + +func extractPrimaryEncodingForRED(payload []byte) ([]byte, error) { /* RED payload https://datatracker.ietf.org/doc/html/rfc2198#section-3 0 1 2 3 diff --git a/pkg/sfu/redreceiver.go b/pkg/sfu/redreceiver.go index 81e84bad7..7cc8ad08a 100644 --- a/pkg/sfu/redreceiver.go +++ b/pkg/sfu/redreceiver.go @@ -2,6 +2,7 @@ package sfu import ( "encoding/binary" + "fmt" "go.uber.org/atomic" @@ -98,8 +99,17 @@ func (r *RedReceiver) ReadRTP(buf []byte, layer uint8, sn uint16) (int, error) { func (r *RedReceiver) encodeRedForPrimary(pkt *rtp.Packet, redPayload []byte) (int, error) { redPkts := make([]*rtp.Packet, 0, maxRedCount+1) - for _, prev := range r.pktBuff { - if prev == nil || pkt.SequenceNumber == prev.SequenceNumber || + lastNilPkt := -1 + for i := len(r.pktBuff) - 1; i >= 0; i-- { + if r.pktBuff[i] == nil { + lastNilPkt = i + break + } + + } + + for _, prev := range r.pktBuff[lastNilPkt+1:] { + if pkt.SequenceNumber == prev.SequenceNumber || (pkt.SequenceNumber-prev.SequenceNumber) > uint16(maxRedCount) { continue } @@ -113,6 +123,10 @@ func (r *RedReceiver) encodeRedForPrimary(pkt *rtp.Packet, redPayload []byte) (i r.pktBuff[0], r.pktBuff[1] = r.pktBuff[1], pkt } + return encodeRedForPrimary(redPkts, pkt, redPayload) +} + +func encodeRedForPrimary(redPkts []*rtp.Packet, primary *rtp.Packet, redPayload []byte) (int, error) { var index int for _, p := range redPkts { /* RED payload https://datatracker.ietf.org/doc/html/rfc2198#section-3 @@ -127,7 +141,7 @@ func (r *RedReceiver) encodeRedForPrimary(pkt *rtp.Packet, redPayload []byte) (i */ header := uint32(0x80 | uint8(opusPT)) header <<= 14 - header |= (pkt.Timestamp - p.Timestamp) & 0x3FFF + header |= (primary.Timestamp - p.Timestamp) & 0x3FFF header <<= 10 header |= uint32(len(p.Payload)) & 0x3FF binary.BigEndian.PutUint32(redPayload[index:], header) @@ -138,11 +152,10 @@ func (r *RedReceiver) encodeRedForPrimary(pkt *rtp.Packet, redPayload []byte) (i index++ // append data blocks - redPkts = append(redPkts, pkt) + redPkts = append(redPkts, primary) for _, p := range redPkts { if copy(redPayload[index:], p.Payload) < len(p.Payload) { - r.logger.Errorw("red payload don't have enough space", nil, "needsize", p.Payload) - return 0, bucket.ErrBufferTooSmall + return 0, fmt.Errorf("red payload don't have enough space, needsize %d", len(p.Payload)) } index += len(p.Payload) } diff --git a/pkg/sfu/redreceiver_test.go b/pkg/sfu/redreceiver_test.go index 003f1feb2..3bb633e63 100644 --- a/pkg/sfu/redreceiver_test.go +++ b/pkg/sfu/redreceiver_test.go @@ -1,7 +1,6 @@ package sfu import ( - "encoding/binary" "testing" "github.com/pion/rtp" @@ -11,20 +10,22 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/buffer" ) +const tsStep = uint32(48000 / 1000 * 10) + type dummyDowntrack struct { TrackSender - pkt *rtp.Packet + lastReceivedPkt *rtp.Packet + receivedPkts []*rtp.Packet } func (dt *dummyDowntrack) WriteRTP(p *buffer.ExtPacket, _ int32) error { - dt.pkt = p.Packet + dt.lastReceivedPkt = p.Packet + dt.receivedPkts = append(dt.receivedPkts, p.Packet) return nil } func TestRedReceiver(t *testing.T) { - dt := &dummyDowntrack{TrackSender: &DownTrack{}} - tsStep := uint32(48000 / 1000 * 10) t.Run("normal", func(t *testing.T) { w := &WebRTCReceiver{isRED: true, kind: webrtc.RTPCodecTypeAudio} @@ -36,22 +37,15 @@ func TestRedReceiver(t *testing.T) { header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} expectPkt := make([]*rtp.Packet, 0, maxRedCount+1) - for i := 0; i < 10; i++ { - hbuf, _ := header.Marshal() - pkt1 := &rtp.Packet{ - Header: header, - Payload: hbuf, - } - expectPkt = append(expectPkt, pkt1) + for _, pkt := range generatePkts(header, 10, tsStep) { + expectPkt = append(expectPkt, pkt) if len(expectPkt) > maxRedCount+1 { expectPkt = expectPkt[1:] } red.ForwardRTP(&buffer.ExtPacket{ - Packet: pkt1, + Packet: pkt, }, 0) - verifyRedEncodings(t, dt.pkt, expectPkt) - header.SequenceNumber++ - header.Timestamp += tsStep + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) } }) @@ -81,7 +75,7 @@ func TestRedReceiver(t *testing.T) { red.ForwardRTP(&buffer.ExtPacket{ Packet: pkt1, }, 0) - verifyRedEncodings(t, dt.pkt, expectPkt) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) header.SequenceNumber++ header.Timestamp += tsStep } @@ -91,22 +85,15 @@ func TestRedReceiver(t *testing.T) { header.Timestamp += 10 * tsStep expectPkt = expectPkt[:0] - for i := 0; i < 3; i++ { - hbuf, _ := header.Marshal() - pkt1 := &rtp.Packet{ - Header: header, - Payload: hbuf, - } - expectPkt = append(expectPkt, pkt1) + for _, pkt := range generatePkts(header, 3, tsStep) { + expectPkt = append(expectPkt, pkt) if len(expectPkt) > maxRedCount+1 { expectPkt = expectPkt[len(expectPkt)-maxRedCount-1:] } red.ForwardRTP(&buffer.ExtPacket{ - Packet: pkt1, + Packet: pkt, }, 0) - verifyRedEncodings(t, dt.pkt, expectPkt) - header.SequenceNumber++ - header.Timestamp += tsStep + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) } }) @@ -117,18 +104,11 @@ func TestRedReceiver(t *testing.T) { header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} var prevPkts []*rtp.Packet - for i := 0; i < 10; i++ { - hbuf, _ := header.Marshal() - pkt1 := &rtp.Packet{ - Header: header, - Payload: hbuf, - } + for _, pkt := range generatePkts(header, 10, tsStep) { red.ForwardRTP(&buffer.ExtPacket{ - Packet: pkt1, + Packet: pkt, }, 0) - header.SequenceNumber++ - header.Timestamp += tsStep - prevPkts = append(prevPkts, pkt1) + prevPkts = append(prevPkts, pkt) } // old unorder data don't have red records @@ -136,14 +116,14 @@ func TestRedReceiver(t *testing.T) { red.ForwardRTP(&buffer.ExtPacket{ Packet: expectPkt[0], }, 0) - verifyRedEncodings(t, dt.pkt, expectPkt) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) // repeat packet only have 1 red records expectPkt = prevPkts[len(prevPkts)-2:] red.ForwardRTP(&buffer.ExtPacket{ Packet: expectPkt[1], }, 0) - verifyRedEncodings(t, dt.pkt, expectPkt) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) }) } @@ -154,7 +134,7 @@ func verifyRedEncodings(t *testing.T, red *rtp.Packet, redPkts []*rtp.Packet) { solidPkts = append(solidPkts, pkt) } } - pktsFromRed, err := extractPktsFromRed(red) + pktsFromRed, err := extractPktsFromRed(red, 0xFF) require.NoError(t, err) require.Len(t, pktsFromRed, len(solidPkts)) for i, pkt := range pktsFromRed { @@ -162,61 +142,162 @@ func verifyRedEncodings(t *testing.T, red *rtp.Packet, redPkts []*rtp.Packet) { } } +func verifyPktsEqual(t *testing.T, p1s, p2s []*rtp.Packet) { + require.Len(t, p1s, len(p2s)) + for i, pkt := range p1s { + verifyEncodingEqual(t, pkt, p2s[i]) + } +} + func verifyEncodingEqual(t *testing.T, p1, p2 *rtp.Packet) { require.Equal(t, p1.Header.Timestamp, p2.Header.Timestamp) require.Equal(t, p1.PayloadType, p2.PayloadType) - require.EqualValues(t, p1.Payload, p2.Payload) + require.EqualValues(t, p1.Payload, p2.Payload, "seq1 %s", p1.SequenceNumber) } -type block struct { - tsOffset uint32 - length int - pt uint8 +func generatePkts(header rtp.Header, count int, tsStep uint32) []*rtp.Packet { + pkts := make([]*rtp.Packet, 0, count) + for i := 0; i < count; i++ { + hbuf, _ := header.Marshal() + pkts = append(pkts, &rtp.Packet{ + Header: header, + Payload: hbuf, + }) + header.SequenceNumber++ + header.Timestamp += tsStep + } + return pkts } -func extractPktsFromRed(redPkt *rtp.Packet) ([]*rtp.Packet, error) { - payload := redPkt.Payload - var blocks []block - var blockLength int - for { - if payload[0]&0x80 == 0 { - // last block is primary encoding data - payload = payload[1:] - blocks = append(blocks, block{}) - break - } else { - if len(payload) < 4 { - // illegal data - return nil, ErrIncompleteRedHeader +func generateRedPkts(t *testing.T, pkts []*rtp.Packet, redCount int) []*rtp.Packet { + redPkts := make([]*rtp.Packet, 0, len(pkts)) + for i, pkt := range pkts { + encodingPkts := make([]*rtp.Packet, 0, redCount) + for j := i - redCount; j < i; j++ { + if j < 0 { + continue } - blockHead := binary.BigEndian.Uint32(payload[0:]) - length := int(blockHead & 0x03FF) - blockHead >>= 10 - tsOffset := blockHead & 0x3FFF - blockHead >>= 14 - pt := uint8(blockHead & 0x7F) - payload = payload[4:] - blockLength += length - blocks = append(blocks, block{pt: pt, length: length, tsOffset: tsOffset}) + encodingPkts = append(encodingPkts, pkts[j]) } + buf := make([]byte, mtuSize) + redPkt := *pkt + encoded, err := encodeRedForPrimary(encodingPkts, pkt, buf) + require.NoError(t, err) + redPkt.Payload = buf[:encoded] + redPkts = append(redPkts, &redPkt) } - - if len(payload) < blockLength { - return nil, ErrIncompleteRedBlock - } - - pkts := make([]*rtp.Packet, 0, len(blocks)) - for _, b := range blocks { - if b.tsOffset == 0 { - pkts = append(pkts, &rtp.Packet{Header: redPkt.Header, Payload: payload}) - break - } - header := redPkt.Header - header.Timestamp -= b.tsOffset - header.PayloadType = b.pt - pkts = append(pkts, &rtp.Packet{Header: header, Payload: payload[:b.length]}) - payload = payload[b.length:] - } - - return pkts, nil + return redPkts +} + +func testRedRedPrimaryReceiver(t *testing.T, maxPktCount, redCount int, sendPktIdx, expectPktIdx []int) { + dt := &dummyDowntrack{TrackSender: &DownTrack{}} + w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio} + require.Equal(t, w.GetPrimaryReceiverForRed(), w) + w.isRED = true + red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + primaryPkts := generatePkts(header, maxPktCount, tsStep) + redPkts := generateRedPkts(t, primaryPkts, redCount) + + for _, i := range sendPktIdx { + red.ForwardRTP(&buffer.ExtPacket{ + Packet: redPkts[i], + }, 0) + } + + expectPkts := make([]*rtp.Packet, 0, len(expectPktIdx)) + for _, i := range expectPktIdx { + expectPkts = append(expectPkts, primaryPkts[i]) + } + + verifyPktsEqual(t, expectPkts, dt.receivedPkts) +} + +func TestRedPrimaryReceiver(t *testing.T) { + w := &WebRTCReceiver{kind: webrtc.RTPCodecTypeAudio} + require.Equal(t, w.GetPrimaryReceiverForRed(), w) + w.isRED = true + red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) + require.NotNil(t, red) + + t.Run("packet should send only once", func(t *testing.T) { + maxPktCount := 19 + var sendPktIndex []int + for i := 0; i < maxPktCount; i++ { + sendPktIndex = append(sendPktIndex, i) + } + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, sendPktIndex) + }) + + t.Run("packet duplicate and unorder", func(t *testing.T) { + maxPktCount := 19 + var sendPktIndex []int + for i := 0; i < maxPktCount; i++ { + sendPktIndex = append(sendPktIndex, i) + if i > 0 { + sendPktIndex = append(sendPktIndex, i-1) + } + sendPktIndex = append(sendPktIndex, i) + } + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, sendPktIndex) + }) + + t.Run("full recover", func(t *testing.T) { + maxPktCount := 19 + var sendPktIndex, recvPktIndex []int + for i := 0; i < maxPktCount; i++ { + recvPktIndex = append(recvPktIndex, i) + + // drop packets covered by red encoding + if i%(maxRedCount+1) != 0 { + continue + } + sendPktIndex = append(sendPktIndex, i) + } + + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) + }) + + t.Run("lost 2 but red recover 1", func(t *testing.T) { + maxPktCount := 19 + sendPktIndex := []int{0, 3, 6, 9, 12} + recvPktIndex := []int{0, 2, 3, 5, 6, 8, 9, 11, 12} + testRedRedPrimaryReceiver(t, maxPktCount, 1, sendPktIndex, recvPktIndex) + }) + + t.Run("part recover and long jump", func(t *testing.T) { + maxPktCount := 50 + sendPktIndex := []int{0, 5, 12, 21 /*long jump*/, 24, 27} + recvPktIndex := []int{0, 3, 4, 5, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 27} + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) + }) + + t.Run("unorder", func(t *testing.T) { + maxPktCount := 50 + sendPktIndex := []int{20, 10 /*unorder can't recover*/, 25, 23, 34} + recvPktIndex := []int{20, 10, 23, 24, 25, 21, 22, 23, 32, 33, 34} + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) + }) +} + +func TestExtractPrimaryEncodingForRED(t *testing.T) { + header := rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + pkts := generatePkts(header, 10, tsStep) + redPkts := generateRedPkts(t, pkts, 2) + + primaryPkts := make([]*rtp.Packet, 0, len(redPkts)) + + for _, redPkt := range redPkts { + payload, err := extractPrimaryEncodingForRED(redPkt.Payload) + require.NoError(t, err) + primaryPkts = append(primaryPkts, &rtp.Packet{ + Header: redPkt.Header, + Payload: payload, + }) + } + + verifyPktsEqual(t, pkts, primaryPkts) }