diff --git a/go.mod b/go.mod index 7bd70560e..c65c66ee4 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,8 @@ require ( github.com/google/wire v0.5.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.11.14 + // TODO: replace with merged protocol version + github.com/livekit/protocol v0.11.15-0.20220320074808-41056286643d github.com/mackerelio/go-osstat v0.2.1 github.com/magefile/mage v1.11.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0 diff --git a/go.sum b/go.sum index 93db648f3..322eab64c 100644 --- a/go.sum +++ b/go.sum @@ -132,8 +132,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY/mKrcvA= github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= -github.com/livekit/protocol v0.11.14 h1:KmFPWNMtrKMhwhdPZHMQ9Dj2DFH4XLzdvv1gTJlJJKM= -github.com/livekit/protocol v0.11.14/go.mod h1:3pHsWUtQmWaH8mG0cXrQWpbf3Vo+kj0U+In77CEXu90= +github.com/livekit/protocol v0.11.15-0.20220320074808-41056286643d h1:dzljbpcV2ZuglBXRd9PsEJZv6oByYRBYaAQsG9XGRuA= +github.com/livekit/protocol v0.11.15-0.20220320074808-41056286643d/go.mod h1:3pHsWUtQmWaH8mG0cXrQWpbf3Vo+kj0U+In77CEXu90= github.com/mackerelio/go-osstat v0.2.1 h1:5AeAcBEutEErAOlDz6WCkEvm6AKYgHTUQrfwm5RbeQc= github.com/mackerelio/go-osstat v0.2.1/go.mod h1:UzRL8dMCCTqG5WdRtsxbuljMpZt9PCAGXqxPst5QtaY= github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls= diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 20c4bc4f5..6e9e3e37c 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -30,12 +30,8 @@ type MessageSource interface { type ParticipantInit struct { Identity livekit.ParticipantIdentity Name livekit.ParticipantName - Metadata string Reconnect bool - Permission *livekit.ParticipantPermission AutoSubscribe bool - Hidden bool - Recorder bool Client *livekit.ClientInfo Grants *auth.ClaimGrants } diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 916ee7f66..1e0b7d285 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -159,15 +159,11 @@ func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livek err = sink.WriteMessage(&livekit.StartSession{ RoomName: string(roomName), Identity: string(pi.Identity), - Metadata: pi.Metadata, Name: string(pi.Name), // connection id is to allow the RTC node to identify where to route the message back to ConnectionId: string(connectionID), Reconnect: pi.Reconnect, - Permission: pi.Permission, AutoSubscribe: pi.AutoSubscribe, - Hidden: pi.Hidden, - Recorder: pi.Recorder, Client: pi.Client, GrantsJson: string(claims), }) @@ -255,14 +251,10 @@ func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantK pi := ParticipantInit{ Identity: livekit.ParticipantIdentity(ss.Identity), - Metadata: ss.Metadata, Name: livekit.ParticipantName(ss.Name), Reconnect: ss.Reconnect, - Permission: ss.Permission, Client: ss.Client, AutoSubscribe: ss.AutoSubscribe, - Hidden: ss.Hidden, - Recorder: ss.Recorder, Grants: claims, } diff --git a/pkg/rtc/errors.go b/pkg/rtc/errors.go index 320491270..bc04f5e69 100644 --- a/pkg/rtc/errors.go +++ b/pkg/rtc/errors.go @@ -11,4 +11,7 @@ var ( ErrUnexpectedOffer = errors.New("expected answer SDP, received offer") ErrDataChannelUnavailable = errors.New("data channel is not available") ErrCannotSubscribe = errors.New("participant does not have permission to subscribe") + ErrEmptyIdentity = errors.New("participant identity cannot be empty") + ErrEmptyParticipantID = errors.New("participant ID cannot be empty") + ErrMissingGrants = errors.New("VideoGrant is missing") ) diff --git a/pkg/rtc/helper_test.go b/pkg/rtc/helper_test.go index d1a07ba61..a8cf523a9 100644 --- a/pkg/rtc/helper_test.go +++ b/pkg/rtc/helper_test.go @@ -27,8 +27,8 @@ func newMockParticipant(identity livekit.ParticipantIdentity, protocol types.Pro p.SetMetadataStub = func(m string) { var f func(participant types.LocalParticipant) - if p.OnMetadataUpdateCallCount() > 0 { - f = p.OnMetadataUpdateArgsForCall(p.OnMetadataUpdateCallCount() - 1) + if p.OnParticipantUpdateCallCount() > 0 { + f = p.OnParticipantUpdateArgsForCall(p.OnParticipantUpdateCallCount() - 1) } if f != nil { f(p) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 46d1ebe0f..5ec0a5db3 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -2,7 +2,6 @@ package rtc import ( "context" - "fmt" "io" "strings" "sync" @@ -27,7 +26,6 @@ import ( "github.com/livekit/livekit-server/pkg/sfu/twcc" "github.com/livekit/livekit-server/pkg/telemetry" "github.com/livekit/livekit-server/pkg/telemetry/prometheus" - "github.com/livekit/livekit-server/version" ) const ( @@ -57,8 +55,6 @@ type ParticipantParams struct { PLIThrottleConfig config.PLIThrottleConfig CongestionControlConfig config.CongestionControlConfig EnabledCodecs []*livekit.Codec - Hidden bool - Recorder bool Logger logger.Logger SimTracks map[uint32]SimulcastTrackInfo Grants *auth.ClaimGrants @@ -71,12 +67,12 @@ type ParticipantImpl struct { publisher *PCTransport subscriber *PCTransport isClosed atomic.Bool - permission *livekit.ParticipantPermission state atomic.Value // livekit.ParticipantInfo_State updateCache *lru.Cache resSink atomic.Value // routing.MessageSink resSinkValid atomic.Bool subscriberAsPrimary bool + grants atomic.Value // *auth.ClaimGrants // reliable and unreliable data channels reliableDC *webrtc.DataChannel @@ -116,11 +112,11 @@ type ParticipantImpl struct { version atomic.Uint32 // callbacks & handlers - onTrackPublished func(types.LocalParticipant, types.MediaTrack) - onTrackUpdated func(types.LocalParticipant, types.MediaTrack) - onStateChange func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) - onMetadataUpdate func(types.LocalParticipant) - onDataPacket func(types.LocalParticipant, *livekit.DataPacket) + onTrackPublished func(types.LocalParticipant, types.MediaTrack) + onTrackUpdated func(types.LocalParticipant, types.MediaTrack) + onStateChange func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State) + onParticipantUpdate func(types.LocalParticipant) + onDataPacket func(types.LocalParticipant, *livekit.DataPacket) migrateState atomic.Value // types.MigrateState pendingOffer *webrtc.SessionDescription @@ -131,8 +127,16 @@ type ParticipantImpl struct { activeCounter atomic.Int32 } -func NewParticipant(params ParticipantParams, perms *livekit.ParticipantPermission) (*ParticipantImpl, error) { - // TODO: check to ensure params are valid, id and identity can't be empty +func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { + if params.Identity == "" { + return nil, ErrEmptyIdentity + } + if params.SID == "" { + return nil, ErrEmptyParticipantID + } + if params.Grants == nil || params.Grants.Video == nil { + return nil, ErrMissingGrants + } p := &ParticipantImpl{ params: params, rtcpCh: make(chan []rtcp.Packet, 50), @@ -146,7 +150,7 @@ func NewParticipant(params ParticipantParams, perms *livekit.ParticipantPermissi p.version.Store(params.InitialVersion) p.migrateState.Store(types.MigrateStateInit) p.state.Store(livekit.ParticipantInfo_JOINING) - p.SetPermission(perms) + p.grants.Store(params.Grants) p.SetResponseSink(params.Sink) var err error @@ -272,17 +276,17 @@ func (p *ParticipantImpl) ConnectedAt() time.Time { // SetMetadata attaches metadata to the participant func (p *ParticipantImpl) SetMetadata(metadata string) { - p.lock.Lock() - changed := p.params.Grants.Metadata != metadata - p.params.Grants.Metadata = metadata - p.lock.Unlock() + grants := p.ClaimGrants() + changed := grants.Metadata != metadata + grants.Metadata = metadata + p.grants.Store(grants) if !changed { return } - if p.onMetadataUpdate != nil { - p.onMetadataUpdate(p) + if p.onParticipantUpdate != nil { + p.onParticipantUpdate(p) } if p.onClaimsChanged != nil { p.onClaimsChanged(p) @@ -290,66 +294,73 @@ func (p *ParticipantImpl) SetMetadata(metadata string) { } func (p *ParticipantImpl) ClaimGrants() *auth.ClaimGrants { - p.lock.RLock() - defer p.lock.RUnlock() - return p.params.Grants + return p.grants.Load().(*auth.ClaimGrants) } -func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermission) { - p.lock.Lock() - p.permission = permission - - // update grants with this - if p.params.Grants != nil && p.params.Grants.Video != nil && permission != nil { - video := p.params.Grants.Video - video.SetCanSubscribe(permission.CanSubscribe) - video.SetCanPublish(permission.CanPublish) - video.SetCanPublishData(permission.CanPublishData) +func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermission) bool { + if permission == nil { + return false + } + grants := p.ClaimGrants() + video := grants.Video + hasChanged := video.GetCanSubscribe() != permission.CanSubscribe || + video.GetCanPublish() != permission.CanPublish || + video.GetCanPublishData() != permission.CanPublishData || + video.Hidden != permission.Hidden || + video.Recorder != permission.Recorder + + if !hasChanged { + return false + } + + video.SetCanSubscribe(permission.CanSubscribe) + video.SetCanPublish(permission.CanPublish) + video.SetCanPublishData(permission.CanPublishData) + video.Hidden = permission.Hidden + video.Recorder = permission.Recorder + p.grants.Store(grants) + + // publish permission has been revoked then remove all published tracks + if !video.GetCanPublish() { + for _, track := range p.GetPublishedTracks() { + p.RemovePublishedTrack(track) + if p.ProtocolVersion().SupportsUnpublish() { + p.sendTrackUnpublished(track.ID()) + } else { + // for older clients that don't support unpublish, mute to avoid them sending data + p.sendTrackMuted(track.ID(), true) + } + } + } + + if p.onParticipantUpdate != nil { + p.onParticipantUpdate(p) } - p.lock.Unlock() if p.onClaimsChanged != nil { p.onClaimsChanged(p) } + return true } func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { + grants := p.ClaimGrants() info := &livekit.ParticipantInfo{ - Sid: string(p.params.SID), - Identity: string(p.params.Identity), - Name: string(p.params.Name), - State: p.State(), - JoinedAt: p.ConnectedAt().Unix(), - Hidden: p.Hidden(), - Recorder: p.IsRecorder(), - Version: p.version.Inc(), + Sid: string(p.params.SID), + Identity: string(p.params.Identity), + Name: string(p.params.Name), + State: p.State(), + JoinedAt: p.ConnectedAt().Unix(), + Version: p.version.Inc(), + Permission: grants.Video.ToPermission(), } info.Tracks = p.UpTrackManager.ToProto() if p.params.Grants != nil { - info.Metadata = p.params.Grants.Metadata + info.Metadata = grants.Metadata } return info } -func (p *ParticipantImpl) GetResponseSink() routing.MessageSink { - if !p.resSinkValid.Load() { - return nil - } - sink := p.resSink.Load() - if s, ok := sink.(routing.MessageSink); ok { - return s - } - return nil -} - -func (p *ParticipantImpl) SetResponseSink(sink routing.MessageSink) { - p.resSinkValid.Store(sink != nil) - if sink != nil { - // cannot store nil into atomic.Value - p.resSink.Store(sink) - } -} - func (p *ParticipantImpl) SubscriberMediaEngine() *webrtc.MediaEngine { return p.subscriber.me } @@ -368,8 +379,8 @@ func (p *ParticipantImpl) OnTrackUpdated(callback func(types.LocalParticipant, t p.onTrackUpdated = callback } -func (p *ParticipantImpl) OnMetadataUpdate(callback func(types.LocalParticipant)) { - p.onMetadataUpdate = callback +func (p *ParticipantImpl) OnParticipantUpdate(callback func(types.LocalParticipant)) { + p.onParticipantUpdate = callback } func (p *ParticipantImpl) OnDataPacket(callback func(types.LocalParticipant, *livekit.DataPacket)) { @@ -617,164 +628,6 @@ func (p *ParticipantImpl) ICERestart() error { // // signal connection methods // -func (p *ParticipantImpl) SendJoinResponse( - roomInfo *livekit.Room, - otherParticipants []*livekit.ParticipantInfo, - iceServers []*livekit.ICEServer, - region string, -) error { - // send Join response - return p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_Join{ - Join: &livekit.JoinResponse{ - Room: roomInfo, - Participant: p.ToProto(), - OtherParticipants: otherParticipants, - ServerVersion: version.Version, - ServerRegion: region, - IceServers: iceServers, - // indicates both server and client support subscriber as primary - SubscriberPrimary: p.SubscriberAsPrimary(), - ClientConfiguration: p.params.ClientConf, - }, - }, - }) -} - -func (p *ParticipantImpl) SendParticipantUpdate(participantsToUpdate []*livekit.ParticipantInfo) error { - p.updateLock.Lock() - validUpdates := make([]*livekit.ParticipantInfo, 0, len(participantsToUpdate)) - for _, pi := range participantsToUpdate { - isValid := true - if val, ok := p.updateCache.Get(pi.Sid); ok { - if lastVersion, ok := val.(uint32); ok { - // this is a message delivered out of order, a more recent version of the message had already been - // sent. - if pi.Version < lastVersion { - p.params.Logger.Debugw("skipping outdated participant update", "version", pi.Version, "lastVersion", lastVersion) - isValid = false - } - } - } - if isValid { - p.updateCache.Add(pi.Sid, pi.Version) - validUpdates = append(validUpdates, pi) - } - } - p.updateLock.Unlock() - - if len(validUpdates) == 0 { - return nil - } - - return p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_Update{ - Update: &livekit.ParticipantUpdate{ - Participants: validUpdates, - }, - }, - }) -} - -// SendSpeakerUpdate notifies participant changes to speakers. only send members that have changed since last update -func (p *ParticipantImpl) SendSpeakerUpdate(speakers []*livekit.SpeakerInfo) error { - if !p.IsReady() { - return nil - } - - var scopedSpeakers []*livekit.SpeakerInfo - for _, s := range speakers { - participantID := livekit.ParticipantID(s.Sid) - if p.isSubscribedTo(participantID) || participantID == p.ID() { - scopedSpeakers = append(scopedSpeakers, s) - } - } - - if len(scopedSpeakers) == 0 { - return nil - } - - return p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_SpeakersChanged{ - SpeakersChanged: &livekit.SpeakersChanged{ - Speakers: scopedSpeakers, - }, - }, - }) -} - -func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket) error { - if p.State() != livekit.ParticipantInfo_ACTIVE { - return ErrDataChannelUnavailable - } - - data, err := proto.Marshal(dp) - if err != nil { - return err - } - - var dc *webrtc.DataChannel - if dp.Kind == livekit.DataPacket_RELIABLE { - if p.SubscriberAsPrimary() { - dc = p.reliableDCSub - } else { - dc = p.reliableDC - } - } else { - if p.SubscriberAsPrimary() { - dc = p.lossyDCSub - } else { - dc = p.lossyDC - } - } - - if dc == nil { - return ErrDataChannelUnavailable - } - return dc.Send(data) -} - -func (p *ParticipantImpl) SendRoomUpdate(room *livekit.Room) error { - return p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_RoomUpdate{ - RoomUpdate: &livekit.RoomUpdate{ - Room: room, - }, - }, - }) -} - -func (p *ParticipantImpl) SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error { - return p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_ConnectionQuality{ - ConnectionQuality: update, - }, - }) -} - -func (p *ParticipantImpl) SendRefreshToken(token string) error { - return p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_RefreshToken{ - RefreshToken: token, - }, - }) -} - -func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fromAdmin bool) { - // when request is coming from admin, send message to current participant - if fromAdmin { - _ = p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_Mute{ - Mute: &livekit.MuteTrackRequest{ - Sid: string(trackID), - Muted: muted, - }, - }, - }) - } - - p.setTrackMuted(trackID, muted) -} func (p *ParticipantImpl) GetAudioLevel() (level uint8, active bool) { level = SilentAudioLevel @@ -830,23 +683,23 @@ func (p *ParticipantImpl) GetSubscribedParticipants() []livekit.ParticipantID { } func (p *ParticipantImpl) CanPublish() bool { - return p.permission == nil || p.permission.CanPublish + return p.ClaimGrants().Video.GetCanPublish() } func (p *ParticipantImpl) CanSubscribe() bool { - return p.permission == nil || p.permission.CanSubscribe + return p.ClaimGrants().Video.GetCanSubscribe() } func (p *ParticipantImpl) CanPublishData() bool { - return p.permission == nil || p.permission.CanPublishData + return p.ClaimGrants().Video.GetCanPublishData() } func (p *ParticipantImpl) Hidden() bool { - return p.params.Hidden + return p.ClaimGrants().Video.Hidden } func (p *ParticipantImpl) IsRecorder() bool { - return p.params.Recorder + return p.ClaimGrants().Video.Recorder } func (p *ParticipantImpl) SubscriberAsPrimary() bool { @@ -988,15 +841,6 @@ func (p *ParticipantImpl) UpdateRTT(rtt uint32) { } } -// closes signal connection to notify client to resume/reconnect -func (p *ParticipantImpl) closeSignalConnection() { - sink := p.GetResponseSink() - if sink != nil { - sink.Close() - p.SetResponseSink(nil) - } -} - func (p *ParticipantImpl) setupUpTrackManager() { p.UpTrackManager = NewUpTrackManager(UpTrackManagerParams{ SID: p.params.SID, @@ -1016,21 +860,6 @@ func (p *ParticipantImpl) setupUpTrackManager() { p.UpTrackManager.OnUpTrackManagerClose(p.onUpTrackManagerClose) } -func (p *ParticipantImpl) sendIceCandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { - ci := c.ToJSON() - - // write candidate - p.params.Logger.Debugw("sending ice candidates", - "candidate", c.String(), "target", target) - trickle := ToProtoTrickle(ci) - trickle.Target = target - _ = p.writeMessage(&livekit.SignalResponse{ - Message: &livekit.SignalResponse_Trickle{ - Trickle: trickle, - }, - }) -} - func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { oldState := p.State() if state == oldState { @@ -1049,23 +878,6 @@ func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { } } -func (p *ParticipantImpl) writeMessage(msg *livekit.SignalResponse) error { - if p.State() == livekit.ParticipantInfo_DISCONNECTED { - return nil - } - sink := p.GetResponseSink() - if sink == nil { - return nil - } - err := sink.WriteMessage(msg) - if err != nil { - p.params.Logger.Warnw("could not send message to participant", err, - "message", fmt.Sprintf("%T", msg.Message)) - return err - } - return nil -} - // when the server has an offer for participant func (p *ParticipantImpl) onOffer(offer webrtc.SessionDescription) { if p.State() == livekit.ParticipantInfo_DISCONNECTED { @@ -1405,6 +1217,15 @@ func (p *ParticipantImpl) addPendingTrack(req *livekit.AddTrackRequest) *livekit return ti } +func (p *ParticipantImpl) SetTrackMuted(trackID livekit.TrackID, muted bool, fromAdmin bool) { + // when request is coming from admin, send message to current participant + if fromAdmin { + p.sendTrackMuted(trackID, muted) + } + + p.setTrackMuted(trackID, muted) +} + func (p *ParticipantImpl) setTrackMuted(trackID livekit.TrackID, muted bool) { track := p.UpTrackManager.SetPublishedTrackMuted(trackID, muted) if track != nil { diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index be6211087..0837baa87 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -6,6 +6,7 @@ import ( "github.com/livekit/protocol/auth" "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" "github.com/stretchr/testify/require" "github.com/livekit/livekit-server/pkg/config" @@ -344,16 +345,23 @@ func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *p if err != nil { panic(err) } + grants := &auth.ClaimGrants{ + Video: &auth.VideoGrant{}, + } + if opts.permissions != nil { + grants.Video.SetCanPublish(opts.permissions.CanPublish) + grants.Video.SetCanPublishData(opts.permissions.CanPublishData) + grants.Video.SetCanSubscribe(opts.permissions.CanSubscribe) + } p, _ := NewParticipant(ParticipantParams{ + SID: livekit.ParticipantID(utils.NewGuid(utils.ParticipantPrefix)), Identity: identity, Config: rtcConf, Sink: &routingfakes.FakeMessageSink{}, ProtocolVersion: opts.protocolVersion, PLIThrottleConfig: conf.RTC.PLIThrottle, - Grants: &auth.ClaimGrants{ - Video: &auth.VideoGrant{}, - }, - }, opts.permissions) + Grants: grants, + }) return p } diff --git a/pkg/rtc/participant_signal.go b/pkg/rtc/participant_signal.go new file mode 100644 index 000000000..5197758c3 --- /dev/null +++ b/pkg/rtc/participant_signal.go @@ -0,0 +1,235 @@ +package rtc + +import ( + "fmt" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/version" + "github.com/livekit/protocol/livekit" + "github.com/pion/webrtc/v3" + "google.golang.org/protobuf/proto" +) + +func (p *ParticipantImpl) GetResponseSink() routing.MessageSink { + if !p.resSinkValid.Load() { + return nil + } + sink := p.resSink.Load() + if s, ok := sink.(routing.MessageSink); ok { + return s + } + return nil +} + +func (p *ParticipantImpl) SetResponseSink(sink routing.MessageSink) { + p.resSinkValid.Store(sink != nil) + if sink != nil { + // cannot store nil into atomic.Value + p.resSink.Store(sink) + } +} + +func (p *ParticipantImpl) SendJoinResponse( + roomInfo *livekit.Room, + otherParticipants []*livekit.ParticipantInfo, + iceServers []*livekit.ICEServer, + region string, +) error { + // send Join response + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Join{ + Join: &livekit.JoinResponse{ + Room: roomInfo, + Participant: p.ToProto(), + OtherParticipants: otherParticipants, + ServerVersion: version.Version, + ServerRegion: region, + IceServers: iceServers, + // indicates both server and client support subscriber as primary + SubscriberPrimary: p.SubscriberAsPrimary(), + ClientConfiguration: p.params.ClientConf, + }, + }, + }) +} + +func (p *ParticipantImpl) SendParticipantUpdate(participantsToUpdate []*livekit.ParticipantInfo) error { + p.updateLock.Lock() + validUpdates := make([]*livekit.ParticipantInfo, 0, len(participantsToUpdate)) + for _, pi := range participantsToUpdate { + isValid := true + if val, ok := p.updateCache.Get(pi.Sid); ok { + if lastVersion, ok := val.(uint32); ok { + // this is a message delivered out of order, a more recent version of the message had already been + // sent. + if pi.Version < lastVersion { + p.params.Logger.Debugw("skipping outdated participant update", "version", pi.Version, "lastVersion", lastVersion) + isValid = false + } + } + } + if isValid { + p.updateCache.Add(pi.Sid, pi.Version) + validUpdates = append(validUpdates, pi) + } + } + p.updateLock.Unlock() + + if len(validUpdates) == 0 { + return nil + } + + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Update{ + Update: &livekit.ParticipantUpdate{ + Participants: validUpdates, + }, + }, + }) +} + +// SendSpeakerUpdate notifies participant changes to speakers. only send members that have changed since last update +func (p *ParticipantImpl) SendSpeakerUpdate(speakers []*livekit.SpeakerInfo) error { + if !p.IsReady() { + return nil + } + + var scopedSpeakers []*livekit.SpeakerInfo + for _, s := range speakers { + participantID := livekit.ParticipantID(s.Sid) + if p.isSubscribedTo(participantID) || participantID == p.ID() { + scopedSpeakers = append(scopedSpeakers, s) + } + } + + if len(scopedSpeakers) == 0 { + return nil + } + + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_SpeakersChanged{ + SpeakersChanged: &livekit.SpeakersChanged{ + Speakers: scopedSpeakers, + }, + }, + }) +} + +func (p *ParticipantImpl) SendDataPacket(dp *livekit.DataPacket) error { + if p.State() != livekit.ParticipantInfo_ACTIVE { + return ErrDataChannelUnavailable + } + + data, err := proto.Marshal(dp) + if err != nil { + return err + } + + var dc *webrtc.DataChannel + if dp.Kind == livekit.DataPacket_RELIABLE { + if p.SubscriberAsPrimary() { + dc = p.reliableDCSub + } else { + dc = p.reliableDC + } + } else { + if p.SubscriberAsPrimary() { + dc = p.lossyDCSub + } else { + dc = p.lossyDC + } + } + + if dc == nil { + return ErrDataChannelUnavailable + } + return dc.Send(data) +} + +func (p *ParticipantImpl) SendRoomUpdate(room *livekit.Room) error { + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_RoomUpdate{ + RoomUpdate: &livekit.RoomUpdate{ + Room: room, + }, + }, + }) +} + +func (p *ParticipantImpl) SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error { + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_ConnectionQuality{ + ConnectionQuality: update, + }, + }) +} + +func (p *ParticipantImpl) SendRefreshToken(token string) error { + return p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_RefreshToken{ + RefreshToken: token, + }, + }) +} + +func (p *ParticipantImpl) sendIceCandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { + ci := c.ToJSON() + + // write candidate + p.params.Logger.Debugw("sending ice candidates", + "candidate", c.String(), "target", target) + trickle := ToProtoTrickle(ci) + trickle.Target = target + _ = p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Trickle{ + Trickle: trickle, + }, + }) +} + +func (p *ParticipantImpl) sendTrackMuted(trackID livekit.TrackID, muted bool) { + _ = p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Mute{ + Mute: &livekit.MuteTrackRequest{ + Sid: string(trackID), + Muted: muted, + }, + }, + }) +} + +func (p *ParticipantImpl) sendTrackUnpublished(trackID livekit.TrackID) { + _ = p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_TrackUnpublished{ + TrackUnpublished: &livekit.TrackUnpublishedResponse{ + TrackSid: string(trackID), + }, + }, + }) +} + +func (p *ParticipantImpl) writeMessage(msg *livekit.SignalResponse) error { + if p.State() == livekit.ParticipantInfo_DISCONNECTED { + return nil + } + sink := p.GetResponseSink() + if sink == nil { + return nil + } + err := sink.WriteMessage(msg) + if err != nil { + p.params.Logger.Warnw("could not send message to participant", err, + "message", fmt.Sprintf("%T", msg.Message)) + return err + } + return nil +} + +// closes signal connection to notify client to resume/reconnect +func (p *ParticipantImpl) closeSignalConnection() { + sink := p.GetResponseSink() + if sink != nil { + sink.Close() + p.SetResponseSink(nil) + } +} diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index bfdd876a1..4c2b19a68 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -222,7 +222,7 @@ func (r *Room) Join(participant types.LocalParticipant, opts *ParticipantOptions } }) participant.OnTrackUpdated(r.onTrackUpdated) - participant.OnMetadataUpdate(r.onParticipantMetadataUpdate) + participant.OnParticipantUpdate(r.onParticipantUpdate) participant.OnDataPacket(r.onDataPacket) r.Logger.Infow("new participant joined", "pID", participant.ID(), @@ -329,7 +329,7 @@ func (r *Room) RemoveParticipant(identity livekit.ParticipantIdentity) { p.OnTrackUpdated(nil) p.OnTrackPublished(nil) p.OnStateChange(nil) - p.OnMetadataUpdate(nil) + p.OnParticipantUpdate(nil) p.OnDataPacket(nil) // close participant as well @@ -423,6 +423,7 @@ func (r *Room) SetParticipantPermission(participant types.LocalParticipant, perm } } } + return nil } @@ -632,7 +633,7 @@ func (r *Room) onTrackUpdated(p types.LocalParticipant, _ types.MediaTrack) { } } -func (r *Room) onParticipantMetadataUpdate(p types.LocalParticipant) { +func (r *Room) onParticipantUpdate(p types.LocalParticipant) { r.broadcastParticipantState(p, false) if r.onParticipantChanged != nil { r.onParticipantChanged(p) diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index 3ab60b931..3c7d6b799 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -272,22 +272,19 @@ func TestNewTrack(t *testing.T) { func TestActiveSpeakers(t *testing.T) { t.Parallel() - getActiveSpeakerUpdates := func(p *typesfakes.FakeLocalParticipant) []*livekit.ActiveSpeakerUpdate { - var updates []*livekit.ActiveSpeakerUpdate - numCalls := p.SendDataPacketCallCount() + getActiveSpeakerUpdates := func(p *typesfakes.FakeLocalParticipant) [][]*livekit.SpeakerInfo { + var updates [][]*livekit.SpeakerInfo + numCalls := p.SendSpeakerUpdateCallCount() for i := 0; i < numCalls; i++ { - dp := p.SendDataPacketArgsForCall(i) - switch val := dp.Value.(type) { - case *livekit.DataPacket_Speaker: - updates = append(updates, val.Speaker) - } + infos := p.SendSpeakerUpdateArgsForCall(i) + updates = append(updates, infos) } return updates } audioUpdateDuration := (audioUpdateInterval + 10) * time.Millisecond t.Run("participant should not be getting audio updates (protocol 2)", func(t *testing.T) { - rm := newRoomWithParticipants(t, testRoomOpts{num: 1, protocol: types.DefaultProtocol}) + rm := newRoomWithParticipants(t, testRoomOpts{num: 1, protocol: 2}) defer rm.Close() p := rm.GetParticipants()[0].(*typesfakes.FakeLocalParticipant) require.Empty(t, rm.GetActiveSpeakers()) @@ -298,7 +295,7 @@ func TestActiveSpeakers(t *testing.T) { require.Empty(t, updates) }) - t.Run("speakers should be sorted by loudness (protocol 0)", func(t *testing.T) { + t.Run("speakers should be sorted by loudness", func(t *testing.T) { rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) defer rm.Close() participants := rm.GetParticipants() @@ -313,8 +310,8 @@ func TestActiveSpeakers(t *testing.T) { require.Equal(t, string(p2.ID()), speakers[1].Sid) }) - t.Run("participants are getting audio updates (protocol 2)", func(t *testing.T) { - rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: types.DefaultProtocol}) + t.Run("participants are getting audio updates (protocol 3+)", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: 3}) defer rm.Close() participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) @@ -342,15 +339,18 @@ func TestActiveSpeakers(t *testing.T) { testutils.WithTimeout(t, func() string { updates := getActiveSpeakerUpdates(p) lastUpdate := updates[len(updates)-1] - if len(lastUpdate.Speakers) != 0 { - return fmt.Sprintf("expected no speakers, but found %d", len(lastUpdate.Speakers)) + if len(lastUpdate) == 0 { + return "did not get updates of speaker going quiet" + } + if lastUpdate[0].Active { + return "speaker should not have been active" } return "" }) }) t.Run("audio level is smoothed", func(t *testing.T) { - rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: types.DefaultProtocol, audioSmoothIntervals: 3}) + rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: 3, audioSmoothIntervals: 3}) defer rm.Close() participants := rm.GetParticipants() p := participants[0].(*typesfakes.FakeLocalParticipant) @@ -363,7 +363,7 @@ func TestActiveSpeakers(t *testing.T) { if len(updates) == 0 { return "no speaker updates received" } - lastSpeakers := updates[len(updates)-1].Speakers + lastSpeakers := updates[len(updates)-1] if len(lastSpeakers) == 0 { return "no speakers in the update" } @@ -378,7 +378,7 @@ func TestActiveSpeakers(t *testing.T) { if len(updates) == 0 { return "no updates received" } - lastSpeakers := updates[len(updates)-1].Speakers + lastSpeakers := updates[len(updates)-1] if len(lastSpeakers) == 0 { return "no speakers found" } @@ -389,14 +389,13 @@ func TestActiveSpeakers(t *testing.T) { }) p.GetAudioLevelReturns(127, false) - testutils.WithTimeout(t, func() string { updates := getActiveSpeakerUpdates(op) if len(updates) == 0 { return "no speaker updates received" } - lastSpeakers := updates[len(updates)-1].Speakers - if len(lastSpeakers) == 0 { + lastSpeakers := updates[len(updates)-1] + if len(lastSpeakers) == 1 && !lastSpeakers[0].Active { return "" } return "speakers didn't go back to zero" diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 7a635a0e9..8c7dd6a70 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -100,7 +100,7 @@ type LocalParticipant interface { // permissions ClaimGrants() *auth.ClaimGrants - SetPermission(permission *livekit.ParticipantPermission) + SetPermission(permission *livekit.ParticipantPermission) bool CanPublish() bool CanSubscribe() bool CanPublishData() bool @@ -144,7 +144,8 @@ type LocalParticipant interface { OnTrackPublished(func(LocalParticipant, MediaTrack)) // OnTrackUpdated - one of its publishedTracks changed in status OnTrackUpdated(callback func(LocalParticipant, MediaTrack)) - OnMetadataUpdate(callback func(LocalParticipant)) + // OnParticipantUpdate - metadata or permission is updated + OnParticipantUpdate(callback func(LocalParticipant)) OnDataPacket(callback func(LocalParticipant, *livekit.DataPacket)) OnClose(_callback func(LocalParticipant, map[livekit.TrackID]livekit.ParticipantID)) OnClaimsChanged(_callback func(LocalParticipant)) diff --git a/pkg/rtc/types/protocol_version.go b/pkg/rtc/types/protocol_version.go index ca9ea6d80..accfb437c 100644 --- a/pkg/rtc/types/protocol_version.go +++ b/pkg/rtc/types/protocol_version.go @@ -2,7 +2,7 @@ package types type ProtocolVersion int -const DefaultProtocol = 2 +const DefaultProtocol = 6 func (v ProtocolVersion) SupportsPackedStreamId() bool { return v > 0 @@ -43,3 +43,7 @@ func (v ProtocolVersion) SupportsSessionMigrate() bool { func (v ProtocolVersion) SupportsICELite() bool { return v > 5 } + +func (v ProtocolVersion) SupportsUnpublish() bool { + return v > 6 +} diff --git a/pkg/rtc/types/typesfakes/fake_local_participant.go b/pkg/rtc/types/typesfakes/fake_local_participant.go index c8782c0a1..2e19ecd6a 100644 --- a/pkg/rtc/types/typesfakes/fake_local_participant.go +++ b/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -317,9 +317,9 @@ type FakeLocalParticipant struct { onDataPacketArgsForCall []struct { arg1 func(types.LocalParticipant, *livekit.DataPacket) } - OnMetadataUpdateStub func(func(types.LocalParticipant)) - onMetadataUpdateMutex sync.RWMutex - onMetadataUpdateArgsForCall []struct { + OnParticipantUpdateStub func(func(types.LocalParticipant)) + onParticipantUpdateMutex sync.RWMutex + onParticipantUpdateArgsForCall []struct { arg1 func(types.LocalParticipant) } OnStateChangeStub func(func(p types.LocalParticipant, oldState livekit.ParticipantInfo_State)) @@ -455,11 +455,17 @@ type FakeLocalParticipant struct { setMigrateStateArgsForCall []struct { arg1 types.MigrateState } - SetPermissionStub func(*livekit.ParticipantPermission) + SetPermissionStub func(*livekit.ParticipantPermission) bool setPermissionMutex sync.RWMutex setPermissionArgsForCall []struct { arg1 *livekit.ParticipantPermission } + setPermissionReturns struct { + result1 bool + } + setPermissionReturnsOnCall map[int]struct { + result1 bool + } SetPreviousAnswerStub func(*webrtc.SessionDescription) setPreviousAnswerMutex sync.RWMutex setPreviousAnswerArgsForCall []struct { @@ -2239,35 +2245,35 @@ func (fake *FakeLocalParticipant) OnDataPacketArgsForCall(i int) func(types.Loca return argsForCall.arg1 } -func (fake *FakeLocalParticipant) OnMetadataUpdate(arg1 func(types.LocalParticipant)) { - fake.onMetadataUpdateMutex.Lock() - fake.onMetadataUpdateArgsForCall = append(fake.onMetadataUpdateArgsForCall, struct { +func (fake *FakeLocalParticipant) OnParticipantUpdate(arg1 func(types.LocalParticipant)) { + fake.onParticipantUpdateMutex.Lock() + fake.onParticipantUpdateArgsForCall = append(fake.onParticipantUpdateArgsForCall, struct { arg1 func(types.LocalParticipant) }{arg1}) - stub := fake.OnMetadataUpdateStub - fake.recordInvocation("OnMetadataUpdate", []interface{}{arg1}) - fake.onMetadataUpdateMutex.Unlock() + stub := fake.OnParticipantUpdateStub + fake.recordInvocation("OnParticipantUpdate", []interface{}{arg1}) + fake.onParticipantUpdateMutex.Unlock() if stub != nil { - fake.OnMetadataUpdateStub(arg1) + fake.OnParticipantUpdateStub(arg1) } } -func (fake *FakeLocalParticipant) OnMetadataUpdateCallCount() int { - fake.onMetadataUpdateMutex.RLock() - defer fake.onMetadataUpdateMutex.RUnlock() - return len(fake.onMetadataUpdateArgsForCall) +func (fake *FakeLocalParticipant) OnParticipantUpdateCallCount() int { + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() + return len(fake.onParticipantUpdateArgsForCall) } -func (fake *FakeLocalParticipant) OnMetadataUpdateCalls(stub func(func(types.LocalParticipant))) { - fake.onMetadataUpdateMutex.Lock() - defer fake.onMetadataUpdateMutex.Unlock() - fake.OnMetadataUpdateStub = stub +func (fake *FakeLocalParticipant) OnParticipantUpdateCalls(stub func(func(types.LocalParticipant))) { + fake.onParticipantUpdateMutex.Lock() + defer fake.onParticipantUpdateMutex.Unlock() + fake.OnParticipantUpdateStub = stub } -func (fake *FakeLocalParticipant) OnMetadataUpdateArgsForCall(i int) func(types.LocalParticipant) { - fake.onMetadataUpdateMutex.RLock() - defer fake.onMetadataUpdateMutex.RUnlock() - argsForCall := fake.onMetadataUpdateArgsForCall[i] +func (fake *FakeLocalParticipant) OnParticipantUpdateArgsForCall(i int) func(types.LocalParticipant) { + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() + argsForCall := fake.onParticipantUpdateArgsForCall[i] return argsForCall.arg1 } @@ -3043,17 +3049,23 @@ func (fake *FakeLocalParticipant) SetMigrateStateArgsForCall(i int) types.Migrat return argsForCall.arg1 } -func (fake *FakeLocalParticipant) SetPermission(arg1 *livekit.ParticipantPermission) { +func (fake *FakeLocalParticipant) SetPermission(arg1 *livekit.ParticipantPermission) bool { fake.setPermissionMutex.Lock() + ret, specificReturn := fake.setPermissionReturnsOnCall[len(fake.setPermissionArgsForCall)] fake.setPermissionArgsForCall = append(fake.setPermissionArgsForCall, struct { arg1 *livekit.ParticipantPermission }{arg1}) stub := fake.SetPermissionStub + fakeReturns := fake.setPermissionReturns fake.recordInvocation("SetPermission", []interface{}{arg1}) fake.setPermissionMutex.Unlock() if stub != nil { - fake.SetPermissionStub(arg1) + return stub(arg1) } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 } func (fake *FakeLocalParticipant) SetPermissionCallCount() int { @@ -3062,7 +3074,7 @@ func (fake *FakeLocalParticipant) SetPermissionCallCount() int { return len(fake.setPermissionArgsForCall) } -func (fake *FakeLocalParticipant) SetPermissionCalls(stub func(*livekit.ParticipantPermission)) { +func (fake *FakeLocalParticipant) SetPermissionCalls(stub func(*livekit.ParticipantPermission) bool) { fake.setPermissionMutex.Lock() defer fake.setPermissionMutex.Unlock() fake.SetPermissionStub = stub @@ -3075,6 +3087,29 @@ func (fake *FakeLocalParticipant) SetPermissionArgsForCall(i int) *livekit.Parti return argsForCall.arg1 } +func (fake *FakeLocalParticipant) SetPermissionReturns(result1 bool) { + fake.setPermissionMutex.Lock() + defer fake.setPermissionMutex.Unlock() + fake.SetPermissionStub = nil + fake.setPermissionReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SetPermissionReturnsOnCall(i int, result1 bool) { + fake.setPermissionMutex.Lock() + defer fake.setPermissionMutex.Unlock() + fake.SetPermissionStub = nil + if fake.setPermissionReturnsOnCall == nil { + fake.setPermissionReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.setPermissionReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeLocalParticipant) SetPreviousAnswer(arg1 *webrtc.SessionDescription) { fake.setPreviousAnswerMutex.Lock() fake.setPreviousAnswerArgsForCall = append(fake.setPreviousAnswerArgsForCall, struct { @@ -3959,8 +3994,8 @@ func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { defer fake.onCloseMutex.RUnlock() fake.onDataPacketMutex.RLock() defer fake.onDataPacketMutex.RUnlock() - fake.onMetadataUpdateMutex.RLock() - defer fake.onMetadataUpdateMutex.RUnlock() + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() fake.onStateChangeMutex.RLock() defer fake.onStateChangeMutex.RUnlock() fake.onTrackPublishedMutex.RLock() diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go index 65bfe3aaa..44b19ab4b 100644 --- a/pkg/rtc/uptrackmanager.go +++ b/pkg/rtc/uptrackmanager.go @@ -283,6 +283,9 @@ func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { func (u *UpTrackManager) RemovePublishedTrack(track types.MediaTrack) { track.RemoveAllSubscribers() + u.lock.Lock() + delete(u.publishedTracks, track.ID()) + u.lock.Unlock() } // should be called with lock held diff --git a/pkg/service/auth.go b/pkg/service/auth.go index b24c06c1d..014641800 100644 --- a/pkg/service/auth.go +++ b/pkg/service/auth.go @@ -83,7 +83,8 @@ func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, } func GetGrants(ctx context.Context) *auth.ClaimGrants { - claims, ok := ctx.Value(grantsKey{}).(*auth.ClaimGrants) + val := ctx.Value(grantsKey{}) + claims, ok := val.(*auth.ClaimGrants) if !ok { return nil } diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 170fc35fc..7298ed85e 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -252,10 +252,9 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName livekit.RoomNam CongestionControlConfig: r.config.RTC.CongestionControl, EnabledCodecs: room.Room.EnabledCodecs, Grants: pi.Grants, - Hidden: pi.Hidden, Logger: pLogger, ClientConf: clientConf, - }, pi.Permission) + }) if err != nil { logger.Errorw("could not create participant", err) return @@ -472,7 +471,8 @@ func (r *RoomManager) handleRTCMessage(_ context.Context, roomName livekit.RoomN if participant == nil { return } - pLogger.Debugw("updating participant") + pLogger.Debugw("updating participant", "metadata", rm.UpdateParticipant.Metadata, + "permission", rm.UpdateParticipant.Permission) if rm.UpdateParticipant.Metadata != "" { participant.SetMetadata(rm.UpdateParticipant.Metadata) } diff --git a/pkg/service/roomservice.go b/pkg/service/roomservice.go index f2908916c..7ff5c343e 100644 --- a/pkg/service/roomservice.go +++ b/pkg/service/roomservice.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/thoas/go-funk" "github.com/twitchtv/twirp" + "google.golang.org/protobuf/proto" "github.com/livekit/livekit-server/pkg/routing" ) @@ -229,7 +230,10 @@ func (s *RoomService) UpdateParticipant(ctx context.Context, req *livekit.Update if err != nil { return err } - if participant.Metadata != req.Metadata { + if req.Metadata != "" && participant.Metadata != req.Metadata { + return ErrOperationFailed + } + if req.Permission != nil && !proto.Equal(req.Permission, participant.Permission) { return ErrOperationFailed } return nil diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index 143040807..57006987b 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -96,7 +96,7 @@ func (s *RTCService) validate(r *http.Request) (livekit.RoomName, routing.Partic // this is new connection for existing participant - with publish only permissions if publishParam != "" { // Make sure grant has CanPublish set, - if claims.Video.CanPublish != nil && !*claims.Video.CanPublish { + if !claims.Video.GetCanPublish() { return "", routing.ParticipantInit{}, http.StatusUnauthorized, rtc.ErrPermissionDenied } // Make sure by default subscribe is off @@ -117,9 +117,6 @@ func (s *RTCService) validate(r *http.Request) (livekit.RoomName, routing.Partic Identity: livekit.ParticipantIdentity(claims.Identity), Name: livekit.ParticipantName(claims.Name), AutoSubscribe: true, - Metadata: claims.Metadata, - Hidden: claims.Video.Hidden, - Recorder: claims.Video.Recorder, Client: s.ParseClientInfo(r), Grants: claims, } @@ -127,7 +124,6 @@ func (s *RTCService) validate(r *http.Request) (livekit.RoomName, routing.Partic if autoSubParam != "" { pi.AutoSubscribe = boolValue(autoSubParam) } - pi.Permission = permissionFromGrant(claims.Video) return roomName, pi, http.StatusOK, nil } diff --git a/pkg/service/utils.go b/pkg/service/utils.go index c7c4c0475..c47d1b61b 100644 --- a/pkg/service/utils.go +++ b/pkg/service/utils.go @@ -4,8 +4,6 @@ import ( "net/http" "regexp" - "github.com/livekit/protocol/auth" - "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -24,21 +22,3 @@ func IsValidDomain(domain string) bool { domainRegexp := regexp.MustCompile(`^(?i)[a-z0-9-]+(\.[a-z0-9-]+)+\.?$`) return domainRegexp.MatchString(domain) } - -func permissionFromGrant(claim *auth.VideoGrant) *livekit.ParticipantPermission { - p := &livekit.ParticipantPermission{ - CanSubscribe: true, - CanPublish: true, - CanPublishData: true, - } - if claim.CanPublish != nil { - p.CanPublish = *claim.CanPublish - } - if claim.CanSubscribe != nil { - p.CanSubscribe = *claim.CanSubscribe - } - if claim.CanPublishData != nil { - p.CanPublishData = *claim.CanPublishData - } - return p -} diff --git a/test/client/client.go b/test/client/client.go index 3ecee6497..c8b294b3d 100644 --- a/test/client/client.go +++ b/test/client/client.go @@ -36,6 +36,7 @@ type RTCClient struct { subscriber *rtc.PCTransport // sid => track localTracks map[string]webrtc.TrackLocal + trackSenders map[string]*webrtc.RTPSender lock sync.Mutex wsLock sync.Mutex ctx context.Context @@ -89,7 +90,7 @@ type Options struct { } func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) { - u, err := url.Parse(host + "/rtc?protocol=6") + u, err := url.Parse(host + "/rtc?protocol=7") if err != nil { return nil, err } @@ -115,6 +116,7 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { c := &RTCClient{ conn: conn, localTracks: make(map[string]webrtc.TrackLocal), + trackSenders: make(map[string]*webrtc.RTPSender), pendingPublishedTracks: make(map[string]*livekit.TrackInfo), subscribedTracks: make(map[livekit.ParticipantID][]*webrtc.TrackRemote), remoteParticipants: make(map[livekit.ParticipantID]*livekit.ParticipantInfo), @@ -329,6 +331,19 @@ func (c *RTCClient) Run() error { c.lock.Lock() c.refreshToken = msg.RefreshToken c.lock.Unlock() + case *livekit.SignalResponse_TrackUnpublished: + sid := msg.TrackUnpublished.TrackSid + c.lock.Lock() + sender := c.trackSenders[sid] + if sender != nil { + if err := c.publisher.PeerConnection().RemoveTrack(sender); err != nil { + logger.Errorw("Could not unpublish track", err) + } + c.publisher.Negotiate() + } + delete(c.trackSenders, sid) + delete(c.localTracks, sid) + c.lock.Unlock() } } } @@ -477,11 +492,13 @@ func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string) c.lock.Lock() defer c.lock.Unlock() - c.localTracks[ti.Sid] = track - if _, err = c.publisher.PeerConnection().AddTrack(track); err != nil { + sender, err := c.publisher.PeerConnection().AddTrack(track) + if err != nil { return } + c.localTracks[ti.Sid] = track + c.trackSenders[ti.Sid] = sender c.publisher.Negotiate() writer = NewTrackWriter(c.ctx, track, path) diff --git a/test/multinode_test.go b/test/multinode_test.go index 3ed875163..5250a83f7 100644 --- a/test/multinode_test.go +++ b/test/multinode_test.go @@ -200,3 +200,51 @@ func TestMultiNodeRefreshToken(t *testing.T) { return "" }) } + +func TestMultiNodeRevokePublishPermission(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeRevokePublishPermission") + defer finish() + + c1 := createRTCClient("c1", defaultServerPort, nil) + c2 := createRTCClient("c2", secondServerPort, nil) + waitUntilConnected(t, c1, c2) + + // c1 publishes a track for c2 + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 did not receive c1's tracks" + } + return "" + }) + + // revoke permission + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "c1", + Permission: &livekit.ParticipantPermission{ + CanPublish: false, + CanPublishData: true, + CanSubscribe: true, + }, + }) + require.NoError(t, err) + + // ensure c1 no longer has track published, c2 no longer see track under C1 + testutils.WithTimeout(t, func() string { + if len(c1.GetPublishedTrackIDs()) != 0 { + return "c1 did not unpublish tracks" + } + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + if remoteC1 == nil { + return "c2 doesn't know about c1" + } + if len(remoteC1.Tracks) != 0 { + return "c2 still has c1's tracks" + } + return "" + }) +}