add handler interface to receive agent worker updates (#2830)

* add handler interface to receive agent worker updates

* cleanup
This commit is contained in:
Paul Wells
2024-07-02 13:11:08 -07:00
committed by GitHub
parent 7dff092285
commit e511464d3d
6 changed files with 343 additions and 285 deletions
+120 -83
View File
@@ -21,13 +21,9 @@ import (
"sync/atomic"
"time"
"github.com/gorilla/websocket"
pagent "github.com/livekit/protocol/agent"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils"
putil "github.com/livekit/protocol/utils"
"github.com/livekit/protocol/utils/guid"
)
@@ -42,13 +38,40 @@ const (
)
var (
ErrWorkerClosed = errors.New("worker closed")
ErrWorkerNotAvailable = errors.New("worker not available")
ErrAvailabilityTimeout = errors.New("agent worker availability timeout")
ErrWorkerClosed = errors.New("worker closed")
ErrWorkerNotAvailable = errors.New("worker not available")
ErrAvailabilityTimeout = errors.New("agent worker availability timeout")
ErrDuplicateJobAssignment = errors.New("duplicate job assignment")
)
type sigConn interface {
type SignalConn interface {
WriteServerMessage(msg *livekit.ServerMessage) (int, error)
ReadWorkerMessage() (*livekit.WorkerMessage, int, error)
Close() error
}
type WorkerHandler interface {
HandleWorkerRegister(w *Worker)
HandleWorkerDeregister(w *Worker)
HandleWorkerStatus(w *Worker, status *livekit.UpdateWorkerStatus)
HandleWorkerJobStatus(w *Worker, status *livekit.UpdateJobStatus)
HandleWorkerSimulateJob(w *Worker, job *livekit.Job)
HandleWorkerMigrateJob(w *Worker, request *livekit.MigrateJobRequest)
}
var _ WorkerHandler = UnimplementedWorkerHandler{}
type UnimplementedWorkerHandler struct{}
func (UnimplementedWorkerHandler) HandleWorkerRegister(*Worker) {}
func (UnimplementedWorkerHandler) HandleWorkerDeregister(*Worker) {}
func (UnimplementedWorkerHandler) HandleWorkerStatus(*Worker, *livekit.UpdateWorkerStatus) {}
func (UnimplementedWorkerHandler) HandleWorkerJobStatus(*Worker, *livekit.UpdateJobStatus) {}
func (UnimplementedWorkerHandler) HandleWorkerSimulateJob(*Worker, *livekit.Job) {}
func (UnimplementedWorkerHandler) HandleWorkerMigrateJob(*Worker, *livekit.MigrateJobRequest) {}
func JobStatusIsEnded(s livekit.JobStatus) bool {
return s == livekit.JobStatus_JS_SUCCESS || s == livekit.JobStatus_JS_FAILED
}
type Worker struct {
@@ -69,18 +92,17 @@ type Worker struct {
status livekit.WorkerStatus
runningJobs map[string]*livekit.Job
onWorkerRegistered func(w *Worker)
handler WorkerHandler
conn *websocket.Conn
sigConn sigConn
closed chan struct{}
conn SignalConn
closed chan struct{}
availability map[string]chan *livekit.AvailabilityResponse
ctx context.Context
cancel context.CancelFunc
Logger logger.Logger
logger logger.Logger
}
func NewWorker(
@@ -88,14 +110,15 @@ func NewWorker(
apiKey string,
apiSecret string,
serverInfo *livekit.ServerInfo,
conn *websocket.Conn,
sigConn sigConn,
conn SignalConn,
logger logger.Logger,
handler WorkerHandler,
) *Worker {
ctx, cancel := context.WithCancel(context.Background())
id := guid.New(guid.AgentWorkerPrefix)
w := &Worker{
id: putil.NewGuid(utils.AgentWorkerPrefix),
id: id,
protocolVersion: protocolVersion,
apiKey: apiKey,
apiSecret: apiSecret,
@@ -104,26 +127,25 @@ func NewWorker(
runningJobs: make(map[string]*livekit.Job),
availability: make(map[string]chan *livekit.AvailabilityResponse),
conn: conn,
sigConn: sigConn,
ctx: ctx,
cancel: cancel,
Logger: logger,
logger: logger.WithValues("workerID", id),
handler: handler,
}
go func() {
<-time.After(registerTimeout)
time.AfterFunc(registerTimeout, func() {
if !w.registered.Load() && !w.IsClosed() {
w.Logger.Warnw("worker did not register in time", nil, "id", w.id)
w.logger.Warnw("worker did not register in time", nil, "id", w.id)
w.Close()
}
}()
})
return w
}
func (w *Worker) sendRequest(req *livekit.ServerMessage) {
if _, err := w.sigConn.WriteServerMessage(req); err != nil {
w.Logger.Errorw("error writing to websocket", err)
if _, err := w.conn.WriteServerMessage(req); err != nil {
w.logger.Errorw("error writing to websocket", err)
}
}
@@ -132,15 +154,10 @@ func (w *Worker) ID() string {
}
func (w *Worker) JobType() livekit.JobType {
w.mu.Lock()
defer w.mu.Unlock()
return w.jobType
}
func (w *Worker) Namespace() string {
w.mu.Lock()
defer w.mu.Unlock()
return w.namespace
}
@@ -156,20 +173,14 @@ func (w *Worker) Load() float32 {
return w.load
}
func (w *Worker) OnWorkerRegistered(f func(w *Worker)) {
w.mu.Lock()
defer w.mu.Unlock()
w.onWorkerRegistered = f
}
func (w *Worker) Registered() bool {
return w.registered.Load()
func (w *Worker) Logger() logger.Logger {
return w.logger
}
func (w *Worker) RunningJobs() map[string]*livekit.Job {
jobs := make(map[string]*livekit.Job, len(w.runningJobs))
w.mu.Lock()
defer w.mu.Unlock()
jobs := make(map[string]*livekit.Job, len(w.runningJobs))
for k, v := range w.runningJobs {
jobs[k] = v
}
@@ -180,9 +191,20 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
availCh := make(chan *livekit.AvailabilityResponse, 1)
w.mu.Lock()
if _, ok := w.availability[job.Id]; ok {
w.mu.Unlock()
return ErrDuplicateJobAssignment
}
w.availability[job.Id] = availCh
w.mu.Unlock()
defer func() {
w.mu.Lock()
delete(w.availability, job.Id)
w.mu.Unlock()
}()
if job.State == nil {
job.State = &livekit.JobState{}
}
@@ -191,6 +213,9 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
Availability: &livekit.AvailabilityRequest{Job: job},
}})
timeout := time.NewTimer(assignJobTimeout)
defer timeout.Stop()
// See handleAvailability for the response
select {
case res := <-availCh:
@@ -200,7 +225,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
token, err := pagent.BuildAgentToken(w.apiKey, w.apiSecret, job.Room.Name, res.ParticipantIdentity, res.ParticipantName, res.ParticipantMetadata, w.permissions)
if err != nil {
w.Logger.Errorw("failed to build agent token", err)
w.logger.Errorw("failed to build agent token", err)
return err
}
@@ -216,7 +241,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
// TODO sweep jobs that are never started. We can't do this until all SDKs actually update the the JOB state
return nil
case <-time.After(assignJobTimeout):
case <-timeout.C:
return ErrAvailabilityTimeout
case <-w.ctx.Done():
return ErrWorkerClosed
@@ -225,17 +250,8 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
}
}
func (w *Worker) UpdateStatus(status *livekit.UpdateWorkerStatus) {
w.mu.Lock()
if status.Status != nil {
w.status = status.GetStatus()
}
w.load = status.GetLoad()
w.mu.Unlock()
}
func (w *Worker) UpdateMetadata(metadata string) {
w.Logger.Debugw("worker metadata updated", nil, "metadata", metadata)
w.logger.Debugw("worker metadata updated", nil, "metadata", metadata)
}
func (w *Worker) IsClosed() bool {
@@ -254,41 +270,52 @@ func (w *Worker) Close() {
return
}
w.Logger.Infow("closing worker")
w.logger.Infow("closing worker")
close(w.closed)
w.cancel()
_ = w.conn.Close()
w.mu.Unlock()
if w.registered.Load() {
w.handler.HandleWorkerDeregister(w)
}
}
func (w *Worker) HandleMessage(req *livekit.WorkerMessage) {
switch m := req.Message.(type) {
case *livekit.WorkerMessage_Register:
go w.handleRegister(m.Register)
w.handleRegister(m.Register)
case *livekit.WorkerMessage_Availability:
go w.handleAvailability(m.Availability)
w.handleAvailability(m.Availability)
case *livekit.WorkerMessage_UpdateJob:
go w.handleJobUpdate(m.UpdateJob)
w.handleJobUpdate(m.UpdateJob)
case *livekit.WorkerMessage_SimulateJob:
go w.handleSimulateJob(m.SimulateJob)
w.handleSimulateJob(m.SimulateJob)
case *livekit.WorkerMessage_Ping:
go w.handleWorkerPing(m.Ping)
w.handleWorkerPing(m.Ping)
case *livekit.WorkerMessage_UpdateWorker:
go w.handleWorkerStatus(m.UpdateWorker)
w.handleWorkerStatus(m.UpdateWorker)
case *livekit.WorkerMessage_MigrateJob:
go w.handleMigrateJob(m.MigrateJob)
w.handleMigrateJob(m.MigrateJob)
}
}
func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) {
if w.registered.Load() {
w.Logger.Warnw("worker already registered", nil, "id", w.id)
w.mu.Lock()
var err error
if w.IsClosed() {
err = errors.New("worker closed")
}
if w.registered.Swap(true) {
err = errors.New("worker already registered")
}
if err != nil {
w.mu.Unlock()
w.logger.Warnw("unable to register worker", err, "id", w.id)
return
}
w.mu.Lock()
onWorkerRegistered := w.onWorkerRegistered
w.version = req.Version
w.name = req.Name
w.namespace = req.GetNamespace()
@@ -307,10 +334,9 @@ func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) {
}
w.status = livekit.WorkerStatus_WS_AVAILABLE
w.registered.Store(true)
w.mu.Unlock()
w.Logger.Debugw("worker registered", "request", req)
w.logger.Debugw("worker registered", "request", logger.Proto(req))
w.sendRequest(&livekit.ServerMessage{
Message: &livekit.ServerMessage_Register{
@@ -321,9 +347,7 @@ func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) {
},
})
if onWorkerRegistered != nil {
onWorkerRegistered(w)
}
w.handler.HandleWorkerRegister(w)
}
func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) {
@@ -332,7 +356,7 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) {
availCh, ok := w.availability[res.JobId]
if !ok {
w.Logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId)
w.logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId)
return
}
@@ -342,32 +366,34 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) {
func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) {
w.mu.Lock()
defer w.mu.Unlock()
job, ok := w.runningJobs[update.JobId]
if !ok {
w.Logger.Infow("received job update for unknown job", "jobId", update.JobId)
w.logger.Infow("received job update for unknown job", "jobId", update.JobId)
return
}
now := time.Now()
job.State.UpdatedAt = now.UnixNano()
if job.State.Status == livekit.JobStatus_JS_PENDING && update.Status >= livekit.JobStatus_JS_RUNNING {
if job.State.Status == livekit.JobStatus_JS_PENDING && JobStatusIsEnded(update.Status) {
job.State.StartedAt = now.UnixNano()
}
if job.State.Status < livekit.JobStatus_JS_SUCCESS && update.Status >= livekit.JobStatus_JS_SUCCESS {
if job.State.Status < livekit.JobStatus_JS_SUCCESS && JobStatusIsEnded(update.Status) {
job.State.EndedAt = now.UnixNano()
}
job.State.Status = update.Status
job.State.Error = update.Error
// TODO do not delete, leafve inside the JobDefinition
if job.State.Status >= livekit.JobStatus_JS_SUCCESS {
// TODO do not delete, leave inside the JobDefinition
if JobStatusIsEnded(job.State.Status) {
delete(w.runningJobs, job.Id)
}
w.mu.Unlock()
w.handler.HandleWorkerJobStatus(w, update)
}
func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) {
@@ -377,19 +403,21 @@ func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) {
}
job := &livekit.Job{
Id: guid.New(utils.AgentJobPrefix),
Id: guid.New(guid.AgentJobPrefix),
Type: jobType,
Room: simulate.Room,
Participant: simulate.Participant,
Namespace: w.Namespace(),
}
ctx := context.Background()
err := w.AssignJob(ctx, job)
if err != nil {
w.Logger.Errorw("failed to simulate job, assignment failed", err, "jobId", job.Id)
}
go func() {
err := w.AssignJob(w.ctx, job)
if err != nil {
w.logger.Errorw("failed to simulate job, assignment failed", err, "jobId", job.Id)
} else {
w.handler.HandleWorkerSimulateJob(w, job)
}
}()
}
func (w *Worker) handleWorkerPing(ping *livekit.WorkerPing) {
@@ -402,11 +430,20 @@ func (w *Worker) handleWorkerPing(ping *livekit.WorkerPing) {
}
func (w *Worker) handleWorkerStatus(update *livekit.UpdateWorkerStatus) {
w.Logger.Debugw("worker status update", "status", update.Status, "load", update.Load)
w.UpdateStatus(update)
w.logger.Debugw("worker status update", "update", logger.Proto(update))
w.mu.Lock()
if update.Status != nil {
w.status = update.GetStatus()
}
w.load = update.GetLoad()
w.mu.Unlock()
w.handler.HandleWorkerStatus(w, update)
}
func (w *Worker) handleMigrateJob(migrate *livekit.MigrateJobRequest) {
// TODO(theomonnom): On OSS this is not implemented
// We could maybe just move a specific job to another worker
w.handler.HandleWorkerMigrateJob(w, migrate)
}
+1
View File
@@ -39,6 +39,7 @@ type WebsocketClient interface {
ReadMessage() (messageType int, p []byte, err error)
WriteMessage(messageType int, data []byte) error
WriteControl(messageType int, data []byte, deadline time.Time) error
Close() error
}
type AddSubscriberParams struct {
@@ -9,6 +9,16 @@ import (
)
type FakeWebsocketClient struct {
CloseStub func() error
closeMutex sync.RWMutex
closeArgsForCall []struct {
}
closeReturns struct {
result1 error
}
closeReturnsOnCall map[int]struct {
result1 error
}
ReadMessageStub func() (int, []byte, error)
readMessageMutex sync.RWMutex
readMessageArgsForCall []struct {
@@ -52,6 +62,59 @@ type FakeWebsocketClient struct {
invocationsMutex sync.RWMutex
}
func (fake *FakeWebsocketClient) Close() error {
fake.closeMutex.Lock()
ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)]
fake.closeArgsForCall = append(fake.closeArgsForCall, struct {
}{})
stub := fake.CloseStub
fakeReturns := fake.closeReturns
fake.recordInvocation("Close", []interface{}{})
fake.closeMutex.Unlock()
if stub != nil {
return stub()
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeWebsocketClient) CloseCallCount() int {
fake.closeMutex.RLock()
defer fake.closeMutex.RUnlock()
return len(fake.closeArgsForCall)
}
func (fake *FakeWebsocketClient) CloseCalls(stub func() error) {
fake.closeMutex.Lock()
defer fake.closeMutex.Unlock()
fake.CloseStub = stub
}
func (fake *FakeWebsocketClient) CloseReturns(result1 error) {
fake.closeMutex.Lock()
defer fake.closeMutex.Unlock()
fake.CloseStub = nil
fake.closeReturns = struct {
result1 error
}{result1}
}
func (fake *FakeWebsocketClient) CloseReturnsOnCall(i int, result1 error) {
fake.closeMutex.Lock()
defer fake.closeMutex.Unlock()
fake.CloseStub = nil
if fake.closeReturnsOnCall == nil {
fake.closeReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.closeReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeWebsocketClient) ReadMessage() (int, []byte, error) {
fake.readMessageMutex.Lock()
ret, specificReturn := fake.readMessageReturnsOnCall[len(fake.readMessageArgsForCall)]
@@ -249,6 +312,8 @@ func (fake *FakeWebsocketClient) WriteMessageReturnsOnCall(i int, result1 error)
func (fake *FakeWebsocketClient) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
fake.closeMutex.RLock()
defer fake.closeMutex.RUnlock()
fake.readMessageMutex.RLock()
defer fake.readMessageMutex.RUnlock()
fake.writeControlMutex.RLock()
+135 -189
View File
@@ -17,11 +17,11 @@ package service
import (
"context"
"errors"
"io"
"math/rand"
"net/http"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -41,13 +41,48 @@ import (
"github.com/livekit/psrpc"
)
type AgentSocketUpgrader struct {
websocket.Upgrader
}
func (u AgentSocketUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, agent.WorkerProtocolVersion, bool) {
// reject non websocket requests
if !websocket.IsWebSocketUpgrade(r) {
w.WriteHeader(404)
return nil, 0, false
}
// require a claim
claims := GetGrants(r.Context())
if claims == nil || claims.Video == nil || !claims.Video.Agent {
handleError(w, r, http.StatusUnauthorized, rtc.ErrPermissionDenied)
return nil, 0, false
}
// upgrade
conn, err := u.Upgrader.Upgrade(w, r, responseHeader)
if err != nil {
handleError(w, r, http.StatusInternalServerError, err)
return nil, 0, false
}
var protocol agent.WorkerProtocolVersion = agent.CurrentProtocol
if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil {
protocol = agent.WorkerProtocolVersion(pv)
}
return conn, protocol, true
}
type AgentService struct {
upgrader websocket.Upgrader
upgrader AgentSocketUpgrader
*AgentHandler
}
type AgentHandler struct {
agent.UnimplementedWorkerHandler
agentServer rpc.AgentInternalServer
mu sync.Mutex
logger logger.Logger
@@ -56,18 +91,19 @@ type AgentHandler struct {
workers map[string]*agent.Worker
keyProvider auth.KeyProvider
namespaceWorkers map[workerKey][]*agent.Worker
roomKeyCount int
publisherKeyCount int
// TODO remove once deprecated CheckEnabled is removed
namespaces map[string]*namespaceInfo
publisherEnabled bool
roomEnabled bool
namespaces []string
roomTopic string
publisherTopic string
}
type namespaceInfo struct {
numPublishers int32
numRooms int32
type workerKey struct {
namespace string
jobType livekit.JobType
}
func NewAgentService(conf *config.Config,
@@ -75,9 +111,7 @@ func NewAgentService(conf *config.Config,
bus psrpc.MessageBus,
keyProvider auth.KeyProvider,
) (*AgentService, error) {
s := &AgentService{
upgrader: websocket.Upgrader{},
}
s := &AgentService{}
// allow connections from any origin, since script may be hosted anywhere
// security is enforced by access tokens
@@ -110,27 +144,9 @@ func NewAgentService(conf *config.Config,
}
func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) {
// reject non websocket requests
if !websocket.IsWebSocketUpgrade(r) {
writer.WriteHeader(404)
return
if conn, protocol, ok := s.upgrader.Upgrade(writer, r, nil); ok {
s.HandleConnection(r.Context(), NewWSSignalConnection(conn), protocol)
}
// require a claim
claims := GetGrants(r.Context())
if claims == nil || claims.Video == nil || !claims.Video.Agent {
handleError(writer, r, http.StatusUnauthorized, rtc.ErrPermissionDenied)
return
}
// upgrade
conn, err := s.upgrader.Upgrade(writer, r, nil)
if err != nil {
handleError(writer, r, http.StatusInternalServerError, err)
return
}
s.HandleConnection(r, conn, nil)
}
func NewAgentHandler(
@@ -142,193 +158,136 @@ func NewAgentHandler(
publisherTopic string,
) *AgentHandler {
return &AgentHandler{
agentServer: agentServer,
logger: logger,
workers: make(map[string]*agent.Worker),
namespaces: make(map[string]*namespaceInfo),
serverInfo: serverInfo,
keyProvider: keyProvider,
roomTopic: roomTopic,
publisherTopic: publisherTopic,
agentServer: agentServer,
logger: logger,
workers: make(map[string]*agent.Worker),
namespaceWorkers: make(map[workerKey][]*agent.Worker),
serverInfo: serverInfo,
keyProvider: keyProvider,
roomTopic: roomTopic,
publisherTopic: publisherTopic,
}
}
func (h *AgentHandler) HandleConnection(r *http.Request, conn *websocket.Conn, onIdle func()) {
var protocol agent.WorkerProtocolVersion
if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil {
protocol = agent.WorkerProtocolVersion(pv)
}
sigConn := NewWSSignalConnection(conn)
apiKey := GetAPIKey(r.Context())
func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalConn, protocol agent.WorkerProtocolVersion) {
apiKey := GetAPIKey(ctx)
apiSecret := h.keyProvider.GetSecret(apiKey)
worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, sigConn, h.logger)
worker.OnWorkerRegistered(h.handleWorkerRegister)
worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, h.logger, h)
h.mu.Lock()
h.workers[worker.ID()] = worker
h.mu.Unlock()
defer func() {
worker.Close()
h.mu.Lock()
delete(h.workers, worker.ID())
numWorkers := len(h.workers)
h.mu.Unlock()
if worker.Registered() {
h.handleWorkerDeregister(worker)
}
if numWorkers == 0 && onIdle != nil {
onIdle()
}
}()
for {
req, _, err := sigConn.ReadWorkerMessage()
req, _, err := conn.ReadWorkerMessage()
if err != nil {
// normal/expected closure
if err == io.EOF ||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
websocket.IsCloseError(
err,
websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNormalClosure,
websocket.CloseNoStatusReceived,
) {
worker.Logger.Infow("worker closed WS connection", "wsError", err)
if IsWebSocketCloseError(err) {
worker.Logger().Infow("worker closed WS connection", "wsError", err)
} else {
worker.Logger.Errorw("error reading from websocket", err)
worker.Logger().Errorw("error reading from websocket", err)
}
return
break
}
worker.HandleMessage(req)
}
}
func (h *AgentHandler) handleWorkerRegister(w *agent.Worker) {
h.mu.Lock()
info, ok := h.namespaces[w.Namespace()]
numPublishers := int32(0)
numRooms := int32(0)
if ok {
numPublishers = info.numPublishers
numRooms = info.numRooms
}
shouldNotify := false
var err error
if w.JobType() == livekit.JobType_JT_PUBLISHER {
numPublishers++
if numPublishers == 1 {
shouldNotify = true
err = h.agentServer.RegisterJobRequestTopic(w.Namespace(), h.publisherTopic)
}
} else if w.JobType() == livekit.JobType_JT_ROOM {
numRooms++
if numRooms == 1 {
shouldNotify = true
err = h.agentServer.RegisterJobRequestTopic(w.Namespace(), h.roomTopic)
}
}
if err != nil {
w.Logger.Errorw("failed to register job request topic", err)
h.mu.Unlock()
w.Close() // Close the worker
return
}
h.namespaces[w.Namespace()] = &namespaceInfo{
numPublishers: numPublishers,
numRooms: numRooms,
}
h.roomEnabled = h.roomAvailableLocked()
h.publisherEnabled = h.publisherAvailableLocked()
delete(h.workers, worker.ID())
h.mu.Unlock()
if shouldNotify {
h.logger.Infow("initial worker registered", "namespace", w.Namespace(), "jobType", w.JobType())
err = h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{})
worker.Close()
}
func (h *AgentHandler) HandleWorkerRegister(w *agent.Worker) {
h.mu.Lock()
key := workerKey{w.Namespace(), w.JobType()}
workers := h.namespaceWorkers[key]
created := len(workers) == 0
if created {
topic := h.roomTopic
if w.JobType() == livekit.JobType_JT_PUBLISHER {
topic = h.publisherTopic
}
err := h.agentServer.RegisterJobRequestTopic(w.Namespace(), topic)
if err != nil {
w.Logger.Errorw("failed to publish worker registered", err)
h.mu.Unlock()
w.Logger().Errorw("failed to register job request topic", err)
w.Close()
return
}
if w.JobType() == livekit.JobType_JT_ROOM {
h.roomKeyCount++
} else {
h.publisherKeyCount++
}
h.namespaces = append(h.namespaces, w.Namespace())
sort.Strings(h.namespaces)
}
h.namespaceWorkers[key] = append(workers, w)
h.mu.Unlock()
if created {
h.logger.Infow("initial worker registered", "namespace", w.Namespace(), "jobType", w.JobType())
err := h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{})
if err != nil {
w.Logger().Errorw("failed to publish worker registered", err)
}
}
}
func (h *AgentHandler) handleWorkerDeregister(worker *agent.Worker) {
func (h *AgentHandler) HandleWorkerDeregister(w *agent.Worker) {
h.mu.Lock()
defer h.mu.Unlock()
info, ok := h.namespaces[worker.Namespace()]
key := workerKey{w.Namespace(), w.JobType()}
workers, ok := h.namespaceWorkers[key]
if !ok {
return
}
if worker.JobType() == livekit.JobType_JT_PUBLISHER {
info.numPublishers--
if info.numPublishers == 0 {
h.agentServer.DeregisterJobRequestTopic(worker.Namespace(), h.publisherTopic)
}
} else if worker.JobType() == livekit.JobType_JT_ROOM {
info.numRooms--
if info.numRooms == 0 {
h.agentServer.DeregisterJobRequestTopic(worker.Namespace(), h.roomTopic)
}
index := slices.Index(workers, w)
if index == -1 {
return
}
if info.numPublishers == 0 && info.numRooms == 0 {
if len(workers) > 1 {
h.namespaceWorkers[key] = slices.Delete(workers, index, index+1)
} else {
h.logger.Debugw("last worker deregistered")
delete(h.namespaces, worker.Namespace())
}
delete(h.namespaceWorkers, key)
h.roomEnabled = h.roomAvailableLocked()
h.publisherEnabled = h.publisherAvailableLocked()
}
func (h *AgentHandler) roomAvailableLocked() bool {
for _, w := range h.workers {
if w.JobType() == livekit.JobType_JT_ROOM {
return true
h.roomKeyCount--
h.agentServer.DeregisterJobRequestTopic(w.Namespace(), h.roomTopic)
} else {
h.publisherKeyCount--
h.agentServer.DeregisterJobRequestTopic(w.Namespace(), h.publisherTopic)
}
if i := slices.Index(h.namespaces, w.Namespace()); i != -1 {
h.namespaces = slices.Delete(h.namespaces, i, i+1)
}
}
return false
}
func (h *AgentHandler) publisherAvailableLocked() bool {
for _, w := range h.workers {
if w.JobType() == livekit.JobType_JT_PUBLISHER {
return true
}
}
return false
}
func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) {
attempted := make(map[string]bool)
key := workerKey{job.Namespace, job.Type}
attempted := make(map[*agent.Worker]struct{})
for {
h.mu.Lock()
var selected *agent.Worker
var maxLoad float32
for _, w := range h.workers {
if w.Namespace() != job.Namespace || w.JobType() != job.Type {
continue
}
_, ok := attempted[w.ID()]
if ok {
for _, w := range h.namespaceWorkers[key] {
if _, ok := attempted[w]; ok {
continue
}
@@ -341,7 +300,6 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty
selected = w
}
}
}
h.mu.Unlock()
@@ -349,7 +307,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "no workers available")
}
attempted[selected.ID()] = true
attempted[selected] = struct{}{}
values := []interface{}{
"jobID", job.Id,
@@ -373,7 +331,6 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty
return &emptypb.Empty{}, nil
}
}
func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 {
@@ -396,7 +353,6 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job)
affinity = 0.5
}
}
}
return affinity
@@ -405,24 +361,14 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job)
func (h *AgentHandler) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) {
h.mu.Lock()
defer h.mu.Unlock()
namespaces := make([]string, 0, len(h.namespaces))
for ns := range h.namespaces {
namespaces = append(namespaces, ns)
}
return &rpc.CheckEnabledResponse{
Namespaces: namespaces,
RoomEnabled: h.roomEnabled,
PublisherEnabled: h.publisherEnabled,
Namespaces: slices.Compact(slices.Clone(h.namespaces)),
RoomEnabled: h.roomKeyCount != 0,
PublisherEnabled: h.publisherKeyCount != 0,
}, nil
}
func (h *AgentHandler) NumConnections() int {
h.mu.Lock()
defer h.mu.Unlock()
return len(h.workers)
}
func (h *AgentHandler) DrainConnections(interval time.Duration) {
// jitter drain start
time.Sleep(time.Duration(rand.Int63n(int64(interval))))
+1 -13
View File
@@ -18,12 +18,10 @@ import (
"context"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -369,17 +367,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for {
req, count, err := sigConn.ReadRequest()
if err != nil {
// normal/expected closure
if errors.Is(err, io.EOF) ||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
websocket.IsCloseError(
err,
websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNormalClosure,
websocket.CloseNoStatusReceived,
) {
if IsWebSocketCloseError(err) {
closedByClient.Store(true)
} else {
pLogger.Errorw("error reading from websocket", err, "connID", cr.ConnectionID)
+21
View File
@@ -15,6 +15,9 @@
package service
import (
"errors"
"io"
"strings"
"sync"
"time"
@@ -49,6 +52,10 @@ func NewWSSignalConnection(conn types.WebsocketClient) *WSSignalConnection {
return wsc
}
func (c *WSSignalConnection) Close() error {
return c.conn.Close()
}
func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error) {
for {
// handle special messages and pass on the rest
@@ -172,3 +179,17 @@ func (c *WSSignalConnection) pingWorker() {
}
}
}
// IsWebSocketCloseError checks that error is normal/expected closure
func IsWebSocketCloseError(err error) bool {
return errors.Is(err, io.EOF) ||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
websocket.IsCloseError(
err,
websocket.CloseAbnormalClosure,
websocket.CloseGoingAway,
websocket.CloseNormalClosure,
websocket.CloseNoStatusReceived,
)
}