From d70a8e366c20280a4e95f7ee75c12a146ed54aa0 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Wed, 10 Jan 2024 03:31:54 -0800 Subject: [PATCH] inject logger constructor into signal server (#2372) * inject logger constructor into signal server * tidy * tidy * test --- .../servicefakes/fake_session_handler.go | 199 ++++++++++++++++++ pkg/service/signal.go | 95 +++++---- pkg/service/signal_test.go | 69 +++--- 3 files changed, 295 insertions(+), 68 deletions(-) create mode 100644 pkg/service/servicefakes/fake_session_handler.go diff --git a/pkg/service/servicefakes/fake_session_handler.go b/pkg/service/servicefakes/fake_session_handler.go new file mode 100644 index 000000000..552386918 --- /dev/null +++ b/pkg/service/servicefakes/fake_session_handler.go @@ -0,0 +1,199 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type FakeSessionHandler struct { + HandleSessionStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error + handleSessionMutex sync.RWMutex + handleSessionArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.ConnectionID + arg5 routing.MessageSource + arg6 routing.MessageSink + } + handleSessionReturns struct { + result1 error + } + handleSessionReturnsOnCall map[int]struct { + result1 error + } + LoggerStub func(context.Context) logger.Logger + loggerMutex sync.RWMutex + loggerArgsForCall []struct { + arg1 context.Context + } + loggerReturns struct { + result1 logger.Logger + } + loggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSessionHandler) HandleSession(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.ConnectionID, arg5 routing.MessageSource, arg6 routing.MessageSink) error { + fake.handleSessionMutex.Lock() + ret, specificReturn := fake.handleSessionReturnsOnCall[len(fake.handleSessionArgsForCall)] + fake.handleSessionArgsForCall = append(fake.handleSessionArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.ConnectionID + arg5 routing.MessageSource + arg6 routing.MessageSink + }{arg1, arg2, arg3, arg4, arg5, arg6}) + stub := fake.HandleSessionStub + fakeReturns := fake.handleSessionReturns + fake.recordInvocation("HandleSession", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6}) + fake.handleSessionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4, arg5, arg6) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSessionHandler) HandleSessionCallCount() int { + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + return len(fake.handleSessionArgsForCall) +} + +func (fake *FakeSessionHandler) HandleSessionCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = stub +} + +func (fake *FakeSessionHandler) HandleSessionArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) { + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + argsForCall := fake.handleSessionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6 +} + +func (fake *FakeSessionHandler) HandleSessionReturns(result1 error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = nil + fake.handleSessionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSessionHandler) HandleSessionReturnsOnCall(i int, result1 error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = nil + if fake.handleSessionReturnsOnCall == nil { + fake.handleSessionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleSessionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSessionHandler) Logger(arg1 context.Context) logger.Logger { + fake.loggerMutex.Lock() + ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)] + fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.LoggerStub + fakeReturns := fake.loggerReturns + fake.recordInvocation("Logger", []interface{}{arg1}) + fake.loggerMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSessionHandler) LoggerCallCount() int { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + return len(fake.loggerArgsForCall) +} + +func (fake *FakeSessionHandler) LoggerCalls(stub func(context.Context) logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = stub +} + +func (fake *FakeSessionHandler) LoggerArgsForCall(i int) context.Context { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + argsForCall := fake.loggerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSessionHandler) LoggerReturns(result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + fake.loggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeSessionHandler) LoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + if fake.loggerReturnsOnCall == nil { + fake.loggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.loggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeSessionHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSessionHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.SessionHandler = new(FakeSessionHandler) diff --git a/pkg/service/signal.go b/pkg/service/signal.go index a181987e9..a0abc5952 100644 --- a/pkg/service/signal.go +++ b/pkg/service/signal.go @@ -31,14 +31,21 @@ import ( "github.com/livekit/psrpc/pkg/middleware" ) -type SessionHandler func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, -) error +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . SessionHandler +type SessionHandler interface { + Logger(ctx context.Context) logger.Logger + + HandleSession( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error +} type SignalServer struct { server rpc.TypedSignalServer @@ -72,43 +79,53 @@ func NewDefaultSignalServer( router routing.Router, roomManager *RoomManager, ) (r *SignalServer, err error) { - sessionHandler := func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, - ) error { - prometheus.IncrementParticipantRtcInit(1) + return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, &defaultSessionHandler{currentNode, router, roomManager}) +} - if rr, ok := router.(*routing.RedisRouter); ok { - rtcNode, err := router.GetNodeForRoom(ctx, roomName) - if err != nil { - return err - } +type defaultSessionHandler struct { + currentNode routing.LocalNode + router routing.Router + roomManager *RoomManager +} - if rtcNode.Id != currentNode.Id { - err = routing.ErrIncorrectRTCNode - logger.Errorw("called participant on incorrect node", err, - "rtcNode", rtcNode, - ) - return err - } +func (s *defaultSessionHandler) Logger(ctx context.Context) logger.Logger { + return logger.GetLogger() +} - pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity) - pKeyB62 := routing.ParticipantKey(roomName, pi.Identity) +func (s *defaultSessionHandler) HandleSession( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, +) error { + prometheus.IncrementParticipantRtcInit(1) - // RTC session should start on this node - if err := rr.SetParticipantRTCNode(pKey, pKeyB62, currentNode.Id); err != nil { - return err - } + if rr, ok := s.router.(*routing.RedisRouter); ok { + rtcNode, err := s.router.GetNodeForRoom(ctx, roomName) + if err != nil { + return err } - return roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink) + if rtcNode.Id != s.currentNode.Id { + err = routing.ErrIncorrectRTCNode + logger.Errorw("called participant on incorrect node", err, + "rtcNode", rtcNode, + ) + return err + } + + pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity) + pKeyB62 := routing.ParticipantKey(roomName, pi.Identity) + + // RTC session should start on this node + if err := rr.SetParticipantRTCNode(pKey, pKeyB62, s.currentNode.Id); err != nil { + return err + } } - return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, sessionHandler) + return s.roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink) } func (s *SignalServer) Start() error { @@ -142,7 +159,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe return errors.Wrap(err, "failed to read participant from session") } - l := logger.GetLogger().WithValues( + l := r.sessionHandler.Logger(stream.Context()).WithValues( "room", ss.RoomName, "participant", ss.Identity, "connID", ss.ConnectionId, @@ -175,7 +192,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe // copy the incoming rpc headers to avoid dropping any session vars. ctx := metadata.NewContextWithIncomingHeader(context.Background(), metadata.IncomingHeader(stream.Context())) - err = r.sessionHandler(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) + err = r.sessionHandler.HandleSession(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) if err != nil { sink.Close() l.Errorw("could not handle new participant", err) diff --git a/pkg/service/signal_test.go b/pkg/service/signal_test.go index 83bf391e9..a80935182 100644 --- a/pkg/service/signal_test.go +++ b/pkg/service/signal_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package service +package service_test import ( "context" @@ -26,8 +26,11 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/service/servicefakes" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" "github.com/livekit/psrpc" ) @@ -61,22 +64,26 @@ func TestSignal(t *testing.T) { client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) require.NoError(t, err) - server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, - ) error { - go func() { - reqMessageOut = <-requestSource.ReadChan() - resErr = responseSink.WriteMessage(resMessageIn) - responseSink.Close() - close(done) - }() - return nil - }) + handler := &servicefakes.FakeSessionHandler{ + LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() }, + HandleSessionStub: func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + go func() { + reqMessageOut = <-requestSource.ReadChan() + resErr = responseSink.WriteMessage(resMessageIn) + responseSink.Close() + close(done) + }() + return nil + }, + } + server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler) require.NoError(t, err) err = server.Start() @@ -114,18 +121,22 @@ func TestSignal(t *testing.T) { client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) require.NoError(t, err) - server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( - ctx context.Context, - roomName livekit.RoomName, - pi routing.ParticipantInit, - connectionID livekit.ConnectionID, - requestSource routing.MessageSource, - responseSink routing.MessageSink, - ) error { - defer close(done) - resErr = responseSink.WriteMessage(resMessageIn) - return errors.New("start session failed") - }) + handler := &servicefakes.FakeSessionHandler{ + LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() }, + HandleSessionStub: func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + defer close(done) + resErr = responseSink.WriteMessage(resMessageIn) + return errors.New("start session failed") + }, + } + server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler) require.NoError(t, err) err = server.Start()