diff --git a/go.mod b/go.mod index 5df1d6509..df947357d 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e - github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 + github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e github.com/livekit/psrpc v0.5.0 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 @@ -27,7 +27,7 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/pion/dtls/v2 v2.2.7 github.com/pion/ice/v2 v2.3.11 - github.com/pion/interceptor v0.1.24 + github.com/pion/interceptor v0.1.25 github.com/pion/rtcp v1.2.10 github.com/pion/rtp v1.8.2 github.com/pion/sctp v1.8.9 @@ -47,7 +47,7 @@ require ( github.com/urfave/negroni/v3 v3.0.0 go.uber.org/atomic v1.11.0 golang.org/x/exp v0.0.0-20231006140011-7918f672742d - golang.org/x/sync v0.4.0 + golang.org/x/sync v0.5.0 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 4aedde2df..c59b38f48 100644 --- a/go.sum +++ b/go.sum @@ -125,8 +125,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e h1:yNeIo7MSMUWgoLu7LkNKnBYnJBFPFH9Wq4S6h1kS44M= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY= -github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 h1:WPWxU9w5XHAsonxnSSIIXbWMty9b5uHnTnyKS9TpaXM= -github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= +github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e h1:YShBpEjkEBY7yil2gjMWlkVkxs3OI58LIIYsBdb8aBU= +github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= github.com/livekit/psrpc v0.5.0 h1:g+yYNSs6Y1/vM7UlFkB2s/ARe2y3RKWZhX8ata5j+eo= github.com/livekit/psrpc v0.5.0/go.mod h1:1XYH1LLoD/YbvBvt6xg2KQ/J3InLXSJK6PL/+DKmuAU= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= @@ -185,8 +185,8 @@ github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ github.com/pion/ice/v2 v2.3.11 h1:rZjVmUwyT55cmN8ySMpL7rsS8KYsJERsrxJLLxpKhdw= github.com/pion/ice/v2 v2.3.11/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E= github.com/pion/interceptor v0.1.18/go.mod h1:tpvvF4cPM6NGxFA1DUMbhabzQBxdWMATDGEUYOR9x6I= -github.com/pion/interceptor v0.1.24 h1:lN4ua3yUAJCgNKQKcZIM52wFjBgjN0r7shLj91PkJ0c= -github.com/pion/interceptor v0.1.24/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y= +github.com/pion/interceptor v0.1.25 h1:pwY9r7P6ToQ3+IF0bajN0xmk/fNw/suTgaTdlwTDmhc= +github.com/pion/interceptor v0.1.25/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.8 h1:HhicWIg7OX5PVilyBO6plhMetInbzkVJAhbdJiAeVaI= @@ -335,8 +335,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/clientconfiguration/staticconfiguration.go b/pkg/clientconfiguration/staticconfiguration.go index 2d0fcf984..e0309c26c 100644 --- a/pkg/clientconfiguration/staticconfiguration.go +++ b/pkg/clientconfiguration/staticconfiguration.go @@ -40,7 +40,7 @@ func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit. for _, c := range s.confs { matched, err := c.Match.Match(clientInfo) if err != nil { - logger.Errorw("matchrule failed", err, "clientInfo", clientInfo.String()) + logger.Errorw("matchrule failed", err, "clientInfo", logger.Proto(clientInfo)) continue } if !matched { diff --git a/pkg/rtc/agentclient.go b/pkg/rtc/agentclient.go new file mode 100644 index 000000000..65be3ba7c --- /dev/null +++ b/pkg/rtc/agentclient.go @@ -0,0 +1,92 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" +) + +const ( + RoomAgentTopic = "room" + PublisherAgentTopic = "publisher" +) + +type AgentClient interface { + CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse + JobRequest(ctx context.Context, job *livekit.Job) +} + +type agentClient struct { + client rpc.AgentInternalClient +} + +func NewAgentClient(bus psrpc.MessageBus) (AgentClient, error) { + client, err := rpc.NewAgentInternalClient(bus) + if err != nil { + return nil, err + } + return &agentClient{client: client}, nil +} + +func (c *agentClient) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse { + res := &rpc.CheckEnabledResponse{} + resChan, err := c.client.CheckEnabled(ctx, req, psrpc.WithRequestTimeout(time.Second)) + if err != nil { + return res + } + + for r := range resChan { + if r.Err != nil { + continue + } + if r.Result.RoomEnabled { + res.RoomEnabled = true + if res.PublisherEnabled { + return res + } + } + if r.Result.PublisherEnabled { + res.PublisherEnabled = true + if res.RoomEnabled { + return res + } + } + } + + return res +} + +func (c *agentClient) JobRequest(ctx context.Context, job *livekit.Job) { + var topic string + var logError bool + switch job.Type { + case livekit.JobType_JT_ROOM: + topic = RoomAgentTopic + case livekit.JobType_JT_PUBLISHER: + topic = PublisherAgentTopic + logError = true + } + + _, err := c.client.JobRequest(ctx, topic, job) + if err != nil && logError { + logger.Warnw("agent job request failed", err) + } +} diff --git a/pkg/rtc/room_egress.go b/pkg/rtc/egress.go similarity index 100% rename from pkg/rtc/room_egress.go rename to pkg/rtc/egress.go diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 0edd60565..8bbb3db20 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -134,7 +134,7 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool return false } -func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string { +func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { // sort these by compatibility, since we are looking for backups if slices.ContainsFunc(enabledCodecs, func(c *livekit.Codec) bool { return strings.EqualFold(c.Mime, webrtc.MimeTypeVP8) @@ -146,9 +146,11 @@ func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string { }) { return webrtc.MimeTypeH264 } - if len(enabledCodecs) > 0 { - return enabledCodecs[0].Mime + for _, c := range enabledCodecs { + if strings.HasPrefix(c.Mime, "video/") { + return c.Mime + } } - // uh oh. this should not happen - return "" + // no viable codec in the list of enabled codecs, fall back to the most widely supported codec + return webrtc.MimeTypeVP8 } diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 9b7f1c311..584c953c3 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -239,13 +239,22 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra t.params.Logger.Debugw("AddReceiver", "mime", track.Codec().MimeType) wr := t.MediaTrackReceiver.Receiver(mime) if wr == nil { - var priority int + priority := -1 for idx, c := range t.params.TrackInfo.Codecs { - if strings.HasSuffix(mime, c.MimeType) { + if strings.EqualFold(mime, c.MimeType) { priority = idx break } } + if len(t.params.TrackInfo.Codecs) == 0 { + priority = 0 + } + if priority < 0 { + t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(t.params.TrackInfo)) + t.lock.Unlock() + return false + } + newWR := sfu.NewWebRTCReceiver( receiver, track, diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 88d7f2aaf..404cbf8a2 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -101,7 +101,8 @@ type ParticipantParams struct { PLIThrottleConfig config.PLIThrottleConfig CongestionControlConfig config.CongestionControlConfig // codecs that are enabled for this room - EnabledCodecs []*livekit.Codec + PublishEnabledCodecs []*livekit.Codec + SubscribeEnabledCodecs []*livekit.Codec Logger logger.Logger SimTracks map[uint32]SimulcastTrackInfo Grants *auth.ClaimGrants @@ -138,6 +139,7 @@ type ParticipantImpl struct { resSinkMu sync.Mutex resSink routing.MessageSink grants *auth.ClaimGrants + hidden atomic.Bool isPublisher atomic.Bool // when first connected @@ -248,8 +250,9 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.migrateState.Store(types.MigrateStateInit) p.state.Store(livekit.ParticipantInfo_JOINING) p.grants = params.Grants + p.hidden.Store(p.grants.Video.Hidden) p.SetResponseSink(params.Sink) - p.setupEnabledCodecs(params.EnabledCodecs, params.ClientConf.GetDisabledCodecs()) + p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs()) p.supervisor.OnPublicationError(p.onPublicationError) @@ -425,6 +428,7 @@ func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermissio p.params.Logger.Infow("updating participant permission", "permission", permission) video.UpdateFromPermission(permission) + p.hidden.Store(permission.Hidden) p.dirty.Store(true) canPublish := video.GetCanPublish() @@ -710,7 +714,7 @@ func (p *ParticipantImpl) SetMigrateInfo( p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted) p.pendingTracks[t.GetCid()] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}, migrated: true} - p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", ti.String()) + p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", logger.Proto(ti)) } p.pendingTracksLock.Unlock() @@ -734,7 +738,7 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea "sendLeave", sendLeave, "reason", reason.String(), "isExpectedToResume", isExpectedToResume, - "clientInfo", p.params.ClientInfo.String(), + "clientInfo", logger.Proto(p.params.ClientInfo), ) p.clearDisconnectTimer() p.clearMigrationTimer() @@ -1020,10 +1024,7 @@ func (p *ParticipantImpl) CanPublishData() bool { } func (p *ParticipantImpl) Hidden() bool { - p.lock.RLock() - defer p.lock.RUnlock() - - return p.grants.Video.Hidden + return p.hidden.Load() } func (p *ParticipantImpl) IsRecorder() bool { @@ -1033,6 +1034,13 @@ func (p *ParticipantImpl) IsRecorder() bool { return p.grants.Video.Recorder } +func (p *ParticipantImpl) IsAgent() bool { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.grants.Video.Agent +} + func (p *ParticipantImpl) VerifySubscribeParticipantInfo(pID livekit.ParticipantID, version uint32) { if !p.IsReady() { // we have not sent a JoinResponse yet. metadata would be covered in JoinResponse @@ -1623,10 +1631,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l seenCodecs := make(map[string]struct{}) for _, codec := range req.SimulcastCodecs { mime := codec.Codec - if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") { - mime = "video/" + mime + if req.Type == livekit.TrackType_VIDEO { + if !strings.HasPrefix(mime, "video/") { + mime = "video/" + mime + } if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { - altCodec := selectAlternativeCodec(p.enabledPublishCodecs) + altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow("falling back to alternative codec", "codec", mime, "altCodec", altCodec, @@ -1639,7 +1649,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l mime = "audio/" + mime } - if _, ok := seenCodecs[mime]; ok { + if _, ok := seenCodecs[mime]; ok || mime == "" { continue } seenCodecs[mime] = struct{}{} @@ -1658,12 +1668,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l } else { p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti) } - p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", ti.String(), "request", req.String()) + p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req)) return nil } p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} - p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", ti.String(), "request", req.String()) + p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req)) return ti } @@ -1681,7 +1691,7 @@ func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.Trac } func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) { - p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", ti.String()) + p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", logger.Proto(ti)) _ = p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_TrackPublished{ TrackPublished: &livekit.TrackPublishedResponse{ @@ -1761,12 +1771,32 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei // use existing media track to handle simulcast mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) if !ok { - signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) + signalCid, ti, migrated := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) if ti == nil { p.pendingTracksLock.Unlock() return nil, false } + // check if the migrated track has correct codec + if migrated && len(ti.Codecs) > 0 { + parameters := rtpReceiver.GetParameters() + var codecFound int + for _, c := range ti.Codecs { + for _, nc := range parameters.Codecs { + if strings.EqualFold(nc.MimeType, c.MimeType) { + codecFound++ + break + } + } + } + if codecFound != len(ti.Codecs) { + p.params.Logger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters) + p.pendingTracksLock.Unlock() + p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch) + return nil, false + } + } + ti.MimeType = track.Codec().MimeType mt = p.addMediaTrack(signalCid, track.ID(), ti) newTrack = true @@ -1793,7 +1823,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei } func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack { - p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", ti.String()) + p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti)) rtpReceiver := p.TransportManager.GetPublisherRTPReceiver(ti.Mid) if rtpReceiver == nil { p.pubLogger.Errorw("could not find receiver for migrated track", nil, "trackID", ti.Sid) @@ -1975,7 +2005,7 @@ func (p *ParticipantImpl) onUpTrackManagerClose() { p.postRtcp(nil) } -func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { +func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo, bool) { signalCid := clientId pendingInfo := p.pendingTracks[clientId] if pendingInfo == nil { @@ -2011,10 +2041,10 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp // if still not found, we are done if pendingInfo == nil { p.pubLogger.Errorw("track info not published prior to track", nil, "clientId", clientId) - return signalCid, nil + return signalCid, nil, false } - return signalCid, pendingInfo.trackInfos[0] + return signalCid, pendingInfo.trackInfos[0], pendingInfo.migrated } // setStableTrackID either generates a new TrackID or reuses a previously used one @@ -2205,7 +2235,7 @@ func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason scr := types.SignallingCloseReasonUnknown switch reason { - case types.ParticipantCloseReasonPublicationError: + case types.ParticipantCloseReasonPublicationError, types.ParticipantCloseReasonMigrateCodecMismatch: scr = types.SignallingCloseReasonFullReconnectPublicationError case types.ParticipantCloseReasonSubscriptionError: scr = types.SignallingCloseReasonFullReconnectSubscriptionError @@ -2334,9 +2364,7 @@ func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket, data []byte) er return err } -func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) { - subscribeCodecs := make([]*livekit.Codec, 0, len(codecs)) - publishCodecs := make([]*livekit.Codec, 0, len(codecs)) +func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Codec, subscribeEnabledCodecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) { shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool { for _, disableCodec := range disabled { // disable codec's fmtp is empty means disable this codec entirely @@ -2346,22 +2374,22 @@ func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCo } return false } - for _, c := range codecs { - var publishDisabled bool - var subscribeDisabled bool + + publishCodecs := make([]*livekit.Codec, 0, len(publishEnabledCodecs)) + for _, c := range publishEnabledCodecs { + if shouldDisable(c, disabledCodecs.GetCodecs()) || shouldDisable(c, disabledCodecs.GetPublish()) { + continue + } + publishCodecs = append(publishCodecs, c) + } + p.enabledPublishCodecs = publishCodecs + + subscribeCodecs := make([]*livekit.Codec, 0, len(subscribeEnabledCodecs)) + for _, c := range subscribeEnabledCodecs { if shouldDisable(c, disabledCodecs.GetCodecs()) { - publishDisabled = true - subscribeDisabled = true - } else if shouldDisable(c, disabledCodecs.GetPublish()) { - publishDisabled = true - } - if !publishDisabled { - publishCodecs = append(publishCodecs, c) - } - if !subscribeDisabled { - subscribeCodecs = append(subscribeCodecs, c) + continue } + subscribeCodecs = append(subscribeCodecs, c) } p.enabledSubscribeCodecs = subscribeCodecs - p.enabledPublishCodecs = publishCodecs } diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 2b2ce2505..afd312bc8 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -284,7 +284,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - _, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO) + _, ti, _ := p.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) }) @@ -748,19 +748,20 @@ func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *p } sid := livekit.ParticipantID(utils.NewGuid(utils.ParticipantPrefix)) p, _ := NewParticipant(ParticipantParams{ - SID: sid, - Identity: identity, - Config: rtcConf, - Sink: &routingfakes.FakeMessageSink{}, - ProtocolVersion: opts.protocolVersion, - PLIThrottleConfig: conf.RTC.PLIThrottle, - Grants: grants, - EnabledCodecs: enabledCodecs, - ClientConf: opts.clientConf, - ClientInfo: ClientInfo{ClientInfo: opts.clientInfo}, - Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false), - Telemetry: &telemetryfakes.FakeTelemetryService{}, - VersionGenerator: utils.NewDefaultTimedVersionGenerator(), + SID: sid, + Identity: identity, + Config: rtcConf, + Sink: &routingfakes.FakeMessageSink{}, + ProtocolVersion: opts.protocolVersion, + PLIThrottleConfig: conf.RTC.PLIThrottle, + Grants: grants, + PublishEnabledCodecs: enabledCodecs, + SubscribeEnabledCodecs: enabledCodecs, + ClientConf: opts.clientConf, + ClientInfo: ClientInfo{ClientInfo: opts.clientInfo}, + Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false), + Telemetry: &telemetryfakes.FakeTelemetryService{}, + VersionGenerator: utils.NewDefaultTimedVersionGenerator(), }) p.isPublisher.Store(opts.publisher) p.updateState(livekit.ParticipantInfo_ACTIVE) diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 70829a34c..51e98c478 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -46,7 +46,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se } p.pendingTracksLock.RLock() - _, info := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + _, info, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) // if RED is disabled for this track, don't prefer RED codec in offer disableRed := info != nil && info.DisableRed p.pendingTracksLock.RUnlock() @@ -132,7 +132,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess if mt != nil { info = mt.ToProto() } else { - _, info = p.getPendingTrack(streamID, livekit.TrackType_VIDEO) + _, info, _ = p.getPendingTrack(streamID, livekit.TrackType_VIDEO) } if info == nil { @@ -239,7 +239,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack) if track == nil { p.pendingTracksLock.RLock() - _, ti = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + _, ti, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) p.pendingTracksLock.RUnlock() } else { ti = track.TrackInfo(false) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 77ef5d1bb..e35fbc738 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -30,6 +30,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" "github.com/livekit/livekit-server/pkg/config" @@ -77,11 +78,15 @@ type Room struct { egressLauncher EgressLauncher trackManager *RoomTrackManager + // agents + agentClient AgentClient + publisherAgentsEnabled bool + // map of identity -> Participant participants map[livekit.ParticipantIdentity]types.LocalParticipant participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource - hasPublished sync.Map // map of identity -> bool + hasPublished map[livekit.ParticipantIdentity]bool bufferFactory *buffer.FactoryOfBufferFactory // batch update participant info for non-publishers @@ -113,6 +118,7 @@ func NewRoom( audioConfig *config.AudioConfig, serverInfo *livekit.ServerInfo, telemetry telemetry.TelemetryService, + agentClient AgentClient, egressLauncher EgressLauncher, ) *Room { r := &Room{ @@ -127,16 +133,19 @@ func NewRoom( audioConfig: audioConfig, telemetry: telemetry, egressLauncher: egressLauncher, + agentClient: agentClient, trackManager: NewRoomTrackManager(), serverInfo: serverInfo, participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), + hasPublished: make(map[livekit.ParticipantIdentity]bool), bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSize), batchedUpdates: make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo), closed: make(chan struct{}), trailer: []byte(utils.RandomSecret()), } + r.protoProxy = utils.NewProtoProxy[*livekit.Room](roomUpdateInterval, r.updateProto) if r.protoRoom.EmptyTimeout == 0 { r.protoRoom.EmptyTimeout = DefaultEmptyTimeout @@ -145,6 +154,21 @@ func NewRoom( r.protoRoom.CreationTime = time.Now().Unix() } + if agentClient != nil { + go func() { + res := r.agentClient.CheckEnabled(context.Background(), &rpc.CheckEnabledRequest{}) + if res.PublisherEnabled { + r.lock.Lock() + r.publisherAgentsEnabled = true + // if there are already published tracks, start the agents + for identity := range r.hasPublished { + r.launchPublisherAgent(r.participants[identity]) + } + r.lock.Unlock() + } + }() + } + go r.audioUpdateWorker() go r.connectionQualityWorker() go r.changeUpdateWorker() @@ -474,6 +498,7 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek delete(r.participants, identity) delete(r.participantOpts, identity) delete(r.participantRequestSources, identity) + delete(r.hasPublished, identity) if !p.Hidden() { r.protoRoom.NumParticipants-- } @@ -509,7 +534,6 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek for _, t := range p.GetPublishedTracks() { r.trackManager.RemoveTrack(t) } - r.hasPublished.Delete(p.Identity()) p.OnTrackUpdated(nil) p.OnTrackPublished(nil) @@ -898,10 +922,19 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. r.trackManager.AddTrack(track, participant.Identity(), participant.ID()) - // auto egress - if r.internal != nil { - if r.internal.ParticipantEgress != nil { - if _, hasPublished := r.hasPublished.Swap(participant.Identity(), true); !hasPublished { + // launch jobs + r.lock.Lock() + hasPublished := r.hasPublished[participant.Identity()] + r.hasPublished[participant.Identity()] = true + publisherAgentsEnabled := r.publisherAgentsEnabled + r.lock.Unlock() + + if !hasPublished { + if publisherAgentsEnabled { + r.launchPublisherAgent(participant) + } + if r.internal != nil && r.internal.ParticipantEgress != nil { + go func() { if err := StartParticipantEgress( context.Background(), r.egressLauncher, @@ -913,9 +946,11 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. ); err != nil { r.Logger.Errorw("failed to launch participant egress", err) } - } + }() } - if r.internal.TrackEgress != nil { + } + if r.internal != nil && r.internal.TrackEgress != nil { + go func() { if err := StartTrackEgress( context.Background(), r.egressLauncher, @@ -927,7 +962,7 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. ); err != nil { r.Logger.Errorw("failed to launch track egress", err) } - } + }() } } @@ -1286,6 +1321,21 @@ func (r *Room) connectionQualityWorker() { } } +func (r *Room) launchPublisherAgent(p types.Participant) { + if p == nil || p.IsRecorder() || p.IsAgent() { + return + } + + go func() { + r.agentClient.JobRequest(context.Background(), &livekit.Job{ + Id: utils.NewGuid("JP_"), + Type: livekit.JobType_JT_PUBLISHER, + Room: r.ToProto(), + Participant: p.ToProto(), + }) + }() +} + func (r *Room) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "Name": r.protoRoom.Name, diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index 0e31e4772..b86527583 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -739,7 +739,7 @@ func newRoomWithParticipants(t *testing.T, opts testRoomOpts) *Room { Region: "testregion", }, telemetry.NewTelemetryService(webhook.NewDefaultNotifier("", "", nil), &telemetryfakes.FakeAnalyticsService{}), - nil, + nil, nil, ) for i := 0; i < opts.num+opts.numHidden; i++ { identity := livekit.ParticipantIdentity(fmt.Sprintf("p%d", i)) diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 1a20d6540..ca22ffeec 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -833,7 +833,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni } dcCloseHandler := func() { - t.params.Logger.Infow(dc.Label() + " data channel close") + t.params.Logger.Debugw(dc.Label() + " data channel close") } dcErrorHandler := func(err error) { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index da3ee7178..bef8abb6e 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -104,6 +104,7 @@ const ( ParticipantCloseReasonPublicationError ParticipantCloseReasonSubscriptionError ParticipantCloseReasonDataChannelError + ParticipantCloseReasonMigrateCodecMismatch ) func (p ParticipantCloseReason) String() string { @@ -154,6 +155,8 @@ func (p ParticipantCloseReason) String() string { return "SUBSCRIPTION_ERROR" case ParticipantCloseReasonDataChannelError: return "DATA_CHANNEL_ERROR" + case ParticipantCloseReasonMigrateCodecMismatch: + return "MIGRATE_CODEC_MISMATCH" default: return fmt.Sprintf("%d", int(p)) } @@ -184,7 +187,7 @@ func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason { return livekit.DisconnectReason_SERVER_SHUTDOWN case ParticipantCloseReasonOvercommitted: return livekit.DisconnectReason_SERVER_SHUTDOWN - case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError: + case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch: return livekit.DisconnectReason_STATE_MISMATCH default: // the other types will map to unknown reason @@ -261,6 +264,7 @@ type Participant interface { // permissions Hidden() bool IsRecorder() bool + IsAgent() bool Start() Close(sendLeave bool, reason ParticipantCloseReason, isExpectedToResume bool) error diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 4020502be..9e11470c3 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -408,6 +408,16 @@ type FakeLocalParticipant struct { identityReturnsOnCall map[int]struct { result1 livekit.ParticipantIdentity } + IsAgentStub func() bool + isAgentMutex sync.RWMutex + isAgentArgsForCall []struct { + } + isAgentReturns struct { + result1 bool + } + isAgentReturnsOnCall map[int]struct { + result1 bool + } IsClosedStub func() bool isClosedMutex sync.RWMutex isClosedArgsForCall []struct { @@ -2986,6 +2996,59 @@ func (fake *FakeLocalParticipant) IdentityReturnsOnCall(i int, result1 livekit.P }{result1} } +func (fake *FakeLocalParticipant) IsAgent() bool { + fake.isAgentMutex.Lock() + ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)] + fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct { + }{}) + stub := fake.IsAgentStub + fakeReturns := fake.isAgentReturns + fake.recordInvocation("IsAgent", []interface{}{}) + fake.isAgentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsAgentCallCount() int { + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() + return len(fake.isAgentArgsForCall) +} + +func (fake *FakeLocalParticipant) IsAgentCalls(stub func() bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = stub +} + +func (fake *FakeLocalParticipant) IsAgentReturns(result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + fake.isAgentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsAgentReturnsOnCall(i int, result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + if fake.isAgentReturnsOnCall == nil { + fake.isAgentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAgentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) IsClosed() bool { fake.isClosedMutex.Lock() ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] @@ -6031,6 +6094,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.identityMutex.RLock() defer fake.identityMutex.RUnlock() + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() fake.isClosedMutex.RLock() defer fake.isClosedMutex.RUnlock() fake.isDisconnectedMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index fa92204fe..3f46bdbba 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -118,6 +118,16 @@ type FakeParticipant struct { identityReturnsOnCall map[int]struct { result1 livekit.ParticipantIdentity } + IsAgentStub func() bool + isAgentMutex sync.RWMutex + isAgentArgsForCall []struct { + } + isAgentReturns struct { + result1 bool + } + isAgentReturnsOnCall map[int]struct { + result1 bool + } IsPublisherStub func() bool isPublisherMutex sync.RWMutex isPublisherArgsForCall []struct { @@ -780,6 +790,59 @@ func (fake *FakeParticipant) IdentityReturnsOnCall(i int, result1 livekit.Partic }{result1} } +func (fake *FakeParticipant) IsAgent() bool { + fake.isAgentMutex.Lock() + ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)] + fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct { + }{}) + stub := fake.IsAgentStub + fakeReturns := fake.isAgentReturns + fake.recordInvocation("IsAgent", []interface{}{}) + fake.isAgentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsAgentCallCount() int { + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() + return len(fake.isAgentArgsForCall) +} + +func (fake *FakeParticipant) IsAgentCalls(stub func() bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = stub +} + +func (fake *FakeParticipant) IsAgentReturns(result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + fake.isAgentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsAgentReturnsOnCall(i int, result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + if fake.isAgentReturnsOnCall == nil { + fake.isAgentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAgentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeParticipant) IsPublisher() bool { fake.isPublisherMutex.Lock() ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)] @@ -1318,6 +1381,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.identityMutex.RLock() defer fake.identityMutex.RUnlock() + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() fake.isPublisherMutex.RLock() defer fake.isPublisherMutex.RUnlock() fake.isRecorderMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 1d1dfa6e5..1c079dc34 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -160,8 +160,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( u.params.Logger.Debugw( "skipping older subscription permission version", "existingValue", perms, - "existingVersion", u.subscriptionPermissionVersion.ToProto().String(), - "requestingValue", subscriptionPermission.String(), + "existingVersion", u.subscriptionPermissionVersion.String(), + "requestingValue", logger.Proto(subscriptionPermission), "requestingVersion", timedVersion.String(), ) u.lock.Unlock() @@ -178,7 +178,7 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( if subscriptionPermission == nil { u.params.Logger.Debugw( "updating subscription permission, setting to nil", - "version", u.subscriptionPermissionVersion.ToProto().String(), + "version", u.subscriptionPermissionVersion.String(), ) // possible to get a nil when migrating u.lock.Unlock() @@ -187,8 +187,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( u.params.Logger.Debugw( "updating subscription permission", - "permissions", u.subscriptionPermission.String(), - "version", u.subscriptionPermissionVersion.ToProto().String(), + "permissions", logger.Proto(u.subscriptionPermission), + "version", u.subscriptionPermissionVersion.String(), ) if err := u.parseSubscriptionPermissionsLocked(subscriptionPermission, func(pID livekit.ParticipantID) types.LocalParticipant { u.lock.Unlock() @@ -247,7 +247,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { u.publishedTracks[track.ID()] = track } u.lock.Unlock() - u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", track.ToProto().String()) + u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", logger.Proto(track.ToProto())) track.AddOnClose(func() { notifyClose := false diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go new file mode 100644 index 000000000..baaf40da4 --- /dev/null +++ b/pkg/service/agentservice.go @@ -0,0 +1,468 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "io" + "math/rand" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" +) + +const AgentServiceVersion = "0.1.0" + +type AgentService struct { + upgrader websocket.Upgrader + + *AgentHandler +} + +type AgentHandler struct { + agentServer rpc.AgentInternalServer + roomTopic string + publisherTopic string + + mu sync.Mutex + availability map[string]chan *availability + unregistered map[*websocket.Conn]*worker + roomRegistered bool + roomWorkers map[string]*worker + publisherRegistered bool + publisherWorkers map[string]*worker +} + +type worker struct { + mu sync.Mutex + conn *websocket.Conn + sigConn *WSSignalConnection + + id string + jobType livekit.JobType + status livekit.WorkerStatus + activeJobs int +} + +type availability struct { + workerID string + available bool +} + +func NewAgentService(bus psrpc.MessageBus) (*AgentService, error) { + s := &AgentService{ + upgrader: websocket.Upgrader{}, + } + + // allow connections from any origin, since script may be hosted anywhere + // security is enforced by access tokens + s.upgrader.CheckOrigin = func(r *http.Request) bool { + return true + } + + agentServer, err := rpc.NewAgentInternalServer(s, bus) + if err != nil { + return nil, err + } + s.AgentHandler = NewAgentHandler(agentServer, rtc.RoomAgentTopic, rtc.PublisherAgentTopic) + + return s, nil +} + +func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) { + // reject non websocket requests + if !websocket.IsWebSocketUpgrade(r) { + writer.WriteHeader(404) + return + } + + // require a claim + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil || !claims.Video.Agent { + handleError(writer, http.StatusUnauthorized, rtc.ErrPermissionDenied) + return + } + + // upgrade + conn, err := s.upgrader.Upgrade(writer, r, nil) + if err != nil { + handleError(writer, http.StatusInternalServerError, err) + return + } + + s.HandleConnection(conn) +} + +func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, publisherTopic string) *AgentHandler { + return &AgentHandler{ + agentServer: agentServer, + roomTopic: roomTopic, + publisherTopic: publisherTopic, + availability: make(map[string]chan *availability), + unregistered: make(map[*websocket.Conn]*worker), + roomWorkers: make(map[string]*worker), + publisherWorkers: make(map[string]*worker), + } +} + +func (s *AgentHandler) HandleConnection(conn *websocket.Conn) { + sigConn := NewWSSignalConnection(conn) + w := &worker{ + conn: conn, + sigConn: sigConn, + } + + s.mu.Lock() + s.unregistered[conn] = w + s.mu.Unlock() + + defer func() { + s.mu.Lock() + if w.id == "" { + delete(s.unregistered, conn) + } else { + switch w.jobType { + case livekit.JobType_JT_ROOM: + delete(s.roomWorkers, w.id) + if s.roomRegistered && !s.roomAvailableLocked() { + s.roomRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.roomTopic) + } + case livekit.JobType_JT_PUBLISHER: + delete(s.publisherWorkers, w.id) + if s.publisherRegistered && !s.publisherAvailableLocked() { + s.publisherRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.publisherTopic) + } + } + } + s.mu.Unlock() + }() + + // handle incoming requests from websocket + for { + req, _, err := sigConn.ReadWorkerMessage() + if err != nil { + // normal/expected closure + if err == io.EOF || + strings.HasSuffix(err.Error(), "use of closed network connection") || + strings.HasSuffix(err.Error(), "connection reset by peer") || + websocket.IsCloseError( + err, + websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + ) { + logger.Infow("exit ws read loop for closed connection", "wsError", err) + } else { + logger.Errorw("error reading from websocket", err) + } + return + } + + switch m := req.Message.(type) { + case *livekit.WorkerMessage_Register: + go s.handleRegister(w, m.Register) + case *livekit.WorkerMessage_Availability: + go s.handleAvailability(w, m.Availability) + case *livekit.WorkerMessage_JobUpdate: + go s.handleJobUpdate(w, m.JobUpdate) + case *livekit.WorkerMessage_Status: + go s.handleStatus(w, m.Status) + } + } +} + +func (s *AgentHandler) handleRegister(worker *worker, msg *livekit.RegisterWorkerRequest) { + s.mu.Lock() + defer s.mu.Unlock() + + switch msg.Type { + case livekit.JobType_JT_ROOM: + worker.id = msg.WorkerId + delete(s.unregistered, worker.conn) + s.roomWorkers[worker.id] = worker + + if !s.roomRegistered { + err := s.agentServer.RegisterJobRequestTopic(s.roomTopic) + if err != nil { + logger.Errorw("failed to register room agents", err) + } else { + s.roomRegistered = true + } + } + + case livekit.JobType_JT_PUBLISHER: + worker.id = msg.WorkerId + delete(s.unregistered, worker.conn) + s.publisherWorkers[worker.id] = worker + + if !s.publisherRegistered { + err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic) + if err != nil { + logger.Errorw("failed to register publisher agents", err) + } else { + s.publisherRegistered = true + } + } + } + + _, err := worker.sigConn.WriteServerMessage(&livekit.ServerMessage{ + Message: &livekit.ServerMessage_Register{ + Register: &livekit.RegisterWorkerResponse{ + WorkerId: worker.id, + ServerVersion: AgentServiceVersion, + }, + }, + }) + if err != nil { + logger.Errorw("failed to write server message", err) + } +} + +func (s *AgentHandler) handleAvailability(w *worker, msg *livekit.AvailabilityResponse) { + s.mu.Lock() + availabilityChan, ok := s.availability[msg.JobId] + s.mu.Unlock() + + if ok { + availabilityChan <- &availability{ + workerID: w.id, + available: msg.Available, + } + } +} + +func (s *AgentHandler) handleJobUpdate(w *worker, msg *livekit.JobStatusUpdate) { + switch msg.Status { + case livekit.JobStatus_JS_SUCCESS: + logger.Debugw("job complete", "jobID", msg.JobId) + case livekit.JobStatus_JS_FAILED: + logger.Warnw("job failed", errors.New(msg.Error), "jobID", msg.JobId) + } + + w.mu.Lock() + w.activeJobs-- + w.mu.Unlock() +} + +func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) { + s.mu.Lock() + defer s.mu.Unlock() + + w.mu.Lock() + w.status = msg.Status + w.mu.Unlock() + + switch w.jobType { + case livekit.JobType_JT_ROOM: + if s.roomRegistered && !s.roomAvailableLocked() { + s.roomRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.roomTopic) + } else if !s.roomRegistered && s.roomAvailableLocked() { + if err := s.agentServer.RegisterJobRequestTopic(s.roomTopic); err != nil { + logger.Errorw("failed to register room agents", err) + } else { + s.roomRegistered = true + } + } + case livekit.JobType_JT_PUBLISHER: + if s.publisherRegistered && !s.publisherAvailableLocked() { + s.publisherRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.publisherTopic) + } else if !s.publisherRegistered && s.publisherAvailableLocked() { + if err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic); err != nil { + logger.Errorw("failed to register publisher agents", err) + } else { + s.publisherRegistered = true + } + } + } +} + +func (s *AgentHandler) CheckEnabled(_ context.Context, _ *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) { + s.mu.Lock() + res := &rpc.CheckEnabledResponse{ + RoomEnabled: len(s.roomWorkers) > 0, + PublisherEnabled: len(s.publisherWorkers) > 0, + } + s.mu.Unlock() + return res, nil +} + +func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) { + s.mu.Lock() + ac := make(chan *availability, 100) + s.availability[job.Id] = ac + s.mu.Unlock() + + defer func() { + s.mu.Lock() + delete(s.availability, job.Id) + s.mu.Unlock() + }() + + var pool map[string]*worker + switch job.Type { + case livekit.JobType_JT_ROOM: + pool = s.roomWorkers + case livekit.JobType_JT_PUBLISHER: + pool = s.publisherWorkers + } + + attempted := make(map[string]bool) + for { + select { + case <-ctx.Done(): + return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out") + default: + s.mu.Lock() + var selected *worker + for _, w := range pool { + if attempted[w.id] { + continue + } + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + if w.activeJobs > 0 { + selected = w + break + } else if selected == nil { + selected = w + } + } + } + s.mu.Unlock() + + if selected == nil { + return nil, psrpc.NewErrorf(psrpc.Unavailable, "no workers available") + } + + attempted[selected.id] = true + _, err := selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Availability{ + Availability: &livekit.AvailabilityRequest{Job: job}, + }}) + if err != nil { + logger.Errorw("failed to send availability request", err) + return nil, err + } + + select { + case <-ctx.Done(): + return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out") + case res := <-ac: + if res.available { + _, err = selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Assignment{ + Assignment: &livekit.JobAssignment{Job: job}, + }}) + if err != nil { + logger.Errorw("failed to assign job", err) + } else { + selected.mu.Lock() + selected.activeJobs++ + selected.mu.Unlock() + return &emptypb.Empty{}, nil + } + } + } + } + } +} + +func (s *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 { + s.mu.Lock() + defer s.mu.Unlock() + + var pool map[string]*worker + switch job.Type { + case livekit.JobType_JT_ROOM: + pool = s.roomWorkers + case livekit.JobType_JT_PUBLISHER: + pool = s.publisherWorkers + } + + var affinity float32 + for _, w := range pool { + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + if w.activeJobs > 0 { + return 1 + } else { + affinity = 0.5 + } + } + } + + return affinity +} + +func (s *AgentHandler) NumConnections() int { + s.mu.Lock() + defer s.mu.Unlock() + + return len(s.unregistered) + len(s.roomWorkers) + len(s.publisherWorkers) +} + +func (s *AgentHandler) DrainConnections(interval time.Duration) { + // jitter drain start + time.Sleep(time.Duration(rand.Int63n(int64(interval)))) + + t := time.NewTicker(interval) + defer t.Stop() + + s.mu.Lock() + defer s.mu.Unlock() + + for conn := range s.unregistered { + _ = conn.Close() + <-t.C + } + for _, w := range s.roomWorkers { + _ = w.conn.Close() + <-t.C + } + for _, w := range s.publisherWorkers { + _ = w.conn.Close() + <-t.C + } +} + +func (s *AgentHandler) roomAvailableLocked() bool { + for _, w := range s.roomWorkers { + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + return true + } + } + return false +} + +func (s *AgentHandler) publisherAvailableLocked() bool { + for _, w := range s.publisherWorkers { + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + return true + } + } + return false +} diff --git a/pkg/service/clients.go b/pkg/service/clients.go new file mode 100644 index 000000000..53c0e415b --- /dev/null +++ b/pkg/service/clients.go @@ -0,0 +1,75 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" +) + +type IOClient interface { + CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) + GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) + ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) +} + +type egressLauncher struct { + client rpc.EgressClient + io IOClient +} + +func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher { + if client == nil { + return nil + } + return &egressLauncher{ + client: client, + io: io, + } +} + +func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { + info, err := s.StartEgressWithClusterId(ctx, "", req) + if err != nil { + return nil, err + } + + _, err = s.io.CreateEgress(ctx, info) + if err != nil { + logger.Errorw("failed to create egress", err) + } + + return info, nil +} + +func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { + if s.client == nil { + return nil, ErrEgressNotConnected + } + + // Ensure we have an Egress ID + if req.EgressId == "" { + req.EgressId = utils.NewGuid(utils.EgressPrefix) + } + + return s.client.StartEgress(ctx, clusterId, req) +} diff --git a/pkg/service/egress.go b/pkg/service/egress.go index d971aaa34..a745caaf0 100644 --- a/pkg/service/egress.go +++ b/pkg/service/egress.go @@ -25,40 +25,23 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/protocol/egress" "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" - "github.com/livekit/protocol/utils" ) type EgressService struct { + launcher rtc.EgressLauncher client rpc.EgressClient io IOClient roomService livekit.RoomService store ServiceStore - launcher rtc.EgressLauncher -} - -type egressLauncher struct { - client rpc.EgressClient - io IOClient -} - -func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher { - if client == nil { - return nil - } - return &egressLauncher{ - client: client, - io: io, - } } func NewEgressService( client rpc.EgressClient, + launcher rtc.EgressLauncher, store ServiceStore, io IOClient, rs livekit.RoomService, - launcher rtc.EgressLauncher, ) *EgressService { return &EgressService{ client: client, @@ -189,33 +172,6 @@ func (s *EgressService) startEgress(ctx context.Context, roomName livekit.RoomNa return s.launcher.StartEgress(ctx, req) } -func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { - info, err := s.StartEgressWithClusterId(ctx, "", req) - if err != nil { - return nil, err - } - - _, err = s.io.CreateEgress(ctx, info) - if err != nil { - logger.Errorw("failed to create egress", err) - } - - return info, nil -} - -func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { - if s.client == nil { - return nil, ErrEgressNotConnected - } - - // Ensure we have an Egress ID - if req.EgressId == "" { - req.EgressId = utils.NewGuid(utils.EgressPrefix) - } - - return s.client.StartEgress(ctx, clusterId, req) -} - type LayoutMetadata struct { Layout string `json:"layout"` } diff --git a/pkg/service/interfaces.go b/pkg/service/interfaces.go index 653269225..6c32d7c5c 100644 --- a/pkg/service/interfaces.go +++ b/pkg/service/interfaces.go @@ -18,10 +18,7 @@ import ( "context" "time" - "google.golang.org/protobuf/types/known/emptypb" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/rpc" ) //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate @@ -74,13 +71,6 @@ type IngressStore interface { DeleteIngress(ctx context.Context, info *livekit.IngressInfo) error } -//counterfeiter:generate . IOClient -type IOClient interface { - CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) - GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) - ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) -} - //counterfeiter:generate . RoomAllocator type RoomAllocator interface { CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, bool, error) diff --git a/pkg/service/roomallocator.go b/pkg/service/roomallocator.go index f222f7b15..cf88261c6 100644 --- a/pkg/service/roomallocator.go +++ b/pkg/service/roomallocator.go @@ -60,8 +60,10 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre }() // find existing room and update it + var created bool rm, internal, err := r.roomStore.LoadRoom(ctx, livekit.RoomName(req.Name), true) if err == ErrRoomNotFound { + created = true rm = &livekit.Room{ Sid: utils.NewGuid(utils.RoomPrefix), Name: req.Name, @@ -114,7 +116,7 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre return nil, false, routing.ErrNodeLimitReached } - return rm, false, nil + return rm, created, nil } // select a new node diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 918e847bf..d26eec0d5 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -69,6 +69,7 @@ type RoomManager struct { roomStore ObjectStore telemetry telemetry.TelemetryService clientConfManager clientconfiguration.ClientConfigurationManager + agentClient rtc.AgentClient egressLauncher rtc.EgressLauncher versionGenerator utils.TimedVersionGenerator turnAuthHandler *TURNAuthHandler @@ -89,6 +90,7 @@ func NewLocalRoomManager( router routing.Router, telemetry telemetry.TelemetryService, clientConfManager clientconfiguration.ClientConfigurationManager, + agentClient rtc.AgentClient, egressLauncher rtc.EgressLauncher, versionGenerator utils.TimedVersionGenerator, turnAuthHandler *TURNAuthHandler, @@ -108,6 +110,7 @@ func NewLocalRoomManager( telemetry: telemetry, clientConfManager: clientConfManager, egressLauncher: egressLauncher, + agentClient: agentClient, versionGenerator: versionGenerator, turnAuthHandler: turnAuthHandler, bus: bus, @@ -393,7 +396,8 @@ func (r *RoomManager) StartSession( Trailer: room.Trailer(), PLIThrottleConfig: r.config.RTC.PLIThrottle, CongestionControlConfig: r.config.RTC.CongestionControl, - EnabledCodecs: protoRoom.EnabledCodecs, + PublishEnabledCodecs: protoRoom.EnabledCodecs, + SubscribeEnabledCodecs: protoRoom.EnabledCodecs, Grants: pi.Grants, Logger: pLogger, ClientConf: clientConf, @@ -526,7 +530,7 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room } // construct ice servers - newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.egressLauncher) + newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.agentClient, r.egressLauncher) roomTopic := rpc.FormatRoomTopic(roomName) roomServer := utils.Must(rpc.NewTypedRoomServer(r, r.bus)) diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index aa5e323d2..174ea3391 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -31,6 +31,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" ) // A rooms service that supports a single node @@ -41,6 +42,7 @@ type RoomService struct { router routing.MessageRouter roomAllocator RoomAllocator roomStore ServiceStore + agentClient rtc.AgentClient egressLauncher rtc.EgressLauncher topicFormatter rpc.TopicFormatter roomClient rpc.TypedRoomClient @@ -54,6 +56,7 @@ func NewRoomService( router routing.MessageRouter, roomAllocator RoomAllocator, serviceStore ServiceStore, + agentClient rtc.AgentClient, egressLauncher rtc.EgressLauncher, topicFormatter rpc.TopicFormatter, roomClient rpc.TypedRoomClient, @@ -66,6 +69,7 @@ func NewRoomService( router: router, roomAllocator: roomAllocator, roomStore: serviceStore, + agentClient: agentClient, egressLauncher: egressLauncher, topicFormatter: topicFormatter, roomClient: roomClient, @@ -112,17 +116,29 @@ func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomReq return nil, err } - if created && req.Egress != nil && req.Egress.Room != nil { - egress := &rpc.StartEgressRequest{ - Request: &rpc.StartEgressRequest_RoomComposite{ - RoomComposite: req.Egress.Room, - }, - RoomId: rm.Sid, + if created { + go func() { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: utils.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: rm, + }) + }() + + if req.Egress != nil && req.Egress.Room != nil { + _, err = s.egressLauncher.StartEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_RoomComposite{ + RoomComposite: req.Egress.Room, + }, + RoomId: rm.Sid, + }) + if err != nil { + return nil, err + } } - _, err = s.egressLauncher.StartEgress(ctx, egress) } - return rm, err + return rm, nil } func (s *RoomService) ListRooms(ctx context.Context, req *livekit.ListRoomsRequest) (*livekit.ListRoomsResponse, error) { @@ -427,7 +443,7 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat // no one has joined the room, would not have been created on an RTC node. // in this case, we'd want to run create again - _, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{ + room, created, err := s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{ Name: req.Room, Metadata: req.Metadata, }) @@ -465,6 +481,16 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return nil, err } + if created { + go func() { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: utils.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: room, + }) + }() + } + return room, nil } diff --git a/pkg/service/roomservice_test.go b/pkg/service/roomservice_test.go index f15a72cd7..f9de24a51 100644 --- a/pkg/service/roomservice_test.go +++ b/pkg/service/roomservice_test.go @@ -136,6 +136,7 @@ func newTestRoomService(conf config.RoomConfig) *TestRoomService { allocator, store, nil, + nil, rpc.NewTopicFormatter(), &rpcfakes.FakeTypedRoomClient{}, &rpcfakes.FakeTypedParticipantClient{}, diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index b0683115b..884a89656 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -31,17 +31,17 @@ import ( "github.com/ua-parser/uap-go/uaparser" "golang.org/x/exp/maps" - "github.com/livekit/livekit-server/pkg/utils" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" - "github.com/livekit/psrpc" - "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/routing/selector" "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + putil "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" ) type RTCService struct { @@ -54,6 +54,7 @@ type RTCService struct { isDev bool limits config.LimitConfig parser *uaparser.Parser + agentClient rtc.AgentClient telemetry telemetry.TelemetryService mu sync.Mutex @@ -66,6 +67,7 @@ func NewRTCService( store ServiceStore, router routing.MessageRouter, currentNode routing.LocalNode, + agentClient rtc.AgentClient, telemetry telemetry.TelemetryService, ) *RTCService { s := &RTCService{ @@ -78,6 +80,7 @@ func NewRTCService( isDev: conf.Development, limits: conf.Limit, parser: uaparser.NewFromSaved(), + agentClient: agentClient, telemetry: telemetry, connections: map[*websocket.Conn]struct{}{}, } @@ -229,6 +232,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Warnw("failed to start connection, retrying", err, fieldsWithAttempt...) } } + if err != nil { prometheus.IncrementParticipantJoinFail(1) handleError(w, http.StatusInternalServerError, err, loggerFields...) @@ -514,10 +518,16 @@ type connectionResult struct { ResponseSource routing.MessageSource } -func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomName, pi routing.ParticipantInit, timeout time.Duration) (connectionResult, *livekit.SignalResponse, error) { +func (s *RTCService) startConnection( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + timeout time.Duration, +) (connectionResult, *livekit.SignalResponse, error) { var cr connectionResult + var created bool var err error - cr.Room, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)}) + cr.Room, created, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)}) if err != nil { return cr, nil, err } @@ -538,6 +548,17 @@ func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomN cr.ResponseSource.Close() return cr, nil, err } + + if created && s.agentClient != nil { + go func() { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: putil.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: cr.Room, + }) + }() + } + return cr, initialResponse, nil } @@ -559,5 +580,4 @@ func readInitialResponse(source routing.MessageSource, timeout time.Duration) (* return res, nil } } - } diff --git a/pkg/service/server.go b/pkg/service/server.go index c67555614..f7ade50d2 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -48,6 +48,7 @@ type LivekitServer struct { config *config.Config ioService *IOInfoService rtcService *RTCService + agentService *AgentService httpServer *http.Server promServer *http.Server router routing.Router @@ -66,6 +67,7 @@ func NewLivekitServer(conf *config.Config, ingressService *IngressService, ioService *IOInfoService, rtcService *RTCService, + agentService *AgentService, keyProvider auth.KeyProvider, router routing.Router, roomManager *RoomManager, @@ -77,6 +79,7 @@ func NewLivekitServer(conf *config.Config, config: conf, ioService: ioService, rtcService: rtcService, + agentService: agentService, router: router, roomManager: roomManager, signalServer: signalServer, @@ -125,6 +128,7 @@ func NewLivekitServer(conf *config.Config, mux.Handle(egressServer.PathPrefix(), egressServer) mux.Handle(ingressServer.PathPrefix(), ingressServer) mux.Handle("/rtc", rtcService) + mux.Handle("/agent", agentService) mux.HandleFunc("/rtc/validate", rtcService.Validate) mux.HandleFunc("/", s.defaultHandler) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 2e2b39592..afca2631e 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -30,6 +30,7 @@ import ( "github.com/livekit/livekit-server/pkg/clientconfiguration" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" @@ -62,16 +63,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewIOInfoService, wire.Bind(new(IOClient), new(*IOInfoService)), rpc.NewEgressClient, + rpc.NewIngressClient, getEgressStore, NewEgressLauncher, NewEgressService, - rpc.NewIngressClient, getIngressStore, getIngressConfig, NewIngressService, NewRoomAllocator, NewRoomService, NewRTCService, + NewAgentService, + rtc.NewAgentClient, getSignalRelayConfig, NewDefaultSignalServer, routing.NewSignalClient, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 322d125e6..3a6470b88 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -11,6 +11,7 @@ import ( "github.com/livekit/livekit-server/pkg/clientconfiguration" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" @@ -55,6 +56,10 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } + agentClient, err := rtc.NewAgentClient(messageBus) + if err != nil { + return nil, err + } egressClient, err := rpc.NewEgressClient(messageBus) if err != nil { return nil, err @@ -86,22 +91,26 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient) + roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, agentClient, rtcEgressLauncher, topicFormatter, roomClient, participantClient) if err != nil { return nil, err } - egressService := NewEgressService(egressClient, objectStore, ioInfoService, roomService, rtcEgressLauncher) + egressService := NewEgressService(egressClient, rtcEgressLauncher, objectStore, ioInfoService, roomService) ingressConfig := getIngressConfig(conf) ingressClient, err := rpc.NewIngressClient(messageBus) if err != nil { return nil, err } ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, roomService, telemetryService) - rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, telemetryService) + rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, agentClient, telemetryService) + agentService, err := NewAgentService(messageBus) + if err != nil { + return nil, err + } clientConfigurationManager := createClientConfiguration() timedVersionGenerator := utils.NewDefaultTimedVersionGenerator() turnAuthHandler := NewTURNAuthHandler(keyProvider) - roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus) + roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, agentClient, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus) if err != nil { return nil, err } @@ -114,7 +123,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, signalServer, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) if err != nil { return nil, err } diff --git a/pkg/service/wsprotocol.go b/pkg/service/wsprotocol.go index 50a4dc2fb..5ac7c744f 100644 --- a/pkg/service/wsprotocol.go +++ b/pkg/service/wsprotocol.go @@ -83,6 +83,40 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error) } } +func (c *WSSignalConnection) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) { + for { + // handle special messages and pass on the rest + messageType, payload, err := c.conn.ReadMessage() + if err != nil { + return nil, 0, err + } + + msg := &livekit.WorkerMessage{} + switch messageType { + case websocket.BinaryMessage: + if c.useJSON { + c.mu.Lock() + // switch to protobuf if client supports it + c.useJSON = false + c.mu.Unlock() + } + // protobuf encoded + err := proto.Unmarshal(payload, msg) + return msg, len(payload), err + case websocket.TextMessage: + c.mu.Lock() + // json encoded, also write back JSON + c.useJSON = true + c.mu.Unlock() + err := protojson.Unmarshal(payload, msg) + return msg, len(payload), err + default: + logger.Debugw("unsupported message", "message", messageType) + return nil, len(payload), nil + } + } +} + func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, error) { var msgType int var payload []byte @@ -105,6 +139,28 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, er return len(payload), c.conn.WriteMessage(msgType, payload) } +func (c *WSSignalConnection) WriteServerMessage(msg *livekit.ServerMessage) (int, error) { + var msgType int + var payload []byte + var err error + + c.mu.Lock() + defer c.mu.Unlock() + + if c.useJSON { + msgType = websocket.TextMessage + payload, err = protojson.Marshal(msg) + } else { + msgType = websocket.BinaryMessage + payload, err = proto.Marshal(msg) + } + if err != nil { + return 0, err + } + + return len(payload), c.conn.WriteMessage(msgType, payload) +} + func (c *WSSignalConnection) pingWorker() { for { <-time.After(pingFrequency) diff --git a/pkg/sfu/audio/audiolevel.go b/pkg/sfu/audio/audiolevel.go index 7e834fda7..f908788d8 100644 --- a/pkg/sfu/audio/audiolevel.go +++ b/pkg/sfu/audio/audiolevel.go @@ -16,8 +16,8 @@ package audio import ( "math" - - "go.uber.org/atomic" + "sync" + "time" ) const ( @@ -40,11 +40,13 @@ type AudioLevel struct { smoothFactor float64 activeThreshold float64 - smoothedLevel atomic.Float64 + lock sync.Mutex + smoothedLevel float64 loudestObservedLevel uint8 activeDuration uint32 // ms observedDuration uint32 // ms + lastObservedAt time.Time } func NewAudioLevel(params AudioLevelParams) *AudioLevel { @@ -64,8 +66,13 @@ func NewAudioLevel(params AudioLevelParams) *AudioLevel { return l } -// Observes a new frame, must be called from the same thread -func (l *AudioLevel) Observe(level uint8, durationMs uint32) { +// Observes a new frame +func (l *AudioLevel) Observe(level uint8, durationMs uint32, arrivalTime time.Time) { + l.lock.Lock() + defer l.lock.Unlock() + + l.lastObservedAt = arrivalTime + l.observedDuration += durationMs if level <= l.params.ActiveLevel { @@ -76,6 +83,7 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) { } if l.observedDuration >= l.params.ObserveDuration { + smoothedLevel := float64(0.0) // compute and reset if l.activeDuration >= l.minActiveDuration { // adjust loudest observed level by how much of the window was active. @@ -87,25 +95,39 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) { linearLevel := ConvertAudioLevel(adjustedLevel) // exponential smoothing to dampen transients - smoothedLevel := l.smoothedLevel.Load() - smoothedLevel += (linearLevel - smoothedLevel) * l.smoothFactor - l.smoothedLevel.Store(smoothedLevel) - } else { - l.smoothedLevel.Store(0) + smoothedLevel = l.smoothedLevel + (linearLevel-l.smoothedLevel)*l.smoothFactor } - l.loudestObservedLevel = silentAudioLevel - l.activeDuration = 0 - l.observedDuration = 0 + l.resetLocked(smoothedLevel) } } // returns current soothed audio level -func (l *AudioLevel) GetLevel() (float64, bool) { - smoothedLevel := l.smoothedLevel.Load() - active := smoothedLevel >= l.activeThreshold - return smoothedLevel, active +func (l *AudioLevel) GetLevel(now time.Time) (float64, bool) { + l.lock.Lock() + defer l.lock.Unlock() + + l.resetIfStaleLocked(now) + + return l.smoothedLevel, l.smoothedLevel >= l.activeThreshold } +func (l *AudioLevel) resetIfStaleLocked(arrivalTime time.Time) { + if arrivalTime.Sub(l.lastObservedAt).Milliseconds() < int64(2*l.params.ObserveDuration) { + return + } + + l.resetLocked(0.0) +} + +func (l *AudioLevel) resetLocked(smoothedLevel float64) { + l.smoothedLevel = smoothedLevel + l.loudestObservedLevel = silentAudioLevel + l.activeDuration = 0 + l.observedDuration = 0 +} + +// --------------------------------------------------- + // convert decibel back to linear func ConvertAudioLevel(level float64) float64 { return math.Pow(10, level*negInv20) diff --git a/pkg/sfu/audio/audiolevel_test.go b/pkg/sfu/audio/audiolevel_test.go index 8b8f03eba..84d5a34dd 100644 --- a/pkg/sfu/audio/audiolevel_test.go +++ b/pkg/sfu/audio/audiolevel_test.go @@ -16,6 +16,7 @@ package audio import ( "testing" + "time" "github.com/stretchr/testify/require" ) @@ -30,46 +31,79 @@ const ( func TestAudioLevel(t *testing.T) { t.Run("initially to return not noisy, within a few samples", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - _, noisy := a.GetLevel() + + _, noisy := a.GetLevel(clock) require.False(t, noisy) - observeSamples(a, 28, 5) - _, noisy = a.GetLevel() + observeSamples(a, 28, 5, clock) + clock = clock.Add(5 * 20 * time.Millisecond) + + _, noisy = a.GetLevel(clock) require.False(t, noisy) }) t.Run("not noisy when all samples are below threshold", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - observeSamples(a, 35, 100) - _, noisy := a.GetLevel() + observeSamples(a, 35, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + + _, noisy := a.GetLevel(clock) require.False(t, noisy) }) t.Run("not noisy when less than percentile samples are above threshold", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - observeSamples(a, 35, samplesPerBatch-2) - observeSamples(a, 25, 1) - observeSamples(a, 35, 1) + observeSamples(a, 35, samplesPerBatch-2, clock) + clock = clock.Add((samplesPerBatch - 2) * 20 * time.Millisecond) + observeSamples(a, 25, 1, clock) + clock = clock.Add(20 * time.Millisecond) + observeSamples(a, 35, 1, clock) + clock = clock.Add(20 * time.Millisecond) - _, noisy := a.GetLevel() + _, noisy := a.GetLevel(clock) require.False(t, noisy) }) t.Run("noisy when higher than percentile samples are above threshold", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - observeSamples(a, 35, samplesPerBatch-16) - observeSamples(a, 25, 8) - observeSamples(a, 29, 8) + observeSamples(a, 35, samplesPerBatch-16, clock) + clock = clock.Add((samplesPerBatch - 16) * 20 * time.Millisecond) + observeSamples(a, 25, 8, clock) + clock = clock.Add(8 * 20 * time.Millisecond) + observeSamples(a, 29, 8, clock) + clock = clock.Add(8 * 20 * time.Millisecond) - level, noisy := a.GetLevel() + level, noisy := a.GetLevel(clock) require.True(t, noisy) require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) require.Less(t, level, ConvertAudioLevel(float64(25))) }) + + t.Run("not noisy when samples are stale", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamples(a, 25, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + level, noisy := a.GetLevel(clock) + require.True(t, noisy) + require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) + require.Less(t, level, ConvertAudioLevel(float64(20))) + + // let enough time pass to make the samples stale + clock = clock.Add(1500 * time.Millisecond) + level, noisy = a.GetLevel(clock) + require.Equal(t, float64(0.0), level) + require.False(t, noisy) + }) } func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration uint32) *AudioLevel { @@ -80,8 +114,8 @@ func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration ui }) } -func observeSamples(a *AudioLevel, level uint8, count int) { +func observeSamples(a *AudioLevel, level uint8, count int, baseTime time.Time) { for i := 0; i < count; i++ { - a.Observe(level, 20) + a.Observe(level, 20, baseTime.Add(+time.Duration(i*20)*time.Millisecond)) } } diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index f9e761c1a..e0d3241e5 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -31,6 +31,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/audio" + dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/utils" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/mediatransportutil" @@ -39,8 +40,6 @@ import ( "github.com/livekit/mediatransportutil/pkg/twcc" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - - dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" ) const ( @@ -593,7 +592,7 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime time.Time) { if (p.Timestamp - b.latestTSForAudioLevel) < (1 << 31) { duration := (int64(p.Timestamp) - int64(b.latestTSForAudioLevel)) * 1e3 / int64(b.clockRate) if duration > 0 { - b.audioLevel.Observe(ext.Level, uint32(duration)) + b.audioLevel.Observe(ext.Level, uint32(duration), arrivalTime) } b.latestTSForAudioLevel = p.Timestamp @@ -624,9 +623,6 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime time.Time, flow if b.ddParser != nil { ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet) if err != nil { - if err != ErrFrameEarlierThanKeyFrame { - b.logger.Warnw("could not parse dependency descriptor", err) - } return nil } else if ddVal != nil { ep.DependencyDescriptor = ddVal @@ -855,7 +851,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) { return 0, false } - return b.audioLevel.GetLevel() + return b.audioLevel.GetLevel(time.Now()) } func (b *Buffer) OnFpsChanged(f func()) { diff --git a/pkg/sfu/buffer/dependencydescriptorparser.go b/pkg/sfu/buffer/dependencydescriptorparser.go index 05b675ad0..6c91af260 100644 --- a/pkg/sfu/buffer/dependencydescriptorparser.go +++ b/pkg/sfu/buffer/dependencydescriptorparser.go @@ -27,7 +27,8 @@ import ( ) var ( - ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe") + ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe") + ErrDDStructureAttachedToNonFirstPacket = fmt.Errorf("dependency descriptor structure is attached to non-first packet of a frame") ) type DependencyDescriptorParser struct { @@ -39,7 +40,6 @@ type DependencyDescriptorParser struct { seqWrapAround *utils.WrapAround[uint16, uint64] frameWrapAround *utils.WrapAround[uint16, uint64] - structureExtSeq uint64 structureExtFrameNum uint64 activeDecodeTargetsExtSeq uint64 activeDecodeTargetsMask uint32 @@ -66,12 +66,15 @@ type ExtDependencyDescriptor struct { ActiveDecodeTargetsUpdated bool Integrity bool ExtFrameNum uint64 + // the frame number of the keyframe which the current frame depends on + ExtKeyFrameNum uint64 } func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescriptor, VideoLayer, error) { var videoLayer VideoLayer ddBuf := pkt.GetExtension(r.ddExtID) if ddBuf == nil { + r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber) return nil, videoLayer, nil } @@ -82,7 +85,9 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } _, err := ext.Unmarshal(ddBuf) if err != nil { - // r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf)) + if err != dd.ErrDDReaderNoStructure { + r.logger.Warnw("failed to parse generic dependency descriptor", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf)) + } return nil, videoLayer, err } @@ -108,17 +113,21 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } if ddVal.AttachedStructure != nil { - r.logger.Debugw("parsed dependency descriptor", "extSeq", extSeq, "extFN", extFN, "structureID", ddVal.AttachedStructure.StructureId, "descriptor", ddVal.String()) - if extSeq > r.structureExtSeq { - r.structure = ddVal.AttachedStructure - r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure) - r.structureExtSeq = extSeq - r.structureExtFrameNum = extFN - extDD.StructureUpdated = true - extDD.ActiveDecodeTargetsUpdated = true - // The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present, - // so don't need to notify max layer change here. + if !ddVal.FirstPacketInFrame { + r.logger.Warnw("attached structure is not the first packet in frame", nil, "extSeq", extSeq, "extFN", extFN) + return nil, videoLayer, ErrDDStructureAttachedToNonFirstPacket } + + if r.structure == nil || ddVal.AttachedStructure.StructureId != r.structure.StructureId { + r.logger.Infow("structure updated", "structureID", ddVal.AttachedStructure.StructureId, "extSeq", extSeq, "extFN", extFN, "descriptor", ddVal.String()) + } + r.structure = ddVal.AttachedStructure + r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure) + r.structureExtFrameNum = extFN + extDD.StructureUpdated = true + extDD.ActiveDecodeTargetsUpdated = true + // The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present, + // so don't need to notify max layer change here. } if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil && extSeq > r.activeDecodeTargetsExtSeq { @@ -143,6 +152,7 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } extDD.DecodeTargets = r.decodeTargets + extDD.ExtKeyFrameNum = r.structureExtFrameNum return extDD, videoLayer, nil } diff --git a/pkg/sfu/buffer/rtpstats_receiver.go b/pkg/sfu/buffer/rtpstats_receiver.go index 3869ebae8..dad7c66ef 100644 --- a/pkg/sfu/buffer/rtpstats_receiver.go +++ b/pkg/sfu/buffer/rtpstats_receiver.go @@ -129,31 +129,30 @@ func (r *RTPStatsReceiver) Update( pktSize := uint64(hdrSize + payloadSize + paddingSize) gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest) if gapSN <= 0 { // duplicate OR out-of-order - if payloadSize == 0 { - // do not start on a padding only packet - if resTS.IsRestart { - r.logger.Infow( - "rolling back timestamp restart", - "tsBefore", resTS.PreExtendedStart, - "tsAfter", r.timestamp.GetExtendedStart(), - "snBefore", resSN.PreExtendedStart, - "snAfter", r.sequenceNumber.GetExtendedStart(), - ) - r.timestamp.RollbackRestart(resTS.PreExtendedStart) - } - if resSN.IsRestart { - r.logger.Infow( - "rolling back sequence number restart", - "snBefore", resSN.PreExtendedStart, - "snAfter", r.sequenceNumber.GetExtendedStart(), - "tsBefore", resTS.PreExtendedStart, - "tsAfter", r.timestamp.GetExtendedStart(), - ) - r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart) - flowState.IsNotHandled = true - return - } + // before start, don't restart + if resTS.IsRestart { + r.logger.Infow( + "rolling back timestamp restart", + "tsBefore", resTS.PreExtendedStart, + "tsAfter", r.timestamp.GetExtendedStart(), + "snBefore", resSN.PreExtendedStart, + "snAfter", r.sequenceNumber.GetExtendedStart(), + ) + r.timestamp.RollbackRestart(resTS.PreExtendedStart) } + if resSN.IsRestart { + r.logger.Infow( + "rolling back sequence number restart", + "snBefore", resSN.PreExtendedStart, + "snAfter", r.sequenceNumber.GetExtendedStart(), + "tsBefore", resTS.PreExtendedStart, + "tsAfter", r.timestamp.GetExtendedStart(), + ) + r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart) + flowState.IsNotHandled = true + return + } + if -gapSN >= cNumSequenceNumbers/2 { r.logger.Warnw( "large sequence number gap negative", nil, diff --git a/pkg/sfu/buffer/rtpstats_receiver_test.go b/pkg/sfu/buffer/rtpstats_receiver_test.go index 34fea0f4b..7a39c1b54 100644 --- a/pkg/sfu/buffer/rtpstats_receiver_test.go +++ b/pkg/sfu/buffer/rtpstats_receiver_test.go @@ -130,7 +130,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - // out-of-order + // out-of-order, would cause a restart which is disallowed packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) flowState = r.Update( time.Now(), @@ -142,14 +142,15 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 0, ) require.False(t, flowState.HasLoss) + require.True(t, flowState.IsNotHandled) require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - require.Equal(t, uint64(1), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsOutOfOrder) require.Equal(t, uint64(0), r.packetsDuplicate) - // duplicate + // duplicate of the above out-of-order packet, but would not be handled as it causes a restart packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) flowState = r.Update( time.Now(), @@ -161,12 +162,13 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 0, ) require.False(t, flowState.HasLoss) + require.True(t, flowState.IsNotHandled) require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - require.Equal(t, uint64(2), r.packetsOutOfOrder) - require.Equal(t, uint64(1), r.packetsDuplicate) + require.Equal(t, uint64(0), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) // loss sequenceNumber += 10 @@ -184,10 +186,10 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.True(t, flowState.HasLoss) require.Equal(t, uint64(sequenceNumber-9), flowState.LossStartInclusive) require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) - require.Equal(t, uint64(17), r.packetsLost) + require.Equal(t, uint64(9), r.packetsLost) // out-of-order should decrement number of lost packets - packet = getPacket(sequenceNumber-15, timestamp-45000, 1000) + packet = getPacket(sequenceNumber-6, timestamp-45000, 1000) flowState = r.Update( time.Now(), packet.Header.SequenceNumber, @@ -202,9 +204,9 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - require.Equal(t, uint64(3), r.packetsOutOfOrder) - require.Equal(t, uint64(1), r.packetsDuplicate) - require.Equal(t, uint64(16), r.packetsLost) + require.Equal(t, uint64(1), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) + require.Equal(t, uint64(8), r.packetsLost) // test sequence number history // with a gap @@ -223,7 +225,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.True(t, flowState.HasLoss) require.Equal(t, uint64(sequenceNumber-1), flowState.LossStartInclusive) require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) - require.Equal(t, uint64(17), r.packetsLost) + require.Equal(t, uint64(9), r.packetsLost) require.False(t, r.history.IsSet(uint64(sequenceNumber)-1)) // out-of-order @@ -240,8 +242,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 0, ) require.False(t, flowState.HasLoss) - require.Equal(t, uint64(16), r.packetsLost) - require.Equal(t, uint64(4), r.packetsOutOfOrder) + require.Equal(t, uint64(8), r.packetsLost) + require.Equal(t, uint64(2), r.packetsOutOfOrder) require.True(t, r.history.IsSet(uint64(sequenceNumber))) // padding only @@ -257,8 +259,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 25, ) require.False(t, flowState.HasLoss) - require.Equal(t, uint64(16), r.packetsLost) - require.Equal(t, uint64(4), r.packetsOutOfOrder) + require.Equal(t, uint64(8), r.packetsLost) + require.Equal(t, uint64(2), r.packetsOutOfOrder) require.True(t, r.history.IsSet(uint64(sequenceNumber))) require.True(t, r.history.IsSet(uint64(sequenceNumber)-1)) require.True(t, r.history.IsSet(uint64(sequenceNumber)-2)) diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go index 04ae1ce7c..2ca21ff8f 100644 --- a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go +++ b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go @@ -18,6 +18,18 @@ import ( "errors" ) +var ( + ErrDDReaderNoStructure = errors.New("DependencyDescriptorReader: Structure is nil") + ErrDDReaderTemplateWithoutStructure = errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil") + ErrDDReaderTooManyTemplates = errors.New("DependencyDescriptorReader: too many templates") + ErrDDReaderTooManyTemporalLayers = errors.New("DependencyDescriptorReader: too many temporal layers") + ErrDDReaderTooManySpatialLayers = errors.New("DependencyDescriptorReader: too many spatial layers") + ErrDDReaderInvalidTemplateIndex = errors.New("DependencyDescriptorReader: invalid template index") + ErrDDReaderInvalidSpatialLayer = errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions") + ErrDDReaderNumDTIMismatch = errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets") + ErrDDReaderNumChainDiffsMismatch = errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains") +) + type DependencyDescriptorReader struct { // Output. descriptor *DependencyDescriptor @@ -59,7 +71,7 @@ func (r *DependencyDescriptorReader) Parse() (int, error) { if r.structure == nil { r.buffer.Invalidate() - return 0, errors.New("DependencyDescriptorReader: Structure is nil") + return 0, ErrDDReaderNoStructure } if r.activeDecodeTargetsPresentFlag { @@ -140,7 +152,7 @@ func (r *DependencyDescriptorReader) readExtendedFields() error { return err } if r.descriptor.AttachedStructure == nil { - return errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil") + return ErrDDReaderTemplateWithoutStructure } bitmask := uint32((uint64(1) << r.descriptor.AttachedStructure.NumDecodeTargets) - 1) r.descriptor.ActiveDecodeTargetsBitmask = &bitmask @@ -203,7 +215,7 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error { ) for { if len(templates) == MaxTemplates { - return errors.New("DependencyDescriptorReader: too many templates") + return ErrDDReaderTooManyTemplates } var lastTemplate FrameDependencyTemplate @@ -220,13 +232,13 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error { if nextLayerIdc == nextTemporalLayer { temporalId++ if temporalId >= MaxTemporalIds { - return errors.New("DependencyDescriptorReader: too many temporal layers") + return ErrDDReaderTooManyTemporalLayers } } else if nextLayerIdc == nextSpatialLayer { spatialId++ temporalId = 0 if spatialId >= MaxSpatialIds { - return errors.New("DependencyDescriptorReader: too many spatial layers") + return ErrDDReaderTooManySpatialLayers } } @@ -340,7 +352,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { if templateIndex >= len(r.structure.Templates) { r.buffer.Invalidate() - return errors.New("DependencyDescriptorReader: invalid template index") + return ErrDDReaderInvalidTemplateIndex } // Copy all the fields from the matching template @@ -374,7 +386,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { // then each spatial layer got one. if r.descriptor.FrameDependencies.SpatialId >= len(r.structure.Resolutions) { r.buffer.Invalidate() - return errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions") + return ErrDDReaderInvalidSpatialLayer } res := r.structure.Resolutions[r.descriptor.FrameDependencies.SpatialId] r.descriptor.Resolution = &res @@ -385,7 +397,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { func (r *DependencyDescriptorReader) readFrameDtis() error { if len(r.descriptor.FrameDependencies.DecodeTargetIndications) != r.structure.NumDecodeTargets { - return errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets") + return ErrDDReaderNumDTIMismatch } for i := range r.descriptor.FrameDependencies.DecodeTargetIndications { @@ -420,7 +432,7 @@ func (r *DependencyDescriptorReader) readFrameFdiffs() error { func (r *DependencyDescriptorReader) readFrameChains() error { if len(r.descriptor.FrameDependencies.ChainDiffs) != r.structure.NumChains { - return errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains") + return ErrDDReaderNumChainDiffsMismatch } for i := range r.descriptor.FrameDependencies.ChainDiffs { diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go b/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go index 37ce7bcf8..bd7da49e3 100644 --- a/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go +++ b/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go @@ -148,7 +148,7 @@ func (w *DependencyDescriptorWriter) calculateMatch(idx int, template *FrameDepe result.NeedCustomDtis = w.descriptor.FrameDependencies.DecodeTargetIndications != nil && !reflect.DeepEqual(w.descriptor.FrameDependencies.DecodeTargetIndications, template.DecodeTargetIndications) for i := 0; i < w.structure.NumChains; i++ { - if w.activeChains&(1<= len(s.trackers) { + s.logger.Errorw("unexpected layer", nil, "layer", layer) + return nil + } return s.trackers[layer] } diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go index f03f78c6c..3c091160e 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -16,6 +16,7 @@ package videolayerselector import ( "fmt" + "runtime/debug" "sync" "github.com/livekit/livekit-server/pkg/sfu/buffer" @@ -31,6 +32,8 @@ type DependencyDescriptor struct { previousActiveDecodeTargetsBitmask *uint32 activeDecodeTargetsBitmask *uint32 structure *dede.FrameDependencyStructure + extKeyFrameNum uint64 + keyFrameValid bool chains []*FrameChain @@ -79,38 +82,76 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r Temporal: int32(fd.TemporalId), } + if !d.keyFrameValid && dd.AttachedStructure == nil { + return + } + // early return if this frame is already forwarded or dropped sd, err := d.decisions.GetDecision(extFrameNum) if err != nil { // do not mark as dropped as only error is an old frame - d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d", - incomingLayer, - dd.FrameNumber, - extFrameNum, - extPkt.Packet.SequenceNumber, - ), "err", err) + // d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d", + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, + // ), "err", err) return } switch sd { case selectorDecisionDropped: // a packet of an alreadty dropped frame, maintain decision - d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", - incomingLayer, - dd.FrameNumber, - extFrameNum, - extPkt.Packet.SequenceNumber, - )) + // d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, + // )) return } if ddwdt.StructureUpdated { - d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets) + // TODO-REMOVE: remove this log after stable + d.logger.Infow("update dependency structure", + "structureID", dd.AttachedStructure.StructureId, + "structure", dd.AttachedStructure, + "decodeTargets", ddwdt.DecodeTargets, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "isKeyFrame", extPkt.KeyFrame, + "currentKeyframe", d.extKeyFrameNum, + ) + + d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets, extFrameNum) + } + + if ddwdt.ExtKeyFrameNum != d.extKeyFrameNum { + // keyframe mismatch, drop and reset chains + // TODO-REMOVE: remove this log after stable + d.logger.Infow("drop packet for keyframe mismatch", "incoming", incomingLayer, "efn", extFrameNum, "sn", extPkt.Packet.SequenceNumber, "requiredKeyFrame", ddwdt.ExtKeyFrameNum, "structureKeyFrame", d.extKeyFrameNum) + d.decisions.AddDropped(extFrameNum) + d.invalidateKeyFrame() + return } if ddwdt.ActiveDecodeTargetsUpdated { d.updateActiveDecodeTargets(*dd.ActiveDecodeTargetsBitmask) } + // TODO-REMOVE: remove this log after stable + if len(fd.ChainDiffs) != len(d.chains) { + d.logger.Warnw("frame chain diff length mismatch", nil, + "incoming", incomingLayer, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "chainDiffs", fd.ChainDiffs, + "chains", len(d.chains), + "requiredKeyFrame", ddwdt.ExtKeyFrameNum, + "structureKeyFrame", d.extKeyFrameNum) + d.decisions.AddDropped(extFrameNum) + return + } + for _, chain := range d.chains { chain.OnFrame(extFrameNum, fd) } @@ -133,7 +174,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r if err != nil { d.decodeTargetsLock.RUnlock() // dtis error, dependency descriptor might lost - d.logger.Debugw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), "err", err) + d.logger.Warnw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), err) d.decisions.AddDropped(extFrameNum) return } @@ -148,34 +189,34 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r if highestDecodeTarget.Target < 0 { // no active decode target, do not select - d.logger.Debugw( - "drop packet for no target found", - "highestDecodeTarget", highestDecodeTarget, - "decodeTargets", d.decodeTargets, - "tagetLayer", d.targetLayer, - "incoming", incomingLayer, - "fn", dd.FrameNumber, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - ) + // d.logger.Debugw( + // "drop packet for no target found", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) d.decisions.AddDropped(extFrameNum) return } // DD-TODO : if bandwidth in congest, could drop the 'Discardable' frame if dti == dede.DecodeTargetNotPresent { - d.logger.Debugw( - "drop packet for decode target not present", - "highestDecodeTarget", highestDecodeTarget, - "decodeTargets", d.decodeTargets, - "tagetLayer", d.targetLayer, - "incoming", incomingLayer, - "fn", dd.FrameNumber, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - ) + // d.logger.Debugw( + // "drop packet for decode target not present", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) d.decisions.AddDropped(extFrameNum) return } @@ -195,17 +236,17 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r } } if !isDecodable { - d.logger.Debugw( - "drop packet for not decodable", - "highestDecodeTarget", highestDecodeTarget, - "decodeTargets", d.decodeTargets, - "tagetLayer", d.targetLayer, - "incoming", incomingLayer, - "fn", dd.FrameNumber, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - ) + // d.logger.Debugw( + // "drop packet for not decodable", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) d.decisions.AddDropped(extFrameNum) return } @@ -263,11 +304,33 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r // d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask) } } - bytes, err := ddExtension.Marshal() - if err != nil { - d.logger.Warnw("error marshalling dependency descriptor extension", err) - } else { - result.DependencyDescriptorExtension = bytes + + var ddMarshaled bool + func() { + defer func() { + if r := recover(); r != nil { + d.logger.Errorw("panic marshalling dependency descriptor extension", nil, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "keyframeRequired", ddwdt.ExtKeyFrameNum, + "currentKeyframe", d.extKeyFrameNum, + "panic", r, + "stack", string(debug.Stack())) + } + }() + bytes, err := ddExtension.Marshal() + if err != nil { + d.logger.Warnw("error marshalling dependency descriptor extension", err) + } else { + result.DependencyDescriptorExtension = bytes + ddMarshaled = true + } + }() + + if !ddMarshaled { + // drop packet if we can't marshal dependency descriptor + d.decisions.AddDropped(extFrameNum) + return } if ddwdt.Integrity { @@ -284,8 +347,10 @@ func (d *DependencyDescriptor) Rollback() { d.Base.Rollback() } -func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget) { +func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget, extFrameNum uint64) { d.structure = structure + d.extKeyFrameNum = extFrameNum + d.keyFrameValid = true d.chains = d.chains[:0] @@ -329,9 +394,17 @@ func (d *DependencyDescriptor) updateActiveDecodeTargets(activeDecodeTargetsBitm } } +func (d *DependencyDescriptor) invalidateKeyFrame() { + d.keyFrameValid = false + d.chains = d.chains[:0] + d.decodeTargetsLock.Lock() + d.decodeTargets = d.decodeTargets[:0] + d.decodeTargetsLock.Unlock() +} + func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { layer = d.GetRequestSpatial() - if !d.currentLayer.IsValid() { + if !d.currentLayer.IsValid() || !d.keyFrameValid { // always declare not locked when trying to resume from nothing return false, layer } @@ -339,7 +412,7 @@ func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { d.decodeTargetsLock.RLock() defer d.decodeTargetsLock.RUnlock() for _, dt := range d.decodeTargets { - if dt.Active() && dt.Layer.Spatial <= d.GetTarget().Spatial && dt.Valid() { + if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() { d.logger.Debugw(fmt.Sprintf("checking sync, matching decode target, layer: %d, dt: %s, dts: %+v", layer, dt, d.decodeTargets)) return true, layer } diff --git a/pkg/sfu/videolayerselector/dependencydescriptor_test.go b/pkg/sfu/videolayerselector/dependencydescriptor_test.go index c013e46af..2b346dedb 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor_test.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor_test.go @@ -253,11 +253,23 @@ func TestDependencyDescriptor(t *testing.T) { } require.True(t, switchToLower) - // sync with requested layer + // not sync with requested layer ddSelector.SetRequestSpatial(targetLayer.Spatial) locked, layer := ddSelector.CheckSync() - require.True(t, locked) + require.False(t, locked) require.Equal(t, targetLayer.Spatial, layer) + // request to current layer, sync + ddSelector.SetRequestSpatial(ddSelector.GetCurrent().Spatial) + locked, _ = ddSelector.CheckSync() + require.True(t, locked) + + // should drop frame that relies on a keyframe is not present in current selection + framesPrevious := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 1000) + ret = ddSelector.Select(framesPrevious[1], 0) + require.False(t, ret.IsSelected) + // keyframe lost, out of sync + locked, _ = ddSelector.CheckSync() + require.False(t, locked) } func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buffer.ExtPacket { @@ -279,7 +291,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff return decodeTargets[i].Layer.GreaterThan(decodeTargets[j].Layer) }) - chainDiffs := make([]int, len(decodeTargets)) + chainDiffs := make([]int, int(maxLayer.Spatial)+1) dtis := make([]dd.DecodeTargetIndication, len(decodeTargets)) for _, dt := range decodeTargets { dtis[dt.Target] = dd.DecodeTargetSwitch @@ -319,6 +331,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff ActiveDecodeTargetsUpdated: true, Integrity: true, ExtFrameNum: uint64(startFrameNumber), + ExtKeyFrameNum: uint64(startFrameNumber), }, Packet: &rtp.Packet{ Header: rtp.Header{ @@ -356,7 +369,6 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff } frame := &buffer.ExtPacket{ - KeyFrame: true, DependencyDescriptor: &buffer.ExtDependencyDescriptor{ Descriptor: &dd.DependencyDescriptor{ FrameNumber: startFrameNumber, @@ -367,9 +379,10 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff DecodeTargetIndications: frameDtis, }, }, - DecodeTargets: decodeTargets, - Integrity: true, - ExtFrameNum: uint64(startFrameNumber), + DecodeTargets: decodeTargets, + Integrity: true, + ExtFrameNum: uint64(startFrameNumber), + ExtKeyFrameNum: keyFrame.DependencyDescriptor.ExtFrameNum, }, Packet: &rtp.Packet{ Header: rtp.Header{ diff --git a/test/agent.go b/test/agent.go new file mode 100644 index 000000000..76f608314 --- /dev/null +++ b/test/agent.go @@ -0,0 +1,158 @@ +package test + +import ( + "fmt" + "net/http" + "net/url" + "sync" + + "github.com/gorilla/websocket" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" +) + +type agentClient struct { + mu sync.Mutex + conn *websocket.Conn + + registered atomic.Int32 + roomAvailability atomic.Int32 + roomJobs atomic.Int32 + participantAvailability atomic.Int32 + participantJobs atomic.Int32 + + done chan struct{} +} + +func newAgentClient(token string) (*agentClient, error) { + host := fmt.Sprintf("ws://localhost:%d", defaultServerPort) + u, err := url.Parse(host + "/agent") + if err != nil { + return nil, err + } + requestHeader := make(http.Header) + requestHeader.Set("Authorization", "Bearer "+token) + + connectUrl := u.String() + conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader) + if err != nil { + return nil, err + } + + return &agentClient{ + conn: conn, + done: make(chan struct{}), + }, nil +} + +func (c *agentClient) Run() error { + go c.read() + + workerID := utils.NewGuid("W_") + + if err := c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_ROOM, + WorkerId: workerID, + Version: "version", + Name: "name", + }, + }, + }); err != nil { + return err + } + + if err := c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_PUBLISHER, + WorkerId: workerID, + Version: "version", + Name: "name", + }, + }, + }); err != nil { + return err + } + + return nil +} + +func (c *agentClient) read() { + for { + select { + case <-c.done: + return + default: + _, b, err := c.conn.ReadMessage() + if err != nil { + return + } + + msg := &livekit.ServerMessage{} + if err = proto.Unmarshal(b, msg); err != nil { + return + } + + switch m := msg.Message.(type) { + case *livekit.ServerMessage_Assignment: + go c.handleAssignment(m.Assignment) + case *livekit.ServerMessage_Availability: + go c.handleAvailability(m.Availability) + case *livekit.ServerMessage_Register: + go c.handleRegister(m.Register) + } + } + } +} + +func (c *agentClient) handleAssignment(req *livekit.JobAssignment) { + if req.Job.Type == livekit.JobType_JT_ROOM { + c.roomJobs.Inc() + } else { + c.participantJobs.Inc() + } +} + +func (c *agentClient) handleAvailability(req *livekit.AvailabilityRequest) { + if req.Job.Type == livekit.JobType_JT_ROOM { + c.roomAvailability.Inc() + } else { + c.participantAvailability.Inc() + } + + c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Availability{ + Availability: &livekit.AvailabilityResponse{ + JobId: req.Job.Id, + Available: true, + }, + }, + }) +} + +func (c *agentClient) handleRegister(req *livekit.RegisterWorkerResponse) { + c.registered.Inc() +} + +func (c *agentClient) write(msg *livekit.WorkerMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + b, err := proto.Marshal(msg) + if err != nil { + return err + } + + return c.conn.WriteMessage(websocket.BinaryMessage, b) +} + +func (c *agentClient) close() { + close(c.done) + _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + _ = c.conn.Close() +} diff --git a/test/agent_test.go b/test/agent_test.go new file mode 100644 index 000000000..0a45fa9ae --- /dev/null +++ b/test/agent_test.go @@ -0,0 +1,83 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/auth" +) + +func TestAgents(t *testing.T) { + _, finish := setupSingleNodeTest("TestAgents") + defer finish() + + ac1, err := newAgentClient(agentToken()) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken()) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + ac1.Run() + ac2.Run() + + time.Sleep(time.Second * 3) + + require.Equal(t, int32(2), ac1.registered.Load()) + require.Equal(t, int32(2), ac2.registered.Load()) + + c1 := createRTCClient("c1", defaultServerPort, nil) + c2 := createRTCClient("c2", defaultServerPort, nil) + waitUntilConnected(t, c1, c2) + + // publish 2 tracks + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() + + time.Sleep(time.Second * 3) + + require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load()) + require.Equal(t, int32(1), ac1.participantJobs.Load()+ac2.participantJobs.Load()) + + // publish 2 tracks + t3, err := c2.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t3.Stop() + t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t4.Stop() + + time.Sleep(time.Second * 3) + + require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load()) + require.Equal(t, int32(2), ac1.participantJobs.Load()+ac2.participantJobs.Load()) +} + +func agentToken() string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(&auth.VideoGrant{Agent: true}) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +}