Split out SignalHandler to simplify testing (#250)

This commit is contained in:
David Zhao
2021-12-10 13:12:45 -08:00
committed by GitHub
parent 882f3bdde5
commit d342335d09
5 changed files with 380 additions and 125 deletions
+4
View File
@@ -82,6 +82,10 @@ func NewRoom(room *livekit.Room, config WebRTCConfig, audioConfig *config.AudioC
return r
}
func (r *Room) Name() string {
return r.Room.Name
}
func (r *Room) GetParticipant(identity string) types.Participant {
r.lock.RLock()
defer r.lock.RUnlock()
+109
View File
@@ -0,0 +1,109 @@
package rtc
import (
"github.com/livekit/livekit-server/pkg/rtc/types"
"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/logger"
)
func HandleParticipantSignal(room types.Room, participant types.Participant, req *livekit.SignalRequest) error {
switch msg := req.Message.(type) {
case *livekit.SignalRequest_Offer:
_, err := participant.HandleOffer(FromProtoSessionDescription(msg.Offer))
if err != nil {
logger.Errorw("could not handle offer", err,
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
)
return err
}
case *livekit.SignalRequest_AddTrack:
logger.Debugw("add track request",
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
"track", msg.AddTrack.Cid)
participant.AddTrack(msg.AddTrack)
case *livekit.SignalRequest_Answer:
sd := FromProtoSessionDescription(msg.Answer)
if err := participant.HandleAnswer(sd); err != nil {
logger.Errorw("could not handle answer", err,
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
)
// connection cannot be successful if we can't answer
return err
}
case *livekit.SignalRequest_Trickle:
candidateInit, err := FromProtoTrickle(msg.Trickle)
if err != nil {
logger.Warnw("could not decode trickle", err,
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
)
return nil
}
// logger.Debugw("adding peer candidate", "participant", participant.Identity())
if err := participant.AddICECandidate(candidateInit, msg.Trickle.Target); err != nil {
logger.Warnw("could not handle trickle", err,
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
)
}
case *livekit.SignalRequest_Mute:
participant.SetTrackMuted(msg.Mute.Sid, msg.Mute.Muted, false)
case *livekit.SignalRequest_Subscription:
var err error
if participant.CanSubscribe() {
updateErr := room.UpdateSubscriptions(participant, msg.Subscription.TrackSids, msg.Subscription.Subscribe)
if updateErr != nil {
err = updateErr
}
} else {
err = ErrCannotSubscribe
}
if err != nil {
logger.Warnw("could not update subscription", err,
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
"tracks", msg.Subscription.TrackSids,
"subscribe", msg.Subscription.Subscribe)
}
case *livekit.SignalRequest_TrackSetting:
for _, sid := range msg.TrackSetting.TrackSids {
subTrack := participant.GetSubscribedTrack(sid)
if subTrack == nil {
logger.Warnw("unable to find SubscribedTrack", nil,
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
"track", sid)
continue
}
// find quality for published track
logger.Debugw("updating track settings",
"room", room.Name(),
"participant", participant.Identity(),
"pID", participant.ID(),
"settings", msg.TrackSetting)
subTrack.UpdateSubscriberSettings(msg.TrackSetting)
}
case *livekit.SignalRequest_UpdateLayers:
track := participant.GetPublishedTrack(msg.UpdateLayers.TrackSid)
if track == nil {
logger.Warnw("could not find published track", nil,
"track", msg.UpdateLayers.TrackSid)
return nil
}
track.UpdateVideoLayers(msg.UpdateLayers.Layers)
case *livekit.SignalRequest_Leave:
_ = participant.Close()
}
return nil
}
+7
View File
@@ -90,6 +90,13 @@ type Participant interface {
DebugInfo() map[string]interface{}
}
// Room is a container of participants, and can provide room level actions
//counterfeiter:generate . Room
type Room interface {
Name() string
UpdateSubscriptions(participant Participant, trackIDs []string, subscribe bool) error
}
// MediaTrack represents a media track
//counterfeiter:generate . MediaTrack
type MediaTrack interface {
+259
View File
@@ -0,0 +1,259 @@
// Code generated by counterfeiter. DO NOT EDIT.
package typesfakes
import (
"sync"
"github.com/livekit/livekit-server/pkg/rtc/types"
)
type FakeRoom struct {
GetParticipantStub func(string) types.Participant
getParticipantMutex sync.RWMutex
getParticipantArgsForCall []struct {
arg1 string
}
getParticipantReturns struct {
result1 types.Participant
}
getParticipantReturnsOnCall map[int]struct {
result1 types.Participant
}
NameStub func() string
nameMutex sync.RWMutex
nameArgsForCall []struct {
}
nameReturns struct {
result1 string
}
nameReturnsOnCall map[int]struct {
result1 string
}
UpdateSubscriptionsStub func(types.Participant, []string, bool) error
updateSubscriptionsMutex sync.RWMutex
updateSubscriptionsArgsForCall []struct {
arg1 types.Participant
arg2 []string
arg3 bool
}
updateSubscriptionsReturns struct {
result1 error
}
updateSubscriptionsReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeRoom) GetParticipant(arg1 string) types.Participant {
fake.getParticipantMutex.Lock()
ret, specificReturn := fake.getParticipantReturnsOnCall[len(fake.getParticipantArgsForCall)]
fake.getParticipantArgsForCall = append(fake.getParticipantArgsForCall, struct {
arg1 string
}{arg1})
stub := fake.GetParticipantStub
fakeReturns := fake.getParticipantReturns
fake.recordInvocation("GetParticipant", []interface{}{arg1})
fake.getParticipantMutex.Unlock()
if stub != nil {
return stub(arg1)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRoom) GetParticipantCallCount() int {
fake.getParticipantMutex.RLock()
defer fake.getParticipantMutex.RUnlock()
return len(fake.getParticipantArgsForCall)
}
func (fake *FakeRoom) GetParticipantCalls(stub func(string) types.Participant) {
fake.getParticipantMutex.Lock()
defer fake.getParticipantMutex.Unlock()
fake.GetParticipantStub = stub
}
func (fake *FakeRoom) GetParticipantArgsForCall(i int) string {
fake.getParticipantMutex.RLock()
defer fake.getParticipantMutex.RUnlock()
argsForCall := fake.getParticipantArgsForCall[i]
return argsForCall.arg1
}
func (fake *FakeRoom) GetParticipantReturns(result1 types.Participant) {
fake.getParticipantMutex.Lock()
defer fake.getParticipantMutex.Unlock()
fake.GetParticipantStub = nil
fake.getParticipantReturns = struct {
result1 types.Participant
}{result1}
}
func (fake *FakeRoom) GetParticipantReturnsOnCall(i int, result1 types.Participant) {
fake.getParticipantMutex.Lock()
defer fake.getParticipantMutex.Unlock()
fake.GetParticipantStub = nil
if fake.getParticipantReturnsOnCall == nil {
fake.getParticipantReturnsOnCall = make(map[int]struct {
result1 types.Participant
})
}
fake.getParticipantReturnsOnCall[i] = struct {
result1 types.Participant
}{result1}
}
func (fake *FakeRoom) Name() string {
fake.nameMutex.Lock()
ret, specificReturn := fake.nameReturnsOnCall[len(fake.nameArgsForCall)]
fake.nameArgsForCall = append(fake.nameArgsForCall, struct {
}{})
stub := fake.NameStub
fakeReturns := fake.nameReturns
fake.recordInvocation("Name", []interface{}{})
fake.nameMutex.Unlock()
if stub != nil {
return stub()
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRoom) NameCallCount() int {
fake.nameMutex.RLock()
defer fake.nameMutex.RUnlock()
return len(fake.nameArgsForCall)
}
func (fake *FakeRoom) NameCalls(stub func() string) {
fake.nameMutex.Lock()
defer fake.nameMutex.Unlock()
fake.NameStub = stub
}
func (fake *FakeRoom) NameReturns(result1 string) {
fake.nameMutex.Lock()
defer fake.nameMutex.Unlock()
fake.NameStub = nil
fake.nameReturns = struct {
result1 string
}{result1}
}
func (fake *FakeRoom) NameReturnsOnCall(i int, result1 string) {
fake.nameMutex.Lock()
defer fake.nameMutex.Unlock()
fake.NameStub = nil
if fake.nameReturnsOnCall == nil {
fake.nameReturnsOnCall = make(map[int]struct {
result1 string
})
}
fake.nameReturnsOnCall[i] = struct {
result1 string
}{result1}
}
func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 bool) error {
var arg2Copy []string
if arg2 != nil {
arg2Copy = make([]string, len(arg2))
copy(arg2Copy, arg2)
}
fake.updateSubscriptionsMutex.Lock()
ret, specificReturn := fake.updateSubscriptionsReturnsOnCall[len(fake.updateSubscriptionsArgsForCall)]
fake.updateSubscriptionsArgsForCall = append(fake.updateSubscriptionsArgsForCall, struct {
arg1 types.Participant
arg2 []string
arg3 bool
}{arg1, arg2Copy, arg3})
stub := fake.UpdateSubscriptionsStub
fakeReturns := fake.updateSubscriptionsReturns
fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3})
fake.updateSubscriptionsMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeRoom) UpdateSubscriptionsCallCount() int {
fake.updateSubscriptionsMutex.RLock()
defer fake.updateSubscriptionsMutex.RUnlock()
return len(fake.updateSubscriptionsArgsForCall)
}
func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.Participant, []string, bool) error) {
fake.updateSubscriptionsMutex.Lock()
defer fake.updateSubscriptionsMutex.Unlock()
fake.UpdateSubscriptionsStub = stub
}
func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.Participant, []string, bool) {
fake.updateSubscriptionsMutex.RLock()
defer fake.updateSubscriptionsMutex.RUnlock()
argsForCall := fake.updateSubscriptionsArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3
}
func (fake *FakeRoom) UpdateSubscriptionsReturns(result1 error) {
fake.updateSubscriptionsMutex.Lock()
defer fake.updateSubscriptionsMutex.Unlock()
fake.UpdateSubscriptionsStub = nil
fake.updateSubscriptionsReturns = struct {
result1 error
}{result1}
}
func (fake *FakeRoom) UpdateSubscriptionsReturnsOnCall(i int, result1 error) {
fake.updateSubscriptionsMutex.Lock()
defer fake.updateSubscriptionsMutex.Unlock()
fake.UpdateSubscriptionsStub = nil
if fake.updateSubscriptionsReturnsOnCall == nil {
fake.updateSubscriptionsReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.updateSubscriptionsReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeRoom) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
fake.getParticipantMutex.RLock()
defer fake.getParticipantMutex.RUnlock()
fake.nameMutex.RLock()
defer fake.nameMutex.RUnlock()
fake.updateSubscriptionsMutex.RLock()
defer fake.updateSubscriptionsMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeRoom) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ types.Room = new(FakeRoom)
+1 -125
View File
@@ -361,7 +361,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici
}
req := obj.(*livekit.SignalRequest)
if err := r.handleSignalRequest(room, participant, req); err != nil {
if err := rtc.HandleParticipantSignal(room, participant, req); err != nil {
// more specific errors are already logged
// treat errors returned as fatal
return
@@ -370,130 +370,6 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici
}
}
func (r *RoomManager) handleSignalRequest(room *rtc.Room, participant types.Participant, req *livekit.SignalRequest) error {
switch msg := req.Message.(type) {
case *livekit.SignalRequest_Offer:
_, err := participant.HandleOffer(rtc.FromProtoSessionDescription(msg.Offer))
if err != nil {
logger.Errorw("could not handle offer", err,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
)
return err
}
case *livekit.SignalRequest_AddTrack:
logger.Debugw("add track request",
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
"track", msg.AddTrack.Cid)
participant.AddTrack(msg.AddTrack)
case *livekit.SignalRequest_Answer:
sd := rtc.FromProtoSessionDescription(msg.Answer)
if err := participant.HandleAnswer(sd); err != nil {
logger.Errorw("could not handle answer", err,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
)
// connection cannot be successful if we can't answer
return err
}
case *livekit.SignalRequest_Trickle:
candidateInit, err := rtc.FromProtoTrickle(msg.Trickle)
if err != nil {
logger.Warnw("could not decode trickle", err,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
)
return nil
}
// logger.Debugw("adding peer candidate", "participant", participant.Identity())
if err := participant.AddICECandidate(candidateInit, msg.Trickle.Target); err != nil {
logger.Warnw("could not handle trickle", err,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
)
}
case *livekit.SignalRequest_Mute:
participant.SetTrackMuted(msg.Mute.Sid, msg.Mute.Muted, false)
case *livekit.SignalRequest_Subscription:
var err error
if participant.CanSubscribe() {
updateErr := room.UpdateSubscriptions(participant, msg.Subscription.TrackSids, msg.Subscription.Subscribe)
if updateErr != nil {
err = updateErr
}
} else {
err = rtc.ErrCannotSubscribe
}
if err != nil {
logger.Warnw("could not update subscription", err,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
"tracks", msg.Subscription.TrackSids,
"subscribe", msg.Subscription.Subscribe)
}
case *livekit.SignalRequest_TrackSetting:
for _, sid := range msg.TrackSetting.TrackSids {
subTrack := participant.GetSubscribedTrack(sid)
if subTrack == nil {
logger.Warnw("unable to find SubscribedTrack", nil,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
"track", sid)
continue
}
// find the source PublishedTrack
publisher := room.GetParticipant(subTrack.PublisherIdentity())
if publisher == nil {
logger.Warnw("unable to find publisher of SubscribedTrack", nil,
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
"publisher", subTrack.PublisherIdentity(),
"track", sid)
continue
}
pubTrack := publisher.GetPublishedTrack(sid)
if pubTrack == nil {
logger.Warnw("unable to find PublishedTrack", nil,
"room", room.Room.Name,
"participant", publisher.Identity(),
"pID", publisher.ID(),
"track", sid)
continue
}
// find quality for published track
logger.Debugw("updating track settings",
"room", room.Room.Name,
"participant", participant.Identity(),
"pID", participant.ID(),
"settings", msg.TrackSetting)
subTrack.UpdateSubscriberSettings(msg.TrackSetting)
}
case *livekit.SignalRequest_UpdateLayers:
track := participant.GetPublishedTrack(msg.UpdateLayers.TrackSid)
if track == nil {
logger.Warnw("could not find published track", nil,
"track", msg.UpdateLayers.TrackSid)
return nil
}
track.UpdateVideoLayers(msg.UpdateLayers.Layers)
case *livekit.SignalRequest_Leave:
_ = participant.Close()
}
return nil
}
// handles RTC messages resulted from Room API calls
func (r *RoomManager) handleRTCMessage(ctx context.Context, roomName, identity string, msg *livekit.RTCNodeMessage) {
r.lock.RLock()