Fix stutter in forwarding path when using dependency descriptor (#1600)

* Decode chains

* clean up

* clean up

* decode targets only on publisher side

* comment out supported codecs

* fix test compile

* fix another test compile

* Adding TODO notes

* chainID -> chainIdx

* do not need to check for switch up point when using chains, as long as chain integrity is good, can switch

* more comments

* address comments
This commit is contained in:
Raja Subramanian
2023-04-12 17:30:54 +05:30
committed by GitHub
parent 29e26931e0
commit 69fb5e51a2
11 changed files with 901 additions and 283 deletions
+3 -3
View File
@@ -42,7 +42,7 @@ type ExtPacket struct {
Payload interface{}
KeyFrame bool
RawPacket []byte
DependencyDescriptor *dd.DependencyDescriptor
DependencyDescriptor *DependencyDescriptorWithDecodeTarget
}
// Buffer contains all packets
@@ -551,7 +551,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPack
if err == nil && ddVal != nil {
ep.DependencyDescriptor = ddVal
ep.VideoLayer = videoLayer
// TODO : notify active decode target change if changed.
// DD-TODO : notify active decode target change if changed.
}
}
switch b.mime {
@@ -779,7 +779,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) {
return b.audioLevel.GetLevel()
}
// TODO : now we rely on stream tracker for layer change, dependency still
// DD-TODO : now we rely on stream tracker for layer change, dependency still
// work for that too. Do we keep it unchanged or use both methods?
func (b *Buffer) OnMaxLayerChanged(fn func(int32, int32)) {
b.maxLayerChangedCB = fn
+129 -77
View File
@@ -2,6 +2,7 @@ package buffer
import (
"fmt"
"sort"
"github.com/pion/rtp"
@@ -12,91 +13,142 @@ import (
type DependencyDescriptorParser struct {
structure *dd.FrameDependencyStructure
ddExt uint8
ddExtID uint8
logger logger.Logger
onMaxLayerChanged func(int32, int32)
decodeTargetLayer []VideoLayer
decodeTargets []DependencyDescriptorDecodeTarget
}
func NewDependencyDescriptorParser(ddExt uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser {
logger.Infow("creating dependency descriptor parser", "ddExt", ddExt)
func NewDependencyDescriptorParser(ddExtID uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser {
logger.Infow("creating dependency descriptor parser", "ddExtID", ddExtID)
return &DependencyDescriptorParser{
ddExt: ddExt,
ddExtID: ddExtID,
logger: logger,
onMaxLayerChanged: onMaxLayerChanged,
}
}
func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*dd.DependencyDescriptor, VideoLayer, error) {
var videoLayer VideoLayer
if ddBuf := pkt.GetExtension(r.ddExt); ddBuf != nil {
var ddVal dd.DependencyDescriptor
ext := &dd.DependencyDescriptorExtension{
Descriptor: &ddVal,
Structure: r.structure,
}
_, err := ext.Unmarshal(ddBuf)
if err != nil {
// r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
return nil, videoLayer, err
}
if ddVal.FrameDependencies != nil {
videoLayer.Spatial, videoLayer.Temporal = int32(ddVal.FrameDependencies.SpatialId), int32(ddVal.FrameDependencies.TemporalId)
}
if ddVal.AttachedStructure != nil && !ddVal.FirstPacketInFrame {
// r.logger.Debugw("ignoring non-first packet in frame with attached structure")
return nil, videoLayer, nil
}
if ddVal.AttachedStructure != nil {
var maxSpatial, maxTemporal int32
r.structure = ddVal.AttachedStructure
r.decodeTargetLayer = r.decodeTargetLayer[:0]
for target := 0; target < r.structure.NumDecodeTargets; target++ {
layer := VideoLayer{0, 0}
for _, t := range r.structure.Templates {
if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent {
if layer.Spatial < int32(t.SpatialId) {
layer.Spatial = int32(t.SpatialId)
}
if layer.Temporal < int32(t.TemporalId) {
layer.Temporal = int32(t.TemporalId)
}
}
}
if layer.Spatial > maxSpatial {
maxSpatial = layer.Spatial
}
if layer.Temporal > maxTemporal {
maxTemporal = layer.Temporal
}
r.decodeTargetLayer = append(r.decodeTargetLayer, layer)
}
r.logger.Debugw("max layer changed", "maxSpatial", maxSpatial, "maxTemporal", maxTemporal)
go r.onMaxLayerChanged(maxSpatial, maxTemporal)
}
if ddVal.AttachedStructure != nil && ddVal.FirstPacketInFrame {
r.logger.Debugw(fmt.Sprintf("parsed dependency descriptor\n%s", ddVal.String()))
}
if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil {
var maxSpatial, maxTemporal int32
for dt, layer := range r.decodeTargetLayer {
if *mask&(1<<dt) != uint32(dd.DecodeTargetNotPresent) {
if maxSpatial < layer.Spatial {
maxSpatial = layer.Spatial
}
if maxTemporal < layer.Temporal {
maxTemporal = layer.Temporal
}
}
}
r.logger.Debugw("max layer changed", "maxSpatial", maxSpatial, "maxTemporal", maxTemporal)
r.onMaxLayerChanged(maxSpatial, maxTemporal)
}
return &ddVal, videoLayer, nil
}
return nil, videoLayer, nil
type DependencyDescriptorWithDecodeTarget struct {
Descriptor *dd.DependencyDescriptor
DecodeTargets []DependencyDescriptorDecodeTarget
}
func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*DependencyDescriptorWithDecodeTarget, VideoLayer, error) {
// DD-TODO: make sure out-of-order RTP packets do not update decode targets
var videoLayer VideoLayer
ddBuf := pkt.GetExtension(r.ddExtID)
if ddBuf == nil {
return nil, videoLayer, nil
}
var ddVal dd.DependencyDescriptor
ext := &dd.DependencyDescriptorExtension{
Descriptor: &ddVal,
Structure: r.structure,
}
_, err := ext.Unmarshal(ddBuf)
if err != nil {
// r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
return nil, videoLayer, err
}
if ddVal.FrameDependencies != nil {
videoLayer.Spatial, videoLayer.Temporal = int32(ddVal.FrameDependencies.SpatialId), int32(ddVal.FrameDependencies.TemporalId)
}
if ddVal.AttachedStructure != nil && !ddVal.FirstPacketInFrame {
// r.logger.Debugw("ignoring non-first packet in frame with attached structure")
return nil, videoLayer, nil
}
if ddVal.AttachedStructure != nil {
r.structure = ddVal.AttachedStructure
r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure)
if len(r.decodeTargets) != 0 {
r.logger.Debugw(fmt.Sprintf("update decode targets: %v", r.decodeTargets))
r.onMaxLayerChanged(r.decodeTargets[0].Layer.Spatial, r.decodeTargets[0].Layer.Temporal)
}
}
if ddVal.AttachedStructure != nil && ddVal.FirstPacketInFrame {
r.logger.Debugw(fmt.Sprintf("parsed dependency descriptor\n%s", ddVal.String()))
}
if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil {
var maxSpatial, maxTemporal int32
for _, dt := range r.decodeTargets {
if *mask&(1<<dt.Target) != uint32(dd.DecodeTargetNotPresent) {
if maxSpatial < dt.Layer.Spatial {
maxSpatial = dt.Layer.Spatial
}
if maxTemporal < dt.Layer.Temporal {
maxTemporal = dt.Layer.Temporal
}
}
}
r.logger.Debugw("max layer changed", "maxSpatial", maxSpatial, "maxTemporal", maxTemporal)
r.onMaxLayerChanged(maxSpatial, maxTemporal)
}
withDecodeTargets := &DependencyDescriptorWithDecodeTarget{
Descriptor: &ddVal,
DecodeTargets: r.decodeTargets,
}
return withDecodeTargets, videoLayer, nil
}
// ------------------------------------------------------------------------------
type DependencyDescriptorDecodeTarget struct {
Target int
Layer VideoLayer
}
func ProcessFrameDependencyStructure(structure *dd.FrameDependencyStructure) []DependencyDescriptorDecodeTarget {
decodeTargets := make([]DependencyDescriptorDecodeTarget, 0, structure.NumDecodeTargets)
for target := 0; target < structure.NumDecodeTargets; target++ {
layer := VideoLayer{Spatial: 0, Temporal: 0}
for _, t := range structure.Templates {
if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent {
if layer.Spatial < int32(t.SpatialId) {
layer.Spatial = int32(t.SpatialId)
}
if layer.Temporal < int32(t.TemporalId) {
layer.Temporal = int32(t.TemporalId)
}
}
}
decodeTargets = append(decodeTargets, DependencyDescriptorDecodeTarget{target, layer})
}
// sort decode target layer by spatial and temporal from high to low
sort.Slice(decodeTargets, func(i, j int) bool {
return decodeTargets[i].Layer.GreaterThan(decodeTargets[j].Layer)
})
return decodeTargets
}
func GetActiveDecodeTargetBitmask(layer VideoLayer, decodeTargets []DependencyDescriptorDecodeTarget) *uint32 {
activeBitMask := uint32(0)
var maxSpatial, maxTemporal int32
for _, dt := range decodeTargets {
if dt.Layer.Spatial > maxSpatial {
maxSpatial = dt.Layer.Spatial
}
if dt.Layer.Temporal > maxTemporal {
maxTemporal = dt.Layer.Temporal
}
if dt.Layer.Spatial <= layer.Spatial && dt.Layer.Temporal <= layer.Temporal {
activeBitMask |= 1 << dt.Target
}
}
if layer.Spatial == maxSpatial && layer.Temporal == maxTemporal {
// all the decode targets are selected
return nil
}
return &activeBitMask
}
// ------------------------------------------------------------------------------
+4 -4
View File
@@ -7,7 +7,7 @@ import (
"github.com/pion/rtp/codecs"
)
var minFramesForCalculation = [DefaultMaxLayerTemporal + 1]int{8, 15, 40}
var minFramesForCalculation = [...]int{8, 15, 40}
type frameInfo struct {
seq uint16
@@ -357,7 +357,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool {
return false
}
fn := ep.DependencyDescriptor.FrameNumber
fn := ep.DependencyDescriptor.Descriptor.FrameNumber
if f.baseFrame == nil {
f.baseFrame = &frameInfo{seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn}
f.fnReceived[0] = f.baseFrame
@@ -397,7 +397,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool {
fn: fn,
temporal: temporal,
spatial: spatial,
frameDiff: ep.DependencyDescriptor.FrameDependencies.FrameDiffs,
frameDiff: ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs,
}
f.fnReceived[baseDiff] = fi
@@ -411,7 +411,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool {
if chain.Len() == 0 {
chain.PushBack(fn)
}
for _, fdiff := range ep.DependencyDescriptor.FrameDependencies.FrameDiffs {
for _, fdiff := range ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs {
dependFrame := fn - uint16(fdiff)
// frame too old, ignore
if dependFrame-f.secondFrames[spatial][temporal].fn > 0x8000 {
+6 -4
View File
@@ -31,10 +31,12 @@ func (f *testFrameInfo) toVP8() *ExtPacket {
func (f *testFrameInfo) toDD() *ExtPacket {
return &ExtPacket{
Packet: &rtp.Packet{Header: f.header},
DependencyDescriptor: &dependencydescriptor.DependencyDescriptor{
FrameNumber: f.framenumber,
FrameDependencies: &dependencydescriptor.FrameDependencyTemplate{
FrameDiffs: f.frameDiff,
DependencyDescriptor: &DependencyDescriptorWithDecodeTarget{
Descriptor: &dependencydescriptor.DependencyDescriptor{
FrameNumber: f.framenumber,
FrameDependencies: &dependencydescriptor.FrameDependencyTemplate{
FrameDiffs: f.frameDiff,
},
},
},
VideoLayer: VideoLayer{Spatial: int32(f.spatial), Temporal: int32(f.temporal)},
@@ -9,23 +9,20 @@ import (
// DependencyDescriptorExtension is a extension payload format in
// https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension
func formatBitmask(b *uint32) string {
if b == nil {
return "-"
}
return strconv.FormatInt(int64(*b), 2)
}
// ------------------------------------------------------------------------------
type DependencyDescriptorExtension struct {
Descriptor *DependencyDescriptor
Structure *FrameDependencyStructure
}
func (d *DependencyDescriptor) MarshalSize() (int, error) {
return d.MarshalSizeWithActiveChains(^uint32(0))
}
func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) {
writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d)
if err != nil {
return 0, err
}
return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil
}
func (d *DependencyDescriptorExtension) Marshal() ([]byte, error) {
return d.MarshalWithActiveChains(^uint32(0))
}
@@ -48,6 +45,8 @@ func (d *DependencyDescriptorExtension) Unmarshal(buf []byte) (int, error) {
return reader.Parse()
}
// ------------------------------------------------------------------------------
const (
MaxSpatialIds = 4
MaxTemporalIds = 8
@@ -59,9 +58,11 @@ const (
ExtensionUrl = "https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension"
)
// ------------------------------------------------------------------------------
type DependencyDescriptor struct {
FirstPacketInFrame bool // = true;
LastPacketInFrame bool // = true;
FirstPacketInFrame bool
LastPacketInFrame bool
FrameNumber uint16
FrameDependencies *FrameDependencyTemplate
Resolution *RenderResolution
@@ -69,11 +70,16 @@ type DependencyDescriptor struct {
AttachedStructure *FrameDependencyStructure
}
func formatBitmask(b *uint32) string {
if b == nil {
return "-"
func (d *DependencyDescriptor) MarshalSize() (int, error) {
return d.MarshalSizeWithActiveChains(^uint32(0))
}
func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) {
writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d)
if err != nil {
return 0, err
}
return strconv.FormatInt(int64(*b), 2)
return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil
}
func (d *DependencyDescriptor) String() string {
@@ -81,6 +87,8 @@ func (d *DependencyDescriptor) String() string {
d.FirstPacketInFrame, d.LastPacketInFrame, d.FrameNumber, *d.FrameDependencies, *d.Resolution, formatBitmask(d.ActiveDecodeTargetsBitmask), d.AttachedStructure)
}
// ------------------------------------------------------------------------------
// Relationship of a frame to a Decode target.
type DecodeTargetIndication int
@@ -106,6 +114,8 @@ func (i DecodeTargetIndication) String() string {
}
}
// ------------------------------------------------------------------------------
type FrameDependencyTemplate struct {
SpatialId int
TemporalId int
@@ -132,6 +142,8 @@ func (t *FrameDependencyTemplate) Clone() *FrameDependencyTemplate {
return t2
}
// ------------------------------------------------------------------------------
type FrameDependencyStructure struct {
StructureId int
NumDecodeTargets int
@@ -156,7 +168,11 @@ func (f *FrameDependencyStructure) String() string {
return str
}
// ------------------------------------------------------------------------------
type RenderResolution struct {
Width int
Height int
}
// ------------------------------------------------------------------------------
@@ -1,6 +1,8 @@
package dependencydescriptor
import "errors"
import (
"errors"
)
type DependencyDescriptorReader struct {
// Output.
@@ -89,7 +91,6 @@ func (r *DependencyDescriptorReader) readMandatoryFields() error {
}
func (r *DependencyDescriptorReader) readExtendedFields() error {
templateDependencyStructurePresentFlag, err := r.buffer.ReadBool()
if err != nil {
return err
@@ -384,7 +385,6 @@ func (r *DependencyDescriptorReader) readFrameDtis() error {
}
func (r *DependencyDescriptorReader) readFrameFdiffs() error {
r.descriptor.FrameDependencies.FrameDiffs = r.descriptor.FrameDependencies.FrameDiffs[:0]
for {
nexFdiffSize, err := r.buffer.ReadBits(2)
+1 -1
View File
@@ -1448,7 +1448,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
}
var extraExtensions []extensionData
if len(meta.ddBytes) > 0 {
if d.dependencyDescriptorID != 0 && len(meta.ddBytes) != 0 {
extraExtensions = append(extraExtensions, extensionData{
id: uint8(d.dependencyDescriptorID),
payload: meta.ddBytes,
+121
View File
@@ -0,0 +1,121 @@
package utils
import (
"unsafe"
)
type number interface {
uint16 | uint32
}
type extendedNumber interface {
uint32 | uint64
}
type WrapAround[T number, ET extendedNumber] struct {
fullRange ET
initialized bool
start T
highest T
cycles int
}
func NewWrapAround[T number, ET extendedNumber]() *WrapAround[T, ET] {
var t T
return &WrapAround[T, ET]{
fullRange: 1 << (unsafe.Sizeof(t) * 8),
}
}
func (w *WrapAround[T, ET]) Seed(from *WrapAround[T, ET]) {
w.initialized = from.initialized
w.start = from.start
w.highest = from.highest
w.cycles = from.cycles
}
type wrapAroundUpdateResult[ET extendedNumber] struct {
IsRestart bool
PreExtendedStart ET // valid only if IsRestart = true
PreExtendedHighest ET
ExtendedVal ET
}
func (w *WrapAround[T, ET]) Update(val T) (result wrapAroundUpdateResult[ET]) {
if !w.initialized {
result.PreExtendedHighest = ET(val) - 1
result.ExtendedVal = ET(val)
w.start = val
w.highest = val
w.initialized = true
return
}
result.PreExtendedHighest = w.GetExtendedHighest()
gap := val - w.highest
if gap == 0 || gap > T(w.fullRange>>1) {
// duplicate OR out-of-order
result.IsRestart, result.PreExtendedStart, result.ExtendedVal = w.maybeAdjustStart(val)
return
}
// in-order
if val < w.highest {
w.cycles++
}
w.highest = val
result.ExtendedVal = ET(w.cycles)*w.fullRange + ET(val)
return
}
func (w *WrapAround[T, ET]) ResetHighest(val T) {
w.highest = val
}
func (w *WrapAround[T, ET]) GetStart() T {
return w.start
}
func (w *WrapAround[T, ET]) GetExtendedStart() ET {
return ET(w.start)
}
func (w *WrapAround[T, ET]) GetHighest() T {
return w.highest
}
func (w *WrapAround[T, ET]) GetExtendedHighest() ET {
return ET(w.cycles)*w.fullRange + ET(w.highest)
}
func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtendedStart ET, extendedVal ET) {
// re-adjust start if necessary. The conditions are
// 1. Not seen more than half the range yet
// 1. wrap around compared to start and not completed a half cycle, sequences like (10, 65530) in uint16 space
// 2. no wrap around, but out-of-order compared to start and not completed a half cycle , sequences like (10, 9), (65530, 65528) in uint16 space
totalNum := w.GetExtendedHighest() - w.GetExtendedStart() + 1
if totalNum > (w.fullRange >> 1) {
extendedVal = ET(w.cycles)*w.fullRange + ET(val)
return
}
cycles := w.cycles
if val-w.start > T(w.fullRange>>1) {
// out-of-order with existing start => a new start
isRestart = true
preExtendedStart = w.GetExtendedStart()
if val > w.highest {
// wrap around
w.cycles = 1
cycles = 0
}
w.start = val
}
extendedVal = ET(cycles)*w.fullRange + ET(val)
return
}
+295
View File
@@ -0,0 +1,295 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestWrapAroundUint16(t *testing.T) {
w := NewWrapAround[uint16, uint32]()
testCases := []struct {
name string
input uint16
updated wrapAroundUpdateResult[uint32]
start uint16
extendedStart uint32
highest uint16
extendedHighest uint32
}{
// initialize
{
name: "initialize",
input: 10,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: 9,
ExtendedVal: 10,
},
start: 10,
extendedStart: 10,
highest: 10,
extendedHighest: 10,
},
// an older number without wrap around should reset start point
{
name: "reset start no wrap around",
input: 8,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: true,
PreExtendedStart: 10,
PreExtendedHighest: 10,
ExtendedVal: 8,
},
start: 8,
extendedStart: 8,
highest: 10,
extendedHighest: 10,
},
// an older number with wrap around should reset start point
{
name: "reset start wrap around",
input: (1 << 16) - 6,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: true,
PreExtendedStart: 8,
PreExtendedHighest: 10,
ExtendedVal: (1 << 16) - 6,
},
start: (1 << 16) - 6,
extendedStart: (1 << 16) - 6,
highest: 10,
extendedHighest: (1 << 16) + 10,
},
// an older number with wrap around should reset start point again
{
name: "reset start again",
input: (1 << 16) - 12,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: true,
PreExtendedStart: (1 << 16) - 6,
PreExtendedHighest: (1 << 16) + 10,
ExtendedVal: (1 << 16) - 12,
},
start: (1 << 16) - 12,
extendedStart: (1 << 16) - 12,
highest: 10,
extendedHighest: (1 << 16) + 10,
},
// duplicate should return same as highest
{
name: "duplicate",
input: 10,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 16) + 10,
ExtendedVal: (1 << 16) + 10,
},
start: (1 << 16) - 12,
extendedStart: (1 << 16) - 12,
highest: 10,
extendedHighest: (1 << 16) + 10,
},
// a significant jump in order should not reset start
{
name: "big in-order jump",
input: 1 << 15,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 16) + 10,
ExtendedVal: (1 << 16) + (1 << 15),
},
start: (1 << 16) - 12,
extendedStart: (1 << 16) - 12,
highest: 1 << 15,
extendedHighest: (1 << 16) + (1 << 15),
},
// now out-of-order should not reset start as half the range has been seen
{
name: "out-of-order after half range",
input: (1 << 15) - 1,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 16) + (1 << 15),
ExtendedVal: (1 << 16) + (1 << 15) - 1,
},
start: (1 << 16) - 12,
extendedStart: (1 << 16) - 12,
highest: 1 << 15,
extendedHighest: (1 << 16) + (1 << 15),
},
// in-order, should update highest
{
name: "in-order",
input: (1 << 15) + 3,
updated: wrapAroundUpdateResult[uint32]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 16) + (1 << 15),
ExtendedVal: (1 << 16) + (1 << 15) + 3,
},
start: (1 << 16) - 12,
extendedStart: (1 << 16) - 12,
highest: (1 << 15) + 3,
extendedHighest: (1 << 16) + (1 << 15) + 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.updated, w.Update(tc.input))
require.Equal(t, tc.start, w.GetStart())
require.Equal(t, tc.extendedStart, w.GetExtendedStart())
require.Equal(t, tc.highest, w.GetHighest())
require.Equal(t, tc.extendedHighest, w.GetExtendedHighest())
})
}
}
func TestWrapAroundUint32(t *testing.T) {
w := NewWrapAround[uint32, uint64]()
testCases := []struct {
name string
input uint32
updated wrapAroundUpdateResult[uint64]
start uint32
extendedStart uint64
highest uint32
extendedHighest uint64
}{
// initialize
{
name: "initialize",
input: 10,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: 9,
ExtendedVal: 10,
},
start: 10,
extendedStart: 10,
highest: 10,
extendedHighest: 10,
},
// an older number without wrap around should reset start point
{
name: "reset start no wrap around",
input: 8,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: true,
PreExtendedStart: 10,
PreExtendedHighest: 10,
ExtendedVal: 8,
},
start: 8,
extendedStart: 8,
highest: 10,
extendedHighest: 10,
},
// an older number with wrap around should reset start point
{
name: "reset start wrap around",
input: (1 << 32) - 6,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: true,
PreExtendedStart: 8,
PreExtendedHighest: 10,
ExtendedVal: (1 << 32) - 6,
},
start: (1 << 32) - 6,
extendedStart: (1 << 32) - 6,
highest: 10,
extendedHighest: (1 << 32) + 10,
},
// an older number with wrap around should reset start point again
{
name: "reset start again",
input: (1 << 32) - 12,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: true,
PreExtendedStart: (1 << 32) - 6,
PreExtendedHighest: (1 << 32) + 10,
ExtendedVal: (1 << 32) - 12,
},
start: (1 << 32) - 12,
extendedStart: (1 << 32) - 12,
highest: 10,
extendedHighest: (1 << 32) + 10,
},
// duplicate should return same as highest
{
name: "duplicate",
input: 10,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 32) + 10,
ExtendedVal: (1 << 32) + 10,
},
start: (1 << 32) - 12,
extendedStart: (1 << 32) - 12,
highest: 10,
extendedHighest: (1 << 32) + 10,
},
// a significant jump in order should not reset start
{
name: "big in-order jump",
input: 1 << 31,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 32) + 10,
ExtendedVal: (1 << 32) + (1 << 31),
},
start: (1 << 32) - 12,
extendedStart: (1 << 32) - 12,
highest: 1 << 31,
extendedHighest: (1 << 32) + (1 << 31),
},
// now out-of-order should not reset start as half the range has been seen
{
name: "out-of-order after half range",
input: (1 << 31) - 1,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 32) + (1 << 31),
ExtendedVal: (1 << 32) + (1 << 31) - 1,
},
start: (1 << 32) - 12,
extendedStart: (1 << 32) - 12,
highest: 1 << 31,
extendedHighest: (1 << 32) + (1 << 31),
},
// in-order, should update highest
{
name: "in-order",
input: (1 << 31) + 3,
updated: wrapAroundUpdateResult[uint64]{
IsRestart: false,
PreExtendedStart: 0,
PreExtendedHighest: (1 << 32) + (1 << 31),
ExtendedVal: (1 << 32) + (1 << 31) + 3,
},
start: (1 << 32) - 12,
extendedStart: (1 << 32) - 12,
highest: (1 << 31) + 3,
extendedHighest: (1 << 32) + (1 << 31) + 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.updated, w.Update(tc.input))
require.Equal(t, tc.start, w.GetStart())
require.Equal(t, tc.extendedStart, w.GetExtendedStart())
require.Equal(t, tc.highest, w.GetHighest())
require.Equal(t, tc.extendedHighest, w.GetExtendedHighest())
})
}
}
+191 -173
View File
@@ -2,39 +2,37 @@ package videolayerselector
import (
"fmt"
"sort"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
dede "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
"github.com/livekit/livekit-server/pkg/sfu/utils"
"github.com/livekit/protocol/logger"
)
type decodeTarget struct {
Target int
Layer buffer.VideoLayer
}
type DependencyDescriptor struct {
*Base
// DD-TODO : fields for frame chain detect
// frameNumberWrapper Uint16Wrapper
// expectKeyFrame bool
frameNum *utils.WrapAround[uint16, uint64]
decisions *SelectorDecisionCache
decodeTargets []decodeTarget
needsDecodeTargetBitmask bool
activeDecodeTargetsBitmask *uint32
structure *dd.FrameDependencyStructure
structure *dede.FrameDependencyStructure
}
func NewDependencyDescriptor(logger logger.Logger) *DependencyDescriptor {
return &DependencyDescriptor{
Base: NewBase(logger),
Base: NewBase(logger),
frameNum: utils.NewWrapAround[uint16, uint64](),
decisions: NewSelectorDecisionCache(256),
}
}
func NewDependencyDescriptorFromNull(vls VideoLayerSelector) *DependencyDescriptor {
return &DependencyDescriptor{
Base: vls.(*Null).Base,
Base: vls.(*Null).Base,
frameNum: utils.NewWrapAround[uint16, uint64](),
decisions: NewSelectorDecisionCache(256),
}
}
@@ -43,21 +41,80 @@ func (d *DependencyDescriptor) IsOvershootOkay() bool {
}
func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) {
if extPkt.DependencyDescriptor == nil {
// packet don't have dependency descriptor
ddwdt := extPkt.DependencyDescriptor
if ddwdt == nil {
// packet doesn't have dependency descriptor
return
}
dd := ddwdt.Descriptor
// a packet is relevant as long as it has DD extension
result.IsRelevant = true
frameNum := d.frameNum.Update(dd.FrameNumber)
extFrameNum := frameNum.ExtendedVal
fd := dd.FrameDependencies
incomingLayer := buffer.VideoLayer{
Spatial: int32(fd.SpatialId),
Temporal: int32(fd.TemporalId),
}
// early return if this frame is already forwarded or dropped
sd, err := d.decisions.GetDecision(extFrameNum)
if err != nil {
// do not mark as dropped as only error is an old frame
return
}
switch sd {
case selectorDecisionForwarded:
// a packet of an alreadty forwarded frame, maintain decision
result.RTPMarker = extPkt.Packet.Header.Marker || (dd.LastPacketInFrame && d.currentLayer.Spatial == int32(fd.SpatialId))
result.IsSelected = true
case selectorDecisionDropped:
// a packet of an alreadty dropped frame, maintain decision
return
}
if !d.currentLayer.IsValid() && !extPkt.KeyFrame {
d.decisions.AddDropped(extFrameNum)
return
}
result.IsRelevant = true
// check decodability using reference frames
isDecodable := true
for _, fdiff := range fd.FrameDiffs {
if fdiff == 0 {
continue
}
if extPkt.DependencyDescriptor.AttachedStructure != nil {
if sd, _ := d.decisions.GetDecision(extFrameNum - uint64(fdiff)); sd != selectorDecisionForwarded {
isDecodable = false
break
}
}
if !isDecodable {
// DD-TODO START
// Not decodable could happen due to packet loss or out-of-order packets,
// Need to figure out better ways to handle this.
//
// 1. Should definitely check if this frame is not part of current decode target OR discardable.
// In that case, forwarding can proceed without disruption.
// 2. Add a packet queue and try to de-jitter for some time. Safest is to packet copy to local queue on
// all down tracks.
// 3. Force a PLI and wait for a key frame.
// DD-TODO END
d.decisions.AddDropped(extFrameNum)
return
}
// DD-TODO should not update for out-of-order RTP packets
if dd.AttachedStructure != nil {
// update decode target layer and active decode targets
// DD-TODO : these targets info can be shared by all the downtracks, no need calculate in every selector
d.updateDependencyStructure(extPkt.DependencyDescriptor.AttachedStructure)
d.updateDependencyStructure(dd.AttachedStructure)
}
// DD-TODO : we don't have a rtp queue to ensure the order of packets now,
@@ -67,133 +124,144 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
// only check DTI of the active decode target.
// it is not effeciency, at last we need check frame chain integrity.
activeDecodeTargets := extPkt.DependencyDescriptor.ActiveDecodeTargetsBitmask
activeDecodeTargets := dd.ActiveDecodeTargetsBitmask
if activeDecodeTargets != nil {
d.logger.Debugw("active decode targets", "activeDecodeTargets", *activeDecodeTargets)
}
currentTarget := -1
for _, dt := range d.decodeTargets {
// find target match with selected layer
if dt.Layer.Spatial <= d.targetLayer.Spatial && dt.Layer.Temporal <= d.targetLayer.Temporal {
if activeDecodeTargets == nil || ((*activeDecodeTargets)&(1<<dt.Target) != 0) {
// DD-TODO : check frame chain integrity
currentTarget = dt.Target
// d.logger.Debugw("select target", "target", currentTarget, "layer", dt.Target, "dtis", extPkt.DependencyDescriptor.FrameDependencies.DecodeTargetIndications)
break
// find decode target closest to targetLayer
highestDecodeTarget := buffer.DependencyDescriptorDecodeTarget{
Target: -1,
Layer: buffer.InvalidLayer,
}
for _, dt := range ddwdt.DecodeTargets {
if dt.Layer.Spatial > d.targetLayer.Spatial || dt.Layer.Temporal > d.targetLayer.Temporal {
continue
}
if activeDecodeTargets != nil && ((*activeDecodeTargets)&(1<<dt.Target) == 0) {
continue
}
if len(d.structure.DecodeTargetProtectedByChain) == 0 {
highestDecodeTarget = dt
//d.logger.Debugw("select target", "highestDecodeTarget", highestDecodeTarget, "dtis", fd.DecodeTargetIndications)
break
}
if len(d.structure.DecodeTargetProtectedByChain) < dt.Target {
// look for lower target
continue
}
chainIdx := d.structure.DecodeTargetProtectedByChain[dt.Target]
if len(fd.ChainDiffs) < chainIdx {
// look for lower target
continue
}
prevFrameInChain := extFrameNum - uint64(fd.ChainDiffs[chainIdx])
if prevFrameInChain != 0 && prevFrameInChain != extFrameNum {
if sd, err := d.decisions.GetDecision(prevFrameInChain); err != nil || sd != selectorDecisionForwarded {
// look for lower target
continue
}
}
highestDecodeTarget = dt
//d.logger.Debugw("select target", "highestDecodeTarget", highestDecodeTarget, "dtis", fd.DecodeTargetIndications)
break
}
if currentTarget < 0 {
//d.logger.Debugw(fmt.Sprintf("drop packet for no target found, decodeTargets %v, tagetLayer %v, s:%d, t:%d",
if highestDecodeTarget.Target < 0 {
// no active decode target, do not select
//d.logger.Debugw(fmt.Sprintf("drop packet for no target found, decodeTargets %v, tagetLayer %v, incoming %v",
//d.decodeTargets,
//d.targetLayer,
//extPkt.DependencyDescriptor.FrameDependencies.SpatialId,
//extPkt.DependencyDescriptor.FrameDependencies.TemporalId,
//incomingLayer,
//))
// no active decode target, do not select
d.decisions.AddDropped(extFrameNum)
return
}
dtis := extPkt.DependencyDescriptor.FrameDependencies.DecodeTargetIndications
if len(dtis) < currentTarget {
dtis := fd.DecodeTargetIndications
if len(dtis) < highestDecodeTarget.Target {
// dtis error, dependency descriptor might lost
d.logger.Debugw(fmt.Sprintf("drop packet for dtis error, dtis %v, currentTarget %d, s:%d, t:%d",
d.logger.Debugw(fmt.Sprintf("drop packet for dtis error, dtis %v, highestDecodeTarget %+v, incoming: %v",
dtis,
currentTarget,
extPkt.DependencyDescriptor.FrameDependencies.SpatialId,
extPkt.DependencyDescriptor.FrameDependencies.TemporalId,
highestDecodeTarget,
incomingLayer,
))
d.decisions.AddDropped(extFrameNum)
return
}
// DD-TODO : if bandwidth in congest, could drop the 'Discardable' packet
dti := dtis[currentTarget]
if dti == dd.DecodeTargetNotPresent {
//d.logger.Debugw(fmt.Sprintf("drop packet for decode target not present, dtis %v, currentTarget %d, s:%d, t:%d",
dti := dtis[highestDecodeTarget.Target]
if dti == dede.DecodeTargetNotPresent {
//d.logger.Debugw(fmt.Sprintf("drop packet for decode target not present, dtis %v, highestDecodeTarget %d, incoming %v, fn: %d/%d",
//dtis,
//currentTarget,
//extPkt.DependencyDescriptor.FrameDependencies.SpatialId,
//extPkt.DependencyDescriptor.FrameDependencies.TemporalId,
//highestDecodeTarget,
//incomingLayer,
//dd.FrameNumber,
//extFrameNum,
//))
d.decisions.AddDropped(extFrameNum)
return
}
if dti == dd.DecodeTargetSwitch {
// dependency descriptor decode target switch is enabled at all potential switch points.
// So, setting current layer on every switch point will change current layer a lot.
//
// However `currentLayer` is not needed for layer selection in this selector.
// But, it is needed to signal things in the selector checks outside of this selector.
//
// The following cases are handled
// 1. To detect resumption
// 2. To detect target achieved so that key frame requests can be stopped
// 3. To detect reaching max spatial layer - checked when current hits target
if d.currentLayer != highestDecodeTarget.Layer {
if !d.currentLayer.IsValid() {
result.IsResuming = true
d.currentLayer = buffer.VideoLayer{
Spatial: int32(extPkt.DependencyDescriptor.FrameDependencies.SpatialId),
Temporal: int32(extPkt.DependencyDescriptor.FrameDependencies.TemporalId),
}
d.logger.Infow(
"resuming at layer",
"current", d.currentLayer,
"current", incomingLayer,
"target", d.targetLayer,
"max", d.maxLayer,
"layer", extPkt.DependencyDescriptor.FrameDependencies.SpatialId,
"layer", fd.SpatialId,
"req", d.requestSpatial,
"maxSeen", d.maxSeenLayer,
"feed", extPkt.Packet.SSRC,
)
}
if d.currentLayer != d.targetLayer {
if d.currentLayer.Spatial != d.targetLayer.Spatial && int32(extPkt.DependencyDescriptor.FrameDependencies.SpatialId) == d.targetLayer.Spatial {
d.currentLayer.Spatial = d.targetLayer.Spatial
if d.currentLayer.Spatial == d.requestSpatial {
result.IsSwitchingToRequestSpatial = true
}
if d.currentLayer.Spatial == d.maxLayer.Spatial {
result.IsSwitchingToMaxSpatial = true
d.logger.Infow(
"reached max layer",
"current", d.currentLayer,
"target", d.targetLayer,
"max", d.maxLayer,
"layer", extPkt.DependencyDescriptor.FrameDependencies.SpatialId,
"req", d.requestSpatial,
"maxSeen", d.maxSeenLayer,
"feed", extPkt.Packet.SSRC,
)
}
}
if d.currentLayer.Temporal != d.targetLayer.Temporal && int32(extPkt.DependencyDescriptor.FrameDependencies.TemporalId) == d.targetLayer.Temporal {
d.currentLayer.Temporal = d.targetLayer.Temporal
}
d.currentLayer = highestDecodeTarget.Layer
if d.currentLayer.Spatial == d.requestSpatial {
result.IsSwitchingToRequestSpatial = true
}
if d.currentLayer.Spatial == d.maxLayer.Spatial {
result.IsSwitchingToMaxSpatial = true
d.logger.Infow(
"reached max layer",
"current", d.currentLayer,
"target", d.targetLayer,
"max", d.maxLayer,
"layer", fd.SpatialId,
"req", d.requestSpatial,
"maxSeen", d.maxSeenLayer,
"feed", extPkt.Packet.SSRC,
)
}
}
// DD-TODO : add frame to forwarded queue if entire frame is forwarded
// d.logger.Debugw("select packet", "target", currentTarget, "targetLayer", d.targetLayer)
ddExtension := &dd.DependencyDescriptorExtension{
Descriptor: extPkt.DependencyDescriptor,
ddExtension := &dede.DependencyDescriptorExtension{
Descriptor: dd,
Structure: d.structure,
}
if extPkt.DependencyDescriptor.AttachedStructure == nil && d.activeDecodeTargetsBitmask != nil {
// clone and override activebitmask
ddClone := *ddExtension.Descriptor
ddClone.ActiveDecodeTargetsBitmask = d.activeDecodeTargetsBitmask
ddExtension.Descriptor = &ddClone
// d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
if dd.AttachedStructure == nil {
if d.needsDecodeTargetBitmask {
d.needsDecodeTargetBitmask = false
d.activeDecodeTargetsBitmask = buffer.GetActiveDecodeTargetBitmask(d.targetLayer, ddwdt.DecodeTargets)
d.logger.Debugw("setting decode target bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
}
if d.activeDecodeTargetsBitmask != nil {
// clone and override activebitmask
ddClone := *ddExtension.Descriptor
ddClone.ActiveDecodeTargetsBitmask = d.activeDecodeTargetsBitmask
ddExtension.Descriptor = &ddClone
// d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
}
}
bytes, err := ddExtension.Marshal()
if err != nil {
@@ -202,83 +270,33 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
result.DependencyDescriptorExtension = bytes
}
result.RTPMarker = extPkt.Packet.Header.Marker || (extPkt.DependencyDescriptor.LastPacketInFrame && d.targetLayer.Spatial == int32(extPkt.DependencyDescriptor.FrameDependencies.SpatialId))
// DD-TODO START
// Ideally should add this frame only on the last packet of the frame and if all packets of the frame have been selected.
// But, adding on any packet so that any out-of-order packets within a frame can be fowarded.
// But, that could result in decodability/chain integrity to erroneously pass (i. e. in the case of lost packet in this
// frame, this frame is not decodable and hence the chain is broken).
//
// Note that packets can get lost in the forwarded path also. That will be handled by receiver sending PLI.
//
// Within SFU, there is more work to do to ensure integrity of forwarded packets/frames to adhere to the complete design
// goal of dependency descriptor
// DD-TODO END
d.decisions.AddForwarded(extFrameNum)
result.RTPMarker = extPkt.Packet.Header.Marker || (dd.LastPacketInFrame && d.currentLayer.Spatial == int32(fd.SpatialId))
result.IsSelected = true
return
}
func (d *DependencyDescriptor) SetTarget(targetLayer buffer.VideoLayer) {
if targetLayer == d.targetLayer {
return
}
d.Base.SetTarget(targetLayer)
activeBitMask := uint32(0)
var maxSpatial, maxTemporal int32
for _, dt := range d.decodeTargets {
if dt.Layer.Spatial > maxSpatial {
maxSpatial = dt.Layer.Spatial
}
if dt.Layer.Temporal > maxTemporal {
maxTemporal = dt.Layer.Temporal
}
if dt.Layer.Spatial <= targetLayer.Spatial && dt.Layer.Temporal <= targetLayer.Temporal {
activeBitMask |= 1 << dt.Target
}
}
if targetLayer.Spatial == maxSpatial && targetLayer.Temporal == maxTemporal {
// all the decode targets are selected
d.activeDecodeTargetsBitmask = nil
} else {
d.activeDecodeTargetsBitmask = &activeBitMask
}
d.logger.Debugw("setting target", "targetlayer", targetLayer, "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
d.needsDecodeTargetBitmask = true
}
func (d *DependencyDescriptor) updateDependencyStructure(structure *dd.FrameDependencyStructure) {
func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure) {
d.structure = structure
d.decodeTargets = d.decodeTargets[:0]
for target := 0; target < structure.NumDecodeTargets; target++ {
layer := buffer.VideoLayer{Spatial: 0, Temporal: 0}
for _, t := range structure.Templates {
if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent {
if layer.Spatial < int32(t.SpatialId) {
layer.Spatial = int32(t.SpatialId)
}
if layer.Temporal < int32(t.TemporalId) {
layer.Temporal = int32(t.TemporalId)
}
}
}
d.decodeTargets = append(d.decodeTargets, decodeTarget{target, layer})
}
// sort decode target layer by spatial and temporal from high to low
sort.Slice(d.decodeTargets, func(i, j int) bool {
return d.decodeTargets[i].Layer.GreaterThan(d.decodeTargets[j].Layer)
})
d.logger.Debugw(fmt.Sprintf("update decode targets: %v", d.decodeTargets))
}
// DD-TODO : use generic wrapper when updated to go 1.18
type Uint16Wrapper struct {
lastValue *uint16
lastUnwrapped int32
}
func (w *Uint16Wrapper) Unwrap(value uint16) int32 {
if w.lastValue == nil {
w.lastValue = &value
w.lastUnwrapped = int32(value)
return int32(*w.lastValue)
}
diff := value - *w.lastValue
w.lastUnwrapped += int32(diff)
if diff == 0x8000 && value < *w.lastValue {
w.lastUnwrapped -= 0x10000
} else if diff > 0x8000 {
w.lastUnwrapped -= 0x10000
}
*w.lastValue = value
return w.lastUnwrapped
}
@@ -0,0 +1,114 @@
package videolayerselector
import (
"fmt"
)
// ----------------------------------------------------------------------
type selectorDecision int
const (
selectorDecisionMissing selectorDecision = iota
selectorDecisionDropped
selectorDecisionForwarded
selectorDecisionUnknown
)
func (s selectorDecision) String() string {
switch s {
case selectorDecisionMissing:
return "MISSING"
case selectorDecisionDropped:
return "DROPPED"
case selectorDecisionForwarded:
return "FORWARDED"
case selectorDecisionUnknown:
return "UNKNOWN"
default:
return fmt.Sprintf("%d", int(s))
}
}
// ----------------------------------------------------------------------
type SelectorDecisionCache struct {
initialized bool
base uint64
last uint64
masks []uint64
numEntries uint64
}
func NewSelectorDecisionCache(maxNumElements uint64) *SelectorDecisionCache {
numElements := (maxNumElements*2 + 63) / 64
return &SelectorDecisionCache{
masks: make([]uint64, numElements),
numEntries: numElements * 32, // 2 bits per entry
}
}
func (s *SelectorDecisionCache) AddForwarded(entity uint64) {
s.addEntity(entity, selectorDecisionForwarded)
}
func (s *SelectorDecisionCache) AddDropped(entity uint64) {
s.addEntity(entity, selectorDecisionDropped)
}
func (s *SelectorDecisionCache) GetDecision(entity uint64) (selectorDecision, error) {
if !s.initialized || entity > s.last || entity < s.base {
return selectorDecisionUnknown, nil
}
offset := s.last - entity
if offset >= s.numEntries {
// asking for something too old
return selectorDecisionUnknown, fmt.Errorf("too old, oldest: %d, asking: %d", s.last-s.numEntries+1, entity)
}
return s.getEntity(entity), nil
}
func (s *SelectorDecisionCache) addEntity(entity uint64, sd selectorDecision) {
if !s.initialized {
s.initialized = true
s.base = entity
s.last = entity
s.setEntity(entity, sd)
return
}
if entity <= s.base {
// before base, too old
return
}
if entity <= s.last {
s.setEntity(entity, sd)
return
}
for e := s.last + 1; e != entity; e++ {
s.setEntity(e, selectorDecisionMissing)
}
s.setEntity(entity, sd)
s.last = entity
}
func (s *SelectorDecisionCache) setEntity(entity uint64, sd selectorDecision) {
index, bitpos := s.getPos(entity)
s.masks[index] &= ^(0x3 << bitpos) // clear before bitwise OR
s.masks[index] |= (uint64(sd) & 0x3) << bitpos
}
func (s *SelectorDecisionCache) getEntity(entity uint64) selectorDecision {
index, bitpos := s.getPos(entity)
return selectorDecision((s.masks[index] >> bitpos) & 0x3)
}
func (s *SelectorDecisionCache) getPos(entity uint64) (int, int) {
// 2 bits per entity, a uint64 mask can hold 32 entities
offset := (entity - s.base) % s.numEntries
return int(offset >> 5), int(offset&0x1F) * 2
}