From e511464d3d4a98e685facd4228cd8d3b18c4a5d4 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Tue, 2 Jul 2024 13:11:08 -0700 Subject: [PATCH] add handler interface to receive agent worker updates (#2830) * add handler interface to receive agent worker updates * cleanup --- pkg/agent/worker.go | 203 ++++++----- pkg/rtc/types/interfaces.go | 1 + .../types/typesfakes/fake_websocket_client.go | 65 ++++ pkg/service/agentservice.go | 324 ++++++++---------- pkg/service/rtcservice.go | 14 +- pkg/service/wsprotocol.go | 21 ++ 6 files changed, 343 insertions(+), 285 deletions(-) diff --git a/pkg/agent/worker.go b/pkg/agent/worker.go index daed90b85..a7d2a4bad 100644 --- a/pkg/agent/worker.go +++ b/pkg/agent/worker.go @@ -21,13 +21,9 @@ import ( "sync/atomic" "time" - "github.com/gorilla/websocket" - pagent "github.com/livekit/protocol/agent" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/utils" - putil "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/guid" ) @@ -42,13 +38,40 @@ const ( ) var ( - ErrWorkerClosed = errors.New("worker closed") - ErrWorkerNotAvailable = errors.New("worker not available") - ErrAvailabilityTimeout = errors.New("agent worker availability timeout") + ErrWorkerClosed = errors.New("worker closed") + ErrWorkerNotAvailable = errors.New("worker not available") + ErrAvailabilityTimeout = errors.New("agent worker availability timeout") + ErrDuplicateJobAssignment = errors.New("duplicate job assignment") ) -type sigConn interface { +type SignalConn interface { WriteServerMessage(msg *livekit.ServerMessage) (int, error) + ReadWorkerMessage() (*livekit.WorkerMessage, int, error) + Close() error +} + +type WorkerHandler interface { + HandleWorkerRegister(w *Worker) + HandleWorkerDeregister(w *Worker) + HandleWorkerStatus(w *Worker, status *livekit.UpdateWorkerStatus) + HandleWorkerJobStatus(w *Worker, status *livekit.UpdateJobStatus) + HandleWorkerSimulateJob(w *Worker, job *livekit.Job) + HandleWorkerMigrateJob(w *Worker, request *livekit.MigrateJobRequest) +} + +var _ WorkerHandler = UnimplementedWorkerHandler{} + +type UnimplementedWorkerHandler struct{} + +func (UnimplementedWorkerHandler) HandleWorkerRegister(*Worker) {} +func (UnimplementedWorkerHandler) HandleWorkerDeregister(*Worker) {} +func (UnimplementedWorkerHandler) HandleWorkerStatus(*Worker, *livekit.UpdateWorkerStatus) {} +func (UnimplementedWorkerHandler) HandleWorkerJobStatus(*Worker, *livekit.UpdateJobStatus) {} +func (UnimplementedWorkerHandler) HandleWorkerSimulateJob(*Worker, *livekit.Job) {} +func (UnimplementedWorkerHandler) HandleWorkerMigrateJob(*Worker, *livekit.MigrateJobRequest) {} + +func JobStatusIsEnded(s livekit.JobStatus) bool { + return s == livekit.JobStatus_JS_SUCCESS || s == livekit.JobStatus_JS_FAILED } type Worker struct { @@ -69,18 +92,17 @@ type Worker struct { status livekit.WorkerStatus runningJobs map[string]*livekit.Job - onWorkerRegistered func(w *Worker) + handler WorkerHandler - conn *websocket.Conn - sigConn sigConn - closed chan struct{} + conn SignalConn + closed chan struct{} availability map[string]chan *livekit.AvailabilityResponse ctx context.Context cancel context.CancelFunc - Logger logger.Logger + logger logger.Logger } func NewWorker( @@ -88,14 +110,15 @@ func NewWorker( apiKey string, apiSecret string, serverInfo *livekit.ServerInfo, - conn *websocket.Conn, - sigConn sigConn, + conn SignalConn, logger logger.Logger, + handler WorkerHandler, ) *Worker { ctx, cancel := context.WithCancel(context.Background()) + id := guid.New(guid.AgentWorkerPrefix) w := &Worker{ - id: putil.NewGuid(utils.AgentWorkerPrefix), + id: id, protocolVersion: protocolVersion, apiKey: apiKey, apiSecret: apiSecret, @@ -104,26 +127,25 @@ func NewWorker( runningJobs: make(map[string]*livekit.Job), availability: make(map[string]chan *livekit.AvailabilityResponse), conn: conn, - sigConn: sigConn, ctx: ctx, cancel: cancel, - Logger: logger, + logger: logger.WithValues("workerID", id), + handler: handler, } - go func() { - <-time.After(registerTimeout) + time.AfterFunc(registerTimeout, func() { if !w.registered.Load() && !w.IsClosed() { - w.Logger.Warnw("worker did not register in time", nil, "id", w.id) + w.logger.Warnw("worker did not register in time", nil, "id", w.id) w.Close() } - }() + }) return w } func (w *Worker) sendRequest(req *livekit.ServerMessage) { - if _, err := w.sigConn.WriteServerMessage(req); err != nil { - w.Logger.Errorw("error writing to websocket", err) + if _, err := w.conn.WriteServerMessage(req); err != nil { + w.logger.Errorw("error writing to websocket", err) } } @@ -132,15 +154,10 @@ func (w *Worker) ID() string { } func (w *Worker) JobType() livekit.JobType { - w.mu.Lock() - defer w.mu.Unlock() - return w.jobType } func (w *Worker) Namespace() string { - w.mu.Lock() - defer w.mu.Unlock() return w.namespace } @@ -156,20 +173,14 @@ func (w *Worker) Load() float32 { return w.load } -func (w *Worker) OnWorkerRegistered(f func(w *Worker)) { - w.mu.Lock() - defer w.mu.Unlock() - w.onWorkerRegistered = f -} - -func (w *Worker) Registered() bool { - return w.registered.Load() +func (w *Worker) Logger() logger.Logger { + return w.logger } func (w *Worker) RunningJobs() map[string]*livekit.Job { - jobs := make(map[string]*livekit.Job, len(w.runningJobs)) w.mu.Lock() defer w.mu.Unlock() + jobs := make(map[string]*livekit.Job, len(w.runningJobs)) for k, v := range w.runningJobs { jobs[k] = v } @@ -180,9 +191,20 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { availCh := make(chan *livekit.AvailabilityResponse, 1) w.mu.Lock() + if _, ok := w.availability[job.Id]; ok { + w.mu.Unlock() + return ErrDuplicateJobAssignment + } + w.availability[job.Id] = availCh w.mu.Unlock() + defer func() { + w.mu.Lock() + delete(w.availability, job.Id) + w.mu.Unlock() + }() + if job.State == nil { job.State = &livekit.JobState{} } @@ -191,6 +213,9 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { Availability: &livekit.AvailabilityRequest{Job: job}, }}) + timeout := time.NewTimer(assignJobTimeout) + defer timeout.Stop() + // See handleAvailability for the response select { case res := <-availCh: @@ -200,7 +225,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { token, err := pagent.BuildAgentToken(w.apiKey, w.apiSecret, job.Room.Name, res.ParticipantIdentity, res.ParticipantName, res.ParticipantMetadata, w.permissions) if err != nil { - w.Logger.Errorw("failed to build agent token", err) + w.logger.Errorw("failed to build agent token", err) return err } @@ -216,7 +241,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { // TODO sweep jobs that are never started. We can't do this until all SDKs actually update the the JOB state return nil - case <-time.After(assignJobTimeout): + case <-timeout.C: return ErrAvailabilityTimeout case <-w.ctx.Done(): return ErrWorkerClosed @@ -225,17 +250,8 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { } } -func (w *Worker) UpdateStatus(status *livekit.UpdateWorkerStatus) { - w.mu.Lock() - if status.Status != nil { - w.status = status.GetStatus() - } - w.load = status.GetLoad() - w.mu.Unlock() -} - func (w *Worker) UpdateMetadata(metadata string) { - w.Logger.Debugw("worker metadata updated", nil, "metadata", metadata) + w.logger.Debugw("worker metadata updated", nil, "metadata", metadata) } func (w *Worker) IsClosed() bool { @@ -254,41 +270,52 @@ func (w *Worker) Close() { return } - w.Logger.Infow("closing worker") + w.logger.Infow("closing worker") close(w.closed) w.cancel() _ = w.conn.Close() w.mu.Unlock() + + if w.registered.Load() { + w.handler.HandleWorkerDeregister(w) + } } func (w *Worker) HandleMessage(req *livekit.WorkerMessage) { switch m := req.Message.(type) { case *livekit.WorkerMessage_Register: - go w.handleRegister(m.Register) + w.handleRegister(m.Register) case *livekit.WorkerMessage_Availability: - go w.handleAvailability(m.Availability) + w.handleAvailability(m.Availability) case *livekit.WorkerMessage_UpdateJob: - go w.handleJobUpdate(m.UpdateJob) + w.handleJobUpdate(m.UpdateJob) case *livekit.WorkerMessage_SimulateJob: - go w.handleSimulateJob(m.SimulateJob) + w.handleSimulateJob(m.SimulateJob) case *livekit.WorkerMessage_Ping: - go w.handleWorkerPing(m.Ping) + w.handleWorkerPing(m.Ping) case *livekit.WorkerMessage_UpdateWorker: - go w.handleWorkerStatus(m.UpdateWorker) + w.handleWorkerStatus(m.UpdateWorker) case *livekit.WorkerMessage_MigrateJob: - go w.handleMigrateJob(m.MigrateJob) + w.handleMigrateJob(m.MigrateJob) } } func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) { - if w.registered.Load() { - w.Logger.Warnw("worker already registered", nil, "id", w.id) + w.mu.Lock() + var err error + if w.IsClosed() { + err = errors.New("worker closed") + } + if w.registered.Swap(true) { + err = errors.New("worker already registered") + } + if err != nil { + w.mu.Unlock() + w.logger.Warnw("unable to register worker", err, "id", w.id) return } - w.mu.Lock() - onWorkerRegistered := w.onWorkerRegistered w.version = req.Version w.name = req.Name w.namespace = req.GetNamespace() @@ -307,10 +334,9 @@ func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) { } w.status = livekit.WorkerStatus_WS_AVAILABLE - w.registered.Store(true) w.mu.Unlock() - w.Logger.Debugw("worker registered", "request", req) + w.logger.Debugw("worker registered", "request", logger.Proto(req)) w.sendRequest(&livekit.ServerMessage{ Message: &livekit.ServerMessage_Register{ @@ -321,9 +347,7 @@ func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) { }, }) - if onWorkerRegistered != nil { - onWorkerRegistered(w) - } + w.handler.HandleWorkerRegister(w) } func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) { @@ -332,7 +356,7 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) { availCh, ok := w.availability[res.JobId] if !ok { - w.Logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId) + w.logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId) return } @@ -342,32 +366,34 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) { func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) { w.mu.Lock() - defer w.mu.Unlock() job, ok := w.runningJobs[update.JobId] if !ok { - w.Logger.Infow("received job update for unknown job", "jobId", update.JobId) + w.logger.Infow("received job update for unknown job", "jobId", update.JobId) return } now := time.Now() job.State.UpdatedAt = now.UnixNano() - if job.State.Status == livekit.JobStatus_JS_PENDING && update.Status >= livekit.JobStatus_JS_RUNNING { + if job.State.Status == livekit.JobStatus_JS_PENDING && JobStatusIsEnded(update.Status) { job.State.StartedAt = now.UnixNano() } - if job.State.Status < livekit.JobStatus_JS_SUCCESS && update.Status >= livekit.JobStatus_JS_SUCCESS { + if job.State.Status < livekit.JobStatus_JS_SUCCESS && JobStatusIsEnded(update.Status) { job.State.EndedAt = now.UnixNano() } job.State.Status = update.Status job.State.Error = update.Error - // TODO do not delete, leafve inside the JobDefinition - if job.State.Status >= livekit.JobStatus_JS_SUCCESS { + // TODO do not delete, leave inside the JobDefinition + if JobStatusIsEnded(job.State.Status) { delete(w.runningJobs, job.Id) } + w.mu.Unlock() + + w.handler.HandleWorkerJobStatus(w, update) } func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) { @@ -377,19 +403,21 @@ func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) { } job := &livekit.Job{ - Id: guid.New(utils.AgentJobPrefix), + Id: guid.New(guid.AgentJobPrefix), Type: jobType, Room: simulate.Room, Participant: simulate.Participant, Namespace: w.Namespace(), } - ctx := context.Background() - err := w.AssignJob(ctx, job) - if err != nil { - w.Logger.Errorw("failed to simulate job, assignment failed", err, "jobId", job.Id) - } - + go func() { + err := w.AssignJob(w.ctx, job) + if err != nil { + w.logger.Errorw("failed to simulate job, assignment failed", err, "jobId", job.Id) + } else { + w.handler.HandleWorkerSimulateJob(w, job) + } + }() } func (w *Worker) handleWorkerPing(ping *livekit.WorkerPing) { @@ -402,11 +430,20 @@ func (w *Worker) handleWorkerPing(ping *livekit.WorkerPing) { } func (w *Worker) handleWorkerStatus(update *livekit.UpdateWorkerStatus) { - w.Logger.Debugw("worker status update", "status", update.Status, "load", update.Load) - w.UpdateStatus(update) + w.logger.Debugw("worker status update", "update", logger.Proto(update)) + + w.mu.Lock() + if update.Status != nil { + w.status = update.GetStatus() + } + w.load = update.GetLoad() + w.mu.Unlock() + + w.handler.HandleWorkerStatus(w, update) } func (w *Worker) handleMigrateJob(migrate *livekit.MigrateJobRequest) { // TODO(theomonnom): On OSS this is not implemented // We could maybe just move a specific job to another worker + w.handler.HandleWorkerMigrateJob(w, migrate) } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index e548900e8..e877677c7 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -39,6 +39,7 @@ type WebsocketClient interface { ReadMessage() (messageType int, p []byte, err error) WriteMessage(messageType int, data []byte) error WriteControl(messageType int, data []byte, deadline time.Time) error + Close() error } type AddSubscriberParams struct { diff --git a/pkg/rtc/types/typesfakes/fake_websocket_client.go b/pkg/rtc/types/typesfakes/fake_websocket_client.go index 8cb00b9d9..0a9c14b7f 100644 --- a/pkg/rtc/types/typesfakes/fake_websocket_client.go +++ b/pkg/rtc/types/typesfakes/fake_websocket_client.go @@ -9,6 +9,16 @@ import ( ) type FakeWebsocketClient struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } ReadMessageStub func() (int, []byte, error) readMessageMutex sync.RWMutex readMessageArgsForCall []struct { @@ -52,6 +62,59 @@ type FakeWebsocketClient struct { invocationsMutex sync.RWMutex } +func (fake *FakeWebsocketClient) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebsocketClient) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeWebsocketClient) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeWebsocketClient) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeWebsocketClient) ReadMessage() (int, []byte, error) { fake.readMessageMutex.Lock() ret, specificReturn := fake.readMessageReturnsOnCall[len(fake.readMessageArgsForCall)] @@ -249,6 +312,8 @@ func (fake *FakeWebsocketClient) WriteMessageReturnsOnCall(i int, result1 error) func (fake *FakeWebsocketClient) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() fake.readMessageMutex.RLock() defer fake.readMessageMutex.RUnlock() fake.writeControlMutex.RLock() diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index 82f6b3af7..dd3cbd6dc 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -17,11 +17,11 @@ package service import ( "context" "errors" - "io" "math/rand" "net/http" + "slices" + "sort" "strconv" - "strings" "sync" "time" @@ -41,13 +41,48 @@ import ( "github.com/livekit/psrpc" ) +type AgentSocketUpgrader struct { + websocket.Upgrader +} + +func (u AgentSocketUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, agent.WorkerProtocolVersion, bool) { + // reject non websocket requests + if !websocket.IsWebSocketUpgrade(r) { + w.WriteHeader(404) + return nil, 0, false + } + + // require a claim + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil || !claims.Video.Agent { + handleError(w, r, http.StatusUnauthorized, rtc.ErrPermissionDenied) + return nil, 0, false + } + + // upgrade + conn, err := u.Upgrader.Upgrade(w, r, responseHeader) + if err != nil { + handleError(w, r, http.StatusInternalServerError, err) + return nil, 0, false + } + + var protocol agent.WorkerProtocolVersion = agent.CurrentProtocol + if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil { + protocol = agent.WorkerProtocolVersion(pv) + } + + return conn, protocol, true +} + type AgentService struct { - upgrader websocket.Upgrader + upgrader AgentSocketUpgrader *AgentHandler } type AgentHandler struct { + agent.UnimplementedWorkerHandler + agentServer rpc.AgentInternalServer mu sync.Mutex logger logger.Logger @@ -56,18 +91,19 @@ type AgentHandler struct { workers map[string]*agent.Worker keyProvider auth.KeyProvider + namespaceWorkers map[workerKey][]*agent.Worker + roomKeyCount int + publisherKeyCount int // TODO remove once deprecated CheckEnabled is removed - namespaces map[string]*namespaceInfo - publisherEnabled bool - roomEnabled bool + namespaces []string roomTopic string publisherTopic string } -type namespaceInfo struct { - numPublishers int32 - numRooms int32 +type workerKey struct { + namespace string + jobType livekit.JobType } func NewAgentService(conf *config.Config, @@ -75,9 +111,7 @@ func NewAgentService(conf *config.Config, bus psrpc.MessageBus, keyProvider auth.KeyProvider, ) (*AgentService, error) { - s := &AgentService{ - upgrader: websocket.Upgrader{}, - } + s := &AgentService{} // allow connections from any origin, since script may be hosted anywhere // security is enforced by access tokens @@ -110,27 +144,9 @@ func NewAgentService(conf *config.Config, } func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) { - // reject non websocket requests - if !websocket.IsWebSocketUpgrade(r) { - writer.WriteHeader(404) - return + if conn, protocol, ok := s.upgrader.Upgrade(writer, r, nil); ok { + s.HandleConnection(r.Context(), NewWSSignalConnection(conn), protocol) } - - // require a claim - claims := GetGrants(r.Context()) - if claims == nil || claims.Video == nil || !claims.Video.Agent { - handleError(writer, r, http.StatusUnauthorized, rtc.ErrPermissionDenied) - return - } - - // upgrade - conn, err := s.upgrader.Upgrade(writer, r, nil) - if err != nil { - handleError(writer, r, http.StatusInternalServerError, err) - return - } - - s.HandleConnection(r, conn, nil) } func NewAgentHandler( @@ -142,193 +158,136 @@ func NewAgentHandler( publisherTopic string, ) *AgentHandler { return &AgentHandler{ - agentServer: agentServer, - logger: logger, - workers: make(map[string]*agent.Worker), - namespaces: make(map[string]*namespaceInfo), - serverInfo: serverInfo, - keyProvider: keyProvider, - roomTopic: roomTopic, - publisherTopic: publisherTopic, + agentServer: agentServer, + logger: logger, + workers: make(map[string]*agent.Worker), + namespaceWorkers: make(map[workerKey][]*agent.Worker), + serverInfo: serverInfo, + keyProvider: keyProvider, + roomTopic: roomTopic, + publisherTopic: publisherTopic, } } -func (h *AgentHandler) HandleConnection(r *http.Request, conn *websocket.Conn, onIdle func()) { - var protocol agent.WorkerProtocolVersion - if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil { - protocol = agent.WorkerProtocolVersion(pv) - } - - sigConn := NewWSSignalConnection(conn) - - apiKey := GetAPIKey(r.Context()) +func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalConn, protocol agent.WorkerProtocolVersion) { + apiKey := GetAPIKey(ctx) apiSecret := h.keyProvider.GetSecret(apiKey) - worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, sigConn, h.logger) - worker.OnWorkerRegistered(h.handleWorkerRegister) + worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, h.logger, h) h.mu.Lock() h.workers[worker.ID()] = worker h.mu.Unlock() - defer func() { - worker.Close() - - h.mu.Lock() - delete(h.workers, worker.ID()) - numWorkers := len(h.workers) - h.mu.Unlock() - - if worker.Registered() { - h.handleWorkerDeregister(worker) - } - - if numWorkers == 0 && onIdle != nil { - onIdle() - } - }() - for { - req, _, err := sigConn.ReadWorkerMessage() + req, _, err := conn.ReadWorkerMessage() if err != nil { - // normal/expected closure - if err == io.EOF || - strings.HasSuffix(err.Error(), "use of closed network connection") || - strings.HasSuffix(err.Error(), "connection reset by peer") || - websocket.IsCloseError( - err, - websocket.CloseAbnormalClosure, - websocket.CloseGoingAway, - websocket.CloseNormalClosure, - websocket.CloseNoStatusReceived, - ) { - worker.Logger.Infow("worker closed WS connection", "wsError", err) + if IsWebSocketCloseError(err) { + worker.Logger().Infow("worker closed WS connection", "wsError", err) } else { - worker.Logger.Errorw("error reading from websocket", err) + worker.Logger().Errorw("error reading from websocket", err) } - return + break } worker.HandleMessage(req) } -} -func (h *AgentHandler) handleWorkerRegister(w *agent.Worker) { h.mu.Lock() - - info, ok := h.namespaces[w.Namespace()] - numPublishers := int32(0) - numRooms := int32(0) - if ok { - numPublishers = info.numPublishers - numRooms = info.numRooms - } - - shouldNotify := false - var err error - if w.JobType() == livekit.JobType_JT_PUBLISHER { - numPublishers++ - if numPublishers == 1 { - shouldNotify = true - err = h.agentServer.RegisterJobRequestTopic(w.Namespace(), h.publisherTopic) - } - - } else if w.JobType() == livekit.JobType_JT_ROOM { - numRooms++ - if numRooms == 1 { - shouldNotify = true - err = h.agentServer.RegisterJobRequestTopic(w.Namespace(), h.roomTopic) - } - } - - if err != nil { - w.Logger.Errorw("failed to register job request topic", err) - h.mu.Unlock() - w.Close() // Close the worker - return - } - - h.namespaces[w.Namespace()] = &namespaceInfo{ - numPublishers: numPublishers, - numRooms: numRooms, - } - - h.roomEnabled = h.roomAvailableLocked() - h.publisherEnabled = h.publisherAvailableLocked() + delete(h.workers, worker.ID()) h.mu.Unlock() - if shouldNotify { - h.logger.Infow("initial worker registered", "namespace", w.Namespace(), "jobType", w.JobType()) - err = h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{}) + worker.Close() +} + +func (h *AgentHandler) HandleWorkerRegister(w *agent.Worker) { + h.mu.Lock() + + key := workerKey{w.Namespace(), w.JobType()} + + workers := h.namespaceWorkers[key] + created := len(workers) == 0 + + if created { + topic := h.roomTopic + if w.JobType() == livekit.JobType_JT_PUBLISHER { + topic = h.publisherTopic + } + err := h.agentServer.RegisterJobRequestTopic(w.Namespace(), topic) if err != nil { - w.Logger.Errorw("failed to publish worker registered", err) + h.mu.Unlock() + + w.Logger().Errorw("failed to register job request topic", err) + w.Close() + return + } + + if w.JobType() == livekit.JobType_JT_ROOM { + h.roomKeyCount++ + } else { + h.publisherKeyCount++ + } + + h.namespaces = append(h.namespaces, w.Namespace()) + sort.Strings(h.namespaces) + } + + h.namespaceWorkers[key] = append(workers, w) + h.mu.Unlock() + + if created { + h.logger.Infow("initial worker registered", "namespace", w.Namespace(), "jobType", w.JobType()) + err := h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{}) + if err != nil { + w.Logger().Errorw("failed to publish worker registered", err) } } } -func (h *AgentHandler) handleWorkerDeregister(worker *agent.Worker) { +func (h *AgentHandler) HandleWorkerDeregister(w *agent.Worker) { h.mu.Lock() defer h.mu.Unlock() - info, ok := h.namespaces[worker.Namespace()] + key := workerKey{w.Namespace(), w.JobType()} + + workers, ok := h.namespaceWorkers[key] if !ok { return } - - if worker.JobType() == livekit.JobType_JT_PUBLISHER { - info.numPublishers-- - if info.numPublishers == 0 { - h.agentServer.DeregisterJobRequestTopic(worker.Namespace(), h.publisherTopic) - } - } else if worker.JobType() == livekit.JobType_JT_ROOM { - info.numRooms-- - if info.numRooms == 0 { - h.agentServer.DeregisterJobRequestTopic(worker.Namespace(), h.roomTopic) - } + index := slices.Index(workers, w) + if index == -1 { + return } - if info.numPublishers == 0 && info.numRooms == 0 { + if len(workers) > 1 { + h.namespaceWorkers[key] = slices.Delete(workers, index, index+1) + } else { h.logger.Debugw("last worker deregistered") - delete(h.namespaces, worker.Namespace()) - } + delete(h.namespaceWorkers, key) - h.roomEnabled = h.roomAvailableLocked() - h.publisherEnabled = h.publisherAvailableLocked() -} - -func (h *AgentHandler) roomAvailableLocked() bool { - for _, w := range h.workers { if w.JobType() == livekit.JobType_JT_ROOM { - return true + h.roomKeyCount-- + h.agentServer.DeregisterJobRequestTopic(w.Namespace(), h.roomTopic) + } else { + h.publisherKeyCount-- + h.agentServer.DeregisterJobRequestTopic(w.Namespace(), h.publisherTopic) + } + + if i := slices.Index(h.namespaces, w.Namespace()); i != -1 { + h.namespaces = slices.Delete(h.namespaces, i, i+1) } } - return false - -} - -func (h *AgentHandler) publisherAvailableLocked() bool { - for _, w := range h.workers { - if w.JobType() == livekit.JobType_JT_PUBLISHER { - return true - } - } - - return false } func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) { - attempted := make(map[string]bool) + key := workerKey{job.Namespace, job.Type} + attempted := make(map[*agent.Worker]struct{}) for { h.mu.Lock() var selected *agent.Worker var maxLoad float32 - for _, w := range h.workers { - if w.Namespace() != job.Namespace || w.JobType() != job.Type { - continue - } - - _, ok := attempted[w.ID()] - if ok { + for _, w := range h.namespaceWorkers[key] { + if _, ok := attempted[w]; ok { continue } @@ -341,7 +300,6 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty selected = w } } - } h.mu.Unlock() @@ -349,7 +307,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "no workers available") } - attempted[selected.ID()] = true + attempted[selected] = struct{}{} values := []interface{}{ "jobID", job.Id, @@ -373,7 +331,6 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty return &emptypb.Empty{}, nil } - } func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 { @@ -396,7 +353,6 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) affinity = 0.5 } } - } return affinity @@ -405,24 +361,14 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) func (h *AgentHandler) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) { h.mu.Lock() defer h.mu.Unlock() - namespaces := make([]string, 0, len(h.namespaces)) - for ns := range h.namespaces { - namespaces = append(namespaces, ns) - } return &rpc.CheckEnabledResponse{ - Namespaces: namespaces, - RoomEnabled: h.roomEnabled, - PublisherEnabled: h.publisherEnabled, + Namespaces: slices.Compact(slices.Clone(h.namespaces)), + RoomEnabled: h.roomKeyCount != 0, + PublisherEnabled: h.publisherKeyCount != 0, }, nil } -func (h *AgentHandler) NumConnections() int { - h.mu.Lock() - defer h.mu.Unlock() - return len(h.workers) -} - func (h *AgentHandler) DrainConnections(interval time.Duration) { // jitter drain start time.Sleep(time.Duration(rand.Int63n(int64(interval)))) diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index 74ce2f564..86952b767 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -18,12 +18,10 @@ import ( "context" "errors" "fmt" - "io" "math/rand" "net/http" "os" "strconv" - "strings" "sync" "time" @@ -369,17 +367,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { for { req, count, err := sigConn.ReadRequest() if err != nil { - // normal/expected closure - if errors.Is(err, io.EOF) || - strings.HasSuffix(err.Error(), "use of closed network connection") || - strings.HasSuffix(err.Error(), "connection reset by peer") || - websocket.IsCloseError( - err, - websocket.CloseAbnormalClosure, - websocket.CloseGoingAway, - websocket.CloseNormalClosure, - websocket.CloseNoStatusReceived, - ) { + if IsWebSocketCloseError(err) { closedByClient.Store(true) } else { pLogger.Errorw("error reading from websocket", err, "connID", cr.ConnectionID) diff --git a/pkg/service/wsprotocol.go b/pkg/service/wsprotocol.go index d94ce3f90..cdfb72a8a 100644 --- a/pkg/service/wsprotocol.go +++ b/pkg/service/wsprotocol.go @@ -15,6 +15,9 @@ package service import ( + "errors" + "io" + "strings" "sync" "time" @@ -49,6 +52,10 @@ func NewWSSignalConnection(conn types.WebsocketClient) *WSSignalConnection { return wsc } +func (c *WSSignalConnection) Close() error { + return c.conn.Close() +} + func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error) { for { // handle special messages and pass on the rest @@ -172,3 +179,17 @@ func (c *WSSignalConnection) pingWorker() { } } } + +// IsWebSocketCloseError checks that error is normal/expected closure +func IsWebSocketCloseError(err error) bool { + return errors.Is(err, io.EOF) || + strings.HasSuffix(err.Error(), "use of closed network connection") || + strings.HasSuffix(err.Error(), "connection reset by peer") || + websocket.IsCloseError( + err, + websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + ) +}