From 5a4181b581804ffa68dec8b5694352ccf7b003e4 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 25 Feb 2022 11:57:09 +0530 Subject: [PATCH] Replacing hand rolled ion-sfu atomic with uber/atomic (#465) * Replacing hand rolled ion-sfu atomic with uber/atomic * Remove another hand rolled atomic --- go.mod | 2 +- go.sum | 3 ++- pkg/sfu/atomic.go | 50 ---------------------------------- pkg/sfu/buffer/buffer.go | 21 +++++++-------- pkg/sfu/buffer/helpers.go | 15 ----------- pkg/sfu/buffer/rtcpreader.go | 9 ++++--- pkg/sfu/downtrack.go | 51 ++++++++++++++++++----------------- pkg/sfu/forwarder.go | 7 ++--- pkg/sfu/receiver.go | 18 ++++++------- pkg/sfu/streamtracker.go | 33 ++++++++++++----------- pkg/sfu/streamtracker_test.go | 31 ++++++++++----------- 11 files changed, 90 insertions(+), 150 deletions(-) delete mode 100644 pkg/sfu/atomic.go diff --git a/go.mod b/go.mod index 41f61aa99..406e93972 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/ua-parser/uap-go v0.0.0-20211112212520-00c877edfe0f github.com/urfave/cli/v2 v2.3.0 github.com/urfave/negroni v1.0.0 - go.uber.org/atomic v1.7.0 + go.uber.org/atomic v1.9.0 go.uber.org/zap v1.19.1 google.golang.org/protobuf v1.27.1 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b diff --git a/go.sum b/go.sum index ad098cfb3..2bba105bf 100644 --- a/go.sum +++ b/go.sum @@ -275,8 +275,9 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723 h1:sHOAIxRGBp443oHZIPB+HsUGaksVCXVQENPxwTfQdH4= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= diff --git a/pkg/sfu/atomic.go b/pkg/sfu/atomic.go deleted file mode 100644 index c02acd3ff..000000000 --- a/pkg/sfu/atomic.go +++ /dev/null @@ -1,50 +0,0 @@ -package sfu - -import "sync/atomic" - -type atomicBool int32 - -func (a *atomicBool) set(value bool) (swapped bool) { - if value { - return atomic.SwapInt32((*int32)(a), 1) == 0 - } - return atomic.SwapInt32((*int32)(a), 0) == 1 -} - -func (a *atomicBool) get() bool { - return atomic.LoadInt32((*int32)(a)) != 0 -} - -type atomicUint8 uint32 - -func (a *atomicUint8) set(value uint8) { - atomic.StoreUint32((*uint32)(a), uint32(value)) -} - -func (a *atomicUint8) get() uint8 { - return uint8(atomic.LoadUint32((*uint32)(a))) -} - -type atomicUint32 uint32 - -func (a *atomicUint32) set(value uint32) { - atomic.StoreUint32((*uint32)(a), value) -} - -func (a *atomicUint32) add(value uint32) { - atomic.AddUint32((*uint32)(a), value) -} - -func (a *atomicUint32) get() uint32 { - return atomic.LoadUint32((*uint32)(a)) -} - -type atomicInt64 int64 - -func (a *atomicInt64) set(value int64) { - atomic.StoreInt64((*int64)(a), value) -} - -func (a *atomicInt64) get() int64 { - return atomic.LoadInt64((*int64)(a)) -} diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index 665c21363..b2613fbd7 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -6,7 +6,6 @@ import ( "math/rand" "strings" "sync" - "sync/atomic" "time" "github.com/gammazero/deque" @@ -16,6 +15,7 @@ import ( "github.com/pion/rtp" "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" + "go.uber.org/atomic" ) const ( @@ -54,7 +54,7 @@ type Buffer struct { twccExt uint8 audioExt uint8 bound bool - closed atomicBool + closed atomic.Bool mime string // supported feedbacks @@ -197,7 +197,7 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) { b.Lock() defer b.Unlock() - if b.closed.get() { + if b.closed.Load() { err = io.EOF return } @@ -218,7 +218,7 @@ func (b *Buffer) Write(pkt []byte) (n int, err error) { func (b *Buffer) Read(buff []byte) (n int, err error) { for { - if b.closed.get() { + if b.closed.Load() { err = io.EOF return } @@ -242,7 +242,7 @@ func (b *Buffer) Read(buff []byte) (n int, err error) { func (b *Buffer) ReadExtended() (*ExtPacket, error) { for { - if b.closed.get() { + if b.closed.Load() { return nil, io.EOF } b.Lock() @@ -267,7 +267,7 @@ func (b *Buffer) Close() error { if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeAudio { b.audioPool.Put(b.bucket.src) } - b.closed.set(true) + b.closed.Store(true) b.onClose() b.callbacksQueue.Stop() }) @@ -680,7 +680,7 @@ func (b *Buffer) getRTCP() []rtcp.Packet { func (b *Buffer) GetPacket(buff []byte, sn uint16) (int, error) { b.Lock() defer b.Unlock() - if b.closed.get() { + if b.closed.Load() { return 0, io.EOF } return b.bucket.GetPacket(buff, sn) @@ -744,11 +744,10 @@ func (b *Buffer) GetClockRate() uint32 { // GetSenderReportData returns the rtp, ntp and nanos of the last sender report func (b *Buffer) GetSenderReportData() (rtpTime uint32, ntpTime uint64, lastReceivedTimeInNanosSinceEpoch int64) { - rtpTime = atomic.LoadUint32(&b.lastSRRTPTime) - ntpTime = atomic.LoadUint64(&b.lastSRNTPTime) - lastReceivedTimeInNanosSinceEpoch = atomic.LoadInt64(&b.lastSRRecv) + b.RLock() + defer b.RUnlock() - return rtpTime, ntpTime, lastReceivedTimeInNanosSinceEpoch + return b.lastSRRTPTime, b.lastSRNTPTime, b.lastSRRecv } func (b *Buffer) GetStats() *StreamStatsWithLayers { diff --git a/pkg/sfu/buffer/helpers.go b/pkg/sfu/buffer/helpers.go index d31830c6d..1f14c568c 100644 --- a/pkg/sfu/buffer/helpers.go +++ b/pkg/sfu/buffer/helpers.go @@ -3,7 +3,6 @@ package buffer import ( "encoding/binary" "errors" - "sync/atomic" "github.com/livekit/protocol/logger" ) @@ -14,20 +13,6 @@ var ( errInvalidPacket = errors.New("invalid packet") ) -type atomicBool int32 - -func (a *atomicBool) set(value bool) { - var i int32 - if value { - i = 1 - } - atomic.StoreInt32((*int32)(a), i) -} - -func (a *atomicBool) get() bool { - return atomic.LoadInt32((*int32)(a)) != 0 -} - // VP8 is a helper to get temporal data from VP8 packet header /* VP8 Payload Descriptor diff --git a/pkg/sfu/buffer/rtcpreader.go b/pkg/sfu/buffer/rtcpreader.go index 16a58b923..0fce45b7a 100644 --- a/pkg/sfu/buffer/rtcpreader.go +++ b/pkg/sfu/buffer/rtcpreader.go @@ -2,12 +2,13 @@ package buffer import ( "io" - "sync/atomic" + + "go.uber.org/atomic" ) type RTCPReader struct { ssrc uint32 - closed atomicBool + closed atomic.Bool onPacket atomic.Value // func([]byte) onClose func() } @@ -17,7 +18,7 @@ func NewRTCPReader(ssrc uint32) *RTCPReader { } func (r *RTCPReader) Write(p []byte) (n int, err error) { - if r.closed.get() { + if r.closed.Load() { err = io.EOF return } @@ -32,7 +33,7 @@ func (r *RTCPReader) OnClose(fn func()) { } func (r *RTCPReader) Close() error { - r.closed.set(true) + r.closed.Store(true) r.onClose() return nil } diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 53c759b57..0c3ee7b64 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -16,6 +16,7 @@ import ( "github.com/pion/sdp/v3" "github.com/pion/transport/packetio" "github.com/pion/webrtc/v3" + "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" @@ -73,7 +74,7 @@ type DownTrack struct { logger logger.Logger id livekit.TrackID peerID livekit.ParticipantID - bound atomicBool + bound atomic.Bool kind webrtc.RTPCodecType mime string ssrc uint32 @@ -103,11 +104,11 @@ type DownTrack struct { connectionStats *connectionquality.ConnectionStats // Debug info - lastPli atomicInt64 - lastRTP atomicInt64 - pktsDropped atomicUint32 + lastPli atomic.Time + lastRTP atomic.Time + pktsDropped atomic.Uint32 - isNACKThrottled atomicBool + isNACKThrottled atomic.Bool // RTCP callbacks onREMB func(dt *DownTrack, remb *rtcp.ReceiverEstimatedMaximumBitrate) @@ -196,7 +197,7 @@ func NewDownTrack( // This asserts that the code requested is supported by the remote peer. // If so it sets up all the state (SSRC and PayloadType) to have a call func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) { - if d.bound.get() { + if d.bound.Load() { return webrtc.RTPCodecParameters{}, ErrTrackAlreadyBind } parameters := webrtc.RTPCodecParameters{RTPCodecCapability: d.codec} @@ -216,7 +217,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, if d.onBind != nil { d.onBind() } - d.bound.set(true) + d.bound.Store(true) go d.requestFirstKeyframe() return codec, nil } @@ -226,7 +227,7 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, // Unbind implements the teardown logic when the track is no longer needed. This happens // because a track has been stopped. func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { - d.bound.set(false) + d.bound.Store(false) d.receiver.DeleteDownTrack(d.peerID) return nil } @@ -299,20 +300,20 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { } }() - if !d.bound.get() { + if !d.bound.Load() { return nil } - d.lastRTP.set(time.Now().UnixNano()) + d.lastRTP.Store(time.Now()) tp, err := d.forwarder.GetTranslationParams(extPkt, layer) if tp.shouldSendPLI { - d.lastPli.set(time.Now().UnixNano()) + d.lastPli.Store(time.Now()) d.receiver.SendPLI(layer) } if tp.shouldDrop { if tp.isDroppingRelevant { - d.pktsDropped.add(1) + d.pktsDropped.Inc() } return err } @@ -328,7 +329,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { } payload, err = d.translateVP8PacketTo(extPkt.Packet, &incomingVP8, tp.vp8.header, outbuf) if err != nil { - d.pktsDropped.add(1) + d.pktsDropped.Inc() return err } } @@ -342,7 +343,7 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { hdr, err := d.getTranslatedRTPHeader(extPkt, tp.rtp) if err != nil { - d.pktsDropped.add(1) + d.pktsDropped.Inc() return err } @@ -359,11 +360,11 @@ func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) error { d.updatePrimaryStats(pktSize, hdr.Marker) if extPkt.KeyFrame { - d.isNACKThrottled.set(false) + d.isNACKThrottled.Store(false) } } else { d.logger.Errorw("writing rtp packet err", err) - d.pktsDropped.add(1) + d.pktsDropped.Inc() } return err @@ -706,7 +707,7 @@ func (d *DownTrack) Resync() { } func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk { - if !d.bound.get() { + if !d.bound.Load() { return nil } return []rtcp.SourceDescriptionChunk{ @@ -727,7 +728,7 @@ func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChun } func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport { - if !d.bound.get() { + if !d.bound.Load() { return nil } @@ -899,9 +900,9 @@ func (d *DownTrack) handleRTCP(bytes []byte) { if pliOnce { targetLayers := d.forwarder.TargetLayers() if targetLayers != InvalidLayers { - d.lastPli.set(time.Now().UnixNano()) + d.lastPli.Store(time.Now()) d.receiver.SendPLI(targetLayers.spatial) - d.isNACKThrottled.set(true) + d.isNACKThrottled.Store(true) pliOnce = false } } @@ -999,7 +1000,7 @@ func (d *DownTrack) retransmitPackets(nacks []uint16) { return } - if FlagStopRTXOnPLI && d.isNACKThrottled.get() { + if FlagStopRTXOnPLI && d.isNACKThrottled.Load() { return } @@ -1182,9 +1183,9 @@ func (d *DownTrack) DebugInfo() map[string]interface{} { "LastTS": rtpMungerParams.lastTS, "TSOffset": rtpMungerParams.tsOffset, "LastMarker": rtpMungerParams.lastMarker, - "LastRTP": d.lastRTP.get(), - "LastPli": d.lastPli.get(), - "PacketsDropped": d.pktsDropped.get(), + "LastRTP": d.lastRTP.Load(), + "LastPli": d.lastPli.Load(), + "PacketsDropped": d.pktsDropped.Load(), } senderReport := d.CreateSenderReport() @@ -1200,7 +1201,7 @@ func (d *DownTrack) DebugInfo() map[string]interface{} { "StreamID": d.streamID, "SSRC": d.ssrc, "MimeType": d.codec.MimeType, - "Bound": d.bound.get(), + "Bound": d.bound.Load(), "Muted": d.forwarder.IsMuted(), "CurrentSpatialLayer": d.forwarder.CurrentLayers().spatial, "Stats": stats, diff --git a/pkg/sfu/forwarder.go b/pkg/sfu/forwarder.go index a5c83c28c..d5db8e562 100644 --- a/pkg/sfu/forwarder.go +++ b/pkg/sfu/forwarder.go @@ -8,6 +8,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/pion/webrtc/v3" + "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/sfu/buffer" ) @@ -188,7 +189,7 @@ type Forwarder struct { rtpMunger *RTPMunger vp8Munger *VP8Munger - receivedFirstKeyFrame atomicBool + receivedFirstKeyFrame atomic.Bool } func NewForwarder(codec webrtc.RTPCodecCapability, kind webrtc.RTPCodecType, logger logger.Logger) *Forwarder { @@ -1193,7 +1194,7 @@ func (f *Forwarder) GetTranslationParams(extPkt *buffer.ExtPacket, layer int32) } func (f *Forwarder) ReceivedFirstKeyFrame() bool { - return f.receivedFirstKeyFrame.get() + return f.receivedFirstKeyFrame.Load() } // should be called with lock held @@ -1255,7 +1256,7 @@ func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer in if f.currentLayers.spatial == f.maxLayers.spatial { tp.isSwitchingToMaxLayer = true } - f.receivedFirstKeyFrame.set(true) + f.receivedFirstKeyFrame.Store(true) } else { tp.shouldSendPLI = true } diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 9b58e477e..ba6232db7 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -4,7 +4,6 @@ import ( "io" "runtime" "sync" - "sync/atomic" "time" "github.com/go-logr/logr" @@ -13,6 +12,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/rs/zerolog/log" + "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/sfu/buffer" @@ -58,7 +58,7 @@ type WebRTCReceiver struct { isSimulcast bool onCloseHandler func() closeOnce sync.Once - closed atomicBool + closed atomic.Bool useTrackers bool rtcpCh chan []rtcp.Packet @@ -239,7 +239,7 @@ func (w *WebRTCReceiver) Kind() webrtc.RTPCodecType { } func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buffer) { - if w.closed.get() { + if w.closed.Load() { return } @@ -286,7 +286,7 @@ func (w *WebRTCReceiver) SetUpTrackPaused(paused bool) { } func (w *WebRTCReceiver) AddDownTrack(track TrackSender) { - if w.closed.get() { + if w.closed.Load() { return } @@ -351,7 +351,7 @@ func (w *WebRTCReceiver) OnCloseHandler(fn func()) { // DeleteDownTrack removes a DownTrack from a Receiver func (w *WebRTCReceiver) DeleteDownTrack(peerID livekit.ParticipantID) { - if w.closed.get() { + if w.closed.Load() { return } @@ -368,7 +368,7 @@ func (w *WebRTCReceiver) DeleteDownTrack(peerID livekit.ParticipantID) { } func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) { - if w.closed.get() { + if w.closed.Load() { return } @@ -446,7 +446,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { defer func() { w.closeOnce.Do(func() { - w.closed.set(true) + w.closed.Store(true) w.closeTracks() }) @@ -479,7 +479,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { } } else { // parallel - enables much more efficient multi-core utilization - start := uint64(0) + start := atomic.NewUint64(0) end := uint64(len(downTracks)) // 100µs is enough to amortize the overhead and provide sufficient load balancing. @@ -492,7 +492,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { go func() { defer wg.Done() for { - n := atomic.AddUint64(&start, step) + n := start.Add(step) if n >= end+step { return } diff --git a/pkg/sfu/streamtracker.go b/pkg/sfu/streamtracker.go index 29e7471fd..68f1ccc51 100644 --- a/pkg/sfu/streamtracker.go +++ b/pkg/sfu/streamtracker.go @@ -2,9 +2,10 @@ package sfu import ( "sync" - "sync/atomic" "time" + "go.uber.org/atomic" + "github.com/livekit/protocol/utils" ) @@ -37,9 +38,9 @@ type StreamTracker struct { onStatusChanged func(status StreamStatus) - paused atomicBool - countSinceLast uint32 // number of packets received since last check - generation atomicUint32 + paused atomic.Bool + countSinceLast atomic.Uint32 // number of packets received since last check + generation atomic.Uint32 initMu sync.Mutex initialized bool @@ -102,7 +103,7 @@ func (s *StreamTracker) maybeSetStopped() { func (s *StreamTracker) init() { s.maybeSetActive() - go s.detectWorker(s.generation.get()) + go s.detectWorker(s.generation.Load()) } func (s *StreamTracker) Start() { @@ -114,7 +115,7 @@ func (s *StreamTracker) Stop() { } // bump generation to trigger exit of worker - s.generation.add(1) + s.generation.Inc() } func (s *StreamTracker) Reset() { @@ -123,9 +124,9 @@ func (s *StreamTracker) Reset() { } // bump generation to trigger exit of current worker - s.generation.add(1) + s.generation.Inc() - atomic.StoreUint32(&s.countSinceLast, 0) + s.countSinceLast.Store(0) s.cycleCount = 0 s.initMu.Lock() @@ -138,12 +139,12 @@ func (s *StreamTracker) Reset() { } func (s *StreamTracker) SetPaused(paused bool) { - s.paused.set(paused) + s.paused.Store(paused) } // Observe a packet that's received func (s *StreamTracker) Observe(sn uint16) { - if s.paused.get() { + if s.paused.Load() { return } @@ -154,7 +155,7 @@ func (s *StreamTracker) Observe(sn uint16) { s.initMu.Unlock() s.lastSN = sn - atomic.AddUint32(&s.countSinceLast, 1) + s.countSinceLast.Inc() // declare stream active and start the detection worker go s.init() @@ -168,7 +169,7 @@ func (s *StreamTracker) Observe(sn uint16) { return } s.lastSN = sn - atomic.AddUint32(&s.countSinceLast, 1) + s.countSinceLast.Inc() } func (s *StreamTracker) detectWorker(generation uint32) { @@ -176,7 +177,7 @@ func (s *StreamTracker) detectWorker(generation uint32) { for { <-ticker.C - if generation != s.generation.get() { + if generation != s.generation.Load() { return } @@ -185,11 +186,11 @@ func (s *StreamTracker) detectWorker(generation uint32) { } func (s *StreamTracker) detectChanges() { - if s.paused.get() { + if s.paused.Load() { return } - if atomic.LoadUint32(&s.countSinceLast) >= s.samplesRequired { + if s.countSinceLast.Load() >= s.samplesRequired { s.cycleCount += 1 } else { s.cycleCount = 0 @@ -203,5 +204,5 @@ func (s *StreamTracker) detectChanges() { s.maybeSetActive() } - atomic.StoreUint32(&s.countSinceLast, 0) + s.countSinceLast.Store(0) } diff --git a/pkg/sfu/streamtracker_test.go b/pkg/sfu/streamtracker_test.go index eee4befd8..e45836481 100644 --- a/pkg/sfu/streamtracker_test.go +++ b/pkg/sfu/streamtracker_test.go @@ -6,16 +6,17 @@ import ( "time" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "github.com/livekit/livekit-server/pkg/testutils" ) func TestStreamTracker(t *testing.T) { t.Run("flips to active on first observe", func(t *testing.T) { - callbackCalled := atomicBool(0) + callbackCalled := atomic.NewBool(false) tracker := NewStreamTracker(5, 60, 500*time.Millisecond) tracker.OnStatusChanged(func(status StreamStatus) { - callbackCalled.set(true) + callbackCalled.Store(true) }) require.Equal(t, StreamStatusStopped, tracker.Status()) @@ -23,7 +24,7 @@ func TestStreamTracker(t *testing.T) { tracker.Observe(1) testutils.WithTimeout(t, func() string { - if callbackCalled.get() { + if callbackCalled.Load() { return "" } else { return "first packet didn't activate stream" @@ -31,7 +32,7 @@ func TestStreamTracker(t *testing.T) { }) require.Equal(t, StreamStatusActive, tracker.Status()) - require.True(t, callbackCalled.get()) + require.True(t, callbackCalled.Load()) }) t.Run("flips to inactive immediately", func(t *testing.T) { @@ -47,16 +48,16 @@ func TestStreamTracker(t *testing.T) { } }) - callbackCalled := atomicBool(0) + callbackCalled := atomic.NewBool(false) tracker.OnStatusChanged(func(status StreamStatus) { - callbackCalled.set(true) + callbackCalled.Store(true) }) require.Equal(t, StreamStatusActive, tracker.Status()) // run a single iteration tracker.detectChanges() require.Equal(t, StreamStatusStopped, tracker.Status()) - require.True(t, callbackCalled.get()) + require.True(t, callbackCalled.Load()) }) t.Run("flips back to active after iterations", func(t *testing.T) { @@ -100,10 +101,10 @@ func TestStreamTracker(t *testing.T) { }) t.Run("flips back to active on first observe after reset", func(t *testing.T) { - callbackCalled := atomicUint32(0) + callbackCalled := atomic.NewUint32(0) tracker := NewStreamTracker(5, 60, 500*time.Millisecond) tracker.OnStatusChanged(func(status StreamStatus) { - callbackCalled.add(1) + callbackCalled.Inc() }) require.Equal(t, StreamStatusStopped, tracker.Status()) @@ -111,15 +112,15 @@ func TestStreamTracker(t *testing.T) { tracker.Observe(1) testutils.WithTimeout(t, func() string { - if callbackCalled.get() == 1 { + if callbackCalled.Load() == 1 { return "" } else { - return fmt.Sprintf("expected onStatusChanged to be called once, actual: %d", callbackCalled.get()) + return fmt.Sprintf("expected onStatusChanged to be called once, actual: %d", callbackCalled.Load()) } }) require.Equal(t, StreamStatusActive, tracker.Status()) - require.Equal(t, uint32(1), callbackCalled.get()) + require.Equal(t, uint32(1), callbackCalled.Load()) // observe a few more tracker.Observe(2) @@ -139,14 +140,14 @@ func TestStreamTracker(t *testing.T) { tracker.Observe(1) testutils.WithTimeout(t, func() string { - if callbackCalled.get() == 2 { + if callbackCalled.Load() == 2 { return "" } else { - return fmt.Sprintf("expected onStatusChanged to be called twice, actual %d", callbackCalled.get()) + return fmt.Sprintf("expected onStatusChanged to be called twice, actual %d", callbackCalled.Load()) } }) require.Equal(t, StreamStatusActive, tracker.Status()) - require.Equal(t, uint32(2), callbackCalled.get()) + require.Equal(t, uint32(2), callbackCalled.Load()) }) }