Do not accept websocket connection if response not received (#923)

When the instance handling the signal request did not respond to the
initial connection, we will fail the connection attempt instead of
having it hang forever.
This commit is contained in:
David Zhao
2022-08-16 19:57:41 -07:00
committed by GitHub
parent 3f53dea223
commit d9059f4f3b
3 changed files with 65 additions and 17 deletions
+7 -5
View File
@@ -21,7 +21,9 @@ const (
type grantsKey struct{}
var (
ErrPermissionDenied = errors.New("permissions denied")
ErrPermissionDenied = errors.New("permissions denied")
ErrMissingAuthorization = errors.New("invalid authorization header. Must start with " + bearerPrefix)
ErrInvalidAuthorizationToken = errors.New("invalid authorization token")
)
// authentication middleware
@@ -45,7 +47,7 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request,
if authHeader != "" {
if !strings.HasPrefix(authHeader, bearerPrefix) {
handleError(w, http.StatusUnauthorized, "invalid authorization header. Must start with "+bearerPrefix)
handleError(w, http.StatusUnauthorized, ErrMissingAuthorization)
return
}
@@ -58,19 +60,19 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request,
if authToken != "" {
v, err := auth.ParseAPIToken(authToken)
if err != nil {
handleError(w, http.StatusUnauthorized, "invalid authorization token")
handleError(w, http.StatusUnauthorized, ErrInvalidAuthorizationToken)
return
}
secret := m.provider.GetSecret(v.APIKey())
if secret == "" {
handleError(w, http.StatusUnauthorized, "invalid API key")
handleError(w, http.StatusUnauthorized, errors.New("invalid API key: "+v.APIKey()))
return
}
grants, err := v.Verify(secret)
if err != nil {
handleError(w, http.StatusUnauthorized, "invalid token: "+authToken+", error: "+err.Error())
handleError(w, http.StatusUnauthorized, errors.New("invalid token: "+authToken+", error: "+err.Error()))
return
}
+54 -8
View File
@@ -2,11 +2,13 @@ package service
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/sebest/xff"
@@ -65,7 +67,7 @@ func NewRTCService(
func (s *RTCService) Validate(w http.ResponseWriter, r *http.Request) {
_, _, code, err := s.validate(r)
if err != nil {
handleError(w, code, err.Error())
handleError(w, code, err)
return
}
_, _ = w.Write([]byte("success"))
@@ -148,18 +150,25 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
roomName, pi, code, err := s.validate(r)
if err != nil {
handleError(w, code, err.Error())
handleError(w, code, err)
return
}
// for logger
loggerFields := []interface{}{
"participant", pi.Identity,
"room", roomName,
"remote", false,
}
// when auto create is disabled, we'll check to ensure it's already created
if !s.config.Room.AutoCreate {
_, err := s.store.LoadRoom(context.Background(), roomName)
if err == ErrRoomNotFound {
handleError(w, 404, err.Error())
handleError(w, 404, err, loggerFields...)
return
} else if err != nil {
handleError(w, 500, err.Error())
handleError(w, 500, err, loggerFields...)
return
}
}
@@ -168,7 +177,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rm, err := s.roomAllocator.CreateRoom(r.Context(), &livekit.CreateRoomRequest{Name: string(roomName)})
if err != nil {
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "create_room").Add(1)
handleError(w, http.StatusInternalServerError, err.Error())
handleError(w, http.StatusInternalServerError, err, loggerFields...)
return
}
@@ -176,7 +185,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
connId, reqSink, resSource, err := s.router.StartParticipantSignal(r.Context(), roomName, pi)
if err != nil {
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "start_signal").Add(1)
handleError(w, http.StatusInternalServerError, "could not start session: "+err.Error())
handleError(w, http.StatusInternalServerError, err, loggerFields...)
return
}
@@ -186,6 +195,17 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
"",
false,
)
// wait for the first message before upgrading to websocket. If no one is
// responding to our connection attempt, we should terminate the connection
// instead of waiting forever on the WebSocket
initialResponse, err := readInitialResponse(resSource, 5*time.Second)
if err != nil {
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "initial_response").Add(1)
handleError(w, http.StatusInternalServerError, err, loggerFields...)
return
}
done := make(chan struct{})
// function exits when websocket terminates, it'll close the event reading off of response sink as well
defer func() {
@@ -199,11 +219,16 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "error", "upgrade").Add(1)
pLogger.Warnw("could not upgrade to WS", err)
handleError(w, http.StatusInternalServerError, err.Error())
handleError(w, http.StatusInternalServerError, err, loggerFields...)
return
}
// websocket established
sigConn := NewWSSignalConnection(conn)
if err := sigConn.WriteResponse(initialResponse); err != nil {
pLogger.Warnw("could not write initial response", err)
return
}
prometheus.ServiceOperationCounter.WithLabelValues("signal_ws", "success", "").Add(1)
pLogger.Infow("new client WS connected", "connID", connId)
@@ -331,3 +356,24 @@ func (s *RTCService) ParseClientInfo(r *http.Request) *livekit.ClientInfo {
return ci
}
func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*livekit.SignalResponse, error) {
responseTimer := time.NewTimer(timeout)
defer responseTimer.Stop()
for {
select {
case <-responseTimer.C:
return nil, errors.New("timed out while waiting for signal response")
case msg := <-source.ReadChan():
if msg == nil {
return nil, errors.New("connection closed by media")
}
res, ok := msg.(*livekit.SignalResponse)
if !ok {
return nil, fmt.Errorf("unexpected message type: %T", msg)
}
return res, nil
}
}
}
+4 -4
View File
@@ -7,11 +7,11 @@ import (
"github.com/livekit/protocol/logger"
)
func handleError(w http.ResponseWriter, status int, msg string) {
// GetLogger already with extra depth 1
logger.GetLogger().V(1).Info("error handling request", "error", msg, "status", status)
func handleError(w http.ResponseWriter, status int, err error, keysAndValues ...interface{}) {
keysAndValues = append(keysAndValues, "status", status)
logger.Warnw("error handling request", err, keysAndValues...)
w.WriteHeader(status)
_, _ = w.Write([]byte(msg))
_, _ = w.Write([]byte(err.Error()))
}
func boolValue(s string) bool {