mirror of
https://github.com/livekit/livekit.git
synced 2026-05-14 20:35:27 +00:00
Do not accept websocket connection if response not received (#923)
When the instance handling the signal request did not respond to the initial connection, we will fail the connection attempt instead of having it hang forever.
This commit is contained in:
+7
-5
@@ -21,7 +21,9 @@ const (
|
||||
type grantsKey struct{}
|
||||
|
||||
var (
|
||||
ErrPermissionDenied = errors.New("permissions denied")
|
||||
ErrPermissionDenied = errors.New("permissions denied")
|
||||
ErrMissingAuthorization = errors.New("invalid authorization header. Must start with " + bearerPrefix)
|
||||
ErrInvalidAuthorizationToken = errors.New("invalid authorization token")
|
||||
)
|
||||
|
||||
// authentication middleware
|
||||
@@ -45,7 +47,7 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
if authHeader != "" {
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
handleError(w, http.StatusUnauthorized, "invalid authorization header. Must start with "+bearerPrefix)
|
||||
handleError(w, http.StatusUnauthorized, ErrMissingAuthorization)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,19 +60,19 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request,
|
||||
if authToken != "" {
|
||||
v, err := auth.ParseAPIToken(authToken)
|
||||
if err != nil {
|
||||
handleError(w, http.StatusUnauthorized, "invalid authorization token")
|
||||
handleError(w, http.StatusUnauthorized, ErrInvalidAuthorizationToken)
|
||||
return
|
||||
}
|
||||
|
||||
secret := m.provider.GetSecret(v.APIKey())
|
||||
if secret == "" {
|
||||
handleError(w, http.StatusUnauthorized, "invalid API key")
|
||||
handleError(w, http.StatusUnauthorized, errors.New("invalid API key: "+v.APIKey()))
|
||||
return
|
||||
}
|
||||
|
||||
grants, err := v.Verify(secret)
|
||||
if err != nil {
|
||||
handleError(w, http.StatusUnauthorized, "invalid token: "+authToken+", error: "+err.Error())
|
||||
handleError(w, http.StatusUnauthorized, errors.New("invalid token: "+authToken+", error: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sebest/xff"
|
||||
@@ -65,7 +67,7 @@ func NewRTCService(
|
||||
func (s *RTCService) Validate(w http.ResponseWriter, r *http.Request) {
|
||||
_, _, code, err := s.validate(r)
|
||||
if err != nil {
|
||||
handleError(w, code, err.Error())
|
||||
handleError(w, code, err)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte("success"))
|
||||
@@ -148,18 +150,25 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
roomName, pi, code, err := s.validate(r)
|
||||
if err != nil {
|
||||
handleError(w, code, err.Error())
|
||||
handleError(w, code, err)
|
||||
return
|
||||
}
|
||||
|
||||
// for logger
|
||||
loggerFields := []interface{}{
|
||||
"participant", pi.Identity,
|
||||
"room", roomName,
|
||||
"remote", false,
|
||||
}
|
||||
|
||||
// when auto create is disabled, we'll check to ensure it's already created
|
||||
if !s.config.Room.AutoCreate {
|
||||
_, err := s.store.LoadRoom(context.Background(), roomName)
|
||||
if err == ErrRoomNotFound {
|
||||
handleError(w, 404, err.Error())
|
||||
handleError(w, 404, err, loggerFields...)
|
||||
return
|
||||
} else if err != nil {
|
||||
handleError(w, 500, err.Error())
|
||||
handleError(w, 500, err, loggerFields...)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -168,7 +177,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rm, err := s.roomAllocator.CreateRoom(r.Context(), &livekit.CreateRoomRequest{Name: string(roomName)})
|
||||
if err != nil {
|
||||
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "create_room").Add(1)
|
||||
handleError(w, http.StatusInternalServerError, err.Error())
|
||||
handleError(w, http.StatusInternalServerError, err, loggerFields...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -176,7 +185,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
connId, reqSink, resSource, err := s.router.StartParticipantSignal(r.Context(), roomName, pi)
|
||||
if err != nil {
|
||||
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "start_signal").Add(1)
|
||||
handleError(w, http.StatusInternalServerError, "could not start session: "+err.Error())
|
||||
handleError(w, http.StatusInternalServerError, err, loggerFields...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,6 +195,17 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
"",
|
||||
false,
|
||||
)
|
||||
|
||||
// wait for the first message before upgrading to websocket. If no one is
|
||||
// responding to our connection attempt, we should terminate the connection
|
||||
// instead of waiting forever on the WebSocket
|
||||
initialResponse, err := readInitialResponse(resSource, 5*time.Second)
|
||||
if err != nil {
|
||||
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "initial_response").Add(1)
|
||||
handleError(w, http.StatusInternalServerError, err, loggerFields...)
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
// function exits when websocket terminates, it'll close the event reading off of response sink as well
|
||||
defer func() {
|
||||
@@ -199,11 +219,16 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "upgrade").Add(1)
|
||||
pLogger.Warnw("could not upgrade to WS", err)
|
||||
handleError(w, http.StatusInternalServerError, err.Error())
|
||||
handleError(w, http.StatusInternalServerError, err, loggerFields...)
|
||||
return
|
||||
}
|
||||
|
||||
// websocket established
|
||||
sigConn := NewWSSignalConnection(conn)
|
||||
if err := sigConn.WriteResponse(initialResponse); err != nil {
|
||||
pLogger.Warnw("could not write initial response", err)
|
||||
return
|
||||
}
|
||||
|
||||
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "success", "").Add(1)
|
||||
pLogger.Infow("new client WS connected", "connID", connId)
|
||||
@@ -331,3 +356,24 @@ func (s *RTCService) ParseClientInfo(r *http.Request) *livekit.ClientInfo {
|
||||
|
||||
return ci
|
||||
}
|
||||
|
||||
func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*livekit.SignalResponse, error) {
|
||||
responseTimer := time.NewTimer(timeout)
|
||||
defer responseTimer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-responseTimer.C:
|
||||
return nil, errors.New("timed out while waiting for signal response")
|
||||
case msg := <-source.ReadChan():
|
||||
if msg == nil {
|
||||
return nil, errors.New("connection closed by media")
|
||||
}
|
||||
res, ok := msg.(*livekit.SignalResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected message type: %T", msg)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"github.com/livekit/protocol/logger"
|
||||
)
|
||||
|
||||
func handleError(w http.ResponseWriter, status int, msg string) {
|
||||
// GetLogger already with extra depth 1
|
||||
logger.GetLogger().V(1).Info("error handling request", "error", msg, "status", status)
|
||||
func handleError(w http.ResponseWriter, status int, err error, keysAndValues ...interface{}) {
|
||||
keysAndValues = append(keysAndValues, "status", status)
|
||||
logger.Warnw("error handling request", err, keysAndValues...)
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write([]byte(msg))
|
||||
_, _ = w.Write([]byte(err.Error()))
|
||||
}
|
||||
|
||||
func boolValue(s string) bool {
|
||||
|
||||
Reference in New Issue
Block a user