From ecf9590d565cc147d80dc9dc22b6bc8b3a1ab82e Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Fri, 7 Jan 2022 01:46:15 +0530 Subject: [PATCH] More abstractions (#316) --- pkg/rtc/audiolevel.go | 12 +-- pkg/rtc/mediatrack.go | 4 +- pkg/rtc/mediatracksubscriptions.go | 56 +++++--------- pkg/rtc/participant.go | 69 +++++------------ pkg/rtc/participant_internal_test.go | 28 +++---- pkg/rtc/room.go | 12 ++- pkg/rtc/signalhandler.go | 9 +-- pkg/rtc/types/interfaces.go | 8 ++ pkg/rtc/types/typesfakes/fake_participant.go | 74 ++++++++++++++++++ .../types/typesfakes/fake_published_track.go | 70 +++++++++++++++++ pkg/rtc/types/typesfakes/fake_room.go | 76 +++++++++++++++++++ pkg/rtc/uptrackmanager.go | 51 +++++++------ 12 files changed, 326 insertions(+), 143 deletions(-) diff --git a/pkg/rtc/audiolevel.go b/pkg/rtc/audiolevel.go index 803ab6ea6..b34eef6f1 100644 --- a/pkg/rtc/audiolevel.go +++ b/pkg/rtc/audiolevel.go @@ -8,7 +8,7 @@ import ( const ( // duration of audio frames for observe window observeDuration = 500 // ms - silentAudioLevel = 127 + SilentAudioLevel = 127 ) // keeps track of audio level for a participant @@ -29,8 +29,8 @@ func NewAudioLevel(activeLevel uint8, minPercentile uint8) *AudioLevel { l := &AudioLevel{ levelThreshold: activeLevel, minActiveDuration: uint32(minPercentile) * observeDuration / 100, - currentLevel: silentAudioLevel, - observeLevel: silentAudioLevel, + currentLevel: SilentAudioLevel, + observeLevel: SilentAudioLevel, } return l } @@ -52,9 +52,9 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) { level := uint32(l.observeLevel) - uint32(20*math.Log10(float64(l.activeDuration)/float64(observeDuration))) atomic.StoreUint32(&l.currentLevel, level) } else { - atomic.StoreUint32(&l.currentLevel, silentAudioLevel) + atomic.StoreUint32(&l.currentLevel, SilentAudioLevel) } - l.observeLevel = silentAudioLevel + l.observeLevel = SilentAudioLevel l.activeDuration = 0 l.observedDuration = 0 } @@ -63,7 +63,7 @@ func (l *AudioLevel) Observe(level uint8, durationMs uint32) { // returns current audio level, 0 (loudest) to 127 (silent) func (l *AudioLevel) GetLevel() (uint8, bool) { level := uint8(atomic.LoadUint32(&l.currentLevel)) - active := level != silentAudioLevel + active := level != SilentAudioLevel return level, active } diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 883c78f4a..05a660e09 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -101,7 +101,7 @@ func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTra ReceiverConfig: params.ReceiverConfig, SubscriberConfig: params.SubscriberConfig, Telemetry: params.Telemetry, - Logger: ¶ms.Logger, + Logger: params.Logger, }) if params.TrackInfo.Muted { @@ -345,7 +345,7 @@ func (t *MediaTrack) GetAudioLevel() (level uint8, active bool) { defer t.audioLevelMu.RUnlock() if t.audioLevel == nil { - return silentAudioLevel, false + return SilentAudioLevel, false } return t.audioLevel.GetLevel() } diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index e159ea394..675c4fb2b 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -49,7 +49,7 @@ type MediaTrackSubscriptionsParams struct { Telemetry telemetry.TelemetryService - Logger *logger.Logger + Logger logger.Logger } func NewMediaTrackSubscriptions(params MediaTrackSubscriptionsParams) *MediaTrackSubscriptions { @@ -177,19 +177,13 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web go t.sendDownTrackBindingReports(sub) }) downTrack.OnPacketSent(func(_ *sfu.DownTrack, size int) { - if t.params.Telemetry != nil { - t.params.Telemetry.OnDownstreamPacket(subscriberID, t.params.MediaTrack.ID(), size) - } + t.params.Telemetry.OnDownstreamPacket(subscriberID, t.params.MediaTrack.ID(), size) }) downTrack.OnPaddingSent(func(_ *sfu.DownTrack, size int) { - if t.params.Telemetry != nil { - t.params.Telemetry.OnDownstreamPacket(subscriberID, t.params.MediaTrack.ID(), size) - } + t.params.Telemetry.OnDownstreamPacket(subscriberID, t.params.MediaTrack.ID(), size) }) downTrack.OnRTCP(func(pkts []rtcp.Packet) { - if t.params.Telemetry != nil { - t.params.Telemetry.HandleRTCP(livekit.StreamType_DOWNSTREAM, subscriberID, t.params.MediaTrack.ID(), pkts) - } + t.params.Telemetry.HandleRTCP(livekit.StreamType_DOWNSTREAM, subscriberID, t.params.MediaTrack.ID(), pkts) }) downTrack.OnCloseHandler(func() { @@ -199,9 +193,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web t.subscribedTracksMu.Unlock() t.maybeNotifyNoSubscribers() - if t.params.Telemetry != nil { - t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) - } + t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) // ignore if the subscribing sub is not connected if sub.SubscriberPC().ConnectionState() == webrtc.PeerConnectionStateClosed { @@ -213,14 +205,12 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web if sender == nil { return } - if t.params.Logger != nil { - t.params.Logger.Debugw("removing peerconnection track", - "track", t.params.MediaTrack.ID(), - "subscriber", sub.Identity(), - "subscriberID", subscriberID, - "kind", t.params.MediaTrack.Kind(), - ) - } + t.params.Logger.Debugw("removing peerconnection track", + "track", t.params.MediaTrack.ID(), + "subscriber", sub.Identity(), + "subscriberID", subscriberID, + "kind", t.params.MediaTrack.Kind(), + ) if err := sub.SubscriberPC().RemoveTrack(sender); err != nil { if err == webrtc.ErrConnectionClosed { // sub closing, can skip removing subscribedtracks @@ -229,13 +219,11 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web if _, ok := err.(*rtcerr.InvalidStateError); !ok { // most of these are safe to ignore, since the track state might have already // been set to Inactive - if t.params.Logger != nil { - t.params.Logger.Debugw("could not remove remoteTrack from forwarder", - "error", err, - "subscriber", sub.Identity(), - "subscriberID", subscriberID, - ) - } + t.params.Logger.Debugw("could not remove remoteTrack from forwarder", + "error", err, + "subscriber", sub.Identity(), + "subscriberID", subscriberID, + ) } } @@ -255,9 +243,7 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.Participant, codec web sub.Negotiate() }() - if t.params.Telemetry != nil { - t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) - } + t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.params.MediaTrack.ToProto()) return downTrack, nil } @@ -271,9 +257,7 @@ func (t *MediaTrackSubscriptions) RemoveSubscriber(participantID livekit.Partici } func (t *MediaTrackSubscriptions) RemoveAllSubscribers() { - if t.params.Logger != nil { - t.params.Logger.Debugw("removing all subscribers", "track", t.params.MediaTrack.ID()) - } + t.params.Logger.Debugw("removing all subscribers", "track", t.params.MediaTrack.ID()) t.subscribedTracksMu.RLock() subscribedTracks := t.subscribedTracks @@ -353,9 +337,7 @@ func (t *MediaTrackSubscriptions) sendDownTrackBindingReports(sub types.Particip i := 0 for { if err := sub.SubscriberPC().WriteRTCP(batch); err != nil { - if t.params.Logger != nil { - t.params.Logger.Errorw("could not write RTCP", err) - } + t.params.Logger.Errorw("could not write RTCP", err) return } if i > 5 { diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 9c7b697d9..296f4d394 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -76,7 +76,7 @@ type ParticipantImpl struct { // hold reference for MediaTrack twcc *twcc.Responder - uptrackManager *UptrackManager + *UptrackManager // tracks the current participant is subscribed to, map of sid => DownTrack subscribedTracks map[livekit.TrackID]types.SubscribedTrack @@ -237,7 +237,7 @@ func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { Hidden: p.Hidden(), Recorder: p.IsRecorder(), } - info.Tracks = p.uptrackManager.ToProto() + info.Tracks = p.UptrackManager.ToProto() return info } @@ -338,7 +338,7 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { return } - ti := p.uptrackManager.AddTrack(req) + ti := p.UptrackManager.AddTrack(req) if ti == nil { return } @@ -381,7 +381,7 @@ func (p *ParticipantImpl) AddICECandidate(candidate webrtc.ICECandidateInit, tar func (p *ParticipantImpl) Start() { p.once.Do(func() { - p.uptrackManager.Start() + p.UptrackManager.Start() go p.downTracksRTCPWorker() }) } @@ -399,7 +399,7 @@ func (p *ParticipantImpl) Close() error { }, }) - p.uptrackManager.Close() + p.UptrackManager.Close() p.lock.Lock() disallowedSubscriptions := make(map[livekit.TrackID]livekit.ParticipantID) @@ -448,17 +448,9 @@ func (p *ParticipantImpl) ICERestart() error { }) } -// AddSubscriber subscribes op to all publishedTracks or given set of tracks -func (p *ParticipantImpl) AddSubscriber(op types.Participant, params types.AddSubscriberParams) (int, error) { - return p.uptrackManager.AddSubscriber(op, params) -} - -func (p *ParticipantImpl) RemoveSubscriber(op types.Participant, trackID livekit.TrackID) { - p.uptrackManager.RemoveSubscriber(op, trackID) -} - +// // signal connection methods - +// func (p *ParticipantImpl) SendJoinResponse( roomInfo *livekit.Room, otherParticipants []*livekit.ParticipantInfo, @@ -594,16 +586,12 @@ func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fro }) } - p.uptrackManager.SetTrackMuted(trackID, muted) -} - -func (p *ParticipantImpl) GetAudioLevel() (level uint8, active bool) { - return p.uptrackManager.GetAudioLevel() + p.UptrackManager.SetTrackMuted(trackID, muted) } func (p *ParticipantImpl) GetConnectionQuality() *livekit.ConnectionQualityInfo { // avg loss across all tracks, weigh published the same as subscribed - scores, numTracks := p.uptrackManager.GetConnectionQuality() + scores, numTracks := p.UptrackManager.GetConnectionQuality() p.lock.RLock() for _, subTrack := range p.subscribedTracks { @@ -673,14 +661,6 @@ func (p *ParticipantImpl) SubscriberPC() *webrtc.PeerConnection { return p.subscriber.pc } -func (p *ParticipantImpl) GetPublishedTrack(sid livekit.TrackID) types.PublishedTrack { - return p.uptrackManager.GetPublishedTrack(sid) -} - -func (p *ParticipantImpl) GetPublishedTracks() []types.PublishedTrack { - return p.uptrackManager.GetPublishedTracks() -} - func (p *ParticipantImpl) GetSubscribedTrack(sid livekit.TrackID) types.SubscribedTrack { p.lock.RLock() defer p.lock.RUnlock() @@ -759,13 +739,6 @@ func (p *ParticipantImpl) RemoveSubscribedTrack(subTrack types.SubscribedTrack) } } -func (p *ParticipantImpl) UpdateSubscriptionPermissions( - permissions *livekit.UpdateSubscriptionPermissions, - resolver func(participantID livekit.ParticipantID) types.Participant, -) error { - return p.uptrackManager.UpdateSubscriptionPermissions(permissions, resolver) -} - func (p *ParticipantImpl) SubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) { p.lock.Lock() if allowed { @@ -789,16 +762,8 @@ func (p *ParticipantImpl) SubscriptionPermissionUpdate(publisherID livekit.Parti } } -func (p *ParticipantImpl) UpdateSubscribedQuality(nodeID string, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error { - return p.uptrackManager.UpdateSubscribedQuality(nodeID, trackID, maxQuality) -} - -func (p *ParticipantImpl) UpdateMediaLoss(nodeID string, trackID livekit.TrackID, fractionalLoss uint32) error { - return p.uptrackManager.UpdateMediaLoss(nodeID, trackID, fractionalLoss) -} - func (p *ParticipantImpl) setupUptrackManager() { - p.uptrackManager = NewUptrackManager(UptrackManagerParams{ + p.UptrackManager = NewUptrackManager(UptrackManagerParams{ Identity: p.params.Identity, SID: p.params.SID, Config: p.params.Config, @@ -808,13 +773,13 @@ func (p *ParticipantImpl) setupUptrackManager() { Logger: p.params.Logger, }) - p.uptrackManager.OnTrackPublished(func(track types.PublishedTrack) { + p.UptrackManager.OnTrackPublished(func(track types.PublishedTrack) { if p.onTrackPublished != nil { p.onTrackPublished(p, track) } }) - p.uptrackManager.OnTrackUpdated(func(track types.PublishedTrack, onlyIfReady bool) { + p.UptrackManager.OnTrackUpdated(func(track types.PublishedTrack, onlyIfReady bool) { if onlyIfReady && !p.IsReady() { return } @@ -824,13 +789,13 @@ func (p *ParticipantImpl) setupUptrackManager() { } }) - p.uptrackManager.OnWriteRTCP(func(pkts []rtcp.Packet) { + p.UptrackManager.OnWriteRTCP(func(pkts []rtcp.Packet) { if err := p.publisher.pc.WriteRTCP(pkts); err != nil { p.params.Logger.Errorw("could not write RTCP to participant", err) } }) - p.uptrackManager.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) + p.UptrackManager.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) } func (p *ParticipantImpl) sendIceCandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { @@ -921,7 +886,7 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w return } - p.uptrackManager.MediaTrackReceived(track, rtpReceiver) + p.UptrackManager.MediaTrackReceived(track, rtpReceiver) } func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { @@ -1079,7 +1044,7 @@ func (p *ParticipantImpl) configureReceiverDTX() { // multiple audio tracks. At that point, there might be a need to // rely on something like order of tracks. TODO // - enableDTX := p.uptrackManager.GetDTX() + enableDTX := p.UptrackManager.GetDTX() transceivers := p.publisher.pc.GetTransceivers() for _, transceiver := range transceivers { if transceiver.Kind() != webrtc.RTPCodecTypeAudio { @@ -1175,7 +1140,7 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { "State": p.State().String(), } - uptrackManagerInfo := p.uptrackManager.DebugInfo() + uptrackManagerInfo := p.UptrackManager.DebugInfo() subscribedTrackInfo := make(map[livekit.TrackID]interface{}) p.lock.RLock() diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index ba0d28350..8cb591523 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -80,14 +80,14 @@ func TestTrackPublishing(t *testing.T) { p.OnTrackPublished(func(p types.Participant, track types.PublishedTrack) { published = true }) - p.uptrackManager.handleTrackPublished(track) + p.UptrackManager.handleTrackPublished(track) require.True(t, published) require.False(t, updated) - require.Len(t, p.uptrackManager.publishedTracks, 1) + require.Len(t, p.UptrackManager.publishedTracks, 1) track.AddOnCloseArgsForCall(0)() - require.Len(t, p.uptrackManager.publishedTracks, 0) + require.Len(t, p.UptrackManager.publishedTracks, 0) require.True(t, updated) }) @@ -136,7 +136,7 @@ func TestTrackPublishing(t *testing.T) { track := &typesfakes.FakePublishedTrack{} track.SignalCidReturns("cid") // directly add to publishedTracks without lock - for testing purpose only - p.uptrackManager.publishedTracks["cid"] = track + p.UptrackManager.publishedTracks["cid"] = track p.AddTrack(&livekit.AddTrackRequest{ Cid: "cid", @@ -153,7 +153,7 @@ func TestTrackPublishing(t *testing.T) { track := &typesfakes.FakePublishedTrack{} track.SdpCidReturns("cid") // directly add to publishedTracks without lock - for testing purpose only - p.uptrackManager.publishedTracks["cid"] = track + p.UptrackManager.publishedTracks["cid"] = track p.AddTrack(&livekit.AddTrackRequest{ Cid: "cid", @@ -202,7 +202,7 @@ func TestDisconnectTiming(t *testing.T) { } }() track := &typesfakes.FakePublishedTrack{} - p.uptrackManager.handleTrackPublished(track) + p.UptrackManager.handleTrackPublished(track) // close channel and then try to Negotiate msg.Close() @@ -220,7 +220,7 @@ func TestMuteSetting(t *testing.T) { t.Run("can set mute when track is pending", func(t *testing.T) { p := newParticipantForTest("test") ti := &livekit.TrackInfo{Sid: "testTrack"} - p.uptrackManager.pendingTracks["cid"] = ti + p.UptrackManager.pendingTracks["cid"] = ti p.SetTrackMuted(livekit.TrackID(ti.Sid), true, false) require.True(t, ti.Muted) @@ -234,7 +234,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - _, ti := p.uptrackManager.getPendingTrack("cid", livekit.TrackType_AUDIO) + _, ti := p.UptrackManager.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) }) @@ -282,30 +282,30 @@ func TestConnectionQuality(t *testing.T) { t.Run("smooth sailing", func(t *testing.T) { p := newParticipantForTest("test") - p.uptrackManager.publishedTracks["video"] = testPublishedVideoTrack(2, 3, 3) - p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 0) + p.UptrackManager.publishedTracks["video"] = testPublishedVideoTrack(2, 3, 3) + p.UptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 0) require.Equal(t, livekit.ConnectionQuality_EXCELLENT, p.GetConnectionQuality().GetQuality()) }) t.Run("reduced publishing", func(t *testing.T) { p := newParticipantForTest("test") - p.uptrackManager.publishedTracks["video"] = testPublishedVideoTrack(3, 2, 3) - p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) + p.UptrackManager.publishedTracks["video"] = testPublishedVideoTrack(3, 2, 3) + p.UptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) require.Equal(t, livekit.ConnectionQuality_GOOD, p.GetConnectionQuality().GetQuality()) }) t.Run("audio smooth publishing", func(t *testing.T) { p := newParticipantForTest("test") - p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 10) + p.UptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 10) require.Equal(t, livekit.ConnectionQuality_EXCELLENT, p.GetConnectionQuality().GetQuality()) }) t.Run("audio reduced publishing", func(t *testing.T) { p := newParticipantForTest("test") - p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) + p.UptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) require.Equal(t, livekit.ConnectionQuality_GOOD, p.GetConnectionQuality().GetQuality()) }) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 3446f363d..a50224493 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -392,6 +392,10 @@ func (r *Room) RemoveDisallowedSubscriptions(sub types.Participant, disallowedSu } } +func (r *Room) UpdateVideoLayers(participant types.Participant, updateVideoLayers *livekit.UpdateVideoLayers) error { + return participant.UpdateVideoLayers(updateVideoLayers) +} + func (r *Room) IsClosed() bool { select { case <-r.closed: @@ -607,14 +611,16 @@ func (r *Room) subscribeToExistingTracks(p types.Participant) { // don't send to itself continue } - if n, err := op.AddSubscriber(p, types.AddSubscriberParams{AllTracks: true}); err != nil { + + // subscribe to all + n, err := op.AddSubscriber(p, types.AddSubscriberParams{AllTracks: true}) + if err != nil { // TODO: log error? or disconnect? r.Logger.Errorw("could not subscribe to participant", err, "participants", []livekit.ParticipantIdentity{op.Identity(), p.Identity()}, "pIDs", []livekit.ParticipantID{op.ID(), p.ID()}) - } else { - tracksAdded += n } + tracksAdded += n } if tracksAdded > 0 { r.Logger.Debugw("subscribed participants to existing tracks", "tracks", tracksAdded) diff --git a/pkg/rtc/signalhandler.go b/pkg/rtc/signalhandler.go index 9723ec981..d6d556a61 100644 --- a/pkg/rtc/signalhandler.go +++ b/pkg/rtc/signalhandler.go @@ -71,13 +71,12 @@ func HandleParticipantSignal(room types.Room, participant types.Participant, req subTrack.UpdateSubscriberSettings(msg.TrackSetting) } case *livekit.SignalRequest_UpdateLayers: - track := participant.GetPublishedTrack(livekit.TrackID(msg.UpdateLayers.TrackSid)) - if track == nil { - pLogger.Warnw("could not find published track", nil, - "track", msg.UpdateLayers.TrackSid) + err := room.UpdateVideoLayers(participant, msg.UpdateLayers) + if err != nil { + pLogger.Warnw("could not update video layers", err, + "update", msg.UpdateLayers) return nil } - track.UpdateVideoLayers(msg.UpdateLayers.Layers) case *livekit.SignalRequest_Leave: _ = participant.Close() case *livekit.SignalRequest_SubscriptionPermissions: diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index a4dac0e1e..cdb312bb1 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -93,6 +93,8 @@ type Participant interface { UpdateSubscriptionPermissions(permissions *livekit.UpdateSubscriptionPermissions, resolver func(participantID livekit.ParticipantID) Participant) error SubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) + UpdateVideoLayers(updateVideoLayers *livekit.UpdateVideoLayers) error + UpdateSubscribedQuality(nodeID string, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error UpdateMediaLoss(nodeID string, trackID livekit.TrackID, fractionalLoss uint32) error @@ -106,6 +108,8 @@ type Room interface { Name() livekit.RoomName UpdateSubscriptions(participant Participant, trackIDs []livekit.TrackID, participantTracks []*livekit.ParticipantTracks, subscribe bool) error UpdateSubscriptionPermissions(participant Participant, permissions *livekit.UpdateSubscriptionPermissions) error + + UpdateVideoLayers(participant Participant, updateVideoLayers *livekit.UpdateVideoLayers) error } // MediaTrack represents a media track @@ -157,6 +161,10 @@ type PublishedTrack interface { Receiver() sfu.TrackReceiver GetConnectionScore() float64 + GetAudioLevel() (level uint8, active bool) + + UpdateVideoLayers(layers []*livekit.VideoLayer) + // callbacks AddOnClose(func()) } diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index c5dea5305..790cbb0b8 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -532,6 +532,17 @@ type FakeParticipant struct { updateSubscriptionPermissionsReturnsOnCall map[int]struct { result1 error } + UpdateVideoLayersStub func(*livekit.UpdateVideoLayers) error + updateVideoLayersMutex sync.RWMutex + updateVideoLayersArgsForCall []struct { + arg1 *livekit.UpdateVideoLayers + } + updateVideoLayersReturns struct { + result1 error + } + updateVideoLayersReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -3356,6 +3367,67 @@ func (fake *FakeParticipant) UpdateSubscriptionPermissionsReturnsOnCall(i int, r }{result1} } +func (fake *FakeParticipant) UpdateVideoLayers(arg1 *livekit.UpdateVideoLayers) error { + fake.updateVideoLayersMutex.Lock() + ret, specificReturn := fake.updateVideoLayersReturnsOnCall[len(fake.updateVideoLayersArgsForCall)] + fake.updateVideoLayersArgsForCall = append(fake.updateVideoLayersArgsForCall, struct { + arg1 *livekit.UpdateVideoLayers + }{arg1}) + stub := fake.UpdateVideoLayersStub + fakeReturns := fake.updateVideoLayersReturns + fake.recordInvocation("UpdateVideoLayers", []interface{}{arg1}) + fake.updateVideoLayersMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) UpdateVideoLayersCallCount() int { + fake.updateVideoLayersMutex.RLock() + defer fake.updateVideoLayersMutex.RUnlock() + return len(fake.updateVideoLayersArgsForCall) +} + +func (fake *FakeParticipant) UpdateVideoLayersCalls(stub func(*livekit.UpdateVideoLayers) error) { + fake.updateVideoLayersMutex.Lock() + defer fake.updateVideoLayersMutex.Unlock() + fake.UpdateVideoLayersStub = stub +} + +func (fake *FakeParticipant) UpdateVideoLayersArgsForCall(i int) *livekit.UpdateVideoLayers { + fake.updateVideoLayersMutex.RLock() + defer fake.updateVideoLayersMutex.RUnlock() + argsForCall := fake.updateVideoLayersArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipant) UpdateVideoLayersReturns(result1 error) { + fake.updateVideoLayersMutex.Lock() + defer fake.updateVideoLayersMutex.Unlock() + fake.UpdateVideoLayersStub = nil + fake.updateVideoLayersReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeParticipant) UpdateVideoLayersReturnsOnCall(i int, result1 error) { + fake.updateVideoLayersMutex.Lock() + defer fake.updateVideoLayersMutex.Unlock() + fake.UpdateVideoLayersStub = nil + if fake.updateVideoLayersReturnsOnCall == nil { + fake.updateVideoLayersReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateVideoLayersReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeParticipant) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -3473,6 +3545,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.updateSubscribedQualityMutex.RUnlock() fake.updateSubscriptionPermissionsMutex.RLock() defer fake.updateSubscriptionPermissionsMutex.RUnlock() + fake.updateVideoLayersMutex.RLock() + defer fake.updateVideoLayersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/pkg/rtc/types/typesfakes/fake_published_track.go b/pkg/rtc/types/typesfakes/fake_published_track.go index c29239aa8..7a159a2d5 100644 --- a/pkg/rtc/types/typesfakes/fake_published_track.go +++ b/pkg/rtc/types/typesfakes/fake_published_track.go @@ -26,6 +26,18 @@ type FakePublishedTrack struct { addSubscriberReturnsOnCall map[int]struct { result1 error } + GetAudioLevelStub func() (uint8, bool) + getAudioLevelMutex sync.RWMutex + getAudioLevelArgsForCall []struct { + } + getAudioLevelReturns struct { + result1 uint8 + result2 bool + } + getAudioLevelReturnsOnCall map[int]struct { + result1 uint8 + result2 bool + } GetConnectionScoreStub func() float64 getConnectionScoreMutex sync.RWMutex getConnectionScoreArgsForCall []struct { @@ -346,6 +358,62 @@ func (fake *FakePublishedTrack) AddSubscriberReturnsOnCall(i int, result1 error) }{result1} } +func (fake *FakePublishedTrack) GetAudioLevel() (uint8, bool) { + fake.getAudioLevelMutex.Lock() + ret, specificReturn := fake.getAudioLevelReturnsOnCall[len(fake.getAudioLevelArgsForCall)] + fake.getAudioLevelArgsForCall = append(fake.getAudioLevelArgsForCall, struct { + }{}) + stub := fake.GetAudioLevelStub + fakeReturns := fake.getAudioLevelReturns + fake.recordInvocation("GetAudioLevel", []interface{}{}) + fake.getAudioLevelMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakePublishedTrack) GetAudioLevelCallCount() int { + fake.getAudioLevelMutex.RLock() + defer fake.getAudioLevelMutex.RUnlock() + return len(fake.getAudioLevelArgsForCall) +} + +func (fake *FakePublishedTrack) GetAudioLevelCalls(stub func() (uint8, bool)) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = stub +} + +func (fake *FakePublishedTrack) GetAudioLevelReturns(result1 uint8, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + fake.getAudioLevelReturns = struct { + result1 uint8 + result2 bool + }{result1, result2} +} + +func (fake *FakePublishedTrack) GetAudioLevelReturnsOnCall(i int, result1 uint8, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + if fake.getAudioLevelReturnsOnCall == nil { + fake.getAudioLevelReturnsOnCall = make(map[int]struct { + result1 uint8 + result2 bool + }) + } + fake.getAudioLevelReturnsOnCall[i] = struct { + result1 uint8 + result2 bool + }{result1, result2} +} + func (fake *FakePublishedTrack) GetConnectionScore() float64 { fake.getConnectionScoreMutex.Lock() ret, specificReturn := fake.getConnectionScoreReturnsOnCall[len(fake.getConnectionScoreArgsForCall)] @@ -1564,6 +1632,8 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.addOnCloseMutex.RUnlock() fake.addSubscriberMutex.RLock() defer fake.addSubscriberMutex.RUnlock() + fake.getAudioLevelMutex.RLock() + defer fake.getAudioLevelMutex.RUnlock() fake.getConnectionScoreMutex.RLock() defer fake.getConnectionScoreMutex.RUnlock() fake.getQualityForDimensionMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_room.go b/pkg/rtc/types/typesfakes/fake_room.go index 940d4bdf6..0309b83c0 100644 --- a/pkg/rtc/types/typesfakes/fake_room.go +++ b/pkg/rtc/types/typesfakes/fake_room.go @@ -45,6 +45,18 @@ type FakeRoom struct { updateSubscriptionsReturnsOnCall map[int]struct { result1 error } + UpdateVideoLayersStub func(types.Participant, *livekit.UpdateVideoLayers) error + updateVideoLayersMutex sync.RWMutex + updateVideoLayersArgsForCall []struct { + arg1 types.Participant + arg2 *livekit.UpdateVideoLayers + } + updateVideoLayersReturns struct { + result1 error + } + updateVideoLayersReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -238,6 +250,68 @@ func (fake *FakeRoom) UpdateSubscriptionsReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeRoom) UpdateVideoLayers(arg1 types.Participant, arg2 *livekit.UpdateVideoLayers) error { + fake.updateVideoLayersMutex.Lock() + ret, specificReturn := fake.updateVideoLayersReturnsOnCall[len(fake.updateVideoLayersArgsForCall)] + fake.updateVideoLayersArgsForCall = append(fake.updateVideoLayersArgsForCall, struct { + arg1 types.Participant + arg2 *livekit.UpdateVideoLayers + }{arg1, arg2}) + stub := fake.UpdateVideoLayersStub + fakeReturns := fake.updateVideoLayersReturns + fake.recordInvocation("UpdateVideoLayers", []interface{}{arg1, arg2}) + fake.updateVideoLayersMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) UpdateVideoLayersCallCount() int { + fake.updateVideoLayersMutex.RLock() + defer fake.updateVideoLayersMutex.RUnlock() + return len(fake.updateVideoLayersArgsForCall) +} + +func (fake *FakeRoom) UpdateVideoLayersCalls(stub func(types.Participant, *livekit.UpdateVideoLayers) error) { + fake.updateVideoLayersMutex.Lock() + defer fake.updateVideoLayersMutex.Unlock() + fake.UpdateVideoLayersStub = stub +} + +func (fake *FakeRoom) UpdateVideoLayersArgsForCall(i int) (types.Participant, *livekit.UpdateVideoLayers) { + fake.updateVideoLayersMutex.RLock() + defer fake.updateVideoLayersMutex.RUnlock() + argsForCall := fake.updateVideoLayersArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRoom) UpdateVideoLayersReturns(result1 error) { + fake.updateVideoLayersMutex.Lock() + defer fake.updateVideoLayersMutex.Unlock() + fake.UpdateVideoLayersStub = nil + fake.updateVideoLayersReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRoom) UpdateVideoLayersReturnsOnCall(i int, result1 error) { + fake.updateVideoLayersMutex.Lock() + defer fake.updateVideoLayersMutex.Unlock() + fake.UpdateVideoLayersStub = nil + if fake.updateVideoLayersReturnsOnCall == nil { + fake.updateVideoLayersReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateVideoLayersReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeRoom) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -247,6 +321,8 @@ func (fake *FakeRoom) Invocations() map[string][][]interface{} { defer fake.updateSubscriptionPermissionsMutex.RUnlock() fake.updateSubscriptionsMutex.RLock() defer fake.updateSubscriptionsMutex.RUnlock() + fake.updateVideoLayersMutex.RLock() + defer fake.updateVideoLayersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 3207ed303..a89dd0733 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -150,6 +150,7 @@ func (u *UptrackManager) AddSubscriber(sub types.Participant, params types.AddSu if params.AllTracks { tracks = u.GetPublishedTracks() } else { + u.lock.RLock() for _, trackID := range params.TrackIDs { track := u.getPublishedTrack(trackID) if track == nil { @@ -158,6 +159,7 @@ func (u *UptrackManager) AddSubscriber(sub types.Participant, params types.AddSu tracks = append(tracks, track) } + u.lock.RUnlock() } if len(tracks) == 0 { return 0, nil @@ -186,15 +188,14 @@ func (u *UptrackManager) AddSubscriber(sub types.Participant, params types.AddSu } func (u *UptrackManager) RemoveSubscriber(sub types.Participant, trackID livekit.TrackID) { - u.lock.Lock() - defer u.lock.Unlock() - - track := u.getPublishedTrack(trackID) + track := u.GetPublishedTrack(trackID) if track != nil { track.RemoveSubscriber(sub.ID()) } + u.lock.Lock() u.maybeRemovePendingSubscription(trackID, sub) + u.lock.Unlock() } func (u *UptrackManager) SetTrackMuted(trackID livekit.TrackID, muted bool) { @@ -230,15 +231,13 @@ func (u *UptrackManager) GetAudioLevel() (level uint8, active bool) { u.lock.RLock() defer u.lock.RUnlock() - level = silentAudioLevel + level = SilentAudioLevel for _, pt := range u.publishedTracks { - if mt, ok := pt.(*MediaTrack); ok { - tl, ta := mt.GetAudioLevel() - if ta { - active = true - if tl < level { - level = tl - } + tl, ta := pt.GetAudioLevel() + if ta { + active = true + if tl < level { + level = tl } } } @@ -259,11 +258,11 @@ func (u *UptrackManager) GetConnectionQuality() (scores float64, numTracks int) return } -func (u *UptrackManager) GetPublishedTrack(sid livekit.TrackID) types.PublishedTrack { +func (u *UptrackManager) GetPublishedTrack(trackID livekit.TrackID) types.PublishedTrack { u.lock.RLock() defer u.lock.RUnlock() - return u.getPublishedTrack(sid) + return u.getPublishedTrack(trackID) } func (u *UptrackManager) GetPublishedTracks() []types.PublishedTrack { @@ -312,11 +311,18 @@ func (u *UptrackManager) UpdateSubscriptionPermissions( return nil } -func (u *UptrackManager) UpdateSubscribedQuality(nodeID string, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error { - u.lock.RLock() - defer u.lock.RUnlock() +func (u *UptrackManager) UpdateVideoLayers(updateVideoLayers *livekit.UpdateVideoLayers) error { + track := u.GetPublishedTrack(livekit.TrackID(updateVideoLayers.TrackSid)) + if track == nil { + return errors.New("could not find published track") + } - track := u.getPublishedTrack(trackID) + track.UpdateVideoLayers(updateVideoLayers.Layers) + return nil +} + +func (u *UptrackManager) UpdateSubscribedQuality(nodeID string, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error { + track := u.GetPublishedTrack(trackID) if track == nil { u.params.Logger.Warnw("could not find track", nil, "trackID", trackID) return errors.New("could not find track") @@ -330,10 +336,7 @@ func (u *UptrackManager) UpdateSubscribedQuality(nodeID string, trackID livekit. } func (u *UptrackManager) UpdateMediaLoss(nodeID string, trackID livekit.TrackID, fractionalLoss uint32) error { - u.lock.RLock() - defer u.lock.RUnlock() - - track := u.getPublishedTrack(trackID) + track := u.GetPublishedTrack(trackID) if track == nil { u.params.Logger.Warnw("could not find track", nil, "trackID", trackID) return errors.New("could not find track") @@ -405,8 +408,8 @@ func (u *UptrackManager) MediaTrackReceived(track *webrtc.TrackRemote, rtpReceiv } // should be called with lock held -func (u *UptrackManager) getPublishedTrack(sid livekit.TrackID) types.PublishedTrack { - return u.publishedTracks[sid] +func (u *UptrackManager) getPublishedTrack(trackID livekit.TrackID) types.PublishedTrack { + return u.publishedTracks[trackID] } // should be called with lock held