From 67a3f04d5ed4f1f232015e5f06c456375f4ad0ec Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:44:53 +0530 Subject: [PATCH 01/10] Bump github.com/docker/docker (#2911) Bumps [github.com/docker/docker](https://github.com/docker/docker) from 27.0.0+incompatible to 27.1.0+incompatible. - [Release notes](https://github.com/docker/docker/releases) - [Commits](https://github.com/docker/docker/commits/v27.1.0) --- updated-dependencies: - dependency-name: github.com/docker/docker dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 3977bd586..a664dc4a7 100644 --- a/go.mod +++ b/go.mod @@ -75,7 +75,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/docker/cli v26.1.4+incompatible // indirect - github.com/docker/docker v27.0.0+incompatible // indirect + github.com/docker/docker v27.1.0+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/eapache/channels v1.1.0 // indirect diff --git a/go.sum b/go.sum index f78e4d63f..27d1955b5 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/docker/cli v26.1.4+incompatible h1:I8PHdc0MtxEADqYJZvhBrW9bo8gawKwwenxRM7/rLu8= github.com/docker/cli v26.1.4+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= -github.com/docker/docker v27.0.0+incompatible h1:JRugTYuelmWlW0M3jakcIadDx2HUoUO6+Tf2C5jVfwA= -github.com/docker/docker v27.0.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v27.1.0+incompatible h1:rEHVQc4GZ0MIQKifQPHSFGV/dVgaZafgRf8fCPtDYBs= +github.com/docker/docker v27.1.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= From de0c5bbd91577465b6f44da793bfbb8994394809 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Tue, 6 Aug 2024 19:45:00 -0700 Subject: [PATCH 02/10] use structured logging for create room request (#2912) --- pkg/service/roomservice.go | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index f4b4b8116..8c817b860 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -19,7 +19,6 @@ import ( "fmt" "strconv" - "github.com/avast/retry-go/v4" "github.com/pkg/errors" "github.com/twitchtv/twirp" "google.golang.org/protobuf/proto" @@ -30,6 +29,7 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/protocol/egress" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" ) @@ -77,9 +77,7 @@ func NewRoomService( } func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, error) { - clone := redactCreateRoomRequest(req) - - AppendLogFields(ctx, "room", clone.Name, "request", clone) + AppendLogFields(ctx, "room", req.Name, "request", logger.Proto(redactCreateRoomRequest(req))) if err := EnsureCreatePermission(ctx); err != nil { return nil, twirpAuthError(err) } else if req.Egress != nil && s.egressLauncher == nil { @@ -281,12 +279,12 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return nil, twirpAuthError(err) } - room, _, err := s.roomStore.LoadRoom(ctx, livekit.RoomName(req.Room), false) + _, _, err := s.roomStore.LoadRoom(ctx, livekit.RoomName(req.Room), false) if err != nil { return nil, err } - room, err = s.roomClient.UpdateRoomMetadata(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + room, err := s.roomClient.UpdateRoomMetadata(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) if err != nil { return nil, err } @@ -294,18 +292,6 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return room, nil } -func (s *RoomService) confirmExecution(ctx context.Context, f func() error) error { - ctx, cancel := context.WithTimeout(ctx, s.apiConf.ExecutionTimeout) - defer cancel() - return retry.Do( - f, - retry.Context(ctx), - retry.Delay(s.apiConf.CheckInterval), - retry.MaxDelay(s.apiConf.MaxCheckInterval), - retry.DelayType(retry.BackOffDelay), - ) -} - // startRoom starts the room on an RTC node, to ensure metadata & empty timeout functionality func (s *RoomService) startRoom(ctx context.Context, roomName livekit.RoomName) (func(), error) { res, err := s.router.StartParticipantSignal(ctx, roomName, routing.ParticipantInit{}) From e9b6bf43c3067199a0f8bd5e172ed34d05764874 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Tue, 6 Aug 2024 19:46:52 -0700 Subject: [PATCH 03/10] add mock agent for integration tests (#2913) * add mock agent for integration tests * cleanup --- pkg/agent/testutil/server.go | 474 +++++++++++++++++++++++++++++++++++ pkg/service/agentservice.go | 27 +- 2 files changed, 495 insertions(+), 6 deletions(-) create mode 100644 pkg/agent/testutil/server.go diff --git a/pkg/agent/testutil/server.go b/pkg/agent/testutil/server.go new file mode 100644 index 000000000..4ff2d560a --- /dev/null +++ b/pkg/agent/testutil/server.go @@ -0,0 +1,474 @@ +package testutil + +import ( + "context" + "errors" + "io" + "math" + "math/rand/v2" + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/gammazero/deque" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/utils/must" + "github.com/livekit/psrpc" +) + +type TestServer struct { + *service.AgentService + keyProvider auth.KeyProvider +} + +func NewTestServer(bus psrpc.MessageBus) *TestServer { + keyProvider := auth.NewSimpleKeyProvider("test", "verysecretsecret") + + s := must.Get(service.NewAgentService( + &config.Config{Region: "test"}, + &livekit.Node{Id: guid.New("N_")}, + bus, + keyProvider, + )) + + return &TestServer{ + AgentService: s, + keyProvider: keyProvider, + } +} + +type SimulatedWorkerOptions struct { + SupportResume bool + DefaultJobLoad float32 + JobLoadThreshold float32 + HandleAvailability func(AgentJobRequest) + HandleAssignment func(*livekit.Job) JobLoad +} + +type SimulatedWorkerOption func(*SimulatedWorkerOptions) + +func WithJobAvailabilityHandler(h func(AgentJobRequest)) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.HandleAvailability = h + } +} + +func WithJobAssignmentHandler(h func(*livekit.Job) JobLoad) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.HandleAssignment = h + } +} + +func WithJobLoad(l JobLoad) SimulatedWorkerOption { + return WithJobAssignmentHandler(func(j *livekit.Job) JobLoad { return l }) +} + +func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWorker { + o := &SimulatedWorkerOptions{ + DefaultJobLoad: 0.1, + JobLoadThreshold: 0.8, + HandleAvailability: func(r AgentJobRequest) { r.Accept() }, + HandleAssignment: func(j *livekit.Job) JobLoad { return nil }, + } + for _, opt := range opts { + opt(o) + } + + w := &AgentWorker{ + workerMessages: make(chan *livekit.WorkerMessage, 1), + jobs: map[string]*AgentJob{}, + SimulatedWorkerOptions: o, + + RegisterWorkerResponses: utils.NewDefaultEventObserverList[*livekit.RegisterWorkerResponse](), + AvailabilityRequests: utils.NewDefaultEventObserverList[*livekit.AvailabilityRequest](), + JobAssignments: utils.NewDefaultEventObserverList[*livekit.JobAssignment](), + JobTerminations: utils.NewDefaultEventObserverList[*livekit.JobTermination](), + WorkerPongs: utils.NewDefaultEventObserverList[*livekit.WorkerPong](), + } + w.ctx, w.cancel = context.WithCancel(context.Background()) + + go w.worker() + go h.handleConnection(w) + return w +} + +func (h *TestServer) handleConnection(w *AgentWorker) { + worker := agent.NewWorker( + agent.CurrentProtocol, + "test", + h.keyProvider.GetSecret("test"), + &livekit.ServerInfo{}, + w, + logger.GetLogger(), + h, + ) + + h.InsertWorker(worker) + + for { + req, _, err := w.ReadWorkerMessage() + if err != nil { + if service.IsWebSocketCloseError(err) { + worker.Logger().Infow("worker closed WS connection", "wsError", err) + } else { + worker.Logger().Errorw("error reading from websocket", err) + } + break + } + + worker.HandleMessage(req) + } + + h.DeleteWorker(worker) + + worker.Close() +} + +func (h *TestServer) Close() { + for _, w := range h.Workers() { + w.Close() + } +} + +var _ agent.SignalConn = (*AgentWorker)(nil) + +type JobLoad interface { + Load() float32 +} + +type AgentJob struct { + *livekit.Job + JobLoad +} + +type AgentJobRequest struct { + w *AgentWorker + *livekit.AvailabilityRequest +} + +func (r AgentJobRequest) Accept() { + identity := guid.New("PI_") + r.w.SendAvailability(&livekit.AvailabilityResponse{ + JobId: r.Job.Id, + Available: true, + SupportsResume: r.w.SupportResume, + ParticipantName: identity, + ParticipantIdentity: identity, + }) +} + +func (r AgentJobRequest) Reject() { + r.w.SendAvailability(&livekit.AvailabilityResponse{ + JobId: r.Job.Id, + Available: false, + }) +} + +type AgentWorker struct { + *SimulatedWorkerOptions + + fuse core.Fuse + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc + workerMessages chan *livekit.WorkerMessage + serverMessages deque.Deque[*livekit.ServerMessage] + jobs map[string]*AgentJob + + RegisterWorkerResponses *utils.EventObserverList[*livekit.RegisterWorkerResponse] + AvailabilityRequests *utils.EventObserverList[*livekit.AvailabilityRequest] + JobAssignments *utils.EventObserverList[*livekit.JobAssignment] + JobTerminations *utils.EventObserverList[*livekit.JobTermination] + WorkerPongs *utils.EventObserverList[*livekit.WorkerPong] +} + +func (w *AgentWorker) worker() { + t := time.NewTicker(5 * time.Second) + defer t.Stop() + + for !w.fuse.IsBroken() { + <-t.C + w.sendStatus() + } +} + +func (w *AgentWorker) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + w.fuse.Break() + return nil +} + +func (w *AgentWorker) SetReadDeadline(t time.Time) error { + w.mu.Lock() + defer w.mu.Unlock() + if !w.fuse.IsBroken() { + cancel := w.cancel + if t.IsZero() { + w.ctx, w.cancel = context.WithCancel(context.Background()) + } else { + w.ctx, w.cancel = context.WithDeadline(context.Background(), t) + } + cancel() + } + return nil +} + +func (w *AgentWorker) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) { + for { + w.mu.Lock() + ctx := w.ctx + w.mu.Unlock() + + select { + case <-w.fuse.Watch(): + return nil, 0, io.EOF + case <-ctx.Done(): + if err := ctx.Err(); errors.Is(err, context.DeadlineExceeded) { + return nil, 0, err + } + case m := <-w.workerMessages: + return m, 0, nil + } + } +} + +func (w *AgentWorker) WriteServerMessage(m *livekit.ServerMessage) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + w.serverMessages.PushBack(m) + if w.serverMessages.Len() == 1 { + go w.handleServerMessages() + } + return 0, nil +} + +func (w *AgentWorker) handleServerMessages() { + w.mu.Lock() + for w.serverMessages.Len() != 0 { + m := w.serverMessages.Front() + w.mu.Unlock() + + switch m := m.Message.(type) { + case *livekit.ServerMessage_Register: + w.handleRegister(m.Register) + case *livekit.ServerMessage_Availability: + w.handleAvailability(m.Availability) + case *livekit.ServerMessage_Assignment: + w.handleAssignment(m.Assignment) + case *livekit.ServerMessage_Termination: + w.handleTermination(m.Termination) + case *livekit.ServerMessage_Pong: + w.handlePong(m.Pong) + } + + w.mu.Lock() + w.serverMessages.PopFront() + } + w.mu.Unlock() +} + +func (w *AgentWorker) handleRegister(m *livekit.RegisterWorkerResponse) { + w.RegisterWorkerResponses.Emit(m) +} + +func (w *AgentWorker) handleAvailability(m *livekit.AvailabilityRequest) { + w.AvailabilityRequests.Emit(m) + if w.HandleAvailability != nil { + w.HandleAvailability(AgentJobRequest{w, m}) + } else { + AgentJobRequest{w, m}.Accept() + } +} + +func (w *AgentWorker) handleAssignment(m *livekit.JobAssignment) { + w.JobAssignments.Emit(m) + + var load JobLoad + if w.HandleAssignment != nil { + load = w.HandleAssignment(m.Job) + } + if load == nil { + load = NewStableJobLoad(w.DefaultJobLoad) + } + + w.mu.Lock() + defer w.mu.Unlock() + w.jobs[m.Job.Id] = &AgentJob{m.Job, load} +} + +func (w *AgentWorker) handleTermination(m *livekit.JobTermination) { + w.JobTerminations.Emit(m) + + w.mu.Lock() + defer w.mu.Unlock() + delete(w.jobs, m.JobId) +} + +func (w *AgentWorker) handlePong(m *livekit.WorkerPong) { + w.WorkerPongs.Emit(m) +} + +func (w *AgentWorker) sendMessage(m *livekit.WorkerMessage) { + select { + case <-w.fuse.Watch(): + case w.workerMessages <- m: + } +} + +func (w *AgentWorker) SendRegister(m *livekit.RegisterWorkerRequest) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_Register{ + Register: m, + }}) +} + +func (w *AgentWorker) SendAvailability(m *livekit.AvailabilityResponse) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_Availability{ + Availability: m, + }}) +} + +func (w *AgentWorker) SendUpdateWorker(m *livekit.UpdateWorkerStatus) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_UpdateWorker{ + UpdateWorker: m, + }}) +} + +func (w *AgentWorker) SendUpdateJob(m *livekit.UpdateJobStatus) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_UpdateJob{ + UpdateJob: m, + }}) +} + +func (w *AgentWorker) SendPing(m *livekit.WorkerPing) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_Ping{ + Ping: m, + }}) +} + +func (w *AgentWorker) SendSimulateJob(m *livekit.SimulateJobRequest) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_SimulateJob{ + SimulateJob: m, + }}) +} + +func (w *AgentWorker) SendMigrateJob(m *livekit.MigrateJobRequest) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_MigrateJob{ + MigrateJob: m, + }}) +} + +func (w *AgentWorker) sendStatus() { + w.mu.Lock() + var load float32 + jobCount := len(w.jobs) + for _, j := range w.jobs { + load += j.Load() + } + w.mu.Unlock() + + status := livekit.WorkerStatus_WS_AVAILABLE + if load > w.JobLoadThreshold { + status = livekit.WorkerStatus_WS_FULL + } + + w.SendUpdateWorker(&livekit.UpdateWorkerStatus{ + Status: &status, + Load: load, + JobCount: int32(jobCount), + }) +} + +func (w *AgentWorker) Register(agentName string, jobType livekit.JobType) { + w.SendRegister(&livekit.RegisterWorkerRequest{ + Type: jobType, + AgentName: agentName, + }) + w.sendStatus() +} + +func (w *AgentWorker) SimulateRoomJob(roomName string) { + w.SendSimulateJob(&livekit.SimulateJobRequest{ + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{ + Sid: guid.New(guid.RoomPrefix), + Name: roomName, + }, + }) +} + +type stableJobLoad struct { + load float32 +} + +func NewStableJobLoad(load float32) JobLoad { + return stableJobLoad{load} +} + +func (s stableJobLoad) Load() float32 { + return s.load +} + +type periodicJobLoad struct { + amplitude float64 + period time.Duration + epoch time.Time +} + +func NewPeriodicJobLoad(max float32, period time.Duration) JobLoad { + return periodicJobLoad{ + amplitude: float64(max / 2), + period: period, + epoch: time.Now().Add(-time.Duration(rand.Int64N(int64(period)))), + } +} + +func (s periodicJobLoad) Load() float32 { + a := math.Sin(time.Since(s.epoch).Seconds() / s.period.Seconds() * math.Pi * 2) + return float32(s.amplitude + a*s.amplitude) +} + +type uniformRandomJobLoad struct { + min, max float32 + rng func() float64 +} + +func NewUniformRandomJobLoad(min, max float32) JobLoad { + return uniformRandomJobLoad{min, max, rand.Float64} +} + +func NewUniformRandomJobLoadWithRNG(min, max float32, rng *rand.Rand) JobLoad { + return uniformRandomJobLoad{min, max, rng.Float64} +} + +func (s uniformRandomJobLoad) Load() float32 { + return rand.Float32()*(s.max-s.min) + s.min +} + +type normalRandomJobLoad struct { + mean, stddev float64 + rng func() float64 +} + +func NewNormalRandomJobLoad(mean, stddev float64) JobLoad { + return normalRandomJobLoad{mean, stddev, rand.Float64} +} + +func NewNormalRandomJobLoadWithRNG(mean, stddev float64, rng *rand.Rand) JobLoad { + return normalRandomJobLoad{mean, stddev, rng.Float64} +} + +func (s normalRandomJobLoad) Load() float32 { + u := 1 - s.rng() + v := s.rng() + z := math.Sqrt(-2.0*math.Log(u)) * math.Cos(2.0*math.Pi*v) + return float32(max(0, z*s.stddev+s.mean)) +} diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index 734609354..abb708a3d 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -26,6 +26,7 @@ import ( "time" "github.com/gorilla/websocket" + "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/emptypb" "github.com/livekit/livekit-server/pkg/agent" @@ -170,15 +171,31 @@ func NewAgentHandler( } } +func (h *AgentHandler) InsertWorker(w *agent.Worker) { + h.mu.Lock() + defer h.mu.Unlock() + h.workers[w.ID()] = w +} + +func (h *AgentHandler) DeleteWorker(w *agent.Worker) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.workers, w.ID()) +} + +func (h *AgentHandler) Workers() []*agent.Worker { + h.mu.Lock() + defer h.mu.Unlock() + return maps.Values(h.workers) +} + func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalConn, protocol agent.WorkerProtocolVersion) { apiKey := GetAPIKey(ctx) apiSecret := h.keyProvider.GetSecret(apiKey) worker := agent.NewWorker(protocol, apiKey, apiSecret, h.serverInfo, conn, h.logger, h) - h.mu.Lock() - h.workers[worker.ID()] = worker - h.mu.Unlock() + h.InsertWorker(worker) for { req, _, err := conn.ReadWorkerMessage() @@ -194,9 +211,7 @@ func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalCo worker.HandleMessage(req) } - h.mu.Lock() - delete(h.workers, worker.ID()) - h.mu.Unlock() + h.DeleteWorker(worker) worker.Close() } From 2346c8a6b762447859fc90697afb18cf54011623 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Tue, 6 Aug 2024 19:51:30 -0700 Subject: [PATCH 04/10] add example agent test (#2914) --- pkg/agent/agent_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 pkg/agent/agent_test.go diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go new file mode 100644 index 000000000..7d3628a43 --- /dev/null +++ b/pkg/agent/agent_test.go @@ -0,0 +1,47 @@ +package agent_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/agent/testutil" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/utils/must" + "github.com/livekit/psrpc" +) + +func TestAgent(t *testing.T) { + t.Run("dispatched jobs are assigned to a worker", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + server := testutil.NewTestServer(bus) + t.Cleanup(server.Close) + + worker := server.SimulateAgentWorker() + worker.Register("test", livekit.JobType_JT_ROOM) + jobAssignments := worker.JobAssignments.Observe() + + job := &livekit.Job{ + Id: guid.New(guid.AgentJobPrefix), + DispatchId: guid.New(guid.AgentDispatchPrefix), + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{}, + AgentName: "test", + } + client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + + select { + case a := <-jobAssignments.Events(): + require.EqualValues(t, job.Id, a.Job.Id) + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + }) +} From a8730b04b8670378dd2aed927e0d83dfe130e18d Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Wed, 7 Aug 2024 22:30:52 +0800 Subject: [PATCH 05/10] move TrackSubscribed trigger to MediaSubscription (#2916) --- pkg/rtc/mediatrack.go | 3 +- pkg/rtc/mediatracksubscriptions.go | 6 ++++ pkg/rtc/types/interfaces.go | 1 + .../typesfakes/fake_local_media_track.go | 30 +++++++++++++++++++ pkg/rtc/types/typesfakes/fake_media_track.go | 30 +++++++++++++++++++ pkg/sfu/receiver.go | 22 ++------------ 6 files changed, 70 insertions(+), 22 deletions(-) diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index c55faf536..11f306a54 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -285,7 +285,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra sfu.WithLoadBalanceThreshold(20), sfu.WithStreamTrackers(), sfu.WithForwardStats(t.params.ForwardStats), - sfu.WithEverHasDownTrackAdded(t.handleReceiverEverAddDowntrack), ) newWR.OnCloseHandler(func() { t.MediaTrackReceiver.SetClosing() @@ -434,7 +433,7 @@ func (t *MediaTrack) SetMuted(muted bool) { t.MediaTrackReceiver.SetMuted(muted) } -func (t *MediaTrack) handleReceiverEverAddDowntrack() { +func (t *MediaTrack) OnTrackSubscribed() { if !t.everSubscribed.Swap(true) && t.params.OnTrackEverSubscribed != nil { go t.params.OnTrackEverSubscribed(t.ID()) } diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index acd85037d..20b73db55 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -161,6 +161,12 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr * AdaptiveStream: sub.GetAdaptiveStream(), }) + subTrack.AddOnBind(func(err error) { + if err == nil { + t.params.MediaTrack.OnTrackSubscribed() + } + }) + // Bind callback can happen from replaceTrack, so set it up early var reusingTransceiver atomic.Bool var dtState sfu.DownTrackState diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index e3d25a59d..1470e58b8 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -495,6 +495,7 @@ type MediaTrack interface { RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity GetAllSubscribers() []livekit.ParticipantID GetNumSubscribers() int + OnTrackSubscribed() // returns quality information that's appropriate for width & height GetQualityForDimension(width, height uint32) livekit.VideoQuality diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index 1bf6297b4..382c609ae 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -221,6 +221,10 @@ type FakeLocalMediaTrack struct { arg1 livekit.NodeID arg2 uint8 } + OnTrackSubscribedStub func() + onTrackSubscribedMutex sync.RWMutex + onTrackSubscribedArgsForCall []struct { + } PublisherIDStub func() livekit.ParticipantID publisherIDMutex sync.RWMutex publisherIDArgsForCall []struct { @@ -1471,6 +1475,30 @@ func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLossArgsForCall(i int) return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeLocalMediaTrack) OnTrackSubscribed() { + fake.onTrackSubscribedMutex.Lock() + fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct { + }{}) + stub := fake.OnTrackSubscribedStub + fake.recordInvocation("OnTrackSubscribed", []interface{}{}) + fake.onTrackSubscribedMutex.Unlock() + if stub != nil { + fake.OnTrackSubscribedStub() + } +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribedCallCount() int { + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() + return len(fake.onTrackSubscribedArgsForCall) +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribedCalls(stub func()) { + fake.onTrackSubscribedMutex.Lock() + defer fake.onTrackSubscribedMutex.Unlock() + fake.OnTrackSubscribedStub = stub +} + func (fake *FakeLocalMediaTrack) PublisherID() livekit.ParticipantID { fake.publisherIDMutex.Lock() ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] @@ -2225,6 +2253,8 @@ func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() fake.notifySubscriberNodeMediaLossMutex.RLock() defer fake.notifySubscriberNodeMediaLossMutex.RUnlock() + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() fake.publisherIDMutex.RLock() defer fake.publisherIDMutex.RUnlock() fake.publisherIdentityMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 887646a18..ce6e9355e 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -176,6 +176,10 @@ type FakeMediaTrack struct { nameReturnsOnCall map[int]struct { result1 string } + OnTrackSubscribedStub func() + onTrackSubscribedMutex sync.RWMutex + onTrackSubscribedArgsForCall []struct { + } PublisherIDStub func() livekit.ParticipantID publisherIDMutex sync.RWMutex publisherIDArgsForCall []struct { @@ -1166,6 +1170,30 @@ func (fake *FakeMediaTrack) NameReturnsOnCall(i int, result1 string) { }{result1} } +func (fake *FakeMediaTrack) OnTrackSubscribed() { + fake.onTrackSubscribedMutex.Lock() + fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct { + }{}) + stub := fake.OnTrackSubscribedStub + fake.recordInvocation("OnTrackSubscribed", []interface{}{}) + fake.onTrackSubscribedMutex.Unlock() + if stub != nil { + fake.OnTrackSubscribedStub() + } +} + +func (fake *FakeMediaTrack) OnTrackSubscribedCallCount() int { + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() + return len(fake.onTrackSubscribedArgsForCall) +} + +func (fake *FakeMediaTrack) OnTrackSubscribedCalls(stub func()) { + fake.onTrackSubscribedMutex.Lock() + defer fake.onTrackSubscribedMutex.Unlock() + fake.OnTrackSubscribedStub = stub +} + func (fake *FakeMediaTrack) PublisherID() livekit.ParticipantID { fake.publisherIDMutex.Lock() ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] @@ -1801,6 +1829,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.kindMutex.RUnlock() fake.nameMutex.RLock() defer fake.nameMutex.RUnlock() + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() fake.publisherIDMutex.RLock() defer fake.publisherIDMutex.RUnlock() fake.publisherIdentityMutex.RLock() diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index beb2fb06b..cc0241070 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -120,10 +120,8 @@ type WebRTCReceiver struct { connectionStats *connectionquality.ConnectionStats - onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) - onMaxLayerChange func(maxLayer int32) - downTrackEverAdded atomic.Bool - onDownTrackEverAdded func() + onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) + onMaxLayerChange func(maxLayer int32) primaryReceiver atomic.Pointer[RedPrimaryReceiver] redReceiver atomic.Pointer[RedReceiver] @@ -177,13 +175,6 @@ func WithForwardStats(forwardStats *ForwardStats) ReceiverOpts { } } -func WithEverHasDownTrackAdded(f func()) ReceiverOpts { - return func(w *WebRTCReceiver) *WebRTCReceiver { - w.onDownTrackEverAdded = f - return w - } -} - // NewWebRTCReceiver creates a new webrtc track receiver func NewWebRTCReceiver( receiver *webrtc.RTPReceiver, @@ -420,16 +411,9 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error { w.downTrackSpreader.Store(track) w.logger.Debugw("downtrack added", "subscriberID", track.SubscriberID()) - w.handleDowntrackAdded() return nil } -func (w *WebRTCReceiver) handleDowntrackAdded() { - if !w.downTrackEverAdded.Swap(true) && w.onDownTrackEverAdded != nil { - w.onDownTrackEverAdded() - } -} - func (w *WebRTCReceiver) notifyMaxExpectedLayer(layer int32) { ti := w.TrackInfo() if ti == nil { @@ -792,7 +776,6 @@ func (w *WebRTCReceiver) GetPrimaryReceiverForRed() TrackReceiver { w.bufferMu.Lock() w.redPktWriter = pr.ForwardRTP w.bufferMu.Unlock() - w.handleDowntrackAdded() } } return w.primaryReceiver.Load() @@ -812,7 +795,6 @@ func (w *WebRTCReceiver) GetRedReceiver() TrackReceiver { w.bufferMu.Lock() w.redPktWriter = pr.ForwardRTP w.bufferMu.Unlock() - w.handleDowntrackAdded() } } return w.redReceiver.Load() From 63dd744f580aaa3be5e7b25bbe6f64ae23325e3d Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 12:54:21 -0700 Subject: [PATCH 06/10] Update pion deps (#2877) Generated by renovateBot Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- go.mod | 19 ++++++++++--------- go.sum | 56 ++++++++++++++++++++++---------------------------------- 2 files changed, 32 insertions(+), 43 deletions(-) diff --git a/go.mod b/go.mod index a664dc4a7..d7a675538 100644 --- a/go.mod +++ b/go.mod @@ -28,16 +28,16 @@ require ( github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 github.com/ory/dockertest/v3 v3.10.0 - github.com/pion/dtls/v2 v2.2.11 - github.com/pion/ice/v2 v2.3.29 - github.com/pion/interceptor v0.1.29 + github.com/pion/dtls/v2 v2.2.12 + github.com/pion/ice/v2 v2.3.34 + github.com/pion/interceptor v0.1.30 github.com/pion/rtcp v1.2.14 - github.com/pion/rtp v1.8.7 - github.com/pion/sctp v1.8.19 + github.com/pion/rtp v1.8.9 + github.com/pion/sctp v1.8.20 github.com/pion/sdp/v3 v3.0.9 - github.com/pion/transport/v2 v2.2.5 + github.com/pion/transport/v2 v2.2.10 github.com/pion/turn/v2 v2.1.6 - github.com/pion/webrtc/v3 v3.2.47 + github.com/pion/webrtc/v3 v3.2.51 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.19.1 github.com/redis/go-redis/v9 v9.6.1 @@ -108,11 +108,11 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/runc v1.1.13 // indirect - github.com/pion/datachannel v1.5.5 // indirect + github.com/pion/datachannel v1.5.8 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns v0.0.12 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/srtp/v2 v2.0.18 // indirect + github.com/pion/srtp/v2 v2.0.20 // indirect github.com/pion/stun v0.6.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect @@ -122,6 +122,7 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect + github.com/wlynxg/anet v0.0.3 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect diff --git a/go.sum b/go.sum index 27d1955b5..b45182d54 100644 --- a/go.sum +++ b/go.sum @@ -228,15 +228,15 @@ github.com/opencontainers/runc v1.1.13 h1:98S2srgG9vw0zWcDpFMn5TRrh8kLxa/5OFUstu github.com/opencontainers/runc v1.1.13/go.mod h1:R016aXacfp/gwQBYw2FDGa9m+n6atbLWrYY8hNMT/sA= github.com/ory/dockertest/v3 v3.10.0 h1:4K3z2VMe8Woe++invjaTB7VRyQXQy5UY+loujO4aNE4= github.com/ory/dockertest/v3 v3.10.0/go.mod h1:nr57ZbRWMqfsdGdFNLHz5jjNdDb7VVFnzAeW1n5N1Lg= -github.com/pion/datachannel v1.5.5 h1:10ef4kwdjije+M9d7Xm9im2Y3O6A6ccQb0zcqZcJew8= -github.com/pion/datachannel v1.5.5/go.mod h1:iMz+lECmfdCMqFRhXhcA/219B0SQlbpoR2V118yimL0= +github.com/pion/datachannel v1.5.8 h1:ph1P1NsGkazkjrvyMfhRBUAWMxugJjq2HfQifaOoSNo= +github.com/pion/datachannel v1.5.8/go.mod h1:PgmdpoaNBLX9HNzNClmdki4DYW5JtI7Yibu8QzbL3tI= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= -github.com/pion/dtls/v2 v2.2.11 h1:9U/dpCYl1ySttROPWJgqWKEylUdT0fXp/xst6JwY5Ks= -github.com/pion/dtls/v2 v2.2.11/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/ice/v2 v2.3.29 h1:nKSU0Kb7F0Idfaz15EwGB1GbOxBlONXnWma5p1lOFcE= -github.com/pion/ice/v2 v2.3.29/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw= -github.com/pion/interceptor v0.1.29 h1:39fsnlP1U8gw2JzOFWdfCU82vHvhW9o0rZnZF56wF+M= -github.com/pion/interceptor v0.1.29/go.mod h1:ri+LGNjRUc5xUNtDEPzfdkmSqISixVTBF/z/Zms/6T4= +github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= +github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= +github.com/pion/ice/v2 v2.3.34 h1:Ic1ppYCj4tUOcPAp76U6F3fVrlSw8A9JtRXLqw6BbUM= +github.com/pion/ice/v2 v2.3.34/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ= +github.com/pion/interceptor v0.1.30 h1:au5rlVHsgmxNi+v/mjOPazbW1SHzfx7/hYOEYQnUcxA= +github.com/pion/interceptor v0.1.30/go.mod h1:RQuKT5HTdkP2Fi0cuOS5G5WNymTjzXaGF75J4k7z2nc= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= @@ -247,33 +247,29 @@ github.com/pion/rtcp v1.2.12/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9 github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE= github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= github.com/pion/rtp v1.8.3/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= -github.com/pion/rtp v1.8.7 h1:qslKkG8qxvQ7hqaxkmL7Pl0XcUm+/Er7nMnu6Vq+ZxM= -github.com/pion/rtp v1.8.7/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= -github.com/pion/sctp v1.8.5/go.mod h1:SUFFfDpViyKejTAdwD1d/HQsCu+V/40cCs2nZIvC3s0= -github.com/pion/sctp v1.8.19 h1:2CYuw+SQ5vkQ9t0HdOPccsCz1GQMDuVy5PglLgKVBW8= -github.com/pion/sctp v1.8.19/go.mod h1:P6PbDVA++OJMrVNg2AL3XtYHV4uD6dvfyOovCgMs0PE= +github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk= +github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= +github.com/pion/sctp v1.8.20 h1:sOc3lkV/tQaP57ZUEXIMdM2V92IIB2ia5v/ygnBxaEg= +github.com/pion/sctp v1.8.20/go.mod h1:oTxw8i5m+WbDHZJL/xUpe6CPIn1Y0GIKKwTLF4h53H8= github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY= github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M= -github.com/pion/srtp/v2 v2.0.18 h1:vKpAXfawO9RtTRKZJbG4y0v1b11NZxQnxRl85kGuUlo= -github.com/pion/srtp/v2 v2.0.18/go.mod h1:0KJQjA99A6/a0DOVTu1PhDSw0CXF2jTkqOoMg3ODqdA= +github.com/pion/srtp/v2 v2.0.20 h1:HNNny4s+OUmG280ETrCdgFndp4ufx3/uy85EawYEhTk= +github.com/pion/srtp/v2 v2.0.20/go.mod h1:0KJQjA99A6/a0DOVTu1PhDSw0CXF2jTkqOoMg3ODqdA= github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= -github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= -github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= -github.com/pion/transport/v2 v2.2.2/go.mod h1:OJg3ojoBJopjEeECq2yJdXH9YVrUJ1uQ++NjXLOUorc= github.com/pion/transport/v2 v2.2.3/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= -github.com/pion/transport/v2 v2.2.5 h1:iyi25i/21gQck4hfRhomF6SktmUQjRsRW4WJdhfc3Kc= -github.com/pion/transport/v2 v2.2.5/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= +github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q= +github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= -github.com/pion/transport/v3 v3.0.2 h1:r+40RJR25S9w3jbA6/5uEPTzcdn7ncyU44RWCbHkLg4= -github.com/pion/transport/v3 v3.0.2/go.mod h1:nIToODoOlb5If2jF9y2Igfx3PFYWfuXi37m0IlWa/D0= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc= github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= -github.com/pion/webrtc/v3 v3.2.47 h1:2DrJ7YnxiZVcmVA+HRyyACCSYvVW8E1YpOvF/EXeRYI= -github.com/pion/webrtc/v3 v3.2.47/go.mod h1:g7pwdiN9Gj2zZZlSTW5XC7OzrgHS9QzRM0y+O2jtjVg= +github.com/pion/webrtc/v3 v3.2.51 h1:NVelmwm/t/QAIb9qNuVDNitLo/858j7DSK3Tk3TwW5s= +github.com/pion/webrtc/v3 v3.2.51/go.mod h1:hVmrDJvwhEertRWObeb1xzulzHGeVUoPlWvxdGzcfU0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -327,6 +323,8 @@ github.com/urfave/cli/v2 v2.27.3 h1:/POWahRmdh7uztQ3CYnaDddk0Rm90PyOgIxgW2rr41M= github.com/urfave/cli/v2 v2.27.3/go.mod h1:m4QzxcD2qpra4z7WhzEGn74WZLViBnMpb1ToCAKdGRQ= github.com/urfave/negroni/v3 v3.1.1 h1:6MS4nG9Jk/UuCACaUlNXCbiKa0ywF9LXz5dGu09v8hw= github.com/urfave/negroni/v3 v3.1.1/go.mod h1:jWvnX03kcSjDBl/ShB0iHvx5uOs7mAzZXW+JvJ5XYAs= +github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -358,7 +356,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= -golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= @@ -393,11 +390,9 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= @@ -441,13 +436,10 @@ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= @@ -456,11 +448,9 @@ golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= @@ -469,10 +459,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= From 489f73f0a49fa332f73c7113cb88d34014643ee1 Mon Sep 17 00:00:00 2001 From: Dan McFaul <55854809+real-danm@users.noreply.github.com> Date: Wed, 7 Aug 2024 21:05:47 -0600 Subject: [PATCH 07/10] distribute load to agents probabilistically, inversely proportionate to load (#2902) * select the least loaded agent worker for job dispatch * update to load balance using inverse load * remove unused file * adding unit tests for worker job distribution --- pkg/agent/agent_test.go | 147 ++++++++++++++++++++++++++++++++++- pkg/agent/testutil/server.go | 29 ++++++- pkg/service/agentservice.go | 68 ++++++++++------ 3 files changed, 214 insertions(+), 30 deletions(-) diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 7d3628a43..52dee7839 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -2,6 +2,8 @@ package agent_test import ( "context" + "fmt" + "sync" "testing" "time" @@ -25,7 +27,7 @@ func TestAgent(t *testing.T) { t.Cleanup(server.Close) worker := server.SimulateAgentWorker() - worker.Register("test", livekit.JobType_JT_ROOM) + worker.Register("", "test", livekit.JobType_JT_ROOM) jobAssignments := worker.JobAssignments.Observe() job := &livekit.Job{ @@ -33,9 +35,10 @@ func TestAgent(t *testing.T) { DispatchId: guid.New(guid.AgentDispatchPrefix), Type: livekit.JobType_JT_ROOM, Room: &livekit.Room{}, - AgentName: "test", + Namespace: "test", } - client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + _, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + require.NoError(t, err) select { case a := <-jobAssignments.Events(): @@ -45,3 +48,141 @@ func TestAgent(t *testing.T) { } }) } + +func TestAgentLoadBalancing(t *testing.T) { + + batchJobCreate := func(wg *sync.WaitGroup, batchSize int, totalJobs int, client rpc.AgentInternalClient) { + for i := 0; i < totalJobs; i += batchSize { + wg.Add(1) + go func(start int) { + defer wg.Done() + for j := start; j < start+batchSize && j < totalJobs; j++ { + job := &livekit.Job{ + Id: guid.New(guid.AgentJobPrefix), + DispatchId: guid.New(guid.AgentDispatchPrefix), + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{}, + Namespace: "test", + } + _, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + require.NoError(t, err) + } + }(i) + } + } + + t.Run("jobs are distributed normally with baseline worker load", func(t *testing.T) { + totalWorkers := 5 + totalJobs := 100 + + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + server := testutil.NewTestServer(bus) + t.Cleanup(server.Close) + + agents := make([]*testutil.AgentWorker, totalWorkers) + for i := 0; i < totalWorkers; i++ { + agents[i] = server.SimulateAgentWorker() + agents[i].Register(fmt.Sprintf("agent-%d", i), "test", livekit.JobType_JT_ROOM) + } + + jobAssignments := make(chan *livekit.Job, totalJobs) + for i := 0; i < totalWorkers; i++ { + worker := agents[i] + go func() { + for a := range worker.JobAssignments.Observe().Events() { + jobAssignments <- a.Job + } + }() + } + + var wg sync.WaitGroup + batchJobCreate(&wg, 10, totalJobs, client) + wg.Wait() + + jobCount := make(map[string]int) + for i := 0; i < totalJobs; i++ { + select { + case job := <-jobAssignments: + jobCount[job.AgentName]++ + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + } + + assignedJobs := 0 + // check that jobs are distributed normally + for i := 0; i < totalWorkers; i++ { + agentName := fmt.Sprintf("agent-%d", i) + assignedJobs += jobCount[agentName] + require.GreaterOrEqual(t, jobCount[agentName], 0) + require.Less(t, jobCount[agentName], 35) // three std deviations from the mean is 32 + } + + // ensure all jobs are assigned + require.Equal(t, 100, assignedJobs) + }) + + t.Run("jobs are distributed with variable and overloaded worker load", func(t *testing.T) { + totalWorkers := 4 + totalJobs := 15 + + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + server := testutil.NewTestServer(bus) + t.Cleanup(server.Close) + + agents := make([]*testutil.AgentWorker, totalWorkers) + for i := 0; i < totalWorkers; i++ { + if i%2 == 0 { + // make sure we have some workers that can accept jobs + agents[i] = server.SimulateAgentWorker() + } else { + agents[i] = server.SimulateAgentWorker(testutil.WithDefaultWorkerLoad(0.9)) + } + agents[i].Register(fmt.Sprintf("agent-%d", i), "test", livekit.JobType_JT_ROOM) + } + + jobAssignments := make(chan *livekit.Job, totalJobs) + for i := 0; i < totalWorkers; i++ { + worker := agents[i] + go func() { + for a := range worker.JobAssignments.Observe().Events() { + jobAssignments <- a.Job + } + }() + } + + var wg sync.WaitGroup + batchJobCreate(&wg, 1, totalJobs, client) + wg.Wait() + + jobCount := make(map[string]int) + for i := 0; i < totalJobs; i++ { + select { + case job := <-jobAssignments: + jobCount[job.AgentName]++ + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + } + + assignedJobs := 0 + for i := 0; i < totalWorkers; i++ { + agentName := fmt.Sprintf("agent-%d", i) + assignedJobs += jobCount[agentName] + + if i%2 == 0 { + require.GreaterOrEqual(t, jobCount[agentName], 2) + } else { + require.Equal(t, 0, jobCount[agentName]) + } + require.GreaterOrEqual(t, jobCount[agentName], 0) + } + + // ensure all jobs are assigned + require.Equal(t, 15, assignedJobs) + }) +} diff --git a/pkg/agent/testutil/server.go b/pkg/agent/testutil/server.go index 4ff2d560a..04f437534 100644 --- a/pkg/agent/testutil/server.go +++ b/pkg/agent/testutil/server.go @@ -49,6 +49,7 @@ type SimulatedWorkerOptions struct { SupportResume bool DefaultJobLoad float32 JobLoadThreshold float32 + DefaultWorkerLoad float32 HandleAvailability func(AgentJobRequest) HandleAssignment func(*livekit.Job) JobLoad } @@ -71,10 +72,17 @@ func WithJobLoad(l JobLoad) SimulatedWorkerOption { return WithJobAssignmentHandler(func(j *livekit.Job) JobLoad { return l }) } +func WithDefaultWorkerLoad(load float32) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.DefaultWorkerLoad = load + } +} + func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWorker { o := &SimulatedWorkerOptions{ DefaultJobLoad: 0.1, JobLoadThreshold: 0.8, + DefaultWorkerLoad: 0.0, HandleAvailability: func(r AgentJobRequest) { r.Accept() }, HandleAssignment: func(j *livekit.Job) JobLoad { return nil }, } @@ -95,6 +103,10 @@ func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWo } w.ctx, w.cancel = context.WithCancel(context.Background()) + if o.DefaultWorkerLoad > 0.0 { + w.sendStatus() + } + go w.worker() go h.handleConnection(w) return w @@ -173,6 +185,7 @@ func (r AgentJobRequest) Reject() { } type AgentWorker struct { + Name string *SimulatedWorkerOptions fuse core.Fuse @@ -290,12 +303,14 @@ func (w *AgentWorker) handleAvailability(m *livekit.AvailabilityRequest) { } func (w *AgentWorker) handleAssignment(m *livekit.JobAssignment) { + m.Job.AgentName = w.Name w.JobAssignments.Emit(m) var load JobLoad if w.HandleAssignment != nil { load = w.HandleAssignment(m.Job) } + if load == nil { load = NewStableJobLoad(w.DefaultJobLoad) } @@ -370,8 +385,13 @@ func (w *AgentWorker) sendStatus() { w.mu.Lock() var load float32 jobCount := len(w.jobs) - for _, j := range w.jobs { - load += j.Load() + + if len(w.jobs) == 0 { + load = w.DefaultWorkerLoad + } else { + for _, j := range w.jobs { + load += j.Load() + } } w.mu.Unlock() @@ -387,10 +407,11 @@ func (w *AgentWorker) sendStatus() { }) } -func (w *AgentWorker) Register(agentName string, jobType livekit.JobType) { +func (w *AgentWorker) Register(name string, namespace string, jobType livekit.JobType) { + w.Name = name w.SendRegister(&livekit.RegisterWorkerRequest{ Type: jobType, - AgentName: agentName, + Namespace: &namespace, }) w.sendStatus() } diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index abb708a3d..d626dbd53 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -308,28 +308,9 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J key := workerKey{job.AgentName, job.Namespace, job.Type} attempted := make(map[*agent.Worker]struct{}) for { - h.mu.Lock() - var selected *agent.Worker - var maxLoad float32 - for _, w := range h.namespaceWorkers[key] { - if _, ok := attempted[w]; ok { - continue - } - - if w.Status() == livekit.WorkerStatus_WS_AVAILABLE { - load := w.Load() - if len(w.RunningJobs()) > 0 && load > maxLoad { - maxLoad = load - selected = w - } else if selected == nil { - selected = w - } - } - } - h.mu.Unlock() - - if selected == nil { - return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "no workers available") + selected, err := h.selectWorkerWeightedByLoad(key, attempted) + if err != nil { + return nil, psrpc.NewError(psrpc.DeadlineExceeded, err) } attempted[selected] = struct{}{} @@ -347,7 +328,7 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J values = append(values, "participant", job.Participant.Identity) } h.logger.Debugw("assigning job", values...) - err := selected.AssignJob(ctx, job) + err = selected.AssignJob(ctx, job) if err != nil { if errors.Is(err, agent.ErrWorkerNotAvailable) { continue // Try another worker @@ -415,3 +396,44 @@ func (h *AgentHandler) DrainConnections(interval time.Duration) { <-t.C } } + +func (h *AgentHandler) selectWorkerWeightedByLoad(key workerKey, ignore map[*agent.Worker]struct{}) (*agent.Worker, error) { + h.mu.Lock() + defer h.mu.Unlock() + + workers, ok := h.namespaceWorkers[key] + if !ok { + return nil, errors.New("no workers available") + } + + normalizeLoad := func(load float32) int { + if load >= 1 { + return 0 + } + return int((1 - load) * 100) + } + + normalizedLoads := make(map[*agent.Worker]int) + var availableSum int + for _, w := range workers { + if _, ok := ignore[w]; !ok && w.Status() == livekit.WorkerStatus_WS_AVAILABLE { + normalizedLoads[w] = normalizeLoad(w.Load()) + availableSum += normalizedLoads[w] + } + } + + if availableSum == 0 { + return nil, errors.New("no workers with sufficient capacity") + } + + threshold := rand.Intn(availableSum) + var currentSum int + for w, load := range normalizedLoads { + currentSum += load + if currentSum >= threshold { + return w, nil + } + } + + return nil, errors.New("no workers available") +} From 7018e485f295da1b6e6bcc73c1bb3ab67e6f6ff2 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 8 Aug 2024 23:15:04 +0530 Subject: [PATCH 08/10] Do not start forwarding on an out-of-order packet. (#2917) It is possible that old packets arrive on receiver. If subscriber starts on that, the first packet time would be incorrect. Do not start forwarding on out-of-order packets. --- pkg/sfu/buffer/buffer.go | 2 ++ pkg/sfu/downtrack.go | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 505ddd694..ccaf9cd4c 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -66,6 +66,7 @@ type ExtPacket struct { RawPacket []byte DependencyDescriptor *ExtDependencyDescriptor AbsCaptureTimeExt *act.AbsCaptureTime + IsOutOfOrder bool } // Buffer contains all packets @@ -763,6 +764,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64, flowStat Spatial: InvalidLayerSpatial, Temporal: InvalidLayerTemporal, }, + IsOutOfOrder: flowState.IsOutOfOrder, } if len(rtpPacket.Payload) == 0 { diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 54c56714c..b94cd308d 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -745,7 +745,8 @@ func (d *DownTrack) maxLayerNotifierWorker() { // WriteRTP writes an RTP Packet to the DownTrack func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { - if !d.writable.Load() { + if !d.writable.Load() || (extPkt.IsOutOfOrder && !d.rtpStats.IsActive()) { + // do not start on an out-of-order packet return nil } @@ -801,7 +802,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { // the corresponding sequence number is received. // The extreme case is all packets containing the play out delay are lost and // all of them retransmitted and an RTCP Receiver Report received for those - // retransmited sequence numbers. But, that is highly improbable, if not impossible. + // retransmitted sequence numbers. But, that is highly improbable, if not impossible. } } var actBytes []byte From 64057c3e4dd855d10f7874595a7fc0eb87e08de3 Mon Sep 17 00:00:00 2001 From: Benjamin Pracht Date: Thu, 8 Aug 2024 22:31:23 +0200 Subject: [PATCH 09/10] Implement AgentDispatch service (#2919) This allows listing, adding and deleting agent dispatches on an existing room. Requests go to a new AgentDispatchService, which sends them over RPC to the rtc.Room via the RoomManager. The rtc.Room then does agent job management using RPCs to the agent service. --- go.mod | 3 +- go.sum | 6 +- pkg/agent/client.go | 27 ++- pkg/agent/worker.go | 54 ++++- pkg/rtc/room.go | 288 ++++++++++++++++++++++---- pkg/service/agent_dispatch_service.go | 61 ++++++ pkg/service/agentservice.go | 50 +++++ pkg/service/ingress.go | 6 +- pkg/service/roommanager.go | 60 +++++- pkg/service/server.go | 3 + pkg/service/wire.go | 2 + pkg/service/wire_gen.go | 9 +- 12 files changed, 510 insertions(+), 59 deletions(-) create mode 100644 pkg/service/agent_dispatch_service.go diff --git a/go.mod b/go.mod index d7a675538..170969453 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/livekit/livekit-server go 1.22 require ( - github.com/avast/retry-go/v4 v4.6.0 github.com/bep/debounce v1.2.1 github.com/d5/tengo/v2 v2.17.0 github.com/dustin/go-humanize v1.0.1 @@ -20,7 +19,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20240730083616-559fa5ece598 - github.com/livekit/protocol v1.19.4-0.20240805121416-5be7cb358ec1 + github.com/livekit/protocol v1.19.4-0.20240808180722-581b59b65309 github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index b45182d54..2140186e4 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEV github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= -github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA= -github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -167,8 +165,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20240730083616-559fa5ece598 h1:yLlkHk2feSLHstD9n4VKg7YEBR4rLODTI4WE8gNBEnQ= github.com/livekit/mediatransportutil v0.0.0-20240730083616-559fa5ece598/go.mod h1:jwKUCmObuiEDH0iiuJHaGMXwRs3RjrB4G6qqgkr/5oE= -github.com/livekit/protocol v1.19.4-0.20240805121416-5be7cb358ec1 h1:GP4QtOjYE6zDdtIi8AyM6ukse55HXr0174uOYXxb/H8= -github.com/livekit/protocol v1.19.4-0.20240805121416-5be7cb358ec1/go.mod h1:oU5XbEaQlywdgXcSQDzrI5CPnwuGn/HuRXuQaDxVryQ= +github.com/livekit/protocol v1.19.4-0.20240808180722-581b59b65309 h1:iGGiQkgRkDND59LZDPDNsRNBEWF/rpLfaH6BGy4KqWI= +github.com/livekit/protocol v1.19.4-0.20240808180722-581b59b65309/go.mod h1:oU5XbEaQlywdgXcSQDzrI5CPnwuGn/HuRXuQaDxVryQ= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a h1:EQAHmcYEGlc6V517cQ3Iy0+jHgP6+tM/B4l2vGuLpQo= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= diff --git a/pkg/agent/client.go b/pkg/agent/client.go index df1ff7dfe..95acedb5f 100644 --- a/pkg/agent/client.go +++ b/pkg/agent/client.go @@ -43,6 +43,7 @@ type Client interface { // LaunchJob starts a room or participant job on an agent. // it will launch a job once for each worker in each namespace LaunchJob(ctx context.Context, desc *JobRequest) *serverutils.IncrementalDispatcher[*livekit.Job] + TerminateJob(ctx context.Context, jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) Stop() error } @@ -118,7 +119,15 @@ func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverut jobTypeTopic = PublisherAgentTopic } + var wg sync.WaitGroup ret := serverutils.NewIncrementalDispatcher[*livekit.Job]() + defer func() { + c.workers.Submit(func() { + wg.Wait() + ret.Done() + }) + }() + dispatcher := c.getDispatcher(desc.AgentName, desc.JobType) if dispatcher == nil { @@ -130,7 +139,6 @@ func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverut return ret } - var wg sync.WaitGroup dispatcher.ForEach(func(curNs string) { topic := GetAgentTopic(desc.AgentName, curNs) @@ -157,14 +165,23 @@ func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverut ret.Add(job) }) }) - c.workers.Submit(func() { - wg.Wait() - ret.Done() - }) return ret } +func (c *agentClient) TerminateJob(ctx context.Context, jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) { + resp, err := c.client.JobTerminate(context.Background(), jobID, &rpc.JobTerminateRequest{ + JobId: jobID, + Reason: reason, + }) + if err != nil { + logger.Infow("failed to send job request", "error", err, "jobID", jobID) + return nil, err + } + + return resp.State, nil +} + func (c *agentClient) getDispatcher(agName string, jobType livekit.JobType) *serverutils.IncrementalDispatcher[string] { c.mu.Lock() diff --git a/pkg/agent/worker.go b/pkg/agent/worker.go index 143feee79..7f3f201ae 100644 --- a/pkg/agent/worker.go +++ b/pkg/agent/worker.go @@ -24,7 +24,9 @@ import ( pagent "github.com/livekit/protocol/agent" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" ) type WorkerProtocolVersion int @@ -90,7 +92,7 @@ type Worker struct { protocolVersion WorkerProtocolVersion registered atomic.Bool status livekit.WorkerStatus - runningJobs map[string]*livekit.Job + runningJobs map[string]*livekit.Job // JobID -> Job handler WorkerHandler @@ -145,7 +147,7 @@ func NewWorker( func (w *Worker) sendRequest(req *livekit.ServerMessage) { if _, err := w.conn.WriteServerMessage(req); err != nil { - w.logger.Errorw("error writing to websocket", err) + w.logger.Warnw("error writing to websocket", err) } } @@ -232,6 +234,8 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { return ErrWorkerNotAvailable } + job.State.ParticipantIdentity = res.ParticipantIdentity + token, err := pagent.BuildAgentToken(w.apiKey, w.apiSecret, job.Room.Name, res.ParticipantIdentity, res.ParticipantName, res.ParticipantMetadata, w.permissions) if err != nil { w.logger.Errorw("failed to build agent token", err) @@ -259,6 +263,37 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { } } +func (w *Worker) TerminateJob(jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) { + w.mu.Lock() + job := w.runningJobs[jobID] + w.mu.Unlock() + + if job == nil { + return nil, psrpc.NewErrorf(psrpc.NotFound, "no running job for given jobID") + } + + w.sendRequest(&livekit.ServerMessage{Message: &livekit.ServerMessage_Termination{ + Termination: &livekit.JobTermination{ + JobId: jobID, + }, + }}) + + status := livekit.JobStatus_JS_SUCCESS + errorStr := "" + if reason == rpc.JobTerminateReason_AGENT_LEFT_ROOM { + status = livekit.JobStatus_JS_FAILED + errorStr = "agent worker left the room" + } + + w.updateJobStatus(&livekit.UpdateJobStatus{ + JobId: jobID, + Status: status, + Error: errorStr, + }) + + return job.State, nil +} + func (w *Worker) UpdateMetadata(metadata string) { w.logger.Debugw("worker metadata updated", nil, "metadata", metadata) } @@ -374,12 +409,21 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) { } func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) { + err := w.updateJobStatus(update) + + if err != nil { + w.logger.Infow("received job update for unknown job", "jobID", update.JobId) + } +} + +func (w *Worker) updateJobStatus(update *livekit.UpdateJobStatus) error { w.mu.Lock() job, ok := w.runningJobs[update.JobId] if !ok { - w.logger.Infow("received job update for unknown job", "jobId", update.JobId) - return + w.mu.Unlock() + + return psrpc.NewErrorf(psrpc.NotFound, "received job update for unknown job") } now := time.Now() @@ -403,6 +447,8 @@ func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) { w.mu.Unlock() w.handler.HandleWorkerJobStatus(w, update) + + return nil } func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) { diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 74ddbe81d..8b2933842 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -34,8 +34,10 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" "github.com/livekit/livekit-server/pkg/agent" "github.com/livekit/livekit-server/pkg/config" @@ -61,6 +63,8 @@ const ( var ( // var to allow unit test override roomUpdateInterval = 5 * time.Second // frequency to update room participant counts + + ErrJobShutdownTimeout = psrpc.NewErrorf(psrpc.DeadlineExceeded, "timed out waiting for agent job to shutdown") ) // Duplicate the service.AgentStore interface to avoid a rtc -> service -> rtc import cycle @@ -112,7 +116,7 @@ type Room struct { telemetry telemetry.TelemetryService egressLauncher EgressLauncher trackManager *RoomTrackManager - agentDispatches []*livekit.AgentDispatch + agentDispatches map[string]*agentDispatch // agents agentClient agent.Client @@ -123,6 +127,7 @@ type Room struct { participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource hasPublished map[livekit.ParticipantIdentity]bool + agentParticpants map[livekit.ParticipantIdentity]*agentJob bufferFactory *buffer.FactoryOfBufferFactory // batch update participant info for non-publishers @@ -146,6 +151,88 @@ type ParticipantOptions struct { AutoSubscribe bool } +type agentDispatch struct { + *livekit.AgentDispatch + lock sync.Mutex + pending map[chan struct{}]struct{} +} + +type agentJob struct { + *livekit.Job + lock sync.Mutex + done chan struct{} +} + +// This provides utilities attached the agent dispatch to ensure that all pending jobs are created +// before terminating jobs attached to an agent dispatch. This avoids a race that could cause some pending jobs +// to not be terminated when a dispatch is deleted. +func newAgentDispatch(ad *livekit.AgentDispatch) *agentDispatch { + return &agentDispatch{ + AgentDispatch: ad, + pending: make(map[chan struct{}]struct{}), + } +} + +func (ad *agentDispatch) jobsLaunching() (jobsLaunched func()) { + ad.lock.Lock() + c := make(chan struct{}) + ad.pending[c] = struct{}{} + ad.lock.Unlock() + + return func() { + close(c) + ad.lock.Lock() + delete(ad.pending, c) + ad.lock.Unlock() + } +} + +func (ad *agentDispatch) waitForPendingJobs() { + ad.lock.Lock() + cs := maps.Keys(ad.pending) + ad.lock.Unlock() + + for _, c := range cs { + <-c + } +} + +// This provides utilities to ensure that an agent left the room when killing a job +func newAgentJob(j *livekit.Job) *agentJob { + return &agentJob{ + Job: j, + done: make(chan struct{}), + } +} + +func (j *agentJob) participantLeft() { + j.lock.Lock() + if j.done != nil { + close(j.done) + j.done = nil + } + j.lock.Unlock() +} + +func (j *agentJob) waitForParticipantLeaving() error { + var done chan struct{} + + j.lock.Lock() + done = j.done + j.lock.Unlock() + + if done != nil { + select { + case <-done: + return nil + case <-time.After(3 * time.Second): + return ErrJobShutdownTimeout + } + } + + return nil +} + func NewRoom( room *livekit.Room, internal *livekit.RoomInternal, @@ -172,12 +259,14 @@ func NewRoom( egressLauncher: egressLauncher, agentClient: agentClient, agentStore: agentStore, + agentDispatches: make(map[string]*agentDispatch), trackManager: NewRoomTrackManager(), serverInfo: serverInfo, participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), hasPublished: make(map[livekit.ParticipantIdentity]bool), + agentParticpants: make(map[livekit.ParticipantIdentity]*agentJob), bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSizeVideo, config.Receiver.PacketBufferSizeAudio), batchedUpdates: make(map[livekit.ParticipantIdentity]*participantUpdate), closed: make(chan struct{}), @@ -199,7 +288,7 @@ func NewRoom( r.createAgentDispatchesFromRoomAgent() - r.launchRoomAgents() + r.launchRoomAgents(maps.Values(r.agentDispatches)) go r.audioUpdateWorker() go r.connectionQualityWorker() @@ -576,10 +665,13 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek return } + agentJob := r.agentParticpants[identity] + delete(r.participants, identity) delete(r.participantOpts, identity) delete(r.participantRequestSources, identity) delete(r.hasPublished, identity) + delete(r.agentParticpants, identity) if !p.Hidden() { r.protoRoom.NumParticipants-- } @@ -617,6 +709,17 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek r.trackManager.RemoveTrack(t) } + if agentJob != nil { + agentJob.participantLeft() + + go func() { + _, err := r.agentClient.TerminateJob(context.Background(), agentJob.Id, rpc.JobTerminateReason_AGENT_LEFT_ROOM) + if err != nil { + r.Logger.Infow("failed sending TerminateJob RPC", "error", err, "jobID", agentJob.Id, "participant", identity) + } + }() + } + p.OnTrackUpdated(nil) p.OnTrackPublished(nil) p.OnTrackUnpublished(nil) @@ -853,6 +956,91 @@ func (r *Room) sendRoomUpdate() { } } +func (r *Room) GetAgentDispatches(dispatchID string) ([]*livekit.AgentDispatch, error) { + r.lock.RLock() + defer r.lock.RUnlock() + + var ret []*livekit.AgentDispatch + + for _, ad := range r.agentDispatches { + if dispatchID == "" || ad.Id == dispatchID { + ret = append(ret, proto.Clone(ad.AgentDispatch).(*livekit.AgentDispatch)) + } + } + + return ret, nil +} + +func (r *Room) AddAgentDispatch(agentName string, metadata string) (*livekit.AgentDispatch, error) { + ad, err := r.createAgentDispatchFromParams(agentName, metadata) + if err != nil { + return nil, err + } + + r.launchRoomAgents([]*agentDispatch{ad}) + + r.lock.RLock() + // launchPublisherAgents starts a goroutine to send requests, so is safe to call locked + for _, p := range r.participants { + r.launchPublisherAgents([]*agentDispatch{ad}, p) + } + r.lock.RUnlock() + + return ad.AgentDispatch, nil +} + +func (r *Room) DeleteAgentDispatch(dispatchID string) (*livekit.AgentDispatch, error) { + r.lock.Lock() + ad := r.agentDispatches[dispatchID] + if ad == nil { + r.lock.Unlock() + return nil, psrpc.NewErrorf(psrpc.NotFound, "dispatch ID not found") + } + + delete(r.agentDispatches, dispatchID) + r.lock.Unlock() + + // Should Delete be synchronous instead? + go func() { + ad.waitForPendingJobs() + + var jobs []*livekit.Job + r.lock.RLock() + if ad.State != nil { + jobs = ad.State.Jobs + } + r.lock.RUnlock() + + for _, j := range jobs { + state, err := r.agentClient.TerminateJob(context.Background(), j.Id, rpc.JobTerminateReason_TERINATION_REQUESTED) + if err != nil { + continue + } + if state.ParticipantIdentity != "" { + r.lock.RLock() + agentJob := r.agentParticpants[livekit.ParticipantIdentity(state.ParticipantIdentity)] + p := r.participants[livekit.ParticipantIdentity(state.ParticipantIdentity)] + r.lock.RUnlock() + + if p != nil { + if agentJob != nil { + err := agentJob.waitForParticipantLeaving() + if err == ErrJobShutdownTimeout { + r.Logger.Infow("Agent Worker did not disconnect after 3s") + } + } + r.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonServiceRequestRemoveParticipant) + } + } + r.lock.Lock() + j.State = state + r.lock.Unlock() + } + }() + + return ad.AgentDispatch, nil +} + func (r *Room) OnRoomUpdated(f func()) { r.onRoomUpdated = f } @@ -1013,7 +1201,9 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. r.lock.Unlock() if !hasPublished { - r.launchPublisherAgents(participant) + r.lock.RLock() + r.launchPublisherAgents(maps.Values(r.agentDispatches), participant) + r.lock.RUnlock() if r.internal != nil && r.internal.ParticipantEgress != nil { go func() { if err := StartParticipantEgress( @@ -1422,49 +1612,63 @@ func (r *Room) simulationCleanupWorker() { } } -func (r *Room) launchRoomAgents() { +func (r *Room) launchRoomAgents(ads []*agentDispatch) { if r.agentClient == nil { return } - for _, ag := range r.agentDispatches { + for _, ad := range ads { + done := ad.jobsLaunching() + go func() { inc := r.agentClient.LaunchJob(context.Background(), &agent.JobRequest{ JobType: livekit.JobType_JT_ROOM, Room: r.ToProto(), - Metadata: ag.Metadata, - AgentName: ag.AgentName, - DispatchId: ag.Id, - }) - inc.ForEach(func(job *livekit.Job) { - r.agentStore.StoreAgentJob(context.Background(), job) + Metadata: ad.Metadata, + AgentName: ad.AgentName, + DispatchId: ad.Id, }) + r.handleNewJobs(ad.AgentDispatch, inc) + done() }() } } -func (r *Room) launchPublisherAgents(p types.Participant) { +func (r *Room) launchPublisherAgents(ads []*agentDispatch, p types.Participant) { if p == nil || p.IsDependent() || r.agentClient == nil { return } - for _, ag := range r.agentDispatches { + for _, ad := range ads { + done := ad.jobsLaunching() + go func() { inc := r.agentClient.LaunchJob(context.Background(), &agent.JobRequest{ JobType: livekit.JobType_JT_PUBLISHER, Room: r.ToProto(), Participant: p.ToProto(), - Metadata: ag.Metadata, - AgentName: ag.AgentName, - DispatchId: ag.Id, - }) - inc.ForEach(func(job *livekit.Job) { - r.agentStore.StoreAgentJob(context.Background(), job) + Metadata: ad.Metadata, + AgentName: ad.AgentName, + DispatchId: ad.Id, }) + r.handleNewJobs(ad.AgentDispatch, inc) + done() }() } } +func (r *Room) handleNewJobs(ad *livekit.AgentDispatch, inc *sutils.IncrementalDispatcher[*livekit.Job]) { + inc.ForEach(func(job *livekit.Job) { + r.agentStore.StoreAgentJob(context.Background(), job) + r.lock.Lock() + ad.State.Jobs = append(ad.State.Jobs, job) + if job.State != nil && job.State.ParticipantIdentity != "" { + r.agentParticpants[livekit.ParticipantIdentity(job.State.ParticipantIdentity)] = newAgentJob(job) + } + r.lock.Unlock() + }) +} + func (r *Room) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "Name": r.protoRoom.Name, @@ -1482,8 +1686,34 @@ func (r *Room) DebugInfo() map[string]interface{} { return info } -func (r *Room) createAgentDispatchesFromRoomAgent() { +func (r *Room) createAgentDispatchFromParams(agentName string, metadata string) (*agentDispatch, error) { now := time.Now() + + ad := newAgentDispatch( + &livekit.AgentDispatch{ + Id: guid.New(guid.AgentDispatchPrefix), + AgentName: agentName, + Metadata: metadata, + Room: r.protoRoom.Name, + State: &livekit.AgentDispatchState{ + CreatedAt: now.UnixNano(), + }, + }, + ) + r.lock.RLock() + r.agentDispatches[ad.Id] = ad + r.lock.RUnlock() + if r.agentStore != nil { + err := r.agentStore.StoreAgentDispatch(context.Background(), ad.AgentDispatch) + if err != nil { + return nil, err + } + } + + return ad, nil +} + +func (r *Room) createAgentDispatchesFromRoomAgent() { if r.internal == nil { return } @@ -1497,21 +1727,9 @@ func (r *Room) createAgentDispatchesFromRoomAgent() { } for _, ag := range roomDisp { - ad := &livekit.AgentDispatch{ - Id: guid.New(guid.AgentDispatchPrefix), - AgentName: ag.AgentName, - Metadata: ag.Metadata, - Room: r.protoRoom.Name, - State: &livekit.AgentDispatchState{ - CreatedAt: now.UnixNano(), - }, - } - r.agentDispatches = append(r.agentDispatches, ad) - if r.agentStore != nil { - err := r.agentStore.StoreAgentDispatch(context.Background(), ad) - if err != nil { - r.Logger.Warnw("failed storing room dispatch", err) - } + _, err := r.createAgentDispatchFromParams(ag.AgentName, ag.Metadata) + if err != nil { + r.Logger.Warnw("failed storing room dispatch", err) } } } diff --git a/pkg/service/agent_dispatch_service.go b/pkg/service/agent_dispatch_service.go new file mode 100644 index 000000000..680c98268 --- /dev/null +++ b/pkg/service/agent_dispatch_service.go @@ -0,0 +1,61 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" +) + +type AgentDispatchService struct { + agentDispatchClient rpc.TypedAgentDispatchInternalClient + topicFormatter rpc.TopicFormatter +} + +func NewAgentDispatchService(agentDispatchClient rpc.TypedAgentDispatchInternalClient, topicFormatter rpc.TopicFormatter) *AgentDispatchService { + return &AgentDispatchService{ + agentDispatchClient: agentDispatchClient, + topicFormatter: topicFormatter, + } +} + +func (ag *AgentDispatchService) CreateDispatch(ctx context.Context, req *livekit.CreateAgentDispatchRequest) (*livekit.AgentDispatch, error) { + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.CreateDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} + +func (ag *AgentDispatchService) DeleteDispatch(ctx context.Context, req *livekit.DeleteAgentDispatchRequest) (*livekit.AgentDispatch, error) { + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.DeleteDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} + +func (ag *AgentDispatchService) ListDispatch(ctx context.Context, req *livekit.ListAgentDispatchRequest) (*livekit.ListAgentDispatchResponse, error) { + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.ListDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index d626dbd53..c7c0846c0 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -90,6 +90,7 @@ type AgentHandler struct { serverInfo *livekit.ServerInfo workers map[string]*agent.Worker + jobToWorker map[string]*agent.Worker keyProvider auth.KeyProvider namespaceWorkers map[workerKey][]*agent.Worker @@ -163,6 +164,7 @@ func NewAgentHandler( agentServer: agentServer, logger: logger, workers: make(map[string]*agent.Worker), + jobToWorker: make(map[string]*agent.Worker), namespaceWorkers: make(map[workerKey][]*agent.Worker), serverInfo: serverInfo, keyProvider: keyProvider, @@ -302,6 +304,27 @@ func (h *AgentHandler) HandleWorkerDeregister(w *agent.Worker) { h.agentNames = slices.Delete(h.agentNames, i, i+1) } } + + jobs := w.RunningJobs() + for _, j := range jobs { + h.deregisterJob(j.Id) + } +} + +func (h *AgentHandler) HandleWorkerJobStatus(w *agent.Worker, status *livekit.UpdateJobStatus) { + if agent.JobStatusIsEnded(status.Status) { + h.mu.Lock() + h.deregisterJob(status.JobId) + h.mu.Unlock() + } +} + +func (h *AgentHandler) deregisterJob(jobID string) { + h.agentServer.DeregisterJobTerminateTopic(jobID) + + delete(h.jobToWorker, jobID) + + // TODO update dispatch state } func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.JobRequestResponse, error) { @@ -335,6 +358,14 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J } return nil, err } + h.mu.Lock() + h.jobToWorker[job.Id] = selected + h.mu.Unlock() + + err = h.agentServer.RegisterJobTerminateTopic(job.Id) + if err != nil { + h.logger.Errorw("failes registering JobTerminate handler", err, values...) + } return &rpc.JobRequestResponse{ State: job.State, @@ -367,6 +398,25 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) return affinity } +func (h *AgentHandler) JobTerminate(ctx context.Context, req *rpc.JobTerminateRequest) (*rpc.JobTerminateResponse, error) { + h.mu.Lock() + w := h.jobToWorker[req.JobId] + h.mu.Unlock() + + if w == nil { + return nil, psrpc.NewErrorf(psrpc.NotFound, "no worker for jobID") + } + + state, err := w.TerminateJob(req.JobId, req.Reason) + if err != nil { + return nil, err + } + + return &rpc.JobTerminateResponse{ + State: state, + }, nil +} + func (h *AgentHandler) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) { h.mu.Lock() defer h.mu.Unlock() diff --git a/pkg/service/ingress.go b/pkg/service/ingress.go index e0a08b720..fcee84634 100644 --- a/pkg/service/ingress.go +++ b/pkg/service/ingress.go @@ -41,7 +41,6 @@ type IngressService struct { psrpcClient rpc.IngressClient store IngressStore io IOClient - roomService livekit.RoomService telemetry telemetry.TelemetryService launcher IngressLauncher } @@ -53,7 +52,6 @@ func NewIngressServiceWithIngressLauncher( psrpcClient rpc.IngressClient, store IngressStore, io IOClient, - rs livekit.RoomService, ts telemetry.TelemetryService, launcher IngressLauncher, ) *IngressService { @@ -65,7 +63,6 @@ func NewIngressServiceWithIngressLauncher( psrpcClient: psrpcClient, store: store, io: io, - roomService: rs, telemetry: ts, launcher: launcher, } @@ -78,10 +75,9 @@ func NewIngressService( psrpcClient rpc.IngressClient, store IngressStore, io IOClient, - rs livekit.RoomService, ts telemetry.TelemetryService, ) *IngressService { - s := NewIngressServiceWithIngressLauncher(conf, nodeID, bus, psrpcClient, store, io, rs, ts, nil) + s := NewIngressServiceWithIngressLauncher(conf, nodeID, bus, psrpcClient, store, io, ts, nil) s.launcher = s diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 06f783a43..7024d9fe9 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -82,8 +82,9 @@ type RoomManager struct { rooms map[livekit.RoomName]*rtc.Room - roomServers utils.MultitonService[rpc.RoomTopic] - participantServers utils.MultitonService[rpc.ParticipantTopic] + roomServers utils.MultitonService[rpc.RoomTopic] + agentDispatchServers utils.MultitonService[rpc.RoomTopic] + participantServers utils.MultitonService[rpc.ParticipantTopic] iceConfigCache *sutils.IceConfigCache[iceConfigCacheKey] @@ -229,6 +230,7 @@ func (r *RoomManager) Stop() { } r.roomServers.Kill() + r.agentDispatchServers.Kill() r.participantServers.Kill() if r.rtcConfig != nil { @@ -570,9 +572,17 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room r.lock.Unlock() return nil, err } + agentDispatchServer := must.Get(rpc.NewTypedAgentDispatchInternalServer(r, r.bus)) + killDispServer := r.agentDispatchServers.Replace(roomTopic, agentDispatchServer) + if err := agentDispatchServer.RegisterAllRoomTopics(roomTopic); err != nil { + killDispServer() + r.lock.Unlock() + return nil, err + } newRoom.OnClose(func() { killRoomServer() + killDispServer() roomInfo := newRoom.ToProto() r.telemetry.RoomEnded(ctx, roomInfo) @@ -807,6 +817,52 @@ func (r *RoomManager) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return room.ToProto(), nil } +func (r *RoomManager) ListDispatch(ctx context.Context, req *livekit.ListAgentDispatchRequest) (*livekit.ListAgentDispatchResponse, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.GetAgentDispatches(req.DispatchId) + if err != nil { + return nil, err + } + + ret := &livekit.ListAgentDispatchResponse{ + AgentDispatches: disp, + } + + return ret, nil +} + +func (r *RoomManager) CreateDispatch(ctx context.Context, req *livekit.CreateAgentDispatchRequest) (*livekit.AgentDispatch, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.AddAgentDispatch(req.AgentName, req.Metadata) + if err != nil { + return nil, err + } + + return disp, nil +} + +func (r *RoomManager) DeleteDispatch(ctx context.Context, req *livekit.DeleteAgentDispatchRequest) (*livekit.AgentDispatch, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.DeleteAgentDispatch(req.DispatchId) + if err != nil { + return nil, err + } + + return disp, nil +} + func (r *RoomManager) iceServersForParticipant(apiKey string, participant types.LocalParticipant, tlsOnly bool) []*livekit.ICEServer { var iceServers []*livekit.ICEServer rtcConf := r.config.RTC diff --git a/pkg/service/server.go b/pkg/service/server.go index 88498518e..bfdd5d4e5 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -62,6 +62,7 @@ type LivekitServer struct { func NewLivekitServer(conf *config.Config, roomService livekit.RoomService, + agentDispatchService *AgentDispatchService, egressService *EgressService, ingressService *IngressService, sipService *SIPService, @@ -109,6 +110,7 @@ func NewLivekitServer(conf *config.Config, twirpLoggingHook := TwirpLogger() twirpRequestStatusHook := TwirpRequestStatusReporter() roomServer := livekit.NewRoomServiceServer(roomService, twirpLoggingHook) + agentDispatchServer := livekit.NewAgentDispatchServiceServer(agentDispatchService, twirpLoggingHook) egressServer := livekit.NewEgressServer(egressService, twirp.WithServerHooks( twirp.ChainHooks( twirpLoggingHook, @@ -127,6 +129,7 @@ func NewLivekitServer(conf *config.Config, } mux.Handle(roomServer.PathPrefix(), roomServer) + mux.Handle(agentDispatchServer.PathPrefix(), agentDispatchServer) mux.Handle(egressServer.PathPrefix(), egressServer) mux.Handle(ingressServer.PathPrefix(), ingressServer) mux.Handle(sipServer.PathPrefix(), sipServer) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 75b3509a1..56d601bf6 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -79,6 +79,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewRoomService, NewRTCService, NewAgentService, + NewAgentDispatchService, agent.NewAgentClient, getAgentStore, getSignalRelayConfig, @@ -90,6 +91,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live rpc.NewTopicFormatter, rpc.NewTypedRoomClient, rpc.NewTypedParticipantClient, + rpc.NewTypedAgentDispatchInternalClient, NewLocalRoomManager, NewTURNAuthHandler, getTURNAuthHandlerFunc, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 15f81061f..f27f4dde6 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -100,13 +100,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } + agentDispatchInternalClient, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) + if err != nil { + return nil, err + } + agentDispatchService := NewAgentDispatchService(agentDispatchInternalClient, topicFormatter) egressService := NewEgressService(egressClient, rtcEgressLauncher, objectStore, ioInfoService, roomService) ingressConfig := getIngressConfig(conf) ingressClient, err := rpc.NewIngressClient(clientParams) if err != nil { return nil, err } - ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, ioInfoService, roomService, telemetryService) + ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, ioInfoService, telemetryService) sipConfig := getSIPConfig(conf) sipClient, err := rpc.NewSIPClient(messageBus) if err != nil { @@ -136,7 +141,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, sipService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, agentDispatchService, egressService, ingressService, sipService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) if err != nil { return nil, err } From d0ac19779e1dacb5900877bec18a643b934e8156 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 10 Aug 2024 15:42:14 +0530 Subject: [PATCH 10/10] Reset DD tracker layers when muted. (#2920) * Reset DD tracker layers when muted. @cnderrauber, I think this is okay to do, but please let me know if there are gotchas in there. * copy * more compact form --- pkg/sfu/streamallocator/streamallocator.go | 4 +--- pkg/sfu/streamtracker/streamtracker_dd.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pkg/sfu/streamallocator/streamallocator.go b/pkg/sfu/streamallocator/streamallocator.go index caeb594db..e1595b427 100644 --- a/pkg/sfu/streamallocator/streamallocator.go +++ b/pkg/sfu/streamallocator/streamallocator.go @@ -545,9 +545,7 @@ func (s *StreamAllocator) maybePostEventAllocateTrack(downTrack *sfu.DownTrack) shouldPost := false s.videoTracksMu.Lock() if track := s.videoTracks[livekit.TrackID(downTrack.ID())]; track != nil { - if track.SetDirty(true) { - shouldPost = true - } + shouldPost = track.SetDirty(true) } s.videoTracksMu.Unlock() diff --git a/pkg/sfu/streamtracker/streamtracker_dd.go b/pkg/sfu/streamtracker/streamtracker_dd.go index 29b876c70..5feb33d9e 100644 --- a/pkg/sfu/streamtracker/streamtracker_dd.go +++ b/pkg/sfu/streamtracker/streamtracker_dd.go @@ -117,6 +117,9 @@ func (s *StreamTrackerDependencyDescriptor) resetLocked() { s.bitrate[i][j] = 0 } } + + s.maxSpatialLayer = buffer.InvalidLayerSpatial + s.maxTemporalLayer = buffer.InvalidLayerTemporal } func (s *StreamTrackerDependencyDescriptor) SetPaused(paused bool) { @@ -126,8 +129,14 @@ func (s *StreamTrackerDependencyDescriptor) SetPaused(paused bool) { return } s.paused = paused + + var notifyFns []func(status StreamStatus) + var notifyStatus StreamStatus if !paused { s.resetLocked() + + notifyStatus = StreamStatusStopped + notifyFns = append(notifyFns, s.onStatusChanged[:]...) } else { s.lastBitrateReport = time.Now() go s.worker(s.generation.Inc()) @@ -135,6 +144,11 @@ func (s *StreamTrackerDependencyDescriptor) SetPaused(paused bool) { } s.lock.Unlock() + for _, fn := range notifyFns { + if fn != nil { + fn(notifyStatus) + } + } } func (s *StreamTrackerDependencyDescriptor) Observe(temporalLayer int32, pktSize int, payloadSize int, hasMarker bool, ts uint32, ddVal *buffer.ExtDependencyDescriptor) {