inject logger constructor into signal server (#2372)

* inject logger constructor into signal server

* tidy

* tidy

* test
This commit is contained in:
Paul Wells
2024-01-10 03:31:54 -08:00
committed by GitHub
parent 4e30e1a86d
commit d70a8e366c
3 changed files with 295 additions and 68 deletions

View File

@@ -0,0 +1,199 @@
// Code generated by counterfeiter. DO NOT EDIT.
package servicefakes
import (
"context"
"sync"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/service"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
)
type FakeSessionHandler struct {
HandleSessionStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error
handleSessionMutex sync.RWMutex
handleSessionArgsForCall []struct {
arg1 context.Context
arg2 livekit.RoomName
arg3 routing.ParticipantInit
arg4 livekit.ConnectionID
arg5 routing.MessageSource
arg6 routing.MessageSink
}
handleSessionReturns struct {
result1 error
}
handleSessionReturnsOnCall map[int]struct {
result1 error
}
LoggerStub func(context.Context) logger.Logger
loggerMutex sync.RWMutex
loggerArgsForCall []struct {
arg1 context.Context
}
loggerReturns struct {
result1 logger.Logger
}
loggerReturnsOnCall map[int]struct {
result1 logger.Logger
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeSessionHandler) HandleSession(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.ConnectionID, arg5 routing.MessageSource, arg6 routing.MessageSink) error {
fake.handleSessionMutex.Lock()
ret, specificReturn := fake.handleSessionReturnsOnCall[len(fake.handleSessionArgsForCall)]
fake.handleSessionArgsForCall = append(fake.handleSessionArgsForCall, struct {
arg1 context.Context
arg2 livekit.RoomName
arg3 routing.ParticipantInit
arg4 livekit.ConnectionID
arg5 routing.MessageSource
arg6 routing.MessageSink
}{arg1, arg2, arg3, arg4, arg5, arg6})
stub := fake.HandleSessionStub
fakeReturns := fake.handleSessionReturns
fake.recordInvocation("HandleSession", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6})
fake.handleSessionMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3, arg4, arg5, arg6)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeSessionHandler) HandleSessionCallCount() int {
fake.handleSessionMutex.RLock()
defer fake.handleSessionMutex.RUnlock()
return len(fake.handleSessionArgsForCall)
}
func (fake *FakeSessionHandler) HandleSessionCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error) {
fake.handleSessionMutex.Lock()
defer fake.handleSessionMutex.Unlock()
fake.HandleSessionStub = stub
}
func (fake *FakeSessionHandler) HandleSessionArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) {
fake.handleSessionMutex.RLock()
defer fake.handleSessionMutex.RUnlock()
argsForCall := fake.handleSessionArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6
}
func (fake *FakeSessionHandler) HandleSessionReturns(result1 error) {
fake.handleSessionMutex.Lock()
defer fake.handleSessionMutex.Unlock()
fake.HandleSessionStub = nil
fake.handleSessionReturns = struct {
result1 error
}{result1}
}
func (fake *FakeSessionHandler) HandleSessionReturnsOnCall(i int, result1 error) {
fake.handleSessionMutex.Lock()
defer fake.handleSessionMutex.Unlock()
fake.HandleSessionStub = nil
if fake.handleSessionReturnsOnCall == nil {
fake.handleSessionReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.handleSessionReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeSessionHandler) Logger(arg1 context.Context) logger.Logger {
fake.loggerMutex.Lock()
ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)]
fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct {
arg1 context.Context
}{arg1})
stub := fake.LoggerStub
fakeReturns := fake.loggerReturns
fake.recordInvocation("Logger", []interface{}{arg1})
fake.loggerMutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeSessionHandler) LoggerCallCount() int {
fake.loggerMutex.RLock()
defer fake.loggerMutex.RUnlock()
return len(fake.loggerArgsForCall)
}
func (fake *FakeSessionHandler) LoggerCalls(stub func(context.Context) logger.Logger) {
fake.loggerMutex.Lock()
defer fake.loggerMutex.Unlock()
fake.LoggerStub = stub
}
func (fake *FakeSessionHandler) LoggerArgsForCall(i int) context.Context {
fake.loggerMutex.RLock()
defer fake.loggerMutex.RUnlock()
argsForCall := fake.loggerArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeSessionHandler) LoggerReturns(result1 logger.Logger) {
fake.loggerMutex.Lock()
defer fake.loggerMutex.Unlock()
fake.LoggerStub = nil
fake.loggerReturns = struct {
result1 logger.Logger
}{result1}
}
func (fake *FakeSessionHandler) LoggerReturnsOnCall(i int, result1 logger.Logger) {
fake.loggerMutex.Lock()
defer fake.loggerMutex.Unlock()
fake.LoggerStub = nil
if fake.loggerReturnsOnCall == nil {
fake.loggerReturnsOnCall = make(map[int]struct {
result1 logger.Logger
})
}
fake.loggerReturnsOnCall[i] = struct {
result1 logger.Logger
}{result1}
}
func (fake *FakeSessionHandler) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
fake.handleSessionMutex.RLock()
defer fake.handleSessionMutex.RUnlock()
fake.loggerMutex.RLock()
defer fake.loggerMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeSessionHandler) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ service.SessionHandler = new(FakeSessionHandler)

View File

@@ -31,14 +31,21 @@ import (
"github.com/livekit/psrpc/pkg/middleware"
)
type SessionHandler func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
//counterfeiter:generate . SessionHandler
type SessionHandler interface {
Logger(ctx context.Context) logger.Logger
HandleSession(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error
}
type SignalServer struct {
server rpc.TypedSignalServer
@@ -72,43 +79,53 @@ func NewDefaultSignalServer(
router routing.Router,
roomManager *RoomManager,
) (r *SignalServer, err error) {
sessionHandler := func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
prometheus.IncrementParticipantRtcInit(1)
return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, &defaultSessionHandler{currentNode, router, roomManager})
}
if rr, ok := router.(*routing.RedisRouter); ok {
rtcNode, err := router.GetNodeForRoom(ctx, roomName)
if err != nil {
return err
}
type defaultSessionHandler struct {
currentNode routing.LocalNode
router routing.Router
roomManager *RoomManager
}
if rtcNode.Id != currentNode.Id {
err = routing.ErrIncorrectRTCNode
logger.Errorw("called participant on incorrect node", err,
"rtcNode", rtcNode,
)
return err
}
func (s *defaultSessionHandler) Logger(ctx context.Context) logger.Logger {
return logger.GetLogger()
}
pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity)
pKeyB62 := routing.ParticipantKey(roomName, pi.Identity)
func (s *defaultSessionHandler) HandleSession(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
prometheus.IncrementParticipantRtcInit(1)
// RTC session should start on this node
if err := rr.SetParticipantRTCNode(pKey, pKeyB62, currentNode.Id); err != nil {
return err
}
if rr, ok := s.router.(*routing.RedisRouter); ok {
rtcNode, err := s.router.GetNodeForRoom(ctx, roomName)
if err != nil {
return err
}
return roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink)
if rtcNode.Id != s.currentNode.Id {
err = routing.ErrIncorrectRTCNode
logger.Errorw("called participant on incorrect node", err,
"rtcNode", rtcNode,
)
return err
}
pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity)
pKeyB62 := routing.ParticipantKey(roomName, pi.Identity)
// RTC session should start on this node
if err := rr.SetParticipantRTCNode(pKey, pKeyB62, s.currentNode.Id); err != nil {
return err
}
}
return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, sessionHandler)
return s.roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink)
}
func (s *SignalServer) Start() error {
@@ -142,7 +159,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe
return errors.Wrap(err, "failed to read participant from session")
}
l := logger.GetLogger().WithValues(
l := r.sessionHandler.Logger(stream.Context()).WithValues(
"room", ss.RoomName,
"participant", ss.Identity,
"connID", ss.ConnectionId,
@@ -175,7 +192,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe
// copy the incoming rpc headers to avoid dropping any session vars.
ctx := metadata.NewContextWithIncomingHeader(context.Background(), metadata.IncomingHeader(stream.Context()))
err = r.sessionHandler(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink)
err = r.sessionHandler.HandleSession(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink)
if err != nil {
sink.Close()
l.Errorw("could not handle new participant", err)

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package service
package service_test
import (
"context"
@@ -26,8 +26,11 @@ import (
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/service"
"github.com/livekit/livekit-server/pkg/service/servicefakes"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/psrpc"
)
@@ -61,22 +64,26 @@ func TestSignal(t *testing.T) {
client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg)
require.NoError(t, err)
server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
go func() {
reqMessageOut = <-requestSource.ReadChan()
resErr = responseSink.WriteMessage(resMessageIn)
responseSink.Close()
close(done)
}()
return nil
})
handler := &servicefakes.FakeSessionHandler{
LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() },
HandleSessionStub: func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
go func() {
reqMessageOut = <-requestSource.ReadChan()
resErr = responseSink.WriteMessage(resMessageIn)
responseSink.Close()
close(done)
}()
return nil
},
}
server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler)
require.NoError(t, err)
err = server.Start()
@@ -114,18 +121,22 @@ func TestSignal(t *testing.T) {
client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg)
require.NoError(t, err)
server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
defer close(done)
resErr = responseSink.WriteMessage(resMessageIn)
return errors.New("start session failed")
})
handler := &servicefakes.FakeSessionHandler{
LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() },
HandleSessionStub: func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
defer close(done)
resErr = responseSink.WriteMessage(resMessageIn)
return errors.New("start session failed")
},
}
server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler)
require.NoError(t, err)
err = server.Start()