Files
livekit/pkg/agent/agent_test.go

221 lines
6.1 KiB
Go

package agent_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"github.com/livekit/livekit-server/pkg/agent"
"github.com/livekit/livekit-server/pkg/agent/testutils"
"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/rpc"
"github.com/livekit/protocol/utils/guid"
"github.com/livekit/protocol/utils/must"
"github.com/livekit/psrpc"
)
func TestAgent(t *testing.T) {
testAgentName := "test_agent"
t.Run("dispatched jobs are assigned to a worker", func(t *testing.T) {
bus := psrpc.NewLocalMessageBus()
client := must.Get(rpc.NewAgentInternalClient(bus))
server := testutils.NewTestServer(bus)
t.Cleanup(server.Close)
worker := server.SimulateAgentWorker()
worker.Register(testAgentName, livekit.JobType_JT_ROOM)
jobAssignments := worker.JobAssignments.Observe()
job := &livekit.Job{
Id: guid.New(guid.AgentJobPrefix),
DispatchId: guid.New(guid.AgentDispatchPrefix),
Type: livekit.JobType_JT_ROOM,
Room: &livekit.Room{},
AgentName: testAgentName,
}
_, err := client.JobRequest(context.Background(), testAgentName, agent.RoomAgentTopic, job)
require.NoError(t, err)
select {
case a := <-jobAssignments.Events():
require.EqualValues(t, job.Id, a.Job.Id)
v, err := auth.ParseAPIToken(a.Token)
require.NoError(t, err)
_, claims, err := v.Verify(server.TestAPISecret)
require.NoError(t, err)
require.Equal(t, testAgentName, claims.Attributes[agent.AgentNameAttributeKey])
case <-time.After(time.Second):
require.Fail(t, "job assignment timeout")
}
})
}
func testBatchJobRequest(t require.TestingT, batchSize int, totalJobs int, client rpc.AgentInternalClient, workers []*testutils.AgentWorker) <-chan struct{} {
var assigned atomic.Uint32
done := make(chan struct{})
for _, w := range workers {
assignments := w.JobAssignments.Observe()
go func() {
defer assignments.Stop()
for {
select {
case <-done:
case <-assignments.Events():
if assigned.Inc() == uint32(totalJobs) {
close(done)
}
}
}
}()
}
// wait for agent registration
time.Sleep(100 * time.Millisecond)
var wg sync.WaitGroup
for i := 0; i < totalJobs; i += batchSize {
wg.Add(1)
go func(start int) {
defer wg.Done()
for j := start; j < start+batchSize && j < totalJobs; j++ {
job := &livekit.Job{
Id: guid.New(guid.AgentJobPrefix),
DispatchId: guid.New(guid.AgentDispatchPrefix),
Type: livekit.JobType_JT_ROOM,
Room: &livekit.Room{},
AgentName: "test",
}
_, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job)
require.NoError(t, err)
}
}(i)
}
wg.Wait()
return done
}
func TestAgentLoadBalancing(t *testing.T) {
t.Run("jobs are distributed normally with baseline worker load", func(t *testing.T) {
totalWorkers := 5
totalJobs := 100
bus := psrpc.NewLocalMessageBus()
client := must.Get(rpc.NewAgentInternalClient(bus))
t.Cleanup(client.Close)
server := testutils.NewTestServer(bus)
t.Cleanup(server.Close)
agents := make([]*testutils.AgentWorker, totalWorkers)
for i := range totalWorkers {
agents[i] = server.SimulateAgentWorker(
testutils.WithLabel(fmt.Sprintf("agent-%d", i)),
testutils.WithJobLoad(testutils.NewStableJobLoad(0.01)),
)
agents[i].Register("test", livekit.JobType_JT_ROOM)
}
select {
case <-testBatchJobRequest(t, 10, totalJobs, client, agents):
case <-time.After(time.Second):
require.Fail(t, "job assignment timeout")
}
jobCount := make(map[string]int)
for _, w := range agents {
jobCount[w.Label] = len(w.Jobs())
}
// check that jobs are distributed normally
for i := range totalWorkers {
label := fmt.Sprintf("agent-%d", i)
require.GreaterOrEqual(t, jobCount[label], 0)
require.Less(t, jobCount[label], 35) // three std deviations from the mean is 32
}
})
t.Run("jobs are distributed with variable and overloaded worker load", func(t *testing.T) {
totalWorkers := 4
totalJobs := 15
bus := psrpc.NewLocalMessageBus()
client := must.Get(rpc.NewAgentInternalClient(bus))
t.Cleanup(client.Close)
server := testutils.NewTestServer(bus)
t.Cleanup(server.Close)
agents := make([]*testutils.AgentWorker, totalWorkers)
for i := range totalWorkers {
label := fmt.Sprintf("agent-%d", i)
if i%2 == 0 {
// make sure we have some workers that can accept jobs
agents[i] = server.SimulateAgentWorker(testutils.WithLabel(label))
} else {
agents[i] = server.SimulateAgentWorker(testutils.WithLabel(label), testutils.WithDefaultWorkerLoad(0.9))
}
agents[i].Register("test", livekit.JobType_JT_ROOM)
}
select {
case <-testBatchJobRequest(t, 1, totalJobs, client, agents):
case <-time.After(time.Second):
require.Fail(t, "job assignment timeout")
}
jobCount := make(map[string]int)
for _, w := range agents {
jobCount[w.Label] = len(w.Jobs())
}
for i := range totalWorkers {
label := fmt.Sprintf("agent-%d", i)
if i%2 == 0 {
require.GreaterOrEqual(t, jobCount[label], 2)
} else {
require.Equal(t, 0, jobCount[label])
}
require.GreaterOrEqual(t, jobCount[label], 0)
}
})
}
func TestConnectionClosedOnDispatchError(t *testing.T) {
t.Run("connection closed when unknown message type received", func(t *testing.T) {
bus := psrpc.NewLocalMessageBus()
server := testutils.NewTestServer(bus)
t.Cleanup(server.Close)
// register agent
worker := server.SimulateAgentWorker()
worker.Register("test_agent", livekit.JobType_JT_ROOM)
responses := worker.RegisterWorkerResponses.Observe()
select {
case <-responses.Events():
// registered
case <-time.After(time.Second):
require.Fail(t, "registration timeout")
}
responses.Stop()
// send invalid message (nil Message field triggers ErrUnknownWorkerSignal)
worker.SendMessage(&livekit.WorkerMessage{Message: nil})
select {
case <-worker.Closed():
// connection closed
case <-time.After(time.Second):
require.Fail(t, "connection should have been closed after dispatch error")
}
})
}