From 52fc53d3257b8efff5ab1081029cfb3bd2af1d1f Mon Sep 17 00:00:00 2001 From: David Zhao Date: Sun, 23 Jan 2022 23:15:49 -0800 Subject: [PATCH] Issue updated tokens to clients. (#365) This ensures client reconnect attempts would be successful for long running rooms. It also fixes inaccurate permissions that were set incorrectly when full reconnections take place. --- go.mod | 2 +- go.sum | 6 +- pkg/routing/interfaces.go | 2 + pkg/routing/redisrouter.go | 15 ++ pkg/rtc/participant.go | 51 ++++- pkg/rtc/types/interfaces.go | 6 +- .../typesfakes/fake_local_participant.go | 179 ++++++++++++++++++ pkg/service/roommanager.go | 46 ++++- pkg/service/rtcservice.go | 1 + pkg/service/wire.go | 6 +- pkg/service/wire_gen.go | 6 +- test/client/client.go | 11 ++ test/multinode_test.go | 40 ++++ 13 files changed, 356 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 26d2ce113..40ff48f96 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/google/wire v0.5.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.11.11-0.20220122062547-e3f90e29577a + github.com/livekit/protocol v0.11.11 github.com/magefile/mage v1.11.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0 github.com/mitchellh/go-homedir v1.1.0 diff --git a/go.sum b/go.sum index bbea36506..0d57beb05 100644 --- a/go.sum +++ b/go.sum @@ -132,10 +132,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY/mKrcvA= github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= -github.com/livekit/protocol v0.11.11-0.20220120192814-bde53c19d1bd h1:FUaZaJafv7ljzw4R9eMpwjrGVo1aSuTswmvP8trNegA= -github.com/livekit/protocol v0.11.11-0.20220120192814-bde53c19d1bd/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= -github.com/livekit/protocol v0.11.11-0.20220122062547-e3f90e29577a h1:+TO/0De0NzkklSvooK6fyqJJ1jOIAEsC5K5VCK3Nqz8= -github.com/livekit/protocol v0.11.11-0.20220122062547-e3f90e29577a/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= +github.com/livekit/protocol v0.11.11 h1:je6yFjRMtDULH1Ir6d6PhX3ii676NGH7bUru7xmqGZ0= +github.com/livekit/protocol v0.11.11/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls= github.com/magefile/mage v1.11.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index e55a24368..b5e1e1d1e 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -4,6 +4,7 @@ import ( "context" "github.com/go-redis/redis/v8" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "google.golang.org/protobuf/proto" @@ -36,6 +37,7 @@ type ParticipantInit struct { Hidden bool Recorder bool Client *livekit.ClientInfo + Grants *auth.ClaimGrants } type NewParticipantCallback func(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, requestSource MessageSource, responseSink MessageSink) diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 6fff9a157..55a44dd73 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -2,9 +2,11 @@ package routing import ( "context" + "encoding/json" "time" "github.com/go-redis/redis/v8" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils" @@ -146,6 +148,12 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek sink := NewRTCNodeSink(r.rc, rtcNode.Id, pKey) + // serialize claims + claims, err := json.Marshal(pi.Grants) + if err != nil { + return + } + // sends a message to start session err = sink.WriteMessage(&livekit.StartSession{ RoomName: string(roomName), @@ -160,6 +168,7 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek Hidden: pi.Hidden, Recorder: pi.Recorder, Client: pi.Client, + GrantsJson: string(claims), }) if err != nil { return @@ -238,6 +247,11 @@ func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantK } } + claims := &auth.ClaimGrants{} + if err := json.Unmarshal([]byte(ss.GrantsJson), claims); err != nil { + return err + } + pi := ParticipantInit{ Identity: livekit.ParticipantIdentity(ss.Identity), Metadata: ss.Metadata, @@ -248,6 +262,7 @@ func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantK AutoSubscribe: ss.AutoSubscribe, Hidden: ss.Hidden, Recorder: ss.Recorder, + Grants: claims, } reqChan := r.getOrCreateMessageChannel(r.requestChannels, participantKey) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index a0b35015e..7b8a69734 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -10,6 +10,7 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/connectionquality" "github.com/livekit/livekit-server/pkg/sfu/twcc" + "github.com/livekit/protocol/auth" lru "github.com/hashicorp/golang-lru" "github.com/livekit/protocol/livekit" @@ -56,6 +57,7 @@ type ParticipantParams struct { Recorder bool Logger logger.Logger SimTracks map[uint32]SimulcastTrackInfo + Grants *auth.ClaimGrants } type ParticipantImpl struct { @@ -109,9 +111,10 @@ type ParticipantImpl struct { onMetadataUpdate func(types.LocalParticipant) onDataPacket func(types.LocalParticipant, *livekit.DataPacket) - migrateState atomic.Value // types.MigrateState - pendingOffer *webrtc.SessionDescription - onClose func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID) + migrateState atomic.Value // types.MigrateState + pendingOffer *webrtc.SessionDescription + onClose func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID) + onClaimsChanged func(participant types.LocalParticipant) } func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { @@ -236,15 +239,45 @@ func (p *ParticipantImpl) ConnectedAt() time.Time { // SetMetadata attaches metadata to the participant func (p *ParticipantImpl) SetMetadata(metadata string) { + p.lock.Lock() + changed := p.metadata != metadata p.metadata = metadata + p.params.Grants.Metadata = metadata + p.lock.Unlock() + + if !changed { + return + } if p.onMetadataUpdate != nil { p.onMetadataUpdate(p) } + if p.onClaimsChanged != nil { + p.onClaimsChanged(p) + } +} + +func (p *ParticipantImpl) ClaimGrants() *auth.ClaimGrants { + p.lock.RLock() + defer p.lock.RUnlock() + return p.params.Grants } func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermission) { + p.lock.Lock() p.permission = permission + + // update grants with this + if p.params.Grants != nil && p.params.Grants.Video != nil { + video := p.params.Grants.Video + video.SetCanSubscribe(permission.CanSubscribe) + video.SetCanPublish(permission.CanPublish) + video.SetCanPublishData(permission.CanPublishData) + } + p.lock.Unlock() + if p.onClaimsChanged != nil { + p.onClaimsChanged(p) + } } func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { @@ -301,6 +334,10 @@ func (p *ParticipantImpl) OnClose(callback func(types.LocalParticipant, map[live p.onClose = callback } +func (p *ParticipantImpl) OnClaimsChanged(callback func(types.LocalParticipant)) { + p.onClaimsChanged = callback +} + // HandleOffer an offer from remote participant, used when clients make the initial connection func (p *ParticipantImpl) HandleOffer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) { p.lock.Lock() @@ -650,6 +687,14 @@ func (p *ParticipantImpl) SendConnectionQualityUpdate(update *livekit.Connection }) } +func (p *ParticipantImpl) SendRefreshToken(token string) error { + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_RefreshToken{ + RefreshToken: token, + }, + }) +} + func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fromAdmin bool) { // when request is coming from admin, send message to current participant if fromAdmin { diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 6b425c5cb..094ddc4fd 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -3,6 +3,7 @@ package types import ( "time" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/pion/webrtc/v3" @@ -74,15 +75,14 @@ type LocalParticipant interface { State() livekit.ParticipantInfo_State IsReady() bool - IsRecorder() bool - SubscriberAsPrimary() bool GetResponseSink() routing.MessageSink SetResponseSink(sink routing.MessageSink) // permissions + ClaimGrants() *auth.ClaimGrants SetPermission(permission *livekit.ParticipantPermission) CanPublish() bool CanSubscribe() bool @@ -119,6 +119,7 @@ type LocalParticipant interface { SendRoomUpdate(room *livekit.Room) error SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error SubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) + SendRefreshToken(token string) error // callbacks OnStateChange(func(p LocalParticipant, oldState livekit.ParticipantInfo_State)) @@ -129,6 +130,7 @@ type LocalParticipant interface { OnMetadataUpdate(callback func(LocalParticipant)) OnDataPacket(callback func(LocalParticipant, *livekit.DataPacket)) OnClose(_callback func(LocalParticipant, map[livekit.TrackID]livekit.ParticipantID)) + OnClaimsChanged(_callback func(LocalParticipant)) // session migration SetMigrateState(s MigrateState) diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index 0bd850818..db38a1164 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -7,6 +7,7 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" webrtc "github.com/pion/webrtc/v3" ) @@ -84,6 +85,16 @@ type FakeLocalParticipant struct { canSubscribeReturnsOnCall map[int]struct { result1 bool } + ClaimGrantsStub func() *auth.ClaimGrants + claimGrantsMutex sync.RWMutex + claimGrantsArgsForCall []struct { + } + claimGrantsReturns struct { + result1 *auth.ClaimGrants + } + claimGrantsReturnsOnCall map[int]struct { + result1 *auth.ClaimGrants + } CloseStub func(bool) error closeMutex sync.RWMutex closeArgsForCall []struct { @@ -297,6 +308,11 @@ type FakeLocalParticipant struct { negotiateMutex sync.RWMutex negotiateArgsForCall []struct { } + OnClaimsChangedStub func(func(types.LocalParticipant)) + onClaimsChangedMutex sync.RWMutex + onClaimsChangedArgsForCall []struct { + arg1 func(types.LocalParticipant) + } OnCloseStub func(func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID)) onCloseMutex sync.RWMutex onCloseArgsForCall []struct { @@ -396,6 +412,17 @@ type FakeLocalParticipant struct { sendParticipantUpdateReturnsOnCall map[int]struct { result1 error } + SendRefreshTokenStub func(string) error + sendRefreshTokenMutex sync.RWMutex + sendRefreshTokenArgsForCall []struct { + arg1 string + } + sendRefreshTokenReturns struct { + result1 error + } + sendRefreshTokenReturnsOnCall map[int]struct { + result1 error + } SendRoomUpdateStub func(*livekit.Room) error sendRoomUpdateMutex sync.RWMutex sendRoomUpdateArgsForCall []struct { @@ -957,6 +984,59 @@ func (fake *FakeLocalParticipant) CanSubscribeReturnsOnCall(i int, result1 bool) }{result1} } +func (fake *FakeLocalParticipant) ClaimGrants() *auth.ClaimGrants { + fake.claimGrantsMutex.Lock() + ret, specificReturn := fake.claimGrantsReturnsOnCall[len(fake.claimGrantsArgsForCall)] + fake.claimGrantsArgsForCall = append(fake.claimGrantsArgsForCall, struct { + }{}) + stub := fake.ClaimGrantsStub + fakeReturns := fake.claimGrantsReturns + fake.recordInvocation("ClaimGrants", []interface{}{}) + fake.claimGrantsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) ClaimGrantsCallCount() int { + fake.claimGrantsMutex.RLock() + defer fake.claimGrantsMutex.RUnlock() + return len(fake.claimGrantsArgsForCall) +} + +func (fake *FakeLocalParticipant) ClaimGrantsCalls(stub func() *auth.ClaimGrants) { + fake.claimGrantsMutex.Lock() + defer fake.claimGrantsMutex.Unlock() + fake.ClaimGrantsStub = stub +} + +func (fake *FakeLocalParticipant) ClaimGrantsReturns(result1 *auth.ClaimGrants) { + fake.claimGrantsMutex.Lock() + defer fake.claimGrantsMutex.Unlock() + fake.ClaimGrantsStub = nil + fake.claimGrantsReturns = struct { + result1 *auth.ClaimGrants + }{result1} +} + +func (fake *FakeLocalParticipant) ClaimGrantsReturnsOnCall(i int, result1 *auth.ClaimGrants) { + fake.claimGrantsMutex.Lock() + defer fake.claimGrantsMutex.Unlock() + fake.ClaimGrantsStub = nil + if fake.claimGrantsReturnsOnCall == nil { + fake.claimGrantsReturnsOnCall = make(map[int]struct { + result1 *auth.ClaimGrants + }) + } + fake.claimGrantsReturnsOnCall[i] = struct { + result1 *auth.ClaimGrants + }{result1} +} + func (fake *FakeLocalParticipant) Close(arg1 bool) error { fake.closeMutex.Lock() ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] @@ -2087,6 +2167,38 @@ func (fake *FakeLocalParticipant) NegotiateCalls(stub func()) { fake.NegotiateStub = stub } +func (fake *FakeLocalParticipant) OnClaimsChanged(arg1 func(types.LocalParticipant)) { + fake.onClaimsChangedMutex.Lock() + fake.onClaimsChangedArgsForCall = append(fake.onClaimsChangedArgsForCall, struct { + arg1 func(types.LocalParticipant) + }{arg1}) + stub := fake.OnClaimsChangedStub + fake.recordInvocation("OnClaimsChanged", []interface{}{arg1}) + fake.onClaimsChangedMutex.Unlock() + if stub != nil { + fake.OnClaimsChangedStub(arg1) + } +} + +func (fake *FakeLocalParticipant) OnClaimsChangedCallCount() int { + fake.onClaimsChangedMutex.RLock() + defer fake.onClaimsChangedMutex.RUnlock() + return len(fake.onClaimsChangedArgsForCall) +} + +func (fake *FakeLocalParticipant) OnClaimsChangedCalls(stub func(func(types.LocalParticipant))) { + fake.onClaimsChangedMutex.Lock() + defer fake.onClaimsChangedMutex.Unlock() + fake.OnClaimsChangedStub = stub +} + +func (fake *FakeLocalParticipant) OnClaimsChangedArgsForCall(i int) func(types.LocalParticipant) { + fake.onClaimsChangedMutex.RLock() + defer fake.onClaimsChangedMutex.RUnlock() + argsForCall := fake.onClaimsChangedArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeLocalParticipant) OnClose(arg1 func(types.LocalParticipant, map[livekit.TrackID]livekit.ParticipantID)) { fake.onCloseMutex.Lock() fake.onCloseArgsForCall = append(fake.onCloseArgsForCall, struct { @@ -2660,6 +2772,67 @@ func (fake *FakeLocalParticipant) SendParticipantUpdateReturnsOnCall(i int, resu }{result1} } +func (fake *FakeLocalParticipant) SendRefreshToken(arg1 string) error { + fake.sendRefreshTokenMutex.Lock() + ret, specificReturn := fake.sendRefreshTokenReturnsOnCall[len(fake.sendRefreshTokenArgsForCall)] + fake.sendRefreshTokenArgsForCall = append(fake.sendRefreshTokenArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.SendRefreshTokenStub + fakeReturns := fake.sendRefreshTokenReturns + fake.recordInvocation("SendRefreshToken", []interface{}{arg1}) + fake.sendRefreshTokenMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendRefreshTokenCallCount() int { + fake.sendRefreshTokenMutex.RLock() + defer fake.sendRefreshTokenMutex.RUnlock() + return len(fake.sendRefreshTokenArgsForCall) +} + +func (fake *FakeLocalParticipant) SendRefreshTokenCalls(stub func(string) error) { + fake.sendRefreshTokenMutex.Lock() + defer fake.sendRefreshTokenMutex.Unlock() + fake.SendRefreshTokenStub = stub +} + +func (fake *FakeLocalParticipant) SendRefreshTokenArgsForCall(i int) string { + fake.sendRefreshTokenMutex.RLock() + defer fake.sendRefreshTokenMutex.RUnlock() + argsForCall := fake.sendRefreshTokenArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendRefreshTokenReturns(result1 error) { + fake.sendRefreshTokenMutex.Lock() + defer fake.sendRefreshTokenMutex.Unlock() + fake.SendRefreshTokenStub = nil + fake.sendRefreshTokenReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRefreshTokenReturnsOnCall(i int, result1 error) { + fake.sendRefreshTokenMutex.Lock() + defer fake.sendRefreshTokenMutex.Unlock() + fake.SendRefreshTokenStub = nil + if fake.sendRefreshTokenReturnsOnCall == nil { + fake.sendRefreshTokenReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendRefreshTokenReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeLocalParticipant) SendRoomUpdate(arg1 *livekit.Room) error { fake.sendRoomUpdateMutex.Lock() ret, specificReturn := fake.sendRoomUpdateReturnsOnCall[len(fake.sendRoomUpdateArgsForCall)] @@ -3625,6 +3798,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.canPublishDataMutex.RUnlock() fake.canSubscribeMutex.RLock() defer fake.canSubscribeMutex.RUnlock() + fake.claimGrantsMutex.RLock() + defer fake.claimGrantsMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() fake.connectedAtMutex.RLock() @@ -3667,6 +3842,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.migrateStateMutex.RUnlock() fake.negotiateMutex.RLock() defer fake.negotiateMutex.RUnlock() + fake.onClaimsChangedMutex.RLock() + defer fake.onClaimsChangedMutex.RUnlock() fake.onCloseMutex.RLock() defer fake.onCloseMutex.RUnlock() fake.onDataPacketMutex.RLock() @@ -3693,6 +3870,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.sendJoinResponseMutex.RUnlock() fake.sendParticipantUpdateMutex.RLock() defer fake.sendParticipantUpdateMutex.RUnlock() + fake.sendRefreshTokenMutex.RLock() + defer fake.sendRefreshTokenMutex.RUnlock() fake.sendRoomUpdateMutex.RLock() defer fake.sendRoomUpdateMutex.RUnlock() fake.sendSpeakerUpdateMutex.RLock() diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index cf543259e..2ac24fc34 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" "github.com/livekit/protocol/utils" @@ -18,7 +19,9 @@ import ( ) const ( - roomPurgeSeconds = 24 * 60 * 60 + roomPurgeSeconds = 24 * 60 * 60 + tokenRefreshInterval = 5 * time.Minute + tokenDefaultTTL = 10 * time.Minute ) // RoomManager manages rooms and its interaction with participants. @@ -241,18 +244,18 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName livekit.RoomNam ThrottleConfig: r.config.RTC.PLIThrottle, CongestionControlConfig: r.config.RTC.CongestionControl, EnabledCodecs: room.Room.EnabledCodecs, + Grants: pi.Grants, Hidden: pi.Hidden, Logger: pLogger, }) - if err != nil { logger.Errorw("could not create participant", err) return } + if pi.Metadata != "" { participant.SetMetadata(pi.Metadata) } - if pi.Permission != nil { participant.SetPermission(pi.Permission) } @@ -292,6 +295,12 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName livekit.RoomNam room.RemoveDisallowedSubscriptions(p, disallowedSubscriptions) }) + participant.OnClaimsChanged(func(participant types.LocalParticipant) { + pLogger.Debugw("refreshing client token after claims change") + if err := r.refreshToken(participant); err != nil { + logger.Errorw("could not refresh token", err) + } + }) go r.rtcSessionWorker(room, participant, requestSource) } @@ -361,6 +370,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.LocalPa participant.Identity(), participant.ID(), ) + lastTokenUpdate := time.Now() for { select { case <-time.After(time.Millisecond * 50): @@ -368,6 +378,15 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.LocalPa if participant.State() == livekit.ParticipantInfo_DISCONNECTED { return } + + if time.Now().Sub(lastTokenUpdate) > tokenRefreshInterval { + pLogger.Debugw("refreshing client token after interval") + // refresh token with the first API Key/secret pair + if err := r.refreshToken(participant); err != nil { + pLogger.Errorw("could not refresh token", err) + } + lastTokenUpdate = time.Now() + } case obj := <-requestSource.ReadChan(): // In single node mode, the request source is directly tied to the signal message channel // this means ICE restart isn't possible in single node mode @@ -503,6 +522,27 @@ func (r *RoomManager) iceServersForRoom(ri *livekit.Room) []*livekit.ICEServer { return iceServers } +func (r *RoomManager) refreshToken(participant types.LocalParticipant) error { + for key, secret := range r.config.Keys { + grants := participant.ClaimGrants() + token := auth.NewAccessToken(key, secret) + token.SetName(grants.Name). + SetIdentity(string(participant.Identity())). + SetValidFor(tokenDefaultTTL). + SetMetadata(grants.Metadata). + AddGrant(grants.Video) + jwt, err := token.ToJWT() + if err == nil { + err = participant.SendRefreshToken(jwt) + } + if err != nil { + return err + } + break + } + return nil +} + func iceServerForStunServers(servers []string) *livekit.ICEServer { iceServer := &livekit.ICEServer{} for _, stunServer := range servers { diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index b2a938d35..12d3e663a 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -117,6 +117,7 @@ func (s *RTCService) validate(r *http.Request) (livekit.RoomName, routing.Partic Hidden: claims.Video.Hidden, Recorder: claims.Video.Recorder, Client: ParseClientInfo(r.Form), + Grants: claims, } if autoSubParam != "" { diff --git a/pkg/service/wire.go b/pkg/service/wire.go index 0f0e20a37..1a3774b90 100644 --- a/pkg/service/wire.go +++ b/pkg/service/wire.go @@ -16,6 +16,7 @@ import ( "github.com/livekit/protocol/utils" "github.com/livekit/protocol/webhook" "github.com/pkg/errors" + "gopkg.in/yaml.v3" "github.com/livekit/livekit-server/pkg/config" "github.com/livekit/livekit-server/pkg/routing" @@ -71,7 +72,10 @@ func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) { defer func() { _ = f.Close() }() - return auth.NewFileBasedKeyProviderFromReader(f) + decoder := yaml.NewDecoder(f) + if err = decoder.Decode(conf.Keys); err != nil { + return nil, err + } } if len(conf.Keys) == 0 { diff --git a/pkg/service/wire_gen.go b/pkg/service/wire_gen.go index dfdf424b0..55cff3f46 100644 --- a/pkg/service/wire_gen.go +++ b/pkg/service/wire_gen.go @@ -18,6 +18,7 @@ import ( "github.com/livekit/protocol/utils" "github.com/livekit/protocol/webhook" "github.com/pkg/errors" + "gopkg.in/yaml.v3" "os" ) @@ -93,7 +94,10 @@ func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) { defer func() { _ = f.Close() }() - return auth.NewFileBasedKeyProviderFromReader(f) + decoder := yaml.NewDecoder(f) + if err = decoder.Decode(conf.Keys); err != nil { + return nil, err + } } if len(conf.Keys) == 0 { diff --git a/test/client/client.go b/test/client/client.go index cc287ff69..22954ab43 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -60,6 +60,7 @@ type RTCClient struct { pendingTrackWriters []*TrackWriter OnConnected func() OnDataReceived func(data []byte, sid string) + refreshToken string // map of livekit.ParticipantID and last packet lastPackets map[livekit.ParticipantID]*rtp.Packet @@ -324,6 +325,10 @@ func (c *RTCClient) Run() error { c.lock.Lock() c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track c.lock.Unlock() + case *livekit.SignalResponse_RefreshToken: + c.lock.Lock() + c.refreshToken = msg.RefreshToken + c.lock.Unlock() } } } @@ -412,6 +417,12 @@ func (c *RTCClient) Stop() { c.cancel() } +func (c *RTCClient) RefreshToken() string { + c.lock.Lock() + defer c.lock.Unlock() + return c.refreshToken +} + func (c *RTCClient) SendRequest(msg *livekit.SignalRequest) error { payload, err := proto.Marshal(msg) if err != nil { diff --git a/test/multinode_test.go b/test/multinode_test.go index e506ac4e0..862e3606a 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" "github.com/stretchr/testify/require" @@ -152,3 +153,42 @@ func TestMultiNodeJoinAfterClose(t *testing.T) { scenarioJoinClosedRoom(t) } + +// ensure that token accurately reflects out of band updates +func TestMultiNodeRefreshToken(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeJoinAfterClose") + defer finish() + + // a participant joining with full permissions + c1 := createRTCClient("c1", defaultServerPort, nil) + waitUntilConnected(t, c1) + + // update permissions and metadata + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "c1", + Permission: &livekit.ParticipantPermission{ + CanPublish: false, + CanSubscribe: true, + }, + Metadata: "metadata", + }) + require.NoError(t, err) + + testutils.WithTimeout(t, "waiting for refresh token", func() bool { + return c1.RefreshToken() != "" + }) + + // parse token to ensure it's correct + verifier, err := auth.ParseAPIToken(c1.RefreshToken()) + require.NoError(t, err) + + grants, err := verifier.Verify(testApiSecret) + require.NoError(t, err) + + require.Equal(t, "metadata", grants.Metadata) + require.False(t, *grants.Video.CanPublish) + require.False(t, *grants.Video.CanPublishData) + require.True(t, *grants.Video.CanSubscribe) +}