diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 98e22813c..5fc4ca5d6 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -82,6 +82,10 @@ func NewRoom(room *livekit.Room, config WebRTCConfig, audioConfig *config.AudioC return r } +func (r *Room) Name() string { + return r.Room.Name +} + func (r *Room) GetParticipant(identity string) types.Participant { r.lock.RLock() defer r.lock.RUnlock() diff --git a/pkg/rtc/signalhandler.go b/pkg/rtc/signalhandler.go new file mode 100644 index 000000000..e41902b3f --- /dev/null +++ b/pkg/rtc/signalhandler.go @@ -0,0 +1,109 @@ +package rtc + +import ( + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +func HandleParticipantSignal(room types.Room, participant types.Participant, req *livekit.SignalRequest) error { + switch msg := req.Message.(type) { + case *livekit.SignalRequest_Offer: + _, err := participant.HandleOffer(FromProtoSessionDescription(msg.Offer)) + if err != nil { + logger.Errorw("could not handle offer", err, + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + ) + return err + } + case *livekit.SignalRequest_AddTrack: + logger.Debugw("add track request", + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + "track", msg.AddTrack.Cid) + participant.AddTrack(msg.AddTrack) + case *livekit.SignalRequest_Answer: + sd := FromProtoSessionDescription(msg.Answer) + if err := participant.HandleAnswer(sd); err != nil { + logger.Errorw("could not handle answer", err, + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + ) + // connection cannot be successful if we can't answer + return err + } + case *livekit.SignalRequest_Trickle: + candidateInit, err := FromProtoTrickle(msg.Trickle) + if err != nil { + logger.Warnw("could not decode trickle", err, + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + ) + return nil + } + // logger.Debugw("adding peer candidate", "participant", participant.Identity()) + if err := participant.AddICECandidate(candidateInit, msg.Trickle.Target); err != nil { + logger.Warnw("could not handle trickle", err, + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + ) + } + case *livekit.SignalRequest_Mute: + participant.SetTrackMuted(msg.Mute.Sid, msg.Mute.Muted, false) + case *livekit.SignalRequest_Subscription: + var err error + if participant.CanSubscribe() { + updateErr := room.UpdateSubscriptions(participant, msg.Subscription.TrackSids, msg.Subscription.Subscribe) + if updateErr != nil { + err = updateErr + } + } else { + err = ErrCannotSubscribe + } + if err != nil { + logger.Warnw("could not update subscription", err, + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + "tracks", msg.Subscription.TrackSids, + "subscribe", msg.Subscription.Subscribe) + } + case *livekit.SignalRequest_TrackSetting: + for _, sid := range msg.TrackSetting.TrackSids { + subTrack := participant.GetSubscribedTrack(sid) + if subTrack == nil { + logger.Warnw("unable to find SubscribedTrack", nil, + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + "track", sid) + continue + } + + // find quality for published track + logger.Debugw("updating track settings", + "room", room.Name(), + "participant", participant.Identity(), + "pID", participant.ID(), + "settings", msg.TrackSetting) + subTrack.UpdateSubscriberSettings(msg.TrackSetting) + } + case *livekit.SignalRequest_UpdateLayers: + track := participant.GetPublishedTrack(msg.UpdateLayers.TrackSid) + if track == nil { + logger.Warnw("could not find published track", nil, + "track", msg.UpdateLayers.TrackSid) + return nil + } + track.UpdateVideoLayers(msg.UpdateLayers.Layers) + case *livekit.SignalRequest_Leave: + _ = participant.Close() + } + return nil +} diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index f6093fbab..ce6b1ea77 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -90,6 +90,13 @@ type Participant interface { DebugInfo() map[string]interface{} } +// Room is a container of participants, and can provide room level actions +//counterfeiter:generate . Room +type Room interface { + Name() string + UpdateSubscriptions(participant Participant, trackIDs []string, subscribe bool) error +} + // MediaTrack represents a media track //counterfeiter:generate . MediaTrack type MediaTrack interface { diff --git a/pkg/rtc/types/typesfakes/fake_room.go b/pkg/rtc/types/typesfakes/fake_room.go new file mode 100644 index 000000000..78172f3ce --- /dev/null +++ b/pkg/rtc/types/typesfakes/fake_room.go @@ -0,0 +1,259 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +type FakeRoom struct { + GetParticipantStub func(string) types.Participant + getParticipantMutex sync.RWMutex + getParticipantArgsForCall []struct { + arg1 string + } + getParticipantReturns struct { + result1 types.Participant + } + getParticipantReturnsOnCall map[int]struct { + result1 types.Participant + } + NameStub func() string + nameMutex sync.RWMutex + nameArgsForCall []struct { + } + nameReturns struct { + result1 string + } + nameReturnsOnCall map[int]struct { + result1 string + } + UpdateSubscriptionsStub func(types.Participant, []string, bool) error + updateSubscriptionsMutex sync.RWMutex + updateSubscriptionsArgsForCall []struct { + arg1 types.Participant + arg2 []string + arg3 bool + } + updateSubscriptionsReturns struct { + result1 error + } + updateSubscriptionsReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRoom) GetParticipant(arg1 string) types.Participant { + fake.getParticipantMutex.Lock() + ret, specificReturn := fake.getParticipantReturnsOnCall[len(fake.getParticipantArgsForCall)] + fake.getParticipantArgsForCall = append(fake.getParticipantArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.GetParticipantStub + fakeReturns := fake.getParticipantReturns + fake.recordInvocation("GetParticipant", []interface{}{arg1}) + fake.getParticipantMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) GetParticipantCallCount() int { + fake.getParticipantMutex.RLock() + defer fake.getParticipantMutex.RUnlock() + return len(fake.getParticipantArgsForCall) +} + +func (fake *FakeRoom) GetParticipantCalls(stub func(string) types.Participant) { + fake.getParticipantMutex.Lock() + defer fake.getParticipantMutex.Unlock() + fake.GetParticipantStub = stub +} + +func (fake *FakeRoom) GetParticipantArgsForCall(i int) string { + fake.getParticipantMutex.RLock() + defer fake.getParticipantMutex.RUnlock() + argsForCall := fake.getParticipantArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRoom) GetParticipantReturns(result1 types.Participant) { + fake.getParticipantMutex.Lock() + defer fake.getParticipantMutex.Unlock() + fake.GetParticipantStub = nil + fake.getParticipantReturns = struct { + result1 types.Participant + }{result1} +} + +func (fake *FakeRoom) GetParticipantReturnsOnCall(i int, result1 types.Participant) { + fake.getParticipantMutex.Lock() + defer fake.getParticipantMutex.Unlock() + fake.GetParticipantStub = nil + if fake.getParticipantReturnsOnCall == nil { + fake.getParticipantReturnsOnCall = make(map[int]struct { + result1 types.Participant + }) + } + fake.getParticipantReturnsOnCall[i] = struct { + result1 types.Participant + }{result1} +} + +func (fake *FakeRoom) Name() string { + fake.nameMutex.Lock() + ret, specificReturn := fake.nameReturnsOnCall[len(fake.nameArgsForCall)] + fake.nameArgsForCall = append(fake.nameArgsForCall, struct { + }{}) + stub := fake.NameStub + fakeReturns := fake.nameReturns + fake.recordInvocation("Name", []interface{}{}) + fake.nameMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) NameCallCount() int { + fake.nameMutex.RLock() + defer fake.nameMutex.RUnlock() + return len(fake.nameArgsForCall) +} + +func (fake *FakeRoom) NameCalls(stub func() string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = stub +} + +func (fake *FakeRoom) NameReturns(result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + fake.nameReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeRoom) NameReturnsOnCall(i int, result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + if fake.nameReturnsOnCall == nil { + fake.nameReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.nameReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeRoom) UpdateSubscriptions(arg1 types.Participant, arg2 []string, arg3 bool) error { + var arg2Copy []string + if arg2 != nil { + arg2Copy = make([]string, len(arg2)) + copy(arg2Copy, arg2) + } + fake.updateSubscriptionsMutex.Lock() + ret, specificReturn := fake.updateSubscriptionsReturnsOnCall[len(fake.updateSubscriptionsArgsForCall)] + fake.updateSubscriptionsArgsForCall = append(fake.updateSubscriptionsArgsForCall, struct { + arg1 types.Participant + arg2 []string + arg3 bool + }{arg1, arg2Copy, arg3}) + stub := fake.UpdateSubscriptionsStub + fakeReturns := fake.updateSubscriptionsReturns + fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3}) + fake.updateSubscriptionsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) UpdateSubscriptionsCallCount() int { + fake.updateSubscriptionsMutex.RLock() + defer fake.updateSubscriptionsMutex.RUnlock() + return len(fake.updateSubscriptionsArgsForCall) +} + +func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.Participant, []string, bool) error) { + fake.updateSubscriptionsMutex.Lock() + defer fake.updateSubscriptionsMutex.Unlock() + fake.UpdateSubscriptionsStub = stub +} + +func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.Participant, []string, bool) { + fake.updateSubscriptionsMutex.RLock() + defer fake.updateSubscriptionsMutex.RUnlock() + argsForCall := fake.updateSubscriptionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRoom) UpdateSubscriptionsReturns(result1 error) { + fake.updateSubscriptionsMutex.Lock() + defer fake.updateSubscriptionsMutex.Unlock() + fake.UpdateSubscriptionsStub = nil + fake.updateSubscriptionsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRoom) UpdateSubscriptionsReturnsOnCall(i int, result1 error) { + fake.updateSubscriptionsMutex.Lock() + defer fake.updateSubscriptionsMutex.Unlock() + fake.UpdateSubscriptionsStub = nil + if fake.updateSubscriptionsReturnsOnCall == nil { + fake.updateSubscriptionsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscriptionsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRoom) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getParticipantMutex.RLock() + defer fake.getParticipantMutex.RUnlock() + fake.nameMutex.RLock() + defer fake.nameMutex.RUnlock() + fake.updateSubscriptionsMutex.RLock() + defer fake.updateSubscriptionsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRoom) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.Room = new(FakeRoom) diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index f09312698..49ce5c168 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -361,7 +361,7 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici } req := obj.(*livekit.SignalRequest) - if err := r.handleSignalRequest(room, participant, req); err != nil { + if err := rtc.HandleParticipantSignal(room, participant, req); err != nil { // more specific errors are already logged // treat errors returned as fatal return @@ -370,130 +370,6 @@ func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.Partici } } -func (r *RoomManager) handleSignalRequest(room *rtc.Room, participant types.Participant, req *livekit.SignalRequest) error { - switch msg := req.Message.(type) { - case *livekit.SignalRequest_Offer: - _, err := participant.HandleOffer(rtc.FromProtoSessionDescription(msg.Offer)) - if err != nil { - logger.Errorw("could not handle offer", err, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - ) - return err - } - case *livekit.SignalRequest_AddTrack: - logger.Debugw("add track request", - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - "track", msg.AddTrack.Cid) - participant.AddTrack(msg.AddTrack) - case *livekit.SignalRequest_Answer: - sd := rtc.FromProtoSessionDescription(msg.Answer) - if err := participant.HandleAnswer(sd); err != nil { - logger.Errorw("could not handle answer", err, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - ) - // connection cannot be successful if we can't answer - return err - } - case *livekit.SignalRequest_Trickle: - candidateInit, err := rtc.FromProtoTrickle(msg.Trickle) - if err != nil { - logger.Warnw("could not decode trickle", err, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - ) - return nil - } - // logger.Debugw("adding peer candidate", "participant", participant.Identity()) - if err := participant.AddICECandidate(candidateInit, msg.Trickle.Target); err != nil { - logger.Warnw("could not handle trickle", err, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - ) - } - case *livekit.SignalRequest_Mute: - participant.SetTrackMuted(msg.Mute.Sid, msg.Mute.Muted, false) - case *livekit.SignalRequest_Subscription: - var err error - if participant.CanSubscribe() { - updateErr := room.UpdateSubscriptions(participant, msg.Subscription.TrackSids, msg.Subscription.Subscribe) - if updateErr != nil { - err = updateErr - } - } else { - err = rtc.ErrCannotSubscribe - } - if err != nil { - logger.Warnw("could not update subscription", err, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - "tracks", msg.Subscription.TrackSids, - "subscribe", msg.Subscription.Subscribe) - } - case *livekit.SignalRequest_TrackSetting: - for _, sid := range msg.TrackSetting.TrackSids { - subTrack := participant.GetSubscribedTrack(sid) - if subTrack == nil { - logger.Warnw("unable to find SubscribedTrack", nil, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - "track", sid) - continue - } - - // find the source PublishedTrack - publisher := room.GetParticipant(subTrack.PublisherIdentity()) - if publisher == nil { - logger.Warnw("unable to find publisher of SubscribedTrack", nil, - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - "publisher", subTrack.PublisherIdentity(), - "track", sid) - continue - } - - pubTrack := publisher.GetPublishedTrack(sid) - if pubTrack == nil { - logger.Warnw("unable to find PublishedTrack", nil, - "room", room.Room.Name, - "participant", publisher.Identity(), - "pID", publisher.ID(), - "track", sid) - continue - } - - // find quality for published track - logger.Debugw("updating track settings", - "room", room.Room.Name, - "participant", participant.Identity(), - "pID", participant.ID(), - "settings", msg.TrackSetting) - subTrack.UpdateSubscriberSettings(msg.TrackSetting) - } - case *livekit.SignalRequest_UpdateLayers: - track := participant.GetPublishedTrack(msg.UpdateLayers.TrackSid) - if track == nil { - logger.Warnw("could not find published track", nil, - "track", msg.UpdateLayers.TrackSid) - return nil - } - track.UpdateVideoLayers(msg.UpdateLayers.Layers) - case *livekit.SignalRequest_Leave: - _ = participant.Close() - } - return nil -} - // handles RTC messages resulted from Room API calls func (r *RoomManager) handleRTCMessage(ctx context.Context, roomName, identity string, msg *livekit.RTCNodeMessage) { r.lock.RLock()