diff --git a/pkg/routing/signal.go b/pkg/routing/signal.go index d1a759199..20077b233 100644 --- a/pkg/routing/signal.go +++ b/pkg/routing/signal.go @@ -180,7 +180,6 @@ func CopySignalStreamToMessageChannel[SendType, RecvType RelaySignalMessage]( config: config, } for msg := range stream.Channel() { - var res []proto.Message res, err := r.Read(msg) if err != nil { return err @@ -280,20 +279,17 @@ func (s *signalMessageSink[SendType, RecvType]) nextMessage() (msg SendType, n i func (s *signalMessageSink[SendType, RecvType]) write() { interval := s.Config.MinRetryInterval deadline := time.Now().Add(s.Config.RetryTimeout) + var err error s.mu.Lock() for { msg, n := s.nextMessage() if n == 0 || s.IsClosed() { - if s.draining { - s.Stream.Close(nil) - } - s.writing = false break } s.mu.Unlock() - err := s.Stream.Send(msg, psrpc.WithTimeout(interval)) + err = s.Stream.Send(msg, psrpc.WithTimeout(interval)) if err != nil { if time.Now().After(deadline) { s.Logger.Warnw("could not send signal message", err) @@ -301,12 +297,7 @@ func (s *signalMessageSink[SendType, RecvType]) write() { s.mu.Lock() s.seq += uint64(len(s.queue)) s.queue = nil - - if s.CloseOnFailure { - s.Stream.Close(ErrSignalFailed) - } - s.mu.Unlock() - return + break } interval *= 2 @@ -324,6 +315,14 @@ func (s *signalMessageSink[SendType, RecvType]) write() { s.queue = s.queue[n:] } } + + s.writing = false + if s.draining { + s.Stream.Close(nil) + } + if err != nil && s.CloseOnFailure { + s.Stream.Close(ErrSignalFailed) + } s.mu.Unlock() } diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index d6f622239..16639c9d5 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -190,10 +190,11 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { // give it a few attempts to start session var cr connectionResult + var initialResponse *livekit.SignalResponse for i := 0; i < 3; i++ { connectionTimeout := 3 * time.Second * time.Duration(i+1) ctx := utils.ContextWithAttempt(r.Context(), i) - cr, err = s.startConnection(ctx, roomName, pi, connectionTimeout) + cr, initialResponse, err = s.startConnection(ctx, roomName, pi, connectionTimeout) if err == nil { break } @@ -210,8 +211,8 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { prometheus.IncrementParticipantJoin(1) - if !pi.Reconnect && cr.InitialResponse.GetJoin() != nil { - pi.ID = livekit.ParticipantID(cr.InitialResponse.GetJoin().GetParticipant().GetSid()) + if !pi.Reconnect && initialResponse.GetJoin() != nil { + pi.ID = livekit.ParticipantID(initialResponse.GetJoin().GetParticipant().GetSid()) } var signalStats *telemetry.BytesTrackStats @@ -251,7 +252,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { // websocket established sigConn := NewWSSignalConnection(conn) - if count, err := sigConn.WriteResponse(cr.InitialResponse); err != nil { + if count, err := sigConn.WriteResponse(initialResponse); err != nil { pLogger.Warnw("could not write initial response", err) return } else { @@ -301,14 +302,6 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { pLogger.Debugw("sending answer", "answer", m) } - if pi.ID == "" && cr.InitialResponse.GetJoin() != nil { - pi.ID = livekit.ParticipantID(cr.InitialResponse.GetJoin().GetParticipant().GetSid()) - signalStats = telemetry.NewBytesTrackStats( - telemetry.BytesTrackIDForParticipantID(telemetry.BytesTrackTypeSignal, pi.ID), - pi.ID, - s.telemetry) - } - if count, err := sigConn.WriteResponse(res); err != nil { pLogger.Warnw("error writing to websocket", err) return @@ -443,38 +436,37 @@ func (s *RTCService) ParseClientInfo(r *http.Request) *livekit.ClientInfo { } type connectionResult struct { - Room *livekit.Room - ConnectionID livekit.ConnectionID - RequestSink routing.MessageSink - ResponseSource routing.MessageSource - InitialResponse *livekit.SignalResponse + Room *livekit.Room + ConnectionID livekit.ConnectionID + RequestSink routing.MessageSink + ResponseSource routing.MessageSource } -func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomName, pi routing.ParticipantInit, timeout time.Duration) (connectionResult, error) { +func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomName, pi routing.ParticipantInit, timeout time.Duration) (connectionResult, *livekit.SignalResponse, error) { var cr connectionResult var err error cr.Room, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)}) if err != nil { - return cr, err + return cr, nil, err } // this needs to be started first *before* using router functions on this node cr.ConnectionID, cr.RequestSink, cr.ResponseSource, err = s.router.StartParticipantSignal(ctx, roomName, pi) if err != nil { - return cr, err + return cr, nil, err } // wait for the first message before upgrading to websocket. If no one is // responding to our connection attempt, we should terminate the connection // instead of waiting forever on the WebSocket - cr.InitialResponse, err = readInitialResponse(cr.ResponseSource, timeout) + initialResponse, err := readInitialResponse(cr.ResponseSource, timeout) if err != nil { // close the connection to avoid leaking cr.RequestSink.Close() cr.ResponseSource.Close() - return cr, err + return cr, nil, err } - return cr, nil + return cr, initialResponse, nil } func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*livekit.SignalResponse, error) {