From 335f4c33fbefcf87ba083de64fb42db97fc99a9a Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 3 Jan 2026 03:17:58 +0530 Subject: [PATCH] Swap result sink atomically rather than closing and setting. (#4216) To prevent WriteMessage after Close potentially. --- pkg/rtc/participant.go | 2 +- pkg/rtc/participant_internal_test.go | 8 +- pkg/rtc/participant_signal.go | 4 +- pkg/rtc/room.go | 3 +- pkg/rtc/signalling/interfaces.go | 2 +- pkg/rtc/signalling/signallerasyncbase.go | 32 +++++--- pkg/rtc/signalling/signallerunimplemented.go | 3 +- pkg/rtc/types/interfaces.go | 2 +- .../typesfakes/fake_local_participant.go | 76 ++++++++++--------- pkg/service/wire_gen.go | 14 ++-- 10 files changed, 80 insertions(+), 66 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 442528ca7..2c2a08b97 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -392,7 +392,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.state.Store(livekit.ParticipantInfo_JOINING) p.grants.Store(params.Grants.Clone()) - p.SetResponseSink(params.Sink) + p.SwapResponseSink(params.Sink, types.SignallingCloseReasonUnknown) p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs()) if p.supervisor != nil { diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 2f4a0dbfa..f4d265361 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -378,7 +378,7 @@ func TestDisableCodecs(t *testing.T) { // negotiated codec should not contain h264 sink := &routingfakes.FakeMessageSink{} - participant.SetResponseSink(sink) + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) var answer webrtc.SessionDescription var answerId uint32 var answerReceived atomic.Bool @@ -435,7 +435,7 @@ func TestDisablePublishCodec(t *testing.T) { } sink := &routingfakes.FakeMessageSink{} - participant.SetResponseSink(sink) + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) var publishReceived atomic.Bool sink.WriteMessageCalls(func(msg proto.Message) error { if res, ok := msg.(*livekit.SignalResponse); ok { @@ -570,7 +570,7 @@ func TestPreferMediaCodecForPublisher(t *testing.T) { offerId := uint32(23) sink := &routingfakes.FakeMessageSink{} - participant.SetResponseSink(sink) + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) var answer webrtc.SessionDescription var answerId uint32 var answerReceived atomic.Bool @@ -690,7 +690,7 @@ func TestPreferAudioCodecForRed(t *testing.T) { offerId := uint32(0xffffff) sink := &routingfakes.FakeMessageSink{} - participant.SetResponseSink(sink) + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) var answer webrtc.SessionDescription var answerId uint32 var answerReceived atomic.Bool diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go index ad08ba0f7..3fa388adc 100644 --- a/pkg/rtc/participant_signal.go +++ b/pkg/rtc/participant_signal.go @@ -27,8 +27,8 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" ) -func (p *ParticipantImpl) SetResponseSink(sink routing.MessageSink) { - p.signaller.SetResponseSink(sink) +func (p *ParticipantImpl) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) { + p.signaller.SwapResponseSink(sink, reason) } func (p *ParticipantImpl) GetResponseSink() routing.MessageSink { diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index fa41c7aa3..97a6019aa 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -551,8 +551,7 @@ func (r *Room) ResumeParticipant( ) error { r.ReplaceParticipantRequestSource(p.Identity(), requestSource) // close previous sink, and link to new one - p.CloseSignalConnection(types.SignallingCloseReasonResume) - p.SetResponseSink(responseSink) + p.SwapResponseSink(responseSink, types.SignallingCloseReasonResume) p.SetSignalSourceValid(true) diff --git a/pkg/rtc/signalling/interfaces.go b/pkg/rtc/signalling/interfaces.go index d193aa4fd..2470e9d31 100644 --- a/pkg/rtc/signalling/interfaces.go +++ b/pkg/rtc/signalling/interfaces.go @@ -28,7 +28,7 @@ type ParticipantSignalHandler interface { } type ParticipantSignaller interface { - SetResponseSink(sink routing.MessageSink) + SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) GetResponseSink() routing.MessageSink CloseSignalConnection(reason types.SignallingCloseReason) diff --git a/pkg/rtc/signalling/signallerasyncbase.go b/pkg/rtc/signalling/signallerasyncbase.go index 271b3618a..0b4d2b665 100644 --- a/pkg/rtc/signalling/signallerasyncbase.go +++ b/pkg/rtc/signalling/signallerasyncbase.go @@ -42,10 +42,29 @@ func newSignallerAsyncBase(params signallerAsyncBaseParams) *signallerAsyncBase } } -func (s *signallerAsyncBase) SetResponseSink(sink routing.MessageSink) { +func (s *signallerAsyncBase) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) { s.resSinkMu.Lock() - defer s.resSinkMu.Unlock() + oldSink := s.resSink s.resSink = sink + s.resSinkMu.Unlock() + + if oldSink != nil { + if sink != nil { + s.params.Logger.Debugw( + "swapping signal connection", + "reason", reason, + "connID", oldSink.ConnectionID(), + "newConnID", sink.ConnectionID(), + ) + } else { + s.params.Logger.Debugw( + "closing signal connection", + "reason", reason, + "connID", oldSink.ConnectionID(), + ) + } + oldSink.Close() + } } func (s *signallerAsyncBase) GetResponseSink() routing.MessageSink { @@ -56,12 +75,5 @@ func (s *signallerAsyncBase) GetResponseSink() routing.MessageSink { // closes signal connection to notify client to resume/reconnect func (s *signallerAsyncBase) CloseSignalConnection(reason types.SignallingCloseReason) { - sink := s.GetResponseSink() - if sink == nil { - return - } - - s.params.Logger.Debugw("closing signal connection", "reason", reason, "connID", sink.ConnectionID()) - sink.Close() - s.SetResponseSink(nil) + s.SwapResponseSink(nil, reason) } diff --git a/pkg/rtc/signalling/signallerunimplemented.go b/pkg/rtc/signalling/signallerunimplemented.go index 4c4be8bff..3592fd857 100644 --- a/pkg/rtc/signalling/signallerunimplemented.go +++ b/pkg/rtc/signalling/signallerunimplemented.go @@ -25,7 +25,8 @@ var _ ParticipantSignaller = (*signallerUnimplemented)(nil) type signallerUnimplemented struct{} -func (u *signallerUnimplemented) SetResponseSink(sink routing.MessageSink) {} +func (u *signallerUnimplemented) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) { +} func (u *signallerUnimplemented) GetResponseSink() routing.MessageSink { return nil diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 5280b7498..ec8b00e7f 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -400,7 +400,7 @@ type LocalParticipant interface { SupportsMoving() error GetLastReliableSequence(migrateOut bool) uint32 - SetResponseSink(sink routing.MessageSink) + SwapResponseSink(sink routing.MessageSink, reason SignallingCloseReason) GetResponseSink() routing.MessageSink CloseSignalConnection(reason SignallingCloseReason) UpdateLastSeenSignal() diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 511b7f444..494d685fa 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -1163,11 +1163,6 @@ type FakeLocalParticipant struct { setPermissionReturnsOnCall map[int]struct { result1 bool } - SetResponseSinkStub func(routing.MessageSink) - setResponseSinkMutex sync.RWMutex - setResponseSinkArgsForCall []struct { - arg1 routing.MessageSink - } SetSignalSourceValidStub func(bool) setSignalSourceValidMutex sync.RWMutex setSignalSourceValidArgsForCall []struct { @@ -1288,6 +1283,12 @@ type FakeLocalParticipant struct { supportsTransceiverReuseReturnsOnCall map[int]struct { result1 bool } + SwapResponseSinkStub func(routing.MessageSink, types.SignallingCloseReason) + swapResponseSinkMutex sync.RWMutex + swapResponseSinkArgsForCall []struct { + arg1 routing.MessageSink + arg2 types.SignallingCloseReason + } TelemetryGuardStub func() *telemetry.ReferenceGuard telemetryGuardMutex sync.RWMutex telemetryGuardArgsForCall []struct { @@ -7635,38 +7636,6 @@ func (fake *FakeLocalParticipant) SetPermissionReturnsOnCall(i int, result1 bool }{result1} } -func (fake *FakeLocalParticipant) SetResponseSink(arg1 routing.MessageSink) { - fake.setResponseSinkMutex.Lock() - fake.setResponseSinkArgsForCall = append(fake.setResponseSinkArgsForCall, struct { - arg1 routing.MessageSink - }{arg1}) - stub := fake.SetResponseSinkStub - fake.recordInvocation("SetResponseSink", []interface{}{arg1}) - fake.setResponseSinkMutex.Unlock() - if stub != nil { - fake.SetResponseSinkStub(arg1) - } -} - -func (fake *FakeLocalParticipant) SetResponseSinkCallCount() int { - fake.setResponseSinkMutex.RLock() - defer fake.setResponseSinkMutex.RUnlock() - return len(fake.setResponseSinkArgsForCall) -} - -func (fake *FakeLocalParticipant) SetResponseSinkCalls(stub func(routing.MessageSink)) { - fake.setResponseSinkMutex.Lock() - defer fake.setResponseSinkMutex.Unlock() - fake.SetResponseSinkStub = stub -} - -func (fake *FakeLocalParticipant) SetResponseSinkArgsForCall(i int) routing.MessageSink { - fake.setResponseSinkMutex.RLock() - defer fake.setResponseSinkMutex.RUnlock() - argsForCall := fake.setResponseSinkArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakeLocalParticipant) SetSignalSourceValid(arg1 bool) { fake.setSignalSourceValidMutex.Lock() fake.setSignalSourceValidArgsForCall = append(fake.setSignalSourceValidArgsForCall, struct { @@ -8317,6 +8286,39 @@ func (fake *FakeLocalParticipant) SupportsTransceiverReuseReturnsOnCall(i int, r }{result1} } +func (fake *FakeLocalParticipant) SwapResponseSink(arg1 routing.MessageSink, arg2 types.SignallingCloseReason) { + fake.swapResponseSinkMutex.Lock() + fake.swapResponseSinkArgsForCall = append(fake.swapResponseSinkArgsForCall, struct { + arg1 routing.MessageSink + arg2 types.SignallingCloseReason + }{arg1, arg2}) + stub := fake.SwapResponseSinkStub + fake.recordInvocation("SwapResponseSink", []interface{}{arg1, arg2}) + fake.swapResponseSinkMutex.Unlock() + if stub != nil { + fake.SwapResponseSinkStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) SwapResponseSinkCallCount() int { + fake.swapResponseSinkMutex.RLock() + defer fake.swapResponseSinkMutex.RUnlock() + return len(fake.swapResponseSinkArgsForCall) +} + +func (fake *FakeLocalParticipant) SwapResponseSinkCalls(stub func(routing.MessageSink, types.SignallingCloseReason)) { + fake.swapResponseSinkMutex.Lock() + defer fake.swapResponseSinkMutex.Unlock() + fake.SwapResponseSinkStub = stub +} + +func (fake *FakeLocalParticipant) SwapResponseSinkArgsForCall(i int) (routing.MessageSink, types.SignallingCloseReason) { + fake.swapResponseSinkMutex.RLock() + defer fake.swapResponseSinkMutex.RUnlock() + argsForCall := fake.swapResponseSinkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeLocalParticipant) TelemetryGuard() *telemetry.ReferenceGuard { fake.telemetryGuardMutex.Lock() ret, specificReturn := fake.telemetryGuardReturnsOnCall[len(fake.telemetryGuardArgsForCall)] diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index b33cc1744..22b1e0c1c 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -89,23 +89,23 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live } rtcEgressLauncher := NewEgressLauncher(egressClient, ioInfoService, objectStore) topicFormatter := rpc.NewTopicFormatter() - roomClient, err := rpc.NewTypedRoomClient(clientParams) + v, err := rpc.NewTypedRoomClient(clientParams) if err != nil { return nil, err } - participantClient, err := rpc.NewTypedParticipantClient(clientParams) + v2, err := rpc.NewTypedParticipantClient(clientParams) if err != nil { return nil, err } - roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient) + roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, v, v2) if err != nil { return nil, err } - agentDispatchInternalClient, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) + v3, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) if err != nil { return nil, err } - agentDispatchService := NewAgentDispatchService(agentDispatchInternalClient, topicFormatter, roomAllocator, router) + agentDispatchService := NewAgentDispatchService(v3, topicFormatter, roomAllocator, router) egressService := NewEgressService(egressClient, rtcEgressLauncher, ioInfoService, roomService) ingressConfig := getIngressConfig(conf) ingressClient, err := rpc.NewIngressClient(clientParams) @@ -120,11 +120,11 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live } sipService := NewSIPService(sipConfig, nodeID, messageBus, sipClient, sipStore, roomService, telemetryService) rtcService := NewRTCService(conf, roomAllocator, router, telemetryService) - whipParticipantClient, err := rpc.NewTypedWHIPParticipantClient(clientParams) + v4, err := rpc.NewTypedWHIPParticipantClient(clientParams) if err != nil { return nil, err } - serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, whipParticipantClient) + serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, v4) if err != nil { return nil, err }