Include participant_sid in UpdateSubscription. (#279)

* Include `participant_sid` in `UpdateSubscription`.

Prevents all publisher tracks to find a match.

* generate

* Update protocol version
This commit is contained in:
Raja Subramanian
2021-12-23 09:18:32 +05:30
committed by GitHub
parent 42a9b6657d
commit eae6eff6a3
7 changed files with 73 additions and 24 deletions
+1 -1
View File
@@ -14,7 +14,7 @@ require (
github.com/google/wire v0.5.0
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/golang-lru v0.5.4
github.com/livekit/protocol v0.11.1
github.com/livekit/protocol v0.11.3
github.com/magefile/mage v1.11.0
github.com/maxbrunsfeld/counterfeiter/v6 v6.3.0
github.com/mitchellh/go-homedir v1.1.0
+2 -2
View File
@@ -132,8 +132,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lithammer/shortuuid/v3 v3.0.6 h1:pr15YQyvhiSX/qPxncFtqk+v4xLEpOZObbsY/mKrcvA=
github.com/lithammer/shortuuid/v3 v3.0.6/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts=
github.com/livekit/protocol v0.11.1 h1:SY9oZlbHD9s2fNjus5zSmTapI3uOIbb6YLkZYtIJESs=
github.com/livekit/protocol v0.11.1/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg=
github.com/livekit/protocol v0.11.3 h1:Al2oOrRwFNmgpw7dUvvc0s+oju9DoRUWi7g7GwrDiZc=
github.com/livekit/protocol v0.11.3/go.mod h1:YoHW9YbWbPnuVsgwBB4hAINKT+V68jmfh9zXBSSn6Wg=
github.com/magefile/mage v1.11.0 h1:C/55Ywp9BpgVVclD3lRnSYCwXTYxmSppIgLeDYlNuls=
github.com/magefile/mage v1.11.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
+39 -8
View File
@@ -92,6 +92,19 @@ func (r *Room) GetParticipant(identity string) types.Participant {
return r.participants[identity]
}
func (r *Room) GetParticipantBySid(participantID string) types.Participant {
r.lock.RLock()
defer r.lock.RUnlock()
for _, p := range r.participants {
if p.ID() == participantID {
return p
}
}
return nil
}
func (r *Room) GetParticipants() []types.Participant {
r.lock.RLock()
defer r.lock.RUnlock()
@@ -322,16 +335,34 @@ func (r *Room) RemoveParticipant(identity string) {
}
}
func (r *Room) UpdateSubscriptions(participant types.Participant, trackIds []string, subscribe bool) error {
func (r *Room) UpdateSubscriptions(
participant types.Participant,
trackIds []string,
participantTracks []*livekit.ParticipantTracks,
subscribe bool,
) error {
// find all matching tracks
var tracks []types.PublishedTrack
tracks := make(map[string]types.PublishedTrack)
participants := r.GetParticipants()
for _, p := range participants {
for _, sid := range trackIds {
for _, track := range p.GetPublishedTracks() {
if sid == track.ID() {
tracks = append(tracks, track)
}
for _, trackSid := range trackIds {
for _, p := range participants {
track := p.GetPublishedTrack(trackSid)
if track != nil {
tracks[trackSid] = track
break
}
}
}
for _, pt := range participantTracks {
p := r.GetParticipantBySid(pt.ParticipantSid)
if p == nil {
continue
}
for _, trackSid := range pt.TrackSids {
track := p.GetPublishedTrack(trackSid)
if track != nil {
tracks[trackSid] = track
}
}
}
+6 -1
View File
@@ -39,7 +39,12 @@ func HandleParticipantSignal(room types.Room, participant types.Participant, req
case *livekit.SignalRequest_Subscription:
var err error
if participant.CanSubscribe() {
updateErr := room.UpdateSubscriptions(participant, msg.Subscription.TrackSids, msg.Subscription.Subscribe)
updateErr := room.UpdateSubscriptions(
participant,
msg.Subscription.TrackSids,
msg.Subscription.ParticipantTracks,
msg.Subscription.Subscribe,
)
if updateErr != nil {
err = updateErr
}
+1 -1
View File
@@ -94,7 +94,7 @@ type Participant interface {
//counterfeiter:generate . Room
type Room interface {
Name() string
UpdateSubscriptions(participant Participant, trackIDs []string, subscribe bool) error
UpdateSubscriptions(participant Participant, trackIDs []string, participantTracks []*livekit.ParticipantTracks, subscribe bool) error
}
// MediaTrack represents a media track
+18 -10
View File
@@ -5,6 +5,7 @@ import (
"sync"
"github.com/livekit/livekit-server/pkg/rtc/types"
"github.com/livekit/protocol/livekit"
)
type FakeRoom struct {
@@ -18,12 +19,13 @@ type FakeRoom struct {
nameReturnsOnCall map[int]struct {
result1 string
}
UpdateSubscriptionsStub func(types.Participant, []string, bool) error
UpdateSubscriptionsStub func(types.Participant, []string, []*livekit.ParticipantTracks, bool) error
updateSubscriptionsMutex sync.RWMutex
updateSubscriptionsArgsForCall []struct {
arg1 types.Participant
arg2 []string
arg3 bool
arg3 []*livekit.ParticipantTracks
arg4 bool
}
updateSubscriptionsReturns struct {
result1 error
@@ -88,25 +90,31 @@ func (fake *FakeRoom) NameReturnsOnCall(i int, result1 string) {
}{result1}
}
func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 bool) error {
func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 []*livekit.ParticipantTracks, arg4 bool) error {
var arg2Copy []string
if arg2 != nil {
arg2Copy = make([]string, len(arg2))
copy(arg2Copy, arg2)
}
var arg3Copy []*livekit.ParticipantTracks
if arg3 != nil {
arg3Copy = make([]*livekit.ParticipantTracks, len(arg3))
copy(arg3Copy, arg3)
}
fake.updateSubscriptionsMutex.Lock()
ret, specificReturn := fake.updateSubscriptionsReturnsOnCall[len(fake.updateSubscriptionsArgsForCall)]
fake.updateSubscriptionsArgsForCall = append(fake.updateSubscriptionsArgsForCall, struct {
arg1 types.Participant
arg2 []string
arg3 bool
}{arg1, arg2Copy, arg3})
arg3 []*livekit.ParticipantTracks
arg4 bool
}{arg1, arg2Copy, arg3Copy, arg4})
stub := fake.UpdateSubscriptionsStub
fakeReturns := fake.updateSubscriptionsReturns
fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3})
fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3Copy, arg4})
fake.updateSubscriptionsMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
return stub(arg1, arg2, arg3, arg4)
}
if specificReturn {
return ret.result1
@@ -120,17 +128,17 @@ func (fake *FakeRoom) UpdateSubscriptionsCallCount() int {
return len(fake.updateSubscriptionsArgsForCall)
}
func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.Participant, []string, bool) error) {
func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.Participant, []string, []*livekit.ParticipantTracks, bool) error) {
fake.updateSubscriptionsMutex.Lock()
defer fake.updateSubscriptionsMutex.Unlock()
fake.UpdateSubscriptionsStub = stub
}
func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.Participant, []string, bool) {
func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.Participant, []string, []*livekit.ParticipantTracks, bool) {
fake.updateSubscriptionsMutex.RLock()
defer fake.updateSubscriptionsMutex.RUnlock()
argsForCall := fake.updateSubscriptionsArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4
}
func (fake *FakeRoom) UpdateSubscriptionsReturns(result1 error) {
+6 -1
View File
@@ -437,7 +437,12 @@ func (r *RoomManager) handleRTCMessage(_ context.Context, roomName, identity str
return
}
pLogger.Debugw("updating participant subscriptions")
if err := room.UpdateSubscriptions(participant, rm.UpdateSubscriptions.TrackSids, rm.UpdateSubscriptions.Subscribe); err != nil {
if err := room.UpdateSubscriptions(
participant,
rm.UpdateSubscriptions.TrackSids,
rm.UpdateSubscriptions.ParticipantTracks,
rm.UpdateSubscriptions.Subscribe,
); err != nil {
pLogger.Warnw("could not update subscription", err,
"tracks", rm.UpdateSubscriptions.TrackSids,
"subscribe", rm.UpdateSubscriptions.Subscribe)