mirror of
https://github.com/livekit/livekit.git
synced 2026-03-30 15:35:41 +00:00
inject logger constructor into signal server (#2372)
* inject logger constructor into signal server * tidy * tidy * test
This commit is contained in:
199
pkg/service/servicefakes/fake_session_handler.go
Normal file
199
pkg/service/servicefakes/fake_session_handler.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Code generated by counterfeiter. DO NOT EDIT.
|
||||
package servicefakes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/routing"
|
||||
"github.com/livekit/livekit-server/pkg/service"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
)
|
||||
|
||||
type FakeSessionHandler struct {
|
||||
HandleSessionStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error
|
||||
handleSessionMutex sync.RWMutex
|
||||
handleSessionArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
arg2 livekit.RoomName
|
||||
arg3 routing.ParticipantInit
|
||||
arg4 livekit.ConnectionID
|
||||
arg5 routing.MessageSource
|
||||
arg6 routing.MessageSink
|
||||
}
|
||||
handleSessionReturns struct {
|
||||
result1 error
|
||||
}
|
||||
handleSessionReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
LoggerStub func(context.Context) logger.Logger
|
||||
loggerMutex sync.RWMutex
|
||||
loggerArgsForCall []struct {
|
||||
arg1 context.Context
|
||||
}
|
||||
loggerReturns struct {
|
||||
result1 logger.Logger
|
||||
}
|
||||
loggerReturnsOnCall map[int]struct {
|
||||
result1 logger.Logger
|
||||
}
|
||||
invocations map[string][][]interface{}
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) HandleSession(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.ConnectionID, arg5 routing.MessageSource, arg6 routing.MessageSink) error {
|
||||
fake.handleSessionMutex.Lock()
|
||||
ret, specificReturn := fake.handleSessionReturnsOnCall[len(fake.handleSessionArgsForCall)]
|
||||
fake.handleSessionArgsForCall = append(fake.handleSessionArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
arg2 livekit.RoomName
|
||||
arg3 routing.ParticipantInit
|
||||
arg4 livekit.ConnectionID
|
||||
arg5 routing.MessageSource
|
||||
arg6 routing.MessageSink
|
||||
}{arg1, arg2, arg3, arg4, arg5, arg6})
|
||||
stub := fake.HandleSessionStub
|
||||
fakeReturns := fake.handleSessionReturns
|
||||
fake.recordInvocation("HandleSession", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6})
|
||||
fake.handleSessionMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1, arg2, arg3, arg4, arg5, arg6)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) HandleSessionCallCount() int {
|
||||
fake.handleSessionMutex.RLock()
|
||||
defer fake.handleSessionMutex.RUnlock()
|
||||
return len(fake.handleSessionArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) HandleSessionCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error) {
|
||||
fake.handleSessionMutex.Lock()
|
||||
defer fake.handleSessionMutex.Unlock()
|
||||
fake.HandleSessionStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) HandleSessionArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) {
|
||||
fake.handleSessionMutex.RLock()
|
||||
defer fake.handleSessionMutex.RUnlock()
|
||||
argsForCall := fake.handleSessionArgsForCall[i]
|
||||
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) HandleSessionReturns(result1 error) {
|
||||
fake.handleSessionMutex.Lock()
|
||||
defer fake.handleSessionMutex.Unlock()
|
||||
fake.HandleSessionStub = nil
|
||||
fake.handleSessionReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) HandleSessionReturnsOnCall(i int, result1 error) {
|
||||
fake.handleSessionMutex.Lock()
|
||||
defer fake.handleSessionMutex.Unlock()
|
||||
fake.HandleSessionStub = nil
|
||||
if fake.handleSessionReturnsOnCall == nil {
|
||||
fake.handleSessionReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.handleSessionReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) Logger(arg1 context.Context) logger.Logger {
|
||||
fake.loggerMutex.Lock()
|
||||
ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)]
|
||||
fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct {
|
||||
arg1 context.Context
|
||||
}{arg1})
|
||||
stub := fake.LoggerStub
|
||||
fakeReturns := fake.loggerReturns
|
||||
fake.recordInvocation("Logger", []interface{}{arg1})
|
||||
fake.loggerMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub(arg1)
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) LoggerCallCount() int {
|
||||
fake.loggerMutex.RLock()
|
||||
defer fake.loggerMutex.RUnlock()
|
||||
return len(fake.loggerArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) LoggerCalls(stub func(context.Context) logger.Logger) {
|
||||
fake.loggerMutex.Lock()
|
||||
defer fake.loggerMutex.Unlock()
|
||||
fake.LoggerStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) LoggerArgsForCall(i int) context.Context {
|
||||
fake.loggerMutex.RLock()
|
||||
defer fake.loggerMutex.RUnlock()
|
||||
argsForCall := fake.loggerArgsForCall[i]
|
||||
return argsForCall.arg1
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) LoggerReturns(result1 logger.Logger) {
|
||||
fake.loggerMutex.Lock()
|
||||
defer fake.loggerMutex.Unlock()
|
||||
fake.LoggerStub = nil
|
||||
fake.loggerReturns = struct {
|
||||
result1 logger.Logger
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) LoggerReturnsOnCall(i int, result1 logger.Logger) {
|
||||
fake.loggerMutex.Lock()
|
||||
defer fake.loggerMutex.Unlock()
|
||||
fake.LoggerStub = nil
|
||||
if fake.loggerReturnsOnCall == nil {
|
||||
fake.loggerReturnsOnCall = make(map[int]struct {
|
||||
result1 logger.Logger
|
||||
})
|
||||
}
|
||||
fake.loggerReturnsOnCall[i] = struct {
|
||||
result1 logger.Logger
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
fake.handleSessionMutex.RLock()
|
||||
defer fake.handleSessionMutex.RUnlock()
|
||||
fake.loggerMutex.RLock()
|
||||
defer fake.loggerMutex.RUnlock()
|
||||
copiedInvocations := map[string][][]interface{}{}
|
||||
for key, value := range fake.invocations {
|
||||
copiedInvocations[key] = value
|
||||
}
|
||||
return copiedInvocations
|
||||
}
|
||||
|
||||
func (fake *FakeSessionHandler) recordInvocation(key string, args []interface{}) {
|
||||
fake.invocationsMutex.Lock()
|
||||
defer fake.invocationsMutex.Unlock()
|
||||
if fake.invocations == nil {
|
||||
fake.invocations = map[string][][]interface{}{}
|
||||
}
|
||||
if fake.invocations[key] == nil {
|
||||
fake.invocations[key] = [][]interface{}{}
|
||||
}
|
||||
fake.invocations[key] = append(fake.invocations[key], args)
|
||||
}
|
||||
|
||||
var _ service.SessionHandler = new(FakeSessionHandler)
|
||||
@@ -31,14 +31,21 @@ import (
|
||||
"github.com/livekit/psrpc/pkg/middleware"
|
||||
)
|
||||
|
||||
type SessionHandler func(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error
|
||||
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
|
||||
|
||||
//counterfeiter:generate . SessionHandler
|
||||
type SessionHandler interface {
|
||||
Logger(ctx context.Context) logger.Logger
|
||||
|
||||
HandleSession(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error
|
||||
}
|
||||
|
||||
type SignalServer struct {
|
||||
server rpc.TypedSignalServer
|
||||
@@ -72,43 +79,53 @@ func NewDefaultSignalServer(
|
||||
router routing.Router,
|
||||
roomManager *RoomManager,
|
||||
) (r *SignalServer, err error) {
|
||||
sessionHandler := func(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error {
|
||||
prometheus.IncrementParticipantRtcInit(1)
|
||||
return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, &defaultSessionHandler{currentNode, router, roomManager})
|
||||
}
|
||||
|
||||
if rr, ok := router.(*routing.RedisRouter); ok {
|
||||
rtcNode, err := router.GetNodeForRoom(ctx, roomName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
type defaultSessionHandler struct {
|
||||
currentNode routing.LocalNode
|
||||
router routing.Router
|
||||
roomManager *RoomManager
|
||||
}
|
||||
|
||||
if rtcNode.Id != currentNode.Id {
|
||||
err = routing.ErrIncorrectRTCNode
|
||||
logger.Errorw("called participant on incorrect node", err,
|
||||
"rtcNode", rtcNode,
|
||||
)
|
||||
return err
|
||||
}
|
||||
func (s *defaultSessionHandler) Logger(ctx context.Context) logger.Logger {
|
||||
return logger.GetLogger()
|
||||
}
|
||||
|
||||
pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity)
|
||||
pKeyB62 := routing.ParticipantKey(roomName, pi.Identity)
|
||||
func (s *defaultSessionHandler) HandleSession(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error {
|
||||
prometheus.IncrementParticipantRtcInit(1)
|
||||
|
||||
// RTC session should start on this node
|
||||
if err := rr.SetParticipantRTCNode(pKey, pKeyB62, currentNode.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
if rr, ok := s.router.(*routing.RedisRouter); ok {
|
||||
rtcNode, err := s.router.GetNodeForRoom(ctx, roomName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink)
|
||||
if rtcNode.Id != s.currentNode.Id {
|
||||
err = routing.ErrIncorrectRTCNode
|
||||
logger.Errorw("called participant on incorrect node", err,
|
||||
"rtcNode", rtcNode,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
pKey := routing.ParticipantKeyLegacy(roomName, pi.Identity)
|
||||
pKeyB62 := routing.ParticipantKey(roomName, pi.Identity)
|
||||
|
||||
// RTC session should start on this node
|
||||
if err := rr.SetParticipantRTCNode(pKey, pKeyB62, s.currentNode.Id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, config, sessionHandler)
|
||||
return s.roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink)
|
||||
}
|
||||
|
||||
func (s *SignalServer) Start() error {
|
||||
@@ -142,7 +159,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe
|
||||
return errors.Wrap(err, "failed to read participant from session")
|
||||
}
|
||||
|
||||
l := logger.GetLogger().WithValues(
|
||||
l := r.sessionHandler.Logger(stream.Context()).WithValues(
|
||||
"room", ss.RoomName,
|
||||
"participant", ss.Identity,
|
||||
"connID", ss.ConnectionId,
|
||||
@@ -175,7 +192,7 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe
|
||||
// copy the incoming rpc headers to avoid dropping any session vars.
|
||||
ctx := metadata.NewContextWithIncomingHeader(context.Background(), metadata.IncomingHeader(stream.Context()))
|
||||
|
||||
err = r.sessionHandler(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink)
|
||||
err = r.sessionHandler.HandleSession(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink)
|
||||
if err != nil {
|
||||
sink.Close()
|
||||
l.Errorw("could not handle new participant", err)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package service
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -26,8 +26,11 @@ import (
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/config"
|
||||
"github.com/livekit/livekit-server/pkg/routing"
|
||||
"github.com/livekit/livekit-server/pkg/service"
|
||||
"github.com/livekit/livekit-server/pkg/service/servicefakes"
|
||||
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/psrpc"
|
||||
)
|
||||
|
||||
@@ -61,22 +64,26 @@ func TestSignal(t *testing.T) {
|
||||
client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error {
|
||||
go func() {
|
||||
reqMessageOut = <-requestSource.ReadChan()
|
||||
resErr = responseSink.WriteMessage(resMessageIn)
|
||||
responseSink.Close()
|
||||
close(done)
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
handler := &servicefakes.FakeSessionHandler{
|
||||
LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() },
|
||||
HandleSessionStub: func(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error {
|
||||
go func() {
|
||||
reqMessageOut = <-requestSource.ReadChan()
|
||||
resErr = responseSink.WriteMessage(resMessageIn)
|
||||
responseSink.Close()
|
||||
close(done)
|
||||
}()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.Start()
|
||||
@@ -114,18 +121,22 @@ func TestSignal(t *testing.T) {
|
||||
client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
server, err := NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error {
|
||||
defer close(done)
|
||||
resErr = responseSink.WriteMessage(resMessageIn)
|
||||
return errors.New("start session failed")
|
||||
})
|
||||
handler := &servicefakes.FakeSessionHandler{
|
||||
LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() },
|
||||
HandleSessionStub: func(
|
||||
ctx context.Context,
|
||||
roomName livekit.RoomName,
|
||||
pi routing.ParticipantInit,
|
||||
connectionID livekit.ConnectionID,
|
||||
requestSource routing.MessageSource,
|
||||
responseSink routing.MessageSink,
|
||||
) error {
|
||||
defer close(done)
|
||||
resErr = responseSink.WriteMessage(resMessageIn)
|
||||
return errors.New("start session failed")
|
||||
},
|
||||
}
|
||||
server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.Start()
|
||||
|
||||
Reference in New Issue
Block a user