Swap result sink atomically rather than closing and setting. (#4216)

To prevent WriteMessage after Close potentially.
This commit is contained in:
Raja Subramanian
2026-01-03 03:17:58 +05:30
committed by GitHub
parent 46651c1978
commit 335f4c33fb
10 changed files with 80 additions and 66 deletions
+1 -1
View File
@@ -392,7 +392,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) {
p.state.Store(livekit.ParticipantInfo_JOINING)
p.grants.Store(params.Grants.Clone())
p.SetResponseSink(params.Sink)
p.SwapResponseSink(params.Sink, types.SignallingCloseReasonUnknown)
p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs())
if p.supervisor != nil {
+4 -4
View File
@@ -378,7 +378,7 @@ func TestDisableCodecs(t *testing.T) {
// negotiated codec should not contain h264
sink := &routingfakes.FakeMessageSink{}
participant.SetResponseSink(sink)
participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown)
var answer webrtc.SessionDescription
var answerId uint32
var answerReceived atomic.Bool
@@ -435,7 +435,7 @@ func TestDisablePublishCodec(t *testing.T) {
}
sink := &routingfakes.FakeMessageSink{}
participant.SetResponseSink(sink)
participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown)
var publishReceived atomic.Bool
sink.WriteMessageCalls(func(msg proto.Message) error {
if res, ok := msg.(*livekit.SignalResponse); ok {
@@ -570,7 +570,7 @@ func TestPreferMediaCodecForPublisher(t *testing.T) {
offerId := uint32(23)
sink := &routingfakes.FakeMessageSink{}
participant.SetResponseSink(sink)
participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown)
var answer webrtc.SessionDescription
var answerId uint32
var answerReceived atomic.Bool
@@ -690,7 +690,7 @@ func TestPreferAudioCodecForRed(t *testing.T) {
offerId := uint32(0xffffff)
sink := &routingfakes.FakeMessageSink{}
participant.SetResponseSink(sink)
participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown)
var answer webrtc.SessionDescription
var answerId uint32
var answerReceived atomic.Bool
+2 -2
View File
@@ -27,8 +27,8 @@ import (
"github.com/livekit/livekit-server/pkg/rtc/types"
)
func (p *ParticipantImpl) SetResponseSink(sink routing.MessageSink) {
p.signaller.SetResponseSink(sink)
func (p *ParticipantImpl) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) {
p.signaller.SwapResponseSink(sink, reason)
}
func (p *ParticipantImpl) GetResponseSink() routing.MessageSink {
+1 -2
View File
@@ -551,8 +551,7 @@ func (r *Room) ResumeParticipant(
) error {
r.ReplaceParticipantRequestSource(p.Identity(), requestSource)
// close previous sink, and link to new one
p.CloseSignalConnection(types.SignallingCloseReasonResume)
p.SetResponseSink(responseSink)
p.SwapResponseSink(responseSink, types.SignallingCloseReasonResume)
p.SetSignalSourceValid(true)
+1 -1
View File
@@ -28,7 +28,7 @@ type ParticipantSignalHandler interface {
}
type ParticipantSignaller interface {
SetResponseSink(sink routing.MessageSink)
SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason)
GetResponseSink() routing.MessageSink
CloseSignalConnection(reason types.SignallingCloseReason)
+22 -10
View File
@@ -42,10 +42,29 @@ func newSignallerAsyncBase(params signallerAsyncBaseParams) *signallerAsyncBase
}
}
func (s *signallerAsyncBase) SetResponseSink(sink routing.MessageSink) {
func (s *signallerAsyncBase) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) {
s.resSinkMu.Lock()
defer s.resSinkMu.Unlock()
oldSink := s.resSink
s.resSink = sink
s.resSinkMu.Unlock()
if oldSink != nil {
if sink != nil {
s.params.Logger.Debugw(
"swapping signal connection",
"reason", reason,
"connID", oldSink.ConnectionID(),
"newConnID", sink.ConnectionID(),
)
} else {
s.params.Logger.Debugw(
"closing signal connection",
"reason", reason,
"connID", oldSink.ConnectionID(),
)
}
oldSink.Close()
}
}
func (s *signallerAsyncBase) GetResponseSink() routing.MessageSink {
@@ -56,12 +75,5 @@ func (s *signallerAsyncBase) GetResponseSink() routing.MessageSink {
// closes signal connection to notify client to resume/reconnect
func (s *signallerAsyncBase) CloseSignalConnection(reason types.SignallingCloseReason) {
sink := s.GetResponseSink()
if sink == nil {
return
}
s.params.Logger.Debugw("closing signal connection", "reason", reason, "connID", sink.ConnectionID())
sink.Close()
s.SetResponseSink(nil)
s.SwapResponseSink(nil, reason)
}
+2 -1
View File
@@ -25,7 +25,8 @@ var _ ParticipantSignaller = (*signallerUnimplemented)(nil)
type signallerUnimplemented struct{}
func (u *signallerUnimplemented) SetResponseSink(sink routing.MessageSink) {}
func (u *signallerUnimplemented) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) {
}
func (u *signallerUnimplemented) GetResponseSink() routing.MessageSink {
return nil
+1 -1
View File
@@ -400,7 +400,7 @@ type LocalParticipant interface {
SupportsMoving() error
GetLastReliableSequence(migrateOut bool) uint32
SetResponseSink(sink routing.MessageSink)
SwapResponseSink(sink routing.MessageSink, reason SignallingCloseReason)
GetResponseSink() routing.MessageSink
CloseSignalConnection(reason SignallingCloseReason)
UpdateLastSeenSignal()
@@ -1163,11 +1163,6 @@ type FakeLocalParticipant struct {
setPermissionReturnsOnCall map[int]struct {
result1 bool
}
SetResponseSinkStub func(routing.MessageSink)
setResponseSinkMutex sync.RWMutex
setResponseSinkArgsForCall []struct {
arg1 routing.MessageSink
}
SetSignalSourceValidStub func(bool)
setSignalSourceValidMutex sync.RWMutex
setSignalSourceValidArgsForCall []struct {
@@ -1288,6 +1283,12 @@ type FakeLocalParticipant struct {
supportsTransceiverReuseReturnsOnCall map[int]struct {
result1 bool
}
SwapResponseSinkStub func(routing.MessageSink, types.SignallingCloseReason)
swapResponseSinkMutex sync.RWMutex
swapResponseSinkArgsForCall []struct {
arg1 routing.MessageSink
arg2 types.SignallingCloseReason
}
TelemetryGuardStub func() *telemetry.ReferenceGuard
telemetryGuardMutex sync.RWMutex
telemetryGuardArgsForCall []struct {
@@ -7635,38 +7636,6 @@ func (fake *FakeLocalParticipant) SetPermissionReturnsOnCall(i int, result1 bool
}{result1}
}
func (fake *FakeLocalParticipant) SetResponseSink(arg1 routing.MessageSink) {
fake.setResponseSinkMutex.Lock()
fake.setResponseSinkArgsForCall = append(fake.setResponseSinkArgsForCall, struct {
arg1 routing.MessageSink
}{arg1})
stub := fake.SetResponseSinkStub
fake.recordInvocation("SetResponseSink", []interface{}{arg1})
fake.setResponseSinkMutex.Unlock()
if stub != nil {
fake.SetResponseSinkStub(arg1)
}
}
func (fake *FakeLocalParticipant) SetResponseSinkCallCount() int {
fake.setResponseSinkMutex.RLock()
defer fake.setResponseSinkMutex.RUnlock()
return len(fake.setResponseSinkArgsForCall)
}
func (fake *FakeLocalParticipant) SetResponseSinkCalls(stub func(routing.MessageSink)) {
fake.setResponseSinkMutex.Lock()
defer fake.setResponseSinkMutex.Unlock()
fake.SetResponseSinkStub = stub
}
func (fake *FakeLocalParticipant) SetResponseSinkArgsForCall(i int) routing.MessageSink {
fake.setResponseSinkMutex.RLock()
defer fake.setResponseSinkMutex.RUnlock()
argsForCall := fake.setResponseSinkArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeLocalParticipant) SetSignalSourceValid(arg1 bool) {
fake.setSignalSourceValidMutex.Lock()
fake.setSignalSourceValidArgsForCall = append(fake.setSignalSourceValidArgsForCall, struct {
@@ -8317,6 +8286,39 @@ func (fake *FakeLocalParticipant) SupportsTransceiverReuseReturnsOnCall(i int, r
}{result1}
}
func (fake *FakeLocalParticipant) SwapResponseSink(arg1 routing.MessageSink, arg2 types.SignallingCloseReason) {
fake.swapResponseSinkMutex.Lock()
fake.swapResponseSinkArgsForCall = append(fake.swapResponseSinkArgsForCall, struct {
arg1 routing.MessageSink
arg2 types.SignallingCloseReason
}{arg1, arg2})
stub := fake.SwapResponseSinkStub
fake.recordInvocation("SwapResponseSink", []interface{}{arg1, arg2})
fake.swapResponseSinkMutex.Unlock()
if stub != nil {
fake.SwapResponseSinkStub(arg1, arg2)
}
}
func (fake *FakeLocalParticipant) SwapResponseSinkCallCount() int {
fake.swapResponseSinkMutex.RLock()
defer fake.swapResponseSinkMutex.RUnlock()
return len(fake.swapResponseSinkArgsForCall)
}
func (fake *FakeLocalParticipant) SwapResponseSinkCalls(stub func(routing.MessageSink, types.SignallingCloseReason)) {
fake.swapResponseSinkMutex.Lock()
defer fake.swapResponseSinkMutex.Unlock()
fake.SwapResponseSinkStub = stub
}
func (fake *FakeLocalParticipant) SwapResponseSinkArgsForCall(i int) (routing.MessageSink, types.SignallingCloseReason) {
fake.swapResponseSinkMutex.RLock()
defer fake.swapResponseSinkMutex.RUnlock()
argsForCall := fake.swapResponseSinkArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2
}
func (fake *FakeLocalParticipant) TelemetryGuard() *telemetry.ReferenceGuard {
fake.telemetryGuardMutex.Lock()
ret, specificReturn := fake.telemetryGuardReturnsOnCall[len(fake.telemetryGuardArgsForCall)]
+7 -7
View File
@@ -89,23 +89,23 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
}
rtcEgressLauncher := NewEgressLauncher(egressClient, ioInfoService, objectStore)
topicFormatter := rpc.NewTopicFormatter()
roomClient, err := rpc.NewTypedRoomClient(clientParams)
v, err := rpc.NewTypedRoomClient(clientParams)
if err != nil {
return nil, err
}
participantClient, err := rpc.NewTypedParticipantClient(clientParams)
v2, err := rpc.NewTypedParticipantClient(clientParams)
if err != nil {
return nil, err
}
roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient)
roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, v, v2)
if err != nil {
return nil, err
}
agentDispatchInternalClient, err := rpc.NewTypedAgentDispatchInternalClient(clientParams)
v3, err := rpc.NewTypedAgentDispatchInternalClient(clientParams)
if err != nil {
return nil, err
}
agentDispatchService := NewAgentDispatchService(agentDispatchInternalClient, topicFormatter, roomAllocator, router)
agentDispatchService := NewAgentDispatchService(v3, topicFormatter, roomAllocator, router)
egressService := NewEgressService(egressClient, rtcEgressLauncher, ioInfoService, roomService)
ingressConfig := getIngressConfig(conf)
ingressClient, err := rpc.NewIngressClient(clientParams)
@@ -120,11 +120,11 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
}
sipService := NewSIPService(sipConfig, nodeID, messageBus, sipClient, sipStore, roomService, telemetryService)
rtcService := NewRTCService(conf, roomAllocator, router, telemetryService)
whipParticipantClient, err := rpc.NewTypedWHIPParticipantClient(clientParams)
v4, err := rpc.NewTypedWHIPParticipantClient(clientParams)
if err != nil {
return nil, err
}
serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, whipParticipantClient)
serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, v4)
if err != nil {
return nil, err
}