diff --git a/go.mod b/go.mod index 66e5242fb..718c71201 100644 --- a/go.mod +++ b/go.mod @@ -17,8 +17,8 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20230130133657-96cfb115473a - github.com/livekit/protocol v1.4.3-0.20230228000108-073251b64ab4 - github.com/livekit/psrpc v0.2.9 + github.com/livekit/protocol v1.4.3-0.20230303025609-c0705dbb696a + github.com/livekit/psrpc v0.2.10-0.20230303054701-5853a56b4643 github.com/livekit/rtcscore-go v0.0.0-20230224125650-6a6442ef9ebc github.com/mackerelio/go-osstat v0.2.3 github.com/magefile/mage v1.14.0 @@ -98,7 +98,7 @@ require ( golang.org/x/sys v0.5.0 // indirect golang.org/x/text v0.7.0 // indirect golang.org/x/tools v0.5.0 // indirect - google.golang.org/genproto v0.0.0-20230227214838-9b19f0bdc514 // indirect + google.golang.org/genproto v0.0.0-20230301171018-9ab4bdc49ad5 // indirect google.golang.org/grpc v1.53.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 0188d4cd1..e3e483782 100644 --- a/go.sum +++ b/go.sum @@ -232,10 +232,10 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20230130133657-96cfb115473a h1:5UkGQpskXp7HcBmyrCwWtO7ygDWbqtjN09Yva4l/nyE= github.com/livekit/mediatransportutil v0.0.0-20230130133657-96cfb115473a/go.mod h1:1Dlx20JPoIKGP45eo+yuj0HjeE25zmyeX/EWHiPCjFw= -github.com/livekit/protocol v1.4.3-0.20230228000108-073251b64ab4 h1:Vr8mjL0sb0RR6nCr0LgNr+hcEvFCm48IWBEuJqBMDN4= -github.com/livekit/protocol v1.4.3-0.20230228000108-073251b64ab4/go.mod h1:hkK/G0wwFiLUGp9F5kxeQxq2CQuIzkmfBwKhTsc71us= -github.com/livekit/psrpc v0.2.9 h1:F9QatmORMcCzzzkDDqFJHe1ZIrJw9rXiluBk33Pmcdw= -github.com/livekit/psrpc v0.2.9/go.mod h1:K0j8f1PgLShR7Lx80KbmwFkDH2BvOnycXGV0OSRURKc= +github.com/livekit/protocol v1.4.3-0.20230303025609-c0705dbb696a h1:3yPLmATyLh6EJxXi80MYc2vapr6b5Y00nzg8Prvgha4= +github.com/livekit/protocol v1.4.3-0.20230303025609-c0705dbb696a/go.mod h1:hkK/G0wwFiLUGp9F5kxeQxq2CQuIzkmfBwKhTsc71us= +github.com/livekit/psrpc v0.2.10-0.20230303054701-5853a56b4643 h1:ftDwqesgXMu0hUXFxf4KWAqqDIXz8BBdNnit3xc6RQA= +github.com/livekit/psrpc v0.2.10-0.20230303054701-5853a56b4643/go.mod h1:K0j8f1PgLShR7Lx80KbmwFkDH2BvOnycXGV0OSRURKc= github.com/livekit/rtcscore-go v0.0.0-20230224125650-6a6442ef9ebc h1:C8gL3pCjKmevR38PJ7+TsPS+Rm4Kbba8lwJmjNMqdUU= github.com/livekit/rtcscore-go v0.0.0-20230224125650-6a6442ef9ebc/go.mod h1:116ych8UaEs9vfIE8n6iZCZ30iagUFTls0vRmC+Ix5U= github.com/mackerelio/go-osstat v0.2.3 h1:jAMXD5erlDE39kdX2CU7YwCGRcxIO33u/p8+Fhe5dJw= @@ -728,8 +728,8 @@ google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7Fc google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20230227214838-9b19f0bdc514 h1:rtNKfB++wz5mtDY2t5C8TXlU5y52ojSu7tZo0z7u8eQ= -google.golang.org/genproto v0.0.0-20230227214838-9b19f0bdc514/go.mod h1:TvhZT5f700eVlTNwND1xoEZQeWTB2RY/65kplwl/bFA= +google.golang.org/genproto v0.0.0-20230301171018-9ab4bdc49ad5 h1:/cadn7taPtPlCgiWNetEPsle7jgnlad2R7gR5MXB6dM= +google.golang.org/genproto v0.0.0-20230301171018-9ab4bdc49ad5/go.mod h1:TvhZT5f700eVlTNwND1xoEZQeWTB2RY/65kplwl/bFA= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/pkg/config/config.go b/pkg/config/config.go index 729baa5e9..7eadbcc40 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -61,6 +61,7 @@ type Config struct { KeyFile string `yaml:"key_file,omitempty"` Keys map[string]string `yaml:"keys,omitempty"` Region string `yaml:"region,omitempty"` + UsePSRPCSignal bool `yaml:"use_psrpc_signal,omitempty"` // LogLevel is deprecated LogLevel string `yaml:"log_level,omitempty"` Logging LoggingConfig `yaml:"logging,omitempty"` diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 64b6c16c6..d7968e413 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -7,6 +7,7 @@ import ( "github.com/redis/go-redis/v9" "google.golang.org/protobuf/proto" + "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -60,6 +61,17 @@ type RTCMessageCallback func( msg *livekit.RTCNodeMessage, ) +type NewSignalClientCallabck func( + roomName livekit.RoomName, + pi ParticipantInit, + nodeID livekit.NodeID, +) ( + connectionID livekit.ConnectionID, + reqSink MessageSink, + resSource MessageSource, + err error, +) + // Router allows multiple nodes to coordinate the participant session // //counterfeiter:generate . Router @@ -98,14 +110,16 @@ type MessageRouter interface { WriteRoomRTC(ctx context.Context, roomName livekit.RoomName, msg *livekit.RTCNodeMessage) error } -func CreateRouter(rc redis.UniversalClient, node LocalNode) Router { +func CreateRouter(config *config.Config, rc redis.UniversalClient, node LocalNode, signalClient SignalClient) Router { + lr := NewLocalRouter(node, signalClient) + if rc != nil { - return NewRedisRouter(node, rc) + return NewRedisRouter(config, lr, rc) } // local routing and store logger.Infow("using single-node routing") - return NewLocalRouter(node) + return lr } func (pi *ParticipantInit) ToStartSession(roomName livekit.RoomName, connectionID livekit.ConnectionID) (*livekit.StartSession, error) { diff --git a/pkg/routing/localrouter.go b/pkg/routing/localrouter.go index 49bd4136c..a70cec005 100644 --- a/pkg/routing/localrouter.go +++ b/pkg/routing/localrouter.go @@ -8,10 +8,8 @@ import ( "go.uber.org/atomic" "google.golang.org/protobuf/proto" - "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/utils" ) // aggregated channel for all participants @@ -19,8 +17,10 @@ const localRTCChannelSize = 10000 // a router of messages on the same node, basic implementation for local testing type LocalRouter struct { - currentNode LocalNode - lock sync.RWMutex + currentNode LocalNode + signalClient SignalClient + + lock sync.RWMutex // channels for each participant requestChannels map[string]*MessageChannel responseChannels map[string]*MessageChannel @@ -32,9 +32,10 @@ type LocalRouter struct { onRTCMessage RTCMessageCallback } -func NewLocalRouter(currentNode LocalNode) *LocalRouter { +func NewLocalRouter(currentNode LocalNode, signalClient SignalClient) *LocalRouter { return &LocalRouter{ currentNode: currentNode, + signalClient: signalClient, requestChannels: make(map[string]*MessageChannel), responseChannels: make(map[string]*MessageChannel), rtcMessageChan: NewMessageChannel(localRTCChannelSize), @@ -83,51 +84,19 @@ func (r *LocalRouter) ListNodes() ([]*livekit.Node, error) { } func (r *LocalRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) { - prometheus.IncrementParticipantRtcInit(1) - // treat it as a new participant connecting - if r.onNewParticipant == nil { - err = ErrHandlerNotDefined - return - } + return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(r.currentNode.Id)) +} - // create a new connection id - connectionID = livekit.ConnectionID(utils.NewGuid("CO_")) - // index channels by roomName | identity - key := participantKey(roomName, pi.Identity) - key = key + "|" + livekit.ParticipantKey(connectionID) - - // close older channels if one already exists - reqChan := r.getMessageChannel(r.requestChannels, string(key)) - if reqChan != nil { - reqChan.Close() - } - resChan := r.getMessageChannel(r.responseChannels, string(key)) - if resChan != nil { - resChan.Close() - } - reqChan = r.getOrCreateMessageChannel(r.requestChannels, string(key)) - resChan = r.getOrCreateMessageChannel(r.responseChannels, string(key)) - - go func() { - err := r.onNewParticipant( - ctx, - roomName, - pi, - // request source - reqChan, - // response sink - resChan, +func (r *LocalRouter) StartParticipantSignalWithNodeID(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) { + connectionID, reqSink, resSource, err = r.signalClient.StartParticipantSignal(ctx, roomName, pi, livekit.NodeID(r.currentNode.Id)) + if err != nil { + logger.Errorw("could not handle new participant", err, + "room", roomName, + "participant", pi.Identity, + "connectionID", connectionID, ) - if err != nil { - reqChan.Close() - resChan.Close() - logger.Errorw("could not handle new participant", err, - "room", roomName, - "participant", pi.Identity, - ) - } - }() - return connectionID, reqChan, resChan, nil + } + return } func (r *LocalRouter) WriteParticipantRTC(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) error { diff --git a/pkg/routing/messagechannel.go b/pkg/routing/messagechannel.go index 6b6493b76..de5bd9075 100644 --- a/pkg/routing/messagechannel.go +++ b/pkg/routing/messagechannel.go @@ -15,6 +15,10 @@ type MessageChannel struct { lock sync.RWMutex } +func NewDefaultMessageChannel() *MessageChannel { + return NewMessageChannel(DefaultMessageChannelSize) +} + func NewMessageChannel(size int) *MessageChannel { return &MessageChannel{ // allow some buffer to avoid blocked writes diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 5f251e285..4acbbdff4 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -16,6 +16,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils" + "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing/selector" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" ) @@ -31,12 +32,13 @@ const ( // It relies on the RTC node to be the primary driver of the participant connection. // Because type RedisRouter struct { - LocalRouter + *LocalRouter - rc redis.UniversalClient - ctx context.Context - isStarted atomic.Bool - nodeMu sync.RWMutex + rc redis.UniversalClient + usePSRPCSignal bool + ctx context.Context + isStarted atomic.Bool + nodeMu sync.RWMutex // previous stats for computing averages prevStats *livekit.NodeStats @@ -44,10 +46,11 @@ type RedisRouter struct { cancel func() } -func NewRedisRouter(currentNode LocalNode, rc redis.UniversalClient) *RedisRouter { +func NewRedisRouter(config *config.Config, lr *LocalRouter, rc redis.UniversalClient) *RedisRouter { rr := &RedisRouter{ - LocalRouter: *NewLocalRouter(currentNode), - rc: rc, + LocalRouter: lr, + rc: rc, + usePSRPCSignal: config.UsePSRPCSignal, } rr.ctx, rr.cancel = context.WithCancel(context.Background()) return rr @@ -146,6 +149,10 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek return } + if r.usePSRPCSignal { + return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id)) + } + // create a new connection id connectionID = livekit.ConnectionID(utils.NewGuid("CO_")) pKey := participantKeyLegacy(roomName, pi.Identity) diff --git a/pkg/routing/routingfakes/fake_signal_client.go b/pkg/routing/routingfakes/fake_signal_client.go new file mode 100644 index 000000000..884dac4d5 --- /dev/null +++ b/pkg/routing/routingfakes/fake_signal_client.go @@ -0,0 +1,134 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package routingfakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" +) + +type FakeSignalClient struct { + StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) + startParticipantSignalMutex sync.RWMutex + startParticipantSignalArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.NodeID + } + startParticipantSignalReturns struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + } + startParticipantSignalReturnsOnCall map[int]struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSignalClient) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) { + fake.startParticipantSignalMutex.Lock() + ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)] + fake.startParticipantSignalArgsForCall = append(fake.startParticipantSignalArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.NodeID + }{arg1, arg2, arg3, arg4}) + stub := fake.StartParticipantSignalStub + fakeReturns := fake.startParticipantSignalReturns + fake.recordInvocation("StartParticipantSignal", []interface{}{arg1, arg2, arg3, arg4}) + fake.startParticipantSignalMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3, ret.result4 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 +} + +func (fake *FakeSignalClient) StartParticipantSignalCallCount() int { + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + return len(fake.startParticipantSignalArgsForCall) +} + +func (fake *FakeSignalClient) StartParticipantSignalCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error)) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = stub +} + +func (fake *FakeSignalClient) StartParticipantSignalArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) { + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + argsForCall := fake.startParticipantSignalArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeSignalClient) StartParticipantSignalReturns(result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = nil + fake.startParticipantSignalReturns = struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeSignalClient) StartParticipantSignalReturnsOnCall(i int, result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = nil + if fake.startParticipantSignalReturnsOnCall == nil { + fake.startParticipantSignalReturnsOnCall = make(map[int]struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + }) + } + fake.startParticipantSignalReturnsOnCall[i] = struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeSignalClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSignalClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ routing.SignalClient = new(FakeSignalClient) diff --git a/pkg/routing/signal.go b/pkg/routing/signal.go new file mode 100644 index 000000000..c5cede85e --- /dev/null +++ b/pkg/routing/signal.go @@ -0,0 +1,112 @@ +package routing + +import ( + "context" + + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . SignalClient +type SignalClient interface { + StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) +} + +type signalClient struct { + nodeID livekit.NodeID + client rpc.TypedSignalClient +} + +func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus) (SignalClient, error) { + c, err := rpc.NewTypedSignalClient(nodeID, bus) + if err != nil { + return nil, err + } + + return &signalClient{ + nodeID: nodeID, + client: c, + }, nil +} + +func (r *signalClient) StartParticipantSignal( + ctx context.Context, + roomName livekit.RoomName, + pi ParticipantInit, + nodeID livekit.NodeID, +) ( + connectionID livekit.ConnectionID, + reqSink MessageSink, + resSource MessageSource, + err error, +) { + connectionID = livekit.ConnectionID(utils.NewGuid("CO_")) + ss, err := pi.ToStartSession(roomName, connectionID) + if err != nil { + return + } + + logger.Debugw( + "starting signal connection", + "room", roomName, + "reqNodeID", nodeID, + "participant", pi.Identity, + "connectionID", connectionID, + ) + + stream, err := r.client.RelaySignal(ctx, nodeID) + if err != nil { + return + } + + err = stream.Send(&rpc.RelaySignalRequest{StartSession: ss}) + if err != nil { + stream.Close(err) + return + } + + resChan := NewDefaultMessageChannel() + + go func() { + var err error + for msg := range stream.Channel() { + if err = resChan.WriteMessage(msg.Response); err != nil { + break + } + } + + logger.Debugw("participant signal stream closed", + "error", err, + "room", ss.RoomName, + "participant", ss.Identity, + "connectionID", connectionID, + ) + + resChan.Close() + }() + + return connectionID, &relaySignalRequestSink{stream}, resChan, nil +} + +type relaySignalRequestSink struct { + psrpc.ClientStream[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse] +} + +func (s *relaySignalRequestSink) Close() { + s.ClientStream.Close(nil) +} + +func (s *relaySignalRequestSink) IsClosed() bool { + return s.Context().Err() != nil +} + +func (s *relaySignalRequestSink) WriteMessage(msg proto.Message) error { + return s.Send(&rpc.RelaySignalRequest{Request: msg.(*livekit.SignalRequest)}) +} diff --git a/pkg/service/server.go b/pkg/service/server.go index 10c0f4ada..bbd9f6775 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -28,18 +28,19 @@ import ( ) type LivekitServer struct { - config *config.Config - ioService *IOInfoService - rtcService *RTCService - httpServer *http.Server - promServer *http.Server - router routing.Router - roomManager *RoomManager - turnServer *turn.Server - currentNode routing.LocalNode - running atomic.Bool - doneChan chan struct{} - closedChan chan struct{} + config *config.Config + ioService *IOInfoService + rtcService *RTCService + httpServer *http.Server + promServer *http.Server + router routing.Router + roomManager *RoomManager + signalServer *SignalServer + turnServer *turn.Server + currentNode routing.LocalNode + running atomic.Bool + doneChan chan struct{} + closedChan chan struct{} } func NewLivekitServer(conf *config.Config, @@ -51,15 +52,17 @@ func NewLivekitServer(conf *config.Config, keyProvider auth.KeyProvider, router routing.Router, roomManager *RoomManager, + signalServer *SignalServer, turnServer *turn.Server, currentNode routing.LocalNode, ) (s *LivekitServer, err error) { s = &LivekitServer{ - config: conf, - ioService: ioService, - rtcService: rtcService, - router: router, - roomManager: roomManager, + config: conf, + ioService: ioService, + rtcService: rtcService, + router: router, + roomManager: roomManager, + signalServer: signalServer, // turn server starts automatically turnServer: turnServer, currentNode: currentNode, @@ -252,6 +255,7 @@ func (s *LivekitServer) Start() error { } s.roomManager.Stop() + s.signalServer.Stop() s.ioService.Stop() close(s.closedChan) diff --git a/pkg/service/signal.go b/pkg/service/signal.go new file mode 100644 index 000000000..abf9b05dd --- /dev/null +++ b/pkg/service/signal.go @@ -0,0 +1,147 @@ +package service + +import ( + "context" + + "github.com/pkg/errors" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" +) + +type SessionHandler func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, +) error + +type SignalServer struct { + server rpc.TypedSignalServer +} + +func NewSignalServer( + nodeID livekit.NodeID, + region string, + bus psrpc.MessageBus, + sessionHandler SessionHandler, +) (*SignalServer, error) { + s, err := rpc.NewTypedSignalServer(nodeID, &signalService{region, sessionHandler}, bus) + if err != nil { + return nil, err + } + logger.Debugw("starting relay signal server", "topic", nodeID) + if err := s.RegisterRelaySignalTopic(nodeID); err != nil { + return nil, err + } + + return &SignalServer{s}, nil +} + +func NewDefaultSignalServer( + currentNode routing.LocalNode, + bus psrpc.MessageBus, + router routing.Router, + roomManager *RoomManager, +) (r *SignalServer, err error) { + sessionHandler := func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + prometheus.IncrementParticipantRtcInit(1) + return roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink) + } + + return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, sessionHandler) +} + +func (r *SignalServer) Stop() { + r.server.Kill() +} + +type signalService struct { + region string + sessionHandler SessionHandler +} + +func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]) (err error) { + // copy the context to prevent a race between the session handler closing + // and the delivery of any parting messages from the client. take care to + // copy the incoming rpc headers to avoid dropping any session vars. + ctx, cancel := context.WithCancel(psrpc.NewContextWithIncomingHeader(context.Background(), psrpc.IncomingHeader(stream.Context()))) + defer cancel() + + req, ok := <-stream.Channel() + if !ok { + return nil + } + + ss := req.StartSession + if ss == nil { + return errors.New("expected start session message") + } + + pi, err := routing.ParticipantInitFromStartSession(ss, r.region) + if err != nil { + return errors.Wrap(err, "failed to read participant from session") + } + + reqChan := routing.NewDefaultMessageChannel() + defer reqChan.Close() + + err = r.sessionHandler( + ctx, + livekit.RoomName(ss.RoomName), + *pi, + livekit.ConnectionID(ss.ConnectionId), + reqChan, + &relaySignalResponseSink{stream}, + ) + if err != nil { + logger.Errorw("could not handle new participant", err, + "room", ss.RoomName, + "participant", ss.Identity, + "connectionID", ss.ConnectionId, + ) + } + + for msg := range stream.Channel() { + if err = reqChan.WriteMessage(msg.Request); err != nil { + break + } + } + + logger.Debugw("participant signal stream closed", + "room", ss.RoomName, + "participant", ss.Identity, + "connectionID", ss.ConnectionId, + ) + return +} + +type relaySignalResponseSink struct { + psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest] +} + +func (s *relaySignalResponseSink) Close() { + s.ServerStream.Close(nil) +} + +func (s *relaySignalResponseSink) IsClosed() bool { + return s.Context().Err() != nil +} + +func (s *relaySignalResponseSink) WriteMessage(msg proto.Message) error { + return s.Send(&rpc.RelaySignalResponse{Response: msg.(*livekit.SignalResponse)}) +} diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 9b4eaefa8..f010c7d3a 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -57,6 +57,8 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewRoomAllocator, NewRoomService, NewRTCService, + NewDefaultSignalServer, + routing.NewSignalClient, NewLocalRoomManager, newTurnAuthHandler, newInProcessTurnServer, @@ -69,6 +71,9 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routing.Router, error) { wire.Build( createRedisClient, + getNodeID, + getMessageBus, + routing.NewSignalClient, routing.CreateRouter, ) @@ -136,7 +141,7 @@ func createStore(rc redis.UniversalClient) ObjectStore { func getMessageBus(rc redis.UniversalClient) psrpc.MessageBus { if rc == nil { - return nil + return psrpc.NewLocalMessageBus() } return psrpc.NewRedisMessageBus(rc) } diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 0dbdcd8f0..9231a74d2 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -40,14 +40,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - router := routing.CreateRouter(universalClient, currentNode) + nodeID := getNodeID(currentNode) + messageBus := getMessageBus(universalClient) + signalClient, err := routing.NewSignalClient(nodeID, messageBus) + if err != nil { + return nil, err + } + router := routing.CreateRouter(conf, universalClient, currentNode, signalClient) objectStore := createStore(universalClient) roomAllocator, err := NewRoomAllocator(conf, router, objectStore) if err != nil { return nil, err } - nodeID := getNodeID(currentNode) - messageBus := getMessageBus(universalClient) egressClient, err := getEgressClient(conf, nodeID, messageBus) if err != nil { return nil, err @@ -88,12 +92,16 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } + signalServer, err := NewDefaultSignalServer(currentNode, messageBus, router, roomManager) + if err != nil { + return nil, err + } authHandler := newTurnAuthHandler(objectStore) server, err := newInProcessTurnServer(conf, authHandler) if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, signalServer, server, currentNode) if err != nil { return nil, err } @@ -105,7 +113,13 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi if err != nil { return nil, err } - router := routing.CreateRouter(universalClient, currentNode) + nodeID := getNodeID(currentNode) + messageBus := getMessageBus(universalClient) + signalClient, err := routing.NewSignalClient(nodeID, messageBus) + if err != nil { + return nil, err + } + router := routing.CreateRouter(conf, universalClient, currentNode, signalClient) return router, nil } @@ -172,7 +186,7 @@ func createStore(rc redis.UniversalClient) ObjectStore { func getMessageBus(rc redis.UniversalClient) psrpc.MessageBus { if rc == nil { - return nil + return psrpc.NewLocalMessageBus() } return psrpc.NewRedisMessageBus(rc) }