From babbfb37aafaa73d06b6ce0a5ea169915253faec Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 18 Feb 2022 14:21:30 +0530 Subject: [PATCH] Include NACK ratio in congestion control (#443) * WIP commit * WIP commit * WIP commit * WIP commit * WIP commit * WIP commit * WIP commit * WIP commit * Clean up * Remove debug * Remove unneeded change * fix test * Remove incorrect comment * WIP commit * Reset probe after estimate trends down * WIP commit * variable name change * WIP commit * WIP commit * out-of-order test * WIP commit * Clean up * more strict probe NACKs --- pkg/sfu/downtrack.go | 75 ++++++--- pkg/sfu/sequencer.go | 111 ++++++++----- pkg/sfu/sequencer_test.go | 32 ++-- pkg/sfu/streamallocator.go | 329 ++++++++++++++++++++++++++++--------- 4 files changed, 392 insertions(+), 155 deletions(-) diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 1b97d537c..9b0d7fc4d 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -96,8 +96,9 @@ type DownTrack struct { listenerLock sync.RWMutex closeOnce sync.Once - statsLock sync.RWMutex - stats buffer.StreamStats + statsLock sync.RWMutex + stats buffer.StreamStats + totalRepeatedNACKs uint32 connectionStats *connectionquality.ConnectionStats @@ -333,7 +334,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { } if d.sequencer != nil { - meta := d.sequencer.push(extPkt.Packet.SequenceNumber, tp.rtp.sequenceNumber, tp.rtp.timestamp, uint8(layer), extPkt.Head) + meta := d.sequencer.push(extPkt.Packet.SequenceNumber, tp.rtp.sequenceNumber, tp.rtp.timestamp, int8(layer)) if meta != nil && tp.vp8 != nil { meta.packVP8(tp.vp8.header) } @@ -451,12 +452,14 @@ func (d *DownTrack) WritePaddingRTP(bytesToSend int) int { f(d, size) } - // LK-TODO-START - // NACK buffer for these probe packets. - // Probably okay to absorb the NACKs for these and ignore them. + // + // Register with sequencer with invalid layer so that NACKs for these can be filtered out. // Retransmission is probably a sign of network congestion/badness. // So, retransmitting padding packets is only going to make matters worse. - // LK-TODO-END + // + if d.sequencer != nil { + d.sequencer.push(0, hdr.SequenceNumber, hdr.Timestamp, int8(InvalidLayerSpatial)) + } bytesSent += size } @@ -951,9 +954,9 @@ func (d *DownTrack) handleRTCP(bytes []byte) { d.stats.RTT = rtt d.stats.Jitter = float64(r.Jitter) + d.statsLock.Unlock() d.connectionStats.UpdateWindow(r.SSRC, r.LastSequenceNumber, r.TotalLost, rtt, r.Jitter) - d.statsLock.Unlock() } if len(rr.Reports) > 0 { d.listenerLock.RLock() @@ -966,10 +969,11 @@ func (d *DownTrack) handleRTCP(bytes []byte) { case *rtcp.TransportLayerNack: var nacks []uint16 for _, pair := range p.Nacks { - nacks = append(nacks, pair.PacketList()...) + packetList := pair.PacketList() + numNACKs += uint32(len(packetList)) + nacks = append(nacks, packetList...) } go d.retransmitPackets(nacks) - numNACKs += uint32(len(nacks)) case *rtcp.TransportLayerCC: if p.MediaSSRC == d.ssrc && d.onTransportCCFeedback != nil { @@ -984,12 +988,22 @@ func (d *DownTrack) handleRTCP(bytes []byte) { d.stats.TotalFIRs += numFIRs d.statsLock.Unlock() - if rttToReport != 0 && d.onRttUpdate != nil { - d.onRttUpdate(d, rttToReport) + if rttToReport != 0 { + if d.sequencer != nil { + d.sequencer.setRTT(rttToReport) + } + + if d.onRttUpdate != nil { + d.onRttUpdate(d, rttToReport) + } } } func (d *DownTrack) retransmitPackets(nacks []uint16) { + if d.sequencer == nil { + return + } + if FlagStopRTXOnPLI && d.isNACKThrottled.get() { return } @@ -999,12 +1013,6 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { return } - if d.sequencer == nil { - return - } - - nackedPackets := d.sequencer.getSeqNoPairs(filtered) - var pool *[]byte defer func() { if pool != nil { @@ -1016,18 +1024,32 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { src := PacketFactory.Get().(*[]byte) defer PacketFactory.Put(src) - for _, meta := range nackedPackets { + numRepeatedNACKs := uint32(0) + for _, meta := range d.sequencer.getPacketsMeta(filtered) { + if meta.layer == int8(InvalidLayerSpatial) { + if meta.nacked > 1 { + numRepeatedNACKs++ + } + + // padding packet, no RTX for those + continue + } + if disallowedLayers[meta.layer] { continue } + if meta.nacked > 1 { + numRepeatedNACKs++ + } + if pool != nil { PacketFactory.Put(pool) pool = nil } pktBuff := *src - n, err := d.receiver.ReadRTP(pktBuff, meta.layer, meta.sourceSeqNo) + n, err := d.receiver.ReadRTP(pktBuff, uint8(meta.layer), meta.sourceSeqNo) if err != nil { if err == io.EOF { break @@ -1081,6 +1103,10 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { d.updateRtxStats(pktSize) } } + + d.statsLock.Lock() + d.totalRepeatedNACKs += numRepeatedNACKs + d.statsLock.Unlock() } func (d *DownTrack) getSRStats() (uint64, uint32) { @@ -1210,3 +1236,12 @@ func (d *DownTrack) getTrackStats() map[uint32]*buffer.StreamStatsWithLayers { return stats } + +func (d *DownTrack) GetNackStats() (totalPackets uint32, totalRepeatedNACKs uint32) { + d.statsLock.RLock() + defer d.statsLock.RUnlock() + + totalPackets = d.stats.TotalPrimaryPackets + d.stats.TotalPaddingPackets + totalRepeatedNACKs = d.totalRepeatedNACKs + return +} diff --git a/pkg/sfu/sequencer.go b/pkg/sfu/sequencer.go index 3ea025b97..d0471023d 100644 --- a/pkg/sfu/sequencer.go +++ b/pkg/sfu/sequencer.go @@ -1,6 +1,7 @@ package sfu import ( + "math" "sync" "time" @@ -10,6 +11,8 @@ import ( ) const ( + maxPadding = 2000 + defaultRtt = 70 ignoreRetransmission = 100 // Ignore packet retransmission after ignoreRetransmission milliseconds ) @@ -44,8 +47,10 @@ type packetMeta struct { // the same packet. // The resolution is 1 ms counting after the sequencer start time. lastNack uint32 + // number of NACKs this packet has received + nacked uint8 // Spatial layer of packet - layer uint8 + layer int8 // Information that differs depending on the codec misc uint64 } @@ -93,81 +98,101 @@ type sequencer struct { step int headSN uint16 startTime int64 + rtt uint32 logger logger.Logger } func newSequencer(maxTrack int, logger logger.Logger) *sequencer { return &sequencer{ startTime: time.Now().UnixNano() / 1e6, - max: maxTrack, - seq: make([]packetMeta, maxTrack), + max: maxTrack + maxPadding, + seq: make([]packetMeta, maxTrack+maxPadding), + rtt: defaultRtt, logger: logger, } } -func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer uint8, head bool) *packetMeta { +func (n *sequencer) setRTT(rtt uint32) { n.Lock() defer n.Unlock() - if !n.init { + + if rtt == 0 { + n.rtt = defaultRtt + } else { + n.rtt = rtt + } +} + +func (n *sequencer) push(sn, offSn uint16, timeStamp uint32, layer int8) *packetMeta { + n.Lock() + defer n.Unlock() + + inc := offSn - n.headSN + step := 0 + switch { + case !n.init: n.headSN = offSn n.init = true - } - - step := 0 - if head { - inc := offSn - n.headSN - for i := uint16(1); i < inc; i++ { - n.step++ - if n.step >= n.max { - n.step = 0 - } + case inc == 0: + // duplicate + return nil + case inc < (1 << 15): // in-order packet + n.step += int(inc) + if n.step >= n.max { + n.step -= n.max } step = n.step n.headSN = offSn - } else { - step = n.step - int(n.headSN-offSn) + default: // out-of-order packet + back := int(n.headSN - offSn) + if back >= n.max { + n.logger.Debugw("old packet, can not be sequenced", "head", sn, "received", offSn) + return nil + } + step = n.step - back if step < 0 { - if step*-1 >= n.max { - n.logger.Debugw("old packet received, can not be sequenced", "head", sn, "received", offSn) - return nil - } - step = n.max + step + step += n.max } } - n.seq[n.step] = packetMeta{ + + n.seq[step] = packetMeta{ sourceSeqNo: sn, targetSeqNo: offSn, timestamp: timeStamp, layer: layer, } - pm := &n.seq[n.step] - n.step++ - if n.step >= n.max { - n.step = 0 - } - return pm + return &n.seq[step] } -func (n *sequencer) getSeqNoPairs(seqNo []uint16) []packetMeta { +func (n *sequencer) getPacketsMeta(seqNo []uint16) []packetMeta { n.Lock() - meta := make([]packetMeta, 0, 17) + defer n.Unlock() + + meta := make([]packetMeta, 0, len(seqNo)) refTime := uint32(time.Now().UnixNano()/1e6 - n.startTime) for _, sn := range seqNo { - step := n.step - int(n.headSN-sn) - 1 - if step < 0 { - if step*-1 >= n.max { - continue - } - step = n.max + step + diff := n.headSN - sn + if diff > (1<<15) || int(diff) >= n.max { + // out-of-order from head (should not happen) or too old + continue } + + step := n.step - int(diff) + if step < 0 { + step += n.max + } + seq := &n.seq[step] - if seq.targetSeqNo == sn { - if seq.lastNack == 0 || refTime-seq.lastNack > ignoreRetransmission { - seq.lastNack = refTime - meta = append(meta, *seq) - } + if seq.targetSeqNo != sn { + continue + } + + if seq.lastNack == 0 || refTime-seq.lastNack > uint32(math.Min(float64(ignoreRetransmission), float64(2*n.rtt))) { + seq.nacked++ + seq.lastNack = refTime + meta = append(meta, *seq) } } - n.Unlock() + return meta } diff --git a/pkg/sfu/sequencer_test.go b/pkg/sfu/sequencer_test.go index 1f1831641..7566029c3 100644 --- a/pkg/sfu/sequencer_test.go +++ b/pkg/sfu/sequencer_test.go @@ -15,33 +15,39 @@ func Test_sequencer(t *testing.T) { seq := newSequencer(500, logger.Logger(logger.GetLogger())) off := uint16(15) - for i := uint16(1); i < 520; i++ { - seq.push(i, i+off, 123, 2, true) + for i := uint16(1); i < 518; i++ { + seq.push(i, i+off, 123, 2) } + // send the last two out-of-order + seq.push(519, 519+off, 123, 2) + seq.push(518, 518+off, 123, 2) time.Sleep(60 * time.Millisecond) req := []uint16{57, 58, 62, 63, 513, 514, 515, 516, 517} - res := seq.getSeqNoPairs(req) + res := seq.getPacketsMeta(req) require.Equal(t, len(req), len(res)) for i, val := range res { require.Equal(t, val.targetSeqNo, req[i]) require.Equal(t, val.sourceSeqNo, req[i]-off) - require.Equal(t, val.layer, uint8(2)) + require.Equal(t, val.layer, int8(2)) } - res = seq.getSeqNoPairs(req) + res = seq.getPacketsMeta(req) require.Equal(t, 0, len(res)) time.Sleep(150 * time.Millisecond) - res = seq.getSeqNoPairs(req) + res = seq.getPacketsMeta(req) require.Equal(t, len(req), len(res)) for i, val := range res { require.Equal(t, val.targetSeqNo, req[i]) require.Equal(t, val.sourceSeqNo, req[i]-off) - require.Equal(t, val.layer, uint8(2)) + require.Equal(t, val.layer, int8(2)) } - s := seq.push(521, 521+off, 123, 1, true) - s.sourceSeqNo = 12 - m := seq.getSeqNoPairs([]uint16{521 + off}) + seq.push(521, 521+off, 123, 1) + m := seq.getPacketsMeta([]uint16{521 + off}) + require.Equal(t, 1, len(m)) + + seq.push(505, 505+off, 123, 1) + m = seq.getPacketsMeta([]uint16{505 + off}) require.Equal(t, 1, len(m)) } @@ -78,16 +84,16 @@ func Test_sequencer_getNACKSeqNo(t *testing.T) { n := newSequencer(500, logger.Logger(logger.GetLogger())) for _, i := range tt.fields.input { - n.push(i, i+tt.fields.offset, 123, 3, true) + n.push(i, i+tt.fields.offset, 123, 3) } - g := n.getSeqNoPairs(tt.args.seqNo) + g := n.getPacketsMeta(tt.args.seqNo) var got []uint16 for _, sn := range g { got = append(got, sn.sourceSeqNo) } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getSeqNoPairs() = %v, want %v", got, tt.want) + t.Errorf("getPacketsMeta() = %v, want %v", got, tt.want) } }) } diff --git a/pkg/sfu/streamallocator.go b/pkg/sfu/streamallocator.go index 98a17c5bd..4b8a89179 100644 --- a/pkg/sfu/streamallocator.go +++ b/pkg/sfu/streamallocator.go @@ -23,6 +23,11 @@ const ( NumRequiredEstimatesNonProbe = 8 NumRequiredEstimatesProbe = 3 + NackRatioThresholdNonProbe = 0.06 + NackRatioThresholdProbe = 0.04 + + NackRatioAttenuator = 0.4 // how much to attenuate NACK ratio while calculating loss adjusted estimate + ProbeWaitBase = 5 * time.Second ProbeBackoffFactor = 1.5 ProbeWaitMax = 30 * time.Second @@ -130,9 +135,7 @@ type StreamAllocator struct { bwe cc.BandwidthEstimator - lastReceivedEstimate int64 - estimator *Estimator - + lastReceivedEstimate int64 committedChannelCapacity int64 probeInterval time.Duration @@ -142,10 +145,12 @@ type StreamAllocator struct { abortedProbeClusterId ProbeClusterId probeTrendObserved bool probeEndTime time.Time - probeEstimator *Estimator + probeChannelObserver *ChannelObserver prober *Prober + channelObserver *ChannelObserver + audioTracks map[livekit.TrackID]*Track videoTracks map[livekit.TrackID]*Track exemptVideoTracksSorted TrackSorter @@ -161,14 +166,14 @@ type StreamAllocator struct { func NewStreamAllocator(params StreamAllocatorParams) *StreamAllocator { s := &StreamAllocator{ - params: params, - estimator: NewEstimator("non-probe", params.Logger, NumRequiredEstimatesNonProbe), - audioTracks: make(map[livekit.TrackID]*Track), - videoTracks: make(map[livekit.TrackID]*Track), + params: params, prober: NewProber(ProberParams{ Logger: params.Logger, }), - eventCh: make(chan Event, 20), + channelObserver: NewChannelObserver("non-probe", params.Logger, NumRequiredEstimatesNonProbe, NackRatioThresholdNonProbe), + audioTracks: make(map[livekit.TrackID]*Track), + videoTracks: make(map[livekit.TrackID]*Track), + eventCh: make(chan Event, 20), } s.resetState() @@ -639,7 +644,7 @@ func (s *StreamAllocator) handleSignalProbeClusterDone(event *Event) { // ensure probe queue is flushed // LK-TODO: ProbeSettleWait should actually be a certain number of RTTs. - lowestEstimate := int64(math.Min(float64(s.committedChannelCapacity), float64(s.probeEstimator.GetLowest()))) + lowestEstimate := int64(math.Min(float64(s.committedChannelCapacity), float64(s.probeChannelObserver.GetLowestEstimate()))) expectedDuration := float64(info.BytesSent*8*1000) / float64(lowestEstimate) queueTime := expectedDuration - float64(info.Duration.Milliseconds()) if queueTime < 0.0 { @@ -684,8 +689,13 @@ func (s *StreamAllocator) handleNewEstimate(receivedEstimate int64) { } func (s *StreamAllocator) handleNewEstimateInProbe() { - trend := s.probeEstimator.AddEstimate(s.lastReceivedEstimate) - if trend != EstimateTrendNeutral { + s.probeChannelObserver.AddEstimate(s.lastReceivedEstimate) + + packetDelta, repeatedNackDelta := s.getNackDelta() + s.probeChannelObserver.AddNack(packetDelta, repeatedNackDelta) + + trend := s.probeChannelObserver.GetTrend() + if trend != ChannelTrendNeutral { s.probeTrendObserved = true } switch { @@ -695,15 +705,15 @@ func (s *StreamAllocator) handleNewEstimateInProbe() { // // More of a safety net. // In rare cases, the estimate gets stuck. Prevent from probe running amok - // LK-TODO: Need more testing this here and ensure that probe does not cause a lot of damage + // LK-TODO: Need more testing here to ensure that probe does not cause a lot of damage // s.params.Logger.Debugw("probe: aborting, no trend") s.abortProbe() - case trend == EstimateTrendDownward: - // stop immediately if estimate falls below the previously committed estimate, the probe is congesting channel more - s.params.Logger.Debugw("probe: aborting, estimate is trending downward") + case trend == ChannelTrendCongesting: + // stop immediately if the probe is congesting channel more + s.params.Logger.Debugw("probe: aborting, channel is congesting") s.abortProbe() - case s.probeEstimator.GetHighest() > s.probeGoalBps: + case s.probeChannelObserver.GetHighestEstimate() > s.probeGoalBps: // reached goal, stop probing s.params.Logger.Debugw("probe: stopping, goal reached") s.stopProbe() @@ -711,20 +721,36 @@ func (s *StreamAllocator) handleNewEstimateInProbe() { } func (s *StreamAllocator) handleNewEstimateInNonProbe() { - trend := s.estimator.AddEstimate(s.lastReceivedEstimate) - if trend != EstimateTrendDownward { + s.channelObserver.AddEstimate(s.lastReceivedEstimate) + + packetDelta, repeatedNackDelta := s.getNackDelta() + s.channelObserver.AddNack(packetDelta, repeatedNackDelta) + + trend := s.channelObserver.GetTrend() + if trend != ChannelTrendCongesting { + return + } + + nackRatio := s.channelObserver.GetNackRatio() + lossAdjustedEstimate := s.lastReceivedEstimate + if nackRatio > NackRatioThresholdNonProbe { + lossAdjustedEstimate = int64(float64(lossAdjustedEstimate) * (1.0 - NackRatioAttenuator*nackRatio)) + } + if s.committedChannelCapacity == lossAdjustedEstimate { return } s.params.Logger.Infow( - "estimate trending down, updating channel capacity", + "channel congestion detected, updating channel capacity", "old(bps)", s.committedChannelCapacity, - "new(bps)", s.lastReceivedEstimate, + "new(bps)", lossAdjustedEstimate, + "lastReceived(bps)", s.lastReceivedEstimate, + "nackRatio", nackRatio, ) - s.committedChannelCapacity = s.lastReceivedEstimate + s.committedChannelCapacity = lossAdjustedEstimate // reset to get new set of samples for next trend - s.estimator.Reset() + s.channelObserver.Reset() // reset probe to ensure it does not start too soon after a downward trend s.resetProbe() @@ -832,7 +858,7 @@ func (s *StreamAllocator) allocateTrack(track *Track) { func (s *StreamAllocator) finalizeProbe() { aborted := s.probeClusterId == s.abortedProbeClusterId - highestEstimateInProbe := s.probeEstimator.GetHighest() + highestEstimateInProbe := s.probeChannelObserver.GetHighestEstimate() s.clearProbe() @@ -847,7 +873,7 @@ func (s *StreamAllocator) finalizeProbe() { // NOTE: With TWCC, it is possible to reset bandwidth estimation to clean state as // the send side is in full control of bandwidth estimation. // - s.estimator.Reset() + s.channelObserver.Reset() if aborted { // failed probe, backoff @@ -1018,6 +1044,18 @@ func (s *StreamAllocator) getExpectedBandwidthUsage() int64 { return expected } +func (s *StreamAllocator) getNackDelta() (uint32, uint32) { + aggPacketDelta := uint32(0) + aggRepeatedNackDelta := uint32(0) + for _, track := range s.videoTracks { + packetDelta, nackDelta := track.GetNackDelta() + aggPacketDelta += packetDelta + aggRepeatedNackDelta += nackDelta + } + + return aggPacketDelta, aggRepeatedNackDelta +} + // LK-TODO: unused till loss based estimation is done, but just a sample impl of weighting audio higher func (s *StreamAllocator) calculateLoss() float32 { packetsAudio := uint32(0) @@ -1062,8 +1100,8 @@ func (s *StreamAllocator) initProbe(goalBps int64) { s.probeEndTime = time.Time{} - s.probeEstimator = NewEstimator("probe", s.params.Logger, NumRequiredEstimatesProbe) - s.probeEstimator.Seed(s.lastReceivedEstimate) + s.probeChannelObserver = NewChannelObserver("probe", s.params.Logger, NumRequiredEstimatesProbe, NackRatioThresholdProbe) + s.probeChannelObserver.SeedEstimate(s.lastReceivedEstimate) } func (s *StreamAllocator) resetProbe() { @@ -1082,7 +1120,7 @@ func (s *StreamAllocator) clearProbe() { s.probeEndTime = time.Time{} - s.probeEstimator = nil + s.probeChannelObserver = nil } func (s *StreamAllocator) backoffProbeInterval() { @@ -1275,6 +1313,9 @@ type Track struct { lastPacketsLost uint32 maxLayers VideoLayers + + totalPackets uint32 + totalRepeatedNacks uint32 } func newTrack(downTrack *DownTrack, isManaged bool, publisherID livekit.ParticipantID, logger logger.Logger) *Track { @@ -1384,6 +1425,18 @@ func (t *Track) DistanceToDesired() int32 { return t.downTrack.DistanceToDesired() } +func (t *Track) GetNackDelta() (uint32, uint32) { + totalPackets, totalRepeatedNacks := t.downTrack.GetNackStats() + + packetDelta := totalPackets - t.totalPackets + t.totalPackets = totalPackets + + nackDelta := totalRepeatedNacks - t.totalRepeatedNacks + t.totalRepeatedNacks = totalRepeatedNacks + + return packetDelta, nackDelta +} + // ------------------------------------------------ type TrackSorter []*Track @@ -1438,116 +1491,234 @@ func (m MinDistanceSorter) Less(i, j int) bool { // ------------------------------------------------ -type EstimateTrend int +type ChannelTrend int const ( - EstimateTrendNeutral EstimateTrend = iota - EstimateTrendUpward - EstimateTrendDownward + ChannelTrendNeutral ChannelTrend = iota + ChannelTrendClearing + ChannelTrendCongesting ) -func (e EstimateTrend) String() string { - switch e { - case EstimateTrendNeutral: +func (c ChannelTrend) String() string { + switch c { + case ChannelTrendNeutral: return "NEUTRAL" - case EstimateTrendUpward: - return "UPWARD" - case EstimateTrendDownward: - return "DOWNWARD" + case ChannelTrendClearing: + return "CLEARING" + case ChannelTrendCongesting: + return "CONGESTING" default: - return fmt.Sprintf("%d", int(e)) + return fmt.Sprintf("%d", int(c)) } } -type Estimator struct { +type ChannelObserver struct { + name string + logger logger.Logger + + estimateTrend *TrendDetector + + nackRatioThreshold float64 + packets uint32 + repeatedNacks uint32 +} + +func NewChannelObserver( + name string, + logger logger.Logger, + estimateRequiredSamples int, + nackRatioThreshold float64, +) *ChannelObserver { + return &ChannelObserver{ + name: name, + logger: logger, + estimateTrend: NewTrendDetector(name+"-estimate", logger, estimateRequiredSamples), + nackRatioThreshold: nackRatioThreshold, + } +} + +func (c *ChannelObserver) Reset() { + c.estimateTrend.Reset() + + c.packets = 0 + c.repeatedNacks = 0 +} + +func (c *ChannelObserver) SeedEstimate(estimate int64) { + c.estimateTrend.Seed(estimate) +} + +func (c *ChannelObserver) SeedNack(packets uint32, repeatedNacks uint32) { + c.packets = packets + c.repeatedNacks = repeatedNacks +} + +func (c *ChannelObserver) AddEstimate(estimate int64) { + c.estimateTrend.AddValue(estimate) +} + +func (c *ChannelObserver) AddNack(packets uint32, repeatedNacks uint32) { + c.packets += packets + c.repeatedNacks += repeatedNacks +} + +func (c *ChannelObserver) GetLowestEstimate() int64 { + return c.estimateTrend.GetLowest() +} + +func (c *ChannelObserver) GetHighestEstimate() int64 { + return c.estimateTrend.GetHighest() +} + +func (c *ChannelObserver) GetNackRatio() float64 { + ratio := float64(0.0) + if c.packets != 0 { + ratio = float64(c.repeatedNacks) / float64(c.packets) + if ratio > 1.0 { + ratio = 1.0 + } + } + + return ratio +} + +func (c *ChannelObserver) GetTrend() ChannelTrend { + estimateDirection := c.estimateTrend.GetDirection() + nackRatio := c.GetNackRatio() + + switch { + case estimateDirection == TrendDirectionDownward: + c.logger.Debugw("channel observer: estimate is trending downward") + return ChannelTrendCongesting + case nackRatio > c.nackRatioThreshold: + c.logger.Debugw("channel observer: high rate of repeated NACKs", "ratio", nackRatio) + return ChannelTrendCongesting + case estimateDirection == TrendDirectionUpward: + return ChannelTrendClearing + } + + return ChannelTrendNeutral +} + +// ------------------------------------------------ + +type TrendDirection int + +const ( + TrendDirectionNeutral TrendDirection = iota + TrendDirectionUpward + TrendDirectionDownward +) + +func (t TrendDirection) String() string { + switch t { + case TrendDirectionNeutral: + return "NEUTRAL" + case TrendDirectionUpward: + return "UPWARD" + case TrendDirectionDownward: + return "DOWNWARD" + default: + return fmt.Sprintf("%d", int(t)) + } +} + +type TrendDetector struct { name string logger logger.Logger requiredSamples int - estimates []int64 - lowestEstimate int64 - highestEstimate int64 + values []int64 + lowestvalue int64 + highestvalue int64 + + direction TrendDirection } -func NewEstimator(name string, logger logger.Logger, requiredSamples int) *Estimator { - return &Estimator{ +func NewTrendDetector(name string, logger logger.Logger, requiredSamples int) *TrendDetector { + return &TrendDetector{ name: name, logger: logger, requiredSamples: requiredSamples, + direction: TrendDirectionNeutral, } } -func (e *Estimator) Reset() { - e.estimates = nil - e.lowestEstimate = int64(0) - e.highestEstimate = int64(0) +func (t *TrendDetector) Reset() { + t.values = nil + t.lowestvalue = int64(0) + t.highestvalue = int64(0) } -func (e *Estimator) Seed(estimate int64) { - if len(e.estimates) != 0 { +func (t *TrendDetector) Seed(value int64) { + if len(t.values) != 0 { return } - e.estimates = append(e.estimates, estimate) + t.values = append(t.values, value) } -func (e *Estimator) AddEstimate(estimate int64) EstimateTrend { - if e.lowestEstimate == 0 || estimate < e.lowestEstimate { - e.lowestEstimate = estimate +func (t *TrendDetector) AddValue(value int64) { + if t.lowestvalue == 0 || value < t.lowestvalue { + t.lowestvalue = value } - if estimate > e.highestEstimate { - e.highestEstimate = estimate + if value > t.highestvalue { + t.highestvalue = value } - if len(e.estimates) == e.requiredSamples { - e.estimates = e.estimates[1:] + if len(t.values) == t.requiredSamples { + t.values = t.values[1:] } - e.estimates = append(e.estimates, estimate) + t.values = append(t.values, value) - return e.updateTrend() + t.updateDirection() } -func (e *Estimator) GetLowest() int64 { - return e.lowestEstimate +func (t *TrendDetector) GetLowest() int64 { + return t.lowestvalue } -func (e *Estimator) GetHighest() int64 { - return e.highestEstimate +func (t *TrendDetector) GetHighest() int64 { + return t.highestvalue } -func (e *Estimator) updateTrend() EstimateTrend { - if len(e.estimates) < e.requiredSamples { - return EstimateTrendNeutral +func (t *TrendDetector) GetDirection() TrendDirection { + return t.direction +} + +func (t *TrendDetector) updateDirection() { + if len(t.values) < t.requiredSamples { + t.direction = TrendDirectionNeutral + return } // using Kendall's Tau to find trend concordantPairs := 0 discordantPairs := 0 - for i := 0; i < len(e.estimates)-1; i++ { - for j := i + 1; j < len(e.estimates); j++ { - if e.estimates[i] < e.estimates[j] { + for i := 0; i < len(t.values)-1; i++ { + for j := i + 1; j < len(t.values); j++ { + if t.values[i] < t.values[j] { concordantPairs++ - } else if e.estimates[i] > e.estimates[j] { + } else if t.values[i] > t.values[j] { discordantPairs++ } } } if (concordantPairs + discordantPairs) == 0 { - return EstimateTrendNeutral + t.direction = TrendDirectionNeutral + return } - trend := EstimateTrendNeutral + t.direction = TrendDirectionNeutral kt := (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs)) switch { case kt > 0: - trend = EstimateTrendUpward + t.direction = TrendDirectionUpward case kt < 0: - trend = EstimateTrendDownward + t.direction = TrendDirectionDownward } - - return trend } // ------------------------------------------------