Support join request as proto + base64 encoded query param (#3836)

* Support join request as proto + base64 encoded query param

* joinPublish

* staticcheck

* deps

* tests

* gzip

* test

* deps

* clean up
This commit is contained in:
Raja Subramanian
2025-08-07 11:13:27 +05:30
committed by GitHub
parent 7dea101286
commit 5ca1626439
12 changed files with 198 additions and 86 deletions
+1 -1
View File
@@ -23,7 +23,7 @@ require (
github.com/jxskiss/base62 v1.1.0
github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731
github.com/livekit/mediatransportutil v0.0.0-20250519131108-fb90f5acfded
github.com/livekit/protocol v1.39.4-0.20250806031641-1edabe8e86df
github.com/livekit/protocol v1.39.4-0.20250807053007-7f6468a6a059
github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741
github.com/mackerelio/go-osstat v0.2.5
github.com/magefile/mage v1.15.0
+2 -2
View File
@@ -167,8 +167,8 @@ github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 h1:9x+U2HGLrSw5AT
github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ=
github.com/livekit/mediatransportutil v0.0.0-20250519131108-fb90f5acfded h1:ylZPdnlX1RW9Z15SD4mp87vT2D2shsk0hpLJwSPcq3g=
github.com/livekit/mediatransportutil v0.0.0-20250519131108-fb90f5acfded/go.mod h1:mSNtYzSf6iY9xM3UX42VEI+STHvMgHmrYzEHPcdhB8A=
github.com/livekit/protocol v1.39.4-0.20250806031641-1edabe8e86df h1:3YB9qvVAPK0SNWDngETtE7UL75xAygZw07DtsYbKSNk=
github.com/livekit/protocol v1.39.4-0.20250806031641-1edabe8e86df/go.mod h1:YlgUxAegtU8jZ0tVXoIV/4fHeHqqLvS+6JnPKDbpFPU=
github.com/livekit/protocol v1.39.4-0.20250807053007-7f6468a6a059 h1:z9J/0wbTfbJ5QmFhPoYNR+VR3vaOaifQfNYELJHjTZs=
github.com/livekit/protocol v1.39.4-0.20250807053007-7f6468a6a059/go.mod h1:YlgUxAegtU8jZ0tVXoIV/4fHeHqqLvS+6JnPKDbpFPU=
github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741 h1:KKL1u94l6dF9u4cBwnnfozk27GH1txWy2SlvkfgmzoY=
github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741/go.mod h1:AuDC5uOoEjQJEc69v4Li3t77Ocz0e0NdjQEuFfO+vfk=
github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o=
+33 -25
View File
@@ -197,6 +197,8 @@ type ParticipantInit struct {
SubscriberAllowPause *bool
DisableICELite bool
CreateRoom *livekit.CreateRoomRequest
AddTrackRequests []*livekit.AddTrackRequest
PublisherOffer *livekit.SessionDescription
}
func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error {
@@ -224,6 +226,8 @@ func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error {
logBoolPtr("SubscriberAllowPause", pi.SubscriberAllowPause)
logBoolPtr("DisableICELite", &pi.DisableICELite)
e.AddObject("CreateRoom", logger.Proto(pi.CreateRoom))
e.AddArray("AddTrackRequests", logger.ProtoSlice(pi.AddTrackRequests))
e.AddObject("PublisherOffer", logger.Proto(pi.PublisherOffer))
return nil
}
@@ -234,19 +238,21 @@ func (pi *ParticipantInit) ToStartSession(roomName livekit.RoomName, connectionI
}
ss := &livekit.StartSession{
RoomName: string(roomName),
Identity: string(pi.Identity),
Name: string(pi.Name),
ConnectionId: string(connectionID),
Reconnect: pi.Reconnect,
ReconnectReason: pi.ReconnectReason,
AutoSubscribe: pi.AutoSubscribe,
Client: pi.Client,
GrantsJson: string(claims),
AdaptiveStream: pi.AdaptiveStream,
ParticipantId: string(pi.ID),
DisableIceLite: pi.DisableICELite,
CreateRoom: pi.CreateRoom,
RoomName: string(roomName),
Identity: string(pi.Identity),
Name: string(pi.Name),
ConnectionId: string(connectionID),
Reconnect: pi.Reconnect,
ReconnectReason: pi.ReconnectReason,
AutoSubscribe: pi.AutoSubscribe,
Client: pi.Client,
GrantsJson: string(claims),
AdaptiveStream: pi.AdaptiveStream,
ParticipantId: string(pi.ID),
DisableIceLite: pi.DisableICELite,
CreateRoom: pi.CreateRoom,
AddTrackRequests: pi.AddTrackRequests,
PublisherOffer: pi.PublisherOffer,
}
if pi.SubscriberAllowPause != nil {
subscriberAllowPause := *pi.SubscriberAllowPause
@@ -263,18 +269,20 @@ func ParticipantInitFromStartSession(ss *livekit.StartSession, region string) (*
}
pi := &ParticipantInit{
Identity: livekit.ParticipantIdentity(ss.Identity),
Name: livekit.ParticipantName(ss.Name),
Reconnect: ss.Reconnect,
ReconnectReason: ss.ReconnectReason,
Client: ss.Client,
AutoSubscribe: ss.AutoSubscribe,
Grants: claims,
Region: region,
AdaptiveStream: ss.AdaptiveStream,
ID: livekit.ParticipantID(ss.ParticipantId),
DisableICELite: ss.DisableIceLite,
CreateRoom: ss.CreateRoom,
Identity: livekit.ParticipantIdentity(ss.Identity),
Name: livekit.ParticipantName(ss.Name),
Reconnect: ss.Reconnect,
ReconnectReason: ss.ReconnectReason,
Client: ss.Client,
AutoSubscribe: ss.AutoSubscribe,
Grants: claims,
Region: region,
AdaptiveStream: ss.AdaptiveStream,
ID: livekit.ParticipantID(ss.ParticipantId),
DisableICELite: ss.DisableIceLite,
CreateRoom: ss.CreateRoom,
AddTrackRequests: ss.AddTrackRequests,
PublisherOffer: ss.PublisherOffer,
}
if ss.SubscriberAllowPause != nil {
subscriberAllowPause := *ss.SubscriberAllowPause
+1
View File
@@ -1243,6 +1243,7 @@ func (p *ParticipantImpl) handleMigrateTracks() []*MediaTrack {
// AddTrack is called when client intends to publish track.
// records track details and lets client know it's ok to proceed
func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) {
p.params.Logger.Debugw("add track request", "trackID", req.Cid)
if !p.CanPublishSource(req.Source) {
p.pubLogger.Warnw("no permission to publish track", nil, "trackID", req.Sid, "kind", req.Type)
return
+4 -1
View File
@@ -1053,7 +1053,10 @@ func (r *Room) autoSubscribe(participant types.LocalParticipant) bool {
return true
}
func (r *Room) createJoinResponseLocked(participant types.LocalParticipant, iceServers []*livekit.ICEServer) *livekit.JoinResponse {
func (r *Room) createJoinResponseLocked(
participant types.LocalParticipant,
iceServers []*livekit.ICEServer,
) *livekit.JoinResponse {
iceConfig := participant.GetICEConfig()
hasICEFallback := iceConfig.GetPreferencePublisher() != livekit.ICECandidateType_ICT_NONE || iceConfig.GetPreferenceSubscriber() != livekit.ICECandidateType_ICT_NONE
return &livekit.JoinResponse{
-1
View File
@@ -71,7 +71,6 @@ func (s *signalhandler) HandleMessage(msg proto.Message) error {
s.params.Participant.AddICECandidate(candidateInit, msg.Trickle.Target)
case *livekit.SignalRequest_AddTrack:
s.params.Logger.Debugw("add track request", "trackID", msg.AddTrack.Cid)
s.params.Participant.AddTrack(msg.AddTrack)
case *livekit.SignalRequest_Mute:
+2 -3
View File
@@ -42,9 +42,8 @@ var (
// ensuring this is longer than iceFailedTimeout so we are certain the participant won't return
notFoundTimeout = time.Minute
// amount of time to try otherwise before flagging subscription as failed
subscriptionTimeout = iceFailedTimeoutTotal
trackRemoveGracePeriod = time.Second
maxUnsubscribeWait = time.Second
subscriptionTimeout = iceFailedTimeoutTotal
maxUnsubscribeWait = time.Second
)
const (
+8
View File
@@ -34,6 +34,7 @@ import (
"github.com/livekit/protocol/observability"
"github.com/livekit/protocol/observability/roomobs"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/signalling"
"github.com/livekit/protocol/utils"
"github.com/livekit/protocol/utils/guid"
"github.com/livekit/protocol/utils/must"
@@ -561,6 +562,13 @@ func (r *RoomManager) StartSession(
r.iceConfigCache.Put(iceConfigCacheKey{room.Name(), participant.Identity()}, iceConfig)
})
for _, addTrackRequest := range pi.AddTrackRequests {
participant.AddTrack(addTrackRequest)
}
if pi.PublisherOffer != nil {
participant.HandleOffer(signalling.FromProtoSessionDescription(pi.PublisherOffer))
}
go r.rtcSessionWorker(room, participant, requestSource)
return nil
}
+84 -31
View File
@@ -15,11 +15,14 @@
package service
import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"os"
@@ -30,6 +33,7 @@ import (
"github.com/gorilla/websocket"
"go.uber.org/atomic"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/proto"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
@@ -114,19 +118,46 @@ func decodeAttributes(str string) (map[string]string, error) {
func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict bool) (livekit.RoomName, routing.ParticipantInit, int, error) {
var params ValidateConnectRequestParams
params.publish = r.FormValue("publish")
joinRequest := &livekit.JoinRequest{}
attributesStrParam := r.FormValue("attributes")
if attributesStrParam != "" {
attrs, err := decodeAttributes(attributesStrParam)
if err != nil {
if strict {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decode attributes")
joinRequestBase64 := r.FormValue("join_request")
if joinRequestBase64 == "" {
params.publish = r.FormValue("publish")
attributesStrParam := r.FormValue("attributes")
if attributesStrParam != "" {
attrs, err := decodeAttributes(attributesStrParam)
if err != nil {
if strict {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decode attributes")
}
lgr.Debugw("failed to decode attributes", "error", err)
// attrs will be empty here, so just proceed
}
params.attributes = attrs
}
} else {
if compressedBytes, err := base64.URLEncoding.DecodeString(joinRequestBase64); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot base64 decode join request")
} else {
b := bytes.NewReader(compressedBytes)
if reader, err := gzip.NewReader(b); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decompress join request")
} else {
protoBytes, err := io.ReadAll(reader)
reader.Close()
if err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot read decompressed join request")
}
if err := proto.Unmarshal(protoBytes, joinRequest); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal join request")
}
params.metadata = joinRequest.Metadata
params.attributes = joinRequest.ParticipantAttributes
}
lgr.Debugw("failed to decode attributes", "error", err)
// attrs will be empty here, so just proceed
}
params.attributes = attrs
}
res, code, err := ValidateConnectRequest(
@@ -142,33 +173,55 @@ func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict
}
pi := routing.ParticipantInit{
Reconnect: boolValue(r.FormValue("reconnect")),
Identity: livekit.ParticipantIdentity(res.grants.Identity),
Name: livekit.ParticipantName(res.grants.Name),
Client: ParseClientInfo(r),
Grants: res.grants,
Region: res.region,
CreateRoom: res.createRoomRequest,
AutoSubscribe: true,
AdaptiveStream: boolValue(r.FormValue("adaptive_stream")),
DisableICELite: boolValue(r.FormValue("disable_ice_lite")),
Identity: livekit.ParticipantIdentity(res.grants.Identity),
Name: livekit.ParticipantName(res.grants.Name),
Grants: res.grants,
Region: res.region,
CreateRoom: res.createRoomRequest,
}
reconnectReason, _ := strconv.Atoi(r.FormValue("reconnect_reason")) // 0 means unknown reason
pi.ReconnectReason = livekit.ReconnectReason(reconnectReason)
if joinRequestBase64 == "" {
pi.Reconnect = boolValue(r.FormValue("reconnect"))
pi.Client = ParseClientInfo(r)
pi.AutoSubscribe = true
pi.AdaptiveStream = boolValue(r.FormValue("adaptive_stream"))
pi.DisableICELite = boolValue(r.FormValue("disable_ice_lite"))
if pi.Reconnect {
pi.ID = livekit.ParticipantID(r.FormValue("sid"))
}
reconnectReason, _ := strconv.Atoi(r.FormValue("reconnect_reason")) // 0 means unknown reason
pi.ReconnectReason = livekit.ReconnectReason(reconnectReason)
if autoSubscribe := r.FormValue("auto_subscribe"); autoSubscribe != "" {
pi.AutoSubscribe = boolValue(autoSubscribe)
}
if pi.Reconnect {
pi.ID = livekit.ParticipantID(r.FormValue("sid"))
}
subscriberAllowPauseParam := r.FormValue("subscriber_allow_pause")
if subscriberAllowPauseParam != "" {
subscriberAllowPause := boolValue(subscriberAllowPauseParam)
if autoSubscribe := r.FormValue("auto_subscribe"); autoSubscribe != "" {
pi.AutoSubscribe = boolValue(autoSubscribe)
}
subscriberAllowPauseParam := r.FormValue("subscriber_allow_pause")
if subscriberAllowPauseParam != "" {
subscriberAllowPause := boolValue(subscriberAllowPauseParam)
pi.SubscriberAllowPause = &subscriberAllowPause
}
} else {
lgr.Debugw("processing join request", "joinRequest", logger.Proto(joinRequest))
AugmentClientInfo(joinRequest.ClientInfo, r)
pi.Client = joinRequest.ClientInfo
pi.AutoSubscribe = joinRequest.GetConnectionSettings().GetAutoSubscribe()
pi.AdaptiveStream = joinRequest.GetConnectionSettings().GetAdaptiveStream()
pi.DisableICELite = joinRequest.GetConnectionSettings().GetDisableIceLite()
subscriberAllowPause := joinRequest.GetConnectionSettings().GetSubscriberAllowPause()
pi.SubscriberAllowPause = &subscriberAllowPause
pi.AddTrackRequests = joinRequest.AddTrackRequests
pi.PublisherOffer = joinRequest.PublisherOffer
pi.Reconnect = joinRequest.Reconnect
pi.ReconnectReason = joinRequest.ReconnectReason
pi.ID = livekit.ParticipantID(joinRequest.ParticipantSid)
}
return res.roomName, pi, code, err
+2 -1
View File
@@ -22,6 +22,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/livekit/livekit-server/pkg/testutils"
testclient "github.com/livekit/livekit-server/test/client"
"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
)
@@ -69,7 +70,7 @@ func TestAgents(t *testing.T) {
}, RegisterTimeout)
c1 := createRTCClient("c1", defaultServerPort, nil)
c2 := createRTCClient("c2", defaultServerPort, nil)
c2 := createRTCClient("c2", defaultServerPort, &testclient.Options{UseJoinRequestQueryParam: true})
waitUntilConnected(t, c1, c2)
// publish 2 tracks
+57 -20
View File
@@ -15,6 +15,8 @@
package client
import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
@@ -24,6 +26,7 @@ import (
"net/http"
"net/url"
"path/filepath"
"runtime"
"sync"
"time"
@@ -122,10 +125,11 @@ type Options struct {
TokenCustomizer func(token *auth.AccessToken, grants *auth.VideoGrant)
SignalRequestInterceptor SignalRequestInterceptor
SignalResponseInterceptor SignalResponseInterceptor
UseJoinRequestQueryParam bool
}
func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) {
u, err := url.Parse(host + fmt.Sprintf("/rtc?protocol=%d", types.CurrentProtocol))
u, err := url.Parse(host + "/rtc")
if err != nil {
return nil, err
}
@@ -133,32 +137,65 @@ func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error
SetAuthorizationToken(requestHeader, token)
connectUrl := u.String()
sdk := "go"
if opts != nil {
connectUrl = fmt.Sprintf("%s&auto_subscribe=%t", connectUrl, opts.AutoSubscribe)
if opts.Publish != "" {
connectUrl += encodeQueryParam("publish", opts.Publish)
}
if len(opts.Attributes) != 0 {
data, err := json.Marshal(opts.Attributes)
if err != nil {
return nil, err
}
connectUrl += encodeQueryParam("attributes", base64.URLEncoding.EncodeToString(data))
if opts != nil && opts.UseJoinRequestQueryParam {
clientInfo := &livekit.ClientInfo{
Os: runtime.GOOS,
Sdk: livekit.ClientInfo_GO,
Protocol: types.CurrentProtocol,
}
if opts.ClientInfo != nil {
if opts.ClientInfo.DeviceModel != "" {
connectUrl += encodeQueryParam("device_model", opts.ClientInfo.DeviceModel)
clientInfo = opts.ClientInfo
}
connectionSettings := &livekit.ConnectionSettings{
AutoSubscribe: opts.AutoSubscribe,
}
joinRequest := &livekit.JoinRequest{
ClientInfo: clientInfo,
ConnectionSettings: connectionSettings,
ParticipantAttributes: opts.Attributes,
}
if marshalled, err := proto.Marshal(joinRequest); err == nil {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
writer.Write(marshalled)
writer.Close()
connectUrl += fmt.Sprintf("?join_request=%s", base64.URLEncoding.EncodeToString(buf.Bytes()))
}
} else {
connectUrl += fmt.Sprintf("?protocol=%d", types.CurrentProtocol)
sdk := "go"
if opts != nil {
connectUrl += fmt.Sprintf("&auto_subscribe=%t", opts.AutoSubscribe)
if opts.Publish != "" {
connectUrl += encodeQueryParam("publish", opts.Publish)
}
if opts.ClientInfo.Os != "" {
connectUrl += encodeQueryParam("os", opts.ClientInfo.Os)
if len(opts.Attributes) != 0 {
data, err := json.Marshal(opts.Attributes)
if err != nil {
return nil, err
}
connectUrl += encodeQueryParam("attributes", base64.URLEncoding.EncodeToString(data))
}
if opts.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN {
sdk = opts.ClientInfo.Sdk.String()
if opts.ClientInfo != nil {
if opts.ClientInfo.DeviceModel != "" {
connectUrl += encodeQueryParam("device_model", opts.ClientInfo.DeviceModel)
}
if opts.ClientInfo.Os != "" {
connectUrl += encodeQueryParam("os", opts.ClientInfo.Os)
}
if opts.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN {
sdk = opts.ClientInfo.Sdk.String()
}
}
}
connectUrl += encodeQueryParam("sdk", sdk)
}
connectUrl += encodeQueryParam("sdk", sdk)
conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader)
return conn, err
}
+4 -1
View File
@@ -564,6 +564,7 @@ func TestSingleNodeAttributes(t *testing.T) {
"b": "1",
})
},
UseJoinRequestQueryParam: true,
})
grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom}
grant.SetCanSubscribe(false)
@@ -663,7 +664,9 @@ func TestSubscribeToCodecUnsupported(t *testing.T) {
_, finish := setupSingleNodeTest("TestSubscribeToCodecUnsupported")
defer finish()
c1 := createRTCClient("c1", defaultServerPort, nil)
c1 := createRTCClient("c1", defaultServerPort, &testclient.Options{
UseJoinRequestQueryParam: true,
})
// create a client that doesn't support H264
c2 := createRTCClient("c2", defaultServerPort, &testclient.Options{
AutoSubscribe: true,