From 9747243ce2ca84d7c520bbf822b891dbd844f6b5 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Fri, 28 Jan 2022 09:55:10 -0800 Subject: [PATCH] Honor autoSubscribe when subscription permissions are granted later (#381) * Ensure autosubscribe is honored when subscription permissions were granted later * negotiate even if no media has been added * don't double-negotiate --- pkg/rtc/participant.go | 26 +++++---- pkg/rtc/participant_internal_test.go | 53 ++++++++++++++----- pkg/rtc/room.go | 21 +++++++- pkg/rtc/types/interfaces.go | 2 +- pkg/rtc/types/typesfakes/fake_room.go | 76 +++++++++++++++++++++++++++ pkg/service/roommanager.go | 10 ++-- test/client/client.go | 2 +- test/singlenode_test.go | 58 ++++++++++++++++++++ test/webhook_test.go | 18 ++++++- 9 files changed, 231 insertions(+), 35 deletions(-) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 9ffe35479..2186228fa 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -61,13 +61,14 @@ type ParticipantParams struct { } type ParticipantImpl struct { - params ParticipantParams - publisher *PCTransport - subscriber *PCTransport - isClosed utils.AtomicFlag - permission *livekit.ParticipantPermission - state atomic.Value // livekit.ParticipantInfo_State - updateCache *lru.Cache + params ParticipantParams + publisher *PCTransport + subscriber *PCTransport + isClosed utils.AtomicFlag + permission *livekit.ParticipantPermission + state atomic.Value // livekit.ParticipantInfo_State + updateCache *lru.Cache + subscriberAsPrimary bool // reliable and unreliable data channels reliableDC *webrtc.DataChannel @@ -119,7 +120,7 @@ type ParticipantImpl struct { onClaimsChanged func(participant types.LocalParticipant) } -func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { +func NewParticipant(params ParticipantParams, perms *livekit.ParticipantPermission) (*ParticipantImpl, error) { // TODO: check to ensure params are valid, id and identity can't be empty p := &ParticipantImpl{ @@ -134,6 +135,7 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { } p.migrateState.Store(types.MigrateStateInit) p.state.Store(livekit.ParticipantInfo_JOINING) + p.SetPermission(perms) var err error // keep last participants and when updates were sent @@ -182,7 +184,9 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { }) primaryPC := p.publisher.pc - + // primary connection does not change, canSubscribe can change if permission was updated + // after the participant has joined + p.subscriberAsPrimary = p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe() if p.SubscriberAsPrimary() { primaryPC = p.subscriber.pc ordered := true @@ -271,7 +275,7 @@ func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermissio p.permission = permission // update grants with this - if p.params.Grants != nil && p.params.Grants.Video != nil { + if p.params.Grants != nil && p.params.Grants.Video != nil && permission != nil { video := p.params.Grants.Video video.SetCanSubscribe(permission.CanSubscribe) video.SetCanPublish(permission.CanPublish) @@ -788,7 +792,7 @@ func (p *ParticipantImpl) IsRecorder() bool { } func (p *ParticipantImpl) SubscriberAsPrimary() bool { - return p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe() + return p.subscriberAsPrimary } func (p *ParticipantImpl) SubscriberPC() *webrtc.PeerConnection { diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index b10101ede..9e6430d50 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -319,35 +319,56 @@ func TestConnectionQuality(t *testing.T) { func TestSubscriberAsPrimary(t *testing.T) { t.Run("protocol 4 uses subs as primary", func(t *testing.T) { - p := newParticipantForTest("test") - p.SetPermission(&livekit.ParticipantPermission{ - CanSubscribe: true, - CanPublish: true, + p := newParticipantForTestWithOpts("test", &participantOpts{ + permissions: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, }) require.True(t, p.SubscriberAsPrimary()) }) t.Run("protocol 2 uses pub as primary", func(t *testing.T) { - p := newParticipantForTest("test") - p.params.ProtocolVersion = 2 - p.SetPermission(&livekit.ParticipantPermission{ - CanSubscribe: true, - CanPublish: true, + p := newParticipantForTestWithOpts("test", &participantOpts{ + protocolVersion: 2, + permissions: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, }) require.False(t, p.SubscriberAsPrimary()) }) t.Run("publisher only uses pub as primary", func(t *testing.T) { - p := newParticipantForTest("test") + p := newParticipantForTestWithOpts("test", &participantOpts{ + permissions: &livekit.ParticipantPermission{ + CanSubscribe: false, + CanPublish: true, + }, + }) + require.False(t, p.SubscriberAsPrimary()) + + // ensure that it doesn't change after perms p.SetPermission(&livekit.ParticipantPermission{ - CanSubscribe: false, + CanSubscribe: true, CanPublish: true, }) require.False(t, p.SubscriberAsPrimary()) }) } -func newParticipantForTest(identity livekit.ParticipantIdentity) *ParticipantImpl { +type participantOpts struct { + permissions *livekit.ParticipantPermission + protocolVersion types.ProtocolVersion +} + +func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *participantOpts) *ParticipantImpl { + if opts == nil { + opts = &participantOpts{} + } + if opts.protocolVersion == 0 { + opts.protocolVersion = 6 + } conf, _ := config.NewConfig("", nil) // disable mux, it doesn't play too well with unit test conf.RTC.UDPPort = 0 @@ -360,8 +381,12 @@ func newParticipantForTest(identity livekit.ParticipantIdentity) *ParticipantImp Identity: identity, Config: rtcConf, Sink: &routingfakes.FakeMessageSink{}, - ProtocolVersion: 4, + ProtocolVersion: opts.protocolVersion, ThrottleConfig: conf.RTC.PLIThrottle, - }) + }, opts.permissions) return p } + +func newParticipantForTest(identity livekit.ParticipantIdentity) *ParticipantImpl { + return newParticipantForTestWithOpts(identity, nil) +} diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index e7c82a3bd..5b5e0c8ad 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -410,6 +410,22 @@ func (r *Room) RemoveDisallowedSubscriptions(sub types.LocalParticipant, disallo } } +func (r *Room) SetParticipantPermission(participant types.LocalParticipant, permission *livekit.ParticipantPermission) error { + hadCanSubscribe := participant.CanSubscribe() + participant.SetPermission(permission) + // when subscribe perms are given, trigger autosub + if !hadCanSubscribe && participant.CanSubscribe() { + if participant.State() == livekit.ParticipantInfo_ACTIVE { + if r.subscribeToExistingTracks(participant) == 0 { + // start negotiating even if there are other media tracks to subscribe + // we'll need to set the participant up to receive data + participant.Negotiate() + } + } + } + return nil +} + func (r *Room) UpdateVideoLayers(participant types.Participant, updateVideoLayers *livekit.UpdateVideoLayers) error { return participant.UpdateVideoLayers(updateVideoLayers) } @@ -651,12 +667,12 @@ func (r *Room) onDataPacket(source types.LocalParticipant, dp *livekit.DataPacke } } -func (r *Room) subscribeToExistingTracks(p types.LocalParticipant) { +func (r *Room) subscribeToExistingTracks(p types.LocalParticipant) int { r.lock.RLock() shouldSubscribe := r.autoSubscribe(p) r.lock.RUnlock() if !shouldSubscribe { - return + return 0 } tracksAdded := 0 @@ -679,6 +695,7 @@ func (r *Room) subscribeToExistingTracks(p types.LocalParticipant) { if tracksAdded > 0 { r.Logger.Debugw("subscribed participants to existing tracks", "tracks", tracksAdded) } + return tracksAdded } // broadcast an update about participant p diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index ad30b88b4..3aa48070a 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -161,7 +161,7 @@ type Room interface { UpdateSubscriptionPermission(participant LocalParticipant, permissions *livekit.SubscriptionPermission) error SyncState(participant LocalParticipant, state *livekit.SyncState) error SimulateScenario(participant LocalParticipant, scenario *livekit.SimulateScenario) error - + SetParticipantPermission(participant LocalParticipant, permission *livekit.ParticipantPermission) error UpdateVideoLayers(participant Participant, updateVideoLayers *livekit.UpdateVideoLayers) error } diff --git a/pkg/rtc/types/typesfakes/fake_room.go b/pkg/rtc/types/typesfakes/fake_room.go index de7970d95..6dcea7862 100644 --- a/pkg/rtc/types/typesfakes/fake_room.go +++ b/pkg/rtc/types/typesfakes/fake_room.go @@ -29,6 +29,18 @@ type FakeRoom struct { nameReturnsOnCall map[int]struct { result1 livekit.RoomName } + SetParticipantPermissionStub func(types.LocalParticipant, *livekit.ParticipantPermission) error + setParticipantPermissionMutex sync.RWMutex + setParticipantPermissionArgsForCall []struct { + arg1 types.LocalParticipant + arg2 *livekit.ParticipantPermission + } + setParticipantPermissionReturns struct { + result1 error + } + setParticipantPermissionReturnsOnCall map[int]struct { + result1 error + } SimulateScenarioStub func(types.LocalParticipant, *livekit.SimulateScenario) error simulateScenarioMutex sync.RWMutex simulateScenarioArgsForCall []struct { @@ -201,6 +213,68 @@ func (fake *FakeRoom) NameReturnsOnCall(i int, result1 livekit.RoomName) { }{result1} } +func (fake *FakeRoom) SetParticipantPermission(arg1 types.LocalParticipant, arg2 *livekit.ParticipantPermission) error { + fake.setParticipantPermissionMutex.Lock() + ret, specificReturn := fake.setParticipantPermissionReturnsOnCall[len(fake.setParticipantPermissionArgsForCall)] + fake.setParticipantPermissionArgsForCall = append(fake.setParticipantPermissionArgsForCall, struct { + arg1 types.LocalParticipant + arg2 *livekit.ParticipantPermission + }{arg1, arg2}) + stub := fake.SetParticipantPermissionStub + fakeReturns := fake.setParticipantPermissionReturns + fake.recordInvocation("SetParticipantPermission", []interface{}{arg1, arg2}) + fake.setParticipantPermissionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) SetParticipantPermissionCallCount() int { + fake.setParticipantPermissionMutex.RLock() + defer fake.setParticipantPermissionMutex.RUnlock() + return len(fake.setParticipantPermissionArgsForCall) +} + +func (fake *FakeRoom) SetParticipantPermissionCalls(stub func(types.LocalParticipant, *livekit.ParticipantPermission) error) { + fake.setParticipantPermissionMutex.Lock() + defer fake.setParticipantPermissionMutex.Unlock() + fake.SetParticipantPermissionStub = stub +} + +func (fake *FakeRoom) SetParticipantPermissionArgsForCall(i int) (types.LocalParticipant, *livekit.ParticipantPermission) { + fake.setParticipantPermissionMutex.RLock() + defer fake.setParticipantPermissionMutex.RUnlock() + argsForCall := fake.setParticipantPermissionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRoom) SetParticipantPermissionReturns(result1 error) { + fake.setParticipantPermissionMutex.Lock() + defer fake.setParticipantPermissionMutex.Unlock() + fake.SetParticipantPermissionStub = nil + fake.setParticipantPermissionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRoom) SetParticipantPermissionReturnsOnCall(i int, result1 error) { + fake.setParticipantPermissionMutex.Lock() + defer fake.setParticipantPermissionMutex.Unlock() + fake.SetParticipantPermissionStub = nil + if fake.setParticipantPermissionReturnsOnCall == nil { + fake.setParticipantPermissionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setParticipantPermissionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeRoom) SimulateScenario(arg1 types.LocalParticipant, arg2 *livekit.SimulateScenario) error { fake.simulateScenarioMutex.Lock() ret, specificReturn := fake.simulateScenarioReturnsOnCall[len(fake.simulateScenarioArgsForCall)] @@ -530,6 +604,8 @@ func (fake *FakeRoom) Invocations() map[string][][]interface{} { defer fake.iDMutex.RUnlock() fake.nameMutex.RLock() defer fake.nameMutex.RUnlock() + fake.setParticipantPermissionMutex.RLock() + defer fake.setParticipantPermissionMutex.RUnlock() fake.simulateScenarioMutex.RLock() defer fake.simulateScenarioMutex.RUnlock() fake.syncStateMutex.RLock() diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 4f08219a3..5977855df 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -248,7 +248,7 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName livekit.RoomNam Grants: pi.Grants, Hidden: pi.Hidden, Logger: pLogger, - }) + }, pi.Permission) if err != nil { logger.Errorw("could not create participant", err) return @@ -257,9 +257,6 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName livekit.RoomNam if pi.Metadata != "" { participant.SetMetadata(pi.Metadata) } - if pi.Permission != nil { - participant.SetPermission(pi.Permission) - } // join room opts := rtc.ParticipantOptions{ @@ -461,7 +458,10 @@ func (r *RoomManager) handleRTCMessage(_ context.Context, roomName livekit.RoomN participant.SetMetadata(rm.UpdateParticipant.Metadata) } if rm.UpdateParticipant.Permission != nil { - participant.SetPermission(rm.UpdateParticipant.Permission) + err := room.SetParticipantPermission(participant, rm.UpdateParticipant.Permission) + if err != nil { + pLogger.Errorw("could not update permissions", err) + } } case *livekit.RTCNodeMessage_DeleteRoom: for _, p := range room.GetParticipants() { diff --git a/test/client/client.go b/test/client/client.go index 22954ab43..6ab69d4c5 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -89,7 +89,7 @@ type Options struct { } func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) { - u, err := url.Parse(host + "/rtc?protocol=3") + u, err := url.Parse(host + "/rtc?protocol=6") if err != nil { return nil, err } diff --git a/test/singlenode_test.go b/test/singlenode_test.go index 1b62439e8..393a8f9de 100644 --- a/test/singlenode_test.go +++ b/test/singlenode_test.go @@ -341,6 +341,10 @@ func TestSingleNodeJoinAfterClose(t *testing.T) { } func TestAutoCreate(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } disableAutoCreate := func(conf *config.Config) { conf.Room.AutoCreate = false } @@ -381,3 +385,57 @@ func TestAutoCreate(t *testing.T) { c1.Stop() }) } + +// don't give user subscribe permissions initially, and ensure autosubscribe is triggered afterwards +func TestSingleNodeUpdateSubscriptionPermissions(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestSingleNodeUpdateSubscriptionPermissions") + defer finish() + + pub := createRTCClient("pub", defaultServerPort, nil) + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanSubscribe(false) + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(grant). + SetIdentity("sub") + token, err := at.ToJWT() + require.NoError(t, err) + sub := createRTCClientWithToken(token, defaultServerPort, nil) + + waitUntilConnected(t, pub, sub) + + writers := publishTracksForClients(t, pub) + defer stopWriters(writers...) + + // wait sub receives tracks + testutils.WithTimeout(t, "waiting for sub to receive track metadata", func() bool { + pubRemote := sub.GetRemoteParticipant(pub.ID()) + if pubRemote == nil { + return false + } + if len(pubRemote.Tracks) != 2 { + return false + } + return true + }) + + // set permissions out of band + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err = roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "sub", + Permission: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, + }) + require.NoError(t, err) + + testutils.WithTimeout(t, "waiting to get subscriptions", func() bool { + tracks := sub.SubscribedTracks()[pub.ID()] + return len(tracks) == 2 + }) +} diff --git a/test/webhook_test.go b/test/webhook_test.go index 4878d116c..7d13aa5a6 100644 --- a/test/webhook_test.go +++ b/test/webhook_test.go @@ -2,6 +2,7 @@ package test import ( "context" + "errors" "fmt" "net" "net/http" @@ -184,7 +185,22 @@ func (s *webhookTestServer) Start() error { return err } go s.server.Serve(l) - return nil + + // wait for webhook server to start + ctx, cancel := context.WithTimeout(context.Background(), testutils.ConnectTimeout) + defer cancel() + for { + select { + case <-ctx.Done(): + return errors.New("could not start webhook server after timeout") + case <-time.After(10 * time.Millisecond): + // ensure we can connect to it + res, err := http.Get(fmt.Sprintf("http://localhost%s", s.server.Addr)) + if err == nil && res.StatusCode == http.StatusOK { + return nil + } + } + } } func (s *webhookTestServer) Stop() {