diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 7d3628a43..52dee7839 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -2,6 +2,8 @@ package agent_test import ( "context" + "fmt" + "sync" "testing" "time" @@ -25,7 +27,7 @@ func TestAgent(t *testing.T) { t.Cleanup(server.Close) worker := server.SimulateAgentWorker() - worker.Register("test", livekit.JobType_JT_ROOM) + worker.Register("", "test", livekit.JobType_JT_ROOM) jobAssignments := worker.JobAssignments.Observe() job := &livekit.Job{ @@ -33,9 +35,10 @@ func TestAgent(t *testing.T) { DispatchId: guid.New(guid.AgentDispatchPrefix), Type: livekit.JobType_JT_ROOM, Room: &livekit.Room{}, - AgentName: "test", + Namespace: "test", } - client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + _, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + require.NoError(t, err) select { case a := <-jobAssignments.Events(): @@ -45,3 +48,141 @@ func TestAgent(t *testing.T) { } }) } + +func TestAgentLoadBalancing(t *testing.T) { + + batchJobCreate := func(wg *sync.WaitGroup, batchSize int, totalJobs int, client rpc.AgentInternalClient) { + for i := 0; i < totalJobs; i += batchSize { + wg.Add(1) + go func(start int) { + defer wg.Done() + for j := start; j < start+batchSize && j < totalJobs; j++ { + job := &livekit.Job{ + Id: guid.New(guid.AgentJobPrefix), + DispatchId: guid.New(guid.AgentDispatchPrefix), + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{}, + Namespace: "test", + } + _, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + require.NoError(t, err) + } + }(i) + } + } + + t.Run("jobs are distributed normally with baseline worker load", func(t *testing.T) { + totalWorkers := 5 + totalJobs := 100 + + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + server := testutil.NewTestServer(bus) + t.Cleanup(server.Close) + + agents := make([]*testutil.AgentWorker, totalWorkers) + for i := 0; i < totalWorkers; i++ { + agents[i] = server.SimulateAgentWorker() + agents[i].Register(fmt.Sprintf("agent-%d", i), "test", livekit.JobType_JT_ROOM) + } + + jobAssignments := make(chan *livekit.Job, totalJobs) + for i := 0; i < totalWorkers; i++ { + worker := agents[i] + go func() { + for a := range worker.JobAssignments.Observe().Events() { + jobAssignments <- a.Job + } + }() + } + + var wg sync.WaitGroup + batchJobCreate(&wg, 10, totalJobs, client) + wg.Wait() + + jobCount := make(map[string]int) + for i := 0; i < totalJobs; i++ { + select { + case job := <-jobAssignments: + jobCount[job.AgentName]++ + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + } + + assignedJobs := 0 + // check that jobs are distributed normally + for i := 0; i < totalWorkers; i++ { + agentName := fmt.Sprintf("agent-%d", i) + assignedJobs += jobCount[agentName] + require.GreaterOrEqual(t, jobCount[agentName], 0) + require.Less(t, jobCount[agentName], 35) // three std deviations from the mean is 32 + } + + // ensure all jobs are assigned + require.Equal(t, 100, assignedJobs) + }) + + t.Run("jobs are distributed with variable and overloaded worker load", func(t *testing.T) { + totalWorkers := 4 + totalJobs := 15 + + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + server := testutil.NewTestServer(bus) + t.Cleanup(server.Close) + + agents := make([]*testutil.AgentWorker, totalWorkers) + for i := 0; i < totalWorkers; i++ { + if i%2 == 0 { + // make sure we have some workers that can accept jobs + agents[i] = server.SimulateAgentWorker() + } else { + agents[i] = server.SimulateAgentWorker(testutil.WithDefaultWorkerLoad(0.9)) + } + agents[i].Register(fmt.Sprintf("agent-%d", i), "test", livekit.JobType_JT_ROOM) + } + + jobAssignments := make(chan *livekit.Job, totalJobs) + for i := 0; i < totalWorkers; i++ { + worker := agents[i] + go func() { + for a := range worker.JobAssignments.Observe().Events() { + jobAssignments <- a.Job + } + }() + } + + var wg sync.WaitGroup + batchJobCreate(&wg, 1, totalJobs, client) + wg.Wait() + + jobCount := make(map[string]int) + for i := 0; i < totalJobs; i++ { + select { + case job := <-jobAssignments: + jobCount[job.AgentName]++ + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + } + + assignedJobs := 0 + for i := 0; i < totalWorkers; i++ { + agentName := fmt.Sprintf("agent-%d", i) + assignedJobs += jobCount[agentName] + + if i%2 == 0 { + require.GreaterOrEqual(t, jobCount[agentName], 2) + } else { + require.Equal(t, 0, jobCount[agentName]) + } + require.GreaterOrEqual(t, jobCount[agentName], 0) + } + + // ensure all jobs are assigned + require.Equal(t, 15, assignedJobs) + }) +} diff --git a/pkg/agent/testutil/server.go b/pkg/agent/testutil/server.go index 4ff2d560a..04f437534 100644 --- a/pkg/agent/testutil/server.go +++ b/pkg/agent/testutil/server.go @@ -49,6 +49,7 @@ type SimulatedWorkerOptions struct { SupportResume bool DefaultJobLoad float32 JobLoadThreshold float32 + DefaultWorkerLoad float32 HandleAvailability func(AgentJobRequest) HandleAssignment func(*livekit.Job) JobLoad } @@ -71,10 +72,17 @@ func WithJobLoad(l JobLoad) SimulatedWorkerOption { return WithJobAssignmentHandler(func(j *livekit.Job) JobLoad { return l }) } +func WithDefaultWorkerLoad(load float32) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.DefaultWorkerLoad = load + } +} + func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWorker { o := &SimulatedWorkerOptions{ DefaultJobLoad: 0.1, JobLoadThreshold: 0.8, + DefaultWorkerLoad: 0.0, HandleAvailability: func(r AgentJobRequest) { r.Accept() }, HandleAssignment: func(j *livekit.Job) JobLoad { return nil }, } @@ -95,6 +103,10 @@ func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWo } w.ctx, w.cancel = context.WithCancel(context.Background()) + if o.DefaultWorkerLoad > 0.0 { + w.sendStatus() + } + go w.worker() go h.handleConnection(w) return w @@ -173,6 +185,7 @@ func (r AgentJobRequest) Reject() { } type AgentWorker struct { + Name string *SimulatedWorkerOptions fuse core.Fuse @@ -290,12 +303,14 @@ func (w *AgentWorker) handleAvailability(m *livekit.AvailabilityRequest) { } func (w *AgentWorker) handleAssignment(m *livekit.JobAssignment) { + m.Job.AgentName = w.Name w.JobAssignments.Emit(m) var load JobLoad if w.HandleAssignment != nil { load = w.HandleAssignment(m.Job) } + if load == nil { load = NewStableJobLoad(w.DefaultJobLoad) } @@ -370,8 +385,13 @@ func (w *AgentWorker) sendStatus() { w.mu.Lock() var load float32 jobCount := len(w.jobs) - for _, j := range w.jobs { - load += j.Load() + + if len(w.jobs) == 0 { + load = w.DefaultWorkerLoad + } else { + for _, j := range w.jobs { + load += j.Load() + } } w.mu.Unlock() @@ -387,10 +407,11 @@ func (w *AgentWorker) sendStatus() { }) } -func (w *AgentWorker) Register(agentName string, jobType livekit.JobType) { +func (w *AgentWorker) Register(name string, namespace string, jobType livekit.JobType) { + w.Name = name w.SendRegister(&livekit.RegisterWorkerRequest{ Type: jobType, - AgentName: agentName, + Namespace: &namespace, }) w.sendStatus() } diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index abb708a3d..d626dbd53 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -308,28 +308,9 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J key := workerKey{job.AgentName, job.Namespace, job.Type} attempted := make(map[*agent.Worker]struct{}) for { - h.mu.Lock() - var selected *agent.Worker - var maxLoad float32 - for _, w := range h.namespaceWorkers[key] { - if _, ok := attempted[w]; ok { - continue - } - - if w.Status() == livekit.WorkerStatus_WS_AVAILABLE { - load := w.Load() - if len(w.RunningJobs()) > 0 && load > maxLoad { - maxLoad = load - selected = w - } else if selected == nil { - selected = w - } - } - } - h.mu.Unlock() - - if selected == nil { - return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "no workers available") + selected, err := h.selectWorkerWeightedByLoad(key, attempted) + if err != nil { + return nil, psrpc.NewError(psrpc.DeadlineExceeded, err) } attempted[selected] = struct{}{} @@ -347,7 +328,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J values = append(values, "participant", job.Participant.Identity) } h.logger.Debugw("assigning job", values...) - err := selected.AssignJob(ctx, job) + err = selected.AssignJob(ctx, job) if err != nil { if errors.Is(err, agent.ErrWorkerNotAvailable) { continue // Try another worker @@ -415,3 +396,44 @@ func (h *AgentHandler) DrainConnections(interval time.Duration) { <-t.C } } + +func (h *AgentHandler) selectWorkerWeightedByLoad(key workerKey, ignore map[*agent.Worker]struct{}) (*agent.Worker, error) { + h.mu.Lock() + defer h.mu.Unlock() + + workers, ok := h.namespaceWorkers[key] + if !ok { + return nil, errors.New("no workers available") + } + + normalizeLoad := func(load float32) int { + if load >= 1 { + return 0 + } + return int((1 - load) * 100) + } + + normalizedLoads := make(map[*agent.Worker]int) + var availableSum int + for _, w := range workers { + if _, ok := ignore[w]; !ok && w.Status() == livekit.WorkerStatus_WS_AVAILABLE { + normalizedLoads[w] = normalizeLoad(w.Load()) + availableSum += normalizedLoads[w] + } + } + + if availableSum == 0 { + return nil, errors.New("no workers with sufficient capacity") + } + + threshold := rand.Intn(availableSum) + var currentSum int + for w, load := range normalizedLoads { + currentSum += load + if currentSum >= threshold { + return w, nil + } + } + + return nil, errors.New("no workers available") +}