mirror of
https://github.com/livekit/livekit.git
synced 2026-05-13 01:25:30 +00:00
Split out load balancer into a separate module (#657)
This commit is contained in:
@@ -0,0 +1,152 @@
|
||||
package sfu
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/livekit-server/pkg/sfu/buffer"
|
||||
"github.com/livekit/protocol/livekit"
|
||||
"github.com/livekit/protocol/logger"
|
||||
)
|
||||
|
||||
type DownTrackSpreaderParams struct {
|
||||
Threshold int
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
type DownTrackSpreader struct {
|
||||
params DownTrackSpreaderParams
|
||||
|
||||
downTrackMu sync.RWMutex
|
||||
downTracks []TrackSender
|
||||
index map[livekit.ParticipantID]int
|
||||
free map[int]struct{}
|
||||
numProcs int
|
||||
}
|
||||
|
||||
func NewDownTrackSpreader(params DownTrackSpreaderParams) *DownTrackSpreader {
|
||||
d := &DownTrackSpreader{
|
||||
params: params,
|
||||
downTracks: make([]TrackSender, 0),
|
||||
index: make(map[livekit.ParticipantID]int),
|
||||
free: make(map[int]struct{}),
|
||||
numProcs: runtime.NumCPU(),
|
||||
}
|
||||
|
||||
if runtime.GOMAXPROCS(0) < d.numProcs {
|
||||
d.numProcs = runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) GetDownTracks() []TrackSender {
|
||||
d.downTrackMu.RLock()
|
||||
defer d.downTrackMu.RUnlock()
|
||||
|
||||
return d.downTracks
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) ResetAndGetDownTracks() []TrackSender {
|
||||
d.downTrackMu.Lock()
|
||||
defer d.downTrackMu.Unlock()
|
||||
|
||||
downTracks := d.downTracks
|
||||
|
||||
d.index = make(map[livekit.ParticipantID]int)
|
||||
d.free = make(map[int]struct{})
|
||||
d.downTracks = make([]TrackSender, 0)
|
||||
|
||||
return downTracks
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) Store(ts TrackSender) {
|
||||
d.downTrackMu.Lock()
|
||||
defer d.downTrackMu.Unlock()
|
||||
|
||||
peerID := ts.PeerID()
|
||||
for idx := range d.free {
|
||||
d.index[peerID] = idx
|
||||
delete(d.free, idx)
|
||||
d.downTracks[idx] = ts
|
||||
return
|
||||
}
|
||||
|
||||
d.index[peerID] = len(d.downTracks)
|
||||
d.downTracks = append(d.downTracks, ts)
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) Free(peerID livekit.ParticipantID) {
|
||||
d.downTrackMu.Lock()
|
||||
defer d.downTrackMu.Unlock()
|
||||
|
||||
idx, ok := d.index[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(d.index, peerID)
|
||||
d.downTracks[idx] = nil
|
||||
d.free[idx] = struct{}{}
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) HasDownTrack(peerID livekit.ParticipantID) bool {
|
||||
d.downTrackMu.RLock()
|
||||
defer d.downTrackMu.RUnlock()
|
||||
|
||||
_, ok := d.index[peerID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) Broadcast(layer int32, pkt *buffer.ExtPacket) {
|
||||
d.downTrackMu.RLock()
|
||||
downTracks := d.downTracks
|
||||
free := d.free
|
||||
d.downTrackMu.RUnlock()
|
||||
|
||||
if d.params.Threshold == 0 || len(downTracks)-len(free) < d.params.Threshold {
|
||||
// serial - not enough down tracks for parallelization to outweigh overhead
|
||||
for _, dt := range downTracks {
|
||||
if dt != nil {
|
||||
d.writeRTP(layer, dt, pkt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// parallel - enables much more efficient multi-core utilization
|
||||
start := atomic.NewUint64(0)
|
||||
end := uint64(len(downTracks))
|
||||
|
||||
// 100µs is enough to amortize the overhead and provide sufficient load balancing.
|
||||
// WriteRTP takes about 50µs on average, so we write to 2 down tracks per loop.
|
||||
step := uint64(2)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(d.numProcs)
|
||||
for p := 0; p < d.numProcs; p++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
n := start.Add(step)
|
||||
if n >= end+step {
|
||||
return
|
||||
}
|
||||
|
||||
for i := n - step; i < n && i < end; i++ {
|
||||
if dt := downTracks[i]; dt != nil {
|
||||
d.writeRTP(layer, dt, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DownTrackSpreader) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPacket) {
|
||||
if err := dt.WriteRTP(pkt, layer); err != nil {
|
||||
d.params.Logger.Errorw("failed writing to down track", err)
|
||||
}
|
||||
}
|
||||
+14
-104
@@ -3,14 +3,12 @@ package sfu
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/livekit/protocol/livekit"
|
||||
@@ -83,15 +81,12 @@ type WebRTCReceiver struct {
|
||||
upTrackMu sync.RWMutex
|
||||
upTracks [DefaultMaxLayerSpatial + 1]*webrtc.TrackRemote
|
||||
|
||||
downTrackMu sync.RWMutex
|
||||
downTracks []TrackSender
|
||||
index map[livekit.ParticipantID]int
|
||||
free map[int]struct{}
|
||||
numProcs int
|
||||
lbThreshold int
|
||||
|
||||
streamTrackerManager *StreamTrackerManager
|
||||
|
||||
downTrackSpreader *DownTrackSpreader
|
||||
|
||||
connectionStats *connectionquality.ConnectionStats
|
||||
|
||||
// update stats
|
||||
@@ -168,23 +163,20 @@ func NewWebRTCReceiver(
|
||||
// LK-TODO: this should be based on VideoLayers protocol message rather than RID based
|
||||
isSimulcast: len(track.RID()) > 0,
|
||||
twcc: twcc,
|
||||
downTracks: make([]TrackSender, 0),
|
||||
index: make(map[livekit.ParticipantID]int),
|
||||
free: make(map[int]struct{}),
|
||||
numProcs: runtime.NumCPU(),
|
||||
streamTrackerManager: NewStreamTrackerManager(logger, source),
|
||||
}
|
||||
w.streamTrackerManager.OnAvailableLayersChanged(w.downTrackLayerChange)
|
||||
w.streamTrackerManager.OnBitrateAvailabilityChanged(w.downTrackBitrateAvailabilityChange)
|
||||
|
||||
if runtime.GOMAXPROCS(0) < w.numProcs {
|
||||
w.numProcs = runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
w = opt(w)
|
||||
}
|
||||
|
||||
w.downTrackSpreader = NewDownTrackSpreader(DownTrackSpreaderParams{
|
||||
Threshold: w.lbThreshold,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
w.connectionStats = connectionquality.NewConnectionStats(connectionquality.ConnectionStatsParams{
|
||||
CodecType: w.kind,
|
||||
GetDeltaStats: w.getDeltaStats,
|
||||
@@ -322,10 +314,7 @@ func (w *WebRTCReceiver) AddDownTrack(track TrackSender) error {
|
||||
return ErrReceiverClosed
|
||||
}
|
||||
|
||||
w.downTrackMu.RLock()
|
||||
_, ok := w.index[track.PeerID()]
|
||||
w.downTrackMu.RUnlock()
|
||||
if ok {
|
||||
if w.downTrackSpreader.HasDownTrack(track.PeerID()) {
|
||||
return ErrDownTrackAlreadyExist
|
||||
}
|
||||
|
||||
@@ -346,11 +335,7 @@ func (w *WebRTCReceiver) SetMaxExpectedSpatialLayer(layer int32) {
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) downTrackLayerChange(layers []int32) {
|
||||
w.downTrackMu.RLock()
|
||||
downTracks := w.downTracks
|
||||
w.downTrackMu.RUnlock()
|
||||
|
||||
for _, dt := range downTracks {
|
||||
for _, dt := range w.downTrackSpreader.GetDownTracks() {
|
||||
if dt != nil {
|
||||
dt.UpTrackLayersChange(layers)
|
||||
}
|
||||
@@ -358,11 +343,7 @@ func (w *WebRTCReceiver) downTrackLayerChange(layers []int32) {
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) downTrackBitrateAvailabilityChange() {
|
||||
w.downTrackMu.RLock()
|
||||
downTracks := w.downTracks
|
||||
w.downTrackMu.RUnlock()
|
||||
|
||||
for _, dt := range downTracks {
|
||||
for _, dt := range w.downTrackSpreader.GetDownTracks() {
|
||||
if dt != nil {
|
||||
dt.UpTrackBitrateAvailabilityChange()
|
||||
}
|
||||
@@ -384,16 +365,7 @@ func (w *WebRTCReceiver) DeleteDownTrack(peerID livekit.ParticipantID) {
|
||||
return
|
||||
}
|
||||
|
||||
w.downTrackMu.Lock()
|
||||
defer w.downTrackMu.Unlock()
|
||||
|
||||
idx, ok := w.index[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(w.index, peerID)
|
||||
w.downTracks[idx] = nil
|
||||
w.free[idx] = struct{}{}
|
||||
w.downTrackSpreader.Free(peerID)
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) {
|
||||
@@ -564,53 +536,7 @@ func (w *WebRTCReceiver) forwardRTP(layer int32) {
|
||||
tracker.Observe(pkt.Packet.SequenceNumber, pkt.TemporalLayer, len(pkt.RawPacket), len(pkt.Packet.Payload))
|
||||
}
|
||||
|
||||
w.downTrackMu.RLock()
|
||||
downTracks := w.downTracks
|
||||
free := w.free
|
||||
w.downTrackMu.RUnlock()
|
||||
if w.lbThreshold == 0 || len(downTracks)-len(free) < w.lbThreshold {
|
||||
// serial - not enough down tracks for parallelization to outweigh overhead
|
||||
for _, dt := range downTracks {
|
||||
if dt != nil {
|
||||
w.writeRTP(layer, dt, pkt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// parallel - enables much more efficient multi-core utilization
|
||||
start := atomic.NewUint64(0)
|
||||
end := uint64(len(downTracks))
|
||||
|
||||
// 100µs is enough to amortize the overhead and provide sufficient load balancing.
|
||||
// WriteRTP takes about 50µs on average, so we write to 2 down tracks per loop.
|
||||
step := uint64(2)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(w.numProcs)
|
||||
for p := 0; p < w.numProcs; p++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
n := start.Add(step)
|
||||
if n >= end+step {
|
||||
return
|
||||
}
|
||||
|
||||
for i := n - step; i < n && i < end; i++ {
|
||||
if dt := downTracks[i]; dt != nil {
|
||||
w.writeRTP(layer, dt, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPacket) {
|
||||
if err := dt.WriteRTP(pkt, layer); err != nil {
|
||||
log.Error().Err(err).Str("id", dt.ID()).Msg("Error writing to down track")
|
||||
w.downTrackSpreader.Broadcast(layer, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,16 +544,11 @@ func (w *WebRTCReceiver) writeRTP(layer int32, dt TrackSender, pkt *buffer.ExtPa
|
||||
func (w *WebRTCReceiver) closeTracks() {
|
||||
w.connectionStats.Close()
|
||||
|
||||
w.downTrackMu.Lock()
|
||||
for _, dt := range w.downTracks {
|
||||
for _, dt := range w.downTrackSpreader.ResetAndGetDownTracks() {
|
||||
if dt != nil {
|
||||
dt.Close()
|
||||
}
|
||||
}
|
||||
w.downTracks = make([]TrackSender, 0)
|
||||
w.index = make(map[livekit.ParticipantID]int)
|
||||
w.free = make(map[int]struct{})
|
||||
w.downTrackMu.Unlock()
|
||||
|
||||
if w.onCloseHandler != nil {
|
||||
w.onCloseHandler()
|
||||
@@ -635,18 +556,7 @@ func (w *WebRTCReceiver) closeTracks() {
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) storeDownTrack(track TrackSender) {
|
||||
w.downTrackMu.Lock()
|
||||
defer w.downTrackMu.Unlock()
|
||||
|
||||
for idx := range w.free {
|
||||
w.index[track.PeerID()] = idx
|
||||
w.downTracks[idx] = track
|
||||
delete(w.free, idx)
|
||||
return
|
||||
}
|
||||
|
||||
w.index[track.PeerID()] = len(w.downTracks)
|
||||
w.downTracks = append(w.downTracks, track)
|
||||
w.downTrackSpreader.Store(track)
|
||||
}
|
||||
|
||||
func (w *WebRTCReceiver) DebugInfo() map[string]interface{} {
|
||||
|
||||
@@ -75,11 +75,10 @@ type StreamTracker struct {
|
||||
}
|
||||
|
||||
func NewStreamTracker(params StreamTrackerParams) *StreamTracker {
|
||||
s := &StreamTracker{
|
||||
return &StreamTracker{
|
||||
params: params,
|
||||
status: StreamStatusStopped,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *StreamTracker) OnStatusChanged(f func(status StreamStatus)) {
|
||||
|
||||
Reference in New Issue
Block a user