diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index c16b4bb7f..acd85037d 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -21,6 +21,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" + "go.uber.org/atomic" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" @@ -161,6 +162,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * }) // Bind callback can happen from replaceTrack, so set it up early + var reusingTransceiver atomic.Bool var dtState sfu.DownTrackState downTrack.OnBinding(func(err error) { if err != nil { @@ -168,7 +170,9 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * return } wr.DetermineReceiver(downTrack.Codec()) - downTrack.SeedState(dtState) + if reusingTransceiver.Load() { + downTrack.SeedState(dtState) + } if err = wr.AddDownTrack(downTrack); err != nil && err != sfu.ErrReceiverClosed { sub.GetLogger().Errorw( "could not add down track", err, @@ -216,6 +220,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * "publisherID", subTrack.PublisherID(), "trackID", trackID, ) + reusingTransceiver.Store(true) rtpSender := existingTransceiver.Sender() if rtpSender != nil { // replaced track will bind immediately without negotiation, SetTransceiver first before bind @@ -246,6 +251,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * existingTransceiver.Stop() } } + reusingTransceiver.Store(false) // if cannot replace, find an unused transceiver or add new one if transceiver == nil { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 0bf290ceb..853936955 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -129,6 +129,7 @@ type ParticipantParams struct { TURNSEnabled bool GetParticipantInfo func(pID livekit.ParticipantID) *livekit.ParticipantInfo GetRegionSettings func(ip string) *livekit.RegionSettings + GetSubscriberForwarderState func(p types.LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error) DisableSupervisor bool ReconnectOnPublicationError bool ReconnectOnSubscriptionError bool @@ -231,6 +232,7 @@ type ParticipantImpl struct { onICEConfigChanged func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig) cachedDownTracks map[livekit.TrackID]*downTrackState + forwarderState map[livekit.TrackID]*livekit.RTPForwarderState supervisor *supervisor.ParticipantSupervisor @@ -862,7 +864,6 @@ func (p *ParticipantImpl) SetMigrateInfo( previousOffer, previousAnswer *webrtc.SessionDescription, mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo, - forwarderStates map[livekit.TrackID]*livekit.RTPForwarderState, ) { p.pendingTracksLock.Lock() for _, t := range mediaTracks { @@ -883,10 +884,6 @@ func (p *ParticipantImpl) SetMigrateInfo( } p.TransportManager.SetMigrateInfo(previousOffer, previousAnswer, dataChannels) - - for trackID, fs := range forwarderStates { - p.CacheDownTrack(trackID, nil, sfu.DownTrackState{ForwarderState: fs}) - } } func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseReason, isExpectedToResume bool) error { @@ -1056,6 +1053,7 @@ func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { case types.MigrateStateComplete: p.TransportManager.ProcessPendingPublisherDataChannels() + p.cacheForwarderState() } if onMigrateStateChange := p.getOnMigrateStateChange(); onMigrateStateChange != nil { @@ -1214,7 +1212,9 @@ func (p *ParticipantImpl) onTrackSubscribed(subTrack types.SubscribedTrack) { return } if p.TransportManager.HasSubscriberEverConnected() { - subTrack.DownTrack().SetConnected() + dt := subTrack.DownTrack() + dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(subTrack.ID())}) + dt.SetConnected() } p.TransportManager.AddSubscribedTrack(subTrack) }) @@ -1641,7 +1641,7 @@ func (p *ParticipantImpl) onPublisherInitialConnected() { func (p *ParticipantImpl) onSubscriberInitialConnected() { go p.subscriberRTCPWorker() - p.setDowntracksConnected() + p.setDownTracksConnected() } func (p *ParticipantImpl) onPrimaryTransportInitialConnected() { @@ -2453,14 +2453,35 @@ func (p *ParticipantImpl) postRtcp(pkts []rtcp.Packet) { }, postRtcpOp{p, pkts}) } -func (p *ParticipantImpl) setDowntracksConnected() { +func (p *ParticipantImpl) setDownTracksConnected() { for _, t := range p.SubscriptionManager.GetSubscribedTracks() { if dt := t.DownTrack(); dt != nil { + dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(t.ID())}) dt.SetConnected() } } } +func (p *ParticipantImpl) cacheForwarderState() { + // if migrating in, get forwarder state from migrating out node to facilitate resume + if f := p.params.GetSubscriberForwarderState; f != nil { + if fs, err := f(p); err == nil { + p.lock.Lock() + p.forwarderState = fs + p.lock.Unlock() + } + } +} + +func (p *ParticipantImpl) getAndDeleteForwarderState(trackID livekit.TrackID) *livekit.RTPForwarderState { + p.lock.Lock() + fs := p.forwarderState[trackID] + delete(p.forwarderState, trackID) + p.lock.Unlock() + + return fs +} + func (p *ParticipantImpl) CacheDownTrack(trackID livekit.TrackID, rtpTransceiver *webrtc.RTPTransceiver, downTrack sfu.DownTrackState) { p.lock.Lock() if existing := p.cachedDownTracks[trackID]; existing != nil && existing.transceiver != rtpTransceiver { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 593c47e96..e3d25a59d 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -417,7 +417,6 @@ type LocalParticipant interface { previousOffer, previousAnswer *webrtc.SessionDescription, mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo, - forwarderStates map[livekit.TrackID]*livekit.RTPForwarderState, ) UpdateMediaRTT(rtt uint32) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 918a5f6a0..79f22b445 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -782,14 +782,13 @@ type FakeLocalParticipant struct { setMetadataArgsForCall []struct { arg1 string } - SetMigrateInfoStub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, map[livekit.TrackID]*livekit.RTPForwarderState) + SetMigrateInfoStub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo) setMigrateInfoMutex sync.RWMutex setMigrateInfoArgsForCall []struct { arg1 *webrtc.SessionDescription arg2 *webrtc.SessionDescription arg3 []*livekit.TrackPublishedResponse arg4 []*livekit.DataChannelInfo - arg5 map[livekit.TrackID]*livekit.RTPForwarderState } SetMigrateStateStub func(types.MigrateState) setMigrateStateMutex sync.RWMutex @@ -5204,7 +5203,7 @@ func (fake *FakeLocalParticipant) SetMetadataArgsForCall(i int) string { return argsForCall.arg1 } -func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription, arg2 *webrtc.SessionDescription, arg3 []*livekit.TrackPublishedResponse, arg4 []*livekit.DataChannelInfo, arg5 map[livekit.TrackID]*livekit.RTPForwarderState) { +func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription, arg2 *webrtc.SessionDescription, arg3 []*livekit.TrackPublishedResponse, arg4 []*livekit.DataChannelInfo) { var arg3Copy []*livekit.TrackPublishedResponse if arg3 != nil { arg3Copy = make([]*livekit.TrackPublishedResponse, len(arg3)) @@ -5221,13 +5220,12 @@ func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription arg2 *webrtc.SessionDescription arg3 []*livekit.TrackPublishedResponse arg4 []*livekit.DataChannelInfo - arg5 map[livekit.TrackID]*livekit.RTPForwarderState - }{arg1, arg2, arg3Copy, arg4Copy, arg5}) + }{arg1, arg2, arg3Copy, arg4Copy}) stub := fake.SetMigrateInfoStub - fake.recordInvocation("SetMigrateInfo", []interface{}{arg1, arg2, arg3Copy, arg4Copy, arg5}) + fake.recordInvocation("SetMigrateInfo", []interface{}{arg1, arg2, arg3Copy, arg4Copy}) fake.setMigrateInfoMutex.Unlock() if stub != nil { - fake.SetMigrateInfoStub(arg1, arg2, arg3, arg4, arg5) + fake.SetMigrateInfoStub(arg1, arg2, arg3, arg4) } } @@ -5237,17 +5235,17 @@ func (fake *FakeLocalParticipant) SetMigrateInfoCallCount() int { return len(fake.setMigrateInfoArgsForCall) } -func (fake *FakeLocalParticipant) SetMigrateInfoCalls(stub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, map[livekit.TrackID]*livekit.RTPForwarderState)) { +func (fake *FakeLocalParticipant) SetMigrateInfoCalls(stub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo)) { fake.setMigrateInfoMutex.Lock() defer fake.setMigrateInfoMutex.Unlock() fake.SetMigrateInfoStub = stub } -func (fake *FakeLocalParticipant) SetMigrateInfoArgsForCall(i int) (*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, map[livekit.TrackID]*livekit.RTPForwarderState) { +func (fake *FakeLocalParticipant) SetMigrateInfoArgsForCall(i int) (*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo) { fake.setMigrateInfoMutex.RLock() defer fake.setMigrateInfoMutex.RUnlock() argsForCall := fake.setMigrateInfoArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 } func (fake *FakeLocalParticipant) SetMigrateState(arg1 types.MigrateState) { diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index 8aa926a49..7f7f06d0f 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -385,11 +385,13 @@ func (f *Forwarder) GetState() *livekit.RTPForwarderState { state := &livekit.RTPForwarderState{ Started: f.started, ReferenceLayerSpatial: f.referenceLayerSpatial, - PreStartTime: f.preStartTime.UnixNano(), ExtFirstTimestamp: f.extFirstTS, DummyStartTimestampOffset: f.dummyStartTSOffset, RtpMunger: f.rtpMunger.GetState(), } + if !f.preStartTime.IsZero() { + state.PreStartTime = f.preStartTime.UnixNano() + } codecMungerState := f.codecMunger.GetState() if vp8MungerState, ok := codecMungerState.(*livekit.VP8MungerState); ok { @@ -413,7 +415,9 @@ func (f *Forwarder) SeedState(state *livekit.RTPForwarderState) { f.started = true f.referenceLayerSpatial = state.ReferenceLayerSpatial - f.preStartTime = time.Unix(0, state.PreStartTime) + if state.PreStartTime != 0 { + f.preStartTime = time.Unix(0, state.PreStartTime) + } f.extFirstTS = state.ExtFirstTimestamp f.dummyStartTSOffset = state.DummyStartTimestampOffset } @@ -1633,12 +1637,14 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e f.logger.Debugw( message, "layer", layer, + "referenceLayerSpatial", f.referenceLayerSpatial, "extExpectedTS", extExpectedTS, "incomingTS", extPkt.Packet.Timestamp, "extIncomingTS", extPkt.ExtTimestamp, "extRefTS", extRefTS, "extLastTS", extLastTS, "diffSeconds", math.Abs(diffSeconds), + "refInfos", wrappedRefInfoLogger{f}, ) } // TODO-REMOVE-AFTER-DATA-COLLECTION @@ -1646,12 +1652,14 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e f.logger.Infow( message, "layer", layer, + "referenceLayerSpatial", f.referenceLayerSpatial, "extExpectedTS", extExpectedTS, "incomingTS", extPkt.Packet.Timestamp, "extIncomingTS", extPkt.ExtTimestamp, "extRefTS", extRefTS, "extLastTS", extLastTS, "diffSeconds", math.Abs(diffSeconds), + "refInfos", wrappedRefInfoLogger{f}, ) }