From 9ca85454ed595219477ecee7d2ef76559789843d Mon Sep 17 00:00:00 2001 From: boks1971 Date: Tue, 4 Jan 2022 16:40:47 +0530 Subject: [PATCH] Refactor media track subscriptions - To enable re-use of common bits - Add max quality from other nodes --- go.mod | 2 +- go.sum | 4 +- pkg/rtc/mediatrack.go | 486 ++--------------- pkg/rtc/mediatrack_test.go | 6 +- pkg/rtc/mediatracksubscriptions.go | 499 ++++++++++++++++++ pkg/rtc/subscribedtrack.go | 2 +- pkg/rtc/types/interfaces.go | 7 +- pkg/rtc/types/typesfakes/fake_media_track.go | 253 ++++++++- .../types/typesfakes/fake_published_track.go | 188 ++++++- pkg/sfu/downtrack.go | 10 +- 10 files changed, 959 insertions(+), 498 deletions(-) create mode 100644 pkg/rtc/mediatracksubscriptions.go diff --git a/go.mod b/go.mod index 4f80fdf53..79609685d 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/google/wire v0.5.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.11.8-0.20220103045453-c441eb5f03c8 + github.com/livekit/protocol v0.11.8-0.20220104065946-2c4c8d7764ed github.com/magefile/mage v1.11.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0 github.com/mitchellh/go-homedir v1.1.0 diff --git a/go.sum b/go.sum index a33092c72..6499baf25 100644 --- a/go.sum +++ b/go.sum @@ -132,8 +132,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY/mKrcvA= github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= -github.com/livekit/protocol v0.11.8-0.20220103045453-c441eb5f03c8 h1:eo4OOUKLgNguaHEmcItTfk2IzVC2s2r6/WXqFfj5HjE= -github.com/livekit/protocol v0.11.8-0.20220103045453-c441eb5f03c8/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= +github.com/livekit/protocol v0.11.8-0.20220104065946-2c4c8d7764ed h1:6vxJ62pwuhXtEjqvsANTIoEcTgHR9laMa9tR3Xr0fAM= +github.com/livekit/protocol v0.11.8-0.20220104065946-2c4c8d7764ed/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls= github.com/magefile/mage v1.11.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index c9ac7491f..8fb9e5442 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu/connectionquality" "github.com/livekit/protocol/livekit" @@ -15,10 +16,8 @@ import ( "github.com/livekit/protocol/utils" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" - "github.com/pion/webrtc/v3/pkg/rtcerr" "github.com/livekit/livekit-server/pkg/config" - "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/sfu/buffer" "github.com/livekit/livekit-server/pkg/sfu/twcc" @@ -29,7 +28,6 @@ const ( lostUpdateDelta = time.Second connectionQualityUpdateInterval = 5 * time.Second layerSelectionTolerance = 0.9 - initialQualityUpdateWait = 10 * time.Second ) // MediaTrack represents a WebRTC track that needs to be forwarded @@ -46,13 +44,11 @@ type MediaTrack struct { lock sync.RWMutex - // map of target participantID -> types.SubscribedTrack - subscribedTracks sync.Map // participantID => types.SubscribedTrack - twcc *twcc.Responder - audioLevel *AudioLevel - receiver sfu.Receiver - lastPLI time.Time - layerDimensions sync.Map // quality => *livekit.VideoLayer + twcc *twcc.Responder + audioLevel *AudioLevel + receiver sfu.Receiver + lastPLI time.Time + layerDimensions sync.Map // quality => *livekit.VideoLayer // track audio fraction lost statsLock sync.Mutex @@ -65,15 +61,9 @@ type MediaTrack struct { done chan struct{} - // quality level enable/disable - maxQualityLock sync.RWMutex - maxSubscriberQuality map[livekit.ParticipantID]livekit.VideoQuality - maxSubscribedQuality livekit.VideoQuality - allSubscribersMuted bool - onSubscribedMaxQualityChange func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error - maxQualityTimer *time.Timer - onClose []func() + + *MediaTrackSubscriptions } type MediaTrackParams struct { @@ -95,15 +85,23 @@ type MediaTrackParams struct { func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTrack { t := &MediaTrack{ - params: params, - ssrc: track.SSRC(), - streamID: track.StreamID(), - codec: track.Codec(), - connectionStats: connectionquality.NewConnectionStats(), - done: make(chan struct{}), - maxSubscriberQuality: make(map[livekit.ParticipantID]livekit.VideoQuality), + params: params, + ssrc: track.SSRC(), + streamID: track.StreamID(), + codec: track.Codec(), + connectionStats: connectionquality.NewConnectionStats(), + done: make(chan struct{}), } + t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ + MediaTrack: t, + BufferFactory: params.BufferFactory, + ReceiverConfig: params.ReceiverConfig, + SubscriberConfig: params.SubscriberConfig, + Telemetry: params.Telemetry, + Logger: params.Logger, + }) + if params.TrackInfo.Muted { t.SetMuted(true) } @@ -139,6 +137,14 @@ func (t *MediaTrack) Source() livekit.TrackSource { return t.params.TrackInfo.Source } +func (t *MediaTrack) ParticipantID() livekit.ParticipantID { + return t.params.ParticipantID +} + +func (t *MediaTrack) ParticipantIdentity() livekit.ParticipantIdentity { + return t.params.ParticipantIdentity +} + func (t *MediaTrack) IsSimulcast() bool { return t.simulcasted.Get() } @@ -160,18 +166,7 @@ func (t *MediaTrack) SetMuted(muted bool) { } t.lock.RUnlock() - // mute all subscribed tracks - t.subscribedTracks.Range(func(_, value interface{}) bool { - if st, ok := value.(types.SubscribedTrack); ok { - st.SetPublisherMuted(muted) - } - return true - }) - - // update quality based on subscription if unmuting - if !muted { - t.updateQualityChange() - } + t.MediaTrackSubscriptions.SetMuted(muted) } func (t *MediaTrack) AddOnClose(f func()) { @@ -181,11 +176,6 @@ func (t *MediaTrack) AddOnClose(f func()) { t.onClose = append(t.onClose, f) } -func (t *MediaTrack) IsSubscriber(subID livekit.ParticipantID) bool { - _, ok := t.subscribedTracks.Load(subID) - return ok -} - func (t *MediaTrack) PublishLossPercentage() uint32 { return FixedPointToPercent(uint8(atomic.LoadUint32(&t.currentUpFracLost))) } @@ -195,19 +185,11 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { t.lock.Lock() defer t.lock.Unlock() - subscriberID := sub.ID() - - // don't subscribe to the same track multiple times - if _, ok := t.subscribedTracks.Load(subscriberID); ok { - return nil - } - if t.receiver == nil { // cannot add, no receiver return errors.New("cannot subscribe without a receiver in place") } - codec := t.receiver.Codec() // using DownTrack from ion-sfu streamId := string(t.params.ParticipantID) if sub.ProtocolVersion().SupportsPackedStreamId() { @@ -216,163 +198,27 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { streamId = PackStreamID(t.params.ParticipantID, t.ID()) } - receiver := NewWrappedReceiver(t.receiver, t.ID(), streamId) - - var rtcpFeedback []webrtc.RTCPFeedback - switch t.Kind() { - case livekit.TrackType_AUDIO: - rtcpFeedback = t.params.SubscriberConfig.RTCPFeedback.Audio - case livekit.TrackType_VIDEO: - rtcpFeedback = t.params.SubscriberConfig.RTCPFeedback.Video - } - downTrack, err := sfu.NewDownTrack(webrtc.RTPCodecCapability{ - MimeType: codec.MimeType, - ClockRate: codec.ClockRate, - Channels: codec.Channels, - SDPFmtpLine: codec.SDPFmtpLine, - RTCPFeedback: rtcpFeedback, - }, receiver, t.params.BufferFactory, subscriberID, t.params.ReceiverConfig.PacketBufferSize) + downTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, t.receiver.Codec(), NewWrappedReceiver(t.receiver, t.ID(), streamId)) if err != nil { return err } - subTrack := NewSubscribedTrack(SubscribedTrackParams{ - PublisherID: t.params.ParticipantID, - PublisherIdentity: t.params.ParticipantIdentity, - SubscriberID: subscriberID, - MediaTrack: t, - DownTrack: downTrack, - }) - var transceiver *webrtc.RTPTransceiver - var sender *webrtc.RTPSender - if sub.ProtocolVersion().SupportsTransceiverReuse() { - // - // AddTrack will create a new transceiver or re-use an unused one - // if the attributes match. This prevents SDP from bloating - // because of dormant transceivers building up. - // - sender, err = sub.SubscriberPC().AddTrack(downTrack) - if err != nil { - return err + if downTrack != nil { + if t.Kind() == livekit.TrackType_AUDIO { + downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) } - // as there is no way to get transceiver from sender, search - for _, tr := range sub.SubscriberPC().GetTransceivers() { - if tr.Sender() == sender { - transceiver = tr - break - } - } - if transceiver == nil { - // cannot add, no transceiver - return errors.New("cannot subscribe without a transceiver in place") - } - } else { - transceiver, err = sub.SubscriberPC().AddTransceiverFromTrack(downTrack, webrtc.RTPTransceiverInit{ - Direction: webrtc.RTPTransceiverDirectionSendonly, - }) - if err != nil { - return err - } - - sender = transceiver.Sender() - if sender == nil { - // cannot add, no sender - return errors.New("cannot subscribe without a sender in place") - } + t.receiver.AddDownTrack(downTrack) } - - sendParameters := sender.GetParameters() - downTrack.SetRTPHeaderExtensions(sendParameters.HeaderExtensions) - - downTrack.SetTransceiver(transceiver) - // when outtrack is bound, start loop to send reports - downTrack.OnBind(func() { - go subTrack.Bound() - go t.sendDownTrackBindingReports(sub) - }) - downTrack.OnPacketSent(func(_ *sfu.DownTrack, size int) { - t.params.Telemetry.OnDownstreamPacket(subscriberID, t.ID(), size) - }) - downTrack.OnPaddingSent(func(_ *sfu.DownTrack, size int) { - t.params.Telemetry.OnDownstreamPacket(subscriberID, t.ID(), size) - }) - downTrack.OnRTCP(func(pkts []rtcp.Packet) { - t.params.Telemetry.HandleRTCP(livekit.StreamType_DOWNSTREAM, subscriberID, t.ID(), pkts) - }) - - downTrack.OnCloseHandler(func() { - go func() { - t.subscribedTracks.Delete(subscriberID) - t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.ToProto()) - - // ignore if the subscribing sub is not connected - if sub.SubscriberPC().ConnectionState() == webrtc.PeerConnectionStateClosed { - return - } - - // if the source has been terminated, we'll need to terminate all of the subscribedtracks - // however, if the dest sub has disconnected, then we can skip - if sender == nil { - return - } - t.params.Logger.Debugw("removing peerconnection track", - "track", t.ID(), - "subscriber", sub.Identity(), - "subscriberID", subscriberID, - "kind", t.Kind(), - ) - if err := sub.SubscriberPC().RemoveTrack(sender); err != nil { - if err == webrtc.ErrConnectionClosed { - // sub closing, can skip removing subscribedtracks - return - } - if _, ok := err.(*rtcerr.InvalidStateError); !ok { - // most of these are safe to ignore, since the track state might have already - // been set to Inactive - t.params.Logger.Debugw("could not remove remoteTrack from forwarder", - "error", err, - "subscriber", sub.Identity(), - "subscriberID", subscriberID, - ) - } - } - - t.NotifySubscriberMute(subscriberID) - sub.RemoveSubscribedTrack(subTrack) - sub.Negotiate() - }() - }) - if t.Kind() == livekit.TrackType_AUDIO { - downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) - } - - t.subscribedTracks.Store(subscriberID, subTrack) - subTrack.SetPublisherMuted(t.IsMuted()) - - t.receiver.AddDownTrack(downTrack) - // since sub will lock, run it in a goroutine to avoid deadlocks - go func() { - t.NotifySubscriberMaxQuality(subscriberID, livekit.VideoQuality_HIGH) // start with HIGH, let subscription change it later - sub.AddSubscribedTrack(subTrack) - sub.Negotiate() - }() - - t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.ToProto()) return nil } func (t *MediaTrack) NumUpTracks() (uint32, uint32) { numExpected := atomic.LoadUint32(&t.numUpTracks) - t.maxQualityLock.RLock() - maxSubscribed := uint32(0) - if !t.allSubscribersMuted { - maxSubscribed = uint32(SpatialLayerForQuality(t.maxSubscribedQuality) + 1) - } - t.maxQualityLock.RUnlock() - if maxSubscribed < numExpected { - numExpected = maxSubscribed + numSubscribed := t.numSubscribed() + if numSubscribed < numExpected { + numExpected = numSubscribed } t.lock.RLock() @@ -471,57 +317,6 @@ func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.Tra }) } -// RemoveSubscriber removes participant from subscription -// stop all forwarders to the client -func (t *MediaTrack) RemoveSubscriber(participantID livekit.ParticipantID) { - subTrack := t.getSubscribedTrack(participantID) - if subTrack != nil { - go subTrack.DownTrack().Close() - } -} - -func (t *MediaTrack) RemoveAllSubscribers() { - t.params.Logger.Debugw("removing all subscribers", "track", t.ID()) - t.lock.Lock() - defer t.lock.Unlock() - t.subscribedTracks.Range(func(_, val interface{}) bool { - if subTrack, ok := val.(types.SubscribedTrack); ok { - go subTrack.DownTrack().Close() - } - return true - }) - t.subscribedTracks = sync.Map{} -} - -func (t *MediaTrack) RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID { - t.lock.Lock() - defer t.lock.Unlock() - - var revokedSubscriberIDs []livekit.ParticipantID - // LK-TODO: large number of subscribers needs to be solved for this loop - t.subscribedTracks.Range(func(key interface{}, val interface{}) bool { - if subID, ok := key.(livekit.ParticipantID); ok { - found := false - for _, allowedID := range allowedSubscriberIDs { - if subID == allowedID { - found = true - break - } - } - - if !found { - if subTrack, ok := val.(types.SubscribedTrack); ok { - go subTrack.DownTrack().Close() - revokedSubscriberIDs = append(revokedSubscriberIDs, subID) - } - } - } - return true - }) - - return revokedSubscriberIDs -} - func (t *MediaTrack) ToProto() *livekit.TrackInfo { info := t.params.TrackInfo info.Muted = t.IsMuted() @@ -542,12 +337,9 @@ func (t *MediaTrack) UpdateVideoLayers(layers []*livekit.VideoLayer) { for _, layer := range layers { t.layerDimensions.Store(layer.Quality, layer) } - t.subscribedTracks.Range(func(_, val interface{}) bool { - if st, ok := val.(types.SubscribedTrack); ok { - st.UpdateVideoLayer() - } - return true - }) + + t.MediaTrackSubscriptions.UpdateVideoLayers() + // TODO: this might need to trigger a participant update for clients to pick up dimension change } @@ -596,53 +388,6 @@ func (t *MediaTrack) GetQualityForDimension(width, height uint32) livekit.VideoQ return quality } -func (t *MediaTrack) getSubscribedTrack(subscriberID livekit.ParticipantID) types.SubscribedTrack { - if val, ok := t.subscribedTracks.Load(subscriberID); ok { - if st, ok := val.(types.SubscribedTrack); ok { - return st - } - } - return nil -} - -// TODO: send for all downtracks from the source participant -// https://tools.ietf.org/html/rfc7941 -func (t *MediaTrack) sendDownTrackBindingReports(sub types.Participant) { - var sd []rtcp.SourceDescriptionChunk - - subTrack := t.getSubscribedTrack(sub.ID()) - if subTrack == nil { - return - } - - chunks := subTrack.DownTrack().CreateSourceDescriptionChunks() - if chunks == nil { - return - } - sd = append(sd, chunks...) - - pkts := []rtcp.Packet{ - &rtcp.SourceDescription{Chunks: sd}, - } - - go func() { - defer RecoverSilent() - batch := pkts - i := 0 - for { - if err := sub.SubscriberPC().WriteRTCP(batch); err != nil { - t.params.Logger.Errorw("could not write RTCP", err) - return - } - if i > 5 { - return - } - i++ - time.Sleep(20 * time.Millisecond) - } - }() -} - func (t *MediaTrack) handlePublisherFeedback(packets []rtcp.Packet) { var maxLost uint8 var hasReport bool @@ -748,24 +493,16 @@ func (t *MediaTrack) DebugInfo() map[string]interface{} { "PubMuted": t.muted.Get(), } - subscribedTrackInfo := make([]map[string]interface{}, 0) - t.subscribedTracks.Range(func(_, val interface{}) bool { - if track, ok := val.(*SubscribedTrack); ok { - dt := track.DownTrack().DebugInfo() - dt["PubMuted"] = track.pubMuted.Get() - dt["SubMuted"] = track.subMuted.Get() - subscribedTrackInfo = append(subscribedTrackInfo, dt) - } - return true - }) - info["DownTracks"] = subscribedTrackInfo + info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() + t.lock.RLock() if t.receiver != nil { receiverInfo := t.receiver.DebugInfo() for k, v := range receiverInfo { info[k] = v } } + t.lock.RUnlock() return info } @@ -816,132 +553,17 @@ func (t *MediaTrack) calculateVideoScore() { } func (t *MediaTrack) OnSubscribedMaxQualityChange(f func(trackID livekit.TrackID, subscribedQualities []*livekit.SubscribedQuality) error) { - t.onSubscribedMaxQualityChange = f -} - -func (t *MediaTrack) NotifySubscriberMute(subscriberID livekit.ParticipantID) { - if t.Kind() != livekit.TrackType_VIDEO { - return - } - - t.maxQualityLock.Lock() - _, ok := t.maxSubscriberQuality[subscriberID] - if !ok { - t.maxQualityLock.Unlock() - return - } - - delete(t.maxSubscriberQuality, subscriberID) - t.maxQualityLock.Unlock() - - t.updateQualityChange() -} - -func (t *MediaTrack) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, quality livekit.VideoQuality) { - if t.Kind() != livekit.TrackType_VIDEO { - return - } - - t.maxQualityLock.Lock() - maxQuality, ok := t.maxSubscriberQuality[subscriberID] - if ok && maxQuality == quality { - t.maxQualityLock.Unlock() - return - } - - t.maxSubscriberQuality[subscriberID] = quality - t.maxQualityLock.Unlock() - - t.updateQualityChange() -} - -func (t *MediaTrack) startMaxQualityTimer() { - t.maxQualityLock.Lock() - defer t.maxQualityLock.Unlock() - - if t.Kind() != livekit.TrackType_VIDEO { - return - } - - t.maxQualityTimer = time.AfterFunc(initialQualityUpdateWait, func() { - t.stopMaxQualityTimer() - t.updateQualityChange() - }) -} - -func (t *MediaTrack) stopMaxQualityTimer() { - t.maxQualityLock.Lock() - defer t.maxQualityLock.Unlock() - - if t.maxQualityTimer != nil { - t.maxQualityTimer.Stop() - t.maxQualityTimer = nil - } -} - -func (t *MediaTrack) updateQualityChange() { - if t.Kind() != livekit.TrackType_VIDEO || t.IsMuted() { - return - } - - var subscribedQualities []*livekit.SubscribedQuality - - t.maxQualityLock.Lock() - allSubscribersMuted := false - maxSubscribedQuality := livekit.VideoQuality_LOW - if len(t.maxSubscriberQuality) == 0 { - allSubscribersMuted = true - } else { - for _, subQuality := range t.maxSubscriberQuality { - if subQuality > maxSubscribedQuality { - maxSubscribedQuality = subQuality - } + t.MediaTrackSubscriptions.OnSubscribedMaxQualityChange(func(subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) { + if f != nil && !t.IsMuted() { + _ = f(t.ID(), subscribedQualities) } - } - notifyMaxExpected := false - maxExpectedSpatialLayer := int32(-1) - if allSubscribersMuted { - if !t.allSubscribersMuted { - notifyMaxExpected = true - maxExpectedSpatialLayer = sfu.InvalidLayerSpatial - - t.allSubscribersMuted = true - - subscribedQualities = []*livekit.SubscribedQuality{ - {Quality: livekit.VideoQuality_LOW, Enabled: false}, - {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, - {Quality: livekit.VideoQuality_HIGH, Enabled: false}, - } - } - } else { - if t.allSubscribersMuted || maxSubscribedQuality != t.maxSubscribedQuality { - t.allSubscribersMuted = false - notifyMaxExpected = true - maxExpectedSpatialLayer = SpatialLayerForQuality(maxSubscribedQuality) - t.maxSubscribedQuality = maxSubscribedQuality - - for q := livekit.VideoQuality_LOW; q <= livekit.VideoQuality_HIGH; q++ { - subscribedQualities = append(subscribedQualities, &livekit.SubscribedQuality{ - Quality: q, - Enabled: q <= t.maxSubscribedQuality, - }) - } - } - } - t.maxQualityLock.Unlock() - - if notifyMaxExpected { t.lock.RLock() if t.receiver != nil { - t.receiver.SetMaxExpectedSpatialLayer(maxExpectedSpatialLayer) + t.receiver.SetMaxExpectedSpatialLayer(SpatialLayerForQuality(maxSubscribedQuality)) } t.lock.RUnlock() - } - - if len(subscribedQualities) != 0 && t.onSubscribedMaxQualityChange != nil { - _ = t.onSubscribedMaxQualityChange(t.ID(), subscribedQualities) - } + }) } //--------------------------- @@ -952,7 +574,11 @@ func SpatialLayerForQuality(quality livekit.VideoQuality) int32 { return 0 case livekit.VideoQuality_MEDIUM: return 1 - default: + case livekit.VideoQuality_HIGH: return 2 + case livekit.VideoQuality_OFF: + return -1 + default: + return -1 } } diff --git a/pkg/rtc/mediatrack_test.go b/pkg/rtc/mediatrack_test.go index cf3daa6c7..f0265ef5e 100644 --- a/pkg/rtc/mediatrack_test.go +++ b/pkg/rtc/mediatrack_test.go @@ -137,7 +137,7 @@ func TestSubscribedMaxQuality(t *testing.T) { }) // mute all subscribers - mt.NotifySubscriberMute("s1") + mt.NotifySubscriberMaxQuality("s1", livekit.VideoQuality_OFF) expectedSubscribedQualities := []*livekit.SubscribedQuality{ &livekit.SubscribedQuality{Quality: livekit.VideoQuality_LOW, Enabled: false}, @@ -216,7 +216,7 @@ func TestSubscribedMaxQuality(t *testing.T) { require.EqualValues(t, expectedSubscribedQualities, actualSubscribedQualities) // muting "s2" only should not disable all qualities - mt.NotifySubscriberMute("s2") + mt.NotifySubscriberMaxQuality("s2", livekit.VideoQuality_OFF) expectedSubscribedQualities = []*livekit.SubscribedQuality{ &livekit.SubscribedQuality{Quality: livekit.VideoQuality_LOW, Enabled: true}, @@ -227,7 +227,7 @@ func TestSubscribedMaxQuality(t *testing.T) { require.EqualValues(t, expectedSubscribedQualities, actualSubscribedQualities) // muting "s1" also should disable all qualities - mt.NotifySubscriberMute("s1") + mt.NotifySubscriberMaxQuality("s1", livekit.VideoQuality_OFF) expectedSubscribedQualities = []*livekit.SubscribedQuality{ &livekit.SubscribedQuality{Quality: livekit.VideoQuality_LOW, Enabled: false}, diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go new file mode 100644 index 000000000..ebf6dfb80 --- /dev/null +++ b/pkg/rtc/mediatracksubscriptions.go @@ -0,0 +1,499 @@ +package rtc + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" + "github.com/pion/webrtc/v3/pkg/rtcerr" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +const ( + initialQualityUpdateWait = 10 * time.Second +) + +// MediaTrackSubscriptions manages subscriptions of a media track +type MediaTrackSubscriptions struct { + params MediaTrackSubscriptionsParams + + subscribedTracks sync.Map // participantID => types.SubscribedTrack + + // quality level enable/disable + maxQualityLock sync.RWMutex + maxSubscriberQuality map[livekit.ParticipantID]livekit.VideoQuality + maxSubscriberNodeQuality map[string]livekit.VideoQuality // nodeID => livekit.VideoQuality + maxSubscribedQuality livekit.VideoQuality + onSubscribedMaxQualityChange func(subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality) + maxQualityTimer *time.Timer +} + +type MediaTrackSubscriptionsParams struct { + MediaTrack types.MediaTrack + + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + + Telemetry telemetry.TelemetryService + + Logger logger.Logger +} + +func NewMediaTrackSubscriptions(params MediaTrackSubscriptionsParams) *MediaTrackSubscriptions { + t := &MediaTrackSubscriptions{ + params: params, + maxSubscriberQuality: make(map[livekit.ParticipantID]livekit.VideoQuality), + } + + return t +} + +func (t *MediaTrackSubscriptions) SetMuted(muted bool) { + // mute all subscribed tracks + t.subscribedTracks.Range(func(_, value interface{}) bool { + if st, ok := value.(types.SubscribedTrack); ok { + st.SetPublisherMuted(muted) + } + return true + }) + + // update quality based on subscription if unmuting + if !muted { + t.updateQualityChange() + } +} + +func (t *MediaTrackSubscriptions) IsSubscriber(subID livekit.ParticipantID) bool { + _, ok := t.subscribedTracks.Load(subID) + return ok +} + +// AddSubscriber subscribes sub to current mediaTrack +func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec webrtc.RTPCodecCapability, wr WrappedReceiver) (*sfu.DownTrack, error) { + subscriberID := sub.ID() + + // don't subscribe to the same track multiple times + if _, ok := t.subscribedTracks.Load(subscriberID); ok { + return nil, nil + } + + var rtcpFeedback []webrtc.RTCPFeedback + switch t.params.MediaTrack.Kind() { + case livekit.TrackType_AUDIO: + rtcpFeedback = t.params.SubscriberConfig.RTCPFeedback.Audio + case livekit.TrackType_VIDEO: + rtcpFeedback = t.params.SubscriberConfig.RTCPFeedback.Video + } + downTrack, err := sfu.NewDownTrack(webrtc.RTPCodecCapability{ + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: rtcpFeedback, + }, wr, t.params.BufferFactory, subscriberID, t.params.ReceiverConfig.PacketBufferSize) + if err != nil { + return nil, err + } + subTrack := NewSubscribedTrack(SubscribedTrackParams{ + PublisherID: t.params.MediaTrack.ParticipantID(), + PublisherIdentity: t.params.MediaTrack.ParticipantIdentity(), + SubscriberID: subscriberID, + MediaTrack: t.params.MediaTrack, + DownTrack: downTrack, + }) + + var transceiver *webrtc.RTPTransceiver + var sender *webrtc.RTPSender + if sub.ProtocolVersion().SupportsTransceiverReuse() { + // + // AddTrack will create a new transceiver or re-use an unused one + // if the attributes match. This prevents SDP from bloating + // because of dormant transceivers building up. + // + sender, err = sub.SubscriberPC().AddTrack(downTrack) + if err != nil { + return nil, err + } + + // as there is no way to get transceiver from sender, search + for _, tr := range sub.SubscriberPC().GetTransceivers() { + if tr.Sender() == sender { + transceiver = tr + break + } + } + if transceiver == nil { + // cannot add, no transceiver + return nil, errors.New("cannot subscribe without a transceiver in place") + } + } else { + transceiver, err = sub.SubscriberPC().AddTransceiverFromTrack(downTrack, webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionSendonly, + }) + if err != nil { + return nil, err + } + + sender = transceiver.Sender() + if sender == nil { + // cannot add, no sender + return nil, errors.New("cannot subscribe without a sender in place") + } + } + + sendParameters := sender.GetParameters() + downTrack.SetRTPHeaderExtensions(sendParameters.HeaderExtensions) + + downTrack.SetTransceiver(transceiver) + // when outtrack is bound, start loop to send reports + downTrack.OnBind(func() { + go subTrack.Bound() + go t.sendDownTrackBindingReports(sub) + }) + downTrack.OnPacketSent(func(_ *sfu.DownTrack, size int) { + if t.params.Telemetry != nil { + t.params.Telemetry.OnDownstreamPacket(subscriberID, t.params.MediaTrack.ID(), size) + } + }) + downTrack.OnPaddingSent(func(_ *sfu.DownTrack, size int) { + if t.params.Telemetry != nil { + t.params.Telemetry.OnDownstreamPacket(subscriberID, t.params.MediaTrack.ID(), size) + } + }) + downTrack.OnRTCP(func(pkts []rtcp.Packet) { + if t.params.Telemetry != nil { + t.params.Telemetry.HandleRTCP(livekit.StreamType_DOWNSTREAM, subscriberID, t.params.MediaTrack.ID(), pkts) + } + }) + + downTrack.OnCloseHandler(func() { + go func() { + t.subscribedTracks.Delete(subscriberID) + if t.params.Telemetry != nil { + t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) + } + + // ignore if the subscribing sub is not connected + if sub.SubscriberPC().ConnectionState() == webrtc.PeerConnectionStateClosed { + return + } + + // if the source has been terminated, we'll need to terminate all of the subscribedtracks + // however, if the dest sub has disconnected, then we can skip + if sender == nil { + return + } + t.params.Logger.Debugw("removing peerconnection track", + "track", t.params.MediaTrack.ID(), + "subscriber", sub.Identity(), + "subscriberID", subscriberID, + "kind", t.params.MediaTrack.Kind(), + ) + if err := sub.SubscriberPC().RemoveTrack(sender); err != nil { + if err == webrtc.ErrConnectionClosed { + // sub closing, can skip removing subscribedtracks + return + } + if _, ok := err.(*rtcerr.InvalidStateError); !ok { + // most of these are safe to ignore, since the track state might have already + // been set to Inactive + t.params.Logger.Debugw("could not remove remoteTrack from forwarder", + "error", err, + "subscriber", sub.Identity(), + "subscriberID", subscriberID, + ) + } + } + + t.NotifySubscriberMaxQuality(subscriberID, livekit.VideoQuality_OFF) + sub.RemoveSubscribedTrack(subTrack) + sub.Negotiate() + }() + }) + + t.subscribedTracks.Store(subscriberID, subTrack) + subTrack.SetPublisherMuted(t.params.MediaTrack.IsMuted()) + + // since sub will lock, run it in a goroutine to avoid deadlocks + go func() { + t.NotifySubscriberMaxQuality(subscriberID, livekit.VideoQuality_HIGH) // start with HIGH, let subscription change it later + sub.AddSubscribedTrack(subTrack) + sub.Negotiate() + }() + + if t.params.Telemetry != nil { + t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) + } + return downTrack, nil +} + +// RemoveSubscriber removes participant from subscription +// stop all forwarders to the client +func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.ParticipantID) { + subTrack := t.getSubscribedTrack(participantID) + if subTrack != nil { + go subTrack.DownTrack().Close() + } +} + +func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { + t.params.Logger.Debugw("removing all subscribers", "track", t.params.MediaTrack.ID()) + + t.subscribedTracks.Range(func(_, val interface{}) bool { + if subTrack, ok := val.(types.SubscribedTrack); ok { + go subTrack.DownTrack().Close() + } + return true + }) + t.subscribedTracks = sync.Map{} +} + +func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID { + var revokedSubscriberIDs []livekit.ParticipantID + // LK-TODO: large number of subscribers needs to be solved for this loop + t.subscribedTracks.Range(func(key interface{}, val interface{}) bool { + if subID, ok := key.(livekit.ParticipantID); ok { + found := false + for _, allowedID := range allowedSubscriberIDs { + if subID == allowedID { + found = true + break + } + } + + if !found { + if subTrack, ok := val.(types.SubscribedTrack); ok { + go subTrack.DownTrack().Close() + revokedSubscriberIDs = append(revokedSubscriberIDs, subID) + } + } + } + return true + }) + + return revokedSubscriberIDs +} + +func (t *MediaTrackSubscriptions) UpdateVideoLayers() { + t.subscribedTracks.Range(func(_, val interface{}) bool { + if st, ok := val.(types.SubscribedTrack); ok { + st.UpdateVideoLayer() + } + return true + }) +} + +func (t *MediaTrackSubscriptions) getSubscribedTrack(subscriberID livekit.ParticipantID) types.SubscribedTrack { + if val, ok := t.subscribedTracks.Load(subscriberID); ok { + if st, ok := val.(types.SubscribedTrack); ok { + return st + } + } + return nil +} + +// TODO: send for all downtracks from the source participant +// https://tools.ietf.org/html/rfc7941 +func (t *MediaTrackSubscriptions) sendDownTrackBindingReports(sub types.Participant) { + var sd []rtcp.SourceDescriptionChunk + + subTrack := t.getSubscribedTrack(sub.ID()) + if subTrack == nil { + return + } + + chunks := subTrack.DownTrack().CreateSourceDescriptionChunks() + if chunks == nil { + return + } + sd = append(sd, chunks...) + + pkts := []rtcp.Packet{ + &rtcp.SourceDescription{Chunks: sd}, + } + + go func() { + defer RecoverSilent() + batch := pkts + i := 0 + for { + if err := sub.SubscriberPC().WriteRTCP(batch); err != nil { + t.params.Logger.Errorw("could not write RTCP", err) + return + } + if i > 5 { + return + } + i++ + time.Sleep(20 * time.Millisecond) + } + }() +} + +func (t *MediaTrackSubscriptions) DebugInfo() []map[string]interface{} { + subscribedTrackInfo := make([]map[string]interface{}, 0) + t.subscribedTracks.Range(func(_, val interface{}) bool { + if track, ok := val.(*SubscribedTrack); ok { + dt := track.DownTrack().DebugInfo() + dt["PubMuted"] = track.pubMuted.Get() + dt["SubMuted"] = track.subMuted.Get() + subscribedTrackInfo = append(subscribedTrackInfo, dt) + } + return true + }) + + return subscribedTrackInfo +} + +func (t *MediaTrackSubscriptions) OnSubscribedMaxQualityChange(f func(subscribedQualities []*livekit.SubscribedQuality, maxSubscribedQuality livekit.VideoQuality)) { + t.onSubscribedMaxQualityChange = f +} + +func (t *MediaTrackSubscriptions) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, quality livekit.VideoQuality) { + if t.params.MediaTrack.Kind() != livekit.TrackType_VIDEO { + return + } + + t.maxQualityLock.Lock() + if quality == livekit.VideoQuality_OFF { + _, ok := t.maxSubscriberQuality[subscriberID] + if !ok { + t.maxQualityLock.Unlock() + return + } + + delete(t.maxSubscriberQuality, subscriberID) + } else { + maxQuality, ok := t.maxSubscriberQuality[subscriberID] + if ok && maxQuality == quality { + t.maxQualityLock.Unlock() + return + } + + t.maxSubscriberQuality[subscriberID] = quality + } + t.maxQualityLock.Unlock() + + t.updateQualityChange() +} + +func (t *MediaTrackSubscriptions) NotifySubscriberNodeMaxQuality(nodeID string, quality livekit.VideoQuality) { + if t.params.MediaTrack.Kind() != livekit.TrackType_VIDEO { + return + } + + t.maxQualityLock.Lock() + if quality == livekit.VideoQuality_OFF { + _, ok := t.maxSubscriberNodeQuality[nodeID] + if !ok { + t.maxQualityLock.Unlock() + return + } + + delete(t.maxSubscriberNodeQuality, nodeID) + } else { + maxQuality, ok := t.maxSubscriberNodeQuality[nodeID] + if ok && maxQuality == quality { + t.maxQualityLock.Unlock() + return + } + + t.maxSubscriberNodeQuality[nodeID] = quality + } + t.maxQualityLock.Unlock() + + t.updateQualityChange() +} + +func (t *MediaTrackSubscriptions) startMaxQualityTimer() { + t.maxQualityLock.Lock() + defer t.maxQualityLock.Unlock() + + if t.params.MediaTrack.Kind() != livekit.TrackType_VIDEO { + return + } + + t.maxQualityTimer = time.AfterFunc(initialQualityUpdateWait, func() { + t.stopMaxQualityTimer() + t.updateQualityChange() + }) +} + +func (t *MediaTrackSubscriptions) stopMaxQualityTimer() { + t.maxQualityLock.Lock() + defer t.maxQualityLock.Unlock() + + if t.maxQualityTimer != nil { + t.maxQualityTimer.Stop() + t.maxQualityTimer = nil + } +} + +func (t *MediaTrackSubscriptions) updateQualityChange() { + if t.params.MediaTrack.Kind() != livekit.TrackType_VIDEO { + return + } + + t.maxQualityLock.Lock() + maxSubscribedQuality := livekit.VideoQuality_OFF + for _, subQuality := range t.maxSubscriberQuality { + if maxSubscribedQuality == livekit.VideoQuality_OFF || subQuality > maxSubscribedQuality { + maxSubscribedQuality = subQuality + } + } + + for _, subQuality := range t.maxSubscriberNodeQuality { + if maxSubscribedQuality == livekit.VideoQuality_OFF || subQuality > maxSubscribedQuality { + maxSubscribedQuality = subQuality + } + } + + if maxSubscribedQuality == t.maxSubscribedQuality { + t.maxQualityLock.Unlock() + return + } + + t.maxSubscribedQuality = maxSubscribedQuality + + var subscribedQualities []*livekit.SubscribedQuality + if t.maxSubscribedQuality == livekit.VideoQuality_OFF { + subscribedQualities = []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + } + } else { + for q := livekit.VideoQuality_LOW; q <= livekit.VideoQuality_HIGH; q++ { + subscribedQualities = append(subscribedQualities, &livekit.SubscribedQuality{ + Quality: q, + Enabled: q <= t.maxSubscribedQuality, + }) + } + } + t.maxQualityLock.Unlock() + + if t.onSubscribedMaxQualityChange != nil { + t.onSubscribedMaxQualityChange(subscribedQualities, maxSubscribedQuality) + } +} + +func (t *MediaTrackSubscriptions) numSubscribed() uint32 { + t.maxQualityLock.RLock() + numSubscribed := uint32(0) + if t.maxSubscribedQuality != livekit.VideoQuality_OFF { + numSubscribed = uint32(SpatialLayerForQuality(t.maxSubscribedQuality) + 1) + } + t.maxQualityLock.RUnlock() + + return numSubscribed +} diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index 90a7b2199..7a6ef25ea 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -99,7 +99,7 @@ func (t *SubscribedTrack) UpdateVideoLayer() { return } if t.subMuted.Get() { - t.MediaTrack().NotifySubscriberMute(t.params.SubscriberID) + t.MediaTrack().NotifySubscriberMaxQuality(t.params.SubscriberID, livekit.VideoQuality_OFF) return } settings, ok := t.settings.Load().(*livekit.UpdateTrackSettings) diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index a6405f692..442f73822 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -116,6 +116,11 @@ type MediaTrack interface { Source() livekit.TrackSource IsSimulcast() bool + ParticipantID() livekit.ParticipantID + ParticipantIdentity() livekit.ParticipantIdentity + + ToProto() *livekit.TrackInfo + // subscribers AddSubscriber(participant Participant) error RemoveSubscriber(participantID livekit.ParticipantID) @@ -126,8 +131,8 @@ type MediaTrack interface { // returns quality information that's appropriate for width & height GetQualityForDimension(width, height uint32) livekit.VideoQuality - NotifySubscriberMute(subscriberID livekit.ParticipantID) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, quality livekit.VideoQuality) + NotifySubscriberNodeMaxQuality(nodeID string, quality livekit.VideoQuality) } // PublishedTrack is the main interface representing a track published to the room diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 128a750f9..dca8d3967 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -99,10 +99,31 @@ type FakeMediaTrack struct { arg1 livekit.ParticipantID arg2 livekit.VideoQuality } - NotifySubscriberMuteStub func(livekit.ParticipantID) - notifySubscriberMuteMutex sync.RWMutex - notifySubscriberMuteArgsForCall []struct { - arg1 livekit.ParticipantID + NotifySubscriberNodeMaxQualityStub func(string, livekit.VideoQuality) + notifySubscriberNodeMaxQualityMutex sync.RWMutex + notifySubscriberNodeMaxQualityArgsForCall []struct { + arg1 string + arg2 livekit.VideoQuality + } + ParticipantIDStub func() livekit.ParticipantID + participantIDMutex sync.RWMutex + participantIDArgsForCall []struct { + } + participantIDReturns struct { + result1 livekit.ParticipantID + } + participantIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + ParticipantIdentityStub func() livekit.ParticipantIdentity + participantIdentityMutex sync.RWMutex + participantIdentityArgsForCall []struct { + } + participantIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + participantIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity } RemoveAllSubscribersStub func() removeAllSubscribersMutex sync.RWMutex @@ -139,6 +160,16 @@ type FakeMediaTrack struct { sourceReturnsOnCall map[int]struct { result1 livekit.TrackSource } + ToProtoStub func() *livekit.TrackInfo + toProtoMutex sync.RWMutex + toProtoArgsForCall []struct { + } + toProtoReturns struct { + result1 *livekit.TrackInfo + } + toProtoReturnsOnCall map[int]struct { + result1 *livekit.TrackInfo + } UpdateVideoLayersStub func([]*livekit.VideoLayer) updateVideoLayersMutex sync.RWMutex updateVideoLayersArgsForCall []struct { @@ -630,36 +661,143 @@ func (fake *FakeMediaTrack) NotifySubscriberMaxQualityArgsForCall(i int) (liveki return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeMediaTrack) NotifySubscriberMute(arg1 livekit.ParticipantID) { - fake.notifySubscriberMuteMutex.Lock() - fake.notifySubscriberMuteArgsForCall = append(fake.notifySubscriberMuteArgsForCall, struct { - arg1 livekit.ParticipantID - }{arg1}) - stub := fake.NotifySubscriberMuteStub - fake.recordInvocation("NotifySubscriberMute", []interface{}{arg1}) - fake.notifySubscriberMuteMutex.Unlock() +func (fake *FakeMediaTrack) NotifySubscriberNodeMaxQuality(arg1 string, arg2 livekit.VideoQuality) { + fake.notifySubscriberNodeMaxQualityMutex.Lock() + fake.notifySubscriberNodeMaxQualityArgsForCall = append(fake.notifySubscriberNodeMaxQualityArgsForCall, struct { + arg1 string + arg2 livekit.VideoQuality + }{arg1, arg2}) + stub := fake.NotifySubscriberNodeMaxQualityStub + fake.recordInvocation("NotifySubscriberNodeMaxQuality", []interface{}{arg1, arg2}) + fake.notifySubscriberNodeMaxQualityMutex.Unlock() if stub != nil { - fake.NotifySubscriberMuteStub(arg1) + fake.NotifySubscriberNodeMaxQualityStub(arg1, arg2) } } -func (fake *FakeMediaTrack) NotifySubscriberMuteCallCount() int { - fake.notifySubscriberMuteMutex.RLock() - defer fake.notifySubscriberMuteMutex.RUnlock() - return len(fake.notifySubscriberMuteArgsForCall) +func (fake *FakeMediaTrack) NotifySubscriberNodeMaxQualityCallCount() int { + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + return len(fake.notifySubscriberNodeMaxQualityArgsForCall) } -func (fake *FakeMediaTrack) NotifySubscriberMuteCalls(stub func(livekit.ParticipantID)) { - fake.notifySubscriberMuteMutex.Lock() - defer fake.notifySubscriberMuteMutex.Unlock() - fake.NotifySubscriberMuteStub = stub +func (fake *FakeMediaTrack) NotifySubscriberNodeMaxQualityCalls(stub func(string, livekit.VideoQuality)) { + fake.notifySubscriberNodeMaxQualityMutex.Lock() + defer fake.notifySubscriberNodeMaxQualityMutex.Unlock() + fake.NotifySubscriberNodeMaxQualityStub = stub } -func (fake *FakeMediaTrack) NotifySubscriberMuteArgsForCall(i int) livekit.ParticipantID { - fake.notifySubscriberMuteMutex.RLock() - defer fake.notifySubscriberMuteMutex.RUnlock() - argsForCall := fake.notifySubscriberMuteArgsForCall[i] - return argsForCall.arg1 +func (fake *FakeMediaTrack) NotifySubscriberNodeMaxQualityArgsForCall(i int) (string, livekit.VideoQuality) { + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + argsForCall := fake.notifySubscriberNodeMaxQualityArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeMediaTrack) ParticipantID() livekit.ParticipantID { + fake.participantIDMutex.Lock() + ret, specificReturn := fake.participantIDReturnsOnCall[len(fake.participantIDArgsForCall)] + fake.participantIDArgsForCall = append(fake.participantIDArgsForCall, struct { + }{}) + stub := fake.ParticipantIDStub + fakeReturns := fake.participantIDReturns + fake.recordInvocation("ParticipantID", []interface{}{}) + fake.participantIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) ParticipantIDCallCount() int { + fake.participantIDMutex.RLock() + defer fake.participantIDMutex.RUnlock() + return len(fake.participantIDArgsForCall) +} + +func (fake *FakeMediaTrack) ParticipantIDCalls(stub func() livekit.ParticipantID) { + fake.participantIDMutex.Lock() + defer fake.participantIDMutex.Unlock() + fake.ParticipantIDStub = stub +} + +func (fake *FakeMediaTrack) ParticipantIDReturns(result1 livekit.ParticipantID) { + fake.participantIDMutex.Lock() + defer fake.participantIDMutex.Unlock() + fake.ParticipantIDStub = nil + fake.participantIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeMediaTrack) ParticipantIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.participantIDMutex.Lock() + defer fake.participantIDMutex.Unlock() + fake.ParticipantIDStub = nil + if fake.participantIDReturnsOnCall == nil { + fake.participantIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.participantIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeMediaTrack) ParticipantIdentity() livekit.ParticipantIdentity { + fake.participantIdentityMutex.Lock() + ret, specificReturn := fake.participantIdentityReturnsOnCall[len(fake.participantIdentityArgsForCall)] + fake.participantIdentityArgsForCall = append(fake.participantIdentityArgsForCall, struct { + }{}) + stub := fake.ParticipantIdentityStub + fakeReturns := fake.participantIdentityReturns + fake.recordInvocation("ParticipantIdentity", []interface{}{}) + fake.participantIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) ParticipantIdentityCallCount() int { + fake.participantIdentityMutex.RLock() + defer fake.participantIdentityMutex.RUnlock() + return len(fake.participantIdentityArgsForCall) +} + +func (fake *FakeMediaTrack) ParticipantIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.participantIdentityMutex.Lock() + defer fake.participantIdentityMutex.Unlock() + fake.ParticipantIdentityStub = stub +} + +func (fake *FakeMediaTrack) ParticipantIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.participantIdentityMutex.Lock() + defer fake.participantIdentityMutex.Unlock() + fake.ParticipantIdentityStub = nil + fake.participantIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeMediaTrack) ParticipantIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.participantIdentityMutex.Lock() + defer fake.participantIdentityMutex.Unlock() + fake.ParticipantIdentityStub = nil + if fake.participantIdentityReturnsOnCall == nil { + fake.participantIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.participantIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} } func (fake *FakeMediaTrack) RemoveAllSubscribers() { @@ -869,6 +1007,59 @@ func (fake *FakeMediaTrack) SourceReturnsOnCall(i int, result1 livekit.TrackSour }{result1} } +func (fake *FakeMediaTrack) ToProto() *livekit.TrackInfo { + fake.toProtoMutex.Lock() + ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] + fake.toProtoArgsForCall = append(fake.toProtoArgsForCall, struct { + }{}) + stub := fake.ToProtoStub + fakeReturns := fake.toProtoReturns + fake.recordInvocation("ToProto", []interface{}{}) + fake.toProtoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) ToProtoCallCount() int { + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() + return len(fake.toProtoArgsForCall) +} + +func (fake *FakeMediaTrack) ToProtoCalls(stub func() *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = stub +} + +func (fake *FakeMediaTrack) ToProtoReturns(result1 *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + fake.toProtoReturns = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeMediaTrack) ToProtoReturnsOnCall(i int, result1 *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + if fake.toProtoReturnsOnCall == nil { + fake.toProtoReturnsOnCall = make(map[int]struct { + result1 *livekit.TrackInfo + }) + } + fake.toProtoReturnsOnCall[i] = struct { + result1 *livekit.TrackInfo + }{result1} +} + func (fake *FakeMediaTrack) UpdateVideoLayers(arg1 []*livekit.VideoLayer) { var arg1Copy []*livekit.VideoLayer if arg1 != nil { @@ -927,8 +1118,12 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.nameMutex.RUnlock() fake.notifySubscriberMaxQualityMutex.RLock() defer fake.notifySubscriberMaxQualityMutex.RUnlock() - fake.notifySubscriberMuteMutex.RLock() - defer fake.notifySubscriberMuteMutex.RUnlock() + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + fake.participantIDMutex.RLock() + defer fake.participantIDMutex.RUnlock() + fake.participantIdentityMutex.RLock() + defer fake.participantIdentityMutex.RUnlock() fake.removeAllSubscribersMutex.RLock() defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() @@ -939,6 +1134,8 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.setMutedMutex.RUnlock() fake.sourceMutex.RLock() defer fake.sourceMutex.RUnlock() + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() fake.updateVideoLayersMutex.RLock() defer fake.updateVideoLayersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/rtc/types/typesfakes/fake_published_track.go b/pkg/rtc/types/typesfakes/fake_published_track.go index 1cd1ab58c..8dbe9cc45 100644 --- a/pkg/rtc/types/typesfakes/fake_published_track.go +++ b/pkg/rtc/types/typesfakes/fake_published_track.go @@ -115,10 +115,11 @@ type FakePublishedTrack struct { arg1 livekit.ParticipantID arg2 livekit.VideoQuality } - NotifySubscriberMuteStub func(livekit.ParticipantID) - notifySubscriberMuteMutex sync.RWMutex - notifySubscriberMuteArgsForCall []struct { - arg1 livekit.ParticipantID + NotifySubscriberNodeMaxQualityStub func(string, livekit.VideoQuality) + notifySubscriberNodeMaxQualityMutex sync.RWMutex + notifySubscriberNodeMaxQualityArgsForCall []struct { + arg1 string + arg2 livekit.VideoQuality } NumUpTracksStub func() (uint32, uint32) numUpTracksMutex sync.RWMutex @@ -132,6 +133,26 @@ type FakePublishedTrack struct { result1 uint32 result2 uint32 } + ParticipantIDStub func() livekit.ParticipantID + participantIDMutex sync.RWMutex + participantIDArgsForCall []struct { + } + participantIDReturns struct { + result1 livekit.ParticipantID + } + participantIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + ParticipantIdentityStub func() livekit.ParticipantIdentity + participantIdentityMutex sync.RWMutex + participantIdentityArgsForCall []struct { + } + participantIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + participantIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } PublishLossPercentageStub func() uint32 publishLossPercentageMutex sync.RWMutex publishLossPercentageArgsForCall []struct { @@ -793,36 +814,37 @@ func (fake *FakePublishedTrack) NotifySubscriberMaxQualityArgsForCall(i int) (li return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakePublishedTrack) NotifySubscriberMute(arg1 livekit.ParticipantID) { - fake.notifySubscriberMuteMutex.Lock() - fake.notifySubscriberMuteArgsForCall = append(fake.notifySubscriberMuteArgsForCall, struct { - arg1 livekit.ParticipantID - }{arg1}) - stub := fake.NotifySubscriberMuteStub - fake.recordInvocation("NotifySubscriberMute", []interface{}{arg1}) - fake.notifySubscriberMuteMutex.Unlock() +func (fake *FakePublishedTrack) NotifySubscriberNodeMaxQuality(arg1 string, arg2 livekit.VideoQuality) { + fake.notifySubscriberNodeMaxQualityMutex.Lock() + fake.notifySubscriberNodeMaxQualityArgsForCall = append(fake.notifySubscriberNodeMaxQualityArgsForCall, struct { + arg1 string + arg2 livekit.VideoQuality + }{arg1, arg2}) + stub := fake.NotifySubscriberNodeMaxQualityStub + fake.recordInvocation("NotifySubscriberNodeMaxQuality", []interface{}{arg1, arg2}) + fake.notifySubscriberNodeMaxQualityMutex.Unlock() if stub != nil { - fake.NotifySubscriberMuteStub(arg1) + fake.NotifySubscriberNodeMaxQualityStub(arg1, arg2) } } -func (fake *FakePublishedTrack) NotifySubscriberMuteCallCount() int { - fake.notifySubscriberMuteMutex.RLock() - defer fake.notifySubscriberMuteMutex.RUnlock() - return len(fake.notifySubscriberMuteArgsForCall) +func (fake *FakePublishedTrack) NotifySubscriberNodeMaxQualityCallCount() int { + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + return len(fake.notifySubscriberNodeMaxQualityArgsForCall) } -func (fake *FakePublishedTrack) NotifySubscriberMuteCalls(stub func(livekit.ParticipantID)) { - fake.notifySubscriberMuteMutex.Lock() - defer fake.notifySubscriberMuteMutex.Unlock() - fake.NotifySubscriberMuteStub = stub +func (fake *FakePublishedTrack) NotifySubscriberNodeMaxQualityCalls(stub func(string, livekit.VideoQuality)) { + fake.notifySubscriberNodeMaxQualityMutex.Lock() + defer fake.notifySubscriberNodeMaxQualityMutex.Unlock() + fake.NotifySubscriberNodeMaxQualityStub = stub } -func (fake *FakePublishedTrack) NotifySubscriberMuteArgsForCall(i int) livekit.ParticipantID { - fake.notifySubscriberMuteMutex.RLock() - defer fake.notifySubscriberMuteMutex.RUnlock() - argsForCall := fake.notifySubscriberMuteArgsForCall[i] - return argsForCall.arg1 +func (fake *FakePublishedTrack) NotifySubscriberNodeMaxQualityArgsForCall(i int) (string, livekit.VideoQuality) { + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + argsForCall := fake.notifySubscriberNodeMaxQualityArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakePublishedTrack) NumUpTracks() (uint32, uint32) { @@ -881,6 +903,112 @@ func (fake *FakePublishedTrack) NumUpTracksReturnsOnCall(i int, result1 uint32, }{result1, result2} } +func (fake *FakePublishedTrack) ParticipantID() livekit.ParticipantID { + fake.participantIDMutex.Lock() + ret, specificReturn := fake.participantIDReturnsOnCall[len(fake.participantIDArgsForCall)] + fake.participantIDArgsForCall = append(fake.participantIDArgsForCall, struct { + }{}) + stub := fake.ParticipantIDStub + fakeReturns := fake.participantIDReturns + fake.recordInvocation("ParticipantID", []interface{}{}) + fake.participantIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePublishedTrack) ParticipantIDCallCount() int { + fake.participantIDMutex.RLock() + defer fake.participantIDMutex.RUnlock() + return len(fake.participantIDArgsForCall) +} + +func (fake *FakePublishedTrack) ParticipantIDCalls(stub func() livekit.ParticipantID) { + fake.participantIDMutex.Lock() + defer fake.participantIDMutex.Unlock() + fake.ParticipantIDStub = stub +} + +func (fake *FakePublishedTrack) ParticipantIDReturns(result1 livekit.ParticipantID) { + fake.participantIDMutex.Lock() + defer fake.participantIDMutex.Unlock() + fake.ParticipantIDStub = nil + fake.participantIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakePublishedTrack) ParticipantIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.participantIDMutex.Lock() + defer fake.participantIDMutex.Unlock() + fake.ParticipantIDStub = nil + if fake.participantIDReturnsOnCall == nil { + fake.participantIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.participantIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakePublishedTrack) ParticipantIdentity() livekit.ParticipantIdentity { + fake.participantIdentityMutex.Lock() + ret, specificReturn := fake.participantIdentityReturnsOnCall[len(fake.participantIdentityArgsForCall)] + fake.participantIdentityArgsForCall = append(fake.participantIdentityArgsForCall, struct { + }{}) + stub := fake.ParticipantIdentityStub + fakeReturns := fake.participantIdentityReturns + fake.recordInvocation("ParticipantIdentity", []interface{}{}) + fake.participantIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePublishedTrack) ParticipantIdentityCallCount() int { + fake.participantIdentityMutex.RLock() + defer fake.participantIdentityMutex.RUnlock() + return len(fake.participantIdentityArgsForCall) +} + +func (fake *FakePublishedTrack) ParticipantIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.participantIdentityMutex.Lock() + defer fake.participantIdentityMutex.Unlock() + fake.ParticipantIdentityStub = stub +} + +func (fake *FakePublishedTrack) ParticipantIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.participantIdentityMutex.Lock() + defer fake.participantIdentityMutex.Unlock() + fake.ParticipantIdentityStub = nil + fake.participantIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakePublishedTrack) ParticipantIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.participantIdentityMutex.Lock() + defer fake.participantIdentityMutex.Unlock() + fake.ParticipantIdentityStub = nil + if fake.participantIdentityReturnsOnCall == nil { + fake.participantIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.participantIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + func (fake *FakePublishedTrack) PublishLossPercentage() uint32 { fake.publishLossPercentageMutex.Lock() ret, specificReturn := fake.publishLossPercentageReturnsOnCall[len(fake.publishLossPercentageArgsForCall)] @@ -1415,10 +1543,14 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.nameMutex.RUnlock() fake.notifySubscriberMaxQualityMutex.RLock() defer fake.notifySubscriberMaxQualityMutex.RUnlock() - fake.notifySubscriberMuteMutex.RLock() - defer fake.notifySubscriberMuteMutex.RUnlock() + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() fake.numUpTracksMutex.RLock() defer fake.numUpTracksMutex.RUnlock() + fake.participantIDMutex.RLock() + defer fake.participantIDMutex.RUnlock() + fake.participantIdentityMutex.RLock() + defer fake.participantIdentityMutex.RUnlock() fake.publishLossPercentageMutex.RLock() defer fake.publishLossPercentageMutex.RUnlock() fake.receiverMutex.RLock() diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index 2fa6cb61f..f9d995670 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -94,7 +94,7 @@ type DownTrack struct { receiver TrackReceiver transceiver *webrtc.RTPTransceiver writeStream webrtc.TrackLocalWriter - onCloseHandler func() + onCloseHandlers []func() onBind func() receiverReportListeners []ReceiverReportListener listenerLock sync.RWMutex @@ -438,8 +438,8 @@ func (d *DownTrack) Close() { d.closeOnce.Do(func() { Logger.V(1).Info("Closing sender", "peer_id", d.peerID, "kind", d.kind) - if d.onCloseHandler != nil { - d.onCloseHandler() + for _, f := range d.onCloseHandlers { + f() } close(d.done) }) @@ -485,7 +485,9 @@ func (d *DownTrack) UptrackLayersChange(availableLayers []uint16) { // OnCloseHandler method to be called on remote tracked removed func (d *DownTrack) OnCloseHandler(fn func()) { - d.onCloseHandler = fn + if fn != nil { + d.onCloseHandlers = append(d.onCloseHandlers, fn) + } } func (d *DownTrack) OnBind(fn func()) {