From 4fa60247c1b6952129f0dbb2d707986d78ef3d50 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Wed, 1 Nov 2023 16:30:49 +0530 Subject: [PATCH 01/18] Reduce log level (#2209) --- pkg/rtc/transport.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 1a20d6540..ca22ffeec 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -833,7 +833,7 @@ func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelIni } dcCloseHandler := func() { - t.params.Logger.Infow(dc.Label() + " data channel close") + t.params.Logger.Debugw(dc.Label() + " data channel close") } dcErrorHandler := func(err error) { From 9399fb2bfe038971183867b842a0acec48020ab0 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Wed, 1 Nov 2023 19:45:44 +0800 Subject: [PATCH 02/18] Only select alternative codec for video (#2210) * Only select alternative codec for video * Filter out empty mime --- pkg/rtc/mediaengine.go | 8 +++++--- pkg/rtc/participant.go | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index 0edd60565..f9e6c838b 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -134,7 +134,7 @@ func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool return false } -func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string { +func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { // sort these by compatibility, since we are looking for backups if slices.ContainsFunc(enabledCodecs, func(c *livekit.Codec) bool { return strings.EqualFold(c.Mime, webrtc.MimeTypeVP8) @@ -146,8 +146,10 @@ func selectAlternativeCodec(enabledCodecs []*livekit.Codec) string { }) { return webrtc.MimeTypeH264 } - if len(enabledCodecs) > 0 { - return enabledCodecs[0].Mime + for _, c := range enabledCodecs { + if strings.HasPrefix(c.Mime, "video/") { + return c.Mime + } } // uh oh. this should not happen return "" diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 88d7f2aaf..2ef16843d 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1626,7 +1626,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") { mime = "video/" + mime if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { - altCodec := selectAlternativeCodec(p.enabledPublishCodecs) + altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow("falling back to alternative codec", "codec", mime, "altCodec", altCodec, @@ -1639,7 +1639,7 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l mime = "audio/" + mime } - if _, ok := seenCodecs[mime]; ok { + if _, ok := seenCodecs[mime]; ok || mime == "" { continue } seenCodecs[mime] = struct{}{} From f38a5794a06b0aada2dfba034c9054b2352a9ac7 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 2 Nov 2023 00:03:47 +0530 Subject: [PATCH 03/18] fallback to vp8 if no viable codec (#2211) --- pkg/rtc/mediaengine.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/rtc/mediaengine.go b/pkg/rtc/mediaengine.go index f9e6c838b..8bbb3db20 100644 --- a/pkg/rtc/mediaengine.go +++ b/pkg/rtc/mediaengine.go @@ -151,6 +151,6 @@ func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { return c.Mime } } - // uh oh. this should not happen - return "" + // no viable codec in the list of enabled codecs, fall back to the most widely supported codec + return webrtc.MimeTypeVP8 } From 072bb9dd69582140aa322149d4a3fe4f2a26f4d1 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 1 Nov 2023 14:30:29 -0700 Subject: [PATCH 04/18] Update module github.com/pion/interceptor to v0.1.25 (#2208) Generated by renovateBot Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 5df1d6509..801da0746 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/pion/dtls/v2 v2.2.7 github.com/pion/ice/v2 v2.3.11 - github.com/pion/interceptor v0.1.24 + github.com/pion/interceptor v0.1.25 github.com/pion/rtcp v1.2.10 github.com/pion/rtp v1.8.2 github.com/pion/sctp v1.8.9 diff --git a/go.sum b/go.sum index 4aedde2df..92e574678 100644 --- a/go.sum +++ b/go.sum @@ -185,8 +185,8 @@ github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ github.com/pion/ice/v2 v2.3.11 h1:rZjVmUwyT55cmN8ySMpL7rsS8KYsJERsrxJLLxpKhdw= github.com/pion/ice/v2 v2.3.11/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E= github.com/pion/interceptor v0.1.18/go.mod h1:tpvvF4cPM6NGxFA1DUMbhabzQBxdWMATDGEUYOR9x6I= -github.com/pion/interceptor v0.1.24 h1:lN4ua3yUAJCgNKQKcZIM52wFjBgjN0r7shLj91PkJ0c= -github.com/pion/interceptor v0.1.24/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y= +github.com/pion/interceptor v0.1.25 h1:pwY9r7P6ToQ3+IF0bajN0xmk/fNw/suTgaTdlwTDmhc= +github.com/pion/interceptor v0.1.25/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.8 h1:HhicWIg7OX5PVilyBO6plhMetInbzkVJAhbdJiAeVaI= From a6ede46adc1e48dfdbd6089c2cd8cb712cc6d49d Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Wed, 1 Nov 2023 18:04:04 -0700 Subject: [PATCH 05/18] add bounds check to dependency descriptor loop (#2214) --- pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go b/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go index 37ce7bcf8..bd7da49e3 100644 --- a/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go +++ b/pkg/sfu/dependencydescriptor/dependencydescriptorwriter.go @@ -148,7 +148,7 @@ func (w *DependencyDescriptorWriter) calculateMatch(idx int, template *FrameDepe result.NeedCustomDtis = w.descriptor.FrameDependencies.DecodeTargetIndications != nil && !reflect.DeepEqual(w.descriptor.FrameDependencies.DecodeTargetIndications, template.DecodeTargetIndications) for i := 0; i < w.structure.NumChains; i++ { - if w.activeChains&(1< Date: Thu, 2 Nov 2023 11:10:28 +0530 Subject: [PATCH 06/18] Squelching DD reader error. (#2215) Squelching Structure is nil error as it can happen on packets received before a key frame is received. --- pkg/sfu/buffer/buffer.go | 5 ++-- .../dependencydescriptorreader.go | 30 +++++++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 3bb0a6715..471aea927 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -31,6 +31,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/audio" + dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" "github.com/livekit/livekit-server/pkg/sfu/utils" sutils "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/mediatransportutil" @@ -39,8 +40,6 @@ import ( "github.com/livekit/mediatransportutil/pkg/twcc" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - - dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" ) const ( @@ -615,7 +614,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime time.Time, flow if b.ddParser != nil { ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet) if err != nil { - if err != ErrFrameEarlierThanKeyFrame { + if !errors.Is(err, ErrFrameEarlierThanKeyFrame) && !errors.Is(err, dd.ErrDDReaderNoStructure) { b.logger.Warnw("could not parse dependency descriptor", err) } return nil diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go index 04ae1ce7c..2ca21ff8f 100644 --- a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go +++ b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go @@ -18,6 +18,18 @@ import ( "errors" ) +var ( + ErrDDReaderNoStructure = errors.New("DependencyDescriptorReader: Structure is nil") + ErrDDReaderTemplateWithoutStructure = errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil") + ErrDDReaderTooManyTemplates = errors.New("DependencyDescriptorReader: too many templates") + ErrDDReaderTooManyTemporalLayers = errors.New("DependencyDescriptorReader: too many temporal layers") + ErrDDReaderTooManySpatialLayers = errors.New("DependencyDescriptorReader: too many spatial layers") + ErrDDReaderInvalidTemplateIndex = errors.New("DependencyDescriptorReader: invalid template index") + ErrDDReaderInvalidSpatialLayer = errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions") + ErrDDReaderNumDTIMismatch = errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets") + ErrDDReaderNumChainDiffsMismatch = errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains") +) + type DependencyDescriptorReader struct { // Output. descriptor *DependencyDescriptor @@ -59,7 +71,7 @@ func (r *DependencyDescriptorReader) Parse() (int, error) { if r.structure == nil { r.buffer.Invalidate() - return 0, errors.New("DependencyDescriptorReader: Structure is nil") + return 0, ErrDDReaderNoStructure } if r.activeDecodeTargetsPresentFlag { @@ -140,7 +152,7 @@ func (r *DependencyDescriptorReader) readExtendedFields() error { return err } if r.descriptor.AttachedStructure == nil { - return errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil") + return ErrDDReaderTemplateWithoutStructure } bitmask := uint32((uint64(1) << r.descriptor.AttachedStructure.NumDecodeTargets) - 1) r.descriptor.ActiveDecodeTargetsBitmask = &bitmask @@ -203,7 +215,7 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error { ) for { if len(templates) == MaxTemplates { - return errors.New("DependencyDescriptorReader: too many templates") + return ErrDDReaderTooManyTemplates } var lastTemplate FrameDependencyTemplate @@ -220,13 +232,13 @@ func (r *DependencyDescriptorReader) readTemplateLayers() error { if nextLayerIdc == nextTemporalLayer { temporalId++ if temporalId >= MaxTemporalIds { - return errors.New("DependencyDescriptorReader: too many temporal layers") + return ErrDDReaderTooManyTemporalLayers } } else if nextLayerIdc == nextSpatialLayer { spatialId++ temporalId = 0 if spatialId >= MaxSpatialIds { - return errors.New("DependencyDescriptorReader: too many spatial layers") + return ErrDDReaderTooManySpatialLayers } } @@ -340,7 +352,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { if templateIndex >= len(r.structure.Templates) { r.buffer.Invalidate() - return errors.New("DependencyDescriptorReader: invalid template index") + return ErrDDReaderInvalidTemplateIndex } // Copy all the fields from the matching template @@ -374,7 +386,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { // then each spatial layer got one. if r.descriptor.FrameDependencies.SpatialId >= len(r.structure.Resolutions) { r.buffer.Invalidate() - return errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions") + return ErrDDReaderInvalidSpatialLayer } res := r.structure.Resolutions[r.descriptor.FrameDependencies.SpatialId] r.descriptor.Resolution = &res @@ -385,7 +397,7 @@ func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { func (r *DependencyDescriptorReader) readFrameDtis() error { if len(r.descriptor.FrameDependencies.DecodeTargetIndications) != r.structure.NumDecodeTargets { - return errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets") + return ErrDDReaderNumDTIMismatch } for i := range r.descriptor.FrameDependencies.DecodeTargetIndications { @@ -420,7 +432,7 @@ func (r *DependencyDescriptorReader) readFrameFdiffs() error { func (r *DependencyDescriptorReader) readFrameChains() error { if len(r.descriptor.FrameDependencies.ChainDiffs) != r.structure.NumChains { - return errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains") + return ErrDDReaderNumChainDiffsMismatch } for i := range r.descriptor.FrameDependencies.ChainDiffs { From a7a227709a482f05af8c4411e0992be85204a84d Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Thu, 2 Nov 2023 12:33:02 +0530 Subject: [PATCH 07/18] Prevent out-of-bounds access. (#2216) * Prevent out-of-bounds access. Don't know which codec causes a spatial layer three access. Returning nil and also logging so that we know the trackID of offending track. * spelling --- pkg/sfu/streamtrackermanager.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/sfu/streamtrackermanager.go b/pkg/sfu/streamtrackermanager.go index 0aba20ff9..c9a22bff9 100644 --- a/pkg/sfu/streamtrackermanager.go +++ b/pkg/sfu/streamtrackermanager.go @@ -214,9 +214,9 @@ func (s *StreamTrackerManager) AddTracker(layer int32) streamtracker.StreamTrack }) } - s.logger.Debugw("StreamTrackerManager add track", "layer", layer) + s.logger.Debugw("stream tracker add track", "layer", layer) tracker.OnStatusChanged(func(status streamtracker.StreamStatus) { - s.logger.Debugw("StreamTrackerManager OnStatusChanged", "layer", layer, "status", status) + s.logger.Debugw("stream tracker status changed", "layer", layer, "status", status) if status == streamtracker.StreamStatusStopped { s.removeAvailableLayer(layer) } else { @@ -289,6 +289,10 @@ func (s *StreamTrackerManager) GetTracker(layer int32) streamtracker.StreamTrack s.lock.RLock() defer s.lock.RUnlock() + if int(layer) >= len(s.trackers) { + s.logger.Errorw("unexpected layer", nil, "layer", layer) + return nil + } return s.trackers[layer] } From f165ae1fa09d30d3d574e5e095d148b0a74093fc Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 3 Nov 2023 10:14:11 +0530 Subject: [PATCH 08/18] Separate publish and subscribe enabled codecs for finer grained control. (#2217) --- pkg/rtc/participant.go | 37 ++++++++++++++-------------- pkg/rtc/participant_internal_test.go | 27 ++++++++++---------- pkg/service/roommanager.go | 3 ++- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 2ef16843d..47b7d7f01 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -101,7 +101,8 @@ type ParticipantParams struct { PLIThrottleConfig config.PLIThrottleConfig CongestionControlConfig config.CongestionControlConfig // codecs that are enabled for this room - EnabledCodecs []*livekit.Codec + PublishEnabledCodecs []*livekit.Codec + SubscribeEnabledCodecs []*livekit.Codec Logger logger.Logger SimTracks map[uint32]SimulcastTrackInfo Grants *auth.ClaimGrants @@ -249,7 +250,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.state.Store(livekit.ParticipantInfo_JOINING) p.grants = params.Grants p.SetResponseSink(params.Sink) - p.setupEnabledCodecs(params.EnabledCodecs, params.ClientConf.GetDisabledCodecs()) + p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs()) p.supervisor.OnPublicationError(p.onPublicationError) @@ -2334,9 +2335,7 @@ func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket, data []byte) er return err } -func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) { - subscribeCodecs := make([]*livekit.Codec, 0, len(codecs)) - publishCodecs := make([]*livekit.Codec, 0, len(codecs)) +func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Codec, subscribeEnabledCodecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) { shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool { for _, disableCodec := range disabled { // disable codec's fmtp is empty means disable this codec entirely @@ -2346,22 +2345,22 @@ func (p *ParticipantImpl) setupEnabledCodecs(codecs []*livekit.Codec, disabledCo } return false } - for _, c := range codecs { - var publishDisabled bool - var subscribeDisabled bool + + publishCodecs := make([]*livekit.Codec, 0, len(publishEnabledCodecs)) + for _, c := range publishEnabledCodecs { + if shouldDisable(c, disabledCodecs.GetCodecs()) || shouldDisable(c, disabledCodecs.GetPublish()) { + continue + } + publishCodecs = append(publishCodecs, c) + } + p.enabledPublishCodecs = publishCodecs + + subscribeCodecs := make([]*livekit.Codec, 0, len(subscribeEnabledCodecs)) + for _, c := range subscribeEnabledCodecs { if shouldDisable(c, disabledCodecs.GetCodecs()) { - publishDisabled = true - subscribeDisabled = true - } else if shouldDisable(c, disabledCodecs.GetPublish()) { - publishDisabled = true - } - if !publishDisabled { - publishCodecs = append(publishCodecs, c) - } - if !subscribeDisabled { - subscribeCodecs = append(subscribeCodecs, c) + continue } + subscribeCodecs = append(subscribeCodecs, c) } p.enabledSubscribeCodecs = subscribeCodecs - p.enabledPublishCodecs = publishCodecs } diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 2b2ce2505..2256a70ca 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -748,19 +748,20 @@ func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *p } sid := livekit.ParticipantID(utils.NewGuid(utils.ParticipantPrefix)) p, _ := NewParticipant(ParticipantParams{ - SID: sid, - Identity: identity, - Config: rtcConf, - Sink: &routingfakes.FakeMessageSink{}, - ProtocolVersion: opts.protocolVersion, - PLIThrottleConfig: conf.RTC.PLIThrottle, - Grants: grants, - EnabledCodecs: enabledCodecs, - ClientConf: opts.clientConf, - ClientInfo: ClientInfo{ClientInfo: opts.clientInfo}, - Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false), - Telemetry: &telemetryfakes.FakeTelemetryService{}, - VersionGenerator: utils.NewDefaultTimedVersionGenerator(), + SID: sid, + Identity: identity, + Config: rtcConf, + Sink: &routingfakes.FakeMessageSink{}, + ProtocolVersion: opts.protocolVersion, + PLIThrottleConfig: conf.RTC.PLIThrottle, + Grants: grants, + PublishEnabledCodecs: enabledCodecs, + SubscribeEnabledCodecs: enabledCodecs, + ClientConf: opts.clientConf, + ClientInfo: ClientInfo{ClientInfo: opts.clientInfo}, + Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false), + Telemetry: &telemetryfakes.FakeTelemetryService{}, + VersionGenerator: utils.NewDefaultTimedVersionGenerator(), }) p.isPublisher.Store(opts.publisher) p.updateState(livekit.ParticipantInfo_ACTIVE) diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 918e847bf..738d1dd13 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -393,7 +393,8 @@ func (r *RoomManager) StartSession( Trailer: room.Trailer(), PLIThrottleConfig: r.config.RTC.PLIThrottle, CongestionControlConfig: r.config.RTC.CongestionControl, - EnabledCodecs: protoRoom.EnabledCodecs, + PublishEnabledCodecs: protoRoom.EnabledCodecs, + SubscribeEnabledCodecs: protoRoom.EnabledCodecs, Grants: pi.Grants, Logger: pLogger, ClientConf: clientConf, From f247b68ed67902947c0d8bdc3e24f3554739e971 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Fri, 3 Nov 2023 17:49:02 +0800 Subject: [PATCH 09/18] Make sure dd selector uses correct keyframe to select packets (#2218) * Make sure dd selector uses correct keyframe to select packets * Fix test case * remove unsed field --- pkg/sfu/buffer/buffer.go | 3 - pkg/sfu/buffer/dependencydescriptorparser.go | 36 ++-- pkg/sfu/streamtracker/streamtracker_dd.go | 2 +- .../dependencydescriptor.go | 179 ++++++++++++------ .../dependencydescriptor_test.go | 19 +- 5 files changed, 164 insertions(+), 75 deletions(-) diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 471aea927..c6a018972 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -614,9 +614,6 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime time.Time, flow if b.ddParser != nil { ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet) if err != nil { - if !errors.Is(err, ErrFrameEarlierThanKeyFrame) && !errors.Is(err, dd.ErrDDReaderNoStructure) { - b.logger.Warnw("could not parse dependency descriptor", err) - } return nil } else if ddVal != nil { ep.DependencyDescriptor = ddVal diff --git a/pkg/sfu/buffer/dependencydescriptorparser.go b/pkg/sfu/buffer/dependencydescriptorparser.go index 05b675ad0..6c91af260 100644 --- a/pkg/sfu/buffer/dependencydescriptorparser.go +++ b/pkg/sfu/buffer/dependencydescriptorparser.go @@ -27,7 +27,8 @@ import ( ) var ( - ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe") + ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe") + ErrDDStructureAttachedToNonFirstPacket = fmt.Errorf("dependency descriptor structure is attached to non-first packet of a frame") ) type DependencyDescriptorParser struct { @@ -39,7 +40,6 @@ type DependencyDescriptorParser struct { seqWrapAround *utils.WrapAround[uint16, uint64] frameWrapAround *utils.WrapAround[uint16, uint64] - structureExtSeq uint64 structureExtFrameNum uint64 activeDecodeTargetsExtSeq uint64 activeDecodeTargetsMask uint32 @@ -66,12 +66,15 @@ type ExtDependencyDescriptor struct { ActiveDecodeTargetsUpdated bool Integrity bool ExtFrameNum uint64 + // the frame number of the keyframe which the current frame depends on + ExtKeyFrameNum uint64 } func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescriptor, VideoLayer, error) { var videoLayer VideoLayer ddBuf := pkt.GetExtension(r.ddExtID) if ddBuf == nil { + r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber) return nil, videoLayer, nil } @@ -82,7 +85,9 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } _, err := ext.Unmarshal(ddBuf) if err != nil { - // r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf)) + if err != dd.ErrDDReaderNoStructure { + r.logger.Warnw("failed to parse generic dependency descriptor", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf)) + } return nil, videoLayer, err } @@ -108,17 +113,21 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } if ddVal.AttachedStructure != nil { - r.logger.Debugw("parsed dependency descriptor", "extSeq", extSeq, "extFN", extFN, "structureID", ddVal.AttachedStructure.StructureId, "descriptor", ddVal.String()) - if extSeq > r.structureExtSeq { - r.structure = ddVal.AttachedStructure - r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure) - r.structureExtSeq = extSeq - r.structureExtFrameNum = extFN - extDD.StructureUpdated = true - extDD.ActiveDecodeTargetsUpdated = true - // The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present, - // so don't need to notify max layer change here. + if !ddVal.FirstPacketInFrame { + r.logger.Warnw("attached structure is not the first packet in frame", nil, "extSeq", extSeq, "extFN", extFN) + return nil, videoLayer, ErrDDStructureAttachedToNonFirstPacket } + + if r.structure == nil || ddVal.AttachedStructure.StructureId != r.structure.StructureId { + r.logger.Infow("structure updated", "structureID", ddVal.AttachedStructure.StructureId, "extSeq", extSeq, "extFN", extFN, "descriptor", ddVal.String()) + } + r.structure = ddVal.AttachedStructure + r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure) + r.structureExtFrameNum = extFN + extDD.StructureUpdated = true + extDD.ActiveDecodeTargetsUpdated = true + // The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present, + // so don't need to notify max layer change here. } if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil && extSeq > r.activeDecodeTargetsExtSeq { @@ -143,6 +152,7 @@ func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescr } extDD.DecodeTargets = r.decodeTargets + extDD.ExtKeyFrameNum = r.structureExtFrameNum return extDD, videoLayer, nil } diff --git a/pkg/sfu/streamtracker/streamtracker_dd.go b/pkg/sfu/streamtracker/streamtracker_dd.go index b0f74ef44..be8009eb1 100644 --- a/pkg/sfu/streamtracker/streamtracker_dd.go +++ b/pkg/sfu/streamtracker/streamtracker_dd.go @@ -193,7 +193,7 @@ func (s *StreamTrackerDependencyDescriptor) Observe(temporalLayer int32, pktSize for _, dt := range ddVal.DecodeTargets { if len(dtis) <= dt.Target { - s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtls", dtis) + s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtis", dtis) continue } // we are not dropping discardable frames now, so only ingore not present frames diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go index f03f78c6c..f458928d8 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -16,6 +16,7 @@ package videolayerselector import ( "fmt" + "runtime/debug" "sync" "github.com/livekit/livekit-server/pkg/sfu/buffer" @@ -31,6 +32,8 @@ type DependencyDescriptor struct { previousActiveDecodeTargetsBitmask *uint32 activeDecodeTargetsBitmask *uint32 structure *dede.FrameDependencyStructure + extKeyFrameNum uint64 + keyFrameValid bool chains []*FrameChain @@ -79,38 +82,76 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r Temporal: int32(fd.TemporalId), } + if !d.keyFrameValid && dd.AttachedStructure == nil { + return + } + // early return if this frame is already forwarded or dropped sd, err := d.decisions.GetDecision(extFrameNum) if err != nil { // do not mark as dropped as only error is an old frame - d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d", - incomingLayer, - dd.FrameNumber, - extFrameNum, - extPkt.Packet.SequenceNumber, - ), "err", err) + // d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d", + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, + // ), "err", err) return } switch sd { case selectorDecisionDropped: // a packet of an alreadty dropped frame, maintain decision - d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", - incomingLayer, - dd.FrameNumber, - extFrameNum, - extPkt.Packet.SequenceNumber, - )) + // d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, + // )) return } if ddwdt.StructureUpdated { - d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets) + // TODO-REMOVE: remove this log after stable + d.logger.Infow("update dependency structure", + "structureID", dd.AttachedStructure.StructureId, + "structure", dd.AttachedStructure, + "decodeTargets", ddwdt.DecodeTargets, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "isKeyFrame", extPkt.KeyFrame, + "currentKeyframe", d.extKeyFrameNum, + ) + + d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets, extFrameNum) + } + + if ddwdt.ExtKeyFrameNum != d.extKeyFrameNum { + // keyframe mismatch, drop and reset chains + // TODO-REMOVE: remove this log after stable + d.logger.Infow("drop packet for keyframe mismatch", "incoming", incomingLayer, "efn", extFrameNum, "sn", extPkt.Packet.SequenceNumber, "requiredKeyFrame", ddwdt.ExtKeyFrameNum, "structureKeyFrame", d.extKeyFrameNum) + d.decisions.AddDropped(extFrameNum) + d.invalidateKeyFrame() + return } if ddwdt.ActiveDecodeTargetsUpdated { d.updateActiveDecodeTargets(*dd.ActiveDecodeTargetsBitmask) } + // TODO-REMOVE: remove this log after stable + if len(fd.ChainDiffs) != len(d.chains) { + d.logger.Warnw("frame chain diff length mismatch", nil, + "incoming", incomingLayer, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "chainDiffs", fd.ChainDiffs, + "chains", len(d.chains), + "requiredKeyFrame", ddwdt.ExtKeyFrameNum, + "structureKeyFrame", d.extKeyFrameNum) + d.decisions.AddDropped(extFrameNum) + return + } + for _, chain := range d.chains { chain.OnFrame(extFrameNum, fd) } @@ -133,7 +174,7 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r if err != nil { d.decodeTargetsLock.RUnlock() // dtis error, dependency descriptor might lost - d.logger.Debugw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), "err", err) + d.logger.Warnw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), err) d.decisions.AddDropped(extFrameNum) return } @@ -148,34 +189,34 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r if highestDecodeTarget.Target < 0 { // no active decode target, do not select - d.logger.Debugw( - "drop packet for no target found", - "highestDecodeTarget", highestDecodeTarget, - "decodeTargets", d.decodeTargets, - "tagetLayer", d.targetLayer, - "incoming", incomingLayer, - "fn", dd.FrameNumber, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - ) + // d.logger.Debugw( + // "drop packet for no target found", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) d.decisions.AddDropped(extFrameNum) return } // DD-TODO : if bandwidth in congest, could drop the 'Discardable' frame if dti == dede.DecodeTargetNotPresent { - d.logger.Debugw( - "drop packet for decode target not present", - "highestDecodeTarget", highestDecodeTarget, - "decodeTargets", d.decodeTargets, - "tagetLayer", d.targetLayer, - "incoming", incomingLayer, - "fn", dd.FrameNumber, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - ) + // d.logger.Debugw( + // "drop packet for decode target not present", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) d.decisions.AddDropped(extFrameNum) return } @@ -195,17 +236,17 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r } } if !isDecodable { - d.logger.Debugw( - "drop packet for not decodable", - "highestDecodeTarget", highestDecodeTarget, - "decodeTargets", d.decodeTargets, - "tagetLayer", d.targetLayer, - "incoming", incomingLayer, - "fn", dd.FrameNumber, - "efn", extFrameNum, - "sn", extPkt.Packet.SequenceNumber, - "isKeyFrame", extPkt.KeyFrame, - ) + // d.logger.Debugw( + // "drop packet for not decodable", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) d.decisions.AddDropped(extFrameNum) return } @@ -263,11 +304,33 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r // d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask) } } - bytes, err := ddExtension.Marshal() - if err != nil { - d.logger.Warnw("error marshalling dependency descriptor extension", err) - } else { - result.DependencyDescriptorExtension = bytes + + var ddMarshaled bool + func() { + defer func() { + if r := recover(); r != nil { + d.logger.Errorw("panic marshalling dependency descriptor extension", nil, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "keyframeRequired", ddwdt.ExtKeyFrameNum, + "currentKeyframe", d.extKeyFrameNum, + "panic", r, + "stack", string(debug.Stack())) + } + }() + bytes, err := ddExtension.Marshal() + if err != nil { + d.logger.Warnw("error marshalling dependency descriptor extension", err) + } else { + result.DependencyDescriptorExtension = bytes + ddMarshaled = true + } + }() + + if !ddMarshaled { + // drop packet if we can't marshal dependency descriptor + d.decisions.AddDropped(extFrameNum) + return } if ddwdt.Integrity { @@ -284,8 +347,10 @@ func (d *DependencyDescriptor) Rollback() { d.Base.Rollback() } -func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget) { +func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget, extFrameNum uint64) { d.structure = structure + d.extKeyFrameNum = extFrameNum + d.keyFrameValid = true d.chains = d.chains[:0] @@ -329,6 +394,14 @@ func (d *DependencyDescriptor) updateActiveDecodeTargets(activeDecodeTargetsBitm } } +func (d *DependencyDescriptor) invalidateKeyFrame() { + d.keyFrameValid = false + d.chains = d.chains[:0] + d.decodeTargetsLock.Lock() + d.decodeTargets = d.decodeTargets[:0] + d.decodeTargetsLock.Unlock() +} + func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { layer = d.GetRequestSpatial() if !d.currentLayer.IsValid() { diff --git a/pkg/sfu/videolayerselector/dependencydescriptor_test.go b/pkg/sfu/videolayerselector/dependencydescriptor_test.go index c013e46af..863c6755f 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor_test.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor_test.go @@ -258,6 +258,14 @@ func TestDependencyDescriptor(t *testing.T) { locked, layer := ddSelector.CheckSync() require.True(t, locked) require.Equal(t, targetLayer.Spatial, layer) + + // should drop frame that relies on a keyframe is not present in current selection + framesPrevious := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 1000) + ret = ddSelector.Select(framesPrevious[1], 0) + require.False(t, ret.IsSelected) + // keyframe lost, out of sync + locked, _ = ddSelector.CheckSync() + require.False(t, locked) } func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buffer.ExtPacket { @@ -279,7 +287,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff return decodeTargets[i].Layer.GreaterThan(decodeTargets[j].Layer) }) - chainDiffs := make([]int, len(decodeTargets)) + chainDiffs := make([]int, int(maxLayer.Spatial)+1) dtis := make([]dd.DecodeTargetIndication, len(decodeTargets)) for _, dt := range decodeTargets { dtis[dt.Target] = dd.DecodeTargetSwitch @@ -319,6 +327,7 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff ActiveDecodeTargetsUpdated: true, Integrity: true, ExtFrameNum: uint64(startFrameNumber), + ExtKeyFrameNum: uint64(startFrameNumber), }, Packet: &rtp.Packet{ Header: rtp.Header{ @@ -356,7 +365,6 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff } frame := &buffer.ExtPacket{ - KeyFrame: true, DependencyDescriptor: &buffer.ExtDependencyDescriptor{ Descriptor: &dd.DependencyDescriptor{ FrameNumber: startFrameNumber, @@ -367,9 +375,10 @@ func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buff DecodeTargetIndications: frameDtis, }, }, - DecodeTargets: decodeTargets, - Integrity: true, - ExtFrameNum: uint64(startFrameNumber), + DecodeTargets: decodeTargets, + Integrity: true, + ExtFrameNum: uint64(startFrameNumber), + ExtKeyFrameNum: keyFrame.DependencyDescriptor.ExtFrameNum, }, Packet: &rtp.Packet{ Header: rtp.Header{ From f5047ab653d1eb3fdda36177facfeffa25e6543f Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Fri, 3 Nov 2023 18:40:36 +0800 Subject: [PATCH 10/18] Check request layer for DD selector sync (#2219) --- pkg/sfu/videolayerselector/dependencydescriptor.go | 4 ++-- pkg/sfu/videolayerselector/dependencydescriptor_test.go | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go index f458928d8..3c091160e 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -404,7 +404,7 @@ func (d *DependencyDescriptor) invalidateKeyFrame() { func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { layer = d.GetRequestSpatial() - if !d.currentLayer.IsValid() { + if !d.currentLayer.IsValid() || !d.keyFrameValid { // always declare not locked when trying to resume from nothing return false, layer } @@ -412,7 +412,7 @@ func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { d.decodeTargetsLock.RLock() defer d.decodeTargetsLock.RUnlock() for _, dt := range d.decodeTargets { - if dt.Active() && dt.Layer.Spatial <= d.GetTarget().Spatial && dt.Valid() { + if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() { d.logger.Debugw(fmt.Sprintf("checking sync, matching decode target, layer: %d, dt: %s, dts: %+v", layer, dt, d.decodeTargets)) return true, layer } diff --git a/pkg/sfu/videolayerselector/dependencydescriptor_test.go b/pkg/sfu/videolayerselector/dependencydescriptor_test.go index 863c6755f..2b346dedb 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor_test.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor_test.go @@ -253,11 +253,15 @@ func TestDependencyDescriptor(t *testing.T) { } require.True(t, switchToLower) - // sync with requested layer + // not sync with requested layer ddSelector.SetRequestSpatial(targetLayer.Spatial) locked, layer := ddSelector.CheckSync() - require.True(t, locked) + require.False(t, locked) require.Equal(t, targetLayer.Spatial, layer) + // request to current layer, sync + ddSelector.SetRequestSpatial(ddSelector.GetCurrent().Spatial) + locked, _ = ddSelector.CheckSync() + require.True(t, locked) // should drop frame that relies on a keyframe is not present in current selection framesPrevious := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 1000) From 60374c64025440fc23d6e3826d6906f3b58f4f26 Mon Sep 17 00:00:00 2001 From: David Colburn Date: Fri, 3 Nov 2023 11:43:35 -0700 Subject: [PATCH 11/18] Agents (#2203) * agents * add test * undo name changes * remove debug logs * fixes * fix data race in test --- go.mod | 2 +- go.sum | 4 +- pkg/rtc/{room_egress.go => clients.go} | 4 + pkg/rtc/room.go | 30 +- pkg/rtc/room_test.go | 2 +- pkg/service/agentservice.go | 446 +++++++++++++++++++++++++ pkg/service/clients.go | 87 +++++ pkg/service/egress.go | 48 +-- pkg/service/interfaces.go | 10 - pkg/service/roomallocator.go | 4 +- pkg/service/roommanager.go | 5 +- pkg/service/roomservice.go | 37 +- pkg/service/roomservice_test.go | 1 + pkg/service/rtcservice.go | 36 +- pkg/service/server.go | 4 + pkg/service/wire.go | 4 +- pkg/service/wire_gen.go | 15 +- pkg/service/wsprotocol.go | 56 ++++ test/agent.go | 158 +++++++++ test/agent_test.go | 69 ++++ 20 files changed, 931 insertions(+), 91 deletions(-) rename pkg/rtc/{room_egress.go => clients.go} (98%) create mode 100644 pkg/service/agentservice.go create mode 100644 pkg/service/clients.go create mode 100644 test/agent.go create mode 100644 test/agent_test.go diff --git a/go.mod b/go.mod index 801da0746..654e94312 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e - github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 + github.com/livekit/protocol v1.9.1-0.20231103182211-6d382559cf42 github.com/livekit/psrpc v0.5.0 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 92e574678..675a1fd92 100644 --- a/go.sum +++ b/go.sum @@ -125,8 +125,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e h1:yNeIo7MSMUWgoLu7LkNKnBYnJBFPFH9Wq4S6h1kS44M= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY= -github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1 h1:WPWxU9w5XHAsonxnSSIIXbWMty9b5uHnTnyKS9TpaXM= -github.com/livekit/protocol v1.8.2-0.20231101040827-02a4a42603b1/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= +github.com/livekit/protocol v1.9.1-0.20231103182211-6d382559cf42 h1:uDziAK5uhQPOj0fCKl+YyJx51tdFORLjC+rHgNNBCmY= +github.com/livekit/protocol v1.9.1-0.20231103182211-6d382559cf42/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= github.com/livekit/psrpc v0.5.0 h1:g+yYNSs6Y1/vM7UlFkB2s/ARe2y3RKWZhX8ata5j+eo= github.com/livekit/psrpc v0.5.0/go.mod h1:1XYH1LLoD/YbvBvt6xg2KQ/J3InLXSJK6PL/+DKmuAU= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= diff --git a/pkg/rtc/room_egress.go b/pkg/rtc/clients.go similarity index 98% rename from pkg/rtc/room_egress.go rename to pkg/rtc/clients.go index db6d12813..eeaf22869 100644 --- a/pkg/rtc/room_egress.go +++ b/pkg/rtc/clients.go @@ -27,6 +27,10 @@ import ( "github.com/livekit/protocol/webhook" ) +type AgentClient interface { + JobRequest(ctx context.Context, job *livekit.Job) +} + type EgressLauncher interface { StartEgress(context.Context, *rpc.StartEgressRequest) (*livekit.EgressInfo, error) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 77ef5d1bb..7ac27ca4a 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -74,6 +74,7 @@ type Room struct { audioConfig *config.AudioConfig serverInfo *livekit.ServerInfo telemetry telemetry.TelemetryService + agentClient AgentClient egressLauncher EgressLauncher trackManager *RoomTrackManager @@ -113,6 +114,7 @@ func NewRoom( audioConfig *config.AudioConfig, serverInfo *livekit.ServerInfo, telemetry telemetry.TelemetryService, + agentClient AgentClient, egressLauncher EgressLauncher, ) *Room { r := &Room{ @@ -127,6 +129,7 @@ func NewRoom( audioConfig: audioConfig, telemetry: telemetry, egressLauncher: egressLauncher, + agentClient: agentClient, trackManager: NewRoomTrackManager(), serverInfo: serverInfo, participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), @@ -898,10 +901,21 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. r.trackManager.AddTrack(track, participant.Identity(), participant.ID()) - // auto egress - if r.internal != nil { - if r.internal.ParticipantEgress != nil { - if _, hasPublished := r.hasPublished.Swap(participant.Identity(), true); !hasPublished { + // launch jobs + _, hasPublished := r.hasPublished.Swap(participant.Identity(), true) + if !hasPublished { + if r.agentClient != nil { + go func() { + r.agentClient.JobRequest(context.Background(), &livekit.Job{ + Id: utils.NewGuid("JP_"), + Type: livekit.JobType_JT_PUBLISHER, + Room: r.protoRoom, + Participant: participant.ToProto(), + }) + }() + } + if r.internal != nil && r.internal.ParticipantEgress != nil { + go func() { if err := StartParticipantEgress( context.Background(), r.egressLauncher, @@ -913,9 +927,11 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. ); err != nil { r.Logger.Errorw("failed to launch participant egress", err) } - } + }() } - if r.internal.TrackEgress != nil { + } + if r.internal != nil && r.internal.TrackEgress != nil { + go func() { if err := StartTrackEgress( context.Background(), r.egressLauncher, @@ -927,7 +943,7 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. ); err != nil { r.Logger.Errorw("failed to launch track egress", err) } - } + }() } } diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index 0e31e4772..b86527583 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -739,7 +739,7 @@ func newRoomWithParticipants(t *testing.T, opts testRoomOpts) *Room { Region: "testregion", }, telemetry.NewTelemetryService(webhook.NewDefaultNotifier("", "", nil), &telemetryfakes.FakeAnalyticsService{}), - nil, + nil, nil, ) for i := 0; i < opts.num+opts.numHidden; i++ { identity := livekit.ParticipantIdentity(fmt.Sprintf("p%d", i)) diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go new file mode 100644 index 000000000..bb7a9afb0 --- /dev/null +++ b/pkg/service/agentservice.go @@ -0,0 +1,446 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "io" + "math/rand" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" +) + +type AgentService struct { + upgrader websocket.Upgrader + + *AgentHandler +} + +type AgentHandler struct { + agentServer rpc.AgentInternalServer + roomTopic string + participantTopic string + + mu sync.Mutex + availability map[string]chan *availability + unregistered map[*websocket.Conn]*worker + roomRegistered bool + roomWorkers map[string]*worker + participantRegistered bool + participantWorkers map[string]*worker +} + +type worker struct { + mu sync.Mutex + conn *websocket.Conn + sigConn *WSSignalConnection + + id string + jobType livekit.JobType + status livekit.WorkerStatus + activeJobs int +} + +type availability struct { + workerID string + available bool +} + +func NewAgentService(bus psrpc.MessageBus) (*AgentService, error) { + s := &AgentService{ + upgrader: websocket.Upgrader{}, + } + + // allow connections from any origin, since script may be hosted anywhere + // security is enforced by access tokens + s.upgrader.CheckOrigin = func(r *http.Request) bool { + return true + } + + agentServer, err := rpc.NewAgentInternalServer(s, bus) + if err != nil { + return nil, err + } + s.AgentHandler = NewAgentHandler(agentServer, "room", "participant") + + return s, nil +} + +func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) { + // reject non websocket requests + if !websocket.IsWebSocketUpgrade(r) { + writer.WriteHeader(404) + return + } + + // require a claim + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil || !claims.Video.Agent { + handleError(writer, http.StatusUnauthorized, rtc.ErrPermissionDenied) + return + } + + // upgrade + conn, err := s.upgrader.Upgrade(writer, r, nil) + if err != nil { + handleError(writer, http.StatusInternalServerError, err) + return + } + + s.HandleConnection(conn) +} + +func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, participantTopic string) *AgentHandler { + return &AgentHandler{ + agentServer: agentServer, + roomTopic: roomTopic, + participantTopic: participantTopic, + availability: make(map[string]chan *availability), + unregistered: make(map[*websocket.Conn]*worker), + roomWorkers: make(map[string]*worker), + participantWorkers: make(map[string]*worker), + } +} + +func (s *AgentHandler) HandleConnection(conn *websocket.Conn) { + sigConn := NewWSSignalConnection(conn) + w := &worker{ + conn: conn, + sigConn: sigConn, + } + + s.mu.Lock() + s.unregistered[conn] = w + s.mu.Unlock() + + defer func() { + s.mu.Lock() + if w.id == "" { + delete(s.unregistered, conn) + } else { + switch w.jobType { + case livekit.JobType_JT_ROOM: + delete(s.roomWorkers, w.id) + if s.roomRegistered && !s.roomAvailableLocked() { + s.roomRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.roomTopic) + } + case livekit.JobType_JT_PUBLISHER: + delete(s.participantWorkers, w.id) + if s.participantRegistered && !s.participantAvailableLocked() { + s.participantRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.participantTopic) + } + } + } + s.mu.Unlock() + }() + + // handle incoming requests from websocket + for { + req, _, err := sigConn.ReadWorkerMessage() + if err != nil { + // normal/expected closure + if err == io.EOF || + strings.HasSuffix(err.Error(), "use of closed network connection") || + strings.HasSuffix(err.Error(), "connection reset by peer") || + websocket.IsCloseError( + err, + websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + ) { + logger.Infow("exit ws read loop for closed connection", "wsError", err) + } else { + logger.Errorw("error reading from websocket", err) + } + return + } + + switch m := req.Message.(type) { + case *livekit.WorkerMessage_Register: + go s.handleRegister(w, m.Register) + case *livekit.WorkerMessage_Availability: + go s.handleAvailability(w, m.Availability) + case *livekit.WorkerMessage_JobUpdate: + go s.handleJobUpdate(w, m.JobUpdate) + case *livekit.WorkerMessage_Status: + go s.handleStatus(w, m.Status) + } + } +} + +func (s *AgentHandler) handleRegister(worker *worker, msg *livekit.RegisterWorkerRequest) { + s.mu.Lock() + defer s.mu.Unlock() + + switch msg.Type { + case livekit.JobType_JT_ROOM: + worker.id = msg.WorkerId + delete(s.unregistered, worker.conn) + s.roomWorkers[worker.id] = worker + + if !s.roomRegistered { + err := s.agentServer.RegisterJobRequestTopic(s.roomTopic) + if err != nil { + logger.Errorw("failed to register room agents", err) + } else { + s.roomRegistered = true + } + } + + case livekit.JobType_JT_PUBLISHER: + worker.id = msg.WorkerId + delete(s.unregistered, worker.conn) + s.participantWorkers[worker.id] = worker + + if !s.participantRegistered { + err := s.agentServer.RegisterJobRequestTopic(s.participantTopic) + if err != nil { + logger.Errorw("failed to register participant agents", err) + } else { + s.participantRegistered = true + } + } + } + + worker.sigConn.WriteServerMessage(&livekit.ServerMessage{ + Message: &livekit.ServerMessage_Register{ + Register: &livekit.RegisterWorkerResponse{ + WorkerId: worker.id, + ServerVersion: "version", + }, + }, + }) +} + +func (s *AgentHandler) handleAvailability(w *worker, msg *livekit.AvailabilityResponse) { + s.mu.Lock() + availabilityChan, ok := s.availability[msg.JobId] + s.mu.Unlock() + + if ok { + availabilityChan <- &availability{ + workerID: w.id, + available: msg.Available, + } + } +} + +func (s *AgentHandler) handleJobUpdate(w *worker, msg *livekit.JobStatusUpdate) { + switch msg.Status { + case livekit.JobStatus_JS_SUCCESS: + logger.Debugw("job complete", "jobID", msg.JobId) + case livekit.JobStatus_JS_FAILED: + logger.Warnw("job failed", errors.New(msg.Error), "jobID", msg.JobId) + } + + w.mu.Lock() + w.activeJobs-- + w.mu.Unlock() +} + +func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) { + s.mu.Lock() + defer s.mu.Unlock() + + w.mu.Lock() + w.status = msg.Status + w.mu.Unlock() + + switch w.jobType { + case livekit.JobType_JT_ROOM: + if s.roomRegistered && !s.roomAvailableLocked() { + s.roomRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.roomTopic) + } else if !s.roomRegistered && s.roomAvailableLocked() { + if err := s.agentServer.RegisterJobRequestTopic(s.roomTopic); err != nil { + logger.Errorw("failed to register room agents", err) + } else { + s.roomRegistered = true + } + } + case livekit.JobType_JT_PUBLISHER: + if s.participantRegistered && !s.participantAvailableLocked() { + s.participantRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.participantTopic) + } else if !s.participantRegistered && s.participantAvailableLocked() { + if err := s.agentServer.RegisterJobRequestTopic(s.participantTopic); err != nil { + logger.Errorw("failed to register participant agents", err) + } else { + s.participantRegistered = true + } + } + } +} + +func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) { + s.mu.Lock() + ac := make(chan *availability, 100) + s.availability[job.Id] = ac + s.mu.Unlock() + + defer func() { + s.mu.Lock() + delete(s.availability, job.Id) + s.mu.Unlock() + }() + + var pool map[string]*worker + switch job.Type { + case livekit.JobType_JT_ROOM: + pool = s.roomWorkers + case livekit.JobType_JT_PUBLISHER: + pool = s.participantWorkers + } + + attempted := make(map[string]bool) + for { + select { + case <-ctx.Done(): + return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out") + default: + s.mu.Lock() + var selected *worker + for _, w := range pool { + if attempted[w.id] { + continue + } + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + if w.activeJobs > 0 { + selected = w + break + } else if selected == nil { + selected = w + } + } + } + s.mu.Unlock() + + if selected == nil { + return nil, psrpc.NewErrorf(psrpc.Unavailable, "no workers available") + } + + attempted[selected.id] = true + _, err := selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Availability{ + Availability: &livekit.AvailabilityRequest{Job: job}, + }}) + if err != nil { + logger.Errorw("failed to send availability request", err) + return nil, err + } + + select { + case <-ctx.Done(): + return nil, psrpc.NewErrorf(psrpc.DeadlineExceeded, "request timed out") + case res := <-ac: + if res.available { + _, err = selected.sigConn.WriteServerMessage(&livekit.ServerMessage{Message: &livekit.ServerMessage_Assignment{ + Assignment: &livekit.JobAssignment{Job: job}, + }}) + if err != nil { + logger.Errorw("failed to assign job", err) + } else { + selected.mu.Lock() + selected.activeJobs++ + selected.mu.Unlock() + return &emptypb.Empty{}, nil + } + } + } + } + } +} + +func (s *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 { + s.mu.Lock() + defer s.mu.Unlock() + + var pool map[string]*worker + switch job.Type { + case livekit.JobType_JT_ROOM: + pool = s.roomWorkers + case livekit.JobType_JT_PUBLISHER: + pool = s.participantWorkers + } + + var affinity float32 + for _, w := range pool { + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + if w.activeJobs > 0 { + return 1 + } else { + affinity = 0.5 + } + } + } + + return affinity +} + +func (s *AgentHandler) DrainConnections(interval time.Duration) { + // jitter drain start + time.Sleep(time.Duration(rand.Int63n(int64(interval)))) + + t := time.NewTicker(interval) + defer t.Stop() + + s.mu.Lock() + defer s.mu.Unlock() + + for conn := range s.unregistered { + _ = conn.Close() + <-t.C + } + for _, w := range s.roomWorkers { + _ = w.conn.Close() + <-t.C + } + for _, w := range s.participantWorkers { + _ = w.conn.Close() + <-t.C + } +} + +func (s *AgentHandler) roomAvailableLocked() bool { + for _, w := range s.roomWorkers { + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + return true + } + } + return false +} + +func (s *AgentHandler) participantAvailableLocked() bool { + for _, w := range s.participantWorkers { + if w.status == livekit.WorkerStatus_WS_AVAILABLE { + return true + } + } + return false +} diff --git a/pkg/service/clients.go b/pkg/service/clients.go new file mode 100644 index 000000000..cad477a93 --- /dev/null +++ b/pkg/service/clients.go @@ -0,0 +1,87 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" +) + +type agentClient struct { + s *AgentService +} + +func NewAgentClient(s *AgentService) rtc.AgentClient { + return &agentClient{s} +} + +func (c *agentClient) JobRequest(ctx context.Context, job *livekit.Job) { + _, _ = c.s.JobRequest(ctx, job) +} + +type IOClient interface { + CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) + GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) + ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) +} + +type egressLauncher struct { + client rpc.EgressClient + io IOClient +} + +func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher { + if client == nil { + return nil + } + return &egressLauncher{ + client: client, + io: io, + } +} + +func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { + info, err := s.StartEgressWithClusterId(ctx, "", req) + if err != nil { + return nil, err + } + + _, err = s.io.CreateEgress(ctx, info) + if err != nil { + logger.Errorw("failed to create egress", err) + } + + return info, nil +} + +func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { + if s.client == nil { + return nil, ErrEgressNotConnected + } + + // Ensure we have an Egress ID + if req.EgressId == "" { + req.EgressId = utils.NewGuid(utils.EgressPrefix) + } + + return s.client.StartEgress(ctx, clusterId, req) +} diff --git a/pkg/service/egress.go b/pkg/service/egress.go index d971aaa34..a745caaf0 100644 --- a/pkg/service/egress.go +++ b/pkg/service/egress.go @@ -25,40 +25,23 @@ import ( "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/protocol/egress" "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" - "github.com/livekit/protocol/utils" ) type EgressService struct { + launcher rtc.EgressLauncher client rpc.EgressClient io IOClient roomService livekit.RoomService store ServiceStore - launcher rtc.EgressLauncher -} - -type egressLauncher struct { - client rpc.EgressClient - io IOClient -} - -func NewEgressLauncher(client rpc.EgressClient, io IOClient) rtc.EgressLauncher { - if client == nil { - return nil - } - return &egressLauncher{ - client: client, - io: io, - } } func NewEgressService( client rpc.EgressClient, + launcher rtc.EgressLauncher, store ServiceStore, io IOClient, rs livekit.RoomService, - launcher rtc.EgressLauncher, ) *EgressService { return &EgressService{ client: client, @@ -189,33 +172,6 @@ func (s *EgressService) startEgress(ctx context.Context, roomName livekit.RoomNa return s.launcher.StartEgress(ctx, req) } -func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { - info, err := s.StartEgressWithClusterId(ctx, "", req) - if err != nil { - return nil, err - } - - _, err = s.io.CreateEgress(ctx, info) - if err != nil { - logger.Errorw("failed to create egress", err) - } - - return info, nil -} - -func (s *egressLauncher) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { - if s.client == nil { - return nil, ErrEgressNotConnected - } - - // Ensure we have an Egress ID - if req.EgressId == "" { - req.EgressId = utils.NewGuid(utils.EgressPrefix) - } - - return s.client.StartEgress(ctx, clusterId, req) -} - type LayoutMetadata struct { Layout string `json:"layout"` } diff --git a/pkg/service/interfaces.go b/pkg/service/interfaces.go index 653269225..6c32d7c5c 100644 --- a/pkg/service/interfaces.go +++ b/pkg/service/interfaces.go @@ -18,10 +18,7 @@ import ( "context" "time" - "google.golang.org/protobuf/types/known/emptypb" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/rpc" ) //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate @@ -74,13 +71,6 @@ type IngressStore interface { DeleteIngress(ctx context.Context, info *livekit.IngressInfo) error } -//counterfeiter:generate . IOClient -type IOClient interface { - CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) - GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) - ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) -} - //counterfeiter:generate . RoomAllocator type RoomAllocator interface { CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, bool, error) diff --git a/pkg/service/roomallocator.go b/pkg/service/roomallocator.go index f222f7b15..cf88261c6 100644 --- a/pkg/service/roomallocator.go +++ b/pkg/service/roomallocator.go @@ -60,8 +60,10 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre }() // find existing room and update it + var created bool rm, internal, err := r.roomStore.LoadRoom(ctx, livekit.RoomName(req.Name), true) if err == ErrRoomNotFound { + created = true rm = &livekit.Room{ Sid: utils.NewGuid(utils.RoomPrefix), Name: req.Name, @@ -114,7 +116,7 @@ func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.Cre return nil, false, routing.ErrNodeLimitReached } - return rm, false, nil + return rm, created, nil } // select a new node diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 738d1dd13..d26eec0d5 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -69,6 +69,7 @@ type RoomManager struct { roomStore ObjectStore telemetry telemetry.TelemetryService clientConfManager clientconfiguration.ClientConfigurationManager + agentClient rtc.AgentClient egressLauncher rtc.EgressLauncher versionGenerator utils.TimedVersionGenerator turnAuthHandler *TURNAuthHandler @@ -89,6 +90,7 @@ func NewLocalRoomManager( router routing.Router, telemetry telemetry.TelemetryService, clientConfManager clientconfiguration.ClientConfigurationManager, + agentClient rtc.AgentClient, egressLauncher rtc.EgressLauncher, versionGenerator utils.TimedVersionGenerator, turnAuthHandler *TURNAuthHandler, @@ -108,6 +110,7 @@ func NewLocalRoomManager( telemetry: telemetry, clientConfManager: clientConfManager, egressLauncher: egressLauncher, + agentClient: agentClient, versionGenerator: versionGenerator, turnAuthHandler: turnAuthHandler, bus: bus, @@ -527,7 +530,7 @@ func (r *RoomManager) getOrCreateRoom(ctx context.Context, roomName livekit.Room } // construct ice servers - newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.egressLauncher) + newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, &r.config.Audio, r.serverInfo, r.telemetry, r.agentClient, r.egressLauncher) roomTopic := rpc.FormatRoomTopic(roomName) roomServer := utils.Must(rpc.NewTypedRoomServer(r, r.bus)) diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index aa5e323d2..0576bd9c2 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -31,6 +31,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" ) // A rooms service that supports a single node @@ -41,6 +42,7 @@ type RoomService struct { router routing.MessageRouter roomAllocator RoomAllocator roomStore ServiceStore + agentClient rtc.AgentClient egressLauncher rtc.EgressLauncher topicFormatter rpc.TopicFormatter roomClient rpc.TypedRoomClient @@ -54,6 +56,7 @@ func NewRoomService( router routing.MessageRouter, roomAllocator RoomAllocator, serviceStore ServiceStore, + agentClient rtc.AgentClient, egressLauncher rtc.EgressLauncher, topicFormatter rpc.TopicFormatter, roomClient rpc.TypedRoomClient, @@ -66,6 +69,7 @@ func NewRoomService( router: router, roomAllocator: roomAllocator, roomStore: serviceStore, + agentClient: agentClient, egressLauncher: egressLauncher, topicFormatter: topicFormatter, roomClient: roomClient, @@ -112,14 +116,23 @@ func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomReq return nil, err } - if created && req.Egress != nil && req.Egress.Room != nil { - egress := &rpc.StartEgressRequest{ - Request: &rpc.StartEgressRequest_RoomComposite{ - RoomComposite: req.Egress.Room, - }, - RoomId: rm.Sid, + if created { + if s.agentClient != nil { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: utils.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: rm, + }) + } + if req.Egress != nil && req.Egress.Room != nil { + egress := &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_RoomComposite{ + RoomComposite: req.Egress.Room, + }, + RoomId: rm.Sid, + } + _, err = s.egressLauncher.StartEgress(ctx, egress) } - _, err = s.egressLauncher.StartEgress(ctx, egress) } return rm, err @@ -427,7 +440,7 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat // no one has joined the room, would not have been created on an RTC node. // in this case, we'd want to run create again - _, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{ + room, created, err := s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{ Name: req.Room, Metadata: req.Metadata, }) @@ -465,6 +478,14 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return nil, err } + if created && s.agentClient != nil { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: utils.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: room, + }) + } + return room, nil } diff --git a/pkg/service/roomservice_test.go b/pkg/service/roomservice_test.go index f15a72cd7..f9de24a51 100644 --- a/pkg/service/roomservice_test.go +++ b/pkg/service/roomservice_test.go @@ -136,6 +136,7 @@ func newTestRoomService(conf config.RoomConfig) *TestRoomService { allocator, store, nil, + nil, rpc.NewTopicFormatter(), &rpcfakes.FakeTypedRoomClient{}, &rpcfakes.FakeTypedParticipantClient{}, diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index b0683115b..884a89656 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -31,17 +31,17 @@ import ( "github.com/ua-parser/uap-go/uaparser" "golang.org/x/exp/maps" - "github.com/livekit/livekit-server/pkg/utils" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" - "github.com/livekit/psrpc" - "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/routing/selector" "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + putil "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" ) type RTCService struct { @@ -54,6 +54,7 @@ type RTCService struct { isDev bool limits config.LimitConfig parser *uaparser.Parser + agentClient rtc.AgentClient telemetry telemetry.TelemetryService mu sync.Mutex @@ -66,6 +67,7 @@ func NewRTCService( store ServiceStore, router routing.MessageRouter, currentNode routing.LocalNode, + agentClient rtc.AgentClient, telemetry telemetry.TelemetryService, ) *RTCService { s := &RTCService{ @@ -78,6 +80,7 @@ func NewRTCService( isDev: conf.Development, limits: conf.Limit, parser: uaparser.NewFromSaved(), + agentClient: agentClient, telemetry: telemetry, connections: map[*websocket.Conn]struct{}{}, } @@ -229,6 +232,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Warnw("failed to start connection, retrying", err, fieldsWithAttempt...) } } + if err != nil { prometheus.IncrementParticipantJoinFail(1) handleError(w, http.StatusInternalServerError, err, loggerFields...) @@ -514,10 +518,16 @@ type connectionResult struct { ResponseSource routing.MessageSource } -func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomName, pi routing.ParticipantInit, timeout time.Duration) (connectionResult, *livekit.SignalResponse, error) { +func (s *RTCService) startConnection( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + timeout time.Duration, +) (connectionResult, *livekit.SignalResponse, error) { var cr connectionResult + var created bool var err error - cr.Room, _, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)}) + cr.Room, created, err = s.roomAllocator.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: string(roomName)}) if err != nil { return cr, nil, err } @@ -538,6 +548,17 @@ func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomN cr.ResponseSource.Close() return cr, nil, err } + + if created && s.agentClient != nil { + go func() { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: putil.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: cr.Room, + }) + }() + } + return cr, initialResponse, nil } @@ -559,5 +580,4 @@ func readInitialResponse(source routing.MessageSource, timeout time.Duration) (* return res, nil } } - } diff --git a/pkg/service/server.go b/pkg/service/server.go index c67555614..f7ade50d2 100644 --- a/pkg/service/server.go +++ b/pkg/service/server.go @@ -48,6 +48,7 @@ type LivekitServer struct { config *config.Config ioService *IOInfoService rtcService *RTCService + agentService *AgentService httpServer *http.Server promServer *http.Server router routing.Router @@ -66,6 +67,7 @@ func NewLivekitServer(conf *config.Config, ingressService *IngressService, ioService *IOInfoService, rtcService *RTCService, + agentService *AgentService, keyProvider auth.KeyProvider, router routing.Router, roomManager *RoomManager, @@ -77,6 +79,7 @@ func NewLivekitServer(conf *config.Config, config: conf, ioService: ioService, rtcService: rtcService, + agentService: agentService, router: router, roomManager: roomManager, signalServer: signalServer, @@ -125,6 +128,7 @@ func NewLivekitServer(conf *config.Config, mux.Handle(egressServer.PathPrefix(), egressServer) mux.Handle(ingressServer.PathPrefix(), ingressServer) mux.Handle("/rtc", rtcService) + mux.Handle("/agent", agentService) mux.HandleFunc("/rtc/validate", rtcService.Validate) mux.HandleFunc("/", s.defaultHandler) diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 2e2b39592..f53297d8f 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -62,16 +62,18 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewIOInfoService, wire.Bind(new(IOClient), new(*IOInfoService)), rpc.NewEgressClient, + rpc.NewIngressClient, getEgressStore, NewEgressLauncher, NewEgressService, - rpc.NewIngressClient, getIngressStore, getIngressConfig, NewIngressService, NewRoomAllocator, NewRoomService, NewRTCService, + NewAgentService, + NewAgentClient, getSignalRelayConfig, NewDefaultSignalServer, routing.NewSignalClient, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 322d125e6..30a61b66c 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -55,6 +55,11 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } + agentService, err := NewAgentService(messageBus) + if err != nil { + return nil, err + } + rtcAgentClient := NewAgentClient(agentService) egressClient, err := rpc.NewEgressClient(messageBus) if err != nil { return nil, err @@ -86,22 +91,22 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient) + roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcAgentClient, rtcEgressLauncher, topicFormatter, roomClient, participantClient) if err != nil { return nil, err } - egressService := NewEgressService(egressClient, objectStore, ioInfoService, roomService, rtcEgressLauncher) + egressService := NewEgressService(egressClient, rtcEgressLauncher, objectStore, ioInfoService, roomService) ingressConfig := getIngressConfig(conf) ingressClient, err := rpc.NewIngressClient(messageBus) if err != nil { return nil, err } ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, roomService, telemetryService) - rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, telemetryService) + rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, rtcAgentClient, telemetryService) clientConfigurationManager := createClientConfiguration() timedVersionGenerator := utils.NewDefaultTimedVersionGenerator() turnAuthHandler := NewTURNAuthHandler(keyProvider) - roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus) + roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcAgentClient, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus) if err != nil { return nil, err } @@ -114,7 +119,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, keyProvider, router, roomManager, signalServer, server, currentNode) + livekitServer, err := NewLivekitServer(conf, roomService, egressService, ingressService, ioInfoService, rtcService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) if err != nil { return nil, err } diff --git a/pkg/service/wsprotocol.go b/pkg/service/wsprotocol.go index 50a4dc2fb..5ac7c744f 100644 --- a/pkg/service/wsprotocol.go +++ b/pkg/service/wsprotocol.go @@ -83,6 +83,40 @@ func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error) } } +func (c *WSSignalConnection) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) { + for { + // handle special messages and pass on the rest + messageType, payload, err := c.conn.ReadMessage() + if err != nil { + return nil, 0, err + } + + msg := &livekit.WorkerMessage{} + switch messageType { + case websocket.BinaryMessage: + if c.useJSON { + c.mu.Lock() + // switch to protobuf if client supports it + c.useJSON = false + c.mu.Unlock() + } + // protobuf encoded + err := proto.Unmarshal(payload, msg) + return msg, len(payload), err + case websocket.TextMessage: + c.mu.Lock() + // json encoded, also write back JSON + c.useJSON = true + c.mu.Unlock() + err := protojson.Unmarshal(payload, msg) + return msg, len(payload), err + default: + logger.Debugw("unsupported message", "message", messageType) + return nil, len(payload), nil + } + } +} + func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, error) { var msgType int var payload []byte @@ -105,6 +139,28 @@ func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, er return len(payload), c.conn.WriteMessage(msgType, payload) } +func (c *WSSignalConnection) WriteServerMessage(msg *livekit.ServerMessage) (int, error) { + var msgType int + var payload []byte + var err error + + c.mu.Lock() + defer c.mu.Unlock() + + if c.useJSON { + msgType = websocket.TextMessage + payload, err = protojson.Marshal(msg) + } else { + msgType = websocket.BinaryMessage + payload, err = proto.Marshal(msg) + } + if err != nil { + return 0, err + } + + return len(payload), c.conn.WriteMessage(msgType, payload) +} + func (c *WSSignalConnection) pingWorker() { for { <-time.After(pingFrequency) diff --git a/test/agent.go b/test/agent.go new file mode 100644 index 000000000..76f608314 --- /dev/null +++ b/test/agent.go @@ -0,0 +1,158 @@ +package test + +import ( + "fmt" + "net/http" + "net/url" + "sync" + + "github.com/gorilla/websocket" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" +) + +type agentClient struct { + mu sync.Mutex + conn *websocket.Conn + + registered atomic.Int32 + roomAvailability atomic.Int32 + roomJobs atomic.Int32 + participantAvailability atomic.Int32 + participantJobs atomic.Int32 + + done chan struct{} +} + +func newAgentClient(token string) (*agentClient, error) { + host := fmt.Sprintf("ws://localhost:%d", defaultServerPort) + u, err := url.Parse(host + "/agent") + if err != nil { + return nil, err + } + requestHeader := make(http.Header) + requestHeader.Set("Authorization", "Bearer "+token) + + connectUrl := u.String() + conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader) + if err != nil { + return nil, err + } + + return &agentClient{ + conn: conn, + done: make(chan struct{}), + }, nil +} + +func (c *agentClient) Run() error { + go c.read() + + workerID := utils.NewGuid("W_") + + if err := c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_ROOM, + WorkerId: workerID, + Version: "version", + Name: "name", + }, + }, + }); err != nil { + return err + } + + if err := c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_PUBLISHER, + WorkerId: workerID, + Version: "version", + Name: "name", + }, + }, + }); err != nil { + return err + } + + return nil +} + +func (c *agentClient) read() { + for { + select { + case <-c.done: + return + default: + _, b, err := c.conn.ReadMessage() + if err != nil { + return + } + + msg := &livekit.ServerMessage{} + if err = proto.Unmarshal(b, msg); err != nil { + return + } + + switch m := msg.Message.(type) { + case *livekit.ServerMessage_Assignment: + go c.handleAssignment(m.Assignment) + case *livekit.ServerMessage_Availability: + go c.handleAvailability(m.Availability) + case *livekit.ServerMessage_Register: + go c.handleRegister(m.Register) + } + } + } +} + +func (c *agentClient) handleAssignment(req *livekit.JobAssignment) { + if req.Job.Type == livekit.JobType_JT_ROOM { + c.roomJobs.Inc() + } else { + c.participantJobs.Inc() + } +} + +func (c *agentClient) handleAvailability(req *livekit.AvailabilityRequest) { + if req.Job.Type == livekit.JobType_JT_ROOM { + c.roomAvailability.Inc() + } else { + c.participantAvailability.Inc() + } + + c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Availability{ + Availability: &livekit.AvailabilityResponse{ + JobId: req.Job.Id, + Available: true, + }, + }, + }) +} + +func (c *agentClient) handleRegister(req *livekit.RegisterWorkerResponse) { + c.registered.Inc() +} + +func (c *agentClient) write(msg *livekit.WorkerMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + b, err := proto.Marshal(msg) + if err != nil { + return err + } + + return c.conn.WriteMessage(websocket.BinaryMessage, b) +} + +func (c *agentClient) close() { + close(c.done) + _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + _ = c.conn.Close() +} diff --git a/test/agent_test.go b/test/agent_test.go new file mode 100644 index 000000000..172dec62f --- /dev/null +++ b/test/agent_test.go @@ -0,0 +1,69 @@ +package test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/auth" +) + +func TestAgents(t *testing.T) { + _, finish := setupSingleNodeTest("TestAgents") + defer finish() + + ac1, err := newAgentClient(agentToken()) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken()) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + ac1.Run() + ac2.Run() + + time.Sleep(time.Second * 3) + + require.Equal(t, int32(2), ac1.registered.Load()) + require.Equal(t, int32(2), ac2.registered.Load()) + + c1 := createRTCClient("c1", defaultServerPort, nil) + c2 := createRTCClient("c2", defaultServerPort, nil) + waitUntilConnected(t, c1, c2) + + // publish 2 tracks + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() + + time.Sleep(time.Second * 3) + + require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load()) + require.Equal(t, int32(1), ac1.participantJobs.Load()+ac2.participantJobs.Load()) + + // publish 2 tracks + t3, err := c2.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t3.Stop() + t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t4.Stop() + + time.Sleep(time.Second * 3) + + require.Equal(t, int32(1), ac1.roomJobs.Load()+ac2.roomJobs.Load()) + require.Equal(t, int32(2), ac1.participantJobs.Load()+ac2.participantJobs.Load()) +} + +func agentToken() string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(&auth.VideoGrant{Agent: true}) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} From 2a941ba58df800120496740533b37b817b96cc91 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Sat, 4 Nov 2023 09:59:23 -0700 Subject: [PATCH 12/18] improve participant hidden (#2220) --- pkg/rtc/participant.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 47b7d7f01..f95d4efe3 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -139,6 +139,7 @@ type ParticipantImpl struct { resSinkMu sync.Mutex resSink routing.MessageSink grants *auth.ClaimGrants + hidden atomic.Bool isPublisher atomic.Bool // when first connected @@ -249,6 +250,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.migrateState.Store(types.MigrateStateInit) p.state.Store(livekit.ParticipantInfo_JOINING) p.grants = params.Grants + p.hidden.Store(p.grants.Video.Hidden) p.SetResponseSink(params.Sink) p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs()) @@ -426,6 +428,7 @@ func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermissio p.params.Logger.Infow("updating participant permission", "permission", permission) video.UpdateFromPermission(permission) + p.hidden.Store(permission.Hidden) p.dirty.Store(true) canPublish := video.GetCanPublish() @@ -1021,10 +1024,7 @@ func (p *ParticipantImpl) CanPublishData() bool { } func (p *ParticipantImpl) Hidden() bool { - p.lock.RLock() - defer p.lock.RUnlock() - - return p.grants.Video.Hidden + return p.hidden.Load() } func (p *ParticipantImpl) IsRecorder() bool { From eaf8834f0cac88cd234204837b2aeba982598117 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 10:29:47 -0700 Subject: [PATCH 13/18] Update module golang.org/x/sync to v0.5.0 (#2204) Generated by renovateBot Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 654e94312..6202fccc7 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,7 @@ require ( github.com/urfave/negroni/v3 v3.0.0 go.uber.org/atomic v1.11.0 golang.org/x/exp v0.0.0-20231006140011-7918f672742d - golang.org/x/sync v0.4.0 + golang.org/x/sync v0.5.0 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 675a1fd92..1c09de005 100644 --- a/go.sum +++ b/go.sum @@ -335,8 +335,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= From 43703538c65a7a80928abb2aacbb9619ad7d0416 Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Sat, 4 Nov 2023 10:32:03 -0700 Subject: [PATCH 14/18] clean up proto logging (#2221) --- pkg/clientconfiguration/staticconfiguration.go | 2 +- pkg/rtc/participant.go | 12 ++++++------ pkg/rtc/uptrackmanager.go | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pkg/clientconfiguration/staticconfiguration.go b/pkg/clientconfiguration/staticconfiguration.go index 2d0fcf984..e0309c26c 100644 --- a/pkg/clientconfiguration/staticconfiguration.go +++ b/pkg/clientconfiguration/staticconfiguration.go @@ -40,7 +40,7 @@ func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit. for _, c := range s.confs { matched, err := c.Match.Match(clientInfo) if err != nil { - logger.Errorw("matchrule failed", err, "clientInfo", clientInfo.String()) + logger.Errorw("matchrule failed", err, "clientInfo", logger.Proto(clientInfo)) continue } if !matched { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index f95d4efe3..b04dd2561 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -714,7 +714,7 @@ func (p *ParticipantImpl) SetMigrateInfo( p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted) p.pendingTracks[t.GetCid()] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}, migrated: true} - p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", ti.String()) + p.pubLogger.Infow("pending track added (migration)", "trackID", ti.Sid, "track", logger.Proto(ti)) } p.pendingTracksLock.Unlock() @@ -738,7 +738,7 @@ func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseRea "sendLeave", sendLeave, "reason", reason.String(), "isExpectedToResume", isExpectedToResume, - "clientInfo", p.params.ClientInfo.String(), + "clientInfo", logger.Proto(p.params.ClientInfo), ) p.clearDisconnectTimer() p.clearMigrationTimer() @@ -1659,12 +1659,12 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l } else { p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti) } - p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", ti.String(), "request", req.String()) + p.pubLogger.Infow("pending track queued", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req)) return nil } p.pendingTracks[req.Cid] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} - p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", ti.String(), "request", req.String()) + p.pubLogger.Infow("pending track added", "trackID", ti.Sid, "track", logger.Proto(ti), "request", logger.Proto(req)) return ti } @@ -1682,7 +1682,7 @@ func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.Trac } func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) { - p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", ti.String()) + p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", logger.Proto(ti)) _ = p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_TrackPublished{ TrackPublished: &livekit.TrackPublishedResponse{ @@ -1794,7 +1794,7 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei } func (p *ParticipantImpl) addMigrateMutedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack { - p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", ti.String()) + p.pubLogger.Infow("add migrate muted track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti)) rtpReceiver := p.TransportManager.GetPublisherRTPReceiver(ti.Mid) if rtpReceiver == nil { p.pubLogger.Errorw("could not find receiver for migrated track", nil, "trackID", ti.Sid) diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 1d1dfa6e5..1c079dc34 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -160,8 +160,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( u.params.Logger.Debugw( "skipping older subscription permission version", "existingValue", perms, - "existingVersion", u.subscriptionPermissionVersion.ToProto().String(), - "requestingValue", subscriptionPermission.String(), + "existingVersion", u.subscriptionPermissionVersion.String(), + "requestingValue", logger.Proto(subscriptionPermission), "requestingVersion", timedVersion.String(), ) u.lock.Unlock() @@ -178,7 +178,7 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( if subscriptionPermission == nil { u.params.Logger.Debugw( "updating subscription permission, setting to nil", - "version", u.subscriptionPermissionVersion.ToProto().String(), + "version", u.subscriptionPermissionVersion.String(), ) // possible to get a nil when migrating u.lock.Unlock() @@ -187,8 +187,8 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( u.params.Logger.Debugw( "updating subscription permission", - "permissions", u.subscriptionPermission.String(), - "version", u.subscriptionPermissionVersion.ToProto().String(), + "permissions", logger.Proto(u.subscriptionPermission), + "version", u.subscriptionPermissionVersion.String(), ) if err := u.parseSubscriptionPermissionsLocked(subscriptionPermission, func(pID livekit.ParticipantID) types.LocalParticipant { u.lock.Unlock() @@ -247,7 +247,7 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { u.publishedTracks[track.ID()] = track } u.lock.Unlock() - u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", track.ToProto().String()) + u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", logger.Proto(track.ToProto())) track.AddOnClose(func() { notifyClose := false From 12a9d74acbbbf2b96f1bae99cd3f03413cb2f518 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Mon, 6 Nov 2023 10:41:56 +0530 Subject: [PATCH 15/18] Do not restart on receiver side. (#2224) * Do not restart on receiver side. Restart with wrap back causes issues in the forwarding path as the subscriber assumes the extended type from receiver side does not restart. Restart was an attempt to include as many packets as possible, but in practice is not super useful. So, taking it out. Can clean up a bit more stuff, but want to run this first and check for any oddities. * fix test --- pkg/sfu/buffer/rtpstats_receiver.go | 47 ++++++++++++------------ pkg/sfu/buffer/rtpstats_receiver_test.go | 32 ++++++++-------- 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/pkg/sfu/buffer/rtpstats_receiver.go b/pkg/sfu/buffer/rtpstats_receiver.go index 3869ebae8..dad7c66ef 100644 --- a/pkg/sfu/buffer/rtpstats_receiver.go +++ b/pkg/sfu/buffer/rtpstats_receiver.go @@ -129,31 +129,30 @@ func (r *RTPStatsReceiver) Update( pktSize := uint64(hdrSize + payloadSize + paddingSize) gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest) if gapSN <= 0 { // duplicate OR out-of-order - if payloadSize == 0 { - // do not start on a padding only packet - if resTS.IsRestart { - r.logger.Infow( - "rolling back timestamp restart", - "tsBefore", resTS.PreExtendedStart, - "tsAfter", r.timestamp.GetExtendedStart(), - "snBefore", resSN.PreExtendedStart, - "snAfter", r.sequenceNumber.GetExtendedStart(), - ) - r.timestamp.RollbackRestart(resTS.PreExtendedStart) - } - if resSN.IsRestart { - r.logger.Infow( - "rolling back sequence number restart", - "snBefore", resSN.PreExtendedStart, - "snAfter", r.sequenceNumber.GetExtendedStart(), - "tsBefore", resTS.PreExtendedStart, - "tsAfter", r.timestamp.GetExtendedStart(), - ) - r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart) - flowState.IsNotHandled = true - return - } + // before start, don't restart + if resTS.IsRestart { + r.logger.Infow( + "rolling back timestamp restart", + "tsBefore", resTS.PreExtendedStart, + "tsAfter", r.timestamp.GetExtendedStart(), + "snBefore", resSN.PreExtendedStart, + "snAfter", r.sequenceNumber.GetExtendedStart(), + ) + r.timestamp.RollbackRestart(resTS.PreExtendedStart) } + if resSN.IsRestart { + r.logger.Infow( + "rolling back sequence number restart", + "snBefore", resSN.PreExtendedStart, + "snAfter", r.sequenceNumber.GetExtendedStart(), + "tsBefore", resTS.PreExtendedStart, + "tsAfter", r.timestamp.GetExtendedStart(), + ) + r.sequenceNumber.RollbackRestart(resSN.PreExtendedStart) + flowState.IsNotHandled = true + return + } + if -gapSN >= cNumSequenceNumbers/2 { r.logger.Warnw( "large sequence number gap negative", nil, diff --git a/pkg/sfu/buffer/rtpstats_receiver_test.go b/pkg/sfu/buffer/rtpstats_receiver_test.go index 34fea0f4b..7a39c1b54 100644 --- a/pkg/sfu/buffer/rtpstats_receiver_test.go +++ b/pkg/sfu/buffer/rtpstats_receiver_test.go @@ -130,7 +130,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - // out-of-order + // out-of-order, would cause a restart which is disallowed packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) flowState = r.Update( time.Now(), @@ -142,14 +142,15 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 0, ) require.False(t, flowState.HasLoss) + require.True(t, flowState.IsNotHandled) require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - require.Equal(t, uint64(1), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsOutOfOrder) require.Equal(t, uint64(0), r.packetsDuplicate) - // duplicate + // duplicate of the above out-of-order packet, but would not be handled as it causes a restart packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) flowState = r.Update( time.Now(), @@ -161,12 +162,13 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 0, ) require.False(t, flowState.HasLoss) + require.True(t, flowState.IsNotHandled) require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - require.Equal(t, uint64(2), r.packetsOutOfOrder) - require.Equal(t, uint64(1), r.packetsDuplicate) + require.Equal(t, uint64(0), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) // loss sequenceNumber += 10 @@ -184,10 +186,10 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.True(t, flowState.HasLoss) require.Equal(t, uint64(sequenceNumber-9), flowState.LossStartInclusive) require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) - require.Equal(t, uint64(17), r.packetsLost) + require.Equal(t, uint64(9), r.packetsLost) // out-of-order should decrement number of lost packets - packet = getPacket(sequenceNumber-15, timestamp-45000, 1000) + packet = getPacket(sequenceNumber-6, timestamp-45000, 1000) flowState = r.Update( time.Now(), packet.Header.SequenceNumber, @@ -202,9 +204,9 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) require.Equal(t, timestamp, r.timestamp.GetHighest()) require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) - require.Equal(t, uint64(3), r.packetsOutOfOrder) - require.Equal(t, uint64(1), r.packetsDuplicate) - require.Equal(t, uint64(16), r.packetsLost) + require.Equal(t, uint64(1), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) + require.Equal(t, uint64(8), r.packetsLost) // test sequence number history // with a gap @@ -223,7 +225,7 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { require.True(t, flowState.HasLoss) require.Equal(t, uint64(sequenceNumber-1), flowState.LossStartInclusive) require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) - require.Equal(t, uint64(17), r.packetsLost) + require.Equal(t, uint64(9), r.packetsLost) require.False(t, r.history.IsSet(uint64(sequenceNumber)-1)) // out-of-order @@ -240,8 +242,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 0, ) require.False(t, flowState.HasLoss) - require.Equal(t, uint64(16), r.packetsLost) - require.Equal(t, uint64(4), r.packetsOutOfOrder) + require.Equal(t, uint64(8), r.packetsLost) + require.Equal(t, uint64(2), r.packetsOutOfOrder) require.True(t, r.history.IsSet(uint64(sequenceNumber))) // padding only @@ -257,8 +259,8 @@ func Test_RTPStatsReceiver_Update(t *testing.T) { 25, ) require.False(t, flowState.HasLoss) - require.Equal(t, uint64(16), r.packetsLost) - require.Equal(t, uint64(4), r.packetsOutOfOrder) + require.Equal(t, uint64(8), r.packetsLost) + require.Equal(t, uint64(2), r.packetsOutOfOrder) require.True(t, r.history.IsSet(uint64(sequenceNumber))) require.True(t, r.history.IsSet(uint64(sequenceNumber)-1)) require.True(t, r.history.IsSet(uint64(sequenceNumber)-2)) From 302facc60dae494f04ec2a94933b9e8a0379427e Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 6 Nov 2023 21:11:39 +0800 Subject: [PATCH 16/18] Reject migration if codec mismatch with published tracks (#2225) * Reject migrated/published track mismatch codec with track info * Check potential codecs * Issue full connect if mismatch * fix codec finding --- pkg/rtc/mediatrack.go | 13 ++++++++-- pkg/rtc/participant.go | 36 ++++++++++++++++++++++------ pkg/rtc/participant_internal_test.go | 2 +- pkg/rtc/participant_sdp.go | 6 ++--- pkg/rtc/types/interfaces.go | 5 +++- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 9b7f1c311..584c953c3 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -239,13 +239,22 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra t.params.Logger.Debugw("AddReceiver", "mime", track.Codec().MimeType) wr := t.MediaTrackReceiver.Receiver(mime) if wr == nil { - var priority int + priority := -1 for idx, c := range t.params.TrackInfo.Codecs { - if strings.HasSuffix(mime, c.MimeType) { + if strings.EqualFold(mime, c.MimeType) { priority = idx break } } + if len(t.params.TrackInfo.Codecs) == 0 { + priority = 0 + } + if priority < 0 { + t.params.Logger.Warnw("could not find codec for webrtc receiver", nil, "webrtcCodec", mime, "track", logger.Proto(t.params.TrackInfo)) + t.lock.Unlock() + return false + } + newWR := sfu.NewWebRTCReceiver( receiver, track, diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index b04dd2561..21134b0f1 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1624,8 +1624,10 @@ func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *l seenCodecs := make(map[string]struct{}) for _, codec := range req.SimulcastCodecs { mime := codec.Codec - if req.Type == livekit.TrackType_VIDEO && !strings.HasPrefix(mime, "video/") { - mime = "video/" + mime + if req.Type == livekit.TrackType_VIDEO { + if !strings.HasPrefix(mime, "video/") { + mime = "video/" + mime + } if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mime}) { altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) p.pubLogger.Infow("falling back to alternative codec", @@ -1762,12 +1764,32 @@ func (p *ParticipantImpl) mediaTrackReceived(track *webrtc.TrackRemote, rtpRecei // use existing media track to handle simulcast mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) if !ok { - signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) + signalCid, ti, migrated := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) if ti == nil { p.pendingTracksLock.Unlock() return nil, false } + // check if the migrated track has correct codec + if migrated && len(ti.Codecs) > 0 { + parameters := rtpReceiver.GetParameters() + var codecFound int + for _, c := range ti.Codecs { + for _, nc := range parameters.Codecs { + if strings.EqualFold(nc.MimeType, c.MimeType) { + codecFound++ + break + } + } + } + if codecFound != len(ti.Codecs) { + p.params.Logger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters) + p.pendingTracksLock.Unlock() + p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch) + return nil, false + } + } + ti.MimeType = track.Codec().MimeType mt = p.addMediaTrack(signalCid, track.ID(), ti) newTrack = true @@ -1976,7 +1998,7 @@ func (p *ParticipantImpl) onUpTrackManagerClose() { p.postRtcp(nil) } -func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { +func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo, bool) { signalCid := clientId pendingInfo := p.pendingTracks[clientId] if pendingInfo == nil { @@ -2012,10 +2034,10 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp // if still not found, we are done if pendingInfo == nil { p.pubLogger.Errorw("track info not published prior to track", nil, "clientId", clientId) - return signalCid, nil + return signalCid, nil, false } - return signalCid, pendingInfo.trackInfos[0] + return signalCid, pendingInfo.trackInfos[0], pendingInfo.migrated } // setStableTrackID either generates a new TrackID or reuses a previously used one @@ -2206,7 +2228,7 @@ func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason scr := types.SignallingCloseReasonUnknown switch reason { - case types.ParticipantCloseReasonPublicationError: + case types.ParticipantCloseReasonPublicationError, types.ParticipantCloseReasonMigrateCodecMismatch: scr = types.SignallingCloseReasonFullReconnectPublicationError case types.ParticipantCloseReasonSubscriptionError: scr = types.SignallingCloseReasonFullReconnectSubscriptionError diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 2256a70ca..afd312bc8 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -284,7 +284,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - _, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO) + _, ti, _ := p.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) }) diff --git a/pkg/rtc/participant_sdp.go b/pkg/rtc/participant_sdp.go index 70829a34c..51e98c478 100644 --- a/pkg/rtc/participant_sdp.go +++ b/pkg/rtc/participant_sdp.go @@ -46,7 +46,7 @@ func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher(offer webrtc.Se } p.pendingTracksLock.RLock() - _, info := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + _, info, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO) // if RED is disabled for this track, don't prefer RED codec in offer disableRed := info != nil && info.DisableRed p.pendingTracksLock.RUnlock() @@ -132,7 +132,7 @@ func (p *ParticipantImpl) setCodecPreferencesVideoForPublisher(offer webrtc.Sess if mt != nil { info = mt.ToProto() } else { - _, info = p.getPendingTrack(streamID, livekit.TrackType_VIDEO) + _, info, _ = p.getPendingTrack(streamID, livekit.TrackType_VIDEO) } if info == nil { @@ -239,7 +239,7 @@ func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescript track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack) if track == nil { p.pendingTracksLock.RLock() - _, ti = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) + _, ti, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO) p.pendingTracksLock.RUnlock() } else { ti = track.TrackInfo(false) diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index da3ee7178..6c2210613 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -104,6 +104,7 @@ const ( ParticipantCloseReasonPublicationError ParticipantCloseReasonSubscriptionError ParticipantCloseReasonDataChannelError + ParticipantCloseReasonMigrateCodecMismatch ) func (p ParticipantCloseReason) String() string { @@ -154,6 +155,8 @@ func (p ParticipantCloseReason) String() string { return "SUBSCRIPTION_ERROR" case ParticipantCloseReasonDataChannelError: return "DATA_CHANNEL_ERROR" + case ParticipantCloseReasonMigrateCodecMismatch: + return "MIGRATE_CODEC_MISMATCH" default: return fmt.Sprintf("%d", int(p)) } @@ -184,7 +187,7 @@ func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason { return livekit.DisconnectReason_SERVER_SHUTDOWN case ParticipantCloseReasonOvercommitted: return livekit.DisconnectReason_SERVER_SHUTDOWN - case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError: + case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch: return livekit.DisconnectReason_STATE_MISMATCH default: // the other types will map to unknown reason From 57643a42ed80be56187f48e3d09f769ff2fd723a Mon Sep 17 00:00:00 2001 From: David Colburn Date: Tue, 7 Nov 2023 19:19:07 -0800 Subject: [PATCH 17/18] Agents enabled check (#2227) * agents enabled check * participant -> publisher * nil check client * add NumConnections * add lock around agent check * do not launch agents against other agents * regen * don't need atomic anymore * update protocol --- go.mod | 2 +- go.sum | 4 +- pkg/rtc/agentclient.go | 92 +++++++++++++++ pkg/rtc/{clients.go => egress.go} | 4 - pkg/rtc/participant.go | 7 ++ pkg/rtc/room.go | 60 +++++++--- pkg/rtc/types/interfaces.go | 1 + .../typesfakes/fake_local_participant.go | 65 +++++++++++ pkg/rtc/types/typesfakes/fake_participant.go | 65 +++++++++++ pkg/service/agentservice.go | 106 +++++++++++------- pkg/service/clients.go | 12 -- pkg/service/roomservice.go | 27 +++-- pkg/service/wire.go | 3 +- pkg/service/wire_gen.go | 14 ++- test/agent_test.go | 14 +++ 15 files changed, 385 insertions(+), 91 deletions(-) create mode 100644 pkg/rtc/agentclient.go rename pkg/rtc/{clients.go => egress.go} (98%) diff --git a/go.mod b/go.mod index 6202fccc7..df947357d 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/jxskiss/base62 v1.1.0 github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e - github.com/livekit/protocol v1.9.1-0.20231103182211-6d382559cf42 + github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e github.com/livekit/psrpc v0.5.0 github.com/mackerelio/go-osstat v0.2.4 github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 1c09de005..c59b38f48 100644 --- a/go.sum +++ b/go.sum @@ -125,8 +125,8 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e h1:yNeIo7MSMUWgoLu7LkNKnBYnJBFPFH9Wq4S6h1kS44M= github.com/livekit/mediatransportutil v0.0.0-20231017082622-43f077b4e60e/go.mod h1:+WIOYwiBMive5T81V8B2wdAc2zQNRjNQiJIcPxMTILY= -github.com/livekit/protocol v1.9.1-0.20231103182211-6d382559cf42 h1:uDziAK5uhQPOj0fCKl+YyJx51tdFORLjC+rHgNNBCmY= -github.com/livekit/protocol v1.9.1-0.20231103182211-6d382559cf42/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= +github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e h1:YShBpEjkEBY7yil2gjMWlkVkxs3OI58LIIYsBdb8aBU= +github.com/livekit/protocol v1.9.1-0.20231107185101-e230ee2d840e/go.mod h1:l2WjlZWErS6vBlQaQyCGwWLt1aOx10XfQTsmvLjJWFQ= github.com/livekit/psrpc v0.5.0 h1:g+yYNSs6Y1/vM7UlFkB2s/ARe2y3RKWZhX8ata5j+eo= github.com/livekit/psrpc v0.5.0/go.mod h1:1XYH1LLoD/YbvBvt6xg2KQ/J3InLXSJK6PL/+DKmuAU= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= diff --git a/pkg/rtc/agentclient.go b/pkg/rtc/agentclient.go new file mode 100644 index 000000000..65be3ba7c --- /dev/null +++ b/pkg/rtc/agentclient.go @@ -0,0 +1,92 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" +) + +const ( + RoomAgentTopic = "room" + PublisherAgentTopic = "publisher" +) + +type AgentClient interface { + CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse + JobRequest(ctx context.Context, job *livekit.Job) +} + +type agentClient struct { + client rpc.AgentInternalClient +} + +func NewAgentClient(bus psrpc.MessageBus) (AgentClient, error) { + client, err := rpc.NewAgentInternalClient(bus) + if err != nil { + return nil, err + } + return &agentClient{client: client}, nil +} + +func (c *agentClient) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) *rpc.CheckEnabledResponse { + res := &rpc.CheckEnabledResponse{} + resChan, err := c.client.CheckEnabled(ctx, req, psrpc.WithRequestTimeout(time.Second)) + if err != nil { + return res + } + + for r := range resChan { + if r.Err != nil { + continue + } + if r.Result.RoomEnabled { + res.RoomEnabled = true + if res.PublisherEnabled { + return res + } + } + if r.Result.PublisherEnabled { + res.PublisherEnabled = true + if res.RoomEnabled { + return res + } + } + } + + return res +} + +func (c *agentClient) JobRequest(ctx context.Context, job *livekit.Job) { + var topic string + var logError bool + switch job.Type { + case livekit.JobType_JT_ROOM: + topic = RoomAgentTopic + case livekit.JobType_JT_PUBLISHER: + topic = PublisherAgentTopic + logError = true + } + + _, err := c.client.JobRequest(ctx, topic, job) + if err != nil && logError { + logger.Warnw("agent job request failed", err) + } +} diff --git a/pkg/rtc/clients.go b/pkg/rtc/egress.go similarity index 98% rename from pkg/rtc/clients.go rename to pkg/rtc/egress.go index eeaf22869..db6d12813 100644 --- a/pkg/rtc/clients.go +++ b/pkg/rtc/egress.go @@ -27,10 +27,6 @@ import ( "github.com/livekit/protocol/webhook" ) -type AgentClient interface { - JobRequest(ctx context.Context, job *livekit.Job) -} - type EgressLauncher interface { StartEgress(context.Context, *rpc.StartEgressRequest) (*livekit.EgressInfo, error) StartEgressWithClusterId(ctx context.Context, clusterId string, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 21134b0f1..404cbf8a2 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -1034,6 +1034,13 @@ func (p *ParticipantImpl) IsRecorder() bool { return p.grants.Video.Recorder } +func (p *ParticipantImpl) IsAgent() bool { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.grants.Video.Agent +} + func (p *ParticipantImpl) VerifySubscribeParticipantInfo(pID livekit.ParticipantID, version uint32) { if !p.IsReady() { // we have not sent a JoinResponse yet. metadata would be covered in JoinResponse diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 7ac27ca4a..e35fbc738 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -30,6 +30,7 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/utils" "github.com/livekit/livekit-server/pkg/config" @@ -74,15 +75,18 @@ type Room struct { audioConfig *config.AudioConfig serverInfo *livekit.ServerInfo telemetry telemetry.TelemetryService - agentClient AgentClient egressLauncher EgressLauncher trackManager *RoomTrackManager + // agents + agentClient AgentClient + publisherAgentsEnabled bool + // map of identity -> Participant participants map[livekit.ParticipantIdentity]types.LocalParticipant participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource - hasPublished sync.Map // map of identity -> bool + hasPublished map[livekit.ParticipantIdentity]bool bufferFactory *buffer.FactoryOfBufferFactory // batch update participant info for non-publishers @@ -135,11 +139,13 @@ func NewRoom( participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), + hasPublished: make(map[livekit.ParticipantIdentity]bool), bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSize), batchedUpdates: make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo), closed: make(chan struct{}), trailer: []byte(utils.RandomSecret()), } + r.protoProxy = utils.NewProtoProxy[*livekit.Room](roomUpdateInterval, r.updateProto) if r.protoRoom.EmptyTimeout == 0 { r.protoRoom.EmptyTimeout = DefaultEmptyTimeout @@ -148,6 +154,21 @@ func NewRoom( r.protoRoom.CreationTime = time.Now().Unix() } + if agentClient != nil { + go func() { + res := r.agentClient.CheckEnabled(context.Background(), &rpc.CheckEnabledRequest{}) + if res.PublisherEnabled { + r.lock.Lock() + r.publisherAgentsEnabled = true + // if there are already published tracks, start the agents + for identity := range r.hasPublished { + r.launchPublisherAgent(r.participants[identity]) + } + r.lock.Unlock() + } + }() + } + go r.audioUpdateWorker() go r.connectionQualityWorker() go r.changeUpdateWorker() @@ -477,6 +498,7 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek delete(r.participants, identity) delete(r.participantOpts, identity) delete(r.participantRequestSources, identity) + delete(r.hasPublished, identity) if !p.Hidden() { r.protoRoom.NumParticipants-- } @@ -512,7 +534,6 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity, pID livek for _, t := range p.GetPublishedTracks() { r.trackManager.RemoveTrack(t) } - r.hasPublished.Delete(p.Identity()) p.OnTrackUpdated(nil) p.OnTrackPublished(nil) @@ -902,17 +923,15 @@ func (r *Room) onTrackPublished(participant types.LocalParticipant, track types. r.trackManager.AddTrack(track, participant.Identity(), participant.ID()) // launch jobs - _, hasPublished := r.hasPublished.Swap(participant.Identity(), true) + r.lock.Lock() + hasPublished := r.hasPublished[participant.Identity()] + r.hasPublished[participant.Identity()] = true + publisherAgentsEnabled := r.publisherAgentsEnabled + r.lock.Unlock() + if !hasPublished { - if r.agentClient != nil { - go func() { - r.agentClient.JobRequest(context.Background(), &livekit.Job{ - Id: utils.NewGuid("JP_"), - Type: livekit.JobType_JT_PUBLISHER, - Room: r.protoRoom, - Participant: participant.ToProto(), - }) - }() + if publisherAgentsEnabled { + r.launchPublisherAgent(participant) } if r.internal != nil && r.internal.ParticipantEgress != nil { go func() { @@ -1302,6 +1321,21 @@ func (r *Room) connectionQualityWorker() { } } +func (r *Room) launchPublisherAgent(p types.Participant) { + if p == nil || p.IsRecorder() || p.IsAgent() { + return + } + + go func() { + r.agentClient.JobRequest(context.Background(), &livekit.Job{ + Id: utils.NewGuid("JP_"), + Type: livekit.JobType_JT_PUBLISHER, + Room: r.ToProto(), + Participant: p.ToProto(), + }) + }() +} + func (r *Room) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "Name": r.protoRoom.Name, diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 6c2210613..bef8abb6e 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -264,6 +264,7 @@ type Participant interface { // permissions Hidden() bool IsRecorder() bool + IsAgent() bool Start() Close(sendLeave bool, reason ParticipantCloseReason, isExpectedToResume bool) error diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 4020502be..9e11470c3 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -408,6 +408,16 @@ type FakeLocalParticipant struct { identityReturnsOnCall map[int]struct { result1 livekit.ParticipantIdentity } + IsAgentStub func() bool + isAgentMutex sync.RWMutex + isAgentArgsForCall []struct { + } + isAgentReturns struct { + result1 bool + } + isAgentReturnsOnCall map[int]struct { + result1 bool + } IsClosedStub func() bool isClosedMutex sync.RWMutex isClosedArgsForCall []struct { @@ -2986,6 +2996,59 @@ func (fake *FakeLocalParticipant) IdentityReturnsOnCall(i int, result1 livekit.P }{result1} } +func (fake *FakeLocalParticipant) IsAgent() bool { + fake.isAgentMutex.Lock() + ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)] + fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct { + }{}) + stub := fake.IsAgentStub + fakeReturns := fake.isAgentReturns + fake.recordInvocation("IsAgent", []interface{}{}) + fake.isAgentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsAgentCallCount() int { + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() + return len(fake.isAgentArgsForCall) +} + +func (fake *FakeLocalParticipant) IsAgentCalls(stub func() bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = stub +} + +func (fake *FakeLocalParticipant) IsAgentReturns(result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + fake.isAgentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsAgentReturnsOnCall(i int, result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + if fake.isAgentReturnsOnCall == nil { + fake.isAgentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAgentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) IsClosed() bool { fake.isClosedMutex.Lock() ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] @@ -6031,6 +6094,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.identityMutex.RLock() defer fake.identityMutex.RUnlock() + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() fake.isClosedMutex.RLock() defer fake.isClosedMutex.RUnlock() fake.isDisconnectedMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index fa92204fe..3f46bdbba 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -118,6 +118,16 @@ type FakeParticipant struct { identityReturnsOnCall map[int]struct { result1 livekit.ParticipantIdentity } + IsAgentStub func() bool + isAgentMutex sync.RWMutex + isAgentArgsForCall []struct { + } + isAgentReturns struct { + result1 bool + } + isAgentReturnsOnCall map[int]struct { + result1 bool + } IsPublisherStub func() bool isPublisherMutex sync.RWMutex isPublisherArgsForCall []struct { @@ -780,6 +790,59 @@ func (fake *FakeParticipant) IdentityReturnsOnCall(i int, result1 livekit.Partic }{result1} } +func (fake *FakeParticipant) IsAgent() bool { + fake.isAgentMutex.Lock() + ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)] + fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct { + }{}) + stub := fake.IsAgentStub + fakeReturns := fake.isAgentReturns + fake.recordInvocation("IsAgent", []interface{}{}) + fake.isAgentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsAgentCallCount() int { + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() + return len(fake.isAgentArgsForCall) +} + +func (fake *FakeParticipant) IsAgentCalls(stub func() bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = stub +} + +func (fake *FakeParticipant) IsAgentReturns(result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + fake.isAgentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsAgentReturnsOnCall(i int, result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + if fake.isAgentReturnsOnCall == nil { + fake.isAgentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAgentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeParticipant) IsPublisher() bool { fake.isPublisherMutex.Lock() ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)] @@ -1318,6 +1381,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.identityMutex.RLock() defer fake.identityMutex.RUnlock() + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() fake.isPublisherMutex.RLock() defer fake.isPublisherMutex.RUnlock() fake.isRecorderMutex.RLock() diff --git a/pkg/service/agentservice.go b/pkg/service/agentservice.go index bb7a9afb0..baaf40da4 100644 --- a/pkg/service/agentservice.go +++ b/pkg/service/agentservice.go @@ -34,6 +34,8 @@ import ( "github.com/livekit/psrpc" ) +const AgentServiceVersion = "0.1.0" + type AgentService struct { upgrader websocket.Upgrader @@ -41,17 +43,17 @@ type AgentService struct { } type AgentHandler struct { - agentServer rpc.AgentInternalServer - roomTopic string - participantTopic string + agentServer rpc.AgentInternalServer + roomTopic string + publisherTopic string - mu sync.Mutex - availability map[string]chan *availability - unregistered map[*websocket.Conn]*worker - roomRegistered bool - roomWorkers map[string]*worker - participantRegistered bool - participantWorkers map[string]*worker + mu sync.Mutex + availability map[string]chan *availability + unregistered map[*websocket.Conn]*worker + roomRegistered bool + roomWorkers map[string]*worker + publisherRegistered bool + publisherWorkers map[string]*worker } type worker struct { @@ -85,7 +87,7 @@ func NewAgentService(bus psrpc.MessageBus) (*AgentService, error) { if err != nil { return nil, err } - s.AgentHandler = NewAgentHandler(agentServer, "room", "participant") + s.AgentHandler = NewAgentHandler(agentServer, rtc.RoomAgentTopic, rtc.PublisherAgentTopic) return s, nil } @@ -114,15 +116,15 @@ func (s *AgentService) ServeHTTP(writer http.ResponseWriter, r *http.Request) { s.HandleConnection(conn) } -func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, participantTopic string) *AgentHandler { +func NewAgentHandler(agentServer rpc.AgentInternalServer, roomTopic, publisherTopic string) *AgentHandler { return &AgentHandler{ - agentServer: agentServer, - roomTopic: roomTopic, - participantTopic: participantTopic, - availability: make(map[string]chan *availability), - unregistered: make(map[*websocket.Conn]*worker), - roomWorkers: make(map[string]*worker), - participantWorkers: make(map[string]*worker), + agentServer: agentServer, + roomTopic: roomTopic, + publisherTopic: publisherTopic, + availability: make(map[string]chan *availability), + unregistered: make(map[*websocket.Conn]*worker), + roomWorkers: make(map[string]*worker), + publisherWorkers: make(map[string]*worker), } } @@ -150,10 +152,10 @@ func (s *AgentHandler) HandleConnection(conn *websocket.Conn) { s.agentServer.DeregisterJobRequestTopic(s.roomTopic) } case livekit.JobType_JT_PUBLISHER: - delete(s.participantWorkers, w.id) - if s.participantRegistered && !s.participantAvailableLocked() { - s.participantRegistered = false - s.agentServer.DeregisterJobRequestTopic(s.participantTopic) + delete(s.publisherWorkers, w.id) + if s.publisherRegistered && !s.publisherAvailableLocked() { + s.publisherRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.publisherTopic) } } } @@ -217,26 +219,29 @@ func (s *AgentHandler) handleRegister(worker *worker, msg *livekit.RegisterWorke case livekit.JobType_JT_PUBLISHER: worker.id = msg.WorkerId delete(s.unregistered, worker.conn) - s.participantWorkers[worker.id] = worker + s.publisherWorkers[worker.id] = worker - if !s.participantRegistered { - err := s.agentServer.RegisterJobRequestTopic(s.participantTopic) + if !s.publisherRegistered { + err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic) if err != nil { - logger.Errorw("failed to register participant agents", err) + logger.Errorw("failed to register publisher agents", err) } else { - s.participantRegistered = true + s.publisherRegistered = true } } } - worker.sigConn.WriteServerMessage(&livekit.ServerMessage{ + _, err := worker.sigConn.WriteServerMessage(&livekit.ServerMessage{ Message: &livekit.ServerMessage_Register{ Register: &livekit.RegisterWorkerResponse{ WorkerId: worker.id, - ServerVersion: "version", + ServerVersion: AgentServiceVersion, }, }, }) + if err != nil { + logger.Errorw("failed to write server message", err) + } } func (s *AgentHandler) handleAvailability(w *worker, msg *livekit.AvailabilityResponse) { @@ -286,19 +291,29 @@ func (s *AgentHandler) handleStatus(w *worker, msg *livekit.UpdateWorkerStatus) } } case livekit.JobType_JT_PUBLISHER: - if s.participantRegistered && !s.participantAvailableLocked() { - s.participantRegistered = false - s.agentServer.DeregisterJobRequestTopic(s.participantTopic) - } else if !s.participantRegistered && s.participantAvailableLocked() { - if err := s.agentServer.RegisterJobRequestTopic(s.participantTopic); err != nil { - logger.Errorw("failed to register participant agents", err) + if s.publisherRegistered && !s.publisherAvailableLocked() { + s.publisherRegistered = false + s.agentServer.DeregisterJobRequestTopic(s.publisherTopic) + } else if !s.publisherRegistered && s.publisherAvailableLocked() { + if err := s.agentServer.RegisterJobRequestTopic(s.publisherTopic); err != nil { + logger.Errorw("failed to register publisher agents", err) } else { - s.participantRegistered = true + s.publisherRegistered = true } } } } +func (s *AgentHandler) CheckEnabled(_ context.Context, _ *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) { + s.mu.Lock() + res := &rpc.CheckEnabledResponse{ + RoomEnabled: len(s.roomWorkers) > 0, + PublisherEnabled: len(s.publisherWorkers) > 0, + } + s.mu.Unlock() + return res, nil +} + func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*emptypb.Empty, error) { s.mu.Lock() ac := make(chan *availability, 100) @@ -316,7 +331,7 @@ func (s *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*empty case livekit.JobType_JT_ROOM: pool = s.roomWorkers case livekit.JobType_JT_PUBLISHER: - pool = s.participantWorkers + pool = s.publisherWorkers } attempted := make(map[string]bool) @@ -386,7 +401,7 @@ func (s *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) case livekit.JobType_JT_ROOM: pool = s.roomWorkers case livekit.JobType_JT_PUBLISHER: - pool = s.participantWorkers + pool = s.publisherWorkers } var affinity float32 @@ -403,6 +418,13 @@ func (s *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) return affinity } +func (s *AgentHandler) NumConnections() int { + s.mu.Lock() + defer s.mu.Unlock() + + return len(s.unregistered) + len(s.roomWorkers) + len(s.publisherWorkers) +} + func (s *AgentHandler) DrainConnections(interval time.Duration) { // jitter drain start time.Sleep(time.Duration(rand.Int63n(int64(interval)))) @@ -421,7 +443,7 @@ func (s *AgentHandler) DrainConnections(interval time.Duration) { _ = w.conn.Close() <-t.C } - for _, w := range s.participantWorkers { + for _, w := range s.publisherWorkers { _ = w.conn.Close() <-t.C } @@ -436,8 +458,8 @@ func (s *AgentHandler) roomAvailableLocked() bool { return false } -func (s *AgentHandler) participantAvailableLocked() bool { - for _, w := range s.participantWorkers { +func (s *AgentHandler) publisherAvailableLocked() bool { + for _, w := range s.publisherWorkers { if w.status == livekit.WorkerStatus_WS_AVAILABLE { return true } diff --git a/pkg/service/clients.go b/pkg/service/clients.go index cad477a93..53c0e415b 100644 --- a/pkg/service/clients.go +++ b/pkg/service/clients.go @@ -26,18 +26,6 @@ import ( "github.com/livekit/protocol/utils" ) -type agentClient struct { - s *AgentService -} - -func NewAgentClient(s *AgentService) rtc.AgentClient { - return &agentClient{s} -} - -func (c *agentClient) JobRequest(ctx context.Context, job *livekit.Job) { - _, _ = c.s.JobRequest(ctx, job) -} - type IOClient interface { CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index 0576bd9c2..174ea3391 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -117,25 +117,28 @@ func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomReq } if created { - if s.agentClient != nil { + go func() { s.agentClient.JobRequest(ctx, &livekit.Job{ Id: utils.NewGuid("JR_"), Type: livekit.JobType_JT_ROOM, Room: rm, }) - } + }() + if req.Egress != nil && req.Egress.Room != nil { - egress := &rpc.StartEgressRequest{ + _, err = s.egressLauncher.StartEgress(ctx, &rpc.StartEgressRequest{ Request: &rpc.StartEgressRequest_RoomComposite{ RoomComposite: req.Egress.Room, }, RoomId: rm.Sid, + }) + if err != nil { + return nil, err } - _, err = s.egressLauncher.StartEgress(ctx, egress) } } - return rm, err + return rm, nil } func (s *RoomService) ListRooms(ctx context.Context, req *livekit.ListRoomsRequest) (*livekit.ListRoomsResponse, error) { @@ -478,12 +481,14 @@ func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.Updat return nil, err } - if created && s.agentClient != nil { - s.agentClient.JobRequest(ctx, &livekit.Job{ - Id: utils.NewGuid("JR_"), - Type: livekit.JobType_JT_ROOM, - Room: room, - }) + if created { + go func() { + s.agentClient.JobRequest(ctx, &livekit.Job{ + Id: utils.NewGuid("JR_"), + Type: livekit.JobType_JT_ROOM, + Room: room, + }) + }() } return room, nil diff --git a/pkg/service/wire.go b/pkg/service/wire.go index f53297d8f..afca2631e 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -30,6 +30,7 @@ import ( "github.com/livekit/livekit-server/pkg/clientconfiguration" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" @@ -73,7 +74,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live NewRoomService, NewRTCService, NewAgentService, - NewAgentClient, + rtc.NewAgentClient, getSignalRelayConfig, NewDefaultSignalServer, routing.NewSignalClient, diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index 30a61b66c..3a6470b88 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -11,6 +11,7 @@ import ( "github.com/livekit/livekit-server/pkg/clientconfiguration" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/protocol/auth" @@ -55,11 +56,10 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - agentService, err := NewAgentService(messageBus) + agentClient, err := rtc.NewAgentClient(messageBus) if err != nil { return nil, err } - rtcAgentClient := NewAgentClient(agentService) egressClient, err := rpc.NewEgressClient(messageBus) if err != nil { return nil, err @@ -91,7 +91,7 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live if err != nil { return nil, err } - roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, rtcAgentClient, rtcEgressLauncher, topicFormatter, roomClient, participantClient) + roomService, err := NewRoomService(roomConfig, apiConfig, psrpcConfig, router, roomAllocator, objectStore, agentClient, rtcEgressLauncher, topicFormatter, roomClient, participantClient) if err != nil { return nil, err } @@ -102,11 +102,15 @@ func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*Live return nil, err } ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, roomService, telemetryService) - rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, rtcAgentClient, telemetryService) + rtcService := NewRTCService(conf, roomAllocator, objectStore, router, currentNode, agentClient, telemetryService) + agentService, err := NewAgentService(messageBus) + if err != nil { + return nil, err + } clientConfigurationManager := createClientConfiguration() timedVersionGenerator := utils.NewDefaultTimedVersionGenerator() turnAuthHandler := NewTURNAuthHandler(keyProvider) - roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, rtcAgentClient, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus) + roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, telemetryService, clientConfigurationManager, agentClient, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus) if err != nil { return nil, err } diff --git a/test/agent_test.go b/test/agent_test.go index 172dec62f..0a45fa9ae 100644 --- a/test/agent_test.go +++ b/test/agent_test.go @@ -1,3 +1,17 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package test import ( From 440f00bcac5e126120a188e99745b164be77eb1a Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Wed, 8 Nov 2023 11:13:39 +0530 Subject: [PATCH 18/18] Declare audio inactive if stale. (#2229) * Declare audio inactive if stale. Stale samples were used to declare audio active. Maintain last update time and declare inactive if samples are stale. * correct comment * spelling * check level in test --- pkg/sfu/audio/audiolevel.go | 56 +++++++++++++++++++--------- pkg/sfu/audio/audiolevel_test.go | 64 ++++++++++++++++++++++++-------- pkg/sfu/buffer/buffer.go | 4 +- 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/pkg/sfu/audio/audiolevel.go b/pkg/sfu/audio/audiolevel.go index 7e834fda7..f908788d8 100644 --- a/pkg/sfu/audio/audiolevel.go +++ b/pkg/sfu/audio/audiolevel.go @@ -16,8 +16,8 @@ package audio import ( "math" - - "go.uber.org/atomic" + "sync" + "time" ) const ( @@ -40,11 +40,13 @@ type AudioLevel struct { smoothFactor float64 activeThreshold float64 - smoothedLevel atomic.Float64 + lock sync.Mutex + smoothedLevel float64 loudestObservedLevel uint8 activeDuration uint32 // ms observedDuration uint32 // ms + lastObservedAt time.Time } func NewAudioLevel(params AudioLevelParams) *AudioLevel { @@ -64,8 +66,13 @@ func NewAudioLevel(params AudioLevelParams) *AudioLevel { return l } -// Observes a new frame, must be called from the same thread -func (l *AudioLevel) Observe(level uint8, durationMs uint32) { +// Observes a new frame +func (l *AudioLevel) Observe(level uint8, durationMs uint32, arrivalTime time.Time) { + l.lock.Lock() + defer l.lock.Unlock() + + l.lastObservedAt = arrivalTime + l.observedDuration += durationMs if level <= l.params.ActiveLevel { @@ -76,6 +83,7 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) { } if l.observedDuration >= l.params.ObserveDuration { + smoothedLevel := float64(0.0) // compute and reset if l.activeDuration >= l.minActiveDuration { // adjust loudest observed level by how much of the window was active. @@ -87,25 +95,39 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) { linearLevel := ConvertAudioLevel(adjustedLevel) // exponential smoothing to dampen transients - smoothedLevel := l.smoothedLevel.Load() - smoothedLevel += (linearLevel - smoothedLevel) * l.smoothFactor - l.smoothedLevel.Store(smoothedLevel) - } else { - l.smoothedLevel.Store(0) + smoothedLevel = l.smoothedLevel + (linearLevel-l.smoothedLevel)*l.smoothFactor } - l.loudestObservedLevel = silentAudioLevel - l.activeDuration = 0 - l.observedDuration = 0 + l.resetLocked(smoothedLevel) } } // returns current soothed audio level -func (l *AudioLevel) GetLevel() (float64, bool) { - smoothedLevel := l.smoothedLevel.Load() - active := smoothedLevel >= l.activeThreshold - return smoothedLevel, active +func (l *AudioLevel) GetLevel(now time.Time) (float64, bool) { + l.lock.Lock() + defer l.lock.Unlock() + + l.resetIfStaleLocked(now) + + return l.smoothedLevel, l.smoothedLevel >= l.activeThreshold } +func (l *AudioLevel) resetIfStaleLocked(arrivalTime time.Time) { + if arrivalTime.Sub(l.lastObservedAt).Milliseconds() < int64(2*l.params.ObserveDuration) { + return + } + + l.resetLocked(0.0) +} + +func (l *AudioLevel) resetLocked(smoothedLevel float64) { + l.smoothedLevel = smoothedLevel + l.loudestObservedLevel = silentAudioLevel + l.activeDuration = 0 + l.observedDuration = 0 +} + +// --------------------------------------------------- + // convert decibel back to linear func ConvertAudioLevel(level float64) float64 { return math.Pow(10, level*negInv20) diff --git a/pkg/sfu/audio/audiolevel_test.go b/pkg/sfu/audio/audiolevel_test.go index 8b8f03eba..84d5a34dd 100644 --- a/pkg/sfu/audio/audiolevel_test.go +++ b/pkg/sfu/audio/audiolevel_test.go @@ -16,6 +16,7 @@ package audio import ( "testing" + "time" "github.com/stretchr/testify/require" ) @@ -30,46 +31,79 @@ const ( func TestAudioLevel(t *testing.T) { t.Run("initially to return not noisy, within a few samples", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - _, noisy := a.GetLevel() + + _, noisy := a.GetLevel(clock) require.False(t, noisy) - observeSamples(a, 28, 5) - _, noisy = a.GetLevel() + observeSamples(a, 28, 5, clock) + clock = clock.Add(5 * 20 * time.Millisecond) + + _, noisy = a.GetLevel(clock) require.False(t, noisy) }) t.Run("not noisy when all samples are below threshold", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - observeSamples(a, 35, 100) - _, noisy := a.GetLevel() + observeSamples(a, 35, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + + _, noisy := a.GetLevel(clock) require.False(t, noisy) }) t.Run("not noisy when less than percentile samples are above threshold", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - observeSamples(a, 35, samplesPerBatch-2) - observeSamples(a, 25, 1) - observeSamples(a, 35, 1) + observeSamples(a, 35, samplesPerBatch-2, clock) + clock = clock.Add((samplesPerBatch - 2) * 20 * time.Millisecond) + observeSamples(a, 25, 1, clock) + clock = clock.Add(20 * time.Millisecond) + observeSamples(a, 35, 1, clock) + clock = clock.Add(20 * time.Millisecond) - _, noisy := a.GetLevel() + _, noisy := a.GetLevel(clock) require.False(t, noisy) }) t.Run("noisy when higher than percentile samples are above threshold", func(t *testing.T) { + clock := time.Now() a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) - observeSamples(a, 35, samplesPerBatch-16) - observeSamples(a, 25, 8) - observeSamples(a, 29, 8) + observeSamples(a, 35, samplesPerBatch-16, clock) + clock = clock.Add((samplesPerBatch - 16) * 20 * time.Millisecond) + observeSamples(a, 25, 8, clock) + clock = clock.Add(8 * 20 * time.Millisecond) + observeSamples(a, 29, 8, clock) + clock = clock.Add(8 * 20 * time.Millisecond) - level, noisy := a.GetLevel() + level, noisy := a.GetLevel(clock) require.True(t, noisy) require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) require.Less(t, level, ConvertAudioLevel(float64(25))) }) + + t.Run("not noisy when samples are stale", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamples(a, 25, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + level, noisy := a.GetLevel(clock) + require.True(t, noisy) + require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) + require.Less(t, level, ConvertAudioLevel(float64(20))) + + // let enough time pass to make the samples stale + clock = clock.Add(1500 * time.Millisecond) + level, noisy = a.GetLevel(clock) + require.Equal(t, float64(0.0), level) + require.False(t, noisy) + }) } func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration uint32) *AudioLevel { @@ -80,8 +114,8 @@ func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration ui }) } -func observeSamples(a *AudioLevel, level uint8, count int) { +func observeSamples(a *AudioLevel, level uint8, count int, baseTime time.Time) { for i := 0; i < count; i++ { - a.Observe(level, 20) + a.Observe(level, 20, baseTime.Add(+time.Duration(i*20)*time.Millisecond)) } } diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index c6a018972..573b85649 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -583,7 +583,7 @@ func (b *Buffer) processHeaderExtensions(p *rtp.Packet, arrivalTime time.Time) { if (p.Timestamp - b.latestTSForAudioLevel) < (1 << 31) { duration := (int64(p.Timestamp) - int64(b.latestTSForAudioLevel)) * 1e3 / int64(b.clockRate) if duration > 0 { - b.audioLevel.Observe(ext.Level, uint32(duration)) + b.audioLevel.Observe(ext.Level, uint32(duration), arrivalTime) } b.latestTSForAudioLevel = p.Timestamp @@ -842,7 +842,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) { return 0, false } - return b.audioLevel.GetLevel() + return b.audioLevel.GetLevel(time.Now()) } func (b *Buffer) OnFpsChanged(f func()) {