TWCC based congestion control - v0 (#3165)

* file output

* wake under lock

* keep track of RTX bytes separately

* packet group

* Packet group of 50ms

* Minor refactoring

* rate calculator

* send bit rate

* WIP

* comment

* reduce packet infos size

* extended twcc seq num

* fix packet info

* WIP

* queuing delay

* refactor

* config

* callbacks

* fixes

* clean up

* remove debug file, fix rate calculation

* fmt

* fix probes

* format

* notes

* check loss

* tweak detection settings

* 24-bit wrap

* clean up a bit

* limit symbol list to number of packets

* fmt

* clean up

* lost

* fixes

* fmt

* rename

* fixes

* fmt

* use min/max

* hold on early warning of congestion

* make note about need for all optimal allocation on hold release

* estimate trend in congested state

* tweaks

* quantized

* fmt

* TrendDetector generics

* CTR trend

* tweaks

* config

* config

* comments

* clean up

* consistent naming

* pariticpant level setting

* log usage mode

* feedback
This commit is contained in:
Raja Subramanian
2024-11-11 10:24:47 +05:30
committed by GitHub
parent 653857e42b
commit a3f2ca56f9
25 changed files with 1748 additions and 169 deletions
+15 -7
View File
@@ -29,6 +29,7 @@ import (
"github.com/livekit/livekit-server/pkg/metric"
"github.com/livekit/livekit-server/pkg/sfu"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/livekit-server/pkg/sfu/streamallocator"
"github.com/livekit/mediatransportutil/pkg/rtcconfig"
"github.com/livekit/protocol/livekit"
@@ -123,10 +124,15 @@ type TURNServer struct {
}
type CongestionControlConfig struct {
Enabled bool `yaml:"enabled,omitempty"`
AllowPause bool `yaml:"allow_pause,omitempty"`
Enabled bool `yaml:"enabled,omitempty"`
AllowPause bool `yaml:"allow_pause,omitempty"`
StreamAllocator streamallocator.StreamAllocatorConfig `yaml:"stream_allocator,omitempty"`
UseSendSideBWE bool `yaml:"use_send_side_bwe,omitempty"`
UseSendSideBWEInterceptor bool `yaml:"use_send_side_bwe_interceptor,omitempty"`
UseSendSideBWE bool `yaml:"use_send_side_bwe,omitempty"`
SendSideBWE sendsidebwe.SendSideBWEConfig `yaml:"send_side_bwe,omitempty"`
}
type PlayoutDelayConfig struct {
@@ -305,10 +311,12 @@ var DefaultConfig = Config{
StrictACKs: true,
PLIThrottle: sfu.DefaultPLIThrottleConfig,
CongestionControl: CongestionControlConfig{
Enabled: true,
AllowPause: false,
StreamAllocator: streamallocator.DefaultStreamAllocatorConfig,
UseSendSideBWE: false,
Enabled: true,
AllowPause: false,
StreamAllocator: streamallocator.DefaultStreamAllocatorConfig,
UseSendSideBWEInterceptor: false,
UseSendSideBWE: false,
SendSideBWE: sendsidebwe.DefaultSendSideBWEConfig,
},
},
Audio: sfu.DefaultAudioConfig,
+1 -1
View File
@@ -133,7 +133,7 @@ func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) {
},
},
}
if rtcConf.CongestionControl.UseSendSideBWE {
if rtcConf.CongestionControl.UseSendSideBWEInterceptor || rtcConf.CongestionControl.UseSendSideBWE {
subscriberConfig.RTPHeaderExtension.Video = append(subscriberConfig.RTPHeaderExtension.Video, sdp.TransportCCURI)
subscriberConfig.RTCPFeedback.Video = append(subscriberConfig.RTCPFeedback.Video, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC})
} else {
+4
View File
@@ -158,6 +158,8 @@ type ParticipantParams struct {
ForwardStats *sfu.ForwardStats
DisableSenderReportPassThrough bool
MetricConfig metric.MetricConfig
UseSendSideBWEInterceptor bool
UseSendSideBWE bool
}
type ParticipantImpl struct {
@@ -1458,6 +1460,8 @@ func (p *ParticipantImpl) setupTransportManager() error {
PublisherHandler: pth,
SubscriberHandler: sth,
DataChannelStats: p.dataChannelStats,
UseSendSideBWEInterceptor: p.params.UseSendSideBWEInterceptor,
UseSendSideBWE: p.params.UseSendSideBWE,
}
if p.params.SyncStreams && p.params.PlayoutDelay.GetEnabled() && p.params.ClientInfo.isFirefox() {
// we will disable playout delay for Firefox if the user is expecting
+30 -11
View File
@@ -36,23 +36,23 @@ import (
"go.uber.org/atomic"
"go.uber.org/zap/zapcore"
lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor"
lktwcc "github.com/livekit/mediatransportutil/pkg/twcc"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/logger/pionlogger"
lksdp "github.com/livekit/protocol/sdp"
"github.com/livekit/livekit-server/pkg/config"
"github.com/livekit/livekit-server/pkg/rtc/transport"
"github.com/livekit/livekit-server/pkg/rtc/types"
sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor"
"github.com/livekit/livekit-server/pkg/sfu/pacer"
pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/livekit-server/pkg/sfu/streamallocator"
sfuutils "github.com/livekit/livekit-server/pkg/sfu/utils"
"github.com/livekit/livekit-server/pkg/telemetry/prometheus"
"github.com/livekit/livekit-server/pkg/utils"
lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor"
lktwcc "github.com/livekit/mediatransportutil/pkg/twcc"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/logger/pionlogger"
lksdp "github.com/livekit/protocol/sdp"
)
const (
@@ -208,7 +208,8 @@ type PCTransport struct {
streamAllocator *streamallocator.StreamAllocator
// only for subscriber PC
pacer pacer.Pacer
sendSideBWE *sendsidebwe.SendSideBWE
pacer pacer.Pacer
previousAnswer *webrtc.SessionDescription
// track id -> description map in previous offer sdp
@@ -254,6 +255,8 @@ type TransportParams struct {
IsSendSide bool
AllowPlayoutDelay bool
DataChannelMaxBufferedAmount uint64
UseSendSideBWEInterceptor bool
UseSendSideBWE bool
}
func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimator cc.BandwidthEstimator)) (*webrtc.PeerConnection, *webrtc.MediaEngine, error) {
@@ -347,7 +350,8 @@ func newPeerConnection(params TransportParams, onBandwidthEstimator func(estimat
ir := &interceptor.Registry{}
if params.IsSendSide {
se.DetachDataChannels()
if params.CongestionControlConfig.UseSendSideBWE {
if params.CongestionControlConfig.UseSendSideBWEInterceptor || params.UseSendSideBWEInterceptor && (!params.CongestionControlConfig.UseSendSideBWE && !params.UseSendSideBWE) {
params.Logger.Infow("using send side BWE - interceptor")
gf, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) {
return gcc.NewSendSideBWE(
gcc.SendSideBWEInitialBitrate(1*1000*1000),
@@ -449,7 +453,19 @@ func NewPCTransport(params TransportParams) (*PCTransport, error) {
}, params.CongestionControlConfig.Enabled, params.CongestionControlConfig.AllowPause)
t.streamAllocator.OnStreamStateChange(params.Handler.OnStreamStateChange)
t.streamAllocator.Start()
t.pacer = pacer.NewPassThrough(params.Logger)
if params.CongestionControlConfig.UseSendSideBWE || params.UseSendSideBWE {
params.Logger.Infow("using send side BWE")
t.sendSideBWE = sendsidebwe.NewSendSideBWE(sendsidebwe.SendSideBWEParams{
Config: params.CongestionControlConfig.SendSideBWE,
Logger: params.Logger,
})
t.pacer = pacer.NewNoQueue(params.Logger, t.sendSideBWE)
t.streamAllocator.SetSendSideBWE(t.sendSideBWE)
} else {
t.pacer = pacer.NewPassThrough(params.Logger, nil)
}
}
if err := t.createPeerConnection(); err != nil {
@@ -496,7 +512,7 @@ func (t *PCTransport) createPeerConnection() error {
t.me = me
if bwe != nil && t.streamAllocator != nil {
t.streamAllocator.SetBandwidthEstimator(bwe)
t.streamAllocator.SetSendSideBWEInterceptor(bwe)
}
return nil
@@ -975,6 +991,9 @@ func (t *PCTransport) Close() {
if t.pacer != nil {
t.pacer.Stop()
}
if t.sendSideBWE != nil {
t.sendSideBWE.Stop()
}
_ = t.pc.Close()
+4
View File
@@ -103,6 +103,8 @@ type TransportManagerParams struct {
PublisherHandler transport.Handler
SubscriberHandler transport.Handler
DataChannelStats *telemetry.BytesTrackStats
UseSendSideBWEInterceptor bool
UseSendSideBWE bool
}
type TransportManager struct {
@@ -180,6 +182,8 @@ func NewTransportManager(params TransportManagerParams) (*TransportManager, erro
DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount,
Transport: livekit.SignalTarget_SUBSCRIBER,
Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr},
UseSendSideBWEInterceptor: params.UseSendSideBWEInterceptor,
UseSendSideBWE: params.UseSendSideBWE,
})
if err != nil {
return nil, err
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package streamallocator
package ccutils
import (
"fmt"
@@ -46,26 +46,15 @@ func (t TrendDirection) String() string {
// ------------------------------------------------
type trendDetectorSample struct {
value int64
at time.Time
type trendDetectorNumber interface {
int64 | float64
}
func trendDetectorSampleListToString(samples []trendDetectorSample) string {
samplesStr := ""
if len(samples) > 0 {
firstTime := samples[0].at
samplesStr += "["
for i, sample := range samples {
suffix := ", "
if i == len(samples)-1 {
suffix = ""
}
samplesStr += fmt.Sprintf("%d(%d)%s", sample.value, sample.at.Sub(firstTime).Milliseconds(), suffix)
}
samplesStr += "]"
}
return samplesStr
// ------------------------------------------------
type trendDetectorSample[T trendDetectorNumber] struct {
value T
at time.Time
}
// ------------------------------------------------
@@ -79,26 +68,6 @@ type TrendDetectorConfig struct {
ValidityWindow time.Duration `yaml:"validity_window,omitempty"`
}
var (
DefaultTrendDetectorConfigProbe = TrendDetectorConfig{
RequiredSamples: 3,
RequiredSamplesMin: 3,
DownwardTrendThreshold: 0.0,
DownwardTrendMaxWait: 5 * time.Second,
CollapseThreshold: 0,
ValidityWindow: 10 * time.Second,
}
DefaultTrendDetectorConfigNonProbe = TrendDetectorConfig{
RequiredSamples: 12,
RequiredSamplesMin: 8,
DownwardTrendThreshold: -0.6,
DownwardTrendMaxWait: 5 * time.Second,
CollapseThreshold: 500 * time.Millisecond,
ValidityWindow: 10 * time.Second,
}
)
// ------------------------------------------------
type TrendDetectorParams struct {
@@ -107,35 +76,35 @@ type TrendDetectorParams struct {
Config TrendDetectorConfig
}
type TrendDetector struct {
type TrendDetector[T trendDetectorNumber] struct {
params TrendDetectorParams
startTime time.Time
numSamples int
samples []trendDetectorSample
lowestValue int64
highestValue int64
samples []trendDetectorSample[T]
lowestValue T
highestValue T
direction TrendDirection
}
func NewTrendDetector(params TrendDetectorParams) *TrendDetector {
return &TrendDetector{
func NewTrendDetector[T trendDetectorNumber](params TrendDetectorParams) *TrendDetector[T] {
return &TrendDetector[T]{
params: params,
startTime: time.Now(),
direction: TrendDirectionNeutral,
}
}
func (t *TrendDetector) Seed(value int64) {
func (t *TrendDetector[T]) Seed(value T) {
if len(t.samples) != 0 {
return
}
t.samples = append(t.samples, trendDetectorSample{value: value, at: time.Now()})
t.samples = append(t.samples, trendDetectorSample[T]{value: value, at: time.Now()})
}
func (t *TrendDetector) AddValue(value int64) {
func (t *TrendDetector[T]) AddValue(value T) {
t.numSamples++
if t.lowestValue == 0 || value < t.lowestValue {
t.lowestValue = value
@@ -156,7 +125,7 @@ func (t *TrendDetector) AddValue(value int64) {
// But, on the flip side, estimate could fall once or twice within a sliding window and stay there.
// In those cases, using a collapse window to record a value even if it is duplicate. By doing that,
// a trend could be detected eventually. It will be delayed, but that is fine with slow changing estimates.
var lastSample *trendDetectorSample
var lastSample *trendDetectorSample[T]
if len(t.samples) != 0 {
lastSample = &t.samples[len(t.samples)-1]
}
@@ -164,38 +133,52 @@ func (t *TrendDetector) AddValue(value int64) {
return
}
t.samples = append(t.samples, trendDetectorSample{value: value, at: time.Now()})
t.samples = append(t.samples, trendDetectorSample[T]{value: value, at: time.Now()})
t.prune()
t.updateDirection()
}
func (t *TrendDetector) GetLowest() int64 {
func (t *TrendDetector[T]) GetLowest() T {
return t.lowestValue
}
func (t *TrendDetector) GetHighest() int64 {
func (t *TrendDetector[T]) GetHighest() T {
return t.highestValue
}
func (t *TrendDetector) GetDirection() TrendDirection {
func (t *TrendDetector[T]) GetDirection() TrendDirection {
return t.direction
}
func (t *TrendDetector) HasEnoughSamples() bool {
func (t *TrendDetector[T]) HasEnoughSamples() bool {
return t.numSamples >= t.params.Config.RequiredSamples
}
func (t *TrendDetector) ToString() string {
func (t *TrendDetector[T]) ToString() string {
samplesStr := ""
if len(t.samples) > 0 {
firstTime := t.samples[0].at
samplesStr += "["
for i, sample := range t.samples {
suffix := ", "
if i == len(t.samples)-1 {
suffix = ""
}
samplesStr += fmt.Sprintf("%v(%d)%s", sample.value, sample.at.Sub(firstTime).Milliseconds(), suffix)
}
samplesStr += "]"
}
now := time.Now()
elapsed := now.Sub(t.startTime).Seconds()
return fmt.Sprintf("n: %s, t: %+v|%+v|%.2fs, v: %d|%d|%d|%s|%.2f",
return fmt.Sprintf("n: %s, t: %+v|%+v|%.2fs, v: %d|%v|%v|%s|%.2f",
t.params.Name,
t.startTime.Format(time.UnixDate), now.Format(time.UnixDate), elapsed,
t.numSamples, t.lowestValue, t.highestValue, trendDetectorSampleListToString(t.samples), kendallsTau(t.samples),
t.numSamples, t.lowestValue, t.highestValue, samplesStr, t.kendallsTau(),
)
}
func (t *TrendDetector) prune() {
func (t *TrendDetector[T]) prune() {
// prune based on a few rules
// 1. If there are more than required samples
@@ -238,14 +221,14 @@ func (t *TrendDetector) prune() {
}
}
func (t *TrendDetector) updateDirection() {
func (t *TrendDetector[T]) updateDirection() {
if len(t.samples) < t.params.Config.RequiredSamplesMin {
t.direction = TrendDirectionNeutral
return
}
// using Kendall's Tau to find trend
kt := kendallsTau(t.samples)
kt := t.kendallsTau()
t.direction = TrendDirectionNeutral
switch {
@@ -256,17 +239,15 @@ func (t *TrendDetector) updateDirection() {
}
}
// ------------------------------------------------
func kendallsTau(samples []trendDetectorSample) float64 {
func (t *TrendDetector[T]) kendallsTau() float64 {
concordantPairs := 0
discordantPairs := 0
for i := 0; i < len(samples)-1; i++ {
for j := i + 1; j < len(samples); j++ {
if samples[i].value < samples[j].value {
for i := 0; i < len(t.samples)-1; i++ {
for j := i + 1; j < len(t.samples); j++ {
if t.samples[i].value < t.samples[j].value {
concordantPairs++
} else if samples[i].value > samples[j].value {
} else if t.samples[i].value > t.samples[j].value {
discordantPairs++
}
}
+3 -2
View File
@@ -1407,9 +1407,9 @@ func (d *DownTrack) DistanceToDesired() float64 {
return d.forwarder.DistanceToDesired(al, brs)
}
func (d *DownTrack) AllocateOptimal(allowOvershoot bool) VideoAllocation {
func (d *DownTrack) AllocateOptimal(allowOvershoot bool, hold bool) VideoAllocation {
al, brs := d.params.Receiver.GetLayeredBitrate()
allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot)
allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot, hold)
d.postKeyFrameRequestEvent()
d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason)
return allocation
@@ -1966,6 +1966,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
Header: &pkt.Header,
Extensions: extensions,
Payload: payload,
IsRTX: true,
AbsSendTimeExtID: uint8(d.absSendTimeExtID),
TransportWideExtID: uint8(d.transportWideExtID),
WriteStream: d.writeStream,
+44 -14
View File
@@ -714,7 +714,7 @@ func (f *Forwarder) GetOptimalBandwidthNeeded(brs Bitrates) int64 {
return getOptimalBandwidthNeeded(f.muted, f.pubMuted, f.vls.GetMaxSeen().Spatial, brs, f.vls.GetMax())
}
func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allowOvershoot bool) VideoAllocation {
func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allowOvershoot bool, hold bool) VideoAllocation {
f.lock.Lock()
defer f.lock.Unlock()
@@ -755,7 +755,7 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow
}
alloc.TargetLayer = buffer.VideoLayer{
Spatial: int32(math.Min(float64(maxSeenLayer.Spatial), float64(maxSpatial))),
Spatial: min(maxSeenLayer.Spatial, maxSpatial),
Temporal: getMaxTemporal(),
}
}
@@ -783,8 +783,9 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow
// 2. If current is a valid layer, check against currently available layers and continue at current
// if possible. Else, choose the highest available layer as the next target.
// 3. If current is not valid, set next target to be opportunistic.
maxLayerSpatialLimit := int32(math.Min(float64(maxLayer.Spatial), float64(maxSeenLayer.Spatial)))
maxLayerSpatialLimit := min(maxLayer.Spatial, maxSeenLayer.Spatial)
highestAvailableLayer := buffer.InvalidLayerSpatial
lowestAvailableLayer := buffer.InvalidLayerSpatial
requestLayerSpatial := buffer.InvalidLayerSpatial
for _, al := range availableLayers {
if al > requestLayerSpatial && al <= maxLayerSpatialLimit {
@@ -793,6 +794,9 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow
if al > highestAvailableLayer {
highestAvailableLayer = al
}
if lowestAvailableLayer == buffer.InvalidLayerSpatial || al < lowestAvailableLayer {
lowestAvailableLayer = al
}
}
if requestLayerSpatial == buffer.InvalidLayerSpatial && highestAvailableLayer != buffer.InvalidLayerSpatial && allowOvershoot && f.vls.IsOvershootOkay() {
requestLayerSpatial = highestAvailableLayer
@@ -811,20 +815,46 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow
Temporal: getMaxTemporal(),
}
} else {
// current layer has stopped, switch to highest available
alloc.TargetLayer = buffer.VideoLayer{
Spatial: requestLayerSpatial,
Temporal: getMaxTemporal(),
// current layer has stopped, switch to lowest available if `hold`ing, else switch to highest available
if hold {
// if `hold` is requested, may be set due to early warning congestion
// signal, in that case layers are not increased as increasing layers
// will result in more load on the channel
alloc.TargetLayer = buffer.VideoLayer{
Spatial: lowestAvailableLayer,
Temporal: 0,
}
} else {
alloc.TargetLayer = buffer.VideoLayer{
Spatial: requestLayerSpatial,
Temporal: getMaxTemporal(),
}
}
}
alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial
} else {
// opportunistically latch on to anything
opportunisticAlloc()
if requestLayerSpatial == buffer.InvalidLayerSpatial {
alloc.RequestLayerSpatial = maxLayerSpatialLimit
if hold {
// allocate minimal to make the stream active while `hold`ing.
if lowestAvailableLayer == buffer.InvalidLayerSpatial {
alloc.TargetLayer = buffer.VideoLayer{
Spatial: 0,
Temporal: 0,
}
} else {
alloc.TargetLayer = buffer.VideoLayer{
Spatial: lowestAvailableLayer,
Temporal: 0,
}
}
alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial
} else {
alloc.RequestLayerSpatial = requestLayerSpatial
// opportunistically latch on to anything
opportunisticAlloc()
if requestLayerSpatial == buffer.InvalidLayerSpatial {
alloc.RequestLayerSpatial = maxLayerSpatialLimit
} else {
alloc.RequestLayerSpatial = requestLayerSpatial
}
}
}
}
@@ -834,7 +864,7 @@ func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allow
alloc.RequestLayerSpatial = buffer.InvalidLayerSpatial
}
if alloc.TargetLayer.IsValid() {
alloc.BandwidthRequested = optimalBandwidthNeeded
alloc.BandwidthRequested = getOptimalBandwidthNeeded(f.muted, f.pubMuted, maxSeenLayer.Spatial, brs, alloc.TargetLayer)
}
alloc.BandwidthDelta = alloc.BandwidthRequested - getBandwidthNeeded(brs, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested)
alloc.DistanceToDesired = getDistanceToDesired(
@@ -1120,7 +1150,7 @@ func (f *Forwarder) ProvisionalAllocateGetBestWeightedTransition() (VideoTransit
break
}
bandwidthDelta := int64(math.Max(float64(0), float64(existingBandwidthNeeded-f.provisional.bitrates[s][t])))
bandwidthDelta := max(0, existingBandwidthNeeded-f.provisional.bitrates[s][t])
transitionCost := int32(0)
// SVC-TODO: SVC will need a different cost transition
+91 -26
View File
@@ -145,7 +145,7 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.InvalidLayer,
DistanceToDesired: 0,
}
result := f.AllocateOptimal(nil, bitrates, true)
result := f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
@@ -163,7 +163,7 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 0,
}
result = f.AllocateOptimal(nil, bitrates, true)
result = f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
@@ -182,7 +182,7 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 0,
}
result = f.AllocateOptimal(nil, bitrates, true)
result = f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
@@ -201,19 +201,19 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 0,
}
result = f.AllocateOptimal(nil, bitrates, true)
result = f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
f.PubMute(false)
// when max layers changes, target is opportunistic, but requested spatial layer should be at max
f.SetMaxTemporalLayerSeen(3)
f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal)
f.vls.SetMax(buffer.VideoLayer{Spatial: 1, Temporal: 3})
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[1][3],
BandwidthDelta: bitrates[1][3],
BandwidthRequested: bitrates[2][1],
BandwidthDelta: bitrates[2][1],
BandwidthNeeded: bitrates[1][3],
Bitrates: bitrates,
TargetLayer: buffer.DefaultMaxLayer,
@@ -221,7 +221,7 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: f.vls.GetMax(),
DistanceToDesired: -1,
}
result = f.AllocateOptimal(nil, bitrates, true)
result = f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer())
@@ -232,14 +232,11 @@ func TestForwarderAllocateOptimal(t *testing.T) {
// when feed is dry and current is not valid, should set up for opportunistic forwarding
// NOTE: feed is dry due to availableLayers = nil, some valid bitrates may be passed in here for testing purposes only
disable(f)
expectedTargetLayer := buffer.VideoLayer{
Spatial: 2,
Temporal: buffer.DefaultMaxLayerTemporal,
}
expectedTargetLayer := buffer.DefaultMaxLayer
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[2][1],
BandwidthDelta: bitrates[2][1] - bitrates[1][3],
BandwidthDelta: 0,
BandwidthNeeded: bitrates[2][1],
Bitrates: bitrates,
TargetLayer: expectedTargetLayer,
@@ -247,7 +244,7 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: -0.5,
}
result = f.AllocateOptimal(nil, bitrates, true)
result = f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
@@ -270,7 +267,7 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: -0.75,
}
result = f.AllocateOptimal(nil, emptyBitrates, true)
result = f.AllocateOptimal(nil, emptyBitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
@@ -289,16 +286,37 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: -0.5,
}
result = f.AllocateOptimal([]int32{0, 1}, bitrates, true)
result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer())
// when holding in above scenario, should choose the lowest available layer
expectedTargetLayer = buffer.VideoLayer{
Spatial: 1,
Temporal: 0,
}
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[1][0],
BandwidthDelta: bitrates[1][0] - bitrates[2][1],
BandwidthNeeded: bitrates[2][1],
Bitrates: bitrates,
TargetLayer: expectedTargetLayer,
RequestLayerSpatial: 1,
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 1.25,
}
result = f.AllocateOptimal([]int32{1, 2}, bitrates, true, true)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
// opportunistic target if feed is dry and current is not valid, i. e. not forwarding
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[2][1],
BandwidthDelta: 0,
BandwidthDelta: bitrates[2][1] - bitrates[1][0],
BandwidthNeeded: bitrates[2][1],
Bitrates: bitrates,
TargetLayer: buffer.DefaultMaxLayer,
@@ -306,26 +324,49 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: -0.5,
}
result = f.AllocateOptimal(nil, bitrates, true)
result = f.AllocateOptimal(nil, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer())
// when holding in above scenario, should choose layer 0
expectedTargetLayer = buffer.VideoLayer{
Spatial: 0,
Temporal: 0,
}
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[0][0],
BandwidthDelta: bitrates[0][0] - bitrates[2][1],
BandwidthNeeded: bitrates[2][1],
Bitrates: bitrates,
TargetLayer: expectedTargetLayer,
RequestLayerSpatial: 0,
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 2.25,
}
result = f.AllocateOptimal(nil, bitrates, true, true)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
// if feed is not dry and current is not locked, should be opportunistic (with and without overshoot)
f.vls.SetTarget(buffer.InvalidLayer)
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonFeedDry,
BandwidthRequested: 0,
BandwidthDelta: 0 - bitrates[2][1],
BandwidthDelta: 0 - bitrates[0][0],
BandwidthNeeded: 0,
Bitrates: emptyBitrates,
TargetLayer: buffer.DefaultMaxLayer,
RequestLayerSpatial: 1,
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: -1.0,
}
result = f.AllocateOptimal([]int32{0, 1}, emptyBitrates, false)
result = f.AllocateOptimal([]int32{0, 1}, emptyBitrates, false, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer())
f.vls.SetTarget(buffer.InvalidLayer)
expectedTargetLayer = buffer.VideoLayer{
@@ -343,9 +384,10 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: -0.5,
}
result = f.AllocateOptimal([]int32{0, 1}, bitrates, true)
result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
// switches request layer to highest available if feed is not dry and current is valid and current is not available
f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 1})
@@ -355,8 +397,8 @@ func TestForwarderAllocateOptimal(t *testing.T) {
}
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[2][1],
BandwidthDelta: 0,
BandwidthRequested: bitrates[1][3],
BandwidthDelta: bitrates[1][3] - bitrates[2][1],
BandwidthNeeded: bitrates[2][1],
Bitrates: bitrates,
TargetLayer: expectedTargetLayer,
@@ -364,9 +406,31 @@ func TestForwarderAllocateOptimal(t *testing.T) {
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 0.5,
}
result = f.AllocateOptimal([]int32{1}, bitrates, true)
result = f.AllocateOptimal([]int32{1}, bitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
// when holding in above scenario, should switch to lowest available layer
expectedTargetLayer = buffer.VideoLayer{
Spatial: 0,
Temporal: 0,
}
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonNone,
BandwidthRequested: bitrates[0][0],
BandwidthDelta: bitrates[0][0] - bitrates[1][3],
BandwidthNeeded: bitrates[2][1],
Bitrates: bitrates,
TargetLayer: expectedTargetLayer,
RequestLayerSpatial: 0,
MaxLayer: buffer.DefaultMaxLayer,
DistanceToDesired: 2.25,
}
result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, true)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
// stays the same if feed is not dry and current is valid, available and locked
f.vls.SetMax(buffer.VideoLayer{Spatial: 0, Temporal: 1})
@@ -379,16 +443,17 @@ func TestForwarderAllocateOptimal(t *testing.T) {
expectedResult = VideoAllocation{
PauseReason: VideoPauseReasonFeedDry,
BandwidthRequested: 0,
BandwidthDelta: 0 - bitrates[2][1],
BandwidthDelta: 0 - bitrates[0][0],
Bitrates: emptyBitrates,
TargetLayer: expectedTargetLayer,
RequestLayerSpatial: 0,
MaxLayer: f.vls.GetMax(),
DistanceToDesired: 0.0,
}
result = f.AllocateOptimal([]int32{0}, emptyBitrates, true)
result = f.AllocateOptimal([]int32{0}, emptyBitrates, true, false)
require.Equal(t, expectedResult, result)
require.Equal(t, expectedResult, f.lastAllocation)
require.Equal(t, expectedTargetLayer, f.TargetLayer())
}
func TestForwarderProvisionalAllocate(t *testing.T) {
+31 -9
View File
@@ -19,6 +19,7 @@ import (
"io"
"time"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils/mono"
"github.com/pion/rtp"
@@ -26,11 +27,14 @@ import (
type Base struct {
logger logger.Logger
sendSideBWE *sendsidebwe.SendSideBWE
}
func NewBase(logger logger.Logger) *Base {
func NewBase(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE) *Base {
return &Base{
logger: logger,
logger: logger,
sendSideBWE: sendSideBWE,
}
}
@@ -47,7 +51,7 @@ func (b *Base) SendPacket(p *Packet) (int, error) {
}
}()
_, err := b.writeRTPHeaderExtensions(p)
err := b.writeRTPHeaderExtensions(p)
if err != nil {
b.logger.Errorw("writing rtp header extensions err", err)
return 0, err
@@ -66,7 +70,7 @@ func (b *Base) SendPacket(p *Packet) (int, error) {
}
// writes RTP header extensions of track
func (b *Base) writeRTPHeaderExtensions(p *Packet) (time.Time, error) {
func (b *Base) writeRTPHeaderExtensions(p *Packet) error {
// clear out extensions that may have been in the forwarded header
p.Header.Extension = false
p.Header.ExtensionProfile = 0
@@ -85,16 +89,34 @@ func (b *Base) writeRTPHeaderExtensions(p *Packet) (time.Time, error) {
sendTime := rtp.NewAbsSendTimeExtension(sendingAt)
b, err := sendTime.Marshal()
if err != nil {
return time.Time{}, err
return err
}
err = p.Header.SetExtension(p.AbsSendTimeExtID, b)
if err != nil {
return time.Time{}, err
if err = p.Header.SetExtension(p.AbsSendTimeExtID, b); err != nil {
return err
}
}
return sendingAt, nil
if p.TransportWideExtID != 0 && b.sendSideBWE != nil {
twccSN := b.sendSideBWE.RecordPacketSendAndGetSequenceNumber(
sendingAt,
p.Header.MarshalSize()+len(p.Payload),
p.IsRTX,
)
twccExt := rtp.TransportCCExtension{
TransportSequence: twccSN,
}
b, err := twccExt.Marshal()
if err != nil {
return err
}
if err = p.Header.SetExtension(p.TransportWideExtID, b); err != nil {
return err
}
}
return nil
}
// ------------------------------------------------
+3 -2
View File
@@ -19,6 +19,7 @@ import (
"time"
"github.com/gammazero/deque"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/protocol/logger"
)
@@ -38,9 +39,9 @@ type LeakyBucket struct {
isStopped bool
}
func NewLeakyBucket(logger logger.Logger, interval time.Duration, bitrate int) *LeakyBucket {
func NewLeakyBucket(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE, interval time.Duration, bitrate int) *LeakyBucket {
l := &LeakyBucket{
Base: NewBase(logger),
Base: NewBase(logger, sendSideBWE),
logger: logger,
interval: interval,
bitrate: bitrate,
+3 -2
View File
@@ -18,6 +18,7 @@ import (
"sync"
"github.com/gammazero/deque"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/protocol/logger"
)
@@ -32,9 +33,9 @@ type NoQueue struct {
isStopped bool
}
func NewNoQueue(logger logger.Logger) *NoQueue {
func NewNoQueue(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE) *NoQueue {
n := &NoQueue{
Base: NewBase(logger),
Base: NewBase(logger, sendSideBWE),
logger: logger,
wake: make(chan struct{}, 1),
}
+1
View File
@@ -31,6 +31,7 @@ type Packet struct {
Header *rtp.Header
Extensions []ExtensionData
Payload []byte
IsRTX bool
AbsSendTimeExtID uint8
TransportWideExtID uint8
WriteStream webrtc.TrackLocalWriter
+3 -2
View File
@@ -15,6 +15,7 @@
package pacer
import (
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/protocol/logger"
)
@@ -22,9 +23,9 @@ type PassThrough struct {
*Base
}
func NewPassThrough(logger logger.Logger) *PassThrough {
func NewPassThrough(logger logger.Logger, sendSideBWE *sendsidebwe.SendSideBWE) *PassThrough {
return &PassThrough{
Base: NewBase(logger),
Base: NewBase(logger, sendSideBWE),
}
}
+3 -6
View File
@@ -502,7 +502,6 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt
)
return
}
r.extHighestSNFromRR = extHighestSNFromRR
if r.srNewest != nil {
@@ -515,12 +514,10 @@ func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt
}
}
// This is 24-bit max in the protocol. So, technically doesn't need extended type. But, done for consistency.
packetsLostFromRR := r.packetsLostFromRR&0xFFFF_FFFF_0000_0000 + uint64(rr.TotalLost)
if (rr.TotalLost-r.lastRR.TotalLost) < (1<<31) && rr.TotalLost < r.lastRR.TotalLost {
packetsLostFromRR += (1 << 32)
r.packetsLostFromRR = r.packetsLostFromRR&0xFFFF_FFFF_FF00_0000 + uint64(rr.TotalLost)
if ((rr.TotalLost-r.lastRR.TotalLost)&((1<<24)-1)) < (1<<23) && rr.TotalLost < r.lastRR.TotalLost {
r.packetsLostFromRR += (1 << 24)
}
r.packetsLostFromRR = packetsLostFromRR
if isRttChanged {
r.rtt = rtt
+556
View File
@@ -0,0 +1,556 @@
package sendsidebwe
import (
"sync"
"time"
"github.com/frostbyte73/core"
"github.com/gammazero/deque"
"github.com/livekit/livekit-server/pkg/sfu/ccutils"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils/mono"
"github.com/pion/rtcp"
)
// -------------------------------------------------------------------------------
type CongestionSignalConfig struct {
MinNumberOfGroups int `yaml:"min_number_of_groups,omitempty"`
MinDuration time.Duration `yaml:"min_duration,omitempty"`
}
var (
DefaultEarlyWarningCongestionSignalConfig = CongestionSignalConfig{
MinNumberOfGroups: 1,
MinDuration: 100 * time.Millisecond,
}
DefaultCongestedCongestionSignalConfig = CongestionSignalConfig{
MinNumberOfGroups: 3,
MinDuration: 300 * time.Millisecond,
}
)
// -------------------------------------------------------------------------------
type CongestionDetectorConfig struct {
PacketGroup PacketGroupConfig `yaml:"packet_group,omitempty"`
PacketGroupMaxAge time.Duration `yaml:"packet_group_max_age,omitempty"`
JQRMinDelay time.Duration `yaml:"jqr_min_delay,omitempty"`
DQRMaxDelay time.Duration `yaml:"dqr_max_delay,omitempty"`
EarlyWarning CongestionSignalConfig `yaml:"early_warning,omitempty"`
EarlyWarningHangover time.Duration `yaml:"early_warning_hangover,omitempty"`
Congested CongestionSignalConfig `yaml:"congested,omitempty"`
CongestedHangover time.Duration `yaml:"congested_hangover,omitempty"`
RateMeasurementWindowFullnessMin float64 `yaml:"rate_measurement_window_fullness_min,omitempty"`
RateMeasurementWindowDurationMin time.Duration `yaml:"rate_measurement_window_duration_min,omitempty"`
RateMeasurementWindowDurationMax time.Duration `yaml:"rate_measurement_window_duration_max,omitempty"`
PeriodicCheckInterval time.Duration `yaml:"periodic_check_interval,omitempty"`
PeriodicCheckIntervalCongested time.Duration `yaml:"periodic_check_interval_congested,omitempty"`
CongestedCTRTrend ccutils.TrendDetectorConfig `yaml:"congested_ctr_trend,omitempty"`
CongestedCTREpsilon float64 `yaml:"congested_ctr_epsilon,omitempty"`
}
var (
defaultTrendDetectorConfigCongestedCTR = ccutils.TrendDetectorConfig{
RequiredSamples: 5,
RequiredSamplesMin: 2,
DownwardTrendThreshold: -0.6,
DownwardTrendMaxWait: 5 * time.Second,
CollapseThreshold: 500 * time.Millisecond,
ValidityWindow: 10 * time.Second,
}
DefaultCongestionDetectorConfig = CongestionDetectorConfig{
PacketGroup: DefaultPacketGroupConfig,
PacketGroupMaxAge: 30 * time.Second,
JQRMinDelay: 15 * time.Millisecond,
DQRMaxDelay: 5 * time.Millisecond,
EarlyWarning: DefaultEarlyWarningCongestionSignalConfig,
EarlyWarningHangover: 500 * time.Millisecond,
Congested: DefaultCongestedCongestionSignalConfig,
CongestedHangover: 3 * time.Second,
RateMeasurementWindowFullnessMin: 0.8,
RateMeasurementWindowDurationMin: 800 * time.Millisecond,
RateMeasurementWindowDurationMax: 2 * time.Second,
PeriodicCheckInterval: 2 * time.Second,
PeriodicCheckIntervalCongested: 200 * time.Millisecond,
CongestedCTRTrend: defaultTrendDetectorConfigCongestedCTR,
CongestedCTREpsilon: 0.05,
}
)
// -------------------------------------------------------------------------------
type feedbackReport struct {
at time.Time
report *rtcp.TransportLayerCC
}
type congestionDetectorParams struct {
Config CongestionDetectorConfig
Logger logger.Logger
}
type congestionDetector struct {
params congestionDetectorParams
lock sync.RWMutex
feedbackReports deque.Deque[feedbackReport]
*packetTracker
twccFeedback *twccFeedback
packetGroups []*packetGroup
wake chan struct{}
stop core.Fuse
estimatedAvailableChannelCapacity int64
congestionState CongestionState
congestionStateSwitchedAt time.Time
congestedCTRTrend *ccutils.TrendDetector[float64]
onCongestionStateChange func(congestionState CongestionState, estimatedAvailableChannelCapacity int64)
}
func newCongestionDetector(params congestionDetectorParams) *congestionDetector {
c := &congestionDetector{
params: params,
packetTracker: newPacketTracker(packetTrackerParams{Logger: params.Logger}),
twccFeedback: newTWCCFeedback(twccFeedbackParams{Logger: params.Logger}),
wake: make(chan struct{}, 1),
estimatedAvailableChannelCapacity: 100_000_000,
}
c.feedbackReports.SetMinCapacity(3)
go c.worker()
return c
}
func (c *congestionDetector) Stop() {
c.stop.Once(func() {
close(c.wake)
})
}
func (c *congestionDetector) OnCongestionStateChange(f func(congestionState CongestionState, estimatedAvailableChannelCapacity int64)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onCongestionStateChange = f
}
func (c *congestionDetector) GetCongestionState() CongestionState {
c.lock.RLock()
defer c.lock.RUnlock()
return c.congestionState
}
func (c *congestionDetector) updateCongestionState(state CongestionState) {
c.lock.Lock()
c.params.Logger.Infow(
"congestion state change",
"from", c.congestionState,
"to", state,
"estimatedAvailableChannelCapacity", c.estimatedAvailableChannelCapacity,
)
prevState := c.congestionState
c.congestionState = state
onCongestionStateChange := c.onCongestionStateChange
estimatedAvailableChannelCapacity := c.estimatedAvailableChannelCapacity
c.lock.Unlock()
if onCongestionStateChange != nil {
onCongestionStateChange(state, estimatedAvailableChannelCapacity)
}
// when in congested state, monitor changes in captured traffic ratio (CTR)
// to ensure allocations are in line with latest estimates, it is possible that
// the estimate is incorrect when congestion starts and the allocation may be
// sub-optimal and not enough to reduce/relieve congestion, by monitoing CTR
// on a continuous basis allocations can be adjusted in the direction of
// reducing/relieving congestion
if state == CongestionStateCongested && prevState != CongestionStateCongested {
c.congestedCTRTrend = ccutils.NewTrendDetector[float64](ccutils.TrendDetectorParams{
Name: "ssbwe-estimate",
Logger: c.params.Logger,
Config: c.params.Config.CongestedCTRTrend,
})
} else if state != CongestionStateCongested {
c.congestedCTRTrend = nil
}
}
func (c *congestionDetector) GetEstimatedAvailableChannelCapacity() int64 {
c.lock.RLock()
defer c.lock.RUnlock()
return c.estimatedAvailableChannelCapacity
}
func (c *congestionDetector) HandleRTCP(report *rtcp.TransportLayerCC) {
c.lock.Lock()
defer c.lock.Unlock()
if c.stop.IsBroken() {
return
}
c.feedbackReports.PushBack(feedbackReport{mono.Now(), report})
// notify worker of a new feedback
select {
case c.wake <- struct{}{}:
default:
}
}
func (c *congestionDetector) prunePacketGroups() {
if len(c.packetGroups) == 0 {
return
}
threshold := c.packetGroups[len(c.packetGroups)-1].MinSendTime() - c.params.Config.PacketGroupMaxAge.Microseconds()
for idx, pg := range c.packetGroups {
if mst := pg.MinSendTime(); mst < threshold {
c.packetGroups = c.packetGroups[idx+1:]
return
}
}
}
func (c *congestionDetector) isCongestionSignalTriggered() (bool, bool) {
earlyWarningTriggered := false
congestedTriggered := false
numGroups := 0
duration := int64(0)
for idx := len(c.packetGroups) - 1; idx >= 0; idx-- {
pg := c.packetGroups[idx]
pqd, ok := pg.PropagatedQueuingDelay()
if !ok {
continue
}
if pqd > c.params.Config.JQRMinDelay.Microseconds() {
// JQR group builds up congestion signal
numGroups++
duration += pg.SendDuration()
}
// INDETERMINATE group is treated as a no-op
if pqd < c.params.Config.DQRMaxDelay.Microseconds() {
// any DQR group breaks the continuity
return earlyWarningTriggered, congestedTriggered
}
if numGroups >= c.params.Config.EarlyWarning.MinNumberOfGroups && duration >= c.params.Config.EarlyWarning.MinDuration.Microseconds() {
earlyWarningTriggered = true
}
if numGroups >= c.params.Config.Congested.MinNumberOfGroups && duration >= c.params.Config.Congested.MinDuration.Microseconds() {
congestedTriggered = true
}
if earlyWarningTriggered && congestedTriggered {
break
}
}
return earlyWarningTriggered, congestedTriggered
}
func (c *congestionDetector) congestionDetectionStateMachine() {
state := c.GetCongestionState()
newState := state
earlyWarningTriggered, congestedTriggered := c.isCongestionSignalTriggered()
switch state {
case CongestionStateNone:
if congestedTriggered {
c.params.Logger.Warnw("invalid congested state transition", nil, "from", state)
}
if earlyWarningTriggered {
newState = CongestionStateEarlyWarning
}
case CongestionStateEarlyWarning:
if congestedTriggered {
newState = CongestionStateCongested
} else if !earlyWarningTriggered {
newState = CongestionStateEarlyWarningHangover
}
case CongestionStateEarlyWarningHangover:
if congestedTriggered {
c.params.Logger.Warnw("invalid congested state transition", nil, "from", state)
}
if earlyWarningTriggered {
newState = CongestionStateEarlyWarning
} else if time.Since(c.congestionStateSwitchedAt) >= c.params.Config.EarlyWarningHangover {
newState = CongestionStateNone
}
case CongestionStateCongested:
if !congestedTriggered {
newState = CongestionStateCongestedHangover
}
case CongestionStateCongestedHangover:
if congestedTriggered {
c.params.Logger.Warnw("invalid congested state transition", nil, "from", state)
}
if earlyWarningTriggered {
newState = CongestionStateEarlyWarning
} else if time.Since(c.congestionStateSwitchedAt) >= c.params.Config.CongestedHangover {
newState = CongestionStateNone
}
}
c.estimateAvailableChannelCapacity()
// update after running the above estimate as state change callback includes the estimated available channel capacity
if newState != state {
c.congestionStateSwitchedAt = mono.Now()
c.updateCongestionState(newState)
}
}
func (c *congestionDetector) updateTrend(ctr float64) {
if c.congestedCTRTrend == nil {
return
}
// quantise the CTR to filter out small changes
c.congestedCTRTrend.AddValue(float64(int((ctr+(c.params.Config.CongestedCTREpsilon/2))/c.params.Config.CongestedCTREpsilon)) * c.params.Config.CongestedCTREpsilon)
if c.congestedCTRTrend.GetDirection() == ccutils.TrendDirectionDownward {
c.params.Logger.Infow("captured traffic ratio is trending downward", "channel", c.congestedCTRTrend.ToString())
c.lock.RLock()
state := c.congestionState
estimatedAvailableChannelCapacity := c.estimatedAvailableChannelCapacity
onCongestionStateChange := c.onCongestionStateChange
c.lock.RUnlock()
if onCongestionStateChange != nil {
onCongestionStateChange(state, estimatedAvailableChannelCapacity)
}
// reset to get new set of samples for next trend
c.congestedCTRTrend = ccutils.NewTrendDetector[float64](ccutils.TrendDetectorParams{
Name: "ssbwe-estimate",
Logger: c.params.Logger,
Config: c.params.Config.CongestedCTRTrend,
})
}
}
func (c *congestionDetector) estimateAvailableChannelCapacity() {
if len(c.packetGroups) == 0 {
return
}
totalDuration := int64(0)
totalBytes := 0
threshold := c.packetGroups[len(c.packetGroups)-1].MinSendTime() - c.params.Config.RateMeasurementWindowDurationMax.Microseconds()
for idx := len(c.packetGroups) - 1; idx >= 0; idx-- {
pg := c.packetGroups[idx]
mst, dur, nbytes, fullness := pg.Traffic()
if mst < threshold {
break
}
if fullness < c.params.Config.RateMeasurementWindowFullnessMin {
continue
}
totalDuration += dur
totalBytes += nbytes
}
if totalDuration >= c.params.Config.RateMeasurementWindowDurationMin.Microseconds() {
c.lock.Lock()
c.estimatedAvailableChannelCapacity = int64(totalBytes) * 8 * 1e6 / totalDuration
c.lock.Unlock()
} else {
c.params.Logger.Infow("not enough data to estimate available channel capacity", "totalDuration", totalDuration)
}
}
func (c *congestionDetector) processFeedbackReport(fbr feedbackReport) {
recvRefTime, isOutOfOrder := c.twccFeedback.ProcessReport(fbr.report, fbr.at)
if isOutOfOrder {
c.params.Logger.Infow("received out-of-order feedback report")
}
if len(c.packetGroups) == 0 {
c.packetGroups = append(
c.packetGroups,
newPacketGroup(
packetGroupParams{
Config: c.params.Config.PacketGroup,
Logger: c.params.Logger,
},
0,
),
)
}
pg := c.packetGroups[len(c.packetGroups)-1]
trackPacketGroup := func(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) {
if pi == nil {
return
}
err := pg.Add(pi, sendDelta, recvDelta, isLost)
if err == nil {
return
}
if err == errGroupFinalized {
// previous group ended, start a new group
c.updateTrend(pg.CapturedTrafficRatio())
// SSBWE-REMOVE c.params.Logger.Infow("packet group done", "group", pg, "numGroups", len(c.packetGroups)) // SSBWE-REMOVE
pqd, _ := pg.PropagatedQueuingDelay()
pg = newPacketGroup(
packetGroupParams{
Config: c.params.Config.PacketGroup,
Logger: c.params.Logger,
},
pqd,
)
c.packetGroups = append(c.packetGroups, pg)
pg.Add(pi, sendDelta, recvDelta, isLost)
return
}
// try an older group
for idx := len(c.packetGroups) - 2; idx >= 0; idx-- {
opg := c.packetGroups[idx]
if err := opg.Add(pi, sendDelta, recvDelta, isLost); err == nil {
return
} else if err == errGroupFinalized {
c.params.Logger.Infow("unpected finalized group", "packetInfo", pi, "packetGroup", opg)
}
}
}
// 1. go through the TWCC feedback report and record recive time as reported by remote
// 2. process acknowledged packet and group them
//
// losses are not recorded if a feedback report is completely lost.
// RFC recommends treating lost reports by ignoring packets that would have been in it.
// -----------------------------------------------------------------------------------
// | From a congestion control perspective, lost feedback messages are |
// | handled by ignoring packets which would have been reported as lost or |
// | received in the lost feedback messages. This behavior is similar to |
// | how a lost RTCP receiver report is handled. |
// -----------------------------------------------------------------------------------
// Reference: https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-4
sequenceNumber := fbr.report.BaseSequenceNumber
endSequenceNumberExclusive := sequenceNumber + fbr.report.PacketStatusCount
deltaIdx := 0
for _, chunk := range fbr.report.PacketChunks {
if sequenceNumber == endSequenceNumberExclusive {
break
}
switch chunk := chunk.(type) {
case *rtcp.RunLengthChunk:
for i := uint16(0); i < chunk.RunLength; i++ {
if sequenceNumber == endSequenceNumberExclusive {
break
}
recvTime := int64(0)
isLost := false
if chunk.PacketStatusSymbol != rtcp.TypeTCCPacketNotReceived {
recvRefTime += fbr.report.RecvDeltas[deltaIdx].Delta
deltaIdx++
recvTime = recvRefTime
} else {
isLost = true
}
pi, sendDelta, recvDelta := c.packetTracker.RecordPacketIndicationFromRemote(sequenceNumber, recvTime)
if pi.sendTime != 0 {
trackPacketGroup(&pi, sendDelta, recvDelta, isLost)
}
sequenceNumber++
}
case *rtcp.StatusVectorChunk:
for _, symbol := range chunk.SymbolList {
if sequenceNumber == endSequenceNumberExclusive {
break
}
recvTime := int64(0)
isLost := false
if symbol != rtcp.TypeTCCPacketNotReceived {
recvRefTime += fbr.report.RecvDeltas[deltaIdx].Delta
deltaIdx++
recvTime = recvRefTime
} else {
isLost = true
}
pi, sendDelta, recvDelta := c.packetTracker.RecordPacketIndicationFromRemote(sequenceNumber, recvTime)
if pi.sendTime != 0 {
trackPacketGroup(&pi, sendDelta, recvDelta, isLost)
}
sequenceNumber++
}
}
}
c.prunePacketGroups()
c.congestionDetectionStateMachine()
}
func (c *congestionDetector) worker() {
ticker := time.NewTicker(c.params.Config.PeriodicCheckInterval)
defer ticker.Stop()
for {
select {
case <-c.wake:
for {
c.lock.Lock()
if c.feedbackReports.Len() == 0 {
c.lock.Unlock()
break
}
fbReport := c.feedbackReports.PopFront()
c.lock.Unlock()
c.processFeedbackReport(fbReport)
}
if c.GetCongestionState() == CongestionStateCongested {
ticker.Reset(c.params.Config.PeriodicCheckIntervalCongested)
} else {
ticker.Reset(c.params.Config.PeriodicCheckInterval)
}
case <-ticker.C:
c.prunePacketGroups()
c.congestionDetectionStateMachine()
case <-c.stop.Watch():
return
}
}
}
+338
View File
@@ -0,0 +1,338 @@
package sendsidebwe
import (
"errors"
"math"
"time"
"github.com/livekit/protocol/logger"
"github.com/livekit/protocol/utils"
"go.uber.org/zap/zapcore"
)
// -------------------------------------------------------------
var (
errGroupFinalized = errors.New("packet group is finalized")
errOldPacket = errors.New("packet is older than packet group start")
)
// -------------------------------------------------------------
type PacketGroupConfig struct {
MinPackets int `yaml:"min_packets,omitempty"`
MaxWindowDuration time.Duration `yaml:"max_window_duration,omitempty"`
// should have at least this fraction of `MinPackets` for loss penalty consideration
LossPenaltyMinPacketsRatio float64 `yaml:"loss_penalty_min_packet_ratio,omitempty"`
LossPenaltyFactor float64 `yaml:"loss_penalty_factor,omitempty"`
}
var (
DefaultPacketGroupConfig = PacketGroupConfig{
MinPackets: 20,
MaxWindowDuration: 500 * time.Millisecond,
LossPenaltyMinPacketsRatio: 0.5,
LossPenaltyFactor: 0.25,
}
)
// -------------------------------------------------------------
type stat struct {
numPackets int
numBytes int
}
func (s *stat) add(size int) {
s.numPackets++
s.numBytes += size
}
func (s *stat) remove(size int) {
s.numPackets--
s.numBytes -= size
}
func (s *stat) getNumPackets() int {
return s.numPackets
}
func (s *stat) getNumBytes() int {
return s.numBytes
}
func (s stat) MarshalLogObject(e zapcore.ObjectEncoder) error {
e.AddInt("numPackets", s.numPackets)
e.AddInt("numBytes", s.numBytes)
return nil
}
// -------------------------------------------------------------
type classStat struct {
primary stat
rtx stat
}
func (c *classStat) add(size int, isRTX bool) {
if isRTX {
c.rtx.add(size)
} else {
c.primary.add(size)
}
}
func (c *classStat) remove(size int, isRTX bool) {
if isRTX {
c.rtx.remove(size)
} else {
c.primary.remove(size)
}
}
func (c *classStat) numPackets() int {
return c.primary.getNumPackets() + c.rtx.getNumPackets()
}
func (c *classStat) numBytes() int {
return c.primary.getNumBytes() + c.rtx.getNumBytes()
}
func (c classStat) MarshalLogObject(e zapcore.ObjectEncoder) error {
e.AddObject("primary", c.primary)
e.AddObject("rtx", c.rtx)
return nil
}
// -------------------------------------------------------------
type packetGroupParams struct {
Config PacketGroupConfig
Logger logger.Logger
}
type packetGroup struct {
params packetGroupParams
minSequenceNumber uint64
maxSequenceNumber uint64
minSendTime int64
maxSendTime int64
minRecvTime int64 // for information only
maxRecvTime int64 // for information only
acked classStat
lost classStat
snBitmap *utils.Bitmap[uint64]
aggregateSendDelta int64
aggregateRecvDelta int64
queuingDelay int64
isFinalized bool
}
func newPacketGroup(params packetGroupParams, queuingDelay int64) *packetGroup {
return &packetGroup{
params: params,
queuingDelay: queuingDelay,
snBitmap: utils.NewBitmap[uint64](params.Config.MinPackets),
}
}
func (p *packetGroup) Add(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) error {
if isLost {
return p.lostPacket(pi)
}
if err := p.inGroup(pi.sequenceNumber); err != nil {
return err
}
if p.minSequenceNumber == 0 || pi.sequenceNumber < p.minSequenceNumber {
p.minSequenceNumber = pi.sequenceNumber
}
p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber)
if p.minSendTime == 0 || (pi.sendTime-sendDelta) < p.minSendTime {
p.minSendTime = pi.sendTime - sendDelta
}
p.maxSendTime = max(p.maxSendTime, pi.sendTime)
if p.minRecvTime == 0 || (pi.recvTime-recvDelta) < p.minRecvTime {
p.minRecvTime = pi.recvTime - recvDelta
}
p.maxRecvTime = max(p.maxRecvTime, pi.recvTime)
p.acked.add(int(pi.size), pi.isRTX)
if p.snBitmap.IsSet(pi.sequenceNumber - p.minSequenceNumber) {
// an earlier packet reported as lost has been received
p.snBitmap.Clear(pi.sequenceNumber - p.minSequenceNumber)
p.lost.remove(int(pi.size), pi.isRTX)
}
// note that out-of-order deliveries will amplify the queueing delay.
// for e.g. a, b, c getting delivered as a, c, b.
// let us say packets are delivered with interval of `x`
// send delta aggregate will go up by x((a, c) = 2x + (c, b) -1x)
// recv delta aggregate will go up by 3x((a, c) = 2x + (c, b) 1x)
p.aggregateSendDelta += sendDelta
p.aggregateRecvDelta += recvDelta
if p.acked.numPackets() == p.params.Config.MinPackets || (pi.sendTime-p.minSendTime) > p.params.Config.MaxWindowDuration.Microseconds() {
p.isFinalized = true
}
return nil
}
func (p *packetGroup) lostPacket(pi *packetInfo) error {
if pi.recvTime != 0 {
// previously received packet, so not lost
return nil
}
if err := p.inGroup(pi.sequenceNumber); err != nil {
return err
}
if p.minSequenceNumber == 0 || pi.sequenceNumber < p.minSequenceNumber {
p.minSequenceNumber = pi.sequenceNumber
}
p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber)
p.snBitmap.Set(pi.sequenceNumber - p.minSequenceNumber)
p.lost.add(int(pi.size), pi.isRTX)
return nil
}
func (p *packetGroup) MinSendTime() int64 {
return p.minSendTime
}
func (p *packetGroup) PropagatedQueuingDelay() (int64, bool) {
if !p.isFinalized {
return 0, false
}
if p.queuingDelay+p.aggregateRecvDelta-p.aggregateSendDelta > 0 {
return p.queuingDelay + p.aggregateRecvDelta - p.aggregateSendDelta, true
}
return max(0, p.aggregateRecvDelta-p.aggregateSendDelta), true
}
func (p *packetGroup) SendDuration() int64 {
if !p.isFinalized {
return 0
}
return p.maxSendTime - p.minSendTime
}
func (p *packetGroup) CapturedTrafficRatio() float64 {
capturedTrafficRatio := float64(0.0)
if p.aggregateRecvDelta != 0 {
// apply a penalty for lost packets,
// tha rationale being packet dropping is a strategy to relieve congestion
// and if they were not dropped, they would have increased queuing delay,
// as it is not possible to know the reason for the losses,
// apply a small penalty to receive delta aggregate to simulate those packets
// build up queuing delay.
//
// note that it is applied only for determining rate and
// not while determining queuing region, adding synthetic delays
// like this could cause queuing region to be stuck in JQR
capturedTrafficRatio = float64(p.aggregateSendDelta) / float64(p.aggregateRecvDelta+p.getLossPenalty())
}
return min(1.0, capturedTrafficRatio)
}
func (p *packetGroup) Traffic() (int64, int64, int, float64) {
numBytes := int(float64(p.acked.numBytes()) * p.CapturedTrafficRatio())
fullness := max(
float64(p.acked.numPackets())/float64(p.params.Config.MinPackets),
float64(p.maxSendTime-p.minSendTime)/float64(p.params.Config.MaxWindowDuration.Microseconds()),
)
return p.minSendTime, p.maxSendTime - p.minSendTime, numBytes, fullness
}
func (p *packetGroup) MarshalLogObject(e zapcore.ObjectEncoder) error {
if p == nil {
return nil
}
e.AddInt64("minSendTime", p.minSendTime)
e.AddInt64("maxSendTime", p.maxSendTime)
sendDuration := time.Duration((p.maxSendTime - p.minSendTime) * 1000)
e.AddDuration("sendDuration", sendDuration)
e.AddInt64("minRecvTime", p.minRecvTime)
e.AddInt64("maxRecvTime", p.maxRecvTime)
recvDuration := time.Duration((p.maxRecvTime - p.minRecvTime) * 1000)
e.AddDuration("recvDuration", recvDuration)
e.AddObject("acked", p.acked)
e.AddObject("lost", p.lost)
sendBitrate := float64(0)
if sendDuration != 0 {
sendBitrate = float64(p.acked.numBytes()*8) / sendDuration.Seconds()
e.AddFloat64("sendBitrate", sendBitrate)
}
recvBitrate := float64(0)
if recvDuration != 0 {
recvBitrate = float64(p.acked.numBytes()*8) / recvDuration.Seconds()
e.AddFloat64("recvBitrate", recvBitrate)
}
e.AddInt64("aggregateSendDelta", p.aggregateSendDelta)
e.AddInt64("aggregateRecvDelta", p.aggregateRecvDelta)
e.AddInt64("queuingDelay", p.queuingDelay)
e.AddInt64("groupDelay", p.aggregateRecvDelta-p.aggregateSendDelta)
e.AddFloat64("lossRatio", float64(p.lost.numPackets())/float64(p.acked.numPackets()+p.lost.numPackets()))
e.AddInt64("lossPenalty", p.getLossPenalty())
capturedTrafficRatio := p.CapturedTrafficRatio()
e.AddFloat64("capturedTrafficRatio", capturedTrafficRatio)
e.AddFloat64("estimatedAvailableChannelCapacity", sendBitrate*capturedTrafficRatio)
e.AddBool("isFinalized", p.isFinalized)
return nil
}
func (p *packetGroup) inGroup(sequenceNumber uint64) error {
if p.isFinalized && sequenceNumber > p.maxSequenceNumber {
return errGroupFinalized
}
if sequenceNumber < p.minSequenceNumber {
return errOldPacket
}
return nil
}
func (p *packetGroup) getLossPenalty() int64 {
if p.aggregateRecvDelta == 0 {
return 0
}
lostPackets := p.lost.numPackets()
totalPackets := float64(lostPackets + p.acked.numPackets())
if totalPackets < float64(p.params.Config.MinPackets)*p.params.Config.LossPenaltyMinPacketsRatio {
return 0
}
// Log10 is used to give higher weight for the same loss ratio at higher packet rates,
// for e.g. with a penalty factor of 0.25
// - 10% loss at 20 total packets = 0.1 * log10(20) * 0.25 = 0.032
// - 10% loss at 100 total packets = 0.1 * log10(100) * 0.25 = 0.05
// - 10% loss at 1000 total packets = 0.1 * log10(100) * 0.25 = 0.075
lossRatio := float64(lostPackets) / totalPackets
return int64(float64(p.aggregateRecvDelta) * lossRatio * math.Log10(totalPackets) * p.params.Config.LossPenaltyFactor)
}
+36
View File
@@ -0,0 +1,36 @@
package sendsidebwe
import (
"go.uber.org/zap/zapcore"
)
type packetInfo struct {
sequenceNumber uint64
sendTime int64
recvTime int64
size uint16
isRTX bool
// SSBWE-TODO: possibly add the following fields - pertaining to this packet,
// idea is to be able to figure out probe start/end and check for bitrate in that window
}
func (pi *packetInfo) Reset(sequenceNumber uint64) {
pi.sequenceNumber = sequenceNumber
pi.sendTime = 0
pi.recvTime = 0
pi.size = 0
pi.isRTX = false
}
func (pi *packetInfo) MarshalLogObject(e zapcore.ObjectEncoder) error {
if pi == nil {
return nil
}
e.AddUint64("sequenceNumber", pi.sequenceNumber)
e.AddInt64("sendTime", pi.sendTime)
e.AddInt64("recvTime", pi.recvTime)
e.AddUint16("size", pi.size)
e.AddBool("isRTX", pi.isRTX)
return nil
}
+113
View File
@@ -0,0 +1,113 @@
package sendsidebwe
import (
"errors"
"math/rand"
"sync"
"time"
"github.com/livekit/protocol/logger"
)
// -------------------------------------------------------------------------------
var (
errNoPacketInRange = errors.New("no packet in range")
)
// -------------------------------------------------------------------------------
type packetTrackerParams struct {
Logger logger.Logger
}
type packetTracker struct {
params packetTrackerParams
lock sync.Mutex
sequenceNumber uint64
baseSendTime int64
packetInfos [2048]packetInfo
baseRecvTime int64
piLastRecv *packetInfo
}
func newPacketTracker(params packetTrackerParams) *packetTracker {
return &packetTracker{
params: params,
sequenceNumber: uint64(rand.Intn(1<<14)) + uint64(1<<15), // a random number in third quartile of sequence number space
}
}
// SSBWE-TODO: this potentially needs to take isProbe as argument?
func (p *packetTracker) RecordPacketSendAndGetSequenceNumber(at time.Time, size int, isRTX bool) uint16 {
p.lock.Lock()
defer p.lock.Unlock()
sendTime := at.UnixMicro()
if p.baseSendTime == 0 {
p.baseSendTime = sendTime
}
pi := p.getPacketInfo(uint16(p.sequenceNumber))
pi.sequenceNumber = p.sequenceNumber
pi.sendTime = sendTime - p.baseSendTime
pi.recvTime = 0
pi.size = uint16(size)
pi.isRTX = isRTX
// SSBWE-REMOVE p.params.Logger.Infow("packet sent", "packetInfo", pi) // SSBWE-REMOVE
p.sequenceNumber++
// extreme case of wrap around before receiving any feedback
if pi == p.piLastRecv {
p.piLastRecv = nil
}
return uint16(pi.sequenceNumber)
}
func (p *packetTracker) RecordPacketIndicationFromRemote(sn uint16, recvTime int64) (piRecv packetInfo, sendDelta, recvDelta int64) {
p.lock.Lock()
defer p.lock.Unlock()
pi := p.getPacketInfoExisting(sn)
if pi == nil {
return
}
if recvTime == 0 {
// maybe lost OR already receied but reported lost in a later report
piRecv = *pi
return
}
if p.baseRecvTime == 0 {
p.baseRecvTime = recvTime
p.piLastRecv = pi
}
pi.recvTime = recvTime - p.baseRecvTime
piRecv = *pi
if p.piLastRecv != nil {
sendDelta, recvDelta = pi.sendTime-p.piLastRecv.sendTime, pi.recvTime-p.piLastRecv.recvTime
}
p.piLastRecv = pi
return
}
func (p *packetTracker) getPacketInfo(sn uint16) *packetInfo {
return &p.packetInfos[int(sn)%len(p.packetInfos)]
}
func (p *packetTracker) getPacketInfoExisting(sn uint16) *packetInfo {
pi := &p.packetInfos[int(sn)%len(p.packetInfos)]
if uint16(pi.sequenceNumber) == sn {
return pi
}
return nil
}
+105
View File
@@ -0,0 +1,105 @@
package sendsidebwe
import (
"fmt"
"github.com/livekit/protocol/logger"
)
//
// Based on a simplified/modified version of JitterPath paper
// (https://homepage.iis.sinica.edu.tw/papers/lcs/2114-F.pdf)
//
// TWCC feedback is uesed to calcualte delta one-way-delay.
// It is accumulated/propagated to determine in which region
// groups of packets are operating in.
//
// In simplified terms,
// o JQR (Join Queuing Region) is when channel is congested.
// o DQR (Disjoint Queuing Region) is when channel is not.
//
// Packets are grouped and thresholds applied to smooth over
// small variations. For example, in the paper,
// if propagated_queuing_delay + delta_one_way_delay > 0 {
// possibly_operating_in_jqr
// }
// But, in this implementation it is checked at packet group level,
// i. e. using queuing delay and aggreated delta one-way-delay of
// the group and a minimum value threshold is applied before declaring
// that a group is in JQR.
//
// There is also hysteresis to make transisitons smoother, i.e. if the
// metric is above a certain threshold, it is JQR and it is DQR only if it
// is below a certain value and the gap in between those two thresholds
// are treated as interdeterminate groups.
//
// ---------------------------------------------------------------------------
type CongestionState int
const (
CongestionStateNone CongestionState = iota
CongestionStateEarlyWarning
CongestionStateEarlyWarningHangover
CongestionStateCongested
CongestionStateCongestedHangover
)
func (c CongestionState) String() string {
switch c {
case CongestionStateNone:
return "NONE"
case CongestionStateEarlyWarning:
return "EARLY_WARNING"
case CongestionStateEarlyWarningHangover:
return "EARLY_WARNING_HANGOVER"
case CongestionStateCongested:
return "CONGESTED"
case CongestionStateCongestedHangover:
return "CONGESTED_HANGOVER"
default:
return fmt.Sprintf("%d", int(c))
}
}
// ---------------------------------------------------------------------------
type SendSideBWEConfig struct {
CongestionDetector CongestionDetectorConfig `yaml:"congestion_detector,omitempty"`
}
var (
DefaultSendSideBWEConfig = SendSideBWEConfig{
CongestionDetector: DefaultCongestionDetectorConfig,
}
)
// ---------------------------------------------------------------------------
type SendSideBWEParams struct {
Config SendSideBWEConfig
Logger logger.Logger
}
type SendSideBWE struct {
params SendSideBWEParams
*congestionDetector
}
func NewSendSideBWE(params SendSideBWEParams) *SendSideBWE {
return &SendSideBWE{
params: params,
congestionDetector: newCongestionDetector(congestionDetectorParams{
Config: params.Config.CongestionDetector,
Logger: params.Logger,
}),
}
}
func (s *SendSideBWE) Stop() {
s.congestionDetector.Stop()
}
// ------------------------------------------------
+116
View File
@@ -0,0 +1,116 @@
package sendsidebwe
import (
"errors"
"time"
"github.com/livekit/protocol/logger"
"github.com/pion/rtcp"
"go.uber.org/zap/zapcore"
)
// ------------------------------------------------------
const (
cOutlierReportFactor = 3
cEstimatedFeedbackIntervalAlpha = float64(0.9)
cReferenceTimeMask = (1 << 24) - 1
cReferenceTimeResolution = 64 // 64 ms
)
// ------------------------------------------------------
var (
errFeedbackReportOutOfOrder = errors.New("feedback report out-of-order")
)
// ------------------------------------------------------
type twccFeedbackParams struct {
Logger logger.Logger
}
type twccFeedback struct {
params twccFeedbackParams
lastFeedbackTime time.Time
estimatedFeedbackInterval time.Duration
numReports int
numReportsOutOfOrder int
highestFeedbackCount uint8
cycles int64
highestReferenceTime uint32
}
func newTWCCFeedback(params twccFeedbackParams) *twccFeedback {
return &twccFeedback{
params: params,
}
}
func (t *twccFeedback) ProcessReport(report *rtcp.TransportLayerCC, at time.Time) (int64, bool) {
// SSBWE-REMOVE t.params.Logger.Infow("TWCC feedback", "report", report.String()) // SSBWE-REMOVE
t.numReports++
if t.lastFeedbackTime.IsZero() {
t.lastFeedbackTime = at
t.highestReferenceTime = report.ReferenceTime
t.highestFeedbackCount = report.FbPktCount
return (t.cycles + int64(report.ReferenceTime)) * cReferenceTimeResolution * 1000, false
}
isOutOfOrder := false
if (report.FbPktCount - t.highestFeedbackCount) > (1 << 7) {
t.numReportsOutOfOrder++
isOutOfOrder = true
}
// reference time wrap around handling
var referenceTime int64
if (report.ReferenceTime-t.highestReferenceTime)&cReferenceTimeMask < (1 << 23) {
if report.ReferenceTime < t.highestReferenceTime {
t.cycles += (1 << 24)
}
t.highestReferenceTime = report.ReferenceTime
referenceTime = t.cycles + int64(report.ReferenceTime)
} else {
cycles := t.cycles
if report.ReferenceTime > t.highestReferenceTime && cycles >= (1<<24) {
cycles -= (1 << 24)
}
referenceTime = cycles + int64(report.ReferenceTime)
}
if !isOutOfOrder {
sinceLast := at.Sub(t.lastFeedbackTime)
// SSBWE-REMOVE t.params.Logger.Infow("report received", "at", at, "sinceLast", sinceLast, "pktCount", report.FbPktCount) // SSBWE-REMOVE
if t.estimatedFeedbackInterval == 0 {
t.estimatedFeedbackInterval = sinceLast
} else {
// filter out outliers from estimate
if sinceLast > t.estimatedFeedbackInterval/cOutlierReportFactor && sinceLast < cOutlierReportFactor*t.estimatedFeedbackInterval {
// smoothed version of inter feedback interval
t.estimatedFeedbackInterval = time.Duration(cEstimatedFeedbackIntervalAlpha*float64(t.estimatedFeedbackInterval) + (1.0-cEstimatedFeedbackIntervalAlpha)*float64(sinceLast))
}
}
t.lastFeedbackTime = at
t.highestFeedbackCount = report.FbPktCount
}
return referenceTime * cReferenceTimeResolution * 1000, isOutOfOrder
}
func (t *twccFeedback) MarshalLogObject(e zapcore.ObjectEncoder) error {
if t == nil {
return nil
}
e.AddTime("lastFeedbackTime", t.lastFeedbackTime)
e.AddDuration("estimatedFeedbackInterval", t.estimatedFeedbackInterval)
e.AddInt("numReports", t.numReports)
e.AddInt("numReportsOutOfOrder", t.numReportsOutOfOrder)
e.AddInt64("cycles", t.cycles/(1<<24))
return nil
}
+28 -8
View File
@@ -16,7 +16,9 @@ package streamallocator
import (
"fmt"
"time"
"github.com/livekit/livekit-server/pkg/sfu/ccutils"
"github.com/livekit/protocol/logger"
)
@@ -69,18 +71,36 @@ func (c ChannelCongestionReason) String() string {
// ------------------------------------------------
type ChannelObserverConfig struct {
Estimate TrendDetectorConfig `yaml:"estimate,omitempty"`
Nack NackTrackerConfig `yaml:"nack,omitempty"`
Estimate ccutils.TrendDetectorConfig `yaml:"estimate,omitempty"`
Nack NackTrackerConfig `yaml:"nack,omitempty"`
}
var (
defaultTrendDetectorConfigProbe = ccutils.TrendDetectorConfig{
RequiredSamples: 3,
RequiredSamplesMin: 3,
DownwardTrendThreshold: 0.0,
DownwardTrendMaxWait: 5 * time.Second,
CollapseThreshold: 0,
ValidityWindow: 10 * time.Second,
}
DefaultChannelObserverConfigProbe = ChannelObserverConfig{
Estimate: DefaultTrendDetectorConfigProbe,
Estimate: defaultTrendDetectorConfigProbe,
Nack: DefaultNackTrackerConfigProbe,
}
defaultTrendDetectorConfigNonProbe = ccutils.TrendDetectorConfig{
RequiredSamples: 12,
RequiredSamplesMin: 8,
DownwardTrendThreshold: -0.6,
DownwardTrendMaxWait: 5 * time.Second,
CollapseThreshold: 500 * time.Millisecond,
ValidityWindow: 10 * time.Second,
}
DefaultChannelObserverConfigNonProbe = ChannelObserverConfig{
Estimate: DefaultTrendDetectorConfigNonProbe,
Estimate: defaultTrendDetectorConfigNonProbe,
Nack: DefaultNackTrackerConfigNonProbe,
}
)
@@ -96,7 +116,7 @@ type ChannelObserver struct {
params ChannelObserverParams
logger logger.Logger
estimateTrend *TrendDetector
estimateTrend *ccutils.TrendDetector[int64]
nackTracker *NackTracker
}
@@ -104,7 +124,7 @@ func NewChannelObserver(params ChannelObserverParams, logger logger.Logger) *Cha
return &ChannelObserver{
params: params,
logger: logger,
estimateTrend: NewTrendDetector(TrendDetectorParams{
estimateTrend: ccutils.NewTrendDetector[int64](ccutils.TrendDetectorParams{
Name: params.Name + "-estimate",
Logger: logger,
Config: params.Config.Estimate,
@@ -155,7 +175,7 @@ func (c *ChannelObserver) GetTrend() (ChannelTrend, ChannelCongestionReason) {
estimateDirection := c.estimateTrend.GetDirection()
switch {
case estimateDirection == TrendDirectionDownward:
case estimateDirection == ccutils.TrendDirectionDownward:
c.logger.Debugw("stream allocator: channel observer: estimate is trending downward", "channel", c.ToString())
return ChannelTrendCongesting, ChannelCongestionReasonEstimate
@@ -163,7 +183,7 @@ func (c *ChannelObserver) GetTrend() (ChannelTrend, ChannelCongestionReason) {
c.logger.Debugw("stream allocator: channel observer: high rate of repeated NACKs", "channel", c.ToString())
return ChannelTrendCongesting, ChannelCongestionReasonLoss
case estimateDirection == TrendDirectionUpward:
case estimateDirection == ccutils.TrendDirectionUpward:
return ChannelTrendClearing, ChannelCongestionReasonNone
}
@@ -18,6 +18,7 @@ import (
"sync"
"time"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/protocol/logger"
)
@@ -73,6 +74,7 @@ type ProbeController struct {
params ProbeControllerParams
lock sync.RWMutex
sendSideBWE *sendsidebwe.SendSideBWE
probeInterval time.Duration
lastProbeStartTime time.Time
probeGoalBps int64
@@ -95,6 +97,13 @@ func NewProbeController(params ProbeControllerParams) *ProbeController {
return p
}
func (p *ProbeController) SetSendSideBWE(sendSideBWE *sendsidebwe.SendSideBWE) {
p.lock.Lock()
defer p.lock.Unlock()
p.sendSideBWE = sendSideBWE
}
func (p *ProbeController) Reset() {
p.lock.Lock()
defer p.lock.Unlock()
@@ -270,9 +279,78 @@ func (p *ProbeController) InitProbe(probeGoalDeltaBps int64, expectedBandwidthUs
time.Duration(float64(p.probeDuration.Milliseconds())*p.params.Config.DurationOverflowFactor)*time.Millisecond,
)
p.pollProbe(p.probeClusterId)
return p.probeClusterId, p.probeGoalBps
}
// SSBWE-TODO: try to do same path for both SSBWE and regular, the congesting part might be different though
func (p *ProbeController) pollProbe(probeClusterId ProbeClusterId) {
if p.sendSideBWE == nil {
return
}
startingEstimate := p.sendSideBWE.GetEstimatedAvailableChannelCapacity()
go func() {
for {
p.lock.Lock()
if p.probeClusterId != probeClusterId {
p.lock.Unlock()
return
}
done := false
congestionState := p.sendSideBWE.GetCongestionState()
currentEstimate := p.sendSideBWE.GetEstimatedAvailableChannelCapacity()
switch {
case currentEstimate <= startingEstimate && time.Since(p.lastProbeStartTime) > p.params.Config.TrendWait:
//
// More of a safety net.
// In rare cases, the estimate gets stuck. Prevent from probe running amok
// STREAM-ALLOCATOR-TODO: Need more testing here to ensure that probe does not cause a lot of damage
//
p.params.Logger.Infow("stream allocator: probe: aborting, no trend", "cluster", probeClusterId)
p.abortProbeLocked()
done = true
break
case congestionState == sendsidebwe.CongestionStateCongested || congestionState == sendsidebwe.CongestionStateEarlyWarning:
// stop immediately if the probe is congesting channel more
p.params.Logger.Infow(
"stream allocator: probe: aborting, channel is congesting",
"cluster", probeClusterId,
"congestionState", congestionState,
)
p.abortProbeLocked()
done = true
break
case currentEstimate > p.probeGoalBps:
// reached goal, stop probing
p.params.Logger.Infow(
"stream allocator: probe: stopping, goal reached",
"cluster", probeClusterId,
"goal", p.probeGoalBps,
"current", currentEstimate,
)
p.goalReachedProbeClusterId = p.probeClusterId
p.StopProbe()
done = true
break
}
p.lock.Unlock()
if done {
return
}
// SSBWE-TODO: do not hard code sleep time
time.Sleep(50 * time.Millisecond)
}
}()
}
func (p *ProbeController) clearProbeLocked() {
p.probeClusterId = ProbeClusterIdInvalid
p.doneProbeClusterInfo = ProbeClusterInfo{Id: ProbeClusterIdInvalid}
+92 -10
View File
@@ -30,6 +30,7 @@ import (
"github.com/livekit/livekit-server/pkg/sfu"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
"github.com/livekit/livekit-server/pkg/sfu/sendsidebwe"
"github.com/livekit/livekit-server/pkg/utils"
)
@@ -86,6 +87,7 @@ const (
streamAllocatorSignalSetChannelCapacity
// STREAM-ALLOCATOR-DATA streamAllocatorSignalNACK
// STREAM-ALLOCATOR-DATA streamAllocatorSignalRTCPReceiverReport
streamAllocatorSignalCongestionStateChange
)
func (s streamAllocatorSignal) String() string {
@@ -116,6 +118,8 @@ func (s streamAllocatorSignal) String() string {
case streamAllocatorSignalRTCPReceiverReport:
return "RTCP_RECEIVER_REPORT"
*/
case streamAllocatorSignalCongestionStateChange:
return "CONGESTION_STATE_CHANGE"
default:
return fmt.Sprintf("%d", int(s))
}
@@ -179,7 +183,8 @@ type StreamAllocator struct {
onStreamStateChange func(update *StreamStateUpdate) error
bwe cc.BandwidthEstimator
sendSideBWEInterceptor cc.BandwidthEstimator
sendSideBWE *sendsidebwe.SendSideBWE
enabled bool
allowPause bool
@@ -200,7 +205,8 @@ type StreamAllocator struct {
isAllocateAllPending bool
rembTrackingSSRC uint32
state streamAllocatorState
state streamAllocatorState
isHolding bool
eventsQueue *utils.TypedOpsQueue[Event]
@@ -256,11 +262,19 @@ func (s *StreamAllocator) OnStreamStateChange(f func(update *StreamStateUpdate)
s.onStreamStateChange = f
}
func (s *StreamAllocator) SetBandwidthEstimator(bwe cc.BandwidthEstimator) {
if bwe != nil {
bwe.OnTargetBitrateChange(s.onTargetBitrateChange)
func (s *StreamAllocator) SetSendSideBWEInterceptor(sendSideBWEInterceptor cc.BandwidthEstimator) {
if sendSideBWEInterceptor != nil {
sendSideBWEInterceptor.OnTargetBitrateChange(s.onTargetBitrateChange)
}
s.bwe = bwe
s.sendSideBWEInterceptor = sendSideBWEInterceptor
}
func (s *StreamAllocator) SetSendSideBWE(sendSideBWE *sendsidebwe.SendSideBWE) {
if sendSideBWE != nil {
sendSideBWE.OnCongestionStateChange(s.onCongestionStateChange)
}
s.sendSideBWE = sendSideBWE
s.probeController.SetSendSideBWE(sendSideBWE)
}
type AddTrackParams struct {
@@ -428,8 +442,12 @@ func (s *StreamAllocator) OnREMB(downTrack *sfu.DownTrack, remb *rtcp.ReceiverEs
// called when a new transport-cc feedback is received
func (s *StreamAllocator) OnTransportCCFeedback(downTrack *sfu.DownTrack, fb *rtcp.TransportLayerCC) {
if s.bwe != nil {
s.bwe.WriteRTCP([]rtcp.Packet{fb}, nil)
if s.sendSideBWEInterceptor != nil {
s.sendSideBWEInterceptor.WriteRTCP([]rtcp.Packet{fb}, nil)
}
if s.sendSideBWE != nil {
s.sendSideBWE.HandleRTCP(fb)
}
}
@@ -441,6 +459,22 @@ func (s *StreamAllocator) onTargetBitrateChange(bitrate int) {
})
}
// called when congestion state changes (send side bandwidth estimation)
type congestionStateChangeData struct {
congestionState sendsidebwe.CongestionState
estimatedAvailableChannelCapacity int64
}
func (s *StreamAllocator) onCongestionStateChange(congestionState sendsidebwe.CongestionState, estimatedAvailableChannelCapacity int64) {
s.postEvent(Event{
Signal: streamAllocatorSignalCongestionStateChange,
Data: congestionStateChangeData{
congestionState: congestionState,
estimatedAvailableChannelCapacity: estimatedAvailableChannelCapacity,
},
})
}
// called when feeding track's layer availability changes
func (s *StreamAllocator) OnAvailableLayersChanged(downTrack *sfu.DownTrack) {
s.maybePostEventAllocateTrack(downTrack)
@@ -637,6 +671,8 @@ func (s *StreamAllocator) postEvent(event Event) {
case streamAllocatorSignalRTCPReceiverReport:
event.s.handleSignalRTCPReceiverReport(event)
*/
case streamAllocatorSignalCongestionStateChange:
s.handleSignalCongestionStateChange(event)
}
}, event)
}
@@ -780,6 +816,45 @@ func (s *StreamAllocator) handleSignalRTCPReceiverReport(event Event) {
}
*/
func (s *StreamAllocator) handleSignalCongestionStateChange(event Event) {
cscd := event.Data.(congestionStateChangeData)
if cscd.congestionState != sendsidebwe.CongestionStateNone {
s.probeController.AbortProbe()
}
if cscd.congestionState == sendsidebwe.CongestionStateEarlyWarning ||
cscd.congestionState == sendsidebwe.CongestionStateEarlyWarningHangover {
s.isHolding = true
} else {
s.isHolding = false
// early warning is done and hold has been released,
// if there is no congestion, allocate all tracks optimally as
// some tracks may have been held at sub-optimal allocation
// during early warning hold
if cscd.congestionState == sendsidebwe.CongestionStateNone && s.state == streamAllocatorStateStable {
update := NewStreamStateUpdate()
for _, track := range s.getTracks() {
allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal, s.isHolding)
updateStreamStateChange(track, allocation, update)
}
s.maybeSendUpdate(update)
}
}
if cscd.congestionState == sendsidebwe.CongestionStateCongested {
s.params.Logger.Infow(
"stream allocator: channel congestion detected, updating channel capacity",
"old(bps)", s.committedChannelCapacity,
"new(bps)", cscd.estimatedAvailableChannelCapacity,
"expectedUsage(bps)", s.getExpectedBandwidthUsage(),
)
s.committedChannelCapacity = cscd.estimatedAvailableChannelCapacity
s.allocateAllTracks()
}
}
func (s *StreamAllocator) setState(state streamAllocatorState) {
if s.state == state {
return
@@ -905,7 +980,7 @@ func (s *StreamAllocator) allocateTrack(track *Track) {
// if not deficient, free pass allocate track
if !s.enabled || s.state == streamAllocatorStateStable || !track.IsManaged() {
update := NewStreamStateUpdate()
allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal)
allocation := track.AllocateOptimal(FlagAllowOvershootWhileOptimal, s.isHolding)
updateStreamStateChange(track, allocation, update)
s.maybeSendUpdate(update)
return
@@ -1055,6 +1130,9 @@ func (s *StreamAllocator) allocateTrack(track *Track) {
func (s *StreamAllocator) onProbeDone(isNotFailing bool, isGoalReached bool) {
highestEstimateInProbe := s.channelObserver.GetHighestEstimate()
if s.sendSideBWE != nil {
highestEstimateInProbe = s.sendSideBWE.GetEstimatedAvailableChannelCapacity()
}
//
// Reset estimator at the end of a probe irrespective of probe result to get fresh readings.
@@ -1161,7 +1239,7 @@ func (s *StreamAllocator) allocateAllTracks() {
continue
}
allocation := track.AllocateOptimal(FlagAllowOvershootExemptTrackWhileDeficient)
allocation := track.AllocateOptimal(FlagAllowOvershootExemptTrackWhileDeficient, false)
updateStreamStateChange(track, allocation, update)
// STREAM-ALLOCATOR-TODO: optimistic allocation before bitrate is available will return 0. How to account for that?
@@ -1198,6 +1276,7 @@ func (s *StreamAllocator) allocateAllTracks() {
for _, track := range sorted {
_, usedChannelCapacity := track.ProvisionalAllocate(availableChannelCapacity, layer, s.allowPause, FlagAllowOvershootWhileDeficient)
s.params.Logger.Infow("debug allocated", "trackID", track.ID(), "usedChannelCapacity", usedChannelCapacity, "availableChannelCapacity", availableChannelCapacity) // REMOVE
availableChannelCapacity -= usedChannelCapacity
if availableChannelCapacity < 0 {
availableChannelCapacity = 0
@@ -1357,6 +1436,9 @@ func (s *StreamAllocator) maybeProbe() {
if !s.probeController.CanProbe() {
return
}
if s.sendSideBWE != nil && s.sendSideBWE.GetCongestionState() != sendsidebwe.CongestionStateNone {
return
}
switch s.params.Config.ProbeMode {
case ProbeModeMedia:
+2 -2
View File
@@ -154,8 +154,8 @@ func (t *Track) WritePaddingRTP(bytesToSend int) int {
return t.downTrack.WritePaddingRTP(bytesToSend, false, false)
}
func (t *Track) AllocateOptimal(allowOvershoot bool) sfu.VideoAllocation {
return t.downTrack.AllocateOptimal(allowOvershoot)
func (t *Track) AllocateOptimal(allowOvershoot bool, hold bool) sfu.VideoAllocation {
return t.downTrack.AllocateOptimal(allowOvershoot, hold)
}
func (t *Track) ProvisionalAllocatePrepare() {