From eb095db70a9beb4cc1f31a745d65e26c21abe48e Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Sun, 9 Apr 2023 18:18:21 -0700 Subject: [PATCH] Batch signal retries (#1593) * batch signal retries * cleanup * update protocol * range check message dedup * update protocol with codegen * block while draining * only log send timeouts * cleanup * cleanup * cleanup * typo * update config yaml options * update protocol --- go.mod | 8 +- go.sum | 16 +-- pkg/config/config.go | 13 +- pkg/routing/signal.go | 264 ++++++++++++++++++++++++++++++++----- pkg/service/signal.go | 119 ++++++----------- pkg/service/signal_test.go | 83 ++++++++++++ 6 files changed, 370 insertions(+), 133 deletions(-) create mode 100644 pkg/service/signal_test.go diff --git a/go.mod b/go.mod index 7c052ce66..b1d53bcec 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20230326055817-ed569ca13d26 - github.com/livekit/protocol v1.5.2 + github.com/livekit/protocol v1.5.3-0.20230410011118-30f8b4c081aa github.com/livekit/psrpc v0.2.11-0.20230405191830-d76f71512630 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.14.0 @@ -90,12 +90,12 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.uber.org/multierr v1.6.0 // indirect - golang.org/x/crypto v0.7.0 // indirect + golang.org/x/crypto v0.8.0 // indirect golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect golang.org/x/mod v0.8.0 // indirect - golang.org/x/net v0.8.0 // indirect + golang.org/x/net v0.9.0 // indirect golang.org/x/sys v0.7.0 // indirect - golang.org/x/text v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect golang.org/x/tools v0.6.0 // indirect google.golang.org/genproto v0.0.0-20230403163135-c38d8f061ccd // indirect google.golang.org/grpc v1.54.0 // indirect diff --git a/go.sum b/go.sum index 59640e93b..e763d9714 100644 --- a/go.sum +++ b/go.sum @@ -235,8 +235,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20230326055817-ed569ca13d26 h1:QlQFyMwCDgjyySsrgmrMcVbEBA6KZcyTzvK+z346tUA= github.com/livekit/mediatransportutil v0.0.0-20230326055817-ed569ca13d26/go.mod h1:eDA41kiySZoG+wy4Etsjb3w0jjLx69i/vAmSjG4bteA= -github.com/livekit/protocol v1.5.2 h1:mbbkJNxbStvb9sDtB7CFX7NnTObYKFumNU7wWm4UOfY= -github.com/livekit/protocol v1.5.2/go.mod h1:UFgAWejoO4eshaaDe2jynTdQWwSktNO+8Wx19V7bs+o= +github.com/livekit/protocol v1.5.3-0.20230410011118-30f8b4c081aa h1:s7ACG7CGvt12tiBYSsywSavYh3S/JLVZI7Ob3ot0rKs= +github.com/livekit/protocol v1.5.3-0.20230410011118-30f8b4c081aa/go.mod h1:GzQYVsW/eIsI7xdDTNUGed+SD7IpCI1dLdOlIqRmd2U= github.com/livekit/psrpc v0.2.11-0.20230405191830-d76f71512630 h1:Rm5KLZgQxWnTidY+H8MsAV6sk1iiFxeXqPFgSLkMing= github.com/livekit/psrpc v0.2.11-0.20230405191830-d76f71512630/go.mod h1:K0j8f1PgLShR7Lx80KbmwFkDH2BvOnycXGV0OSRURKc= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= @@ -436,8 +436,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -522,8 +522,8 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -626,8 +626,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= diff --git a/pkg/config/config.go b/pkg/config/config.go index 42ab4e735..50cdf5ccc 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -230,10 +230,11 @@ type NodeSelectorConfig struct { type SignalRelayConfig struct { Enabled bool `yaml:"enabled"` - MaxAttempts int `yaml:"max_attempts,omitempty"` - Timeout time.Duration `yaml:"timeout,omitempty"` - Backoff time.Duration `yaml:"backoff,omitempty"` + RetryTimeout time.Duration `yaml:"retry_timeout,omitempty"` + MinRetryInterval time.Duration `yaml:"min_retry_interval,omitempty"` + MaxRetryInterval time.Duration `yaml:"max_retry_interval,omitempty"` StreamBufferSize int `yaml:"stream_buffer_size,omitempty"` + MinVersion int `yaml:"min_version,omitempty"` } // RegionConfig lists available regions and their latitude/longitude, so the selector would prefer @@ -407,9 +408,9 @@ func NewConfig(confString string, strictMode bool, c *cli.Context, baseFlags []c }, SignalRelay: SignalRelayConfig{ Enabled: false, - MaxAttempts: 3, - Timeout: 500 * time.Millisecond, - Backoff: 500 * time.Millisecond, + RetryTimeout: 30 * time.Second, + MinRetryInterval: 500 * time.Millisecond, + MaxRetryInterval: 5 * time.Second, StreamBufferSize: 1000, }, Keys: map[string]string{}, diff --git a/pkg/routing/signal.go b/pkg/routing/signal.go index 418852c3d..5be76e70d 100644 --- a/pkg/routing/signal.go +++ b/pkg/routing/signal.go @@ -2,6 +2,9 @@ package routing import ( "context" + "errors" + "sync" + "time" "go.uber.org/atomic" "google.golang.org/protobuf/proto" @@ -36,11 +39,6 @@ func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus, config config. nodeID, bus, middleware.WithClientMetrics(prometheus.PSRPCMetricsObserver{}), - middleware.WithStreamRetries(middleware.RetryOptions{ - MaxAttempts: config.MaxAttempts, - Timeout: config.Timeout, - Backoff: config.Backoff, - }), psrpc.WithClientChannelSize(config.StreamBufferSize), ) if err != nil { @@ -75,14 +73,15 @@ func (r *signalClient) StartParticipantSignal( return } - logger.Debugw( - "starting signal connection", + l := logger.GetLogger().WithValues( "room", roomName, "reqNodeID", nodeID, "participant", pi.Identity, "connectionID", connectionID, ) + l.Debugw("starting signal connection") + stream, err := r.client.RelaySignal(ctx, nodeID) if err != nil { return @@ -94,49 +93,248 @@ func (r *signalClient) StartParticipantSignal( return } + sink := NewSignalMessageSink(SignalSinkParams[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse]{ + Logger: l, + Stream: stream, + Config: r.config, + Writer: signalRequestMessageWriter{}, + CloseOnFailure: true, + }) resChan := NewDefaultMessageChannel() go func() { r.active.Inc() defer r.active.Dec() - var err error - for msg := range stream.Channel() { - if err = resChan.WriteMessage(msg.Response); err != nil { - break - } - for _, res := range msg.Responses { - if err = resChan.WriteMessage(res); err != nil { - break - } - } - } - - logger.Debugw("participant signal stream closed", - "error", err, - "room", ss.RoomName, - "participant", ss.Identity, - "connectionID", connectionID, + err = CopySignalStreamToMessageChannel[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse]( + stream, + resChan, + signalResponseMessageReader{}, + r.config, ) + l.Debugw("participant signal stream closed", "error", err) resChan.Close() }() - return connectionID, &relaySignalRequestSink{stream}, resChan, nil + return connectionID, sink, resChan, nil } -type relaySignalRequestSink struct { - psrpc.ClientStream[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse] +type signalRequestMessageWriter struct{} + +func (e signalRequestMessageWriter) WriteOne(seq uint64, msg proto.Message) *rpc.RelaySignalRequest { + return &rpc.RelaySignalRequest{ + Seq: seq, + Request: msg.(*livekit.SignalRequest), + } } -func (s *relaySignalRequestSink) Close() { - s.ClientStream.Close(nil) +func (e signalRequestMessageWriter) WriteMany(seq uint64, msgs []proto.Message) *rpc.RelaySignalRequest { + r := &rpc.RelaySignalRequest{ + Seq: seq, + Requests: make([]*livekit.SignalRequest, 0, len(msgs)), + } + for _, m := range msgs { + r.Requests = append(r.Requests, m.(*livekit.SignalRequest)) + } + return r } -func (s *relaySignalRequestSink) IsClosed() bool { - return s.Context().Err() != nil +type signalResponseMessageReader struct{} + +func (e signalResponseMessageReader) Read(rm *rpc.RelaySignalResponse) ([]proto.Message, error) { + msgs := make([]proto.Message, 0, len(rm.Responses)+1) + if rm.Response != nil { + msgs = append(msgs, rm.Response) + } + for _, m := range rm.Responses { + msgs = append(msgs, m) + } + return msgs, nil } -func (s *relaySignalRequestSink) WriteMessage(msg proto.Message) error { - return s.Send(&rpc.RelaySignalRequest{Request: msg.(*livekit.SignalRequest)}) +type RelaySignalMessage interface { + proto.Message + GetSeq() uint64 +} + +type SignalMessageWriter[SendType RelaySignalMessage] interface { + WriteOne(seq uint64, msg proto.Message) SendType + WriteMany(seq uint64, msgs []proto.Message) SendType +} + +type SignalMessageReader[RecvType RelaySignalMessage] interface { + Read(msg RecvType) ([]proto.Message, error) +} + +func CopySignalStreamToMessageChannel[SendType, RecvType RelaySignalMessage]( + stream psrpc.Stream[SendType, RecvType], + ch *MessageChannel, + reader SignalMessageReader[RecvType], + config config.SignalRelayConfig, +) error { + r := &signalMessageReader[SendType, RecvType]{ + reader: reader, + config: config, + } + for msg := range stream.Channel() { + var res []proto.Message + res, err := r.Read(msg) + if err != nil { + return err + } + for _, r := range res { + if err = ch.WriteMessage(r); err != nil { + return err + } + } + } + return stream.Err() +} + +type signalMessageReader[SendType, RecvType RelaySignalMessage] struct { + seq uint64 + reader SignalMessageReader[RecvType] + config config.SignalRelayConfig +} + +func (r *signalMessageReader[SendType, RecvType]) Read(msg RecvType) ([]proto.Message, error) { + res, err := r.reader.Read(msg) + if err != nil { + return nil, err + } + + if r.config.MinVersion >= 1 { + if r.seq < msg.GetSeq() { + return nil, errors.New("signal message dropped") + } + if r.seq > msg.GetSeq() { + n := int(r.seq - msg.GetSeq()) + if n > len(res) { + n = len(res) + } + res = res[n:] + } + r.seq += uint64(len(res)) + } + return res, nil +} + +type SignalSinkParams[SendType, RecvType RelaySignalMessage] struct { + Stream psrpc.Stream[SendType, RecvType] + Logger logger.Logger + Config config.SignalRelayConfig + Writer SignalMessageWriter[SendType] + CloseOnFailure bool +} + +func NewSignalMessageSink[SendType, RecvType RelaySignalMessage](params SignalSinkParams[SendType, RecvType]) MessageSink { + return &signalMessageSink[SendType, RecvType]{ + SignalSinkParams: params, + } +} + +var ErrSignalFailed = errors.New("signal stream failed") + +type signalMessageSink[SendType, RecvType RelaySignalMessage] struct { + SignalSinkParams[SendType, RecvType] + + mu sync.Mutex + seq uint64 + queue []proto.Message + writing bool + draining bool +} + +func (s *signalMessageSink[SendType, RecvType]) Close() { + s.mu.Lock() + s.draining = true + if !s.writing { + s.Stream.Close(nil) + } + s.mu.Unlock() + + <-s.Stream.Context().Done() +} + +func (s *signalMessageSink[SendType, RecvType]) IsClosed() bool { + return s.Stream.Err() != nil +} + +func (s *signalMessageSink[SendType, RecvType]) nextMessage() (msg SendType, n int) { + if len(s.queue) == 0 { + return + } + if s.Config.MinVersion >= 1 { + return s.Writer.WriteMany(s.seq, s.queue), len(s.queue) + } + return s.Writer.WriteOne(s.seq, s.queue[0]), 1 +} + +func (s *signalMessageSink[SendType, RecvType]) write() { + interval := s.Config.MinRetryInterval + deadline := time.Now().Add(s.Config.RetryTimeout) + + s.mu.Lock() + for { + msg, n := s.nextMessage() + if n == 0 || s.IsClosed() { + if s.draining { + s.Stream.Close(nil) + } + s.writing = false + break + } + s.mu.Unlock() + + err := s.Stream.Send(msg, psrpc.WithTimeout(interval)) + if err != nil { + if time.Now().After(deadline) { + s.Logger.Warnw("could not send signal message", err) + + s.mu.Lock() + s.seq += uint64(len(s.queue)) + s.queue = nil + + if s.CloseOnFailure { + s.Stream.Close(ErrSignalFailed) + } + s.mu.Unlock() + return + } + + interval *= 2 + if interval > s.Config.MaxRetryInterval { + interval = s.Config.MaxRetryInterval + } + } + + s.mu.Lock() + if err == nil { + interval = s.Config.MinRetryInterval + deadline = time.Now().Add(s.Config.RetryTimeout) + + s.seq += uint64(n) + s.queue = s.queue[n:] + } + } + s.mu.Unlock() +} + +func (s *signalMessageSink[SendType, RecvType]) WriteMessage(msg proto.Message) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.Stream.Err(); err != nil { + return err + } else if s.draining { + return psrpc.ErrStreamClosed + } + + s.queue = append(s.queue, msg) + if !s.writing { + s.writing = true + go s.write() + } + return nil } diff --git a/pkg/service/signal.go b/pkg/service/signal.go index 407c173e1..3d9b6c17e 100644 --- a/pkg/service/signal.go +++ b/pkg/service/signal.go @@ -2,8 +2,6 @@ package service import ( "context" - "fmt" - "sync" "github.com/pkg/errors" "google.golang.org/protobuf/proto" @@ -40,14 +38,9 @@ func NewSignalServer( ) (*SignalServer, error) { s, err := rpc.NewTypedSignalServer( nodeID, - &signalService{region, sessionHandler}, + &signalService{region, sessionHandler, config}, bus, middleware.WithServerMetrics(prometheus.PSRPCMetricsObserver{}), - psrpc.WithServerStreamInterceptors(middleware.NewStreamRetryInterceptorFactory(middleware.RetryOptions{ - MaxAttempts: config.MaxAttempts, - Timeout: config.Timeout, - Backoff: config.Backoff, - })), psrpc.WithServerChannelSize(config.StreamBufferSize), ) if err != nil { @@ -101,6 +94,7 @@ func (r *SignalServer) Stop() { type signalService struct { region string sessionHandler SessionHandler + config config.SignalRelayConfig } func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]) (err error) { @@ -134,92 +128,53 @@ func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalRe reqChan := routing.NewDefaultMessageChannel() defer reqChan.Close() - err = r.sessionHandler( - ctx, - livekit.RoomName(ss.RoomName), - *pi, - livekit.ConnectionID(ss.ConnectionId), - reqChan, - &relaySignalResponseSink{ - ServerStream: stream, - logger: l, - }, - ) + sink := routing.NewSignalMessageSink(routing.SignalSinkParams[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]{ + Logger: l, + Stream: stream, + Config: r.config, + Writer: signalResponseMessageWriter{}, + }) + + err = r.sessionHandler(ctx, livekit.RoomName(ss.RoomName), *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) if err != nil { l.Errorw("could not handle new participant", err) } - for msg := range stream.Channel() { - if err = reqChan.WriteMessage(msg.Request); err != nil { - break - } - } + err = routing.CopySignalStreamToMessageChannel[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest](stream, reqChan, signalRequestMessageReader{}, r.config) + l.Debugw("participant signal stream closed", "error", err) - l.Debugw("participant signal stream closed") return } -type relaySignalResponseSink struct { - psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest] - logger logger.Logger +type signalResponseMessageWriter struct{} - mu sync.Mutex - queue []*livekit.SignalResponse - writing bool - draining bool -} - -func (s *relaySignalResponseSink) Close() { - s.mu.Lock() - s.draining = true - if !s.writing { - s.ServerStream.Close(nil) - } - s.mu.Unlock() -} - -func (s *relaySignalResponseSink) IsClosed() bool { - return s.Context().Err() != nil -} - -func (s *relaySignalResponseSink) write() { - for { - s.mu.Lock() - var msg *livekit.SignalResponse - if len(s.queue) != 0 && !s.IsClosed() { - msg = s.queue[0] - s.queue = s.queue[1:] - } else { - if s.draining { - s.ServerStream.Close(nil) - } - s.writing = false - s.mu.Unlock() - return - } - s.mu.Unlock() - - if err := s.Send(&rpc.RelaySignalResponse{Response: msg}); err != nil { - s.logger.Warnw( - "could not send message to participant", err, - "messageType", fmt.Sprintf("%T", msg.Message), - ) - } +func (e signalResponseMessageWriter) WriteOne(seq uint64, msg proto.Message) *rpc.RelaySignalResponse { + return &rpc.RelaySignalResponse{ + Seq: seq, + Response: msg.(*livekit.SignalResponse), } } -func (s *relaySignalResponseSink) WriteMessage(msg proto.Message) error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.draining || s.IsClosed() { - return psrpc.ErrStreamClosed +func (e signalResponseMessageWriter) WriteMany(seq uint64, msgs []proto.Message) *rpc.RelaySignalResponse { + r := &rpc.RelaySignalResponse{ + Seq: seq, + Responses: make([]*livekit.SignalResponse, 0, len(msgs)), } - - s.queue = append(s.queue, msg.(*livekit.SignalResponse)) - if !s.writing { - s.writing = true - go s.write() + for _, m := range msgs { + r.Responses = append(r.Responses, m.(*livekit.SignalResponse)) } - return nil + return r +} + +type signalRequestMessageReader struct{} + +func (e signalRequestMessageReader) Read(rm *rpc.RelaySignalRequest) ([]proto.Message, error) { + msgs := make([]proto.Message, 0, len(rm.Requests)+1) + if rm.Request != nil { + msgs = append(msgs, rm.Request) + } + for _, m := range rm.Requests { + msgs = append(msgs, m) + } + return msgs, nil } diff --git a/pkg/service/signal_test.go b/pkg/service/signal_test.go new file mode 100644 index 000000000..a953fd9dc --- /dev/null +++ b/pkg/service/signal_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/livekit" + "github.com/livekit/psrpc" +) + +func init() { + prometheus.Init("node", livekit.NodeType_CONTROLLER, "test") +} + +func TestSignal(t *testing.T) { + bus := psrpc.NewLocalMessageBus() + cfg := config.SignalRelayConfig{ + Enabled: false, + RetryTimeout: 30 * time.Second, + MinRetryInterval: 500 * time.Millisecond, + MaxRetryInterval: 5 * time.Second, + StreamBufferSize: 1000, + MinVersion: 1, + } + + reqMessageIn := &livekit.SignalRequest{ + Message: &livekit.SignalRequest_Ping{Ping: 123}, + } + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } + + var reqMessageOut proto.Message + var resErr error + done := make(chan struct{}) + + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) + + _, err = NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, func( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + go func() { + reqMessageOut = <-requestSource.ReadChan() + resErr = responseSink.WriteMessage(resMessageIn) + responseSink.Close() + close(done) + }() + return nil + }) + require.NoError(t, err) + + _, reqSink, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + err = reqSink.WriteMessage(reqMessageIn) + require.NoError(t, err) + + <-done + require.True(t, proto.Equal(reqMessageIn, reqMessageOut), "req message should match %s %s", protojson.Format(reqMessageIn), protojson.Format(reqMessageOut)) + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) +}