From a9e0598210eaa99a5921d2cd55cb30ca0a90fc04 Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Sun, 9 Jan 2022 10:45:49 +0530 Subject: [PATCH] uptrackmanager reuse (#318) * WIP commit * Remove the double lock * Remove unused variable * WIP commit * Fix test * WIP commit * Split out MediaTrackReceiver * Address comments from David --- pkg/rtc/localparticipant.go | 395 ++++++++++++++++ pkg/rtc/mediatrack.go | 396 +++------------- pkg/rtc/mediatrack_test.go | 4 +- pkg/rtc/mediatrackreceiver.go | 383 ++++++++++++++++ pkg/rtc/mediatracksubscriptions.go | 15 +- pkg/rtc/participant.go | 60 +-- pkg/rtc/participant_internal_test.go | 10 +- pkg/rtc/room.go | 5 + pkg/rtc/types/interfaces.go | 5 +- .../types/typesfakes/fake_published_track.go | 152 ++---- pkg/rtc/uptrackmanager.go | 434 +++--------------- pkg/sfu/receiver.go | 14 +- 12 files changed, 1007 insertions(+), 866 deletions(-) create mode 100644 pkg/rtc/localparticipant.go create mode 100644 pkg/rtc/mediatrackreceiver.go diff --git a/pkg/rtc/localparticipant.go b/pkg/rtc/localparticipant.go new file mode 100644 index 000000000..29d93b9f3 --- /dev/null +++ b/pkg/rtc/localparticipant.go @@ -0,0 +1,395 @@ +package rtc + +import ( + "errors" + "sync" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/twcc" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +type pendingTrackInfo struct { + *livekit.TrackInfo + migrated bool +} + +type LocalParticipantParams struct { + Identity livekit.ParticipantIdentity + SID livekit.ParticipantID + Config *WebRTCConfig + AudioConfig config.AudioConfig + Telemetry telemetry.TelemetryService + ThrottleConfig config.PLIThrottleConfig + Logger logger.Logger +} + +type LocalParticipant struct { + params LocalParticipantParams + rtcpCh chan []rtcp.Packet + pliThrottle *pliThrottle + + // hold reference for MediaTrack + twcc *twcc.Responder + + // client intended to publish, yet to be reconciled + pendingTracksLock sync.RWMutex + pendingTracks map[string]*pendingTrackInfo + + *UptrackManager + + // callbacks & handlers + onWriteRTCP func(pkts []rtcp.Packet) + onTrackPublished func(track types.PublishedTrack) +} + +func NewLocalParticipant(params LocalParticipantParams) *LocalParticipant { + l := &LocalParticipant{ + params: params, + rtcpCh: make(chan []rtcp.Packet, 50), + pliThrottle: newPLIThrottle(params.ThrottleConfig), + pendingTracks: make(map[string]*pendingTrackInfo), + } + + l.setupUptrackManager() + + return l +} + +func (l *LocalParticipant) Start() { + l.UptrackManager.Start() + go l.rtcpSendWorker() +} + +func (l *LocalParticipant) Close() { + l.UptrackManager.Close() + + l.pendingTracksLock.Lock() + l.pendingTracks = make(map[string]*pendingTrackInfo) + l.pendingTracksLock.Unlock() +} + +func (l *LocalParticipant) OnWriteRTCP(f func(pkts []rtcp.Packet)) { + l.onWriteRTCP = f +} + +func (l *LocalParticipant) OnTrackPublished(f func(track types.PublishedTrack)) { + l.onTrackPublished = f +} + +// AddTrack is called when client intends to publish track. +// records track details and lets client know it's ok to proceed +func (l *LocalParticipant) AddTrack(req *livekit.AddTrackRequest) *livekit.TrackInfo { + l.pendingTracksLock.Lock() + defer l.pendingTracksLock.Unlock() + + // if track is already published, reject + if l.pendingTracks[req.Cid] != nil { + return nil + } + + if l.UptrackManager.GetPublishedTrackBySignalCidOrSdpCid(req.Cid) != nil { + return nil + } + + ti := &livekit.TrackInfo{ + Type: req.Type, + Name: req.Name, + Sid: utils.NewGuid(utils.TrackPrefix), + Width: req.Width, + Height: req.Height, + Muted: req.Muted, + DisableDtx: req.DisableDtx, + Source: req.Source, + Layers: req.Layers, + } + l.pendingTracks[req.Cid] = &pendingTrackInfo{TrackInfo: ti} + + return ti +} + +func (l *LocalParticipant) AddMigratedTrack(cid string, ti *livekit.TrackInfo) { + l.pendingTracksLock.Lock() + defer l.pendingTracksLock.Unlock() + + l.pendingTracks[cid] = &pendingTrackInfo{ti, true} +} + +func (l *LocalParticipant) SetTrackMuted(trackID livekit.TrackID, muted bool) { + track := l.UptrackManager.SetTrackMuted(trackID, muted) + if track != nil { + // handled in UptrackManager for a published track, no need to update state of pending track + return + } + + isPending := false + l.pendingTracksLock.RLock() + for _, ti := range l.pendingTracks { + if livekit.TrackID(ti.Sid) == trackID { + ti.Muted = muted + isPending = true + break + } + } + l.pendingTracksLock.RUnlock() + + if !isPending { + l.params.Logger.Warnw("could not locate track", nil, "track", trackID) + } +} + +func (l *LocalParticipant) GetAudioLevel() (level uint8, active bool) { + level = SilentAudioLevel + for _, pt := range l.UptrackManager.GetPublishedTracks() { + tl, ta := pt.GetAudioLevel() + if ta { + active = true + if tl < level { + level = tl + } + } + } + return +} + +func (l *LocalParticipant) GetConnectionQuality() (scores float64, numTracks int) { + for _, pt := range l.UptrackManager.GetPublishedTracks() { + if pt.IsMuted() { + continue + } + scores += pt.GetConnectionScore() + numTracks++ + } + + return +} + +func (l *LocalParticipant) GetDTX() bool { + l.pendingTracksLock.RLock() + defer l.pendingTracksLock.RUnlock() + + // + // Although DTX is set per track, there are cases where + // pending track has to be looked up by kind. This happens + // when clients change track id between signalling and SDP. + // In that case, look at all pending tracks by kind and + // enable DTX even if one has it enabled. + // + // Most of the time in practice, there is going to be one + // audio kind track and hence this is fine. + // + for _, ti := range l.pendingTracks { + if ti.Type == livekit.TrackType_AUDIO { + if !ti.TrackInfo.DisableDtx { + return true + } + } + } + + return false +} + +func (l *LocalParticipant) MediaTrackReceived(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver, mid string) (types.PublishedTrack, bool) { + l.pendingTracksLock.Lock() + newTrack := false + + // use existing mediatrack to handle simulcast + mt, ok := l.UptrackManager.GetPublishedTrackBySdpCid(track.ID()).(*MediaTrack) + if !ok { + signalCid, ti := l.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) + if ti == nil { + l.pendingTracksLock.Unlock() + return nil, false + } + + ti.MimeType = track.Codec().MimeType + ti.Mid = mid + + mt = NewMediaTrack(track, MediaTrackParams{ + TrackInfo: ti, + SignalCid: signalCid, + SdpCid: track.ID(), + ParticipantID: l.params.SID, + ParticipantIdentity: l.params.Identity, + RTCPChan: l.rtcpCh, + BufferFactory: l.params.Config.BufferFactory, + ReceiverConfig: l.params.Config.Receiver, + AudioConfig: l.params.AudioConfig, + Telemetry: l.params.Telemetry, + Logger: l.params.Logger, + SubscriberConfig: l.params.Config.Subscriber, + }) + + // add to published and clean up pending + l.UptrackManager.AddPublishedTrack(mt) + delete(l.pendingTracks, signalCid) + + newTrack = true + } + + ssrc := uint32(track.SSRC()) + l.pliThrottle.addTrack(ssrc, track.RID()) + if l.twcc == nil { + l.twcc = twcc.NewTransportWideCCResponder(ssrc) + l.twcc.OnFeedback(func(pkt rtcp.RawPacket) { + if l.onWriteRTCP != nil { + l.onWriteRTCP([]rtcp.Packet{&pkt}) + } + }) + } + l.pendingTracksLock.Unlock() + + mt.AddReceiver(rtpReceiver, track, l.twcc) + + if newTrack { + l.handleTrackPublished(mt) + } + + return mt, newTrack +} + +func (l *LocalParticipant) handleTrackPublished(track types.PublishedTrack) { + if l.onTrackPublished != nil { + l.onTrackPublished(track) + } +} + +func (l *LocalParticipant) UpdateSubscribedQuality(nodeID string, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error { + track := l.UptrackManager.GetPublishedTrack(trackID) + if track == nil { + l.params.Logger.Warnw("could not find track", nil, "trackID", trackID) + return errors.New("could not find track") + } + + if mt, ok := track.(*MediaTrack); ok { + mt.NotifySubscriberNodeMaxQuality(nodeID, maxQuality) + } + + return nil +} + +func (l *LocalParticipant) UpdateMediaLoss(nodeID string, trackID livekit.TrackID, fractionalLoss uint32) error { + track := l.UptrackManager.GetPublishedTrack(trackID) + if track == nil { + l.params.Logger.Warnw("could not find track", nil, "trackID", trackID) + return errors.New("could not find track") + } + + if mt, ok := track.(*MediaTrack); ok { + mt.NotifySubscriberNodeMediaLoss(nodeID, uint8(fractionalLoss)) + } + + return nil +} + +func (l *LocalParticipant) HasPendingMigratedTrack() bool { + l.pendingTracksLock.RLock() + defer l.pendingTracksLock.RUnlock() + + for _, t := range l.pendingTracks { + if t.migrated { + return true + } + } + + return false +} + +func (l *LocalParticipant) DebugInfo() map[string]interface{} { + info := map[string]interface{}{} + pendingTrackInfo := make(map[string]interface{}) + + l.pendingTracksLock.RLock() + for clientID, ti := range l.pendingTracks { + pendingTrackInfo[clientID] = map[string]interface{}{ + "Sid": ti.Sid, + "Type": ti.Type.String(), + "Simulcast": ti.Simulcast, + } + } + l.pendingTracksLock.RUnlock() + + info["PendingTracks"] = pendingTrackInfo + info["UptrackManager"] = l.UptrackManager.DebugInfo() + + return info +} + +func (l *LocalParticipant) setupUptrackManager() { + l.UptrackManager = NewUptrackManager(UptrackManagerParams{ + SID: l.params.SID, + Logger: l.params.Logger, + }) + + l.UptrackManager.OnClose(l.onUptrackManagerClose) +} + +func (l *LocalParticipant) onUptrackManagerClose() { + close(l.rtcpCh) +} + +func (l *LocalParticipant) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { + signalCid := clientId + trackInfo := l.pendingTracks[clientId] + + if trackInfo == nil { + // + // If no match on client id, find first one matching type + // as MediaStreamTrack can change client id when transceiver + // is added to peer connection. + // + for cid, ti := range l.pendingTracks { + if ti.Type == kind { + trackInfo = ti + signalCid = cid + break + } + } + } + + // if still not found, we are done + if trackInfo == nil { + l.params.Logger.Errorw("track info not published prior to track", nil, "clientId", clientId) + } + return signalCid, trackInfo.TrackInfo +} + +func (l *LocalParticipant) rtcpSendWorker() { + defer Recover() + + // read from rtcpChan + for pkts := range l.rtcpCh { + if pkts == nil { + return + } + + fwdPkts := make([]rtcp.Packet, 0, len(pkts)) + for _, pkt := range pkts { + switch pkt.(type) { + case *rtcp.PictureLossIndication: + mediaSSRC := pkt.(*rtcp.PictureLossIndication).MediaSSRC + if l.pliThrottle.canSend(mediaSSRC) { + fwdPkts = append(fwdPkts, pkt) + } + case *rtcp.FullIntraRequest: + mediaSSRC := pkt.(*rtcp.FullIntraRequest).MediaSSRC + if l.pliThrottle.canSend(mediaSSRC) { + fwdPkts = append(fwdPkts, pkt) + } + default: + fwdPkts = append(fwdPkts, pkt) + } + } + + if len(fwdPkts) > 0 && l.onWriteRTCP != nil { + l.onWriteRTCP(fwdPkts) + } + } +} diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index bc3e06d69..ff3297d89 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -2,18 +2,14 @@ package rtc import ( "context" - "errors" - "sort" "sync" "sync/atomic" "time" - "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/utils" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" @@ -25,38 +21,21 @@ import ( ) const ( - lostUpdateDelta = time.Second + upLostUpdateDelta = time.Second connectionQualityUpdateInterval = 5 * time.Second - layerSelectionTolerance = 0.9 ) // MediaTrack represents a WebRTC track that needs to be forwarded // Implements MediaTrack and PublishedTrack interface type MediaTrack struct { params MediaTrackParams - ssrc webrtc.SSRC - streamID string - codec webrtc.RTPCodecParameters - muted utils.AtomicFlag numUpTracks uint32 - simulcasted utils.AtomicFlag buffer *buffer.Buffer - lock sync.RWMutex - - twcc *twcc.Responder - audioLevelMu sync.RWMutex audioLevel *AudioLevel - receiver sfu.Receiver - layerDimensions sync.Map // livekit.VideoQuality => *livekit.VideoLayer - layerSsrcs [livekit.VideoQuality_HIGH + 1]uint32 - - // track audio fraction lost statsLock sync.Mutex - maxDownFracLost uint8 - maxDownFracLostTs time.Time currentUpFracLost uint32 maxUpFracLost uint8 maxUpFracLostTs time.Time @@ -64,9 +43,9 @@ type MediaTrack struct { done chan struct{} - onClose []func() + *MediaTrackReceiver - *MediaTrackSubscriptions + onMediaLossUpdate func(trackID livekit.TrackID, fractionalLoss uint32) } type MediaTrackParams struct { @@ -76,43 +55,41 @@ type MediaTrackParams struct { ParticipantID livekit.ParticipantID ParticipantIdentity livekit.ParticipantIdentity // channel to send RTCP packets to the source - RTCPChan chan []rtcp.Packet - BufferFactory *buffer.Factory - ReceiverConfig ReceiverConfig - AudioConfig config.AudioConfig - Telemetry telemetry.TelemetryService - Logger logger.Logger - + RTCPChan chan []rtcp.Packet + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig SubscriberConfig DirectionConfig + AudioConfig config.AudioConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger } func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTrack { t := &MediaTrack{ params: params, - ssrc: track.SSRC(), - streamID: track.StreamID(), - codec: track.Codec(), connectionStats: connectionquality.NewConnectionStats(), done: make(chan struct{}), } - t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ - MediaTrack: t, - BufferFactory: params.BufferFactory, - ReceiverConfig: params.ReceiverConfig, - SubscriberConfig: params.SubscriberConfig, - Telemetry: params.Telemetry, - Logger: params.Logger, + t.MediaTrackReceiver = NewMediaTrackReceiver(MediaTrackReceiverParams{ + TrackInfo: params.TrackInfo, + MediaTrack: t, + ParticipantID: params.ParticipantID, + ParticipantIdentity: params.ParticipantIdentity, + BufferFactory: params.BufferFactory, + ReceiverConfig: params.ReceiverConfig, + SubscriberConfig: params.SubscriberConfig, + AudioConfig: params.AudioConfig, + Telemetry: params.Telemetry, + Logger: params.Logger, + }) + t.MediaTrackReceiver.OnMediaLossUpdate(func(fractionalLoss uint8) { + if t.buffer != nil { + // ok to access buffer since receivers are added before subscribers + t.buffer.SetLastFractionLostReport(fractionalLoss) + } }) - if params.TrackInfo.Muted { - t.SetMuted(true) - } - - if params.TrackInfo != nil && t.Kind() == livekit.TrackType_VIDEO { - t.UpdateVideoLayers(params.TrackInfo.Layers) - // LK-TODO: maybe use this or simulcast flag in TrackInfo to set simulcasted here - } // on close signal via closing channel to workers t.AddOnClose(t.closeChan) go t.updateStats() @@ -120,14 +97,8 @@ func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTra return t } -func (t *MediaTrack) TrySetSimulcastSSRC(layer uint8, ssrc uint32) { - if int(layer) < len(t.layerSsrcs) && t.layerSsrcs[layer] == 0 { - t.layerSsrcs[layer] = ssrc - } -} - -func (t *MediaTrack) ID() livekit.TrackID { - return livekit.TrackID(t.params.TrackInfo.Sid) +func (t *MediaTrack) OnMediaLossUpdate(f func(trackID livekit.TrackID, fractionalLoss uint32)) { + t.onMediaLossUpdate = f } func (t *MediaTrack) SignalCid() string { @@ -138,91 +109,11 @@ func (t *MediaTrack) SdpCid() string { return t.params.SdpCid } -func (t *MediaTrack) Kind() livekit.TrackType { - return t.params.TrackInfo.Type -} - -func (t *MediaTrack) Source() livekit.TrackSource { - return t.params.TrackInfo.Source -} - -func (t *MediaTrack) PublisherID() livekit.ParticipantID { - return t.params.ParticipantID -} - -func (t *MediaTrack) PublisherIdentity() livekit.ParticipantIdentity { - return t.params.ParticipantIdentity -} - -func (t *MediaTrack) IsSimulcast() bool { - return t.simulcasted.Get() -} - -func (t *MediaTrack) Name() string { - return t.params.TrackInfo.Name -} - -func (t *MediaTrack) IsMuted() bool { - return t.muted.Get() -} - -func (t *MediaTrack) SetMuted(muted bool) { - t.muted.TrySet(muted) - - t.lock.RLock() - if t.receiver != nil { - t.receiver.SetUpTrackPaused(muted) - } - t.lock.RUnlock() - - t.MediaTrackSubscriptions.SetMuted(muted) -} - -func (t *MediaTrack) AddOnClose(f func()) { - if f == nil { - return - } - t.onClose = append(t.onClose, f) -} - -func (t *MediaTrack) PublishLossPercentage() uint32 { +func (t *MediaTrack) publishLossPercentage() uint32 { return FixedPointToPercent(uint8(atomic.LoadUint32(&t.currentUpFracLost))) } -// AddSubscriber subscribes sub to current mediaTrack -func (t *MediaTrack) AddSubscriber(sub types.Participant) error { - t.lock.Lock() - defer t.lock.Unlock() - - if t.receiver == nil { - // cannot add, no receiver - return errors.New("cannot subscribe without a receiver in place") - } - - // using DownTrack from ion-sfu - streamId := string(t.params.ParticipantID) - if sub.ProtocolVersion().SupportsPackedStreamId() { - // when possible, pack both IDs in streamID to allow new streams to be generated - // react-native-webrtc still uses stream based APIs and require this - streamId = PackStreamID(t.params.ParticipantID, t.ID()) - } - - downTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, t.receiver.Codec(), NewWrappedReceiver(t.receiver, t.ID(), streamId)) - if err != nil { - return err - } - - if downTrack != nil { - if t.Kind() == livekit.TrackType_AUDIO { - downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) - } - - t.receiver.AddDownTrack(downTrack) - } - return nil -} - -func (t *MediaTrack) NumUpTracks() (uint32, uint32) { +func (t *MediaTrack) getNumUpTracks() (uint32, uint32) { numExpected := atomic.LoadUint32(&t.numUpTracks) numSubscribedLayers := t.numSubscribedLayers() @@ -230,21 +121,17 @@ func (t *MediaTrack) NumUpTracks() (uint32, uint32) { numExpected = numSubscribedLayers } - t.lock.RLock() numPublishing := uint32(0) - if t.receiver != nil { - numPublishing = uint32(t.receiver.NumAvailableSpatialLayers()) + receiver := t.Receiver() + if receiver != nil { + numPublishing = uint32(receiver.(sfu.Receiver).NumAvailableSpatialLayers()) } - t.lock.RUnlock() return numPublishing, numExpected } // AddReceiver adds a new RTP receiver to the track func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote, twcc *twcc.Responder) { - t.lock.Lock() - defer t.lock.Unlock() - buff, rtcpReader := t.params.BufferFactory.GetBufferPair(uint32(track.SSRC())) if buff == nil || rtcpReader == nil { logger.Errorw("could not retrieve buffer pair", nil, @@ -288,42 +175,39 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra } }) - if t.receiver == nil { - t.receiver = sfu.NewWebRTCReceiver(receiver, track, t.params.ParticipantID, + if t.Receiver() == nil { + wr := sfu.NewWebRTCReceiver( + receiver, + track, + t.PublisherID(), sfu.WithPliThrottle(0), sfu.WithLoadBalanceThreshold(20), - sfu.WithStreamTrackers()) - t.receiver.SetRTCPCh(t.params.RTCPChan) - t.receiver.OnCloseHandler(func() { + sfu.WithStreamTrackers(), + ) + wr.SetRTCPCh(t.params.RTCPChan) + wr.OnCloseHandler(func() { t.stopMaxQualityTimer() - - t.lock.Lock() - t.receiver = nil - onclose := t.onClose - t.lock.Unlock() - t.RemoveAllSubscribers() - t.params.Telemetry.TrackUnpublished(context.Background(), t.params.ParticipantID, t.ToProto(), uint32(track.SSRC())) - for _, f := range onclose { - f() - } + t.MediaTrackReceiver.Close() + t.params.Telemetry.TrackUnpublished(context.Background(), t.PublisherID(), t.ToProto(), uint32(track.SSRC())) }) - t.params.Telemetry.TrackPublished(context.Background(), t.params.ParticipantID, t.ToProto()) + t.params.Telemetry.TrackPublished(context.Background(), t.PublisherID(), t.ToProto()) if t.Kind() == livekit.TrackType_AUDIO { t.buffer = buff } + t.MediaTrackReceiver.SetupReceiver(wr) t.startMaxQualityTimer() } - t.receiver.AddUpTrack(track, buff) - t.params.Telemetry.AddUpTrack(t.params.ParticipantID, t.ID(), buff) + t.Receiver().(sfu.Receiver).AddUpTrack(track, buff) + t.params.Telemetry.AddUpTrack(t.PublisherID(), t.ID(), buff) atomic.AddUint32(&t.numUpTracks, 1) // LK-TODO: can remove this completely when VideoLayers protocol becomes the default as it has info from client or if we decide to use TrackInfo.Simulcast if atomic.LoadUint32(&t.numUpTracks) > 1 || track.RID() != "" { // cannot only rely on numUpTracks since we fire metadata events immediately after the first layer - t.simulcasted.TrySet(true) + t.MediaTrackReceiver.SetSimulcast(true) } if t.IsSimulcast() { @@ -338,25 +222,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra }) } -func (t *MediaTrack) ToProto() *livekit.TrackInfo { - info := t.params.TrackInfo - info.Muted = t.IsMuted() - info.Simulcast = t.simulcasted.Get() - layers := make([]*livekit.VideoLayer, 0) - t.layerDimensions.Range(func(_, val interface{}) bool { - if layer, ok := val.(*livekit.VideoLayer); ok { - if int(layer.Quality) < len(t.layerSsrcs) { - layer.Ssrc = t.layerSsrcs[layer.Quality] - } - layers = append(layers, layer) - } - return true - }) - info.Layers = layers - - return info -} - func (t *MediaTrack) GetAudioLevel() (level uint8, active bool) { t.audioLevelMu.RLock() defer t.audioLevelMu.RUnlock() @@ -367,61 +232,6 @@ func (t *MediaTrack) GetAudioLevel() (level uint8, active bool) { return t.audioLevel.GetLevel() } -func (t *MediaTrack) UpdateVideoLayers(layers []*livekit.VideoLayer) { - for _, layer := range layers { - t.layerDimensions.Store(layer.Quality, layer) - } - - t.MediaTrackSubscriptions.UpdateVideoLayers() - - // TODO: this might need to trigger a participant update for clients to pick up dimension change -} - -// GetQualityForDimension finds the closest quality to use for desired dimensions -// affords a 20% tolerance on dimension -func (t *MediaTrack) GetQualityForDimension(width, height uint32) livekit.VideoQuality { - quality := livekit.VideoQuality_HIGH - if t.Kind() == livekit.TrackType_AUDIO || t.params.TrackInfo.Height == 0 { - return quality - } - origSize := t.params.TrackInfo.Height - requestedSize := height - if t.params.TrackInfo.Width < t.params.TrackInfo.Height { - // for portrait videos - origSize = t.params.TrackInfo.Width - requestedSize = width - } - - // default sizes representing qualities low - high - layerSizes := []uint32{180, 360, origSize} - var providedSizes []uint32 - t.layerDimensions.Range(func(_, val interface{}) bool { - if layer, ok := val.(*livekit.VideoLayer); ok { - providedSizes = append(providedSizes, layer.Height) - } - return true - }) - if len(providedSizes) > 0 { - layerSizes = providedSizes - // comparing height always - requestedSize = height - sort.Slice(layerSizes, func(i, j int) bool { - return layerSizes[i] < layerSizes[j] - }) - } - - // finds the lowest layer that could satisfy client demands - requestedSize = uint32(float32(requestedSize) * layerSelectionTolerance) - for i, s := range layerSizes { - quality = livekit.VideoQuality(i) - if s >= requestedSize { - break - } - } - - return quality -} - func (t *MediaTrack) handlePublisherFeedback(packets []rtcp.Packet) { var maxLost uint8 var hasReport bool @@ -466,7 +276,7 @@ func (t *MediaTrack) handlePublisherFeedback(packets []rtcp.Packet) { } now := time.Now() - if now.Sub(t.maxUpFracLostTs) > lostUpdateDelta { + if now.Sub(t.maxUpFracLostTs) > upLostUpdateDelta { atomic.StoreUint32(&t.currentUpFracLost, uint32(t.maxUpFracLost)) t.maxUpFracLost = 0 t.maxUpFracLostTs = now @@ -491,77 +301,6 @@ func (t *MediaTrack) handlePublisherFeedback(packets []rtcp.Packet) { t.params.RTCPChan <- packets } -// handles max loss for audio packets -func (t *MediaTrack) handleMaxLossFeedback(_ *sfu.DownTrack, report *rtcp.ReceiverReport) { - t.statsLock.Lock() - for _, rr := range report.Reports { - if t.maxDownFracLost < rr.FractionLost { - t.maxDownFracLost = rr.FractionLost - } - } - t.statsLock.Unlock() - - t.maybeUpdateLoss() -} - -func (t *MediaTrack) NotifySubscriberNodeMediaLoss(_nodeID string, fractionalLoss uint8) { - t.statsLock.Lock() - if t.maxDownFracLost < fractionalLoss { - t.maxDownFracLost = fractionalLoss - } - t.statsLock.Unlock() - - t.maybeUpdateLoss() -} - -func (t *MediaTrack) maybeUpdateLoss() { - var ( - shouldUpdate bool - maxLost uint8 - ) - - t.statsLock.Lock() - now := time.Now() - if now.Sub(t.maxDownFracLostTs) > lostUpdateDelta { - shouldUpdate = true - maxLost = t.maxDownFracLost - t.maxDownFracLost = 0 - t.maxDownFracLostTs = now - } - t.statsLock.Unlock() - - if shouldUpdate && t.buffer != nil { - // ok to access buffer since receivers are added before subscribers - t.buffer.SetLastFractionLostReport(maxLost) - } -} - -func (t *MediaTrack) DebugInfo() map[string]interface{} { - info := map[string]interface{}{ - "ID": t.ID(), - "SSRC": t.ssrc, - "Kind": t.Kind().String(), - "PubMuted": t.muted.Get(), - } - - info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() - - t.lock.RLock() - if t.receiver != nil { - receiverInfo := t.receiver.DebugInfo() - for k, v := range receiverInfo { - info[k] = v - } - } - t.lock.RUnlock() - - return info -} - -func (t *MediaTrack) Receiver() sfu.TrackReceiver { - return t.receiver -} - func (t *MediaTrack) GetConnectionScore() float64 { t.statsLock.Lock() defer t.statsLock.Unlock() @@ -591,45 +330,14 @@ func (t *MediaTrack) updateStats() { func (t *MediaTrack) calculateVideoScore() { var reducedQuality bool - publishing, expected := t.NumUpTracks() + publishing, expected := t.getNumUpTracks() if publishing < expected { reducedQuality = true } - loss := t.PublishLossPercentage() + loss := t.publishLossPercentage() if expected == 0 { loss = 0 } t.connectionStats.Score = connectionquality.Loss2Score(loss, reducedQuality) } - -func (t *MediaTrack) OnSubscribedMaxQualityChange(f func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error) { - t.MediaTrackSubscriptions.OnSubscribedMaxQualityChange(func(subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) { - if f != nil && !t.IsMuted() { - _ = f(t.ID(), subscribedQualities) - } - - t.lock.RLock() - if t.receiver != nil { - t.receiver.SetMaxExpectedSpatialLayer(SpatialLayerForQuality(maxSubscribedQuality)) - } - t.lock.RUnlock() - }) -} - -//--------------------------- - -func SpatialLayerForQuality(quality livekit.VideoQuality) int32 { - switch quality { - case livekit.VideoQuality_LOW: - return 0 - case livekit.VideoQuality_MEDIUM: - return 1 - case livekit.VideoQuality_HIGH: - return 2 - case livekit.VideoQuality_OFF: - return -1 - default: - return -1 - } -} diff --git a/pkg/rtc/mediatrack_test.go b/pkg/rtc/mediatrack_test.go index f0265ef5e..d22330f5c 100644 --- a/pkg/rtc/mediatrack_test.go +++ b/pkg/rtc/mediatrack_test.go @@ -130,7 +130,7 @@ func TestSubscribedMaxQuality(t *testing.T) { actualTrackID := livekit.TrackID("") actualSubscribedQualities := []*livekit.SubscribedQuality{} - mt.OnSubscribedMaxQualityChange(func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error { + mt.OnSubscribedMaxQualityChange(func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, _maxSubscribedQuality livekit.VideoQuality) error { actualTrackID = trackID actualSubscribedQualities = subscribedQualities return nil @@ -175,7 +175,7 @@ func TestSubscribedMaxQuality(t *testing.T) { actualTrackID := livekit.TrackID("") actualSubscribedQualities := []*livekit.SubscribedQuality{} - mt.OnSubscribedMaxQualityChange(func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error { + mt.OnSubscribedMaxQualityChange(func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, _maxSubscribedQuality livekit.VideoQuality) error { actualTrackID = trackID actualSubscribedQualities = subscribedQualities return nil diff --git a/pkg/rtc/mediatrackreceiver.go b/pkg/rtc/mediatrackreceiver.go new file mode 100644 index 000000000..a92c74cd1 --- /dev/null +++ b/pkg/rtc/mediatrackreceiver.go @@ -0,0 +1,383 @@ +package rtc + +import ( + "errors" + "sort" + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/rtc/types" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/pion/rtcp" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +const ( + downLostUpdateDelta = time.Second + layerSelectionTolerance = 0.9 +) + +type MediaTrackReceiver struct { + params MediaTrackReceiverParams + muted utils.AtomicFlag + simulcasted utils.AtomicFlag + + lock sync.RWMutex + receiver sfu.TrackReceiver + layerDimensions sync.Map // livekit.VideoQuality => *livekit.VideoLayer + layerSsrcs [livekit.VideoQuality_HIGH + 1]uint32 + + // track audio fraction lost + downFracLostLock sync.Mutex + maxDownFracLost uint8 + maxDownFracLostTs time.Time + onMediaLossUpdate func(fractionalLoss uint8) + + onClose []func() + + *MediaTrackSubscriptions +} + +type MediaTrackReceiverParams struct { + TrackInfo *livekit.TrackInfo + MediaTrack types.MediaTrack + ParticipantID livekit.ParticipantID + ParticipantIdentity livekit.ParticipantIdentity + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + AudioConfig config.AudioConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger +} + +func NewMediaTrackReceiver(params MediaTrackReceiverParams) *MediaTrackReceiver { + t := &MediaTrackReceiver{ + params: params, + } + + t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ + MediaTrack: params.MediaTrack, + BufferFactory: params.BufferFactory, + ReceiverConfig: params.ReceiverConfig, + SubscriberConfig: params.SubscriberConfig, + Telemetry: params.Telemetry, + Logger: params.Logger, + }) + + if params.TrackInfo.Muted { + t.SetMuted(true) + } + + if params.TrackInfo != nil && t.Kind() == livekit.TrackType_VIDEO { + t.UpdateVideoLayers(params.TrackInfo.Layers) + // LK-TODO: maybe use this or simulcast flag in TrackInfo to set simulcasted here + } + + return t +} + +func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver) { + t.lock.Lock() + defer t.lock.Unlock() + + t.receiver = receiver +} + +func (t *MediaTrackReceiver) OnMediaLossUpdate(f func(fractionalLoss uint8)) { + t.onMediaLossUpdate = f +} + +func (t *MediaTrackReceiver) Close() { + t.lock.Lock() + t.receiver = nil + onclose := t.onClose + t.lock.Unlock() + + for _, f := range onclose { + f() + } +} + +func (t *MediaTrackReceiver) ID() livekit.TrackID { + return livekit.TrackID(t.params.TrackInfo.Sid) +} + +func (t *MediaTrackReceiver) Kind() livekit.TrackType { + return t.params.TrackInfo.Type +} + +func (t *MediaTrackReceiver) Source() livekit.TrackSource { + return t.params.TrackInfo.Source +} + +func (t *MediaTrackReceiver) PublisherID() livekit.ParticipantID { + return t.params.ParticipantID +} + +func (t *MediaTrackReceiver) PublisherIdentity() livekit.ParticipantIdentity { + return t.params.ParticipantIdentity +} + +func (t *MediaTrackReceiver) IsSimulcast() bool { + return t.simulcasted.Get() +} + +func (t *MediaTrackReceiver) SetSimulcast(simulcast bool) { + t.simulcasted.TrySet(simulcast) +} + +func (t *MediaTrackReceiver) Name() string { + return t.params.TrackInfo.Name +} + +func (t *MediaTrackReceiver) IsMuted() bool { + return t.muted.Get() +} + +func (t *MediaTrackReceiver) SetMuted(muted bool) { + t.muted.TrySet(muted) + + t.lock.RLock() + if t.receiver != nil { + t.receiver.SetUpTrackPaused(muted) + } + t.lock.RUnlock() + + t.MediaTrackSubscriptions.SetMuted(muted) +} + +func (t *MediaTrackReceiver) AddOnClose(f func()) { + if f == nil { + return + } + + t.lock.Lock() + t.onClose = append(t.onClose, f) + t.lock.Unlock() +} + +// AddSubscriber subscribes sub to current mediaTrack +func (t *MediaTrackReceiver) AddSubscriber(sub types.Participant) error { + t.lock.Lock() + defer t.lock.Unlock() + + if t.receiver == nil { + // cannot add, no receiver + return errors.New("cannot subscribe without a receiver in place") + } + + // using DownTrack from ion-sfu + streamId := string(t.PublisherID()) + if sub.ProtocolVersion().SupportsPackedStreamId() { + // when possible, pack both IDs in streamID to allow new streams to be generated + // react-native-webrtc still uses stream based APIs and require this + streamId = PackStreamID(t.PublisherID(), t.ID()) + } + + downTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, t.receiver.Codec(), NewWrappedReceiver(t.receiver, t.ID(), streamId)) + if err != nil { + return err + } + + if downTrack != nil { + if t.Kind() == livekit.TrackType_AUDIO { + downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) + } + + t.receiver.AddDownTrack(downTrack) + } + return nil +} + +func (t *MediaTrackReceiver) ToProto() *livekit.TrackInfo { + info := t.params.TrackInfo + info.Muted = t.IsMuted() + info.Simulcast = t.IsSimulcast() + layers := make([]*livekit.VideoLayer, 0) + t.layerDimensions.Range(func(_, val interface{}) bool { + if layer, ok := val.(*livekit.VideoLayer); ok { + if int(layer.Quality) < len(t.layerSsrcs) { + layer.Ssrc = t.layerSsrcs[layer.Quality] + } + layers = append(layers, layer) + } + return true + }) + info.Layers = layers + + return info +} + +func (t *MediaTrackReceiver) UpdateVideoLayers(layers []*livekit.VideoLayer) { + for _, layer := range layers { + t.layerDimensions.Store(layer.Quality, layer) + } + + t.MediaTrackSubscriptions.UpdateVideoLayers() + + // TODO: this might need to trigger a participant update for clients to pick up dimension change +} + +func (t *MediaTrackReceiver) TrySetSimulcastSSRC(layer uint8, ssrc uint32) { + if int(layer) < len(t.layerSsrcs) && t.layerSsrcs[layer] == 0 { + t.layerSsrcs[layer] = ssrc + } +} + +// GetQualityForDimension finds the closest quality to use for desired dimensions +// affords a 20% tolerance on dimension +func (t *MediaTrackReceiver) GetQualityForDimension(width, height uint32) livekit.VideoQuality { + quality := livekit.VideoQuality_HIGH + if t.Kind() == livekit.TrackType_AUDIO || t.params.TrackInfo.Height == 0 { + return quality + } + origSize := t.params.TrackInfo.Height + requestedSize := height + if t.params.TrackInfo.Width < t.params.TrackInfo.Height { + // for portrait videos + origSize = t.params.TrackInfo.Width + requestedSize = width + } + + // default sizes representing qualities low - high + layerSizes := []uint32{180, 360, origSize} + var providedSizes []uint32 + t.layerDimensions.Range(func(_, val interface{}) bool { + if layer, ok := val.(*livekit.VideoLayer); ok { + providedSizes = append(providedSizes, layer.Height) + } + return true + }) + if len(providedSizes) > 0 { + layerSizes = providedSizes + // comparing height always + requestedSize = height + sort.Slice(layerSizes, func(i, j int) bool { + return layerSizes[i] < layerSizes[j] + }) + } + + // finds the lowest layer that could satisfy client demands + requestedSize = uint32(float32(requestedSize) * layerSelectionTolerance) + for i, s := range layerSizes { + quality = livekit.VideoQuality(i) + if s >= requestedSize { + break + } + } + + return quality +} + +// handles max loss for audio packets +func (t *MediaTrackReceiver) handleMaxLossFeedback(_ *sfu.DownTrack, report *rtcp.ReceiverReport) { + t.downFracLostLock.Lock() + for _, rr := range report.Reports { + if t.maxDownFracLost < rr.FractionLost { + t.maxDownFracLost = rr.FractionLost + } + } + t.downFracLostLock.Unlock() + + t.maybeUpdateLoss() +} + +func (t *MediaTrackReceiver) NotifySubscriberNodeMediaLoss(_nodeID string, fractionalLoss uint8) { + t.downFracLostLock.Lock() + if t.maxDownFracLost < fractionalLoss { + t.maxDownFracLost = fractionalLoss + } + t.downFracLostLock.Unlock() + + t.maybeUpdateLoss() +} + +func (t *MediaTrackReceiver) maybeUpdateLoss() { + var ( + shouldUpdate bool + maxLost uint8 + ) + + t.downFracLostLock.Lock() + now := time.Now() + if now.Sub(t.maxDownFracLostTs) > downLostUpdateDelta { + shouldUpdate = true + maxLost = t.maxDownFracLost + t.maxDownFracLost = 0 + t.maxDownFracLostTs = now + } + t.downFracLostLock.Unlock() + + if shouldUpdate { + if t.onMediaLossUpdate != nil { + t.onMediaLossUpdate(maxLost) + } + } +} + +func (t *MediaTrackReceiver) DebugInfo() map[string]interface{} { + info := map[string]interface{}{ + "ID": t.ID(), + "Kind": t.Kind().String(), + "PubMuted": t.muted.Get(), + } + + info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() + + t.lock.RLock() + if t.receiver != nil { + receiverInfo := t.receiver.(sfu.Receiver).DebugInfo() + for k, v := range receiverInfo { + info[k] = v + } + } + t.lock.RUnlock() + + return info +} + +func (t *MediaTrackReceiver) Receiver() sfu.TrackReceiver { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.receiver +} + +func (t *MediaTrackReceiver) OnSubscribedMaxQualityChange(f func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) error) { + t.MediaTrackSubscriptions.OnSubscribedMaxQualityChange(func(subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) { + if f != nil && !t.IsMuted() { + _ = f(t.ID(), subscribedQualities, maxSubscribedQuality) + } + + t.lock.RLock() + if t.receiver != nil { + t.receiver.SetMaxExpectedSpatialLayer(SpatialLayerForQuality(maxSubscribedQuality)) + } + t.lock.RUnlock() + }) +} + +//--------------------------- + +func SpatialLayerForQuality(quality livekit.VideoQuality) int32 { + switch quality { + case livekit.VideoQuality_LOW: + return 0 + case livekit.VideoQuality_MEDIUM: + return 1 + case livekit.VideoQuality_HIGH: + return 2 + case livekit.VideoQuality_OFF: + return -1 + default: + return -1 + } +} diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index 9df177162..4c8fef3c7 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -193,6 +193,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web t.subscribedTracksMu.Unlock() t.maybeNotifyNoSubscribers() + t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) // ignore if the subscribing sub is not connected @@ -251,21 +252,31 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web // stop all forwarders to the client func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.ParticipantID) { subTrack := t.getSubscribedTrack(participantID) + + t.subscribedTracksMu.Lock() + delete(t.subscribedTracks, participantID) + t.subscribedTracksMu.Unlock() + if subTrack != nil { go subTrack.DownTrack().Close() } + + t.maybeNotifyNoSubscribers() } func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { t.params.Logger.Debugw("removing all subscribers", "track", t.params.MediaTrack.ID()) - t.subscribedTracksMu.RLock() + t.subscribedTracksMu.Lock() subscribedTracks := t.subscribedTracks - t.subscribedTracksMu.RUnlock() + t.subscribedTracks = make(map[livekit.ParticipantID]types.SubscribedTrack) + t.subscribedTracksMu.Unlock() for _, subTrack := range subscribedTracks { go subTrack.DownTrack().Close() } + + t.maybeNotifyNoSubscribers() } func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index dcf3c48c2..79798f1b9 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -23,7 +23,6 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" - "github.com/livekit/livekit-server/pkg/sfu/twcc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" "github.com/livekit/livekit-server/version" @@ -74,10 +73,7 @@ type ParticipantImpl struct { // JSON encoded metadata to pass to clients metadata string - // hold reference for MediaTrack - twcc *twcc.Responder - - *UptrackManager + *LocalParticipant // tracks the current participant is subscribed to, map of sid => DownTrack subscribedTracks map[livekit.TrackID]types.SubscribedTrack @@ -189,7 +185,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.subscriber.OnStreamStateChange(p.onStreamStateChange) - p.setupUptrackManager() + p.setupLocalParticipant() return p, nil } @@ -243,7 +239,7 @@ func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { Hidden: p.Hidden(), Recorder: p.IsRecorder(), } - info.Tracks = p.UptrackManager.ToProto() + info.Tracks = p.LocalParticipant.ToProto() return info } @@ -341,7 +337,7 @@ func (p *ParticipantImpl) HandleOffer(sdp webrtc.SessionDescription) (answer web } func (p *ParticipantImpl) AddMigratedTrack(cid string, ti *livekit.TrackInfo) { - p.UptrackManager.AddMigratedTrack(cid, ti) + p.LocalParticipant.AddMigratedTrack(cid, ti) } // AddTrack is called when client intends to publish track. @@ -355,7 +351,7 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { return } - ti := p.UptrackManager.AddTrack(req) + ti := p.LocalParticipant.AddTrack(req) if ti == nil { return } @@ -398,7 +394,7 @@ func (p *ParticipantImpl) AddICECandidate(candidate webrtc.ICECandidateInit, tar func (p *ParticipantImpl) Start() { p.once.Do(func() { - p.UptrackManager.Start() + p.LocalParticipant.Start() go p.downTracksRTCPWorker() }) } @@ -416,7 +412,7 @@ func (p *ParticipantImpl) Close() error { }, }) - p.UptrackManager.Close() + p.LocalParticipant.Close() p.lock.Lock() disallowedSubscriptions := make(map[livekit.TrackID]livekit.ParticipantID) @@ -471,7 +467,7 @@ func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { var pendingOffer *webrtc.SessionDescription p.migrateState.Store(s) if s == types.MigrateStateSync { - if !p.UptrackManager.HasPendingMigratedTrack() { + if !p.LocalParticipant.HasPendingMigratedTrack() { p.migrateState.Store(types.MigrateComplete) } pendingOffer = p.pendingOffer @@ -636,12 +632,12 @@ func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fro }) } - p.UptrackManager.SetTrackMuted(trackID, muted) + p.LocalParticipant.SetTrackMuted(trackID, muted) } func (p *ParticipantImpl) GetConnectionQuality() *livekit.ConnectionQualityInfo { // avg loss across all tracks, weigh published the same as subscribed - scores, numTracks := p.UptrackManager.GetConnectionQuality() + scores, numTracks := p.LocalParticipant.GetConnectionQuality() p.lock.RLock() for _, subTrack := range p.subscribedTracks { @@ -812,8 +808,8 @@ func (p *ParticipantImpl) SubscriptionPermissionUpdate(publisherID livekit.Parti } } -func (p *ParticipantImpl) setupUptrackManager() { - p.UptrackManager = NewUptrackManager(UptrackManagerParams{ +func (p *ParticipantImpl) setupLocalParticipant() { + p.LocalParticipant = NewLocalParticipant(LocalParticipantParams{ Identity: p.params.Identity, SID: p.params.SID, Config: p.params.Config, @@ -823,16 +819,17 @@ func (p *ParticipantImpl) setupUptrackManager() { Logger: p.params.Logger, }) - p.UptrackManager.OnTrackPublished(func(track types.PublishedTrack) { - if !p.UptrackManager.HasPendingMigratedTrack() { + p.LocalParticipant.OnTrackPublished(func(track types.PublishedTrack) { + if !p.LocalParticipant.HasPendingMigratedTrack() { p.SetMigrateState(types.MigrateComplete) } + if p.onTrackPublished != nil { p.onTrackPublished(p, track) } }) - p.UptrackManager.OnTrackUpdated(func(track types.PublishedTrack, onlyIfReady bool) { + p.LocalParticipant.OnTrackUpdated(func(track types.PublishedTrack, onlyIfReady bool) { if onlyIfReady && !p.IsReady() { return } @@ -842,13 +839,13 @@ func (p *ParticipantImpl) setupUptrackManager() { } }) - p.UptrackManager.OnWriteRTCP(func(pkts []rtcp.Packet) { + p.LocalParticipant.OnWriteRTCP(func(pkts []rtcp.Packet) { if err := p.publisher.pc.WriteRTCP(pkts); err != nil { p.params.Logger.Errorw("could not write RTCP to participant", err) } }) - p.UptrackManager.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) + p.LocalParticipant.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) } func (p *ParticipantImpl) sendIceCandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { @@ -939,7 +936,18 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w return } - p.UptrackManager.MediaTrackReceived(track, rtpReceiver, p) + var mid string + for _, tr := range p.publisher.pc.GetTransceivers() { + if tr.Receiver() == rtpReceiver { + mid = tr.Mid() + break + } + } + + publishedTrack, isNewTrack := p.LocalParticipant.MediaTrackReceived(track, rtpReceiver, mid) + if !isNewTrack && publishedTrack != nil && p.IsReady() && p.onTrackUpdated != nil { + p.onTrackUpdated(p, publishedTrack) + } } func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { @@ -1097,7 +1105,7 @@ func (p *ParticipantImpl) configureReceiverDTX() { // multiple audio tracks. At that point, there might be a need to // rely on something like order of tracks. TODO // - enableDTX := p.UptrackManager.GetDTX() + enableDTX := p.LocalParticipant.GetDTX() transceivers := p.publisher.pc.GetTransceivers() for _, transceiver := range transceivers { if transceiver.Kind() != webrtc.RTPCodecTypeAudio { @@ -1170,7 +1178,7 @@ func (p *ParticipantImpl) onStreamStateChange(update *sfu.StreamStateUpdate) err }) } -func (p *ParticipantImpl) onSubscribedMaxQualityChange(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error { +func (p *ParticipantImpl) onSubscribedMaxQualityChange(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, _maxSubscribedQuality livekit.VideoQuality) error { if len(subscribedQualities) == 0 { return nil } @@ -1193,7 +1201,7 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { "State": p.State().String(), } - uptrackManagerInfo := p.UptrackManager.DebugInfo() + localParticipantInfo := p.LocalParticipant.DebugInfo() subscribedTrackInfo := make(map[livekit.TrackID]interface{}) p.lock.RLock() @@ -1204,7 +1212,7 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { } p.lock.RUnlock() - info["UptrackManager"] = uptrackManagerInfo + info["LocalParticipant"] = localParticipantInfo info["SubscribedTracks"] = subscribedTrackInfo return info diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 315a1a3c7..126213992 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -80,7 +80,8 @@ func TestTrackPublishing(t *testing.T) { p.OnTrackPublished(func(p types.Participant, track types.PublishedTrack) { published = true }) - p.UptrackManager.handleTrackPublished(track) + p.LocalParticipant.AddPublishedTrack(track) + p.LocalParticipant.handleTrackPublished(track) require.True(t, published) require.False(t, updated) @@ -202,7 +203,8 @@ func TestDisconnectTiming(t *testing.T) { } }() track := &typesfakes.FakePublishedTrack{} - p.UptrackManager.handleTrackPublished(track) + p.LocalParticipant.AddPublishedTrack(track) + p.LocalParticipant.handleTrackPublished(track) // close channel and then try to Negotiate msg.Close() @@ -220,7 +222,7 @@ func TestMuteSetting(t *testing.T) { t.Run("can set mute when track is pending", func(t *testing.T) { p := newParticipantForTest("test") ti := &livekit.TrackInfo{Sid: "testTrack"} - p.UptrackManager.pendingTracks["cid"] = &pendingTrackInfo{TrackInfo: ti} + p.LocalParticipant.pendingTracks["cid"] = &pendingTrackInfo{TrackInfo: ti} p.SetTrackMuted(livekit.TrackID(ti.Sid), true, false) require.True(t, ti.Muted) @@ -234,7 +236,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - _, ti := p.UptrackManager.getPendingTrack("cid", livekit.TrackType_AUDIO) + _, ti := p.LocalParticipant.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) }) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 0886cb8b3..1f4b89fd2 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -635,10 +635,15 @@ func (r *Room) subscribeToExistingTracks(p types.Participant) { // broadcast an update about participant p func (r *Room) broadcastParticipantState(p types.Participant, skipSource bool) { + // + // This is a critical section to ensure that participant update time and + // the corresponding data are paired properly. + // r.lock.Lock() updatedAt := time.Now() updates := ToProtoParticipants([]types.Participant{p}) r.lock.Unlock() + if p.Hidden() { if !skipSource { // send update only to hidden participant diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 6feb908ec..946831237 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -169,9 +169,6 @@ type PublishedTrack interface { SdpCid() string ToProto() *livekit.TrackInfo - // returns number of uptracks that are publishing, registered - NumUpTracks() (uint32, uint32) - PublishLossPercentage() uint32 Receiver() sfu.TrackReceiver GetConnectionScore() float64 @@ -179,6 +176,8 @@ type PublishedTrack interface { UpdateVideoLayers(layers []*livekit.VideoLayer) + OnSubscribedMaxQualityChange(f func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error) + // callbacks AddOnClose(func()) } diff --git a/pkg/rtc/types/typesfakes/fake_published_track.go b/pkg/rtc/types/typesfakes/fake_published_track.go index 084d47362..2c920c45b 100644 --- a/pkg/rtc/types/typesfakes/fake_published_track.go +++ b/pkg/rtc/types/typesfakes/fake_published_track.go @@ -149,27 +149,10 @@ type FakePublishedTrack struct { arg1 string arg2 uint8 } - NumUpTracksStub func() (uint32, uint32) - numUpTracksMutex sync.RWMutex - numUpTracksArgsForCall []struct { - } - numUpTracksReturns struct { - result1 uint32 - result2 uint32 - } - numUpTracksReturnsOnCall map[int]struct { - result1 uint32 - result2 uint32 - } - PublishLossPercentageStub func() uint32 - publishLossPercentageMutex sync.RWMutex - publishLossPercentageArgsForCall []struct { - } - publishLossPercentageReturns struct { - result1 uint32 - } - publishLossPercentageReturnsOnCall map[int]struct { - result1 uint32 + OnSubscribedMaxQualityChangeStub func(func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error) + onSubscribedMaxQualityChangeMutex sync.RWMutex + onSubscribedMaxQualityChangeArgsForCall []struct { + arg1 func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error } PublisherIDStub func() livekit.ParticipantID publisherIDMutex sync.RWMutex @@ -1017,113 +1000,36 @@ func (fake *FakePublishedTrack) NotifySubscriberNodeMediaLossArgsForCall(i int) return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakePublishedTrack) NumUpTracks() (uint32, uint32) { - fake.numUpTracksMutex.Lock() - ret, specificReturn := fake.numUpTracksReturnsOnCall[len(fake.numUpTracksArgsForCall)] - fake.numUpTracksArgsForCall = append(fake.numUpTracksArgsForCall, struct { - }{}) - stub := fake.NumUpTracksStub - fakeReturns := fake.numUpTracksReturns - fake.recordInvocation("NumUpTracks", []interface{}{}) - fake.numUpTracksMutex.Unlock() +func (fake *FakePublishedTrack) OnSubscribedMaxQualityChange(arg1 func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error) { + fake.onSubscribedMaxQualityChangeMutex.Lock() + fake.onSubscribedMaxQualityChangeArgsForCall = append(fake.onSubscribedMaxQualityChangeArgsForCall, struct { + arg1 func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error + }{arg1}) + stub := fake.OnSubscribedMaxQualityChangeStub + fake.recordInvocation("OnSubscribedMaxQualityChange", []interface{}{arg1}) + fake.onSubscribedMaxQualityChangeMutex.Unlock() if stub != nil { - return stub() + fake.OnSubscribedMaxQualityChangeStub(arg1) } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 } -func (fake *FakePublishedTrack) NumUpTracksCallCount() int { - fake.numUpTracksMutex.RLock() - defer fake.numUpTracksMutex.RUnlock() - return len(fake.numUpTracksArgsForCall) +func (fake *FakePublishedTrack) OnSubscribedMaxQualityChangeCallCount() int { + fake.onSubscribedMaxQualityChangeMutex.RLock() + defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() + return len(fake.onSubscribedMaxQualityChangeArgsForCall) } -func (fake *FakePublishedTrack) NumUpTracksCalls(stub func() (uint32, uint32)) { - fake.numUpTracksMutex.Lock() - defer fake.numUpTracksMutex.Unlock() - fake.NumUpTracksStub = stub +func (fake *FakePublishedTrack) OnSubscribedMaxQualityChangeCalls(stub func(func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error)) { + fake.onSubscribedMaxQualityChangeMutex.Lock() + defer fake.onSubscribedMaxQualityChangeMutex.Unlock() + fake.OnSubscribedMaxQualityChangeStub = stub } -func (fake *FakePublishedTrack) NumUpTracksReturns(result1 uint32, result2 uint32) { - fake.numUpTracksMutex.Lock() - defer fake.numUpTracksMutex.Unlock() - fake.NumUpTracksStub = nil - fake.numUpTracksReturns = struct { - result1 uint32 - result2 uint32 - }{result1, result2} -} - -func (fake *FakePublishedTrack) NumUpTracksReturnsOnCall(i int, result1 uint32, result2 uint32) { - fake.numUpTracksMutex.Lock() - defer fake.numUpTracksMutex.Unlock() - fake.NumUpTracksStub = nil - if fake.numUpTracksReturnsOnCall == nil { - fake.numUpTracksReturnsOnCall = make(map[int]struct { - result1 uint32 - result2 uint32 - }) - } - fake.numUpTracksReturnsOnCall[i] = struct { - result1 uint32 - result2 uint32 - }{result1, result2} -} - -func (fake *FakePublishedTrack) PublishLossPercentage() uint32 { - fake.publishLossPercentageMutex.Lock() - ret, specificReturn := fake.publishLossPercentageReturnsOnCall[len(fake.publishLossPercentageArgsForCall)] - fake.publishLossPercentageArgsForCall = append(fake.publishLossPercentageArgsForCall, struct { - }{}) - stub := fake.PublishLossPercentageStub - fakeReturns := fake.publishLossPercentageReturns - fake.recordInvocation("PublishLossPercentage", []interface{}{}) - fake.publishLossPercentageMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakePublishedTrack) PublishLossPercentageCallCount() int { - fake.publishLossPercentageMutex.RLock() - defer fake.publishLossPercentageMutex.RUnlock() - return len(fake.publishLossPercentageArgsForCall) -} - -func (fake *FakePublishedTrack) PublishLossPercentageCalls(stub func() uint32) { - fake.publishLossPercentageMutex.Lock() - defer fake.publishLossPercentageMutex.Unlock() - fake.PublishLossPercentageStub = stub -} - -func (fake *FakePublishedTrack) PublishLossPercentageReturns(result1 uint32) { - fake.publishLossPercentageMutex.Lock() - defer fake.publishLossPercentageMutex.Unlock() - fake.PublishLossPercentageStub = nil - fake.publishLossPercentageReturns = struct { - result1 uint32 - }{result1} -} - -func (fake *FakePublishedTrack) PublishLossPercentageReturnsOnCall(i int, result1 uint32) { - fake.publishLossPercentageMutex.Lock() - defer fake.publishLossPercentageMutex.Unlock() - fake.PublishLossPercentageStub = nil - if fake.publishLossPercentageReturnsOnCall == nil { - fake.publishLossPercentageReturnsOnCall = make(map[int]struct { - result1 uint32 - }) - } - fake.publishLossPercentageReturnsOnCall[i] = struct { - result1 uint32 - }{result1} +func (fake *FakePublishedTrack) OnSubscribedMaxQualityChangeArgsForCall(i int) func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxQuality livekit.VideoQuality) error { + fake.onSubscribedMaxQualityChangeMutex.RLock() + defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() + argsForCall := fake.onSubscribedMaxQualityChangeArgsForCall[i] + return argsForCall.arg1 } func (fake *FakePublishedTrack) PublisherID() livekit.ParticipantID { @@ -1721,10 +1627,8 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() fake.notifySubscriberNodeMediaLossMutex.RLock() defer fake.notifySubscriberNodeMediaLossMutex.RUnlock() - fake.numUpTracksMutex.RLock() - defer fake.numUpTracksMutex.RUnlock() - fake.publishLossPercentageMutex.RLock() - defer fake.publishLossPercentageMutex.RUnlock() + fake.onSubscribedMaxQualityChangeMutex.RLock() + defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() fake.publisherIDMutex.RLock() defer fake.publisherIDMutex.RUnlock() fake.publisherIdentityMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index cf5e3b15f..14e27edba 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -6,46 +6,22 @@ import ( "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" - "github.com/livekit/protocol/utils" - "github.com/pion/rtcp" - "github.com/pion/webrtc/v3" - "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/rtc/types" - "github.com/livekit/livekit-server/pkg/sfu" - "github.com/livekit/livekit-server/pkg/sfu/twcc" - "github.com/livekit/livekit-server/pkg/telemetry" ) type UptrackManagerParams struct { - Identity livekit.ParticipantIdentity - SID livekit.ParticipantID - Config *WebRTCConfig - AudioConfig config.AudioConfig - Telemetry telemetry.TelemetryService - ThrottleConfig config.PLIThrottleConfig - Logger logger.Logger -} - -type pendingTrackInfo struct { - *livekit.TrackInfo - migrated bool + SID livekit.ParticipantID + Logger logger.Logger } type UptrackManager struct { - params UptrackManagerParams - rtcpCh chan []rtcp.Packet - pliThrottle *pliThrottle + params UptrackManagerParams closed bool - // hold reference for MediaTrack - twcc *twcc.Responder - // publishedTracks that participant is publishing publishedTracks map[livekit.TrackID]types.PublishedTrack - // client intended to publish, yet to be reconciled - pendingTracks map[string]*pendingTrackInfo // keeps track of subscriptions that are awaiting permissions subscriptionPermissions map[livekit.ParticipantID]*livekit.TrackPermission // subscriberID => *livekit.TrackPermission // keeps tracks of track specific subscribers who are awaiting permission @@ -54,31 +30,24 @@ type UptrackManager struct { lock sync.RWMutex // callbacks & handlers - onTrackPublished func(track types.PublishedTrack) + onClose func() onTrackUpdated func(track types.PublishedTrack, onlyIfReady bool) - onWriteRTCP func(pkts []rtcp.Packet) - onSubscribedMaxQualityChange func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error + onSubscribedMaxQualityChange func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) error } func NewUptrackManager(params UptrackManagerParams) *UptrackManager { return &UptrackManager{ params: params, - rtcpCh: make(chan []rtcp.Packet, 50), - pliThrottle: newPLIThrottle(params.ThrottleConfig), publishedTracks: make(map[livekit.TrackID]types.PublishedTrack, 0), - pendingTracks: make(map[string]*pendingTrackInfo), pendingSubscriptions: make(map[livekit.TrackID][]livekit.ParticipantID), } } func (u *UptrackManager) Start() { - go u.rtcpSendWorker() } func (u *UptrackManager) Close() { u.lock.Lock() - defer u.lock.Unlock() - u.closed = true // remove all subscribers @@ -86,11 +55,18 @@ func (u *UptrackManager) Close() { t.RemoveAllSubscribers() } - if len(u.publishedTracks) == 0 { - close(u.rtcpCh) + notify := len(u.publishedTracks) == 0 + u.lock.Unlock() + + if notify && u.onClose != nil { + u.onClose() } } +func (u *UptrackManager) OnClose(f func()) { + u.onClose = f +} + func (u *UptrackManager) ToProto() []*livekit.TrackInfo { u.lock.RLock() defer u.lock.RUnlock() @@ -103,59 +79,14 @@ func (u *UptrackManager) ToProto() []*livekit.TrackInfo { return trackInfos } -func (u *UptrackManager) OnTrackPublished(f func(track types.PublishedTrack)) { - u.onTrackPublished = f -} - func (u *UptrackManager) OnTrackUpdated(f func(track types.PublishedTrack, onlyIfReady bool)) { u.onTrackUpdated = f } -func (u *UptrackManager) OnWriteRTCP(f func(pkts []rtcp.Packet)) { - u.onWriteRTCP = f -} - -func (u *UptrackManager) OnSubscribedMaxQualityChange(f func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error) { +func (u *UptrackManager) OnSubscribedMaxQualityChange(f func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) error) { u.onSubscribedMaxQualityChange = f } -// AddTrack is called when client intends to publish track. -// records track details and lets client know it's ok to proceed -func (u *UptrackManager) AddTrack(req *livekit.AddTrackRequest) *livekit.TrackInfo { - u.lock.Lock() - defer u.lock.Unlock() - - // if track is already published, reject - if u.pendingTracks[req.Cid] != nil { - return nil - } - - if u.getPublishedTrackBySignalCid(req.Cid) != nil || u.getPublishedTrackBySdpCid(req.Cid) != nil { - return nil - } - - ti := &livekit.TrackInfo{ - Type: req.Type, - Name: req.Name, - Sid: utils.NewGuid(utils.TrackPrefix), - Width: req.Width, - Height: req.Height, - Muted: req.Muted, - DisableDtx: req.DisableDtx, - Source: req.Source, - Layers: req.Layers, - } - u.pendingTracks[req.Cid] = &pendingTrackInfo{TrackInfo: ti} - - return ti -} - -func (u *UptrackManager) AddMigratedTrack(cid string, ti *livekit.TrackInfo) { - u.lock.Lock() - defer u.lock.Unlock() - u.pendingTracks[cid] = &pendingTrackInfo{ti, true} -} - // AddSubscriber subscribes op to all publishedTracks func (u *UptrackManager) AddSubscriber(sub types.Participant, params types.AddSubscriberParams) (int, error) { var tracks []types.PublishedTrack @@ -210,64 +141,24 @@ func (u *UptrackManager) RemoveSubscriber(sub types.Participant, trackID livekit u.lock.Unlock() } -func (u *UptrackManager) SetTrackMuted(trackID livekit.TrackID, muted bool) { - isPending := false +func (u *UptrackManager) SetTrackMuted(trackID livekit.TrackID, muted bool) types.PublishedTrack { u.lock.RLock() - for _, ti := range u.pendingTracks { - if livekit.TrackID(ti.Sid) == trackID { - ti.Muted = muted - isPending = true - } - } track := u.publishedTracks[trackID] u.lock.RUnlock() - if track == nil { - if !isPending { - u.params.Logger.Warnw("could not locate track", nil, "track", trackID) - } - return - } - currentMuted := track.IsMuted() - track.SetMuted(muted) + if track != nil { + currentMuted := track.IsMuted() + track.SetMuted(muted) - if currentMuted != track.IsMuted() && u.onTrackUpdated != nil { - u.params.Logger.Debugw("mute status changed", - "track", trackID, - "muted", track.IsMuted()) - u.onTrackUpdated(track, false) - } -} - -func (u *UptrackManager) GetAudioLevel() (level uint8, active bool) { - u.lock.RLock() - defer u.lock.RUnlock() - - level = SilentAudioLevel - for _, pt := range u.publishedTracks { - tl, ta := pt.GetAudioLevel() - if ta { - active = true - if tl < level { - level = tl - } + if currentMuted != track.IsMuted() && u.onTrackUpdated != nil { + u.params.Logger.Debugw("mute status changed", + "track", trackID, + "muted", track.IsMuted()) + u.onTrackUpdated(track, false) } } - return -} -func (u *UptrackManager) GetConnectionQuality() (scores float64, numTracks int) { - u.lock.RLock() - defer u.lock.RUnlock() - - for _, pt := range u.publishedTracks { - if pt.IsMuted() { - continue - } - scores += pt.GetConnectionScore() - numTracks++ - } - return + return track } func (u *UptrackManager) GetPublishedTrack(trackID livekit.TrackID) types.PublishedTrack { @@ -288,25 +179,6 @@ func (u *UptrackManager) GetPublishedTracks() []types.PublishedTrack { return tracks } -func (u *UptrackManager) GetDTX() bool { - u.lock.RLock() - defer u.lock.RUnlock() - - var trackInfo *livekit.TrackInfo - for _, ti := range u.pendingTracks { - if ti.Type == livekit.TrackType_AUDIO { - trackInfo = ti.TrackInfo - break - } - } - - if trackInfo == nil { - return false - } - - return !trackInfo.DisableDtx -} - func (u *UptrackManager) UpdateSubscriptionPermissions( permissions *livekit.UpdateSubscriptionPermissions, resolver func(participantID livekit.ParticipantID) types.Participant, @@ -333,106 +205,43 @@ func (u *UptrackManager) UpdateVideoLayers(updateVideoLayers *livekit.UpdateVide return nil } -func (u *UptrackManager) UpdateSubscribedQuality(nodeID string, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error { - track := u.GetPublishedTrack(trackID) - if track == nil { - u.params.Logger.Warnw("could not find track", nil, "trackID", trackID) - return errors.New("could not find track") - } +func (u *UptrackManager) AddPublishedTrack(track types.PublishedTrack) { + track.OnSubscribedMaxQualityChange(u.onSubscribedMaxQualityChange) - if mt, ok := track.(*MediaTrack); ok { - mt.NotifySubscriberNodeMaxQuality(nodeID, maxQuality) - } - - return nil -} - -func (u *UptrackManager) UpdateMediaLoss(nodeID string, trackID livekit.TrackID, fractionalLoss uint32) error { - track := u.GetPublishedTrack(trackID) - if track == nil { - u.params.Logger.Warnw("could not find track", nil, "trackID", trackID) - return errors.New("could not find track") - } - - if mt, ok := track.(*MediaTrack); ok { - mt.NotifySubscriberNodeMediaLoss(nodeID, uint8(fractionalLoss)) - } - - return nil -} - -// when a new remoteTrack is created, creates a Track and adds it to room -func (u *UptrackManager) MediaTrackReceived(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver, p *ParticipantImpl) { - var newTrack bool - - // use existing mediatrack to handle simulcast u.lock.Lock() - mt, ok := u.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) - if !ok { - signalCid, ti := u.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) - if ti == nil { - u.lock.Unlock() - return - } - - var mid string - for _, tr := range p.publisher.pc.GetTransceivers() { - if tr.Receiver() == rtpReceiver { - mid = tr.Mid() - break - } - } - ti.MimeType = track.Codec().MimeType - ti.Mid = mid - - mt = NewMediaTrack(track, MediaTrackParams{ - TrackInfo: ti, - SignalCid: signalCid, - SdpCid: track.ID(), - ParticipantID: u.params.SID, - ParticipantIdentity: u.params.Identity, - RTCPChan: u.rtcpCh, - BufferFactory: u.params.Config.BufferFactory, - ReceiverConfig: u.params.Config.Receiver, - AudioConfig: u.params.AudioConfig, - Telemetry: u.params.Telemetry, - Logger: u.params.Logger, - SubscriberConfig: u.params.Config.Subscriber, - }) - for ssrc, t := range p.params.SimTracks { - if t.Mid != mid { - continue - } - mt.TrySetSimulcastSSRC(uint8(sfu.RidToLayer(t.Rid)), ssrc) - } - mt.OnSubscribedMaxQualityChange(u.onSubscribedMaxQualityChange) - - // add to published and clean up pending - u.publishedTracks[mt.ID()] = mt - delete(u.pendingTracks, signalCid) - - newTrack = true - } - - ssrc := uint32(track.SSRC()) - u.pliThrottle.addTrack(ssrc, track.RID()) - if u.twcc == nil { - u.twcc = twcc.NewTransportWideCCResponder(ssrc) - u.twcc.OnFeedback(func(pkt rtcp.RawPacket) { - if u.onWriteRTCP != nil { - u.onWriteRTCP([]rtcp.Packet{&pkt}) - } - }) + if _, ok := u.publishedTracks[track.ID()]; !ok { + u.publishedTracks[track.ID()] = track } u.lock.Unlock() - mt.AddReceiver(rtpReceiver, track, u.twcc) + track.AddOnClose(func() { + notifyClose := false - if newTrack { - u.handleTrackPublished(mt) - } else { - u.onTrackUpdated(mt, true) - } + // cleanup + u.lock.Lock() + trackID := track.ID() + delete(u.publishedTracks, trackID) + delete(u.pendingSubscriptions, trackID) + // not modifying subscription permissions, will get reset on next update from participant + + if u.closed && len(u.publishedTracks) == 0 { + notifyClose = true + } + u.lock.Unlock() + + // only send this when client is in a ready state + if u.onTrackUpdated != nil { + u.onTrackUpdated(track, true) + } + + if notifyClose && u.onClose != nil { + u.onClose() + } + }) +} + +func (u *UptrackManager) RemovePublishedTrack(track types.PublishedTrack) { + track.RemoveAllSubscribers() } // should be called with lock held @@ -440,6 +249,18 @@ func (u *UptrackManager) getPublishedTrack(trackID livekit.TrackID) types.Publis return u.publishedTracks[trackID] } +func (u *UptrackManager) GetPublishedTrackBySignalCidOrSdpCid(clientId string) types.PublishedTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + track := u.getPublishedTrackBySignalCid(clientId) + if track == nil { + track = u.getPublishedTrackBySdpCid(clientId) + } + + return track +} + // should be called with lock held func (u *UptrackManager) getPublishedTrackBySignalCid(clientId string) types.PublishedTrack { for _, publishedTrack := range u.publishedTracks { @@ -451,6 +272,13 @@ func (u *UptrackManager) getPublishedTrackBySignalCid(clientId string) types.Pub return nil } +func (u *UptrackManager) GetPublishedTrackBySdpCid(clientId string) types.PublishedTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.getPublishedTrackBySdpCid(clientId) +} + // should be called with lock held func (u *UptrackManager) getPublishedTrackBySdpCid(clientId string) types.PublishedTrack { for _, publishedTrack := range u.publishedTracks { @@ -462,64 +290,6 @@ func (u *UptrackManager) getPublishedTrackBySdpCid(clientId string) types.Publis return nil } -// should be called with lock held -func (u *UptrackManager) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { - signalCid := clientId - trackInfo := u.pendingTracks[clientId] - - if trackInfo == nil { - // - // If no match on client id, find first one matching type - // as MediaStreamTrack can change client id when transceiver - // is added to peer connection. - // - for cid, ti := range u.pendingTracks { - if ti.Type == kind { - trackInfo = ti - signalCid = cid - break - } - } - } - - // if still not found, we are done - if trackInfo == nil { - u.params.Logger.Errorw("track info not published prior to track", nil, "clientId", clientId) - } - return signalCid, trackInfo.TrackInfo -} - -func (u *UptrackManager) handleTrackPublished(track types.PublishedTrack) { - u.lock.Lock() - if _, ok := u.publishedTracks[track.ID()]; !ok { - u.publishedTracks[track.ID()] = track - } - u.lock.Unlock() - - track.AddOnClose(func() { - // cleanup - u.lock.Lock() - trackID := track.ID() - delete(u.publishedTracks, trackID) - delete(u.pendingSubscriptions, trackID) - // not modifying subscription permissions, will get reset on next update from participant - - // as rtcpCh handles RTCP for all published tracks, close only after all published tracks are closed - if u.closed && len(u.publishedTracks) == 0 { - close(u.rtcpCh) - } - u.lock.Unlock() - // only send this when client is in a ready state - if u.onTrackUpdated != nil { - u.onTrackUpdated(track, true) - } - }) - - if u.onTrackPublished != nil { - u.onTrackPublished(track) - } -} - func (u *UptrackManager) updateSubscriptionPermissions(permissions *livekit.UpdateSubscriptionPermissions) { // every update overrides the existing @@ -679,43 +449,9 @@ func (u *UptrackManager) maybeRevokeSubscriptions(resolver func(participantID li } } -func (u *UptrackManager) rtcpSendWorker() { - defer Recover() - - // read from rtcpChan - for pkts := range u.rtcpCh { - if pkts == nil { - return - } - - fwdPkts := make([]rtcp.Packet, 0, len(pkts)) - for _, pkt := range pkts { - switch pkt.(type) { - case *rtcp.PictureLossIndication: - mediaSSRC := pkt.(*rtcp.PictureLossIndication).MediaSSRC - if u.pliThrottle.canSend(mediaSSRC) { - fwdPkts = append(fwdPkts, pkt) - } - case *rtcp.FullIntraRequest: - mediaSSRC := pkt.(*rtcp.FullIntraRequest).MediaSSRC - if u.pliThrottle.canSend(mediaSSRC) { - fwdPkts = append(fwdPkts, pkt) - } - default: - fwdPkts = append(fwdPkts, pkt) - } - } - - if len(fwdPkts) > 0 && u.onWriteRTCP != nil { - u.onWriteRTCP(fwdPkts) - } - } -} - func (u *UptrackManager) DebugInfo() map[string]interface{} { info := map[string]interface{}{} publishedTrackInfo := make(map[livekit.TrackID]interface{}) - pendingTrackInfo := make(map[string]interface{}) u.lock.RLock() for trackID, track := range u.publishedTracks { @@ -729,29 +465,9 @@ func (u *UptrackManager) DebugInfo() map[string]interface{} { } } } - - for clientID, ti := range u.pendingTracks { - pendingTrackInfo[clientID] = map[string]interface{}{ - "Sid": ti.Sid, - "Type": ti.Type.String(), - "Simulcast": ti.Simulcast, - } - } u.lock.RUnlock() info["PublishedTracks"] = publishedTrackInfo - info["PendingTracks"] = pendingTrackInfo return info } - -func (u *UptrackManager) HasPendingMigratedTrack() bool { - u.lock.RLock() - defer u.lock.RUnlock() - for _, t := range u.pendingTracks { - if t.migrated { - return true - } - } - return false -} diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 2b61fc83b..c16a42b7e 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -17,6 +17,7 @@ import ( "github.com/livekit/protocol/livekit" ) +type AudioLevelHandle func(level uint8, duration uint32) type Bitrates [DefaultMaxLayerSpatial + 1][DefaultMaxLayerTemporal + 1]int64 // TrackReceiver defines an interface receive media from remote peer @@ -30,6 +31,8 @@ type TrackReceiver interface { SendPLI(layer int32) GetSenderReportTime(layer int32) (rtpTS uint32, ntpTS uint64) Codec() webrtc.RTPCodecCapability + SetUpTrackPaused(paused bool) + SetMaxExpectedSpatialLayer(layer int32) } // Receiver defines an interface for a track receivers @@ -49,6 +52,8 @@ type Receiver interface { SendPLI(layer int32) SetRTCPCh(ch chan []rtcp.Packet) + OnAudioLevel(h AudioLevelHandle) + GetSenderReportTime(layer int32) (rtpTS uint32, ntpTS uint64) DebugInfo() map[string]interface{} } @@ -75,8 +80,9 @@ type WebRTCReceiver struct { lastPli atomicInt64 pliThrottle int64 - bufferMu sync.RWMutex - buffers [DefaultMaxLayerSpatial + 1]*buffer.Buffer + bufferMu sync.RWMutex + buffers [DefaultMaxLayerSpatial + 1]*buffer.Buffer + onAudioLevel AudioLevelHandle upTrackMu sync.RWMutex upTracks [DefaultMaxLayerSpatial + 1]*webrtc.TrackRemote @@ -188,6 +194,10 @@ func (w *WebRTCReceiver) Kind() webrtc.RTPCodecType { return w.kind } +func (w *WebRTCReceiver) OnAudioLevel(fn AudioLevelHandle) { + w.onAudioLevel = fn +} + func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buffer) { if w.closed.get() { return