diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go
index 361162b37..8b9a7094d 100644
--- a/pkg/sfu/buffer/buffer.go
+++ b/pkg/sfu/buffer/buffer.go
@@ -42,7 +42,7 @@ type ExtPacket struct {
Payload interface{}
KeyFrame bool
RawPacket []byte
- DependencyDescriptor *dd.DependencyDescriptor
+ DependencyDescriptor *DependencyDescriptorWithDecodeTarget
}
// Buffer contains all packets
@@ -551,7 +551,7 @@ func (b *Buffer) getExtPacket(rtpPacket *rtp.Packet, arrivalTime int64) *ExtPack
if err == nil && ddVal != nil {
ep.DependencyDescriptor = ddVal
ep.VideoLayer = videoLayer
- // TODO : notify active decode target change if changed.
+ // DD-TODO : notify active decode target change if changed.
}
}
switch b.mime {
@@ -779,7 +779,7 @@ func (b *Buffer) GetAudioLevel() (float64, bool) {
return b.audioLevel.GetLevel()
}
-// TODO : now we rely on stream tracker for layer change, dependency still
+// DD-TODO : now we rely on stream tracker for layer change, dependency still
// work for that too. Do we keep it unchanged or use both methods?
func (b *Buffer) OnMaxLayerChanged(fn func(int32, int32)) {
b.maxLayerChangedCB = fn
diff --git a/pkg/sfu/buffer/dependencydescriptorparser.go b/pkg/sfu/buffer/dependencydescriptorparser.go
index 755e2cd4d..c52e86c83 100644
--- a/pkg/sfu/buffer/dependencydescriptorparser.go
+++ b/pkg/sfu/buffer/dependencydescriptorparser.go
@@ -2,6 +2,7 @@ package buffer
import (
"fmt"
+ "sort"
"github.com/pion/rtp"
@@ -12,91 +13,142 @@ import (
type DependencyDescriptorParser struct {
structure *dd.FrameDependencyStructure
- ddExt uint8
+ ddExtID uint8
logger logger.Logger
onMaxLayerChanged func(int32, int32)
- decodeTargetLayer []VideoLayer
+ decodeTargets []DependencyDescriptorDecodeTarget
}
-func NewDependencyDescriptorParser(ddExt uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser {
- logger.Infow("creating dependency descriptor parser", "ddExt", ddExt)
+func NewDependencyDescriptorParser(ddExtID uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32)) *DependencyDescriptorParser {
+ logger.Infow("creating dependency descriptor parser", "ddExtID", ddExtID)
return &DependencyDescriptorParser{
- ddExt: ddExt,
+ ddExtID: ddExtID,
logger: logger,
onMaxLayerChanged: onMaxLayerChanged,
}
}
-func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*dd.DependencyDescriptor, VideoLayer, error) {
- var videoLayer VideoLayer
- if ddBuf := pkt.GetExtension(r.ddExt); ddBuf != nil {
- var ddVal dd.DependencyDescriptor
- ext := &dd.DependencyDescriptorExtension{
- Descriptor: &ddVal,
- Structure: r.structure,
- }
- _, err := ext.Unmarshal(ddBuf)
- if err != nil {
- // r.logger.Debugw("failed to parse generic dependency descriptor", "err", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf))
- return nil, videoLayer, err
- }
-
- if ddVal.FrameDependencies != nil {
- videoLayer.Spatial, videoLayer.Temporal = int32(ddVal.FrameDependencies.SpatialId), int32(ddVal.FrameDependencies.TemporalId)
- }
- if ddVal.AttachedStructure != nil && !ddVal.FirstPacketInFrame {
- // r.logger.Debugw("ignoring non-first packet in frame with attached structure")
- return nil, videoLayer, nil
- }
-
- if ddVal.AttachedStructure != nil {
- var maxSpatial, maxTemporal int32
- r.structure = ddVal.AttachedStructure
- r.decodeTargetLayer = r.decodeTargetLayer[:0]
- for target := 0; target < r.structure.NumDecodeTargets; target++ {
- layer := VideoLayer{0, 0}
- for _, t := range r.structure.Templates {
- if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent {
- if layer.Spatial < int32(t.SpatialId) {
- layer.Spatial = int32(t.SpatialId)
- }
- if layer.Temporal < int32(t.TemporalId) {
- layer.Temporal = int32(t.TemporalId)
- }
- }
- }
- if layer.Spatial > maxSpatial {
- maxSpatial = layer.Spatial
- }
- if layer.Temporal > maxTemporal {
- maxTemporal = layer.Temporal
- }
- r.decodeTargetLayer = append(r.decodeTargetLayer, layer)
- }
- r.logger.Debugw("max layer changed", "maxSpatial", maxSpatial, "maxTemporal", maxTemporal)
- go r.onMaxLayerChanged(maxSpatial, maxTemporal)
- }
-
- if ddVal.AttachedStructure != nil && ddVal.FirstPacketInFrame {
- r.logger.Debugw(fmt.Sprintf("parsed dependency descriptor\n%s", ddVal.String()))
- }
-
- if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil {
- var maxSpatial, maxTemporal int32
- for dt, layer := range r.decodeTargetLayer {
- if *mask&(1<
maxSpatial {
+ maxSpatial = dt.Layer.Spatial
+ }
+ if dt.Layer.Temporal > maxTemporal {
+ maxTemporal = dt.Layer.Temporal
+ }
+ if dt.Layer.Spatial <= layer.Spatial && dt.Layer.Temporal <= layer.Temporal {
+ activeBitMask |= 1 << dt.Target
+ }
+ }
+ if layer.Spatial == maxSpatial && layer.Temporal == maxTemporal {
+ // all the decode targets are selected
+ return nil
+ }
+
+ return &activeBitMask
+}
+
+// ------------------------------------------------------------------------------
diff --git a/pkg/sfu/buffer/fps.go b/pkg/sfu/buffer/fps.go
index f8f192227..518b8e38c 100644
--- a/pkg/sfu/buffer/fps.go
+++ b/pkg/sfu/buffer/fps.go
@@ -7,7 +7,7 @@ import (
"github.com/pion/rtp/codecs"
)
-var minFramesForCalculation = [DefaultMaxLayerTemporal + 1]int{8, 15, 40}
+var minFramesForCalculation = [...]int{8, 15, 40}
type frameInfo struct {
seq uint16
@@ -357,7 +357,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool {
return false
}
- fn := ep.DependencyDescriptor.FrameNumber
+ fn := ep.DependencyDescriptor.Descriptor.FrameNumber
if f.baseFrame == nil {
f.baseFrame = &frameInfo{seq: ep.Packet.SequenceNumber, ts: ep.Packet.Timestamp, fn: fn}
f.fnReceived[0] = f.baseFrame
@@ -397,7 +397,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool {
fn: fn,
temporal: temporal,
spatial: spatial,
- frameDiff: ep.DependencyDescriptor.FrameDependencies.FrameDiffs,
+ frameDiff: ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs,
}
f.fnReceived[baseDiff] = fi
@@ -411,7 +411,7 @@ func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool {
if chain.Len() == 0 {
chain.PushBack(fn)
}
- for _, fdiff := range ep.DependencyDescriptor.FrameDependencies.FrameDiffs {
+ for _, fdiff := range ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs {
dependFrame := fn - uint16(fdiff)
// frame too old, ignore
if dependFrame-f.secondFrames[spatial][temporal].fn > 0x8000 {
diff --git a/pkg/sfu/buffer/fps_test.go b/pkg/sfu/buffer/fps_test.go
index 5f0ff79ce..206041acf 100644
--- a/pkg/sfu/buffer/fps_test.go
+++ b/pkg/sfu/buffer/fps_test.go
@@ -31,10 +31,12 @@ func (f *testFrameInfo) toVP8() *ExtPacket {
func (f *testFrameInfo) toDD() *ExtPacket {
return &ExtPacket{
Packet: &rtp.Packet{Header: f.header},
- DependencyDescriptor: &dependencydescriptor.DependencyDescriptor{
- FrameNumber: f.framenumber,
- FrameDependencies: &dependencydescriptor.FrameDependencyTemplate{
- FrameDiffs: f.frameDiff,
+ DependencyDescriptor: &DependencyDescriptorWithDecodeTarget{
+ Descriptor: &dependencydescriptor.DependencyDescriptor{
+ FrameNumber: f.framenumber,
+ FrameDependencies: &dependencydescriptor.FrameDependencyTemplate{
+ FrameDiffs: f.frameDiff,
+ },
},
},
VideoLayer: VideoLayer{Spatial: int32(f.spatial), Temporal: int32(f.temporal)},
diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go b/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go
index 4d6e339cc..7493f1bfc 100644
--- a/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go
+++ b/pkg/sfu/dependencydescriptor/dependencydescriptorextension.go
@@ -9,23 +9,20 @@ import (
// DependencyDescriptorExtension is a extension payload format in
// https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension
+func formatBitmask(b *uint32) string {
+ if b == nil {
+ return "-"
+ }
+ return strconv.FormatInt(int64(*b), 2)
+}
+
+// ------------------------------------------------------------------------------
+
type DependencyDescriptorExtension struct {
Descriptor *DependencyDescriptor
Structure *FrameDependencyStructure
}
-func (d *DependencyDescriptor) MarshalSize() (int, error) {
- return d.MarshalSizeWithActiveChains(^uint32(0))
-}
-
-func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) {
- writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d)
- if err != nil {
- return 0, err
- }
- return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil
-}
-
func (d *DependencyDescriptorExtension) Marshal() ([]byte, error) {
return d.MarshalWithActiveChains(^uint32(0))
}
@@ -48,6 +45,8 @@ func (d *DependencyDescriptorExtension) Unmarshal(buf []byte) (int, error) {
return reader.Parse()
}
+// ------------------------------------------------------------------------------
+
const (
MaxSpatialIds = 4
MaxTemporalIds = 8
@@ -59,9 +58,11 @@ const (
ExtensionUrl = "https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension"
)
+// ------------------------------------------------------------------------------
+
type DependencyDescriptor struct {
- FirstPacketInFrame bool // = true;
- LastPacketInFrame bool // = true;
+ FirstPacketInFrame bool
+ LastPacketInFrame bool
FrameNumber uint16
FrameDependencies *FrameDependencyTemplate
Resolution *RenderResolution
@@ -69,11 +70,16 @@ type DependencyDescriptor struct {
AttachedStructure *FrameDependencyStructure
}
-func formatBitmask(b *uint32) string {
- if b == nil {
- return "-"
+func (d *DependencyDescriptor) MarshalSize() (int, error) {
+ return d.MarshalSizeWithActiveChains(^uint32(0))
+}
+
+func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) {
+ writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d)
+ if err != nil {
+ return 0, err
}
- return strconv.FormatInt(int64(*b), 2)
+ return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil
}
func (d *DependencyDescriptor) String() string {
@@ -81,6 +87,8 @@ func (d *DependencyDescriptor) String() string {
d.FirstPacketInFrame, d.LastPacketInFrame, d.FrameNumber, *d.FrameDependencies, *d.Resolution, formatBitmask(d.ActiveDecodeTargetsBitmask), d.AttachedStructure)
}
+// ------------------------------------------------------------------------------
+
// Relationship of a frame to a Decode target.
type DecodeTargetIndication int
@@ -106,6 +114,8 @@ func (i DecodeTargetIndication) String() string {
}
}
+// ------------------------------------------------------------------------------
+
type FrameDependencyTemplate struct {
SpatialId int
TemporalId int
@@ -132,6 +142,8 @@ func (t *FrameDependencyTemplate) Clone() *FrameDependencyTemplate {
return t2
}
+// ------------------------------------------------------------------------------
+
type FrameDependencyStructure struct {
StructureId int
NumDecodeTargets int
@@ -156,7 +168,11 @@ func (f *FrameDependencyStructure) String() string {
return str
}
+// ------------------------------------------------------------------------------
+
type RenderResolution struct {
Width int
Height int
}
+
+// ------------------------------------------------------------------------------
diff --git a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go
index 99f3fd68a..68f00b863 100644
--- a/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go
+++ b/pkg/sfu/dependencydescriptor/dependencydescriptorreader.go
@@ -1,6 +1,8 @@
package dependencydescriptor
-import "errors"
+import (
+ "errors"
+)
type DependencyDescriptorReader struct {
// Output.
@@ -89,7 +91,6 @@ func (r *DependencyDescriptorReader) readMandatoryFields() error {
}
func (r *DependencyDescriptorReader) readExtendedFields() error {
-
templateDependencyStructurePresentFlag, err := r.buffer.ReadBool()
if err != nil {
return err
@@ -384,7 +385,6 @@ func (r *DependencyDescriptorReader) readFrameDtis() error {
}
func (r *DependencyDescriptorReader) readFrameFdiffs() error {
-
r.descriptor.FrameDependencies.FrameDiffs = r.descriptor.FrameDependencies.FrameDiffs[:0]
for {
nexFdiffSize, err := r.buffer.ReadBits(2)
diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go
index 256f5defc..430bb882f 100644
--- a/pkg/sfu/downtrack.go
+++ b/pkg/sfu/downtrack.go
@@ -1448,7 +1448,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
}
var extraExtensions []extensionData
- if len(meta.ddBytes) > 0 {
+ if d.dependencyDescriptorID != 0 && len(meta.ddBytes) != 0 {
extraExtensions = append(extraExtensions, extensionData{
id: uint8(d.dependencyDescriptorID),
payload: meta.ddBytes,
diff --git a/pkg/sfu/utils/wraparound.go b/pkg/sfu/utils/wraparound.go
new file mode 100644
index 000000000..e6c8bbf2c
--- /dev/null
+++ b/pkg/sfu/utils/wraparound.go
@@ -0,0 +1,121 @@
+package utils
+
+import (
+ "unsafe"
+)
+
+type number interface {
+ uint16 | uint32
+}
+
+type extendedNumber interface {
+ uint32 | uint64
+}
+
+type WrapAround[T number, ET extendedNumber] struct {
+ fullRange ET
+
+ initialized bool
+ start T
+ highest T
+ cycles int
+}
+
+func NewWrapAround[T number, ET extendedNumber]() *WrapAround[T, ET] {
+ var t T
+ return &WrapAround[T, ET]{
+ fullRange: 1 << (unsafe.Sizeof(t) * 8),
+ }
+}
+
+func (w *WrapAround[T, ET]) Seed(from *WrapAround[T, ET]) {
+ w.initialized = from.initialized
+ w.start = from.start
+ w.highest = from.highest
+ w.cycles = from.cycles
+}
+
+type wrapAroundUpdateResult[ET extendedNumber] struct {
+ IsRestart bool
+ PreExtendedStart ET // valid only if IsRestart = true
+ PreExtendedHighest ET
+ ExtendedVal ET
+}
+
+func (w *WrapAround[T, ET]) Update(val T) (result wrapAroundUpdateResult[ET]) {
+ if !w.initialized {
+ result.PreExtendedHighest = ET(val) - 1
+ result.ExtendedVal = ET(val)
+
+ w.start = val
+ w.highest = val
+ w.initialized = true
+ return
+ }
+
+ result.PreExtendedHighest = w.GetExtendedHighest()
+
+ gap := val - w.highest
+ if gap == 0 || gap > T(w.fullRange>>1) {
+ // duplicate OR out-of-order
+ result.IsRestart, result.PreExtendedStart, result.ExtendedVal = w.maybeAdjustStart(val)
+ return
+ }
+
+ // in-order
+ if val < w.highest {
+ w.cycles++
+ }
+ w.highest = val
+
+ result.ExtendedVal = ET(w.cycles)*w.fullRange + ET(val)
+ return
+}
+
+func (w *WrapAround[T, ET]) ResetHighest(val T) {
+ w.highest = val
+}
+
+func (w *WrapAround[T, ET]) GetStart() T {
+ return w.start
+}
+
+func (w *WrapAround[T, ET]) GetExtendedStart() ET {
+ return ET(w.start)
+}
+
+func (w *WrapAround[T, ET]) GetHighest() T {
+ return w.highest
+}
+
+func (w *WrapAround[T, ET]) GetExtendedHighest() ET {
+ return ET(w.cycles)*w.fullRange + ET(w.highest)
+}
+
+func (w *WrapAround[T, ET]) maybeAdjustStart(val T) (isRestart bool, preExtendedStart ET, extendedVal ET) {
+ // re-adjust start if necessary. The conditions are
+ // 1. Not seen more than half the range yet
+ // 1. wrap around compared to start and not completed a half cycle, sequences like (10, 65530) in uint16 space
+ // 2. no wrap around, but out-of-order compared to start and not completed a half cycle , sequences like (10, 9), (65530, 65528) in uint16 space
+ totalNum := w.GetExtendedHighest() - w.GetExtendedStart() + 1
+ if totalNum > (w.fullRange >> 1) {
+ extendedVal = ET(w.cycles)*w.fullRange + ET(val)
+ return
+ }
+
+ cycles := w.cycles
+ if val-w.start > T(w.fullRange>>1) {
+ // out-of-order with existing start => a new start
+ isRestart = true
+ preExtendedStart = w.GetExtendedStart()
+
+ if val > w.highest {
+ // wrap around
+ w.cycles = 1
+ cycles = 0
+ }
+ w.start = val
+ }
+ extendedVal = ET(cycles)*w.fullRange + ET(val)
+ return
+}
diff --git a/pkg/sfu/utils/wraparound_test.go b/pkg/sfu/utils/wraparound_test.go
new file mode 100644
index 000000000..828242f87
--- /dev/null
+++ b/pkg/sfu/utils/wraparound_test.go
@@ -0,0 +1,295 @@
+package utils
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestWrapAroundUint16(t *testing.T) {
+ w := NewWrapAround[uint16, uint32]()
+ testCases := []struct {
+ name string
+ input uint16
+ updated wrapAroundUpdateResult[uint32]
+ start uint16
+ extendedStart uint32
+ highest uint16
+ extendedHighest uint32
+ }{
+ // initialize
+ {
+ name: "initialize",
+ input: 10,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: 9,
+ ExtendedVal: 10,
+ },
+ start: 10,
+ extendedStart: 10,
+ highest: 10,
+ extendedHighest: 10,
+ },
+ // an older number without wrap around should reset start point
+ {
+ name: "reset start no wrap around",
+ input: 8,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: true,
+ PreExtendedStart: 10,
+ PreExtendedHighest: 10,
+ ExtendedVal: 8,
+ },
+ start: 8,
+ extendedStart: 8,
+ highest: 10,
+ extendedHighest: 10,
+ },
+ // an older number with wrap around should reset start point
+ {
+ name: "reset start wrap around",
+ input: (1 << 16) - 6,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: true,
+ PreExtendedStart: 8,
+ PreExtendedHighest: 10,
+ ExtendedVal: (1 << 16) - 6,
+ },
+ start: (1 << 16) - 6,
+ extendedStart: (1 << 16) - 6,
+ highest: 10,
+ extendedHighest: (1 << 16) + 10,
+ },
+ // an older number with wrap around should reset start point again
+ {
+ name: "reset start again",
+ input: (1 << 16) - 12,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: true,
+ PreExtendedStart: (1 << 16) - 6,
+ PreExtendedHighest: (1 << 16) + 10,
+ ExtendedVal: (1 << 16) - 12,
+ },
+ start: (1 << 16) - 12,
+ extendedStart: (1 << 16) - 12,
+ highest: 10,
+ extendedHighest: (1 << 16) + 10,
+ },
+ // duplicate should return same as highest
+ {
+ name: "duplicate",
+ input: 10,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 16) + 10,
+ ExtendedVal: (1 << 16) + 10,
+ },
+ start: (1 << 16) - 12,
+ extendedStart: (1 << 16) - 12,
+ highest: 10,
+ extendedHighest: (1 << 16) + 10,
+ },
+ // a significant jump in order should not reset start
+ {
+ name: "big in-order jump",
+ input: 1 << 15,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 16) + 10,
+ ExtendedVal: (1 << 16) + (1 << 15),
+ },
+ start: (1 << 16) - 12,
+ extendedStart: (1 << 16) - 12,
+ highest: 1 << 15,
+ extendedHighest: (1 << 16) + (1 << 15),
+ },
+ // now out-of-order should not reset start as half the range has been seen
+ {
+ name: "out-of-order after half range",
+ input: (1 << 15) - 1,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 16) + (1 << 15),
+ ExtendedVal: (1 << 16) + (1 << 15) - 1,
+ },
+ start: (1 << 16) - 12,
+ extendedStart: (1 << 16) - 12,
+ highest: 1 << 15,
+ extendedHighest: (1 << 16) + (1 << 15),
+ },
+ // in-order, should update highest
+ {
+ name: "in-order",
+ input: (1 << 15) + 3,
+ updated: wrapAroundUpdateResult[uint32]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 16) + (1 << 15),
+ ExtendedVal: (1 << 16) + (1 << 15) + 3,
+ },
+ start: (1 << 16) - 12,
+ extendedStart: (1 << 16) - 12,
+ highest: (1 << 15) + 3,
+ extendedHighest: (1 << 16) + (1 << 15) + 3,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ require.Equal(t, tc.updated, w.Update(tc.input))
+ require.Equal(t, tc.start, w.GetStart())
+ require.Equal(t, tc.extendedStart, w.GetExtendedStart())
+ require.Equal(t, tc.highest, w.GetHighest())
+ require.Equal(t, tc.extendedHighest, w.GetExtendedHighest())
+ })
+ }
+}
+
+func TestWrapAroundUint32(t *testing.T) {
+ w := NewWrapAround[uint32, uint64]()
+ testCases := []struct {
+ name string
+ input uint32
+ updated wrapAroundUpdateResult[uint64]
+ start uint32
+ extendedStart uint64
+ highest uint32
+ extendedHighest uint64
+ }{
+ // initialize
+ {
+ name: "initialize",
+ input: 10,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: 9,
+ ExtendedVal: 10,
+ },
+ start: 10,
+ extendedStart: 10,
+ highest: 10,
+ extendedHighest: 10,
+ },
+ // an older number without wrap around should reset start point
+ {
+ name: "reset start no wrap around",
+ input: 8,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: true,
+ PreExtendedStart: 10,
+ PreExtendedHighest: 10,
+ ExtendedVal: 8,
+ },
+ start: 8,
+ extendedStart: 8,
+ highest: 10,
+ extendedHighest: 10,
+ },
+ // an older number with wrap around should reset start point
+ {
+ name: "reset start wrap around",
+ input: (1 << 32) - 6,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: true,
+ PreExtendedStart: 8,
+ PreExtendedHighest: 10,
+ ExtendedVal: (1 << 32) - 6,
+ },
+ start: (1 << 32) - 6,
+ extendedStart: (1 << 32) - 6,
+ highest: 10,
+ extendedHighest: (1 << 32) + 10,
+ },
+ // an older number with wrap around should reset start point again
+ {
+ name: "reset start again",
+ input: (1 << 32) - 12,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: true,
+ PreExtendedStart: (1 << 32) - 6,
+ PreExtendedHighest: (1 << 32) + 10,
+ ExtendedVal: (1 << 32) - 12,
+ },
+ start: (1 << 32) - 12,
+ extendedStart: (1 << 32) - 12,
+ highest: 10,
+ extendedHighest: (1 << 32) + 10,
+ },
+ // duplicate should return same as highest
+ {
+ name: "duplicate",
+ input: 10,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 32) + 10,
+ ExtendedVal: (1 << 32) + 10,
+ },
+ start: (1 << 32) - 12,
+ extendedStart: (1 << 32) - 12,
+ highest: 10,
+ extendedHighest: (1 << 32) + 10,
+ },
+ // a significant jump in order should not reset start
+ {
+ name: "big in-order jump",
+ input: 1 << 31,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 32) + 10,
+ ExtendedVal: (1 << 32) + (1 << 31),
+ },
+ start: (1 << 32) - 12,
+ extendedStart: (1 << 32) - 12,
+ highest: 1 << 31,
+ extendedHighest: (1 << 32) + (1 << 31),
+ },
+ // now out-of-order should not reset start as half the range has been seen
+ {
+ name: "out-of-order after half range",
+ input: (1 << 31) - 1,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 32) + (1 << 31),
+ ExtendedVal: (1 << 32) + (1 << 31) - 1,
+ },
+ start: (1 << 32) - 12,
+ extendedStart: (1 << 32) - 12,
+ highest: 1 << 31,
+ extendedHighest: (1 << 32) + (1 << 31),
+ },
+ // in-order, should update highest
+ {
+ name: "in-order",
+ input: (1 << 31) + 3,
+ updated: wrapAroundUpdateResult[uint64]{
+ IsRestart: false,
+ PreExtendedStart: 0,
+ PreExtendedHighest: (1 << 32) + (1 << 31),
+ ExtendedVal: (1 << 32) + (1 << 31) + 3,
+ },
+ start: (1 << 32) - 12,
+ extendedStart: (1 << 32) - 12,
+ highest: (1 << 31) + 3,
+ extendedHighest: (1 << 32) + (1 << 31) + 3,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ require.Equal(t, tc.updated, w.Update(tc.input))
+ require.Equal(t, tc.start, w.GetStart())
+ require.Equal(t, tc.extendedStart, w.GetExtendedStart())
+ require.Equal(t, tc.highest, w.GetHighest())
+ require.Equal(t, tc.extendedHighest, w.GetExtendedHighest())
+ })
+ }
+}
diff --git a/pkg/sfu/videolayerselector/dependencydescriptor.go b/pkg/sfu/videolayerselector/dependencydescriptor.go
index 09d46515e..1fa1fc52a 100644
--- a/pkg/sfu/videolayerselector/dependencydescriptor.go
+++ b/pkg/sfu/videolayerselector/dependencydescriptor.go
@@ -2,39 +2,37 @@ package videolayerselector
import (
"fmt"
- "sort"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
- dd "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
+ dede "github.com/livekit/livekit-server/pkg/sfu/dependencydescriptor"
+ "github.com/livekit/livekit-server/pkg/sfu/utils"
"github.com/livekit/protocol/logger"
)
-type decodeTarget struct {
- Target int
- Layer buffer.VideoLayer
-}
-
type DependencyDescriptor struct {
*Base
- // DD-TODO : fields for frame chain detect
- // frameNumberWrapper Uint16Wrapper
- // expectKeyFrame bool
+ frameNum *utils.WrapAround[uint16, uint64]
+ decisions *SelectorDecisionCache
- decodeTargets []decodeTarget
+ needsDecodeTargetBitmask bool
activeDecodeTargetsBitmask *uint32
- structure *dd.FrameDependencyStructure
+ structure *dede.FrameDependencyStructure
}
func NewDependencyDescriptor(logger logger.Logger) *DependencyDescriptor {
return &DependencyDescriptor{
- Base: NewBase(logger),
+ Base: NewBase(logger),
+ frameNum: utils.NewWrapAround[uint16, uint64](),
+ decisions: NewSelectorDecisionCache(256),
}
}
func NewDependencyDescriptorFromNull(vls VideoLayerSelector) *DependencyDescriptor {
return &DependencyDescriptor{
- Base: vls.(*Null).Base,
+ Base: vls.(*Null).Base,
+ frameNum: utils.NewWrapAround[uint16, uint64](),
+ decisions: NewSelectorDecisionCache(256),
}
}
@@ -43,21 +41,80 @@ func (d *DependencyDescriptor) IsOvershootOkay() bool {
}
func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) {
- if extPkt.DependencyDescriptor == nil {
- // packet don't have dependency descriptor
+ ddwdt := extPkt.DependencyDescriptor
+ if ddwdt == nil {
+ // packet doesn't have dependency descriptor
+ return
+ }
+
+ dd := ddwdt.Descriptor
+
+ // a packet is relevant as long as it has DD extension
+ result.IsRelevant = true
+
+ frameNum := d.frameNum.Update(dd.FrameNumber)
+ extFrameNum := frameNum.ExtendedVal
+
+ fd := dd.FrameDependencies
+ incomingLayer := buffer.VideoLayer{
+ Spatial: int32(fd.SpatialId),
+ Temporal: int32(fd.TemporalId),
+ }
+
+ // early return if this frame is already forwarded or dropped
+ sd, err := d.decisions.GetDecision(extFrameNum)
+ if err != nil {
+ // do not mark as dropped as only error is an old frame
+ return
+ }
+ switch sd {
+ case selectorDecisionForwarded:
+ // a packet of an alreadty forwarded frame, maintain decision
+ result.RTPMarker = extPkt.Packet.Header.Marker || (dd.LastPacketInFrame && d.currentLayer.Spatial == int32(fd.SpatialId))
+ result.IsSelected = true
+
+ case selectorDecisionDropped:
+ // a packet of an alreadty dropped frame, maintain decision
return
}
if !d.currentLayer.IsValid() && !extPkt.KeyFrame {
+ d.decisions.AddDropped(extFrameNum)
return
}
- result.IsRelevant = true
+ // check decodability using reference frames
+ isDecodable := true
+ for _, fdiff := range fd.FrameDiffs {
+ if fdiff == 0 {
+ continue
+ }
- if extPkt.DependencyDescriptor.AttachedStructure != nil {
+ if sd, _ := d.decisions.GetDecision(extFrameNum - uint64(fdiff)); sd != selectorDecisionForwarded {
+ isDecodable = false
+ break
+ }
+ }
+ if !isDecodable {
+ // DD-TODO START
+ // Not decodable could happen due to packet loss or out-of-order packets,
+ // Need to figure out better ways to handle this.
+ //
+ // 1. Should definitely check if this frame is not part of current decode target OR discardable.
+ // In that case, forwarding can proceed without disruption.
+ // 2. Add a packet queue and try to de-jitter for some time. Safest is to packet copy to local queue on
+ // all down tracks.
+ // 3. Force a PLI and wait for a key frame.
+ // DD-TODO END
+ d.decisions.AddDropped(extFrameNum)
+ return
+ }
+
+ // DD-TODO should not update for out-of-order RTP packets
+ if dd.AttachedStructure != nil {
// update decode target layer and active decode targets
// DD-TODO : these targets info can be shared by all the downtracks, no need calculate in every selector
- d.updateDependencyStructure(extPkt.DependencyDescriptor.AttachedStructure)
+ d.updateDependencyStructure(dd.AttachedStructure)
}
// DD-TODO : we don't have a rtp queue to ensure the order of packets now,
@@ -67,133 +124,144 @@ func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (r
// only check DTI of the active decode target.
// it is not effeciency, at last we need check frame chain integrity.
- activeDecodeTargets := extPkt.DependencyDescriptor.ActiveDecodeTargetsBitmask
+ activeDecodeTargets := dd.ActiveDecodeTargetsBitmask
if activeDecodeTargets != nil {
d.logger.Debugw("active decode targets", "activeDecodeTargets", *activeDecodeTargets)
}
- currentTarget := -1
- for _, dt := range d.decodeTargets {
- // find target match with selected layer
- if dt.Layer.Spatial <= d.targetLayer.Spatial && dt.Layer.Temporal <= d.targetLayer.Temporal {
- if activeDecodeTargets == nil || ((*activeDecodeTargets)&(1< d.targetLayer.Spatial || dt.Layer.Temporal > d.targetLayer.Temporal {
+ continue
+ }
+
+ if activeDecodeTargets != nil && ((*activeDecodeTargets)&(1< maxSpatial {
- maxSpatial = dt.Layer.Spatial
- }
- if dt.Layer.Temporal > maxTemporal {
- maxTemporal = dt.Layer.Temporal
- }
- if dt.Layer.Spatial <= targetLayer.Spatial && dt.Layer.Temporal <= targetLayer.Temporal {
- activeBitMask |= 1 << dt.Target
- }
- }
- if targetLayer.Spatial == maxSpatial && targetLayer.Temporal == maxTemporal {
- // all the decode targets are selected
- d.activeDecodeTargetsBitmask = nil
- } else {
- d.activeDecodeTargetsBitmask = &activeBitMask
- }
- d.logger.Debugw("setting target", "targetlayer", targetLayer, "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask)
+ d.needsDecodeTargetBitmask = true
}
-func (d *DependencyDescriptor) updateDependencyStructure(structure *dd.FrameDependencyStructure) {
+func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure) {
d.structure = structure
- d.decodeTargets = d.decodeTargets[:0]
-
- for target := 0; target < structure.NumDecodeTargets; target++ {
- layer := buffer.VideoLayer{Spatial: 0, Temporal: 0}
- for _, t := range structure.Templates {
- if t.DecodeTargetIndications[target] != dd.DecodeTargetNotPresent {
- if layer.Spatial < int32(t.SpatialId) {
- layer.Spatial = int32(t.SpatialId)
- }
- if layer.Temporal < int32(t.TemporalId) {
- layer.Temporal = int32(t.TemporalId)
- }
- }
- }
- d.decodeTargets = append(d.decodeTargets, decodeTarget{target, layer})
- }
-
- // sort decode target layer by spatial and temporal from high to low
- sort.Slice(d.decodeTargets, func(i, j int) bool {
- return d.decodeTargets[i].Layer.GreaterThan(d.decodeTargets[j].Layer)
- })
- d.logger.Debugw(fmt.Sprintf("update decode targets: %v", d.decodeTargets))
-}
-
-// DD-TODO : use generic wrapper when updated to go 1.18
-type Uint16Wrapper struct {
- lastValue *uint16
- lastUnwrapped int32
-}
-
-func (w *Uint16Wrapper) Unwrap(value uint16) int32 {
- if w.lastValue == nil {
- w.lastValue = &value
- w.lastUnwrapped = int32(value)
- return int32(*w.lastValue)
- }
-
- diff := value - *w.lastValue
- w.lastUnwrapped += int32(diff)
- if diff == 0x8000 && value < *w.lastValue {
- w.lastUnwrapped -= 0x10000
- } else if diff > 0x8000 {
- w.lastUnwrapped -= 0x10000
- }
-
- *w.lastValue = value
- return w.lastUnwrapped
}
diff --git a/pkg/sfu/videolayerselector/selectordecisioncache.go b/pkg/sfu/videolayerselector/selectordecisioncache.go
new file mode 100644
index 000000000..b83bf9559
--- /dev/null
+++ b/pkg/sfu/videolayerselector/selectordecisioncache.go
@@ -0,0 +1,114 @@
+package videolayerselector
+
+import (
+ "fmt"
+)
+
+// ----------------------------------------------------------------------
+
+type selectorDecision int
+
+const (
+ selectorDecisionMissing selectorDecision = iota
+ selectorDecisionDropped
+ selectorDecisionForwarded
+ selectorDecisionUnknown
+)
+
+func (s selectorDecision) String() string {
+ switch s {
+ case selectorDecisionMissing:
+ return "MISSING"
+ case selectorDecisionDropped:
+ return "DROPPED"
+ case selectorDecisionForwarded:
+ return "FORWARDED"
+ case selectorDecisionUnknown:
+ return "UNKNOWN"
+ default:
+ return fmt.Sprintf("%d", int(s))
+ }
+}
+
+// ----------------------------------------------------------------------
+
+type SelectorDecisionCache struct {
+ initialized bool
+ base uint64
+ last uint64
+ masks []uint64
+ numEntries uint64
+}
+
+func NewSelectorDecisionCache(maxNumElements uint64) *SelectorDecisionCache {
+ numElements := (maxNumElements*2 + 63) / 64
+ return &SelectorDecisionCache{
+ masks: make([]uint64, numElements),
+ numEntries: numElements * 32, // 2 bits per entry
+ }
+}
+
+func (s *SelectorDecisionCache) AddForwarded(entity uint64) {
+ s.addEntity(entity, selectorDecisionForwarded)
+}
+
+func (s *SelectorDecisionCache) AddDropped(entity uint64) {
+ s.addEntity(entity, selectorDecisionDropped)
+}
+
+func (s *SelectorDecisionCache) GetDecision(entity uint64) (selectorDecision, error) {
+ if !s.initialized || entity > s.last || entity < s.base {
+ return selectorDecisionUnknown, nil
+ }
+
+ offset := s.last - entity
+ if offset >= s.numEntries {
+ // asking for something too old
+ return selectorDecisionUnknown, fmt.Errorf("too old, oldest: %d, asking: %d", s.last-s.numEntries+1, entity)
+ }
+
+ return s.getEntity(entity), nil
+}
+
+func (s *SelectorDecisionCache) addEntity(entity uint64, sd selectorDecision) {
+ if !s.initialized {
+ s.initialized = true
+ s.base = entity
+ s.last = entity
+ s.setEntity(entity, sd)
+ return
+ }
+
+ if entity <= s.base {
+ // before base, too old
+ return
+ }
+
+ if entity <= s.last {
+ s.setEntity(entity, sd)
+ return
+ }
+
+ for e := s.last + 1; e != entity; e++ {
+ s.setEntity(e, selectorDecisionMissing)
+ }
+ s.setEntity(entity, sd)
+ s.last = entity
+}
+
+func (s *SelectorDecisionCache) setEntity(entity uint64, sd selectorDecision) {
+ index, bitpos := s.getPos(entity)
+ s.masks[index] &= ^(0x3 << bitpos) // clear before bitwise OR
+ s.masks[index] |= (uint64(sd) & 0x3) << bitpos
+}
+
+func (s *SelectorDecisionCache) getEntity(entity uint64) selectorDecision {
+ index, bitpos := s.getPos(entity)
+ return selectorDecision((s.masks[index] >> bitpos) & 0x3)
+}
+
+func (s *SelectorDecisionCache) getPos(entity uint64) (int, int) {
+ // 2 bits per entity, a uint64 mask can hold 32 entities
+ offset := (entity - s.base) % s.numEntries
+ return int(offset >> 5), int(offset&0x1F) * 2
+}