From 0898c17e8ada385da467044bc047f97248aa9428 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Thu, 28 Oct 2021 21:01:05 -0700 Subject: [PATCH] Select video quality using provided dimensions (#158) --- pkg/rtc/mediatrack.go | 51 ++++-- pkg/rtc/mediatrack_test.go | 29 ++++ pkg/rtc/participant.go | 109 +++++++------ pkg/rtc/subscribedtrack.go | 27 +++- pkg/rtc/types/interfaces.go | 5 + pkg/rtc/types/typesfakes/fake_participant.go | 148 ++++++++++++++++++ .../types/typesfakes/fake_published_track.go | 76 +++++++++ .../types/typesfakes/fake_subscribed_track.go | 65 ++++++++ pkg/service/roommanager.go | 50 +++++- pkg/service/wire_gen.go | 3 +- 10 files changed, 481 insertions(+), 82 deletions(-) diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 08e0b78cb..e9f83e500 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -52,15 +52,16 @@ type MediaTrack struct { } type MediaTrackParams struct { - TrackInfo *livekit.TrackInfo - SignalCid string - SdpCid string - ParticipantID string - RTCPChan chan []rtcp.Packet - BufferFactory *buffer.Factory - ReceiverConfig ReceiverConfig - AudioConfig config.AudioConfig - Stats *stats.RoomStatsReporter + TrackInfo *livekit.TrackInfo + SignalCid string + SdpCid string + ParticipantID string + ParticipantIdentity string + RTCPChan chan []rtcp.Packet + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + AudioConfig config.AudioConfig + Stats *stats.RoomStatsReporter } func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTrack { @@ -172,7 +173,7 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { if err != nil { return err } - subTrack := NewSubscribedTrack(downTrack) + subTrack := NewSubscribedTrack(t.params.ParticipantIdentity, downTrack) var transceiver *webrtc.RTPTransceiver var sender *webrtc.RTPSender @@ -384,6 +385,36 @@ func (t *MediaTrack) ToProto() *livekit.TrackInfo { return info } +// GetQualityForDimension finds the closest quality to use for desired dimensions +// affords a 10% tolerance on dimension +func (t *MediaTrack) GetQualityForDimension(width, height uint32) livekit.VideoQuality { + quality := livekit.VideoQuality_HIGH + if t.Kind() == livekit.TrackType_AUDIO || t.params.TrackInfo.Height == 0 { + return quality + } + origSize := t.params.TrackInfo.Height + requestedSize := height + if t.params.TrackInfo.Width < t.params.TrackInfo.Height { + // for portrait videos + origSize = t.params.TrackInfo.Width + requestedSize = width + } + + // representing qualities low - high + layerSizes := []uint32{180, 360, origSize} + + // finds the lowest layer that could satisfy client demands + requestedSize = uint32(float32(requestedSize) * 0.9) + for i, s := range layerSizes { + quality = livekit.VideoQuality(i) + if s >= requestedSize { + break + } + } + + return quality +} + // this function assumes caller holds lock func (t *MediaTrack) shouldStartWithBestQuality() bool { return len(t.subscribedTracks) < 10 diff --git a/pkg/rtc/mediatrack_test.go b/pkg/rtc/mediatrack_test.go index 566ca2d76..c17559957 100644 --- a/pkg/rtc/mediatrack_test.go +++ b/pkg/rtc/mediatrack_test.go @@ -41,3 +41,32 @@ func TestTrackInfo(t *testing.T) { mt.simulcasted = true require.True(t, mt.ToProto().Simulcast) } + +func TestGetQualityForDimension(t *testing.T) { + t.Run("landscape source", func(t *testing.T) { + mt := NewMediaTrack(&webrtc.TrackRemote{}, MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 1080, + Height: 720, + }}) + + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(120, 120)) + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(300, 200)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(200, 250)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(700, 480)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(500, 1000)) + }) + + t.Run("portrait source", func(t *testing.T) { + mt := NewMediaTrack(&webrtc.TrackRemote{}, MediaTrackParams{TrackInfo: &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 540, + Height: 960, + }}) + + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(200, 400)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(400, 400)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(400, 700)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(600, 900)) + }) +} diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 87ee64c2c..1d40debf8 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -70,8 +70,8 @@ type ParticipantImpl struct { // hold reference for MediaTrack twcc *twcc.Responder - // tracks the current participant is subscribed to, map of otherParticipantId => []DownTrack - subscribedTracks map[string][]types.SubscribedTrack + // tracks the current participant is subscribed to, map of sid => DownTrack + subscribedTracks map[string]types.SubscribedTrack // publishedTracks that participant is publishing publishedTracks map[string]types.PublishedTrack // client intended to publish, yet to be reconciled @@ -97,7 +97,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { id: utils.NewGuid(utils.ParticipantPrefix), rtcpCh: make(chan []rtcp.Packet, 50), pliThrottle: newPLIThrottle(params.ThrottleConfig), - subscribedTracks: make(map[string][]types.SubscribedTrack), + subscribedTracks: make(map[string]types.SubscribedTrack), publishedTracks: make(map[string]types.PublishedTrack, 0), pendingTracks: make(map[string]*livekit.TrackInfo), connectedAt: time.Now(), @@ -361,16 +361,6 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { }) } -func (p *ParticipantImpl) GetPublishedTracks() []types.PublishedTrack { - p.lock.RLock() - defer p.lock.RUnlock() - tracks := make([]types.PublishedTrack, 0, len(p.publishedTracks)) - for _, t := range p.publishedTracks { - tracks = append(tracks, t) - } - return tracks -} - // HandleAnswer handles a client answer response, with subscriber PC, server initiates the // offer and client answers func (p *ParticipantImpl) HandleAnswer(sdp webrtc.SessionDescription) error { @@ -429,10 +419,8 @@ func (p *ParticipantImpl) Close() error { } var downtracksToClose []*sfu.DownTrack - for _, tracks := range p.subscribedTracks { - for _, st := range tracks { - downtracksToClose = append(downtracksToClose, st.DownTrack()) - } + for _, st := range p.subscribedTracks { + downtracksToClose = append(downtracksToClose, st.DownTrack()) } p.lock.Unlock() @@ -686,14 +674,34 @@ func (p *ParticipantImpl) SubscriberPC() *webrtc.PeerConnection { return p.subscriber.pc } +func (p *ParticipantImpl) GetPublishedTrack(sid string) types.PublishedTrack { + p.lock.RLock() + defer p.lock.RUnlock() + return p.publishedTracks[sid] +} + +func (p *ParticipantImpl) GetPublishedTracks() []types.PublishedTrack { + p.lock.RLock() + defer p.lock.RUnlock() + tracks := make([]types.PublishedTrack, 0, len(p.publishedTracks)) + for _, t := range p.publishedTracks { + tracks = append(tracks, t) + } + return tracks +} + +func (p *ParticipantImpl) GetSubscribedTrack(sid string) types.SubscribedTrack { + p.lock.RLock() + defer p.lock.RUnlock() + return p.subscribedTracks[sid] +} + func (p *ParticipantImpl) GetSubscribedTracks() []types.SubscribedTrack { p.lock.RLock() defer p.lock.RUnlock() subscribed := make([]types.SubscribedTrack, 0, len(p.subscribedTracks)) - for _, pTracks := range p.subscribedTracks { - for _, t := range pTracks { - subscribed = append(subscribed, t) - } + for _, st := range p.subscribedTracks { + subscribed = append(subscribed, st) } return subscribed } @@ -703,7 +711,7 @@ func (p *ParticipantImpl) AddSubscribedTrack(pubId string, subTrack types.Subscr logger.Debugw("added subscribedTrack", "pIDs", []string{pubId, p.ID()}, "participant", p.Identity(), "track", subTrack.ID()) p.lock.Lock() - p.subscribedTracks[pubId] = append(p.subscribedTracks[pubId], subTrack) + p.subscribedTracks[subTrack.ID()] = subTrack p.lock.Unlock() } @@ -712,14 +720,8 @@ func (p *ParticipantImpl) RemoveSubscribedTrack(pubId string, subTrack types.Sub logger.Debugw("removed subscribedTrack", "pIDs", []string{pubId, p.ID()}, "participant", p.Identity(), "track", subTrack.ID()) p.lock.Lock() - defer p.lock.Unlock() - tracks := make([]types.SubscribedTrack, 0, len(p.subscribedTracks[pubId])) - for _, tr := range p.subscribedTracks[pubId] { - if tr != subTrack { - tracks = append(tracks, tr) - } - } - p.subscribedTracks[pubId] = tracks + delete(p.subscribedTracks, subTrack.ID()) + p.lock.Unlock() } func (p *ParticipantImpl) sendIceCandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { @@ -834,15 +836,16 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w } mt = NewMediaTrack(track, MediaTrackParams{ - TrackInfo: ti, - SignalCid: signalCid, - SdpCid: track.ID(), - ParticipantID: p.id, - RTCPChan: p.rtcpCh, - BufferFactory: p.params.Config.BufferFactory, - ReceiverConfig: p.params.Config.Receiver, - AudioConfig: p.params.AudioConfig, - Stats: p.params.Stats, + TrackInfo: ti, + SignalCid: signalCid, + SdpCid: track.ID(), + ParticipantID: p.id, + ParticipantIdentity: p.Identity(), + RTCPChan: p.rtcpCh, + BufferFactory: p.params.Config.BufferFactory, + ReceiverConfig: p.params.Config.Receiver, + AudioConfig: p.params.AudioConfig, + Stats: p.params.Stats, }) // add to published and clean up pending @@ -1013,16 +1016,14 @@ func (p *ParticipantImpl) downTracksRTCPWorker() { var srs []rtcp.Packet var sd []rtcp.SourceDescriptionChunk p.lock.RLock() - for _, tracks := range p.subscribedTracks { - for _, subTrack := range tracks { - sr := subTrack.DownTrack().CreateSenderReport() - chunks := subTrack.DownTrack().CreateSourceDescriptionChunks() - if sr == nil || chunks == nil { - continue - } - srs = append(srs, sr) - sd = append(sd, chunks...) + for _, subTrack := range p.subscribedTracks { + sr := subTrack.DownTrack().CreateSenderReport() + chunks := subTrack.DownTrack().CreateSourceDescriptionChunks() + if sr == nil || chunks == nil { + continue } + srs = append(srs, sr) + sd = append(sd, chunks...) } p.lock.RUnlock() @@ -1225,14 +1226,10 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { } } - for pubID, tracks := range p.subscribedTracks { - trackInfo := make([]map[string]interface{}, 0, len(tracks)) - for _, track := range tracks { - dt := track.DownTrack().DebugInfo() - dt["SubMuted"] = track.IsMuted() - trackInfo = append(trackInfo, dt) - } - subscribedTrackInfo[pubID] = trackInfo + for _, track := range p.subscribedTracks { + dt := track.DownTrack().DebugInfo() + dt["SubMuted"] = track.IsMuted() + subscribedTrackInfo[track.ID()] = dt } for clientID, track := range p.pendingTracks { diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index ec9b2364a..eeb5ffdea 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -15,16 +15,18 @@ const ( ) type SubscribedTrack struct { - dt *sfu.DownTrack - subMuted utils.AtomicFlag - pubMuted utils.AtomicFlag - debouncer func(func()) + dt *sfu.DownTrack + publisherIdentity string + subMuted utils.AtomicFlag + pubMuted utils.AtomicFlag + debouncer func(func()) } -func NewSubscribedTrack(dt *sfu.DownTrack) *SubscribedTrack { +func NewSubscribedTrack(publisherIdentity string, dt *sfu.DownTrack) *SubscribedTrack { return &SubscribedTrack{ - dt: dt, - debouncer: debounce.New(subscriptionDebounceInterval), + publisherIdentity: publisherIdentity, + dt: dt, + debouncer: debounce.New(subscriptionDebounceInterval), } } @@ -32,6 +34,10 @@ func (t *SubscribedTrack) ID() string { return t.dt.ID() } +func (t *SubscribedTrack) PublisherIdentity() string { + return t.publisherIdentity +} + func (t *SubscribedTrack) DownTrack() *sfu.DownTrack { return t.dt } @@ -65,6 +71,13 @@ func (t *SubscribedTrack) updateDownTrackMute() { t.dt.Mute(muted) } +// GetQualityForDimension finds the closest quality to use for desired dimensions +// affords a 10% tolerance on dimension +func GetQualityForDimension(width, height uint32) livekit.VideoQuality { + // currently the layers are set to 180p/360p/original res, we should re + return livekit.VideoQuality_HIGH +} + func spatialLayerForQuality(quality livekit.VideoQuality) int32 { switch quality { case livekit.VideoQuality_LOW: diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index b5926580b..bd29ebc7b 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -39,7 +39,9 @@ type Participant interface { ICERestart() error AddTrack(req *livekit.AddTrackRequest) + GetPublishedTrack(sid string) PublishedTrack GetPublishedTracks() []PublishedTrack + GetSubscribedTrack(sid string) SubscribedTrack GetSubscribedTracks() []SubscribedTrack HandleOffer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) HandleAnswer(sdp webrtc.SessionDescription) error @@ -101,6 +103,8 @@ type PublishedTrack interface { RemoveSubscriber(participantId string) IsSubscriber(subId string) bool RemoveAllSubscribers() + // returns quality information that's appropriate for width & height + GetQualityForDimension(width, height uint32) livekit.VideoQuality ToProto() *livekit.TrackInfo // callbacks @@ -110,6 +114,7 @@ type PublishedTrack interface { //counterfeiter:generate . SubscribedTrack type SubscribedTrack interface { ID() string + PublisherIdentity() string DownTrack() *sfu.DownTrack IsMuted() bool SetPublisherMuted(muted bool) diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index bd47e5b7f..d31589fbc 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -121,6 +121,17 @@ type FakeParticipant struct { result1 uint8 result2 bool } + GetPublishedTrackStub func(string) types.PublishedTrack + getPublishedTrackMutex sync.RWMutex + getPublishedTrackArgsForCall []struct { + arg1 string + } + getPublishedTrackReturns struct { + result1 types.PublishedTrack + } + getPublishedTrackReturnsOnCall map[int]struct { + result1 types.PublishedTrack + } GetPublishedTracksStub func() []types.PublishedTrack getPublishedTracksMutex sync.RWMutex getPublishedTracksArgsForCall []struct { @@ -141,6 +152,17 @@ type FakeParticipant struct { getResponseSinkReturnsOnCall map[int]struct { result1 routing.MessageSink } + GetSubscribedTrackStub func(string) types.SubscribedTrack + getSubscribedTrackMutex sync.RWMutex + getSubscribedTrackArgsForCall []struct { + arg1 string + } + getSubscribedTrackReturns struct { + result1 types.SubscribedTrack + } + getSubscribedTrackReturnsOnCall map[int]struct { + result1 types.SubscribedTrack + } GetSubscribedTracksStub func() []types.SubscribedTrack getSubscribedTracksMutex sync.RWMutex getSubscribedTracksArgsForCall []struct { @@ -992,6 +1014,67 @@ func (fake *FakeParticipant) GetAudioLevelReturnsOnCall(i int, result1 uint8, re }{result1, result2} } +func (fake *FakeParticipant) GetPublishedTrack(arg1 string) types.PublishedTrack { + fake.getPublishedTrackMutex.Lock() + ret, specificReturn := fake.getPublishedTrackReturnsOnCall[len(fake.getPublishedTrackArgsForCall)] + fake.getPublishedTrackArgsForCall = append(fake.getPublishedTrackArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.GetPublishedTrackStub + fakeReturns := fake.getPublishedTrackReturns + fake.recordInvocation("GetPublishedTrack", []interface{}{arg1}) + fake.getPublishedTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetPublishedTrackCallCount() int { + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() + return len(fake.getPublishedTrackArgsForCall) +} + +func (fake *FakeParticipant) GetPublishedTrackCalls(stub func(string) types.PublishedTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = stub +} + +func (fake *FakeParticipant) GetPublishedTrackArgsForCall(i int) string { + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() + argsForCall := fake.getPublishedTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipant) GetPublishedTrackReturns(result1 types.PublishedTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = nil + fake.getPublishedTrackReturns = struct { + result1 types.PublishedTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedTrackReturnsOnCall(i int, result1 types.PublishedTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = nil + if fake.getPublishedTrackReturnsOnCall == nil { + fake.getPublishedTrackReturnsOnCall = make(map[int]struct { + result1 types.PublishedTrack + }) + } + fake.getPublishedTrackReturnsOnCall[i] = struct { + result1 types.PublishedTrack + }{result1} +} + func (fake *FakeParticipant) GetPublishedTracks() []types.PublishedTrack { fake.getPublishedTracksMutex.Lock() ret, specificReturn := fake.getPublishedTracksReturnsOnCall[len(fake.getPublishedTracksArgsForCall)] @@ -1098,6 +1181,67 @@ func (fake *FakeParticipant) GetResponseSinkReturnsOnCall(i int, result1 routing }{result1} } +func (fake *FakeParticipant) GetSubscribedTrack(arg1 string) types.SubscribedTrack { + fake.getSubscribedTrackMutex.Lock() + ret, specificReturn := fake.getSubscribedTrackReturnsOnCall[len(fake.getSubscribedTrackArgsForCall)] + fake.getSubscribedTrackArgsForCall = append(fake.getSubscribedTrackArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.GetSubscribedTrackStub + fakeReturns := fake.getSubscribedTrackReturns + fake.recordInvocation("GetSubscribedTrack", []interface{}{arg1}) + fake.getSubscribedTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetSubscribedTrackCallCount() int { + fake.getSubscribedTrackMutex.RLock() + defer fake.getSubscribedTrackMutex.RUnlock() + return len(fake.getSubscribedTrackArgsForCall) +} + +func (fake *FakeParticipant) GetSubscribedTrackCalls(stub func(string) types.SubscribedTrack) { + fake.getSubscribedTrackMutex.Lock() + defer fake.getSubscribedTrackMutex.Unlock() + fake.GetSubscribedTrackStub = stub +} + +func (fake *FakeParticipant) GetSubscribedTrackArgsForCall(i int) string { + fake.getSubscribedTrackMutex.RLock() + defer fake.getSubscribedTrackMutex.RUnlock() + argsForCall := fake.getSubscribedTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipant) GetSubscribedTrackReturns(result1 types.SubscribedTrack) { + fake.getSubscribedTrackMutex.Lock() + defer fake.getSubscribedTrackMutex.Unlock() + fake.GetSubscribedTrackStub = nil + fake.getSubscribedTrackReturns = struct { + result1 types.SubscribedTrack + }{result1} +} + +func (fake *FakeParticipant) GetSubscribedTrackReturnsOnCall(i int, result1 types.SubscribedTrack) { + fake.getSubscribedTrackMutex.Lock() + defer fake.getSubscribedTrackMutex.Unlock() + fake.GetSubscribedTrackStub = nil + if fake.getSubscribedTrackReturnsOnCall == nil { + fake.getSubscribedTrackReturnsOnCall = make(map[int]struct { + result1 types.SubscribedTrack + }) + } + fake.getSubscribedTrackReturnsOnCall[i] = struct { + result1 types.SubscribedTrack + }{result1} +} + func (fake *FakeParticipant) GetSubscribedTracks() []types.SubscribedTrack { fake.getSubscribedTracksMutex.Lock() ret, specificReturn := fake.getSubscribedTracksReturnsOnCall[len(fake.getSubscribedTracksArgsForCall)] @@ -2699,10 +2843,14 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.debugInfoMutex.RUnlock() fake.getAudioLevelMutex.RLock() defer fake.getAudioLevelMutex.RUnlock() + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() fake.getPublishedTracksMutex.RLock() defer fake.getPublishedTracksMutex.RUnlock() fake.getResponseSinkMutex.RLock() defer fake.getResponseSinkMutex.RUnlock() + fake.getSubscribedTrackMutex.RLock() + defer fake.getSubscribedTrackMutex.RUnlock() fake.getSubscribedTracksMutex.RLock() defer fake.getSubscribedTracksMutex.RUnlock() fake.handleAnswerMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_published_track.go b/pkg/rtc/types/typesfakes/fake_published_track.go index 60aeb7c5b..5ba749019 100644 --- a/pkg/rtc/types/typesfakes/fake_published_track.go +++ b/pkg/rtc/types/typesfakes/fake_published_track.go @@ -20,6 +20,18 @@ type FakePublishedTrack struct { addSubscriberReturnsOnCall map[int]struct { result1 error } + GetQualityForDimensionStub func(uint32, uint32) livekit.VideoQuality + getQualityForDimensionMutex sync.RWMutex + getQualityForDimensionArgsForCall []struct { + arg1 uint32 + arg2 uint32 + } + getQualityForDimensionReturns struct { + result1 livekit.VideoQuality + } + getQualityForDimensionReturnsOnCall map[int]struct { + result1 livekit.VideoQuality + } IDStub func() string iDMutex sync.RWMutex iDArgsForCall []struct { @@ -189,6 +201,68 @@ func (fake *FakePublishedTrack) AddSubscriberReturnsOnCall(i int, result1 error) }{result1} } +func (fake *FakePublishedTrack) GetQualityForDimension(arg1 uint32, arg2 uint32) livekit.VideoQuality { + fake.getQualityForDimensionMutex.Lock() + ret, specificReturn := fake.getQualityForDimensionReturnsOnCall[len(fake.getQualityForDimensionArgsForCall)] + fake.getQualityForDimensionArgsForCall = append(fake.getQualityForDimensionArgsForCall, struct { + arg1 uint32 + arg2 uint32 + }{arg1, arg2}) + stub := fake.GetQualityForDimensionStub + fakeReturns := fake.getQualityForDimensionReturns + fake.recordInvocation("GetQualityForDimension", []interface{}{arg1, arg2}) + fake.getQualityForDimensionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePublishedTrack) GetQualityForDimensionCallCount() int { + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() + return len(fake.getQualityForDimensionArgsForCall) +} + +func (fake *FakePublishedTrack) GetQualityForDimensionCalls(stub func(uint32, uint32) livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = stub +} + +func (fake *FakePublishedTrack) GetQualityForDimensionArgsForCall(i int) (uint32, uint32) { + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() + argsForCall := fake.getQualityForDimensionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakePublishedTrack) GetQualityForDimensionReturns(result1 livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = nil + fake.getQualityForDimensionReturns = struct { + result1 livekit.VideoQuality + }{result1} +} + +func (fake *FakePublishedTrack) GetQualityForDimensionReturnsOnCall(i int, result1 livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = nil + if fake.getQualityForDimensionReturnsOnCall == nil { + fake.getQualityForDimensionReturnsOnCall = make(map[int]struct { + result1 livekit.VideoQuality + }) + } + fake.getQualityForDimensionReturnsOnCall[i] = struct { + result1 livekit.VideoQuality + }{result1} +} + func (fake *FakePublishedTrack) ID() string { fake.iDMutex.Lock() ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] @@ -770,6 +844,8 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.addSubscriberMutex.RLock() defer fake.addSubscriberMutex.RUnlock() + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() fake.iDMutex.RLock() defer fake.iDMutex.RUnlock() fake.isMutedMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_subscribed_track.go b/pkg/rtc/types/typesfakes/fake_subscribed_track.go index 69dbc584a..a982a4af2 100644 --- a/pkg/rtc/types/typesfakes/fake_subscribed_track.go +++ b/pkg/rtc/types/typesfakes/fake_subscribed_track.go @@ -40,6 +40,16 @@ type FakeSubscribedTrack struct { isMutedReturnsOnCall map[int]struct { result1 bool } + PublisherIdentityStub func() string + publisherIdentityMutex sync.RWMutex + publisherIdentityArgsForCall []struct { + } + publisherIdentityReturns struct { + result1 string + } + publisherIdentityReturnsOnCall map[int]struct { + result1 string + } SetPublisherMutedStub func(bool) setPublisherMutedMutex sync.RWMutex setPublisherMutedArgsForCall []struct { @@ -214,6 +224,59 @@ func (fake *FakeSubscribedTrack) IsMutedReturnsOnCall(i int, result1 bool) { }{result1} } +func (fake *FakeSubscribedTrack) PublisherIdentity() string { + fake.publisherIdentityMutex.Lock() + ret, specificReturn := fake.publisherIdentityReturnsOnCall[len(fake.publisherIdentityArgsForCall)] + fake.publisherIdentityArgsForCall = append(fake.publisherIdentityArgsForCall, struct { + }{}) + stub := fake.PublisherIdentityStub + fakeReturns := fake.publisherIdentityReturns + fake.recordInvocation("PublisherIdentity", []interface{}{}) + fake.publisherIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) PublisherIdentityCallCount() int { + fake.publisherIdentityMutex.RLock() + defer fake.publisherIdentityMutex.RUnlock() + return len(fake.publisherIdentityArgsForCall) +} + +func (fake *FakeSubscribedTrack) PublisherIdentityCalls(stub func() string) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = stub +} + +func (fake *FakeSubscribedTrack) PublisherIdentityReturns(result1 string) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + fake.publisherIdentityReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeSubscribedTrack) PublisherIdentityReturnsOnCall(i int, result1 string) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + if fake.publisherIdentityReturnsOnCall == nil { + fake.publisherIdentityReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.publisherIdentityReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakeSubscribedTrack) SetPublisherMuted(arg1 bool) { fake.setPublisherMutedMutex.Lock() fake.setPublisherMutedArgsForCall = append(fake.setPublisherMutedArgsForCall, struct { @@ -288,6 +351,8 @@ func (fake *FakeSubscribedTrack) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.isMutedMutex.RLock() defer fake.isMutedMutex.RUnlock() + fake.publisherIdentityMutex.RLock() + defer fake.publisherIdentityMutex.RUnlock() fake.setPublisherMutedMutex.RLock() defer fake.setPublisherMutedMutex.RUnlock() fake.updateSubscriberSettingsMutex.RLock() diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index da23c2052..5012af925 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -409,17 +409,51 @@ func (r *LocalRoomManager) rtcSessionWorker(room *rtc.Room, participant types.Pa "subscribe", msg.Subscription.Subscribe) } case *livekit.SignalRequest_TrackSetting: - for _, subTrack := range participant.GetSubscribedTracks() { - for _, sid := range msg.TrackSetting.TrackSids { - if subTrack.ID() != sid { - continue - } - logger.Debugw("updating track settings", + for _, sid := range msg.TrackSetting.TrackSids { + subTrack := participant.GetSubscribedTrack(sid) + if subTrack == nil { + logger.Warnw("unable to find SubscribedTrack", + nil, "participant", participant.Identity(), "pID", participant.ID(), - "settings", msg.TrackSetting) - subTrack.UpdateSubscriberSettings(!msg.TrackSetting.Disabled, msg.TrackSetting.Quality) + "track", sid) + continue } + + // find the source PublishedTrack + publisher := room.GetParticipant(subTrack.PublisherIdentity()) + if publisher == nil { + logger.Warnw("unable to find publisher of SubscribedTrack", + nil, + "participant", participant.Identity(), + "pID", participant.ID(), + "publisher", subTrack.PublisherIdentity(), + "track", sid) + continue + } + + pubTrack := publisher.GetPublishedTrack(sid) + if pubTrack == nil { + logger.Warnw("unable to find PublishedTrack", + nil, + "participant", publisher.Identity(), + "pID", publisher.ID(), + "track", sid) + continue + } + if msg.TrackSetting.Width > 0 { + msg.TrackSetting.Quality = pubTrack.GetQualityForDimension(msg.TrackSetting.Width, msg.TrackSetting.Height) + } + + // find quality for published track + logger.Debugw("updating track settings", + "participant", participant.Identity(), + "pID", participant.ID(), + "settings", msg.TrackSetting) + subTrack.UpdateSubscriberSettings( + !msg.TrackSetting.Disabled, + msg.TrackSetting.Quality, + ) } case *livekit.SignalRequest_Leave: _ = participant.Close() diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index f7d3557a8..48a1b1cd4 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -1,7 +1,8 @@ // Code generated by Wire. DO NOT EDIT. //go:generate go run github.com/google/wire/cmd/wire -//+build !wireinject +//go:build !wireinject +// +build !wireinject package service