diff --git a/pkg/rtc/mediatrack_test.go b/pkg/rtc/mediatrack_test.go index b3fbd2b6c..be8228f20 100644 --- a/pkg/rtc/mediatrack_test.go +++ b/pkg/rtc/mediatrack_test.go @@ -189,6 +189,7 @@ func TestSubscribedMaxQuality(t *testing.T) { return nil }) + mt.maxSubscribedQuality = livekit.VideoQuality_LOW mt.notifySubscriberMaxQuality("s1", livekit.VideoQuality_HIGH) mt.notifySubscriberMaxQuality("s2", livekit.VideoQuality_MEDIUM) diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go index 0b58ef26e..c523378fe 100644 --- a/pkg/rtc/mediatrackreceiver.go +++ b/pkg/rtc/mediatrackreceiver.go @@ -85,6 +85,17 @@ func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver return t } +func (t *MediaTrackReceiver) Restart() { + t.lock.Lock() + receiver := t.receiver + t.lock.Unlock() + + if receiver != nil { + receiver.SetMaxExpectedSpatialLayer(sfu.DefaultMaxLayerSpatial) + t.MediaTrackSubscriptions.Restart() + } +} + func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver) { t.lock.Lock() t.receiver = receiver diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index a434da117..d3cf67672 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -65,7 +65,7 @@ func NewMediaTrackSubscriptions(params MediaTrackSubscriptionsParams) *MediaTrac pendingClose: make(map[livekit.ParticipantID]types.SubscribedTrack), maxSubscriberQuality: make(map[livekit.ParticipantID]livekit.VideoQuality), maxSubscriberNodeQuality: make(map[livekit.NodeID]livekit.VideoQuality), - maxSubscribedQuality: livekit.VideoQuality_LOW, + maxSubscribedQuality: livekit.VideoQuality_HIGH, maxSubscribedQualityDebounce: debounce.New(params.VideoConfig.DynacastPauseDelay), } @@ -73,7 +73,11 @@ func NewMediaTrackSubscriptions(params MediaTrackSubscriptionsParams) *MediaTrac } func (t *MediaTrackSubscriptions) Start() { - t.startMaxQualityTimer() + t.startMaxQualityTimer(false) +} + +func (t *MediaTrackSubscriptions) Restart() { + t.startMaxQualityTimer(true) } func (t *MediaTrackSubscriptions) Close() { @@ -546,7 +550,7 @@ func (t *MediaTrackSubscriptions) UpdateQualityChange(force bool) { } } -func (t *MediaTrackSubscriptions) startMaxQualityTimer() { +func (t *MediaTrackSubscriptions) startMaxQualityTimer(force bool) { t.maxQualityLock.Lock() defer t.maxQualityLock.Unlock() @@ -556,7 +560,7 @@ func (t *MediaTrackSubscriptions) startMaxQualityTimer() { t.maxQualityTimer = time.AfterFunc(initialQualityUpdateWait, func() { t.stopMaxQualityTimer() - t.UpdateQualityChange(false) + t.UpdateQualityChange(force) }) } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 84ad4fb76..452badaa9 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -666,6 +666,9 @@ func (p *ParticipantImpl) ICERestart() error { // not connected, skip return nil } + + p.UpTrackManager.Restart() + return p.subscriber.CreateAndSendOffer(&webrtc.OfferOptions{ ICERestart: true, }) diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 7cc17985d..8bd5531ac 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -192,6 +192,7 @@ type MediaTrack interface { IsSimulcast() bool Receiver() sfu.TrackReceiver + Restart() // callbacks AddOnClose(func()) diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 0371852eb..cd027a4d1 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -183,6 +183,10 @@ type FakeLocalMediaTrack struct { arg1 livekit.ParticipantID arg2 bool } + RestartStub func() + restartMutex sync.RWMutex + restartArgsForCall []struct { + } RevokeDisallowedSubscribersStub func([]livekit.ParticipantID) []livekit.ParticipantID revokeDisallowedSubscribersMutex sync.RWMutex revokeDisallowedSubscribersArgsForCall []struct { @@ -1178,6 +1182,30 @@ func (fake *FakeLocalMediaTrack) RemoveSubscriberArgsForCall(i int) (livekit.Par return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeLocalMediaTrack) Restart() { + fake.restartMutex.Lock() + fake.restartArgsForCall = append(fake.restartArgsForCall, struct { + }{}) + stub := fake.RestartStub + fake.recordInvocation("Restart", []interface{}{}) + fake.restartMutex.Unlock() + if stub != nil { + fake.RestartStub() + } +} + +func (fake *FakeLocalMediaTrack) RestartCallCount() int { + fake.restartMutex.RLock() + defer fake.restartMutex.RUnlock() + return len(fake.restartArgsForCall) +} + +func (fake *FakeLocalMediaTrack) RestartCalls(stub func()) { + fake.restartMutex.Lock() + defer fake.restartMutex.Unlock() + fake.RestartStub = stub +} + func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantID) []livekit.ParticipantID { var arg1Copy []livekit.ParticipantID if arg1 != nil { @@ -1598,6 +1626,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() + fake.restartMutex.RLock() + defer fake.restartMutex.RUnlock() fake.revokeDisallowedSubscribersMutex.RLock() defer fake.revokeDisallowedSubscribersMutex.RUnlock() fake.sdpCidMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 64583dc07..9d01a3017 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -161,6 +161,10 @@ type FakeMediaTrack struct { arg1 livekit.ParticipantID arg2 bool } + RestartStub func() + restartMutex sync.RWMutex + restartArgsForCall []struct { + } RevokeDisallowedSubscribersStub func([]livekit.ParticipantID) []livekit.ParticipantID revokeDisallowedSubscribersMutex sync.RWMutex revokeDisallowedSubscribersArgsForCall []struct { @@ -1022,6 +1026,30 @@ func (fake *FakeMediaTrack) RemoveSubscriberArgsForCall(i int) (livekit.Particip return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeMediaTrack) Restart() { + fake.restartMutex.Lock() + fake.restartArgsForCall = append(fake.restartArgsForCall, struct { + }{}) + stub := fake.RestartStub + fake.recordInvocation("Restart", []interface{}{}) + fake.restartMutex.Unlock() + if stub != nil { + fake.RestartStub() + } +} + +func (fake *FakeMediaTrack) RestartCallCount() int { + fake.restartMutex.RLock() + defer fake.restartMutex.RUnlock() + return len(fake.restartArgsForCall) +} + +func (fake *FakeMediaTrack) RestartCalls(stub func()) { + fake.restartMutex.Lock() + defer fake.restartMutex.Unlock() + fake.RestartStub = stub +} + func (fake *FakeMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantID) []livekit.ParticipantID { var arg1Copy []livekit.ParticipantID if arg1 != nil { @@ -1300,6 +1328,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() + fake.restartMutex.RLock() + defer fake.restartMutex.RUnlock() fake.revokeDisallowedSubscribersMutex.RLock() defer fake.revokeDisallowedSubscribersMutex.RUnlock() fake.setMutedMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 648781b62..f6b1a2530 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -46,6 +46,12 @@ func NewUpTrackManager(params UpTrackManagerParams) *UpTrackManager { func (u *UpTrackManager) Start() { } +func (u *UpTrackManager) Restart() { + for _, t := range u.GetPublishedTracks() { + t.Restart() + } +} + func (u *UpTrackManager) Close() { u.lock.Lock() u.closed = true