From a83bd5c2f61c24338d4e1bad9bd3b458d46ff896 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Tue, 26 Apr 2022 15:46:30 +0530 Subject: [PATCH] Split out load balancer into a separate module (#657) --- pkg/sfu/downtrackspreader.go | 152 +++++++++++++++++++++++++++++++++++ pkg/sfu/receiver.go | 118 ++++----------------------- pkg/sfu/streamtracker.go | 3 +- 3 files changed, 167 insertions(+), 106 deletions(-) create mode 100644 pkg/sfu/downtrackspreader.go diff --git a/pkg/sfu/downtrackspreader.go b/pkg/sfu/downtrackspreader.go new file mode 100644 index 000000000..f7f1e828e --- /dev/null +++ b/pkg/sfu/downtrackspreader.go @@ -0,0 +1,152 @@ +package sfu + +import ( + "runtime" + "sync" + + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type DownTrackSpreaderParams struct { + Threshold int + Logger logger.Logger +} + +type DownTrackSpreader struct { + params DownTrackSpreaderParams + + downTrackMu sync.RWMutex + downTracks []TrackSender + index map[livekit.ParticipantID]int + free map[int]struct{} + numProcs int +} + +func NewDownTrackSpreader(params DownTrackSpreaderParams) *DownTrackSpreader { + d := &DownTrackSpreader{ + params: params, + downTracks: make([]TrackSender, 0), + index: make(map[livekit.ParticipantID]int), + free: make(map[int]struct{}), + numProcs: runtime.NumCPU(), + } + + if runtime.GOMAXPROCS(0) < d.numProcs { + d.numProcs = runtime.GOMAXPROCS(0) + } + + return d +} + +func (d *DownTrackSpreader) GetDownTracks() []TrackSender { + d.downTrackMu.RLock() + defer d.downTrackMu.RUnlock() + + return d.downTracks +} + +func (d *DownTrackSpreader) ResetAndGetDownTracks() []TrackSender { + d.downTrackMu.Lock() + defer d.downTrackMu.Unlock() + + downTracks := d.downTracks + + d.index = make(map[livekit.ParticipantID]int) + d.free = make(map[int]struct{}) + d.downTracks = make([]TrackSender, 0) + + return downTracks +} + +func (d *DownTrackSpreader) Store(ts TrackSender) { + d.downTrackMu.Lock() + defer d.downTrackMu.Unlock() + + peerID := ts.PeerID() + for idx := range d.free { + d.index[peerID] = idx + delete(d.free, idx) + d.downTracks[idx] = ts + return + } + + d.index[peerID] = len(d.downTracks) + d.downTracks = append(d.downTracks, ts) +} + +func (d *DownTrackSpreader) Free(peerID livekit.ParticipantID) { + d.downTrackMu.Lock() + defer d.downTrackMu.Unlock() + + idx, ok := d.index[peerID] + if !ok { + return + } + + delete(d.index, peerID) + d.downTracks[idx] = nil + d.free[idx] = struct{}{} +} + +func (d *DownTrackSpreader) HasDownTrack(peerID livekit.ParticipantID) bool { + d.downTrackMu.RLock() + defer d.downTrackMu.RUnlock() + + _, ok := d.index[peerID] + return ok +} + +func (d *DownTrackSpreader) Broadcast(layer int32, pkt *buffer.ExtPacket) { + d.downTrackMu.RLock() + downTracks := d.downTracks + free := d.free + d.downTrackMu.RUnlock() + + if d.params.Threshold == 0 || len(downTracks)-len(free) < d.params.Threshold { + // serial - not enough down tracks for parallelization to outweigh overhead + for _, dt := range downTracks { + if dt != nil { + d.writeRTP(layer, dt, pkt) + } + } + } else { + // parallel - enables much more efficient multi-core utilization + start := atomic.NewUint64(0) + end := uint64(len(downTracks)) + + // 100µs is enough to amortize the overhead and provide sufficient load balancing. + // WriteRTP takes about 50µs on average, so we write to 2 down tracks per loop. + step := uint64(2) + + var wg sync.WaitGroup + wg.Add(d.numProcs) + for p := 0; p < d.numProcs; p++ { + go func() { + defer wg.Done() + for { + n := start.Add(step) + if n >= end+step { + return + } + + for i := n - step; i < n && i < end; i++ { + if dt := downTracks[i]; dt != nil { + d.writeRTP(layer, dt, pkt) + } + } + } + }() + } + wg.Wait() + } +} + +func (d *DownTrackSpreader) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPacket) { + if err := dt.WriteRTP(pkt, layer); err != nil { + d.params.Logger.Errorw("failed writing to down track", err) + } +} diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 87d509d37..2fdaa4c99 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -3,14 +3,12 @@ package sfu import ( "errors" "io" - "runtime" "sync" "time" "github.com/go-logr/logr" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" - "github.com/rs/zerolog/log" "go.uber.org/atomic" "github.com/livekit/protocol/livekit" @@ -83,15 +81,12 @@ type WebRTCReceiver struct { upTrackMu sync.RWMutex upTracks [DefaultMaxLayerSpatial + 1]*webrtc.TrackRemote - downTrackMu sync.RWMutex - downTracks []TrackSender - index map[livekit.ParticipantID]int - free map[int]struct{} - numProcs int lbThreshold int streamTrackerManager *StreamTrackerManager + downTrackSpreader *DownTrackSpreader + connectionStats *connectionquality.ConnectionStats // update stats @@ -168,23 +163,20 @@ func NewWebRTCReceiver( // LK-TODO: this should be based on VideoLayers protocol message rather than RID based isSimulcast: len(track.RID()) > 0, twcc: twcc, - downTracks: make([]TrackSender, 0), - index: make(map[livekit.ParticipantID]int), - free: make(map[int]struct{}), - numProcs: runtime.NumCPU(), streamTrackerManager: NewStreamTrackerManager(logger, source), } w.streamTrackerManager.OnAvailableLayersChanged(w.downTrackLayerChange) w.streamTrackerManager.OnBitrateAvailabilityChanged(w.downTrackBitrateAvailabilityChange) - if runtime.GOMAXPROCS(0) < w.numProcs { - w.numProcs = runtime.GOMAXPROCS(0) - } - for _, opt := range opts { w = opt(w) } + w.downTrackSpreader = NewDownTrackSpreader(DownTrackSpreaderParams{ + Threshold: w.lbThreshold, + Logger: logger, + }) + w.connectionStats = connectionquality.NewConnectionStats(connectionquality.ConnectionStatsParams{ CodecType: w.kind, GetDeltaStats: w.getDeltaStats, @@ -322,10 +314,7 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error { return ErrReceiverClosed } - w.downTrackMu.RLock() - _, ok := w.index[track.PeerID()] - w.downTrackMu.RUnlock() - if ok { + if w.downTrackSpreader.HasDownTrack(track.PeerID()) { return ErrDownTrackAlreadyExist } @@ -346,11 +335,7 @@ func (w *WebRTCReceiver) SetMaxExpectedSpatialLayer(layer int32) { } func (w *WebRTCReceiver) downTrackLayerChange(layers []int32) { - w.downTrackMu.RLock() - downTracks := w.downTracks - w.downTrackMu.RUnlock() - - for _, dt := range downTracks { + for _, dt := range w.downTrackSpreader.GetDownTracks() { if dt != nil { dt.UpTrackLayersChange(layers) } @@ -358,11 +343,7 @@ func (w *WebRTCReceiver) downTrackLayerChange(layers []int32) { } func (w *WebRTCReceiver) downTrackBitrateAvailabilityChange() { - w.downTrackMu.RLock() - downTracks := w.downTracks - w.downTrackMu.RUnlock() - - for _, dt := range downTracks { + for _, dt := range w.downTrackSpreader.GetDownTracks() { if dt != nil { dt.UpTrackBitrateAvailabilityChange() } @@ -384,16 +365,7 @@ func (w *WebRTCReceiver) DeleteDownTrack(peerID livekit.ParticipantID) { return } - w.downTrackMu.Lock() - defer w.downTrackMu.Unlock() - - idx, ok := w.index[peerID] - if !ok { - return - } - delete(w.index, peerID) - w.downTracks[idx] = nil - w.free[idx] = struct{}{} + w.downTrackSpreader.Free(peerID) } func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) { @@ -564,53 +536,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) { tracker.Observe(pkt.Packet.SequenceNumber, pkt.TemporalLayer, len(pkt.RawPacket), len(pkt.Packet.Payload)) } - w.downTrackMu.RLock() - downTracks := w.downTracks - free := w.free - w.downTrackMu.RUnlock() - if w.lbThreshold == 0 || len(downTracks)-len(free) < w.lbThreshold { - // serial - not enough down tracks for parallelization to outweigh overhead - for _, dt := range downTracks { - if dt != nil { - w.writeRTP(layer, dt, pkt) - } - } - } else { - // parallel - enables much more efficient multi-core utilization - start := atomic.NewUint64(0) - end := uint64(len(downTracks)) - - // 100µs is enough to amortize the overhead and provide sufficient load balancing. - // WriteRTP takes about 50µs on average, so we write to 2 down tracks per loop. - step := uint64(2) - - var wg sync.WaitGroup - wg.Add(w.numProcs) - for p := 0; p < w.numProcs; p++ { - go func() { - defer wg.Done() - for { - n := start.Add(step) - if n >= end+step { - return - } - - for i := n - step; i < n && i < end; i++ { - if dt := downTracks[i]; dt != nil { - w.writeRTP(layer, dt, pkt) - } - } - } - }() - } - wg.Wait() - } - } -} - -func (w *WebRTCReceiver) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPacket) { - if err := dt.WriteRTP(pkt, layer); err != nil { - log.Error().Err(err).Str("id", dt.ID()).Msg("Error writing to down track") + w.downTrackSpreader.Broadcast(layer, pkt) } } @@ -618,16 +544,11 @@ func (w *WebRTCReceiver) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPa func (w *WebRTCReceiver) closeTracks() { w.connectionStats.Close() - w.downTrackMu.Lock() - for _, dt := range w.downTracks { + for _, dt := range w.downTrackSpreader.ResetAndGetDownTracks() { if dt != nil { dt.Close() } } - w.downTracks = make([]TrackSender, 0) - w.index = make(map[livekit.ParticipantID]int) - w.free = make(map[int]struct{}) - w.downTrackMu.Unlock() if w.onCloseHandler != nil { w.onCloseHandler() @@ -635,18 +556,7 @@ func (w *WebRTCReceiver) closeTracks() { } func (w *WebRTCReceiver) storeDownTrack(track TrackSender) { - w.downTrackMu.Lock() - defer w.downTrackMu.Unlock() - - for idx := range w.free { - w.index[track.PeerID()] = idx - w.downTracks[idx] = track - delete(w.free, idx) - return - } - - w.index[track.PeerID()] = len(w.downTracks) - w.downTracks = append(w.downTracks, track) + w.downTrackSpreader.Store(track) } func (w *WebRTCReceiver) DebugInfo() map[string]interface{} { diff --git a/pkg/sfu/streamtracker.go b/pkg/sfu/streamtracker.go index cd21b1c38..0a929e1e0 100644 --- a/pkg/sfu/streamtracker.go +++ b/pkg/sfu/streamtracker.go @@ -75,11 +75,10 @@ type StreamTracker struct { } func NewStreamTracker(params StreamTrackerParams) *StreamTracker { - s := &StreamTracker{ + return &StreamTracker{ params: params, status: StreamStatusStopped, } - return s } func (s *StreamTracker) OnStatusChanged(f func(status StreamStatus)) {