mirror of
https://github.com/livekit/livekit.git
synced 2026-03-30 15:35:41 +00:00
Replacing hand rolled ion-sfu atomic with uber/atomic (#465)
* Replacing hand rolled ion-sfu atomic with uber/atomic * Remove another hand rolled atomic
This commit is contained in:
2
go.mod
2
go.mod
@@ -40,7 +40,7 @@ require (
|
||||
github.com/ua-parser/uap-go v0.0.0-20211112212520-00c877edfe0f
|
||||
github.com/urfave/cli/v2 v2.3.0
|
||||
github.com/urfave/negroni v1.0.0
|
||||
go.uber.org/atomic v1.7.0
|
||||
go.uber.org/atomic v1.9.0
|
||||
go.uber.org/zap v1.19.1
|
||||
google.golang.org/protobuf v1.27.1
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
|
||||
|
||||
3
go.sum
3
go.sum
@@ -275,8 +275,9 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
|
||||
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
|
||||
go.uber.org/goleak v1.1.11-0.20210813005559-691160354723 h1:sHOAIxRGBp443oHZIPB+HsUGaksVCXVQENPxwTfQdH4=
|
||||
go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package sfu
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (a *atomicBool) set(value bool) (swapped bool) {
|
||||
if value {
|
||||
return atomic.SwapInt32((*int32)(a), 1) == 0
|
||||
}
|
||||
return atomic.SwapInt32((*int32)(a), 0) == 1
|
||||
}
|
||||
|
||||
func (a *atomicBool) get() bool {
|
||||
return atomic.LoadInt32((*int32)(a)) != 0
|
||||
}
|
||||
|
||||
type atomicUint8 uint32
|
||||
|
||||
func (a *atomicUint8) set(value uint8) {
|
||||
atomic.StoreUint32((*uint32)(a), uint32(value))
|
||||
}
|
||||
|
||||
func (a *atomicUint8) get() uint8 {
|
||||
return uint8(atomic.LoadUint32((*uint32)(a)))
|
||||
}
|
||||
|
||||
type atomicUint32 uint32
|
||||
|
||||
func (a *atomicUint32) set(value uint32) {
|
||||
atomic.StoreUint32((*uint32)(a), value)
|
||||
}
|
||||
|
||||
func (a *atomicUint32) add(value uint32) {
|
||||
atomic.AddUint32((*uint32)(a), value)
|
||||
}
|
||||
|
||||
func (a *atomicUint32) get() uint32 {
|
||||
return atomic.LoadUint32((*uint32)(a))
|
||||
}
|
||||
|
||||
type atomicInt64 int64
|
||||
|
||||
func (a *atomicInt64) set(value int64) {
|
||||
atomic.StoreInt64((*int64)(a), value)
|
||||
}
|
||||
|
||||
func (a *atomicInt64) get() int64 {
|
||||
return atomic.LoadInt64((*int64)(a))
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gammazero/deque"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/sdp/v3"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -54,7 +54,7 @@ type Buffer struct {
|
||||
twccExt uint8
|
||||
audioExt uint8
|
||||
bound bool
|
||||
closed atomicBool
|
||||
closed atomic.Bool
|
||||
mime string
|
||||
|
||||
// supported feedbacks
|
||||
@@ -197,7 +197,7 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
if b.closed.get() {
|
||||
if b.closed.Load() {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) {
|
||||
|
||||
func (b *Buffer) Read(buff []byte) (n int, err error) {
|
||||
for {
|
||||
if b.closed.get() {
|
||||
if b.closed.Load() {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
@@ -242,7 +242,7 @@ func (b *Buffer) Read(buff []byte) (n int, err error) {
|
||||
|
||||
func (b *Buffer) ReadExtended() (*ExtPacket, error) {
|
||||
for {
|
||||
if b.closed.get() {
|
||||
if b.closed.Load() {
|
||||
return nil, io.EOF
|
||||
}
|
||||
b.Lock()
|
||||
@@ -267,7 +267,7 @@ func (b *Buffer) Close() error {
|
||||
if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeAudio {
|
||||
b.audioPool.Put(b.bucket.src)
|
||||
}
|
||||
b.closed.set(true)
|
||||
b.closed.Store(true)
|
||||
b.onClose()
|
||||
b.callbacksQueue.Stop()
|
||||
})
|
||||
@@ -680,7 +680,7 @@ func (b *Buffer) getRTCP() []rtcp.Packet {
|
||||
func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
if b.closed.get() {
|
||||
if b.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return b.bucket.GetPacket(buff, sn)
|
||||
@@ -744,11 +744,10 @@ func (b *Buffer) GetClockRate() uint32 {
|
||||
|
||||
// GetSenderReportData returns the rtp, ntp and nanos of the last sender report
|
||||
func (b *Buffer) GetSenderReportData() (rtpTime uint32, ntpTime uint64, lastReceivedTimeInNanosSinceEpoch int64) {
|
||||
rtpTime = atomic.LoadUint32(&b.lastSRRTPTime)
|
||||
ntpTime = atomic.LoadUint64(&b.lastSRNTPTime)
|
||||
lastReceivedTimeInNanosSinceEpoch = atomic.LoadInt64(&b.lastSRRecv)
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
|
||||
return rtpTime, ntpTime, lastReceivedTimeInNanosSinceEpoch
|
||||
return b.lastSRRTPTime, b.lastSRNTPTime, b.lastSRRecv
|
||||
}
|
||||
|
||||
func (b *Buffer) GetStats() *StreamStatsWithLayers {
|
||||
|
||||
@@ -3,7 +3,6 @@ package buffer
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/livekit/protocol/logger"
|
||||
)
|
||||
@@ -14,20 +13,6 @@ var (
|
||||
errInvalidPacket = errors.New("invalid packet")
|
||||
)
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (a *atomicBool) set(value bool) {
|
||||
var i int32
|
||||
if value {
|
||||
i = 1
|
||||
}
|
||||
atomic.StoreInt32((*int32)(a), i)
|
||||
}
|
||||
|
||||
func (a *atomicBool) get() bool {
|
||||
return atomic.LoadInt32((*int32)(a)) != 0
|
||||
}
|
||||
|
||||
// VP8 is a helper to get temporal data from VP8 packet header
|
||||
/*
|
||||
VP8 Payload Descriptor
|
||||
|
||||
@@ -2,12 +2,13 @@ package buffer
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type RTCPReader struct {
|
||||
ssrc uint32
|
||||
closed atomicBool
|
||||
closed atomic.Bool
|
||||
onPacket atomic.Value // func([]byte)
|
||||
onClose func()
|
||||
}
|
||||
@@ -17,7 +18,7 @@ func NewRTCPReader(ssrc uint32) *RTCPReader {
|
||||
}
|
||||
|
||||
func (r *RTCPReader) Write(p []byte) (n int, err error) {
|
||||
if r.closed.get() {
|
||||
if r.closed.Load() {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
@@ -32,7 +33,7 @@ func (r *RTCPReader) OnClose(fn func()) {
|
||||
}
|
||||
|
||||
func (r *RTCPReader) Close() error {
|
||||
r.closed.set(true)
|
||||
r.closed.Store(true)
|
||||
r.onClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/pion/sdp/v3"
|
||||
"github.com/pion/transport/packetio"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/sfu/buffer"
|
||||
"github.com/livekit/livekit-server/pkg/sfu/connectionquality"
|
||||
@@ -73,7 +74,7 @@ type DownTrack struct {
|
||||
logger logger.Logger
|
||||
id livekit.TrackID
|
||||
peerID livekit.ParticipantID
|
||||
bound atomicBool
|
||||
bound atomic.Bool
|
||||
kind webrtc.RTPCodecType
|
||||
mime string
|
||||
ssrc uint32
|
||||
@@ -103,11 +104,11 @@ type DownTrack struct {
|
||||
connectionStats *connectionquality.ConnectionStats
|
||||
|
||||
// Debug info
|
||||
lastPli atomicInt64
|
||||
lastRTP atomicInt64
|
||||
pktsDropped atomicUint32
|
||||
lastPli atomic.Time
|
||||
lastRTP atomic.Time
|
||||
pktsDropped atomic.Uint32
|
||||
|
||||
isNACKThrottled atomicBool
|
||||
isNACKThrottled atomic.Bool
|
||||
|
||||
// RTCP callbacks
|
||||
onREMB func(dt *DownTrack, remb *rtcp.ReceiverEstimatedMaximumBitrate)
|
||||
@@ -196,7 +197,7 @@ func NewDownTrack(
|
||||
// This asserts that the code requested is supported by the remote peer.
|
||||
// If so it sets up all the state (SSRC and PayloadType) to have a call
|
||||
func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
|
||||
if d.bound.get() {
|
||||
if d.bound.Load() {
|
||||
return webrtc.RTPCodecParameters{}, ErrTrackAlreadyBind
|
||||
}
|
||||
parameters := webrtc.RTPCodecParameters{RTPCodecCapability: d.codec}
|
||||
@@ -216,7 +217,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters,
|
||||
if d.onBind != nil {
|
||||
d.onBind()
|
||||
}
|
||||
d.bound.set(true)
|
||||
d.bound.Store(true)
|
||||
go d.requestFirstKeyframe()
|
||||
return codec, nil
|
||||
}
|
||||
@@ -226,7 +227,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters,
|
||||
// Unbind implements the teardown logic when the track is no longer needed. This happens
|
||||
// because a track has been stopped.
|
||||
func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error {
|
||||
d.bound.set(false)
|
||||
d.bound.Store(false)
|
||||
d.receiver.DeleteDownTrack(d.peerID)
|
||||
return nil
|
||||
}
|
||||
@@ -299,20 +300,20 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if !d.bound.get() {
|
||||
if !d.bound.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.lastRTP.set(time.Now().UnixNano())
|
||||
d.lastRTP.Store(time.Now())
|
||||
|
||||
tp, err := d.forwarder.GetTranslationParams(extPkt, layer)
|
||||
if tp.shouldSendPLI {
|
||||
d.lastPli.set(time.Now().UnixNano())
|
||||
d.lastPli.Store(time.Now())
|
||||
d.receiver.SendPLI(layer)
|
||||
}
|
||||
if tp.shouldDrop {
|
||||
if tp.isDroppingRelevant {
|
||||
d.pktsDropped.add(1)
|
||||
d.pktsDropped.Inc()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -328,7 +329,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error {
|
||||
}
|
||||
payload, err = d.translateVP8PacketTo(extPkt.Packet, &incomingVP8, tp.vp8.header, outbuf)
|
||||
if err != nil {
|
||||
d.pktsDropped.add(1)
|
||||
d.pktsDropped.Inc()
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -342,7 +343,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error {
|
||||
|
||||
hdr, err := d.getTranslatedRTPHeader(extPkt, tp.rtp)
|
||||
if err != nil {
|
||||
d.pktsDropped.add(1)
|
||||
d.pktsDropped.Inc()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -359,11 +360,11 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error {
|
||||
|
||||
d.updatePrimaryStats(pktSize, hdr.Marker)
|
||||
if extPkt.KeyFrame {
|
||||
d.isNACKThrottled.set(false)
|
||||
d.isNACKThrottled.Store(false)
|
||||
}
|
||||
} else {
|
||||
d.logger.Errorw("writing rtp packet err", err)
|
||||
d.pktsDropped.add(1)
|
||||
d.pktsDropped.Inc()
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -706,7 +707,7 @@ func (d *DownTrack) Resync() {
|
||||
}
|
||||
|
||||
func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk {
|
||||
if !d.bound.get() {
|
||||
if !d.bound.Load() {
|
||||
return nil
|
||||
}
|
||||
return []rtcp.SourceDescriptionChunk{
|
||||
@@ -727,7 +728,7 @@ func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChun
|
||||
}
|
||||
|
||||
func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport {
|
||||
if !d.bound.get() {
|
||||
if !d.bound.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -899,9 +900,9 @@ func (d *DownTrack) handleRTCP(bytes []byte) {
|
||||
if pliOnce {
|
||||
targetLayers := d.forwarder.TargetLayers()
|
||||
if targetLayers != InvalidLayers {
|
||||
d.lastPli.set(time.Now().UnixNano())
|
||||
d.lastPli.Store(time.Now())
|
||||
d.receiver.SendPLI(targetLayers.spatial)
|
||||
d.isNACKThrottled.set(true)
|
||||
d.isNACKThrottled.Store(true)
|
||||
pliOnce = false
|
||||
}
|
||||
}
|
||||
@@ -999,7 +1000,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) {
|
||||
return
|
||||
}
|
||||
|
||||
if FlagStopRTXOnPLI && d.isNACKThrottled.get() {
|
||||
if FlagStopRTXOnPLI && d.isNACKThrottled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1182,9 +1183,9 @@ func (d *DownTrack) DebugInfo() map[string]interface{} {
|
||||
"LastTS": rtpMungerParams.lastTS,
|
||||
"TSOffset": rtpMungerParams.tsOffset,
|
||||
"LastMarker": rtpMungerParams.lastMarker,
|
||||
"LastRTP": d.lastRTP.get(),
|
||||
"LastPli": d.lastPli.get(),
|
||||
"PacketsDropped": d.pktsDropped.get(),
|
||||
"LastRTP": d.lastRTP.Load(),
|
||||
"LastPli": d.lastPli.Load(),
|
||||
"PacketsDropped": d.pktsDropped.Load(),
|
||||
}
|
||||
|
||||
senderReport := d.CreateSenderReport()
|
||||
@@ -1200,7 +1201,7 @@ func (d *DownTrack) DebugInfo() map[string]interface{} {
|
||||
"StreamID": d.streamID,
|
||||
"SSRC": d.ssrc,
|
||||
"MimeType": d.codec.MimeType,
|
||||
"Bound": d.bound.get(),
|
||||
"Bound": d.bound.Load(),
|
||||
"Muted": d.forwarder.IsMuted(),
|
||||
"CurrentSpatialLayer": d.forwarder.CurrentLayers().spatial,
|
||||
"Stats": stats,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/livekit/protocol/logger"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/sfu/buffer"
|
||||
)
|
||||
@@ -188,7 +189,7 @@ type Forwarder struct {
|
||||
rtpMunger *RTPMunger
|
||||
vp8Munger *VP8Munger
|
||||
|
||||
receivedFirstKeyFrame atomicBool
|
||||
receivedFirstKeyFrame atomic.Bool
|
||||
}
|
||||
|
||||
func NewForwarder(codec webrtc.RTPCodecCapability, kind webrtc.RTPCodecType, logger logger.Logger) *Forwarder {
|
||||
@@ -1193,7 +1194,7 @@ func (f *Forwarder) GetTranslationParams(extPkt *buffer.ExtPacket, layer int32)
|
||||
}
|
||||
|
||||
func (f *Forwarder) ReceivedFirstKeyFrame() bool {
|
||||
return f.receivedFirstKeyFrame.get()
|
||||
return f.receivedFirstKeyFrame.Load()
|
||||
}
|
||||
|
||||
// should be called with lock held
|
||||
@@ -1255,7 +1256,7 @@ func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer in
|
||||
if f.currentLayers.spatial == f.maxLayers.spatial {
|
||||
tp.isSwitchingToMaxLayer = true
|
||||
}
|
||||
f.receivedFirstKeyFrame.set(true)
|
||||
f.receivedFirstKeyFrame.Store(true)
|
||||
} else {
|
||||
tp.shouldSendPLI = true
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/config"
|
||||
"github.com/livekit/livekit-server/pkg/sfu/buffer"
|
||||
@@ -58,7 +58,7 @@ type WebRTCReceiver struct {
|
||||
isSimulcast bool
|
||||
onCloseHandler func()
|
||||
closeOnce sync.Once
|
||||
closed atomicBool
|
||||
closed atomic.Bool
|
||||
useTrackers bool
|
||||
|
||||
rtcpCh chan []rtcp.Packet
|
||||
@@ -239,7 +239,7 @@ func (w *WebRTCReceiver) Kind() webrtc.RTPCodecType {
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buffer) {
|
||||
if w.closed.get() {
|
||||
if w.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (w *WebRTCReceiver) SetUpTrackPaused(paused bool) {
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) AddDownTrack(track TrackSender) {
|
||||
if w.closed.get() {
|
||||
if w.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -351,7 +351,7 @@ func (w *WebRTCReceiver) OnCloseHandler(fn func()) {
|
||||
|
||||
// DeleteDownTrack removes a DownTrack from a Receiver
|
||||
func (w *WebRTCReceiver) DeleteDownTrack(peerID livekit.ParticipantID) {
|
||||
if w.closed.get() {
|
||||
if w.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -368,7 +368,7 @@ func (w *WebRTCReceiver) DeleteDownTrack(peerID livekit.ParticipantID) {
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) {
|
||||
if w.closed.get() {
|
||||
if w.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -446,7 +446,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
|
||||
|
||||
defer func() {
|
||||
w.closeOnce.Do(func() {
|
||||
w.closed.set(true)
|
||||
w.closed.Store(true)
|
||||
w.closeTracks()
|
||||
})
|
||||
|
||||
@@ -479,7 +479,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
|
||||
}
|
||||
} else {
|
||||
// parallel - enables much more efficient multi-core utilization
|
||||
start := uint64(0)
|
||||
start := atomic.NewUint64(0)
|
||||
end := uint64(len(downTracks))
|
||||
|
||||
// 100µs is enough to amortize the overhead and provide sufficient load balancing.
|
||||
@@ -492,7 +492,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
n := atomic.AddUint64(&start, step)
|
||||
n := start.Add(step)
|
||||
if n >= end+step {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@ package sfu
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/protocol/utils"
|
||||
)
|
||||
|
||||
@@ -37,9 +38,9 @@ type StreamTracker struct {
|
||||
|
||||
onStatusChanged func(status StreamStatus)
|
||||
|
||||
paused atomicBool
|
||||
countSinceLast uint32 // number of packets received since last check
|
||||
generation atomicUint32
|
||||
paused atomic.Bool
|
||||
countSinceLast atomic.Uint32 // number of packets received since last check
|
||||
generation atomic.Uint32
|
||||
|
||||
initMu sync.Mutex
|
||||
initialized bool
|
||||
@@ -102,7 +103,7 @@ func (s *StreamTracker) maybeSetStopped() {
|
||||
func (s *StreamTracker) init() {
|
||||
s.maybeSetActive()
|
||||
|
||||
go s.detectWorker(s.generation.get())
|
||||
go s.detectWorker(s.generation.Load())
|
||||
}
|
||||
|
||||
func (s *StreamTracker) Start() {
|
||||
@@ -114,7 +115,7 @@ func (s *StreamTracker) Stop() {
|
||||
}
|
||||
|
||||
// bump generation to trigger exit of worker
|
||||
s.generation.add(1)
|
||||
s.generation.Inc()
|
||||
}
|
||||
|
||||
func (s *StreamTracker) Reset() {
|
||||
@@ -123,9 +124,9 @@ func (s *StreamTracker) Reset() {
|
||||
}
|
||||
|
||||
// bump generation to trigger exit of current worker
|
||||
s.generation.add(1)
|
||||
s.generation.Inc()
|
||||
|
||||
atomic.StoreUint32(&s.countSinceLast, 0)
|
||||
s.countSinceLast.Store(0)
|
||||
s.cycleCount = 0
|
||||
|
||||
s.initMu.Lock()
|
||||
@@ -138,12 +139,12 @@ func (s *StreamTracker) Reset() {
|
||||
}
|
||||
|
||||
func (s *StreamTracker) SetPaused(paused bool) {
|
||||
s.paused.set(paused)
|
||||
s.paused.Store(paused)
|
||||
}
|
||||
|
||||
// Observe a packet that's received
|
||||
func (s *StreamTracker) Observe(sn uint16) {
|
||||
if s.paused.get() {
|
||||
if s.paused.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -154,7 +155,7 @@ func (s *StreamTracker) Observe(sn uint16) {
|
||||
s.initMu.Unlock()
|
||||
|
||||
s.lastSN = sn
|
||||
atomic.AddUint32(&s.countSinceLast, 1)
|
||||
s.countSinceLast.Inc()
|
||||
|
||||
// declare stream active and start the detection worker
|
||||
go s.init()
|
||||
@@ -168,7 +169,7 @@ func (s *StreamTracker) Observe(sn uint16) {
|
||||
return
|
||||
}
|
||||
s.lastSN = sn
|
||||
atomic.AddUint32(&s.countSinceLast, 1)
|
||||
s.countSinceLast.Inc()
|
||||
}
|
||||
|
||||
func (s *StreamTracker) detectWorker(generation uint32) {
|
||||
@@ -176,7 +177,7 @@ func (s *StreamTracker) detectWorker(generation uint32) {
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
if generation != s.generation.get() {
|
||||
if generation != s.generation.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -185,11 +186,11 @@ func (s *StreamTracker) detectWorker(generation uint32) {
|
||||
}
|
||||
|
||||
func (s *StreamTracker) detectChanges() {
|
||||
if s.paused.get() {
|
||||
if s.paused.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&s.countSinceLast) >= s.samplesRequired {
|
||||
if s.countSinceLast.Load() >= s.samplesRequired {
|
||||
s.cycleCount += 1
|
||||
} else {
|
||||
s.cycleCount = 0
|
||||
@@ -203,5 +204,5 @@ func (s *StreamTracker) detectChanges() {
|
||||
s.maybeSetActive()
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&s.countSinceLast, 0)
|
||||
s.countSinceLast.Store(0)
|
||||
}
|
||||
|
||||
@@ -6,16 +6,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/testutils"
|
||||
)
|
||||
|
||||
func TestStreamTracker(t *testing.T) {
|
||||
t.Run("flips to active on first observe", func(t *testing.T) {
|
||||
callbackCalled := atomicBool(0)
|
||||
callbackCalled := atomic.NewBool(false)
|
||||
tracker := NewStreamTracker(5, 60, 500*time.Millisecond)
|
||||
tracker.OnStatusChanged(func(status StreamStatus) {
|
||||
callbackCalled.set(true)
|
||||
callbackCalled.Store(true)
|
||||
})
|
||||
require.Equal(t, StreamStatusStopped, tracker.Status())
|
||||
|
||||
@@ -23,7 +24,7 @@ func TestStreamTracker(t *testing.T) {
|
||||
tracker.Observe(1)
|
||||
|
||||
testutils.WithTimeout(t, func() string {
|
||||
if callbackCalled.get() {
|
||||
if callbackCalled.Load() {
|
||||
return ""
|
||||
} else {
|
||||
return "first packet didn't activate stream"
|
||||
@@ -31,7 +32,7 @@ func TestStreamTracker(t *testing.T) {
|
||||
})
|
||||
|
||||
require.Equal(t, StreamStatusActive, tracker.Status())
|
||||
require.True(t, callbackCalled.get())
|
||||
require.True(t, callbackCalled.Load())
|
||||
})
|
||||
|
||||
t.Run("flips to inactive immediately", func(t *testing.T) {
|
||||
@@ -47,16 +48,16 @@ func TestStreamTracker(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
callbackCalled := atomicBool(0)
|
||||
callbackCalled := atomic.NewBool(false)
|
||||
tracker.OnStatusChanged(func(status StreamStatus) {
|
||||
callbackCalled.set(true)
|
||||
callbackCalled.Store(true)
|
||||
})
|
||||
require.Equal(t, StreamStatusActive, tracker.Status())
|
||||
|
||||
// run a single iteration
|
||||
tracker.detectChanges()
|
||||
require.Equal(t, StreamStatusStopped, tracker.Status())
|
||||
require.True(t, callbackCalled.get())
|
||||
require.True(t, callbackCalled.Load())
|
||||
})
|
||||
|
||||
t.Run("flips back to active after iterations", func(t *testing.T) {
|
||||
@@ -100,10 +101,10 @@ func TestStreamTracker(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("flips back to active on first observe after reset", func(t *testing.T) {
|
||||
callbackCalled := atomicUint32(0)
|
||||
callbackCalled := atomic.NewUint32(0)
|
||||
tracker := NewStreamTracker(5, 60, 500*time.Millisecond)
|
||||
tracker.OnStatusChanged(func(status StreamStatus) {
|
||||
callbackCalled.add(1)
|
||||
callbackCalled.Inc()
|
||||
})
|
||||
require.Equal(t, StreamStatusStopped, tracker.Status())
|
||||
|
||||
@@ -111,15 +112,15 @@ func TestStreamTracker(t *testing.T) {
|
||||
tracker.Observe(1)
|
||||
|
||||
testutils.WithTimeout(t, func() string {
|
||||
if callbackCalled.get() == 1 {
|
||||
if callbackCalled.Load() == 1 {
|
||||
return ""
|
||||
} else {
|
||||
return fmt.Sprintf("expected onStatusChanged to be called once, actual: %d", callbackCalled.get())
|
||||
return fmt.Sprintf("expected onStatusChanged to be called once, actual: %d", callbackCalled.Load())
|
||||
}
|
||||
})
|
||||
|
||||
require.Equal(t, StreamStatusActive, tracker.Status())
|
||||
require.Equal(t, uint32(1), callbackCalled.get())
|
||||
require.Equal(t, uint32(1), callbackCalled.Load())
|
||||
|
||||
// observe a few more
|
||||
tracker.Observe(2)
|
||||
@@ -139,14 +140,14 @@ func TestStreamTracker(t *testing.T) {
|
||||
tracker.Observe(1)
|
||||
|
||||
testutils.WithTimeout(t, func() string {
|
||||
if callbackCalled.get() == 2 {
|
||||
if callbackCalled.Load() == 2 {
|
||||
return ""
|
||||
} else {
|
||||
return fmt.Sprintf("expected onStatusChanged to be called twice, actual %d", callbackCalled.get())
|
||||
return fmt.Sprintf("expected onStatusChanged to be called twice, actual %d", callbackCalled.Load())
|
||||
}
|
||||
})
|
||||
|
||||
require.Equal(t, StreamStatusActive, tracker.Status())
|
||||
require.Equal(t, uint32(2), callbackCalled.get())
|
||||
require.Equal(t, uint32(2), callbackCalled.Load())
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user