From 44d26f0cb477adf7f0cf1a89074657dacae2cce3 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sat, 30 Nov 2024 01:38:25 +0530 Subject: [PATCH] Probe controller refactor (#3221) * WIP * WIP * WIP --- pkg/rtc/transport.go | 9 +- pkg/sfu/bwe/bwe.go | 30 +- pkg/sfu/bwe/null_bwe.go | 10 +- pkg/sfu/bwe/remotebwe/channel_observer.go | 33 +- pkg/sfu/bwe/remotebwe/remote_bwe.go | 91 +--- pkg/sfu/ccutils/prober.go | 155 ++++--- pkg/sfu/pacer/pacer.go | 2 +- pkg/sfu/pacer/probe_observer.go | 58 +-- pkg/sfu/streamallocator/probe_controller.go | 459 ++++++++------------ pkg/sfu/streamallocator/streamallocator.go | 233 +++++----- 10 files changed, 476 insertions(+), 604 deletions(-) diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index e7bc5278c..18a6835c4 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -468,10 +468,11 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { } t.streamAllocator = streamallocator.NewStreamAllocator(streamallocator.StreamAllocatorParams{ - Config: params.CongestionControlConfig.StreamAllocator, - BWE: t.bwe, - Pacer: t.pacer, - Logger: params.Logger.WithComponent(utils.ComponentCongestionControl), + Config: params.CongestionControlConfig.StreamAllocator, + BWE: t.bwe, + Pacer: t.pacer, + RTTGetter: t.GetRTT, + Logger: params.Logger.WithComponent(utils.ComponentCongestionControl), }, params.CongestionControlConfig.Enabled, params.CongestionControlConfig.AllowPause) t.streamAllocator.OnStreamStateChange(params.Handler.OnStreamStateChange) t.streamAllocator.Start() diff --git a/pkg/sfu/bwe/bwe.go b/pkg/sfu/bwe/bwe.go index 400861c49..c15eea60b 100644 --- a/pkg/sfu/bwe/bwe.go +++ b/pkg/sfu/bwe/bwe.go @@ -17,6 +17,7 @@ package bwe import ( "fmt" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/pion/rtcp" ) @@ -51,29 +52,6 @@ func (c CongestionState) String() string { // ------------------------------------------------ -type ChannelTrend int - -const ( - ChannelTrendNeutral ChannelTrend = iota - ChannelTrendClearing - ChannelTrendCongesting -) - -func (c ChannelTrend) String() string { - switch c { - case ChannelTrendNeutral: - return "NEUTRAL" - case ChannelTrendClearing: - return "CLEARING" - case ChannelTrendCongesting: - return "CONGESTING" - default: - return fmt.Sprintf("%d", int(c)) - } -} - -// ------------------------------------------------ - type BWE interface { SetBWEListener(bweListner BWEListener) @@ -83,7 +61,6 @@ type BWE interface { HandleREMB( receivedEstimate int64, - isProbeFinalizing bool, expectedBandwidthUsage int64, sentPackets uint32, repeatedNacks uint32, @@ -94,9 +71,8 @@ type BWE interface { HandleTWCCFeedback(report *rtcp.TransportLayerCC) - ProbingStart(expectedBandwidthUsage int64) - ProbingEnd(isNotFailing bool, isGoalReached bool) - GetProbeStatus() (isValidSignal bool, trend ChannelTrend, lowestEstimate int64, highestEstimate int64) + ProbeClusterStarting(pci ccutils.ProbeClusterInfo) + ProbeClusterDone(pci ccutils.ProbeClusterInfo) (bool, int64) } // ------------------------------------------------ diff --git a/pkg/sfu/bwe/null_bwe.go b/pkg/sfu/bwe/null_bwe.go index 436e2e361..4f810c678 100644 --- a/pkg/sfu/bwe/null_bwe.go +++ b/pkg/sfu/bwe/null_bwe.go @@ -15,6 +15,7 @@ package bwe import ( + "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/pion/rtcp" ) @@ -33,7 +34,6 @@ func (n *NullBWE) RecordPacketSendAndGetSequenceNumber(_atMicro int64, _size int func (n *NullBWE) HandleREMB( _receivedEstimate int64, - _isProbeFinalizing bool, _expectedBandwidthUsage int64, _sentPackets uint32, _repeatedNacks uint32, @@ -42,12 +42,10 @@ func (n *NullBWE) HandleREMB( func (n *NullBWE) HandleTWCCFeedback(_report *rtcp.TransportLayerCC) {} -func (n *NullBWE) ProbingStart(_expectedBandwidthUsage int64) {} +func (n *NullBWE) ProbeClusterStarting(_pci ccutils.ProbeClusterInfo) {} -func (n *NullBWE) ProbingEnd(_isNotFailing bool, _isGoalReached bool) {} - -func (n *NullBWE) GetProbeStatus() (bool, ChannelTrend, int64, int64) { - return false, ChannelTrendNeutral, 0, 0 +func (n *NullBWE) ProbeClusterDone(_pci ccutils.ProbeClusterInfo) (bool, int64) { + return false, 0 } // ------------------------------------------------ diff --git a/pkg/sfu/bwe/remotebwe/channel_observer.go b/pkg/sfu/bwe/remotebwe/channel_observer.go index adf4daba2..856d65537 100644 --- a/pkg/sfu/bwe/remotebwe/channel_observer.go +++ b/pkg/sfu/bwe/remotebwe/channel_observer.go @@ -18,13 +18,34 @@ import ( "fmt" "time" - "github.com/livekit/livekit-server/pkg/sfu/bwe" "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/livekit/protocol/logger" ) // ------------------------------------------------ +type channelTrend int + +const ( + channelTrendNeutral channelTrend = iota + channelTrendClearing + channelTrendCongesting +) + +func (c channelTrend) String() string { + switch c { + case channelTrendNeutral: + return "NEUTRAL" + case channelTrendClearing: + return "CLEARING" + case channelTrendCongesting: + return "CONGESTING" + default: + return fmt.Sprintf("%d", int(c)) + } +} + +// ------------------------------------------------ type channelCongestionReason int const ( @@ -149,23 +170,23 @@ func (c *channelObserver) GetNackHistory() []string { } */ -func (c *channelObserver) GetTrend() (bwe.ChannelTrend, channelCongestionReason) { +func (c *channelObserver) GetTrend() (channelTrend, channelCongestionReason) { estimateDirection := c.estimateTrend.GetDirection() switch { case estimateDirection == ccutils.TrendDirectionDownward: c.logger.Debugw("remote bwe: channel observer: estimate is trending downward", "channel", c) - return bwe.ChannelTrendCongesting, channelCongestionReasonEstimate + return channelTrendCongesting, channelCongestionReasonEstimate case c.nackTracker.IsTriggered(): c.logger.Debugw("remote bwe: channel observer: high rate of repeated NACKs", "channel", c) - return bwe.ChannelTrendCongesting, channelCongestionReasonLoss + return channelTrendCongesting, channelCongestionReasonLoss case estimateDirection == ccutils.TrendDirectionUpward: - return bwe.ChannelTrendClearing, channelCongestionReasonNone + return channelTrendClearing, channelCongestionReasonNone } - return bwe.ChannelTrendNeutral, channelCongestionReasonNone + return channelTrendNeutral, channelCongestionReasonNone } func (c *channelObserver) String() string { diff --git a/pkg/sfu/bwe/remotebwe/remote_bwe.go b/pkg/sfu/bwe/remotebwe/remote_bwe.go index f9fdedc25..4721f5f4b 100644 --- a/pkg/sfu/bwe/remotebwe/remote_bwe.go +++ b/pkg/sfu/bwe/remotebwe/remote_bwe.go @@ -20,6 +20,7 @@ import ( "github.com/frostbyte73/core" "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils/mono" ) @@ -65,7 +66,6 @@ type RemoteBWE struct { lastReceivedEstimate int64 lastExpectedBandwidthUsage int64 - isInProbe bool committedChannelCapacity int64 channelObserver *channelObserver @@ -106,7 +106,7 @@ func (r *RemoteBWE) Reset() { defer r.lock.Unlock() r.channelObserver = r.newChannelObserverNonProbe() - r.isInProbe = false + r.updateCongestionState(bwe.CongestionStateNone, channelCongestionReasonNone) } func (r *RemoteBWE) Stop() { @@ -115,7 +115,6 @@ func (r *RemoteBWE) Stop() { func (r *RemoteBWE) HandleREMB( receivedEstimate int64, - isProbeFinalizing bool, expectedBandwidthUsage int64, sentPackets uint32, repeatedNacks uint32, @@ -124,19 +123,10 @@ func (r *RemoteBWE) HandleREMB( r.lastReceivedEstimate = receivedEstimate r.lastExpectedBandwidthUsage = expectedBandwidthUsage - if !isProbeFinalizing { - r.channelObserver.AddEstimate(r.lastReceivedEstimate) - r.channelObserver.AddNack(sentPackets, repeatedNacks) - } + r.channelObserver.AddEstimate(r.lastReceivedEstimate) + r.channelObserver.AddNack(sentPackets, repeatedNacks) - var ( - shouldNotify bool - state bwe.CongestionState - committedChannelCapacity int64 - ) - if !r.isInProbe { - shouldNotify, state, committedChannelCapacity = r.congestionDetectionStateMachine() - } + shouldNotify, state, committedChannelCapacity := r.congestionDetectionStateMachine() r.lock.Unlock() if shouldNotify { @@ -152,14 +142,14 @@ func (r *RemoteBWE) congestionDetectionStateMachine() (bool, bwe.CongestionState trend, reason := r.channelObserver.GetTrend() switch r.congestionState { case bwe.CongestionStateNone: - if trend == bwe.ChannelTrendCongesting { + if trend == channelTrendCongesting { if r.estimateAvailableChannelCapacity(reason) { newState = bwe.CongestionStateCongested } } case bwe.CongestionStateCongested: - if trend == bwe.ChannelTrendCongesting { + if trend == channelTrendCongesting { if r.estimateAvailableChannelCapacity(reason) { // update state sa this needs to reset switch time to wait for congestion min duration again update = true @@ -252,17 +242,16 @@ func (r *RemoteBWE) newChannelObserverNonProbe() *channelObserver { ) } -func (r *RemoteBWE) ProbingStart(expectedBandwidthUsage int64) { +func (r *RemoteBWE) ProbeClusterStarting(pci ccutils.ProbeClusterInfo) { r.lock.Lock() defer r.lock.Unlock() - r.isInProbe = true - r.lastExpectedBandwidthUsage = expectedBandwidthUsage + r.lastExpectedBandwidthUsage = int64(pci.Goal.ExpectedUsageBps) r.params.Logger.Debugw( "remote bwe: starting probe", "lastReceived", r.lastReceivedEstimate, - "expectedBandwidthUsage", expectedBandwidthUsage, + "expectedBandwidthUsage", r.lastExpectedBandwidthUsage, "channel", r.channelObserver, ) @@ -276,56 +265,21 @@ func (r *RemoteBWE) ProbingStart(expectedBandwidthUsage int64) { r.channelObserver.SeedEstimate(r.lastReceivedEstimate) } -func (r *RemoteBWE) ProbingEnd(isNotFailing bool, isGoalReached bool) { +func (r *RemoteBWE) ProbeClusterDone(_pci ccutils.ProbeClusterInfo) (bool, int64) { r.lock.Lock() defer r.lock.Unlock() - highestEstimateInProbe := r.channelObserver.GetHighestEstimate() - - // - // Reset estimator at the end of a probe irrespective of probe result to get fresh readings. - // With a failed probe, the latest estimate could be lower than committed estimate. - // As bandwidth estimator (remote in REMB case, local in TWCC case) holds state, - // subsequent estimates could start from the lower point. That should not trigger a - // downward trend and get latched to committed estimate as that would trigger a re-allocation. - // With fresh readings, as long as the trend is not going downward, it will not get latched. - // - // BWE-TODO: clean up this comment after implementing probing in TWCC case - // NOTE: With TWCC, it is possible to reset bandwidth estimation to clean state as - // the send side is in full control of bandwidth estimation. - // - r.params.Logger.Debugw( - "remote bwe: probe done", - "isNotFailing", isNotFailing, - "isGoalReached", isGoalReached, - "committedEstimate", r.committedChannelCapacity, - "highestEstimate", highestEstimateInProbe, - "channel", r.channelObserver, - ) + // switch to a non-probe channel observer on probe end + pco := r.channelObserver r.channelObserver = r.newChannelObserverNonProbe() - r.isInProbe = false - if !isNotFailing { - return + + if !pco.HasEnoughEstimateSamples() { + // cannot decide success/failure without enough data + return false, pco.GetHighestEstimate() } - if highestEstimateInProbe > r.committedChannelCapacity { - r.committedChannelCapacity = highestEstimateInProbe - } -} - -func (r *RemoteBWE) GetProbeStatus() (bool, bwe.ChannelTrend, int64, int64) { - r.lock.RLock() - defer r.lock.RUnlock() - - if !r.isInProbe { - return false, bwe.ChannelTrendNeutral, 0, 0 - } - - trend, _ := r.channelObserver.GetTrend() - return r.channelObserver.HasEnoughEstimateSamples(), - trend, - r.channelObserver.GetLowestEstimate(), - r.channelObserver.GetHighestEstimate() + trend, _ := pco.GetTrend() + return trend == channelTrendClearing, pco.GetHighestEstimate() } func (r *RemoteBWE) worker() { @@ -345,13 +299,8 @@ func (r *RemoteBWE) worker() { } case <-ticker.C: - var ( - shouldNotify bool - state bwe.CongestionState - committedChannelCapacity int64 - ) r.lock.Lock() - shouldNotify, state, committedChannelCapacity = r.congestionDetectionStateMachine() + shouldNotify, state, committedChannelCapacity := r.congestionDetectionStateMachine() r.lock.Unlock() if shouldNotify { diff --git a/pkg/sfu/ccutils/prober.go b/pkg/sfu/ccutils/prober.go index 858db8f14..88e127922 100644 --- a/pkg/sfu/ccutils/prober.go +++ b/pkg/sfu/ccutils/prober.go @@ -130,11 +130,12 @@ import ( "go.uber.org/zap/zapcore" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" ) type ProberListener interface { + OnProbeClusterSwitch(info ProbeClusterInfo) OnSendProbe(bytesToSend int) - OnProbeClusterSwitch(probeClusterId ProbeClusterId, desiredBytes int) } type ProberParams struct { @@ -171,8 +172,8 @@ func (p *Prober) Reset(info ProbeClusterInfo) { p.clustersMu.Lock() defer p.clustersMu.Unlock() - if p.activeCluster != nil && p.activeCluster.Id() == info.ProbeClusterId { - p.activeCluster.MarkCompleted(info) + if p.activeCluster != nil && p.activeCluster.Id() == info.Id { + p.activeCluster.MarkCompleted(info.Result) p.params.Logger.Debugw("prober: resetting active cluster", "cluster", p.activeCluster) } @@ -180,30 +181,18 @@ func (p *Prober) Reset(info ProbeClusterInfo) { p.activeCluster = nil } -func (p *Prober) AddCluster( - mode ProbeClusterMode, - desiredRateBps int, - expectedRateBps int, - duration time.Duration, -) ProbeClusterId { - if desiredRateBps <= 0 { - return ProbeClusterIdInvalid +func (p *Prober) AddCluster(mode ProbeClusterMode, pcg ProbeClusterGoal) ProbeClusterInfo { + if pcg.DesiredBps <= 0 { + return ProbeClusterInfoInvalid } clusterId := ProbeClusterId(p.clusterId.Inc()) - cluster := newCluster( - clusterId, - mode, - desiredRateBps, - expectedRateBps, - duration, - p.params.Listener, - ) + cluster := newCluster(clusterId, mode, pcg, p.params.Listener) p.params.Logger.Debugw("cluster added", "cluster", cluster) p.pushBackClusterAndMaybeStart(cluster) - return clusterId + return cluster.Info() } func (p *Prober) ClusterDone(info ProbeClusterInfo) { @@ -212,8 +201,8 @@ func (p *Prober) ClusterDone(info ProbeClusterInfo) { return } - if cluster.Id() == info.ProbeClusterId { - cluster.MarkCompleted(info) + if cluster.Id() == info.Id { + cluster.MarkCompleted(info.Result) p.params.Logger.Debugw("cluster done", "cluster", cluster) p.popFrontCluster(cluster) } @@ -329,31 +318,41 @@ func (p ProbeClusterMode) String() string { // --------------------------------------------------------------------------- -type ProbeClusterInfo struct { - ProbeClusterId ProbeClusterId - DesiredBytes int +type ProbeClusterGoal struct { + AvailableBandwidthBps int + ExpectedUsageBps int + DesiredBps int + Duration time.Duration + DesiredBytes int +} + +func (p ProbeClusterGoal) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddInt("AvailableBandwidthBps", p.AvailableBandwidthBps) + e.AddInt("ExpectedUsageBps", p.ExpectedUsageBps) + e.AddInt("DesiredBps", p.DesiredBps) + e.AddDuration("Duration", p.Duration) + e.AddInt("DesiredBytes", p.DesiredBytes) + return nil +} + +type ProbeClusterResult struct { StartTime int64 EndTime int64 BytesProbe int BytesNonProbePrimary int BytesNonProbeRTX int + IsCompleted bool } -var ( - ProbeClusterInfoInvalid = ProbeClusterInfo{ProbeClusterId: ProbeClusterIdInvalid} -) - -func (p ProbeClusterInfo) Bytes() int { +func (p ProbeClusterResult) Bytes() int { return p.BytesProbe + p.BytesNonProbePrimary + p.BytesNonProbeRTX } -func (p ProbeClusterInfo) Duration() time.Duration { +func (p ProbeClusterResult) Duration() time.Duration { return time.Duration(p.EndTime - p.StartTime) } -func (p ProbeClusterInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { - e.AddUint32("ProbeClusterId", uint32(p.ProbeClusterId)) - e.AddInt("DesiredBytes", p.DesiredBytes) +func (p ProbeClusterResult) MarshalLogObject(e zapcore.ObjectEncoder) error { e.AddTime("StartTime", time.Unix(0, p.StartTime)) e.AddTime("EndTime", time.Unix(0, p.EndTime)) e.AddDuration("Duration", p.Duration()) @@ -361,21 +360,26 @@ func (p ProbeClusterInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { e.AddInt("BytesNonProbePrimary", p.BytesNonProbePrimary) e.AddInt("BytesNonProbeRTX", p.BytesNonProbeRTX) e.AddInt("Bytes", p.Bytes()) + e.AddBool("IsCompleted", p.IsCompleted) return nil } -// --------------------------------------------------------------------------- - -type clusterBucket struct { - desiredNumProbes int - desiredBytes int - sleepDuration time.Duration +type ProbeClusterInfo struct { + Id ProbeClusterId + CreatedAt time.Time + Goal ProbeClusterGoal + Result ProbeClusterResult } -func (c clusterBucket) MarshalLogObject(e zapcore.ObjectEncoder) error { - e.AddInt("desiredNumProbes", c.desiredNumProbes) - e.AddInt("desiredBytes", c.desiredBytes) - e.AddDuration("sleepDuration", c.sleepDuration) +var ( + ProbeClusterInfoInvalid = ProbeClusterInfo{Id: ProbeClusterIdInvalid} +) + +func (p ProbeClusterInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddUint32("Id", uint32(p.Id)) + e.AddTime("CreatedAt", p.CreatedAt) + e.AddObject("Goal", p.Goal) + e.AddObject("Result", p.Result) return nil } @@ -384,42 +388,32 @@ func (c clusterBucket) MarshalLogObject(e zapcore.ObjectEncoder) error { type Cluster struct { lock sync.RWMutex - id ProbeClusterId - mode ProbeClusterMode - desiredRateBps int - expectedRateBps int - duration time.Duration - listener ProberListener + info ProbeClusterInfo + mode ProbeClusterMode + listener ProberListener probeSleeps []time.Duration probeIdx int - isComplete bool - probeClusterInfo ProbeClusterInfo + isComplete bool } -func newCluster( - id ProbeClusterId, - mode ProbeClusterMode, - desiredRateBps int, - expectedRateBps int, - duration time.Duration, - listener ProberListener, -) *Cluster { +func newCluster(id ProbeClusterId, mode ProbeClusterMode, pcg ProbeClusterGoal, listener ProberListener) *Cluster { c := &Cluster{ - id: id, - mode: mode, - desiredRateBps: desiredRateBps, - expectedRateBps: expectedRateBps, - duration: duration, - listener: listener, + mode: mode, + info: ProbeClusterInfo{ + Id: id, + CreatedAt: mono.Now(), + Goal: pcg, + }, + listener: listener, } c.initProbes() return c } func (c *Cluster) initProbes() { - numProbeBytes := int(math.Round(float64(c.desiredRateBps-c.expectedRateBps)*c.duration.Seconds()/8 + 0.5)) + numProbeBytes := int(math.Round(float64(c.info.Goal.DesiredBps-c.info.Goal.ExpectedUsageBps)*c.info.Goal.Duration.Seconds()/8 + 0.5)) numProbes := (numProbeBytes + cBytesPerProbe - 1) / cBytesPerProbe if numProbes < 1 { numProbes = 1 @@ -428,36 +422,45 @@ func (c *Cluster) initProbes() { c.probeSleeps = make([]time.Duration, numProbes) switch c.mode { case ProbeClusterModeUniform: - interval := c.duration / time.Duration(numProbes) + interval := c.info.Goal.Duration / time.Duration(numProbes) for i := 0; i < numProbes; i++ { c.probeSleeps[i] = interval } case ProbeClusterModeLinearChirp: numIntervals := numProbes * (numProbes + 1) / 2 - interval := c.duration / time.Duration(numIntervals) + interval := c.info.Goal.Duration / time.Duration(numIntervals) for i := 0; i < numProbes; i++ { c.probeSleeps[i] = time.Duration(numProbes-i) * interval } } + + c.info.Goal.DesiredBytes = int(math.Round(float64(c.info.Goal.DesiredBps)*c.info.Goal.Duration.Seconds()/8 + 0.5)) } func (c *Cluster) Start() { if c.listener != nil { - c.listener.OnProbeClusterSwitch(c.id, int(math.Round(float64(c.desiredRateBps)*c.duration.Seconds()/8+0.5))) + c.listener.OnProbeClusterSwitch(c.info) } } func (c *Cluster) Id() ProbeClusterId { - return c.id + return c.info.Id } -func (c *Cluster) MarkCompleted(info ProbeClusterInfo) { +func (c *Cluster) Info() ProbeClusterInfo { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.info +} + +func (c *Cluster) MarkCompleted(result ProbeClusterResult) { c.lock.Lock() defer c.lock.Unlock() c.isComplete = true - c.probeClusterInfo = info + c.info.Result = result } func (c *Cluster) Process() time.Duration { @@ -484,16 +487,12 @@ func (c *Cluster) Process() time.Duration { func (c *Cluster) MarshalLogObject(e zapcore.ObjectEncoder) error { if c != nil { - e.AddUint32("id", uint32(c.id)) e.AddString("mode", c.mode.String()) - e.AddInt("desiredRateBps", c.desiredRateBps) - e.AddInt("expectedRateBps", c.expectedRateBps) - e.AddDuration("duration", c.duration) + e.AddObject("info", c.info) e.AddInt("numProbes", len(c.probeSleeps)) e.AddArray("probeSleeps", logger.DurationSlice(c.probeSleeps)) e.AddInt("probeIdx", c.probeIdx) e.AddBool("isComplete", c.isComplete) - e.AddObject("probeClusterInfo", c.probeClusterInfo) } return nil } diff --git a/pkg/sfu/pacer/pacer.go b/pkg/sfu/pacer/pacer.go index c68bbf6e7..f72886df3 100644 --- a/pkg/sfu/pacer/pacer.go +++ b/pkg/sfu/pacer/pacer.go @@ -45,7 +45,7 @@ type Pacer interface { SetBitrate(bitrate int) SetPacerProbeObserverListener(listener PacerProbeObserverListener) - StartProbeCluster(probeClusterId ccutils.ProbeClusterId, desiredBytes int) + StartProbeCluster(pci ccutils.ProbeClusterInfo) EndProbeCluster(probeClusterId ccutils.ProbeClusterId) ccutils.ProbeClusterInfo } diff --git a/pkg/sfu/pacer/probe_observer.go b/pkg/sfu/pacer/probe_observer.go index fd3467faf..4fb9a568f 100644 --- a/pkg/sfu/pacer/probe_observer.go +++ b/pkg/sfu/pacer/probe_observer.go @@ -30,14 +30,8 @@ type ProbeObserver struct { isInProbe atomic.Bool - lock sync.Mutex - clusterStartTime int64 - activeProbeClusterId ccutils.ProbeClusterId - desiredProbeClusterBytes int - bytesNonProbePrimary int - bytesNonProbeRTX int - bytesProbe int - isActiveClusterDone bool + lock sync.Mutex + pci ccutils.ProbeClusterInfo } func NewProbeObserver(logger logger.Logger) *ProbeObserver { @@ -50,12 +44,11 @@ func (po *ProbeObserver) SetPacerProbeObserverListener(listener PacerProbeObserv po.listener = listener } -func (po *ProbeObserver) StartProbeCluster(probeClusterId ccutils.ProbeClusterId, desiredBytes int) { +func (po *ProbeObserver) StartProbeCluster(pci ccutils.ProbeClusterInfo) { if po.isInProbe.Load() { po.logger.Warnw( "ignoring start of a new probe cluster when already active", nil, - "probeClusterId", probeClusterId, - "desiredBytes", desiredBytes, + "probeClusterInfo", pci, ) return } @@ -63,13 +56,10 @@ func (po *ProbeObserver) StartProbeCluster(probeClusterId ccutils.ProbeClusterId po.lock.Lock() defer po.lock.Unlock() - po.clusterStartTime = mono.UnixNano() - po.activeProbeClusterId = probeClusterId - po.desiredProbeClusterBytes = desiredBytes - po.bytesNonProbePrimary = 0 - po.bytesNonProbeRTX = 0 - po.bytesProbe = 0 - po.isActiveClusterDone = false + po.pci = pci + po.pci.Result = ccutils.ProbeClusterResult{ + StartTime: mono.UnixNano(), + } po.isInProbe.Store(true) } @@ -89,30 +79,23 @@ func (po *ProbeObserver) EndProbeCluster(probeClusterId ccutils.ProbeClusterId) po.lock.Lock() defer po.lock.Unlock() - if po.activeProbeClusterId != probeClusterId { + if po.pci.Id != probeClusterId { // probe cluster id not active po.logger.Warnw( "ignoring end of a probe cluster of a non-active one", nil, "probeClusterId", probeClusterId, - "active", po.activeProbeClusterId, + "active", po.pci.Id, ) return ccutils.ProbeClusterInfoInvalid } - clusterInfo := ccutils.ProbeClusterInfo{ - ProbeClusterId: po.activeProbeClusterId, - DesiredBytes: po.desiredProbeClusterBytes, - StartTime: po.clusterStartTime, - EndTime: mono.UnixNano(), - BytesProbe: po.bytesProbe, - BytesNonProbePrimary: po.bytesNonProbePrimary, - BytesNonProbeRTX: po.bytesNonProbeRTX, + if po.pci.Result.EndTime == 0 { + po.pci.Result.EndTime = mono.UnixNano() } - po.activeProbeClusterId = ccutils.ProbeClusterIdInvalid po.isInProbe.Store(false) - return clusterInfo + return po.pci } func (po *ProbeObserver) RecordPacket(size int, isRTX bool, probeClusterId ccutils.ProbeClusterId, isProbe bool) { @@ -121,28 +104,29 @@ func (po *ProbeObserver) RecordPacket(size int, isRTX bool, probeClusterId ccuti } po.lock.Lock() - if probeClusterId != po.activeProbeClusterId || po.isActiveClusterDone { + if probeClusterId != po.pci.Id || po.pci.Result.EndTime != 0 { po.lock.Unlock() return } if isProbe { - po.bytesProbe += size + po.pci.Result.BytesProbe += size } else { if isRTX { - po.bytesNonProbeRTX += size + po.pci.Result.BytesNonProbeRTX += size } else { - po.bytesNonProbePrimary += size + po.pci.Result.BytesNonProbePrimary += size } } notify := false var clusterId ccutils.ProbeClusterId - if !po.isActiveClusterDone && po.bytesProbe+po.bytesNonProbePrimary+po.bytesNonProbeRTX >= po.desiredProbeClusterBytes { - po.isActiveClusterDone = true + if po.pci.Result.EndTime == 0 && po.pci.Result.Bytes() >= po.pci.Goal.DesiredBytes { + po.pci.Result.EndTime = mono.UnixNano() + po.pci.Result.IsCompleted = true notify = true - clusterId = po.activeProbeClusterId + clusterId = po.pci.Id } po.lock.Unlock() diff --git a/pkg/sfu/streamallocator/probe_controller.go b/pkg/sfu/streamallocator/probe_controller.go index b383d4326..c57d74e2e 100644 --- a/pkg/sfu/streamallocator/probe_controller.go +++ b/pkg/sfu/streamallocator/probe_controller.go @@ -15,6 +15,7 @@ package streamallocator import ( + "fmt" "sync" "time" @@ -22,22 +23,51 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/livekit/livekit-server/pkg/sfu/pacer" "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +const ( + cDefaultRTT = float64(0.070) // 70 ms + cRTTSmoothingFactor = float64(0.5) ) // --------------------------------------------------------------------------- +type ProbeControllerState int + +const ( + ProbeControllerStateNone ProbeControllerState = iota + ProbeControllerStateProbing + ProbeControllerStateHangover +) + +func (p ProbeControllerState) String() string { + switch p { + case ProbeControllerStateNone: + return "NONE" + case ProbeControllerStateProbing: + return "PROBING" + case ProbeControllerStateHangover: + return "HANGOVER" + default: + return fmt.Sprintf("%d", int(p)) + } +} + +// ------------------------------------------------ + type ProbeControllerConfig struct { BaseInterval time.Duration `yaml:"base_interval,omitempty"` BackoffFactor float64 `yaml:"backoff_factor,omitempty"` MaxInterval time.Duration `yaml:"max_interval,omitempty"` - SettleWait time.Duration `yaml:"settle_wait,omitempty"` - SettleWaitMax time.Duration `yaml:"settle_wait_max,omitempty"` + SettleWaitNumRTT uint32 `yaml:"settle_wait_num_rtt,omitempty"` + SettleWaitMin time.Duration `yaml:"settle_wait_min,omitempty"` + SettleWaitMax time.Duration `yaml:"settle_wait_max,omitempty"` - TrendWait time.Duration `yaml:"trend_wait,omitempty"` + OveragePct int64 `yaml:"overage_pct,omitempty"` + MinBps int64 `yaml:"min_bps,omitempty"` - OveragePct int64 `yaml:"overage_pct,omitempty"` - MinBps int64 `yaml:"min_bps,omitempty"` MinDuration time.Duration `yaml:"min_duration,omitempty"` MaxDuration time.Duration `yaml:"max_duration,omitempty"` DurationIncreaseFactor float64 `yaml:"duration_increase_factor,omitempty"` @@ -49,13 +79,13 @@ var ( BackoffFactor: 1.5, MaxInterval: 2 * time.Minute, - SettleWait: 250 * time.Millisecond, - SettleWaitMax: 10 * time.Second, + SettleWaitNumRTT: 10, + SettleWaitMin: 500 * time.Millisecond, + SettleWaitMax: 10 * time.Second, - TrendWait: 2 * time.Second, + OveragePct: 120, + MinBps: 200_000, - OveragePct: 120, - MinBps: 200_000, MinDuration: 200 * time.Millisecond, MaxDuration: 20 * time.Second, DurationIncreaseFactor: 1.5, @@ -75,23 +105,23 @@ type ProbeControllerParams struct { type ProbeController struct { params ProbeControllerParams - lock sync.RWMutex - probeInterval time.Duration - lastProbeStartTime time.Time - probeGoalBps int64 - probeClusterId ccutils.ProbeClusterId - doneProbeClusterInfo ccutils.ProbeClusterInfo - abortedProbeClusterId ccutils.ProbeClusterId - goalReachedProbeClusterId ccutils.ProbeClusterId - probeTrendObserved bool - probeEndTime time.Time - probeDuration time.Duration + lock sync.RWMutex + + state ProbeControllerState + stateSwitchedAt time.Time + + pci ccutils.ProbeClusterInfo + rtt float64 + + probeInterval time.Duration + probeDuration time.Duration + nextProbeEarliestAt time.Time } func NewProbeController(params ProbeControllerParams) *ProbeController { p := &ProbeController{ - params: params, - probeDuration: params.Config.MinDuration, + params: params, + rtt: cDefaultRTT, } p.Reset() @@ -102,260 +132,157 @@ func (p *ProbeController) Reset() { p.lock.Lock() defer p.lock.Unlock() - p.lastProbeStartTime = time.Now() - - p.resetProbeIntervalLocked() - p.resetProbeDurationLocked() - - p.StopProbe() - p.clearProbeLocked() -} - -func (p *ProbeController) ProbeClusterDone(probeClusterId ccutils.ProbeClusterId) { - p.lock.Lock() - defer p.lock.Unlock() - - if p.probeClusterId != probeClusterId { - p.params.Logger.Debugw("not expected probe cluster", "probeClusterId", p.probeClusterId, "resetProbeClusterId", probeClusterId) - } else { - p.doneProbeClusterInfo = p.params.Pacer.EndProbeCluster(probeClusterId) - p.params.Prober.ClusterDone(p.doneProbeClusterInfo) - } -} - -func (p *ProbeController) MaybeFinalizeProbe( - isComplete bool, - trend bwe.ChannelTrend, - lowestEstimate int64, -) (isHandled bool, isNotFailing bool, isGoalReached bool) { - p.lock.Lock() - defer p.lock.Unlock() - - if !p.isInProbeLocked() { - return false, false, false - } - - if p.goalReachedProbeClusterId != ccutils.ProbeClusterIdInvalid { - // finalise goal reached probe cluster - p.finalizeProbeLocked(bwe.ChannelTrendNeutral) - return true, true, true - } - - if (isComplete || p.abortedProbeClusterId != ccutils.ProbeClusterIdInvalid) && - p.probeEndTime.IsZero() && - p.doneProbeClusterInfo.ProbeClusterId != ccutils.ProbeClusterIdInvalid && p.doneProbeClusterInfo.ProbeClusterId == p.probeClusterId { - // ensure any queueing due to probing is flushed - // STREAM-ALLOCATOR-TODO: ProbeControllerConfig.SettleWait should actually be a certain number of RTTs. - expectedDuration := float64(0.0) - if lowestEstimate != 0 { - expectedDuration = float64(p.doneProbeClusterInfo.Bytes()*8*1000) / float64(lowestEstimate) - } - queueTime := expectedDuration - float64(p.doneProbeClusterInfo.Duration().Milliseconds()) - if queueTime < 0.0 { - queueTime = 0.0 - } - queueWait := (time.Duration(queueTime) * time.Millisecond) + p.params.Config.SettleWait - if queueWait > p.params.Config.SettleWaitMax { - queueWait = p.params.Config.SettleWaitMax - } - p.probeEndTime = p.lastProbeStartTime.Add(queueWait + p.doneProbeClusterInfo.Duration()) - p.params.Logger.Debugw( - "setting probe end time", - "probeClusterId", p.probeClusterId, - "expectedDuration", expectedDuration, - "queueTime", queueTime, - "queueWait", queueWait, - "probeEndTime", p.probeEndTime, - ) - } - - if !p.probeEndTime.IsZero() && time.Now().After(p.probeEndTime) { - // finalize aborted or non-failing but non-goal-reached probe cluster - return true, p.finalizeProbeLocked(trend), false - } - - return false, false, false -} - -func (p *ProbeController) DoesProbeNeedFinalize() bool { - p.lock.RLock() - defer p.lock.RUnlock() - - return p.abortedProbeClusterId != ccutils.ProbeClusterIdInvalid || p.goalReachedProbeClusterId != ccutils.ProbeClusterIdInvalid -} - -func (p *ProbeController) finalizeProbeLocked(trend bwe.ChannelTrend) (isNotFailing bool) { - aborted := p.probeClusterId == p.abortedProbeClusterId - - p.clearProbeLocked() - - if aborted || trend == bwe.ChannelTrendCongesting { - // failed probe, backoff - p.backoffProbeIntervalLocked() - p.resetProbeDurationLocked() - return false - } - - // reset probe interval and increase probe duration on a upward trending probe - p.resetProbeIntervalLocked() - if trend == bwe.ChannelTrendClearing { - p.increaseProbeDurationLocked() - } - return true -} - -func (p *ProbeController) InitProbe(probeGoalDeltaBps int64, expectedBandwidthUsage int64) (ccutils.ProbeClusterId, int64) { - p.lock.Lock() - defer p.lock.Unlock() - - p.lastProbeStartTime = time.Now() - - // overshoot a bit to account for noise (in measurement/estimate etc) - desiredIncreaseBps := (probeGoalDeltaBps * p.params.Config.OveragePct) / 100 - if desiredIncreaseBps < p.params.Config.MinBps { - desiredIncreaseBps = p.params.Config.MinBps - } - p.probeGoalBps = expectedBandwidthUsage + desiredIncreaseBps - - p.doneProbeClusterInfo = ccutils.ProbeClusterInfoInvalid - p.abortedProbeClusterId = ccutils.ProbeClusterIdInvalid - p.goalReachedProbeClusterId = ccutils.ProbeClusterIdInvalid - - p.probeTrendObserved = false - - p.probeEndTime = time.Time{} - - p.probeClusterId = p.params.Prober.AddCluster( - ccutils.ProbeClusterModeUniform, - int(p.probeGoalBps), - int(expectedBandwidthUsage), - p.probeDuration, - ) - - p.pollProbe(p.probeClusterId, expectedBandwidthUsage) - - return p.probeClusterId, p.probeGoalBps -} - -func (p *ProbeController) pollProbe(probeClusterId ccutils.ProbeClusterId, expectedBandwidthUsage int64) { - p.params.BWE.ProbingStart(expectedBandwidthUsage) - - go func() { - for { - p.lock.Lock() - if p.probeClusterId != probeClusterId { - p.lock.Unlock() - return - } - - done := false - - _, trend, _, highestEstimate := p.params.BWE.GetProbeStatus() - if !p.probeTrendObserved && trend != bwe.ChannelTrendNeutral { - p.probeTrendObserved = true - } - - switch { - case trend == bwe.ChannelTrendCongesting: - // stop immediately if the probe is congesting channel more - p.params.Logger.Infow( - "stream allocator: probe: aborting, channel is congesting", - "cluster", probeClusterId, - ) - p.abortProbeLocked() - done = true - break - - case highestEstimate > p.probeGoalBps: - // reached goal, stop probing - p.params.Logger.Infow( - "stream allocator: probe: stopping, goal reached", - "cluster", probeClusterId, - "goal", p.probeGoalBps, - "highestEstimate", highestEstimate, - ) - p.goalReachedProbeClusterId = p.probeClusterId - p.StopProbe() - done = true - break - - case !p.probeTrendObserved && time.Since(p.lastProbeStartTime) > p.params.Config.TrendWait: - // - // More of a safety net. - // In rare cases, the estimate gets stuck. Prevent from probe running amok - // STREAM-ALLOCATOR-TODO: Need more testing here to ensure that probe does not cause a lot of damage - // - p.params.Logger.Infow("stream allocator: probe: aborting, no trend", "cluster", probeClusterId) - p.abortProbeLocked() - done = true - break - } - p.lock.Unlock() - - if done { - return - } - - // BWE-TODO: do not hard code sleep time - time.Sleep(50 * time.Millisecond) - } - }() -} - -func (p *ProbeController) clearProbeLocked() { - p.probeClusterId = ccutils.ProbeClusterIdInvalid - p.doneProbeClusterInfo = ccutils.ProbeClusterInfoInvalid - p.abortedProbeClusterId = ccutils.ProbeClusterIdInvalid - p.goalReachedProbeClusterId = ccutils.ProbeClusterIdInvalid -} - -func (p *ProbeController) backoffProbeIntervalLocked() { - p.probeInterval = time.Duration(p.probeInterval.Seconds()*p.params.Config.BackoffFactor) * time.Second - if p.probeInterval > p.params.Config.MaxInterval { - p.probeInterval = p.params.Config.MaxInterval - } -} - -func (p *ProbeController) resetProbeIntervalLocked() { + p.state = ProbeControllerStateNone + p.stateSwitchedAt = mono.Now() + p.pci = ccutils.ProbeClusterInfoInvalid p.probeInterval = p.params.Config.BaseInterval -} - -func (p *ProbeController) resetProbeDurationLocked() { p.probeDuration = p.params.Config.MinDuration } -func (p *ProbeController) increaseProbeDurationLocked() { - p.probeDuration = time.Duration(float64(p.probeDuration.Milliseconds())*p.params.Config.DurationIncreaseFactor) * time.Millisecond - if p.probeDuration > p.params.Config.MaxDuration { - p.probeDuration = p.params.Config.MaxDuration +func (p *ProbeController) UpdateRTT(rtt float64) { + if rtt == 0 { + p.rtt = cDefaultRTT + } else { + if p.rtt == 0 { + p.rtt = rtt + } else { + p.rtt = cRTTSmoothingFactor*rtt + (1.0-cRTTSmoothingFactor)*p.rtt + } } } -func (p *ProbeController) StopProbe() { - p.params.Prober.Reset(p.params.Pacer.EndProbeCluster(p.probeClusterId)) -} - -func (p *ProbeController) AbortProbe() { - p.lock.Lock() - defer p.lock.Unlock() - - p.abortProbeLocked() -} - -func (p *ProbeController) abortProbeLocked() { - p.abortedProbeClusterId = p.probeClusterId - p.StopProbe() -} - -func (p *ProbeController) isInProbeLocked() bool { - return p.probeClusterId != ccutils.ProbeClusterIdInvalid -} - func (p *ProbeController) CanProbe() bool { p.lock.RLock() defer p.lock.RUnlock() - return time.Since(p.lastProbeStartTime) >= p.probeInterval && p.probeClusterId == ccutils.ProbeClusterIdInvalid + return p.state == ProbeControllerStateNone && mono.Now().After(p.nextProbeEarliestAt) +} + +func (p *ProbeController) MaybeInitiateProbe(availableBandwidthBps int64, probeGoalDeltaBps int64, expectedBandwidthUsage int64) (ccutils.ProbeClusterGoal, bool) { + p.lock.RLock() + defer p.lock.RUnlock() + + if p.state != ProbeControllerStateNone { + // already probing or in probe hangover, don't start a new one + return ccutils.ProbeClusterGoal{}, false + } + + if mono.Now().Before(p.nextProbeEarliestAt) { + return ccutils.ProbeClusterGoal{}, false + } + + // overshoot a bit to account for noise (in measurement/estimate etc) + desiredIncreaseBps := (probeGoalDeltaBps * p.params.Config.OveragePct) / 100 + if desiredIncreaseBps < p.params.Config.MinBps { + desiredIncreaseBps = p.params.Config.MinBps + } + return ccutils.ProbeClusterGoal{ + AvailableBandwidthBps: int(availableBandwidthBps), + ExpectedUsageBps: int(expectedBandwidthUsage), + DesiredBps: int(expectedBandwidthUsage + desiredIncreaseBps), + Duration: p.probeDuration, + }, true +} + +func (p *ProbeController) ProbeClusterStarting(pci ccutils.ProbeClusterInfo) { + p.lock.Lock() + defer p.lock.Unlock() + + if p.state != ProbeControllerStateNone { + p.params.Logger.Warnw("unexpected probe controller state", nil, "state", p.state) + } + + p.setState(ProbeControllerStateProbing) + p.pci = pci +} + +func (p *ProbeController) ProbeClusterDone(pci ccutils.ProbeClusterInfo) { + p.lock.Lock() + defer p.lock.Unlock() + + if p.state != ProbeControllerStateProbing { + p.params.Logger.Warnw("unexpected probe controller state", nil, "state", p.state) + } + + if p.pci.Id != pci.Id { + p.params.Logger.Warnw("not expected probe cluster", nil, "expectedId", p.pci.Id, "doneId", pci.Id) + } + + p.pci.Result = pci.Result + p.params.Prober.ClusterDone(pci) + + p.setState(ProbeControllerStateHangover) +} + +func (p *ProbeController) MaybeFinalizeProbe() (ccutils.ProbeClusterInfo, bool) { + p.lock.Lock() + defer p.lock.Unlock() + + if p.state != ProbeControllerStateHangover { + return ccutils.ProbeClusterInfoInvalid, false + } + + settleWait := time.Duration(float64(p.params.Config.SettleWaitNumRTT) * p.rtt * float64(time.Second)) + if settleWait < p.params.Config.SettleWaitMin { + settleWait = p.params.Config.SettleWaitMin + } + if settleWait > p.params.Config.SettleWaitMax { + settleWait = p.params.Config.SettleWaitMax + } + if time.Since(p.stateSwitchedAt) < settleWait { + return ccutils.ProbeClusterInfoInvalid, false + } + + p.setState(ProbeControllerStateNone) + return p.pci, true +} + +func (p *ProbeController) ProbeCongestionSignal(isCongestionClearing bool) { + if !isCongestionClearing { + // wait longer till next probe + p.probeInterval = time.Duration(p.probeInterval.Seconds()*p.params.Config.BackoffFactor) * time.Second + if p.probeInterval > p.params.Config.MaxInterval { + p.probeInterval = p.params.Config.MaxInterval + } + + // revert back to starting with shortest probe + p.probeDuration = p.params.Config.MinDuration + } else { + // probe can be started again after minimal interval as previous congestion signal indicated congestion clearing + p.probeInterval = p.params.Config.BaseInterval + + // can do longer probe after a good probe + p.probeDuration = time.Duration(float64(p.probeDuration.Milliseconds())*p.params.Config.DurationIncreaseFactor) * time.Millisecond + if p.probeDuration > p.params.Config.MaxDuration { + p.probeDuration = p.params.Config.MaxDuration + } + } + + if p.pci.CreatedAt.IsZero() { + p.nextProbeEarliestAt = mono.Now().Add(p.probeInterval) + } else { + p.nextProbeEarliestAt = p.pci.CreatedAt.Add(p.probeInterval) + } +} + +func (p *ProbeController) GetActiveProbeClusterId() ccutils.ProbeClusterId { + p.lock.RLock() + defer p.lock.RUnlock() + + if p.state == ProbeControllerStateNone { + return ccutils.ProbeClusterIdInvalid + } + + return p.pci.Id +} + +func (p *ProbeController) setState(state ProbeControllerState) { + if state == p.state { + return + } + + p.state = state + p.stateSwitchedAt = mono.Now() } // ------------------------------------------------ diff --git a/pkg/sfu/streamallocator/streamallocator.go b/pkg/sfu/streamallocator/streamallocator.go index 261ecbb8a..67adcf59d 100644 --- a/pkg/sfu/streamallocator/streamallocator.go +++ b/pkg/sfu/streamallocator/streamallocator.go @@ -50,6 +50,8 @@ const ( FlagAllowOvershootInProbe = true FlagAllowOvershootInCatchup = false FlagAllowOvershootInBoost = true + + cRTTPullInterval = 30 * time.Second ) // --------------------------------------------------------------------------- @@ -82,14 +84,15 @@ const ( streamAllocatorSignalAdjustState streamAllocatorSignalEstimate streamAllocatorSignalPeriodicPing + streamAllocatorSignalProbeClusterSwitch streamAllocatorSignalSendProbe + streamAllocatorSignalPacerProbeObserverClusterComplete streamAllocatorSignalResume streamAllocatorSignalSetAllowPause streamAllocatorSignalSetChannelCapacity // STREAM-ALLOCATOR-DATA streamAllocatorSignalNACK // STREAM-ALLOCATOR-DATA streamAllocatorSignalRTCPReceiverReport streamAllocatorSignalCongestionStateChange - streamAllocatorSignalPacerProbeObserverClusterComplete ) func (s streamAllocatorSignal) String() string { @@ -104,8 +107,12 @@ func (s streamAllocatorSignal) String() string { return "ESTIMATE" case streamAllocatorSignalPeriodicPing: return "PERIODIC_PING" + case streamAllocatorSignalProbeClusterSwitch: + return "PROBE_CLUSTER_SWITCH" case streamAllocatorSignalSendProbe: return "SEND_PROBE" + case streamAllocatorSignalPacerProbeObserverClusterComplete: + return "PACER_PROBE_OBSERVER_CLUSTER_COMPLETE" case streamAllocatorSignalResume: return "RESUME" case streamAllocatorSignalSetAllowPause: @@ -120,8 +127,6 @@ func (s streamAllocatorSignal) String() string { */ case streamAllocatorSignalCongestionStateChange: return "CONGESTION_STATE_CHANGE" - case streamAllocatorSignalPacerProbeObserverClusterComplete: - return "PACER_PROBE_OBSERVER_CLUSTER_COMPLETE" default: return fmt.Sprintf("%d", int(s)) } @@ -168,10 +173,11 @@ var ( // --------------------------------------------------------------------------- type StreamAllocatorParams struct { - Config StreamAllocatorConfig - BWE bwe.BWE - Pacer pacer.Pacer - Logger logger.Logger + Config StreamAllocatorConfig + BWE bwe.BWE + Pacer pacer.Pacer + RTTGetter func() (float64, bool) + Logger logger.Logger } type StreamAllocator struct { @@ -188,8 +194,7 @@ type StreamAllocator struct { overriddenChannelCapacity int64 probeController *ProbeController - - prober *ccutils.Prober + prober *ccutils.Prober // STREAM-ALLOCATOR-DATA rateMonitor *RateMonitor @@ -204,6 +209,8 @@ type StreamAllocator struct { eventsQueue *utils.TypedOpsQueue[Event] + lastRTTTime time.Time + isStopped atomic.Bool } @@ -213,7 +220,9 @@ func NewStreamAllocator(params StreamAllocatorParams, enabled bool, allowPause b enabled: enabled, allowPause: allowPause, // STREAM-ALLOCATOR-DATA rateMonitor: NewRateMonitor(), - videoTracks: make(map[livekit.TrackID]*Track), + videoTracks: make(map[livekit.TrackID]*Track), + state: streamAllocatorStateStable, + congestionState: bwe.CongestionStateNone, eventsQueue: utils.NewTypedOpsQueue[Event](utils.OpsQueueParams{ Name: "stream-allocator", MinSize: 64, @@ -237,8 +246,6 @@ func NewStreamAllocator(params StreamAllocatorParams, enabled bool, allowPause b s.params.BWE.SetBWEListener(s) s.params.Pacer.SetPacerProbeObserverListener(s) - s.resetState() - return s } @@ -254,7 +261,8 @@ func (s *StreamAllocator) Stop() { // wait for eventsQueue to be done <-s.eventsQueue.Stop() - s.probeController.StopProbe() + + s.maybeStopProbe() } func (s *StreamAllocator) OnStreamStateChange(f func(update *StreamStateUpdate) error) { @@ -341,13 +349,6 @@ func (s *StreamAllocator) SetChannelCapacity(channelCapacity int64) { }) } -func (s *StreamAllocator) resetState() { - s.params.BWE.Reset() - s.probeController.Reset() - - s.state = streamAllocatorStateStable -} - // called when a new REMB is received (receive side bandwidth estimation) func (s *StreamAllocator) OnREMB(downTrack *sfu.DownTrack, remb *rtcp.ReceiverEstimatedMaximumBitrate) { // @@ -532,6 +533,14 @@ func (s *StreamAllocator) OnRTCPReceiverReport(downTrack *sfu.DownTrack, rr rtcp } */ +// called when probe cluster changes +func (s *StreamAllocator) OnProbeClusterSwitch(pci ccutils.ProbeClusterInfo) { + s.postEvent(Event{ + Signal: streamAllocatorSignalProbeClusterSwitch, + Data: pci, + }) +} + // called when prober wants to send packet(s) func (s *StreamAllocator) OnSendProbe(bytesToSend int) { s.postEvent(Event{ @@ -540,15 +549,6 @@ func (s *StreamAllocator) OnSendProbe(bytesToSend int) { }) } -// called when probe cluster changes -func (s *StreamAllocator) OnProbeClusterSwitch(probeClusterId ccutils.ProbeClusterId, desiredBytes int) { - s.params.Pacer.StartProbeCluster(probeClusterId, desiredBytes) - - for _, t := range s.getTracks() { - t.DownTrack().SetProbeClusterId(probeClusterId) - } -} - // called when pacer probe observer observes a cluster completion func (s *StreamAllocator) OnPacerProbeObserverClusterComplete(probeClusterId ccutils.ProbeClusterId) { s.postEvent(Event{ @@ -631,8 +631,12 @@ func (s *StreamAllocator) postEvent(event Event) { event.handleSignalEstimate(event) case streamAllocatorSignalPeriodicPing: event.handleSignalPeriodicPing(event) + case streamAllocatorSignalProbeClusterSwitch: + event.handleSignalProbeClusterSwitch(event) case streamAllocatorSignalSendProbe: event.handleSignalSendProbe(event) + case streamAllocatorSignalPacerProbeObserverClusterComplete: + event.handleSignalPacerProbeObserverClusterComplete(event) case streamAllocatorSignalResume: event.handleSignalResume(event) case streamAllocatorSignalSetAllowPause: @@ -647,8 +651,6 @@ func (s *StreamAllocator) postEvent(event Event) { */ case streamAllocatorSignalCongestionStateChange: s.handleSignalCongestionStateChange(event) - case streamAllocatorSignalPacerProbeObserverClusterComplete: - event.handleSignalPacerProbeObserverClusterComplete(event) } }, event) } @@ -688,7 +690,6 @@ func (s *StreamAllocator) handleSignalEstimate(event Event) { s.params.BWE.HandleREMB( receivedEstimate, - s.probeController.DoesProbeNeedFinalize(), // waiting for goal reached OR aborted probe to finalize s.getExpectedBandwidthUsage(), packetDelta, repeatedNackDelta, @@ -696,15 +697,18 @@ func (s *StreamAllocator) handleSignalEstimate(event Event) { } func (s *StreamAllocator) handleSignalPeriodicPing(Event) { - // finalize probe if necessary - isValidSignal, trend, lowestEstimate, highestEstimate := s.params.BWE.GetProbeStatus() - isHandled, isNotFailing, isGoalReached := s.probeController.MaybeFinalizeProbe( - isValidSignal, - trend, - lowestEstimate, - ) - if isHandled { - s.onProbeDone(isNotFailing, isGoalReached, highestEstimate) + // finalize any probe that may have finished/aborted + if pci, ok := s.probeController.MaybeFinalizeProbe(); ok { + isCongestionClearing, channelCapacity := s.params.BWE.ProbeClusterDone(pci) + if isCongestionClearing { + if channelCapacity > s.committedChannelCapacity { + s.committedChannelCapacity = channelCapacity + } + + s.maybeBoostDeficientTracks() + } + + s.probeController.ProbeCongestionSignal(isCongestionClearing) } // probe if necessary and timing is right @@ -712,12 +716,34 @@ func (s *StreamAllocator) handleSignalPeriodicPing(Event) { s.maybeProbe() } + if time.Since(s.lastRTTTime) > cRTTPullInterval { + s.lastRTTTime = time.Now() + + if s.params.RTTGetter != nil { + if rtt, ok := s.params.RTTGetter(); ok { + s.probeController.UpdateRTT(rtt) + } + } + } + /* STREAM-ALLOCATOR-DATA s.monitorRate(s.committedChannelCapacity) s.updateTracksHistory() */ } +func (s *StreamAllocator) handleSignalProbeClusterSwitch(event Event) { + pci := event.Data.(ccutils.ProbeClusterInfo) + s.probeController.ProbeClusterStarting(pci) + s.params.BWE.ProbeClusterStarting(pci) + + s.params.Pacer.StartProbeCluster(pci) + + for _, t := range s.getTracks() { + t.DownTrack().SetProbeClusterId(pci.Id) + } +} + func (s *StreamAllocator) handleSignalSendProbe(event Event) { bytesToSend := event.Data.(int) if bytesToSend <= 0 { @@ -735,6 +761,13 @@ func (s *StreamAllocator) handleSignalSendProbe(event Event) { } } +func (s *StreamAllocator) handleSignalPacerProbeObserverClusterComplete(event Event) { + probeClusterId, _ := event.Data.(ccutils.ProbeClusterId) + pci := s.params.Pacer.EndProbeCluster(probeClusterId) + s.probeController.ProbeClusterDone(pci) + s.params.BWE.ProbeClusterDone(pci) +} + func (s *StreamAllocator) handleSignalResume(event Event) { s.videoTracksMu.Lock() track := s.videoTracks[event.TrackID] @@ -791,7 +824,8 @@ func (s *StreamAllocator) handleSignalRTCPReceiverReport(event Event) { func (s *StreamAllocator) handleSignalCongestionStateChange(event Event) { cscd := event.Data.(congestionStateChangeData) if cscd.congestionState != bwe.CongestionStateNone { - s.probeController.AbortProbe() + // end/abort any running probe if channel is not clear + s.maybeStopProbe() } if cscd.congestionState == bwe.CongestionStateEarlyWarning || @@ -805,7 +839,7 @@ func (s *StreamAllocator) handleSignalCongestionStateChange(event Event) { if s.isHolding && cscd.congestionState == bwe.CongestionStateNone && s.state == streamAllocatorStateStable { update := NewStreamStateUpdate() for _, track := range s.getTracks() { - allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal, s.isHolding) + allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal, false) updateStreamStateChange(track, allocation, update) } s.maybeSendUpdate(update) @@ -815,41 +849,37 @@ func (s *StreamAllocator) handleSignalCongestionStateChange(event Event) { } if cscd.congestionState == bwe.CongestionStateCongested { - s.params.Logger.Infow( - "stream allocator: channel congestion detected, updating channel capacity", - "old(bps)", s.committedChannelCapacity, - "new(bps)", cscd.estimatedAvailableChannelCapacity, - "expectedUsage(bps)", s.getExpectedBandwidthUsage(), - ) - /* STREAM-ALLOCATOR-DATA - s.params.Logger.Debugw( - fmt.Sprintf("stream allocator: channel congestion detected, %s channel capacity: experimental", action), - "rateHistory", s.rateMonitor.GetHistory(), - "expectedQueuing", s.rateMonitor.GetQueuingGuess(), - "trackHistory", s.getTracksHistory(), - ) - */ - s.committedChannelCapacity = cscd.estimatedAvailableChannelCapacity + if s.probeController.GetActiveProbeClusterId() == ccutils.ProbeClusterIdInvalid { + s.params.Logger.Infow( + "stream allocator: channel congestion detected, not updating channel capacity in active probe", + "old(bps)", s.committedChannelCapacity, + "new(bps)", cscd.estimatedAvailableChannelCapacity, + "expectedUsage(bps)", s.getExpectedBandwidthUsage(), + ) + } else { + s.params.Logger.Infow( + "stream allocator: channel congestion detected, updating channel capacity", + "old(bps)", s.committedChannelCapacity, + "new(bps)", cscd.estimatedAvailableChannelCapacity, + "expectedUsage(bps)", s.getExpectedBandwidthUsage(), + ) + /* STREAM-ALLOCATOR-DATA + s.params.Logger.Debugw( + fmt.Sprintf("stream allocator: channel congestion detected, %s channel capacity: experimental", action), + "rateHistory", s.rateMonitor.GetHistory(), + "expectedQueuing", s.rateMonitor.GetQueuingGuess(), + "trackHistory", s.getTracksHistory(), + ) + */ + s.committedChannelCapacity = cscd.estimatedAvailableChannelCapacity - // reset probe to ensure it does not start too soon after a downward trend - // BWE-TODO: maybe probe controller setting should be algorithm specific - // BWE-TODO: for e. g., the reset could be waiting shorter in SSBWE case - // BWE-TODO: a couple of things to consider - // BWE-TODO: 1. Make ProbeController be owned by BWE modules? - // BWE-TODO: 2. Add an interface method to BWE to check if probe controller should be reset? - s.probeController.Reset() - - s.allocateAllTracks() + s.allocateAllTracks() + } } s.congestionState = cscd.congestionState } -func (s *StreamAllocator) handleSignalPacerProbeObserverClusterComplete(event Event) { - probeClusterId, _ := event.Data.(ccutils.ProbeClusterId) - s.probeController.ProbeClusterDone(probeClusterId) -} - func (s *StreamAllocator) setState(state streamAllocatorState) { if s.state == state { return @@ -858,16 +888,14 @@ func (s *StreamAllocator) setState(state streamAllocatorState) { s.params.Logger.Infow("stream allocator: state change", "from", s.state, "to", state) s.state = state - // reset probe to enforce a delay after state change before probing - s.probeController.Reset() + // restart everything when when state is stable + if state == streamAllocatorStateStable { + s.maybeStopProbe() - // a fresh start after state transition to get clean data - // BWE-TODO: ssbwe maybe should not reset like this as it might have useful state across - // BWE-TODO: state changes in this module, actually even remotebwe should also manage it - // BWE-TODO: internally, Reset should probably only be used if all managed tracks go away - // BWE-TODO: and we can get a clean start, mimicking existing behaviour till this can be - // BWE-TODO: evaluated more. - s.params.BWE.Reset() + s.probeController.Reset() + + s.params.BWE.Reset() + } } func (s *StreamAllocator) adjustState() { @@ -882,8 +910,8 @@ func (s *StreamAllocator) adjustState() { } func (s *StreamAllocator) allocateTrack(track *Track) { - // abort any probe that may be running when a track specific change needs allocation - s.probeController.AbortProbe() + // end/abort any probe that may be running when a track specific change needs allocation + s.maybeStopProbe() // if not deficient, free pass allocate track if !s.enabled || s.state == streamAllocatorStateStable || !track.IsManaged() { @@ -1036,18 +1064,13 @@ func (s *StreamAllocator) allocateTrack(track *Track) { s.adjustState() } -func (s *StreamAllocator) onProbeDone(isNotFailing bool, isGoalReached bool, highestEstimate int64) { - s.params.BWE.ProbingEnd(isNotFailing, isGoalReached) - - if !isNotFailing { - return +func (s *StreamAllocator) maybeStopProbe() { + activeProbeClusterId := s.probeController.GetActiveProbeClusterId() + if activeProbeClusterId != ccutils.ProbeClusterIdInvalid { + pci := s.params.Pacer.EndProbeCluster(activeProbeClusterId) + s.prober.Reset(pci) + s.probeController.ProbeClusterDone(pci) } - - if highestEstimate > s.committedChannelCapacity { - s.committedChannelCapacity = highestEstimate - } - - s.maybeBoostDeficientTracks() } func (s *StreamAllocator) maybeBoostDeficientTracks() { @@ -1268,19 +1291,6 @@ func (s *StreamAllocator) getNackDelta() (uint32, uint32) { return aggPacketDelta, aggRepeatedNackDelta } -func (s *StreamAllocator) initProbe(probeGoalDeltaBps int64) { - expectedBandwidthUsage := s.getExpectedBandwidthUsage() - probeClusterId, probeGoalBps := s.probeController.InitProbe(probeGoalDeltaBps, expectedBandwidthUsage) - s.params.Logger.Debugw( - "stream allocator: starting probe", - "probeClusterId", probeClusterId, - "current usage", expectedBandwidthUsage, - "committed", s.committedChannelCapacity, - "probeGoalDeltaBps", probeGoalDeltaBps, - "goalBps", probeGoalBps, - ) -} - func (s *StreamAllocator) maybeProbe() { if s.overriddenChannelCapacity > 0 { // do not probe if channel capacity is overridden @@ -1325,7 +1335,14 @@ func (s *StreamAllocator) maybeProbeWithPadding() { continue } - s.initProbe(transition.BandwidthDelta) + pcg, ok := s.probeController.MaybeInitiateProbe(s.committedChannelCapacity, transition.BandwidthDelta, s.getExpectedBandwidthUsage()) + if ok { + pci := s.prober.AddCluster(ccutils.ProbeClusterModeUniform, pcg) + s.params.Logger.Debugw( + "stream allocator: starting probe", + "probeClusterInfo", pci, + ) + } break } }