diff --git a/config-sample.yaml b/config-sample.yaml index 8a6fcf7c5..7375b2398 100644 --- a/config-sample.yaml +++ b/config-sample.yaml @@ -261,3 +261,9 @@ keys: # num_tracks: -1 # # defaults to 1 GB/s, or just under 10 Gbps # bytes_per_sec: 1_000_000_000 +# # how many tracks (audio / video) that a single participant can subscribe at same time. +# # if the limit is exceeded, subscriptions will be pending until any subscribed track has been unsubscribed. +# # value less or equal than 0 means no limit. +# subscription_limit_video: 0 +# subscription_limit_audio: 0 + diff --git a/pkg/config/config.go b/pkg/config/config.go index 50cdf5ccc..bff9d19e1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -246,8 +246,10 @@ type RegionConfig struct { } type LimitConfig struct { - NumTracks int32 `yaml:"num_tracks,omitempty"` - BytesPerSec float32 `yaml:"bytes_per_sec,omitempty"` + NumTracks int32 `yaml:"num_tracks,omitempty"` + BytesPerSec float32 `yaml:"bytes_per_sec,omitempty"` + SubscriptionLimitVideo int32 `yaml:"subscription_limit_video,omitempty"` + SubscriptionLimitAudio int32 `yaml:"subscription_limit_audio,omitempty"` } type EgressConfig struct { diff --git a/pkg/rtc/errors.go b/pkg/rtc/errors.go index dcd31c979..20c41acb9 100644 --- a/pkg/rtc/errors.go +++ b/pkg/rtc/errors.go @@ -14,9 +14,10 @@ var ( ErrMissingGrants = errors.New("VideoGrant is missing") // Track subscription related - ErrNoTrackPermission = errors.New("participant is not allowed to subscribe to this track") - ErrNoSubscribePermission = errors.New("participant is not given permission to subscribe to tracks") - ErrTrackNotFound = errors.New("track cannot be found") - ErrTrackNotAttached = errors.New("track is not yet attached") - ErrTrackNotBound = errors.New("track not bound") + ErrNoTrackPermission = errors.New("participant is not allowed to subscribe to this track") + ErrNoSubscribePermission = errors.New("participant is not given permission to subscribe to tracks") + ErrTrackNotFound = errors.New("track cannot be found") + ErrTrackNotAttached = errors.New("track is not yet attached") + ErrTrackNotBound = errors.New("track not bound") + ErrSubscriptionLimitExceeded = errors.New("participant has exceeded its subscription limit") ) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 8e64941af..86c721ac7 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -90,6 +90,8 @@ type ParticipantParams struct { TrackResolver types.MediaTrackResolver DisableDynacast bool SubscriberAllowPause bool + SubscriptionLimitAudio int32 + SubscriptionLimitVideo int32 } type ParticipantImpl struct { @@ -1065,13 +1067,15 @@ func (p *ParticipantImpl) setupUpTrackManager() { func (p *ParticipantImpl) setupSubscriptionManager() { p.SubscriptionManager = NewSubscriptionManager(SubscriptionManagerParams{ - Participant: p, - Logger: p.params.Logger.WithoutSampler(), - TrackResolver: p.params.TrackResolver, - Telemetry: p.params.Telemetry, - OnTrackSubscribed: p.onTrackSubscribed, - OnTrackUnsubscribed: p.onTrackUnsubscribed, - OnSubscriptionError: p.onSubscriptionError, + Participant: p, + Logger: p.params.Logger.WithoutSampler(), + TrackResolver: p.params.TrackResolver, + Telemetry: p.params.Telemetry, + OnTrackSubscribed: p.onTrackSubscribed, + OnTrackUnsubscribed: p.onTrackUnsubscribed, + OnSubscriptionError: p.onSubscriptionError, + SubscriptionLimitVideo: p.params.SubscriptionLimitVideo, + SubscriptionLimitAudio: p.params.SubscriptionLimitAudio, }) } @@ -1765,11 +1769,6 @@ func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackTyp if pendingInfo == nil { track_loop: for cid, pti := range p.pendingTracks { - if cid == clientId { - pendingInfo = pti - signalCid = cid - break - } ti := pti.trackInfos[0] for _, c := range ti.Codecs { diff --git a/pkg/rtc/subscriptionmanager.go b/pkg/rtc/subscriptionmanager.go index 9e880aeb3..640e85d19 100644 --- a/pkg/rtc/subscriptionmanager.go +++ b/pkg/rtc/subscriptionmanager.go @@ -42,6 +42,10 @@ var ( trackRemoveGracePeriod = time.Second ) +const ( + trackIDForReconcileSubscriptions = livekit.TrackID("subscriptions_reconcile") +) + type SubscriptionManagerParams struct { Logger logger.Logger Participant types.LocalParticipant @@ -50,6 +54,8 @@ type SubscriptionManagerParams struct { OnTrackUnsubscribed func(subTrack types.SubscribedTrack) OnSubscriptionError func(trackID livekit.TrackID) Telemetry telemetry.TelemetryService + + SubscriptionLimitVideo, SubscriptionLimitAudio int32 } // SubscriptionManager manages a participant's subscriptions @@ -57,25 +63,25 @@ type SubscriptionManager struct { params SubscriptionManagerParams lock sync.RWMutex subscriptions map[livekit.TrackID]*trackSubscription - subscribedTo map[livekit.ParticipantID]map[livekit.TrackID]struct{} - // keeps track of tracks that are already queued for reconcile to avoid duplicating reconcile requests - pendingReconcile map[livekit.TrackID]struct{} - reconcileCh chan livekit.TrackID - closeCh chan struct{} - doneCh chan struct{} + + subscribedVideoCount, subscribedAudioCount atomic.Int32 + + subscribedTo map[livekit.ParticipantID]map[livekit.TrackID]struct{} + reconcileCh chan livekit.TrackID + closeCh chan struct{} + doneCh chan struct{} onSubscribeStatusChanged func(publisherID livekit.ParticipantID, subscribed bool) } func NewSubscriptionManager(params SubscriptionManagerParams) *SubscriptionManager { m := &SubscriptionManager{ - params: params, - subscriptions: make(map[livekit.TrackID]*trackSubscription), - subscribedTo: make(map[livekit.ParticipantID]map[livekit.TrackID]struct{}), - pendingReconcile: make(map[livekit.TrackID]struct{}), - reconcileCh: make(chan livekit.TrackID, 50), - closeCh: make(chan struct{}), - doneCh: make(chan struct{}), + params: params, + subscriptions: make(map[livekit.TrackID]*trackSubscription), + subscribedTo: make(map[livekit.ParticipantID]map[livekit.TrackID]struct{}), + reconcileCh: make(chan livekit.TrackID, 50), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), } go m.reconcileWorker() @@ -282,20 +288,21 @@ func (m *SubscriptionManager) reconcileSubscription(s *trackSubscription) { s.recordAttempt(false) switch err { - case ErrNoTrackPermission, ErrNoSubscribePermission, ErrNoReceiver, ErrNotOpen, ErrTrackNotAttached: + case ErrNoTrackPermission, ErrNoSubscribePermission, ErrNoReceiver, ErrNotOpen, ErrTrackNotAttached, ErrSubscriptionLimitExceeded: // these are errors that are outside of our control, so we'll keep trying // - ErrNoTrackPermission: publisher did not grant subscriber permission, may change any moment // - ErrNoSubscribePermission: participant was not granted canSubscribe, may change any moment // - ErrNoReceiver: Track is in the process of closing (another local track published to the same instance) // - ErrTrackNotAttached: Remote Track that is not attached, but may be attached later // - ErrNotOpen: Track is closing or already closed + // - ErrSubscriptionLimitExceeded: the participant have reached the limit of subscriptions, wait for the other subscription to be unsubscribed // We'll still log an event to reflect this in telemetry since it's been too long if s.durationSinceStart() > subscriptionTimeout { s.maybeRecordError(m.params.Telemetry, m.params.Participant.ID(), err, true) } case ErrTrackNotFound: // source track was never published or closed - // if after timeout, we'd unsubscribe from it. + // if after timeout we'd unsubscribe from it. // this is the *only* case we'd change desired state if s.durationSinceStart() > notFoundTimeout { s.maybeRecordError(m.params.Telemetry, m.params.Participant.ID(), err, true) @@ -353,13 +360,6 @@ func (m *SubscriptionManager) reconcileSubscription(s *trackSubscription) { // trigger an immediate reconciliation, when trackID is empty, will reconcile all subscriptions func (m *SubscriptionManager) queueReconcile(trackID livekit.TrackID) { - m.lock.Lock() - if _, ok := m.pendingReconcile[trackID]; ok { - // already reconciled - m.lock.Unlock() - return - } - m.lock.Unlock() select { case m.reconcileCh <- trackID: default: @@ -381,7 +381,6 @@ func (m *SubscriptionManager) reconcileWorker() { case trackID := <-m.reconcileCh: m.lock.Lock() s := m.subscriptions[trackID] - delete(m.pendingReconcile, trackID) m.lock.Unlock() if s != nil { m.reconcileSubscription(s) @@ -392,6 +391,21 @@ func (m *SubscriptionManager) reconcileWorker() { } } +func (m *SubscriptionManager) hasCapcityForSubscription(kind livekit.TrackType) bool { + switch kind { + case livekit.TrackType_VIDEO: + if m.params.SubscriptionLimitVideo > 0 && m.subscribedVideoCount.Load() >= m.params.SubscriptionLimitVideo { + return false + } + + case livekit.TrackType_AUDIO: + if m.params.SubscriptionLimitAudio > 0 && m.subscribedAudioCount.Load() >= m.params.SubscriptionLimitAudio { + return false + } + } + return true +} + func (m *SubscriptionManager) subscribe(s *trackSubscription) error { s.logger.Debugw("executing subscribe") @@ -399,6 +413,10 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { return ErrNoSubscribePermission } + if kind, ok := s.getKind(); ok && !m.hasCapcityForSubscription(kind) { + return ErrSubscriptionLimitExceeded + } + res := m.params.TrackResolver(m.params.Participant.Identity(), s.trackID) s.logger.Debugw("resolved track", "result", res) @@ -426,6 +444,10 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { if track == nil { return ErrTrackNotFound } + s.trySetKind(track.Kind()) + if !m.hasCapcityForSubscription(track.Kind()) { + return ErrSubscriptionLimitExceeded + } // since hasPermission defaults to true, we will want to send a message to the client the first time // that we discover permissions were denied @@ -453,6 +475,13 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { }) s.setSubscribedTrack(subTrack) + switch track.Kind() { + case livekit.TrackType_VIDEO: + m.subscribedVideoCount.Inc() + case livekit.TrackType_AUDIO: + m.subscribedAudioCount.Inc() + } + if subTrack.NeedsNegotiation() { m.params.Participant.Negotiate(false) } @@ -460,6 +489,8 @@ func (m *SubscriptionManager) subscribe(s *trackSubscription) error { go m.params.OnTrackSubscribed(subTrack) } + m.params.Logger.Debugw("subscribed to track", "track", s.trackID, "subscribedAudioCount", m.subscribedAudioCount.Load(), "subscribedVideoCount", m.subscribedVideoCount.Load()) + // add mark the participant as someone we've subscribed to firstSubscribe := false publisherID := s.getPublisherID() @@ -512,6 +543,14 @@ func (m *SubscriptionManager) handleSubscribedTrackClose(s *trackSubscription, w } s.setSubscribedTrack(nil) + var relieveFromLimits bool + switch subTrack.MediaTrack().Kind() { + case livekit.TrackType_VIDEO: + relieveFromLimits = m.params.SubscriptionLimitVideo > 0 && m.subscribedVideoCount.Dec() == m.params.SubscriptionLimitVideo-1 + case livekit.TrackType_AUDIO: + relieveFromLimits = m.params.SubscriptionLimitAudio > 0 && m.subscribedAudioCount.Dec() == m.params.SubscriptionLimitAudio-1 + } + // remove from subscribedTo publisherID := s.getPublisherID() lastSubscription := false @@ -581,7 +620,11 @@ func (m *SubscriptionManager) handleSubscribedTrackClose(s *trackSubscription, w m.params.Participant.Negotiate(false) } - m.queueReconcile(s.trackID) + if relieveFromLimits { + m.queueReconcile(trackIDForReconcileSubscriptions) + } else { + m.queueReconcile(s.trackID) + } } // -------------------------------------------------------------------------------------- @@ -603,6 +646,7 @@ type trackSubscription struct { eventSent atomic.Bool numAttempts atomic.Int32 bound bool + kind atomic.Pointer[livekit.TrackType] // the later of when subscription was requested OR when the first failure was encountered OR when permission is granted // this timestamp determines when failures are reported @@ -705,6 +749,18 @@ func (s *trackSubscription) setSubscribedTrack(track types.SubscribedTrack) { } } +func (s *trackSubscription) trySetKind(kind livekit.TrackType) { + s.kind.CompareAndSwap(nil, &kind) +} + +func (s *trackSubscription) getKind() (livekit.TrackType, bool) { + kind := s.kind.Load() + if kind == nil { + return livekit.TrackType_AUDIO, false + } + return *kind, true +} + func (s *trackSubscription) getSubscribedTrack() types.SubscribedTrack { s.lock.RLock() defer s.lock.RUnlock() diff --git a/pkg/rtc/subscriptionmanager_test.go b/pkg/rtc/subscriptionmanager_test.go index d5a9d0f4a..0c3de6955 100644 --- a/pkg/rtc/subscriptionmanager_test.go +++ b/pkg/rtc/subscriptionmanager_test.go @@ -340,7 +340,115 @@ func TestUpdateSettingsBeforeSubscription(t *testing.T) { require.Equal(t, settings.Height, applied.Height) } +func TestSubscriptionLimits(t *testing.T) { + sm := newTestSubscriptionManagerWithParams(t, testSubscriptionParams{ + SubscriptionLimitAudio: 1, + SubscriptionLimitVideo: 1, + }) + defer sm.Close(false) + resolver := newTestResolver(true, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + subCount := atomic.Int32{} + failed := atomic.Bool{} + sm.params.OnTrackSubscribed = func(subTrack types.SubscribedTrack) { + subCount.Add(1) + } + sm.params.OnSubscriptionError = func(trackID livekit.TrackID) { + failed.Store(true) + } + numParticipantSubscribed := atomic.Int32{} + numParticipantUnsubscribed := atomic.Int32{} + sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) { + if subscribed { + numParticipantSubscribed.Add(1) + } else { + numParticipantUnsubscribed.Add(1) + } + }) + + sm.SubscribeToTrack("track") + s := sm.subscriptions["track"] + require.True(t, s.isDesired()) + require.Eventually(t, func() bool { + return subCount.Load() == 1 + }, subSettleTimeout, subCheckInterval, "track was not subscribed") + + require.NotNil(t, s.getSubscribedTrack()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + require.Eventually(t, func() bool { + return len(sm.GetSubscribedParticipants()) == 1 + }, subSettleTimeout, subCheckInterval, "GetSubscribedParticipants should have returned one item") + require.Equal(t, "pubID", string(sm.GetSubscribedParticipants()[0])) + + // ensure telemetry events are sent + tm := sm.params.Telemetry.(*telemetryfakes.FakeTelemetryService) + require.Equal(t, 1, tm.TrackSubscribeRequestedCallCount()) + + // ensure bound + setTestSubscribedTrackBound(t, s.getSubscribedTrack()) + + require.Eventually(t, func() bool { + return !s.needsBind() + }, subSettleTimeout, subCheckInterval, "track was not bound") + + // telemetry event should have been sent + require.Equal(t, 1, tm.TrackSubscribedCallCount()) + + // reach subscription limit, subscribe pending + sm.SubscribeToTrack("track2") + s2 := sm.subscriptions["track2"] + time.Sleep(subscriptionTimeout * 2) + require.True(t, s2.needsSubscribe()) + require.Equal(t, 2, tm.TrackSubscribeRequestedCallCount()) + require.Equal(t, 1, tm.TrackSubscribeFailedCallCount()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + // unsubscribe track1, then track2 should be subscribed + sm.UnsubscribeFromTrack("track") + require.False(t, s.isDesired()) + require.True(t, s.needsUnsubscribe()) + // wait for unsubscribe to take effect + time.Sleep(reconcileInterval) + setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false) + require.Nil(t, s.getSubscribedTrack()) + + time.Sleep(reconcileInterval) + require.True(t, s2.isDesired()) + require.False(t, s2.needsSubscribe()) + require.EqualValues(t, 2, subCount.Load()) + require.NotNil(t, s2.getSubscribedTrack()) + require.Equal(t, 2, tm.TrackSubscribeRequestedCallCount()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + // ensure bound + setTestSubscribedTrackBound(t, s2.getSubscribedTrack()) + + require.Eventually(t, func() bool { + return !s2.needsBind() + }, subSettleTimeout, subCheckInterval, "track was not bound") + + // subscribe to track1 again, which should pending + sm.SubscribeToTrack("track") + s = sm.subscriptions["track"] + require.True(t, s.isDesired()) + time.Sleep(subscriptionTimeout * 2) + require.True(t, s.needsSubscribe()) + require.Equal(t, 3, tm.TrackSubscribeRequestedCallCount()) + require.Equal(t, 2, tm.TrackSubscribeFailedCallCount()) + require.Len(t, sm.GetSubscribedTracks(), 1) +} + +type testSubscriptionParams struct { + SubscriptionLimitAudio int32 + SubscriptionLimitVideo int32 +} + func newTestSubscriptionManager(t *testing.T) *SubscriptionManager { + return newTestSubscriptionManagerWithParams(t, testSubscriptionParams{}) +} + +func newTestSubscriptionManagerWithParams(t *testing.T, params testSubscriptionParams) *SubscriptionManager { p := &typesfakes.FakeLocalParticipant{} p.CanSubscribeReturns(true) p.IDReturns("subID") @@ -354,7 +462,9 @@ func newTestSubscriptionManager(t *testing.T) *SubscriptionManager { TrackResolver: func(identity livekit.ParticipantIdentity, trackID livekit.TrackID) types.MediaResolverResult { return types.MediaResolverResult{} }, - Telemetry: &telemetryfakes.FakeTelemetryService{}, + Telemetry: &telemetryfakes.FakeTelemetryService{}, + SubscriptionLimitAudio: params.SubscriptionLimitAudio, + SubscriptionLimitVideo: params.SubscriptionLimitVideo, }) } diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 9f4df01fb..0244abb84 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -341,6 +341,8 @@ func (r *RoomManager) StartSession( VersionGenerator: r.versionGenerator, TrackResolver: room.ResolveMediaTrackForSubscriber, SubscriberAllowPause: subscriberAllowPause, + SubscriptionLimitAudio: r.config.Limit.SubscriptionLimitAudio, + SubscriptionLimitVideo: r.config.Limit.SubscriptionLimitVideo, }) if err != nil { return err