diff --git a/go.mod b/go.mod index f94bfbc4f..c702db297 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 github.com/livekit/mediatransportutil v0.0.0-20250519131108-fb90f5acfded - github.com/livekit/protocol v1.39.4-0.20250806031641-1edabe8e86df + github.com/livekit/protocol v1.39.4-0.20250807053007-7f6468a6a059 github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741 github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index e920532ee..c7aedfdf2 100644 --- a/go.sum +++ b/go.sum @@ -167,8 +167,8 @@ github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 h1:9x+U2HGLrSw5AT github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20250519131108-fb90f5acfded h1:ylZPdnlX1RW9Z15SD4mp87vT2D2shsk0hpLJwSPcq3g= github.com/livekit/mediatransportutil v0.0.0-20250519131108-fb90f5acfded/go.mod h1:mSNtYzSf6iY9xM3UX42VEI+STHvMgHmrYzEHPcdhB8A= -github.com/livekit/protocol v1.39.4-0.20250806031641-1edabe8e86df h1:3YB9qvVAPK0SNWDngETtE7UL75xAygZw07DtsYbKSNk= -github.com/livekit/protocol v1.39.4-0.20250806031641-1edabe8e86df/go.mod h1:YlgUxAegtU8jZ0tVXoIV/4fHeHqqLvS+6JnPKDbpFPU= +github.com/livekit/protocol v1.39.4-0.20250807053007-7f6468a6a059 h1:z9J/0wbTfbJ5QmFhPoYNR+VR3vaOaifQfNYELJHjTZs= +github.com/livekit/protocol v1.39.4-0.20250807053007-7f6468a6a059/go.mod h1:YlgUxAegtU8jZ0tVXoIV/4fHeHqqLvS+6JnPKDbpFPU= github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741 h1:KKL1u94l6dF9u4cBwnnfozk27GH1txWy2SlvkfgmzoY= github.com/livekit/psrpc v0.6.1-0.20250726180611-3915e005e741/go.mod h1:AuDC5uOoEjQJEc69v4Li3t77Ocz0e0NdjQEuFfO+vfk= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 7550960c1..8f32d0dd3 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -197,6 +197,8 @@ type ParticipantInit struct { SubscriberAllowPause *bool DisableICELite bool CreateRoom *livekit.CreateRoomRequest + AddTrackRequests []*livekit.AddTrackRequest + PublisherOffer *livekit.SessionDescription } func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error { @@ -224,6 +226,8 @@ func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error { logBoolPtr("SubscriberAllowPause", pi.SubscriberAllowPause) logBoolPtr("DisableICELite", &pi.DisableICELite) e.AddObject("CreateRoom", logger.Proto(pi.CreateRoom)) + e.AddArray("AddTrackRequests", logger.ProtoSlice(pi.AddTrackRequests)) + e.AddObject("PublisherOffer", logger.Proto(pi.PublisherOffer)) return nil } @@ -234,19 +238,21 @@ func (pi *ParticipantInit) ToStartSession(roomName livekit.RoomName, connectionI } ss := &livekit.StartSession{ - RoomName: string(roomName), - Identity: string(pi.Identity), - Name: string(pi.Name), - ConnectionId: string(connectionID), - Reconnect: pi.Reconnect, - ReconnectReason: pi.ReconnectReason, - AutoSubscribe: pi.AutoSubscribe, - Client: pi.Client, - GrantsJson: string(claims), - AdaptiveStream: pi.AdaptiveStream, - ParticipantId: string(pi.ID), - DisableIceLite: pi.DisableICELite, - CreateRoom: pi.CreateRoom, + RoomName: string(roomName), + Identity: string(pi.Identity), + Name: string(pi.Name), + ConnectionId: string(connectionID), + Reconnect: pi.Reconnect, + ReconnectReason: pi.ReconnectReason, + AutoSubscribe: pi.AutoSubscribe, + Client: pi.Client, + GrantsJson: string(claims), + AdaptiveStream: pi.AdaptiveStream, + ParticipantId: string(pi.ID), + DisableIceLite: pi.DisableICELite, + CreateRoom: pi.CreateRoom, + AddTrackRequests: pi.AddTrackRequests, + PublisherOffer: pi.PublisherOffer, } if pi.SubscriberAllowPause != nil { subscriberAllowPause := *pi.SubscriberAllowPause @@ -263,18 +269,20 @@ func ParticipantInitFromStartSession(ss *livekit.StartSession, region string) (* } pi := &ParticipantInit{ - Identity: livekit.ParticipantIdentity(ss.Identity), - Name: livekit.ParticipantName(ss.Name), - Reconnect: ss.Reconnect, - ReconnectReason: ss.ReconnectReason, - Client: ss.Client, - AutoSubscribe: ss.AutoSubscribe, - Grants: claims, - Region: region, - AdaptiveStream: ss.AdaptiveStream, - ID: livekit.ParticipantID(ss.ParticipantId), - DisableICELite: ss.DisableIceLite, - CreateRoom: ss.CreateRoom, + Identity: livekit.ParticipantIdentity(ss.Identity), + Name: livekit.ParticipantName(ss.Name), + Reconnect: ss.Reconnect, + ReconnectReason: ss.ReconnectReason, + Client: ss.Client, + AutoSubscribe: ss.AutoSubscribe, + Grants: claims, + Region: region, + AdaptiveStream: ss.AdaptiveStream, + ID: livekit.ParticipantID(ss.ParticipantId), + DisableICELite: ss.DisableIceLite, + CreateRoom: ss.CreateRoom, + AddTrackRequests: ss.AddTrackRequests, + PublisherOffer: ss.PublisherOffer, } if ss.SubscriberAllowPause != nil { subscriberAllowPause := *ss.SubscriberAllowPause diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index bf5d9bc0a..8e6d16c31 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1243,6 +1243,7 @@ func (p *ParticipantImpl) handleMigrateTracks() []*MediaTrack { // AddTrack is called when client intends to publish track. // records track details and lets client know it's ok to proceed func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { + p.params.Logger.Debugw("add track request", "trackID", req.Cid) if !p.CanPublishSource(req.Source) { p.pubLogger.Warnw("no permission to publish track", nil, "trackID", req.Sid, "kind", req.Type) return diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index a9de18a03..6c8eda1fe 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -1053,7 +1053,10 @@ func (r *Room) autoSubscribe(participant types.LocalParticipant) bool { return true } -func (r *Room) createJoinResponseLocked(participant types.LocalParticipant, iceServers []*livekit.ICEServer) *livekit.JoinResponse { +func (r *Room) createJoinResponseLocked( + participant types.LocalParticipant, + iceServers []*livekit.ICEServer, +) *livekit.JoinResponse { iceConfig := participant.GetICEConfig() hasICEFallback := iceConfig.GetPreferencePublisher() != livekit.ICECandidateType_ICT_NONE || iceConfig.GetPreferenceSubscriber() != livekit.ICECandidateType_ICT_NONE return &livekit.JoinResponse{ diff --git a/pkg/rtc/signalling/signalhandler.go b/pkg/rtc/signalling/signalhandler.go index bdc42f85d..3d3404c1b 100644 --- a/pkg/rtc/signalling/signalhandler.go +++ b/pkg/rtc/signalling/signalhandler.go @@ -71,7 +71,6 @@ func (s *signalhandler) HandleMessage(msg proto.Message) error { s.params.Participant.AddICECandidate(candidateInit, msg.Trickle.Target) case *livekit.SignalRequest_AddTrack: - s.params.Logger.Debugw("add track request", "trackID", msg.AddTrack.Cid) s.params.Participant.AddTrack(msg.AddTrack) case *livekit.SignalRequest_Mute: diff --git a/pkg/rtc/subscriptionmanager.go b/pkg/rtc/subscriptionmanager.go index 682b8f372..c7d6c7051 100644 --- a/pkg/rtc/subscriptionmanager.go +++ b/pkg/rtc/subscriptionmanager.go @@ -42,9 +42,8 @@ var ( // ensuring this is longer than iceFailedTimeout so we are certain the participant won't return notFoundTimeout = time.Minute // amount of time to try otherwise before flagging subscription as failed - subscriptionTimeout = iceFailedTimeoutTotal - trackRemoveGracePeriod = time.Second - maxUnsubscribeWait = time.Second + subscriptionTimeout = iceFailedTimeoutTotal + maxUnsubscribeWait = time.Second ) const ( diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index a3a680167..615dcf1c7 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -34,6 +34,7 @@ import ( "github.com/livekit/protocol/observability" "github.com/livekit/protocol/observability/roomobs" "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/signalling" "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/guid" "github.com/livekit/protocol/utils/must" @@ -561,6 +562,13 @@ func (r *RoomManager) StartSession( r.iceConfigCache.Put(iceConfigCacheKey{room.Name(), participant.Identity()}, iceConfig) }) + for _, addTrackRequest := range pi.AddTrackRequests { + participant.AddTrack(addTrackRequest) + } + if pi.PublisherOffer != nil { + participant.HandleOffer(signalling.FromProtoSessionDescription(pi.PublisherOffer)) + } + go r.rtcSessionWorker(room, participant, requestSource) return nil } diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index 7675e1cb6..eaf2c64cd 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -15,11 +15,14 @@ package service import ( + "bytes" + "compress/gzip" "context" "encoding/base64" "encoding/json" "errors" "fmt" + "io" "math/rand" "net/http" "os" @@ -30,6 +33,7 @@ import ( "github.com/gorilla/websocket" "go.uber.org/atomic" "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" @@ -114,19 +118,46 @@ func decodeAttributes(str string) (map[string]string, error) { func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict bool) (livekit.RoomName, routing.ParticipantInit, int, error) { var params ValidateConnectRequestParams - params.publish = r.FormValue("publish") + joinRequest := &livekit.JoinRequest{} - attributesStrParam := r.FormValue("attributes") - if attributesStrParam != "" { - attrs, err := decodeAttributes(attributesStrParam) - if err != nil { - if strict { - return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decode attributes") + joinRequestBase64 := r.FormValue("join_request") + if joinRequestBase64 == "" { + params.publish = r.FormValue("publish") + + attributesStrParam := r.FormValue("attributes") + if attributesStrParam != "" { + attrs, err := decodeAttributes(attributesStrParam) + if err != nil { + if strict { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decode attributes") + } + lgr.Debugw("failed to decode attributes", "error", err) + // attrs will be empty here, so just proceed + } + params.attributes = attrs + } + } else { + if compressedBytes, err := base64.URLEncoding.DecodeString(joinRequestBase64); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot base64 decode join request") + } else { + b := bytes.NewReader(compressedBytes) + if reader, err := gzip.NewReader(b); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decompress join request") + } else { + protoBytes, err := io.ReadAll(reader) + reader.Close() + if err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot read decompressed join request") + } + + if err := proto.Unmarshal(protoBytes, joinRequest); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal join request") + } + + params.metadata = joinRequest.Metadata + params.attributes = joinRequest.ParticipantAttributes } - lgr.Debugw("failed to decode attributes", "error", err) - // attrs will be empty here, so just proceed } - params.attributes = attrs } res, code, err := ValidateConnectRequest( @@ -142,33 +173,55 @@ func (s *RTCService) validateInternal(lgr logger.Logger, r *http.Request, strict } pi := routing.ParticipantInit{ - Reconnect: boolValue(r.FormValue("reconnect")), - Identity: livekit.ParticipantIdentity(res.grants.Identity), - Name: livekit.ParticipantName(res.grants.Name), - Client: ParseClientInfo(r), - Grants: res.grants, - Region: res.region, - CreateRoom: res.createRoomRequest, - AutoSubscribe: true, - AdaptiveStream: boolValue(r.FormValue("adaptive_stream")), - DisableICELite: boolValue(r.FormValue("disable_ice_lite")), + Identity: livekit.ParticipantIdentity(res.grants.Identity), + Name: livekit.ParticipantName(res.grants.Name), + Grants: res.grants, + Region: res.region, + CreateRoom: res.createRoomRequest, } - reconnectReason, _ := strconv.Atoi(r.FormValue("reconnect_reason")) // 0 means unknown reason - pi.ReconnectReason = livekit.ReconnectReason(reconnectReason) + if joinRequestBase64 == "" { + pi.Reconnect = boolValue(r.FormValue("reconnect")) + pi.Client = ParseClientInfo(r) + pi.AutoSubscribe = true + pi.AdaptiveStream = boolValue(r.FormValue("adaptive_stream")) + pi.DisableICELite = boolValue(r.FormValue("disable_ice_lite")) - if pi.Reconnect { - pi.ID = livekit.ParticipantID(r.FormValue("sid")) - } + reconnectReason, _ := strconv.Atoi(r.FormValue("reconnect_reason")) // 0 means unknown reason + pi.ReconnectReason = livekit.ReconnectReason(reconnectReason) - if autoSubscribe := r.FormValue("auto_subscribe"); autoSubscribe != "" { - pi.AutoSubscribe = boolValue(autoSubscribe) - } + if pi.Reconnect { + pi.ID = livekit.ParticipantID(r.FormValue("sid")) + } - subscriberAllowPauseParam := r.FormValue("subscriber_allow_pause") - if subscriberAllowPauseParam != "" { - subscriberAllowPause := boolValue(subscriberAllowPauseParam) + if autoSubscribe := r.FormValue("auto_subscribe"); autoSubscribe != "" { + pi.AutoSubscribe = boolValue(autoSubscribe) + } + + subscriberAllowPauseParam := r.FormValue("subscriber_allow_pause") + if subscriberAllowPauseParam != "" { + subscriberAllowPause := boolValue(subscriberAllowPauseParam) + pi.SubscriberAllowPause = &subscriberAllowPause + } + } else { + lgr.Debugw("processing join request", "joinRequest", logger.Proto(joinRequest)) + + AugmentClientInfo(joinRequest.ClientInfo, r) + pi.Client = joinRequest.ClientInfo + + pi.AutoSubscribe = joinRequest.GetConnectionSettings().GetAutoSubscribe() + pi.AdaptiveStream = joinRequest.GetConnectionSettings().GetAdaptiveStream() + pi.DisableICELite = joinRequest.GetConnectionSettings().GetDisableIceLite() + + subscriberAllowPause := joinRequest.GetConnectionSettings().GetSubscriberAllowPause() pi.SubscriberAllowPause = &subscriberAllowPause + + pi.AddTrackRequests = joinRequest.AddTrackRequests + pi.PublisherOffer = joinRequest.PublisherOffer + + pi.Reconnect = joinRequest.Reconnect + pi.ReconnectReason = joinRequest.ReconnectReason + pi.ID = livekit.ParticipantID(joinRequest.ParticipantSid) } return res.roomName, pi, code, err diff --git a/test/agent_test.go b/test/agent_test.go index 795fa1111..d6a2571f2 100644 --- a/test/agent_test.go +++ b/test/agent_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/testutils" + testclient "github.com/livekit/livekit-server/test/client" "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" ) @@ -69,7 +70,7 @@ func TestAgents(t *testing.T) { }, RegisterTimeout) c1 := createRTCClient("c1", defaultServerPort, nil) - c2 := createRTCClient("c2", defaultServerPort, nil) + c2 := createRTCClient("c2", defaultServerPort, &testclient.Options{UseJoinRequestQueryParam: true}) waitUntilConnected(t, c1, c2) // publish 2 tracks diff --git a/test/client/client.go b/test/client/client.go index 6d479aba7..53626547c 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -15,6 +15,8 @@ package client import ( + "bytes" + "compress/gzip" "context" "encoding/base64" "encoding/json" @@ -24,6 +26,7 @@ import ( "net/http" "net/url" "path/filepath" + "runtime" "sync" "time" @@ -122,10 +125,11 @@ type Options struct { TokenCustomizer func(token *auth.AccessToken, grants *auth.VideoGrant) SignalRequestInterceptor SignalRequestInterceptor SignalResponseInterceptor SignalResponseInterceptor + UseJoinRequestQueryParam bool } func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) { - u, err := url.Parse(host + fmt.Sprintf("/rtc?protocol=%d", types.CurrentProtocol)) + u, err := url.Parse(host + "/rtc") if err != nil { return nil, err } @@ -133,32 +137,65 @@ func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error SetAuthorizationToken(requestHeader, token) connectUrl := u.String() - sdk := "go" - if opts != nil { - connectUrl = fmt.Sprintf("%s&auto_subscribe=%t", connectUrl, opts.AutoSubscribe) - if opts.Publish != "" { - connectUrl += encodeQueryParam("publish", opts.Publish) - } - if len(opts.Attributes) != 0 { - data, err := json.Marshal(opts.Attributes) - if err != nil { - return nil, err - } - connectUrl += encodeQueryParam("attributes", base64.URLEncoding.EncodeToString(data)) + if opts != nil && opts.UseJoinRequestQueryParam { + clientInfo := &livekit.ClientInfo{ + Os: runtime.GOOS, + Sdk: livekit.ClientInfo_GO, + Protocol: types.CurrentProtocol, } if opts.ClientInfo != nil { - if opts.ClientInfo.DeviceModel != "" { - connectUrl += encodeQueryParam("device_model", opts.ClientInfo.DeviceModel) + clientInfo = opts.ClientInfo + } + + connectionSettings := &livekit.ConnectionSettings{ + AutoSubscribe: opts.AutoSubscribe, + } + + joinRequest := &livekit.JoinRequest{ + ClientInfo: clientInfo, + ConnectionSettings: connectionSettings, + ParticipantAttributes: opts.Attributes, + } + + if marshalled, err := proto.Marshal(joinRequest); err == nil { + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + writer.Write(marshalled) + writer.Close() + + connectUrl += fmt.Sprintf("?join_request=%s", base64.URLEncoding.EncodeToString(buf.Bytes())) + } + } else { + connectUrl += fmt.Sprintf("?protocol=%d", types.CurrentProtocol) + + sdk := "go" + if opts != nil { + connectUrl += fmt.Sprintf("&auto_subscribe=%t", opts.AutoSubscribe) + if opts.Publish != "" { + connectUrl += encodeQueryParam("publish", opts.Publish) } - if opts.ClientInfo.Os != "" { - connectUrl += encodeQueryParam("os", opts.ClientInfo.Os) + if len(opts.Attributes) != 0 { + data, err := json.Marshal(opts.Attributes) + if err != nil { + return nil, err + } + connectUrl += encodeQueryParam("attributes", base64.URLEncoding.EncodeToString(data)) } - if opts.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN { - sdk = opts.ClientInfo.Sdk.String() + if opts.ClientInfo != nil { + if opts.ClientInfo.DeviceModel != "" { + connectUrl += encodeQueryParam("device_model", opts.ClientInfo.DeviceModel) + } + if opts.ClientInfo.Os != "" { + connectUrl += encodeQueryParam("os", opts.ClientInfo.Os) + } + if opts.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN { + sdk = opts.ClientInfo.Sdk.String() + } } } + connectUrl += encodeQueryParam("sdk", sdk) } - connectUrl += encodeQueryParam("sdk", sdk) + conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader) return conn, err } diff --git a/test/singlenode_test.go b/test/singlenode_test.go index 761f8398b..fe76a8f94 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -564,6 +564,7 @@ func TestSingleNodeAttributes(t *testing.T) { "b": "1", }) }, + UseJoinRequestQueryParam: true, }) grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} grant.SetCanSubscribe(false) @@ -663,7 +664,9 @@ func TestSubscribeToCodecUnsupported(t *testing.T) { _, finish := setupSingleNodeTest("TestSubscribeToCodecUnsupported") defer finish() - c1 := createRTCClient("c1", defaultServerPort, nil) + c1 := createRTCClient("c1", defaultServerPort, &testclient.Options{ + UseJoinRequestQueryParam: true, + }) // create a client that doesn't support H264 c2 := createRTCClient("c2", defaultServerPort, &testclient.Options{ AutoSubscribe: true,