From 69fb5e51a2a5baaf6b784361def429a963eb1672 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Wed, 12 Apr 2023 17:30:54 +0530 Subject: [PATCH] Fix stutter in forwarding path when using dependency descriptor (#1600) * Decode chains * clean up * clean up * decode targets only on publisher side * comment out supported codecs * fix test compile * fix another test compile * Adding TODO notes * chainID -> chainIdx * do not need to check for switch up point when using chains, as long as chain integrity is good, can switch * more comments * address comments --- pkg/sfu/buffer/buffer.go | 6 +- pkg/sfu/buffer/dependencydescriptorparser.go | 206 ++++++---- pkg/sfu/buffer/fps.go | 8 +- pkg/sfu/buffer/fps_test.go | 10 +- .../dependencydescriptorextension.go | 52 ++- .../dependencydescriptorreader.go | 6 +- pkg/sfu/downtrack.go | 2 +- pkg/sfu/utils/wraparound.go | 121 ++++++ pkg/sfu/utils/wraparound_test.go | 295 ++++++++++++++ .../dependencydescriptor.go | 364 +++++++++--------- .../selectordecisioncache.go | 114 ++++++ 11 files changed, 901 insertions(+), 283 deletions(-) create mode 100644 pkg/sfu/utils/wraparound.go create mode 100644 pkg/sfu/utils/wraparound_test.go create mode 100644 pkg/sfu/videolayerselector/selectordecisioncache.go diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 361162b37..8b9a7094d 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -42,7 +42,7 @@ type ExtPacket struct { Payload interface{} KeyFrame bool RawPacket []byte - DependencyDescriptor *dd.DependencyDescriptor + DependencyDescriptor *DependencyDescriptorWithDecodeTarget } // Buffer contains all packets @@ -551,7 +551,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPack if err == nil && ddVal != nil { ep.DependencyDescriptor = ddVal ep.VideoLayer = videoLayer - // TODO : notify active decode target change if changed. + // DD-TODO : notify active decode target change if changed. } } switch b.mime { @@ -779,7 +779,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) { return b.audioLevel.GetLevel() } -// TODO : now we rely on stream tracker for layer change, dependency still +// DD-TODO : now we rely on stream tracker for layer change, dependency still // work for that too. Do we keep it unchanged or use both methods? func (b *Buffer) OnMaxLayerChanged(fn func(int32, int32)) { b.maxLayerChangedCB = fn diff --git a/pkg/sfu/buffer/dependencydescriptorparser.go b/pkg/sfu/buffer/dependencydescriptorparser.go index 755e2cd4d..c52e86c83 100644 --- a/pkg/sfu/buffer/dependencydescriptorparser.go +++ b/pkg/sfu/buffer/dependencydescriptorparser.go @@ -2,6 +2,7 @@ package buffer import ( "fmt" + "sort" "github.com/pion/rtp" @@ -12,91 +13,142 @@ import ( type DependencyDescriptorParser struct { structure *dd.FrameDependencyStructure - ddExt uint8 + ddExtID uint8 logger logger.Logger onMaxLayerChanged func(int32, int32) - decodeTargetLayer []VideoLayer + decodeTargets []DependencyDescriptorDecodeTarget } -func NewDependencyDescriptorParser(ddExt uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser { - logger.Infow("creating dependency descriptor parser", "ddExt", ddExt) +func NewDependencyDescriptorParser(ddExtID uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser { + logger.Infow("creating dependency descriptor parser", "ddExtID", ddExtID) return &DependencyDescriptorParser{ - ddExt: ddExt, + ddExtID: ddExtID, logger: logger, onMaxLayerChanged: onMaxLayerChanged, } } -func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*dd.DependencyDescriptor, VideoLayer, error) { - var videoLayer VideoLayer - if ddBuf := pkt.GetExtension(r.ddExt); ddBuf != nil { - var ddVal dd.DependencyDescriptor - ext := &dd.DependencyDescriptorExtension{ - Descriptor: &ddVal, - Structure: r.structure, - } - _, err := ext.Unmarshal(ddBuf) - if err != nil { - // r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf)) - return nil, videoLayer, err - } - - if ddVal.FrameDependencies != nil { - videoLayer.Spatial, videoLayer.Temporal = int32(ddVal.FrameDependencies.SpatialId), int32(ddVal.FrameDependencies.TemporalId) - } - if ddVal.AttachedStructure != nil && !ddVal.FirstPacketInFrame { - // r.logger.Debugw("ignoring non-first packet in frame with attached structure") - return nil, videoLayer, nil - } - - if ddVal.AttachedStructure != nil { - var maxSpatial, maxTemporal int32 - r.structure = ddVal.AttachedStructure - r.decodeTargetLayer = r.decodeTargetLayer[:0] - for target := 0; target < r.structure.NumDecodeTargets; target++ { - layer := VideoLayer{0, 0} - for _, t := range r.structure.Templates { - if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent { - if layer.Spatial < int32(t.SpatialId) { - layer.Spatial = int32(t.SpatialId) - } - if layer.Temporal < int32(t.TemporalId) { - layer.Temporal = int32(t.TemporalId) - } - } - } - if layer.Spatial > maxSpatial { - maxSpatial = layer.Spatial - } - if layer.Temporal > maxTemporal { - maxTemporal = layer.Temporal - } - r.decodeTargetLayer = append(r.decodeTargetLayer, layer) - } - r.logger.Debugw("max layer changed", "maxSpatial", maxSpatial, "maxTemporal", maxTemporal) - go r.onMaxLayerChanged(maxSpatial, maxTemporal) - } - - if ddVal.AttachedStructure != nil && ddVal.FirstPacketInFrame { - r.logger.Debugw(fmt.Sprintf("parsed dependency descriptor\n%s", ddVal.String())) - } - - if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil { - var maxSpatial, maxTemporal int32 - for dt, layer := range r.decodeTargetLayer { - if *mask&(1< maxSpatial { + maxSpatial = dt.Layer.Spatial + } + if dt.Layer.Temporal > maxTemporal { + maxTemporal = dt.Layer.Temporal + } + if dt.Layer.Spatial <= layer.Spatial && dt.Layer.Temporal <= layer.Temporal { + activeBitMask |= 1 << dt.Target + } + } + if layer.Spatial == maxSpatial && layer.Temporal == maxTemporal { + // all the decode targets are selected + return nil + } + + return &activeBitMask +} + +// ------------------------------------------------------------------------------ diff --git a/pkg/sfu/buffer/fps.go b/pkg/sfu/buffer/fps.go index f8f192227..518b8e38c 100644 --- a/pkg/sfu/buffer/fps.go +++ b/pkg/sfu/buffer/fps.go @@ -7,7 +7,7 @@ import ( "github.com/pion/rtp/codecs" ) -var minFramesForCalculation = [DefaultMaxLayerTemporal + 1]int{8, 15, 40} +var minFramesForCalculation = [...]int{8, 15, 40} type frameInfo struct { seq uint16 @@ -357,7 +357,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { return false } - fn := ep.DependencyDescriptor.FrameNumber + fn := ep.DependencyDescriptor.Descriptor.FrameNumber if f.baseFrame == nil { f.baseFrame = &frameInfo{seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn} f.fnReceived[0] = f.baseFrame @@ -397,7 +397,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { fn: fn, temporal: temporal, spatial: spatial, - frameDiff: ep.DependencyDescriptor.FrameDependencies.FrameDiffs, + frameDiff: ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs, } f.fnReceived[baseDiff] = fi @@ -411,7 +411,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { if chain.Len() == 0 { chain.PushBack(fn) } - for _, fdiff := range ep.DependencyDescriptor.FrameDependencies.FrameDiffs { + for _, fdiff := range ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs { dependFrame := fn - uint16(fdiff) // frame too old, ignore if dependFrame-f.secondFrames[spatial][temporal].fn > 0x8000 { diff --git a/pkg/sfu/buffer/fps_test.go b/pkg/sfu/buffer/fps_test.go index 5f0ff79ce..206041acf 100644 --- a/pkg/sfu/buffer/fps_test.go +++ b/pkg/sfu/buffer/fps_test.go @@ -31,10 +31,12 @@ func (f *testFrameInfo) toVP8() *ExtPacket { func (f *testFrameInfo) toDD() *ExtPacket { return &ExtPacket{ Packet: &rtp.Packet{Header: f.header}, - DependencyDescriptor: &dependencydescriptor.DependencyDescriptor{ - FrameNumber: f.framenumber, - FrameDependencies: &dependencydescriptor.FrameDependencyTemplate{ - FrameDiffs: f.frameDiff, + DependencyDescriptor: &DependencyDescriptorWithDecodeTarget{ + Descriptor: &dependencydescriptor.DependencyDescriptor{ + FrameNumber: f.framenumber, + FrameDependencies: &dependencydescriptor.FrameDependencyTemplate{ + FrameDiffs: f.frameDiff, + }, }, }, VideoLayer: VideoLayer{Spatial: int32(f.spatial), Temporal: int32(f.temporal)}, diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go b/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go index 4d6e339cc..7493f1bfc 100644 --- a/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go +++ b/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go @@ -9,23 +9,20 @@ import ( // DependencyDescriptorExtension is a extension payload format in // https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension +func formatBitmask(b *uint32) string { + if b == nil { + return "-" + } + return strconv.FormatInt(int64(*b), 2) +} + +// ------------------------------------------------------------------------------ + type DependencyDescriptorExtension struct { Descriptor *DependencyDescriptor Structure *FrameDependencyStructure } -func (d *DependencyDescriptor) MarshalSize() (int, error) { - return d.MarshalSizeWithActiveChains(^uint32(0)) -} - -func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) { - writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d) - if err != nil { - return 0, err - } - return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil -} - func (d *DependencyDescriptorExtension) Marshal() ([]byte, error) { return d.MarshalWithActiveChains(^uint32(0)) } @@ -48,6 +45,8 @@ func (d *DependencyDescriptorExtension) Unmarshal(buf []byte) (int, error) { return reader.Parse() } +// ------------------------------------------------------------------------------ + const ( MaxSpatialIds = 4 MaxTemporalIds = 8 @@ -59,9 +58,11 @@ const ( ExtensionUrl = "https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension" ) +// ------------------------------------------------------------------------------ + type DependencyDescriptor struct { - FirstPacketInFrame bool // = true; - LastPacketInFrame bool // = true; + FirstPacketInFrame bool + LastPacketInFrame bool FrameNumber uint16 FrameDependencies *FrameDependencyTemplate Resolution *RenderResolution @@ -69,11 +70,16 @@ type DependencyDescriptor struct { AttachedStructure *FrameDependencyStructure } -func formatBitmask(b *uint32) string { - if b == nil { - return "-" +func (d *DependencyDescriptor) MarshalSize() (int, error) { + return d.MarshalSizeWithActiveChains(^uint32(0)) +} + +func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) { + writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d) + if err != nil { + return 0, err } - return strconv.FormatInt(int64(*b), 2) + return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil } func (d *DependencyDescriptor) String() string { @@ -81,6 +87,8 @@ func (d *DependencyDescriptor) String() string { d.FirstPacketInFrame, d.LastPacketInFrame, d.FrameNumber, *d.FrameDependencies, *d.Resolution, formatBitmask(d.ActiveDecodeTargetsBitmask), d.AttachedStructure) } +// ------------------------------------------------------------------------------ + // Relationship of a frame to a Decode target. type DecodeTargetIndication int @@ -106,6 +114,8 @@ func (i DecodeTargetIndication) String() string { } } +// ------------------------------------------------------------------------------ + type FrameDependencyTemplate struct { SpatialId int TemporalId int @@ -132,6 +142,8 @@ func (t *FrameDependencyTemplate) Clone() *FrameDependencyTemplate { return t2 } +// ------------------------------------------------------------------------------ + type FrameDependencyStructure struct { StructureId int NumDecodeTargets int @@ -156,7 +168,11 @@ func (f *FrameDependencyStructure) String() string { return str } +// ------------------------------------------------------------------------------ + type RenderResolution struct { Width int Height int } + +// ------------------------------------------------------------------------------ diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go index 99f3fd68a..68f00b863 100644 --- a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go +++ b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go @@ -1,6 +1,8 @@ package dependencydescriptor -import "errors" +import ( + "errors" +) type DependencyDescriptorReader struct { // Output. @@ -89,7 +91,6 @@ func (r *DependencyDescriptorReader) readMandatoryFields() error { } func (r *DependencyDescriptorReader) readExtendedFields() error { - templateDependencyStructurePresentFlag, err := r.buffer.ReadBool() if err != nil { return err @@ -384,7 +385,6 @@ func (r *DependencyDescriptorReader) readFrameDtis() error { } func (r *DependencyDescriptorReader) readFrameFdiffs() error { - r.descriptor.FrameDependencies.FrameDiffs = r.descriptor.FrameDependencies.FrameDiffs[:0] for { nexFdiffSize, err := r.buffer.ReadBits(2) diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 256f5defc..430bb882f 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -1448,7 +1448,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { } var extraExtensions []extensionData - if len(meta.ddBytes) > 0 { + if d.dependencyDescriptorID != 0 && len(meta.ddBytes) != 0 { extraExtensions = append(extraExtensions, extensionData{ id: uint8(d.dependencyDescriptorID), payload: meta.ddBytes, diff --git a/pkg/sfu/utils/wraparound.go b/pkg/sfu/utils/wraparound.go new file mode 100644 index 000000000..e6c8bbf2c --- /dev/null +++ b/pkg/sfu/utils/wraparound.go @@ -0,0 +1,121 @@ +package utils + +import ( + "unsafe" +) + +type number interface { + uint16 | uint32 +} + +type extendedNumber interface { + uint32 | uint64 +} + +type WrapAround[T number, ET extendedNumber] struct { + fullRange ET + + initialized bool + start T + highest T + cycles int +} + +func NewWrapAround[T number, ET extendedNumber]() *WrapAround[T, ET] { + var t T + return &WrapAround[T, ET]{ + fullRange: 1 << (unsafe.Sizeof(t) * 8), + } +} + +func (w *WrapAround[T, ET]) Seed(from *WrapAround[T, ET]) { + w.initialized = from.initialized + w.start = from.start + w.highest = from.highest + w.cycles = from.cycles +} + +type wrapAroundUpdateResult[ET extendedNumber] struct { + IsRestart bool + PreExtendedStart ET // valid only if IsRestart = true + PreExtendedHighest ET + ExtendedVal ET +} + +func (w *WrapAround[T, ET]) Update(val T) (result wrapAroundUpdateResult[ET]) { + if !w.initialized { + result.PreExtendedHighest = ET(val) - 1 + result.ExtendedVal = ET(val) + + w.start = val + w.highest = val + w.initialized = true + return + } + + result.PreExtendedHighest = w.GetExtendedHighest() + + gap := val - w.highest + if gap == 0 || gap > T(w.fullRange>>1) { + // duplicate OR out-of-order + result.IsRestart, result.PreExtendedStart, result.ExtendedVal = w.maybeAdjustStart(val) + return + } + + // in-order + if val < w.highest { + w.cycles++ + } + w.highest = val + + result.ExtendedVal = ET(w.cycles)*w.fullRange + ET(val) + return +} + +func (w *WrapAround[T, ET]) ResetHighest(val T) { + w.highest = val +} + +func (w *WrapAround[T, ET]) GetStart() T { + return w.start +} + +func (w *WrapAround[T, ET]) GetExtendedStart() ET { + return ET(w.start) +} + +func (w *WrapAround[T, ET]) GetHighest() T { + return w.highest +} + +func (w *WrapAround[T, ET]) GetExtendedHighest() ET { + return ET(w.cycles)*w.fullRange + ET(w.highest) +} + +func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtendedStart ET, extendedVal ET) { + // re-adjust start if necessary. The conditions are + // 1. Not seen more than half the range yet + // 1. wrap around compared to start and not completed a half cycle, sequences like (10, 65530) in uint16 space + // 2. no wrap around, but out-of-order compared to start and not completed a half cycle , sequences like (10, 9), (65530, 65528) in uint16 space + totalNum := w.GetExtendedHighest() - w.GetExtendedStart() + 1 + if totalNum > (w.fullRange >> 1) { + extendedVal = ET(w.cycles)*w.fullRange + ET(val) + return + } + + cycles := w.cycles + if val-w.start > T(w.fullRange>>1) { + // out-of-order with existing start => a new start + isRestart = true + preExtendedStart = w.GetExtendedStart() + + if val > w.highest { + // wrap around + w.cycles = 1 + cycles = 0 + } + w.start = val + } + extendedVal = ET(cycles)*w.fullRange + ET(val) + return +} diff --git a/pkg/sfu/utils/wraparound_test.go b/pkg/sfu/utils/wraparound_test.go new file mode 100644 index 000000000..828242f87 --- /dev/null +++ b/pkg/sfu/utils/wraparound_test.go @@ -0,0 +1,295 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWrapAroundUint16(t *testing.T) { + w := NewWrapAround[uint16, uint32]() + testCases := []struct { + name string + input uint16 + updated wrapAroundUpdateResult[uint32] + start uint16 + extendedStart uint32 + highest uint16 + extendedHighest uint32 + }{ + // initialize + { + name: "initialize", + input: 10, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: 9, + ExtendedVal: 10, + }, + start: 10, + extendedStart: 10, + highest: 10, + extendedHighest: 10, + }, + // an older number without wrap around should reset start point + { + name: "reset start no wrap around", + input: 8, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: true, + PreExtendedStart: 10, + PreExtendedHighest: 10, + ExtendedVal: 8, + }, + start: 8, + extendedStart: 8, + highest: 10, + extendedHighest: 10, + }, + // an older number with wrap around should reset start point + { + name: "reset start wrap around", + input: (1 << 16) - 6, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: true, + PreExtendedStart: 8, + PreExtendedHighest: 10, + ExtendedVal: (1 << 16) - 6, + }, + start: (1 << 16) - 6, + extendedStart: (1 << 16) - 6, + highest: 10, + extendedHighest: (1 << 16) + 10, + }, + // an older number with wrap around should reset start point again + { + name: "reset start again", + input: (1 << 16) - 12, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: true, + PreExtendedStart: (1 << 16) - 6, + PreExtendedHighest: (1 << 16) + 10, + ExtendedVal: (1 << 16) - 12, + }, + start: (1 << 16) - 12, + extendedStart: (1 << 16) - 12, + highest: 10, + extendedHighest: (1 << 16) + 10, + }, + // duplicate should return same as highest + { + name: "duplicate", + input: 10, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 16) + 10, + ExtendedVal: (1 << 16) + 10, + }, + start: (1 << 16) - 12, + extendedStart: (1 << 16) - 12, + highest: 10, + extendedHighest: (1 << 16) + 10, + }, + // a significant jump in order should not reset start + { + name: "big in-order jump", + input: 1 << 15, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 16) + 10, + ExtendedVal: (1 << 16) + (1 << 15), + }, + start: (1 << 16) - 12, + extendedStart: (1 << 16) - 12, + highest: 1 << 15, + extendedHighest: (1 << 16) + (1 << 15), + }, + // now out-of-order should not reset start as half the range has been seen + { + name: "out-of-order after half range", + input: (1 << 15) - 1, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 16) + (1 << 15), + ExtendedVal: (1 << 16) + (1 << 15) - 1, + }, + start: (1 << 16) - 12, + extendedStart: (1 << 16) - 12, + highest: 1 << 15, + extendedHighest: (1 << 16) + (1 << 15), + }, + // in-order, should update highest + { + name: "in-order", + input: (1 << 15) + 3, + updated: wrapAroundUpdateResult[uint32]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 16) + (1 << 15), + ExtendedVal: (1 << 16) + (1 << 15) + 3, + }, + start: (1 << 16) - 12, + extendedStart: (1 << 16) - 12, + highest: (1 << 15) + 3, + extendedHighest: (1 << 16) + (1 << 15) + 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.updated, w.Update(tc.input)) + require.Equal(t, tc.start, w.GetStart()) + require.Equal(t, tc.extendedStart, w.GetExtendedStart()) + require.Equal(t, tc.highest, w.GetHighest()) + require.Equal(t, tc.extendedHighest, w.GetExtendedHighest()) + }) + } +} + +func TestWrapAroundUint32(t *testing.T) { + w := NewWrapAround[uint32, uint64]() + testCases := []struct { + name string + input uint32 + updated wrapAroundUpdateResult[uint64] + start uint32 + extendedStart uint64 + highest uint32 + extendedHighest uint64 + }{ + // initialize + { + name: "initialize", + input: 10, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: 9, + ExtendedVal: 10, + }, + start: 10, + extendedStart: 10, + highest: 10, + extendedHighest: 10, + }, + // an older number without wrap around should reset start point + { + name: "reset start no wrap around", + input: 8, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: true, + PreExtendedStart: 10, + PreExtendedHighest: 10, + ExtendedVal: 8, + }, + start: 8, + extendedStart: 8, + highest: 10, + extendedHighest: 10, + }, + // an older number with wrap around should reset start point + { + name: "reset start wrap around", + input: (1 << 32) - 6, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: true, + PreExtendedStart: 8, + PreExtendedHighest: 10, + ExtendedVal: (1 << 32) - 6, + }, + start: (1 << 32) - 6, + extendedStart: (1 << 32) - 6, + highest: 10, + extendedHighest: (1 << 32) + 10, + }, + // an older number with wrap around should reset start point again + { + name: "reset start again", + input: (1 << 32) - 12, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: true, + PreExtendedStart: (1 << 32) - 6, + PreExtendedHighest: (1 << 32) + 10, + ExtendedVal: (1 << 32) - 12, + }, + start: (1 << 32) - 12, + extendedStart: (1 << 32) - 12, + highest: 10, + extendedHighest: (1 << 32) + 10, + }, + // duplicate should return same as highest + { + name: "duplicate", + input: 10, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 32) + 10, + ExtendedVal: (1 << 32) + 10, + }, + start: (1 << 32) - 12, + extendedStart: (1 << 32) - 12, + highest: 10, + extendedHighest: (1 << 32) + 10, + }, + // a significant jump in order should not reset start + { + name: "big in-order jump", + input: 1 << 31, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 32) + 10, + ExtendedVal: (1 << 32) + (1 << 31), + }, + start: (1 << 32) - 12, + extendedStart: (1 << 32) - 12, + highest: 1 << 31, + extendedHighest: (1 << 32) + (1 << 31), + }, + // now out-of-order should not reset start as half the range has been seen + { + name: "out-of-order after half range", + input: (1 << 31) - 1, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 32) + (1 << 31), + ExtendedVal: (1 << 32) + (1 << 31) - 1, + }, + start: (1 << 32) - 12, + extendedStart: (1 << 32) - 12, + highest: 1 << 31, + extendedHighest: (1 << 32) + (1 << 31), + }, + // in-order, should update highest + { + name: "in-order", + input: (1 << 31) + 3, + updated: wrapAroundUpdateResult[uint64]{ + IsRestart: false, + PreExtendedStart: 0, + PreExtendedHighest: (1 << 32) + (1 << 31), + ExtendedVal: (1 << 32) + (1 << 31) + 3, + }, + start: (1 << 32) - 12, + extendedStart: (1 << 32) - 12, + highest: (1 << 31) + 3, + extendedHighest: (1 << 32) + (1 << 31) + 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.updated, w.Update(tc.input)) + require.Equal(t, tc.start, w.GetStart()) + require.Equal(t, tc.extendedStart, w.GetExtendedStart()) + require.Equal(t, tc.highest, w.GetHighest()) + require.Equal(t, tc.extendedHighest, w.GetExtendedHighest()) + }) + } +} diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go index 09d46515e..1fa1fc52a 100644 --- a/pkg/sfu/videolayerselector/dependencydescriptor.go +++ b/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -2,39 +2,37 @@ package videolayerselector import ( "fmt" - "sort" "github.com/livekit/livekit-server/pkg/sfu/buffer" - dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" + dede "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor" + "github.com/livekit/livekit-server/pkg/sfu/utils" "github.com/livekit/protocol/logger" ) -type decodeTarget struct { - Target int - Layer buffer.VideoLayer -} - type DependencyDescriptor struct { *Base - // DD-TODO : fields for frame chain detect - // frameNumberWrapper Uint16Wrapper - // expectKeyFrame bool + frameNum *utils.WrapAround[uint16, uint64] + decisions *SelectorDecisionCache - decodeTargets []decodeTarget + needsDecodeTargetBitmask bool activeDecodeTargetsBitmask *uint32 - structure *dd.FrameDependencyStructure + structure *dede.FrameDependencyStructure } func NewDependencyDescriptor(logger logger.Logger) *DependencyDescriptor { return &DependencyDescriptor{ - Base: NewBase(logger), + Base: NewBase(logger), + frameNum: utils.NewWrapAround[uint16, uint64](), + decisions: NewSelectorDecisionCache(256), } } func NewDependencyDescriptorFromNull(vls VideoLayerSelector) *DependencyDescriptor { return &DependencyDescriptor{ - Base: vls.(*Null).Base, + Base: vls.(*Null).Base, + frameNum: utils.NewWrapAround[uint16, uint64](), + decisions: NewSelectorDecisionCache(256), } } @@ -43,21 +41,80 @@ func (d *DependencyDescriptor) IsOvershootOkay() bool { } func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) { - if extPkt.DependencyDescriptor == nil { - // packet don't have dependency descriptor + ddwdt := extPkt.DependencyDescriptor + if ddwdt == nil { + // packet doesn't have dependency descriptor + return + } + + dd := ddwdt.Descriptor + + // a packet is relevant as long as it has DD extension + result.IsRelevant = true + + frameNum := d.frameNum.Update(dd.FrameNumber) + extFrameNum := frameNum.ExtendedVal + + fd := dd.FrameDependencies + incomingLayer := buffer.VideoLayer{ + Spatial: int32(fd.SpatialId), + Temporal: int32(fd.TemporalId), + } + + // early return if this frame is already forwarded or dropped + sd, err := d.decisions.GetDecision(extFrameNum) + if err != nil { + // do not mark as dropped as only error is an old frame + return + } + switch sd { + case selectorDecisionForwarded: + // a packet of an alreadty forwarded frame, maintain decision + result.RTPMarker = extPkt.Packet.Header.Marker || (dd.LastPacketInFrame && d.currentLayer.Spatial == int32(fd.SpatialId)) + result.IsSelected = true + + case selectorDecisionDropped: + // a packet of an alreadty dropped frame, maintain decision return } if !d.currentLayer.IsValid() && !extPkt.KeyFrame { + d.decisions.AddDropped(extFrameNum) return } - result.IsRelevant = true + // check decodability using reference frames + isDecodable := true + for _, fdiff := range fd.FrameDiffs { + if fdiff == 0 { + continue + } - if extPkt.DependencyDescriptor.AttachedStructure != nil { + if sd, _ := d.decisions.GetDecision(extFrameNum - uint64(fdiff)); sd != selectorDecisionForwarded { + isDecodable = false + break + } + } + if !isDecodable { + // DD-TODO START + // Not decodable could happen due to packet loss or out-of-order packets, + // Need to figure out better ways to handle this. + // + // 1. Should definitely check if this frame is not part of current decode target OR discardable. + // In that case, forwarding can proceed without disruption. + // 2. Add a packet queue and try to de-jitter for some time. Safest is to packet copy to local queue on + // all down tracks. + // 3. Force a PLI and wait for a key frame. + // DD-TODO END + d.decisions.AddDropped(extFrameNum) + return + } + + // DD-TODO should not update for out-of-order RTP packets + if dd.AttachedStructure != nil { // update decode target layer and active decode targets // DD-TODO : these targets info can be shared by all the downtracks, no need calculate in every selector - d.updateDependencyStructure(extPkt.DependencyDescriptor.AttachedStructure) + d.updateDependencyStructure(dd.AttachedStructure) } // DD-TODO : we don't have a rtp queue to ensure the order of packets now, @@ -67,133 +124,144 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r // only check DTI of the active decode target. // it is not effeciency, at last we need check frame chain integrity. - activeDecodeTargets := extPkt.DependencyDescriptor.ActiveDecodeTargetsBitmask + activeDecodeTargets := dd.ActiveDecodeTargetsBitmask if activeDecodeTargets != nil { d.logger.Debugw("active decode targets", "activeDecodeTargets", *activeDecodeTargets) } - currentTarget := -1 - for _, dt := range d.decodeTargets { - // find target match with selected layer - if dt.Layer.Spatial <= d.targetLayer.Spatial && dt.Layer.Temporal <= d.targetLayer.Temporal { - if activeDecodeTargets == nil || ((*activeDecodeTargets)&(1< d.targetLayer.Spatial || dt.Layer.Temporal > d.targetLayer.Temporal { + continue + } + + if activeDecodeTargets != nil && ((*activeDecodeTargets)&(1< maxSpatial { - maxSpatial = dt.Layer.Spatial - } - if dt.Layer.Temporal > maxTemporal { - maxTemporal = dt.Layer.Temporal - } - if dt.Layer.Spatial <= targetLayer.Spatial && dt.Layer.Temporal <= targetLayer.Temporal { - activeBitMask |= 1 << dt.Target - } - } - if targetLayer.Spatial == maxSpatial && targetLayer.Temporal == maxTemporal { - // all the decode targets are selected - d.activeDecodeTargetsBitmask = nil - } else { - d.activeDecodeTargetsBitmask = &activeBitMask - } - d.logger.Debugw("setting target", "targetlayer", targetLayer, "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask) + d.needsDecodeTargetBitmask = true } -func (d *DependencyDescriptor) updateDependencyStructure(structure *dd.FrameDependencyStructure) { +func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure) { d.structure = structure - d.decodeTargets = d.decodeTargets[:0] - - for target := 0; target < structure.NumDecodeTargets; target++ { - layer := buffer.VideoLayer{Spatial: 0, Temporal: 0} - for _, t := range structure.Templates { - if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent { - if layer.Spatial < int32(t.SpatialId) { - layer.Spatial = int32(t.SpatialId) - } - if layer.Temporal < int32(t.TemporalId) { - layer.Temporal = int32(t.TemporalId) - } - } - } - d.decodeTargets = append(d.decodeTargets, decodeTarget{target, layer}) - } - - // sort decode target layer by spatial and temporal from high to low - sort.Slice(d.decodeTargets, func(i, j int) bool { - return d.decodeTargets[i].Layer.GreaterThan(d.decodeTargets[j].Layer) - }) - d.logger.Debugw(fmt.Sprintf("update decode targets: %v", d.decodeTargets)) -} - -// DD-TODO : use generic wrapper when updated to go 1.18 -type Uint16Wrapper struct { - lastValue *uint16 - lastUnwrapped int32 -} - -func (w *Uint16Wrapper) Unwrap(value uint16) int32 { - if w.lastValue == nil { - w.lastValue = &value - w.lastUnwrapped = int32(value) - return int32(*w.lastValue) - } - - diff := value - *w.lastValue - w.lastUnwrapped += int32(diff) - if diff == 0x8000 && value < *w.lastValue { - w.lastUnwrapped -= 0x10000 - } else if diff > 0x8000 { - w.lastUnwrapped -= 0x10000 - } - - *w.lastValue = value - return w.lastUnwrapped } diff --git a/pkg/sfu/videolayerselector/selectordecisioncache.go b/pkg/sfu/videolayerselector/selectordecisioncache.go new file mode 100644 index 000000000..b83bf9559 --- /dev/null +++ b/pkg/sfu/videolayerselector/selectordecisioncache.go @@ -0,0 +1,114 @@ +package videolayerselector + +import ( + "fmt" +) + +// ---------------------------------------------------------------------- + +type selectorDecision int + +const ( + selectorDecisionMissing selectorDecision = iota + selectorDecisionDropped + selectorDecisionForwarded + selectorDecisionUnknown +) + +func (s selectorDecision) String() string { + switch s { + case selectorDecisionMissing: + return "MISSING" + case selectorDecisionDropped: + return "DROPPED" + case selectorDecisionForwarded: + return "FORWARDED" + case selectorDecisionUnknown: + return "UNKNOWN" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +// ---------------------------------------------------------------------- + +type SelectorDecisionCache struct { + initialized bool + base uint64 + last uint64 + masks []uint64 + numEntries uint64 +} + +func NewSelectorDecisionCache(maxNumElements uint64) *SelectorDecisionCache { + numElements := (maxNumElements*2 + 63) / 64 + return &SelectorDecisionCache{ + masks: make([]uint64, numElements), + numEntries: numElements * 32, // 2 bits per entry + } +} + +func (s *SelectorDecisionCache) AddForwarded(entity uint64) { + s.addEntity(entity, selectorDecisionForwarded) +} + +func (s *SelectorDecisionCache) AddDropped(entity uint64) { + s.addEntity(entity, selectorDecisionDropped) +} + +func (s *SelectorDecisionCache) GetDecision(entity uint64) (selectorDecision, error) { + if !s.initialized || entity > s.last || entity < s.base { + return selectorDecisionUnknown, nil + } + + offset := s.last - entity + if offset >= s.numEntries { + // asking for something too old + return selectorDecisionUnknown, fmt.Errorf("too old, oldest: %d, asking: %d", s.last-s.numEntries+1, entity) + } + + return s.getEntity(entity), nil +} + +func (s *SelectorDecisionCache) addEntity(entity uint64, sd selectorDecision) { + if !s.initialized { + s.initialized = true + s.base = entity + s.last = entity + s.setEntity(entity, sd) + return + } + + if entity <= s.base { + // before base, too old + return + } + + if entity <= s.last { + s.setEntity(entity, sd) + return + } + + for e := s.last + 1; e != entity; e++ { + s.setEntity(e, selectorDecisionMissing) + } + s.setEntity(entity, sd) + s.last = entity +} + +func (s *SelectorDecisionCache) setEntity(entity uint64, sd selectorDecision) { + index, bitpos := s.getPos(entity) + s.masks[index] &= ^(0x3 << bitpos) // clear before bitwise OR + s.masks[index] |= (uint64(sd) & 0x3) << bitpos +} + +func (s *SelectorDecisionCache) getEntity(entity uint64) selectorDecision { + index, bitpos := s.getPos(entity) + return selectorDecision((s.masks[index] >> bitpos) & 0x3) +} + +func (s *SelectorDecisionCache) getPos(entity uint64) (int, int) { + // 2 bits per entity, a uint64 mask can hold 32 entities + offset := (entity - s.base) % s.numEntries + return int(offset >> 5), int(offset&0x1F) * 2 +}