From cb42c6152c37bf78cbe7f2c3d419b51527eeae0a Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Sun, 21 Jan 2024 06:16:40 -0800 Subject: [PATCH] add psrpc redis keepalive (#2398) * add psrpc redis keepalive * deps --- go.mod | 2 +- go.sum | 2 + pkg/routing/interfaces.go | 20 +-- pkg/routing/messagechannel_test.go | 6 +- pkg/routing/redis.go | 211 ----------------------------- pkg/routing/redisrouter.go | 88 ++++-------- pkg/service/wire.go | 4 + pkg/service/wire_gen.go | 16 ++- 8 files changed, 51 insertions(+), 298 deletions(-) delete mode 100644 pkg/routing/redis.go diff --git a/go.mod b/go.mod index 36dfb175b..e111fbb21 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f - github.com/livekit/protocol v1.9.5-0.20240118112540-cf33ad3861d8 + github.com/livekit/protocol v1.9.5-0.20240121141201-9e82495c0485 github.com/livekit/psrpc v0.5.3-0.20231214055026-06ce27a934c9 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 43692a9cd..cb4e8d092 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f h1:XHrw github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f/go.mod h1:GBzn9xL+mivI1pW+tyExcKgbc0VOc29I9yJsNcAVaAc= github.com/livekit/protocol v1.9.5-0.20240118112540-cf33ad3861d8 h1:E9s9KFCuKgYWYgaKz0ZmC7K3cPr8Iij77HbnwhQ4JZw= github.com/livekit/protocol v1.9.5-0.20240118112540-cf33ad3861d8/go.mod h1:Qv55+z0kD0NYp/G0qAaFA4Mjalxt7tsOJwrvV3HymsA= +github.com/livekit/protocol v1.9.5-0.20240121141201-9e82495c0485 h1:X75uVI0+YA7QN28NaVniP4IjhbcDWlktZ3Ec+PHjoHA= +github.com/livekit/protocol v1.9.5-0.20240121141201-9e82495c0485/go.mod h1:Qv55+z0kD0NYp/G0qAaFA4Mjalxt7tsOJwrvV3HymsA= github.com/livekit/psrpc v0.5.3-0.20231214055026-06ce27a934c9 h1:kXXV/NLVDHZ+Gn7xrR+UPpdwbH48n7WReBjLHAzqzhY= github.com/livekit/psrpc v0.5.3-0.20231214055026-06ce27a934c9/go.mod h1:cQjxg1oCxYHhxxv6KJH1gSvdtCHQoRZCHgPdm5N8v2g= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 0b4f31c41..d13d734a0 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -24,6 +24,7 @@ import ( "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" ) //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate @@ -62,21 +63,6 @@ type ParticipantInit struct { SubscriberAllowPause *bool } -type NewParticipantCallback func( - ctx context.Context, - roomName livekit.RoomName, - pi ParticipantInit, - requestSource MessageSource, - responseSink MessageSink, -) error - -type RTCMessageCallback func( - ctx context.Context, - roomName livekit.RoomName, - identity livekit.ParticipantIdentity, - msg *livekit.RTCNodeMessage, -) - // Router allows multiple nodes to coordinate the participant session // //counterfeiter:generate . Router @@ -113,11 +99,11 @@ type MessageRouter interface { StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error) } -func CreateRouter(rc redis.UniversalClient, node LocalNode, signalClient SignalClient) Router { +func CreateRouter(rc redis.UniversalClient, node LocalNode, signalClient SignalClient, kps rpc.KeepalivePubSub) Router { lr := NewLocalRouter(node, signalClient) if rc != nil { - return NewRedisRouter(lr, rc) + return NewRedisRouter(lr, rc, kps) } // local routing and store diff --git a/pkg/routing/messagechannel_test.go b/pkg/routing/messagechannel_test.go index 5d78c2104..bac68ef1e 100644 --- a/pkg/routing/messagechannel_test.go +++ b/pkg/routing/messagechannel_test.go @@ -39,12 +39,12 @@ func TestMessageChannel_WriteMessageClosed(t *testing.T) { go func() { defer wg.Done() for i := 0; i < 100; i++ { - _ = m.WriteMessage(&livekit.RTCNodeMessage{}) + _ = m.WriteMessage(&livekit.SignalRequest{}) } }() - _ = m.WriteMessage(&livekit.RTCNodeMessage{}) + _ = m.WriteMessage(&livekit.SignalRequest{}) m.Close() - _ = m.WriteMessage(&livekit.RTCNodeMessage{}) + _ = m.WriteMessage(&livekit.SignalRequest{}) wg.Wait() } diff --git a/pkg/routing/redis.go b/pkg/routing/redis.go deleted file mode 100644 index 5d81ce088..000000000 --- a/pkg/routing/redis.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2023 LiveKit, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package routing - -import ( - "context" - - "github.com/redis/go-redis/v9" - "go.uber.org/atomic" - "google.golang.org/protobuf/proto" - - "github.com/livekit/protocol/livekit" -) - -const ( - // hash of node_id => Node proto - NodesKey = "nodes" - - // hash of room_name => node_id - NodeRoomKey = "room_node_map" -) - -var redisCtx = context.Background() - -// location of the participant's RTC connection, hash -func participantRTCKey(participantKey livekit.ParticipantKey) string { - return "participant_rtc:" + string(participantKey) -} - -// location of the participant's Signal connection, hash -func participantSignalKey(connectionID livekit.ConnectionID) string { - return "participant_signal:" + string(connectionID) -} - -func rtcNodeChannel(nodeID livekit.NodeID) string { - return "rtc_channel:" + string(nodeID) -} - -func signalNodeChannel(nodeID livekit.NodeID) string { - return "signal_channel:" + string(nodeID) -} - -func publishRTCMessage(rc redis.UniversalClient, nodeID livekit.NodeID, participantKey livekit.ParticipantKey, participantKeyB62 livekit.ParticipantKey, msg proto.Message) error { - rm := &livekit.RTCNodeMessage{ - ParticipantKey: string(participantKey), - ParticipantKeyB62: string(participantKeyB62), - } - switch o := msg.(type) { - case *livekit.StartSession: - rm.Message = &livekit.RTCNodeMessage_StartSession{ - StartSession: o, - } - case *livekit.SignalRequest: - rm.Message = &livekit.RTCNodeMessage_Request{ - Request: o, - } - case *livekit.RTCNodeMessage: - rm = o - rm.ParticipantKey = string(participantKey) - rm.ParticipantKeyB62 = string(participantKeyB62) - default: - return ErrInvalidRouterMessage - } - data, err := proto.Marshal(rm) - if err != nil { - return err - } - - // logger.Debugw("publishing to rtc", "rtcChannel", rtcNodeChannel(nodeID), - // "message", rm.Message) - return rc.Publish(redisCtx, rtcNodeChannel(nodeID), data).Err() -} - -func publishSignalMessage(rc redis.UniversalClient, nodeID livekit.NodeID, connectionID livekit.ConnectionID, msg proto.Message) error { - rm := &livekit.SignalNodeMessage{ - ConnectionId: string(connectionID), - } - switch o := msg.(type) { - case *livekit.SignalResponse: - rm.Message = &livekit.SignalNodeMessage_Response{ - Response: o, - } - case *livekit.EndSession: - rm.Message = &livekit.SignalNodeMessage_EndSession{ - EndSession: o, - } - default: - return ErrInvalidRouterMessage - } - data, err := proto.Marshal(rm) - if err != nil { - return err - } - - // logger.Debugw("publishing to signal", "signalChannel", signalNodeChannel(nodeID), - // "message", rm.Message) - return rc.Publish(redisCtx, signalNodeChannel(nodeID), data).Err() -} - -type RTCNodeSink struct { - rc redis.UniversalClient - nodeID livekit.NodeID - connectionID livekit.ConnectionID - participantKey livekit.ParticipantKey - participantKeyB62 livekit.ParticipantKey - isClosed atomic.Bool - onClose func() -} - -func NewRTCNodeSink( - rc redis.UniversalClient, - nodeID livekit.NodeID, - connectionID livekit.ConnectionID, - participantKey livekit.ParticipantKey, - participantKeyB62 livekit.ParticipantKey, -) *RTCNodeSink { - return &RTCNodeSink{ - rc: rc, - nodeID: nodeID, - connectionID: connectionID, - participantKey: participantKey, - participantKeyB62: participantKeyB62, - } -} - -func (s *RTCNodeSink) WriteMessage(msg proto.Message) error { - if s.isClosed.Load() { - return ErrChannelClosed - } - return publishRTCMessage(s.rc, s.nodeID, s.participantKey, s.participantKeyB62, msg) -} - -func (s *RTCNodeSink) Close() { - if s.isClosed.Swap(true) { - return - } - if s.onClose != nil { - s.onClose() - } -} - -func (s *RTCNodeSink) IsClosed() bool { - return s.isClosed.Load() -} - -func (s *RTCNodeSink) OnClose(f func()) { - s.onClose = f -} - -func (s *RTCNodeSink) ConnectionID() livekit.ConnectionID { - return s.connectionID -} - -// ---------------------------------------------------------------------- - -type SignalNodeSink struct { - rc redis.UniversalClient - nodeID livekit.NodeID - connectionID livekit.ConnectionID - isClosed atomic.Bool - onClose func() -} - -func NewSignalNodeSink(rc redis.UniversalClient, nodeID livekit.NodeID, connectionID livekit.ConnectionID) *SignalNodeSink { - return &SignalNodeSink{ - rc: rc, - nodeID: nodeID, - connectionID: connectionID, - } -} - -func (s *SignalNodeSink) WriteMessage(msg proto.Message) error { - if s.isClosed.Load() { - return ErrChannelClosed - } - return publishSignalMessage(s.rc, s.nodeID, s.connectionID, msg) -} - -func (s *SignalNodeSink) Close() { - if s.isClosed.Swap(true) { - return - } - _ = publishSignalMessage(s.rc, s.nodeID, s.connectionID, &livekit.EndSession{}) - if s.onClose != nil { - s.onClose() - } -} - -func (s *SignalNodeSink) IsClosed() bool { - return s.isClosed.Load() -} - -func (s *SignalNodeSink) OnClose(f func()) { - s.onClose = f -} - -func (s *SignalNodeSink) ConnectionID() livekit.ConnectionID { - return s.connectionID -} diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 743d63d1c..579671814 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -28,6 +28,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/livekit-server/pkg/routing/selector" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" @@ -38,6 +39,12 @@ const ( participantMappingTTL = 24 * time.Hour statsUpdateInterval = 2 * time.Second statsMaxDelaySeconds = 30 + + // hash of node_id => Node proto + NodesKey = "nodes" + + // hash of room_name => node_id + NodeRoomKey = "room_node_map" ) var _ Router = (*RedisRouter)(nil) @@ -49,20 +56,21 @@ type RedisRouter struct { *LocalRouter rc redis.UniversalClient + kps rpc.KeepalivePubSub ctx context.Context isStarted atomic.Bool nodeMu sync.RWMutex // previous stats for computing averages prevStats *livekit.NodeStats - pubsub *redis.PubSub cancel func() } -func NewRedisRouter(lr *LocalRouter, rc redis.UniversalClient) *RedisRouter { +func NewRedisRouter(lr *LocalRouter, rc redis.UniversalClient, kps rpc.KeepalivePubSub) *RedisRouter { rr := &RedisRouter{ LocalRouter: lr, rc: rc, + kps: kps, } rr.ctx, rr.cancel = context.WithCancel(context.Background()) return rr @@ -164,33 +172,17 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id)) } -func (r *RedisRouter) WriteNodeRTC(_ context.Context, rtcNodeID string, msg *livekit.RTCNodeMessage) error { - rtcSink := NewRTCNodeSink(r.rc, livekit.NodeID(rtcNodeID), "ephemeral", livekit.ParticipantKey(msg.ParticipantKey), livekit.ParticipantKey(msg.ParticipantKeyB62)) - defer rtcSink.Close() - return r.writeRTCMessage(rtcSink, msg) -} - -func (r *LocalRouter) writeRTCMessage(sink MessageSink, msg *livekit.RTCNodeMessage) error { - msg.SenderTime = time.Now().Unix() - return sink.WriteMessage(msg) -} - func (r *RedisRouter) Start() error { if r.isStarted.Swap(true) { return nil } - workerStarted := make(chan struct{}) + workerStarted := make(chan error) go r.statsWorker() - go r.redisWorker(workerStarted) + go r.keepaliveWorker(workerStarted) // wait until worker is running - select { - case <-workerStarted: - return nil - case <-time.After(3 * time.Second): - return errors.New("Unable to start redis router") - } + return <-workerStarted } func (r *RedisRouter) Drain() { @@ -207,7 +199,6 @@ func (r *RedisRouter) Stop() { return } logger.Debugw("stopping RedisRouter") - _ = r.pubsub.Close() _ = r.UnregisterNode() r.cancel() } @@ -219,9 +210,8 @@ func (r *RedisRouter) statsWorker() { // update periodically select { case <-time.After(statsUpdateInterval): - _ = r.WriteNodeRTC(context.Background(), r.currentNode.Id, &livekit.RTCNodeMessage{ - Message: &livekit.RTCNodeMessage_KeepAlive{}, - }) + r.kps.PublishPing(r.ctx, livekit.NodeID(r.currentNode.Id), &rpc.KeepalivePing{Timestamp: time.Now().Unix()}) + r.nodeMu.RLock() stats := r.currentNode.Stats r.nodeMu.RUnlock() @@ -245,44 +235,17 @@ func (r *RedisRouter) statsWorker() { } } -// worker that consumes redis messages intended for this node -func (r *RedisRouter) redisWorker(startedChan chan struct{}) { - defer func() { - logger.Debugw("finishing redisWorker", "nodeID", r.currentNode.Id) - }() - logger.Debugw("starting redisWorker", "nodeID", r.currentNode.Id) - - rtcChannel := rtcNodeChannel(livekit.NodeID(r.currentNode.Id)) - r.pubsub = r.rc.Subscribe(r.ctx, rtcChannel) - - close(startedChan) - for msg := range r.pubsub.Channel() { - if msg == nil { - return - } - - if msg.Channel == rtcChannel { - rm := livekit.RTCNodeMessage{} - if err := proto.Unmarshal([]byte(msg.Payload), &rm); err != nil { - logger.Errorw("could not unmarshal RTC message on rtcchan", err) - prometheus.MessageCounter.WithLabelValues("rtc", "failure").Add(1) - continue - } - if err := r.handleRTCMessage(&rm); err != nil { - logger.Errorw("error processing RTC message", err) - prometheus.MessageCounter.WithLabelValues("rtc", "failure").Add(1) - continue - } - prometheus.MessageCounter.WithLabelValues("rtc", "success").Add(1) - } +func (r *RedisRouter) keepaliveWorker(startedChan chan error) { + pings, err := r.kps.SubscribePing(r.ctx, livekit.NodeID(r.currentNode.Id)) + if err != nil { + startedChan <- err + return } -} + close(startedChan) -func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { - switch rm.Message.(type) { - case *livekit.RTCNodeMessage_KeepAlive: - if time.Since(time.Unix(rm.SenderTime, 0)) > statsUpdateInterval { - logger.Infow("keep alive too old, skipping", "senderTime", rm.SenderTime) + for ping := range pings.Channel() { + if time.Since(time.Unix(ping.Timestamp, 0)) > statsUpdateInterval { + logger.Infow("keep alive too old, skipping", "timestamp", ping.Timestamp) break } @@ -294,7 +257,7 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { if err != nil { logger.Errorw("could not update node stats", err) r.nodeMu.Unlock() - return err + continue } r.currentNode.Stats = updated if computedAvg { @@ -307,5 +270,4 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { logger.Errorw("could not update node", err) } } - return nil } diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 93e3784b0..0a2501a1d 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -82,6 +82,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live getSignalRelayConfig, NewDefaultSignalServer, routing.NewSignalClient, + rpc.NewKeepalivePubSub, getPSRPCConfig, getPSRPCClientParams, rpc.NewTopicFormatter, @@ -103,7 +104,10 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi getNodeID, getMessageBus, getSignalRelayConfig, + getPSRPCConfig, + getPSRPCClientParams, routing.NewSignalClient, + rpc.NewKeepalivePubSub, routing.CreateRouter, ) diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 5b2a3aab6..e76fa56c7 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -50,7 +50,12 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - router := routing.CreateRouter(universalClient, currentNode, signalClient) + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams) + if err != nil { + return nil, err + } + router := routing.CreateRouter(universalClient, currentNode, signalClient, keepalivePubSub) objectStore := createStore(universalClient) roomAllocator, err := NewRoomAllocator(conf, router, objectStore) if err != nil { @@ -60,7 +65,6 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - clientParams := getPSRPCClientParams(psrpcConfig, messageBus) egressClient, err := rpc.NewEgressClient(clientParams) if err != nil { return nil, err @@ -149,7 +153,13 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi if err != nil { return nil, err } - router := routing.CreateRouter(universalClient, currentNode, signalClient) + psrpcConfig := getPSRPCConfig(conf) + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams) + if err != nil { + return nil, err + } + router := routing.CreateRouter(universalClient, currentNode, signalClient, keepalivePubSub) return router, nil }