add psrpc redis keepalive (#2398)

* add psrpc redis keepalive

* deps
This commit is contained in:
Paul Wells
2024-01-21 06:16:40 -08:00
committed by GitHub
parent 8c932da678
commit cb42c6152c
8 changed files with 51 additions and 298 deletions
+1 -1
View File
@@ -18,7 +18,7 @@ require (
github.com/jxskiss/base62 v1.1.0
github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1
github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f
github.com/livekit/protocol v1.9.5-0.20240118112540-cf33ad3861d8
github.com/livekit/protocol v1.9.5-0.20240121141201-9e82495c0485
github.com/livekit/psrpc v0.5.3-0.20231214055026-06ce27a934c9
github.com/mackerelio/go-osstat v0.2.4
github.com/magefile/mage v1.15.0
+2
View File
@@ -128,6 +128,8 @@ github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f h1:XHrw
github.com/livekit/mediatransportutil v0.0.0-20231213075826-cccbf2b93d3f/go.mod h1:GBzn9xL+mivI1pW+tyExcKgbc0VOc29I9yJsNcAVaAc=
github.com/livekit/protocol v1.9.5-0.20240118112540-cf33ad3861d8 h1:E9s9KFCuKgYWYgaKz0ZmC7K3cPr8Iij77HbnwhQ4JZw=
github.com/livekit/protocol v1.9.5-0.20240118112540-cf33ad3861d8/go.mod h1:Qv55+z0kD0NYp/G0qAaFA4Mjalxt7tsOJwrvV3HymsA=
github.com/livekit/protocol v1.9.5-0.20240121141201-9e82495c0485 h1:X75uVI0+YA7QN28NaVniP4IjhbcDWlktZ3Ec+PHjoHA=
github.com/livekit/protocol v1.9.5-0.20240121141201-9e82495c0485/go.mod h1:Qv55+z0kD0NYp/G0qAaFA4Mjalxt7tsOJwrvV3HymsA=
github.com/livekit/psrpc v0.5.3-0.20231214055026-06ce27a934c9 h1:kXXV/NLVDHZ+Gn7xrR+UPpdwbH48n7WReBjLHAzqzhY=
github.com/livekit/psrpc v0.5.3-0.20231214055026-06ce27a934c9/go.mod h1:cQjxg1oCxYHhxxv6KJH1gSvdtCHQoRZCHgPdm5N8v2g=
github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs=
+3 -17
View File
@@ -24,6 +24,7 @@ import (
"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
)
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
@@ -62,21 +63,6 @@ type ParticipantInit struct {
SubscriberAllowPause *bool
}
type NewParticipantCallback func(
ctx context.Context,
roomName livekit.RoomName,
pi ParticipantInit,
requestSource MessageSource,
responseSink MessageSink,
) error
type RTCMessageCallback func(
ctx context.Context,
roomName livekit.RoomName,
identity livekit.ParticipantIdentity,
msg *livekit.RTCNodeMessage,
)
// Router allows multiple nodes to coordinate the participant session
//
//counterfeiter:generate . Router
@@ -113,11 +99,11 @@ type MessageRouter interface {
StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error)
}
func CreateRouter(rc redis.UniversalClient, node LocalNode, signalClient SignalClient) Router {
func CreateRouter(rc redis.UniversalClient, node LocalNode, signalClient SignalClient, kps rpc.KeepalivePubSub) Router {
lr := NewLocalRouter(node, signalClient)
if rc != nil {
return NewRedisRouter(lr, rc)
return NewRedisRouter(lr, rc, kps)
}
// local routing and store
+3 -3
View File
@@ -39,12 +39,12 @@ func TestMessageChannel_WriteMessageClosed(t *testing.T) {
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
_ = m.WriteMessage(&livekit.RTCNodeMessage{})
_ = m.WriteMessage(&livekit.SignalRequest{})
}
}()
_ = m.WriteMessage(&livekit.RTCNodeMessage{})
_ = m.WriteMessage(&livekit.SignalRequest{})
m.Close()
_ = m.WriteMessage(&livekit.RTCNodeMessage{})
_ = m.WriteMessage(&livekit.SignalRequest{})
wg.Wait()
}
-211
View File
@@ -1,211 +0,0 @@
// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"github.com/redis/go-redis/v9"
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"
"github.com/livekit/protocol/livekit"
)
const (
// hash of node_id => Node proto
NodesKey = "nodes"
// hash of room_name => node_id
NodeRoomKey = "room_node_map"
)
var redisCtx = context.Background()
// location of the participant's RTC connection, hash
func participantRTCKey(participantKey livekit.ParticipantKey) string {
return "participant_rtc:" + string(participantKey)
}
// location of the participant's Signal connection, hash
func participantSignalKey(connectionID livekit.ConnectionID) string {
return "participant_signal:" + string(connectionID)
}
func rtcNodeChannel(nodeID livekit.NodeID) string {
return "rtc_channel:" + string(nodeID)
}
func signalNodeChannel(nodeID livekit.NodeID) string {
return "signal_channel:" + string(nodeID)
}
func publishRTCMessage(rc redis.UniversalClient, nodeID livekit.NodeID, participantKey livekit.ParticipantKey, participantKeyB62 livekit.ParticipantKey, msg proto.Message) error {
rm := &livekit.RTCNodeMessage{
ParticipantKey: string(participantKey),
ParticipantKeyB62: string(participantKeyB62),
}
switch o := msg.(type) {
case *livekit.StartSession:
rm.Message = &livekit.RTCNodeMessage_StartSession{
StartSession: o,
}
case *livekit.SignalRequest:
rm.Message = &livekit.RTCNodeMessage_Request{
Request: o,
}
case *livekit.RTCNodeMessage:
rm = o
rm.ParticipantKey = string(participantKey)
rm.ParticipantKeyB62 = string(participantKeyB62)
default:
return ErrInvalidRouterMessage
}
data, err := proto.Marshal(rm)
if err != nil {
return err
}
// logger.Debugw("publishing to rtc", "rtcChannel", rtcNodeChannel(nodeID),
// "message", rm.Message)
return rc.Publish(redisCtx, rtcNodeChannel(nodeID), data).Err()
}
func publishSignalMessage(rc redis.UniversalClient, nodeID livekit.NodeID, connectionID livekit.ConnectionID, msg proto.Message) error {
rm := &livekit.SignalNodeMessage{
ConnectionId: string(connectionID),
}
switch o := msg.(type) {
case *livekit.SignalResponse:
rm.Message = &livekit.SignalNodeMessage_Response{
Response: o,
}
case *livekit.EndSession:
rm.Message = &livekit.SignalNodeMessage_EndSession{
EndSession: o,
}
default:
return ErrInvalidRouterMessage
}
data, err := proto.Marshal(rm)
if err != nil {
return err
}
// logger.Debugw("publishing to signal", "signalChannel", signalNodeChannel(nodeID),
// "message", rm.Message)
return rc.Publish(redisCtx, signalNodeChannel(nodeID), data).Err()
}
type RTCNodeSink struct {
rc redis.UniversalClient
nodeID livekit.NodeID
connectionID livekit.ConnectionID
participantKey livekit.ParticipantKey
participantKeyB62 livekit.ParticipantKey
isClosed atomic.Bool
onClose func()
}
func NewRTCNodeSink(
rc redis.UniversalClient,
nodeID livekit.NodeID,
connectionID livekit.ConnectionID,
participantKey livekit.ParticipantKey,
participantKeyB62 livekit.ParticipantKey,
) *RTCNodeSink {
return &RTCNodeSink{
rc: rc,
nodeID: nodeID,
connectionID: connectionID,
participantKey: participantKey,
participantKeyB62: participantKeyB62,
}
}
func (s *RTCNodeSink) WriteMessage(msg proto.Message) error {
if s.isClosed.Load() {
return ErrChannelClosed
}
return publishRTCMessage(s.rc, s.nodeID, s.participantKey, s.participantKeyB62, msg)
}
func (s *RTCNodeSink) Close() {
if s.isClosed.Swap(true) {
return
}
if s.onClose != nil {
s.onClose()
}
}
func (s *RTCNodeSink) IsClosed() bool {
return s.isClosed.Load()
}
func (s *RTCNodeSink) OnClose(f func()) {
s.onClose = f
}
func (s *RTCNodeSink) ConnectionID() livekit.ConnectionID {
return s.connectionID
}
// ----------------------------------------------------------------------
type SignalNodeSink struct {
rc redis.UniversalClient
nodeID livekit.NodeID
connectionID livekit.ConnectionID
isClosed atomic.Bool
onClose func()
}
func NewSignalNodeSink(rc redis.UniversalClient, nodeID livekit.NodeID, connectionID livekit.ConnectionID) *SignalNodeSink {
return &SignalNodeSink{
rc: rc,
nodeID: nodeID,
connectionID: connectionID,
}
}
func (s *SignalNodeSink) WriteMessage(msg proto.Message) error {
if s.isClosed.Load() {
return ErrChannelClosed
}
return publishSignalMessage(s.rc, s.nodeID, s.connectionID, msg)
}
func (s *SignalNodeSink) Close() {
if s.isClosed.Swap(true) {
return
}
_ = publishSignalMessage(s.rc, s.nodeID, s.connectionID, &livekit.EndSession{})
if s.onClose != nil {
s.onClose()
}
}
func (s *SignalNodeSink) IsClosed() bool {
return s.isClosed.Load()
}
func (s *SignalNodeSink) OnClose(f func()) {
s.onClose = f
}
func (s *SignalNodeSink) ConnectionID() livekit.ConnectionID {
return s.connectionID
}
+25 -63
View File
@@ -28,6 +28,7 @@ import (
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
"github.com/livekit/livekit-server/pkg/routing/selector"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
@@ -38,6 +39,12 @@ const (
participantMappingTTL = 24 * time.Hour
statsUpdateInterval = 2 * time.Second
statsMaxDelaySeconds = 30
// hash of node_id => Node proto
NodesKey = "nodes"
// hash of room_name => node_id
NodeRoomKey = "room_node_map"
)
var _ Router = (*RedisRouter)(nil)
@@ -49,20 +56,21 @@ type RedisRouter struct {
*LocalRouter
rc redis.UniversalClient
kps rpc.KeepalivePubSub
ctx context.Context
isStarted atomic.Bool
nodeMu sync.RWMutex
// previous stats for computing averages
prevStats *livekit.NodeStats
pubsub *redis.PubSub
cancel func()
}
func NewRedisRouter(lr *LocalRouter, rc redis.UniversalClient) *RedisRouter {
func NewRedisRouter(lr *LocalRouter, rc redis.UniversalClient, kps rpc.KeepalivePubSub) *RedisRouter {
rr := &RedisRouter{
LocalRouter: lr,
rc: rc,
kps: kps,
}
rr.ctx, rr.cancel = context.WithCancel(context.Background())
return rr
@@ -164,33 +172,17 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek
return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id))
}
func (r *RedisRouter) WriteNodeRTC(_ context.Context, rtcNodeID string, msg *livekit.RTCNodeMessage) error {
rtcSink := NewRTCNodeSink(r.rc, livekit.NodeID(rtcNodeID), "ephemeral", livekit.ParticipantKey(msg.ParticipantKey), livekit.ParticipantKey(msg.ParticipantKeyB62))
defer rtcSink.Close()
return r.writeRTCMessage(rtcSink, msg)
}
func (r *LocalRouter) writeRTCMessage(sink MessageSink, msg *livekit.RTCNodeMessage) error {
msg.SenderTime = time.Now().Unix()
return sink.WriteMessage(msg)
}
func (r *RedisRouter) Start() error {
if r.isStarted.Swap(true) {
return nil
}
workerStarted := make(chan struct{})
workerStarted := make(chan error)
go r.statsWorker()
go r.redisWorker(workerStarted)
go r.keepaliveWorker(workerStarted)
// wait until worker is running
select {
case <-workerStarted:
return nil
case <-time.After(3 * time.Second):
return errors.New("Unable to start redis router")
}
return <-workerStarted
}
func (r *RedisRouter) Drain() {
@@ -207,7 +199,6 @@ func (r *RedisRouter) Stop() {
return
}
logger.Debugw("stopping RedisRouter")
_ = r.pubsub.Close()
_ = r.UnregisterNode()
r.cancel()
}
@@ -219,9 +210,8 @@ func (r *RedisRouter) statsWorker() {
// update periodically
select {
case <-time.After(statsUpdateInterval):
_ = r.WriteNodeRTC(context.Background(), r.currentNode.Id, &livekit.RTCNodeMessage{
Message: &livekit.RTCNodeMessage_KeepAlive{},
})
r.kps.PublishPing(r.ctx, livekit.NodeID(r.currentNode.Id), &rpc.KeepalivePing{Timestamp: time.Now().Unix()})
r.nodeMu.RLock()
stats := r.currentNode.Stats
r.nodeMu.RUnlock()
@@ -245,44 +235,17 @@ func (r *RedisRouter) statsWorker() {
}
}
// worker that consumes redis messages intended for this node
func (r *RedisRouter) redisWorker(startedChan chan struct{}) {
defer func() {
logger.Debugw("finishing redisWorker", "nodeID", r.currentNode.Id)
}()
logger.Debugw("starting redisWorker", "nodeID", r.currentNode.Id)
rtcChannel := rtcNodeChannel(livekit.NodeID(r.currentNode.Id))
r.pubsub = r.rc.Subscribe(r.ctx, rtcChannel)
close(startedChan)
for msg := range r.pubsub.Channel() {
if msg == nil {
return
}
if msg.Channel == rtcChannel {
rm := livekit.RTCNodeMessage{}
if err := proto.Unmarshal([]byte(msg.Payload), &rm); err != nil {
logger.Errorw("could not unmarshal RTC message on rtcchan", err)
prometheus.MessageCounter.WithLabelValues("rtc", "failure").Add(1)
continue
}
if err := r.handleRTCMessage(&rm); err != nil {
logger.Errorw("error processing RTC message", err)
prometheus.MessageCounter.WithLabelValues("rtc", "failure").Add(1)
continue
}
prometheus.MessageCounter.WithLabelValues("rtc", "success").Add(1)
}
func (r *RedisRouter) keepaliveWorker(startedChan chan error) {
pings, err := r.kps.SubscribePing(r.ctx, livekit.NodeID(r.currentNode.Id))
if err != nil {
startedChan <- err
return
}
}
close(startedChan)
func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error {
switch rm.Message.(type) {
case *livekit.RTCNodeMessage_KeepAlive:
if time.Since(time.Unix(rm.SenderTime, 0)) > statsUpdateInterval {
logger.Infow("keep alive too old, skipping", "senderTime", rm.SenderTime)
for ping := range pings.Channel() {
if time.Since(time.Unix(ping.Timestamp, 0)) > statsUpdateInterval {
logger.Infow("keep alive too old, skipping", "timestamp", ping.Timestamp)
break
}
@@ -294,7 +257,7 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error {
if err != nil {
logger.Errorw("could not update node stats", err)
r.nodeMu.Unlock()
return err
continue
}
r.currentNode.Stats = updated
if computedAvg {
@@ -307,5 +270,4 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error {
logger.Errorw("could not update node", err)
}
}
return nil
}
+4
View File
@@ -82,6 +82,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
getSignalRelayConfig,
NewDefaultSignalServer,
routing.NewSignalClient,
rpc.NewKeepalivePubSub,
getPSRPCConfig,
getPSRPCClientParams,
rpc.NewTopicFormatter,
@@ -103,7 +104,10 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi
getNodeID,
getMessageBus,
getSignalRelayConfig,
getPSRPCConfig,
getPSRPCClientParams,
routing.NewSignalClient,
rpc.NewKeepalivePubSub,
routing.CreateRouter,
)
+13 -3
View File
@@ -50,7 +50,12 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
router := routing.CreateRouter(universalClient, currentNode, signalClient)
clientParams := getPSRPCClientParams(psrpcConfig, messageBus)
keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams)
if err != nil {
return nil, err
}
router := routing.CreateRouter(universalClient, currentNode, signalClient, keepalivePubSub)
objectStore := createStore(universalClient)
roomAllocator, err := NewRoomAllocator(conf, router, objectStore)
if err != nil {
@@ -60,7 +65,6 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live
if err != nil {
return nil, err
}
clientParams := getPSRPCClientParams(psrpcConfig, messageBus)
egressClient, err := rpc.NewEgressClient(clientParams)
if err != nil {
return nil, err
@@ -149,7 +153,13 @@ func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routi
if err != nil {
return nil, err
}
router := routing.CreateRouter(universalClient, currentNode, signalClient)
psrpcConfig := getPSRPCConfig(conf)
clientParams := getPSRPCClientParams(psrpcConfig, messageBus)
keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams)
if err != nil {
return nil, err
}
router := routing.CreateRouter(universalClient, currentNode, signalClient, keepalivePubSub)
return router, nil
}