From fe8c355a32e32ec00ab5b48855303a322724bdbd Mon Sep 17 00:00:00 2001 From: Raja Subramanian Date: Mon, 2 May 2022 12:35:20 +0530 Subject: [PATCH] Support participant identity in permissions (#663) * Support participant identity in permissions It is harder for clients to update permissions by SID as remote reconnecting means a new SID for that participant. Using participant identity is a better option. For now, participant SID is also supported. Internally, it will get mapped to identity. Server code uses identity throughout after doing any necessary conversion from SID -> Identity. * Address comments --- go.mod | 3 +- go.sum | 9 +- pkg/rtc/mediatracksubscriptions.go | 25 ++-- pkg/rtc/room.go | 2 +- pkg/rtc/subscribedtrack.go | 17 ++- pkg/rtc/types/interfaces.go | 9 +- .../typesfakes/fake_local_media_track.go | 30 ++-- .../typesfakes/fake_local_participant.go | 22 +-- pkg/rtc/types/typesfakes/fake_media_track.go | 30 ++-- pkg/rtc/types/typesfakes/fake_participant.go | 22 +-- .../types/typesfakes/fake_subscribed_track.go | 65 +++++++++ pkg/rtc/uptrackmanager.go | 119 ++++++++++----- pkg/rtc/uptrackmanager_test.go | 137 +++++++++++++++--- 13 files changed, 348 insertions(+), 142 deletions(-) diff --git a/go.mod b/go.mod index efcaf6a89..ecb28e858 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/google/wire v0.5.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.13.2-0.20220421193517-fa9efaff8ca5 + github.com/livekit/protocol v0.13.2-0.20220502043729-cef16c8ef304 github.com/mackerelio/go-osstat v0.2.1 github.com/magefile/mage v1.11.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0 @@ -31,7 +31,6 @@ require ( github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.11.0 github.com/rs/cors v1.8.2 - github.com/rs/zerolog v1.26.0 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/stretchr/testify v1.7.1 github.com/thoas/go-funk v0.8.0 diff --git a/go.sum b/go.sum index 8f94d8a8f..91eba8ef0 100644 --- a/go.sum +++ b/go.sum @@ -25,7 +25,6 @@ github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XP github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -70,7 +69,6 @@ github.com/go-redis/redis/v8 v8.11.3 h1:GCjoYp8c+yQTJfc0n69iwSiHjvuAdruxl7elnZCx github.com/go-redis/redis/v8 v8.11.3/go.mod h1:xNJ9xDG09FsIPwh3bWdk+0oDWHbtF9rPN0F/oD9XeKc= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -131,8 +129,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY/mKrcvA= github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= -github.com/livekit/protocol v0.13.2-0.20220421193517-fa9efaff8ca5 h1:cizw9GNLEQosDETGl69H/0dDGyIyDd8I6nLqjlXd/Ic= -github.com/livekit/protocol v0.13.2-0.20220421193517-fa9efaff8ca5/go.mod h1:BLtSeVmn2rLP37xjzw7gHgaAmkWl3L/L9bPvgSbaOfo= +github.com/livekit/protocol v0.13.2-0.20220502043729-cef16c8ef304 h1:89IfWGgaolVwwifutyhqqWdIGQQTdXjWxLb1jKxGB5s= +github.com/livekit/protocol v0.13.2-0.20220502043729-cef16c8ef304/go.mod h1:BLtSeVmn2rLP37xjzw7gHgaAmkWl3L/L9bPvgSbaOfo= github.com/mackerelio/go-osstat v0.2.1 h1:5AeAcBEutEErAOlDz6WCkEvm6AKYgHTUQrfwm5RbeQc= github.com/mackerelio/go-osstat v0.2.1/go.mod h1:UzRL8dMCCTqG5WdRtsxbuljMpZt9PCAGXqxPst5QtaY= github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls= @@ -237,9 +235,6 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rs/cors v1.8.2 h1:KCooALfAYGs415Cwu5ABvv9n9509fSiG5SQJn/AQo4U= github.com/rs/cors v1.8.2/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= -github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.26.0 h1:ORM4ibhEZeTeQlCojCK2kPz1ogAY4bGs4tD+SaAdGaE= -github.com/rs/zerolog v1.26.0/go.mod h1:yBiM87lvSqX8h0Ww4sdzNSkVYZ8dL2xjZJG1lAuGZEo= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/pkg/rtc/mediatracksubscriptions.go b/pkg/rtc/mediatracksubscriptions.go index d13b02f52..d960d2220 100644 --- a/pkg/rtc/mediatracksubscriptions.go +++ b/pkg/rtc/mediatracksubscriptions.go @@ -142,12 +142,13 @@ func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, code } subTrack := NewSubscribedTrack(SubscribedTrackParams{ - PublisherID: t.params.MediaTrack.PublisherID(), - PublisherIdentity: t.params.MediaTrack.PublisherIdentity(), - SubscriberID: subscriberID, - MediaTrack: t.params.MediaTrack, - DownTrack: downTrack, - AdaptiveStream: sub.GetAdaptiveStream(), + PublisherID: t.params.MediaTrack.PublisherID(), + PublisherIdentity: t.params.MediaTrack.PublisherIdentity(), + SubscriberID: subscriberID, + SubscriberIdentity: sub.Identity(), + MediaTrack: t.params.MediaTrack, + DownTrack: downTrack, + AdaptiveStream: sub.GetAdaptiveStream(), }) var transceiver *webrtc.RTPTransceiver @@ -276,14 +277,14 @@ func (t *MediaTrackSubscriptions) ResyncAllSubscribers() { } } -func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID { - var revokedSubscriberIDs []livekit.ParticipantID +func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { + var revokedSubscriberIdentities []livekit.ParticipantIdentity // LK-TODO: large number of subscribers needs to be solved for this loop for _, subTrack := range t.getAllSubscribedTracks() { found := false - for _, allowedID := range allowedSubscriberIDs { - if subTrack.SubscriberID() == allowedID { + for _, allowedIdentity := range allowedSubscriberIdentities { + if subTrack.SubscriberIdentity() == allowedIdentity { found = true break } @@ -291,11 +292,11 @@ func (t *MediaTrackSubscriptions) RevokeDisallowedSubscribers(allowedSubscriberI if !found { go subTrack.DownTrack().Close() - revokedSubscriberIDs = append(revokedSubscriberIDs, subTrack.SubscriberID()) + revokedSubscriberIdentities = append(revokedSubscriberIdentities, subTrack.SubscriberIdentity()) } } - return revokedSubscriberIDs + return revokedSubscriberIdentities } func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID { diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index ec091bad5..580c2b1a6 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -397,7 +397,7 @@ func (r *Room) SyncState(participant types.LocalParticipant, state *livekit.Sync } func (r *Room) UpdateSubscriptionPermission(participant types.LocalParticipant, subscriptionPermission *livekit.SubscriptionPermission) error { - return participant.UpdateSubscriptionPermission(subscriptionPermission, r.GetParticipantBySid) + return participant.UpdateSubscriptionPermission(subscriptionPermission, r.GetParticipant, r.GetParticipantBySid) } func (r *Room) RemoveDisallowedSubscriptions(sub types.LocalParticipant, disallowedSubscriptions map[livekit.TrackID]livekit.ParticipantID) { diff --git a/pkg/rtc/subscribedtrack.go b/pkg/rtc/subscribedtrack.go index b9bc5d5b1..3d6b0b62c 100644 --- a/pkg/rtc/subscribedtrack.go +++ b/pkg/rtc/subscribedtrack.go @@ -18,12 +18,13 @@ const ( ) type SubscribedTrackParams struct { - PublisherID livekit.ParticipantID - PublisherIdentity livekit.ParticipantIdentity - SubscriberID livekit.ParticipantID - MediaTrack types.MediaTrack - DownTrack *sfu.DownTrack - AdaptiveStream bool + PublisherID livekit.ParticipantID + PublisherIdentity livekit.ParticipantIdentity + SubscriberID livekit.ParticipantID + SubscriberIdentity livekit.ParticipantIdentity + MediaTrack types.MediaTrack + DownTrack *sfu.DownTrack + AdaptiveStream bool } type SubscribedTrack struct { @@ -75,6 +76,10 @@ func (t *SubscribedTrack) SubscriberID() livekit.ParticipantID { return t.params.SubscriberID } +func (t *SubscribedTrack) SubscriberIdentity() livekit.ParticipantIdentity { + return t.params.SubscriberIdentity +} + func (t *SubscribedTrack) DownTrack() *sfu.DownTrack { return t.params.DownTrack } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 229b51b70..151d6477a 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -74,7 +74,11 @@ type Participant interface { SubscriptionPermission() *livekit.SubscriptionPermission // updates from remotes - UpdateSubscriptionPermission(subscriptionPermission *livekit.SubscriptionPermission, resolver func(participantID livekit.ParticipantID) LocalParticipant) error + UpdateSubscriptionPermission( + subscriptionPermission *livekit.SubscriptionPermission, + resolverByIdentity func(participantIdentity livekit.ParticipantIdentity) LocalParticipant, + resolverBySid func(participantID livekit.ParticipantID) LocalParticipant, + ) error UpdateVideoLayers(updateVideoLayers *livekit.UpdateVideoLayers) error UpdateSubscribedQuality(nodeID livekit.NodeID, trackID livekit.TrackID, maxQuality livekit.VideoQuality) error UpdateMediaLoss(nodeID livekit.NodeID, trackID livekit.TrackID, fractionalLoss uint32) error @@ -203,7 +207,7 @@ type MediaTrack interface { RemoveSubscriber(participantID livekit.ParticipantID, resume bool) IsSubscriber(subID livekit.ParticipantID) bool RemoveAllSubscribers() - RevokeDisallowedSubscribers(allowedSubscriberIDs []livekit.ParticipantID) []livekit.ParticipantID + RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity GetAllSubscribers() []livekit.ParticipantID // returns quality information that's appropriate for width & height @@ -234,6 +238,7 @@ type SubscribedTrack interface { PublisherID() livekit.ParticipantID PublisherIdentity() livekit.ParticipantIdentity SubscriberID() livekit.ParticipantID + SubscriberIdentity() livekit.ParticipantIdentity DownTrack() *sfu.DownTrack MediaTrack() MediaTrack IsMuted() bool diff --git a/pkg/rtc/types/typesfakes/fake_local_media_track.go b/pkg/rtc/types/typesfakes/fake_local_media_track.go index f52efe5ed..ec04cdb08 100644 --- a/pkg/rtc/types/typesfakes/fake_local_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -187,16 +187,16 @@ type FakeLocalMediaTrack struct { restartMutex sync.RWMutex restartArgsForCall []struct { } - RevokeDisallowedSubscribersStub func([]livekit.ParticipantID) []livekit.ParticipantID + RevokeDisallowedSubscribersStub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity revokeDisallowedSubscribersMutex sync.RWMutex revokeDisallowedSubscribersArgsForCall []struct { - arg1 []livekit.ParticipantID + arg1 []livekit.ParticipantIdentity } revokeDisallowedSubscribersReturns struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity } revokeDisallowedSubscribersReturnsOnCall map[int]struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity } SdpCidStub func() string sdpCidMutex sync.RWMutex @@ -1206,16 +1206,16 @@ func (fake *FakeLocalMediaTrack) RestartCalls(stub func()) { fake.RestartStub = stub } -func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantID) []livekit.ParticipantID { - var arg1Copy []livekit.ParticipantID +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { + var arg1Copy []livekit.ParticipantIdentity if arg1 != nil { - arg1Copy = make([]livekit.ParticipantID, len(arg1)) + arg1Copy = make([]livekit.ParticipantIdentity, len(arg1)) copy(arg1Copy, arg1) } fake.revokeDisallowedSubscribersMutex.Lock() ret, specificReturn := fake.revokeDisallowedSubscribersReturnsOnCall[len(fake.revokeDisallowedSubscribersArgsForCall)] fake.revokeDisallowedSubscribersArgsForCall = append(fake.revokeDisallowedSubscribersArgsForCall, struct { - arg1 []livekit.ParticipantID + arg1 []livekit.ParticipantIdentity }{arg1Copy}) stub := fake.RevokeDisallowedSubscribersStub fakeReturns := fake.revokeDisallowedSubscribersReturns @@ -1236,39 +1236,39 @@ func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersCallCount() int { return len(fake.revokeDisallowedSubscribersArgsForCall) } -func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]livekit.ParticipantID) []livekit.ParticipantID) { +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity) { fake.revokeDisallowedSubscribersMutex.Lock() defer fake.revokeDisallowedSubscribersMutex.Unlock() fake.RevokeDisallowedSubscribersStub = stub } -func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []livekit.ParticipantID { +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []livekit.ParticipantIdentity { fake.revokeDisallowedSubscribersMutex.RLock() defer fake.revokeDisallowedSubscribersMutex.RUnlock() argsForCall := fake.revokeDisallowedSubscribersArgsForCall[i] return argsForCall.arg1 } -func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturns(result1 []livekit.ParticipantID) { +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturns(result1 []livekit.ParticipantIdentity) { fake.revokeDisallowedSubscribersMutex.Lock() defer fake.revokeDisallowedSubscribersMutex.Unlock() fake.RevokeDisallowedSubscribersStub = nil fake.revokeDisallowedSubscribersReturns = struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity }{result1} } -func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantID) { +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantIdentity) { fake.revokeDisallowedSubscribersMutex.Lock() defer fake.revokeDisallowedSubscribersMutex.Unlock() fake.RevokeDisallowedSubscribersStub = nil if fake.revokeDisallowedSubscribersReturnsOnCall == nil { fake.revokeDisallowedSubscribersReturnsOnCall = make(map[int]struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity }) } fake.revokeDisallowedSubscribersReturnsOnCall[i] = struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity }{result1} } diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 0bf3a1cbc..7cbb148d5 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -603,11 +603,12 @@ type FakeLocalParticipant struct { updateSubscribedTrackSettingsReturnsOnCall map[int]struct { result1 error } - UpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission, func(participantID livekit.ParticipantID) types.LocalParticipant) error + UpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission, func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, func(participantID livekit.ParticipantID) types.LocalParticipant) error updateSubscriptionPermissionMutex sync.RWMutex updateSubscriptionPermissionArgsForCall []struct { arg1 *livekit.SubscriptionPermission - arg2 func(participantID livekit.ParticipantID) types.LocalParticipant + arg2 func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant } updateSubscriptionPermissionReturns struct { result1 error @@ -3832,19 +3833,20 @@ func (fake *FakeLocalParticipant) UpdateSubscribedTrackSettingsReturnsOnCall(i i }{result1} } -func (fake *FakeLocalParticipant) UpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission, arg2 func(participantID livekit.ParticipantID) types.LocalParticipant) error { +func (fake *FakeLocalParticipant) UpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission, arg2 func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, arg3 func(participantID livekit.ParticipantID) types.LocalParticipant) error { fake.updateSubscriptionPermissionMutex.Lock() ret, specificReturn := fake.updateSubscriptionPermissionReturnsOnCall[len(fake.updateSubscriptionPermissionArgsForCall)] fake.updateSubscriptionPermissionArgsForCall = append(fake.updateSubscriptionPermissionArgsForCall, struct { arg1 *livekit.SubscriptionPermission - arg2 func(participantID livekit.ParticipantID) types.LocalParticipant - }{arg1, arg2}) + arg2 func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant + }{arg1, arg2, arg3}) stub := fake.UpdateSubscriptionPermissionStub fakeReturns := fake.updateSubscriptionPermissionReturns - fake.recordInvocation("UpdateSubscriptionPermission", []interface{}{arg1, arg2}) + fake.recordInvocation("UpdateSubscriptionPermission", []interface{}{arg1, arg2, arg3}) fake.updateSubscriptionPermissionMutex.Unlock() if stub != nil { - return stub(arg1, arg2) + return stub(arg1, arg2, arg3) } if specificReturn { return ret.result1 @@ -3858,17 +3860,17 @@ func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionCallCount() int { return len(fake.updateSubscriptionPermissionArgsForCall) } -func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission, func(participantID livekit.ParticipantID) types.LocalParticipant) error) { +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission, func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, func(participantID livekit.ParticipantID) types.LocalParticipant) error) { fake.updateSubscriptionPermissionMutex.Lock() defer fake.updateSubscriptionPermissionMutex.Unlock() fake.UpdateSubscriptionPermissionStub = stub } -func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionArgsForCall(i int) (*livekit.SubscriptionPermission, func(participantID livekit.ParticipantID) types.LocalParticipant) { +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionArgsForCall(i int) (*livekit.SubscriptionPermission, func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, func(participantID livekit.ParticipantID) types.LocalParticipant) { fake.updateSubscriptionPermissionMutex.RLock() defer fake.updateSubscriptionPermissionMutex.RUnlock() argsForCall := fake.updateSubscriptionPermissionArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 } func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionReturns(result1 error) { diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 9d01a3017..4417e5238 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -165,16 +165,16 @@ type FakeMediaTrack struct { restartMutex sync.RWMutex restartArgsForCall []struct { } - RevokeDisallowedSubscribersStub func([]livekit.ParticipantID) []livekit.ParticipantID + RevokeDisallowedSubscribersStub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity revokeDisallowedSubscribersMutex sync.RWMutex revokeDisallowedSubscribersArgsForCall []struct { - arg1 []livekit.ParticipantID + arg1 []livekit.ParticipantIdentity } revokeDisallowedSubscribersReturns struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity } revokeDisallowedSubscribersReturnsOnCall map[int]struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity } SetMutedStub func(bool) setMutedMutex sync.RWMutex @@ -1050,16 +1050,16 @@ func (fake *FakeMediaTrack) RestartCalls(stub func()) { fake.RestartStub = stub } -func (fake *FakeMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantID) []livekit.ParticipantID { - var arg1Copy []livekit.ParticipantID +func (fake *FakeMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { + var arg1Copy []livekit.ParticipantIdentity if arg1 != nil { - arg1Copy = make([]livekit.ParticipantID, len(arg1)) + arg1Copy = make([]livekit.ParticipantIdentity, len(arg1)) copy(arg1Copy, arg1) } fake.revokeDisallowedSubscribersMutex.Lock() ret, specificReturn := fake.revokeDisallowedSubscribersReturnsOnCall[len(fake.revokeDisallowedSubscribersArgsForCall)] fake.revokeDisallowedSubscribersArgsForCall = append(fake.revokeDisallowedSubscribersArgsForCall, struct { - arg1 []livekit.ParticipantID + arg1 []livekit.ParticipantIdentity }{arg1Copy}) stub := fake.RevokeDisallowedSubscribersStub fakeReturns := fake.revokeDisallowedSubscribersReturns @@ -1080,39 +1080,39 @@ func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCallCount() int { return len(fake.revokeDisallowedSubscribersArgsForCall) } -func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]livekit.ParticipantID) []livekit.ParticipantID) { +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity) { fake.revokeDisallowedSubscribersMutex.Lock() defer fake.revokeDisallowedSubscribersMutex.Unlock() fake.RevokeDisallowedSubscribersStub = stub } -func (fake *FakeMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []livekit.ParticipantID { +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []livekit.ParticipantIdentity { fake.revokeDisallowedSubscribersMutex.RLock() defer fake.revokeDisallowedSubscribersMutex.RUnlock() argsForCall := fake.revokeDisallowedSubscribersArgsForCall[i] return argsForCall.arg1 } -func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturns(result1 []livekit.ParticipantID) { +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturns(result1 []livekit.ParticipantIdentity) { fake.revokeDisallowedSubscribersMutex.Lock() defer fake.revokeDisallowedSubscribersMutex.Unlock() fake.RevokeDisallowedSubscribersStub = nil fake.revokeDisallowedSubscribersReturns = struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity }{result1} } -func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantID) { +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantIdentity) { fake.revokeDisallowedSubscribersMutex.Lock() defer fake.revokeDisallowedSubscribersMutex.Unlock() fake.RevokeDisallowedSubscribersStub = nil if fake.revokeDisallowedSubscribersReturnsOnCall == nil { fake.revokeDisallowedSubscribersReturnsOnCall = make(map[int]struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity }) } fake.revokeDisallowedSubscribersReturnsOnCall[i] = struct { - result1 []livekit.ParticipantID + result1 []livekit.ParticipantIdentity }{result1} } diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index dbba2c0a6..7c72b558d 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -167,11 +167,12 @@ type FakeParticipant struct { updateSubscribedQualityReturnsOnCall map[int]struct { result1 error } - UpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission, func(participantID livekit.ParticipantID) types.LocalParticipant) error + UpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission, func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, func(participantID livekit.ParticipantID) types.LocalParticipant) error updateSubscriptionPermissionMutex sync.RWMutex updateSubscriptionPermissionArgsForCall []struct { arg1 *livekit.SubscriptionPermission - arg2 func(participantID livekit.ParticipantID) types.LocalParticipant + arg2 func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant } updateSubscriptionPermissionReturns struct { result1 error @@ -1021,19 +1022,20 @@ func (fake *FakeParticipant) UpdateSubscribedQualityReturnsOnCall(i int, result1 }{result1} } -func (fake *FakeParticipant) UpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission, arg2 func(participantID livekit.ParticipantID) types.LocalParticipant) error { +func (fake *FakeParticipant) UpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission, arg2 func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, arg3 func(participantID livekit.ParticipantID) types.LocalParticipant) error { fake.updateSubscriptionPermissionMutex.Lock() ret, specificReturn := fake.updateSubscriptionPermissionReturnsOnCall[len(fake.updateSubscriptionPermissionArgsForCall)] fake.updateSubscriptionPermissionArgsForCall = append(fake.updateSubscriptionPermissionArgsForCall, struct { arg1 *livekit.SubscriptionPermission - arg2 func(participantID livekit.ParticipantID) types.LocalParticipant - }{arg1, arg2}) + arg2 func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant + }{arg1, arg2, arg3}) stub := fake.UpdateSubscriptionPermissionStub fakeReturns := fake.updateSubscriptionPermissionReturns - fake.recordInvocation("UpdateSubscriptionPermission", []interface{}{arg1, arg2}) + fake.recordInvocation("UpdateSubscriptionPermission", []interface{}{arg1, arg2, arg3}) fake.updateSubscriptionPermissionMutex.Unlock() if stub != nil { - return stub(arg1, arg2) + return stub(arg1, arg2, arg3) } if specificReturn { return ret.result1 @@ -1047,17 +1049,17 @@ func (fake *FakeParticipant) UpdateSubscriptionPermissionCallCount() int { return len(fake.updateSubscriptionPermissionArgsForCall) } -func (fake *FakeParticipant) UpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission, func(participantID livekit.ParticipantID) types.LocalParticipant) error) { +func (fake *FakeParticipant) UpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission, func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, func(participantID livekit.ParticipantID) types.LocalParticipant) error) { fake.updateSubscriptionPermissionMutex.Lock() defer fake.updateSubscriptionPermissionMutex.Unlock() fake.UpdateSubscriptionPermissionStub = stub } -func (fake *FakeParticipant) UpdateSubscriptionPermissionArgsForCall(i int) (*livekit.SubscriptionPermission, func(participantID livekit.ParticipantID) types.LocalParticipant) { +func (fake *FakeParticipant) UpdateSubscriptionPermissionArgsForCall(i int) (*livekit.SubscriptionPermission, func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, func(participantID livekit.ParticipantID) types.LocalParticipant) { fake.updateSubscriptionPermissionMutex.RLock() defer fake.updateSubscriptionPermissionMutex.RUnlock() argsForCall := fake.updateSubscriptionPermissionArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 } func (fake *FakeParticipant) UpdateSubscriptionPermissionReturns(result1 error) { diff --git a/pkg/rtc/types/typesfakes/fake_subscribed_track.go b/pkg/rtc/types/typesfakes/fake_subscribed_track.go index 5cf5de198..f06924a72 100644 --- a/pkg/rtc/types/typesfakes/fake_subscribed_track.go +++ b/pkg/rtc/types/typesfakes/fake_subscribed_track.go @@ -90,6 +90,16 @@ type FakeSubscribedTrack struct { subscriberIDReturnsOnCall map[int]struct { result1 livekit.ParticipantID } + SubscriberIdentityStub func() livekit.ParticipantIdentity + subscriberIdentityMutex sync.RWMutex + subscriberIdentityArgsForCall []struct { + } + subscriberIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + subscriberIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } UpdateSubscriberSettingsStub func(*livekit.UpdateTrackSettings) updateSubscriberSettingsMutex sync.RWMutex updateSubscriberSettingsArgsForCall []struct { @@ -538,6 +548,59 @@ func (fake *FakeSubscribedTrack) SubscriberIDReturnsOnCall(i int, result1 liveki }{result1} } +func (fake *FakeSubscribedTrack) SubscriberIdentity() livekit.ParticipantIdentity { + fake.subscriberIdentityMutex.Lock() + ret, specificReturn := fake.subscriberIdentityReturnsOnCall[len(fake.subscriberIdentityArgsForCall)] + fake.subscriberIdentityArgsForCall = append(fake.subscriberIdentityArgsForCall, struct { + }{}) + stub := fake.SubscriberIdentityStub + fakeReturns := fake.subscriberIdentityReturns + fake.recordInvocation("SubscriberIdentity", []interface{}{}) + fake.subscriberIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityCallCount() int { + fake.subscriberIdentityMutex.RLock() + defer fake.subscriberIdentityMutex.RUnlock() + return len(fake.subscriberIdentityArgsForCall) +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.subscriberIdentityMutex.Lock() + defer fake.subscriberIdentityMutex.Unlock() + fake.SubscriberIdentityStub = stub +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.subscriberIdentityMutex.Lock() + defer fake.subscriberIdentityMutex.Unlock() + fake.SubscriberIdentityStub = nil + fake.subscriberIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.subscriberIdentityMutex.Lock() + defer fake.subscriberIdentityMutex.Unlock() + fake.SubscriberIdentityStub = nil + if fake.subscriberIdentityReturnsOnCall == nil { + fake.subscriberIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.subscriberIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + func (fake *FakeSubscribedTrack) UpdateSubscriberSettings(arg1 *livekit.UpdateTrackSettings) { fake.updateSubscriberSettingsMutex.Lock() fake.updateSubscriberSettingsArgsForCall = append(fake.updateSubscriberSettingsArgsForCall, struct { @@ -615,6 +678,8 @@ func (fake *FakeSubscribedTrack) Invocations() map[string][][]interface{} { defer fake.setPublisherMutedMutex.RUnlock() fake.subscriberIDMutex.RLock() defer fake.subscriberIDMutex.RUnlock() + fake.subscriberIdentityMutex.RLock() + defer fake.subscriberIdentityMutex.RUnlock() fake.updateSubscriberSettingsMutex.RLock() defer fake.updateSubscriberSettingsMutex.RUnlock() fake.updateVideoLayerMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index f6b1a2530..50489e55c 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -10,6 +10,10 @@ import ( "github.com/livekit/livekit-server/pkg/rtc/types" ) +var ( + ErrSubscriptionPermissionNeedsId = errors.New("either participant identity or SID needed") +) + type UpTrackManagerParams struct { SID livekit.ParticipantID Logger logger.Logger @@ -24,9 +28,9 @@ type UpTrackManager struct { publishedTracks map[livekit.TrackID]types.MediaTrack subscriptionPermission *livekit.SubscriptionPermission // subscriber permission for published tracks - subscriberPermissions map[livekit.ParticipantID]*livekit.TrackPermission // subscriberID => *livekit.TrackPermission + subscriberPermissions map[livekit.ParticipantIdentity]*livekit.TrackPermission // subscriberIdentity => *livekit.TrackPermission // keeps tracks of track specific subscribers who are awaiting permission - pendingSubscriptions map[livekit.TrackID][]livekit.ParticipantID // trackID => []subscriberID + pendingSubscriptions map[livekit.TrackID][]livekit.ParticipantIdentity // trackID => []subscriberIdentity lock sync.RWMutex @@ -39,7 +43,7 @@ func NewUpTrackManager(params UpTrackManagerParams) *UpTrackManager { return &UpTrackManager{ params: params, publishedTracks: make(map[livekit.TrackID]types.MediaTrack), - pendingSubscriptions: make(map[livekit.TrackID][]livekit.ParticipantID), + pendingSubscriptions: make(map[livekit.TrackID][]livekit.ParticipantIdentity), } } @@ -121,8 +125,8 @@ func (u *UpTrackManager) AddSubscriber(sub types.LocalParticipant, params types. n := 0 for _, track := range tracks { trackID := track.ID() - subscriberID := sub.ID() - if !u.hasPermission(trackID, subscriberID) { + subscriberIdentity := sub.Identity() + if !u.hasPermission(trackID, subscriberIdentity) { u.lock.Lock() u.maybeAddPendingSubscription(trackID, sub) u.lock.Unlock() @@ -188,7 +192,8 @@ func (u *UpTrackManager) GetPublishedTracks() []types.MediaTrack { func (u *UpTrackManager) UpdateSubscriptionPermission( subscriptionPermission *livekit.SubscriptionPermission, - resolver func(participantID livekit.ParticipantID) types.LocalParticipant, + resolverByIdentity func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant, + resolverBySid func(participantID livekit.ParticipantID) types.LocalParticipant, ) error { u.lock.Lock() defer u.lock.Unlock() @@ -200,11 +205,15 @@ func (u *UpTrackManager) UpdateSubscriptionPermission( return nil } - u.parseSubscriptionPermissions(subscriptionPermission) + if err := u.parseSubscriptionPermissions(subscriptionPermission, resolverBySid); err != nil { + // do not accept permissions if parse fails + u.subscriptionPermission = nil + return err + } - u.processPendingSubscriptions(resolver) + u.processPendingSubscriptions(resolverByIdentity) - u.maybeRevokeSubscriptions(resolver) + u.maybeRevokeSubscriptions(resolverByIdentity) return nil } @@ -298,29 +307,63 @@ func (u *UpTrackManager) getPublishedTrack(trackID livekit.TrackID) types.MediaT return u.publishedTracks[trackID] } -func (u *UpTrackManager) parseSubscriptionPermissions(subscriptionPermission *livekit.SubscriptionPermission) { +func (u *UpTrackManager) parseSubscriptionPermissions( + subscriptionPermission *livekit.SubscriptionPermission, + resolver func(participantID livekit.ParticipantID) types.LocalParticipant, +) error { // every update overrides the existing // all_participants takes precedence if subscriptionPermission.AllParticipants { // everything is allowed, nothing else to do u.subscriberPermissions = nil - return + return nil } // per participant permissions - u.subscriberPermissions = make(map[livekit.ParticipantID]*livekit.TrackPermission) + u.subscriberPermissions = make(map[livekit.ParticipantIdentity]*livekit.TrackPermission) for _, trackPerms := range subscriptionPermission.TrackPermissions { - u.subscriberPermissions[livekit.ParticipantID(trackPerms.ParticipantSid)] = trackPerms + subscriberIdentity := livekit.ParticipantIdentity(trackPerms.ParticipantIdentity) + if subscriberIdentity == "" { + if trackPerms.ParticipantSid == "" { + u.subscriberPermissions = nil + return ErrSubscriptionPermissionNeedsId + } + + var sub types.LocalParticipant + if resolver != nil { + sub = resolver(livekit.ParticipantID(trackPerms.ParticipantSid)) + } + if sub == nil { + u.params.Logger.Warnw("could not find subscriber for permissions update", nil, "subscriberID", trackPerms.ParticipantSid) + continue + } + + subscriberIdentity = sub.Identity() + } else { + if trackPerms.ParticipantSid != "" { + sub := resolver(livekit.ParticipantID(trackPerms.ParticipantSid)) + if sub != nil && sub.Identity() != subscriberIdentity { + u.params.Logger.Errorw("participant identity mismatch", nil, "expected", subscriberIdentity, "got", sub.Identity()) + } + if sub == nil { + u.params.Logger.Warnw("could not find subscriber for permissions update", nil, "subscriberID", trackPerms.ParticipantSid) + } + } + } + + u.subscriberPermissions[subscriberIdentity] = trackPerms } + + return nil } -func (u *UpTrackManager) hasPermission(trackID livekit.TrackID, subscriberID livekit.ParticipantID) bool { +func (u *UpTrackManager) hasPermission(trackID livekit.TrackID, subscriberIdentity livekit.ParticipantIdentity) bool { if u.subscriberPermissions == nil { return true } - perms, ok := u.subscriberPermissions[subscriberID] + perms, ok := u.subscriberPermissions[subscriberIdentity] if !ok { return false } @@ -338,21 +381,21 @@ func (u *UpTrackManager) hasPermission(trackID livekit.TrackID, subscriberID liv return false } -func (u *UpTrackManager) getAllowedSubscribers(trackID livekit.TrackID) []livekit.ParticipantID { +func (u *UpTrackManager) getAllowedSubscribers(trackID livekit.TrackID) []livekit.ParticipantIdentity { if u.subscriberPermissions == nil { return nil } - allowed := make([]livekit.ParticipantID, 0) - for subscriberID, perms := range u.subscriberPermissions { + allowed := make([]livekit.ParticipantIdentity, 0) + for subscriberIdentity, perms := range u.subscriberPermissions { if perms.AllTracks { - allowed = append(allowed, subscriberID) + allowed = append(allowed, subscriberIdentity) continue } for _, sid := range perms.TrackSids { if livekit.TrackID(sid) == trackID { - allowed = append(allowed, subscriberID) + allowed = append(allowed, subscriberIdentity) break } } @@ -362,27 +405,27 @@ func (u *UpTrackManager) getAllowedSubscribers(trackID livekit.TrackID) []liveki } func (u *UpTrackManager) maybeAddPendingSubscription(trackID livekit.TrackID, sub types.LocalParticipant) { - subscriberID := sub.ID() + subscriberIdentity := sub.Identity() pending := u.pendingSubscriptions[trackID] - for _, sid := range pending { - if sid == subscriberID { + for _, identity := range pending { + if identity == subscriberIdentity { // already pending return } } - u.pendingSubscriptions[trackID] = append(u.pendingSubscriptions[trackID], subscriberID) + u.pendingSubscriptions[trackID] = append(u.pendingSubscriptions[trackID], subscriberIdentity) go sub.SubscriptionPermissionUpdate(u.params.SID, trackID, false) } func (u *UpTrackManager) maybeRemovePendingSubscription(trackID livekit.TrackID, sub types.LocalParticipant) { - subscriberID := sub.ID() + subscriberIdentity := sub.Identity() pending := u.pendingSubscriptions[trackID] n := len(pending) - for idx, sid := range pending { - if sid == subscriberID { + for idx, identity := range pending { + if identity == subscriberIdentity { u.pendingSubscriptions[trackID][idx] = u.pendingSubscriptions[trackID][n-1] u.pendingSubscriptions[trackID] = u.pendingSubscriptions[trackID][:n-1] break @@ -393,34 +436,34 @@ func (u *UpTrackManager) maybeRemovePendingSubscription(trackID livekit.TrackID, } } -func (u *UpTrackManager) processPendingSubscriptions(resolver func(participantID livekit.ParticipantID) types.LocalParticipant) { - updatedPendingSubscriptions := make(map[livekit.TrackID][]livekit.ParticipantID) +func (u *UpTrackManager) processPendingSubscriptions(resolver func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant) { + updatedPendingSubscriptions := make(map[livekit.TrackID][]livekit.ParticipantIdentity) for trackID, pending := range u.pendingSubscriptions { track := u.getPublishedTrack(trackID) if track == nil { continue } - var updatedPending []livekit.ParticipantID - for _, sid := range pending { + var updatedPending []livekit.ParticipantIdentity + for _, identity := range pending { var sub types.LocalParticipant if resolver != nil { - sub = resolver(sid) + sub = resolver(identity) } if sub == nil || sub.State() == livekit.ParticipantInfo_DISCONNECTED { // do not keep this pending subscription as subscriber may be gone continue } - if !u.hasPermission(trackID, sid) { - updatedPending = append(updatedPending, sid) + if !u.hasPermission(trackID, identity) { + updatedPending = append(updatedPending, identity) continue } if err := track.AddSubscriber(sub); err != nil { u.params.Logger.Errorw("error reinstating pending subscription", err) // keep it in pending on error in case the error is transient - updatedPending = append(updatedPending, sid) + updatedPending = append(updatedPending, identity) continue } @@ -433,7 +476,7 @@ func (u *UpTrackManager) processPendingSubscriptions(resolver func(participantID u.pendingSubscriptions = updatedPendingSubscriptions } -func (u *UpTrackManager) maybeRevokeSubscriptions(resolver func(participantID livekit.ParticipantID) types.LocalParticipant) { +func (u *UpTrackManager) maybeRevokeSubscriptions(resolver func(participantIdentity livekit.ParticipantIdentity) types.LocalParticipant) { for _, track := range u.publishedTracks { trackID := track.ID() allowed := u.getAllowedSubscribers(trackID) @@ -443,10 +486,10 @@ func (u *UpTrackManager) maybeRevokeSubscriptions(resolver func(participantID li } revokedSubscribers := track.RevokeDisallowedSubscribers(allowed) - for _, subID := range revokedSubscribers { + for _, subIdentity := range revokedSubscribers { var sub types.LocalParticipant if resolver != nil { - sub = resolver(subID) + sub = resolver(subIdentity) } if sub == nil { continue diff --git a/pkg/rtc/uptrackmanager_test.go b/pkg/rtc/uptrackmanager_test.go index 31d49d211..3f729435b 100644 --- a/pkg/rtc/uptrackmanager_test.go +++ b/pkg/rtc/uptrackmanager_test.go @@ -7,6 +7,7 @@ import ( "github.com/livekit/protocol/livekit" + "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" ) @@ -26,17 +27,34 @@ func TestUpdateSubscriptionPermission(t *testing.T) { subscriptionPermission := &livekit.SubscriptionPermission{ AllParticipants: true, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.Nil(t, um.subscriberPermissions) // nobody is allowed to subscribe subscriptionPermission = &livekit.SubscriptionPermission{ TrackPermissions: []*livekit.TrackPermission{}, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.NotNil(t, um.subscriberPermissions) require.Equal(t, 0, len(um.subscriberPermissions)) + lp1 := &typesfakes.FakeLocalParticipant{} + lp1.IdentityReturns("p1") + lp2 := &typesfakes.FakeLocalParticipant{} + lp2.IdentityReturns("p2") + + sidResolver := func(sid livekit.ParticipantID) types.LocalParticipant { + if sid == "p1" { + return lp1 + } + + if sid == "p2" { + return lp2 + } + + return nil + } + // allow all tracks for participants perms1 := &livekit.TrackPermission{ ParticipantSid: "p1", @@ -52,23 +70,23 @@ func TestUpdateSubscriptionPermission(t *testing.T) { perms2, }, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, sidResolver) require.Equal(t, 2, len(um.subscriberPermissions)) require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) // allow all tracks for some and restrictive for others perms1 = &livekit.TrackPermission{ - ParticipantSid: "p1", - AllTracks: true, + ParticipantIdentity: "p1", + AllTracks: true, } perms2 = &livekit.TrackPermission{ - ParticipantSid: "p2", - TrackSids: []string{"audio"}, + ParticipantIdentity: "p2", + TrackSids: []string{"audio"}, } perms3 := &livekit.TrackPermission{ - ParticipantSid: "p3", - TrackSids: []string{"video"}, + ParticipantIdentity: "p3", + TrackSids: []string{"video"}, } subscriptionPermission = &livekit.SubscriptionPermission{ TrackPermissions: []*livekit.TrackPermission{ @@ -77,12 +95,83 @@ func TestUpdateSubscriptionPermission(t *testing.T) { perms3, }, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.Equal(t, 3, len(um.subscriberPermissions)) require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) require.EqualValues(t, perms3, um.subscriberPermissions["p3"]) }) + + t.Run("updates subscription permission using both", func(t *testing.T) { + um := NewUpTrackManager(UpTrackManagerParams{}) + + tra := &typesfakes.FakeMediaTrack{} + tra.IDReturns("audio") + um.publishedTracks["audio"] = tra + + trv := &typesfakes.FakeMediaTrack{} + trv.IDReturns("video") + um.publishedTracks["video"] = trv + + lp1 := &typesfakes.FakeLocalParticipant{} + lp1.IdentityReturns("p1") + lp2 := &typesfakes.FakeLocalParticipant{} + lp2.IdentityReturns("p2") + + sidResolver := func(sid livekit.ParticipantID) types.LocalParticipant { + if sid == "p1" { + return lp1 + } + + if sid == "p2" { + return lp2 + } + + return nil + } + + // allow all tracks for participants + perms1 := &livekit.TrackPermission{ + ParticipantSid: "p1", + ParticipantIdentity: "p1", + AllTracks: true, + } + perms2 := &livekit.TrackPermission{ + ParticipantSid: "p2", + ParticipantIdentity: "p2", + AllTracks: true, + } + subscriptionPermission := &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{ + perms1, + perms2, + }, + } + err := um.UpdateSubscriptionPermission(subscriptionPermission, nil, sidResolver) + require.NoError(t, err) + require.Equal(t, 2, len(um.subscriberPermissions)) + require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) + + // mismatched identities should fail a permission update + badSidResolver := func(sid livekit.ParticipantID) types.LocalParticipant { + if sid == "p1" { + return lp2 + } + + if sid == "p2" { + return lp1 + } + + return nil + } + + err = um.UpdateSubscriptionPermission(subscriptionPermission, nil, badSidResolver) + require.NoError(t, err) + require.Equal(t, 2, len(um.subscriberPermissions)) + require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) + }) } func TestSubscriptionPermission(t *testing.T) { @@ -101,7 +190,7 @@ func TestSubscriptionPermission(t *testing.T) { subscriptionPermission := &livekit.SubscriptionPermission{ AllParticipants: true, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.True(t, um.hasPermission("audio", "p1")) require.True(t, um.hasPermission("audio", "p2")) @@ -109,7 +198,7 @@ func TestSubscriptionPermission(t *testing.T) { subscriptionPermission = &livekit.SubscriptionPermission{ TrackPermissions: []*livekit.TrackPermission{}, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.False(t, um.hasPermission("audio", "p1")) require.False(t, um.hasPermission("audio", "p2")) @@ -117,16 +206,16 @@ func TestSubscriptionPermission(t *testing.T) { subscriptionPermission = &livekit.SubscriptionPermission{ TrackPermissions: []*livekit.TrackPermission{ { - ParticipantSid: "p1", - AllTracks: true, + ParticipantIdentity: "p1", + AllTracks: true, }, { - ParticipantSid: "p2", - AllTracks: true, + ParticipantIdentity: "p2", + AllTracks: true, }, }, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.True(t, um.hasPermission("audio", "p1")) require.True(t, um.hasPermission("video", "p1")) require.True(t, um.hasPermission("audio", "p2")) @@ -148,20 +237,20 @@ func TestSubscriptionPermission(t *testing.T) { subscriptionPermission = &livekit.SubscriptionPermission{ TrackPermissions: []*livekit.TrackPermission{ { - ParticipantSid: "p1", - AllTracks: true, + ParticipantIdentity: "p1", + AllTracks: true, }, { - ParticipantSid: "p2", - TrackSids: []string{"audio"}, + ParticipantIdentity: "p2", + TrackSids: []string{"audio"}, }, { - ParticipantSid: "p3", - TrackSids: []string{"video"}, + ParticipantIdentity: "p3", + TrackSids: []string{"video"}, }, }, } - um.UpdateSubscriptionPermission(subscriptionPermission, nil) + um.UpdateSubscriptionPermission(subscriptionPermission, nil, nil) require.True(t, um.hasPermission("audio", "p1")) require.True(t, um.hasPermission("video", "p1")) require.True(t, um.hasPermission("screen", "p1"))