diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index c7159d810..312689133 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -3,10 +3,12 @@ package sfu import ( "fmt" "math" + "math/rand" "strings" "sync" "time" + "github.com/pion/rtp" "github.com/pion/webrtc/v3" "github.com/livekit/protocol/logger" @@ -136,9 +138,12 @@ type TranslationParams struct { // ------------------------------------------------------------------- type ForwarderState struct { - Started bool - RTP RTPMungerState - Codec interface{} + Started bool + PreStartTime time.Time + FirstTS uint32 + RefTSOffset uint32 + RTP RTPMungerState + Codec interface{} } func (f ForwarderState) String() string { @@ -147,7 +152,14 @@ func (f ForwarderState) String() string { case codecmunger.VP8State: codecString = codecState.String() } - return fmt.Sprintf("ForwarderState{started: %v, rtp: %s, codec: %s}", f.Started, f.RTP.String(), codecString) + return fmt.Sprintf("ForwarderState{started: %v, preStartTime: %s, firstTS: %d, refTSOffset: %d, rtp: %s, codec: %s}", + f.Started, + f.PreStartTime.String(), + f.FirstTS, + f.RefTSOffset, + f.RTP.String(), + codecString, + ) } // ------------------------------------------------------------------- @@ -164,8 +176,11 @@ type Forwarder struct { pubMuted bool started bool + preStartTime time.Time + firstTS uint32 lastSSRC uint32 referenceLayerSpatial int32 + refTSOffset uint32 parkedLayerTimer *time.Timer @@ -313,13 +328,14 @@ func (f *Forwarder) GetState() ForwarderState { return ForwarderState{} } - state := ForwarderState{ - Started: f.started, - RTP: f.rtpMunger.GetLast(), - Codec: f.codecMunger.GetState(), + return ForwarderState{ + Started: f.started, + PreStartTime: f.preStartTime, + FirstTS: f.firstTS, + RefTSOffset: f.refTSOffset, + RTP: f.rtpMunger.GetLast(), + Codec: f.codecMunger.GetState(), } - - return state } func (f *Forwarder) SeedState(state ForwarderState) { @@ -334,6 +350,9 @@ func (f *Forwarder) SeedState(state ForwarderState) { f.codecMunger.SeedState(state.Codec) f.started = true + f.preStartTime = state.PreStartTime + f.firstTS = state.FirstTS + f.refTSOffset = state.RefTSOffset } func (f *Forwarder) Mute(muted bool) (bool, buffer.VideoLayer) { @@ -1467,8 +1486,17 @@ func (f *Forwarder) getTranslationParamsCommon(extPkt *buffer.ExtPacket, layer i ts, err := f.getExpectedRTPTimestamp(switchingAt) if err == nil { expectedTS = ts + } else { + rtpDiff := uint32(0) + if !f.preStartTime.IsZero() && f.refTSOffset == 0 { + timeSinceFirst := time.Since(f.preStartTime) + rtpDiff = uint32(timeSinceFirst.Nanoseconds() * int64(f.codec.ClockRate) / 1e9) + f.refTSOffset = f.firstTS + rtpDiff - refTS + } + expectedTS += rtpDiff } } + refTS += f.refTSOffset nextTS, explain := getNextTimestamp(lastTS, refTS, expectedTS) f.logger.Debugw( "next timestamp on switch", @@ -1588,12 +1616,35 @@ func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer in return tp, nil } +func (f *Forwarder) maybeStart() { + if f.started { + return + } + + f.started = true + f.preStartTime = time.Now() + + extPkt := &buffer.ExtPacket{ + Packet: &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: uint16(rand.Intn(1<<14)) + uint16(1<<15), // a random number in third quartile of sequence number space + Timestamp: uint32(rand.Intn(1<<30)) + uint32(1<<31), // a random number in third quartile of time stamp space + }, + }, + } + f.rtpMunger.SetLastSnTs(extPkt) + + f.firstTS = extPkt.Packet.Timestamp +} + func (f *Forwarder) GetSnTsForPadding(num int) ([]SnTs, error) { f.lock.Lock() defer f.lock.Unlock() - // padding is used for probing. Padding packets should be - // at only the frame boundaries to ensure decoder sequencer does + f.maybeStart() + + // padding is used for probing. Padding packets should only + // be at frame boundaries to ensure decoder sequencer does // not get out-of-sync. But, when a stream is paused, // force a frame marker as a restart of the stream will // start with a key frame which will reset the decoder. @@ -1601,18 +1652,30 @@ func (f *Forwarder) GetSnTsForPadding(num int) ([]SnTs, error) { if !f.vls.GetTarget().IsValid() { forceMarker = true } - return f.rtpMunger.UpdateAndGetPaddingSnTs(num, 0, 0, forceMarker) + return f.rtpMunger.UpdateAndGetPaddingSnTs(num, 0, 0, forceMarker, 0) } func (f *Forwarder) GetSnTsForBlankFrames(frameRate uint32, numPackets int) ([]SnTs, bool, error) { f.lock.Lock() defer f.lock.Unlock() + f.maybeStart() + frameEndNeeded := !f.rtpMunger.IsOnFrameBoundary() if frameEndNeeded { numPackets++ } - snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs(numPackets, f.codec.ClockRate, frameRate, frameEndNeeded) + + lastTS := f.rtpMunger.GetLast().LastTS + expectedTS := lastTS + if f.getExpectedRTPTimestamp != nil { + ts, err := f.getExpectedRTPTimestamp(time.Now()) + if err == nil { + expectedTS = ts + } + } + nextTS, _ := getNextTimestamp(lastTS, expectedTS, expectedTS) + snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs(numPackets, f.codec.ClockRate, frameRate, frameEndNeeded, nextTS) return snts, frameEndNeeded, err } diff --git a/pkg/sfu/forwarder_test.go b/pkg/sfu/forwarder_test.go index c99374ded..df03a058a 100644 --- a/pkg/sfu/forwarder_test.go +++ b/pkg/sfu/forwarder_test.go @@ -1844,9 +1844,15 @@ func TestForwardGetSnTsForBlankFrames(t *testing.T) { frameRate := uint32(30) var sntsExpected = make([]SnTs, numPadding) for i := 0; i < numPadding; i++ { + // first blank frame should have same timestamp as last frame as end frame is synthesized + ts := params.Timestamp + if i != 0 { + // +1 here due to expected time stamp bumpint by at least one so that time stamp is always moving ahead + ts = params.Timestamp + 1 + ((uint32(i)*clockRate)+frameRate-1)/frameRate + } sntsExpected[i] = SnTs{ sequenceNumber: params.SequenceNumber + uint16(i) + 1, - timestamp: params.Timestamp + (uint32(i)*clockRate)/frameRate, + timestamp: ts, } } require.Equal(t, sntsExpected, snts) @@ -1858,7 +1864,8 @@ func TestForwardGetSnTsForBlankFrames(t *testing.T) { for i := 0; i < numPadding; i++ { sntsExpected[i] = SnTs{ sequenceNumber: params.SequenceNumber + uint16(len(snts)) + uint16(i) + 1, - timestamp: snts[len(snts)-1].timestamp + (uint32(i+1)*clockRate)/frameRate, + // +1 here due to expected time stamp bumpint by at least one so that time stamp is always moving ahead + timestamp: snts[len(snts)-1].timestamp + 1 + ((uint32(i+1)*clockRate)+frameRate-1)/frameRate, } } snts, frameEndNeeded, err = f.GetSnTsForBlankFrames(30, numBlankFrames) diff --git a/pkg/sfu/rtpmunger.go b/pkg/sfu/rtpmunger.go index 7880dfee7..a80ad6000 100644 --- a/pkg/sfu/rtpmunger.go +++ b/pkg/sfu/rtpmunger.go @@ -2,7 +2,6 @@ package sfu import ( "fmt" - "math/rand" "github.com/livekit/protocol/logger" @@ -42,19 +41,17 @@ type SnTs struct { // ---------------------------------------------------------------------- type RTPMungerState struct { - Started bool - LastSN uint16 - LastTS uint32 + LastSN uint16 + LastTS uint32 } func (r RTPMungerState) String() string { - return fmt.Sprintf("RTPMungerState{started: %v, lastSN: %d, lastTS: %d)", r.Started, r.LastSN, r.LastTS) + return fmt.Sprintf("RTPMungerState{lastSN: %d, lastTS: %d)", r.LastSN, r.LastTS) } // ---------------------------------------------------------------------- type RTPMungerParams struct { - started bool highestIncomingSN uint16 lastSN uint16 snOffset uint16 @@ -86,7 +83,6 @@ func NewRTPMunger(logger logger.Logger) *RTPMunger { func (r *RTPMunger) GetParams() RTPMungerParams { return RTPMungerParams{ - started: r.started, highestIncomingSN: r.highestIncomingSN, lastSN: r.lastSN, snOffset: r.snOffset, @@ -99,14 +95,12 @@ func (r *RTPMunger) GetParams() RTPMungerParams { func (r *RTPMunger) GetLast() RTPMungerState { return RTPMungerState{ - Started: r.started, - LastSN: r.lastSN, - LastTS: r.lastTS, + LastSN: r.lastSN, + LastTS: r.lastTS, } } func (r *RTPMunger) SeedLast(state RTPMungerState) { - r.started = state.Started r.lastSN = state.LastSN r.lastTS = state.LastTS } @@ -114,14 +108,8 @@ func (r *RTPMunger) SeedLast(state RTPMungerState) { func (r *RTPMunger) SetLastSnTs(extPkt *buffer.ExtPacket) { r.highestIncomingSN = extPkt.Packet.SequenceNumber - 1 r.highestIncomingTS = extPkt.Packet.Timestamp - if !r.started { - r.lastSN = extPkt.Packet.SequenceNumber - r.lastTS = extPkt.Packet.Timestamp - } else { - r.snOffset = extPkt.Packet.SequenceNumber - r.lastSN - 1 - r.tsOffset = extPkt.Packet.Timestamp - r.lastTS - 1 - } - r.started = true + r.lastSN = extPkt.Packet.SequenceNumber + r.lastTS = extPkt.Packet.Timestamp } func (r *RTPMunger) UpdateSnTsOffsets(extPkt *buffer.ExtPacket, snAdjust uint16, tsAdjust uint32) { @@ -270,7 +258,8 @@ func (r *RTPMunger) FilterRTX(nacks []uint16) []uint16 { return filtered } -func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate uint32, forceMarker bool) ([]SnTs, error) { +func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate uint32, forceMarker bool, rtpTimestamp uint32) ([]SnTs, error) { + useLastTSForFirst := false tsOffset := 0 if !r.lastMarker { if !forceMarker { @@ -278,20 +267,25 @@ func (r *RTPMunger) UpdateAndGetPaddingSnTs(num int, clockRate uint32, frameRate } // if forcing frame end, use timestamp of latest received frame for the first one + useLastTSForFirst = true tsOffset = 1 } - if !r.started { - r.lastSN = uint16(rand.Intn(1<<14)) + uint16(1<<15) // a random number in third quartile of sequence number space - r.lastTS = uint32(rand.Intn(1<<30)) + uint32(1<<31) // a random number in third quartile of time stamp space - r.started = true - } - + lastTS := r.lastTS vals := make([]SnTs, num) for i := 0; i < num; i++ { vals[i].sequenceNumber = r.lastSN + uint16(i) + 1 if frameRate != 0 { - vals[i].timestamp = r.lastTS + uint32(i+1-tsOffset)*(clockRate/frameRate) + if useLastTSForFirst && i == 0 { + vals[i].timestamp = r.lastTS + } else { + ts := rtpTimestamp + ((uint32(i+1-tsOffset)*clockRate)+frameRate-1)/frameRate + if (ts-lastTS) == 0 || (ts-lastTS) > (1<<31) { + ts = lastTS + 1 + lastTS = ts + } + vals[i].timestamp = ts + } } else { vals[i].timestamp = r.lastTS } diff --git a/pkg/sfu/rtpmunger_test.go b/pkg/sfu/rtpmunger_test.go index e1e772023..68ea60716 100644 --- a/pkg/sfu/rtpmunger_test.go +++ b/pkg/sfu/rtpmunger_test.go @@ -28,49 +28,11 @@ func TestSetLastSnTs(t *testing.T) { r.SetLastSnTs(extPkt) require.Equal(t, uint16(23332), r.highestIncomingSN) + require.Equal(t, uint32(0xabcdef), r.highestIncomingTS) require.Equal(t, uint16(23333), r.lastSN) require.Equal(t, uint32(0xabcdef), r.lastTS) require.Equal(t, uint16(0), r.snOffset) require.Equal(t, uint32(0), r.tsOffset) - require.True(t, r.started) - - // force re-start - r.started = false - - params = &testutils.TestExtPacketParams{ - SequenceNumber: 43, - Timestamp: 0xabcdef, - SSRC: 0x12345678, - } - extPkt, err = testutils.GetTestExtPacket(params) - require.NoError(t, err) - require.NotNil(t, extPkt) - - r.SetLastSnTs(extPkt) - require.Equal(t, uint16(42), r.highestIncomingSN) - require.Equal(t, uint16(43), r.lastSN) - require.Equal(t, uint32(0xabcdef), r.lastTS) - require.Equal(t, uint16(0), r.snOffset) - require.Equal(t, uint32(0), r.tsOffset) - require.True(t, r.started) - - // set on a started munger - params = &testutils.TestExtPacketParams{ - SequenceNumber: 23457, - Timestamp: 0xabcdef, - SSRC: 0x12345678, - } - extPkt, err = testutils.GetTestExtPacket(params) - require.NoError(t, err) - require.NotNil(t, extPkt) - - r.SetLastSnTs(extPkt) - require.Equal(t, uint16(23456), r.highestIncomingSN) - require.Equal(t, uint16(43), r.lastSN) - require.Equal(t, uint32(0xabcdef), r.lastTS) - require.Equal(t, uint16(23413), r.snOffset) - require.Equal(t, uint32(0xffffffff), r.tsOffset) - require.True(t, r.started) } func TestUpdateSnTsOffsets(t *testing.T) { @@ -92,6 +54,7 @@ func TestUpdateSnTsOffsets(t *testing.T) { extPkt, _ = testutils.GetTestExtPacket(params) r.UpdateSnTsOffsets(extPkt, 1, 1) require.Equal(t, uint16(33332), r.highestIncomingSN) + require.Equal(t, uint32(0xabcdef), r.highestIncomingTS) require.Equal(t, uint16(23333), r.lastSN) require.Equal(t, uint32(0xabcdef), r.lastTS) require.Equal(t, uint16(9999), r.snOffset) @@ -109,9 +72,10 @@ func TestPacketDropped(t *testing.T) { } extPkt, _ := testutils.GetTestExtPacket(params) r.SetLastSnTs(extPkt) - require.Equal(t, r.highestIncomingSN, uint16(23332)) - require.Equal(t, r.lastSN, uint16(23333)) - require.Equal(t, r.lastTS, uint32(0xabcdef)) + require.Equal(t, uint16(23332), r.highestIncomingSN) + require.Equal(t, uint32(0xabcdef), r.highestIncomingTS) + require.Equal(t, uint16(23333), r.lastSN) + require.Equal(t, uint32(0xabcdef), r.lastTS) require.Equal(t, uint16(0), r.snOffset) require.Equal(t, uint32(0), r.tsOffset) @@ -126,8 +90,8 @@ func TestPacketDropped(t *testing.T) { } extPkt, _ = testutils.GetTestExtPacket(params) r.PacketDropped(extPkt) - require.Equal(t, r.highestIncomingSN, uint16(23333)) - require.Equal(t, r.lastSN, uint16(23333)) + require.Equal(t, uint16(23333), r.highestIncomingSN) + require.Equal(t, uint16(23333), r.lastSN) require.Equal(t, uint16(0), r.snOffset) // drop a head packet and check offset increases @@ -160,7 +124,7 @@ func TestPacketDropped(t *testing.T) { require.Equal(t, uint16(1), r.snOffsets[snOffsetWritePtr]) snOffsetWritePtr = (snOffsetWritePtr + 1) & SnOffsetCacheMask require.Equal(t, snOffsetWritePtr, r.snOffsetsWritePtr) - require.Equal(t, r.lastSN, uint16(44444)) + require.Equal(t, uint16(44444), r.lastSN) require.Equal(t, uint16(1), r.snOffset) } @@ -464,7 +428,7 @@ func TestUpdateAndGetPaddingSnTs(t *testing.T) { r.SetLastSnTs(extPkt) // getting padding without forcing marker should fail - _, err := r.UpdateAndGetPaddingSnTs(10, 10, 5, false) + _, err := r.UpdateAndGetPaddingSnTs(10, 10, 5, false, 0) require.Error(t, err) require.ErrorIs(t, err, ErrPaddingNotOnFrameBoundary) @@ -477,10 +441,10 @@ func TestUpdateAndGetPaddingSnTs(t *testing.T) { for i := 0; i < numPadding; i++ { sntsExpected[i] = SnTs{ sequenceNumber: params.SequenceNumber + uint16(i) + 1, - timestamp: params.Timestamp + (uint32(i)*clockRate)/frameRate, + timestamp: params.Timestamp + ((uint32(i)*clockRate)+frameRate-1)/frameRate, } } - snts, err := r.UpdateAndGetPaddingSnTs(numPadding, clockRate, frameRate, true) + snts, err := r.UpdateAndGetPaddingSnTs(numPadding, clockRate, frameRate, true, params.Timestamp) require.NoError(t, err) require.Equal(t, sntsExpected, snts) @@ -488,10 +452,10 @@ func TestUpdateAndGetPaddingSnTs(t *testing.T) { for i := 0; i < numPadding; i++ { sntsExpected[i] = SnTs{ sequenceNumber: params.SequenceNumber + uint16(len(snts)) + uint16(i) + 1, - timestamp: snts[len(snts)-1].timestamp + (uint32(i+1)*clockRate)/frameRate, + timestamp: snts[len(snts)-1].timestamp + ((uint32(i+1)*clockRate)+frameRate-1)/frameRate, } } - snts, err = r.UpdateAndGetPaddingSnTs(numPadding, clockRate, frameRate, false) + snts, err = r.UpdateAndGetPaddingSnTs(numPadding, clockRate, frameRate, false, snts[len(snts)-1].timestamp) require.NoError(t, err) require.Equal(t, sntsExpected, snts) }