From d9059f4f3b5f957ff19fa0e6d8ef8457befecf76 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Tue, 16 Aug 2022 19:57:41 -0700 Subject: [PATCH] Do not accept websocket connection if response not received (#923) When the instance handling the signal request did not respond to the initial connection, we will fail the connection attempt instead of having it hang forever. --- pkg/service/auth.go | 12 ++++---- pkg/service/rtcservice.go | 62 ++++++++++++++++++++++++++++++++++----- pkg/service/utils.go | 8 ++--- 3 files changed, 65 insertions(+), 17 deletions(-) diff --git a/pkg/service/auth.go b/pkg/service/auth.go index 014641800..da3803f45 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -21,7 +21,9 @@ const ( type grantsKey struct{} var ( - ErrPermissionDenied = errors.New("permissions denied") + ErrPermissionDenied = errors.New("permissions denied") + ErrMissingAuthorization = errors.New("invalid authorization header. Must start with " + bearerPrefix) + ErrInvalidAuthorizationToken = errors.New("invalid authorization token") ) // authentication middleware @@ -45,7 +47,7 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, if authHeader != "" { if !strings.HasPrefix(authHeader, bearerPrefix) { - handleError(w, http.StatusUnauthorized, "invalid authorization header. Must start with "+bearerPrefix) + handleError(w, http.StatusUnauthorized, ErrMissingAuthorization) return } @@ -58,19 +60,19 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, if authToken != "" { v, err := auth.ParseAPIToken(authToken) if err != nil { - handleError(w, http.StatusUnauthorized, "invalid authorization token") + handleError(w, http.StatusUnauthorized, ErrInvalidAuthorizationToken) return } secret := m.provider.GetSecret(v.APIKey()) if secret == "" { - handleError(w, http.StatusUnauthorized, "invalid API key") + handleError(w, http.StatusUnauthorized, errors.New("invalid API key: "+v.APIKey())) return } grants, err := v.Verify(secret) if err != nil { - handleError(w, http.StatusUnauthorized, "invalid token: "+authToken+", error: "+err.Error()) + handleError(w, http.StatusUnauthorized, errors.New("invalid token: "+authToken+", error: "+err.Error())) return } diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index f954a2cc6..b3a4bdaa1 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -2,11 +2,13 @@ package service import ( "context" + "errors" "fmt" "io" "net/http" "strconv" "strings" + "time" "github.com/gorilla/websocket" "github.com/sebest/xff" @@ -65,7 +67,7 @@ func NewRTCService( func (s *RTCService) Validate(w http.ResponseWriter, r *http.Request) { _, _, code, err := s.validate(r) if err != nil { - handleError(w, code, err.Error()) + handleError(w, code, err) return } _, _ = w.Write([]byte("success")) @@ -148,18 +150,25 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { roomName, pi, code, err := s.validate(r) if err != nil { - handleError(w, code, err.Error()) + handleError(w, code, err) return } + // for logger + loggerFields := []interface{}{ + "participant", pi.Identity, + "room", roomName, + "remote", false, + } + // when auto create is disabled, we'll check to ensure it's already created if !s.config.Room.AutoCreate { _, err := s.store.LoadRoom(context.Background(), roomName) if err == ErrRoomNotFound { - handleError(w, 404, err.Error()) + handleError(w, 404, err, loggerFields...) return } else if err != nil { - handleError(w, 500, err.Error()) + handleError(w, 500, err, loggerFields...) return } } @@ -168,7 +177,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { rm, err := s.roomAllocator.CreateRoom(r.Context(), &livekit.CreateRoomRequest{Name: string(roomName)}) if err != nil { prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "create_room").Add(1) - handleError(w, http.StatusInternalServerError, err.Error()) + handleError(w, http.StatusInternalServerError, err, loggerFields...) return } @@ -176,7 +185,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { connId, reqSink, resSource, err := s.router.StartParticipantSignal(r.Context(), roomName, pi) if err != nil { prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "start_signal").Add(1) - handleError(w, http.StatusInternalServerError, "could not start session: "+err.Error()) + handleError(w, http.StatusInternalServerError, err, loggerFields...) return } @@ -186,6 +195,17 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { "", false, ) + + // wait for the first message before upgrading to websocket. If no one is + // responding to our connection attempt, we should terminate the connection + // instead of waiting forever on the WebSocket + initialResponse, err := readInitialResponse(resSource, 5*time.Second) + if err != nil { + prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "initial_response").Add(1) + handleError(w, http.StatusInternalServerError, err, loggerFields...) + return + } + done := make(chan struct{}) // function exits when websocket terminates, it'll close the event reading off of response sink as well defer func() { @@ -199,11 +219,16 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "upgrade").Add(1) - pLogger.Warnw("could not upgrade to WS", err) - handleError(w, http.StatusInternalServerError, err.Error()) + handleError(w, http.StatusInternalServerError, err, loggerFields...) return } + + // websocket established sigConn := NewWSSignalConnection(conn) + if err := sigConn.WriteResponse(initialResponse); err != nil { + pLogger.Warnw("could not write initial response", err) + return + } prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "success", "").Add(1) pLogger.Infow("new client WS connected", "connID", connId) @@ -331,3 +356,24 @@ func (s *RTCService) ParseClientInfo(r *http.Request) *livekit.ClientInfo { return ci } + +func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*livekit.SignalResponse, error) { + responseTimer := time.NewTimer(timeout) + defer responseTimer.Stop() + for { + select { + case <-responseTimer.C: + return nil, errors.New("timed out while waiting for signal response") + case msg := <-source.ReadChan(): + if msg == nil { + return nil, errors.New("connection closed by media") + } + res, ok := msg.(*livekit.SignalResponse) + if !ok { + return nil, fmt.Errorf("unexpected message type: %T", msg) + } + return res, nil + } + } + +} diff --git a/pkg/service/utils.go b/pkg/service/utils.go index c47d1b61b..af5be9bc9 100644 --- a/pkg/service/utils.go +++ b/pkg/service/utils.go @@ -7,11 +7,11 @@ import ( "github.com/livekit/protocol/logger" ) -func handleError(w http.ResponseWriter, status int, msg string) { - // GetLogger already with extra depth 1 - logger.GetLogger().V(1).Info("error handling request", "error", msg, "status", status) +func handleError(w http.ResponseWriter, status int, err error, keysAndValues ...interface{}) { + keysAndValues = append(keysAndValues, "status", status) + logger.Warnw("error handling request", err, keysAndValues...) w.WriteHeader(status) - _, _ = w.Write([]byte(msg)) + _, _ = w.Write([]byte(err.Error())) } func boolValue(s string) bool {