mirror of
https://github.com/livekit/livekit.git
synced 2026-05-11 01:47:18 +00:00
371 lines
9.1 KiB
Go
371 lines
9.1 KiB
Go
// Copyright 2023 LiveKit, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package routing
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/atomic"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/livekit/livekit-server/pkg/config"
|
|
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
|
|
"github.com/livekit/protocol/livekit"
|
|
"github.com/livekit/protocol/logger"
|
|
"github.com/livekit/protocol/rpc"
|
|
"github.com/livekit/protocol/utils"
|
|
"github.com/livekit/psrpc"
|
|
"github.com/livekit/psrpc/pkg/middleware"
|
|
)
|
|
|
|
var ErrSignalWriteFailed = errors.New("signal write failed")
|
|
var ErrSignalMessageDropped = errors.New("signal message dropped")
|
|
|
|
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
|
|
|
|
//counterfeiter:generate . SignalClient
|
|
type SignalClient interface {
|
|
ActiveCount() int
|
|
StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error)
|
|
}
|
|
|
|
type signalClient struct {
|
|
nodeID livekit.NodeID
|
|
config config.SignalRelayConfig
|
|
client rpc.TypedSignalClient
|
|
active atomic.Int32
|
|
}
|
|
|
|
func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus, config config.SignalRelayConfig) (SignalClient, error) {
|
|
c, err := rpc.NewTypedSignalClient(
|
|
nodeID,
|
|
bus,
|
|
middleware.WithClientMetrics(prometheus.PSRPCMetricsObserver{}),
|
|
psrpc.WithClientChannelSize(config.StreamBufferSize),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &signalClient{
|
|
nodeID: nodeID,
|
|
config: config,
|
|
client: c,
|
|
}, nil
|
|
}
|
|
|
|
func (r *signalClient) ActiveCount() int {
|
|
return int(r.active.Load())
|
|
}
|
|
|
|
func (r *signalClient) StartParticipantSignal(
|
|
ctx context.Context,
|
|
roomName livekit.RoomName,
|
|
pi ParticipantInit,
|
|
nodeID livekit.NodeID,
|
|
) (
|
|
connectionID livekit.ConnectionID,
|
|
reqSink MessageSink,
|
|
resSource MessageSource,
|
|
err error,
|
|
) {
|
|
connectionID = livekit.ConnectionID(utils.NewGuid("CO_"))
|
|
ss, err := pi.ToStartSession(roomName, connectionID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
l := logger.GetLogger().WithValues(
|
|
"room", roomName,
|
|
"reqNodeID", nodeID,
|
|
"participant", pi.Identity,
|
|
"connID", connectionID,
|
|
)
|
|
|
|
l.Debugw("starting signal connection")
|
|
|
|
stream, err := r.client.RelaySignal(ctx, nodeID)
|
|
if err != nil {
|
|
prometheus.MessageCounter.WithLabelValues("signal", "failure").Add(1)
|
|
return
|
|
}
|
|
|
|
err = stream.Send(&rpc.RelaySignalRequest{StartSession: ss})
|
|
if err != nil {
|
|
stream.Close(err)
|
|
prometheus.MessageCounter.WithLabelValues("signal", "failure").Add(1)
|
|
return
|
|
}
|
|
|
|
sink := NewSignalMessageSink(SignalSinkParams[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse]{
|
|
Logger: l,
|
|
Stream: stream,
|
|
Config: r.config,
|
|
Writer: signalRequestMessageWriter{},
|
|
CloseOnFailure: true,
|
|
BlockOnClose: true,
|
|
ConnectionID: connectionID,
|
|
})
|
|
resChan := NewDefaultMessageChannel(connectionID)
|
|
|
|
go func() {
|
|
r.active.Inc()
|
|
defer r.active.Dec()
|
|
|
|
err := CopySignalStreamToMessageChannel[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse](
|
|
stream,
|
|
resChan,
|
|
signalResponseMessageReader{},
|
|
r.config,
|
|
)
|
|
l.Debugw("signal stream closed", "error", err)
|
|
|
|
resChan.Close()
|
|
}()
|
|
|
|
return connectionID, sink, resChan, nil
|
|
}
|
|
|
|
type signalRequestMessageWriter struct{}
|
|
|
|
func (e signalRequestMessageWriter) Write(seq uint64, close bool, msgs []proto.Message) *rpc.RelaySignalRequest {
|
|
r := &rpc.RelaySignalRequest{
|
|
Seq: seq,
|
|
Requests: make([]*livekit.SignalRequest, 0, len(msgs)),
|
|
Close: close,
|
|
}
|
|
for _, m := range msgs {
|
|
r.Requests = append(r.Requests, m.(*livekit.SignalRequest))
|
|
}
|
|
return r
|
|
}
|
|
|
|
type signalResponseMessageReader struct{}
|
|
|
|
func (e signalResponseMessageReader) Read(rm *rpc.RelaySignalResponse) ([]proto.Message, error) {
|
|
msgs := make([]proto.Message, 0, len(rm.Responses))
|
|
for _, m := range rm.Responses {
|
|
msgs = append(msgs, m)
|
|
}
|
|
return msgs, nil
|
|
}
|
|
|
|
type RelaySignalMessage interface {
|
|
proto.Message
|
|
GetSeq() uint64
|
|
GetClose() bool
|
|
}
|
|
|
|
type SignalMessageWriter[SendType RelaySignalMessage] interface {
|
|
Write(seq uint64, close bool, msgs []proto.Message) SendType
|
|
}
|
|
|
|
type SignalMessageReader[RecvType RelaySignalMessage] interface {
|
|
Read(msg RecvType) ([]proto.Message, error)
|
|
}
|
|
|
|
func CopySignalStreamToMessageChannel[SendType, RecvType RelaySignalMessage](
|
|
stream psrpc.Stream[SendType, RecvType],
|
|
ch *MessageChannel,
|
|
reader SignalMessageReader[RecvType],
|
|
config config.SignalRelayConfig,
|
|
) error {
|
|
r := &signalMessageReader[SendType, RecvType]{
|
|
reader: reader,
|
|
config: config,
|
|
}
|
|
for msg := range stream.Channel() {
|
|
res, err := r.Read(msg)
|
|
if err != nil {
|
|
prometheus.MessageCounter.WithLabelValues("signal", "failure").Add(1)
|
|
return err
|
|
}
|
|
|
|
for _, r := range res {
|
|
if err = ch.WriteMessage(r); err != nil {
|
|
prometheus.MessageCounter.WithLabelValues("signal", "failure").Add(1)
|
|
return err
|
|
}
|
|
prometheus.MessageCounter.WithLabelValues("signal", "success").Add(1)
|
|
}
|
|
|
|
if msg.GetClose() {
|
|
return stream.Close(nil)
|
|
}
|
|
}
|
|
return stream.Err()
|
|
}
|
|
|
|
type signalMessageReader[SendType, RecvType RelaySignalMessage] struct {
|
|
seq uint64
|
|
reader SignalMessageReader[RecvType]
|
|
config config.SignalRelayConfig
|
|
}
|
|
|
|
func (r *signalMessageReader[SendType, RecvType]) Read(msg RecvType) ([]proto.Message, error) {
|
|
res, err := r.reader.Read(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if r.seq < msg.GetSeq() {
|
|
return nil, ErrSignalMessageDropped
|
|
}
|
|
if r.seq > msg.GetSeq() {
|
|
n := int(r.seq - msg.GetSeq())
|
|
if n > len(res) {
|
|
n = len(res)
|
|
}
|
|
res = res[n:]
|
|
}
|
|
r.seq += uint64(len(res))
|
|
|
|
return res, nil
|
|
}
|
|
|
|
type SignalSinkParams[SendType, RecvType RelaySignalMessage] struct {
|
|
Stream psrpc.Stream[SendType, RecvType]
|
|
Logger logger.Logger
|
|
Config config.SignalRelayConfig
|
|
Writer SignalMessageWriter[SendType]
|
|
CloseOnFailure bool
|
|
BlockOnClose bool
|
|
ConnectionID livekit.ConnectionID
|
|
}
|
|
|
|
func NewSignalMessageSink[SendType, RecvType RelaySignalMessage](params SignalSinkParams[SendType, RecvType]) MessageSink {
|
|
return &signalMessageSink[SendType, RecvType]{
|
|
SignalSinkParams: params,
|
|
}
|
|
}
|
|
|
|
type signalMessageSink[SendType, RecvType RelaySignalMessage] struct {
|
|
SignalSinkParams[SendType, RecvType]
|
|
|
|
mu sync.Mutex
|
|
seq uint64
|
|
queue []proto.Message
|
|
writing bool
|
|
draining bool
|
|
}
|
|
|
|
func (s *signalMessageSink[SendType, RecvType]) Close() {
|
|
s.mu.Lock()
|
|
s.draining = true
|
|
if !s.writing {
|
|
s.writing = true
|
|
go s.write()
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
// conditionally block while closing to wait for outgoing messages to drain
|
|
//
|
|
// on media the signal sink shares a goroutine with other signal connection
|
|
// attempts from the same participant so blocking delays establishing new
|
|
// sessions during reconnect.
|
|
//
|
|
// on controller closing without waiting for the outstanding messages to
|
|
// drain causes leave messages to be dropped from the write queue. when
|
|
// this happens other participants in the room aren't notified about the
|
|
// departure until the participant times out.
|
|
if s.BlockOnClose {
|
|
<-s.Stream.Context().Done()
|
|
}
|
|
}
|
|
|
|
func (s *signalMessageSink[SendType, RecvType]) IsClosed() bool {
|
|
return s.Stream.Err() != nil
|
|
}
|
|
|
|
func (s *signalMessageSink[SendType, RecvType]) write() {
|
|
interval := s.Config.MinRetryInterval
|
|
deadline := time.Now().Add(s.Config.RetryTimeout)
|
|
var err error
|
|
|
|
s.mu.Lock()
|
|
for {
|
|
close := s.draining
|
|
if (!close && len(s.queue) == 0) || s.IsClosed() {
|
|
break
|
|
}
|
|
msg, n := s.Writer.Write(s.seq, close, s.queue), len(s.queue)
|
|
s.mu.Unlock()
|
|
|
|
err = s.Stream.Send(msg, psrpc.WithTimeout(interval))
|
|
if err != nil {
|
|
if time.Now().After(deadline) {
|
|
s.Logger.Warnw("could not send signal message", err)
|
|
|
|
s.mu.Lock()
|
|
s.seq += uint64(len(s.queue))
|
|
s.queue = nil
|
|
break
|
|
}
|
|
|
|
interval *= 2
|
|
if interval > s.Config.MaxRetryInterval {
|
|
interval = s.Config.MaxRetryInterval
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
if err == nil {
|
|
interval = s.Config.MinRetryInterval
|
|
deadline = time.Now().Add(s.Config.RetryTimeout)
|
|
|
|
s.seq += uint64(n)
|
|
s.queue = s.queue[n:]
|
|
|
|
if close {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
s.writing = false
|
|
if s.draining {
|
|
s.Stream.Close(nil)
|
|
}
|
|
if err != nil && s.CloseOnFailure {
|
|
s.Stream.Close(ErrSignalWriteFailed)
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *signalMessageSink[SendType, RecvType]) WriteMessage(msg proto.Message) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if err := s.Stream.Err(); err != nil {
|
|
return err
|
|
} else if s.draining {
|
|
return psrpc.ErrStreamClosed
|
|
}
|
|
|
|
s.queue = append(s.queue, msg)
|
|
if !s.writing {
|
|
s.writing = true
|
|
go s.write()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *signalMessageSink[SendType, RecvType]) ConnectionID() livekit.ConnectionID {
|
|
return s.SignalSinkParams.ConnectionID
|
|
}
|