Delay getting forwarder state till migration is complete. (#2909)

This commit is contained in:
Raja Subramanian
2024-08-06 12:45:46 +05:30
committed by GitHub
parent 8c323330b6
commit 13ee1aca28
5 changed files with 54 additions and 22 deletions
+7 -1
View File
@@ -21,6 +21,7 @@ import (
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
"go.uber.org/atomic"
sutils "github.com/livekit/livekit-server/pkg/utils"
"github.com/livekit/protocol/livekit"
@@ -161,6 +162,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *
})
// Bind callback can happen from replaceTrack, so set it up early
var reusingTransceiver atomic.Bool
var dtState sfu.DownTrackState
downTrack.OnBinding(func(err error) {
if err != nil {
@@ -168,7 +170,9 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *
return
}
wr.DetermineReceiver(downTrack.Codec())
downTrack.SeedState(dtState)
if reusingTransceiver.Load() {
downTrack.SeedState(dtState)
}
if err = wr.AddDownTrack(downTrack); err != nil && err != sfu.ErrReceiverClosed {
sub.GetLogger().Errorw(
"could not add down track", err,
@@ -216,6 +220,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *
"publisherID", subTrack.PublisherID(),
"trackID", trackID,
)
reusingTransceiver.Store(true)
rtpSender := existingTransceiver.Sender()
if rtpSender != nil {
// replaced track will bind immediately without negotiation, SetTransceiver first before bind
@@ -246,6 +251,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *
existingTransceiver.Stop()
}
}
reusingTransceiver.Store(false)
// if cannot replace, find an unused transceiver or add new one
if transceiver == nil {
+29 -8
View File
@@ -129,6 +129,7 @@ type ParticipantParams struct {
TURNSEnabled bool
GetParticipantInfo func(pID livekit.ParticipantID) *livekit.ParticipantInfo
GetRegionSettings func(ip string) *livekit.RegionSettings
GetSubscriberForwarderState func(p types.LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error)
DisableSupervisor bool
ReconnectOnPublicationError bool
ReconnectOnSubscriptionError bool
@@ -231,6 +232,7 @@ type ParticipantImpl struct {
onICEConfigChanged func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig)
cachedDownTracks map[livekit.TrackID]*downTrackState
forwarderState map[livekit.TrackID]*livekit.RTPForwarderState
supervisor *supervisor.ParticipantSupervisor
@@ -862,7 +864,6 @@ func (p *ParticipantImpl) SetMigrateInfo(
previousOffer, previousAnswer *webrtc.SessionDescription,
mediaTracks []*livekit.TrackPublishedResponse,
dataChannels []*livekit.DataChannelInfo,
forwarderStates map[livekit.TrackID]*livekit.RTPForwarderState,
) {
p.pendingTracksLock.Lock()
for _, t := range mediaTracks {
@@ -883,10 +884,6 @@ func (p *ParticipantImpl) SetMigrateInfo(
}
p.TransportManager.SetMigrateInfo(previousOffer, previousAnswer, dataChannels)
for trackID, fs := range forwarderStates {
p.CacheDownTrack(trackID, nil, sfu.DownTrackState{ForwarderState: fs})
}
}
func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseReason, isExpectedToResume bool) error {
@@ -1056,6 +1053,7 @@ func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) {
case types.MigrateStateComplete:
p.TransportManager.ProcessPendingPublisherDataChannels()
p.cacheForwarderState()
}
if onMigrateStateChange := p.getOnMigrateStateChange(); onMigrateStateChange != nil {
@@ -1214,7 +1212,9 @@ func (p *ParticipantImpl) onTrackSubscribed(subTrack types.SubscribedTrack) {
return
}
if p.TransportManager.HasSubscriberEverConnected() {
subTrack.DownTrack().SetConnected()
dt := subTrack.DownTrack()
dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(subTrack.ID())})
dt.SetConnected()
}
p.TransportManager.AddSubscribedTrack(subTrack)
})
@@ -1641,7 +1641,7 @@ func (p *ParticipantImpl) onPublisherInitialConnected() {
func (p *ParticipantImpl) onSubscriberInitialConnected() {
go p.subscriberRTCPWorker()
p.setDowntracksConnected()
p.setDownTracksConnected()
}
func (p *ParticipantImpl) onPrimaryTransportInitialConnected() {
@@ -2453,14 +2453,35 @@ func (p *ParticipantImpl) postRtcp(pkts []rtcp.Packet) {
}, postRtcpOp{p, pkts})
}
func (p *ParticipantImpl) setDowntracksConnected() {
func (p *ParticipantImpl) setDownTracksConnected() {
for _, t := range p.SubscriptionManager.GetSubscribedTracks() {
if dt := t.DownTrack(); dt != nil {
dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(t.ID())})
dt.SetConnected()
}
}
}
func (p *ParticipantImpl) cacheForwarderState() {
// if migrating in, get forwarder state from migrating out node to facilitate resume
if f := p.params.GetSubscriberForwarderState; f != nil {
if fs, err := f(p); err == nil {
p.lock.Lock()
p.forwarderState = fs
p.lock.Unlock()
}
}
}
func (p *ParticipantImpl) getAndDeleteForwarderState(trackID livekit.TrackID) *livekit.RTPForwarderState {
p.lock.Lock()
fs := p.forwarderState[trackID]
delete(p.forwarderState, trackID)
p.lock.Unlock()
return fs
}
func (p *ParticipantImpl) CacheDownTrack(trackID livekit.TrackID, rtpTransceiver *webrtc.RTPTransceiver, downTrack sfu.DownTrackState) {
p.lock.Lock()
if existing := p.cachedDownTracks[trackID]; existing != nil && existing.transceiver != rtpTransceiver {
-1
View File
@@ -417,7 +417,6 @@ type LocalParticipant interface {
previousOffer, previousAnswer *webrtc.SessionDescription,
mediaTracks []*livekit.TrackPublishedResponse,
dataChannels []*livekit.DataChannelInfo,
forwarderStates map[livekit.TrackID]*livekit.RTPForwarderState,
)
UpdateMediaRTT(rtt uint32)
@@ -782,14 +782,13 @@ type FakeLocalParticipant struct {
setMetadataArgsForCall []struct {
arg1 string
}
SetMigrateInfoStub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, map[livekit.TrackID]*livekit.RTPForwarderState)
SetMigrateInfoStub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo)
setMigrateInfoMutex sync.RWMutex
setMigrateInfoArgsForCall []struct {
arg1 *webrtc.SessionDescription
arg2 *webrtc.SessionDescription
arg3 []*livekit.TrackPublishedResponse
arg4 []*livekit.DataChannelInfo
arg5 map[livekit.TrackID]*livekit.RTPForwarderState
}
SetMigrateStateStub func(types.MigrateState)
setMigrateStateMutex sync.RWMutex
@@ -5204,7 +5203,7 @@ func (fake *FakeLocalParticipant) SetMetadataArgsForCall(i int) string {
return argsForCall.arg1
}
func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription, arg2 *webrtc.SessionDescription, arg3 []*livekit.TrackPublishedResponse, arg4 []*livekit.DataChannelInfo, arg5 map[livekit.TrackID]*livekit.RTPForwarderState) {
func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription, arg2 *webrtc.SessionDescription, arg3 []*livekit.TrackPublishedResponse, arg4 []*livekit.DataChannelInfo) {
var arg3Copy []*livekit.TrackPublishedResponse
if arg3 != nil {
arg3Copy = make([]*livekit.TrackPublishedResponse, len(arg3))
@@ -5221,13 +5220,12 @@ func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription
arg2 *webrtc.SessionDescription
arg3 []*livekit.TrackPublishedResponse
arg4 []*livekit.DataChannelInfo
arg5 map[livekit.TrackID]*livekit.RTPForwarderState
}{arg1, arg2, arg3Copy, arg4Copy, arg5})
}{arg1, arg2, arg3Copy, arg4Copy})
stub := fake.SetMigrateInfoStub
fake.recordInvocation("SetMigrateInfo", []interface{}{arg1, arg2, arg3Copy, arg4Copy, arg5})
fake.recordInvocation("SetMigrateInfo", []interface{}{arg1, arg2, arg3Copy, arg4Copy})
fake.setMigrateInfoMutex.Unlock()
if stub != nil {
fake.SetMigrateInfoStub(arg1, arg2, arg3, arg4, arg5)
fake.SetMigrateInfoStub(arg1, arg2, arg3, arg4)
}
}
@@ -5237,17 +5235,17 @@ func (fake *FakeLocalParticipant) SetMigrateInfoCallCount() int {
return len(fake.setMigrateInfoArgsForCall)
}
func (fake *FakeLocalParticipant) SetMigrateInfoCalls(stub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, map[livekit.TrackID]*livekit.RTPForwarderState)) {
func (fake *FakeLocalParticipant) SetMigrateInfoCalls(stub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo)) {
fake.setMigrateInfoMutex.Lock()
defer fake.setMigrateInfoMutex.Unlock()
fake.SetMigrateInfoStub = stub
}
func (fake *FakeLocalParticipant) SetMigrateInfoArgsForCall(i int) (*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, map[livekit.TrackID]*livekit.RTPForwarderState) {
func (fake *FakeLocalParticipant) SetMigrateInfoArgsForCall(i int) (*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo) {
fake.setMigrateInfoMutex.RLock()
defer fake.setMigrateInfoMutex.RUnlock()
argsForCall := fake.setMigrateInfoArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4
}
func (fake *FakeLocalParticipant) SetMigrateState(arg1 types.MigrateState) {
+10 -2
View File
@@ -385,11 +385,13 @@ func (f *Forwarder) GetState() *livekit.RTPForwarderState {
state := &livekit.RTPForwarderState{
Started: f.started,
ReferenceLayerSpatial: f.referenceLayerSpatial,
PreStartTime: f.preStartTime.UnixNano(),
ExtFirstTimestamp: f.extFirstTS,
DummyStartTimestampOffset: f.dummyStartTSOffset,
RtpMunger: f.rtpMunger.GetState(),
}
if !f.preStartTime.IsZero() {
state.PreStartTime = f.preStartTime.UnixNano()
}
codecMungerState := f.codecMunger.GetState()
if vp8MungerState, ok := codecMungerState.(*livekit.VP8MungerState); ok {
@@ -413,7 +415,9 @@ func (f *Forwarder) SeedState(state *livekit.RTPForwarderState) {
f.started = true
f.referenceLayerSpatial = state.ReferenceLayerSpatial
f.preStartTime = time.Unix(0, state.PreStartTime)
if state.PreStartTime != 0 {
f.preStartTime = time.Unix(0, state.PreStartTime)
}
f.extFirstTS = state.ExtFirstTimestamp
f.dummyStartTSOffset = state.DummyStartTimestampOffset
}
@@ -1633,12 +1637,14 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e
f.logger.Debugw(
message,
"layer", layer,
"referenceLayerSpatial", f.referenceLayerSpatial,
"extExpectedTS", extExpectedTS,
"incomingTS", extPkt.Packet.Timestamp,
"extIncomingTS", extPkt.ExtTimestamp,
"extRefTS", extRefTS,
"extLastTS", extLastTS,
"diffSeconds", math.Abs(diffSeconds),
"refInfos", wrappedRefInfoLogger{f},
)
}
// TODO-REMOVE-AFTER-DATA-COLLECTION
@@ -1646,12 +1652,14 @@ func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) e
f.logger.Infow(
message,
"layer", layer,
"referenceLayerSpatial", f.referenceLayerSpatial,
"extExpectedTS", extExpectedTS,
"incomingTS", extPkt.Packet.Timestamp,
"extIncomingTS", extPkt.ExtTimestamp,
"extRefTS", extRefTS,
"extLastTS", extLastTS,
"diffSeconds", math.Abs(diffSeconds),
"refInfos", wrappedRefInfoLogger{f},
)
}