mirror of
https://github.com/livekit/livekit.git
synced 2026-05-14 18:25:24 +00:00
add handler interface to receive agent worker updates (#2830)
* add handler interface to receive agent worker updates * cleanup
This commit is contained in:
+120
-83
@@ -21,13 +21,9 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
pagent "github.com/livekit/protocol/agent"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/livekit/protocol/utils"
|
||||
putil "github.com/livekit/protocol/utils"
|
||||
"github.com/livekit/protocol/utils/guid"
|
||||
)
|
||||
|
||||
@@ -42,13 +38,40 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWorkerClosed = errors.New("worker closed")
|
||||
ErrWorkerNotAvailable = errors.New("worker not available")
|
||||
ErrAvailabilityTimeout = errors.New("agent worker availability timeout")
|
||||
ErrWorkerClosed = errors.New("worker closed")
|
||||
ErrWorkerNotAvailable = errors.New("worker not available")
|
||||
ErrAvailabilityTimeout = errors.New("agent worker availability timeout")
|
||||
ErrDuplicateJobAssignment = errors.New("duplicate job assignment")
|
||||
)
|
||||
|
||||
type sigConn interface {
|
||||
type SignalConn interface {
|
||||
WriteServerMessage(msg *livekit.ServerMessage) (int, error)
|
||||
ReadWorkerMessage() (*livekit.WorkerMessage, int, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type WorkerHandler interface {
|
||||
HandleWorkerRegister(w *Worker)
|
||||
HandleWorkerDeregister(w *Worker)
|
||||
HandleWorkerStatus(w *Worker, status *livekit.UpdateWorkerStatus)
|
||||
HandleWorkerJobStatus(w *Worker, status *livekit.UpdateJobStatus)
|
||||
HandleWorkerSimulateJob(w *Worker, job *livekit.Job)
|
||||
HandleWorkerMigrateJob(w *Worker, request *livekit.MigrateJobRequest)
|
||||
}
|
||||
|
||||
var _ WorkerHandler = UnimplementedWorkerHandler{}
|
||||
|
||||
type UnimplementedWorkerHandler struct{}
|
||||
|
||||
func (UnimplementedWorkerHandler) HandleWorkerRegister(*Worker) {}
|
||||
func (UnimplementedWorkerHandler) HandleWorkerDeregister(*Worker) {}
|
||||
func (UnimplementedWorkerHandler) HandleWorkerStatus(*Worker, *livekit.UpdateWorkerStatus) {}
|
||||
func (UnimplementedWorkerHandler) HandleWorkerJobStatus(*Worker, *livekit.UpdateJobStatus) {}
|
||||
func (UnimplementedWorkerHandler) HandleWorkerSimulateJob(*Worker, *livekit.Job) {}
|
||||
func (UnimplementedWorkerHandler) HandleWorkerMigrateJob(*Worker, *livekit.MigrateJobRequest) {}
|
||||
|
||||
func JobStatusIsEnded(s livekit.JobStatus) bool {
|
||||
return s == livekit.JobStatus_JS_SUCCESS || s == livekit.JobStatus_JS_FAILED
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
@@ -69,18 +92,17 @@ type Worker struct {
|
||||
status livekit.WorkerStatus
|
||||
runningJobs map[string]*livekit.Job
|
||||
|
||||
onWorkerRegistered func(w *Worker)
|
||||
handler WorkerHandler
|
||||
|
||||
conn *websocket.Conn
|
||||
sigConn sigConn
|
||||
closed chan struct{}
|
||||
conn SignalConn
|
||||
closed chan struct{}
|
||||
|
||||
availability map[string]chan *livekit.AvailabilityResponse
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
Logger logger.Logger
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewWorker(
|
||||
@@ -88,14 +110,15 @@ func NewWorker(
|
||||
apiKey string,
|
||||
apiSecret string,
|
||||
serverInfo *livekit.ServerInfo,
|
||||
conn *websocket.Conn,
|
||||
sigConn sigConn,
|
||||
conn SignalConn,
|
||||
logger logger.Logger,
|
||||
handler WorkerHandler,
|
||||
) *Worker {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
id := guid.New(guid.AgentWorkerPrefix)
|
||||
|
||||
w := &Worker{
|
||||
id: putil.NewGuid(utils.AgentWorkerPrefix),
|
||||
id: id,
|
||||
protocolVersion: protocolVersion,
|
||||
apiKey: apiKey,
|
||||
apiSecret: apiSecret,
|
||||
@@ -104,26 +127,25 @@ func NewWorker(
|
||||
runningJobs: make(map[string]*livekit.Job),
|
||||
availability: make(map[string]chan *livekit.AvailabilityResponse),
|
||||
conn: conn,
|
||||
sigConn: sigConn,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
Logger: logger,
|
||||
logger: logger.WithValues("workerID", id),
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-time.After(registerTimeout)
|
||||
time.AfterFunc(registerTimeout, func() {
|
||||
if !w.registered.Load() && !w.IsClosed() {
|
||||
w.Logger.Warnw("worker did not register in time", nil, "id", w.id)
|
||||
w.logger.Warnw("worker did not register in time", nil, "id", w.id)
|
||||
w.Close()
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *Worker) sendRequest(req *livekit.ServerMessage) {
|
||||
if _, err := w.sigConn.WriteServerMessage(req); err != nil {
|
||||
w.Logger.Errorw("error writing to websocket", err)
|
||||
if _, err := w.conn.WriteServerMessage(req); err != nil {
|
||||
w.logger.Errorw("error writing to websocket", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,15 +154,10 @@ func (w *Worker) ID() string {
|
||||
}
|
||||
|
||||
func (w *Worker) JobType() livekit.JobType {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.jobType
|
||||
}
|
||||
|
||||
func (w *Worker) Namespace() string {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.namespace
|
||||
}
|
||||
|
||||
@@ -156,20 +173,14 @@ func (w *Worker) Load() float32 {
|
||||
return w.load
|
||||
}
|
||||
|
||||
func (w *Worker) OnWorkerRegistered(f func(w *Worker)) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.onWorkerRegistered = f
|
||||
}
|
||||
|
||||
func (w *Worker) Registered() bool {
|
||||
return w.registered.Load()
|
||||
func (w *Worker) Logger() logger.Logger {
|
||||
return w.logger
|
||||
}
|
||||
|
||||
func (w *Worker) RunningJobs() map[string]*livekit.Job {
|
||||
jobs := make(map[string]*livekit.Job, len(w.runningJobs))
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
jobs := make(map[string]*livekit.Job, len(w.runningJobs))
|
||||
for k, v := range w.runningJobs {
|
||||
jobs[k] = v
|
||||
}
|
||||
@@ -180,9 +191,20 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
|
||||
availCh := make(chan *livekit.AvailabilityResponse, 1)
|
||||
|
||||
w.mu.Lock()
|
||||
if _, ok := w.availability[job.Id]; ok {
|
||||
w.mu.Unlock()
|
||||
return ErrDuplicateJobAssignment
|
||||
}
|
||||
|
||||
w.availability[job.Id] = availCh
|
||||
w.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
w.mu.Lock()
|
||||
delete(w.availability, job.Id)
|
||||
w.mu.Unlock()
|
||||
}()
|
||||
|
||||
if job.State == nil {
|
||||
job.State = &livekit.JobState{}
|
||||
}
|
||||
@@ -191,6 +213,9 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
|
||||
Availability: &livekit.AvailabilityRequest{Job: job},
|
||||
}})
|
||||
|
||||
timeout := time.NewTimer(assignJobTimeout)
|
||||
defer timeout.Stop()
|
||||
|
||||
// See handleAvailability for the response
|
||||
select {
|
||||
case res := <-availCh:
|
||||
@@ -200,7 +225,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
|
||||
|
||||
token, err := pagent.BuildAgentToken(w.apiKey, w.apiSecret, job.Room.Name, res.ParticipantIdentity, res.ParticipantName, res.ParticipantMetadata, w.permissions)
|
||||
if err != nil {
|
||||
w.Logger.Errorw("failed to build agent token", err)
|
||||
w.logger.Errorw("failed to build agent token", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -216,7 +241,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
|
||||
// TODO sweep jobs that are never started. We can't do this until all SDKs actually update the the JOB state
|
||||
|
||||
return nil
|
||||
case <-time.After(assignJobTimeout):
|
||||
case <-timeout.C:
|
||||
return ErrAvailabilityTimeout
|
||||
case <-w.ctx.Done():
|
||||
return ErrWorkerClosed
|
||||
@@ -225,17 +250,8 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) UpdateStatus(status *livekit.UpdateWorkerStatus) {
|
||||
w.mu.Lock()
|
||||
if status.Status != nil {
|
||||
w.status = status.GetStatus()
|
||||
}
|
||||
w.load = status.GetLoad()
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Worker) UpdateMetadata(metadata string) {
|
||||
w.Logger.Debugw("worker metadata updated", nil, "metadata", metadata)
|
||||
w.logger.Debugw("worker metadata updated", nil, "metadata", metadata)
|
||||
}
|
||||
|
||||
func (w *Worker) IsClosed() bool {
|
||||
@@ -254,41 +270,52 @@ func (w *Worker) Close() {
|
||||
return
|
||||
}
|
||||
|
||||
w.Logger.Infow("closing worker")
|
||||
w.logger.Infow("closing worker")
|
||||
|
||||
close(w.closed)
|
||||
w.cancel()
|
||||
_ = w.conn.Close()
|
||||
w.mu.Unlock()
|
||||
|
||||
if w.registered.Load() {
|
||||
w.handler.HandleWorkerDeregister(w)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) HandleMessage(req *livekit.WorkerMessage) {
|
||||
switch m := req.Message.(type) {
|
||||
case *livekit.WorkerMessage_Register:
|
||||
go w.handleRegister(m.Register)
|
||||
w.handleRegister(m.Register)
|
||||
case *livekit.WorkerMessage_Availability:
|
||||
go w.handleAvailability(m.Availability)
|
||||
w.handleAvailability(m.Availability)
|
||||
case *livekit.WorkerMessage_UpdateJob:
|
||||
go w.handleJobUpdate(m.UpdateJob)
|
||||
w.handleJobUpdate(m.UpdateJob)
|
||||
case *livekit.WorkerMessage_SimulateJob:
|
||||
go w.handleSimulateJob(m.SimulateJob)
|
||||
w.handleSimulateJob(m.SimulateJob)
|
||||
case *livekit.WorkerMessage_Ping:
|
||||
go w.handleWorkerPing(m.Ping)
|
||||
w.handleWorkerPing(m.Ping)
|
||||
case *livekit.WorkerMessage_UpdateWorker:
|
||||
go w.handleWorkerStatus(m.UpdateWorker)
|
||||
w.handleWorkerStatus(m.UpdateWorker)
|
||||
case *livekit.WorkerMessage_MigrateJob:
|
||||
go w.handleMigrateJob(m.MigrateJob)
|
||||
w.handleMigrateJob(m.MigrateJob)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) {
|
||||
if w.registered.Load() {
|
||||
w.Logger.Warnw("worker already registered", nil, "id", w.id)
|
||||
w.mu.Lock()
|
||||
var err error
|
||||
if w.IsClosed() {
|
||||
err = errors.New("worker closed")
|
||||
}
|
||||
if w.registered.Swap(true) {
|
||||
err = errors.New("worker already registered")
|
||||
}
|
||||
if err != nil {
|
||||
w.mu.Unlock()
|
||||
w.logger.Warnw("unable to register worker", err, "id", w.id)
|
||||
return
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
onWorkerRegistered := w.onWorkerRegistered
|
||||
w.version = req.Version
|
||||
w.name = req.Name
|
||||
w.namespace = req.GetNamespace()
|
||||
@@ -307,10 +334,9 @@ func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) {
|
||||
}
|
||||
|
||||
w.status = livekit.WorkerStatus_WS_AVAILABLE
|
||||
w.registered.Store(true)
|
||||
w.mu.Unlock()
|
||||
|
||||
w.Logger.Debugw("worker registered", "request", req)
|
||||
w.logger.Debugw("worker registered", "request", logger.Proto(req))
|
||||
|
||||
w.sendRequest(&livekit.ServerMessage{
|
||||
Message: &livekit.ServerMessage_Register{
|
||||
@@ -321,9 +347,7 @@ func (w *Worker) handleRegister(req *livekit.RegisterWorkerRequest) {
|
||||
},
|
||||
})
|
||||
|
||||
if onWorkerRegistered != nil {
|
||||
onWorkerRegistered(w)
|
||||
}
|
||||
w.handler.HandleWorkerRegister(w)
|
||||
}
|
||||
|
||||
func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) {
|
||||
@@ -332,7 +356,7 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) {
|
||||
|
||||
availCh, ok := w.availability[res.JobId]
|
||||
if !ok {
|
||||
w.Logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId)
|
||||
w.logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -342,32 +366,34 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) {
|
||||
|
||||
func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
job, ok := w.runningJobs[update.JobId]
|
||||
if !ok {
|
||||
w.Logger.Infow("received job update for unknown job", "jobId", update.JobId)
|
||||
w.logger.Infow("received job update for unknown job", "jobId", update.JobId)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
job.State.UpdatedAt = now.UnixNano()
|
||||
|
||||
if job.State.Status == livekit.JobStatus_JS_PENDING && update.Status >= livekit.JobStatus_JS_RUNNING {
|
||||
if job.State.Status == livekit.JobStatus_JS_PENDING && JobStatusIsEnded(update.Status) {
|
||||
job.State.StartedAt = now.UnixNano()
|
||||
}
|
||||
|
||||
if job.State.Status < livekit.JobStatus_JS_SUCCESS && update.Status >= livekit.JobStatus_JS_SUCCESS {
|
||||
if job.State.Status < livekit.JobStatus_JS_SUCCESS && JobStatusIsEnded(update.Status) {
|
||||
job.State.EndedAt = now.UnixNano()
|
||||
}
|
||||
|
||||
job.State.Status = update.Status
|
||||
job.State.Error = update.Error
|
||||
|
||||
// TODO do not delete, leafve inside the JobDefinition
|
||||
if job.State.Status >= livekit.JobStatus_JS_SUCCESS {
|
||||
// TODO do not delete, leave inside the JobDefinition
|
||||
if JobStatusIsEnded(job.State.Status) {
|
||||
delete(w.runningJobs, job.Id)
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
w.handler.HandleWorkerJobStatus(w, update)
|
||||
}
|
||||
|
||||
func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) {
|
||||
@@ -377,19 +403,21 @@ func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) {
|
||||
}
|
||||
|
||||
job := &livekit.Job{
|
||||
Id: guid.New(utils.AgentJobPrefix),
|
||||
Id: guid.New(guid.AgentJobPrefix),
|
||||
Type: jobType,
|
||||
Room: simulate.Room,
|
||||
Participant: simulate.Participant,
|
||||
Namespace: w.Namespace(),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := w.AssignJob(ctx, job)
|
||||
if err != nil {
|
||||
w.Logger.Errorw("failed to simulate job, assignment failed", err, "jobId", job.Id)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := w.AssignJob(w.ctx, job)
|
||||
if err != nil {
|
||||
w.logger.Errorw("failed to simulate job, assignment failed", err, "jobId", job.Id)
|
||||
} else {
|
||||
w.handler.HandleWorkerSimulateJob(w, job)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *Worker) handleWorkerPing(ping *livekit.WorkerPing) {
|
||||
@@ -402,11 +430,20 @@ func (w *Worker) handleWorkerPing(ping *livekit.WorkerPing) {
|
||||
}
|
||||
|
||||
func (w *Worker) handleWorkerStatus(update *livekit.UpdateWorkerStatus) {
|
||||
w.Logger.Debugw("worker status update", "status", update.Status, "load", update.Load)
|
||||
w.UpdateStatus(update)
|
||||
w.logger.Debugw("worker status update", "update", logger.Proto(update))
|
||||
|
||||
w.mu.Lock()
|
||||
if update.Status != nil {
|
||||
w.status = update.GetStatus()
|
||||
}
|
||||
w.load = update.GetLoad()
|
||||
w.mu.Unlock()
|
||||
|
||||
w.handler.HandleWorkerStatus(w, update)
|
||||
}
|
||||
|
||||
func (w *Worker) handleMigrateJob(migrate *livekit.MigrateJobRequest) {
|
||||
// TODO(theomonnom): On OSS this is not implemented
|
||||
// We could maybe just move a specific job to another worker
|
||||
w.handler.HandleWorkerMigrateJob(w, migrate)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type WebsocketClient interface {
|
||||
ReadMessage() (messageType int, p []byte, err error)
|
||||
WriteMessage(messageType int, data []byte) error
|
||||
WriteControl(messageType int, data []byte, deadline time.Time) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type AddSubscriberParams struct {
|
||||
|
||||
@@ -9,6 +9,16 @@ import (
|
||||
)
|
||||
|
||||
type FakeWebsocketClient struct {
|
||||
CloseStub func() error
|
||||
closeMutex sync.RWMutex
|
||||
closeArgsForCall []struct {
|
||||
}
|
||||
closeReturns struct {
|
||||
result1 error
|
||||
}
|
||||
closeReturnsOnCall map[int]struct {
|
||||
result1 error
|
||||
}
|
||||
ReadMessageStub func() (int, []byte, error)
|
||||
readMessageMutex sync.RWMutex
|
||||
readMessageArgsForCall []struct {
|
||||
@@ -52,6 +62,59 @@ type FakeWebsocketClient struct {
|
||||
invocationsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (fake *FakeWebsocketClient) Close() error {
|
||||
fake.closeMutex.Lock()
|
||||
ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)]
|
||||
fake.closeArgsForCall = append(fake.closeArgsForCall, struct {
|
||||
}{})
|
||||
stub := fake.CloseStub
|
||||
fakeReturns := fake.closeReturns
|
||||
fake.recordInvocation("Close", []interface{}{})
|
||||
fake.closeMutex.Unlock()
|
||||
if stub != nil {
|
||||
return stub()
|
||||
}
|
||||
if specificReturn {
|
||||
return ret.result1
|
||||
}
|
||||
return fakeReturns.result1
|
||||
}
|
||||
|
||||
func (fake *FakeWebsocketClient) CloseCallCount() int {
|
||||
fake.closeMutex.RLock()
|
||||
defer fake.closeMutex.RUnlock()
|
||||
return len(fake.closeArgsForCall)
|
||||
}
|
||||
|
||||
func (fake *FakeWebsocketClient) CloseCalls(stub func() error) {
|
||||
fake.closeMutex.Lock()
|
||||
defer fake.closeMutex.Unlock()
|
||||
fake.CloseStub = stub
|
||||
}
|
||||
|
||||
func (fake *FakeWebsocketClient) CloseReturns(result1 error) {
|
||||
fake.closeMutex.Lock()
|
||||
defer fake.closeMutex.Unlock()
|
||||
fake.CloseStub = nil
|
||||
fake.closeReturns = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeWebsocketClient) CloseReturnsOnCall(i int, result1 error) {
|
||||
fake.closeMutex.Lock()
|
||||
defer fake.closeMutex.Unlock()
|
||||
fake.CloseStub = nil
|
||||
if fake.closeReturnsOnCall == nil {
|
||||
fake.closeReturnsOnCall = make(map[int]struct {
|
||||
result1 error
|
||||
})
|
||||
}
|
||||
fake.closeReturnsOnCall[i] = struct {
|
||||
result1 error
|
||||
}{result1}
|
||||
}
|
||||
|
||||
func (fake *FakeWebsocketClient) ReadMessage() (int, []byte, error) {
|
||||
fake.readMessageMutex.Lock()
|
||||
ret, specificReturn := fake.readMessageReturnsOnCall[len(fake.readMessageArgsForCall)]
|
||||
@@ -249,6 +312,8 @@ func (fake *FakeWebsocketClient) WriteMessageReturnsOnCall(i int, result1 error)
|
||||
func (fake *FakeWebsocketClient) Invocations() map[string][][]interface{} {
|
||||
fake.invocationsMutex.RLock()
|
||||
defer fake.invocationsMutex.RUnlock()
|
||||
fake.closeMutex.RLock()
|
||||
defer fake.closeMutex.RUnlock()
|
||||
fake.readMessageMutex.RLock()
|
||||
defer fake.readMessageMutex.RUnlock()
|
||||
fake.writeControlMutex.RLock()
|
||||
|
||||
+135
-189
@@ -17,11 +17,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -41,13 +41,48 @@ import (
|
||||
"github.com/livekit/psrpc"
|
||||
)
|
||||
|
||||
type AgentSocketUpgrader struct {
|
||||
websocket.Upgrader
|
||||
}
|
||||
|
||||
func (u AgentSocketUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, agent.WorkerProtocolVersion, bool) {
|
||||
// reject non websocket requests
|
||||
if !websocket.IsWebSocketUpgrade(r) {
|
||||
w.WriteHeader(404)
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
// require a claim
|
||||
claims := GetGrants(r.Context())
|
||||
if claims == nil || claims.Video == nil || !claims.Video.Agent {
|
||||
handleError(w, r, http.StatusUnauthorized, rtc.ErrPermissionDenied)
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
// upgrade
|
||||
conn, err := u.Upgrader.Upgrade(w, r, responseHeader)
|
||||
if err != nil {
|
||||
handleError(w, r, http.StatusInternalServerError, err)
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
var protocol agent.WorkerProtocolVersion = agent.CurrentProtocol
|
||||
if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil {
|
||||
protocol = agent.WorkerProtocolVersion(pv)
|
||||
}
|
||||
|
||||
return conn, protocol, true
|
||||
}
|
||||
|
||||
type AgentService struct {
|
||||
upgrader websocket.Upgrader
|
||||
upgrader AgentSocketUpgrader
|
||||
|
||||
*AgentHandler
|
||||
}
|
||||
|
||||
type AgentHandler struct {
|
||||
agent.UnimplementedWorkerHandler
|
||||
|
||||
agentServer rpc.AgentInternalServer
|
||||
mu sync.Mutex
|
||||
logger logger.Logger
|
||||
@@ -56,18 +91,19 @@ type AgentHandler struct {
|
||||
workers map[string]*agent.Worker
|
||||
keyProvider auth.KeyProvider
|
||||
|
||||
namespaceWorkers map[workerKey][]*agent.Worker
|
||||
roomKeyCount int
|
||||
publisherKeyCount int
|
||||
// TODO remove once deprecated CheckEnabled is removed
|
||||
namespaces map[string]*namespaceInfo
|
||||
publisherEnabled bool
|
||||
roomEnabled bool
|
||||
namespaces []string
|
||||
|
||||
roomTopic string
|
||||
publisherTopic string
|
||||
}
|
||||
|
||||
type namespaceInfo struct {
|
||||
numPublishers int32
|
||||
numRooms int32
|
||||
type workerKey struct {
|
||||
namespace string
|
||||
jobType livekit.JobType
|
||||
}
|
||||
|
||||
func NewAgentService(conf *config.Config,
|
||||
@@ -75,9 +111,7 @@ func NewAgentService(conf *config.Config,
|
||||
bus psrpc.MessageBus,
|
||||
keyProvider auth.KeyProvider,
|
||||
) (*AgentService, error) {
|
||||
s := &AgentService{
|
||||
upgrader: websocket.Upgrader{},
|
||||
}
|
||||
s := &AgentService{}
|
||||
|
||||
// allow connections from any origin, since script may be hosted anywhere
|
||||
// security is enforced by access tokens
|
||||
@@ -110,27 +144,9 @@ func NewAgentService(conf *config.Config,
|
||||
}
|
||||
|
||||
func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) {
|
||||
// reject non websocket requests
|
||||
if !websocket.IsWebSocketUpgrade(r) {
|
||||
writer.WriteHeader(404)
|
||||
return
|
||||
if conn, protocol, ok := s.upgrader.Upgrade(writer, r, nil); ok {
|
||||
s.HandleConnection(r.Context(), NewWSSignalConnection(conn), protocol)
|
||||
}
|
||||
|
||||
// require a claim
|
||||
claims := GetGrants(r.Context())
|
||||
if claims == nil || claims.Video == nil || !claims.Video.Agent {
|
||||
handleError(writer, r, http.StatusUnauthorized, rtc.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
|
||||
// upgrade
|
||||
conn, err := s.upgrader.Upgrade(writer, r, nil)
|
||||
if err != nil {
|
||||
handleError(writer, r, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.HandleConnection(r, conn, nil)
|
||||
}
|
||||
|
||||
func NewAgentHandler(
|
||||
@@ -142,193 +158,136 @@ func NewAgentHandler(
|
||||
publisherTopic string,
|
||||
) *AgentHandler {
|
||||
return &AgentHandler{
|
||||
agentServer: agentServer,
|
||||
logger: logger,
|
||||
workers: make(map[string]*agent.Worker),
|
||||
namespaces: make(map[string]*namespaceInfo),
|
||||
serverInfo: serverInfo,
|
||||
keyProvider: keyProvider,
|
||||
roomTopic: roomTopic,
|
||||
publisherTopic: publisherTopic,
|
||||
agentServer: agentServer,
|
||||
logger: logger,
|
||||
workers: make(map[string]*agent.Worker),
|
||||
namespaceWorkers: make(map[workerKey][]*agent.Worker),
|
||||
serverInfo: serverInfo,
|
||||
keyProvider: keyProvider,
|
||||
roomTopic: roomTopic,
|
||||
publisherTopic: publisherTopic,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) HandleConnection(r *http.Request, conn *websocket.Conn, onIdle func()) {
|
||||
var protocol agent.WorkerProtocolVersion
|
||||
if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil {
|
||||
protocol = agent.WorkerProtocolVersion(pv)
|
||||
}
|
||||
|
||||
sigConn := NewWSSignalConnection(conn)
|
||||
|
||||
apiKey := GetAPIKey(r.Context())
|
||||
func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalConn, protocol agent.WorkerProtocolVersion) {
|
||||
apiKey := GetAPIKey(ctx)
|
||||
apiSecret := h.keyProvider.GetSecret(apiKey)
|
||||
|
||||
worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, sigConn, h.logger)
|
||||
worker.OnWorkerRegistered(h.handleWorkerRegister)
|
||||
worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, h.logger, h)
|
||||
|
||||
h.mu.Lock()
|
||||
h.workers[worker.ID()] = worker
|
||||
h.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
worker.Close()
|
||||
|
||||
h.mu.Lock()
|
||||
delete(h.workers, worker.ID())
|
||||
numWorkers := len(h.workers)
|
||||
h.mu.Unlock()
|
||||
|
||||
if worker.Registered() {
|
||||
h.handleWorkerDeregister(worker)
|
||||
}
|
||||
|
||||
if numWorkers == 0 && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
req, _, err := sigConn.ReadWorkerMessage()
|
||||
req, _, err := conn.ReadWorkerMessage()
|
||||
if err != nil {
|
||||
// normal/expected closure
|
||||
if err == io.EOF ||
|
||||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
|
||||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
|
||||
websocket.IsCloseError(
|
||||
err,
|
||||
websocket.CloseAbnormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseNoStatusReceived,
|
||||
) {
|
||||
worker.Logger.Infow("worker closed WS connection", "wsError", err)
|
||||
if IsWebSocketCloseError(err) {
|
||||
worker.Logger().Infow("worker closed WS connection", "wsError", err)
|
||||
} else {
|
||||
worker.Logger.Errorw("error reading from websocket", err)
|
||||
worker.Logger().Errorw("error reading from websocket", err)
|
||||
}
|
||||
return
|
||||
break
|
||||
}
|
||||
|
||||
worker.HandleMessage(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) handleWorkerRegister(w *agent.Worker) {
|
||||
h.mu.Lock()
|
||||
|
||||
info, ok := h.namespaces[w.Namespace()]
|
||||
numPublishers := int32(0)
|
||||
numRooms := int32(0)
|
||||
if ok {
|
||||
numPublishers = info.numPublishers
|
||||
numRooms = info.numRooms
|
||||
}
|
||||
|
||||
shouldNotify := false
|
||||
var err error
|
||||
if w.JobType() == livekit.JobType_JT_PUBLISHER {
|
||||
numPublishers++
|
||||
if numPublishers == 1 {
|
||||
shouldNotify = true
|
||||
err = h.agentServer.RegisterJobRequestTopic(w.Namespace(), h.publisherTopic)
|
||||
}
|
||||
|
||||
} else if w.JobType() == livekit.JobType_JT_ROOM {
|
||||
numRooms++
|
||||
if numRooms == 1 {
|
||||
shouldNotify = true
|
||||
err = h.agentServer.RegisterJobRequestTopic(w.Namespace(), h.roomTopic)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
w.Logger.Errorw("failed to register job request topic", err)
|
||||
h.mu.Unlock()
|
||||
w.Close() // Close the worker
|
||||
return
|
||||
}
|
||||
|
||||
h.namespaces[w.Namespace()] = &namespaceInfo{
|
||||
numPublishers: numPublishers,
|
||||
numRooms: numRooms,
|
||||
}
|
||||
|
||||
h.roomEnabled = h.roomAvailableLocked()
|
||||
h.publisherEnabled = h.publisherAvailableLocked()
|
||||
delete(h.workers, worker.ID())
|
||||
h.mu.Unlock()
|
||||
|
||||
if shouldNotify {
|
||||
h.logger.Infow("initial worker registered", "namespace", w.Namespace(), "jobType", w.JobType())
|
||||
err = h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{})
|
||||
worker.Close()
|
||||
}
|
||||
|
||||
func (h *AgentHandler) HandleWorkerRegister(w *agent.Worker) {
|
||||
h.mu.Lock()
|
||||
|
||||
key := workerKey{w.Namespace(), w.JobType()}
|
||||
|
||||
workers := h.namespaceWorkers[key]
|
||||
created := len(workers) == 0
|
||||
|
||||
if created {
|
||||
topic := h.roomTopic
|
||||
if w.JobType() == livekit.JobType_JT_PUBLISHER {
|
||||
topic = h.publisherTopic
|
||||
}
|
||||
err := h.agentServer.RegisterJobRequestTopic(w.Namespace(), topic)
|
||||
if err != nil {
|
||||
w.Logger.Errorw("failed to publish worker registered", err)
|
||||
h.mu.Unlock()
|
||||
|
||||
w.Logger().Errorw("failed to register job request topic", err)
|
||||
w.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if w.JobType() == livekit.JobType_JT_ROOM {
|
||||
h.roomKeyCount++
|
||||
} else {
|
||||
h.publisherKeyCount++
|
||||
}
|
||||
|
||||
h.namespaces = append(h.namespaces, w.Namespace())
|
||||
sort.Strings(h.namespaces)
|
||||
}
|
||||
|
||||
h.namespaceWorkers[key] = append(workers, w)
|
||||
h.mu.Unlock()
|
||||
|
||||
if created {
|
||||
h.logger.Infow("initial worker registered", "namespace", w.Namespace(), "jobType", w.JobType())
|
||||
err := h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{})
|
||||
if err != nil {
|
||||
w.Logger().Errorw("failed to publish worker registered", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) handleWorkerDeregister(worker *agent.Worker) {
|
||||
func (h *AgentHandler) HandleWorkerDeregister(w *agent.Worker) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
info, ok := h.namespaces[worker.Namespace()]
|
||||
key := workerKey{w.Namespace(), w.JobType()}
|
||||
|
||||
workers, ok := h.namespaceWorkers[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if worker.JobType() == livekit.JobType_JT_PUBLISHER {
|
||||
info.numPublishers--
|
||||
if info.numPublishers == 0 {
|
||||
h.agentServer.DeregisterJobRequestTopic(worker.Namespace(), h.publisherTopic)
|
||||
}
|
||||
} else if worker.JobType() == livekit.JobType_JT_ROOM {
|
||||
info.numRooms--
|
||||
if info.numRooms == 0 {
|
||||
h.agentServer.DeregisterJobRequestTopic(worker.Namespace(), h.roomTopic)
|
||||
}
|
||||
index := slices.Index(workers, w)
|
||||
if index == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
if info.numPublishers == 0 && info.numRooms == 0 {
|
||||
if len(workers) > 1 {
|
||||
h.namespaceWorkers[key] = slices.Delete(workers, index, index+1)
|
||||
} else {
|
||||
h.logger.Debugw("last worker deregistered")
|
||||
delete(h.namespaces, worker.Namespace())
|
||||
}
|
||||
delete(h.namespaceWorkers, key)
|
||||
|
||||
h.roomEnabled = h.roomAvailableLocked()
|
||||
h.publisherEnabled = h.publisherAvailableLocked()
|
||||
}
|
||||
|
||||
func (h *AgentHandler) roomAvailableLocked() bool {
|
||||
for _, w := range h.workers {
|
||||
if w.JobType() == livekit.JobType_JT_ROOM {
|
||||
return true
|
||||
h.roomKeyCount--
|
||||
h.agentServer.DeregisterJobRequestTopic(w.Namespace(), h.roomTopic)
|
||||
} else {
|
||||
h.publisherKeyCount--
|
||||
h.agentServer.DeregisterJobRequestTopic(w.Namespace(), h.publisherTopic)
|
||||
}
|
||||
|
||||
if i := slices.Index(h.namespaces, w.Namespace()); i != -1 {
|
||||
h.namespaces = slices.Delete(h.namespaces, i, i+1)
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func (h *AgentHandler) publisherAvailableLocked() bool {
|
||||
for _, w := range h.workers {
|
||||
if w.JobType() == livekit.JobType_JT_PUBLISHER {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) {
|
||||
attempted := make(map[string]bool)
|
||||
key := workerKey{job.Namespace, job.Type}
|
||||
attempted := make(map[*agent.Worker]struct{})
|
||||
for {
|
||||
h.mu.Lock()
|
||||
var selected *agent.Worker
|
||||
var maxLoad float32
|
||||
for _, w := range h.workers {
|
||||
if w.Namespace() != job.Namespace || w.JobType() != job.Type {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok := attempted[w.ID()]
|
||||
if ok {
|
||||
for _, w := range h.namespaceWorkers[key] {
|
||||
if _, ok := attempted[w]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -341,7 +300,6 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty
|
||||
selected = w
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
@@ -349,7 +307,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty
|
||||
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "no workers available")
|
||||
}
|
||||
|
||||
attempted[selected.ID()] = true
|
||||
attempted[selected] = struct{}{}
|
||||
|
||||
values := []interface{}{
|
||||
"jobID", job.Id,
|
||||
@@ -373,7 +331,6 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 {
|
||||
@@ -396,7 +353,6 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job)
|
||||
affinity = 0.5
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return affinity
|
||||
@@ -405,24 +361,14 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job)
|
||||
func (h *AgentHandler) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
namespaces := make([]string, 0, len(h.namespaces))
|
||||
for ns := range h.namespaces {
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
|
||||
return &rpc.CheckEnabledResponse{
|
||||
Namespaces: namespaces,
|
||||
RoomEnabled: h.roomEnabled,
|
||||
PublisherEnabled: h.publisherEnabled,
|
||||
Namespaces: slices.Compact(slices.Clone(h.namespaces)),
|
||||
RoomEnabled: h.roomKeyCount != 0,
|
||||
PublisherEnabled: h.publisherKeyCount != 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) NumConnections() int {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return len(h.workers)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) DrainConnections(interval time.Duration) {
|
||||
// jitter drain start
|
||||
time.Sleep(time.Duration(rand.Int63n(int64(interval))))
|
||||
|
||||
@@ -18,12 +18,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -369,17 +367,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
for {
|
||||
req, count, err := sigConn.ReadRequest()
|
||||
if err != nil {
|
||||
// normal/expected closure
|
||||
if errors.Is(err, io.EOF) ||
|
||||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
|
||||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
|
||||
websocket.IsCloseError(
|
||||
err,
|
||||
websocket.CloseAbnormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseNoStatusReceived,
|
||||
) {
|
||||
if IsWebSocketCloseError(err) {
|
||||
closedByClient.Store(true)
|
||||
} else {
|
||||
pLogger.Errorw("error reading from websocket", err, "connID", cr.ConnectionID)
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -49,6 +52,10 @@ func NewWSSignalConnection(conn types.WebsocketClient) *WSSignalConnection {
|
||||
return wsc
|
||||
}
|
||||
|
||||
func (c *WSSignalConnection) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error) {
|
||||
for {
|
||||
// handle special messages and pass on the rest
|
||||
@@ -172,3 +179,17 @@ func (c *WSSignalConnection) pingWorker() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsWebSocketCloseError checks that error is normal/expected closure
|
||||
func IsWebSocketCloseError(err error) bool {
|
||||
return errors.Is(err, io.EOF) ||
|
||||
strings.HasSuffix(err.Error(), "use of closed network connection") ||
|
||||
strings.HasSuffix(err.Error(), "connection reset by peer") ||
|
||||
websocket.IsCloseError(
|
||||
err,
|
||||
websocket.CloseAbnormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseNoStatusReceived,
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user