From 7d06cfca8b935ed111eb0a0664fad561db50f5ab Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 2 Apr 2026 19:54:14 +0530 Subject: [PATCH] Keep subscription synchronous when publisher is expected to resume. (#4424) Subscription can switch between remote track and local track or vice-versa. When that happens, closing the subscribed track of one or the other asynchronously means the re-subscribe could race with subscribed track closing. Keeping the case of `isExpectedToResume` sync to prevent the race. Would be good to support multiple subscribed tracks per subscription. So, when subscribed track closes, subscription manager can check and close the correct subscribed track. But, it gets complex to clearly determine if a subccription is pending or not and other events. So, keeping it sync. --- pkg/rtc/mediatracksubscriptions.go | 1 + pkg/rtc/participant.go | 2 +- pkg/rtc/subscribedtrack.go | 12 +- .../servicefakes/fake_service_store.go | 160 ------------------ pkg/service/wire_gen.go | 14 +- 5 files changed, 14 insertions(+), 175 deletions(-) diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 0eda270ed..69575fe6b 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -247,6 +247,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * // But, the subscription could be removed early if the published track is closed // while adding subscription. In those cases, subscription manager would not have set // the `OnClose` callback. So, set it here to handle cases of early close. + // Subscription manager will reset this if this subscription proceeds till that point. subTrack.OnClose(func(isExpectedToResume bool) { if !isExpectedToResume { if err := sub.RemoveTrackLocal(sender); err != nil { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 20eaaa9d8..4efd90f23 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1620,7 +1620,7 @@ func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { // callback could close the remote participant/tracks before the local track // is fully active. // - // that could lead subscribers to unsubscribe due to source + // that could lead to subscribers unsubscribing due to source // track going away, i. e. in this case, the remote track close would have // notified the subscription manager, the subscription manager would // re-resolve to check if the track is still active and unsubscribe if none diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index eae084956..9f6c6ebe1 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -229,7 +229,7 @@ func (t *SubscribedTrack) Bound(err error) { // for DownTrack callback to notify us that it's closed func (t *SubscribedTrack) Close(isExpectedToResume bool) { if onClose := t.onClose.Load(); onClose != nil { - go onClose.(func(bool))(isExpectedToResume) + onClose.(func(bool))(isExpectedToResume) } } @@ -475,10 +475,8 @@ func (t *SubscribedTrack) OnDownTrackClose(isExpectedToResume bool) { } } - go func() { - if t.params.OnDownTrackClosed != nil { - t.params.OnDownTrackClosed(t.params.Subscriber.ID()) - } - t.Close(isExpectedToResume) - }() + if t.params.OnDownTrackClosed != nil { + t.params.OnDownTrackClosed(t.params.Subscriber.ID()) + } + t.Close(isExpectedToResume) } diff --git a/pkg/service/servicefakes/fake_service_store.go b/pkg/service/servicefakes/fake_service_store.go index e4808aa54..45a6ed556 100644 --- a/pkg/service/servicefakes/fake_service_store.go +++ b/pkg/service/servicefakes/fake_service_store.go @@ -10,20 +10,6 @@ import ( ) type FakeServiceStore struct { - ListParticipantsStub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error) - listParticipantsMutex sync.RWMutex - listParticipantsArgsForCall []struct { - arg1 context.Context - arg2 livekit.RoomName - } - listParticipantsReturns struct { - result1 []*livekit.ParticipantInfo - result2 error - } - listParticipantsReturnsOnCall map[int]struct { - result1 []*livekit.ParticipantInfo - result2 error - } ListRoomsStub func(context.Context, []livekit.RoomName) ([]*livekit.Room, error) listRoomsMutex sync.RWMutex listRoomsArgsForCall []struct { @@ -38,21 +24,6 @@ type FakeServiceStore struct { result1 []*livekit.Room result2 error } - LoadParticipantStub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) - loadParticipantMutex sync.RWMutex - loadParticipantArgsForCall []struct { - arg1 context.Context - arg2 livekit.RoomName - arg3 livekit.ParticipantIdentity - } - loadParticipantReturns struct { - result1 *livekit.ParticipantInfo - result2 error - } - loadParticipantReturnsOnCall map[int]struct { - result1 *livekit.ParticipantInfo - result2 error - } LoadRoomStub func(context.Context, livekit.RoomName, bool) (*livekit.Room, *livekit.RoomInternal, error) loadRoomMutex sync.RWMutex loadRoomArgsForCall []struct { @@ -88,71 +59,6 @@ type FakeServiceStore struct { invocationsMutex sync.RWMutex } -func (fake *FakeServiceStore) ListParticipants(arg1 context.Context, arg2 livekit.RoomName) ([]*livekit.ParticipantInfo, error) { - fake.listParticipantsMutex.Lock() - ret, specificReturn := fake.listParticipantsReturnsOnCall[len(fake.listParticipantsArgsForCall)] - fake.listParticipantsArgsForCall = append(fake.listParticipantsArgsForCall, struct { - arg1 context.Context - arg2 livekit.RoomName - }{arg1, arg2}) - stub := fake.ListParticipantsStub - fakeReturns := fake.listParticipantsReturns - fake.recordInvocation("ListParticipants", []interface{}{arg1, arg2}) - fake.listParticipantsMutex.Unlock() - if stub != nil { - return stub(arg1, arg2) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeServiceStore) ListParticipantsCallCount() int { - fake.listParticipantsMutex.RLock() - defer fake.listParticipantsMutex.RUnlock() - return len(fake.listParticipantsArgsForCall) -} - -func (fake *FakeServiceStore) ListParticipantsCalls(stub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error)) { - fake.listParticipantsMutex.Lock() - defer fake.listParticipantsMutex.Unlock() - fake.ListParticipantsStub = stub -} - -func (fake *FakeServiceStore) ListParticipantsArgsForCall(i int) (context.Context, livekit.RoomName) { - fake.listParticipantsMutex.RLock() - defer fake.listParticipantsMutex.RUnlock() - argsForCall := fake.listParticipantsArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 -} - -func (fake *FakeServiceStore) ListParticipantsReturns(result1 []*livekit.ParticipantInfo, result2 error) { - fake.listParticipantsMutex.Lock() - defer fake.listParticipantsMutex.Unlock() - fake.ListParticipantsStub = nil - fake.listParticipantsReturns = struct { - result1 []*livekit.ParticipantInfo - result2 error - }{result1, result2} -} - -func (fake *FakeServiceStore) ListParticipantsReturnsOnCall(i int, result1 []*livekit.ParticipantInfo, result2 error) { - fake.listParticipantsMutex.Lock() - defer fake.listParticipantsMutex.Unlock() - fake.ListParticipantsStub = nil - if fake.listParticipantsReturnsOnCall == nil { - fake.listParticipantsReturnsOnCall = make(map[int]struct { - result1 []*livekit.ParticipantInfo - result2 error - }) - } - fake.listParticipantsReturnsOnCall[i] = struct { - result1 []*livekit.ParticipantInfo - result2 error - }{result1, result2} -} - func (fake *FakeServiceStore) ListRooms(arg1 context.Context, arg2 []livekit.RoomName) ([]*livekit.Room, error) { var arg2Copy []livekit.RoomName if arg2 != nil { @@ -223,72 +129,6 @@ func (fake *FakeServiceStore) ListRoomsReturnsOnCall(i int, result1 []*livekit.R }{result1, result2} } -func (fake *FakeServiceStore) LoadParticipant(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) { - fake.loadParticipantMutex.Lock() - ret, specificReturn := fake.loadParticipantReturnsOnCall[len(fake.loadParticipantArgsForCall)] - fake.loadParticipantArgsForCall = append(fake.loadParticipantArgsForCall, struct { - arg1 context.Context - arg2 livekit.RoomName - arg3 livekit.ParticipantIdentity - }{arg1, arg2, arg3}) - stub := fake.LoadParticipantStub - fakeReturns := fake.loadParticipantReturns - fake.recordInvocation("LoadParticipant", []interface{}{arg1, arg2, arg3}) - fake.loadParticipantMutex.Unlock() - if stub != nil { - return stub(arg1, arg2, arg3) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeServiceStore) LoadParticipantCallCount() int { - fake.loadParticipantMutex.RLock() - defer fake.loadParticipantMutex.RUnlock() - return len(fake.loadParticipantArgsForCall) -} - -func (fake *FakeServiceStore) LoadParticipantCalls(stub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error)) { - fake.loadParticipantMutex.Lock() - defer fake.loadParticipantMutex.Unlock() - fake.LoadParticipantStub = stub -} - -func (fake *FakeServiceStore) LoadParticipantArgsForCall(i int) (context.Context, livekit.RoomName, livekit.ParticipantIdentity) { - fake.loadParticipantMutex.RLock() - defer fake.loadParticipantMutex.RUnlock() - argsForCall := fake.loadParticipantArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 -} - -func (fake *FakeServiceStore) LoadParticipantReturns(result1 *livekit.ParticipantInfo, result2 error) { - fake.loadParticipantMutex.Lock() - defer fake.loadParticipantMutex.Unlock() - fake.LoadParticipantStub = nil - fake.loadParticipantReturns = struct { - result1 *livekit.ParticipantInfo - result2 error - }{result1, result2} -} - -func (fake *FakeServiceStore) LoadParticipantReturnsOnCall(i int, result1 *livekit.ParticipantInfo, result2 error) { - fake.loadParticipantMutex.Lock() - defer fake.loadParticipantMutex.Unlock() - fake.LoadParticipantStub = nil - if fake.loadParticipantReturnsOnCall == nil { - fake.loadParticipantReturnsOnCall = make(map[int]struct { - result1 *livekit.ParticipantInfo - result2 error - }) - } - fake.loadParticipantReturnsOnCall[i] = struct { - result1 *livekit.ParticipantInfo - result2 error - }{result1, result2} -} - func (fake *FakeServiceStore) LoadRoom(arg1 context.Context, arg2 livekit.RoomName, arg3 bool) (*livekit.Room, *livekit.RoomInternal, error) { fake.loadRoomMutex.Lock() ret, specificReturn := fake.loadRoomReturnsOnCall[len(fake.loadRoomArgsForCall)] diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 2a3a7590c..7872e1e45 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -90,23 +90,23 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live } rtcEgressLauncher := NewEgressLauncher(egressClient, ioInfoService, objectStore) topicFormatter := rpc.NewTopicFormatter() - roomClient, err := rpc.NewTypedRoomClient(clientParams) + v, err := rpc.NewTypedRoomClient(clientParams) if err != nil { return nil, err } - participantClient, err := rpc.NewTypedParticipantClient(clientParams) + v2, err := rpc.NewTypedParticipantClient(clientParams) if err != nil { return nil, err } - roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient) + roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, v, v2) if err != nil { return nil, err } - agentDispatchInternalClient, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) + v3, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) if err != nil { return nil, err } - agentDispatchService := NewAgentDispatchService(agentDispatchInternalClient, topicFormatter, roomAllocator, router) + agentDispatchService := NewAgentDispatchService(v3, topicFormatter, roomAllocator, router) egressService := NewEgressService(egressClient, rtcEgressLauncher, ioInfoService, roomService) ingressConfig := getIngressConfig(conf) ingressClient, err := rpc.NewIngressClient(clientParams) @@ -121,11 +121,11 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live } sipService := NewSIPService(sipConfig, nodeID, messageBus, sipClient, sipStore, roomService, telemetryService) rtcService := NewRTCService(conf, roomAllocator, router, telemetryService) - whipParticipantClient, err := rpc.NewTypedWHIPParticipantClient(clientParams) + v4, err := rpc.NewTypedWHIPParticipantClient(clientParams) if err != nil { return nil, err } - serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, whipParticipantClient) + serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, v4) if err != nil { return nil, err }