add signal psrpc service (#1485)

* add signal psrpc service

* update protocol dep

* refactor for cloud

* update psrpc

* pr feedback
This commit is contained in:
Paul Wells
2023-03-03 15:49:46 -08:00
committed by GitHub
parent e48c818532
commit e22de045ba
13 changed files with 503 additions and 92 deletions
+3 -3
View File
@@ -17,8 +17,8 @@ require (
github.com/jxskiss/base62 v1.1.0
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1
github.com/livekit/mediatransportutil v0.0.0-20230130133657-96cfb115473a
github.com/livekit/protocol v1.4.3-0.20230228000108-073251b64ab4
github.com/livekit/psrpc v0.2.9
github.com/livekit/protocol v1.4.3-0.20230303025609-c0705dbb696a
github.com/livekit/psrpc v0.2.10-0.20230303054701-5853a56b4643
github.com/livekit/rtcscore-go v0.0.0-20230224125650-6a6442ef9ebc
github.com/mackerelio/go-osstat v0.2.3
github.com/magefile/mage v1.14.0
@@ -98,7 +98,7 @@ require (
golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.7.0 // indirect
golang.org/x/tools v0.5.0 // indirect
google.golang.org/genproto v0.0.0-20230227214838-9b19f0bdc514 // indirect
google.golang.org/genproto v0.0.0-20230301171018-9ab4bdc49ad5 // indirect
google.golang.org/grpc v1.53.0 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
+6 -6
View File
@@ -232,10 +232,10 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ=
github.com/livekit/mediatransportutil v0.0.0-20230130133657-96cfb115473a h1:5UkGQpskXp7HcBmyrCwWtO7ygDWbqtjN09Yva4l/nyE=
github.com/livekit/mediatransportutil v0.0.0-20230130133657-96cfb115473a/go.mod h1:1Dlx20JPoIKGP45eo+yuj0HjeE25zmyeX/EWHiPCjFw=
github.com/livekit/protocol v1.4.3-0.20230228000108-073251b64ab4 h1:Vr8mjL0sb0RR6nCr0LgNr+hcEvFCm48IWBEuJqBMDN4=
github.com/livekit/protocol v1.4.3-0.20230228000108-073251b64ab4/go.mod h1:hkK/G0wwFiLUGp9F5kxeQxq2CQuIzkmfBwKhTsc71us=
github.com/livekit/psrpc v0.2.9 h1:F9QatmORMcCzzzkDDqFJHe1ZIrJw9rXiluBk33Pmcdw=
github.com/livekit/psrpc v0.2.9/go.mod h1:K0j8f1PgLShR7Lx80KbmwFkDH2BvOnycXGV0OSRURKc=
github.com/livekit/protocol v1.4.3-0.20230303025609-c0705dbb696a h1:3yPLmATyLh6EJxXi80MYc2vapr6b5Y00nzg8Prvgha4=
github.com/livekit/protocol v1.4.3-0.20230303025609-c0705dbb696a/go.mod h1:hkK/G0wwFiLUGp9F5kxeQxq2CQuIzkmfBwKhTsc71us=
github.com/livekit/psrpc v0.2.10-0.20230303054701-5853a56b4643 h1:ftDwqesgXMu0hUXFxf4KWAqqDIXz8BBdNnit3xc6RQA=
github.com/livekit/psrpc v0.2.10-0.20230303054701-5853a56b4643/go.mod h1:K0j8f1PgLShR7Lx80KbmwFkDH2BvOnycXGV0OSRURKc=
github.com/livekit/rtcscore-go v0.0.0-20230224125650-6a6442ef9ebc h1:C8gL3pCjKmevR38PJ7+TsPS+Rm4Kbba8lwJmjNMqdUU=
github.com/livekit/rtcscore-go v0.0.0-20230224125650-6a6442ef9ebc/go.mod h1:116ych8UaEs9vfIE8n6iZCZ30iagUFTls0vRmC+Ix5U=
github.com/mackerelio/go-osstat v0.2.3 h1:jAMXD5erlDE39kdX2CU7YwCGRcxIO33u/p8+Fhe5dJw=
@@ -728,8 +728,8 @@ google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7Fc
google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20230227214838-9b19f0bdc514 h1:rtNKfB++wz5mtDY2t5C8TXlU5y52ojSu7tZo0z7u8eQ=
google.golang.org/genproto v0.0.0-20230227214838-9b19f0bdc514/go.mod h1:TvhZT5f700eVlTNwND1xoEZQeWTB2RY/65kplwl/bFA=
google.golang.org/genproto v0.0.0-20230301171018-9ab4bdc49ad5 h1:/cadn7taPtPlCgiWNetEPsle7jgnlad2R7gR5MXB6dM=
google.golang.org/genproto v0.0.0-20230301171018-9ab4bdc49ad5/go.mod h1:TvhZT5f700eVlTNwND1xoEZQeWTB2RY/65kplwl/bFA=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
+1
View File
@@ -61,6 +61,7 @@ type Config struct {
KeyFile string `yaml:"key_file,omitempty"`
Keys map[string]string `yaml:"keys,omitempty"`
Region string `yaml:"region,omitempty"`
UsePSRPCSignal bool `yaml:"use_psrpc_signal,omitempty"`
// LogLevel is deprecated
LogLevel string `yaml:"log_level,omitempty"`
Logging LoggingConfig `yaml:"logging,omitempty"`
+17 -3
View File
@@ -7,6 +7,7 @@ import (
"github.com/redis/go-redis/v9"
"google.golang.org/protobuf/proto"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
@@ -60,6 +61,17 @@ type RTCMessageCallback func(
msg *livekit.RTCNodeMessage,
)
type NewSignalClientCallabck func(
roomName livekit.RoomName,
pi ParticipantInit,
nodeID livekit.NodeID,
) (
connectionID livekit.ConnectionID,
reqSink MessageSink,
resSource MessageSource,
err error,
)
// Router allows multiple nodes to coordinate the participant session
//
//counterfeiter:generate . Router
@@ -98,14 +110,16 @@ type MessageRouter interface {
WriteRoomRTC(ctx context.Context, roomName livekit.RoomName, msg *livekit.RTCNodeMessage) error
}
func CreateRouter(rc redis.UniversalClient, node LocalNode) Router {
func CreateRouter(config *config.Config, rc redis.UniversalClient, node LocalNode, signalClient SignalClient) Router {
lr := NewLocalRouter(node, signalClient)
if rc != nil {
return NewRedisRouter(node, rc)
return NewRedisRouter(config, lr, rc)
}
// local routing and store
logger.Infow("using single-node routing")
return NewLocalRouter(node)
return lr
}
func (pi *ParticipantInit) ToStartSession(roomName livekit.RoomName, connectionID livekit.ConnectionID) (*livekit.StartSession, error) {
+17 -48
View File
@@ -8,10 +8,8 @@ import (
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils"
)
// aggregated channel for all participants
@@ -19,8 +17,10 @@ const localRTCChannelSize = 10000
// a router of messages on the same node, basic implementation for local testing
type LocalRouter struct {
currentNode LocalNode
lock sync.RWMutex
currentNode LocalNode
signalClient SignalClient
lock sync.RWMutex
// channels for each participant
requestChannels map[string]*MessageChannel
responseChannels map[string]*MessageChannel
@@ -32,9 +32,10 @@ type LocalRouter struct {
onRTCMessage RTCMessageCallback
}
func NewLocalRouter(currentNode LocalNode) *LocalRouter {
func NewLocalRouter(currentNode LocalNode, signalClient SignalClient) *LocalRouter {
return &LocalRouter{
currentNode: currentNode,
signalClient: signalClient,
requestChannels: make(map[string]*MessageChannel),
responseChannels: make(map[string]*MessageChannel),
rtcMessageChan: NewMessageChannel(localRTCChannelSize),
@@ -83,51 +84,19 @@ func (r *LocalRouter) ListNodes() ([]*livekit.Node, error) {
}
func (r *LocalRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) {
prometheus.IncrementParticipantRtcInit(1)
// treat it as a new participant connecting
if r.onNewParticipant == nil {
err = ErrHandlerNotDefined
return
}
return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(r.currentNode.Id))
}
// create a new connection id
connectionID = livekit.ConnectionID(utils.NewGuid("CO_"))
// index channels by roomName | identity
key := participantKey(roomName, pi.Identity)
key = key + "|" + livekit.ParticipantKey(connectionID)
// close older channels if one already exists
reqChan := r.getMessageChannel(r.requestChannels, string(key))
if reqChan != nil {
reqChan.Close()
}
resChan := r.getMessageChannel(r.responseChannels, string(key))
if resChan != nil {
resChan.Close()
}
reqChan = r.getOrCreateMessageChannel(r.requestChannels, string(key))
resChan = r.getOrCreateMessageChannel(r.responseChannels, string(key))
go func() {
err := r.onNewParticipant(
ctx,
roomName,
pi,
// request source
reqChan,
// response sink
resChan,
func (r *LocalRouter) StartParticipantSignalWithNodeID(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) {
connectionID, reqSink, resSource, err = r.signalClient.StartParticipantSignal(ctx, roomName, pi, livekit.NodeID(r.currentNode.Id))
if err != nil {
logger.Errorw("could not handle new participant", err,
"room", roomName,
"participant", pi.Identity,
"connectionID", connectionID,
)
if err != nil {
reqChan.Close()
resChan.Close()
logger.Errorw("could not handle new participant", err,
"room", roomName,
"participant", pi.Identity,
)
}
}()
return connectionID, reqChan, resChan, nil
}
return
}
func (r *LocalRouter) WriteParticipantRTC(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity, msg *livekit.RTCNodeMessage) error {
+4
View File
@@ -15,6 +15,10 @@ type MessageChannel struct {
lock sync.RWMutex
}
func NewDefaultMessageChannel() *MessageChannel {
return NewMessageChannel(DefaultMessageChannelSize)
}
func NewMessageChannel(size int) *MessageChannel {
return &MessageChannel{
// allow some buffer to avoid blocked writes
+15 -8
View File
@@ -16,6 +16,7 @@ import (
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing/selector"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
)
@@ -31,12 +32,13 @@ const (
// It relies on the RTC node to be the primary driver of the participant connection.
// Because
type RedisRouter struct {
LocalRouter
*LocalRouter
rc redis.UniversalClient
ctx context.Context
isStarted atomic.Bool
nodeMu sync.RWMutex
rc redis.UniversalClient
usePSRPCSignal bool
ctx context.Context
isStarted atomic.Bool
nodeMu sync.RWMutex
// previous stats for computing averages
prevStats *livekit.NodeStats
@@ -44,10 +46,11 @@ type RedisRouter struct {
cancel func()
}
func NewRedisRouter(currentNode LocalNode, rc redis.UniversalClient) *RedisRouter {
func NewRedisRouter(config *config.Config, lr *LocalRouter, rc redis.UniversalClient) *RedisRouter {
rr := &RedisRouter{
LocalRouter: *NewLocalRouter(currentNode),
rc: rc,
LocalRouter: lr,
rc: rc,
usePSRPCSignal: config.UsePSRPCSignal,
}
rr.ctx, rr.cancel = context.WithCancel(context.Background())
return rr
@@ -146,6 +149,10 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek
return
}
if r.usePSRPCSignal {
return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id))
}
// create a new connection id
connectionID = livekit.ConnectionID(utils.NewGuid("CO_"))
pKey := participantKeyLegacy(roomName, pi.Identity)
@@ -0,0 +1,134 @@
// Code generated by counterfeiter. DO NOT EDIT.
package routingfakes
import (
"context"
"sync"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/protocol/livekit"
)
type FakeSignalClient struct {
StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error)
startParticipantSignalMutex sync.RWMutex
startParticipantSignalArgsForCall []struct {
arg1 context.Context
arg2 livekit.RoomName
arg3 routing.ParticipantInit
arg4 livekit.NodeID
}
startParticipantSignalReturns struct {
result1 livekit.ConnectionID
result2 routing.MessageSink
result3 routing.MessageSource
result4 error
}
startParticipantSignalReturnsOnCall map[int]struct {
result1 livekit.ConnectionID
result2 routing.MessageSink
result3 routing.MessageSource
result4 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeSignalClient) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) {
fake.startParticipantSignalMutex.Lock()
ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)]
fake.startParticipantSignalArgsForCall = append(fake.startParticipantSignalArgsForCall, struct {
arg1 context.Context
arg2 livekit.RoomName
arg3 routing.ParticipantInit
arg4 livekit.NodeID
}{arg1, arg2, arg3, arg4})
stub := fake.StartParticipantSignalStub
fakeReturns := fake.startParticipantSignalReturns
fake.recordInvocation("StartParticipantSignal", []interface{}{arg1, arg2, arg3, arg4})
fake.startParticipantSignalMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3, arg4)
}
if specificReturn {
return ret.result1, ret.result2, ret.result3, ret.result4
}
return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4
}
func (fake *FakeSignalClient) StartParticipantSignalCallCount() int {
fake.startParticipantSignalMutex.RLock()
defer fake.startParticipantSignalMutex.RUnlock()
return len(fake.startParticipantSignalArgsForCall)
}
func (fake *FakeSignalClient) StartParticipantSignalCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error)) {
fake.startParticipantSignalMutex.Lock()
defer fake.startParticipantSignalMutex.Unlock()
fake.StartParticipantSignalStub = stub
}
func (fake *FakeSignalClient) StartParticipantSignalArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) {
fake.startParticipantSignalMutex.RLock()
defer fake.startParticipantSignalMutex.RUnlock()
argsForCall := fake.startParticipantSignalArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4
}
func (fake *FakeSignalClient) StartParticipantSignalReturns(result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) {
fake.startParticipantSignalMutex.Lock()
defer fake.startParticipantSignalMutex.Unlock()
fake.StartParticipantSignalStub = nil
fake.startParticipantSignalReturns = struct {
result1 livekit.ConnectionID
result2 routing.MessageSink
result3 routing.MessageSource
result4 error
}{result1, result2, result3, result4}
}
func (fake *FakeSignalClient) StartParticipantSignalReturnsOnCall(i int, result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) {
fake.startParticipantSignalMutex.Lock()
defer fake.startParticipantSignalMutex.Unlock()
fake.StartParticipantSignalStub = nil
if fake.startParticipantSignalReturnsOnCall == nil {
fake.startParticipantSignalReturnsOnCall = make(map[int]struct {
result1 livekit.ConnectionID
result2 routing.MessageSink
result3 routing.MessageSource
result4 error
})
}
fake.startParticipantSignalReturnsOnCall[i] = struct {
result1 livekit.ConnectionID
result2 routing.MessageSink
result3 routing.MessageSource
result4 error
}{result1, result2, result3, result4}
}
func (fake *FakeSignalClient) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
fake.startParticipantSignalMutex.RLock()
defer fake.startParticipantSignalMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeSignalClient) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ routing.SignalClient = new(FakeSignalClient)
+112
View File
@@ -0,0 +1,112 @@
package routing
import (
"context"
"google.golang.org/protobuf/proto"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils"
"github.com/livekit/psrpc"
)
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
//counterfeiter:generate . SignalClient
type SignalClient interface {
StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error)
}
type signalClient struct {
nodeID livekit.NodeID
client rpc.TypedSignalClient
}
func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus) (SignalClient, error) {
c, err := rpc.NewTypedSignalClient(nodeID, bus)
if err != nil {
return nil, err
}
return &signalClient{
nodeID: nodeID,
client: c,
}, nil
}
func (r *signalClient) StartParticipantSignal(
ctx context.Context,
roomName livekit.RoomName,
pi ParticipantInit,
nodeID livekit.NodeID,
) (
connectionID livekit.ConnectionID,
reqSink MessageSink,
resSource MessageSource,
err error,
) {
connectionID = livekit.ConnectionID(utils.NewGuid("CO_"))
ss, err := pi.ToStartSession(roomName, connectionID)
if err != nil {
return
}
logger.Debugw(
"starting signal connection",
"room", roomName,
"reqNodeID", nodeID,
"participant", pi.Identity,
"connectionID", connectionID,
)
stream, err := r.client.RelaySignal(ctx, nodeID)
if err != nil {
return
}
err = stream.Send(&rpc.RelaySignalRequest{StartSession: ss})
if err != nil {
stream.Close(err)
return
}
resChan := NewDefaultMessageChannel()
go func() {
var err error
for msg := range stream.Channel() {
if err = resChan.WriteMessage(msg.Response); err != nil {
break
}
}
logger.Debugw("participant signal stream closed",
"error", err,
"room", ss.RoomName,
"participant", ss.Identity,
"connectionID", connectionID,
)
resChan.Close()
}()
return connectionID, &relaySignalRequestSink{stream}, resChan, nil
}
type relaySignalRequestSink struct {
psrpc.ClientStream[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse]
}
func (s *relaySignalRequestSink) Close() {
s.ClientStream.Close(nil)
}
func (s *relaySignalRequestSink) IsClosed() bool {
return s.Context().Err() != nil
}
func (s *relaySignalRequestSink) WriteMessage(msg proto.Message) error {
return s.Send(&rpc.RelaySignalRequest{Request: msg.(*livekit.SignalRequest)})
}
+21 -17
View File
@@ -28,18 +28,19 @@ import (
)
type LivekitServer struct {
config *config.Config
ioService *IOInfoService
rtcService *RTCService
httpServer *http.Server
promServer *http.Server
router routing.Router
roomManager *RoomManager
turnServer *turn.Server
currentNode routing.LocalNode
running atomic.Bool
doneChan chan struct{}
closedChan chan struct{}
config *config.Config
ioService *IOInfoService
rtcService *RTCService
httpServer *http.Server
promServer *http.Server
router routing.Router
roomManager *RoomManager
signalServer *SignalServer
turnServer *turn.Server
currentNode routing.LocalNode
running atomic.Bool
doneChan chan struct{}
closedChan chan struct{}
}
func NewLivekitServer(conf *config.Config,
@@ -51,15 +52,17 @@ func NewLivekitServer(conf *config.Config,
keyProvider auth.KeyProvider,
router routing.Router,
roomManager *RoomManager,
signalServer *SignalServer,
turnServer *turn.Server,
currentNode routing.LocalNode,
) (s *LivekitServer, err error) {
s = &LivekitServer{
config: conf,
ioService: ioService,
rtcService: rtcService,
router: router,
roomManager: roomManager,
config: conf,
ioService: ioService,
rtcService: rtcService,
router: router,
roomManager: roomManager,
signalServer: signalServer,
// turn server starts automatically
turnServer: turnServer,
currentNode: currentNode,
@@ -252,6 +255,7 @@ func (s *LivekitServer) Start() error {
}
s.roomManager.Stop()
s.signalServer.Stop()
s.ioService.Stop()
close(s.closedChan)
+147
View File
@@ -0,0 +1,147 @@
package service
import (
"context"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/psrpc"
)
type SessionHandler func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error
type SignalServer struct {
server rpc.TypedSignalServer
}
func NewSignalServer(
nodeID livekit.NodeID,
region string,
bus psrpc.MessageBus,
sessionHandler SessionHandler,
) (*SignalServer, error) {
s, err := rpc.NewTypedSignalServer(nodeID, &signalService{region, sessionHandler}, bus)
if err != nil {
return nil, err
}
logger.Debugw("starting relay signal server", "topic", nodeID)
if err := s.RegisterRelaySignalTopic(nodeID); err != nil {
return nil, err
}
return &SignalServer{s}, nil
}
func NewDefaultSignalServer(
currentNode routing.LocalNode,
bus psrpc.MessageBus,
router routing.Router,
roomManager *RoomManager,
) (r *SignalServer, err error) {
sessionHandler := func(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
connectionID livekit.ConnectionID,
requestSource routing.MessageSource,
responseSink routing.MessageSink,
) error {
prometheus.IncrementParticipantRtcInit(1)
return roomManager.StartSession(ctx, roomName, pi, requestSource, responseSink)
}
return NewSignalServer(livekit.NodeID(currentNode.Id), currentNode.Region, bus, sessionHandler)
}
func (r *SignalServer) Stop() {
r.server.Kill()
}
type signalService struct {
region string
sessionHandler SessionHandler
}
func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]) (err error) {
// copy the context to prevent a race between the session handler closing
// and the delivery of any parting messages from the client. take care to
// copy the incoming rpc headers to avoid dropping any session vars.
ctx, cancel := context.WithCancel(psrpc.NewContextWithIncomingHeader(context.Background(), psrpc.IncomingHeader(stream.Context())))
defer cancel()
req, ok := <-stream.Channel()
if !ok {
return nil
}
ss := req.StartSession
if ss == nil {
return errors.New("expected start session message")
}
pi, err := routing.ParticipantInitFromStartSession(ss, r.region)
if err != nil {
return errors.Wrap(err, "failed to read participant from session")
}
reqChan := routing.NewDefaultMessageChannel()
defer reqChan.Close()
err = r.sessionHandler(
ctx,
livekit.RoomName(ss.RoomName),
*pi,
livekit.ConnectionID(ss.ConnectionId),
reqChan,
&relaySignalResponseSink{stream},
)
if err != nil {
logger.Errorw("could not handle new participant", err,
"room", ss.RoomName,
"participant", ss.Identity,
"connectionID", ss.ConnectionId,
)
}
for msg := range stream.Channel() {
if err = reqChan.WriteMessage(msg.Request); err != nil {
break
}
}
logger.Debugw("participant signal stream closed",
"room", ss.RoomName,
"participant", ss.Identity,
"connectionID", ss.ConnectionId,
)
return
}
type relaySignalResponseSink struct {
psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]
}
func (s *relaySignalResponseSink) Close() {
s.ServerStream.Close(nil)
}
func (s *relaySignalResponseSink) IsClosed() bool {
return s.Context().Err() != nil
}
func (s *relaySignalResponseSink) WriteMessage(msg proto.Message) error {
return s.Send(&rpc.RelaySignalResponse{Response: msg.(*livekit.SignalResponse)})
}
+6 -1
View File
@@ -57,6 +57,8 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
NewRoomAllocator,
NewRoomService,
NewRTCService,
NewDefaultSignalServer,
routing.NewSignalClient,
NewLocalRoomManager,
newTurnAuthHandler,
newInProcessTurnServer,
@@ -69,6 +71,9 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routing.Router, error) {
wire.Build(
createRedisClient,
getNodeID,
getMessageBus,
routing.NewSignalClient,
routing.CreateRouter,
)
@@ -136,7 +141,7 @@ func createStore(rc redis.UniversalClient) ObjectStore {
func getMessageBus(rc redis.UniversalClient) psrpc.MessageBus {
if rc == nil {
return nil
return psrpc.NewLocalMessageBus()
}
return psrpc.NewRedisMessageBus(rc)
}
+20 -6
View File
@@ -40,14 +40,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
router := routing.CreateRouter(universalClient, currentNode)
nodeID := getNodeID(currentNode)
messageBus := getMessageBus(universalClient)
signalClient, err := routing.NewSignalClient(nodeID, messageBus)
if err != nil {
return nil, err
}
router := routing.CreateRouter(conf, universalClient, currentNode, signalClient)
objectStore := createStore(universalClient)
roomAllocator, err := NewRoomAllocator(conf, router, objectStore)
if err != nil {
return nil, err
}
nodeID := getNodeID(currentNode)
messageBus := getMessageBus(universalClient)
egressClient, err := getEgressClient(conf, nodeID, messageBus)
if err != nil {
return nil, err
@@ -88,12 +92,16 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
signalServer, err := NewDefaultSignalServer(currentNode, messageBus, router, roomManager)
if err != nil {
return nil, err
}
authHandler := newTurnAuthHandler(objectStore)
server, err := newInProcessTurnServer(conf, authHandler)
if err != nil {
return nil, err
}
livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, server, currentNode)
livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, signalServer, server, currentNode)
if err != nil {
return nil, err
}
@@ -105,7 +113,13 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi
if err != nil {
return nil, err
}
router := routing.CreateRouter(universalClient, currentNode)
nodeID := getNodeID(currentNode)
messageBus := getMessageBus(universalClient)
signalClient, err := routing.NewSignalClient(nodeID, messageBus)
if err != nil {
return nil, err
}
router := routing.CreateRouter(conf, universalClient, currentNode, signalClient)
return router, nil
}
@@ -172,7 +186,7 @@ func createStore(rc redis.UniversalClient) ObjectStore {
func getMessageBus(rc redis.UniversalClient) psrpc.MessageBus {
if rc == nil {
return nil
return psrpc.NewLocalMessageBus()
}
return psrpc.NewRedisMessageBus(rc)
}