Include NACK ratio in congestion control (#443)

* WIP commit

* WIP commit

* WIP commit

* WIP commit

* WIP commit

* WIP commit

* WIP commit

* WIP commit

* Clean up

* Remove debug

* Remove unneeded change

* fix test

* Remove incorrect comment

* WIP commit

* Reset probe after estimate trends down

* WIP commit

* variable name change

* WIP commit

* WIP commit

* out-of-order test

* WIP commit

* Clean up

* more strict probe NACKs
This commit is contained in:
Raja Subramanian
2022-02-18 14:21:30 +05:30
committed by GitHub
parent 7fcb887eb8
commit babbfb37aa
4 changed files with 392 additions and 155 deletions

View File

@@ -96,8 +96,9 @@ type DownTrack struct {
listenerLock sync.RWMutex
closeOnce sync.Once
statsLock sync.RWMutex
stats buffer.StreamStats
statsLock sync.RWMutex
stats buffer.StreamStats
totalRepeatedNACKs uint32
connectionStats *connectionquality.ConnectionStats
@@ -333,7 +334,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error {
}
if d.sequencer != nil {
meta := d.sequencer.push(extPkt.Packet.SequenceNumber, tp.rtp.sequenceNumber, tp.rtp.timestamp, uint8(layer), extPkt.Head)
meta := d.sequencer.push(extPkt.Packet.SequenceNumber, tp.rtp.sequenceNumber, tp.rtp.timestamp, int8(layer))
if meta != nil && tp.vp8 != nil {
meta.packVP8(tp.vp8.header)
}
@@ -451,12 +452,14 @@ func (d *DownTrack) WritePaddingRTP(bytesToSend int) int {
f(d, size)
}
// LK-TODO-START
// NACK buffer for these probe packets.
// Probably okay to absorb the NACKs for these and ignore them.
//
// Register with sequencer with invalid layer so that NACKs for these can be filtered out.
// Retransmission is probably a sign of network congestion/badness.
// So, retransmitting padding packets is only going to make matters worse.
// LK-TODO-END
//
if d.sequencer != nil {
d.sequencer.push(0, hdr.SequenceNumber, hdr.Timestamp, int8(InvalidLayerSpatial))
}
bytesSent += size
}
@@ -951,9 +954,9 @@ func (d *DownTrack) handleRTCP(bytes []byte) {
d.stats.RTT = rtt
d.stats.Jitter = float64(r.Jitter)
d.statsLock.Unlock()
d.connectionStats.UpdateWindow(r.SSRC, r.LastSequenceNumber, r.TotalLost, rtt, r.Jitter)
d.statsLock.Unlock()
}
if len(rr.Reports) > 0 {
d.listenerLock.RLock()
@@ -966,10 +969,11 @@ func (d *DownTrack) handleRTCP(bytes []byte) {
case *rtcp.TransportLayerNack:
var nacks []uint16
for _, pair := range p.Nacks {
nacks = append(nacks, pair.PacketList()...)
packetList := pair.PacketList()
numNACKs += uint32(len(packetList))
nacks = append(nacks, packetList...)
}
go d.retransmitPackets(nacks)
numNACKs += uint32(len(nacks))
case *rtcp.TransportLayerCC:
if p.MediaSSRC == d.ssrc && d.onTransportCCFeedback != nil {
@@ -984,12 +988,22 @@ func (d *DownTrack) handleRTCP(bytes []byte) {
d.stats.TotalFIRs += numFIRs
d.statsLock.Unlock()
if rttToReport != 0 && d.onRttUpdate != nil {
d.onRttUpdate(d, rttToReport)
if rttToReport != 0 {
if d.sequencer != nil {
d.sequencer.setRTT(rttToReport)
}
if d.onRttUpdate != nil {
d.onRttUpdate(d, rttToReport)
}
}
}
func (d *DownTrack) retransmitPackets(nacks []uint16) {
if d.sequencer == nil {
return
}
if FlagStopRTXOnPLI && d.isNACKThrottled.get() {
return
}
@@ -999,12 +1013,6 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
return
}
if d.sequencer == nil {
return
}
nackedPackets := d.sequencer.getSeqNoPairs(filtered)
var pool *[]byte
defer func() {
if pool != nil {
@@ -1016,18 +1024,32 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
src := PacketFactory.Get().(*[]byte)
defer PacketFactory.Put(src)
for _, meta := range nackedPackets {
numRepeatedNACKs := uint32(0)
for _, meta := range d.sequencer.getPacketsMeta(filtered) {
if meta.layer == int8(InvalidLayerSpatial) {
if meta.nacked > 1 {
numRepeatedNACKs++
}
// padding packet, no RTX for those
continue
}
if disallowedLayers[meta.layer] {
continue
}
if meta.nacked > 1 {
numRepeatedNACKs++
}
if pool != nil {
PacketFactory.Put(pool)
pool = nil
}
pktBuff := *src
n, err := d.receiver.ReadRTP(pktBuff, meta.layer, meta.sourceSeqNo)
n, err := d.receiver.ReadRTP(pktBuff, uint8(meta.layer), meta.sourceSeqNo)
if err != nil {
if err == io.EOF {
break
@@ -1081,6 +1103,10 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
d.updateRtxStats(pktSize)
}
}
d.statsLock.Lock()
d.totalRepeatedNACKs += numRepeatedNACKs
d.statsLock.Unlock()
}
func (d *DownTrack) getSRStats() (uint64, uint32) {
@@ -1210,3 +1236,12 @@ func (d *DownTrack) getTrackStats() map[uint32]*buffer.StreamStatsWithLayers {
return stats
}
func (d *DownTrack) GetNackStats() (totalPackets uint32, totalRepeatedNACKs uint32) {
d.statsLock.RLock()
defer d.statsLock.RUnlock()
totalPackets = d.stats.TotalPrimaryPackets + d.stats.TotalPaddingPackets
totalRepeatedNACKs = d.totalRepeatedNACKs
return
}

View File

@@ -1,6 +1,7 @@
package sfu
import (
"math"
"sync"
"time"
@@ -10,6 +11,8 @@ import (
)
const (
maxPadding = 2000
defaultRtt = 70
ignoreRetransmission = 100 // Ignore packet retransmission after ignoreRetransmission milliseconds
)
@@ -44,8 +47,10 @@ type packetMeta struct {
// the same packet.
// The resolution is 1 ms counting after the sequencer start time.
lastNack uint32
// number of NACKs this packet has received
nacked uint8
// Spatial layer of packet
layer uint8
layer int8
// Information that differs depending on the codec
misc uint64
}
@@ -93,81 +98,101 @@ type sequencer struct {
step int
headSN uint16
startTime int64
rtt uint32
logger logger.Logger
}
func newSequencer(maxTrack int, logger logger.Logger) *sequencer {
return &sequencer{
startTime: time.Now().UnixNano() / 1e6,
max: maxTrack,
seq: make([]packetMeta, maxTrack),
max: maxTrack + maxPadding,
seq: make([]packetMeta, maxTrack+maxPadding),
rtt: defaultRtt,
logger: logger,
}
}
func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer uint8, head bool) *packetMeta {
func (n *sequencer) setRTT(rtt uint32) {
n.Lock()
defer n.Unlock()
if !n.init {
if rtt == 0 {
n.rtt = defaultRtt
} else {
n.rtt = rtt
}
}
func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer int8) *packetMeta {
n.Lock()
defer n.Unlock()
inc := offSn - n.headSN
step := 0
switch {
case !n.init:
n.headSN = offSn
n.init = true
}
step := 0
if head {
inc := offSn - n.headSN
for i := uint16(1); i < inc; i++ {
n.step++
if n.step >= n.max {
n.step = 0
}
case inc == 0:
// duplicate
return nil
case inc < (1 << 15): // in-order packet
n.step += int(inc)
if n.step >= n.max {
n.step -= n.max
}
step = n.step
n.headSN = offSn
} else {
step = n.step - int(n.headSN-offSn)
default: // out-of-order packet
back := int(n.headSN - offSn)
if back >= n.max {
n.logger.Debugw("old packet, can not be sequenced", "head", sn, "received", offSn)
return nil
}
step = n.step - back
if step < 0 {
if step*-1 >= n.max {
n.logger.Debugw("old packet received, can not be sequenced", "head", sn, "received", offSn)
return nil
}
step = n.max + step
step += n.max
}
}
n.seq[n.step] = packetMeta{
n.seq[step] = packetMeta{
sourceSeqNo: sn,
targetSeqNo: offSn,
timestamp: timeStamp,
layer: layer,
}
pm := &n.seq[n.step]
n.step++
if n.step >= n.max {
n.step = 0
}
return pm
return &n.seq[step]
}
func (n *sequencer) getSeqNoPairs(seqNo []uint16) []packetMeta {
func (n *sequencer) getPacketsMeta(seqNo []uint16) []packetMeta {
n.Lock()
meta := make([]packetMeta, 0, 17)
defer n.Unlock()
meta := make([]packetMeta, 0, len(seqNo))
refTime := uint32(time.Now().UnixNano()/1e6 - n.startTime)
for _, sn := range seqNo {
step := n.step - int(n.headSN-sn) - 1
if step < 0 {
if step*-1 >= n.max {
continue
}
step = n.max + step
diff := n.headSN - sn
if diff > (1<<15) || int(diff) >= n.max {
// out-of-order from head (should not happen) or too old
continue
}
step := n.step - int(diff)
if step < 0 {
step += n.max
}
seq := &n.seq[step]
if seq.targetSeqNo == sn {
if seq.lastNack == 0 || refTime-seq.lastNack > ignoreRetransmission {
seq.lastNack = refTime
meta = append(meta, *seq)
}
if seq.targetSeqNo != sn {
continue
}
if seq.lastNack == 0 || refTime-seq.lastNack > uint32(math.Min(float64(ignoreRetransmission), float64(2*n.rtt))) {
seq.nacked++
seq.lastNack = refTime
meta = append(meta, *seq)
}
}
n.Unlock()
return meta
}

View File

@@ -15,33 +15,39 @@ func Test_sequencer(t *testing.T) {
seq := newSequencer(500, logger.Logger(logger.GetLogger()))
off := uint16(15)
for i := uint16(1); i < 520; i++ {
seq.push(i, i+off, 123, 2, true)
for i := uint16(1); i < 518; i++ {
seq.push(i, i+off, 123, 2)
}
// send the last two out-of-order
seq.push(519, 519+off, 123, 2)
seq.push(518, 518+off, 123, 2)
time.Sleep(60 * time.Millisecond)
req := []uint16{57, 58, 62, 63, 513, 514, 515, 516, 517}
res := seq.getSeqNoPairs(req)
res := seq.getPacketsMeta(req)
require.Equal(t, len(req), len(res))
for i, val := range res {
require.Equal(t, val.targetSeqNo, req[i])
require.Equal(t, val.sourceSeqNo, req[i]-off)
require.Equal(t, val.layer, uint8(2))
require.Equal(t, val.layer, int8(2))
}
res = seq.getSeqNoPairs(req)
res = seq.getPacketsMeta(req)
require.Equal(t, 0, len(res))
time.Sleep(150 * time.Millisecond)
res = seq.getSeqNoPairs(req)
res = seq.getPacketsMeta(req)
require.Equal(t, len(req), len(res))
for i, val := range res {
require.Equal(t, val.targetSeqNo, req[i])
require.Equal(t, val.sourceSeqNo, req[i]-off)
require.Equal(t, val.layer, uint8(2))
require.Equal(t, val.layer, int8(2))
}
s := seq.push(521, 521+off, 123, 1, true)
s.sourceSeqNo = 12
m := seq.getSeqNoPairs([]uint16{521 + off})
seq.push(521, 521+off, 123, 1)
m := seq.getPacketsMeta([]uint16{521 + off})
require.Equal(t, 1, len(m))
seq.push(505, 505+off, 123, 1)
m = seq.getPacketsMeta([]uint16{505 + off})
require.Equal(t, 1, len(m))
}
@@ -78,16 +84,16 @@ func Test_sequencer_getNACKSeqNo(t *testing.T) {
n := newSequencer(500, logger.Logger(logger.GetLogger()))
for _, i := range tt.fields.input {
n.push(i, i+tt.fields.offset, 123, 3, true)
n.push(i, i+tt.fields.offset, 123, 3)
}
g := n.getSeqNoPairs(tt.args.seqNo)
g := n.getPacketsMeta(tt.args.seqNo)
var got []uint16
for _, sn := range g {
got = append(got, sn.sourceSeqNo)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("getSeqNoPairs() = %v, want %v", got, tt.want)
t.Errorf("getPacketsMeta() = %v, want %v", got, tt.want)
}
})
}

View File

@@ -23,6 +23,11 @@ const (
NumRequiredEstimatesNonProbe = 8
NumRequiredEstimatesProbe = 3
NackRatioThresholdNonProbe = 0.06
NackRatioThresholdProbe = 0.04
NackRatioAttenuator = 0.4 // how much to attenuate NACK ratio while calculating loss adjusted estimate
ProbeWaitBase = 5 * time.Second
ProbeBackoffFactor = 1.5
ProbeWaitMax = 30 * time.Second
@@ -130,9 +135,7 @@ type StreamAllocator struct {
bwe cc.BandwidthEstimator
lastReceivedEstimate int64
estimator *Estimator
lastReceivedEstimate int64
committedChannelCapacity int64
probeInterval time.Duration
@@ -142,10 +145,12 @@ type StreamAllocator struct {
abortedProbeClusterId ProbeClusterId
probeTrendObserved bool
probeEndTime time.Time
probeEstimator *Estimator
probeChannelObserver *ChannelObserver
prober *Prober
channelObserver *ChannelObserver
audioTracks map[livekit.TrackID]*Track
videoTracks map[livekit.TrackID]*Track
exemptVideoTracksSorted TrackSorter
@@ -161,14 +166,14 @@ type StreamAllocator struct {
func NewStreamAllocator(params StreamAllocatorParams) *StreamAllocator {
s := &StreamAllocator{
params: params,
estimator: NewEstimator("non-probe", params.Logger, NumRequiredEstimatesNonProbe),
audioTracks: make(map[livekit.TrackID]*Track),
videoTracks: make(map[livekit.TrackID]*Track),
params: params,
prober: NewProber(ProberParams{
Logger: params.Logger,
}),
eventCh: make(chan Event, 20),
channelObserver: NewChannelObserver("non-probe", params.Logger, NumRequiredEstimatesNonProbe, NackRatioThresholdNonProbe),
audioTracks: make(map[livekit.TrackID]*Track),
videoTracks: make(map[livekit.TrackID]*Track),
eventCh: make(chan Event, 20),
}
s.resetState()
@@ -639,7 +644,7 @@ func (s *StreamAllocator) handleSignalProbeClusterDone(event *Event) {
// ensure probe queue is flushed
// LK-TODO: ProbeSettleWait should actually be a certain number of RTTs.
lowestEstimate := int64(math.Min(float64(s.committedChannelCapacity), float64(s.probeEstimator.GetLowest())))
lowestEstimate := int64(math.Min(float64(s.committedChannelCapacity), float64(s.probeChannelObserver.GetLowestEstimate())))
expectedDuration := float64(info.BytesSent*8*1000) / float64(lowestEstimate)
queueTime := expectedDuration - float64(info.Duration.Milliseconds())
if queueTime < 0.0 {
@@ -684,8 +689,13 @@ func (s *StreamAllocator) handleNewEstimate(receivedEstimate int64) {
}
func (s *StreamAllocator) handleNewEstimateInProbe() {
trend := s.probeEstimator.AddEstimate(s.lastReceivedEstimate)
if trend != EstimateTrendNeutral {
s.probeChannelObserver.AddEstimate(s.lastReceivedEstimate)
packetDelta, repeatedNackDelta := s.getNackDelta()
s.probeChannelObserver.AddNack(packetDelta, repeatedNackDelta)
trend := s.probeChannelObserver.GetTrend()
if trend != ChannelTrendNeutral {
s.probeTrendObserved = true
}
switch {
@@ -695,15 +705,15 @@ func (s *StreamAllocator) handleNewEstimateInProbe() {
//
// More of a safety net.
// In rare cases, the estimate gets stuck. Prevent from probe running amok
// LK-TODO: Need more testing this here and ensure that probe does not cause a lot of damage
// LK-TODO: Need more testing here to ensure that probe does not cause a lot of damage
//
s.params.Logger.Debugw("probe: aborting, no trend")
s.abortProbe()
case trend == EstimateTrendDownward:
// stop immediately if estimate falls below the previously committed estimate, the probe is congesting channel more
s.params.Logger.Debugw("probe: aborting, estimate is trending downward")
case trend == ChannelTrendCongesting:
// stop immediately if the probe is congesting channel more
s.params.Logger.Debugw("probe: aborting, channel is congesting")
s.abortProbe()
case s.probeEstimator.GetHighest() > s.probeGoalBps:
case s.probeChannelObserver.GetHighestEstimate() > s.probeGoalBps:
// reached goal, stop probing
s.params.Logger.Debugw("probe: stopping, goal reached")
s.stopProbe()
@@ -711,20 +721,36 @@ func (s *StreamAllocator) handleNewEstimateInProbe() {
}
func (s *StreamAllocator) handleNewEstimateInNonProbe() {
trend := s.estimator.AddEstimate(s.lastReceivedEstimate)
if trend != EstimateTrendDownward {
s.channelObserver.AddEstimate(s.lastReceivedEstimate)
packetDelta, repeatedNackDelta := s.getNackDelta()
s.channelObserver.AddNack(packetDelta, repeatedNackDelta)
trend := s.channelObserver.GetTrend()
if trend != ChannelTrendCongesting {
return
}
nackRatio := s.channelObserver.GetNackRatio()
lossAdjustedEstimate := s.lastReceivedEstimate
if nackRatio > NackRatioThresholdNonProbe {
lossAdjustedEstimate = int64(float64(lossAdjustedEstimate) * (1.0 - NackRatioAttenuator*nackRatio))
}
if s.committedChannelCapacity == lossAdjustedEstimate {
return
}
s.params.Logger.Infow(
"estimate trending down, updating channel capacity",
"channel congestion detected, updating channel capacity",
"old(bps)", s.committedChannelCapacity,
"new(bps)", s.lastReceivedEstimate,
"new(bps)", lossAdjustedEstimate,
"lastReceived(bps)", s.lastReceivedEstimate,
"nackRatio", nackRatio,
)
s.committedChannelCapacity = s.lastReceivedEstimate
s.committedChannelCapacity = lossAdjustedEstimate
// reset to get new set of samples for next trend
s.estimator.Reset()
s.channelObserver.Reset()
// reset probe to ensure it does not start too soon after a downward trend
s.resetProbe()
@@ -832,7 +858,7 @@ func (s *StreamAllocator) allocateTrack(track *Track) {
func (s *StreamAllocator) finalizeProbe() {
aborted := s.probeClusterId == s.abortedProbeClusterId
highestEstimateInProbe := s.probeEstimator.GetHighest()
highestEstimateInProbe := s.probeChannelObserver.GetHighestEstimate()
s.clearProbe()
@@ -847,7 +873,7 @@ func (s *StreamAllocator) finalizeProbe() {
// NOTE: With TWCC, it is possible to reset bandwidth estimation to clean state as
// the send side is in full control of bandwidth estimation.
//
s.estimator.Reset()
s.channelObserver.Reset()
if aborted {
// failed probe, backoff
@@ -1018,6 +1044,18 @@ func (s *StreamAllocator) getExpectedBandwidthUsage() int64 {
return expected
}
func (s *StreamAllocator) getNackDelta() (uint32, uint32) {
aggPacketDelta := uint32(0)
aggRepeatedNackDelta := uint32(0)
for _, track := range s.videoTracks {
packetDelta, nackDelta := track.GetNackDelta()
aggPacketDelta += packetDelta
aggRepeatedNackDelta += nackDelta
}
return aggPacketDelta, aggRepeatedNackDelta
}
// LK-TODO: unused till loss based estimation is done, but just a sample impl of weighting audio higher
func (s *StreamAllocator) calculateLoss() float32 {
packetsAudio := uint32(0)
@@ -1062,8 +1100,8 @@ func (s *StreamAllocator) initProbe(goalBps int64) {
s.probeEndTime = time.Time{}
s.probeEstimator = NewEstimator("probe", s.params.Logger, NumRequiredEstimatesProbe)
s.probeEstimator.Seed(s.lastReceivedEstimate)
s.probeChannelObserver = NewChannelObserver("probe", s.params.Logger, NumRequiredEstimatesProbe, NackRatioThresholdProbe)
s.probeChannelObserver.SeedEstimate(s.lastReceivedEstimate)
}
func (s *StreamAllocator) resetProbe() {
@@ -1082,7 +1120,7 @@ func (s *StreamAllocator) clearProbe() {
s.probeEndTime = time.Time{}
s.probeEstimator = nil
s.probeChannelObserver = nil
}
func (s *StreamAllocator) backoffProbeInterval() {
@@ -1275,6 +1313,9 @@ type Track struct {
lastPacketsLost uint32
maxLayers VideoLayers
totalPackets uint32
totalRepeatedNacks uint32
}
func newTrack(downTrack *DownTrack, isManaged bool, publisherID livekit.ParticipantID, logger logger.Logger) *Track {
@@ -1384,6 +1425,18 @@ func (t *Track) DistanceToDesired() int32 {
return t.downTrack.DistanceToDesired()
}
func (t *Track) GetNackDelta() (uint32, uint32) {
totalPackets, totalRepeatedNacks := t.downTrack.GetNackStats()
packetDelta := totalPackets - t.totalPackets
t.totalPackets = totalPackets
nackDelta := totalRepeatedNacks - t.totalRepeatedNacks
t.totalRepeatedNacks = totalRepeatedNacks
return packetDelta, nackDelta
}
// ------------------------------------------------
type TrackSorter []*Track
@@ -1438,116 +1491,234 @@ func (m MinDistanceSorter) Less(i, j int) bool {
// ------------------------------------------------
type EstimateTrend int
type ChannelTrend int
const (
EstimateTrendNeutral EstimateTrend = iota
EstimateTrendUpward
EstimateTrendDownward
ChannelTrendNeutral ChannelTrend = iota
ChannelTrendClearing
ChannelTrendCongesting
)
func (e EstimateTrend) String() string {
switch e {
case EstimateTrendNeutral:
func (c ChannelTrend) String() string {
switch c {
case ChannelTrendNeutral:
return "NEUTRAL"
case EstimateTrendUpward:
return "UPWARD"
case EstimateTrendDownward:
return "DOWNWARD"
case ChannelTrendClearing:
return "CLEARING"
case ChannelTrendCongesting:
return "CONGESTING"
default:
return fmt.Sprintf("%d", int(e))
return fmt.Sprintf("%d", int(c))
}
}
type Estimator struct {
type ChannelObserver struct {
name string
logger logger.Logger
estimateTrend *TrendDetector
nackRatioThreshold float64
packets uint32
repeatedNacks uint32
}
func NewChannelObserver(
name string,
logger logger.Logger,
estimateRequiredSamples int,
nackRatioThreshold float64,
) *ChannelObserver {
return &ChannelObserver{
name: name,
logger: logger,
estimateTrend: NewTrendDetector(name+"-estimate", logger, estimateRequiredSamples),
nackRatioThreshold: nackRatioThreshold,
}
}
func (c *ChannelObserver) Reset() {
c.estimateTrend.Reset()
c.packets = 0
c.repeatedNacks = 0
}
func (c *ChannelObserver) SeedEstimate(estimate int64) {
c.estimateTrend.Seed(estimate)
}
func (c *ChannelObserver) SeedNack(packets uint32, repeatedNacks uint32) {
c.packets = packets
c.repeatedNacks = repeatedNacks
}
func (c *ChannelObserver) AddEstimate(estimate int64) {
c.estimateTrend.AddValue(estimate)
}
func (c *ChannelObserver) AddNack(packets uint32, repeatedNacks uint32) {
c.packets += packets
c.repeatedNacks += repeatedNacks
}
func (c *ChannelObserver) GetLowestEstimate() int64 {
return c.estimateTrend.GetLowest()
}
func (c *ChannelObserver) GetHighestEstimate() int64 {
return c.estimateTrend.GetHighest()
}
func (c *ChannelObserver) GetNackRatio() float64 {
ratio := float64(0.0)
if c.packets != 0 {
ratio = float64(c.repeatedNacks) / float64(c.packets)
if ratio > 1.0 {
ratio = 1.0
}
}
return ratio
}
func (c *ChannelObserver) GetTrend() ChannelTrend {
estimateDirection := c.estimateTrend.GetDirection()
nackRatio := c.GetNackRatio()
switch {
case estimateDirection == TrendDirectionDownward:
c.logger.Debugw("channel observer: estimate is trending downward")
return ChannelTrendCongesting
case nackRatio > c.nackRatioThreshold:
c.logger.Debugw("channel observer: high rate of repeated NACKs", "ratio", nackRatio)
return ChannelTrendCongesting
case estimateDirection == TrendDirectionUpward:
return ChannelTrendClearing
}
return ChannelTrendNeutral
}
// ------------------------------------------------
type TrendDirection int
const (
TrendDirectionNeutral TrendDirection = iota
TrendDirectionUpward
TrendDirectionDownward
)
func (t TrendDirection) String() string {
switch t {
case TrendDirectionNeutral:
return "NEUTRAL"
case TrendDirectionUpward:
return "UPWARD"
case TrendDirectionDownward:
return "DOWNWARD"
default:
return fmt.Sprintf("%d", int(t))
}
}
type TrendDetector struct {
name string
logger logger.Logger
requiredSamples int
estimates []int64
lowestEstimate int64
highestEstimate int64
values []int64
lowestvalue int64
highestvalue int64
direction TrendDirection
}
func NewEstimator(name string, logger logger.Logger, requiredSamples int) *Estimator {
return &Estimator{
func NewTrendDetector(name string, logger logger.Logger, requiredSamples int) *TrendDetector {
return &TrendDetector{
name: name,
logger: logger,
requiredSamples: requiredSamples,
direction: TrendDirectionNeutral,
}
}
func (e *Estimator) Reset() {
e.estimates = nil
e.lowestEstimate = int64(0)
e.highestEstimate = int64(0)
func (t *TrendDetector) Reset() {
t.values = nil
t.lowestvalue = int64(0)
t.highestvalue = int64(0)
}
func (e *Estimator) Seed(estimate int64) {
if len(e.estimates) != 0 {
func (t *TrendDetector) Seed(value int64) {
if len(t.values) != 0 {
return
}
e.estimates = append(e.estimates, estimate)
t.values = append(t.values, value)
}
func (e *Estimator) AddEstimate(estimate int64) EstimateTrend {
if e.lowestEstimate == 0 || estimate < e.lowestEstimate {
e.lowestEstimate = estimate
func (t *TrendDetector) AddValue(value int64) {
if t.lowestvalue == 0 || value < t.lowestvalue {
t.lowestvalue = value
}
if estimate > e.highestEstimate {
e.highestEstimate = estimate
if value > t.highestvalue {
t.highestvalue = value
}
if len(e.estimates) == e.requiredSamples {
e.estimates = e.estimates[1:]
if len(t.values) == t.requiredSamples {
t.values = t.values[1:]
}
e.estimates = append(e.estimates, estimate)
t.values = append(t.values, value)
return e.updateTrend()
t.updateDirection()
}
func (e *Estimator) GetLowest() int64 {
return e.lowestEstimate
func (t *TrendDetector) GetLowest() int64 {
return t.lowestvalue
}
func (e *Estimator) GetHighest() int64 {
return e.highestEstimate
func (t *TrendDetector) GetHighest() int64 {
return t.highestvalue
}
func (e *Estimator) updateTrend() EstimateTrend {
if len(e.estimates) < e.requiredSamples {
return EstimateTrendNeutral
func (t *TrendDetector) GetDirection() TrendDirection {
return t.direction
}
func (t *TrendDetector) updateDirection() {
if len(t.values) < t.requiredSamples {
t.direction = TrendDirectionNeutral
return
}
// using Kendall's Tau to find trend
concordantPairs := 0
discordantPairs := 0
for i := 0; i < len(e.estimates)-1; i++ {
for j := i + 1; j < len(e.estimates); j++ {
if e.estimates[i] < e.estimates[j] {
for i := 0; i < len(t.values)-1; i++ {
for j := i + 1; j < len(t.values); j++ {
if t.values[i] < t.values[j] {
concordantPairs++
} else if e.estimates[i] > e.estimates[j] {
} else if t.values[i] > t.values[j] {
discordantPairs++
}
}
}
if (concordantPairs + discordantPairs) == 0 {
return EstimateTrendNeutral
t.direction = TrendDirectionNeutral
return
}
trend := EstimateTrendNeutral
t.direction = TrendDirectionNeutral
kt := (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs))
switch {
case kt > 0:
trend = EstimateTrendUpward
t.direction = TrendDirectionUpward
case kt < 0:
trend = EstimateTrendDownward
t.direction = TrendDirectionDownward
}
return trend
}
// ------------------------------------------------