diff --git a/pkg/rtc/mediatrack.go b/pkg/rtc/mediatrack.go index 886520294..f811e4e14 100644 --- a/pkg/rtc/mediatrack.go +++ b/pkg/rtc/mediatrack.go @@ -50,10 +50,10 @@ type MediaTrack struct { simulcasted utils.AtomicFlag buffer *buffer.Buffer - // channel to send RTCP packets to the source lock sync.RWMutex + // map of target participantId -> types.SubscribedTrack - subscribedTracks sync.Map + subscribedTracks sync.Map // participantSid => types.SubscribedTrack twcc *twcc.Responder audioLevel *AudioLevel receiver sfu.Receiver @@ -70,6 +70,7 @@ type MediaTrack struct { connectionStats *connectionquality.ConnectionStats done chan struct{} + // quality level enable/disable maxQualityLock sync.Mutex maxSubscriberQuality map[string]livekit.VideoQuality @@ -86,12 +87,13 @@ type MediaTrackParams struct { SdpCid string ParticipantID string ParticipantIdentity string - RTCPChan chan []rtcp.Packet - BufferFactory *buffer.Factory - ReceiverConfig ReceiverConfig - AudioConfig config.AudioConfig - Telemetry telemetry.TelemetryService - Logger logger.Logger + // channel to send RTCP packets to the source + RTCPChan chan []rtcp.Packet + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + AudioConfig config.AudioConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger } func NewMediaTrack(track *webrtc.TrackRemote, params MediaTrackParams) *MediaTrack { @@ -196,8 +198,10 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { t.lock.Lock() defer t.lock.Unlock() + subscriberID := sub.ID() + // don't subscribe to the same track multiple times - if _, ok := t.subscribedTracks.Load(sub.ID()); ok { + if _, ok := t.subscribedTracks.Load(subscriberID); ok { return nil } @@ -221,14 +225,14 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { Channels: codec.Channels, SDPFmtpLine: codec.SDPFmtpLine, RTCPFeedback: FeedbackTypes, - }, receiver, t.params.BufferFactory, sub.ID(), t.params.ReceiverConfig.PacketBufferSize) + }, receiver, t.params.BufferFactory, subscriberID, t.params.ReceiverConfig.PacketBufferSize) if err != nil { return err } subTrack := NewSubscribedTrack(SubscribedTrackParams{ PublisherID: t.params.ParticipantID, PublisherIdentity: t.params.ParticipantIdentity, - SubscriberID: sub.ID(), + SubscriberID: subscriberID, MediaTrack: t, DownTrack: downTrack, }) @@ -282,19 +286,19 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { go t.sendDownTrackBindingReports(sub) }) downTrack.OnPacketSent(func(_ *sfu.DownTrack, size int) { - t.params.Telemetry.OnDownstreamPacket(sub.ID(), size) + t.params.Telemetry.OnDownstreamPacket(subscriberID, size) }) downTrack.OnPaddingSent(func(_ *sfu.DownTrack, size int) { - t.params.Telemetry.OnDownstreamPacket(sub.ID(), size) + t.params.Telemetry.OnDownstreamPacket(subscriberID, size) }) downTrack.OnRTCP(func(pkts []rtcp.Packet) { - t.params.Telemetry.HandleRTCP(livekit.StreamType_DOWNSTREAM, sub.ID(), pkts) + t.params.Telemetry.HandleRTCP(livekit.StreamType_DOWNSTREAM, subscriberID, pkts) }) downTrack.OnCloseHandler(func() { go func() { - t.subscribedTracks.Delete(sub.ID()) - t.params.Telemetry.TrackUnsubscribed(context.Background(), sub.ID(), t.ToProto()) + t.subscribedTracks.Delete(subscriberID) + t.params.Telemetry.TrackUnsubscribed(context.Background(), subscriberID, t.ToProto()) // ignore if the subscribing sub is not connected if sub.SubscriberPC().ConnectionState() == webrtc.PeerConnectionStateClosed { @@ -309,7 +313,7 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { t.params.Logger.Debugw("removing peerconnection track", "track", t.ID(), "subscriber", sub.Identity(), - "subscriberID", sub.ID(), + "subscriberID", subscriberID, "kind", t.Kind(), ) if err := sub.SubscriberPC().RemoveTrack(sender); err != nil { @@ -323,12 +327,12 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { t.params.Logger.Debugw("could not remove remoteTrack from forwarder", "error", err, "subscriber", sub.Identity(), - "subscriberID", sub.ID(), + "subscriberID", subscriberID, ) } } - t.NotifySubscriberMute(sub.ID()) + t.NotifySubscriberMute(subscriberID) sub.RemoveSubscribedTrack(subTrack) sub.Negotiate() }() @@ -337,18 +341,18 @@ func (t *MediaTrack) AddSubscriber(sub types.Participant) error { downTrack.AddReceiverReportListener(t.handleMaxLossFeedback) } - t.subscribedTracks.Store(sub.ID(), subTrack) + t.subscribedTracks.Store(subscriberID, subTrack) subTrack.SetPublisherMuted(t.IsMuted()) t.receiver.AddDownTrack(downTrack) // since sub will lock, run it in a goroutine to avoid deadlocks go func() { - t.NotifySubscriberMaxQuality(sub.ID(), livekit.VideoQuality_HIGH) // start with HIGH, let subscription change it later + t.NotifySubscriberMaxQuality(subscriberID, livekit.VideoQuality_HIGH) // start with HIGH, let subscription change it later sub.AddSubscribedTrack(subTrack) sub.Negotiate() }() - t.params.Telemetry.TrackSubscribed(context.Background(), sub.ID(), t.ToProto()) + t.params.Telemetry.TrackSubscribed(context.Background(), subscriberID, t.ToProto()) return nil } @@ -469,6 +473,35 @@ func (t *MediaTrack) RemoveAllSubscribers() { t.subscribedTracks = sync.Map{} } +func (t *MediaTrack) RevokeDisallowedSubscribers(allowedSubscriberIDs []string) []string { + t.lock.Lock() + defer t.lock.Unlock() + + var revokedSubscriberIDs []string + // LK-TODO: large number of subscribers needs to be solved for this loop + t.subscribedTracks.Range(func(key interface{}, val interface{}) bool { + if subID, ok := key.(string); ok { + found := false + for _, allowedID := range allowedSubscriberIDs { + if subID == allowedID { + found = true + break + } + } + + if !found { + if subTrack, ok := val.(types.SubscribedTrack); ok { + go subTrack.DownTrack().Close() + revokedSubscriberIDs = append(revokedSubscriberIDs, subID) + } + } + } + return true + }) + + return revokedSubscriberIDs +} + func (t *MediaTrack) ToProto() *livekit.TrackInfo { info := t.params.TrackInfo info.Muted = t.IsMuted() @@ -543,8 +576,8 @@ func (t *MediaTrack) GetQualityForDimension(width, height uint32) livekit.VideoQ return quality } -func (t *MediaTrack) getSubscribedTrack(id string) types.SubscribedTrack { - if val, ok := t.subscribedTracks.Load(id); ok { +func (t *MediaTrack) getSubscribedTrack(subscriberID string) types.SubscribedTrack { + if val, ok := t.subscribedTracks.Load(subscriberID); ok { if st, ok := val.(types.SubscribedTrack); ok { return st } diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index b03a5e05d..f0b7fc542 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -57,8 +57,6 @@ type ParticipantImpl struct { isClosed utils.AtomicFlag permission *livekit.ParticipantPermission state atomic.Value // livekit.ParticipantInfo_State - rtcpCh chan []rtcp.Packet - pliThrottle *pliThrottle updateCache *lru.Cache // reliable and unreliable data channels @@ -76,12 +74,12 @@ type ParticipantImpl struct { // hold reference for MediaTrack twcc *twcc.Responder + uptrackManager *UptrackManager + // tracks the current participant is subscribed to, map of sid => DownTrack subscribedTracks map[string]types.SubscribedTrack - // publishedTracks that participant is publishing - publishedTracks map[string]types.PublishedTrack - // client intended to publish, yet to be reconciled - pendingTracks map[string]*livekit.TrackInfo + // keeps track of disallowed tracks + disallowedSubscriptions map[string]string // trackSid -> publisherSid // keep track of other publishers identities that we are subscribed to subscribedTo sync.Map // string => struct{} @@ -95,20 +93,17 @@ type ParticipantImpl struct { onStateChange func(p types.Participant, oldState livekit.ParticipantInfo_State) onMetadataUpdate func(types.Participant) onDataPacket func(types.Participant, *livekit.DataPacket) - onClose func(types.Participant) + onClose func(types.Participant, map[string]string) } func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { // TODO: check to ensure params are valid, id and identity can't be empty p := &ParticipantImpl{ - params: params, - rtcpCh: make(chan []rtcp.Packet, 50), - pliThrottle: newPLIThrottle(params.ThrottleConfig), - subscribedTracks: make(map[string]types.SubscribedTrack), - publishedTracks: make(map[string]types.PublishedTrack, 0), - pendingTracks: make(map[string]*livekit.TrackInfo), - connectedAt: time.Now(), + params: params, + subscribedTracks: make(map[string]types.SubscribedTrack), + disallowedSubscriptions: make(map[string]string), + connectedAt: time.Now(), } p.state.Store(livekit.ParticipantInfo_JOINING) @@ -184,6 +179,8 @@ func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { p.subscriber.OnStreamStateChange(p.onStreamStateChange) + p.setupUptrackManager() + return p, nil } @@ -225,10 +222,6 @@ func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermissio p.permission = permission } -func (p *ParticipantImpl) RTCPChan() chan []rtcp.Packet { - return p.rtcpCh -} - func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { info := &livekit.ParticipantInfo{ Sid: p.params.SID, @@ -239,12 +232,8 @@ func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { Hidden: p.Hidden(), Recorder: p.IsRecorder(), } + info.Tracks = p.uptrackManager.ToProto() - p.lock.RLock() - for _, t := range p.publishedTracks { - info.Tracks = append(info.Tracks, t.ToProto()) - } - p.lock.RUnlock() return info } @@ -282,7 +271,7 @@ func (p *ParticipantImpl) OnDataPacket(callback func(types.Participant, *livekit p.onDataPacket = callback } -func (p *ParticipantImpl) OnClose(callback func(types.Participant)) { +func (p *ParticipantImpl) OnClose(callback func(types.Participant, map[string]string)) { p.onClose = callback } @@ -339,32 +328,15 @@ func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { p.lock.Lock() defer p.lock.Unlock() - // if track is already published, reject - if p.pendingTracks[req.Cid] != nil { - return - } - - if p.getPublishedTrackBySignalCid(req.Cid) != nil || p.getPublishedTrackBySdpCid(req.Cid) != nil { - return - } - if !p.CanPublish() { p.params.Logger.Warnw("no permission to publish track", nil) return } - ti := &livekit.TrackInfo{ - Type: req.Type, - Name: req.Name, - Sid: utils.NewGuid(utils.TrackPrefix), - Width: req.Width, - Height: req.Height, - Muted: req.Muted, - DisableDtx: req.DisableDtx, - Source: req.Source, - Layers: req.Layers, + ti := p.uptrackManager.AddTrack(req) + if ti == nil { + return } - p.pendingTracks[req.Cid] = ti _ = p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_TrackPublished{ @@ -404,7 +376,7 @@ func (p *ParticipantImpl) AddICECandidate(candidate webrtc.ICECandidateInit, tar func (p *ParticipantImpl) Start() { p.once.Do(func() { - go p.rtcpSendWorker() + p.uptrackManager.Start() go p.downTracksRTCPWorker() }) } @@ -422,13 +394,15 @@ func (p *ParticipantImpl) Close() error { }, }) - // remove all downtracks + p.uptrackManager.Close() + p.lock.Lock() - for _, t := range p.publishedTracks { - // skip updates - t.RemoveAllSubscribers() + disallowedSubscriptions := make(map[string]string) + for trackSid, publisherSid := range p.disallowedSubscriptions { + disallowedSubscriptions[trackSid] = publisherSid } + // remove all downtracks var downtracksToClose []*sfu.DownTrack for _, st := range p.subscribedTracks { downtracksToClose = append(downtracksToClose, st.DownTrack()) @@ -447,11 +421,10 @@ func (p *ParticipantImpl) Close() error { onClose := p.onClose p.lock.RUnlock() if onClose != nil { - onClose(p) + onClose(p, disallowedSubscriptions) } p.publisher.Close() p.subscriber.Close() - close(p.rtcpCh) return nil } @@ -470,27 +443,13 @@ func (p *ParticipantImpl) ICERestart() error { }) } -// AddSubscriber subscribes op to all publishedTracks -func (p *ParticipantImpl) AddSubscriber(op types.Participant) (int, error) { - tracks := p.GetPublishedTracks() +// AddSubscriber subscribes op to all publishedTracks or given set of tracks +func (p *ParticipantImpl) AddSubscriber(op types.Participant, params types.AddSubscriberParams) (int, error) { + return p.uptrackManager.AddSubscriber(op, params) +} - if len(tracks) == 0 { - return 0, nil - } - - p.params.Logger.Debugw("subscribing new participant to tracks", - "subscriber", op.Identity(), - "subscriberID", op.ID(), - "numTracks", len(tracks)) - - n := 0 - for _, track := range tracks { - if err := track.AddSubscriber(op); err != nil { - return n, err - } - n += 1 - } - return n, nil +func (p *ParticipantImpl) RemoveSubscriber(op types.Participant, trackSid string) { + p.uptrackManager.RemoveSubscriber(op, trackSid) } // signal connection methods @@ -616,82 +575,31 @@ func (p *ParticipantImpl) SendConnectionQualityUpdate(update *livekit.Connection }) } -func (p *ParticipantImpl) SetTrackMuted(trackId string, muted bool, fromAdmin bool) { - isPending := false - p.lock.RLock() - for _, ti := range p.pendingTracks { - if ti.Sid == trackId { - ti.Muted = muted - isPending = true - } - } - track := p.publishedTracks[trackId] - p.lock.RUnlock() - - if track == nil { - if !isPending { - p.params.Logger.Warnw("could not locate track", nil, "track", trackId) - } - return - } - currentMuted := track.IsMuted() - track.SetMuted(muted) - +func (p *ParticipantImpl) SetTrackMuted(trackSid string, muted bool, fromAdmin bool) { // when request is coming from admin, send message to current participant if fromAdmin { _ = p.writeMessage(&livekit.SignalResponse{ Message: &livekit.SignalResponse_Mute{ Mute: &livekit.MuteTrackRequest{ - Sid: trackId, + Sid: trackSid, Muted: muted, }, }, }) } - if currentMuted != track.IsMuted() && p.onTrackUpdated != nil { - p.params.Logger.Debugw("mute status changed", - "track", trackId, - "muted", track.IsMuted()) - p.onTrackUpdated(p, track) - } + p.uptrackManager.SetTrackMuted(trackSid, muted) } func (p *ParticipantImpl) GetAudioLevel() (level uint8, active bool) { - p.lock.RLock() - defer p.lock.RUnlock() - level = silentAudioLevel - for _, pt := range p.publishedTracks { - if mt, ok := pt.(*MediaTrack); ok { - if mt.audioLevel == nil { - continue - } - tl, ta := mt.audioLevel.GetLevel() - if ta { - active = true - if tl < level { - level = tl - } - } - } - } - return + return p.uptrackManager.GetAudioLevel() } func (p *ParticipantImpl) GetConnectionQuality() *livekit.ConnectionQualityInfo { // avg loss across all tracks, weigh published the same as subscribed - var scores float64 - var numTracks int - p.lock.RLock() - defer p.lock.RUnlock() - for _, pubTrack := range p.publishedTracks { - if pubTrack.IsMuted() { - continue - } - scores += pubTrack.GetConnectionScore() - numTracks++ - } + scores, numTracks := p.uptrackManager.GetConnectionQuality() + p.lock.RLock() for _, subTrack := range p.subscribedTracks { if subTrack.IsMuted() || subTrack.MediaTrack().IsMuted() { continue @@ -699,6 +607,7 @@ func (p *ParticipantImpl) GetConnectionQuality() *livekit.ConnectionQualityInfo scores += subTrack.DownTrack().GetConnectionScore() numTracks++ } + p.lock.RUnlock() avgScore := 5.0 if numTracks > 0 { @@ -759,19 +668,11 @@ func (p *ParticipantImpl) SubscriberPC() *webrtc.PeerConnection { } func (p *ParticipantImpl) GetPublishedTrack(sid string) types.PublishedTrack { - p.lock.RLock() - defer p.lock.RUnlock() - return p.publishedTracks[sid] + return p.uptrackManager.GetPublishedTrack(sid) } func (p *ParticipantImpl) GetPublishedTracks() []types.PublishedTrack { - p.lock.RLock() - defer p.lock.RUnlock() - tracks := make([]types.PublishedTrack, 0, len(p.publishedTracks)) - for _, t := range p.publishedTracks { - tracks = append(tracks, t) - } - return tracks + return p.uptrackManager.GetPublishedTracks() } func (p *ParticipantImpl) GetSubscribedTrack(sid string) types.SubscribedTrack { @@ -830,6 +731,72 @@ func (p *ParticipantImpl) RemoveSubscribedTrack(subTrack types.SubscribedTrack) } } +func (p *ParticipantImpl) UpdateSubscriptionPermissions( + permissions *livekit.UpdateSubscriptionPermissions, + resolver func(participantSid string) types.Participant, +) error { + return p.uptrackManager.UpdateSubscriptionPermissions(permissions, resolver) +} + +func (p *ParticipantImpl) SubscriptionPermissionUpdate(publisherSid string, trackSid string, allowed bool) { + p.lock.Lock() + if allowed { + delete(p.disallowedSubscriptions, trackSid) + } else { + p.disallowedSubscriptions[trackSid] = publisherSid + } + p.lock.Unlock() + + err := p.writeMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_SubscriptionPermissionUpdate{ + SubscriptionPermissionUpdate: &livekit.SubscriptionPermissionUpdate{ + ParticipantSid: publisherSid, + TrackSid: trackSid, + Allowed: allowed, + }, + }, + }) + if err != nil { + p.params.Logger.Errorw("could not send subscription permission update", err) + } +} + +func (p *ParticipantImpl) setupUptrackManager() { + p.uptrackManager = NewUptrackManager(UptrackManagerParams{ + Identity: p.params.Identity, + SID: p.params.SID, + Config: p.params.Config, + AudioConfig: p.params.AudioConfig, + Telemetry: p.params.Telemetry, + ThrottleConfig: p.params.ThrottleConfig, + Logger: p.params.Logger, + }) + + p.uptrackManager.OnTrackPublished(func(track types.PublishedTrack) { + if p.onTrackPublished != nil { + p.onTrackPublished(p, track) + } + }) + + p.uptrackManager.OnTrackUpdated(func(track types.PublishedTrack, onlyIfReady bool) { + if onlyIfReady && !p.IsReady() { + return + } + + if p.onTrackUpdated != nil { + p.onTrackUpdated(p, track) + } + }) + + p.uptrackManager.OnWriteRTCP(func(pkts []rtcp.Packet) { + if err := p.publisher.pc.WriteRTCP(pkts); err != nil { + p.params.Logger.Errorw("could not write RTCP to participant", err) + } + }) + + p.uptrackManager.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) +} + func (p *ParticipantImpl) sendIceCandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) { ci := c.ToJSON() @@ -918,56 +885,7 @@ func (p *ParticipantImpl) onMediaTrack(track *webrtc.TrackRemote, rtpReceiver *w return } - var newTrack bool - - // use existing mediatrack to handle simulcast - p.lock.Lock() - mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) - if !ok { - signalCid, ti := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) - if ti == nil { - p.lock.Unlock() - return - } - - mt = NewMediaTrack(track, MediaTrackParams{ - TrackInfo: ti, - SignalCid: signalCid, - SdpCid: track.ID(), - ParticipantID: p.params.SID, - ParticipantIdentity: p.Identity(), - RTCPChan: p.rtcpCh, - BufferFactory: p.params.Config.BufferFactory, - ReceiverConfig: p.params.Config.Receiver, - AudioConfig: p.params.AudioConfig, - Telemetry: p.params.Telemetry, - Logger: p.params.Logger, - }) - - mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) - - // add to published and clean up pending - p.publishedTracks[mt.ID()] = mt - delete(p.pendingTracks, signalCid) - - newTrack = true - } - - ssrc := uint32(track.SSRC()) - p.pliThrottle.addTrack(ssrc, track.RID()) - if p.twcc == nil { - p.twcc = twcc.NewTransportWideCCResponder(ssrc) - p.twcc.OnFeedback(func(pkt rtcp.RawPacket) { - _ = p.publisher.pc.WriteRTCP([]rtcp.Packet{&pkt}) - }) - } - p.lock.Unlock() - - mt.AddReceiver(rtpReceiver, track, p.twcc) - - if newTrack { - p.handleTrackPublished(mt) - } + p.uptrackManager.MediaTrackReceived(track, rtpReceiver) } func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { @@ -990,52 +908,6 @@ func (p *ParticipantImpl) onDataChannel(dc *webrtc.DataChannel) { } } -// should be called with lock held -func (p *ParticipantImpl) getPublishedTrackBySignalCid(clientId string) types.PublishedTrack { - for _, publishedTrack := range p.publishedTracks { - if publishedTrack.SignalCid() == clientId { - return publishedTrack - } - } - - return nil -} - -// should be called with lock held -func (p *ParticipantImpl) getPublishedTrackBySdpCid(clientId string) types.PublishedTrack { - for _, publishedTrack := range p.publishedTracks { - if publishedTrack.SdpCid() == clientId { - return publishedTrack - } - } - - return nil -} - -// should be called with lock held -func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { - signalCid := clientId - ti := p.pendingTracks[clientId] - - // then find the first one that matches type. with MediaStreamTrack, it's possible for the client id to - // change after being added to SubscriberPC - if ti == nil { - for cid, info := range p.pendingTracks { - if info.Type == kind { - ti = info - signalCid = cid - break - } - } - } - - // if still not found, we are done - if ti == nil { - p.params.Logger.Errorw("track info not published prior to track", nil, "clientId", clientId) - } - return signalCid, ti -} - func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) { dp := livekit.DataPacket{} if err := proto.Unmarshal(data, &dp); err != nil { @@ -1058,29 +930,6 @@ func (p *ParticipantImpl) handleDataMessage(kind livekit.DataPacket_Kind, data [ } } -func (p *ParticipantImpl) handleTrackPublished(track types.PublishedTrack) { - p.lock.Lock() - if _, ok := p.publishedTracks[track.ID()]; !ok { - p.publishedTracks[track.ID()] = track - } - p.lock.Unlock() - - track.AddOnClose(func() { - // cleanup - p.lock.Lock() - delete(p.publishedTracks, track.ID()) - p.lock.Unlock() - // only send this when client is in a ready state - if p.IsReady() && p.onTrackUpdated != nil { - p.onTrackUpdated(p, track) - } - }) - - if p.onTrackPublished != nil { - p.onTrackPublished(p, track) - } -} - func (p *ParticipantImpl) handlePrimaryICEStateChange(state webrtc.ICEConnectionState) { if state == webrtc.ICEConnectionStateConnected { prometheus.ServiceOperationCounter.WithLabelValues("ice_connection", "success", "").Add(1) @@ -1158,41 +1007,6 @@ func (p *ParticipantImpl) downTracksRTCPWorker() { } } -func (p *ParticipantImpl) rtcpSendWorker() { - defer Recover() - - // read from rtcpChan - for pkts := range p.rtcpCh { - if pkts == nil { - return - } - - fwdPkts := make([]rtcp.Packet, 0, len(pkts)) - for _, pkt := range pkts { - switch pkt.(type) { - case *rtcp.PictureLossIndication: - mediaSSRC := pkt.(*rtcp.PictureLossIndication).MediaSSRC - if p.pliThrottle.canSend(mediaSSRC) { - fwdPkts = append(fwdPkts, pkt) - } - case *rtcp.FullIntraRequest: - mediaSSRC := pkt.(*rtcp.FullIntraRequest).MediaSSRC - if p.pliThrottle.canSend(mediaSSRC) { - fwdPkts = append(fwdPkts, pkt) - } - default: - fwdPkts = append(fwdPkts, pkt) - } - } - - if len(fwdPkts) > 0 { - if err := p.publisher.pc.WriteRTCP(fwdPkts); err != nil { - p.params.Logger.Errorw("could not write RTCP to participant", err) - } - } - } -} - func (p *ParticipantImpl) configureReceiverDTX() { // // DTX (Discontinuous Transmission) allows audio bandwidth saving @@ -1229,25 +1043,7 @@ func (p *ParticipantImpl) configureReceiverDTX() { // multiple audio tracks. At that point, there might be a need to // rely on something like order of tracks. TODO // - enableDTX := false - - p.lock.RLock() - var pendingTrack *livekit.TrackInfo - for _, track := range p.pendingTracks { - if track.Type == livekit.TrackType_AUDIO { - pendingTrack = track - break - } - } - - if pendingTrack == nil { - p.lock.RUnlock() - return - } - - enableDTX = !pendingTrack.DisableDtx - p.lock.RUnlock() - + enableDTX := p.uptrackManager.GetDTX() transceivers := p.publisher.pc.GetTransceivers() for _, transceiver := range transceivers { if transceiver.Kind() != webrtc.RTPCodecTypeAudio { @@ -1343,41 +1139,19 @@ func (p *ParticipantImpl) DebugInfo() map[string]interface{} { "State": p.State().String(), } - publishedTrackInfo := make(map[string]interface{}) + uptrackManagerInfo := p.uptrackManager.DebugInfo() + subscribedTrackInfo := make(map[string]interface{}) - pendingTrackInfo := make(map[string]interface{}) - p.lock.RLock() - for trackID, track := range p.publishedTracks { - if mt, ok := track.(*MediaTrack); ok { - publishedTrackInfo[trackID] = mt.DebugInfo() - } else { - publishedTrackInfo[trackID] = map[string]interface{}{ - "ID": track.ID(), - "Kind": track.Kind().String(), - "PubMuted": track.IsMuted(), - } - } - } - for _, track := range p.subscribedTracks { dt := track.DownTrack().DebugInfo() dt["SubMuted"] = track.IsMuted() subscribedTrackInfo[track.ID()] = dt } - - for clientID, track := range p.pendingTracks { - pendingTrackInfo[clientID] = map[string]interface{}{ - "Sid": track.Sid, - "Type": track.Type.String(), - "Simulcast": track.Simulcast, - } - } p.lock.RUnlock() - info["PublishedTracks"] = publishedTrackInfo + info["UptrackManager"] = uptrackManagerInfo info["SubscribedTracks"] = subscribedTrackInfo - info["PendingTracks"] = pendingTrackInfo return info } diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 8dbaf2dad..7446ee069 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -1,7 +1,6 @@ package rtc import ( - "github.com/livekit/livekit-server/pkg/sfu/connectionquality" "testing" "time" @@ -14,6 +13,7 @@ import ( "github.com/livekit/livekit-server/pkg/routing/routingfakes" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" + "github.com/livekit/livekit-server/pkg/sfu/connectionquality" ) func TestIsReady(t *testing.T) { @@ -52,7 +52,7 @@ func TestICEStateChange(t *testing.T) { t.Run("onClose gets called when ICE disconnected", func(t *testing.T) { p := newParticipantForTest("test") closeChan := make(chan struct{}) - p.onClose = func(participant types.Participant) { + p.onClose = func(participant types.Participant, disallowedSubscriptions map[string]string) { close(closeChan) } p.handlePrimaryICEStateChange(webrtc.ICEConnectionStateFailed) @@ -80,21 +80,19 @@ func TestTrackPublishing(t *testing.T) { p.OnTrackPublished(func(p types.Participant, track types.PublishedTrack) { published = true }) - p.handleTrackPublished(track) + p.uptrackManager.handleTrackPublished(track) require.True(t, published) require.False(t, updated) - require.Len(t, p.publishedTracks, 1) + require.Len(t, p.uptrackManager.publishedTracks, 1) track.AddOnCloseArgsForCall(0)() - require.Len(t, p.publishedTracks, 0) + require.Len(t, p.uptrackManager.publishedTracks, 0) require.True(t, updated) }) t.Run("sends back trackPublished event", func(t *testing.T) { p := newParticipantForTest("test") - // track := &typesfakes.FakePublishedTrack{} - // track.IDReturns("id") sink := p.params.Sink.(*routingfakes.FakeMessageSink) p.AddTrack(&livekit.AddTrackRequest{ Cid: "cid", @@ -116,8 +114,6 @@ func TestTrackPublishing(t *testing.T) { t.Run("should not allow adding of duplicate tracks", func(t *testing.T) { p := newParticipantForTest("test") - // track := &typesfakes.FakePublishedTrack{} - // track.IDReturns("id") sink := p.params.Sink.(*routingfakes.FakeMessageSink) p.AddTrack(&livekit.AddTrackRequest{ Cid: "cid", @@ -140,7 +136,7 @@ func TestTrackPublishing(t *testing.T) { track := &typesfakes.FakePublishedTrack{} track.SignalCidReturns("cid") // directly add to publishedTracks without lock - for testing purpose only - p.publishedTracks["cid"] = track + p.uptrackManager.publishedTracks["cid"] = track p.AddTrack(&livekit.AddTrackRequest{ Cid: "cid", @@ -157,7 +153,7 @@ func TestTrackPublishing(t *testing.T) { track := &typesfakes.FakePublishedTrack{} track.SdpCidReturns("cid") // directly add to publishedTracks without lock - for testing purpose only - p.publishedTracks["cid"] = track + p.uptrackManager.publishedTracks["cid"] = track p.AddTrack(&livekit.AddTrackRequest{ Cid: "cid", @@ -206,7 +202,7 @@ func TestDisconnectTiming(t *testing.T) { } }() track := &typesfakes.FakePublishedTrack{} - p.handleTrackPublished(track) + p.uptrackManager.handleTrackPublished(track) // close channel and then try to Negotiate msg.Close() @@ -224,11 +220,10 @@ func TestMuteSetting(t *testing.T) { t.Run("can set mute when track is pending", func(t *testing.T) { p := newParticipantForTest("test") ti := &livekit.TrackInfo{Sid: "testTrack"} - p.pendingTracks["cid"] = ti + p.uptrackManager.pendingTracks["cid"] = ti p.SetTrackMuted(ti.Sid, true, false) require.True(t, ti.Muted) - }) t.Run("can publish a muted track", func(t *testing.T) { @@ -239,7 +234,7 @@ func TestMuteSetting(t *testing.T) { Muted: true, }) - _, ti := p.getPendingTrack("cid", livekit.TrackType_AUDIO) + _, ti := p.uptrackManager.getPendingTrack("cid", livekit.TrackType_AUDIO) require.NotNil(t, ti) require.True(t, ti.Muted) }) @@ -287,30 +282,30 @@ func TestConnectionQuality(t *testing.T) { t.Run("smooth sailing", func(t *testing.T) { p := newParticipantForTest("test") - p.publishedTracks["video"] = testPublishedVideoTrack(2, 3, 3) - p.publishedTracks["audio"] = testPublishedAudioTrack(1000, 0) + p.uptrackManager.publishedTracks["video"] = testPublishedVideoTrack(2, 3, 3) + p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 0) require.Equal(t, livekit.ConnectionQuality_EXCELLENT, p.GetConnectionQuality().GetQuality()) }) t.Run("reduced publishing", func(t *testing.T) { p := newParticipantForTest("test") - p.publishedTracks["video"] = testPublishedVideoTrack(3, 2, 3) - p.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) + p.uptrackManager.publishedTracks["video"] = testPublishedVideoTrack(3, 2, 3) + p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) require.Equal(t, livekit.ConnectionQuality_GOOD, p.GetConnectionQuality().GetQuality()) }) t.Run("audio smooth publishing", func(t *testing.T) { p := newParticipantForTest("test") - p.publishedTracks["audio"] = testPublishedAudioTrack(1000, 10) + p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 10) require.Equal(t, livekit.ConnectionQuality_EXCELLENT, p.GetConnectionQuality().GetQuality()) }) t.Run("audio reduced publishing", func(t *testing.T) { p := newParticipantForTest("test") - p.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) + p.uptrackManager.publishedTracks["audio"] = testPublishedAudioTrack(1000, 100) require.Equal(t, livekit.ConnectionQuality_GOOD, p.GetConnectionQuality().GetQuality()) }) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 03227c78d..a0fb8c91e 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -342,13 +342,13 @@ func (r *Room) UpdateSubscriptions( subscribe bool, ) error { // find all matching tracks - tracks := make(map[string]types.PublishedTrack) + trackPublishers := make(map[string]types.Participant) participants := r.GetParticipants() for _, trackSid := range trackIds { for _, p := range participants { track := p.GetPublishedTrack(trackSid) if track != nil { - tracks[trackSid] = track + trackPublishers[trackSid] = p break } } @@ -360,26 +360,38 @@ func (r *Room) UpdateSubscriptions( continue } for _, trackSid := range pt.TrackSids { - track := p.GetPublishedTrack(trackSid) - if track != nil { - tracks[trackSid] = track - } + trackPublishers[trackSid] = p } } // handle subscription changes - for _, track := range tracks { + for trackSid, publisher := range trackPublishers { if subscribe { - if err := track.AddSubscriber(participant); err != nil { + if _, err := publisher.AddSubscriber(participant, types.AddSubscriberParams{TrackSids: []string{trackSid}}); err != nil { return err } } else { - track.RemoveSubscriber(participant.ID()) + publisher.RemoveSubscriber(participant, trackSid) } } return nil } +func (r *Room) UpdateSubscriptionPermissions(participant types.Participant, permissions *livekit.UpdateSubscriptionPermissions) error { + return participant.UpdateSubscriptionPermissions(permissions, r.GetParticipantBySid) +} + +func (r *Room) RemoveDisallowedSubscriptions(sub types.Participant, disallowedSubscriptions map[string]string) { + for trackSid, publisherSid := range disallowedSubscriptions { + pub := r.GetParticipantBySid(publisherSid) + if pub == nil { + continue + } + + pub.RemoveSubscriber(sub, trackSid) + } +} + func (r *Room) IsClosed() bool { select { case <-r.closed: @@ -523,7 +535,7 @@ func (r *Room) onTrackPublished(participant types.Participant, track types.Publi "participants", []string{participant.Identity(), existingParticipant.Identity()}, "pIDs", []string{participant.ID(), existingParticipant.ID()}, "track", track.ID()) - if err := track.AddSubscriber(existingParticipant); err != nil { + if _, err := participant.AddSubscriber(existingParticipant, types.AddSubscriberParams{TrackSids: []string{track.ID()}}); err != nil { r.Logger.Errorw("could not subscribe to remoteTrack", err, "participants", []string{participant.Identity(), existingParticipant.Identity()}, "pIDs", []string{participant.ID(), existingParticipant.ID()}, @@ -595,7 +607,7 @@ func (r *Room) subscribeToExistingTracks(p types.Participant) { // don't send to itself continue } - if n, err := op.AddSubscriber(p); err != nil { + if n, err := op.AddSubscriber(p, types.AddSubscriberParams{AllTracks: true}); err != nil { // TODO: log error? or disconnect? r.Logger.Errorw("could not subscribe to participant", err, "participants", []string{op.Identity(), p.Identity()}, diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index 4c016d1ca..025f3514f 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -95,7 +95,9 @@ func TestRoomJoin(t *testing.T) { mockP := op.(*typesfakes.FakeParticipant) require.NotZero(t, mockP.AddSubscriberCallCount()) // last call should be to add the newest participant - require.Equal(t, p, mockP.AddSubscriberArgsForCall(mockP.AddSubscriberCallCount()-1)) + sub, params := mockP.AddSubscriberArgsForCall(mockP.AddSubscriberCallCount() - 1) + require.Equal(t, p, sub) + require.Equal(t, types.AddSubscriberParams{AllTracks: true}, params) } }) @@ -253,14 +255,16 @@ func TestNewTrack(t *testing.T) { pub := participants[2].(*typesfakes.FakeParticipant) - // p3 adds track + // pub adds track track := newMockTrack(livekit.TrackType_VIDEO, "webcam") trackCB := pub.OnTrackPublishedArgsForCall(0) require.NotNil(t, trackCB) trackCB(pub, track) - // only p2 should've been called - require.Equal(t, 1, track.AddSubscriberCallCount()) - require.Equal(t, p1, track.AddSubscriberArgsForCall(0)) + // only p1 should've been called + require.Equal(t, 1, pub.AddSubscriberCallCount()) + sub, params := pub.AddSubscriberArgsForCall(pub.AddSubscriberCallCount() - 1) + require.Equal(t, p1, sub) + require.Equal(t, types.AddSubscriberParams{TrackSids: []string{track.ID()}}, params) }) } @@ -518,7 +522,9 @@ func TestHiddenParticipants(t *testing.T) { mockP := op.(*typesfakes.FakeParticipant) require.NotZero(t, mockP.AddSubscriberCallCount()) // last call should be to add the newest participant - require.Equal(t, p, mockP.AddSubscriberArgsForCall(mockP.AddSubscriberCallCount()-1)) + sub, params := mockP.AddSubscriberArgsForCall(mockP.AddSubscriberCallCount() - 1) + require.Equal(t, p, sub) + require.Equal(t, types.AddSubscriberParams{AllTracks: true}, params) } }) } diff --git a/pkg/rtc/signalhandler.go b/pkg/rtc/signalhandler.go index 8431f5d56..bd0aed65c 100644 --- a/pkg/rtc/signalhandler.go +++ b/pkg/rtc/signalhandler.go @@ -80,6 +80,12 @@ func HandleParticipantSignal(room types.Room, participant types.Participant, req track.UpdateVideoLayers(msg.UpdateLayers.Layers) case *livekit.SignalRequest_Leave: _ = participant.Close() + case *livekit.SignalRequest_SubscriptionPermissions: + err := room.UpdateSubscriptionPermissions(participant, msg.SubscriptionPermissions) + if err != nil { + pLogger.Warnw("could not update subscription permissions", err, + "permissions", msg.SubscriptionPermissions) + } } return nil } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 6101a1a1c..aed5b8084 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -4,7 +4,6 @@ import ( "time" "github.com/livekit/protocol/livekit" - "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/livekit/livekit-server/pkg/routing" @@ -20,6 +19,11 @@ type WebsocketClient interface { WriteControl(messageType int, data []byte, deadline time.Time) error } +type AddSubscriberParams struct { + AllTracks bool + TrackSids []string +} + //counterfeiter:generate . Participant type Participant interface { ID() string @@ -29,7 +33,6 @@ type Participant interface { IsReady() bool ConnectedAt() time.Time ToProto() *livekit.ParticipantInfo - RTCPChan() chan []rtcp.Packet SetMetadata(metadata string) SetPermission(permission *livekit.ParticipantPermission) GetResponseSink() routing.MessageSink @@ -46,7 +49,8 @@ type Participant interface { HandleOffer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) HandleAnswer(sdp webrtc.SessionDescription) error AddICECandidate(candidate webrtc.ICECandidateInit, target livekit.SignalTarget) error - AddSubscriber(op Participant) (int, error) + AddSubscriber(op Participant, params AddSubscriberParams) (int, error) + RemoveSubscriber(op Participant, trackSid string) SendJoinResponse(info *livekit.Room, otherParticipants []*livekit.ParticipantInfo, iceServers []*livekit.ICEServer) error SendParticipantUpdate(participants []*livekit.ParticipantInfo, updatedAt time.Time) error SendSpeakerUpdate(speakers []*livekit.SpeakerInfo) error @@ -72,7 +76,6 @@ type Participant interface { Close() error // callbacks - OnStateChange(func(p Participant, oldState livekit.ParticipantInfo_State)) // OnTrackPublished - remote added a remoteTrack OnTrackPublished(func(Participant, PublishedTrack)) @@ -80,13 +83,16 @@ type Participant interface { OnTrackUpdated(callback func(Participant, PublishedTrack)) OnMetadataUpdate(callback func(Participant)) OnDataPacket(callback func(Participant, *livekit.DataPacket)) - OnClose(func(Participant)) + OnClose(func(Participant, map[string]string)) // package methods AddSubscribedTrack(st SubscribedTrack) RemoveSubscribedTrack(st SubscribedTrack) SubscriberPC() *webrtc.PeerConnection + UpdateSubscriptionPermissions(permissions *livekit.UpdateSubscriptionPermissions, resolver func(participantSid string) Participant) error + SubscriptionPermissionUpdate(publisherSid string, trackSid string, allowed bool) + DebugInfo() map[string]interface{} } @@ -95,6 +101,7 @@ type Participant interface { type Room interface { Name() string UpdateSubscriptions(participant Participant, trackIDs []string, participantTracks []*livekit.ParticipantTracks, subscribe bool) error + UpdateSubscriptionPermissions(participant Participant, permissions *livekit.UpdateSubscriptionPermissions) error } // MediaTrack represents a media track @@ -114,12 +121,13 @@ type MediaTrack interface { RemoveSubscriber(participantId string) IsSubscriber(subId string) bool RemoveAllSubscribers() + RevokeDisallowedSubscribers(allowedSubscriberIDs []string) []string + // returns quality information that's appropriate for width & height GetQualityForDimension(width, height uint32) livekit.VideoQuality NotifySubscriberMute(subscriberID string) NotifySubscriberMaxQuality(subscriberID string, quality livekit.VideoQuality) - OnSubscribedMaxQualityChange(f func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error) } // PublishedTrack is the main interface representing a track published to the room diff --git a/pkg/rtc/types/typesfakes/fake_media_track.go b/pkg/rtc/types/typesfakes/fake_media_track.go index 4f41fba30..e3e314034 100644 --- a/pkg/rtc/types/typesfakes/fake_media_track.go +++ b/pkg/rtc/types/typesfakes/fake_media_track.go @@ -104,11 +104,6 @@ type FakeMediaTrack struct { notifySubscriberMuteArgsForCall []struct { arg1 string } - OnSubscribedMaxQualityChangeStub func(func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error) - onSubscribedMaxQualityChangeMutex sync.RWMutex - onSubscribedMaxQualityChangeArgsForCall []struct { - arg1 func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error - } RemoveAllSubscribersStub func() removeAllSubscribersMutex sync.RWMutex removeAllSubscribersArgsForCall []struct { @@ -118,6 +113,17 @@ type FakeMediaTrack struct { removeSubscriberArgsForCall []struct { arg1 string } + RevokeDisallowedSubscribersStub func([]string) []string + revokeDisallowedSubscribersMutex sync.RWMutex + revokeDisallowedSubscribersArgsForCall []struct { + arg1 []string + } + revokeDisallowedSubscribersReturns struct { + result1 []string + } + revokeDisallowedSubscribersReturnsOnCall map[int]struct { + result1 []string + } SetMutedStub func(bool) setMutedMutex sync.RWMutex setMutedArgsForCall []struct { @@ -656,38 +662,6 @@ func (fake *FakeMediaTrack) NotifySubscriberMuteArgsForCall(i int) string { return argsForCall.arg1 } -func (fake *FakeMediaTrack) OnSubscribedMaxQualityChange(arg1 func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error) { - fake.onSubscribedMaxQualityChangeMutex.Lock() - fake.onSubscribedMaxQualityChangeArgsForCall = append(fake.onSubscribedMaxQualityChangeArgsForCall, struct { - arg1 func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error - }{arg1}) - stub := fake.OnSubscribedMaxQualityChangeStub - fake.recordInvocation("OnSubscribedMaxQualityChange", []interface{}{arg1}) - fake.onSubscribedMaxQualityChangeMutex.Unlock() - if stub != nil { - fake.OnSubscribedMaxQualityChangeStub(arg1) - } -} - -func (fake *FakeMediaTrack) OnSubscribedMaxQualityChangeCallCount() int { - fake.onSubscribedMaxQualityChangeMutex.RLock() - defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() - return len(fake.onSubscribedMaxQualityChangeArgsForCall) -} - -func (fake *FakeMediaTrack) OnSubscribedMaxQualityChangeCalls(stub func(func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error)) { - fake.onSubscribedMaxQualityChangeMutex.Lock() - defer fake.onSubscribedMaxQualityChangeMutex.Unlock() - fake.OnSubscribedMaxQualityChangeStub = stub -} - -func (fake *FakeMediaTrack) OnSubscribedMaxQualityChangeArgsForCall(i int) func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error { - fake.onSubscribedMaxQualityChangeMutex.RLock() - defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() - argsForCall := fake.onSubscribedMaxQualityChangeArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakeMediaTrack) RemoveAllSubscribers() { fake.removeAllSubscribersMutex.Lock() fake.removeAllSubscribersArgsForCall = append(fake.removeAllSubscribersArgsForCall, struct { @@ -744,6 +718,72 @@ func (fake *FakeMediaTrack) RemoveSubscriberArgsForCall(i int) string { return argsForCall.arg1 } +func (fake *FakeMediaTrack) RevokeDisallowedSubscribers(arg1 []string) []string { + var arg1Copy []string + if arg1 != nil { + arg1Copy = make([]string, len(arg1)) + copy(arg1Copy, arg1) + } + fake.revokeDisallowedSubscribersMutex.Lock() + ret, specificReturn := fake.revokeDisallowedSubscribersReturnsOnCall[len(fake.revokeDisallowedSubscribersArgsForCall)] + fake.revokeDisallowedSubscribersArgsForCall = append(fake.revokeDisallowedSubscribersArgsForCall, struct { + arg1 []string + }{arg1Copy}) + stub := fake.RevokeDisallowedSubscribersStub + fakeReturns := fake.revokeDisallowedSubscribersReturns + fake.recordInvocation("RevokeDisallowedSubscribers", []interface{}{arg1Copy}) + fake.revokeDisallowedSubscribersMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCallCount() int { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + return len(fake.revokeDisallowedSubscribersArgsForCall) +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]string) []string) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = stub +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []string { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + argsForCall := fake.revokeDisallowedSubscribersArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturns(result1 []string) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + fake.revokeDisallowedSubscribersReturns = struct { + result1 []string + }{result1} +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []string) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + if fake.revokeDisallowedSubscribersReturnsOnCall == nil { + fake.revokeDisallowedSubscribersReturnsOnCall = make(map[int]struct { + result1 []string + }) + } + fake.revokeDisallowedSubscribersReturnsOnCall[i] = struct { + result1 []string + }{result1} +} + func (fake *FakeMediaTrack) SetMuted(arg1 bool) { fake.setMutedMutex.Lock() fake.setMutedArgsForCall = append(fake.setMutedArgsForCall, struct { @@ -889,12 +929,12 @@ func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { defer fake.notifySubscriberMaxQualityMutex.RUnlock() fake.notifySubscriberMuteMutex.RLock() defer fake.notifySubscriberMuteMutex.RUnlock() - fake.onSubscribedMaxQualityChangeMutex.RLock() - defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() fake.removeAllSubscribersMutex.RLock() defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() fake.setMutedMutex.RLock() defer fake.setMutedMutex.RUnlock() fake.sourceMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index f19484918..8b693f3b5 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -8,7 +8,6 @@ import ( "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/protocol/livekit" - "github.com/pion/rtcp" webrtc "github.com/pion/webrtc/v3" ) @@ -30,10 +29,11 @@ type FakeParticipant struct { addSubscribedTrackArgsForCall []struct { arg1 types.SubscribedTrack } - AddSubscriberStub func(types.Participant) (int, error) + AddSubscriberStub func(types.Participant, types.AddSubscriberParams) (int, error) addSubscriberMutex sync.RWMutex addSubscriberArgsForCall []struct { arg1 types.Participant + arg2 types.AddSubscriberParams } addSubscriberReturns struct { result1 int @@ -291,10 +291,10 @@ type FakeParticipant struct { negotiateMutex sync.RWMutex negotiateArgsForCall []struct { } - OnCloseStub func(func(types.Participant)) + OnCloseStub func(func(types.Participant, map[string]string)) onCloseMutex sync.RWMutex onCloseArgsForCall []struct { - arg1 func(types.Participant) + arg1 func(types.Participant, map[string]string) } OnDataPacketStub func(func(types.Participant, *livekit.DataPacket)) onDataPacketMutex sync.RWMutex @@ -331,21 +331,17 @@ type FakeParticipant struct { protocolVersionReturnsOnCall map[int]struct { result1 types.ProtocolVersion } - RTCPChanStub func() chan []rtcp.Packet - rTCPChanMutex sync.RWMutex - rTCPChanArgsForCall []struct { - } - rTCPChanReturns struct { - result1 chan []rtcp.Packet - } - rTCPChanReturnsOnCall map[int]struct { - result1 chan []rtcp.Packet - } RemoveSubscribedTrackStub func(types.SubscribedTrack) removeSubscribedTrackMutex sync.RWMutex removeSubscribedTrackArgsForCall []struct { arg1 types.SubscribedTrack } + RemoveSubscriberStub func(types.Participant, string) + removeSubscriberMutex sync.RWMutex + removeSubscriberArgsForCall []struct { + arg1 types.Participant + arg2 string + } SendConnectionQualityUpdateStub func(*livekit.ConnectionQualityUpdate) error sendConnectionQualityUpdateMutex sync.RWMutex sendConnectionQualityUpdateArgsForCall []struct { @@ -481,6 +477,13 @@ type FakeParticipant struct { subscriberPCReturnsOnCall map[int]struct { result1 *webrtc.PeerConnection } + SubscriptionPermissionUpdateStub func(string, string, bool) + subscriptionPermissionUpdateMutex sync.RWMutex + subscriptionPermissionUpdateArgsForCall []struct { + arg1 string + arg2 string + arg3 bool + } ToProtoStub func() *livekit.ParticipantInfo toProtoMutex sync.RWMutex toProtoArgsForCall []struct { @@ -491,6 +494,18 @@ type FakeParticipant struct { toProtoReturnsOnCall map[int]struct { result1 *livekit.ParticipantInfo } + UpdateSubscriptionPermissionsStub func(*livekit.UpdateSubscriptionPermissions, func(participantSid string) types.Participant) error + updateSubscriptionPermissionsMutex sync.RWMutex + updateSubscriptionPermissionsArgsForCall []struct { + arg1 *livekit.UpdateSubscriptionPermissions + arg2 func(participantSid string) types.Participant + } + updateSubscriptionPermissionsReturns struct { + result1 error + } + updateSubscriptionPermissionsReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -589,18 +604,19 @@ func (fake *FakeParticipant) AddSubscribedTrackArgsForCall(i int) types.Subscrib return argsForCall.arg1 } -func (fake *FakeParticipant) AddSubscriber(arg1 types.Participant) (int, error) { +func (fake *FakeParticipant) AddSubscriber(arg1 types.Participant, arg2 types.AddSubscriberParams) (int, error) { fake.addSubscriberMutex.Lock() ret, specificReturn := fake.addSubscriberReturnsOnCall[len(fake.addSubscriberArgsForCall)] fake.addSubscriberArgsForCall = append(fake.addSubscriberArgsForCall, struct { arg1 types.Participant - }{arg1}) + arg2 types.AddSubscriberParams + }{arg1, arg2}) stub := fake.AddSubscriberStub fakeReturns := fake.addSubscriberReturns - fake.recordInvocation("AddSubscriber", []interface{}{arg1}) + fake.recordInvocation("AddSubscriber", []interface{}{arg1, arg2}) fake.addSubscriberMutex.Unlock() if stub != nil { - return stub(arg1) + return stub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2 @@ -614,17 +630,17 @@ func (fake *FakeParticipant) AddSubscriberCallCount() int { return len(fake.addSubscriberArgsForCall) } -func (fake *FakeParticipant) AddSubscriberCalls(stub func(types.Participant) (int, error)) { +func (fake *FakeParticipant) AddSubscriberCalls(stub func(types.Participant, types.AddSubscriberParams) (int, error)) { fake.addSubscriberMutex.Lock() defer fake.addSubscriberMutex.Unlock() fake.AddSubscriberStub = stub } -func (fake *FakeParticipant) AddSubscriberArgsForCall(i int) types.Participant { +func (fake *FakeParticipant) AddSubscriberArgsForCall(i int) (types.Participant, types.AddSubscriberParams) { fake.addSubscriberMutex.RLock() defer fake.addSubscriberMutex.RUnlock() argsForCall := fake.addSubscriberArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeParticipant) AddSubscriberReturns(result1 int, result2 error) { @@ -1974,10 +1990,10 @@ func (fake *FakeParticipant) NegotiateCalls(stub func()) { fake.NegotiateStub = stub } -func (fake *FakeParticipant) OnClose(arg1 func(types.Participant)) { +func (fake *FakeParticipant) OnClose(arg1 func(types.Participant, map[string]string)) { fake.onCloseMutex.Lock() fake.onCloseArgsForCall = append(fake.onCloseArgsForCall, struct { - arg1 func(types.Participant) + arg1 func(types.Participant, map[string]string) }{arg1}) stub := fake.OnCloseStub fake.recordInvocation("OnClose", []interface{}{arg1}) @@ -1993,13 +2009,13 @@ func (fake *FakeParticipant) OnCloseCallCount() int { return len(fake.onCloseArgsForCall) } -func (fake *FakeParticipant) OnCloseCalls(stub func(func(types.Participant))) { +func (fake *FakeParticipant) OnCloseCalls(stub func(func(types.Participant, map[string]string))) { fake.onCloseMutex.Lock() defer fake.onCloseMutex.Unlock() fake.OnCloseStub = stub } -func (fake *FakeParticipant) OnCloseArgsForCall(i int) func(types.Participant) { +func (fake *FakeParticipant) OnCloseArgsForCall(i int) func(types.Participant, map[string]string) { fake.onCloseMutex.RLock() defer fake.onCloseMutex.RUnlock() argsForCall := fake.onCloseArgsForCall[i] @@ -2219,59 +2235,6 @@ func (fake *FakeParticipant) ProtocolVersionReturnsOnCall(i int, result1 types.P }{result1} } -func (fake *FakeParticipant) RTCPChan() chan []rtcp.Packet { - fake.rTCPChanMutex.Lock() - ret, specificReturn := fake.rTCPChanReturnsOnCall[len(fake.rTCPChanArgsForCall)] - fake.rTCPChanArgsForCall = append(fake.rTCPChanArgsForCall, struct { - }{}) - stub := fake.RTCPChanStub - fakeReturns := fake.rTCPChanReturns - fake.recordInvocation("RTCPChan", []interface{}{}) - fake.rTCPChanMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeParticipant) RTCPChanCallCount() int { - fake.rTCPChanMutex.RLock() - defer fake.rTCPChanMutex.RUnlock() - return len(fake.rTCPChanArgsForCall) -} - -func (fake *FakeParticipant) RTCPChanCalls(stub func() chan []rtcp.Packet) { - fake.rTCPChanMutex.Lock() - defer fake.rTCPChanMutex.Unlock() - fake.RTCPChanStub = stub -} - -func (fake *FakeParticipant) RTCPChanReturns(result1 chan []rtcp.Packet) { - fake.rTCPChanMutex.Lock() - defer fake.rTCPChanMutex.Unlock() - fake.RTCPChanStub = nil - fake.rTCPChanReturns = struct { - result1 chan []rtcp.Packet - }{result1} -} - -func (fake *FakeParticipant) RTCPChanReturnsOnCall(i int, result1 chan []rtcp.Packet) { - fake.rTCPChanMutex.Lock() - defer fake.rTCPChanMutex.Unlock() - fake.RTCPChanStub = nil - if fake.rTCPChanReturnsOnCall == nil { - fake.rTCPChanReturnsOnCall = make(map[int]struct { - result1 chan []rtcp.Packet - }) - } - fake.rTCPChanReturnsOnCall[i] = struct { - result1 chan []rtcp.Packet - }{result1} -} - func (fake *FakeParticipant) RemoveSubscribedTrack(arg1 types.SubscribedTrack) { fake.removeSubscribedTrackMutex.Lock() fake.removeSubscribedTrackArgsForCall = append(fake.removeSubscribedTrackArgsForCall, struct { @@ -2304,6 +2267,39 @@ func (fake *FakeParticipant) RemoveSubscribedTrackArgsForCall(i int) types.Subsc return argsForCall.arg1 } +func (fake *FakeParticipant) RemoveSubscriber(arg1 types.Participant, arg2 string) { + fake.removeSubscriberMutex.Lock() + fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { + arg1 types.Participant + arg2 string + }{arg1, arg2}) + stub := fake.RemoveSubscriberStub + fake.recordInvocation("RemoveSubscriber", []interface{}{arg1, arg2}) + fake.removeSubscriberMutex.Unlock() + if stub != nil { + fake.RemoveSubscriberStub(arg1, arg2) + } +} + +func (fake *FakeParticipant) RemoveSubscriberCallCount() int { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + return len(fake.removeSubscriberArgsForCall) +} + +func (fake *FakeParticipant) RemoveSubscriberCalls(stub func(types.Participant, string)) { + fake.removeSubscriberMutex.Lock() + defer fake.removeSubscriberMutex.Unlock() + fake.RemoveSubscriberStub = stub +} + +func (fake *FakeParticipant) RemoveSubscriberArgsForCall(i int) (types.Participant, string) { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + argsForCall := fake.removeSubscriberArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeParticipant) SendConnectionQualityUpdate(arg1 *livekit.ConnectionQualityUpdate) error { fake.sendConnectionQualityUpdateMutex.Lock() ret, specificReturn := fake.sendConnectionQualityUpdateReturnsOnCall[len(fake.sendConnectionQualityUpdateArgsForCall)] @@ -3059,6 +3055,40 @@ func (fake *FakeParticipant) SubscriberPCReturnsOnCall(i int, result1 *webrtc.Pe }{result1} } +func (fake *FakeParticipant) SubscriptionPermissionUpdate(arg1 string, arg2 string, arg3 bool) { + fake.subscriptionPermissionUpdateMutex.Lock() + fake.subscriptionPermissionUpdateArgsForCall = append(fake.subscriptionPermissionUpdateArgsForCall, struct { + arg1 string + arg2 string + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.SubscriptionPermissionUpdateStub + fake.recordInvocation("SubscriptionPermissionUpdate", []interface{}{arg1, arg2, arg3}) + fake.subscriptionPermissionUpdateMutex.Unlock() + if stub != nil { + fake.SubscriptionPermissionUpdateStub(arg1, arg2, arg3) + } +} + +func (fake *FakeParticipant) SubscriptionPermissionUpdateCallCount() int { + fake.subscriptionPermissionUpdateMutex.RLock() + defer fake.subscriptionPermissionUpdateMutex.RUnlock() + return len(fake.subscriptionPermissionUpdateArgsForCall) +} + +func (fake *FakeParticipant) SubscriptionPermissionUpdateCalls(stub func(string, string, bool)) { + fake.subscriptionPermissionUpdateMutex.Lock() + defer fake.subscriptionPermissionUpdateMutex.Unlock() + fake.SubscriptionPermissionUpdateStub = stub +} + +func (fake *FakeParticipant) SubscriptionPermissionUpdateArgsForCall(i int) (string, string, bool) { + fake.subscriptionPermissionUpdateMutex.RLock() + defer fake.subscriptionPermissionUpdateMutex.RUnlock() + argsForCall := fake.subscriptionPermissionUpdateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + func (fake *FakeParticipant) ToProto() *livekit.ParticipantInfo { fake.toProtoMutex.Lock() ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] @@ -3112,6 +3142,68 @@ func (fake *FakeParticipant) ToProtoReturnsOnCall(i int, result1 *livekit.Partic }{result1} } +func (fake *FakeParticipant) UpdateSubscriptionPermissions(arg1 *livekit.UpdateSubscriptionPermissions, arg2 func(participantSid string) types.Participant) error { + fake.updateSubscriptionPermissionsMutex.Lock() + ret, specificReturn := fake.updateSubscriptionPermissionsReturnsOnCall[len(fake.updateSubscriptionPermissionsArgsForCall)] + fake.updateSubscriptionPermissionsArgsForCall = append(fake.updateSubscriptionPermissionsArgsForCall, struct { + arg1 *livekit.UpdateSubscriptionPermissions + arg2 func(participantSid string) types.Participant + }{arg1, arg2}) + stub := fake.UpdateSubscriptionPermissionsStub + fakeReturns := fake.updateSubscriptionPermissionsReturns + fake.recordInvocation("UpdateSubscriptionPermissions", []interface{}{arg1, arg2}) + fake.updateSubscriptionPermissionsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionsCallCount() int { + fake.updateSubscriptionPermissionsMutex.RLock() + defer fake.updateSubscriptionPermissionsMutex.RUnlock() + return len(fake.updateSubscriptionPermissionsArgsForCall) +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionsCalls(stub func(*livekit.UpdateSubscriptionPermissions, func(participantSid string) types.Participant) error) { + fake.updateSubscriptionPermissionsMutex.Lock() + defer fake.updateSubscriptionPermissionsMutex.Unlock() + fake.UpdateSubscriptionPermissionsStub = stub +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionsArgsForCall(i int) (*livekit.UpdateSubscriptionPermissions, func(participantSid string) types.Participant) { + fake.updateSubscriptionPermissionsMutex.RLock() + defer fake.updateSubscriptionPermissionsMutex.RUnlock() + argsForCall := fake.updateSubscriptionPermissionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionsReturns(result1 error) { + fake.updateSubscriptionPermissionsMutex.Lock() + defer fake.updateSubscriptionPermissionsMutex.Unlock() + fake.UpdateSubscriptionPermissionsStub = nil + fake.updateSubscriptionPermissionsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionsReturnsOnCall(i int, result1 error) { + fake.updateSubscriptionPermissionsMutex.Lock() + defer fake.updateSubscriptionPermissionsMutex.Unlock() + fake.UpdateSubscriptionPermissionsStub = nil + if fake.updateSubscriptionPermissionsReturnsOnCall == nil { + fake.updateSubscriptionPermissionsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscriptionPermissionsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeParticipant) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -3185,10 +3277,10 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.onTrackUpdatedMutex.RUnlock() fake.protocolVersionMutex.RLock() defer fake.protocolVersionMutex.RUnlock() - fake.rTCPChanMutex.RLock() - defer fake.rTCPChanMutex.RUnlock() fake.removeSubscribedTrackMutex.RLock() defer fake.removeSubscribedTrackMutex.RUnlock() + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() fake.sendConnectionQualityUpdateMutex.RLock() defer fake.sendConnectionQualityUpdateMutex.RUnlock() fake.sendDataPacketMutex.RLock() @@ -3219,8 +3311,12 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.subscriberMediaEngineMutex.RUnlock() fake.subscriberPCMutex.RLock() defer fake.subscriberPCMutex.RUnlock() + fake.subscriptionPermissionUpdateMutex.RLock() + defer fake.subscriptionPermissionUpdateMutex.RUnlock() fake.toProtoMutex.RLock() defer fake.toProtoMutex.RUnlock() + fake.updateSubscriptionPermissionsMutex.RLock() + defer fake.updateSubscriptionPermissionsMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/pkg/rtc/types/typesfakes/fake_published_track.go b/pkg/rtc/types/typesfakes/fake_published_track.go index e8c58b7fa..63f5cf1f6 100644 --- a/pkg/rtc/types/typesfakes/fake_published_track.go +++ b/pkg/rtc/types/typesfakes/fake_published_track.go @@ -132,11 +132,6 @@ type FakePublishedTrack struct { result1 uint32 result2 uint32 } - OnSubscribedMaxQualityChangeStub func(func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error) - onSubscribedMaxQualityChangeMutex sync.RWMutex - onSubscribedMaxQualityChangeArgsForCall []struct { - arg1 func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error - } PublishLossPercentageStub func() uint32 publishLossPercentageMutex sync.RWMutex publishLossPercentageArgsForCall []struct { @@ -166,6 +161,17 @@ type FakePublishedTrack struct { removeSubscriberArgsForCall []struct { arg1 string } + RevokeDisallowedSubscribersStub func([]string) []string + revokeDisallowedSubscribersMutex sync.RWMutex + revokeDisallowedSubscribersArgsForCall []struct { + arg1 []string + } + revokeDisallowedSubscribersReturns struct { + result1 []string + } + revokeDisallowedSubscribersReturnsOnCall map[int]struct { + result1 []string + } SdpCidStub func() string sdpCidMutex sync.RWMutex sdpCidArgsForCall []struct { @@ -875,38 +881,6 @@ func (fake *FakePublishedTrack) NumUpTracksReturnsOnCall(i int, result1 uint32, }{result1, result2} } -func (fake *FakePublishedTrack) OnSubscribedMaxQualityChange(arg1 func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error) { - fake.onSubscribedMaxQualityChangeMutex.Lock() - fake.onSubscribedMaxQualityChangeArgsForCall = append(fake.onSubscribedMaxQualityChangeArgsForCall, struct { - arg1 func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error - }{arg1}) - stub := fake.OnSubscribedMaxQualityChangeStub - fake.recordInvocation("OnSubscribedMaxQualityChange", []interface{}{arg1}) - fake.onSubscribedMaxQualityChangeMutex.Unlock() - if stub != nil { - fake.OnSubscribedMaxQualityChangeStub(arg1) - } -} - -func (fake *FakePublishedTrack) OnSubscribedMaxQualityChangeCallCount() int { - fake.onSubscribedMaxQualityChangeMutex.RLock() - defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() - return len(fake.onSubscribedMaxQualityChangeArgsForCall) -} - -func (fake *FakePublishedTrack) OnSubscribedMaxQualityChangeCalls(stub func(func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error)) { - fake.onSubscribedMaxQualityChangeMutex.Lock() - defer fake.onSubscribedMaxQualityChangeMutex.Unlock() - fake.OnSubscribedMaxQualityChangeStub = stub -} - -func (fake *FakePublishedTrack) OnSubscribedMaxQualityChangeArgsForCall(i int) func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error { - fake.onSubscribedMaxQualityChangeMutex.RLock() - defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() - argsForCall := fake.onSubscribedMaxQualityChangeArgsForCall[i] - return argsForCall.arg1 -} - func (fake *FakePublishedTrack) PublishLossPercentage() uint32 { fake.publishLossPercentageMutex.Lock() ret, specificReturn := fake.publishLossPercentageReturnsOnCall[len(fake.publishLossPercentageArgsForCall)] @@ -1069,6 +1043,72 @@ func (fake *FakePublishedTrack) RemoveSubscriberArgsForCall(i int) string { return argsForCall.arg1 } +func (fake *FakePublishedTrack) RevokeDisallowedSubscribers(arg1 []string) []string { + var arg1Copy []string + if arg1 != nil { + arg1Copy = make([]string, len(arg1)) + copy(arg1Copy, arg1) + } + fake.revokeDisallowedSubscribersMutex.Lock() + ret, specificReturn := fake.revokeDisallowedSubscribersReturnsOnCall[len(fake.revokeDisallowedSubscribersArgsForCall)] + fake.revokeDisallowedSubscribersArgsForCall = append(fake.revokeDisallowedSubscribersArgsForCall, struct { + arg1 []string + }{arg1Copy}) + stub := fake.RevokeDisallowedSubscribersStub + fakeReturns := fake.revokeDisallowedSubscribersReturns + fake.recordInvocation("RevokeDisallowedSubscribers", []interface{}{arg1Copy}) + fake.revokeDisallowedSubscribersMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePublishedTrack) RevokeDisallowedSubscribersCallCount() int { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + return len(fake.revokeDisallowedSubscribersArgsForCall) +} + +func (fake *FakePublishedTrack) RevokeDisallowedSubscribersCalls(stub func([]string) []string) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = stub +} + +func (fake *FakePublishedTrack) RevokeDisallowedSubscribersArgsForCall(i int) []string { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + argsForCall := fake.revokeDisallowedSubscribersArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakePublishedTrack) RevokeDisallowedSubscribersReturns(result1 []string) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + fake.revokeDisallowedSubscribersReturns = struct { + result1 []string + }{result1} +} + +func (fake *FakePublishedTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []string) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + if fake.revokeDisallowedSubscribersReturnsOnCall == nil { + fake.revokeDisallowedSubscribersReturnsOnCall = make(map[int]struct { + result1 []string + }) + } + fake.revokeDisallowedSubscribersReturnsOnCall[i] = struct { + result1 []string + }{result1} +} + func (fake *FakePublishedTrack) SdpCid() string { fake.sdpCidMutex.Lock() ret, specificReturn := fake.sdpCidReturnsOnCall[len(fake.sdpCidArgsForCall)] @@ -1379,8 +1419,6 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.notifySubscriberMuteMutex.RUnlock() fake.numUpTracksMutex.RLock() defer fake.numUpTracksMutex.RUnlock() - fake.onSubscribedMaxQualityChangeMutex.RLock() - defer fake.onSubscribedMaxQualityChangeMutex.RUnlock() fake.publishLossPercentageMutex.RLock() defer fake.publishLossPercentageMutex.RUnlock() fake.receiverMutex.RLock() @@ -1389,6 +1427,8 @@ func (fake *FakePublishedTrack) Invocations() map[string][][]interface{} { defer fake.removeAllSubscribersMutex.RUnlock() fake.removeSubscriberMutex.RLock() defer fake.removeSubscriberMutex.RUnlock() + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() fake.sdpCidMutex.RLock() defer fake.sdpCidMutex.RUnlock() fake.setMutedMutex.RLock() diff --git a/pkg/rtc/types/typesfakes/fake_room.go b/pkg/rtc/types/typesfakes/fake_room.go index 2e6759190..958e6accb 100644 --- a/pkg/rtc/types/typesfakes/fake_room.go +++ b/pkg/rtc/types/typesfakes/fake_room.go @@ -19,6 +19,18 @@ type FakeRoom struct { nameReturnsOnCall map[int]struct { result1 string } + UpdateSubscriptionPermissionsStub func(types.Participant, *livekit.UpdateSubscriptionPermissions) error + updateSubscriptionPermissionsMutex sync.RWMutex + updateSubscriptionPermissionsArgsForCall []struct { + arg1 types.Participant + arg2 *livekit.UpdateSubscriptionPermissions + } + updateSubscriptionPermissionsReturns struct { + result1 error + } + updateSubscriptionPermissionsReturnsOnCall map[int]struct { + result1 error + } UpdateSubscriptionsStub func(types.Participant, []string, []*livekit.ParticipantTracks, bool) error updateSubscriptionsMutex sync.RWMutex updateSubscriptionsArgsForCall []struct { @@ -90,6 +102,68 @@ func (fake *FakeRoom) NameReturnsOnCall(i int, result1 string) { }{result1} } +func (fake *FakeRoom) UpdateSubscriptionPermissions(arg1 types.Participant, arg2 *livekit.UpdateSubscriptionPermissions) error { + fake.updateSubscriptionPermissionsMutex.Lock() + ret, specificReturn := fake.updateSubscriptionPermissionsReturnsOnCall[len(fake.updateSubscriptionPermissionsArgsForCall)] + fake.updateSubscriptionPermissionsArgsForCall = append(fake.updateSubscriptionPermissionsArgsForCall, struct { + arg1 types.Participant + arg2 *livekit.UpdateSubscriptionPermissions + }{arg1, arg2}) + stub := fake.UpdateSubscriptionPermissionsStub + fakeReturns := fake.updateSubscriptionPermissionsReturns + fake.recordInvocation("UpdateSubscriptionPermissions", []interface{}{arg1, arg2}) + fake.updateSubscriptionPermissionsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) UpdateSubscriptionPermissionsCallCount() int { + fake.updateSubscriptionPermissionsMutex.RLock() + defer fake.updateSubscriptionPermissionsMutex.RUnlock() + return len(fake.updateSubscriptionPermissionsArgsForCall) +} + +func (fake *FakeRoom) UpdateSubscriptionPermissionsCalls(stub func(types.Participant, *livekit.UpdateSubscriptionPermissions) error) { + fake.updateSubscriptionPermissionsMutex.Lock() + defer fake.updateSubscriptionPermissionsMutex.Unlock() + fake.UpdateSubscriptionPermissionsStub = stub +} + +func (fake *FakeRoom) UpdateSubscriptionPermissionsArgsForCall(i int) (types.Participant, *livekit.UpdateSubscriptionPermissions) { + fake.updateSubscriptionPermissionsMutex.RLock() + defer fake.updateSubscriptionPermissionsMutex.RUnlock() + argsForCall := fake.updateSubscriptionPermissionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRoom) UpdateSubscriptionPermissionsReturns(result1 error) { + fake.updateSubscriptionPermissionsMutex.Lock() + defer fake.updateSubscriptionPermissionsMutex.Unlock() + fake.UpdateSubscriptionPermissionsStub = nil + fake.updateSubscriptionPermissionsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRoom) UpdateSubscriptionPermissionsReturnsOnCall(i int, result1 error) { + fake.updateSubscriptionPermissionsMutex.Lock() + defer fake.updateSubscriptionPermissionsMutex.Unlock() + fake.UpdateSubscriptionPermissionsStub = nil + if fake.updateSubscriptionPermissionsReturnsOnCall == nil { + fake.updateSubscriptionPermissionsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscriptionPermissionsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 []*livekit.ParticipantTracks, arg4 bool) error { var arg2Copy []string if arg2 != nil { @@ -169,6 +243,8 @@ func (fake *FakeRoom) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.nameMutex.RLock() defer fake.nameMutex.RUnlock() + fake.updateSubscriptionPermissionsMutex.RLock() + defer fake.updateSubscriptionPermissionsMutex.RUnlock() fake.updateSubscriptionsMutex.RLock() defer fake.updateSubscriptionsMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/rtc/uptrackmanager.go b/pkg/rtc/uptrackmanager.go new file mode 100644 index 000000000..9d1719c38 --- /dev/null +++ b/pkg/rtc/uptrackmanager.go @@ -0,0 +1,670 @@ +package rtc + +import ( + "sync" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/twcc" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +type UptrackManagerParams struct { + Identity string + SID string + Config *WebRTCConfig + AudioConfig config.AudioConfig + Telemetry telemetry.TelemetryService + ThrottleConfig config.PLIThrottleConfig + Logger logger.Logger +} + +type UptrackManager struct { + params UptrackManagerParams + rtcpCh chan []rtcp.Packet + pliThrottle *pliThrottle + + // hold reference for MediaTrack + twcc *twcc.Responder + + // publishedTracks that participant is publishing + publishedTracks map[string]types.PublishedTrack + // client intended to publish, yet to be reconciled + pendingTracks map[string]*livekit.TrackInfo + // keeps track of subscriptions that are awaiting permissions + subscriptionPermissions map[string]*livekit.TrackPermission // subscriberID => *livekit.TrackPermission + // keeps tracks of track specific subscribers who are awaiting permission + pendingSubscriptions map[string][]string // trackSid => []subscriberID + + lock sync.RWMutex + + // callbacks & handlers + onTrackPublished func(track types.PublishedTrack) + onTrackUpdated func(track types.PublishedTrack, onlyIfReady bool) + onWriteRTCP func(pkts []rtcp.Packet) + onSubscribedMaxQualityChange func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error +} + +func NewUptrackManager(params UptrackManagerParams) *UptrackManager { + return &UptrackManager{ + params: params, + rtcpCh: make(chan []rtcp.Packet, 50), + pliThrottle: newPLIThrottle(params.ThrottleConfig), + publishedTracks: make(map[string]types.PublishedTrack, 0), + pendingTracks: make(map[string]*livekit.TrackInfo), + pendingSubscriptions: make(map[string][]string), + } +} + +func (u *UptrackManager) Start() { + go u.rtcpSendWorker() +} + +func (u *UptrackManager) Close() { + u.lock.Lock() + defer u.lock.Unlock() + + // remove all subscribers + for _, t := range u.publishedTracks { + // skip updates + t.RemoveAllSubscribers() + } + + close(u.rtcpCh) +} + +func (u *UptrackManager) ToProto() []*livekit.TrackInfo { + u.lock.RLock() + defer u.lock.RUnlock() + + var trackInfos []*livekit.TrackInfo + for _, t := range u.publishedTracks { + trackInfos = append(trackInfos, t.ToProto()) + } + + return trackInfos +} + +func (u *UptrackManager) OnTrackPublished(f func(track types.PublishedTrack)) { + u.onTrackPublished = f +} + +func (u *UptrackManager) OnTrackUpdated(f func(track types.PublishedTrack, onlyIfReady bool)) { + u.onTrackUpdated = f +} + +func (u *UptrackManager) OnWriteRTCP(f func(pkts []rtcp.Packet)) { + u.onWriteRTCP = f +} + +func (u *UptrackManager) OnSubscribedMaxQualityChange(f func(trackSid string, subscribedQualities []*livekit.SubscribedQuality) error) { + u.onSubscribedMaxQualityChange = f +} + +// AddTrack is called when client intends to publish track. +// records track details and lets client know it's ok to proceed +func (u *UptrackManager) AddTrack(req *livekit.AddTrackRequest) *livekit.TrackInfo { + u.lock.Lock() + defer u.lock.Unlock() + + // if track is already published, reject + if u.pendingTracks[req.Cid] != nil { + return nil + } + + if u.getPublishedTrackBySignalCid(req.Cid) != nil || u.getPublishedTrackBySdpCid(req.Cid) != nil { + return nil + } + + ti := &livekit.TrackInfo{ + Type: req.Type, + Name: req.Name, + Sid: utils.NewGuid(utils.TrackPrefix), + Width: req.Width, + Height: req.Height, + Muted: req.Muted, + DisableDtx: req.DisableDtx, + Source: req.Source, + Layers: req.Layers, + } + u.pendingTracks[req.Cid] = ti + + return ti +} + +// AddSubscriber subscribes op to all publishedTracks +func (u *UptrackManager) AddSubscriber(sub types.Participant, params types.AddSubscriberParams) (int, error) { + var tracks []types.PublishedTrack + if params.AllTracks { + tracks = u.GetPublishedTracks() + } else { + for _, trackSid := range params.TrackSids { + track := u.getPublishedTrack(trackSid) + if track == nil { + continue + } + + tracks = append(tracks, track) + } + } + if len(tracks) == 0 { + return 0, nil + } + + u.params.Logger.Debugw("subscribing new participant to tracks", + "subscriber", sub.Identity(), + "subscriberID", sub.ID(), + "numTracks", len(tracks)) + + n := 0 + for _, track := range tracks { + trackSid := track.ID() + subscriberID := sub.ID() + if !u.hasPermission(trackSid, subscriberID) { + u.maybeAddPendingSubscription(trackSid, sub) + continue + } + + if err := track.AddSubscriber(sub); err != nil { + return n, err + } + n += 1 + } + return n, nil +} + +func (u *UptrackManager) RemoveSubscriber(sub types.Participant, trackSid string) { + u.lock.Lock() + defer u.lock.Unlock() + + track := u.getPublishedTrack(trackSid) + if track != nil { + track.RemoveSubscriber(sub.ID()) + } + + u.maybeRemovePendingSubscription(trackSid, sub) +} + +func (u *UptrackManager) SetTrackMuted(trackSid string, muted bool) { + isPending := false + u.lock.RLock() + for _, ti := range u.pendingTracks { + if ti.Sid == trackSid { + ti.Muted = muted + isPending = true + } + } + track := u.publishedTracks[trackSid] + u.lock.RUnlock() + + if track == nil { + if !isPending { + u.params.Logger.Warnw("could not locate track", nil, "track", trackSid) + } + return + } + currentMuted := track.IsMuted() + track.SetMuted(muted) + + if currentMuted != track.IsMuted() && u.onTrackUpdated != nil { + u.params.Logger.Debugw("mute status changed", + "track", trackSid, + "muted", track.IsMuted()) + u.onTrackUpdated(track, false) + } +} + +func (u *UptrackManager) GetAudioLevel() (level uint8, active bool) { + u.lock.RLock() + defer u.lock.RUnlock() + + level = silentAudioLevel + for _, pt := range u.publishedTracks { + if mt, ok := pt.(*MediaTrack); ok { + if mt.audioLevel == nil { + continue + } + tl, ta := mt.audioLevel.GetLevel() + if ta { + active = true + if tl < level { + level = tl + } + } + } + } + return +} + +func (u *UptrackManager) GetConnectionQuality() (scores float64, numTracks int) { + u.lock.RLock() + defer u.lock.RUnlock() + + for _, pt := range u.publishedTracks { + if pt.IsMuted() { + continue + } + scores += pt.GetConnectionScore() + numTracks++ + } + return +} + +func (u *UptrackManager) GetPublishedTrack(sid string) types.PublishedTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.getPublishedTrack(sid) +} + +func (u *UptrackManager) GetPublishedTracks() []types.PublishedTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + tracks := make([]types.PublishedTrack, 0, len(u.publishedTracks)) + for _, t := range u.publishedTracks { + tracks = append(tracks, t) + } + return tracks +} + +func (u *UptrackManager) GetDTX() bool { + u.lock.RLock() + defer u.lock.RUnlock() + + var trackInfo *livekit.TrackInfo + for _, ti := range u.pendingTracks { + if ti.Type == livekit.TrackType_AUDIO { + trackInfo = ti + break + } + } + + if trackInfo == nil { + return false + } + + return !trackInfo.DisableDtx +} + +func (u *UptrackManager) UpdateSubscriptionPermissions( + permissions *livekit.UpdateSubscriptionPermissions, + resolver func(participantSid string) types.Participant, +) error { + u.lock.Lock() + defer u.lock.Unlock() + + u.updateSubscriptionPermissions(permissions) + + u.processPendingSubscriptions(resolver) + + u.maybeRevokeSubscriptions(resolver) + + return nil +} + +// when a new remoteTrack is created, creates a Track and adds it to room +func (u *UptrackManager) MediaTrackReceived(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + var newTrack bool + + // use existing mediatrack to handle simulcast + u.lock.Lock() + mt, ok := u.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) + if !ok { + signalCid, ti := u.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind())) + if ti == nil { + u.lock.Unlock() + return + } + + mt = NewMediaTrack(track, MediaTrackParams{ + TrackInfo: ti, + SignalCid: signalCid, + SdpCid: track.ID(), + ParticipantID: u.params.SID, + ParticipantIdentity: u.params.Identity, + RTCPChan: u.rtcpCh, + BufferFactory: u.params.Config.BufferFactory, + ReceiverConfig: u.params.Config.Receiver, + AudioConfig: u.params.AudioConfig, + Telemetry: u.params.Telemetry, + Logger: u.params.Logger, + }) + mt.OnSubscribedMaxQualityChange(u.onSubscribedMaxQualityChange) + + // add to published and clean up pending + u.publishedTracks[mt.ID()] = mt + delete(u.pendingTracks, signalCid) + + newTrack = true + } + + ssrc := uint32(track.SSRC()) + u.pliThrottle.addTrack(ssrc, track.RID()) + if u.twcc == nil { + u.twcc = twcc.NewTransportWideCCResponder(ssrc) + u.twcc.OnFeedback(func(pkt rtcp.RawPacket) { + if u.onWriteRTCP != nil { + u.onWriteRTCP([]rtcp.Packet{&pkt}) + } + }) + } + u.lock.Unlock() + + mt.AddReceiver(rtpReceiver, track, u.twcc) + + if newTrack { + u.handleTrackPublished(mt) + } +} + +// should be called with lock held +func (u *UptrackManager) getPublishedTrack(sid string) types.PublishedTrack { + return u.publishedTracks[sid] +} + +// should be called with lock held +func (u *UptrackManager) getPublishedTrackBySignalCid(clientId string) types.PublishedTrack { + for _, publishedTrack := range u.publishedTracks { + if publishedTrack.SignalCid() == clientId { + return publishedTrack + } + } + + return nil +} + +// should be called with lock held +func (u *UptrackManager) getPublishedTrackBySdpCid(clientId string) types.PublishedTrack { + for _, publishedTrack := range u.publishedTracks { + if publishedTrack.SdpCid() == clientId { + return publishedTrack + } + } + + return nil +} + +// should be called with lock held +func (u *UptrackManager) getPendingTrack(clientId string, kind livekit.TrackType) (string, *livekit.TrackInfo) { + signalCid := clientId + trackInfo := u.pendingTracks[clientId] + + if trackInfo == nil { + // + // If no match on client id, find first one matching type + // as MediaStreamTrack can change client id when transceiver + // is added to peer connection. + // + for cid, ti := range u.pendingTracks { + if ti.Type == kind { + trackInfo = ti + signalCid = cid + break + } + } + } + + // if still not found, we are done + if trackInfo == nil { + u.params.Logger.Errorw("track info not published prior to track", nil, "clientId", clientId) + } + return signalCid, trackInfo +} + +func (u *UptrackManager) handleTrackPublished(track types.PublishedTrack) { + u.lock.Lock() + if _, ok := u.publishedTracks[track.ID()]; !ok { + u.publishedTracks[track.ID()] = track + } + u.lock.Unlock() + + track.AddOnClose(func() { + // cleanup + u.lock.Lock() + trackSid := track.ID() + delete(u.publishedTracks, trackSid) + delete(u.pendingSubscriptions, trackSid) + // not modifying subscription permissions, will get reset on next update from participant + u.lock.Unlock() + // only send this when client is in a ready state + if u.onTrackUpdated != nil { + u.onTrackUpdated(track, true) + } + }) + + if u.onTrackPublished != nil { + u.onTrackPublished(track) + } +} + +func (u *UptrackManager) updateSubscriptionPermissions(permissions *livekit.UpdateSubscriptionPermissions) { + // every update overrides the existing + + // all_participants takes precedence + if permissions.AllParticipants { + // everything is allowed, nothing else to do + u.subscriptionPermissions = nil + return + } + + // per participant permissions + u.subscriptionPermissions = make(map[string]*livekit.TrackPermission) + for _, trackPerms := range permissions.TrackPermissions { + u.subscriptionPermissions[trackPerms.ParticipantSid] = trackPerms + } +} + +func (u *UptrackManager) hasPermission(trackSid string, subscriberID string) bool { + if u.subscriptionPermissions == nil { + return true + } + + perms, ok := u.subscriptionPermissions[subscriberID] + if !ok { + return false + } + + if perms.AllTracks { + return true + } + + for _, sid := range perms.TrackSids { + if sid == trackSid { + return true + } + } + + return false +} + +func (u *UptrackManager) getAllowedSubscribers(trackSid string) []string { + if u.subscriptionPermissions == nil { + return nil + } + + allowed := []string{} + for subscriberID, perms := range u.subscriptionPermissions { + if perms.AllTracks { + allowed = append(allowed, subscriberID) + continue + } + + for _, sid := range perms.TrackSids { + if sid == trackSid { + allowed = append(allowed, subscriberID) + break + } + } + } + + return allowed +} + +func (u *UptrackManager) maybeAddPendingSubscription(trackSid string, sub types.Participant) { + subscriberID := sub.ID() + + pending := u.pendingSubscriptions[trackSid] + for _, sid := range pending { + if sid == subscriberID { + // already pending + return + } + } + + u.pendingSubscriptions[trackSid] = append(u.pendingSubscriptions[trackSid], subscriberID) + go sub.SubscriptionPermissionUpdate(u.params.SID, trackSid, false) +} + +func (u *UptrackManager) maybeRemovePendingSubscription(trackSid string, sub types.Participant) { + subscriberID := sub.ID() + + pending := u.pendingSubscriptions[trackSid] + n := len(pending) + for idx, sid := range pending { + if sid == subscriberID { + u.pendingSubscriptions[trackSid][idx] = u.pendingSubscriptions[trackSid][n-1] + u.pendingSubscriptions[trackSid] = u.pendingSubscriptions[trackSid][:n-1] + break + } + } + if len(u.pendingSubscriptions[trackSid]) == 0 { + delete(u.pendingSubscriptions, trackSid) + } +} + +func (u *UptrackManager) processPendingSubscriptions(resolver func(participantSid string) types.Participant) { + updatedPendingSubscriptions := make(map[string][]string) + for trackSid, pending := range u.pendingSubscriptions { + track := u.getPublishedTrack(trackSid) + if track == nil { + continue + } + + var updatedPending []string + for _, sid := range pending { + var sub types.Participant + if resolver != nil { + sub = resolver(sid) + } + if sub == nil || sub.State() == livekit.ParticipantInfo_DISCONNECTED { + // do not keep this pending subscription as subscriber may be gone + continue + } + + if !u.hasPermission(trackSid, sid) { + updatedPending = append(updatedPending, sid) + continue + } + + if err := track.AddSubscriber(sub); err != nil { + u.params.Logger.Errorw("error reinstating pending subscription", err) + // keep it in pending on error in case the error is transient + updatedPending = append(updatedPending, sid) + continue + } + + go sub.SubscriptionPermissionUpdate(u.params.SID, trackSid, true) + } + + updatedPendingSubscriptions[trackSid] = updatedPending + } + + u.pendingSubscriptions = updatedPendingSubscriptions +} + +func (u *UptrackManager) maybeRevokeSubscriptions(resolver func(participantSid string) types.Participant) { + for _, track := range u.publishedTracks { + trackSid := track.ID() + allowed := u.getAllowedSubscribers(trackSid) + if allowed == nil { + // no restrictions + continue + } + + revokedSubscribers := track.RevokeDisallowedSubscribers(allowed) + for _, subID := range revokedSubscribers { + var sub types.Participant + if resolver != nil { + sub = resolver(subID) + } + if sub == nil { + continue + } + + u.maybeAddPendingSubscription(trackSid, sub) + } + } +} + +func (u *UptrackManager) rtcpSendWorker() { + defer Recover() + + // read from rtcpChan + for pkts := range u.rtcpCh { + if pkts == nil { + return + } + + fwdPkts := make([]rtcp.Packet, 0, len(pkts)) + for _, pkt := range pkts { + switch pkt.(type) { + case *rtcp.PictureLossIndication: + mediaSSRC := pkt.(*rtcp.PictureLossIndication).MediaSSRC + if u.pliThrottle.canSend(mediaSSRC) { + fwdPkts = append(fwdPkts, pkt) + } + case *rtcp.FullIntraRequest: + mediaSSRC := pkt.(*rtcp.FullIntraRequest).MediaSSRC + if u.pliThrottle.canSend(mediaSSRC) { + fwdPkts = append(fwdPkts, pkt) + } + default: + fwdPkts = append(fwdPkts, pkt) + } + } + + if len(fwdPkts) > 0 && u.onWriteRTCP != nil { + u.onWriteRTCP(fwdPkts) + } + } +} + +func (u *UptrackManager) DebugInfo() map[string]interface{} { + info := map[string]interface{}{} + publishedTrackInfo := make(map[string]interface{}) + pendingTrackInfo := make(map[string]interface{}) + + u.lock.RLock() + for trackSid, track := range u.publishedTracks { + if mt, ok := track.(*MediaTrack); ok { + publishedTrackInfo[trackSid] = mt.DebugInfo() + } else { + publishedTrackInfo[trackSid] = map[string]interface{}{ + "ID": track.ID(), + "Kind": track.Kind().String(), + "PubMuted": track.IsMuted(), + } + } + } + + for clientID, ti := range u.pendingTracks { + pendingTrackInfo[clientID] = map[string]interface{}{ + "Sid": ti.Sid, + "Type": ti.Type.String(), + "Simulcast": ti.Simulcast, + } + } + u.lock.RUnlock() + + info["PublishedTracks"] = publishedTrackInfo + info["PendingTracks"] = pendingTrackInfo + + return info +} diff --git a/pkg/rtc/uptrackmanager_test.go b/pkg/rtc/uptrackmanager_test.go new file mode 100644 index 000000000..8d6601842 --- /dev/null +++ b/pkg/rtc/uptrackmanager_test.go @@ -0,0 +1,195 @@ +package rtc + +import ( + "testing" + + "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" + "github.com/livekit/protocol/livekit" + "github.com/stretchr/testify/require" +) + +func TestUpdateSubscriptionPermissions(t *testing.T) { + t.Run("updates permissions", func(t *testing.T) { + um := NewUptrackManager(UptrackManagerParams{}) + + tra := &typesfakes.FakePublishedTrack{} + tra.IDReturns("audio") + um.publishedTracks["audio"] = tra + + trv := &typesfakes.FakePublishedTrack{} + trv.IDReturns("video") + um.publishedTracks["video"] = trv + + // no restrictive permissions + permissions := &livekit.UpdateSubscriptionPermissions{ + AllParticipants: true, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.Nil(t, um.subscriptionPermissions) + + // nobody is allowed to subscribe + permissions = &livekit.UpdateSubscriptionPermissions{ + TrackPermissions: []*livekit.TrackPermission{}, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.NotNil(t, um.subscriptionPermissions) + require.Equal(t, 0, len(um.subscriptionPermissions)) + + // allow all tracks for participants + perms1 := &livekit.TrackPermission{ + ParticipantSid: "p1", + AllTracks: true, + } + perms2 := &livekit.TrackPermission{ + ParticipantSid: "p2", + AllTracks: true, + } + permissions = &livekit.UpdateSubscriptionPermissions{ + TrackPermissions: []*livekit.TrackPermission{ + perms1, + perms2, + }, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.Equal(t, 2, len(um.subscriptionPermissions)) + require.EqualValues(t, perms1, um.subscriptionPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriptionPermissions["p2"]) + + // allow all tracks for some and restrictive for others + perms1 = &livekit.TrackPermission{ + ParticipantSid: "p1", + AllTracks: true, + } + perms2 = &livekit.TrackPermission{ + ParticipantSid: "p2", + TrackSids: []string{"audio"}, + } + perms3 := &livekit.TrackPermission{ + ParticipantSid: "p3", + TrackSids: []string{"video"}, + } + permissions = &livekit.UpdateSubscriptionPermissions{ + TrackPermissions: []*livekit.TrackPermission{ + perms1, + perms2, + perms3, + }, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.Equal(t, 3, len(um.subscriptionPermissions)) + require.EqualValues(t, perms1, um.subscriptionPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriptionPermissions["p2"]) + require.EqualValues(t, perms3, um.subscriptionPermissions["p3"]) + }) +} + +func TestPermissions(t *testing.T) { + t.Run("checks permissions", func(t *testing.T) { + um := NewUptrackManager(UptrackManagerParams{}) + + tra := &typesfakes.FakePublishedTrack{} + tra.IDReturns("audio") + um.publishedTracks["audio"] = tra + + trv := &typesfakes.FakePublishedTrack{} + trv.IDReturns("video") + um.publishedTracks["video"] = trv + + // no restrictive permissions + permissions := &livekit.UpdateSubscriptionPermissions{ + AllParticipants: true, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.True(t, um.hasPermission("audio", "p1")) + require.True(t, um.hasPermission("audio", "p2")) + + // nobody is allowed to subscribe + permissions = &livekit.UpdateSubscriptionPermissions{ + TrackPermissions: []*livekit.TrackPermission{}, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.False(t, um.hasPermission("audio", "p1")) + require.False(t, um.hasPermission("audio", "p2")) + + // allow all tracks for participants + permissions = &livekit.UpdateSubscriptionPermissions{ + TrackPermissions: []*livekit.TrackPermission{ + { + ParticipantSid: "p1", + AllTracks: true, + }, + { + ParticipantSid: "p2", + AllTracks: true, + }, + }, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.True(t, um.hasPermission("audio", "p1")) + require.True(t, um.hasPermission("video", "p1")) + require.True(t, um.hasPermission("audio", "p2")) + require.True(t, um.hasPermission("video", "p2")) + + // add a new track after permissions are set + trs := &typesfakes.FakePublishedTrack{} + trs.IDReturns("screen") + um.publishedTracks["screen"] = trs + + require.True(t, um.hasPermission("audio", "p1")) + require.True(t, um.hasPermission("video", "p1")) + require.True(t, um.hasPermission("screen", "p1")) + require.True(t, um.hasPermission("audio", "p2")) + require.True(t, um.hasPermission("video", "p2")) + require.True(t, um.hasPermission("screen", "p2")) + + // allow all tracks for some and restrictive for others + permissions = &livekit.UpdateSubscriptionPermissions{ + TrackPermissions: []*livekit.TrackPermission{ + { + ParticipantSid: "p1", + AllTracks: true, + }, + { + ParticipantSid: "p2", + TrackSids: []string{"audio"}, + }, + { + ParticipantSid: "p3", + TrackSids: []string{"video"}, + }, + }, + } + um.UpdateSubscriptionPermissions(permissions, nil) + require.True(t, um.hasPermission("audio", "p1")) + require.True(t, um.hasPermission("video", "p1")) + require.True(t, um.hasPermission("screen", "p1")) + + require.True(t, um.hasPermission("audio", "p2")) + require.False(t, um.hasPermission("video", "p2")) + require.False(t, um.hasPermission("screen", "p2")) + + require.False(t, um.hasPermission("audio", "p3")) + require.True(t, um.hasPermission("video", "p3")) + require.False(t, um.hasPermission("screen", "p3")) + + // add a new track after restrictive permissions are set + trw := &typesfakes.FakePublishedTrack{} + trw.IDReturns("watch") + um.publishedTracks["watch"] = trw + + require.True(t, um.hasPermission("audio", "p1")) + require.True(t, um.hasPermission("video", "p1")) + require.True(t, um.hasPermission("screen", "p1")) + require.True(t, um.hasPermission("watch", "p1")) + + require.True(t, um.hasPermission("audio", "p2")) + require.False(t, um.hasPermission("video", "p2")) + require.False(t, um.hasPermission("screen", "p2")) + require.False(t, um.hasPermission("watch", "p2")) + + require.False(t, um.hasPermission("audio", "p3")) + require.True(t, um.hasPermission("video", "p3")) + require.False(t, um.hasPermission("screen", "p3")) + require.False(t, um.hasPermission("watch", "p3")) + }) +} diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 15eb3dead..baf8bab31 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -274,7 +274,7 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName string, pi rout } r.telemetry.ParticipantJoined(ctx, room.Room, participant.ToProto(), pi.Client) - participant.OnClose(func(p types.Participant) { + participant.OnClose(func(p types.Participant, disallowedSubscriptions map[string]string) { if err := r.roomStore.DeleteParticipant(ctx, roomName, p.Identity()); err != nil { pLogger.Errorw("could not delete participant", err) } @@ -286,6 +286,8 @@ func (r *RoomManager) StartSession(ctx context.Context, roomName string, pi rout } } r.telemetry.ParticipantLeft(ctx, room.Room, p.ToProto()) + + room.RemoveDisallowedSubscriptions(p, disallowedSubscriptions) }) go r.rtcSessionWorker(room, participant, requestSource) diff --git a/pkg/sfu/streamallocator.go b/pkg/sfu/streamallocator.go index 4b630f6ff..5efc4cf28 100644 --- a/pkg/sfu/streamallocator.go +++ b/pkg/sfu/streamallocator.go @@ -459,12 +459,14 @@ func (s *StreamAllocator) handleSignalEstimate(event *Event) { s.prevReceivedEstimate = s.receivedEstimate s.receivedEstimate = int64(remb.Bitrate) - if s.prevReceivedEstimate != s.receivedEstimate { - s.logger.Debugw("received new estimate", - "old(bps)", s.prevReceivedEstimate, - "new(bps)", s.receivedEstimate, - ) - } + /* + if s.prevReceivedEstimate != s.receivedEstimate { + s.logger.Debugw("received new estimate", + "old(bps)", s.prevReceivedEstimate, + "new(bps)", s.receivedEstimate, + ) + } + */ if s.maybeCommitEstimate() { s.allocateAllTracks()