diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 216ae5b81..73f644201 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -160,6 +160,7 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra receiver, track, t.PublisherID(), + t.params.TrackInfo.Source, t.params.Logger, sfu.WithPliThrottle(t.params.PLIThrottleConfig), sfu.WithLoadBalanceThreshold(20), diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 78a5e98fd..a2e7727f1 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -138,6 +138,7 @@ func NewWebRTCReceiver( receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote, pid livekit.ParticipantID, + source livekit.TrackSource, logger logger.Logger, opts ...ReceiverOpts, ) *WebRTCReceiver { @@ -155,7 +156,7 @@ func NewWebRTCReceiver( index: make(map[livekit.ParticipantID]int), free: make(map[int]struct{}), numProcs: runtime.NumCPU(), - streamTrackerManager: NewStreamTrackerManager(logger), + streamTrackerManager: NewStreamTrackerManager(logger, source), } w.streamTrackerManager.OnAvailableLayersChanged(w.downTrackLayerChange) diff --git a/pkg/sfu/streamallocator.go b/pkg/sfu/streamallocator.go index 3c0acff35..410d47523 100644 --- a/pkg/sfu/streamallocator.go +++ b/pkg/sfu/streamallocator.go @@ -1224,8 +1224,6 @@ func (t *Track) SetPriority(priority uint8) bool { switch t.source { case livekit.TrackSource_SCREEN_SHARE: priority = PriorityDefaultScreenshare - case livekit.TrackSource_SCREEN_SHARE_AUDIO: - priority = PriorityDefaultScreenshare default: priority = PriorityDefaultVideo } @@ -1248,7 +1246,7 @@ func (t *Track) DownTrack() *DownTrack { } func (t *Track) IsManaged() bool { - return (t.source != livekit.TrackSource_SCREEN_SHARE && t.source != livekit.TrackSource_SCREEN_SHARE_AUDIO) || t.isSimulcast + return t.source != livekit.TrackSource_SCREEN_SHARE || t.isSimulcast } func (t *Track) ID() livekit.TrackID { diff --git a/pkg/sfu/streamtracker.go b/pkg/sfu/streamtracker.go index 13e385dd7..4b5d48300 100644 --- a/pkg/sfu/streamtracker.go +++ b/pkg/sfu/streamtracker.go @@ -27,14 +27,22 @@ const ( StreamStatusActive StreamStatus = 1 ) +type StreamTrackerParams struct { + // number of samples needed per cycle + SamplesRequired uint32 + + // number of cycles needed to be active + CyclesRequired uint32 + + CycleDuration time.Duration + + Logger logger.Logger +} + // StreamTracker keeps track of packet flow and ensures a particular up track is consistently producing // It runs its own goroutine for detection, and fires OnStatusChanged callback type StreamTracker struct { - // number of samples needed per cycle - samplesRequired uint32 - // number of cycles needed to be active - cyclesRequired uint64 - cycleDuration time.Duration + params StreamTrackerParams onStatusChanged func(status StreamStatus) @@ -49,7 +57,7 @@ type StreamTracker struct { status StreamStatus // only access within detectWorker - cycleCount uint64 + cycleCount uint32 // only access by the same goroutine as Observe lastSN uint16 @@ -59,13 +67,11 @@ type StreamTracker struct { isStopped atomic.Bool } -func NewStreamTracker(logger logger.Logger, samplesRequired uint32, cyclesRequired uint64, cycleDuration time.Duration) *StreamTracker { +func NewStreamTracker(params StreamTrackerParams) *StreamTracker { s := &StreamTracker{ - samplesRequired: samplesRequired, - cyclesRequired: cyclesRequired, - cycleDuration: cycleDuration, - status: StreamStatusStopped, - callbacksQueue: utils.NewOpsQueue(logger), + params: params, + status: StreamStatusStopped, + callbacksQueue: utils.NewOpsQueue(params.Logger), } return s } @@ -181,7 +187,7 @@ func (s *StreamTracker) Observe(sn uint16) { } func (s *StreamTracker) detectWorker(generation uint32) { - ticker := time.NewTicker(s.cycleDuration) + ticker := time.NewTicker(s.params.CycleDuration) for { <-ticker.C @@ -198,7 +204,7 @@ func (s *StreamTracker) detectChanges() { return } - if s.countSinceLast.Load() >= s.samplesRequired { + if s.countSinceLast.Load() >= s.params.SamplesRequired { s.cycleCount += 1 } else { s.cycleCount = 0 @@ -207,7 +213,7 @@ func (s *StreamTracker) detectChanges() { if s.cycleCount == 0 { // flip to stopped s.maybeSetStopped() - } else if s.cycleCount >= s.cyclesRequired { + } else if s.cycleCount >= s.params.CyclesRequired { // flip to active s.maybeSetActive() } diff --git a/pkg/sfu/streamtracker_test.go b/pkg/sfu/streamtracker_test.go index 04beb0525..6ba5c5d83 100644 --- a/pkg/sfu/streamtracker_test.go +++ b/pkg/sfu/streamtracker_test.go @@ -12,10 +12,19 @@ import ( "github.com/livekit/protocol/logger" ) +func newStreamTracker(samplesRequired uint32, cyclesRequired uint32, cycleDuration time.Duration) *StreamTracker { + return NewStreamTracker(StreamTrackerParams{ + SamplesRequired: samplesRequired, + CyclesRequired: cyclesRequired, + CycleDuration: cycleDuration, + Logger: logger.Logger(logger.GetLogger()), + }) +} + func TestStreamTracker(t *testing.T) { t.Run("flips to active on first observe", func(t *testing.T) { callbackCalled := atomic.NewBool(false) - tracker := NewStreamTracker(logger.Logger(logger.GetLogger()), 5, 60, 500*time.Millisecond) + tracker := newStreamTracker(5, 60, 500*time.Millisecond) tracker.Start() tracker.OnStatusChanged(func(status StreamStatus) { callbackCalled.Store(true) @@ -40,7 +49,7 @@ func TestStreamTracker(t *testing.T) { }) t.Run("flips to inactive immediately", func(t *testing.T) { - tracker := NewStreamTracker(logger.Logger(logger.GetLogger()), 5, 60, 500*time.Millisecond) + tracker := newStreamTracker(5, 60, 500*time.Millisecond) tracker.Start() require.Equal(t, StreamStatusStopped, tracker.Status()) @@ -76,7 +85,7 @@ func TestStreamTracker(t *testing.T) { }) t.Run("flips back to active after iterations", func(t *testing.T) { - tracker := NewStreamTracker(logger.Logger(logger.GetLogger()), 1, 2, 500*time.Millisecond) + tracker := newStreamTracker(1, 2, 500*time.Millisecond) tracker.Start() require.Equal(t, StreamStatusStopped, tracker.Status()) @@ -103,7 +112,7 @@ func TestStreamTracker(t *testing.T) { }) t.Run("does not change to inactive when paused", func(t *testing.T) { - tracker := NewStreamTracker(logger.Logger(logger.GetLogger()), 5, 60, 500*time.Millisecond) + tracker := newStreamTracker(5, 60, 500*time.Millisecond) tracker.Start() tracker.Observe(1) testutils.WithTimeout(t, func() string { @@ -123,7 +132,7 @@ func TestStreamTracker(t *testing.T) { t.Run("flips back to active on first observe after reset", func(t *testing.T) { callbackCalled := atomic.NewUint32(0) - tracker := NewStreamTracker(logger.Logger(logger.GetLogger()), 5, 60, 500*time.Millisecond) + tracker := newStreamTracker(5, 60, 500*time.Millisecond) tracker.Start() tracker.OnStatusChanged(func(status StreamStatus) { callbackCalled.Inc() diff --git a/pkg/sfu/streamtrackermanager.go b/pkg/sfu/streamtrackermanager.go index 55df1ed2b..1bb83d661 100644 --- a/pkg/sfu/streamtrackermanager.go +++ b/pkg/sfu/streamtrackermanager.go @@ -5,11 +5,52 @@ import ( "sync" "time" + "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) +var ( + ConfigVideo = []StreamTrackerParams{ + { + SamplesRequired: 1, + CyclesRequired: 4, + CycleDuration: 500 * time.Millisecond, + }, + { + SamplesRequired: 5, + CyclesRequired: 60, + CycleDuration: 500 * time.Millisecond, + }, + { + SamplesRequired: 5, + CyclesRequired: 60, + CycleDuration: 500 * time.Millisecond, + }, + } + + // be very forgiving for screen share to account for cases like static screen where there could be only one packet per second + ConfigScreenshare = []StreamTrackerParams{ + { + SamplesRequired: 1, + CyclesRequired: 1, + CycleDuration: 2 * time.Second, + }, + { + SamplesRequired: 1, + CyclesRequired: 1, + CycleDuration: 2 * time.Second, + }, + { + SamplesRequired: 1, + CyclesRequired: 1, + CycleDuration: 2 * time.Second, + }, + } +) + type StreamTrackerManager struct { logger logger.Logger + source livekit.TrackSource lock sync.RWMutex @@ -21,9 +62,10 @@ type StreamTrackerManager struct { onAvailableLayersChanged func(availableLayers []int32) } -func NewStreamTrackerManager(logger logger.Logger) *StreamTrackerManager { +func NewStreamTrackerManager(logger logger.Logger, source livekit.TrackSource) *StreamTrackerManager { return &StreamTrackerManager{ logger: logger, + source: source, maxExpectedLayer: DefaultMaxLayerSpatial, } } @@ -33,16 +75,22 @@ func (s *StreamTrackerManager) OnAvailableLayersChanged(f func(availableLayers [ } func (s *StreamTrackerManager) AddTracker(layer int32) { - cycleDuration := 500 * time.Millisecond - samplesRequired := uint32(5) - cyclesRequired := uint64(60) // 30s of continuous stream - if layer == 0 { - // be very forgiving for base layer to account for cases like static screen share where there could be only one packet per second - samplesRequired = 1 - cyclesRequired = 1 // 1 packet in 2 seconds - cycleDuration = 2 * time.Second + var params StreamTrackerParams + if s.source == livekit.TrackSource_SCREEN_SHARE { + if int(layer) >= len(ConfigScreenshare) { + return + } + + params = ConfigScreenshare[layer] + } else { + if int(layer) >= len(ConfigVideo) { + return + } + + params = ConfigVideo[layer] } - tracker := NewStreamTracker(s.logger, samplesRequired, cyclesRequired, cycleDuration) + params.Logger = s.logger + tracker := NewStreamTracker(params) tracker.OnStatusChanged(func(status StreamStatus) { if status == StreamStatusStopped { s.removeAvailableLayer(layer)