From 096157e7068c5059a3556e8a6c80bbca70503574 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Wed, 25 Sep 2024 03:04:01 -0700 Subject: [PATCH] clean up worker jobs in handler when job ends (#3042) --- pkg/agent/worker.go | 52 +++++++++++++++++++++++-------------- pkg/service/agentservice.go | 49 +++++++++++++++++++++------------- 2 files changed, 64 insertions(+), 37 deletions(-) diff --git a/pkg/agent/worker.go b/pkg/agent/worker.go index c4729482b..5803b4918 100644 --- a/pkg/agent/worker.go +++ b/pkg/agent/worker.go @@ -36,6 +36,7 @@ var ( ErrUnimplementedWrorkerSignal = errors.New("unimplemented worker signal") ErrUnknownWorkerSignal = errors.New("unknown worker signal") ErrUnknownJobType = errors.New("unknown job type") + ErrJobNotFound = psrpc.NewErrorf(psrpc.NotFound, "no running job for given jobID") ErrWorkerClosed = errors.New("worker closed") ErrWorkerNotAvailable = errors.New("worker not available") ErrAvailabilityTimeout = errors.New("agent worker availability timeout") @@ -236,8 +237,8 @@ type Worker struct { load float32 status livekit.WorkerStatus - runningJobs map[string]*livekit.Job - availability map[string]chan *livekit.AvailabilityResponse + runningJobs map[livekit.JobID]*livekit.Job + availability map[livekit.JobID]chan *livekit.AvailabilityResponse } func NewWorker( @@ -264,8 +265,8 @@ func NewWorker( cancel: cancel, closed: make(chan struct{}), - runningJobs: make(map[string]*livekit.Job), - availability: make(map[string]chan *livekit.AvailabilityResponse), + runningJobs: make(map[livekit.JobID]*livekit.Job), + availability: make(map[livekit.JobID]chan *livekit.AvailabilityResponse), } } @@ -291,32 +292,43 @@ func (w *Worker) Logger() logger.Logger { return w.logger } -func (w *Worker) RunningJobs() map[string]*livekit.Job { +func (w *Worker) RunningJobs() map[livekit.JobID]*livekit.Job { w.mu.Lock() defer w.mu.Unlock() - jobs := make(map[string]*livekit.Job, len(w.runningJobs)) + jobs := make(map[livekit.JobID]*livekit.Job, len(w.runningJobs)) for k, v := range w.runningJobs { jobs[k] = v } return jobs } +func (w *Worker) GetJobState(jobID livekit.JobID) (*livekit.JobState, error) { + w.mu.Lock() + defer w.mu.Unlock() + j, ok := w.runningJobs[jobID] + if !ok { + return nil, ErrJobNotFound + } + return utils.CloneProto(j.State), nil +} + func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) (*livekit.JobState, error) { availCh := make(chan *livekit.AvailabilityResponse, 1) job = utils.CloneProto(job) + jobID := livekit.JobID(job.Id) w.mu.Lock() - if _, ok := w.availability[job.Id]; ok { + if _, ok := w.availability[jobID]; ok { w.mu.Unlock() return nil, ErrDuplicateJobAssignment } - w.availability[job.Id] = availCh + w.availability[jobID] = availCh w.mu.Unlock() defer func() { w.mu.Lock() - delete(w.availability, job.Id) + delete(w.availability, jobID) w.mu.Unlock() }() @@ -357,7 +369,7 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) (*livekit.JobS state := utils.CloneProto(job.State) w.mu.Lock() - w.runningJobs[job.Id] = job + w.runningJobs[jobID] = job w.mu.Unlock() // TODO sweep jobs that are never started. We can't do this until all SDKs actually update the the JOB state @@ -372,18 +384,18 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) (*livekit.JobS } } -func (w *Worker) TerminateJob(jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) { +func (w *Worker) TerminateJob(jobID livekit.JobID, reason rpc.JobTerminateReason) (*livekit.JobState, error) { w.mu.Lock() job := w.runningJobs[jobID] w.mu.Unlock() if job == nil { - return nil, psrpc.NewErrorf(psrpc.NotFound, "no running job for given jobID") + return nil, ErrJobNotFound } w.sendRequest(&livekit.ServerMessage{Message: &livekit.ServerMessage_Termination{ Termination: &livekit.JobTermination{ - JobId: jobID, + JobId: string(jobID), }, }}) @@ -395,7 +407,7 @@ func (w *Worker) TerminateJob(jobID string, reason rpc.JobTerminateReason) (*liv } return w.UpdateJobStatus(&livekit.UpdateJobStatus{ - JobId: jobID, + JobId: string(jobID), Status: status, Error: errorStr, }) @@ -433,14 +445,15 @@ func (w *Worker) HandleAvailability(res *livekit.AvailabilityResponse) error { w.mu.Lock() defer w.mu.Unlock() - availCh, ok := w.availability[res.JobId] + jobID := livekit.JobID(res.JobId) + availCh, ok := w.availability[jobID] if !ok { - w.logger.Warnw("received availability response for unknown job", nil, "jobId", res.JobId) + w.logger.Warnw("received availability response for unknown job", nil, "jobId", jobID) return nil } availCh <- res - delete(w.availability, res.JobId) + delete(w.availability, jobID) return nil } @@ -458,7 +471,8 @@ func (w *Worker) UpdateJobStatus(update *livekit.UpdateJobStatus) (*livekit.JobS w.mu.Lock() defer w.mu.Unlock() - job, ok := w.runningJobs[update.JobId] + jobID := livekit.JobID(update.JobId) + job, ok := w.runningJobs[jobID] if !ok { return nil, psrpc.NewErrorf(psrpc.NotFound, "received job update for unknown job") } @@ -479,7 +493,7 @@ func (w *Worker) UpdateJobStatus(update *livekit.UpdateJobStatus) (*livekit.JobS // TODO do not delete, leave inside the JobDefinition if JobStatusIsEnded(job.State.Status) { - delete(w.runningJobs, job.Id) + delete(w.runningJobs, jobID) } return proto.Clone(job.State).(*livekit.JobState), nil diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index 27e50802a..a04607d26 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -132,7 +132,7 @@ type AgentHandler struct { serverInfo *livekit.ServerInfo workers map[string]*agent.Worker - jobToWorker map[string]*agent.Worker + jobToWorker map[livekit.JobID]*agent.Worker keyProvider auth.KeyProvider namespaceWorkers map[workerKey][]*agent.Worker @@ -201,7 +201,7 @@ func NewAgentHandler( agentServer: agentServer, logger: logger, workers: make(map[string]*agent.Worker), - jobToWorker: make(map[string]*agent.Worker), + jobToWorker: make(map[livekit.JobID]*agent.Worker), namespaceWorkers: make(map[workerKey][]*agent.Worker), serverInfo: serverInfo, keyProvider: keyProvider, @@ -222,8 +222,9 @@ func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalCo worker := agent.NewWorker(registration, apiKey, apiSecret, conn, h.logger) h.registerWorker(worker) + handlerWorker := &agentHandlerWorker{h, worker} for ok := true; ok; { - ok = DispatchAgentWorkerSignal(conn, worker, worker.Logger()) + ok = DispatchAgentWorkerSignal(conn, handlerWorker, worker.Logger()) } h.deregisterWorker(worker) @@ -321,21 +322,13 @@ func (h *AgentHandler) deregisterWorker(w *agent.Worker) { } jobs := w.RunningJobs() - for _, j := range jobs { - h.deregisterJob(j.Id) + for jobID := range jobs { + h.deregisterJob(jobID) } } -func (h *AgentHandler) HandleWorkerJobStatus(w *agent.Worker, status *livekit.UpdateJobStatus) { - if agent.JobStatusIsEnded(status.Status) { - h.mu.Lock() - h.deregisterJob(status.JobId) - h.mu.Unlock() - } -} - -func (h *AgentHandler) deregisterJob(jobID string) { - h.agentServer.DeregisterJobTerminateTopic(jobID) +func (h *AgentHandler) deregisterJob(jobID livekit.JobID) { + h.agentServer.DeregisterJobTerminateTopic(string(jobID)) delete(h.jobToWorker, jobID) @@ -374,7 +367,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J return nil, err } h.mu.Lock() - h.jobToWorker[job.Id] = selected + h.jobToWorker[livekit.JobID(job.Id)] = selected h.mu.Unlock() err = h.agentServer.RegisterJobTerminateTopic(job.Id) @@ -408,14 +401,14 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) func (h *AgentHandler) JobTerminate(ctx context.Context, req *rpc.JobTerminateRequest) (*rpc.JobTerminateResponse, error) { h.mu.Lock() - w := h.jobToWorker[req.JobId] + w := h.jobToWorker[livekit.JobID(req.JobId)] h.mu.Unlock() if w == nil { return nil, psrpc.NewErrorf(psrpc.NotFound, "no worker for jobID") } - state, err := w.TerminateJob(req.JobId, req.Reason) + state, err := w.TerminateJob(livekit.JobID(req.JobId), req.Reason) if err != nil { return nil, err } @@ -485,3 +478,23 @@ func (h *AgentHandler) selectWorkerWeightedByLoad(key workerKey, ignore map[*age } return workers[0], nil } + +var _ agent.WorkerSignalHandler = (*agentHandlerWorker)(nil) + +type agentHandlerWorker struct { + h *AgentHandler + *agent.Worker +} + +func (w *agentHandlerWorker) HandleUpdateJob(update *livekit.UpdateJobStatus) error { + if err := w.Worker.HandleUpdateJob(update); err != nil { + return err + } + + if agent.JobStatusIsEnded(update.Status) { + w.h.mu.Lock() + w.h.deregisterJob(livekit.JobID(update.JobId)) + w.h.mu.Unlock() + } + return nil +}