diff --git a/go.mod b/go.mod index 98c1b4f2e..3663352a3 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/google/wire v0.5.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/golang-lru v0.5.4 - github.com/livekit/protocol v0.11.1 + github.com/livekit/protocol v0.11.3 github.com/magefile/mage v1.11.0 github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0 github.com/mitchellh/go-homedir v1.1.0 diff --git a/go.sum b/go.sum index 4a7814f6a..fda65bff3 100644 --- a/go.sum +++ b/go.sum @@ -132,8 +132,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY/mKrcvA= github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= -github.com/livekit/protocol v0.11.1 h1:SY9oZlbHD9s2fNjus5zSmTapI3uOIbb6YLkZYtIJESs= -github.com/livekit/protocol v0.11.1/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= +github.com/livekit/protocol v0.11.3 h1:Al2oOrRwFNmgpw7dUvvc0s+oju9DoRUWi7g7GwrDiZc= +github.com/livekit/protocol v0.11.3/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg= github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls= github.com/magefile/mage v1.11.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 420b6a300..03227c78d 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -92,6 +92,19 @@ func (r *Room) GetParticipant(identity string) types.Participant { return r.participants[identity] } +func (r *Room) GetParticipantBySid(participantID string) types.Participant { + r.lock.RLock() + defer r.lock.RUnlock() + + for _, p := range r.participants { + if p.ID() == participantID { + return p + } + } + + return nil +} + func (r *Room) GetParticipants() []types.Participant { r.lock.RLock() defer r.lock.RUnlock() @@ -322,16 +335,34 @@ func (r *Room) RemoveParticipant(identity string) { } } -func (r *Room) UpdateSubscriptions(participant types.Participant, trackIds []string, subscribe bool) error { +func (r *Room) UpdateSubscriptions( + participant types.Participant, + trackIds []string, + participantTracks []*livekit.ParticipantTracks, + subscribe bool, +) error { // find all matching tracks - var tracks []types.PublishedTrack + tracks := make(map[string]types.PublishedTrack) participants := r.GetParticipants() - for _, p := range participants { - for _, sid := range trackIds { - for _, track := range p.GetPublishedTracks() { - if sid == track.ID() { - tracks = append(tracks, track) - } + for _, trackSid := range trackIds { + for _, p := range participants { + track := p.GetPublishedTrack(trackSid) + if track != nil { + tracks[trackSid] = track + break + } + } + } + + for _, pt := range participantTracks { + p := r.GetParticipantBySid(pt.ParticipantSid) + if p == nil { + continue + } + for _, trackSid := range pt.TrackSids { + track := p.GetPublishedTrack(trackSid) + if track != nil { + tracks[trackSid] = track } } } diff --git a/pkg/rtc/signalhandler.go b/pkg/rtc/signalhandler.go index bf76061b7..8431f5d56 100644 --- a/pkg/rtc/signalhandler.go +++ b/pkg/rtc/signalhandler.go @@ -39,7 +39,12 @@ func HandleParticipantSignal(room types.Room, participant types.Participant, req case *livekit.SignalRequest_Subscription: var err error if participant.CanSubscribe() { - updateErr := room.UpdateSubscriptions(participant, msg.Subscription.TrackSids, msg.Subscription.Subscribe) + updateErr := room.UpdateSubscriptions( + participant, + msg.Subscription.TrackSids, + msg.Subscription.ParticipantTracks, + msg.Subscription.Subscribe, + ) if updateErr != nil { err = updateErr } diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 9283bca58..6101a1a1c 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -94,7 +94,7 @@ type Participant interface { //counterfeiter:generate . Room type Room interface { Name() string - UpdateSubscriptions(participant Participant, trackIDs []string, subscribe bool) error + UpdateSubscriptions(participant Participant, trackIDs []string, participantTracks []*livekit.ParticipantTracks, subscribe bool) error } // MediaTrack represents a media track diff --git a/pkg/rtc/types/typesfakes/fake_room.go b/pkg/rtc/types/typesfakes/fake_room.go index 46cc27c64..2e6759190 100644 --- a/pkg/rtc/types/typesfakes/fake_room.go +++ b/pkg/rtc/types/typesfakes/fake_room.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" ) type FakeRoom struct { @@ -18,12 +19,13 @@ type FakeRoom struct { nameReturnsOnCall map[int]struct { result1 string } - UpdateSubscriptionsStub func(types.Participant, []string, bool) error + UpdateSubscriptionsStub func(types.Participant, []string, []*livekit.ParticipantTracks, bool) error updateSubscriptionsMutex sync.RWMutex updateSubscriptionsArgsForCall []struct { arg1 types.Participant arg2 []string - arg3 bool + arg3 []*livekit.ParticipantTracks + arg4 bool } updateSubscriptionsReturns struct { result1 error @@ -88,25 +90,31 @@ func (fake *FakeRoom) NameReturnsOnCall(i int, result1 string) { }{result1} } -func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 bool) error { +func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 []*livekit.ParticipantTracks, arg4 bool) error { var arg2Copy []string if arg2 != nil { arg2Copy = make([]string, len(arg2)) copy(arg2Copy, arg2) } + var arg3Copy []*livekit.ParticipantTracks + if arg3 != nil { + arg3Copy = make([]*livekit.ParticipantTracks, len(arg3)) + copy(arg3Copy, arg3) + } fake.updateSubscriptionsMutex.Lock() ret, specificReturn := fake.updateSubscriptionsReturnsOnCall[len(fake.updateSubscriptionsArgsForCall)] fake.updateSubscriptionsArgsForCall = append(fake.updateSubscriptionsArgsForCall, struct { arg1 types.Participant arg2 []string - arg3 bool - }{arg1, arg2Copy, arg3}) + arg3 []*livekit.ParticipantTracks + arg4 bool + }{arg1, arg2Copy, arg3Copy, arg4}) stub := fake.UpdateSubscriptionsStub fakeReturns := fake.updateSubscriptionsReturns - fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3}) + fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3Copy, arg4}) fake.updateSubscriptionsMutex.Unlock() if stub != nil { - return stub(arg1, arg2, arg3) + return stub(arg1, arg2, arg3, arg4) } if specificReturn { return ret.result1 @@ -120,17 +128,17 @@ func (fake *FakeRoom) UpdateSubscriptionsCallCount() int { return len(fake.updateSubscriptionsArgsForCall) } -func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.Participant, []string, bool) error) { +func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.Participant, []string, []*livekit.ParticipantTracks, bool) error) { fake.updateSubscriptionsMutex.Lock() defer fake.updateSubscriptionsMutex.Unlock() fake.UpdateSubscriptionsStub = stub } -func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.Participant, []string, bool) { +func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.Participant, []string, []*livekit.ParticipantTracks, bool) { fake.updateSubscriptionsMutex.RLock() defer fake.updateSubscriptionsMutex.RUnlock() argsForCall := fake.updateSubscriptionsArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 } func (fake *FakeRoom) UpdateSubscriptionsReturns(result1 error) { diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index 9e84b1295..15eb3dead 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -437,7 +437,12 @@ func (r *RoomManager) handleRTCMessage(_ context.Context, roomName, identity str return } pLogger.Debugw("updating participant subscriptions") - if err := room.UpdateSubscriptions(participant, rm.UpdateSubscriptions.TrackSids, rm.UpdateSubscriptions.Subscribe); err != nil { + if err := room.UpdateSubscriptions( + participant, + rm.UpdateSubscriptions.TrackSids, + rm.UpdateSubscriptions.ParticipantTracks, + rm.UpdateSubscriptions.Subscribe, + ); err != nil { pLogger.Warnw("could not update subscription", err, "tracks", rm.UpdateSubscriptions.TrackSids, "subscribe", rm.UpdateSubscriptions.Subscribe)