diff --git a/pkg/routing/routingfakes/fake_signal_client.go b/pkg/routing/routingfakes/fake_signal_client.go index 884dac4d5..0562b7c44 100644 --- a/pkg/routing/routingfakes/fake_signal_client.go +++ b/pkg/routing/routingfakes/fake_signal_client.go @@ -10,6 +10,16 @@ import ( ) type FakeSignalClient struct { + ActiveCountStub func() int + activeCountMutex sync.RWMutex + activeCountArgsForCall []struct { + } + activeCountReturns struct { + result1 int + } + activeCountReturnsOnCall map[int]struct { + result1 int + } StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) startParticipantSignalMutex sync.RWMutex startParticipantSignalArgsForCall []struct { @@ -34,6 +44,59 @@ type FakeSignalClient struct { invocationsMutex sync.RWMutex } +func (fake *FakeSignalClient) ActiveCount() int { + fake.activeCountMutex.Lock() + ret, specificReturn := fake.activeCountReturnsOnCall[len(fake.activeCountArgsForCall)] + fake.activeCountArgsForCall = append(fake.activeCountArgsForCall, struct { + }{}) + stub := fake.ActiveCountStub + fakeReturns := fake.activeCountReturns + fake.recordInvocation("ActiveCount", []interface{}{}) + fake.activeCountMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSignalClient) ActiveCountCallCount() int { + fake.activeCountMutex.RLock() + defer fake.activeCountMutex.RUnlock() + return len(fake.activeCountArgsForCall) +} + +func (fake *FakeSignalClient) ActiveCountCalls(stub func() int) { + fake.activeCountMutex.Lock() + defer fake.activeCountMutex.Unlock() + fake.ActiveCountStub = stub +} + +func (fake *FakeSignalClient) ActiveCountReturns(result1 int) { + fake.activeCountMutex.Lock() + defer fake.activeCountMutex.Unlock() + fake.ActiveCountStub = nil + fake.activeCountReturns = struct { + result1 int + }{result1} +} + +func (fake *FakeSignalClient) ActiveCountReturnsOnCall(i int, result1 int) { + fake.activeCountMutex.Lock() + defer fake.activeCountMutex.Unlock() + fake.ActiveCountStub = nil + if fake.activeCountReturnsOnCall == nil { + fake.activeCountReturnsOnCall = make(map[int]struct { + result1 int + }) + } + fake.activeCountReturnsOnCall[i] = struct { + result1 int + }{result1} +} + func (fake *FakeSignalClient) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) { fake.startParticipantSignalMutex.Lock() ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)] @@ -110,6 +173,8 @@ func (fake *FakeSignalClient) StartParticipantSignalReturnsOnCall(i int, result1 func (fake *FakeSignalClient) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.activeCountMutex.RLock() + defer fake.activeCountMutex.RUnlock() fake.startParticipantSignalMutex.RLock() defer fake.startParticipantSignalMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/routing/signal.go b/pkg/routing/signal.go index 7bbb4df93..4d0f8b812 100644 --- a/pkg/routing/signal.go +++ b/pkg/routing/signal.go @@ -3,6 +3,7 @@ package routing import ( "context" + "go.uber.org/atomic" "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/config" @@ -18,6 +19,7 @@ import ( //counterfeiter:generate . SignalClient type SignalClient interface { + ActiveCount() int StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) } @@ -25,6 +27,7 @@ type signalClient struct { nodeID livekit.NodeID config config.SignalRelayConfig client rpc.TypedSignalClient + active atomic.Int32 } func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus, config config.SignalRelayConfig) (SignalClient, error) { @@ -45,6 +48,10 @@ func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus, config config. }, nil } +func (r *signalClient) ActiveCount() int { + return int(r.active.Load()) +} + func (r *signalClient) StartParticipantSignal( ctx context.Context, roomName livekit.RoomName, @@ -84,6 +91,9 @@ func (r *signalClient) StartParticipantSignal( resChan := NewDefaultMessageChannel() go func() { + r.active.Inc() + defer r.active.Dec() + var err error for msg := range stream.Channel() { if err = resChan.WriteMessage(msg.Response); err != nil {