Files
livekit/pkg/sfu/streamtracker/streamtracker.go
Raja Subramanian 74891f30fe Protect against incorrect temporal layer. (#4327)
Seeing some tracks with temporal layer higher than array bounds.
Protect against it and log some info to understand better.
2026-02-19 23:57:58 +05:30

317 lines
6.5 KiB
Go

// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package streamtracker
import (
"fmt"
"sync"
"time"
"go.uber.org/atomic"
"github.com/livekit/livekit-server/pkg/sfu/buffer"
"github.com/livekit/protocol/logger"
)
// ------------------------------------------------------------
type StreamStatus int32
func (s StreamStatus) String() string {
switch s {
case StreamStatusStopped:
return "stopped"
case StreamStatusActive:
return "active"
default:
return fmt.Sprintf("unknown: %d", int(s))
}
}
const (
StreamStatusStopped StreamStatus = iota
StreamStatusActive
)
// ------------------------------------------------------------
type StreamTrackerParams struct {
StreamTrackerImpl StreamTrackerImpl
BitrateReportInterval time.Duration
Logger logger.Logger
}
type StreamTracker struct {
params StreamTrackerParams
onStatusChanged func(status StreamStatus)
onBitrateAvailable func()
lock sync.RWMutex
paused bool
generation atomic.Uint32
status StreamStatus
lastNotifiedStatus StreamStatus
lastBitrateReport time.Time
bytesForBitrate [buffer.DefaultMaxLayerTemporal + 1]int64
bitrate [buffer.DefaultMaxLayerTemporal + 1]int64
isStopped bool
}
func NewStreamTracker(params StreamTrackerParams) *StreamTracker {
return &StreamTracker{
params: params,
status: StreamStatusStopped,
}
}
func (s *StreamTracker) OnStatusChanged(f func(status StreamStatus)) {
s.onStatusChanged = f
}
func (s *StreamTracker) OnBitrateAvailable(f func()) {
s.onBitrateAvailable = f
}
func (s *StreamTracker) Status() StreamStatus {
s.lock.RLock()
defer s.lock.RUnlock()
return s.status
}
func (s *StreamTracker) setStatusLocked(status StreamStatus) {
s.status = status
}
func (s *StreamTracker) maybeNotifyStatus() {
var status StreamStatus
notify := false
s.lock.Lock()
if s.status != s.lastNotifiedStatus {
notify = true
status = s.status
s.lastNotifiedStatus = s.status
}
s.lock.Unlock()
if notify && s.onStatusChanged != nil {
s.onStatusChanged(status)
}
}
func (s *StreamTracker) Start() {
s.lock.Lock()
defer s.lock.Unlock()
s.params.StreamTrackerImpl.Start()
}
func (s *StreamTracker) Stop() {
s.lock.Lock()
defer s.lock.Unlock()
if s.isStopped {
return
}
s.isStopped = true
// bump generation to trigger exit of worker
s.generation.Inc()
s.params.StreamTrackerImpl.Stop()
}
func (s *StreamTracker) Reset() {
s.lock.Lock()
if s.isStopped {
s.lock.Unlock()
return
}
s.resetLocked()
s.lock.Unlock()
s.maybeNotifyStatus()
}
func (s *StreamTracker) resetLocked() {
// bump generation to trigger exit of current worker
s.generation.Inc()
s.setStatusLocked(StreamStatusStopped)
for i := range len(s.bytesForBitrate) {
s.bytesForBitrate[i] = 0
}
for i := range len(s.bitrate) {
s.bitrate[i] = 0
}
s.params.StreamTrackerImpl.Reset()
}
func (s *StreamTracker) SetPaused(paused bool) {
s.lock.Lock()
s.paused = paused
if !paused {
s.resetLocked()
} else {
// bump generation to trigger exit of current worker
s.generation.Inc()
s.setStatusLocked(StreamStatusStopped)
}
s.lock.Unlock()
s.maybeNotifyStatus()
}
func (s *StreamTracker) Observe(
temporalLayer int32,
pktSize int,
payloadSize int,
hasMarker bool,
ts uint32,
_ *buffer.ExtDependencyDescriptor,
) {
s.lock.Lock()
if s.isStopped || s.paused || payloadSize == 0 {
s.lock.Unlock()
return
}
statusChange := s.params.StreamTrackerImpl.Observe(hasMarker, ts)
if statusChange == StreamStatusChangeActive {
s.setStatusLocked(StreamStatusActive)
s.lastBitrateReport = time.Now()
go s.worker(s.generation.Load())
}
if temporalLayer >= 0 && int(temporalLayer) < len(s.bytesForBitrate) {
s.bytesForBitrate[temporalLayer] += int64(pktSize)
} else if int(temporalLayer) >= len(s.bytesForBitrate) {
s.params.Logger.Warnw(
"invalid temporal layer", nil,
"temporalLayer", temporalLayer,
"pktSize", pktSize,
"payloadSize", payloadSize,
"hasMarker", hasMarker,
"ts", ts,
)
}
s.lock.Unlock()
if statusChange != StreamStatusChangeNone {
s.maybeNotifyStatus()
}
}
// BitrateTemporalCumulative returns the current stream bitrate temporal layer accumulated with lower temporal layers.
func (s *StreamTracker) BitrateTemporalCumulative() []int64 {
s.lock.RLock()
defer s.lock.RUnlock()
// copy and process
brs := make([]int64, len(s.bitrate))
copy(brs, s.bitrate[:])
for i := len(brs) - 1; i >= 1; i-- {
if brs[i] != 0 {
for j := i - 1; j >= 0; j-- {
brs[i] += brs[j]
}
}
}
// clear higher layers
for i := range brs {
if brs[i] == 0 {
for j := i + 1; j < len(brs); j++ {
brs[j] = 0
}
}
}
return brs
}
func (s *StreamTracker) worker(generation uint32) {
ticker := time.NewTicker(s.params.StreamTrackerImpl.GetCheckInterval())
defer ticker.Stop()
tickerBitrate := time.NewTicker(s.params.BitrateReportInterval)
defer tickerBitrate.Stop()
for {
select {
case <-ticker.C:
if generation != s.generation.Load() {
return
}
s.updateStatus()
case <-tickerBitrate.C:
if generation != s.generation.Load() {
return
}
s.bitrateReport()
}
}
}
func (s *StreamTracker) updateStatus() {
s.lock.Lock()
switch s.params.StreamTrackerImpl.CheckStatus() {
case StreamStatusChangeStopped:
s.setStatusLocked(StreamStatusStopped)
case StreamStatusChangeActive:
s.setStatusLocked(StreamStatusActive)
}
s.lock.Unlock()
s.maybeNotifyStatus()
}
func (s *StreamTracker) bitrateReport() {
// run this even if paused to drain out bitrate if there are no packets coming in
s.lock.Lock()
now := time.Now()
diff := now.Sub(s.lastBitrateReport)
s.lastBitrateReport = now
bitrateAvailabilityChanged := false
for i := range len(s.bytesForBitrate) {
bitrate := int64(float64(s.bytesForBitrate[i]*8) / diff.Seconds())
if (s.bitrate[i] == 0 && bitrate > 0) || (s.bitrate[i] > 0 && bitrate == 0) {
bitrateAvailabilityChanged = true
}
s.bitrate[i] = bitrate
s.bytesForBitrate[i] = 0
}
s.lock.Unlock()
if bitrateAvailabilityChanged && s.onBitrateAvailable != nil {
s.onBitrateAvailable()
}
}