diff --git a/go.mod b/go.mod index d7a675538..170969453 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/livekit/livekit-server go 1.22 require ( - github.com/avast/retry-go/v4 v4.6.0 github.com/bep/debounce v1.2.1 github.com/d5/tengo/v2 v2.17.0 github.com/dustin/go-humanize v1.0.1 @@ -20,7 +19,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20240730083616-559fa5ece598 - github.com/livekit/protocol v1.19.4-0.20240805121416-5be7cb358ec1 + github.com/livekit/protocol v1.19.4-0.20240808180722-581b59b65309 github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a github.com/mackerelio/go-osstat v0.2.5 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index b45182d54..2140186e4 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEV github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= -github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA= -github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -167,8 +165,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20240730083616-559fa5ece598 h1:yLlkHk2feSLHstD9n4VKg7YEBR4rLODTI4WE8gNBEnQ= github.com/livekit/mediatransportutil v0.0.0-20240730083616-559fa5ece598/go.mod h1:jwKUCmObuiEDH0iiuJHaGMXwRs3RjrB4G6qqgkr/5oE= -github.com/livekit/protocol v1.19.4-0.20240805121416-5be7cb358ec1 h1:GP4QtOjYE6zDdtIi8AyM6ukse55HXr0174uOYXxb/H8= -github.com/livekit/protocol v1.19.4-0.20240805121416-5be7cb358ec1/go.mod h1:oU5XbEaQlywdgXcSQDzrI5CPnwuGn/HuRXuQaDxVryQ= +github.com/livekit/protocol v1.19.4-0.20240808180722-581b59b65309 h1:iGGiQkgRkDND59LZDPDNsRNBEWF/rpLfaH6BGy4KqWI= +github.com/livekit/protocol v1.19.4-0.20240808180722-581b59b65309/go.mod h1:oU5XbEaQlywdgXcSQDzrI5CPnwuGn/HuRXuQaDxVryQ= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a h1:EQAHmcYEGlc6V517cQ3Iy0+jHgP6+tM/B4l2vGuLpQo= github.com/livekit/psrpc v0.5.3-0.20240616012458-ac39c8549a0a/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= github.com/mackerelio/go-osstat v0.2.5 h1:+MqTbZUhoIt4m8qzkVoXUJg1EuifwlAJSk4Yl2GXh+o= diff --git a/pkg/agent/client.go b/pkg/agent/client.go index df1ff7dfe..95acedb5f 100644 --- a/pkg/agent/client.go +++ b/pkg/agent/client.go @@ -43,6 +43,7 @@ type Client interface { // LaunchJob starts a room or participant job on an agent. // it will launch a job once for each worker in each namespace LaunchJob(ctx context.Context, desc *JobRequest) *serverutils.IncrementalDispatcher[*livekit.Job] + TerminateJob(ctx context.Context, jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) Stop() error } @@ -118,7 +119,15 @@ func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverut jobTypeTopic = PublisherAgentTopic } + var wg sync.WaitGroup ret := serverutils.NewIncrementalDispatcher[*livekit.Job]() + defer func() { + c.workers.Submit(func() { + wg.Wait() + ret.Done() + }) + }() + dispatcher := c.getDispatcher(desc.AgentName, desc.JobType) if dispatcher == nil { @@ -130,7 +139,6 @@ func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverut return ret } - var wg sync.WaitGroup dispatcher.ForEach(func(curNs string) { topic := GetAgentTopic(desc.AgentName, curNs) @@ -157,14 +165,23 @@ func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverut ret.Add(job) }) }) - c.workers.Submit(func() { - wg.Wait() - ret.Done() - }) return ret } +func (c *agentClient) TerminateJob(ctx context.Context, jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) { + resp, err := c.client.JobTerminate(context.Background(), jobID, &rpc.JobTerminateRequest{ + JobId: jobID, + Reason: reason, + }) + if err != nil { + logger.Infow("failed to send job request", "error", err, "jobID", jobID) + return nil, err + } + + return resp.State, nil +} + func (c *agentClient) getDispatcher(agName string, jobType livekit.JobType) *serverutils.IncrementalDispatcher[string] { c.mu.Lock() diff --git a/pkg/agent/worker.go b/pkg/agent/worker.go index 143feee79..7f3f201ae 100644 --- a/pkg/agent/worker.go +++ b/pkg/agent/worker.go @@ -24,7 +24,9 @@ import ( pagent "github.com/livekit/protocol/agent" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" ) type WorkerProtocolVersion int @@ -90,7 +92,7 @@ type Worker struct { protocolVersion WorkerProtocolVersion registered atomic.Bool status livekit.WorkerStatus - runningJobs map[string]*livekit.Job + runningJobs map[string]*livekit.Job // JobID -> Job handler WorkerHandler @@ -145,7 +147,7 @@ func NewWorker( func (w *Worker) sendRequest(req *livekit.ServerMessage) { if _, err := w.conn.WriteServerMessage(req); err != nil { - w.logger.Errorw("error writing to websocket", err) + w.logger.Warnw("error writing to websocket", err) } } @@ -232,6 +234,8 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { return ErrWorkerNotAvailable } + job.State.ParticipantIdentity = res.ParticipantIdentity + token, err := pagent.BuildAgentToken(w.apiKey, w.apiSecret, job.Room.Name, res.ParticipantIdentity, res.ParticipantName, res.ParticipantMetadata, w.permissions) if err != nil { w.logger.Errorw("failed to build agent token", err) @@ -259,6 +263,37 @@ func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) error { } } +func (w *Worker) TerminateJob(jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) { + w.mu.Lock() + job := w.runningJobs[jobID] + w.mu.Unlock() + + if job == nil { + return nil, psrpc.NewErrorf(psrpc.NotFound, "no running job for given jobID") + } + + w.sendRequest(&livekit.ServerMessage{Message: &livekit.ServerMessage_Termination{ + Termination: &livekit.JobTermination{ + JobId: jobID, + }, + }}) + + status := livekit.JobStatus_JS_SUCCESS + errorStr := "" + if reason == rpc.JobTerminateReason_AGENT_LEFT_ROOM { + status = livekit.JobStatus_JS_FAILED + errorStr = "agent worker left the room" + } + + w.updateJobStatus(&livekit.UpdateJobStatus{ + JobId: jobID, + Status: status, + Error: errorStr, + }) + + return job.State, nil +} + func (w *Worker) UpdateMetadata(metadata string) { w.logger.Debugw("worker metadata updated", nil, "metadata", metadata) } @@ -374,12 +409,21 @@ func (w *Worker) handleAvailability(res *livekit.AvailabilityResponse) { } func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) { + err := w.updateJobStatus(update) + + if err != nil { + w.logger.Infow("received job update for unknown job", "jobID", update.JobId) + } +} + +func (w *Worker) updateJobStatus(update *livekit.UpdateJobStatus) error { w.mu.Lock() job, ok := w.runningJobs[update.JobId] if !ok { - w.logger.Infow("received job update for unknown job", "jobId", update.JobId) - return + w.mu.Unlock() + + return psrpc.NewErrorf(psrpc.NotFound, "received job update for unknown job") } now := time.Now() @@ -403,6 +447,8 @@ func (w *Worker) handleJobUpdate(update *livekit.UpdateJobStatus) { w.mu.Unlock() w.handler.HandleWorkerJobStatus(w, update) + + return nil } func (w *Worker) handleSimulateJob(simulate *livekit.SimulateJobRequest) { diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 74ddbe81d..8b2933842 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -34,8 +34,10 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" "github.com/livekit/livekit-server/pkg/agent" "github.com/livekit/livekit-server/pkg/config" @@ -61,6 +63,8 @@ const ( var ( // var to allow unit test override roomUpdateInterval = 5 * time.Second // frequency to update room participant counts + + ErrJobShutdownTimeout = psrpc.NewErrorf(psrpc.DeadlineExceeded, "timed out waiting for agent job to shutdown") ) // Duplicate the service.AgentStore interface to avoid a rtc -> service -> rtc import cycle @@ -112,7 +116,7 @@ type Room struct { telemetry telemetry.TelemetryService egressLauncher EgressLauncher trackManager *RoomTrackManager - agentDispatches []*livekit.AgentDispatch + agentDispatches map[string]*agentDispatch // agents agentClient agent.Client @@ -123,6 +127,7 @@ type Room struct { participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource hasPublished map[livekit.ParticipantIdentity]bool + agentParticpants map[livekit.ParticipantIdentity]*agentJob bufferFactory *buffer.FactoryOfBufferFactory // batch update participant info for non-publishers @@ -146,6 +151,88 @@ type ParticipantOptions struct { AutoSubscribe bool } +type agentDispatch struct { + *livekit.AgentDispatch + lock sync.Mutex + pending map[chan struct{}]struct{} +} + +type agentJob struct { + *livekit.Job + lock sync.Mutex + done chan struct{} +} + +// This provides utilities attached the agent dispatch to ensure that all pending jobs are created +// before terminating jobs attached to an agent dispatch. This avoids a race that could cause some pending jobs +// to not be terminated when a dispatch is deleted. +func newAgentDispatch(ad *livekit.AgentDispatch) *agentDispatch { + return &agentDispatch{ + AgentDispatch: ad, + pending: make(map[chan struct{}]struct{}), + } +} + +func (ad *agentDispatch) jobsLaunching() (jobsLaunched func()) { + ad.lock.Lock() + c := make(chan struct{}) + ad.pending[c] = struct{}{} + ad.lock.Unlock() + + return func() { + close(c) + ad.lock.Lock() + delete(ad.pending, c) + ad.lock.Unlock() + } +} + +func (ad *agentDispatch) waitForPendingJobs() { + ad.lock.Lock() + cs := maps.Keys(ad.pending) + ad.lock.Unlock() + + for _, c := range cs { + <-c + } +} + +// This provides utilities to ensure that an agent left the room when killing a job +func newAgentJob(j *livekit.Job) *agentJob { + return &agentJob{ + Job: j, + done: make(chan struct{}), + } +} + +func (j *agentJob) participantLeft() { + j.lock.Lock() + if j.done != nil { + close(j.done) + j.done = nil + } + j.lock.Unlock() +} + +func (j *agentJob) waitForParticipantLeaving() error { + var done chan struct{} + + j.lock.Lock() + done = j.done + j.lock.Unlock() + + if done != nil { + select { + case <-done: + return nil + case <-time.After(3 * time.Second): + return ErrJobShutdownTimeout + } + } + + return nil +} + func NewRoom( room *livekit.Room, internal *livekit.RoomInternal, @@ -172,12 +259,14 @@ func NewRoom( egressLauncher: egressLauncher, agentClient: agentClient, agentStore: agentStore, + agentDispatches: make(map[string]*agentDispatch), trackManager: NewRoomTrackManager(), serverInfo: serverInfo, participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), hasPublished: make(map[livekit.ParticipantIdentity]bool), + agentParticpants: make(map[livekit.ParticipantIdentity]*agentJob), bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSizeVideo, config.Receiver.PacketBufferSizeAudio), batchedUpdates: make(map[livekit.ParticipantIdentity]*participantUpdate), closed: make(chan struct{}), @@ -199,7 +288,7 @@ func NewRoom( r.createAgentDispatchesFromRoomAgent() - r.launchRoomAgents() + r.launchRoomAgents(maps.Values(r.agentDispatches)) go r.audioUpdateWorker() go r.connectionQualityWorker() @@ -576,10 +665,13 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek return } + agentJob := r.agentParticpants[identity] + delete(r.participants, identity) delete(r.participantOpts, identity) delete(r.participantRequestSources, identity) delete(r.hasPublished, identity) + delete(r.agentParticpants, identity) if !p.Hidden() { r.protoRoom.NumParticipants-- } @@ -617,6 +709,17 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek r.trackManager.RemoveTrack(t) } + if agentJob != nil { + agentJob.participantLeft() + + go func() { + _, err := r.agentClient.TerminateJob(context.Background(), agentJob.Id, rpc.JobTerminateReason_AGENT_LEFT_ROOM) + if err != nil { + r.Logger.Infow("failed sending TerminateJob RPC", "error", err, "jobID", agentJob.Id, "participant", identity) + } + }() + } + p.OnTrackUpdated(nil) p.OnTrackPublished(nil) p.OnTrackUnpublished(nil) @@ -853,6 +956,91 @@ func (r *Room) sendRoomUpdate() { } } +func (r *Room) GetAgentDispatches(dispatchID string) ([]*livekit.AgentDispatch, error) { + r.lock.RLock() + defer r.lock.RUnlock() + + var ret []*livekit.AgentDispatch + + for _, ad := range r.agentDispatches { + if dispatchID == "" || ad.Id == dispatchID { + ret = append(ret, proto.Clone(ad.AgentDispatch).(*livekit.AgentDispatch)) + } + } + + return ret, nil +} + +func (r *Room) AddAgentDispatch(agentName string, metadata string) (*livekit.AgentDispatch, error) { + ad, err := r.createAgentDispatchFromParams(agentName, metadata) + if err != nil { + return nil, err + } + + r.launchRoomAgents([]*agentDispatch{ad}) + + r.lock.RLock() + // launchPublisherAgents starts a goroutine to send requests, so is safe to call locked + for _, p := range r.participants { + r.launchPublisherAgents([]*agentDispatch{ad}, p) + } + r.lock.RUnlock() + + return ad.AgentDispatch, nil +} + +func (r *Room) DeleteAgentDispatch(dispatchID string) (*livekit.AgentDispatch, error) { + r.lock.Lock() + ad := r.agentDispatches[dispatchID] + if ad == nil { + r.lock.Unlock() + return nil, psrpc.NewErrorf(psrpc.NotFound, "dispatch ID not found") + } + + delete(r.agentDispatches, dispatchID) + r.lock.Unlock() + + // Should Delete be synchronous instead? + go func() { + ad.waitForPendingJobs() + + var jobs []*livekit.Job + r.lock.RLock() + if ad.State != nil { + jobs = ad.State.Jobs + } + r.lock.RUnlock() + + for _, j := range jobs { + state, err := r.agentClient.TerminateJob(context.Background(), j.Id, rpc.JobTerminateReason_TERINATION_REQUESTED) + if err != nil { + continue + } + if state.ParticipantIdentity != "" { + r.lock.RLock() + agentJob := r.agentParticpants[livekit.ParticipantIdentity(state.ParticipantIdentity)] + p := r.participants[livekit.ParticipantIdentity(state.ParticipantIdentity)] + r.lock.RUnlock() + + if p != nil { + if agentJob != nil { + err := agentJob.waitForParticipantLeaving() + if err == ErrJobShutdownTimeout { + r.Logger.Infow("Agent Worker did not disconnect after 3s") + } + } + r.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonServiceRequestRemoveParticipant) + } + } + r.lock.Lock() + j.State = state + r.lock.Unlock() + } + }() + + return ad.AgentDispatch, nil +} + func (r *Room) OnRoomUpdated(f func()) { r.onRoomUpdated = f } @@ -1013,7 +1201,9 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. r.lock.Unlock() if !hasPublished { - r.launchPublisherAgents(participant) + r.lock.RLock() + r.launchPublisherAgents(maps.Values(r.agentDispatches), participant) + r.lock.RUnlock() if r.internal != nil && r.internal.ParticipantEgress != nil { go func() { if err := StartParticipantEgress( @@ -1422,49 +1612,63 @@ func (r *Room) simulationCleanupWorker() { } } -func (r *Room) launchRoomAgents() { +func (r *Room) launchRoomAgents(ads []*agentDispatch) { if r.agentClient == nil { return } - for _, ag := range r.agentDispatches { + for _, ad := range ads { + done := ad.jobsLaunching() + go func() { inc := r.agentClient.LaunchJob(context.Background(), &agent.JobRequest{ JobType: livekit.JobType_JT_ROOM, Room: r.ToProto(), - Metadata: ag.Metadata, - AgentName: ag.AgentName, - DispatchId: ag.Id, - }) - inc.ForEach(func(job *livekit.Job) { - r.agentStore.StoreAgentJob(context.Background(), job) + Metadata: ad.Metadata, + AgentName: ad.AgentName, + DispatchId: ad.Id, }) + r.handleNewJobs(ad.AgentDispatch, inc) + done() }() } } -func (r *Room) launchPublisherAgents(p types.Participant) { +func (r *Room) launchPublisherAgents(ads []*agentDispatch, p types.Participant) { if p == nil || p.IsDependent() || r.agentClient == nil { return } - for _, ag := range r.agentDispatches { + for _, ad := range ads { + done := ad.jobsLaunching() + go func() { inc := r.agentClient.LaunchJob(context.Background(), &agent.JobRequest{ JobType: livekit.JobType_JT_PUBLISHER, Room: r.ToProto(), Participant: p.ToProto(), - Metadata: ag.Metadata, - AgentName: ag.AgentName, - DispatchId: ag.Id, - }) - inc.ForEach(func(job *livekit.Job) { - r.agentStore.StoreAgentJob(context.Background(), job) + Metadata: ad.Metadata, + AgentName: ad.AgentName, + DispatchId: ad.Id, }) + r.handleNewJobs(ad.AgentDispatch, inc) + done() }() } } +func (r *Room) handleNewJobs(ad *livekit.AgentDispatch, inc *sutils.IncrementalDispatcher[*livekit.Job]) { + inc.ForEach(func(job *livekit.Job) { + r.agentStore.StoreAgentJob(context.Background(), job) + r.lock.Lock() + ad.State.Jobs = append(ad.State.Jobs, job) + if job.State != nil && job.State.ParticipantIdentity != "" { + r.agentParticpants[livekit.ParticipantIdentity(job.State.ParticipantIdentity)] = newAgentJob(job) + } + r.lock.Unlock() + }) +} + func (r *Room) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "Name": r.protoRoom.Name, @@ -1482,8 +1686,34 @@ func (r *Room) DebugInfo() map[string]interface{} { return info } -func (r *Room) createAgentDispatchesFromRoomAgent() { +func (r *Room) createAgentDispatchFromParams(agentName string, metadata string) (*agentDispatch, error) { now := time.Now() + + ad := newAgentDispatch( + &livekit.AgentDispatch{ + Id: guid.New(guid.AgentDispatchPrefix), + AgentName: agentName, + Metadata: metadata, + Room: r.protoRoom.Name, + State: &livekit.AgentDispatchState{ + CreatedAt: now.UnixNano(), + }, + }, + ) + r.lock.RLock() + r.agentDispatches[ad.Id] = ad + r.lock.RUnlock() + if r.agentStore != nil { + err := r.agentStore.StoreAgentDispatch(context.Background(), ad.AgentDispatch) + if err != nil { + return nil, err + } + } + + return ad, nil +} + +func (r *Room) createAgentDispatchesFromRoomAgent() { if r.internal == nil { return } @@ -1497,21 +1727,9 @@ func (r *Room) createAgentDispatchesFromRoomAgent() { } for _, ag := range roomDisp { - ad := &livekit.AgentDispatch{ - Id: guid.New(guid.AgentDispatchPrefix), - AgentName: ag.AgentName, - Metadata: ag.Metadata, - Room: r.protoRoom.Name, - State: &livekit.AgentDispatchState{ - CreatedAt: now.UnixNano(), - }, - } - r.agentDispatches = append(r.agentDispatches, ad) - if r.agentStore != nil { - err := r.agentStore.StoreAgentDispatch(context.Background(), ad) - if err != nil { - r.Logger.Warnw("failed storing room dispatch", err) - } + _, err := r.createAgentDispatchFromParams(ag.AgentName, ag.Metadata) + if err != nil { + r.Logger.Warnw("failed storing room dispatch", err) } } } diff --git a/pkg/service/agent_dispatch_service.go b/pkg/service/agent_dispatch_service.go new file mode 100644 index 000000000..680c98268 --- /dev/null +++ b/pkg/service/agent_dispatch_service.go @@ -0,0 +1,61 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" +) + +type AgentDispatchService struct { + agentDispatchClient rpc.TypedAgentDispatchInternalClient + topicFormatter rpc.TopicFormatter +} + +func NewAgentDispatchService(agentDispatchClient rpc.TypedAgentDispatchInternalClient, topicFormatter rpc.TopicFormatter) *AgentDispatchService { + return &AgentDispatchService{ + agentDispatchClient: agentDispatchClient, + topicFormatter: topicFormatter, + } +} + +func (ag *AgentDispatchService) CreateDispatch(ctx context.Context, req *livekit.CreateAgentDispatchRequest) (*livekit.AgentDispatch, error) { + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.CreateDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} + +func (ag *AgentDispatchService) DeleteDispatch(ctx context.Context, req *livekit.DeleteAgentDispatchRequest) (*livekit.AgentDispatch, error) { + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.DeleteDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} + +func (ag *AgentDispatchService) ListDispatch(ctx context.Context, req *livekit.ListAgentDispatchRequest) (*livekit.ListAgentDispatchResponse, error) { + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.ListDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index d626dbd53..c7c0846c0 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -90,6 +90,7 @@ type AgentHandler struct { serverInfo *livekit.ServerInfo workers map[string]*agent.Worker + jobToWorker map[string]*agent.Worker keyProvider auth.KeyProvider namespaceWorkers map[workerKey][]*agent.Worker @@ -163,6 +164,7 @@ func NewAgentHandler( agentServer: agentServer, logger: logger, workers: make(map[string]*agent.Worker), + jobToWorker: make(map[string]*agent.Worker), namespaceWorkers: make(map[workerKey][]*agent.Worker), serverInfo: serverInfo, keyProvider: keyProvider, @@ -302,6 +304,27 @@ func (h *AgentHandler) HandleWorkerDeregister(w *agent.Worker) { h.agentNames = slices.Delete(h.agentNames, i, i+1) } } + + jobs := w.RunningJobs() + for _, j := range jobs { + h.deregisterJob(j.Id) + } +} + +func (h *AgentHandler) HandleWorkerJobStatus(w *agent.Worker, status *livekit.UpdateJobStatus) { + if agent.JobStatusIsEnded(status.Status) { + h.mu.Lock() + h.deregisterJob(status.JobId) + h.mu.Unlock() + } +} + +func (h *AgentHandler) deregisterJob(jobID string) { + h.agentServer.DeregisterJobTerminateTopic(jobID) + + delete(h.jobToWorker, jobID) + + // TODO update dispatch state } func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.JobRequestResponse, error) { @@ -335,6 +358,14 @@ func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.J } return nil, err } + h.mu.Lock() + h.jobToWorker[job.Id] = selected + h.mu.Unlock() + + err = h.agentServer.RegisterJobTerminateTopic(job.Id) + if err != nil { + h.logger.Errorw("failes registering JobTerminate handler", err, values...) + } return &rpc.JobRequestResponse{ State: job.State, @@ -367,6 +398,25 @@ func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) return affinity } +func (h *AgentHandler) JobTerminate(ctx context.Context, req *rpc.JobTerminateRequest) (*rpc.JobTerminateResponse, error) { + h.mu.Lock() + w := h.jobToWorker[req.JobId] + h.mu.Unlock() + + if w == nil { + return nil, psrpc.NewErrorf(psrpc.NotFound, "no worker for jobID") + } + + state, err := w.TerminateJob(req.JobId, req.Reason) + if err != nil { + return nil, err + } + + return &rpc.JobTerminateResponse{ + State: state, + }, nil +} + func (h *AgentHandler) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) { h.mu.Lock() defer h.mu.Unlock() diff --git a/pkg/service/ingress.go b/pkg/service/ingress.go index e0a08b720..fcee84634 100644 --- a/pkg/service/ingress.go +++ b/pkg/service/ingress.go @@ -41,7 +41,6 @@ type IngressService struct { psrpcClient rpc.IngressClient store IngressStore io IOClient - roomService livekit.RoomService telemetry telemetry.TelemetryService launcher IngressLauncher } @@ -53,7 +52,6 @@ func NewIngressServiceWithIngressLauncher( psrpcClient rpc.IngressClient, store IngressStore, io IOClient, - rs livekit.RoomService, ts telemetry.TelemetryService, launcher IngressLauncher, ) *IngressService { @@ -65,7 +63,6 @@ func NewIngressServiceWithIngressLauncher( psrpcClient: psrpcClient, store: store, io: io, - roomService: rs, telemetry: ts, launcher: launcher, } @@ -78,10 +75,9 @@ func NewIngressService( psrpcClient rpc.IngressClient, store IngressStore, io IOClient, - rs livekit.RoomService, ts telemetry.TelemetryService, ) *IngressService { - s := NewIngressServiceWithIngressLauncher(conf, nodeID, bus, psrpcClient, store, io, rs, ts, nil) + s := NewIngressServiceWithIngressLauncher(conf, nodeID, bus, psrpcClient, store, io, ts, nil) s.launcher = s diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 06f783a43..7024d9fe9 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -82,8 +82,9 @@ type RoomManager struct { rooms map[livekit.RoomName]*rtc.Room - roomServers utils.MultitonService[rpc.RoomTopic] - participantServers utils.MultitonService[rpc.ParticipantTopic] + roomServers utils.MultitonService[rpc.RoomTopic] + agentDispatchServers utils.MultitonService[rpc.RoomTopic] + participantServers utils.MultitonService[rpc.ParticipantTopic] iceConfigCache *sutils.IceConfigCache[iceConfigCacheKey] @@ -229,6 +230,7 @@ func (r *RoomManager) Stop() { } r.roomServers.Kill() + r.agentDispatchServers.Kill() r.participantServers.Kill() if r.rtcConfig != nil { @@ -570,9 +572,17 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room r.lock.Unlock() return nil, err } + agentDispatchServer := must.Get(rpc.NewTypedAgentDispatchInternalServer(r, r.bus)) + killDispServer := r.agentDispatchServers.Replace(roomTopic, agentDispatchServer) + if err := agentDispatchServer.RegisterAllRoomTopics(roomTopic); err != nil { + killDispServer() + r.lock.Unlock() + return nil, err + } newRoom.OnClose(func() { killRoomServer() + killDispServer() roomInfo := newRoom.ToProto() r.telemetry.RoomEnded(ctx, roomInfo) @@ -807,6 +817,52 @@ func (r *RoomManager) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return room.ToProto(), nil } +func (r *RoomManager) ListDispatch(ctx context.Context, req *livekit.ListAgentDispatchRequest) (*livekit.ListAgentDispatchResponse, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.GetAgentDispatches(req.DispatchId) + if err != nil { + return nil, err + } + + ret := &livekit.ListAgentDispatchResponse{ + AgentDispatches: disp, + } + + return ret, nil +} + +func (r *RoomManager) CreateDispatch(ctx context.Context, req *livekit.CreateAgentDispatchRequest) (*livekit.AgentDispatch, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.AddAgentDispatch(req.AgentName, req.Metadata) + if err != nil { + return nil, err + } + + return disp, nil +} + +func (r *RoomManager) DeleteDispatch(ctx context.Context, req *livekit.DeleteAgentDispatchRequest) (*livekit.AgentDispatch, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.DeleteAgentDispatch(req.DispatchId) + if err != nil { + return nil, err + } + + return disp, nil +} + func (r *RoomManager) iceServersForParticipant(apiKey string, participant types.LocalParticipant, tlsOnly bool) []*livekit.ICEServer { var iceServers []*livekit.ICEServer rtcConf := r.config.RTC diff --git a/pkg/service/server.go b/pkg/service/server.go index 88498518e..bfdd5d4e5 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -62,6 +62,7 @@ type LivekitServer struct { func NewLivekitServer(conf *config.Config, roomService livekit.RoomService, + agentDispatchService *AgentDispatchService, egressService *EgressService, ingressService *IngressService, sipService *SIPService, @@ -109,6 +110,7 @@ func NewLivekitServer(conf *config.Config, twirpLoggingHook := TwirpLogger() twirpRequestStatusHook := TwirpRequestStatusReporter() roomServer := livekit.NewRoomServiceServer(roomService, twirpLoggingHook) + agentDispatchServer := livekit.NewAgentDispatchServiceServer(agentDispatchService, twirpLoggingHook) egressServer := livekit.NewEgressServer(egressService, twirp.WithServerHooks( twirp.ChainHooks( twirpLoggingHook, @@ -127,6 +129,7 @@ func NewLivekitServer(conf *config.Config, } mux.Handle(roomServer.PathPrefix(), roomServer) + mux.Handle(agentDispatchServer.PathPrefix(), agentDispatchServer) mux.Handle(egressServer.PathPrefix(), egressServer) mux.Handle(ingressServer.PathPrefix(), ingressServer) mux.Handle(sipServer.PathPrefix(), sipServer) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 75b3509a1..56d601bf6 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -79,6 +79,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewRoomService, NewRTCService, NewAgentService, + NewAgentDispatchService, agent.NewAgentClient, getAgentStore, getSignalRelayConfig, @@ -90,6 +91,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live rpc.NewTopicFormatter, rpc.NewTypedRoomClient, rpc.NewTypedParticipantClient, + rpc.NewTypedAgentDispatchInternalClient, NewLocalRoomManager, NewTURNAuthHandler, getTURNAuthHandlerFunc, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 15f81061f..f27f4dde6 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -100,13 +100,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } + agentDispatchInternalClient, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) + if err != nil { + return nil, err + } + agentDispatchService := NewAgentDispatchService(agentDispatchInternalClient, topicFormatter) egressService := NewEgressService(egressClient, rtcEgressLauncher, objectStore, ioInfoService, roomService) ingressConfig := getIngressConfig(conf) ingressClient, err := rpc.NewIngressClient(clientParams) if err != nil { return nil, err } - ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, ioInfoService, roomService, telemetryService) + ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, ioInfoService, telemetryService) sipConfig := getSIPConfig(conf) sipClient, err := rpc.NewSIPClient(messageBus) if err != nil { @@ -136,7 +141,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, sipService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, agentDispatchService, egressService, ingressService, sipService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) if err != nil { return nil, err }