diff --git a/pkg/rtc/signalling/signalhandlerv2.go b/pkg/rtc/signalling/signalhandlerv2.go index b0f722da3..eda49d55a 100644 --- a/pkg/rtc/signalling/signalhandlerv2.go +++ b/pkg/rtc/signalling/signalhandlerv2.go @@ -70,14 +70,34 @@ func (s *signalhandlerv2) HandleRequest(msg proto.Message) error { switch msg := req.GetMessage().(type) { case *livekit.Signalv2WireMessage_Envelope: for _, clientMessage := range msg.Envelope.ClientMessages { - // SIGNAL-V2-TODO: cannot do this comparison for very first message - if clientMessage.Sequencer.MessageId != s.lastProcessedRemoteMessageId.Load()+1 { + /* SIGNALLING-V2-TODO: uncommment once remote side sends proper messageId + sequencer := clientMessage.GetSequencer() + if sequencer == nil || sequencer.MessageId == 0 { + s.params.Logger.Warnw( + "skipping message without sequencer", nil, + "messageType", fmt.Sprintf("%T", clientMessage), + ) + continue + } + + lprmi := s.lastProcessedRemoteMessageId.Load() + if sequencer.MessageId <= lprmi { + s.params.Logger.Infow( + "duplicate in message stream", + "last", lprmi, + "current", clientMessage.Sequencer.MessageId, + ) + continue + } + + if lprmi != 0 && sequencer.MessageId != lprmi+1 { s.params.Logger.Infow( "gap in message stream", - "last", s.lastProcessedRemoteMessageId.Load(), + "last", lprmi, "current", clientMessage.Sequencer.MessageId, ) } + */ switch payload := clientMessage.GetMessage().(type) { case *livekit.Signalv2ClientMessage_PublisherSdp: @@ -87,9 +107,11 @@ func (s *signalhandlerv2) HandleRequest(msg proto.Message) error { s.params.Participant.HandleAnswer(protosignalling.FromProtoSessionDescription(payload.SubscriberSdp)) } - s.lastProcessedRemoteMessageId.Store(clientMessage.Sequencer.MessageId) - s.params.Signalling.AckMessageId(clientMessage.Sequencer.LastProcessedRemoteMessageId) - s.params.Signalling.SetLastProcessedRemoteMessageId(clientMessage.Sequencer.MessageId) + /* SIGNALLING-V2-TODO: uncomment once sequencer is implemented on both sides + s.lastProcessedRemoteMessageId.Store(sequencer.MessageId) + s.params.Signalling.AckMessageId(sequencer.LastProcessedRemoteMessageId) + s.params.Signalling.SetLastProcessedRemoteMessageId(sequencer.MessageId) + */ } case *livekit.Signalv2WireMessage_Fragment: diff --git a/pkg/service/roommanager_service.go b/pkg/service/roommanager_service.go index a8b42dfc9..ef56c52e7 100644 --- a/pkg/service/roommanager_service.go +++ b/pkg/service/roommanager_service.go @@ -213,11 +213,9 @@ func (s signalv2ParticipantService) RelaySignalv2Participant(ctx context.Context } var wireMessage *livekit.Signalv2WireMessage - pending := lp.SignalPendingMessages() - if pending != nil { + if pending := lp.SignalPendingMessages(); pending != nil { var ok bool - wireMessage, ok = pending.(*livekit.Signalv2WireMessage) - if !ok { + if wireMessage, ok = pending.(*livekit.Signalv2WireMessage); !ok { return nil, ErrInvalidMessageType } } diff --git a/pkg/service/rtcv2service.go b/pkg/service/rtcv2service.go index 31d38b73e..9f79e9cee 100644 --- a/pkg/service/rtcv2service.go +++ b/pkg/service/rtcv2service.go @@ -24,6 +24,7 @@ import ( "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" @@ -38,6 +39,7 @@ var ( const ( cRTCv2Path = "/rtc/v2" + cRTCv2ValidatePath = "/rtc/v2/validate" cRTCv2ParticipantIDPath = "/rtc/v2/{participant_id}" ) @@ -70,6 +72,7 @@ func NewRTCv2Service( func (s *RTCv2Service) SetupRoutes(mux *http.ServeMux) { mux.HandleFunc("POST "+cRTCv2Path, s.handlePost) + mux.HandleFunc("GET "+cRTCv2Path, s.validate) mux.HandleFunc("PATCH "+cRTCv2ParticipantIDPath, s.handleParticipantPatch) } @@ -119,16 +122,9 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) { return } - body, err := ioutil.ReadAll(r.Body) + wireMessage, err := getWireMessage(r) if err != nil { - HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not read request body: %w", err)) - return - } - - wireMessage := &livekit.Signalv2WireMessage{} - err = proto.Unmarshal(body, wireMessage) - if err != nil { - HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not unmarshal request: %w", err)) + HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err)) return } @@ -138,7 +134,7 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) { switch clientMessage := innerMsg.GetMessage().(type) { case *livekit.Signalv2ClientMessage_ConnectRequest: roomName, participantIdentity, rscr, code, err := s.validateInternal( - logger.GetLogger(), + utils.GetLogger(r.Context()), r, clientMessage.ConnectRequest, ) @@ -201,7 +197,7 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) { } case *livekit.Signalv2WireMessage_Fragment: - logger.Errorw("signalv2 bad request", errFragmentsInHTTP) + utils.GetLogger(r.Context()).Errorw("signalv2 bad request", errFragmentsInHTTP) HandleErrorJson(w, r, http.StatusBadRequest, errFragmentsInHTTP) return } @@ -209,6 +205,34 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } +func (s *RTCv2Service) validate(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-type") != "application/x-protobuf" { + HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("unsupported content-type: %s", r.Header.Get("Content-type"))) + return + } + + wireMessage, err := getWireMessage(r) + if err != nil { + HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err)) + return + } + + connectRequest := getConnectRequest(wireMessage) + if connectRequest == nil { + HandleErrorJson(w, r, http.StatusBadRequest, errors.New("no connect request")) + return + } + + _, _, _, code, err := s.validateInternal(utils.GetLogger(r.Context()), r, connectRequest) + if err != nil { + HandleErrorJson(w, r, code, err) + return + } + + _, _ = w.Write([]byte("success")) + w.WriteHeader(http.StatusOK) +} + func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Content-type") != "application/x-protobuf" { HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("unsupported content-type: %s", r.Header.Get("Content-type"))) @@ -243,16 +267,9 @@ func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Req return } - body, err := ioutil.ReadAll(r.Body) + wireMessage, err := getWireMessage(r) if err != nil { - HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not read request body: %w", err)) - return - } - - wireMessage := &livekit.Signalv2WireMessage{} - err = proto.Unmarshal(body, wireMessage) - if err != nil { - HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not unmarshal request: %w", err)) + HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err)) return } @@ -303,3 +320,34 @@ func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Req w.WriteHeader(http.StatusOK) } + +// --------------------------------------- + +func getWireMessage(r *http.Request) (*livekit.Signalv2WireMessage, error) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + + wireMessage := &livekit.Signalv2WireMessage{} + err = proto.Unmarshal(body, wireMessage) + if err != nil { + return nil, err + } + + return wireMessage, nil +} + +func getConnectRequest(wireMessage *livekit.Signalv2WireMessage) *livekit.ConnectRequest { + switch msg := wireMessage.GetMessage().(type) { + case *livekit.Signalv2WireMessage_Envelope: + for _, innerMsg := range msg.Envelope.GetClientMessages() { + switch clientMessage := innerMsg.GetMessage().(type) { + case *livekit.Signalv2ClientMessage_ConnectRequest: + return clientMessage.ConnectRequest + } + } + } + + return nil +}