Files
livekit/pkg/service/rtcservice.go
Raja Subramanian a35a6ae751 Add participant option for data track auto-subscribe. (#4240)
* Add participant option for data track auto-subscribe.

Default disabled.

* protocol update to use data track auto subscribe setting

* deps
2026-01-14 13:22:43 +05:30

665 lines
19 KiB
Go

// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package service
import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
"go.uber.org/atomic"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/proto"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/psrpc"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/routing"
"github.com/livekit/livekit-server/pkg/rtc"
"github.com/livekit/livekit-server/pkg/telemetry"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/livekit-server/pkg/utils"
)
type RTCService struct {
router routing.MessageRouter
roomAllocator RoomAllocator
upgrader websocket.Upgrader
config *config.Config
isDev bool
limits config.LimitConfig
telemetry telemetry.TelemetryService
mu sync.Mutex
connections map[*websocket.Conn]struct{}
}
func NewRTCService(
conf *config.Config,
ra RoomAllocator,
router routing.MessageRouter,
telemetry telemetry.TelemetryService,
) *RTCService {
s := &RTCService{
router: router,
roomAllocator: ra,
config: conf,
isDev: conf.Development,
limits: conf.Limit,
telemetry: telemetry,
connections: map[*websocket.Conn]struct{}{},
}
s.upgrader = websocket.Upgrader{
EnableCompression: true,
// allow connections from any origin, since script may be hosted anywhere
// security is enforced by access tokens
CheckOrigin: func(r *http.Request) bool {
return true
},
}
return s
}
func (s *RTCService) SetupRoutes(mux *http.ServeMux) {
mux.HandleFunc("/rtc", s.v0)
mux.HandleFunc("/rtc/validate", s.v0Validate)
mux.HandleFunc("/rtc/v1", s.v1)
mux.HandleFunc("/rtc/v1/validate", s.v1Validate)
}
func (s *RTCService) v0Validate(w http.ResponseWriter, r *http.Request) {
lgr := utils.GetLogger(r.Context())
_, _, code, err := s.validateInternal(lgr, r, false, true)
if err != nil {
HandleError(w, r, code, err)
return
}
_, _ = w.Write([]byte("success"))
}
func (s *RTCService) v1Validate(w http.ResponseWriter, r *http.Request) {
lgr := utils.GetLogger(r.Context())
_, _, code, err := s.validateInternal(lgr, r, true, true)
if err != nil {
HandleError(w, r, code, err)
return
}
_, _ = w.Write([]byte("success"))
}
func decodeAttributes(str string) (map[string]string, error) {
data, err := base64.URLEncoding.DecodeString(str)
if err != nil {
return nil, err
}
var attrs map[string]string
if err := json.Unmarshal(data, &attrs); err != nil {
return nil, err
}
return attrs, nil
}
var gzipReaderPool = sync.Pool{
New: func() any { return &gzip.Reader{} },
}
func (s *RTCService) validateInternal(
lgr logger.Logger,
r *http.Request,
needsJoinRequest bool,
strict bool,
) (livekit.RoomName, routing.ParticipantInit, int, error) {
var params ValidateConnectRequestParams
useSinglePeerConnection := false
joinRequest := &livekit.JoinRequest{}
wrappedJoinRequestBase64 := r.FormValue("join_request")
if wrappedJoinRequestBase64 == "" {
if needsJoinRequest {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("join_request is required")
}
params.publish = r.FormValue("publish")
attributesStrParam := r.FormValue("attributes")
if attributesStrParam != "" {
attrs, err := decodeAttributes(attributesStrParam)
if err != nil {
if strict {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decode attributes")
}
lgr.Debugw("failed to decode attributes", "error", err)
// attrs will be empty here, so just proceed
}
params.attributes = attrs
}
} else {
useSinglePeerConnection = true
if wrappedProtoBytes, err := base64.URLEncoding.DecodeString(wrappedJoinRequestBase64); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot base64 decode wrapped join request")
} else {
wrappedJoinRequest := &livekit.WrappedJoinRequest{}
if err := proto.Unmarshal(wrappedProtoBytes, wrappedJoinRequest); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal wrapped join request")
}
switch wrappedJoinRequest.Compression {
case livekit.WrappedJoinRequest_NONE:
if err := proto.Unmarshal(wrappedJoinRequest.JoinRequest, joinRequest); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal join request")
}
case livekit.WrappedJoinRequest_GZIP:
reader := gzipReaderPool.Get().(*gzip.Reader)
defer gzipReaderPool.Put(reader)
reader.Reset(bytes.NewReader(wrappedJoinRequest.JoinRequest))
protoBytes, err := io.ReadAll(reader)
if err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot read decompressed join request")
}
if err := proto.Unmarshal(protoBytes, joinRequest); err != nil {
return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal join request")
}
}
params.metadata = joinRequest.Metadata
params.attributes = joinRequest.ParticipantAttributes
}
}
res, code, err := ValidateConnectRequest(
lgr,
r,
s.limits,
params,
s.router,
s.roomAllocator,
)
if err != nil {
return res.roomName, routing.ParticipantInit{}, code, err
}
pi := routing.ParticipantInit{
Identity: livekit.ParticipantIdentity(res.grants.Identity),
Name: livekit.ParticipantName(res.grants.Name),
Grants: res.grants,
Region: res.region,
CreateRoom: res.createRoomRequest,
UseSinglePeerConnection: useSinglePeerConnection,
}
if wrappedJoinRequestBase64 == "" {
pi.Reconnect = boolValue(r.FormValue("reconnect"))
pi.Client = ParseClientInfo(r)
pi.AutoSubscribe = true
if autoSubscribeParam := r.FormValue("auto_subscribe"); autoSubscribeParam != "" {
pi.AutoSubscribe = boolValue(autoSubscribeParam)
}
if autoSubscribeDataTrackParam := r.FormValue("auto_subscribe_data_track"); autoSubscribeDataTrackParam != "" {
autoSubscribeDataTrack := boolValue(autoSubscribeDataTrackParam)
pi.AutoSubscribeDataTrack = &autoSubscribeDataTrack
}
pi.AdaptiveStream = boolValue(r.FormValue("adaptive_stream"))
pi.DisableICELite = boolValue(r.FormValue("disable_ice_lite"))
reconnectReason, _ := strconv.Atoi(r.FormValue("reconnect_reason")) // 0 means unknown reason
pi.ReconnectReason = livekit.ReconnectReason(reconnectReason)
if pi.Reconnect {
pi.ID = livekit.ParticipantID(r.FormValue("sid"))
}
if subscriberAllowPauseParam := r.FormValue("subscriber_allow_pause"); subscriberAllowPauseParam != "" {
subscriberAllowPause := boolValue(subscriberAllowPauseParam)
pi.SubscriberAllowPause = &subscriberAllowPause
}
} else {
lgr.Debugw("processing join request", "joinRequest", logger.Proto(joinRequest))
AugmentClientInfo(joinRequest.ClientInfo, r)
pi.Client = joinRequest.ClientInfo
pi.AutoSubscribe = joinRequest.GetConnectionSettings().GetAutoSubscribe()
autoSubscribeDataTrack := joinRequest.GetConnectionSettings().GetAutoSubscribeDataTrack()
pi.AutoSubscribeDataTrack = &autoSubscribeDataTrack
pi.AdaptiveStream = joinRequest.GetConnectionSettings().GetAdaptiveStream()
pi.DisableICELite = joinRequest.GetConnectionSettings().GetDisableIceLite()
subscriberAllowPause := joinRequest.GetConnectionSettings().GetSubscriberAllowPause()
pi.SubscriberAllowPause = &subscriberAllowPause
pi.AddTrackRequests = joinRequest.AddTrackRequests
pi.PublisherOffer = joinRequest.PublisherOffer
pi.Reconnect = joinRequest.Reconnect
pi.ReconnectReason = joinRequest.ReconnectReason
pi.ID = livekit.ParticipantID(joinRequest.ParticipantSid)
}
return res.roomName, pi, code, err
}
func (s *RTCService) v0(w http.ResponseWriter, r *http.Request) {
s.serve(w, r, false)
}
func (s *RTCService) v1(w http.ResponseWriter, r *http.Request) {
s.serve(w, r, true)
}
func (s *RTCService) serve(w http.ResponseWriter, r *http.Request, needsJoinRequest bool) {
// reject non websocket requests
if !websocket.IsWebSocketUpgrade(r) {
w.WriteHeader(404)
return
}
var (
roomName livekit.RoomName
roomID livekit.RoomID
participantIdentity livekit.ParticipantIdentity
pID livekit.ParticipantID
loggerResolved bool
pi routing.ParticipantInit
code int
err error
)
pLogger, loggerResolver := utils.GetLogger(r.Context()).WithDeferredValues()
getLoggerFields := func() []any {
return []any{
"room", roomName,
"roomID", roomID,
"participant", participantIdentity,
"pID", pID,
}
}
resolveLogger := func(force bool) {
if loggerResolved {
return
}
if force || (roomName != "" && roomID != "" && participantIdentity != "" && pID != "") {
loggerResolved = true
loggerResolver.Resolve(getLoggerFields()...)
}
}
resetLogger := func() {
loggerResolver.Reset()
roomName = ""
roomID = ""
participantIdentity = ""
pID = ""
loggerResolved = false
}
roomName, pi, code, err = s.validateInternal(pLogger, r, needsJoinRequest, false)
if err != nil {
HandleError(w, r, code, err)
return
}
participantIdentity = pi.Identity
if pi.ID != "" {
pID = pi.ID
}
// give it a few attempts to start session
var cr connectionResult
var initialResponse *livekit.SignalResponse
for attempt := 0; attempt < s.config.SignalRelay.ConnectAttempts; attempt++ {
connectionTimeout := 3 * time.Second * time.Duration(attempt+1)
ctx := utils.ContextWithAttempt(r.Context(), attempt)
cr, initialResponse, err = s.startConnection(ctx, roomName, pi, connectionTimeout)
if err == nil || errors.Is(err, context.Canceled) {
break
}
}
if err != nil {
prometheus.IncrementParticipantJoinFail(1)
status := http.StatusInternalServerError
var psrpcErr psrpc.Error
if errors.As(err, &psrpcErr) {
status = psrpcErr.ToHttp()
}
HandleError(w, r, status, err, getLoggerFields()...)
return
}
prometheus.IncrementParticipantJoin(1)
pLogger = pLogger.WithValues("connID", cr.ConnectionID)
if !pi.Reconnect && initialResponse.GetJoin() != nil {
joinRoomID := livekit.RoomID(initialResponse.GetJoin().GetRoom().GetSid())
if joinRoomID != "" {
roomID = joinRoomID
}
pi.ID = livekit.ParticipantID(initialResponse.GetJoin().GetParticipant().GetSid())
pID = pi.ID
resolveLogger(false)
}
signalStats := telemetry.NewBytesSignalStats(r.Context(), s.telemetry)
if join := initialResponse.GetJoin(); join != nil {
signalStats.ResolveRoom(join.GetRoom())
signalStats.ResolveParticipant(join.GetParticipant())
}
if pi.Reconnect && pi.ID != "" {
signalStats.ResolveParticipant(&livekit.ParticipantInfo{
Sid: string(pi.ID),
Identity: string(pi.Identity),
})
}
closedByClient := atomic.NewBool(false)
done := make(chan struct{})
// function exits when websocket terminates, it'll close the event reading off of request sink and response source as well
defer func() {
pLogger.Debugw("finishing WS connection", "closedByClient", closedByClient.Load())
cr.ResponseSource.Close()
cr.RequestSink.Close()
close(done)
signalStats.Stop()
}()
// upgrade only once the basics are good to go
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
HandleError(w, r, http.StatusInternalServerError, err, getLoggerFields()...)
return
}
s.mu.Lock()
s.connections[conn] = struct{}{}
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.connections, conn)
s.mu.Unlock()
}()
// websocket established
sigConn := NewWSSignalConnection(conn)
pLogger.Debugw("sending initial response", "response", logger.Proto(initialResponse))
count, err := sigConn.WriteResponse(initialResponse)
if err != nil {
resolveLogger(true)
pLogger.Warnw("could not write initial response", err)
return
}
signalStats.AddBytes(uint64(count), true)
pLogger.Debugw(
"new client WS connected",
"reconnect", pi.Reconnect,
"reconnectReason", pi.ReconnectReason,
"adaptiveStream", pi.AdaptiveStream,
"selectedNodeID", cr.NodeID,
"nodeSelectionReason", cr.NodeSelectionReason,
)
// handle responses
go func() {
defer func() {
// when the source is terminated, this means Participant.Close had been called and RTC connection is done
// we would terminate the signal connection as well
closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
_ = conn.WriteControl(websocket.CloseMessage, closeMsg, time.Now().Add(time.Second))
_ = conn.Close()
}()
defer func() {
if r := rtc.Recover(pLogger); r != nil {
os.Exit(1)
}
}()
for {
select {
case <-done:
return
case msg := <-cr.ResponseSource.ReadChan():
if msg == nil {
resolveLogger(true)
pLogger.Debugw("nothing to read from response source")
return
}
res, ok := msg.(*livekit.SignalResponse)
if !ok {
pLogger.Errorw(
"unexpected message type", nil,
"type", fmt.Sprintf("%T", msg),
)
continue
}
switch m := res.Message.(type) {
case *livekit.SignalResponse_Offer:
pLogger.Debugw("sending offer", "offer", m)
case *livekit.SignalResponse_Answer:
pLogger.Debugw("sending answer", "answer", m)
case *livekit.SignalResponse_Join:
pLogger.Debugw("sending join", "join", m)
signalStats.ResolveRoom(m.Join.GetRoom())
signalStats.ResolveParticipant(m.Join.GetParticipant())
case *livekit.SignalResponse_RoomUpdate:
updateRoomID := livekit.RoomID(m.RoomUpdate.GetRoom().GetSid())
if updateRoomID != "" {
roomID = updateRoomID
resolveLogger(false)
}
pLogger.Debugw("sending room update", "roomUpdate", m)
signalStats.ResolveRoom(m.RoomUpdate.GetRoom())
case *livekit.SignalResponse_Update:
pLogger.Debugw("sending participant update", "participantUpdate", m)
case *livekit.SignalResponse_RoomMoved:
resetLogger()
signalStats.Reset()
roomName = livekit.RoomName(m.RoomMoved.GetRoom().GetName())
moveRoomID := livekit.RoomID(m.RoomMoved.GetRoom().GetSid())
if moveRoomID != "" {
roomID = moveRoomID
}
participantIdentity = livekit.ParticipantIdentity(m.RoomMoved.GetParticipant().GetIdentity())
pID = livekit.ParticipantID(m.RoomMoved.GetParticipant().GetSid())
resolveLogger(false)
signalStats.ResolveRoom(m.RoomMoved.GetRoom())
signalStats.ResolveParticipant(m.RoomMoved.GetParticipant())
pLogger.Debugw("sending room moved", "roomMoved", m)
default:
pLogger.Debugw("sending signal response", "response", m)
}
if count, err := sigConn.WriteResponse(res); err != nil {
pLogger.Warnw("error writing to websocket", err)
return
} else {
signalStats.AddBytes(uint64(count), true)
}
}
}
}()
// handle incoming requests from websocket
for {
req, count, err := sigConn.ReadRequest()
if err != nil {
if IsWebSocketCloseError(err) {
closedByClient.Store(true)
} else {
pLogger.Errorw("error reading from websocket", err)
}
return
}
signalStats.AddBytes(uint64(count), false)
switch m := req.Message.(type) {
case *livekit.SignalRequest_Ping:
count, perr := sigConn.WriteResponse(&livekit.SignalResponse{
Message: &livekit.SignalResponse_Pong{
//
// Although this field is int64, some clients (like JS) cause overflow if nanosecond granularity is used.
// So. use UnixMillis().
//
Pong: time.Now().UnixMilli(),
},
})
if perr == nil {
signalStats.AddBytes(uint64(count), true)
}
case *livekit.SignalRequest_PingReq:
count, perr := sigConn.WriteResponse(&livekit.SignalResponse{
Message: &livekit.SignalResponse_PongResp{
PongResp: &livekit.Pong{
LastPingTimestamp: m.PingReq.Timestamp,
Timestamp: time.Now().UnixMilli(),
},
},
})
if perr == nil {
signalStats.AddBytes(uint64(count), true)
}
}
switch m := req.Message.(type) {
case *livekit.SignalRequest_Offer:
pLogger.Debugw("received offer", "offer", m)
case *livekit.SignalRequest_Answer:
pLogger.Debugw("received answer", "answer", m)
default:
pLogger.Debugw("received signal request", "request", m)
}
if err := cr.RequestSink.WriteMessage(req); err != nil {
pLogger.Warnw("error writing to request sink", err)
return
}
}
}
func (s *RTCService) DrainConnections(interval time.Duration) {
s.mu.Lock()
conns := maps.Clone(s.connections)
s.mu.Unlock()
// jitter drain start
time.Sleep(time.Duration(rand.Int63n(int64(interval))))
t := time.NewTicker(interval)
defer t.Stop()
for c := range conns {
_ = c.Close()
<-t.C
}
}
type connectionResult struct {
routing.StartParticipantSignalResults
Room *livekit.Room
}
func (s *RTCService) startConnection(
ctx context.Context,
roomName livekit.RoomName,
pi routing.ParticipantInit,
timeout time.Duration,
) (connectionResult, *livekit.SignalResponse, error) {
var cr connectionResult
var err error
if err := s.roomAllocator.SelectRoomNode(ctx, roomName, ""); err != nil {
return cr, nil, err
}
// this needs to be started first *before* using router functions on this node
cr.StartParticipantSignalResults, err = s.router.StartParticipantSignal(ctx, roomName, pi)
if err != nil {
return cr, nil, err
}
// wait for the first message before upgrading to websocket. If no one is
// responding to our connection attempt, we should terminate the connection
// instead of waiting forever on the WebSocket
initialResponse, err := readInitialResponse(cr.ResponseSource, timeout)
if err != nil {
// close the connection to avoid leaking
cr.RequestSink.Close()
cr.ResponseSource.Close()
return cr, nil, err
}
return cr, initialResponse, nil
}
func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*livekit.SignalResponse, error) {
responseTimer := time.NewTimer(timeout)
defer responseTimer.Stop()
for {
select {
case <-responseTimer.C:
return nil, errors.New("timed out while waiting for signal response")
case msg := <-source.ReadChan():
if msg == nil {
return nil, errors.New("connection closed by media")
}
res, ok := msg.(*livekit.SignalResponse)
if !ok {
return nil, fmt.Errorf("unexpected message type: %T", msg)
}
return res, nil
}
}
}