diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index ba714228d..e2978d82d 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -56,17 +56,16 @@ type MediaTrackParams struct { ParticipantID livekit.ParticipantID ParticipantIdentity livekit.ParticipantIdentity ParticipantVersion uint32 - // channel to send RTCP packets to the source - RTCPChan chan []rtcp.Packet - BufferFactory *buffer.Factory - ReceiverConfig ReceiverConfig - SubscriberConfig DirectionConfig - PLIThrottleConfig config.PLIThrottleConfig - AudioConfig config.AudioConfig - VideoConfig config.VideoConfig - Telemetry telemetry.TelemetryService - Logger logger.Logger - SimTracks map[uint32]SimulcastTrackInfo + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + PLIThrottleConfig config.PLIThrottleConfig + AudioConfig config.AudioConfig + VideoConfig config.VideoConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger + SimTracks map[uint32]SimulcastTrackInfo + OnRTCP func([]rtcp.Packet) } func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { @@ -252,13 +251,13 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra ti, LoggerWithCodecMime(t.params.Logger, mime), twcc, + t.params.OnRTCP, t.params.VideoConfig.StreamTracker, sfu.WithPliThrottleConfig(t.params.PLIThrottleConfig), sfu.WithAudioConfig(t.params.AudioConfig), sfu.WithLoadBalanceThreshold(20), sfu.WithStreamTrackers(), ) - newWR.SetRTCPCh(t.params.RTCPChan) newWR.OnCloseHandler(func() { t.MediaTrackReceiver.SetClosing() t.MediaTrackReceiver.ClearReceiver(mime, false) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 8c476d5e7..96f251d07 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -158,7 +158,7 @@ type ParticipantImpl struct { disconnectTimer *time.Timer migrationTimer *time.Timer - rtcpCh chan []rtcp.Packet + pubRTCPQueue *sutils.OpsQueue // hold reference for MediaTrack twcc *twcc.Responder @@ -240,7 +240,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { } p := &ParticipantImpl{ params: params, - rtcpCh: make(chan []rtcp.Packet, 100), + pubRTCPQueue: sutils.NewOpsQueue("pub-rtcp", 64, false), pendingTracks: make(map[string]*pendingTrackInfo), pendingPublishingTracks: make(map[livekit.TrackID]*pendingTrackInfo), connectedAt: time.Now(), @@ -1467,7 +1467,7 @@ func (p *ParticipantImpl) onPublisherInitialConnected() { if p.supervisor != nil { p.supervisor.SetPublisherPeerConnectionConnected(true) } - go p.publisherRTCPWorker() + p.pubRTCPQueue.Start() } func (p *ParticipantImpl) onSubscriberInitialConnected() { @@ -2000,7 +2000,6 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *liv ParticipantID: p.params.SID, ParticipantIdentity: p.params.Identity, ParticipantVersion: p.version.Load(), - RTCPChan: p.rtcpCh, BufferFactory: p.params.Config.BufferFactory, ReceiverConfig: p.params.Config.Receiver, AudioConfig: p.params.AudioConfig, @@ -2010,6 +2009,7 @@ func (p *ParticipantImpl) addMediaTrack(signalCid string, sdpCid string, ti *liv SubscriberConfig: p.params.Config.Subscriber, PLIThrottleConfig: p.params.PLIThrottleConfig, SimTracks: p.params.SimTracks, + OnRTCP: p.postRtcp, }, ti) mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) @@ -2117,7 +2117,7 @@ func (p *ParticipantImpl) hasPendingMigratedTrack() bool { } func (p *ParticipantImpl) onUpTrackManagerClose() { - p.postRtcp(nil) + p.pubRTCPQueue.Stop() } func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo, bool) { @@ -2241,28 +2241,6 @@ func (p *ParticipantImpl) getPublishedTrackBySdpCid(clientId string) types.Media return nil } -func (p *ParticipantImpl) publisherRTCPWorker() { - defer func() { - if r := Recover(p.GetLogger()); r != nil { - os.Exit(1) - } - }() - - // read from rtcpChan - for pkts := range p.rtcpCh { - if pkts == nil { - p.pubLogger.Debugw("exiting publisher RTCP worker") - return - } - - if err := p.TransportManager.WritePublisherRTCP(pkts); err != nil { - if !IsEOF(err) { - p.pubLogger.Errorw("could not write RTCP to participant", err) - } - } - } -} - func (p *ParticipantImpl) DebugInfo() map[string]interface{} { info := map[string]interface{}{ "ID": p.params.SID, @@ -2291,11 +2269,11 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { } func (p *ParticipantImpl) postRtcp(pkts []rtcp.Packet) { - select { - case p.rtcpCh <- pkts: - default: - p.params.Logger.Warnw("rtcp channel full", nil) - } + p.pubRTCPQueue.Enqueue(func() { + if err := p.TransportManager.WritePublisherRTCP(pkts); err != nil && !IsEOF(err) { + p.pubLogger.Errorw("could not write RTCP to participant", err) + } + }) } func (p *ParticipantImpl) setDowntracksConnected() { diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index ed01e1f79..5cadfc267 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -64,22 +64,37 @@ func NewUpTrackManager(params UpTrackManagerParams) *UpTrackManager { func (u *UpTrackManager) Close(willBeResumed bool) { u.lock.Lock() - u.closed = true - notify := len(u.publishedTracks) == 0 - u.lock.Unlock() - - // remove all subscribers - for _, t := range u.GetPublishedTracks() { - t.ClearAllReceivers(willBeResumed) + if u.closed { + u.lock.Unlock() + return } - if notify && u.onClose != nil { - u.onClose() + u.closed = true + + publishedTracks := u.publishedTracks + u.publishedTracks = make(map[livekit.TrackID]types.MediaTrack) + u.lock.Unlock() + + for _, t := range publishedTracks { + t.Close(willBeResumed) + } + + if onClose := u.getOnUpTrackManagerClose(); onClose != nil { + onClose() } } func (u *UpTrackManager) OnUpTrackManagerClose(f func()) { + u.lock.Lock() u.onClose = f + u.lock.Unlock() +} + +func (u *UpTrackManager) getOnUpTrackManagerClose() func() { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.onClose } func (u *UpTrackManager) ToProto() []*livekit.TrackInfo { @@ -247,22 +262,10 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", logger.Proto(track.ToProto())) track.AddOnClose(func() { - notifyClose := false - - // cleanup u.lock.Lock() - trackID := track.ID() - delete(u.publishedTracks, trackID) + delete(u.publishedTracks, track.ID()) // not modifying subscription permissions, will get reset on next update from participant - - if u.closed && len(u.publishedTracks) == 0 { - notifyClose = true - } u.lock.Unlock() - - if notifyClose && u.onClose != nil { - u.onClose() - } }) } diff --git a/pkg/sfu/buffer/buffer.go b/pkg/sfu/buffer/buffer.go index f27946064..7e250c23a 100644 --- a/pkg/sfu/buffer/buffer.go +++ b/pkg/sfu/buffer/buffer.go @@ -264,7 +264,7 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapabili b.bound = true } -// Write adds an RTP Packet, out of order, new packet may be arrived later +// Write adds an RTP Packet, ordering is not guaranteed, newer packets may arrive later func (b *Buffer) Write(pkt []byte) (n int, err error) { b.Lock() defer b.Unlock() diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index ae368b42f..b05b2b23c 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -108,7 +108,7 @@ type WebRTCReceiver struct { useTrackers bool trackInfo atomic.Pointer[livekit.TrackInfo] - rtcpCh chan []rtcp.Packet + onRTCP func([]rtcp.Packet) twcc *twcc.Responder @@ -198,6 +198,7 @@ func NewWebRTCReceiver( trackInfo *livekit.TrackInfo, logger logger.Logger, twcc *twcc.Responder, + onRTCP func([]rtcp.Packet), trackersConfig config.StreamTrackersConfig, opts ...ReceiverOpts, ) *WebRTCReceiver { @@ -209,6 +210,7 @@ func NewWebRTCReceiver( codec: track.Codec(), kind: track.Kind(), twcc: twcc, + onRTCP: onRTCP, isSVC: IsSvcCodec(track.Codec().MimeType), isRED: IsRedCodec(track.Codec().MimeType), } @@ -514,10 +516,8 @@ func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) { return } - select { - case w.rtcpCh <- packets: - default: - w.logger.Warnw("sendRTCP failed, rtcp channel full", nil) + if w.onRTCP != nil { + w.onRTCP(packets) } } @@ -531,10 +531,6 @@ func (w *WebRTCReceiver) SendPLI(layer int32, force bool) { buff.SendPLI(force) } -func (w *WebRTCReceiver) SetRTCPCh(ch chan []rtcp.Packet) { - w.rtcpCh = ch -} - func (w *WebRTCReceiver) getBuffer(layer int32) *buffer.Buffer { w.bufferMu.RLock() defer w.bufferMu.RUnlock()