From a3f2ca56f95aac4c33512a74daab0b80a15a091d Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Mon, 11 Nov 2024 10:24:47 +0530 Subject: [PATCH] TWCC based congestion control - v0 (#3165) * file output * wake under lock * keep track of RTX bytes separately * packet group * Packet group of 50ms * Minor refactoring * rate calculator * send bit rate * WIP * comment * reduce packet infos size * extended twcc seq num * fix packet info * WIP * queuing delay * refactor * config * callbacks * fixes * clean up * remove debug file, fix rate calculation * fmt * fix probes * format * notes * check loss * tweak detection settings * 24-bit wrap * clean up a bit * limit symbol list to number of packets * fmt * clean up * lost * fixes * fmt * rename * fixes * fmt * use min/max * hold on early warning of congestion * make note about need for all optimal allocation on hold release * estimate trend in congested state * tweaks * quantized * fmt * TrendDetector generics * CTR trend * tweaks * config * config * comments * clean up * consistent naming * pariticpant level setting * log usage mode * feedback --- pkg/config/config.go | 22 +- pkg/rtc/config.go | 2 +- pkg/rtc/participant.go | 4 + pkg/rtc/transport.go | 41 +- pkg/rtc/transportmanager.go | 4 + .../trenddetector.go | 115 ++-- pkg/sfu/downtrack.go | 5 +- pkg/sfu/forwarder.go | 58 +- pkg/sfu/forwarder_test.go | 117 +++- pkg/sfu/pacer/base.go | 40 +- pkg/sfu/pacer/leaky_bucket.go | 5 +- pkg/sfu/pacer/no_queue.go | 5 +- pkg/sfu/pacer/pacer.go | 1 + pkg/sfu/pacer/pass_through.go | 5 +- pkg/sfu/rtpstats/rtpstats_sender.go | 9 +- pkg/sfu/sendsidebwe/congestion_detector.go | 556 ++++++++++++++++++ pkg/sfu/sendsidebwe/packet_group.go | 338 +++++++++++ pkg/sfu/sendsidebwe/packet_info.go | 36 ++ pkg/sfu/sendsidebwe/packet_tracker.go | 113 ++++ pkg/sfu/sendsidebwe/send_side_bwe.go | 105 ++++ pkg/sfu/sendsidebwe/twcc_feedback.go | 116 ++++ pkg/sfu/streamallocator/channelobserver.go | 36 +- pkg/sfu/streamallocator/probe_controller.go | 78 +++ pkg/sfu/streamallocator/streamallocator.go | 102 +++- pkg/sfu/streamallocator/track.go | 4 +- 25 files changed, 1748 insertions(+), 169 deletions(-) rename pkg/sfu/{streamallocator => ccutils}/trenddetector.go (72%) create mode 100644 pkg/sfu/sendsidebwe/congestion_detector.go create mode 100644 pkg/sfu/sendsidebwe/packet_group.go create mode 100644 pkg/sfu/sendsidebwe/packet_info.go create mode 100644 pkg/sfu/sendsidebwe/packet_tracker.go create mode 100644 pkg/sfu/sendsidebwe/send_side_bwe.go create mode 100644 pkg/sfu/sendsidebwe/twcc_feedback.go diff --git a/pkg/config/config.go b/pkg/config/config.go index 85b377da7..a966ab288 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -29,6 +29,7 @@ import ( "github.com/livekit/livekit-server/pkg/metric" "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/livekit" @@ -123,10 +124,15 @@ type TURNServer struct { } type CongestionControlConfig struct { - Enabled bool `yaml:"enabled,omitempty"` - AllowPause bool `yaml:"allow_pause,omitempty"` + Enabled bool `yaml:"enabled,omitempty"` + AllowPause bool `yaml:"allow_pause,omitempty"` + StreamAllocator streamallocator.StreamAllocatorConfig `yaml:"stream_allocator,omitempty"` - UseSendSideBWE bool `yaml:"use_send_side_bwe,omitempty"` + + UseSendSideBWEInterceptor bool `yaml:"use_send_side_bwe_interceptor,omitempty"` + + UseSendSideBWE bool `yaml:"use_send_side_bwe,omitempty"` + SendSideBWE sendsidebwe.SendSideBWEConfig `yaml:"send_side_bwe,omitempty"` } type PlayoutDelayConfig struct { @@ -305,10 +311,12 @@ var DefaultConfig = Config{ StrictACKs: true, PLIThrottle: sfu.DefaultPLIThrottleConfig, CongestionControl: CongestionControlConfig{ - Enabled: true, - AllowPause: false, - StreamAllocator: streamallocator.DefaultStreamAllocatorConfig, - UseSendSideBWE: false, + Enabled: true, + AllowPause: false, + StreamAllocator: streamallocator.DefaultStreamAllocatorConfig, + UseSendSideBWEInterceptor: false, + UseSendSideBWE: false, + SendSideBWE: sendsidebwe.DefaultSendSideBWEConfig, }, }, Audio: sfu.DefaultAudioConfig, diff --git a/pkg/rtc/config.go b/pkg/rtc/config.go index 5c608d1f2..60cf06bc0 100644 --- a/pkg/rtc/config.go +++ b/pkg/rtc/config.go @@ -133,7 +133,7 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { }, }, } - if rtcConf.CongestionControl.UseSendSideBWE { + if rtcConf.CongestionControl.UseSendSideBWEInterceptor || rtcConf.CongestionControl.UseSendSideBWE { subscriberConfig.RTPHeaderExtension.Video = append(subscriberConfig.RTPHeaderExtension.Video, sdp.TransportCCURI) subscriberConfig.RTCPFeedback.Video = append(subscriberConfig.RTCPFeedback.Video, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}) } else { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index d5ea78462..28ab0034f 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -158,6 +158,8 @@ type ParticipantParams struct { ForwardStats *sfu.ForwardStats DisableSenderReportPassThrough bool MetricConfig metric.MetricConfig + UseSendSideBWEInterceptor bool + UseSendSideBWE bool } type ParticipantImpl struct { @@ -1458,6 +1460,8 @@ func (p *ParticipantImpl) setupTransportManager() error { PublisherHandler: pth, SubscriberHandler: sth, DataChannelStats: p.dataChannelStats, + UseSendSideBWEInterceptor: p.params.UseSendSideBWEInterceptor, + UseSendSideBWE: p.params.UseSendSideBWE, } if p.params.SyncStreams && p.params.PlayoutDelay.GetEnabled() && p.params.ClientInfo.isFirefox() { // we will disable playout delay for Firefox if the user is expecting diff --git a/pkg/rtc/transport.go b/pkg/rtc/transport.go index 644096721..c4eb31a4f 100644 --- a/pkg/rtc/transport.go +++ b/pkg/rtc/transport.go @@ -36,23 +36,23 @@ import ( "go.uber.org/atomic" "go.uber.org/zap/zapcore" - lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor" - lktwcc "github.com/livekit/mediatransportutil/pkg/twcc" - "github.com/livekit/protocol/livekit" - "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/logger/pionlogger" - lksdp "github.com/livekit/protocol/sdp" - "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc/transport" "github.com/livekit/livekit-server/pkg/rtc/types" sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor" "github.com/livekit/livekit-server/pkg/sfu/pacer" pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/livekit-server/pkg/sfu/streamallocator" sfuutils "github.com/livekit/livekit-server/pkg/sfu/utils" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/livekit-server/pkg/utils" + lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor" + lktwcc "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/logger/pionlogger" + lksdp "github.com/livekit/protocol/sdp" ) const ( @@ -208,7 +208,8 @@ type PCTransport struct { streamAllocator *streamallocator.StreamAllocator // only for subscriber PC - pacer pacer.Pacer + sendSideBWE *sendsidebwe.SendSideBWE + pacer pacer.Pacer previousAnswer *webrtc.SessionDescription // track id -> description map in previous offer sdp @@ -254,6 +255,8 @@ type TransportParams struct { IsSendSide bool AllowPlayoutDelay bool DataChannelMaxBufferedAmount uint64 + UseSendSideBWEInterceptor bool + UseSendSideBWE bool } func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimator cc.BandwidthEstimator)) (*webrtc.PeerConnection, *webrtc.MediaEngine, error) { @@ -347,7 +350,8 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat ir := &interceptor.Registry{} if params.IsSendSide { se.DetachDataChannels() - if params.CongestionControlConfig.UseSendSideBWE { + if params.CongestionControlConfig.UseSendSideBWEInterceptor || params.UseSendSideBWEInterceptor && (!params.CongestionControlConfig.UseSendSideBWE && !params.UseSendSideBWE) { + params.Logger.Infow("using send side BWE - interceptor") gf, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { return gcc.NewSendSideBWE( gcc.SendSideBWEInitialBitrate(1*1000*1000), @@ -449,7 +453,19 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) { }, params.CongestionControlConfig.Enabled, params.CongestionControlConfig.AllowPause) t.streamAllocator.OnStreamStateChange(params.Handler.OnStreamStateChange) t.streamAllocator.Start() - t.pacer = pacer.NewPassThrough(params.Logger) + + if params.CongestionControlConfig.UseSendSideBWE || params.UseSendSideBWE { + params.Logger.Infow("using send side BWE") + t.sendSideBWE = sendsidebwe.NewSendSideBWE(sendsidebwe.SendSideBWEParams{ + Config: params.CongestionControlConfig.SendSideBWE, + Logger: params.Logger, + }) + t.pacer = pacer.NewNoQueue(params.Logger, t.sendSideBWE) + + t.streamAllocator.SetSendSideBWE(t.sendSideBWE) + } else { + t.pacer = pacer.NewPassThrough(params.Logger, nil) + } } if err := t.createPeerConnection(); err != nil { @@ -496,7 +512,7 @@ func (t *PCTransport) createPeerConnection() error { t.me = me if bwe != nil && t.streamAllocator != nil { - t.streamAllocator.SetBandwidthEstimator(bwe) + t.streamAllocator.SetSendSideBWEInterceptor(bwe) } return nil @@ -975,6 +991,9 @@ func (t *PCTransport) Close() { if t.pacer != nil { t.pacer.Stop() } + if t.sendSideBWE != nil { + t.sendSideBWE.Stop() + } _ = t.pc.Close() diff --git a/pkg/rtc/transportmanager.go b/pkg/rtc/transportmanager.go index cc69e514c..661cd2894 100644 --- a/pkg/rtc/transportmanager.go +++ b/pkg/rtc/transportmanager.go @@ -103,6 +103,8 @@ type TransportManagerParams struct { PublisherHandler transport.Handler SubscriberHandler transport.Handler DataChannelStats *telemetry.BytesTrackStats + UseSendSideBWEInterceptor bool + UseSendSideBWE bool } type TransportManager struct { @@ -180,6 +182,8 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, Transport: livekit.SignalTarget_SUBSCRIBER, Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, + UseSendSideBWEInterceptor: params.UseSendSideBWEInterceptor, + UseSendSideBWE: params.UseSendSideBWE, }) if err != nil { return nil, err diff --git a/pkg/sfu/streamallocator/trenddetector.go b/pkg/sfu/ccutils/trenddetector.go similarity index 72% rename from pkg/sfu/streamallocator/trenddetector.go rename to pkg/sfu/ccutils/trenddetector.go index 164ca419e..4806387bc 100644 --- a/pkg/sfu/streamallocator/trenddetector.go +++ b/pkg/sfu/ccutils/trenddetector.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package streamallocator +package ccutils import ( "fmt" @@ -46,26 +46,15 @@ func (t TrendDirection) String() string { // ------------------------------------------------ -type trendDetectorSample struct { - value int64 - at time.Time +type trendDetectorNumber interface { + int64 | float64 } -func trendDetectorSampleListToString(samples []trendDetectorSample) string { - samplesStr := "" - if len(samples) > 0 { - firstTime := samples[0].at - samplesStr += "[" - for i, sample := range samples { - suffix := ", " - if i == len(samples)-1 { - suffix = "" - } - samplesStr += fmt.Sprintf("%d(%d)%s", sample.value, sample.at.Sub(firstTime).Milliseconds(), suffix) - } - samplesStr += "]" - } - return samplesStr +// ------------------------------------------------ + +type trendDetectorSample[T trendDetectorNumber] struct { + value T + at time.Time } // ------------------------------------------------ @@ -79,26 +68,6 @@ type TrendDetectorConfig struct { ValidityWindow time.Duration `yaml:"validity_window,omitempty"` } -var ( - DefaultTrendDetectorConfigProbe = TrendDetectorConfig{ - RequiredSamples: 3, - RequiredSamplesMin: 3, - DownwardTrendThreshold: 0.0, - DownwardTrendMaxWait: 5 * time.Second, - CollapseThreshold: 0, - ValidityWindow: 10 * time.Second, - } - - DefaultTrendDetectorConfigNonProbe = TrendDetectorConfig{ - RequiredSamples: 12, - RequiredSamplesMin: 8, - DownwardTrendThreshold: -0.6, - DownwardTrendMaxWait: 5 * time.Second, - CollapseThreshold: 500 * time.Millisecond, - ValidityWindow: 10 * time.Second, - } -) - // ------------------------------------------------ type TrendDetectorParams struct { @@ -107,35 +76,35 @@ type TrendDetectorParams struct { Config TrendDetectorConfig } -type TrendDetector struct { +type TrendDetector[T trendDetectorNumber] struct { params TrendDetectorParams startTime time.Time numSamples int - samples []trendDetectorSample - lowestValue int64 - highestValue int64 + samples []trendDetectorSample[T] + lowestValue T + highestValue T direction TrendDirection } -func NewTrendDetector(params TrendDetectorParams) *TrendDetector { - return &TrendDetector{ +func NewTrendDetector[T trendDetectorNumber](params TrendDetectorParams) *TrendDetector[T] { + return &TrendDetector[T]{ params: params, startTime: time.Now(), direction: TrendDirectionNeutral, } } -func (t *TrendDetector) Seed(value int64) { +func (t *TrendDetector[T]) Seed(value T) { if len(t.samples) != 0 { return } - t.samples = append(t.samples, trendDetectorSample{value: value, at: time.Now()}) + t.samples = append(t.samples, trendDetectorSample[T]{value: value, at: time.Now()}) } -func (t *TrendDetector) AddValue(value int64) { +func (t *TrendDetector[T]) AddValue(value T) { t.numSamples++ if t.lowestValue == 0 || value < t.lowestValue { t.lowestValue = value @@ -156,7 +125,7 @@ func (t *TrendDetector) AddValue(value int64) { // But, on the flip side, estimate could fall once or twice within a sliding window and stay there. // In those cases, using a collapse window to record a value even if it is duplicate. By doing that, // a trend could be detected eventually. It will be delayed, but that is fine with slow changing estimates. - var lastSample *trendDetectorSample + var lastSample *trendDetectorSample[T] if len(t.samples) != 0 { lastSample = &t.samples[len(t.samples)-1] } @@ -164,38 +133,52 @@ func (t *TrendDetector) AddValue(value int64) { return } - t.samples = append(t.samples, trendDetectorSample{value: value, at: time.Now()}) + t.samples = append(t.samples, trendDetectorSample[T]{value: value, at: time.Now()}) t.prune() t.updateDirection() } -func (t *TrendDetector) GetLowest() int64 { +func (t *TrendDetector[T]) GetLowest() T { return t.lowestValue } -func (t *TrendDetector) GetHighest() int64 { +func (t *TrendDetector[T]) GetHighest() T { return t.highestValue } -func (t *TrendDetector) GetDirection() TrendDirection { +func (t *TrendDetector[T]) GetDirection() TrendDirection { return t.direction } -func (t *TrendDetector) HasEnoughSamples() bool { +func (t *TrendDetector[T]) HasEnoughSamples() bool { return t.numSamples >= t.params.Config.RequiredSamples } -func (t *TrendDetector) ToString() string { +func (t *TrendDetector[T]) ToString() string { + samplesStr := "" + if len(t.samples) > 0 { + firstTime := t.samples[0].at + samplesStr += "[" + for i, sample := range t.samples { + suffix := ", " + if i == len(t.samples)-1 { + suffix = "" + } + samplesStr += fmt.Sprintf("%v(%d)%s", sample.value, sample.at.Sub(firstTime).Milliseconds(), suffix) + } + samplesStr += "]" + } + now := time.Now() elapsed := now.Sub(t.startTime).Seconds() - return fmt.Sprintf("n: %s, t: %+v|%+v|%.2fs, v: %d|%d|%d|%s|%.2f", + return fmt.Sprintf("n: %s, t: %+v|%+v|%.2fs, v: %d|%v|%v|%s|%.2f", t.params.Name, t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed, - t.numSamples, t.lowestValue, t.highestValue, trendDetectorSampleListToString(t.samples), kendallsTau(t.samples), + t.numSamples, t.lowestValue, t.highestValue, samplesStr, t.kendallsTau(), ) } -func (t *TrendDetector) prune() { +func (t *TrendDetector[T]) prune() { // prune based on a few rules // 1. If there are more than required samples @@ -238,14 +221,14 @@ func (t *TrendDetector) prune() { } } -func (t *TrendDetector) updateDirection() { +func (t *TrendDetector[T]) updateDirection() { if len(t.samples) < t.params.Config.RequiredSamplesMin { t.direction = TrendDirectionNeutral return } // using Kendall's Tau to find trend - kt := kendallsTau(t.samples) + kt := t.kendallsTau() t.direction = TrendDirectionNeutral switch { @@ -256,17 +239,15 @@ func (t *TrendDetector) updateDirection() { } } -// ------------------------------------------------ - -func kendallsTau(samples []trendDetectorSample) float64 { +func (t *TrendDetector[T]) kendallsTau() float64 { concordantPairs := 0 discordantPairs := 0 - for i := 0; i < len(samples)-1; i++ { - for j := i + 1; j < len(samples); j++ { - if samples[i].value < samples[j].value { + for i := 0; i < len(t.samples)-1; i++ { + for j := i + 1; j < len(t.samples); j++ { + if t.samples[i].value < t.samples[j].value { concordantPairs++ - } else if samples[i].value > samples[j].value { + } else if t.samples[i].value > t.samples[j].value { discordantPairs++ } } diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 109e77fb1..2455fc8c5 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -1407,9 +1407,9 @@ func (d *DownTrack) DistanceToDesired() float64 { return d.forwarder.DistanceToDesired(al, brs) } -func (d *DownTrack) AllocateOptimal(allowOvershoot bool) VideoAllocation { +func (d *DownTrack) AllocateOptimal(allowOvershoot bool, hold bool) VideoAllocation { al, brs := d.params.Receiver.GetLayeredBitrate() - allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot) + allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot, hold) d.postKeyFrameRequestEvent() d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) return allocation @@ -1966,6 +1966,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { Header: &pkt.Header, Extensions: extensions, Payload: payload, + IsRTX: true, AbsSendTimeExtID: uint8(d.absSendTimeExtID), TransportWideExtID: uint8(d.transportWideExtID), WriteStream: d.writeStream, diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index 883097590..fcdeeda00 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -714,7 +714,7 @@ func (f *Forwarder) GetOptimalBandwidthNeeded(brs Bitrates) int64 { return getOptimalBandwidthNeeded(f.muted, f.pubMuted, f.vls.GetMaxSeen().Spatial, brs, f.vls.GetMax()) } -func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allowOvershoot bool) VideoAllocation { +func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allowOvershoot bool, hold bool) VideoAllocation { f.lock.Lock() defer f.lock.Unlock() @@ -755,7 +755,7 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow } alloc.TargetLayer = buffer.VideoLayer{ - Spatial: int32(math.Min(float64(maxSeenLayer.Spatial), float64(maxSpatial))), + Spatial: min(maxSeenLayer.Spatial, maxSpatial), Temporal: getMaxTemporal(), } } @@ -783,8 +783,9 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow // 2. If current is a valid layer, check against currently available layers and continue at current // if possible. Else, choose the highest available layer as the next target. // 3. If current is not valid, set next target to be opportunistic. - maxLayerSpatialLimit := int32(math.Min(float64(maxLayer.Spatial), float64(maxSeenLayer.Spatial))) + maxLayerSpatialLimit := min(maxLayer.Spatial, maxSeenLayer.Spatial) highestAvailableLayer := buffer.InvalidLayerSpatial + lowestAvailableLayer := buffer.InvalidLayerSpatial requestLayerSpatial := buffer.InvalidLayerSpatial for _, al := range availableLayers { if al > requestLayerSpatial && al <= maxLayerSpatialLimit { @@ -793,6 +794,9 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow if al > highestAvailableLayer { highestAvailableLayer = al } + if lowestAvailableLayer == buffer.InvalidLayerSpatial || al < lowestAvailableLayer { + lowestAvailableLayer = al + } } if requestLayerSpatial == buffer.InvalidLayerSpatial && highestAvailableLayer != buffer.InvalidLayerSpatial && allowOvershoot && f.vls.IsOvershootOkay() { requestLayerSpatial = highestAvailableLayer @@ -811,20 +815,46 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow Temporal: getMaxTemporal(), } } else { - // current layer has stopped, switch to highest available - alloc.TargetLayer = buffer.VideoLayer{ - Spatial: requestLayerSpatial, - Temporal: getMaxTemporal(), + // current layer has stopped, switch to lowest available if `hold`ing, else switch to highest available + if hold { + // if `hold` is requested, may be set due to early warning congestion + // signal, in that case layers are not increased as increasing layers + // will result in more load on the channel + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: lowestAvailableLayer, + Temporal: 0, + } + } else { + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: requestLayerSpatial, + Temporal: getMaxTemporal(), + } } } alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial } else { - // opportunistically latch on to anything - opportunisticAlloc() - if requestLayerSpatial == buffer.InvalidLayerSpatial { - alloc.RequestLayerSpatial = maxLayerSpatialLimit + if hold { + // allocate minimal to make the stream active while `hold`ing. + if lowestAvailableLayer == buffer.InvalidLayerSpatial { + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + } else { + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: lowestAvailableLayer, + Temporal: 0, + } + } + alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial } else { - alloc.RequestLayerSpatial = requestLayerSpatial + // opportunistically latch on to anything + opportunisticAlloc() + if requestLayerSpatial == buffer.InvalidLayerSpatial { + alloc.RequestLayerSpatial = maxLayerSpatialLimit + } else { + alloc.RequestLayerSpatial = requestLayerSpatial + } } } } @@ -834,7 +864,7 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow alloc.RequestLayerSpatial = buffer.InvalidLayerSpatial } if alloc.TargetLayer.IsValid() { - alloc.BandwidthRequested = optimalBandwidthNeeded + alloc.BandwidthRequested = getOptimalBandwidthNeeded(f.muted, f.pubMuted, maxSeenLayer.Spatial, brs, alloc.TargetLayer) } alloc.BandwidthDelta = alloc.BandwidthRequested - getBandwidthNeeded(brs, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested) alloc.DistanceToDesired = getDistanceToDesired( @@ -1120,7 +1150,7 @@ func (f *Forwarder) ProvisionalAllocateGetBestWeightedTransition() (VideoTransit break } - bandwidthDelta := int64(math.Max(float64(0), float64(existingBandwidthNeeded-f.provisional.bitrates[s][t]))) + bandwidthDelta := max(0, existingBandwidthNeeded-f.provisional.bitrates[s][t]) transitionCost := int32(0) // SVC-TODO: SVC will need a different cost transition diff --git a/pkg/sfu/forwarder_test.go b/pkg/sfu/forwarder_test.go index 608ab9155..f9c8a15e5 100644 --- a/pkg/sfu/forwarder_test.go +++ b/pkg/sfu/forwarder_test.go @@ -145,7 +145,7 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.InvalidLayer, DistanceToDesired: 0, } - result := f.AllocateOptimal(nil, bitrates, true) + result := f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) @@ -163,7 +163,7 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: 0, } - result = f.AllocateOptimal(nil, bitrates, true) + result = f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) @@ -182,7 +182,7 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: 0, } - result = f.AllocateOptimal(nil, bitrates, true) + result = f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) @@ -201,19 +201,19 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: 0, } - result = f.AllocateOptimal(nil, bitrates, true) + result = f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) f.PubMute(false) // when max layers changes, target is opportunistic, but requested spatial layer should be at max - f.SetMaxTemporalLayerSeen(3) + f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal) f.vls.SetMax(buffer.VideoLayer{Spatial: 1, Temporal: 3}) expectedResult = VideoAllocation{ PauseReason: VideoPauseReasonNone, - BandwidthRequested: bitrates[1][3], - BandwidthDelta: bitrates[1][3], + BandwidthRequested: bitrates[2][1], + BandwidthDelta: bitrates[2][1], BandwidthNeeded: bitrates[1][3], Bitrates: bitrates, TargetLayer: buffer.DefaultMaxLayer, @@ -221,7 +221,7 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: f.vls.GetMax(), DistanceToDesired: -1, } - result = f.AllocateOptimal(nil, bitrates, true) + result = f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) @@ -232,14 +232,11 @@ func TestForwarderAllocateOptimal(t *testing.T) { // when feed is dry and current is not valid, should set up for opportunistic forwarding // NOTE: feed is dry due to availableLayers = nil, some valid bitrates may be passed in here for testing purposes only disable(f) - expectedTargetLayer := buffer.VideoLayer{ - Spatial: 2, - Temporal: buffer.DefaultMaxLayerTemporal, - } + expectedTargetLayer := buffer.DefaultMaxLayer expectedResult = VideoAllocation{ PauseReason: VideoPauseReasonNone, BandwidthRequested: bitrates[2][1], - BandwidthDelta: bitrates[2][1] - bitrates[1][3], + BandwidthDelta: 0, BandwidthNeeded: bitrates[2][1], Bitrates: bitrates, TargetLayer: expectedTargetLayer, @@ -247,7 +244,7 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: -0.5, } - result = f.AllocateOptimal(nil, bitrates, true) + result = f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) require.Equal(t, expectedTargetLayer, f.TargetLayer()) @@ -270,7 +267,7 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: -0.75, } - result = f.AllocateOptimal(nil, emptyBitrates, true) + result = f.AllocateOptimal(nil, emptyBitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) require.Equal(t, expectedTargetLayer, f.TargetLayer()) @@ -289,16 +286,37 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: -0.5, } - result = f.AllocateOptimal([]int32{0, 1}, bitrates, true) + result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) + // when holding in above scenario, should choose the lowest available layer + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: 0, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[1][0], + BandwidthDelta: bitrates[1][0] - bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 1, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 1.25, + } + result = f.AllocateOptimal([]int32{1, 2}, bitrates, true, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + // opportunistic target if feed is dry and current is not valid, i. e. not forwarding expectedResult = VideoAllocation{ PauseReason: VideoPauseReasonNone, BandwidthRequested: bitrates[2][1], - BandwidthDelta: 0, + BandwidthDelta: bitrates[2][1] - bitrates[1][0], BandwidthNeeded: bitrates[2][1], Bitrates: bitrates, TargetLayer: buffer.DefaultMaxLayer, @@ -306,26 +324,49 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: -0.5, } - result = f.AllocateOptimal(nil, bitrates, true) + result = f.AllocateOptimal(nil, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) + // when holding in above scenario, should choose layer 0 + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[0][0], + BandwidthDelta: bitrates[0][0] - bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 0, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result = f.AllocateOptimal(nil, bitrates, true, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + // if feed is not dry and current is not locked, should be opportunistic (with and without overshoot) f.vls.SetTarget(buffer.InvalidLayer) expectedResult = VideoAllocation{ PauseReason: VideoPauseReasonFeedDry, BandwidthRequested: 0, - BandwidthDelta: 0 - bitrates[2][1], + BandwidthDelta: 0 - bitrates[0][0], + BandwidthNeeded: 0, Bitrates: emptyBitrates, TargetLayer: buffer.DefaultMaxLayer, RequestLayerSpatial: 1, MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: -1.0, } - result = f.AllocateOptimal([]int32{0, 1}, emptyBitrates, false) + result = f.AllocateOptimal([]int32{0, 1}, emptyBitrates, false, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) f.vls.SetTarget(buffer.InvalidLayer) expectedTargetLayer = buffer.VideoLayer{ @@ -343,9 +384,10 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: -0.5, } - result = f.AllocateOptimal([]int32{0, 1}, bitrates, true) + result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) // switches request layer to highest available if feed is not dry and current is valid and current is not available f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 1}) @@ -355,8 +397,8 @@ func TestForwarderAllocateOptimal(t *testing.T) { } expectedResult = VideoAllocation{ PauseReason: VideoPauseReasonNone, - BandwidthRequested: bitrates[2][1], - BandwidthDelta: 0, + BandwidthRequested: bitrates[1][3], + BandwidthDelta: bitrates[1][3] - bitrates[2][1], BandwidthNeeded: bitrates[2][1], Bitrates: bitrates, TargetLayer: expectedTargetLayer, @@ -364,9 +406,31 @@ func TestForwarderAllocateOptimal(t *testing.T) { MaxLayer: buffer.DefaultMaxLayer, DistanceToDesired: 0.5, } - result = f.AllocateOptimal([]int32{1}, bitrates, true) + result = f.AllocateOptimal([]int32{1}, bitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // when holding in above scenario, should switch to lowest available layer + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[0][0], + BandwidthDelta: bitrates[0][0] - bitrates[1][3], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 0, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) // stays the same if feed is not dry and current is valid, available and locked f.vls.SetMax(buffer.VideoLayer{Spatial: 0, Temporal: 1}) @@ -379,16 +443,17 @@ func TestForwarderAllocateOptimal(t *testing.T) { expectedResult = VideoAllocation{ PauseReason: VideoPauseReasonFeedDry, BandwidthRequested: 0, - BandwidthDelta: 0 - bitrates[2][1], + BandwidthDelta: 0 - bitrates[0][0], Bitrates: emptyBitrates, TargetLayer: expectedTargetLayer, RequestLayerSpatial: 0, MaxLayer: f.vls.GetMax(), DistanceToDesired: 0.0, } - result = f.AllocateOptimal([]int32{0}, emptyBitrates, true) + result = f.AllocateOptimal([]int32{0}, emptyBitrates, true, false) require.Equal(t, expectedResult, result) require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) } func TestForwarderProvisionalAllocate(t *testing.T) { diff --git a/pkg/sfu/pacer/base.go b/pkg/sfu/pacer/base.go index 69bfe5272..13d083745 100644 --- a/pkg/sfu/pacer/base.go +++ b/pkg/sfu/pacer/base.go @@ -19,6 +19,7 @@ import ( "io" "time" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils/mono" "github.com/pion/rtp" @@ -26,11 +27,14 @@ import ( type Base struct { logger logger.Logger + + sendSideBWE *sendsidebwe.SendSideBWE } -func NewBase(logger logger.Logger) *Base { +func NewBase(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE) *Base { return &Base{ - logger: logger, + logger: logger, + sendSideBWE: sendSideBWE, } } @@ -47,7 +51,7 @@ func (b *Base) SendPacket(p *Packet) (int, error) { } }() - _, err := b.writeRTPHeaderExtensions(p) + err := b.writeRTPHeaderExtensions(p) if err != nil { b.logger.Errorw("writing rtp header extensions err", err) return 0, err @@ -66,7 +70,7 @@ func (b *Base) SendPacket(p *Packet) (int, error) { } // writes RTP header extensions of track -func (b *Base) writeRTPHeaderExtensions(p *Packet) (time.Time, error) { +func (b *Base) writeRTPHeaderExtensions(p *Packet) error { // clear out extensions that may have been in the forwarded header p.Header.Extension = false p.Header.ExtensionProfile = 0 @@ -85,16 +89,34 @@ func (b *Base) writeRTPHeaderExtensions(p *Packet) (time.Time, error) { sendTime := rtp.NewAbsSendTimeExtension(sendingAt) b, err := sendTime.Marshal() if err != nil { - return time.Time{}, err + return err } - err = p.Header.SetExtension(p.AbsSendTimeExtID, b) - if err != nil { - return time.Time{}, err + if err = p.Header.SetExtension(p.AbsSendTimeExtID, b); err != nil { + return err } } - return sendingAt, nil + if p.TransportWideExtID != 0 && b.sendSideBWE != nil { + twccSN := b.sendSideBWE.RecordPacketSendAndGetSequenceNumber( + sendingAt, + p.Header.MarshalSize()+len(p.Payload), + p.IsRTX, + ) + twccExt := rtp.TransportCCExtension{ + TransportSequence: twccSN, + } + b, err := twccExt.Marshal() + if err != nil { + return err + } + + if err = p.Header.SetExtension(p.TransportWideExtID, b); err != nil { + return err + } + } + + return nil } // ------------------------------------------------ diff --git a/pkg/sfu/pacer/leaky_bucket.go b/pkg/sfu/pacer/leaky_bucket.go index 9ac2a1350..90d8fa2ea 100644 --- a/pkg/sfu/pacer/leaky_bucket.go +++ b/pkg/sfu/pacer/leaky_bucket.go @@ -19,6 +19,7 @@ import ( "time" "github.com/gammazero/deque" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/protocol/logger" ) @@ -38,9 +39,9 @@ type LeakyBucket struct { isStopped bool } -func NewLeakyBucket(logger logger.Logger, interval time.Duration, bitrate int) *LeakyBucket { +func NewLeakyBucket(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE, interval time.Duration, bitrate int) *LeakyBucket { l := &LeakyBucket{ - Base: NewBase(logger), + Base: NewBase(logger, sendSideBWE), logger: logger, interval: interval, bitrate: bitrate, diff --git a/pkg/sfu/pacer/no_queue.go b/pkg/sfu/pacer/no_queue.go index 927236394..fc3bfd4a4 100644 --- a/pkg/sfu/pacer/no_queue.go +++ b/pkg/sfu/pacer/no_queue.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/gammazero/deque" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/protocol/logger" ) @@ -32,9 +33,9 @@ type NoQueue struct { isStopped bool } -func NewNoQueue(logger logger.Logger) *NoQueue { +func NewNoQueue(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE) *NoQueue { n := &NoQueue{ - Base: NewBase(logger), + Base: NewBase(logger, sendSideBWE), logger: logger, wake: make(chan struct{}, 1), } diff --git a/pkg/sfu/pacer/pacer.go b/pkg/sfu/pacer/pacer.go index 83e79a376..75c9cca43 100644 --- a/pkg/sfu/pacer/pacer.go +++ b/pkg/sfu/pacer/pacer.go @@ -31,6 +31,7 @@ type Packet struct { Header *rtp.Header Extensions []ExtensionData Payload []byte + IsRTX bool AbsSendTimeExtID uint8 TransportWideExtID uint8 WriteStream webrtc.TrackLocalWriter diff --git a/pkg/sfu/pacer/pass_through.go b/pkg/sfu/pacer/pass_through.go index 8c33d808f..fba06c792 100644 --- a/pkg/sfu/pacer/pass_through.go +++ b/pkg/sfu/pacer/pass_through.go @@ -15,6 +15,7 @@ package pacer import ( + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/protocol/logger" ) @@ -22,9 +23,9 @@ type PassThrough struct { *Base } -func NewPassThrough(logger logger.Logger) *PassThrough { +func NewPassThrough(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE) *PassThrough { return &PassThrough{ - Base: NewBase(logger), + Base: NewBase(logger, sendSideBWE), } } diff --git a/pkg/sfu/rtpstats/rtpstats_sender.go b/pkg/sfu/rtpstats/rtpstats_sender.go index 0c76f4352..ca9b355fb 100644 --- a/pkg/sfu/rtpstats/rtpstats_sender.go +++ b/pkg/sfu/rtpstats/rtpstats_sender.go @@ -502,7 +502,6 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt ) return } - r.extHighestSNFromRR = extHighestSNFromRR if r.srNewest != nil { @@ -515,12 +514,10 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt } } - // This is 24-bit max in the protocol. So, technically doesn't need extended type. But, done for consistency. - packetsLostFromRR := r.packetsLostFromRR&0xFFFF_FFFF_0000_0000 + uint64(rr.TotalLost) - if (rr.TotalLost-r.lastRR.TotalLost) < (1<<31) && rr.TotalLost < r.lastRR.TotalLost { - packetsLostFromRR += (1 << 32) + r.packetsLostFromRR = r.packetsLostFromRR&0xFFFF_FFFF_FF00_0000 + uint64(rr.TotalLost) + if ((rr.TotalLost-r.lastRR.TotalLost)&((1<<24)-1)) < (1<<23) && rr.TotalLost < r.lastRR.TotalLost { + r.packetsLostFromRR += (1 << 24) } - r.packetsLostFromRR = packetsLostFromRR if isRttChanged { r.rtt = rtt diff --git a/pkg/sfu/sendsidebwe/congestion_detector.go b/pkg/sfu/sendsidebwe/congestion_detector.go new file mode 100644 index 000000000..051d80148 --- /dev/null +++ b/pkg/sfu/sendsidebwe/congestion_detector.go @@ -0,0 +1,556 @@ +package sendsidebwe + +import ( + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/gammazero/deque" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + "github.com/pion/rtcp" +) + +// ------------------------------------------------------------------------------- + +type CongestionSignalConfig struct { + MinNumberOfGroups int `yaml:"min_number_of_groups,omitempty"` + MinDuration time.Duration `yaml:"min_duration,omitempty"` +} + +var ( + DefaultEarlyWarningCongestionSignalConfig = CongestionSignalConfig{ + MinNumberOfGroups: 1, + MinDuration: 100 * time.Millisecond, + } + + DefaultCongestedCongestionSignalConfig = CongestionSignalConfig{ + MinNumberOfGroups: 3, + MinDuration: 300 * time.Millisecond, + } +) + +// ------------------------------------------------------------------------------- + +type CongestionDetectorConfig struct { + PacketGroup PacketGroupConfig `yaml:"packet_group,omitempty"` + PacketGroupMaxAge time.Duration `yaml:"packet_group_max_age,omitempty"` + + JQRMinDelay time.Duration `yaml:"jqr_min_delay,omitempty"` + DQRMaxDelay time.Duration `yaml:"dqr_max_delay,omitempty"` + + EarlyWarning CongestionSignalConfig `yaml:"early_warning,omitempty"` + EarlyWarningHangover time.Duration `yaml:"early_warning_hangover,omitempty"` + Congested CongestionSignalConfig `yaml:"congested,omitempty"` + CongestedHangover time.Duration `yaml:"congested_hangover,omitempty"` + + RateMeasurementWindowFullnessMin float64 `yaml:"rate_measurement_window_fullness_min,omitempty"` + RateMeasurementWindowDurationMin time.Duration `yaml:"rate_measurement_window_duration_min,omitempty"` + RateMeasurementWindowDurationMax time.Duration `yaml:"rate_measurement_window_duration_max,omitempty"` + + PeriodicCheckInterval time.Duration `yaml:"periodic_check_interval,omitempty"` + PeriodicCheckIntervalCongested time.Duration `yaml:"periodic_check_interval_congested,omitempty"` + CongestedCTRTrend ccutils.TrendDetectorConfig `yaml:"congested_ctr_trend,omitempty"` + CongestedCTREpsilon float64 `yaml:"congested_ctr_epsilon,omitempty"` +} + +var ( + defaultTrendDetectorConfigCongestedCTR = ccutils.TrendDetectorConfig{ + RequiredSamples: 5, + RequiredSamplesMin: 2, + DownwardTrendThreshold: -0.6, + DownwardTrendMaxWait: 5 * time.Second, + CollapseThreshold: 500 * time.Millisecond, + ValidityWindow: 10 * time.Second, + } + + DefaultCongestionDetectorConfig = CongestionDetectorConfig{ + PacketGroup: DefaultPacketGroupConfig, + PacketGroupMaxAge: 30 * time.Second, + JQRMinDelay: 15 * time.Millisecond, + DQRMaxDelay: 5 * time.Millisecond, + EarlyWarning: DefaultEarlyWarningCongestionSignalConfig, + EarlyWarningHangover: 500 * time.Millisecond, + Congested: DefaultCongestedCongestionSignalConfig, + CongestedHangover: 3 * time.Second, + RateMeasurementWindowFullnessMin: 0.8, + RateMeasurementWindowDurationMin: 800 * time.Millisecond, + RateMeasurementWindowDurationMax: 2 * time.Second, + PeriodicCheckInterval: 2 * time.Second, + PeriodicCheckIntervalCongested: 200 * time.Millisecond, + CongestedCTRTrend: defaultTrendDetectorConfigCongestedCTR, + CongestedCTREpsilon: 0.05, + } +) + +// ------------------------------------------------------------------------------- + +type feedbackReport struct { + at time.Time + report *rtcp.TransportLayerCC +} + +type congestionDetectorParams struct { + Config CongestionDetectorConfig + Logger logger.Logger +} + +type congestionDetector struct { + params congestionDetectorParams + + lock sync.RWMutex + feedbackReports deque.Deque[feedbackReport] + + *packetTracker + twccFeedback *twccFeedback + + packetGroups []*packetGroup + + wake chan struct{} + stop core.Fuse + + estimatedAvailableChannelCapacity int64 + congestionState CongestionState + congestionStateSwitchedAt time.Time + congestedCTRTrend *ccutils.TrendDetector[float64] + + onCongestionStateChange func(congestionState CongestionState, estimatedAvailableChannelCapacity int64) +} + +func newCongestionDetector(params congestionDetectorParams) *congestionDetector { + c := &congestionDetector{ + params: params, + packetTracker: newPacketTracker(packetTrackerParams{Logger: params.Logger}), + twccFeedback: newTWCCFeedback(twccFeedbackParams{Logger: params.Logger}), + wake: make(chan struct{}, 1), + estimatedAvailableChannelCapacity: 100_000_000, + } + + c.feedbackReports.SetMinCapacity(3) + + go c.worker() + return c +} + +func (c *congestionDetector) Stop() { + c.stop.Once(func() { + close(c.wake) + }) +} + +func (c *congestionDetector) OnCongestionStateChange(f func(congestionState CongestionState, estimatedAvailableChannelCapacity int64)) { + c.lock.Lock() + defer c.lock.Unlock() + + c.onCongestionStateChange = f +} + +func (c *congestionDetector) GetCongestionState() CongestionState { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.congestionState +} + +func (c *congestionDetector) updateCongestionState(state CongestionState) { + c.lock.Lock() + c.params.Logger.Infow( + "congestion state change", + "from", c.congestionState, + "to", state, + "estimatedAvailableChannelCapacity", c.estimatedAvailableChannelCapacity, + ) + + prevState := c.congestionState + c.congestionState = state + + onCongestionStateChange := c.onCongestionStateChange + estimatedAvailableChannelCapacity := c.estimatedAvailableChannelCapacity + c.lock.Unlock() + + if onCongestionStateChange != nil { + onCongestionStateChange(state, estimatedAvailableChannelCapacity) + } + + // when in congested state, monitor changes in captured traffic ratio (CTR) + // to ensure allocations are in line with latest estimates, it is possible that + // the estimate is incorrect when congestion starts and the allocation may be + // sub-optimal and not enough to reduce/relieve congestion, by monitoing CTR + // on a continuous basis allocations can be adjusted in the direction of + // reducing/relieving congestion + if state == CongestionStateCongested && prevState != CongestionStateCongested { + c.congestedCTRTrend = ccutils.NewTrendDetector[float64](ccutils.TrendDetectorParams{ + Name: "ssbwe-estimate", + Logger: c.params.Logger, + Config: c.params.Config.CongestedCTRTrend, + }) + } else if state != CongestionStateCongested { + c.congestedCTRTrend = nil + } +} + +func (c *congestionDetector) GetEstimatedAvailableChannelCapacity() int64 { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.estimatedAvailableChannelCapacity +} + +func (c *congestionDetector) HandleRTCP(report *rtcp.TransportLayerCC) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.stop.IsBroken() { + return + } + + c.feedbackReports.PushBack(feedbackReport{mono.Now(), report}) + + // notify worker of a new feedback + select { + case c.wake <- struct{}{}: + default: + } +} + +func (c *congestionDetector) prunePacketGroups() { + if len(c.packetGroups) == 0 { + return + } + + threshold := c.packetGroups[len(c.packetGroups)-1].MinSendTime() - c.params.Config.PacketGroupMaxAge.Microseconds() + for idx, pg := range c.packetGroups { + if mst := pg.MinSendTime(); mst < threshold { + c.packetGroups = c.packetGroups[idx+1:] + return + } + } +} + +func (c *congestionDetector) isCongestionSignalTriggered() (bool, bool) { + earlyWarningTriggered := false + congestedTriggered := false + + numGroups := 0 + duration := int64(0) + for idx := len(c.packetGroups) - 1; idx >= 0; idx-- { + pg := c.packetGroups[idx] + pqd, ok := pg.PropagatedQueuingDelay() + if !ok { + continue + } + + if pqd > c.params.Config.JQRMinDelay.Microseconds() { + // JQR group builds up congestion signal + numGroups++ + duration += pg.SendDuration() + } + + // INDETERMINATE group is treated as a no-op + + if pqd < c.params.Config.DQRMaxDelay.Microseconds() { + // any DQR group breaks the continuity + return earlyWarningTriggered, congestedTriggered + } + + if numGroups >= c.params.Config.EarlyWarning.MinNumberOfGroups && duration >= c.params.Config.EarlyWarning.MinDuration.Microseconds() { + earlyWarningTriggered = true + } + if numGroups >= c.params.Config.Congested.MinNumberOfGroups && duration >= c.params.Config.Congested.MinDuration.Microseconds() { + congestedTriggered = true + } + if earlyWarningTriggered && congestedTriggered { + break + } + } + + return earlyWarningTriggered, congestedTriggered +} + +func (c *congestionDetector) congestionDetectionStateMachine() { + state := c.GetCongestionState() + newState := state + + earlyWarningTriggered, congestedTriggered := c.isCongestionSignalTriggered() + + switch state { + case CongestionStateNone: + if congestedTriggered { + c.params.Logger.Warnw("invalid congested state transition", nil, "from", state) + } + if earlyWarningTriggered { + newState = CongestionStateEarlyWarning + } + + case CongestionStateEarlyWarning: + if congestedTriggered { + newState = CongestionStateCongested + } else if !earlyWarningTriggered { + newState = CongestionStateEarlyWarningHangover + } + + case CongestionStateEarlyWarningHangover: + if congestedTriggered { + c.params.Logger.Warnw("invalid congested state transition", nil, "from", state) + } + if earlyWarningTriggered { + newState = CongestionStateEarlyWarning + } else if time.Since(c.congestionStateSwitchedAt) >= c.params.Config.EarlyWarningHangover { + newState = CongestionStateNone + } + + case CongestionStateCongested: + if !congestedTriggered { + newState = CongestionStateCongestedHangover + } + + case CongestionStateCongestedHangover: + if congestedTriggered { + c.params.Logger.Warnw("invalid congested state transition", nil, "from", state) + } + if earlyWarningTriggered { + newState = CongestionStateEarlyWarning + } else if time.Since(c.congestionStateSwitchedAt) >= c.params.Config.CongestedHangover { + newState = CongestionStateNone + } + } + + c.estimateAvailableChannelCapacity() + + // update after running the above estimate as state change callback includes the estimated available channel capacity + if newState != state { + c.congestionStateSwitchedAt = mono.Now() + c.updateCongestionState(newState) + } +} + +func (c *congestionDetector) updateTrend(ctr float64) { + if c.congestedCTRTrend == nil { + return + } + + // quantise the CTR to filter out small changes + c.congestedCTRTrend.AddValue(float64(int((ctr+(c.params.Config.CongestedCTREpsilon/2))/c.params.Config.CongestedCTREpsilon)) * c.params.Config.CongestedCTREpsilon) + + if c.congestedCTRTrend.GetDirection() == ccutils.TrendDirectionDownward { + c.params.Logger.Infow("captured traffic ratio is trending downward", "channel", c.congestedCTRTrend.ToString()) + + c.lock.RLock() + state := c.congestionState + estimatedAvailableChannelCapacity := c.estimatedAvailableChannelCapacity + onCongestionStateChange := c.onCongestionStateChange + c.lock.RUnlock() + + if onCongestionStateChange != nil { + onCongestionStateChange(state, estimatedAvailableChannelCapacity) + } + + // reset to get new set of samples for next trend + c.congestedCTRTrend = ccutils.NewTrendDetector[float64](ccutils.TrendDetectorParams{ + Name: "ssbwe-estimate", + Logger: c.params.Logger, + Config: c.params.Config.CongestedCTRTrend, + }) + } +} + +func (c *congestionDetector) estimateAvailableChannelCapacity() { + if len(c.packetGroups) == 0 { + return + } + + totalDuration := int64(0) + totalBytes := 0 + threshold := c.packetGroups[len(c.packetGroups)-1].MinSendTime() - c.params.Config.RateMeasurementWindowDurationMax.Microseconds() + for idx := len(c.packetGroups) - 1; idx >= 0; idx-- { + pg := c.packetGroups[idx] + mst, dur, nbytes, fullness := pg.Traffic() + if mst < threshold { + break + } + + if fullness < c.params.Config.RateMeasurementWindowFullnessMin { + continue + } + + totalDuration += dur + totalBytes += nbytes + } + + if totalDuration >= c.params.Config.RateMeasurementWindowDurationMin.Microseconds() { + c.lock.Lock() + c.estimatedAvailableChannelCapacity = int64(totalBytes) * 8 * 1e6 / totalDuration + c.lock.Unlock() + } else { + c.params.Logger.Infow("not enough data to estimate available channel capacity", "totalDuration", totalDuration) + } +} + +func (c *congestionDetector) processFeedbackReport(fbr feedbackReport) { + recvRefTime, isOutOfOrder := c.twccFeedback.ProcessReport(fbr.report, fbr.at) + if isOutOfOrder { + c.params.Logger.Infow("received out-of-order feedback report") + } + + if len(c.packetGroups) == 0 { + c.packetGroups = append( + c.packetGroups, + newPacketGroup( + packetGroupParams{ + Config: c.params.Config.PacketGroup, + Logger: c.params.Logger, + }, + 0, + ), + ) + } + + pg := c.packetGroups[len(c.packetGroups)-1] + trackPacketGroup := func(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) { + if pi == nil { + return + } + + err := pg.Add(pi, sendDelta, recvDelta, isLost) + if err == nil { + return + } + + if err == errGroupFinalized { + // previous group ended, start a new group + c.updateTrend(pg.CapturedTrafficRatio()) + + // SSBWE-REMOVE c.params.Logger.Infow("packet group done", "group", pg, "numGroups", len(c.packetGroups)) // SSBWE-REMOVE + pqd, _ := pg.PropagatedQueuingDelay() + pg = newPacketGroup( + packetGroupParams{ + Config: c.params.Config.PacketGroup, + Logger: c.params.Logger, + }, + pqd, + ) + c.packetGroups = append(c.packetGroups, pg) + + pg.Add(pi, sendDelta, recvDelta, isLost) + return + } + + // try an older group + for idx := len(c.packetGroups) - 2; idx >= 0; idx-- { + opg := c.packetGroups[idx] + if err := opg.Add(pi, sendDelta, recvDelta, isLost); err == nil { + return + } else if err == errGroupFinalized { + c.params.Logger.Infow("unpected finalized group", "packetInfo", pi, "packetGroup", opg) + } + } + } + + // 1. go through the TWCC feedback report and record recive time as reported by remote + // 2. process acknowledged packet and group them + // + // losses are not recorded if a feedback report is completely lost. + // RFC recommends treating lost reports by ignoring packets that would have been in it. + // ----------------------------------------------------------------------------------- + // | From a congestion control perspective, lost feedback messages are | + // | handled by ignoring packets which would have been reported as lost or | + // | received in the lost feedback messages. This behavior is similar to | + // | how a lost RTCP receiver report is handled. | + // ----------------------------------------------------------------------------------- + // Reference: https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-4 + sequenceNumber := fbr.report.BaseSequenceNumber + endSequenceNumberExclusive := sequenceNumber + fbr.report.PacketStatusCount + deltaIdx := 0 + for _, chunk := range fbr.report.PacketChunks { + if sequenceNumber == endSequenceNumberExclusive { + break + } + + switch chunk := chunk.(type) { + case *rtcp.RunLengthChunk: + for i := uint16(0); i < chunk.RunLength; i++ { + if sequenceNumber == endSequenceNumberExclusive { + break + } + + recvTime := int64(0) + isLost := false + if chunk.PacketStatusSymbol != rtcp.TypeTCCPacketNotReceived { + recvRefTime += fbr.report.RecvDeltas[deltaIdx].Delta + deltaIdx++ + + recvTime = recvRefTime + } else { + isLost = true + } + pi, sendDelta, recvDelta := c.packetTracker.RecordPacketIndicationFromRemote(sequenceNumber, recvTime) + if pi.sendTime != 0 { + trackPacketGroup(&pi, sendDelta, recvDelta, isLost) + } + sequenceNumber++ + } + + case *rtcp.StatusVectorChunk: + for _, symbol := range chunk.SymbolList { + if sequenceNumber == endSequenceNumberExclusive { + break + } + + recvTime := int64(0) + isLost := false + if symbol != rtcp.TypeTCCPacketNotReceived { + recvRefTime += fbr.report.RecvDeltas[deltaIdx].Delta + deltaIdx++ + + recvTime = recvRefTime + } else { + isLost = true + } + pi, sendDelta, recvDelta := c.packetTracker.RecordPacketIndicationFromRemote(sequenceNumber, recvTime) + if pi.sendTime != 0 { + trackPacketGroup(&pi, sendDelta, recvDelta, isLost) + } + sequenceNumber++ + } + } + } + + c.prunePacketGroups() + c.congestionDetectionStateMachine() +} + +func (c *congestionDetector) worker() { + ticker := time.NewTicker(c.params.Config.PeriodicCheckInterval) + defer ticker.Stop() + + for { + select { + case <-c.wake: + for { + c.lock.Lock() + if c.feedbackReports.Len() == 0 { + c.lock.Unlock() + break + } + fbReport := c.feedbackReports.PopFront() + c.lock.Unlock() + + c.processFeedbackReport(fbReport) + } + + if c.GetCongestionState() == CongestionStateCongested { + ticker.Reset(c.params.Config.PeriodicCheckIntervalCongested) + } else { + ticker.Reset(c.params.Config.PeriodicCheckInterval) + } + + case <-ticker.C: + c.prunePacketGroups() + c.congestionDetectionStateMachine() + + case <-c.stop.Watch(): + return + } + } +} diff --git a/pkg/sfu/sendsidebwe/packet_group.go b/pkg/sfu/sendsidebwe/packet_group.go new file mode 100644 index 000000000..636eef189 --- /dev/null +++ b/pkg/sfu/sendsidebwe/packet_group.go @@ -0,0 +1,338 @@ +package sendsidebwe + +import ( + "errors" + "math" + "time" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------------------- + +var ( + errGroupFinalized = errors.New("packet group is finalized") + errOldPacket = errors.New("packet is older than packet group start") +) + +// ------------------------------------------------------------- + +type PacketGroupConfig struct { + MinPackets int `yaml:"min_packets,omitempty"` + MaxWindowDuration time.Duration `yaml:"max_window_duration,omitempty"` + + // should have at least this fraction of `MinPackets` for loss penalty consideration + LossPenaltyMinPacketsRatio float64 `yaml:"loss_penalty_min_packet_ratio,omitempty"` + LossPenaltyFactor float64 `yaml:"loss_penalty_factor,omitempty"` +} + +var ( + DefaultPacketGroupConfig = PacketGroupConfig{ + MinPackets: 20, + MaxWindowDuration: 500 * time.Millisecond, + LossPenaltyMinPacketsRatio: 0.5, + LossPenaltyFactor: 0.25, + } +) + +// ------------------------------------------------------------- + +type stat struct { + numPackets int + numBytes int +} + +func (s *stat) add(size int) { + s.numPackets++ + s.numBytes += size +} + +func (s *stat) remove(size int) { + s.numPackets-- + s.numBytes -= size +} + +func (s *stat) getNumPackets() int { + return s.numPackets +} + +func (s *stat) getNumBytes() int { + return s.numBytes +} + +func (s stat) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddInt("numPackets", s.numPackets) + e.AddInt("numBytes", s.numBytes) + return nil +} + +// ------------------------------------------------------------- + +type classStat struct { + primary stat + rtx stat +} + +func (c *classStat) add(size int, isRTX bool) { + if isRTX { + c.rtx.add(size) + } else { + c.primary.add(size) + } +} + +func (c *classStat) remove(size int, isRTX bool) { + if isRTX { + c.rtx.remove(size) + } else { + c.primary.remove(size) + } +} + +func (c *classStat) numPackets() int { + return c.primary.getNumPackets() + c.rtx.getNumPackets() +} + +func (c *classStat) numBytes() int { + return c.primary.getNumBytes() + c.rtx.getNumBytes() +} + +func (c classStat) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddObject("primary", c.primary) + e.AddObject("rtx", c.rtx) + return nil +} + +// ------------------------------------------------------------- + +type packetGroupParams struct { + Config PacketGroupConfig + Logger logger.Logger +} + +type packetGroup struct { + params packetGroupParams + + minSequenceNumber uint64 + maxSequenceNumber uint64 + + minSendTime int64 + maxSendTime int64 + + minRecvTime int64 // for information only + maxRecvTime int64 // for information only + + acked classStat + lost classStat + snBitmap *utils.Bitmap[uint64] + + aggregateSendDelta int64 + aggregateRecvDelta int64 + queuingDelay int64 + + isFinalized bool +} + +func newPacketGroup(params packetGroupParams, queuingDelay int64) *packetGroup { + return &packetGroup{ + params: params, + queuingDelay: queuingDelay, + snBitmap: utils.NewBitmap[uint64](params.Config.MinPackets), + } +} + +func (p *packetGroup) Add(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) error { + if isLost { + return p.lostPacket(pi) + } + + if err := p.inGroup(pi.sequenceNumber); err != nil { + return err + } + + if p.minSequenceNumber == 0 || pi.sequenceNumber < p.minSequenceNumber { + p.minSequenceNumber = pi.sequenceNumber + } + p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber) + + if p.minSendTime == 0 || (pi.sendTime-sendDelta) < p.minSendTime { + p.minSendTime = pi.sendTime - sendDelta + } + p.maxSendTime = max(p.maxSendTime, pi.sendTime) + + if p.minRecvTime == 0 || (pi.recvTime-recvDelta) < p.minRecvTime { + p.minRecvTime = pi.recvTime - recvDelta + } + p.maxRecvTime = max(p.maxRecvTime, pi.recvTime) + + p.acked.add(int(pi.size), pi.isRTX) + if p.snBitmap.IsSet(pi.sequenceNumber - p.minSequenceNumber) { + // an earlier packet reported as lost has been received + p.snBitmap.Clear(pi.sequenceNumber - p.minSequenceNumber) + p.lost.remove(int(pi.size), pi.isRTX) + } + + // note that out-of-order deliveries will amplify the queueing delay. + // for e.g. a, b, c getting delivered as a, c, b. + // let us say packets are delivered with interval of `x` + // send delta aggregate will go up by x((a, c) = 2x + (c, b) -1x) + // recv delta aggregate will go up by 3x((a, c) = 2x + (c, b) 1x) + p.aggregateSendDelta += sendDelta + p.aggregateRecvDelta += recvDelta + + if p.acked.numPackets() == p.params.Config.MinPackets || (pi.sendTime-p.minSendTime) > p.params.Config.MaxWindowDuration.Microseconds() { + p.isFinalized = true + } + return nil +} + +func (p *packetGroup) lostPacket(pi *packetInfo) error { + if pi.recvTime != 0 { + // previously received packet, so not lost + return nil + } + + if err := p.inGroup(pi.sequenceNumber); err != nil { + return err + } + + if p.minSequenceNumber == 0 || pi.sequenceNumber < p.minSequenceNumber { + p.minSequenceNumber = pi.sequenceNumber + } + p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber) + p.snBitmap.Set(pi.sequenceNumber - p.minSequenceNumber) + + p.lost.add(int(pi.size), pi.isRTX) + return nil +} + +func (p *packetGroup) MinSendTime() int64 { + return p.minSendTime +} + +func (p *packetGroup) PropagatedQueuingDelay() (int64, bool) { + if !p.isFinalized { + return 0, false + } + + if p.queuingDelay+p.aggregateRecvDelta-p.aggregateSendDelta > 0 { + return p.queuingDelay + p.aggregateRecvDelta - p.aggregateSendDelta, true + } + + return max(0, p.aggregateRecvDelta-p.aggregateSendDelta), true +} + +func (p *packetGroup) SendDuration() int64 { + if !p.isFinalized { + return 0 + } + + return p.maxSendTime - p.minSendTime +} + +func (p *packetGroup) CapturedTrafficRatio() float64 { + capturedTrafficRatio := float64(0.0) + if p.aggregateRecvDelta != 0 { + // apply a penalty for lost packets, + // tha rationale being packet dropping is a strategy to relieve congestion + // and if they were not dropped, they would have increased queuing delay, + // as it is not possible to know the reason for the losses, + // apply a small penalty to receive delta aggregate to simulate those packets + // build up queuing delay. + // + // note that it is applied only for determining rate and + // not while determining queuing region, adding synthetic delays + // like this could cause queuing region to be stuck in JQR + capturedTrafficRatio = float64(p.aggregateSendDelta) / float64(p.aggregateRecvDelta+p.getLossPenalty()) + } + return min(1.0, capturedTrafficRatio) +} + +func (p *packetGroup) Traffic() (int64, int64, int, float64) { + numBytes := int(float64(p.acked.numBytes()) * p.CapturedTrafficRatio()) + + fullness := max( + float64(p.acked.numPackets())/float64(p.params.Config.MinPackets), + float64(p.maxSendTime-p.minSendTime)/float64(p.params.Config.MaxWindowDuration.Microseconds()), + ) + + return p.minSendTime, p.maxSendTime - p.minSendTime, numBytes, fullness +} + +func (p *packetGroup) MarshalLogObject(e zapcore.ObjectEncoder) error { + if p == nil { + return nil + } + + e.AddInt64("minSendTime", p.minSendTime) + e.AddInt64("maxSendTime", p.maxSendTime) + sendDuration := time.Duration((p.maxSendTime - p.minSendTime) * 1000) + e.AddDuration("sendDuration", sendDuration) + + e.AddInt64("minRecvTime", p.minRecvTime) + e.AddInt64("maxRecvTime", p.maxRecvTime) + recvDuration := time.Duration((p.maxRecvTime - p.minRecvTime) * 1000) + e.AddDuration("recvDuration", recvDuration) + + e.AddObject("acked", p.acked) + e.AddObject("lost", p.lost) + + sendBitrate := float64(0) + if sendDuration != 0 { + sendBitrate = float64(p.acked.numBytes()*8) / sendDuration.Seconds() + e.AddFloat64("sendBitrate", sendBitrate) + } + + recvBitrate := float64(0) + if recvDuration != 0 { + recvBitrate = float64(p.acked.numBytes()*8) / recvDuration.Seconds() + e.AddFloat64("recvBitrate", recvBitrate) + } + + e.AddInt64("aggregateSendDelta", p.aggregateSendDelta) + e.AddInt64("aggregateRecvDelta", p.aggregateRecvDelta) + e.AddInt64("queuingDelay", p.queuingDelay) + e.AddInt64("groupDelay", p.aggregateRecvDelta-p.aggregateSendDelta) + e.AddFloat64("lossRatio", float64(p.lost.numPackets())/float64(p.acked.numPackets()+p.lost.numPackets())) + e.AddInt64("lossPenalty", p.getLossPenalty()) + capturedTrafficRatio := p.CapturedTrafficRatio() + e.AddFloat64("capturedTrafficRatio", capturedTrafficRatio) + e.AddFloat64("estimatedAvailableChannelCapacity", sendBitrate*capturedTrafficRatio) + + e.AddBool("isFinalized", p.isFinalized) + return nil +} + +func (p *packetGroup) inGroup(sequenceNumber uint64) error { + if p.isFinalized && sequenceNumber > p.maxSequenceNumber { + return errGroupFinalized + } + + if sequenceNumber < p.minSequenceNumber { + return errOldPacket + } + + return nil +} + +func (p *packetGroup) getLossPenalty() int64 { + if p.aggregateRecvDelta == 0 { + return 0 + } + + lostPackets := p.lost.numPackets() + totalPackets := float64(lostPackets + p.acked.numPackets()) + if totalPackets < float64(p.params.Config.MinPackets)*p.params.Config.LossPenaltyMinPacketsRatio { + return 0 + } + + // Log10 is used to give higher weight for the same loss ratio at higher packet rates, + // for e.g. with a penalty factor of 0.25 + // - 10% loss at 20 total packets = 0.1 * log10(20) * 0.25 = 0.032 + // - 10% loss at 100 total packets = 0.1 * log10(100) * 0.25 = 0.05 + // - 10% loss at 1000 total packets = 0.1 * log10(100) * 0.25 = 0.075 + lossRatio := float64(lostPackets) / totalPackets + return int64(float64(p.aggregateRecvDelta) * lossRatio * math.Log10(totalPackets) * p.params.Config.LossPenaltyFactor) +} diff --git a/pkg/sfu/sendsidebwe/packet_info.go b/pkg/sfu/sendsidebwe/packet_info.go new file mode 100644 index 000000000..58bff7a58 --- /dev/null +++ b/pkg/sfu/sendsidebwe/packet_info.go @@ -0,0 +1,36 @@ +package sendsidebwe + +import ( + "go.uber.org/zap/zapcore" +) + +type packetInfo struct { + sequenceNumber uint64 + sendTime int64 + recvTime int64 + size uint16 + isRTX bool + // SSBWE-TODO: possibly add the following fields - pertaining to this packet, + // idea is to be able to figure out probe start/end and check for bitrate in that window +} + +func (pi *packetInfo) Reset(sequenceNumber uint64) { + pi.sequenceNumber = sequenceNumber + pi.sendTime = 0 + pi.recvTime = 0 + pi.size = 0 + pi.isRTX = false +} + +func (pi *packetInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + if pi == nil { + return nil + } + + e.AddUint64("sequenceNumber", pi.sequenceNumber) + e.AddInt64("sendTime", pi.sendTime) + e.AddInt64("recvTime", pi.recvTime) + e.AddUint16("size", pi.size) + e.AddBool("isRTX", pi.isRTX) + return nil +} diff --git a/pkg/sfu/sendsidebwe/packet_tracker.go b/pkg/sfu/sendsidebwe/packet_tracker.go new file mode 100644 index 000000000..183181265 --- /dev/null +++ b/pkg/sfu/sendsidebwe/packet_tracker.go @@ -0,0 +1,113 @@ +package sendsidebwe + +import ( + "errors" + "math/rand" + "sync" + "time" + + "github.com/livekit/protocol/logger" +) + +// ------------------------------------------------------------------------------- + +var ( + errNoPacketInRange = errors.New("no packet in range") +) + +// ------------------------------------------------------------------------------- + +type packetTrackerParams struct { + Logger logger.Logger +} + +type packetTracker struct { + params packetTrackerParams + + lock sync.Mutex + + sequenceNumber uint64 + + baseSendTime int64 + packetInfos [2048]packetInfo + + baseRecvTime int64 + piLastRecv *packetInfo +} + +func newPacketTracker(params packetTrackerParams) *packetTracker { + return &packetTracker{ + params: params, + sequenceNumber: uint64(rand.Intn(1<<14)) + uint64(1<<15), // a random number in third quartile of sequence number space + } +} + +// SSBWE-TODO: this potentially needs to take isProbe as argument? +func (p *packetTracker) RecordPacketSendAndGetSequenceNumber(at time.Time, size int, isRTX bool) uint16 { + p.lock.Lock() + defer p.lock.Unlock() + + sendTime := at.UnixMicro() + if p.baseSendTime == 0 { + p.baseSendTime = sendTime + } + + pi := p.getPacketInfo(uint16(p.sequenceNumber)) + pi.sequenceNumber = p.sequenceNumber + pi.sendTime = sendTime - p.baseSendTime + pi.recvTime = 0 + pi.size = uint16(size) + pi.isRTX = isRTX + // SSBWE-REMOVE p.params.Logger.Infow("packet sent", "packetInfo", pi) // SSBWE-REMOVE + + p.sequenceNumber++ + + // extreme case of wrap around before receiving any feedback + if pi == p.piLastRecv { + p.piLastRecv = nil + } + + return uint16(pi.sequenceNumber) +} + +func (p *packetTracker) RecordPacketIndicationFromRemote(sn uint16, recvTime int64) (piRecv packetInfo, sendDelta, recvDelta int64) { + p.lock.Lock() + defer p.lock.Unlock() + + pi := p.getPacketInfoExisting(sn) + if pi == nil { + return + } + + if recvTime == 0 { + // maybe lost OR already receied but reported lost in a later report + piRecv = *pi + return + } + + if p.baseRecvTime == 0 { + p.baseRecvTime = recvTime + p.piLastRecv = pi + } + + pi.recvTime = recvTime - p.baseRecvTime + piRecv = *pi + if p.piLastRecv != nil { + sendDelta, recvDelta = pi.sendTime-p.piLastRecv.sendTime, pi.recvTime-p.piLastRecv.recvTime + } + p.piLastRecv = pi + return +} + +func (p *packetTracker) getPacketInfo(sn uint16) *packetInfo { + return &p.packetInfos[int(sn)%len(p.packetInfos)] +} + +func (p *packetTracker) getPacketInfoExisting(sn uint16) *packetInfo { + pi := &p.packetInfos[int(sn)%len(p.packetInfos)] + if uint16(pi.sequenceNumber) == sn { + return pi + } + + return nil +} diff --git a/pkg/sfu/sendsidebwe/send_side_bwe.go b/pkg/sfu/sendsidebwe/send_side_bwe.go new file mode 100644 index 000000000..f48f9af56 --- /dev/null +++ b/pkg/sfu/sendsidebwe/send_side_bwe.go @@ -0,0 +1,105 @@ +package sendsidebwe + +import ( + "fmt" + + "github.com/livekit/protocol/logger" +) + +// +// Based on a simplified/modified version of JitterPath paper +// (https://homepage.iis.sinica.edu.tw/papers/lcs/2114-F.pdf) +// +// TWCC feedback is uesed to calcualte delta one-way-delay. +// It is accumulated/propagated to determine in which region +// groups of packets are operating in. +// +// In simplified terms, +// o JQR (Join Queuing Region) is when channel is congested. +// o DQR (Disjoint Queuing Region) is when channel is not. +// +// Packets are grouped and thresholds applied to smooth over +// small variations. For example, in the paper, +// if propagated_queuing_delay + delta_one_way_delay > 0 { +// possibly_operating_in_jqr +// } +// But, in this implementation it is checked at packet group level, +// i. e. using queuing delay and aggreated delta one-way-delay of +// the group and a minimum value threshold is applied before declaring +// that a group is in JQR. +// +// There is also hysteresis to make transisitons smoother, i.e. if the +// metric is above a certain threshold, it is JQR and it is DQR only if it +// is below a certain value and the gap in between those two thresholds +// are treated as interdeterminate groups. +// + +// --------------------------------------------------------------------------- + +type CongestionState int + +const ( + CongestionStateNone CongestionState = iota + CongestionStateEarlyWarning + CongestionStateEarlyWarningHangover + CongestionStateCongested + CongestionStateCongestedHangover +) + +func (c CongestionState) String() string { + switch c { + case CongestionStateNone: + return "NONE" + case CongestionStateEarlyWarning: + return "EARLY_WARNING" + case CongestionStateEarlyWarningHangover: + return "EARLY_WARNING_HANGOVER" + case CongestionStateCongested: + return "CONGESTED" + case CongestionStateCongestedHangover: + return "CONGESTED_HANGOVER" + default: + return fmt.Sprintf("%d", int(c)) + } +} + +// --------------------------------------------------------------------------- + +type SendSideBWEConfig struct { + CongestionDetector CongestionDetectorConfig `yaml:"congestion_detector,omitempty"` +} + +var ( + DefaultSendSideBWEConfig = SendSideBWEConfig{ + CongestionDetector: DefaultCongestionDetectorConfig, + } +) + +// --------------------------------------------------------------------------- + +type SendSideBWEParams struct { + Config SendSideBWEConfig + Logger logger.Logger +} + +type SendSideBWE struct { + params SendSideBWEParams + + *congestionDetector +} + +func NewSendSideBWE(params SendSideBWEParams) *SendSideBWE { + return &SendSideBWE{ + params: params, + congestionDetector: newCongestionDetector(congestionDetectorParams{ + Config: params.Config.CongestionDetector, + Logger: params.Logger, + }), + } +} + +func (s *SendSideBWE) Stop() { + s.congestionDetector.Stop() +} + +// ------------------------------------------------ diff --git a/pkg/sfu/sendsidebwe/twcc_feedback.go b/pkg/sfu/sendsidebwe/twcc_feedback.go new file mode 100644 index 000000000..23eca2637 --- /dev/null +++ b/pkg/sfu/sendsidebwe/twcc_feedback.go @@ -0,0 +1,116 @@ +package sendsidebwe + +import ( + "errors" + "time" + + "github.com/livekit/protocol/logger" + "github.com/pion/rtcp" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------------ + +const ( + cOutlierReportFactor = 3 + cEstimatedFeedbackIntervalAlpha = float64(0.9) + + cReferenceTimeMask = (1 << 24) - 1 + cReferenceTimeResolution = 64 // 64 ms +) + +// ------------------------------------------------------ + +var ( + errFeedbackReportOutOfOrder = errors.New("feedback report out-of-order") +) + +// ------------------------------------------------------ + +type twccFeedbackParams struct { + Logger logger.Logger +} + +type twccFeedback struct { + params twccFeedbackParams + + lastFeedbackTime time.Time + estimatedFeedbackInterval time.Duration + numReports int + numReportsOutOfOrder int + + highestFeedbackCount uint8 + + cycles int64 + highestReferenceTime uint32 +} + +func newTWCCFeedback(params twccFeedbackParams) *twccFeedback { + return &twccFeedback{ + params: params, + } +} + +func (t *twccFeedback) ProcessReport(report *rtcp.TransportLayerCC, at time.Time) (int64, bool) { + // SSBWE-REMOVE t.params.Logger.Infow("TWCC feedback", "report", report.String()) // SSBWE-REMOVE + t.numReports++ + if t.lastFeedbackTime.IsZero() { + t.lastFeedbackTime = at + t.highestReferenceTime = report.ReferenceTime + t.highestFeedbackCount = report.FbPktCount + return (t.cycles + int64(report.ReferenceTime)) * cReferenceTimeResolution * 1000, false + } + + isOutOfOrder := false + if (report.FbPktCount - t.highestFeedbackCount) > (1 << 7) { + t.numReportsOutOfOrder++ + isOutOfOrder = true + } + + // reference time wrap around handling + var referenceTime int64 + if (report.ReferenceTime-t.highestReferenceTime)&cReferenceTimeMask < (1 << 23) { + if report.ReferenceTime < t.highestReferenceTime { + t.cycles += (1 << 24) + } + t.highestReferenceTime = report.ReferenceTime + referenceTime = t.cycles + int64(report.ReferenceTime) + } else { + cycles := t.cycles + if report.ReferenceTime > t.highestReferenceTime && cycles >= (1<<24) { + cycles -= (1 << 24) + } + referenceTime = cycles + int64(report.ReferenceTime) + } + + if !isOutOfOrder { + sinceLast := at.Sub(t.lastFeedbackTime) + // SSBWE-REMOVE t.params.Logger.Infow("report received", "at", at, "sinceLast", sinceLast, "pktCount", report.FbPktCount) // SSBWE-REMOVE + if t.estimatedFeedbackInterval == 0 { + t.estimatedFeedbackInterval = sinceLast + } else { + // filter out outliers from estimate + if sinceLast > t.estimatedFeedbackInterval/cOutlierReportFactor && sinceLast < cOutlierReportFactor*t.estimatedFeedbackInterval { + // smoothed version of inter feedback interval + t.estimatedFeedbackInterval = time.Duration(cEstimatedFeedbackIntervalAlpha*float64(t.estimatedFeedbackInterval) + (1.0-cEstimatedFeedbackIntervalAlpha)*float64(sinceLast)) + } + } + t.lastFeedbackTime = at + t.highestFeedbackCount = report.FbPktCount + } + + return referenceTime * cReferenceTimeResolution * 1000, isOutOfOrder +} + +func (t *twccFeedback) MarshalLogObject(e zapcore.ObjectEncoder) error { + if t == nil { + return nil + } + + e.AddTime("lastFeedbackTime", t.lastFeedbackTime) + e.AddDuration("estimatedFeedbackInterval", t.estimatedFeedbackInterval) + e.AddInt("numReports", t.numReports) + e.AddInt("numReportsOutOfOrder", t.numReportsOutOfOrder) + e.AddInt64("cycles", t.cycles/(1<<24)) + return nil +} diff --git a/pkg/sfu/streamallocator/channelobserver.go b/pkg/sfu/streamallocator/channelobserver.go index 70b67a3cf..a9a0d2221 100644 --- a/pkg/sfu/streamallocator/channelobserver.go +++ b/pkg/sfu/streamallocator/channelobserver.go @@ -16,7 +16,9 @@ package streamallocator import ( "fmt" + "time" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" "github.com/livekit/protocol/logger" ) @@ -69,18 +71,36 @@ func (c ChannelCongestionReason) String() string { // ------------------------------------------------ type ChannelObserverConfig struct { - Estimate TrendDetectorConfig `yaml:"estimate,omitempty"` - Nack NackTrackerConfig `yaml:"nack,omitempty"` + Estimate ccutils.TrendDetectorConfig `yaml:"estimate,omitempty"` + Nack NackTrackerConfig `yaml:"nack,omitempty"` } var ( + defaultTrendDetectorConfigProbe = ccutils.TrendDetectorConfig{ + RequiredSamples: 3, + RequiredSamplesMin: 3, + DownwardTrendThreshold: 0.0, + DownwardTrendMaxWait: 5 * time.Second, + CollapseThreshold: 0, + ValidityWindow: 10 * time.Second, + } + DefaultChannelObserverConfigProbe = ChannelObserverConfig{ - Estimate: DefaultTrendDetectorConfigProbe, + Estimate: defaultTrendDetectorConfigProbe, Nack: DefaultNackTrackerConfigProbe, } + defaultTrendDetectorConfigNonProbe = ccutils.TrendDetectorConfig{ + RequiredSamples: 12, + RequiredSamplesMin: 8, + DownwardTrendThreshold: -0.6, + DownwardTrendMaxWait: 5 * time.Second, + CollapseThreshold: 500 * time.Millisecond, + ValidityWindow: 10 * time.Second, + } + DefaultChannelObserverConfigNonProbe = ChannelObserverConfig{ - Estimate: DefaultTrendDetectorConfigNonProbe, + Estimate: defaultTrendDetectorConfigNonProbe, Nack: DefaultNackTrackerConfigNonProbe, } ) @@ -96,7 +116,7 @@ type ChannelObserver struct { params ChannelObserverParams logger logger.Logger - estimateTrend *TrendDetector + estimateTrend *ccutils.TrendDetector[int64] nackTracker *NackTracker } @@ -104,7 +124,7 @@ func NewChannelObserver(params ChannelObserverParams, logger logger.Logger) *Cha return &ChannelObserver{ params: params, logger: logger, - estimateTrend: NewTrendDetector(TrendDetectorParams{ + estimateTrend: ccutils.NewTrendDetector[int64](ccutils.TrendDetectorParams{ Name: params.Name + "-estimate", Logger: logger, Config: params.Config.Estimate, @@ -155,7 +175,7 @@ func (c *ChannelObserver) GetTrend() (ChannelTrend, ChannelCongestionReason) { estimateDirection := c.estimateTrend.GetDirection() switch { - case estimateDirection == TrendDirectionDownward: + case estimateDirection == ccutils.TrendDirectionDownward: c.logger.Debugw("stream allocator: channel observer: estimate is trending downward", "channel", c.ToString()) return ChannelTrendCongesting, ChannelCongestionReasonEstimate @@ -163,7 +183,7 @@ func (c *ChannelObserver) GetTrend() (ChannelTrend, ChannelCongestionReason) { c.logger.Debugw("stream allocator: channel observer: high rate of repeated NACKs", "channel", c.ToString()) return ChannelTrendCongesting, ChannelCongestionReasonLoss - case estimateDirection == TrendDirectionUpward: + case estimateDirection == ccutils.TrendDirectionUpward: return ChannelTrendClearing, ChannelCongestionReasonNone } diff --git a/pkg/sfu/streamallocator/probe_controller.go b/pkg/sfu/streamallocator/probe_controller.go index c2b21b8e1..c76a17fc3 100644 --- a/pkg/sfu/streamallocator/probe_controller.go +++ b/pkg/sfu/streamallocator/probe_controller.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/protocol/logger" ) @@ -73,6 +74,7 @@ type ProbeController struct { params ProbeControllerParams lock sync.RWMutex + sendSideBWE *sendsidebwe.SendSideBWE probeInterval time.Duration lastProbeStartTime time.Time probeGoalBps int64 @@ -95,6 +97,13 @@ func NewProbeController(params ProbeControllerParams) *ProbeController { return p } +func (p *ProbeController) SetSendSideBWE(sendSideBWE *sendsidebwe.SendSideBWE) { + p.lock.Lock() + defer p.lock.Unlock() + + p.sendSideBWE = sendSideBWE +} + func (p *ProbeController) Reset() { p.lock.Lock() defer p.lock.Unlock() @@ -270,9 +279,78 @@ func (p *ProbeController) InitProbe(probeGoalDeltaBps int64, expectedBandwidthUs time.Duration(float64(p.probeDuration.Milliseconds())*p.params.Config.DurationOverflowFactor)*time.Millisecond, ) + p.pollProbe(p.probeClusterId) + return p.probeClusterId, p.probeGoalBps } +// SSBWE-TODO: try to do same path for both SSBWE and regular, the congesting part might be different though +func (p *ProbeController) pollProbe(probeClusterId ProbeClusterId) { + if p.sendSideBWE == nil { + return + } + + startingEstimate := p.sendSideBWE.GetEstimatedAvailableChannelCapacity() + + go func() { + for { + p.lock.Lock() + if p.probeClusterId != probeClusterId { + p.lock.Unlock() + return + } + + done := false + congestionState := p.sendSideBWE.GetCongestionState() + currentEstimate := p.sendSideBWE.GetEstimatedAvailableChannelCapacity() + switch { + case currentEstimate <= startingEstimate && time.Since(p.lastProbeStartTime) > p.params.Config.TrendWait: + // + // More of a safety net. + // In rare cases, the estimate gets stuck. Prevent from probe running amok + // STREAM-ALLOCATOR-TODO: Need more testing here to ensure that probe does not cause a lot of damage + // + p.params.Logger.Infow("stream allocator: probe: aborting, no trend", "cluster", probeClusterId) + p.abortProbeLocked() + done = true + break + + case congestionState == sendsidebwe.CongestionStateCongested || congestionState == sendsidebwe.CongestionStateEarlyWarning: + // stop immediately if the probe is congesting channel more + p.params.Logger.Infow( + "stream allocator: probe: aborting, channel is congesting", + "cluster", probeClusterId, + "congestionState", congestionState, + ) + p.abortProbeLocked() + done = true + break + + case currentEstimate > p.probeGoalBps: + // reached goal, stop probing + p.params.Logger.Infow( + "stream allocator: probe: stopping, goal reached", + "cluster", probeClusterId, + "goal", p.probeGoalBps, + "current", currentEstimate, + ) + p.goalReachedProbeClusterId = p.probeClusterId + p.StopProbe() + done = true + break + } + p.lock.Unlock() + + if done { + return + } + + // SSBWE-TODO: do not hard code sleep time + time.Sleep(50 * time.Millisecond) + } + }() +} + func (p *ProbeController) clearProbeLocked() { p.probeClusterId = ProbeClusterIdInvalid p.doneProbeClusterInfo = ProbeClusterInfo{Id: ProbeClusterIdInvalid} diff --git a/pkg/sfu/streamallocator/streamallocator.go b/pkg/sfu/streamallocator/streamallocator.go index c7d422500..a88b0a329 100644 --- a/pkg/sfu/streamallocator/streamallocator.go +++ b/pkg/sfu/streamallocator/streamallocator.go @@ -30,6 +30,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/sendsidebwe" "github.com/livekit/livekit-server/pkg/utils" ) @@ -86,6 +87,7 @@ const ( streamAllocatorSignalSetChannelCapacity // STREAM-ALLOCATOR-DATA streamAllocatorSignalNACK // STREAM-ALLOCATOR-DATA streamAllocatorSignalRTCPReceiverReport + streamAllocatorSignalCongestionStateChange ) func (s streamAllocatorSignal) String() string { @@ -116,6 +118,8 @@ func (s streamAllocatorSignal) String() string { case streamAllocatorSignalRTCPReceiverReport: return "RTCP_RECEIVER_REPORT" */ + case streamAllocatorSignalCongestionStateChange: + return "CONGESTION_STATE_CHANGE" default: return fmt.Sprintf("%d", int(s)) } @@ -179,7 +183,8 @@ type StreamAllocator struct { onStreamStateChange func(update *StreamStateUpdate) error - bwe cc.BandwidthEstimator + sendSideBWEInterceptor cc.BandwidthEstimator + sendSideBWE *sendsidebwe.SendSideBWE enabled bool allowPause bool @@ -200,7 +205,8 @@ type StreamAllocator struct { isAllocateAllPending bool rembTrackingSSRC uint32 - state streamAllocatorState + state streamAllocatorState + isHolding bool eventsQueue *utils.TypedOpsQueue[Event] @@ -256,11 +262,19 @@ func (s *StreamAllocator) OnStreamStateChange(f func(update *StreamStateUpdate) s.onStreamStateChange = f } -func (s *StreamAllocator) SetBandwidthEstimator(bwe cc.BandwidthEstimator) { - if bwe != nil { - bwe.OnTargetBitrateChange(s.onTargetBitrateChange) +func (s *StreamAllocator) SetSendSideBWEInterceptor(sendSideBWEInterceptor cc.BandwidthEstimator) { + if sendSideBWEInterceptor != nil { + sendSideBWEInterceptor.OnTargetBitrateChange(s.onTargetBitrateChange) } - s.bwe = bwe + s.sendSideBWEInterceptor = sendSideBWEInterceptor +} + +func (s *StreamAllocator) SetSendSideBWE(sendSideBWE *sendsidebwe.SendSideBWE) { + if sendSideBWE != nil { + sendSideBWE.OnCongestionStateChange(s.onCongestionStateChange) + } + s.sendSideBWE = sendSideBWE + s.probeController.SetSendSideBWE(sendSideBWE) } type AddTrackParams struct { @@ -428,8 +442,12 @@ func (s *StreamAllocator) OnREMB(downTrack *sfu.DownTrack, remb *rtcp.ReceiverEs // called when a new transport-cc feedback is received func (s *StreamAllocator) OnTransportCCFeedback(downTrack *sfu.DownTrack, fb *rtcp.TransportLayerCC) { - if s.bwe != nil { - s.bwe.WriteRTCP([]rtcp.Packet{fb}, nil) + if s.sendSideBWEInterceptor != nil { + s.sendSideBWEInterceptor.WriteRTCP([]rtcp.Packet{fb}, nil) + } + + if s.sendSideBWE != nil { + s.sendSideBWE.HandleRTCP(fb) } } @@ -441,6 +459,22 @@ func (s *StreamAllocator) onTargetBitrateChange(bitrate int) { }) } +// called when congestion state changes (send side bandwidth estimation) +type congestionStateChangeData struct { + congestionState sendsidebwe.CongestionState + estimatedAvailableChannelCapacity int64 +} + +func (s *StreamAllocator) onCongestionStateChange(congestionState sendsidebwe.CongestionState, estimatedAvailableChannelCapacity int64) { + s.postEvent(Event{ + Signal: streamAllocatorSignalCongestionStateChange, + Data: congestionStateChangeData{ + congestionState: congestionState, + estimatedAvailableChannelCapacity: estimatedAvailableChannelCapacity, + }, + }) +} + // called when feeding track's layer availability changes func (s *StreamAllocator) OnAvailableLayersChanged(downTrack *sfu.DownTrack) { s.maybePostEventAllocateTrack(downTrack) @@ -637,6 +671,8 @@ func (s *StreamAllocator) postEvent(event Event) { case streamAllocatorSignalRTCPReceiverReport: event.s.handleSignalRTCPReceiverReport(event) */ + case streamAllocatorSignalCongestionStateChange: + s.handleSignalCongestionStateChange(event) } }, event) } @@ -780,6 +816,45 @@ func (s *StreamAllocator) handleSignalRTCPReceiverReport(event Event) { } */ +func (s *StreamAllocator) handleSignalCongestionStateChange(event Event) { + cscd := event.Data.(congestionStateChangeData) + if cscd.congestionState != sendsidebwe.CongestionStateNone { + s.probeController.AbortProbe() + } + + if cscd.congestionState == sendsidebwe.CongestionStateEarlyWarning || + cscd.congestionState == sendsidebwe.CongestionStateEarlyWarningHangover { + s.isHolding = true + } else { + s.isHolding = false + + // early warning is done and hold has been released, + // if there is no congestion, allocate all tracks optimally as + // some tracks may have been held at sub-optimal allocation + // during early warning hold + if cscd.congestionState == sendsidebwe.CongestionStateNone && s.state == streamAllocatorStateStable { + update := NewStreamStateUpdate() + for _, track := range s.getTracks() { + allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal, s.isHolding) + updateStreamStateChange(track, allocation, update) + } + s.maybeSendUpdate(update) + } + } + + if cscd.congestionState == sendsidebwe.CongestionStateCongested { + s.params.Logger.Infow( + "stream allocator: channel congestion detected, updating channel capacity", + "old(bps)", s.committedChannelCapacity, + "new(bps)", cscd.estimatedAvailableChannelCapacity, + "expectedUsage(bps)", s.getExpectedBandwidthUsage(), + ) + s.committedChannelCapacity = cscd.estimatedAvailableChannelCapacity + + s.allocateAllTracks() + } +} + func (s *StreamAllocator) setState(state streamAllocatorState) { if s.state == state { return @@ -905,7 +980,7 @@ func (s *StreamAllocator) allocateTrack(track *Track) { // if not deficient, free pass allocate track if !s.enabled || s.state == streamAllocatorStateStable || !track.IsManaged() { update := NewStreamStateUpdate() - allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal) + allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal, s.isHolding) updateStreamStateChange(track, allocation, update) s.maybeSendUpdate(update) return @@ -1055,6 +1130,9 @@ func (s *StreamAllocator) allocateTrack(track *Track) { func (s *StreamAllocator) onProbeDone(isNotFailing bool, isGoalReached bool) { highestEstimateInProbe := s.channelObserver.GetHighestEstimate() + if s.sendSideBWE != nil { + highestEstimateInProbe = s.sendSideBWE.GetEstimatedAvailableChannelCapacity() + } // // Reset estimator at the end of a probe irrespective of probe result to get fresh readings. @@ -1161,7 +1239,7 @@ func (s *StreamAllocator) allocateAllTracks() { continue } - allocation := track.AllocateOptimal(FlagAllowOvershootExemptTrackWhileDeficient) + allocation := track.AllocateOptimal(FlagAllowOvershootExemptTrackWhileDeficient, false) updateStreamStateChange(track, allocation, update) // STREAM-ALLOCATOR-TODO: optimistic allocation before bitrate is available will return 0. How to account for that? @@ -1198,6 +1276,7 @@ func (s *StreamAllocator) allocateAllTracks() { for _, track := range sorted { _, usedChannelCapacity := track.ProvisionalAllocate(availableChannelCapacity, layer, s.allowPause, FlagAllowOvershootWhileDeficient) + s.params.Logger.Infow("debug allocated", "trackID", track.ID(), "usedChannelCapacity", usedChannelCapacity, "availableChannelCapacity", availableChannelCapacity) // REMOVE availableChannelCapacity -= usedChannelCapacity if availableChannelCapacity < 0 { availableChannelCapacity = 0 @@ -1357,6 +1436,9 @@ func (s *StreamAllocator) maybeProbe() { if !s.probeController.CanProbe() { return } + if s.sendSideBWE != nil && s.sendSideBWE.GetCongestionState() != sendsidebwe.CongestionStateNone { + return + } switch s.params.Config.ProbeMode { case ProbeModeMedia: diff --git a/pkg/sfu/streamallocator/track.go b/pkg/sfu/streamallocator/track.go index 1a0669585..77b560109 100644 --- a/pkg/sfu/streamallocator/track.go +++ b/pkg/sfu/streamallocator/track.go @@ -154,8 +154,8 @@ func (t *Track) WritePaddingRTP(bytesToSend int) int { return t.downTrack.WritePaddingRTP(bytesToSend, false, false) } -func (t *Track) AllocateOptimal(allowOvershoot bool) sfu.VideoAllocation { - return t.downTrack.AllocateOptimal(allowOvershoot) +func (t *Track) AllocateOptimal(allowOvershoot bool, hold bool) sfu.VideoAllocation { + return t.downTrack.AllocateOptimal(allowOvershoot, hold) } func (t *Track) ProvisionalAllocatePrepare() {