distribute load to agents probabilistically, inversely proportionate to load (#2902)

* select the least loaded agent worker for job dispatch

* update to load balance using inverse load

* remove unused file

* adding unit tests for worker job distribution
This commit is contained in:
Dan McFaul
2024-08-07 21:05:47 -06:00
committed by GitHub
parent 63dd744f58
commit 489f73f0a4
3 changed files with 214 additions and 30 deletions

View File

@@ -2,6 +2,8 @@ package agent_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -25,7 +27,7 @@ func TestAgent(t *testing.T) {
t.Cleanup(server.Close)
worker := server.SimulateAgentWorker()
worker.Register("test", livekit.JobType_JT_ROOM)
worker.Register("", "test", livekit.JobType_JT_ROOM)
jobAssignments := worker.JobAssignments.Observe()
job := &livekit.Job{
@@ -33,9 +35,10 @@ func TestAgent(t *testing.T) {
DispatchId: guid.New(guid.AgentDispatchPrefix),
Type: livekit.JobType_JT_ROOM,
Room: &livekit.Room{},
AgentName: "test",
Namespace: "test",
}
client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job)
_, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job)
require.NoError(t, err)
select {
case a := <-jobAssignments.Events():
@@ -45,3 +48,141 @@ func TestAgent(t *testing.T) {
}
})
}
func TestAgentLoadBalancing(t *testing.T) {
batchJobCreate := func(wg *sync.WaitGroup, batchSize int, totalJobs int, client rpc.AgentInternalClient) {
for i := 0; i < totalJobs; i += batchSize {
wg.Add(1)
go func(start int) {
defer wg.Done()
for j := start; j < start+batchSize && j < totalJobs; j++ {
job := &livekit.Job{
Id: guid.New(guid.AgentJobPrefix),
DispatchId: guid.New(guid.AgentDispatchPrefix),
Type: livekit.JobType_JT_ROOM,
Room: &livekit.Room{},
Namespace: "test",
}
_, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job)
require.NoError(t, err)
}
}(i)
}
}
t.Run("jobs are distributed normally with baseline worker load", func(t *testing.T) {
totalWorkers := 5
totalJobs := 100
bus := psrpc.NewLocalMessageBus()
client := must.Get(rpc.NewAgentInternalClient(bus))
server := testutil.NewTestServer(bus)
t.Cleanup(server.Close)
agents := make([]*testutil.AgentWorker, totalWorkers)
for i := 0; i < totalWorkers; i++ {
agents[i] = server.SimulateAgentWorker()
agents[i].Register(fmt.Sprintf("agent-%d", i), "test", livekit.JobType_JT_ROOM)
}
jobAssignments := make(chan *livekit.Job, totalJobs)
for i := 0; i < totalWorkers; i++ {
worker := agents[i]
go func() {
for a := range worker.JobAssignments.Observe().Events() {
jobAssignments <- a.Job
}
}()
}
var wg sync.WaitGroup
batchJobCreate(&wg, 10, totalJobs, client)
wg.Wait()
jobCount := make(map[string]int)
for i := 0; i < totalJobs; i++ {
select {
case job := <-jobAssignments:
jobCount[job.AgentName]++
case <-time.After(time.Second):
require.Fail(t, "job assignment timeout")
}
}
assignedJobs := 0
// check that jobs are distributed normally
for i := 0; i < totalWorkers; i++ {
agentName := fmt.Sprintf("agent-%d", i)
assignedJobs += jobCount[agentName]
require.GreaterOrEqual(t, jobCount[agentName], 0)
require.Less(t, jobCount[agentName], 35) // three std deviations from the mean is 32
}
// ensure all jobs are assigned
require.Equal(t, 100, assignedJobs)
})
t.Run("jobs are distributed with variable and overloaded worker load", func(t *testing.T) {
totalWorkers := 4
totalJobs := 15
bus := psrpc.NewLocalMessageBus()
client := must.Get(rpc.NewAgentInternalClient(bus))
server := testutil.NewTestServer(bus)
t.Cleanup(server.Close)
agents := make([]*testutil.AgentWorker, totalWorkers)
for i := 0; i < totalWorkers; i++ {
if i%2 == 0 {
// make sure we have some workers that can accept jobs
agents[i] = server.SimulateAgentWorker()
} else {
agents[i] = server.SimulateAgentWorker(testutil.WithDefaultWorkerLoad(0.9))
}
agents[i].Register(fmt.Sprintf("agent-%d", i), "test", livekit.JobType_JT_ROOM)
}
jobAssignments := make(chan *livekit.Job, totalJobs)
for i := 0; i < totalWorkers; i++ {
worker := agents[i]
go func() {
for a := range worker.JobAssignments.Observe().Events() {
jobAssignments <- a.Job
}
}()
}
var wg sync.WaitGroup
batchJobCreate(&wg, 1, totalJobs, client)
wg.Wait()
jobCount := make(map[string]int)
for i := 0; i < totalJobs; i++ {
select {
case job := <-jobAssignments:
jobCount[job.AgentName]++
case <-time.After(time.Second):
require.Fail(t, "job assignment timeout")
}
}
assignedJobs := 0
for i := 0; i < totalWorkers; i++ {
agentName := fmt.Sprintf("agent-%d", i)
assignedJobs += jobCount[agentName]
if i%2 == 0 {
require.GreaterOrEqual(t, jobCount[agentName], 2)
} else {
require.Equal(t, 0, jobCount[agentName])
}
require.GreaterOrEqual(t, jobCount[agentName], 0)
}
// ensure all jobs are assigned
require.Equal(t, 15, assignedJobs)
})
}

View File

@@ -49,6 +49,7 @@ type SimulatedWorkerOptions struct {
SupportResume bool
DefaultJobLoad float32
JobLoadThreshold float32
DefaultWorkerLoad float32
HandleAvailability func(AgentJobRequest)
HandleAssignment func(*livekit.Job) JobLoad
}
@@ -71,10 +72,17 @@ func WithJobLoad(l JobLoad) SimulatedWorkerOption {
return WithJobAssignmentHandler(func(j *livekit.Job) JobLoad { return l })
}
func WithDefaultWorkerLoad(load float32) SimulatedWorkerOption {
return func(o *SimulatedWorkerOptions) {
o.DefaultWorkerLoad = load
}
}
func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWorker {
o := &SimulatedWorkerOptions{
DefaultJobLoad: 0.1,
JobLoadThreshold: 0.8,
DefaultWorkerLoad: 0.0,
HandleAvailability: func(r AgentJobRequest) { r.Accept() },
HandleAssignment: func(j *livekit.Job) JobLoad { return nil },
}
@@ -95,6 +103,10 @@ func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWo
}
w.ctx, w.cancel = context.WithCancel(context.Background())
if o.DefaultWorkerLoad > 0.0 {
w.sendStatus()
}
go w.worker()
go h.handleConnection(w)
return w
@@ -173,6 +185,7 @@ func (r AgentJobRequest) Reject() {
}
type AgentWorker struct {
Name string
*SimulatedWorkerOptions
fuse core.Fuse
@@ -290,12 +303,14 @@ func (w *AgentWorker) handleAvailability(m *livekit.AvailabilityRequest) {
}
func (w *AgentWorker) handleAssignment(m *livekit.JobAssignment) {
m.Job.AgentName = w.Name
w.JobAssignments.Emit(m)
var load JobLoad
if w.HandleAssignment != nil {
load = w.HandleAssignment(m.Job)
}
if load == nil {
load = NewStableJobLoad(w.DefaultJobLoad)
}
@@ -370,8 +385,13 @@ func (w *AgentWorker) sendStatus() {
w.mu.Lock()
var load float32
jobCount := len(w.jobs)
for _, j := range w.jobs {
load += j.Load()
if len(w.jobs) == 0 {
load = w.DefaultWorkerLoad
} else {
for _, j := range w.jobs {
load += j.Load()
}
}
w.mu.Unlock()
@@ -387,10 +407,11 @@ func (w *AgentWorker) sendStatus() {
})
}
func (w *AgentWorker) Register(agentName string, jobType livekit.JobType) {
func (w *AgentWorker) Register(name string, namespace string, jobType livekit.JobType) {
w.Name = name
w.SendRegister(&livekit.RegisterWorkerRequest{
Type: jobType,
AgentName: agentName,
Namespace: &namespace,
})
w.sendStatus()
}

View File

@@ -308,28 +308,9 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J
key := workerKey{job.AgentName, job.Namespace, job.Type}
attempted := make(map[*agent.Worker]struct{})
for {
h.mu.Lock()
var selected *agent.Worker
var maxLoad float32
for _, w := range h.namespaceWorkers[key] {
if _, ok := attempted[w]; ok {
continue
}
if w.Status() == livekit.WorkerStatus_WS_AVAILABLE {
load := w.Load()
if len(w.RunningJobs()) > 0 && load > maxLoad {
maxLoad = load
selected = w
} else if selected == nil {
selected = w
}
}
}
h.mu.Unlock()
if selected == nil {
return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "no workers available")
selected, err := h.selectWorkerWeightedByLoad(key, attempted)
if err != nil {
return nil, psrpc.NewError(psrpc.DeadlineExceeded, err)
}
attempted[selected] = struct{}{}
@@ -347,7 +328,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J
values = append(values, "participant", job.Participant.Identity)
}
h.logger.Debugw("assigning job", values...)
err := selected.AssignJob(ctx, job)
err = selected.AssignJob(ctx, job)
if err != nil {
if errors.Is(err, agent.ErrWorkerNotAvailable) {
continue // Try another worker
@@ -415,3 +396,44 @@ func (h *AgentHandler) DrainConnections(interval time.Duration) {
<-t.C
}
}
func (h *AgentHandler) selectWorkerWeightedByLoad(key workerKey, ignore map[*agent.Worker]struct{}) (*agent.Worker, error) {
h.mu.Lock()
defer h.mu.Unlock()
workers, ok := h.namespaceWorkers[key]
if !ok {
return nil, errors.New("no workers available")
}
normalizeLoad := func(load float32) int {
if load >= 1 {
return 0
}
return int((1 - load) * 100)
}
normalizedLoads := make(map[*agent.Worker]int)
var availableSum int
for _, w := range workers {
if _, ok := ignore[w]; !ok && w.Status() == livekit.WorkerStatus_WS_AVAILABLE {
normalizedLoads[w] = normalizeLoad(w.Load())
availableSum += normalizedLoads[w]
}
}
if availableSum == 0 {
return nil, errors.New("no workers with sufficient capacity")
}
threshold := rand.Intn(availableSum)
var currentSum int
for w, load := range normalizedLoads {
currentSum += load
if currentSum >= threshold {
return w, nil
}
}
return nil, errors.New("no workers available")
}