diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 3443ce2b1..21c4e1413 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -853,22 +853,7 @@ func (p *ParticipantImpl) clearMigrationTimer() { p.lock.Unlock() } -func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { - allTransportConnected := p.TransportManager.HasSubscriberEverConnected() - if p.IsPublisher() { - allTransportConnected = allTransportConnected && p.TransportManager.HasPublisherEverConnected() - } - if !force && !allTransportConnected { - return false - } - - if onStart != nil { - onStart() - } - - p.sendLeaveRequest(types.ParticipantCloseReasonMigrationRequested, true, false, true) - p.CloseSignalConnection(types.SignallingCloseReasonMigration) - +func (p *ParticipantImpl) setupMigrationTimerLocked() { // // On subscriber peer connection, remote side will try ICE on both // pre- and post-migration ICE candidates as the migrating out @@ -880,9 +865,6 @@ func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { // to try and succeed. If not, close the subscriber peer connection // and help the remote side to narrow down its ICE candidate pool. // - p.clearMigrationTimer() - - p.lock.Lock() p.migrationTimer = time.AfterFunc(migrationWaitDuration, func() { p.clearMigrationTimer() @@ -901,11 +883,45 @@ func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { p.TransportManager.SubscriberClose() }) +} + +func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { + allTransportConnected := p.TransportManager.HasSubscriberEverConnected() + if p.IsPublisher() { + allTransportConnected = allTransportConnected && p.TransportManager.HasPublisherEverConnected() + } + if !force && !allTransportConnected { + return false + } + + if onStart != nil { + onStart() + } + + p.sendLeaveRequest(types.ParticipantCloseReasonMigrationRequested, true, false, true) + p.CloseSignalConnection(types.SignallingCloseReasonMigration) + + p.clearMigrationTimer() + + p.lock.Lock() + p.setupMigrationTimerLocked() p.lock.Unlock() return true } +func (p *ParticipantImpl) NotifyMigration() { + p.lock.Lock() + defer p.lock.Unlock() + + if p.migrationTimer != nil { + // already set up + return + } + + p.setupMigrationTimerLocked() +} + func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { preState := p.MigrateState() if preState == types.MigrateStateComplete || preState == s { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 88f624396..b57a7e92a 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -395,6 +395,7 @@ type LocalParticipant interface { // session migration MaybeStartMigration(force bool, onStart func()) bool + NotifyMigration() SetMigrateState(s MigrateState) MigrateState() MigrateState SetMigrateInfo(previousOffer, previousAnswer *webrtc.SessionDescription, mediaTracks []*livekit.TrackPublishedResponse, dataChannels []*livekit.DataChannelInfo) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 4cc2d1db4..6bbf2b5a6 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -557,6 +557,10 @@ type FakeLocalParticipant struct { negotiateArgsForCall []struct { arg1 bool } + NotifyMigrationStub func() + notifyMigrationMutex sync.RWMutex + notifyMigrationArgsForCall []struct { + } OnClaimsChangedStub func(func(types.LocalParticipant)) onClaimsChangedMutex sync.RWMutex onClaimsChangedArgsForCall []struct { @@ -3840,6 +3844,30 @@ func (fake *FakeLocalParticipant) NegotiateArgsForCall(i int) bool { return argsForCall.arg1 } +func (fake *FakeLocalParticipant) NotifyMigration() { + fake.notifyMigrationMutex.Lock() + fake.notifyMigrationArgsForCall = append(fake.notifyMigrationArgsForCall, struct { + }{}) + stub := fake.NotifyMigrationStub + fake.recordInvocation("NotifyMigration", []interface{}{}) + fake.notifyMigrationMutex.Unlock() + if stub != nil { + fake.NotifyMigrationStub() + } +} + +func (fake *FakeLocalParticipant) NotifyMigrationCallCount() int { + fake.notifyMigrationMutex.RLock() + defer fake.notifyMigrationMutex.RUnlock() + return len(fake.notifyMigrationArgsForCall) +} + +func (fake *FakeLocalParticipant) NotifyMigrationCalls(stub func()) { + fake.notifyMigrationMutex.Lock() + defer fake.notifyMigrationMutex.Unlock() + fake.NotifyMigrationStub = stub +} + func (fake *FakeLocalParticipant) OnClaimsChanged(arg1 func(types.LocalParticipant)) { fake.onClaimsChangedMutex.Lock() fake.onClaimsChangedArgsForCall = append(fake.onClaimsChangedArgsForCall, struct { @@ -6399,6 +6427,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.migrateStateMutex.RUnlock() fake.negotiateMutex.RLock() defer fake.negotiateMutex.RUnlock() + fake.notifyMigrationMutex.RLock() + defer fake.notifyMigrationMutex.RUnlock() fake.onClaimsChangedMutex.RLock() defer fake.onClaimsChangedMutex.RUnlock() fake.onCloseMutex.RLock()