Validation end point for v2 signalling. (#3811)

* WIP

* stricter check

* WIP

* WIP

* clean up
This commit is contained in:
Raja Subramanian
2025-07-23 11:59:30 +05:30
committed by GitHub
parent f2f595f448
commit b20db94dc9
3 changed files with 98 additions and 30 deletions
+28 -6
View File
@@ -70,14 +70,34 @@ func (s *signalhandlerv2) HandleRequest(msg proto.Message) error {
switch msg := req.GetMessage().(type) {
case *livekit.Signalv2WireMessage_Envelope:
for _, clientMessage := range msg.Envelope.ClientMessages {
// SIGNAL-V2-TODO: cannot do this comparison for very first message
if clientMessage.Sequencer.MessageId != s.lastProcessedRemoteMessageId.Load()+1 {
/* SIGNALLING-V2-TODO: uncommment once remote side sends proper messageId
sequencer := clientMessage.GetSequencer()
if sequencer == nil || sequencer.MessageId == 0 {
s.params.Logger.Warnw(
"skipping message without sequencer", nil,
"messageType", fmt.Sprintf("%T", clientMessage),
)
continue
}
lprmi := s.lastProcessedRemoteMessageId.Load()
if sequencer.MessageId <= lprmi {
s.params.Logger.Infow(
"duplicate in message stream",
"last", lprmi,
"current", clientMessage.Sequencer.MessageId,
)
continue
}
if lprmi != 0 && sequencer.MessageId != lprmi+1 {
s.params.Logger.Infow(
"gap in message stream",
"last", s.lastProcessedRemoteMessageId.Load(),
"last", lprmi,
"current", clientMessage.Sequencer.MessageId,
)
}
*/
switch payload := clientMessage.GetMessage().(type) {
case *livekit.Signalv2ClientMessage_PublisherSdp:
@@ -87,9 +107,11 @@ func (s *signalhandlerv2) HandleRequest(msg proto.Message) error {
s.params.Participant.HandleAnswer(protosignalling.FromProtoSessionDescription(payload.SubscriberSdp))
}
s.lastProcessedRemoteMessageId.Store(clientMessage.Sequencer.MessageId)
s.params.Signalling.AckMessageId(clientMessage.Sequencer.LastProcessedRemoteMessageId)
s.params.Signalling.SetLastProcessedRemoteMessageId(clientMessage.Sequencer.MessageId)
/* SIGNALLING-V2-TODO: uncomment once sequencer is implemented on both sides
s.lastProcessedRemoteMessageId.Store(sequencer.MessageId)
s.params.Signalling.AckMessageId(sequencer.LastProcessedRemoteMessageId)
s.params.Signalling.SetLastProcessedRemoteMessageId(sequencer.MessageId)
*/
}
case *livekit.Signalv2WireMessage_Fragment:
+2 -4
View File
@@ -213,11 +213,9 @@ func (s signalv2ParticipantService) RelaySignalv2Participant(ctx context.Context
}
var wireMessage *livekit.Signalv2WireMessage
pending := lp.SignalPendingMessages()
if pending != nil {
if pending := lp.SignalPendingMessages(); pending != nil {
var ok bool
wireMessage, ok = pending.(*livekit.Signalv2WireMessage)
if !ok {
if wireMessage, ok = pending.(*livekit.Signalv2WireMessage); !ok {
return nil, ErrInvalidMessageType
}
}
+68 -20
View File
@@ -24,6 +24,7 @@ import (
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/livekit-server/pkg/utils"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/rpc"
@@ -38,6 +39,7 @@ var (
const (
cRTCv2Path = "/rtc/v2"
cRTCv2ValidatePath = "/rtc/v2/validate"
cRTCv2ParticipantIDPath = "/rtc/v2/{participant_id}"
)
@@ -70,6 +72,7 @@ func NewRTCv2Service(
func (s *RTCv2Service) SetupRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST "+cRTCv2Path, s.handlePost)
mux.HandleFunc("GET "+cRTCv2Path, s.validate)
mux.HandleFunc("PATCH "+cRTCv2ParticipantIDPath, s.handleParticipantPatch)
}
@@ -119,16 +122,9 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) {
return
}
body, err := ioutil.ReadAll(r.Body)
wireMessage, err := getWireMessage(r)
if err != nil {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not read request body: %w", err))
return
}
wireMessage := &livekit.Signalv2WireMessage{}
err = proto.Unmarshal(body, wireMessage)
if err != nil {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not unmarshal request: %w", err))
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err))
return
}
@@ -138,7 +134,7 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) {
switch clientMessage := innerMsg.GetMessage().(type) {
case *livekit.Signalv2ClientMessage_ConnectRequest:
roomName, participantIdentity, rscr, code, err := s.validateInternal(
logger.GetLogger(),
utils.GetLogger(r.Context()),
r,
clientMessage.ConnectRequest,
)
@@ -201,7 +197,7 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) {
}
case *livekit.Signalv2WireMessage_Fragment:
logger.Errorw("signalv2 bad request", errFragmentsInHTTP)
utils.GetLogger(r.Context()).Errorw("signalv2 bad request", errFragmentsInHTTP)
HandleErrorJson(w, r, http.StatusBadRequest, errFragmentsInHTTP)
return
}
@@ -209,6 +205,34 @@ func (s *RTCv2Service) handlePost(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (s *RTCv2Service) validate(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-type") != "application/x-protobuf" {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("unsupported content-type: %s", r.Header.Get("Content-type")))
return
}
wireMessage, err := getWireMessage(r)
if err != nil {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err))
return
}
connectRequest := getConnectRequest(wireMessage)
if connectRequest == nil {
HandleErrorJson(w, r, http.StatusBadRequest, errors.New("no connect request"))
return
}
_, _, _, code, err := s.validateInternal(utils.GetLogger(r.Context()), r, connectRequest)
if err != nil {
HandleErrorJson(w, r, code, err)
return
}
_, _ = w.Write([]byte("success"))
w.WriteHeader(http.StatusOK)
}
func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-type") != "application/x-protobuf" {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("unsupported content-type: %s", r.Header.Get("Content-type")))
@@ -243,16 +267,9 @@ func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Req
return
}
body, err := ioutil.ReadAll(r.Body)
wireMessage, err := getWireMessage(r)
if err != nil {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not read request body: %w", err))
return
}
wireMessage := &livekit.Signalv2WireMessage{}
err = proto.Unmarshal(body, wireMessage)
if err != nil {
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not unmarshal request: %w", err))
HandleErrorJson(w, r, http.StatusBadRequest, fmt.Errorf("could not get wire message: %w", err))
return
}
@@ -303,3 +320,34 @@ func (s *RTCv2Service) handleParticipantPatch(w http.ResponseWriter, r *http.Req
w.WriteHeader(http.StatusOK)
}
// ---------------------------------------
func getWireMessage(r *http.Request) (*livekit.Signalv2WireMessage, error) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, err
}
wireMessage := &livekit.Signalv2WireMessage{}
err = proto.Unmarshal(body, wireMessage)
if err != nil {
return nil, err
}
return wireMessage, nil
}
func getConnectRequest(wireMessage *livekit.Signalv2WireMessage) *livekit.ConnectRequest {
switch msg := wireMessage.GetMessage().(type) {
case *livekit.Signalv2WireMessage_Envelope:
for _, innerMsg := range msg.Envelope.GetClientMessages() {
switch clientMessage := innerMsg.GetMessage().(type) {
case *livekit.Signalv2ClientMessage_ConnectRequest:
return clientMessage.ConnectRequest
}
}
}
return nil
}