diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 966dcddc8..bd1f75c53 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -23,7 +23,7 @@ type MessageSource interface { ReadChan() <-chan proto.Message } -type ParticipantCallback func(roomId, participantId, participantName string, requestSource MessageSource, responseSink MessageSink) +type ParticipantCallback func(roomName, identity string, requestSource MessageSource, responseSink MessageSink) // Router allows multiple nodes to coordinate the participant session //counterfeiter:generate . Router @@ -36,11 +36,8 @@ type Router interface { GetNode(nodeId string) (*livekit.Node, error) ListNodes() ([]*livekit.Node, error) - // functions for websocket handler - GetRequestSink(participantId string) (MessageSink, error) - GetResponseSource(participantId string) (MessageSource, error) // participant signal connection is ready to start - StartParticipantSignal(roomName, participantId, participantName string) error + StartParticipantSignal(roomName, identity string) (reqSink MessageSink, resSource MessageSource, err error) // when a new participant's RTC connection is ready to start OnNewParticipantRTC(callback ParticipantCallback) diff --git a/pkg/routing/localrouter.go b/pkg/routing/localrouter.go index 482a932ce..b7695fed1 100644 --- a/pkg/routing/localrouter.go +++ b/pkg/routing/localrouter.go @@ -59,30 +59,26 @@ func (r *LocalRouter) ListNodes() ([]*livekit.Node, error) { }, nil } -func (r *LocalRouter) StartParticipantSignal(roomName, participantId, participantName string) error { +func (r *LocalRouter) StartParticipantSignal(roomName, identity string) (reqSink MessageSink, resSource MessageSource, err error) { // treat it as a new participant connecting if r.onNewParticipant == nil { - return ErrHandlerNotDefined + return nil, nil, ErrHandlerNotDefined } + + // index channels by roomName | identity + key := participantKey(roomName, identity) + reqChan := r.getOrCreateMessageChannel(r.requestChannels, key) + resChan := r.getOrCreateMessageChannel(r.responseChannels, key) + r.onNewParticipant( roomName, - participantId, - participantName, + identity, // request source - r.getOrCreateMessageChannel(r.requestChannels, participantId), + reqChan, // response sink - r.getOrCreateMessageChannel(r.responseChannels, participantId), + resChan, ) - return nil -} - -// for a local router, sink and source are pointing to the same spot -func (r *LocalRouter) GetRequestSink(participantId string) (MessageSink, error) { - return r.getOrCreateMessageChannel(r.requestChannels, participantId), nil -} - -func (r *LocalRouter) GetResponseSource(participantId string) (MessageSource, error) { - return r.getOrCreateMessageChannel(r.responseChannels, participantId), nil + return reqChan, resChan, nil } func (r *LocalRouter) OnNewParticipantRTC(callback ParticipantCallback) { @@ -106,10 +102,10 @@ func (r *LocalRouter) statsWorker() { } } -func (r *LocalRouter) getOrCreateMessageChannel(target map[string]*MessageChannel, participantId string) *MessageChannel { +func (r *LocalRouter) getOrCreateMessageChannel(target map[string]*MessageChannel, key string) *MessageChannel { r.lock.Lock() defer r.lock.Unlock() - mc := target[participantId] + mc := target[key] if mc != nil { return mc @@ -118,10 +114,10 @@ func (r *LocalRouter) getOrCreateMessageChannel(target map[string]*MessageChanne mc = NewMessageChannel() mc.OnClose(func() { r.lock.Lock() - delete(target, participantId) + delete(target, key) r.lock.Unlock() }) - target[participantId] = mc + target[key] = mc return mc } diff --git a/pkg/routing/messagechannel_test.go b/pkg/routing/messagechannel_test.go index 692907ef2..28d4f0f63 100644 --- a/pkg/routing/messagechannel_test.go +++ b/pkg/routing/messagechannel_test.go @@ -24,12 +24,12 @@ func TestMessageChannel_WriteMessageClosed(t *testing.T) { go func() { defer wg.Done() for i := 0; i < 100; i++ { - m.WriteMessage(&livekit.RouterMessage{}) + m.WriteMessage(&livekit.RTCNodeMessage{}) } }() - m.WriteMessage(&livekit.RouterMessage{}) + m.WriteMessage(&livekit.RTCNodeMessage{}) m.Close() - m.WriteMessage(&livekit.RouterMessage{}) + m.WriteMessage(&livekit.RTCNodeMessage{}) wg.Wait() } diff --git a/pkg/routing/redis.go b/pkg/routing/redis.go index e6c5baf17..069a8a292 100644 --- a/pkg/routing/redis.go +++ b/pkg/routing/redis.go @@ -21,38 +21,57 @@ const ( var redisCtx = context.Background() // location of the participant's RTC connection, hash -func participantRTCKey(participantId string) string { - return "participant_rtc:" + participantId +func participantRTCKey(participantKey string) string { + return "participant_rtc:" + participantKey } // location of the participant's Signal connection, hash -func participantSignalKey(participantId string) string { - return "participant_signal:" + participantId +func participantSignalKey(connectionId string) string { + return "participant_signal:" + connectionId } -func nodeChannel(nodeId string) string { - return "node_channel:" + nodeId +func rtcNodeChannel(nodeId string) string { + return "rtc_channel:" + nodeId } -func publishRouterMessage(rc *redis.Client, nodeId string, participantId string, msg proto.Message) error { - rm := &livekit.RouterMessage{ - ParticipantId: participantId, +func signalNodeChannel(nodeId string) string { + return "signal_channel:" + nodeId +} + +func publishRTCMessage(rc *redis.Client, nodeId string, participantKey string, msg proto.Message) error { + rm := &livekit.RTCNodeMessage{ + ParticipantKey: participantKey, } switch o := msg.(type) { case *livekit.StartSession: - rm.Message = &livekit.RouterMessage_StartSession{ + rm.Message = &livekit.RTCNodeMessage_StartSession{ StartSession: o, } case *livekit.SignalRequest: - rm.Message = &livekit.RouterMessage_Request{ + rm.Message = &livekit.RTCNodeMessage_Request{ Request: o, } + default: + return errInvalidRouterMessage + } + data, err := proto.Marshal(rm) + if err != nil { + return err + } + return rc.Publish(redisCtx, rtcNodeChannel(nodeId), data).Err() +} + +func publishSignalMessage(rc *redis.Client, nodeId string, connectionId string, msg proto.Message) error { + rm := &livekit.SignalNodeMessage{ + ConnectionId: connectionId, + } + switch o := msg.(type) { case *livekit.SignalResponse: - rm.Message = &livekit.RouterMessage_Response{ + rm.Message = &livekit.SignalNodeMessage_Response{ Response: o, } case *livekit.EndSession: - rm.Message = &livekit.RouterMessage_EndSession{ + rm.Message = &livekit.SignalNodeMessage_EndSession{ EndSession: o, } default: @@ -62,42 +81,78 @@ func publishRouterMessage(rc *redis.Client, nodeId string, participantId string, if err != nil { return err } - return rc.Publish(redisCtx, nodeChannel(nodeId), data).Err() + return rc.Publish(redisCtx, signalNodeChannel(nodeId), data).Err() } -type RedisSink struct { - rc *redis.Client - nodeId string - participantId string - isClosed utils.AtomicFlag - onClose func() +type RTCNodeSink struct { + rc *redis.Client + nodeId string + participantKey string + isClosed utils.AtomicFlag + onClose func() } -func NewRedisSink(rc *redis.Client, nodeId, participantId string) *RedisSink { - return &RedisSink{ - rc: rc, - nodeId: nodeId, - participantId: participantId, +func NewRTCNodeSink(rc *redis.Client, nodeId, participantKey string) *RTCNodeSink { + return &RTCNodeSink{ + rc: rc, + nodeId: nodeId, + participantKey: participantKey, } } -func (s *RedisSink) WriteMessage(msg proto.Message) error { +func (s *RTCNodeSink) WriteMessage(msg proto.Message) error { if s.isClosed.Get() { return ErrChannelClosed } - return publishRouterMessage(s.rc, s.nodeId, s.participantId, msg) + return publishRTCMessage(s.rc, s.nodeId, s.participantKey, msg) } -func (s *RedisSink) Close() { +func (s *RTCNodeSink) Close() { if !s.isClosed.TrySet(true) { return } - publishRouterMessage(s.rc, s.nodeId, s.participantId, &livekit.EndSession{}) if s.onClose != nil { s.onClose() } } -func (s *RedisSink) OnClose(f func()) { +func (s *RTCNodeSink) OnClose(f func()) { + s.onClose = f +} + +type SignalNodeSink struct { + rc *redis.Client + nodeId string + connectionId string + isClosed utils.AtomicFlag + onClose func() +} + +func NewSignalNodeSink(rc *redis.Client, nodeId, connectionId string) *SignalNodeSink { + return &SignalNodeSink{ + rc: rc, + nodeId: nodeId, + connectionId: connectionId, + } +} + +func (s *SignalNodeSink) WriteMessage(msg proto.Message) error { + if s.isClosed.Get() { + return ErrChannelClosed + } + return publishSignalMessage(s.rc, s.nodeId, s.connectionId, msg) +} + +func (s *SignalNodeSink) Close() { + if !s.isClosed.TrySet(true) { + return + } + publishSignalMessage(s.rc, s.nodeId, s.connectionId, &livekit.EndSession{}) + if s.onClose != nil { + s.onClose() + } +} + +func (s *SignalNodeSink) OnClose(f func()) { s.onClose = f } diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 78a1db6b6..06985b8f5 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -30,8 +30,11 @@ type RedisRouter struct { ctx context.Context once sync.Once - redisSinks map[string]*RedisSink - cancel func() + // map of participantKey => RTCNodeSink + rtcSinks map[string]*RTCNodeSink + // map of connectionId => SignalNodeSink + signalSinks map[string]*SignalNodeSink + cancel func() } func NewRedisRouter(currentNode LocalNode, rc *redis.Client) *RedisRouter { @@ -39,7 +42,8 @@ func NewRedisRouter(currentNode LocalNode, rc *redis.Client) *RedisRouter { LocalRouter: *NewLocalRouter(currentNode), rc: rc, once: sync.Once{}, - redisSinks: make(map[string]*RedisSink), + rtcSinks: make(map[string]*RTCNodeSink), + signalSinks: make(map[string]*SignalNodeSink), } rr.ctx, rr.cancel = context.WithCancel(context.Background()) rr.cr = utils.NewCachedRedis(rr.ctx, rr.rc) @@ -110,55 +114,44 @@ func (r *RedisRouter) ListNodes() ([]*livekit.Node, error) { return nodes, nil } -// for a local router, sink and source are pointing to the same spot -func (r *RedisRouter) GetRequestSink(participantId string) (MessageSink, error) { - // request should go to RTC node - rtcNode, err := r.getParticipantRTCNode(participantId) - if err != nil { - return nil, err - } - - sink := r.getOrCreateRedisSink(rtcNode, participantId) - return sink, nil -} - -func (r *RedisRouter) GetResponseSource(participantId string) (MessageSource, error) { - // a message channel that we'll send data into - source := r.getOrCreateMessageChannel(r.responseChannels, participantId) - return source, nil -} - // signal connection sets up paths to the RTC node, and starts to route messages to that message queue -func (r *RedisRouter) StartParticipantSignal(roomName, participantId, participantName string) error { +func (r *RedisRouter) StartParticipantSignal(roomName, identity string) (reqSink MessageSink, resSource MessageSource, err error) { // find the node where the room is hosted at rtcNode, err := r.GetNodeForRoom(roomName) if err != nil { - return err + return } + // create a new connection id + connectionId := utils.NewGuid("CO_") + pKey := participantKey(roomName, identity) + // map signal & rtc nodes - if err = r.setParticipantSignalNode(participantId, r.currentNode.Id); err != nil { - return err - } - if err := r.setParticipantRTCNode(participantId, rtcNode); err != nil { - return err + if err = r.setParticipantSignalNode(connectionId, r.currentNode.Id); err != nil { + return } - sink, err := r.GetRequestSink(participantId) - if err != nil { - return err - } + sink := r.getOrCreateRTCSink(rtcNode, pKey) // sends a message to start session - return sink.WriteMessage(&livekit.StartSession{ - RoomName: roomName, - ParticipantName: participantName, + err = sink.WriteMessage(&livekit.StartSession{ + RoomName: roomName, + Identity: identity, + // connection id is to allow the RTC node to identify where to route the message back to + ConnectionId: connectionId, }) + if err != nil { + return + } + + // index by connectionId, since there may be multiple connections for the participant + resChan := r.getOrCreateMessageChannel(r.responseChannels, connectionId) + return sink, resChan, nil } -func (r *RedisRouter) startParticipantRTC(roomName, participantId, participantName string) error { +func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantKey string) error { // find the node where the room is hosted at - rtcNode, err := r.GetNodeForRoom(roomName) + rtcNode, err := r.GetNodeForRoom(ss.RoomName) if err != nil { return err } @@ -169,8 +162,12 @@ func (r *RedisRouter) startParticipantRTC(roomName, participantId, participantNa return ErrIncorrectRTCNode } + if err := r.setParticipantRTCNode(participantKey, rtcNode); err != nil { + return err + } + // find signal node to send responses back - signalNode, err := r.getParticipantSignalNode(participantId) + signalNode, err := r.getParticipantSignalNode(ss.ConnectionId) if err != nil { return err } @@ -180,12 +177,12 @@ func (r *RedisRouter) startParticipantRTC(roomName, participantId, participantNa return ErrHandlerNotDefined } - resSink := r.getOrCreateRedisSink(signalNode, participantId) + reqChan := r.getOrCreateMessageChannel(r.requestChannels, participantKey) + resSink := r.getOrCreateSignalSink(signalNode, ss.ConnectionId) r.onNewParticipant( - roomName, - participantId, - participantName, - r.getOrCreateMessageChannel(r.requestChannels, participantId), + ss.RoomName, + ss.Identity, + reqChan, resSink, ) return nil @@ -194,7 +191,8 @@ func (r *RedisRouter) startParticipantRTC(roomName, participantId, participantNa func (r *RedisRouter) Start() error { r.once.Do(func() { go r.statsWorker() - go r.subscribeWorker() + go r.rtcWorker() + go r.signalWorker() }) return nil } @@ -203,48 +201,67 @@ func (r *RedisRouter) Stop() { r.cancel() } -func (r *RedisRouter) setParticipantRTCNode(participantId, nodeId string) error { - r.cr.Expire(participantRTCKey(participantId)) - err := r.rc.Set(r.ctx, participantRTCKey(participantId), nodeId, participantMappingTTL).Err() +func (r *RedisRouter) setParticipantRTCNode(participantKey, nodeId string) error { + r.cr.Expire(participantRTCKey(participantKey)) + err := r.rc.Set(r.ctx, participantRTCKey(participantKey), nodeId, participantMappingTTL).Err() if err != nil { err = errors.Wrap(err, "could not set rtc node") } return err } -func (r *RedisRouter) setParticipantSignalNode(participantId, nodeId string) error { - r.cr.Expire(participantSignalKey(participantId)) - if err := r.rc.Set(r.ctx, participantSignalKey(participantId), nodeId, participantMappingTTL).Err(); err != nil { +func (r *RedisRouter) setParticipantSignalNode(connectionId, nodeId string) error { + r.cr.Expire(participantSignalKey(connectionId)) + if err := r.rc.Set(r.ctx, participantSignalKey(connectionId), nodeId, participantMappingTTL).Err(); err != nil { return errors.Wrap(err, "could not set signal node") } return nil } -func (r *RedisRouter) getOrCreateRedisSink(nodeId string, participantId string) *RedisSink { +func (r *RedisRouter) getOrCreateRTCSink(nodeId string, participantKey string) *RTCNodeSink { r.lock.Lock() defer r.lock.Unlock() - sink := r.redisSinks[participantId] + sink := r.rtcSinks[participantKey] if sink != nil { return sink } - sink = NewRedisSink(r.rc, nodeId, participantId) + sink = NewRTCNodeSink(r.rc, nodeId, participantKey) sink.OnClose(func() { r.lock.Lock() - delete(r.redisSinks, participantId) + delete(r.rtcSinks, participantKey) r.lock.Unlock() }) - r.redisSinks[participantId] = sink + r.rtcSinks[participantKey] = sink return sink } -func (r *RedisRouter) getParticipantRTCNode(participantId string) (string, error) { - return r.cr.CachedGet(participantRTCKey(participantId)) +func (r *RedisRouter) getOrCreateSignalSink(nodeId string, connectionId string) *SignalNodeSink { + r.lock.Lock() + defer r.lock.Unlock() + sink := r.signalSinks[connectionId] + + if sink != nil { + return sink + } + + sink = NewSignalNodeSink(r.rc, nodeId, connectionId) + sink.OnClose(func() { + r.lock.Lock() + delete(r.signalSinks, connectionId) + r.lock.Unlock() + }) + r.signalSinks[connectionId] = sink + return sink } -func (r *RedisRouter) getParticipantSignalNode(participantId string) (nodeId string, err error) { - return r.cr.CachedGet(participantSignalKey(participantId)) +func (r *RedisRouter) getParticipantRTCNode(participantKey string) (string, error) { + return r.cr.CachedGet(participantRTCKey(participantKey)) +} + +func (r *RedisRouter) getParticipantSignalNode(connectionId string) (nodeId string, err error) { + return r.cr.CachedGet(participantSignalKey(connectionId)) } // update node stats and cleanup @@ -259,13 +276,13 @@ func (r *RedisRouter) statsWorker() { } } -func (r *RedisRouter) subscribeWorker() { - sub := r.rc.Subscribe(redisCtx, nodeChannel(r.currentNode.Id)) - +// worker that consumes signal channel and processes +func (r *RedisRouter) signalWorker() { + sub := r.rc.Subscribe(redisCtx, signalNodeChannel(r.currentNode.Id)) defer func() { - logger.Debugw("finishing redis subscribeWorker", "node", r.currentNode.Id) + logger.Debugw("finishing redis signalWorker", "node", r.currentNode.Id) }() - logger.Debugw("starting redis subscribeWorker", "node", r.currentNode.Id) + logger.Debugw("starting redis signalWorker", "node", r.currentNode.Id) for r.ctx.Err() == nil { obj, err := sub.Receive(r.ctx) if err != nil { @@ -283,42 +300,24 @@ func (r *RedisRouter) subscribeWorker() { continue } - rm := livekit.RouterMessage{} + rm := livekit.SignalNodeMessage{} err = proto.Unmarshal([]byte(msg.Payload), &rm) - pId := rm.ParticipantId + connectionId := rm.ConnectionId switch rmb := rm.Message.(type) { - case *livekit.RouterMessage_StartSession: - logger.Infow("received router startSession", "node", r.currentNode.Id, - "participant", pId) - // RTC session should start on this node - err = r.startParticipantRTC(rmb.StartSession.RoomName, pId, rmb.StartSession.ParticipantName) - if err != nil { - logger.Errorw("could not start participant", "error", err) - } - case *livekit.RouterMessage_Request: - // in the event the current node is an RTC node, push to request channels - reqSink := r.getOrCreateMessageChannel(r.requestChannels, pId) - err = reqSink.WriteMessage(rmb.Request) - if err != nil { - logger.Errorw("could not write to request channel", - "participant", pId, - "error", err) - } - - case *livekit.RouterMessage_Response: + case *livekit.SignalNodeMessage_Response: // in the event the current node is an Signal node, push to response channels - resSink := r.getOrCreateMessageChannel(r.responseChannels, pId) + resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) err = resSink.WriteMessage(rmb.Response) if err != nil { logger.Errorw("could not write to response channel", - "participant", pId, + "connectionId", connectionId, "error", err) } - case *livekit.RouterMessage_EndSession: - signalNode, err := r.getParticipantRTCNode(pId) + case *livekit.SignalNodeMessage_EndSession: + signalNode, err := r.getParticipantSignalNode(connectionId) if err != nil { logger.Errorw("could not get participant RTC node", "error", err) @@ -326,9 +325,61 @@ func (r *RedisRouter) subscribeWorker() { } // EndSession can only be initiated on an RTC node, is handled on the signal node if signalNode == r.currentNode.Id { - resSink := r.getOrCreateMessageChannel(r.responseChannels, pId) + resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) resSink.Close() } } } } + +// worker that consumes RTC channel and processes +func (r *RedisRouter) rtcWorker() { + sub := r.rc.Subscribe(redisCtx, rtcNodeChannel(r.currentNode.Id)) + + defer func() { + logger.Debugw("finishing redis rtcWorker", "node", r.currentNode.Id) + }() + logger.Debugw("starting redis rtcWorker", "node", r.currentNode.Id) + for r.ctx.Err() == nil { + obj, err := sub.Receive(r.ctx) + if err != nil { + logger.Warnw("error receiving redis message", "error", err) + // TODO: retry? ignore? at a minimum need to sleep here to retry + time.Sleep(time.Second) + continue + } + if obj == nil { + return + } + + msg, ok := obj.(*redis.Message) + if !ok { + continue + } + + rm := livekit.RTCNodeMessage{} + err = proto.Unmarshal([]byte(msg.Payload), &rm) + pKey := rm.ParticipantKey + + switch rmb := rm.Message.(type) { + case *livekit.RTCNodeMessage_StartSession: + logger.Debugw("received router startSession", "node", r.currentNode.Id, + "participant", pKey) + // RTC session should start on this node + err = r.startParticipantRTC(rmb.StartSession, pKey) + if err != nil { + logger.Errorw("could not start participant", "error", err) + } + + case *livekit.RTCNodeMessage_Request: + // in the event the current node is an RTC node, push to request channels + reqSink := r.getOrCreateMessageChannel(r.requestChannels, pKey) + err = reqSink.WriteMessage(rmb.Request) + if err != nil { + logger.Errorw("could not write to request channel", + "participant", pKey, + "error", err) + } + } + } +} diff --git a/pkg/routing/routingfakes/fake_router.go b/pkg/routing/routingfakes/fake_router.go index c81c9926f..f9ff37f60 100644 --- a/pkg/routing/routingfakes/fake_router.go +++ b/pkg/routing/routingfakes/fake_router.go @@ -46,32 +46,6 @@ type FakeRouter struct { result1 string result2 error } - GetRequestSinkStub func(string) (routing.MessageSink, error) - getRequestSinkMutex sync.RWMutex - getRequestSinkArgsForCall []struct { - arg1 string - } - getRequestSinkReturns struct { - result1 routing.MessageSink - result2 error - } - getRequestSinkReturnsOnCall map[int]struct { - result1 routing.MessageSink - result2 error - } - GetResponseSourceStub func(string) (routing.MessageSource, error) - getResponseSourceMutex sync.RWMutex - getResponseSourceArgsForCall []struct { - arg1 string - } - getResponseSourceReturns struct { - result1 routing.MessageSource - result2 error - } - getResponseSourceReturnsOnCall map[int]struct { - result1 routing.MessageSource - result2 error - } ListNodesStub func() ([]*livekit.Node, error) listNodesMutex sync.RWMutex listNodesArgsForCall []struct { @@ -121,18 +95,21 @@ type FakeRouter struct { startReturnsOnCall map[int]struct { result1 error } - StartParticipantSignalStub func(string, string, string) error + StartParticipantSignalStub func(string, string) (routing.MessageSink, routing.MessageSource, error) startParticipantSignalMutex sync.RWMutex startParticipantSignalArgsForCall []struct { arg1 string arg2 string - arg3 string } startParticipantSignalReturns struct { - result1 error + result1 routing.MessageSink + result2 routing.MessageSource + result3 error } startParticipantSignalReturnsOnCall map[int]struct { - result1 error + result1 routing.MessageSink + result2 routing.MessageSource + result3 error } StopStub func() stopMutex sync.RWMutex @@ -341,134 +318,6 @@ func (fake *FakeRouter) GetNodeForRoomReturnsOnCall(i int, result1 string, resul }{result1, result2} } -func (fake *FakeRouter) GetRequestSink(arg1 string) (routing.MessageSink, error) { - fake.getRequestSinkMutex.Lock() - ret, specificReturn := fake.getRequestSinkReturnsOnCall[len(fake.getRequestSinkArgsForCall)] - fake.getRequestSinkArgsForCall = append(fake.getRequestSinkArgsForCall, struct { - arg1 string - }{arg1}) - stub := fake.GetRequestSinkStub - fakeReturns := fake.getRequestSinkReturns - fake.recordInvocation("GetRequestSink", []interface{}{arg1}) - fake.getRequestSinkMutex.Unlock() - if stub != nil { - return stub(arg1) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeRouter) GetRequestSinkCallCount() int { - fake.getRequestSinkMutex.RLock() - defer fake.getRequestSinkMutex.RUnlock() - return len(fake.getRequestSinkArgsForCall) -} - -func (fake *FakeRouter) GetRequestSinkCalls(stub func(string) (routing.MessageSink, error)) { - fake.getRequestSinkMutex.Lock() - defer fake.getRequestSinkMutex.Unlock() - fake.GetRequestSinkStub = stub -} - -func (fake *FakeRouter) GetRequestSinkArgsForCall(i int) string { - fake.getRequestSinkMutex.RLock() - defer fake.getRequestSinkMutex.RUnlock() - argsForCall := fake.getRequestSinkArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *FakeRouter) GetRequestSinkReturns(result1 routing.MessageSink, result2 error) { - fake.getRequestSinkMutex.Lock() - defer fake.getRequestSinkMutex.Unlock() - fake.GetRequestSinkStub = nil - fake.getRequestSinkReturns = struct { - result1 routing.MessageSink - result2 error - }{result1, result2} -} - -func (fake *FakeRouter) GetRequestSinkReturnsOnCall(i int, result1 routing.MessageSink, result2 error) { - fake.getRequestSinkMutex.Lock() - defer fake.getRequestSinkMutex.Unlock() - fake.GetRequestSinkStub = nil - if fake.getRequestSinkReturnsOnCall == nil { - fake.getRequestSinkReturnsOnCall = make(map[int]struct { - result1 routing.MessageSink - result2 error - }) - } - fake.getRequestSinkReturnsOnCall[i] = struct { - result1 routing.MessageSink - result2 error - }{result1, result2} -} - -func (fake *FakeRouter) GetResponseSource(arg1 string) (routing.MessageSource, error) { - fake.getResponseSourceMutex.Lock() - ret, specificReturn := fake.getResponseSourceReturnsOnCall[len(fake.getResponseSourceArgsForCall)] - fake.getResponseSourceArgsForCall = append(fake.getResponseSourceArgsForCall, struct { - arg1 string - }{arg1}) - stub := fake.GetResponseSourceStub - fakeReturns := fake.getResponseSourceReturns - fake.recordInvocation("GetResponseSource", []interface{}{arg1}) - fake.getResponseSourceMutex.Unlock() - if stub != nil { - return stub(arg1) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeRouter) GetResponseSourceCallCount() int { - fake.getResponseSourceMutex.RLock() - defer fake.getResponseSourceMutex.RUnlock() - return len(fake.getResponseSourceArgsForCall) -} - -func (fake *FakeRouter) GetResponseSourceCalls(stub func(string) (routing.MessageSource, error)) { - fake.getResponseSourceMutex.Lock() - defer fake.getResponseSourceMutex.Unlock() - fake.GetResponseSourceStub = stub -} - -func (fake *FakeRouter) GetResponseSourceArgsForCall(i int) string { - fake.getResponseSourceMutex.RLock() - defer fake.getResponseSourceMutex.RUnlock() - argsForCall := fake.getResponseSourceArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *FakeRouter) GetResponseSourceReturns(result1 routing.MessageSource, result2 error) { - fake.getResponseSourceMutex.Lock() - defer fake.getResponseSourceMutex.Unlock() - fake.GetResponseSourceStub = nil - fake.getResponseSourceReturns = struct { - result1 routing.MessageSource - result2 error - }{result1, result2} -} - -func (fake *FakeRouter) GetResponseSourceReturnsOnCall(i int, result1 routing.MessageSource, result2 error) { - fake.getResponseSourceMutex.Lock() - defer fake.getResponseSourceMutex.Unlock() - fake.GetResponseSourceStub = nil - if fake.getResponseSourceReturnsOnCall == nil { - fake.getResponseSourceReturnsOnCall = make(map[int]struct { - result1 routing.MessageSource - result2 error - }) - } - fake.getResponseSourceReturnsOnCall[i] = struct { - result1 routing.MessageSource - result2 error - }{result1, result2} -} - func (fake *FakeRouter) ListNodes() ([]*livekit.Node, error) { fake.listNodesMutex.Lock() ret, specificReturn := fake.listNodesReturnsOnCall[len(fake.listNodesArgsForCall)] @@ -725,25 +574,24 @@ func (fake *FakeRouter) StartReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeRouter) StartParticipantSignal(arg1 string, arg2 string, arg3 string) error { +func (fake *FakeRouter) StartParticipantSignal(arg1 string, arg2 string) (routing.MessageSink, routing.MessageSource, error) { fake.startParticipantSignalMutex.Lock() ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)] fake.startParticipantSignalArgsForCall = append(fake.startParticipantSignalArgsForCall, struct { arg1 string arg2 string - arg3 string - }{arg1, arg2, arg3}) + }{arg1, arg2}) stub := fake.StartParticipantSignalStub fakeReturns := fake.startParticipantSignalReturns - fake.recordInvocation("StartParticipantSignal", []interface{}{arg1, arg2, arg3}) + fake.recordInvocation("StartParticipantSignal", []interface{}{arg1, arg2}) fake.startParticipantSignalMutex.Unlock() if stub != nil { - return stub(arg1, arg2, arg3) + return stub(arg1, arg2) } if specificReturn { - return ret.result1 + return ret.result1, ret.result2, ret.result3 } - return fakeReturns.result1 + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 } func (fake *FakeRouter) StartParticipantSignalCallCount() int { @@ -752,40 +600,46 @@ func (fake *FakeRouter) StartParticipantSignalCallCount() int { return len(fake.startParticipantSignalArgsForCall) } -func (fake *FakeRouter) StartParticipantSignalCalls(stub func(string, string, string) error) { +func (fake *FakeRouter) StartParticipantSignalCalls(stub func(string, string) (routing.MessageSink, routing.MessageSource, error)) { fake.startParticipantSignalMutex.Lock() defer fake.startParticipantSignalMutex.Unlock() fake.StartParticipantSignalStub = stub } -func (fake *FakeRouter) StartParticipantSignalArgsForCall(i int) (string, string, string) { +func (fake *FakeRouter) StartParticipantSignalArgsForCall(i int) (string, string) { fake.startParticipantSignalMutex.RLock() defer fake.startParticipantSignalMutex.RUnlock() argsForCall := fake.startParticipantSignalArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 + return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeRouter) StartParticipantSignalReturns(result1 error) { +func (fake *FakeRouter) StartParticipantSignalReturns(result1 routing.MessageSink, result2 routing.MessageSource, result3 error) { fake.startParticipantSignalMutex.Lock() defer fake.startParticipantSignalMutex.Unlock() fake.StartParticipantSignalStub = nil fake.startParticipantSignalReturns = struct { - result1 error - }{result1} + result1 routing.MessageSink + result2 routing.MessageSource + result3 error + }{result1, result2, result3} } -func (fake *FakeRouter) StartParticipantSignalReturnsOnCall(i int, result1 error) { +func (fake *FakeRouter) StartParticipantSignalReturnsOnCall(i int, result1 routing.MessageSink, result2 routing.MessageSource, result3 error) { fake.startParticipantSignalMutex.Lock() defer fake.startParticipantSignalMutex.Unlock() fake.StartParticipantSignalStub = nil if fake.startParticipantSignalReturnsOnCall == nil { fake.startParticipantSignalReturnsOnCall = make(map[int]struct { - result1 error + result1 routing.MessageSink + result2 routing.MessageSource + result3 error }) } fake.startParticipantSignalReturnsOnCall[i] = struct { - result1 error - }{result1} + result1 routing.MessageSink + result2 routing.MessageSource + result3 error + }{result1, result2, result3} } func (fake *FakeRouter) Stop() { @@ -874,10 +728,6 @@ func (fake *FakeRouter) Invocations() map[string][][]interface{} { defer fake.getNodeMutex.RUnlock() fake.getNodeForRoomMutex.RLock() defer fake.getNodeForRoomMutex.RUnlock() - fake.getRequestSinkMutex.RLock() - defer fake.getRequestSinkMutex.RUnlock() - fake.getResponseSourceMutex.RLock() - defer fake.getResponseSourceMutex.RUnlock() fake.listNodesMutex.RLock() defer fake.listNodesMutex.RUnlock() fake.onNewParticipantRTCMutex.RLock() diff --git a/pkg/routing/utils.go b/pkg/routing/utils.go index 3bbce1c9b..a820f5b16 100644 --- a/pkg/routing/utils.go +++ b/pkg/routing/utils.go @@ -20,3 +20,7 @@ func GetAvailableNodes(nodes []*livekit.Node) []*livekit.Node { return IsAvailable(node) }).([]*livekit.Node) } + +func participantKey(roomName, identity string) string { + return roomName + "|" + identity +} diff --git a/pkg/rtc/errors.go b/pkg/rtc/errors.go index 91729081f..b1e704568 100644 --- a/pkg/rtc/errors.go +++ b/pkg/rtc/errors.go @@ -6,6 +6,7 @@ var ( ErrRoomClosed = errors.New("room has already closed") ErrPermissionDenied = errors.New("no permissions to access the room") ErrMaxParticipantsExceeded = errors.New("room has exceeded its max participants") + ErrAlreadyJoined = errors.New("a participant with the same identity is already in the room") ErrUnexpectedOffer = errors.New("expected answer SDP, received offer") ErrUnexpectedNegotiation = errors.New("client negotiation has not been granted") ) diff --git a/pkg/rtc/participant.go b/pkg/rtc/participant.go index 2942a8e29..2b174d0e4 100644 --- a/pkg/rtc/participant.go +++ b/pkg/rtc/participant.go @@ -75,14 +75,14 @@ func NewPeerConnection(conf *WebRTCConfig) (*webrtc.PeerConnection, error) { return pc, err } -func NewParticipant(participantId, identity string, pc types.PeerConnection, rs routing.MessageSink, receiverConfig ReceiverConfig) (*ParticipantImpl, error) { +func NewParticipant(identity string, pc types.PeerConnection, rs routing.MessageSink, receiverConfig ReceiverConfig) (*ParticipantImpl, error) { // TODO: check to ensure params are valid, id and identity can't be empty me := &webrtc.MediaEngine{} me.RegisterDefaultCodecs() ctx, cancel := context.WithCancel(context.Background()) participant := &ParticipantImpl{ - id: participantId, + id: utils.NewGuid(utils.ParticipantPrefix), identity: identity, peerConn: pc, responseSink: rs, @@ -184,6 +184,14 @@ func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { return info } +func (p *ParticipantImpl) GetResponseSink() routing.MessageSink { + return p.responseSink +} + +func (p *ParticipantImpl) SetResponseSink(sink routing.MessageSink) { + p.responseSink = sink +} + // callbacks for clients func (p *ParticipantImpl) OnTrackPublished(callback func(types.Participant, types.PublishedTrack)) { p.onTrackPublished = callback diff --git a/pkg/rtc/participant_internal_test.go b/pkg/rtc/participant_internal_test.go index 102240d78..f9f2a6566 100644 --- a/pkg/rtc/participant_internal_test.go +++ b/pkg/rtc/participant_internal_test.go @@ -9,7 +9,6 @@ import ( "github.com/livekit/livekit-server/pkg/routing/routingfakes" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" - "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/livekit-server/proto/livekit" ) @@ -117,10 +116,9 @@ func TestDisconnectTiming(t *testing.T) { }) } -func newParticipantForTest(name string) *ParticipantImpl { +func newParticipantForTest(identity string) *ParticipantImpl { p, _ := NewParticipant( - utils.NewGuid(utils.ParticipantPrefix), - name, + identity, &typesfakes.FakePeerConnection{}, &routingfakes.FakeMessageSink{}, ReceiverConfig{}) diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index 2df16e34d..82fab5558 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -16,7 +16,7 @@ type Room struct { livekit.Room config WebRTCConfig lock sync.RWMutex - // map of participantId -> Participant + // map of identity -> Participant participants map[string]types.Participant hasJoined bool isClosed utils.AtomicFlag @@ -32,10 +32,10 @@ func NewRoom(room *livekit.Room, config WebRTCConfig) *Room { } } -func (r *Room) GetParticipant(id string) types.Participant { +func (r *Room) GetParticipant(identity string) types.Participant { r.lock.RLock() defer r.lock.RUnlock() - return r.participants[id] + return r.participants[identity] } func (r *Room) GetParticipants() []types.Participant { @@ -52,6 +52,11 @@ func (r *Room) Join(participant types.Participant) error { r.lock.Lock() defer r.lock.Unlock() + if r.participants[participant.Identity()] != nil { + return ErrAlreadyJoined + + } + if r.MaxParticipants > 0 && int(r.MaxParticipants) == len(r.participants) { return ErrMaxParticipantsExceeded } @@ -75,15 +80,15 @@ func (r *Room) Join(participant types.Participant) error { if err := op.AddSubscriber(p); err != nil { // TODO: log error? or disconnect? logger.Errorw("could not subscribe to participant", - "dstParticipant", p.ID(), - "srcParticipant", op.ID()) + "dest", p.Identity(), + "source", op.Identity()) } } // start the workers once connectivity is established p.Start() } else if p.State() == livekit.ParticipantInfo_DISCONNECTED { // remove participant from room - go r.RemoveParticipant(p.ID()) + go r.RemoveParticipant(p.Identity()) } }) participant.OnTrackUpdated(r.onTrackUpdated) @@ -93,7 +98,7 @@ func (r *Room) Join(participant types.Participant) error { "identity", participant.Identity(), "roomId", r.Sid) - r.participants[participant.ID()] = participant + r.participants[participant.Identity()] = participant // gather other participants and send join response otherParticipants := make([]types.Participant, 0, len(r.participants)) @@ -106,11 +111,11 @@ func (r *Room) Join(participant types.Participant) error { return participant.SendJoinResponse(&r.Room, otherParticipants) } -func (r *Room) RemoveParticipant(id string) { +func (r *Room) RemoveParticipant(identity string) { r.lock.Lock() defer r.lock.Unlock() - if p, ok := r.participants[id]; ok { + if p, ok := r.participants[identity]; ok { // avoid blocking lock go func() { Recover() @@ -119,7 +124,7 @@ func (r *Room) RemoveParticipant(id string) { }() } - delete(r.participants, id) + delete(r.participants, identity) go r.CloseIfEmpty() } @@ -139,11 +144,15 @@ func (r *Room) CloseIfEmpty() { } elapsed := uint32(time.Now().Unix() - r.CreationTime) - logger.Infow("comparing elapsed", "elapsed", elapsed, "timeout", r.EmptyTimeout) if r.hasJoined || (r.EmptyTimeout > 0 && elapsed >= r.EmptyTimeout) { - if r.isClosed.TrySet(true) && r.onClose != nil { - r.onClose() - } + r.Close() + } +} + +func (r *Room) Close() { + logger.Infow("closing room", "room", r.Sid, "name", r.Name) + if r.isClosed.TrySet(true) && r.onClose != nil { + r.onClose() } } @@ -171,14 +180,14 @@ func (r *Room) onTrackAdded(participant types.Participant, track types.Published continue } logger.Debugw("subscribing to new track", - "srcParticipant", participant.ID(), + "source", participant.Identity(), "remoteTrack", track.ID(), - "dstParticipant", existingParticipant.ID()) + "dest", existingParticipant.Identity()) if err := track.AddSubscriber(existingParticipant); err != nil { logger.Errorw("could not subscribe to remoteTrack", - "srcParticipant", participant.ID(), + "source", participant.Identity(), "remoteTrack", track.ID(), - "dstParticipant", existingParticipant.ID()) + "dest", existingParticipant.Identity()) } } } @@ -202,7 +211,7 @@ func (r *Room) broadcastParticipantState(p types.Participant) { err := op.SendParticipantUpdate(updates) if err != nil { logger.Errorw("could not send update to participant", - "participant", p.ID(), + "participant", p.Identity(), "err", err) } } diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index df1caf40b..7951526c5 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -1,6 +1,7 @@ package rtc_test import ( + "fmt" "testing" "time" @@ -62,7 +63,7 @@ func TestRoomJoin(t *testing.T) { disconnectedParticipant := participants[1].(*typesfakes.FakeParticipant) disconnectedParticipant.StateReturns(livekit.ParticipantInfo_DISCONNECTED) - rm.RemoveParticipant(p.ID()) + rm.RemoveParticipant(p.Identity()) p.OnStateChangeArgsForCall(0)(p, livekit.ParticipantInfo_ACTIVE) time.Sleep(defaultDelay) @@ -97,7 +98,7 @@ func TestRoomClosure(t *testing.T) { isClosed = true }) p := rm.GetParticipants()[0] - rm.RemoveParticipant(p.ID()) + rm.RemoveParticipant(p.Identity()) time.Sleep(defaultDelay) @@ -155,12 +156,14 @@ func TestNewTrack(t *testing.T) { } func newRoomWithParticipants(t *testing.T, num int) *rtc.Room { + rm := rtc.NewRoom( - &livekit.Room{Name: "identity"}, + &livekit.Room{Name: "room"}, rtc.WebRTCConfig{}, ) for i := 0; i < num; i++ { - participant := newMockParticipant("") + identity := fmt.Sprintf("p%d", i) + participant := newMockParticipant(identity) err := rm.Join(participant) assert.NoError(t, err) //rm.participants[participant.ID()] = participant diff --git a/pkg/rtc/types/interfaces.go b/pkg/rtc/types/interfaces.go index 9ad86f3db..4aa132f83 100644 --- a/pkg/rtc/types/interfaces.go +++ b/pkg/rtc/types/interfaces.go @@ -7,6 +7,7 @@ import ( "github.com/pion/rtp" "github.com/pion/webrtc/v3" + "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/livekit-server/proto/livekit" @@ -53,6 +54,8 @@ type Participant interface { IsReady() bool ToProto() *livekit.ParticipantInfo RTCPChan() *utils.CalmChannel + GetResponseSink() routing.MessageSink + SetResponseSink(sink routing.MessageSink) AddTrack(clientId, name string, trackType livekit.TrackType) Answer(sdp webrtc.SessionDescription) (answer webrtc.SessionDescription, err error) diff --git a/pkg/rtc/types/typesfakes/fake_participant.go b/pkg/rtc/types/typesfakes/fake_participant.go index ccaa59de0..8d31fb64d 100644 --- a/pkg/rtc/types/typesfakes/fake_participant.go +++ b/pkg/rtc/types/typesfakes/fake_participant.go @@ -4,6 +4,7 @@ package typesfakes import ( "sync" + "github.com/livekit/livekit-server/pkg/routing" "github.com/livekit/livekit-server/pkg/rtc/types" "github.com/livekit/livekit-server/pkg/sfu" "github.com/livekit/livekit-server/pkg/utils" @@ -70,6 +71,16 @@ type FakeParticipant struct { closeReturnsOnCall map[int]struct { result1 error } + GetResponseSinkStub func() routing.MessageSink + getResponseSinkMutex sync.RWMutex + getResponseSinkArgsForCall []struct { + } + getResponseSinkReturns struct { + result1 routing.MessageSink + } + getResponseSinkReturnsOnCall map[int]struct { + result1 routing.MessageSink + } HandleAnswerStub func(webrtc.SessionDescription) error handleAnswerMutex sync.RWMutex handleAnswerArgsForCall []struct { @@ -194,6 +205,11 @@ type FakeParticipant struct { sendParticipantUpdateReturnsOnCall map[int]struct { result1 error } + SetResponseSinkStub func(routing.MessageSink) + setResponseSinkMutex sync.RWMutex + setResponseSinkArgsForCall []struct { + arg1 routing.MessageSink + } SetTrackMutedStub func(string, bool) setTrackMutedMutex sync.RWMutex setTrackMutedArgsForCall []struct { @@ -534,6 +550,59 @@ func (fake *FakeParticipant) CloseReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeParticipant) GetResponseSink() routing.MessageSink { + fake.getResponseSinkMutex.Lock() + ret, specificReturn := fake.getResponseSinkReturnsOnCall[len(fake.getResponseSinkArgsForCall)] + fake.getResponseSinkArgsForCall = append(fake.getResponseSinkArgsForCall, struct { + }{}) + stub := fake.GetResponseSinkStub + fakeReturns := fake.getResponseSinkReturns + fake.recordInvocation("GetResponseSink", []interface{}{}) + fake.getResponseSinkMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetResponseSinkCallCount() int { + fake.getResponseSinkMutex.RLock() + defer fake.getResponseSinkMutex.RUnlock() + return len(fake.getResponseSinkArgsForCall) +} + +func (fake *FakeParticipant) GetResponseSinkCalls(stub func() routing.MessageSink) { + fake.getResponseSinkMutex.Lock() + defer fake.getResponseSinkMutex.Unlock() + fake.GetResponseSinkStub = stub +} + +func (fake *FakeParticipant) GetResponseSinkReturns(result1 routing.MessageSink) { + fake.getResponseSinkMutex.Lock() + defer fake.getResponseSinkMutex.Unlock() + fake.GetResponseSinkStub = nil + fake.getResponseSinkReturns = struct { + result1 routing.MessageSink + }{result1} +} + +func (fake *FakeParticipant) GetResponseSinkReturnsOnCall(i int, result1 routing.MessageSink) { + fake.getResponseSinkMutex.Lock() + defer fake.getResponseSinkMutex.Unlock() + fake.GetResponseSinkStub = nil + if fake.getResponseSinkReturnsOnCall == nil { + fake.getResponseSinkReturnsOnCall = make(map[int]struct { + result1 routing.MessageSink + }) + } + fake.getResponseSinkReturnsOnCall[i] = struct { + result1 routing.MessageSink + }{result1} +} + func (fake *FakeParticipant) HandleAnswer(arg1 webrtc.SessionDescription) error { fake.handleAnswerMutex.Lock() ret, specificReturn := fake.handleAnswerReturnsOnCall[len(fake.handleAnswerArgsForCall)] @@ -1242,6 +1311,38 @@ func (fake *FakeParticipant) SendParticipantUpdateReturnsOnCall(i int, result1 e }{result1} } +func (fake *FakeParticipant) SetResponseSink(arg1 routing.MessageSink) { + fake.setResponseSinkMutex.Lock() + fake.setResponseSinkArgsForCall = append(fake.setResponseSinkArgsForCall, struct { + arg1 routing.MessageSink + }{arg1}) + stub := fake.SetResponseSinkStub + fake.recordInvocation("SetResponseSink", []interface{}{arg1}) + fake.setResponseSinkMutex.Unlock() + if stub != nil { + fake.SetResponseSinkStub(arg1) + } +} + +func (fake *FakeParticipant) SetResponseSinkCallCount() int { + fake.setResponseSinkMutex.RLock() + defer fake.setResponseSinkMutex.RUnlock() + return len(fake.setResponseSinkArgsForCall) +} + +func (fake *FakeParticipant) SetResponseSinkCalls(stub func(routing.MessageSink)) { + fake.setResponseSinkMutex.Lock() + defer fake.setResponseSinkMutex.Unlock() + fake.SetResponseSinkStub = stub +} + +func (fake *FakeParticipant) SetResponseSinkArgsForCall(i int) routing.MessageSink { + fake.setResponseSinkMutex.RLock() + defer fake.setResponseSinkMutex.RUnlock() + argsForCall := fake.setResponseSinkArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeParticipant) SetTrackMuted(arg1 string, arg2 bool) { fake.setTrackMutedMutex.Lock() fake.setTrackMutedArgsForCall = append(fake.setTrackMutedArgsForCall, struct { @@ -1420,6 +1521,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.answerMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() + fake.getResponseSinkMutex.RLock() + defer fake.getResponseSinkMutex.RUnlock() fake.handleAnswerMutex.RLock() defer fake.handleAnswerMutex.RUnlock() fake.handleClientNegotiationMutex.RLock() @@ -1452,6 +1555,8 @@ func (fake *FakeParticipant) Invocations() map[string][][]interface{} { defer fake.sendJoinResponseMutex.RUnlock() fake.sendParticipantUpdateMutex.RLock() defer fake.sendParticipantUpdateMutex.RUnlock() + fake.setResponseSinkMutex.RLock() + defer fake.setResponseSinkMutex.RUnlock() fake.setTrackMutedMutex.RLock() defer fake.setTrackMutedMutex.RUnlock() fake.startMutex.RLock() diff --git a/pkg/service/localroomstore.go b/pkg/service/localroomstore.go index a99a0f867..f0b786be0 100644 --- a/pkg/service/localroomstore.go +++ b/pkg/service/localroomstore.go @@ -6,7 +6,6 @@ import ( "github.com/thoas/go-funk" - "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/livekit-server/proto/livekit" ) @@ -16,17 +15,14 @@ type LocalRoomStore struct { rooms map[string]*livekit.Room // map of roomName => roomId roomIds map[string]string - // map of roomName => { participantName: participantId } - participantIds map[string]map[string]string - lock sync.RWMutex + lock sync.RWMutex } func NewLocalRoomStore() *LocalRoomStore { return &LocalRoomStore{ - rooms: make(map[string]*livekit.Room), - roomIds: make(map[string]string), - participantIds: make(map[string]map[string]string), - lock: sync.RWMutex{}, + rooms: make(map[string]*livekit.Room), + roomIds: make(map[string]string), + lock: sync.RWMutex{}, } } @@ -77,22 +73,3 @@ func (p *LocalRoomStore) DeleteRoom(idOrName string) error { delete(p.roomIds, room.Name) return nil } - -func (p *LocalRoomStore) GetParticipantId(room, name string) (string, error) { - p.lock.Lock() - defer p.lock.Unlock() - - roomParticipantIds := p.participantIds[room] - if roomParticipantIds == nil { - p.participantIds[room] = make(map[string]string) - roomParticipantIds = p.participantIds[room] - } - - pId := roomParticipantIds[name] - if pId == "" { - pId = utils.NewGuid(utils.ParticipantPrefix) - roomParticipantIds[name] = pId - } - - return pId, nil -} diff --git a/pkg/service/localroomstore_test.go b/pkg/service/localroomstore_test.go deleted file mode 100644 index 6818df441..000000000 --- a/pkg/service/localroomstore_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package service_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/livekit/livekit-server/pkg/service" -) - -func TestLocalRoomStore_GetParticipantId(t *testing.T) { - s := service.NewLocalRoomStore() - id, _ := s.GetParticipantId("room1", "p1") - assert.NotEmpty(t, id) - - t.Run("diff room, same name returns a new ID", func(t *testing.T) { - id2, _ := s.GetParticipantId("room2", "p1") - assert.NotEmpty(t, id2) - assert.NotEqual(t, id, id2) - }) - - t.Run("same room returns identical id", func(t *testing.T) { - id2, _ := s.GetParticipantId("room1", "p1") - assert.Equal(t, id, id2) - }) - - t.Run("same room with different name", func(t *testing.T) { - id2, _ := s.GetParticipantId("room1", "p2") - assert.NotEqual(t, id, id2) - }) -} diff --git a/pkg/service/redisroomstore.go b/pkg/service/redisroomstore.go index 987d6ce6f..1d2f4067c 100644 --- a/pkg/service/redisroomstore.go +++ b/pkg/service/redisroomstore.go @@ -8,7 +8,6 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/proto" - "github.com/livekit/livekit-server/pkg/utils" "github.com/livekit/livekit-server/proto/livekit" ) @@ -18,12 +17,12 @@ const ( // hash of room_id => room name RoomIdMap = "room_id_map" - - // hash of participant_name => participant_id - // a key for each room, with expiration - RoomParticipantMapPrefix = "participant_map:room:" - - participantMappingTTL = 24 * time.Hour + // + //// hash of participant_name => participant_id + //// a key for each room, with expiration + //RoomParticipantMapPrefix = "participant_map:room:" + // + //participantMappingTTL = 24 * time.Hour ) type RedisRoomStore struct { @@ -112,26 +111,7 @@ func (p *RedisRoomStore) DeleteRoom(idOrName string) error { pp := p.rc.Pipeline() pp.HDel(p.ctx, RoomIdMap, room.Sid) pp.HDel(p.ctx, RoomsKey, room.Name) - pp.HDel(p.ctx, RoomParticipantMapPrefix+room.Name) _, err = pp.Exec(p.ctx) return err } - -func (p *RedisRoomStore) GetParticipantId(room, name string) (string, error) { - key := RoomParticipantMapPrefix + room - - pId, err := p.rc.HGet(p.ctx, key, name).Result() - if err == redis.Nil { - // create - pId = utils.NewGuid(utils.ParticipantPrefix) - pp := p.rc.Pipeline() - pp.HSet(p.ctx, key, name, pId) - pp.Expire(p.ctx, key, participantMappingTTL) - _, err = pp.Exec(p.ctx) - } else if err != nil { - return "", err - } - - return pId, err -} diff --git a/pkg/service/roommanager.go b/pkg/service/roommanager.go index a6355d13f..e6ae438fb 100644 --- a/pkg/service/roommanager.go +++ b/pkg/service/roommanager.go @@ -79,6 +79,7 @@ func (r *RoomManager) CreateRoom(req *livekit.CreateRoomRequest) (*livekit.Room, // DeleteRoom completely deletes all room information, including active sessions, room store, and routing info func (r *RoomManager) DeleteRoom(roomName string) error { + logger.Infow("deleting room state", "room", roomName) r.lock.Lock() delete(r.rooms, roomName) r.lock.Unlock() @@ -124,17 +125,34 @@ func (r *RoomManager) Cleanup() error { } // starts WebRTC session when a new participant is connected, takes place on RTC node -func (r *RoomManager) StartSession(roomName, participantId, participantName string, requestSource routing.MessageSource, responseSink routing.MessageSink) { +func (r *RoomManager) StartSession(roomName, identity string, requestSource routing.MessageSource, responseSink routing.MessageSink) { room, err := r.getOrCreateRoom(roomName) if err != nil { logger.Errorw("could not create room", "error", err) return } + // Use existing peer connection if it's already connected, perhaps from a different signal connection + participant := room.GetParticipant(identity) + if participant != nil { + logger.Debugw("resuming RTC session", + "room", roomName, + "node", r.currentNode.Id, + "participant", identity, + ) + // close previous sink, and link to new one + prevSink := participant.GetResponseSink() + if prevSink != nil { + prevSink.Close() + } + participant.SetResponseSink(responseSink) + return + } + logger.Debugw("starting RTC session", "room", roomName, "node", r.currentNode.Id, - "participant", participantName, + "participant", identity, "num_participants", len(room.GetParticipants()), ) @@ -144,7 +162,7 @@ func (r *RoomManager) StartSession(roomName, participantId, participantName stri return } - participant, err := rtc.NewParticipant(participantId, participantName, pc, responseSink, r.config.Receiver) + participant, err = rtc.NewParticipant(identity, pc, responseSink, r.config.Receiver) if err != nil { logger.Errorw("could not create participant", "error", err) return diff --git a/pkg/service/roomstore.go b/pkg/service/roomstore.go index c8206b7da..b21d0bb73 100644 --- a/pkg/service/roomstore.go +++ b/pkg/service/roomstore.go @@ -14,6 +14,4 @@ type RoomStore interface { GetRoom(idOrName string) (*livekit.Room, error) ListRooms() ([]*livekit.Room, error) DeleteRoom(idOrName string) error - // returns the current participant id in room, or create new one - GetParticipantId(room, name string) (string, error) } diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index 0a0549312..44c6043bf 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -48,7 +48,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { if claims == nil || claims.Video == nil { handleError(w, http.StatusUnauthorized, rtc.ErrPermissionDenied.Error()) } - pName := claims.Identity + identity := claims.Identity onlyName, err := EnsureJoinPermission(r.Context()) if err != nil { @@ -67,35 +67,21 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - participantId, err := s.roomManager.roomStore.GetParticipantId(roomName, pName) - if err != nil { - handleError(w, http.StatusInternalServerError, "could not get participant ID: "+err.Error()) - return - } - - err = s.router.StartParticipantSignal(roomName, participantId, pName) + // this needs to be started first *before* using router functions on this node + reqSink, resSource, err := s.router.StartParticipantSignal(roomName, identity) if err != nil { handleError(w, http.StatusInternalServerError, "could not start session: "+err.Error()) return } - reqSink, err := s.router.GetRequestSink(participantId) - if err != nil { - handleError(w, http.StatusInternalServerError, "could not get request sink"+err.Error()) - return - } + logger.Debugw("started participant signal", "sink", reqSink, "source", resSource) + done := make(chan bool, 1) defer func() { - logger.Infow("WS connection closed", "participant", pName) + logger.Infow("WS connection closed", "participant", identity) reqSink.Close() close(done) }() - resSource, err := s.router.GetResponseSource(participantId) - if err != nil { - handleError(w, http.StatusInternalServerError, "could not get response source"+err.Error()) - return - } - // upgrade only once the basics are good to go conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { @@ -110,12 +96,17 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Infow("new client connected", "room", rm.Sid, "roomName", rm.Name, - "name", pName, - "resSource", fmt.Sprintf("%p", resSource), + "name", identity, ) // handle responses go func() { + defer func() { + // when the source is terminated, this means Participant.Close had been called and RTC connection is done + // we would terminate the signal connection as well + conn.Close() + }() + defer rtc.Recover() for { select { case <-done: diff --git a/pkg/service/servicefakes/fake_room_store.go b/pkg/service/servicefakes/fake_room_store.go index 25f532c7d..8cd943b13 100644 --- a/pkg/service/servicefakes/fake_room_store.go +++ b/pkg/service/servicefakes/fake_room_store.go @@ -31,20 +31,6 @@ type FakeRoomStore struct { deleteRoomReturnsOnCall map[int]struct { result1 error } - GetParticipantIdStub func(string, string) (string, error) - getParticipantIdMutex sync.RWMutex - getParticipantIdArgsForCall []struct { - arg1 string - arg2 string - } - getParticipantIdReturns struct { - result1 string - result2 error - } - getParticipantIdReturnsOnCall map[int]struct { - result1 string - result2 error - } GetRoomStub func(string) (*livekit.Room, error) getRoomMutex sync.RWMutex getRoomArgsForCall []struct { @@ -196,71 +182,6 @@ func (fake *FakeRoomStore) DeleteRoomReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeRoomStore) GetParticipantId(arg1 string, arg2 string) (string, error) { - fake.getParticipantIdMutex.Lock() - ret, specificReturn := fake.getParticipantIdReturnsOnCall[len(fake.getParticipantIdArgsForCall)] - fake.getParticipantIdArgsForCall = append(fake.getParticipantIdArgsForCall, struct { - arg1 string - arg2 string - }{arg1, arg2}) - stub := fake.GetParticipantIdStub - fakeReturns := fake.getParticipantIdReturns - fake.recordInvocation("GetParticipantId", []interface{}{arg1, arg2}) - fake.getParticipantIdMutex.Unlock() - if stub != nil { - return stub(arg1, arg2) - } - if specificReturn { - return ret.result1, ret.result2 - } - return fakeReturns.result1, fakeReturns.result2 -} - -func (fake *FakeRoomStore) GetParticipantIdCallCount() int { - fake.getParticipantIdMutex.RLock() - defer fake.getParticipantIdMutex.RUnlock() - return len(fake.getParticipantIdArgsForCall) -} - -func (fake *FakeRoomStore) GetParticipantIdCalls(stub func(string, string) (string, error)) { - fake.getParticipantIdMutex.Lock() - defer fake.getParticipantIdMutex.Unlock() - fake.GetParticipantIdStub = stub -} - -func (fake *FakeRoomStore) GetParticipantIdArgsForCall(i int) (string, string) { - fake.getParticipantIdMutex.RLock() - defer fake.getParticipantIdMutex.RUnlock() - argsForCall := fake.getParticipantIdArgsForCall[i] - return argsForCall.arg1, argsForCall.arg2 -} - -func (fake *FakeRoomStore) GetParticipantIdReturns(result1 string, result2 error) { - fake.getParticipantIdMutex.Lock() - defer fake.getParticipantIdMutex.Unlock() - fake.GetParticipantIdStub = nil - fake.getParticipantIdReturns = struct { - result1 string - result2 error - }{result1, result2} -} - -func (fake *FakeRoomStore) GetParticipantIdReturnsOnCall(i int, result1 string, result2 error) { - fake.getParticipantIdMutex.Lock() - defer fake.getParticipantIdMutex.Unlock() - fake.GetParticipantIdStub = nil - if fake.getParticipantIdReturnsOnCall == nil { - fake.getParticipantIdReturnsOnCall = make(map[int]struct { - result1 string - result2 error - }) - } - fake.getParticipantIdReturnsOnCall[i] = struct { - result1 string - result2 error - }{result1, result2} -} - func (fake *FakeRoomStore) GetRoom(arg1 string) (*livekit.Room, error) { fake.getRoomMutex.Lock() ret, specificReturn := fake.getRoomReturnsOnCall[len(fake.getRoomArgsForCall)] @@ -388,8 +309,6 @@ func (fake *FakeRoomStore) Invocations() map[string][][]interface{} { defer fake.createRoomMutex.RUnlock() fake.deleteRoomMutex.RLock() defer fake.deleteRoomMutex.RUnlock() - fake.getParticipantIdMutex.RLock() - defer fake.getParticipantIdMutex.RUnlock() fake.getRoomMutex.RLock() defer fake.getRoomMutex.RUnlock() fake.listRoomsMutex.RLock() diff --git a/proto/internal.proto b/proto/internal.proto index 869744b61..f617f2dfb 100644 --- a/proto/internal.proto +++ b/proto/internal.proto @@ -24,20 +24,39 @@ message NodeStats { uint32 num_tracks_out = 6; } -// message for a node through the router -message RouterMessage { +// message to RTC nodes +message RTCNodeMessage { + string participant_key = 1; oneof message { - StartSession start_session = 1; - SignalRequest request = 2; - SignalResponse response = 3; - EndSession end_session = 4; + StartSession start_session = 2; + SignalRequest request = 3; } - string participant_id = 5; } +// message to Signal nodes +message SignalNodeMessage { + string connection_id = 1; + oneof message { + SignalResponse response = 2; + EndSession end_session = 3; + } +} + +//message RouterMessage { +// oneof message { +// StartSession start_session = 1; +// SignalRequest request = 2; +// SignalResponse response = 3; +// EndSession end_session = 4; +// } +// // empty for start session +// string participant_key = 5; +//} + message StartSession { string room_name = 1; - string participant_name = 2; + string identity = 2; + string connection_id = 3; } message EndSession { diff --git a/proto/livekit/internal.pb.go b/proto/livekit/internal.pb.go index 2f6e002b9..48320473b 100644 --- a/proto/livekit/internal.pb.go +++ b/proto/livekit/internal.pb.go @@ -185,23 +185,21 @@ func (x *NodeStats) GetNumTracksOut() uint32 { return 0 } -// message for a node through the router -type RouterMessage struct { +// message to RTC nodes +type RTCNodeMessage struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + ParticipantKey string `protobuf:"bytes,1,opt,name=participant_key,json=participantKey,proto3" json:"participant_key,omitempty"` // Types that are assignable to Message: - // *RouterMessage_StartSession - // *RouterMessage_Request - // *RouterMessage_Response - // *RouterMessage_EndSession - Message isRouterMessage_Message `protobuf_oneof:"message"` - ParticipantId string `protobuf:"bytes,5,opt,name=participant_id,json=participantId,proto3" json:"participant_id,omitempty"` + // *RTCNodeMessage_StartSession + // *RTCNodeMessage_Request + Message isRTCNodeMessage_Message `protobuf_oneof:"message"` } -func (x *RouterMessage) Reset() { - *x = RouterMessage{} +func (x *RTCNodeMessage) Reset() { + *x = RTCNodeMessage{} if protoimpl.UnsafeEnabled { mi := &file_internal_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -209,13 +207,13 @@ func (x *RouterMessage) Reset() { } } -func (x *RouterMessage) String() string { +func (x *RTCNodeMessage) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RouterMessage) ProtoMessage() {} +func (*RTCNodeMessage) ProtoMessage() {} -func (x *RouterMessage) ProtoReflect() protoreflect.Message { +func (x *RTCNodeMessage) ProtoReflect() protoreflect.Message { mi := &file_internal_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -227,94 +225,158 @@ func (x *RouterMessage) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RouterMessage.ProtoReflect.Descriptor instead. -func (*RouterMessage) Descriptor() ([]byte, []int) { +// Deprecated: Use RTCNodeMessage.ProtoReflect.Descriptor instead. +func (*RTCNodeMessage) Descriptor() ([]byte, []int) { return file_internal_proto_rawDescGZIP(), []int{2} } -func (m *RouterMessage) GetMessage() isRouterMessage_Message { +func (x *RTCNodeMessage) GetParticipantKey() string { + if x != nil { + return x.ParticipantKey + } + return "" +} + +func (m *RTCNodeMessage) GetMessage() isRTCNodeMessage_Message { if m != nil { return m.Message } return nil } -func (x *RouterMessage) GetStartSession() *StartSession { - if x, ok := x.GetMessage().(*RouterMessage_StartSession); ok { +func (x *RTCNodeMessage) GetStartSession() *StartSession { + if x, ok := x.GetMessage().(*RTCNodeMessage_StartSession); ok { return x.StartSession } return nil } -func (x *RouterMessage) GetRequest() *SignalRequest { - if x, ok := x.GetMessage().(*RouterMessage_Request); ok { +func (x *RTCNodeMessage) GetRequest() *SignalRequest { + if x, ok := x.GetMessage().(*RTCNodeMessage_Request); ok { return x.Request } return nil } -func (x *RouterMessage) GetResponse() *SignalResponse { - if x, ok := x.GetMessage().(*RouterMessage_Response); ok { +type isRTCNodeMessage_Message interface { + isRTCNodeMessage_Message() +} + +type RTCNodeMessage_StartSession struct { + StartSession *StartSession `protobuf:"bytes,2,opt,name=start_session,json=startSession,proto3,oneof"` +} + +type RTCNodeMessage_Request struct { + Request *SignalRequest `protobuf:"bytes,3,opt,name=request,proto3,oneof"` +} + +func (*RTCNodeMessage_StartSession) isRTCNodeMessage_Message() {} + +func (*RTCNodeMessage_Request) isRTCNodeMessage_Message() {} + +// message to Signal nodes +type SignalNodeMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ConnectionId string `protobuf:"bytes,1,opt,name=connection_id,json=connectionId,proto3" json:"connection_id,omitempty"` + // Types that are assignable to Message: + // *SignalNodeMessage_Response + // *SignalNodeMessage_EndSession + Message isSignalNodeMessage_Message `protobuf_oneof:"message"` +} + +func (x *SignalNodeMessage) Reset() { + *x = SignalNodeMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_internal_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SignalNodeMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SignalNodeMessage) ProtoMessage() {} + +func (x *SignalNodeMessage) ProtoReflect() protoreflect.Message { + mi := &file_internal_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SignalNodeMessage.ProtoReflect.Descriptor instead. +func (*SignalNodeMessage) Descriptor() ([]byte, []int) { + return file_internal_proto_rawDescGZIP(), []int{3} +} + +func (x *SignalNodeMessage) GetConnectionId() string { + if x != nil { + return x.ConnectionId + } + return "" +} + +func (m *SignalNodeMessage) GetMessage() isSignalNodeMessage_Message { + if m != nil { + return m.Message + } + return nil +} + +func (x *SignalNodeMessage) GetResponse() *SignalResponse { + if x, ok := x.GetMessage().(*SignalNodeMessage_Response); ok { return x.Response } return nil } -func (x *RouterMessage) GetEndSession() *EndSession { - if x, ok := x.GetMessage().(*RouterMessage_EndSession); ok { +func (x *SignalNodeMessage) GetEndSession() *EndSession { + if x, ok := x.GetMessage().(*SignalNodeMessage_EndSession); ok { return x.EndSession } return nil } -func (x *RouterMessage) GetParticipantId() string { - if x != nil { - return x.ParticipantId - } - return "" +type isSignalNodeMessage_Message interface { + isSignalNodeMessage_Message() } -type isRouterMessage_Message interface { - isRouterMessage_Message() +type SignalNodeMessage_Response struct { + Response *SignalResponse `protobuf:"bytes,2,opt,name=response,proto3,oneof"` } -type RouterMessage_StartSession struct { - StartSession *StartSession `protobuf:"bytes,1,opt,name=start_session,json=startSession,proto3,oneof"` +type SignalNodeMessage_EndSession struct { + EndSession *EndSession `protobuf:"bytes,3,opt,name=end_session,json=endSession,proto3,oneof"` } -type RouterMessage_Request struct { - Request *SignalRequest `protobuf:"bytes,2,opt,name=request,proto3,oneof"` -} +func (*SignalNodeMessage_Response) isSignalNodeMessage_Message() {} -type RouterMessage_Response struct { - Response *SignalResponse `protobuf:"bytes,3,opt,name=response,proto3,oneof"` -} - -type RouterMessage_EndSession struct { - EndSession *EndSession `protobuf:"bytes,4,opt,name=end_session,json=endSession,proto3,oneof"` -} - -func (*RouterMessage_StartSession) isRouterMessage_Message() {} - -func (*RouterMessage_Request) isRouterMessage_Message() {} - -func (*RouterMessage_Response) isRouterMessage_Message() {} - -func (*RouterMessage_EndSession) isRouterMessage_Message() {} +func (*SignalNodeMessage_EndSession) isSignalNodeMessage_Message() {} type StartSession struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RoomName string `protobuf:"bytes,1,opt,name=room_name,json=roomName,proto3" json:"room_name,omitempty"` - ParticipantName string `protobuf:"bytes,2,opt,name=participant_name,json=participantName,proto3" json:"participant_name,omitempty"` + RoomName string `protobuf:"bytes,1,opt,name=room_name,json=roomName,proto3" json:"room_name,omitempty"` + Identity string `protobuf:"bytes,2,opt,name=identity,proto3" json:"identity,omitempty"` + ConnectionId string `protobuf:"bytes,3,opt,name=connection_id,json=connectionId,proto3" json:"connection_id,omitempty"` } func (x *StartSession) Reset() { *x = StartSession{} if protoimpl.UnsafeEnabled { - mi := &file_internal_proto_msgTypes[3] + mi := &file_internal_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -327,7 +389,7 @@ func (x *StartSession) String() string { func (*StartSession) ProtoMessage() {} func (x *StartSession) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_msgTypes[3] + mi := &file_internal_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -340,7 +402,7 @@ func (x *StartSession) ProtoReflect() protoreflect.Message { // Deprecated: Use StartSession.ProtoReflect.Descriptor instead. func (*StartSession) Descriptor() ([]byte, []int) { - return file_internal_proto_rawDescGZIP(), []int{3} + return file_internal_proto_rawDescGZIP(), []int{4} } func (x *StartSession) GetRoomName() string { @@ -350,9 +412,16 @@ func (x *StartSession) GetRoomName() string { return "" } -func (x *StartSession) GetParticipantName() string { +func (x *StartSession) GetIdentity() string { if x != nil { - return x.ParticipantName + return x.Identity + } + return "" +} + +func (x *StartSession) GetConnectionId() string { + if x != nil { + return x.ConnectionId } return "" } @@ -366,7 +435,7 @@ type EndSession struct { func (x *EndSession) Reset() { *x = EndSession{} if protoimpl.UnsafeEnabled { - mi := &file_internal_proto_msgTypes[4] + mi := &file_internal_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -379,7 +448,7 @@ func (x *EndSession) String() string { func (*EndSession) ProtoMessage() {} func (x *EndSession) ProtoReflect() protoreflect.Message { - mi := &file_internal_proto_msgTypes[4] + mi := &file_internal_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -392,7 +461,7 @@ func (x *EndSession) ProtoReflect() protoreflect.Message { // Deprecated: Use EndSession.ProtoReflect.Descriptor instead. func (*EndSession) Descriptor() ([]byte, []int) { - return file_internal_proto_rawDescGZIP(), []int{4} + return file_internal_proto_rawDescGZIP(), []int{5} } var File_internal_proto protoreflect.FileDescriptor @@ -420,35 +489,41 @@ var file_internal_proto_rawDesc = []byte{ 0x28, 0x0d, 0x52, 0x0b, 0x6e, 0x75, 0x6d, 0x54, 0x72, 0x61, 0x63, 0x6b, 0x73, 0x49, 0x6e, 0x12, 0x24, 0x0a, 0x0e, 0x6e, 0x75, 0x6d, 0x5f, 0x74, 0x72, 0x61, 0x63, 0x6b, 0x73, 0x5f, 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x6e, 0x75, 0x6d, 0x54, 0x72, 0x61, 0x63, - 0x6b, 0x73, 0x4f, 0x75, 0x74, 0x22, 0xa2, 0x02, 0x0a, 0x0d, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x72, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3c, 0x0a, 0x0d, 0x73, 0x74, 0x61, 0x72, 0x74, - 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, - 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2e, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x72, 0x74, 0x53, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, - 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, - 0x52, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x35, 0x0a, 0x08, 0x72, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6c, 0x69, + 0x6b, 0x73, 0x4f, 0x75, 0x74, 0x22, 0xb6, 0x01, 0x0a, 0x0e, 0x52, 0x54, 0x43, 0x4e, 0x6f, 0x64, + 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x27, 0x0a, 0x0f, 0x70, 0x61, 0x72, 0x74, + 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0e, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, 0x4b, 0x65, + 0x79, 0x12, 0x3c, 0x0a, 0x0d, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, + 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, + 0x69, 0x74, 0x2e, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x48, + 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x72, 0x74, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x32, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, + 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x07, 0x72, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xb2, + 0x01, 0x0a, 0x11, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x4e, 0x6f, 0x64, 0x65, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x35, 0x0a, 0x08, 0x72, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x36, 0x0a, 0x0b, 0x65, 0x6e, 0x64, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2e, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2e, 0x45, 0x6e, 0x64, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0a, 0x65, 0x6e, - 0x64, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x61, 0x72, 0x74, - 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, 0x49, 0x64, 0x42, - 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x56, 0x0a, 0x0c, 0x53, 0x74, - 0x61, 0x72, 0x74, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x72, 0x6f, - 0x6f, 0x6d, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x72, - 0x6f, 0x6f, 0x6d, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x70, 0x61, 0x72, 0x74, 0x69, - 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0f, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, 0x6e, 0x74, 0x4e, 0x61, - 0x6d, 0x65, 0x22, 0x0c, 0x0a, 0x0a, 0x45, 0x6e, 0x64, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, - 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2f, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2d, 0x73, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6c, 0x69, 0x76, 0x65, - 0x6b, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x64, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x6c, 0x0a, 0x0c, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x72, 0x6f, 0x6f, 0x6d, 0x5f, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x6f, 0x6d, 0x4e, 0x61, 0x6d, 0x65, + 0x12, 0x1a, 0x0a, 0x08, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x12, 0x23, 0x0a, 0x0d, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x49, + 0x64, 0x22, 0x0c, 0x0a, 0x0a, 0x45, 0x6e, 0x64, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x42, + 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, + 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2f, 0x6c, 0x69, 0x76, 0x65, 0x6b, 0x69, 0x74, 0x2d, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6c, 0x69, 0x76, 0x65, 0x6b, + 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -463,22 +538,23 @@ func file_internal_proto_rawDescGZIP() []byte { return file_internal_proto_rawDescData } -var file_internal_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_internal_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_internal_proto_goTypes = []interface{}{ - (*Node)(nil), // 0: livekit.Node - (*NodeStats)(nil), // 1: livekit.NodeStats - (*RouterMessage)(nil), // 2: livekit.RouterMessage - (*StartSession)(nil), // 3: livekit.StartSession - (*EndSession)(nil), // 4: livekit.EndSession - (*SignalRequest)(nil), // 5: livekit.SignalRequest - (*SignalResponse)(nil), // 6: livekit.SignalResponse + (*Node)(nil), // 0: livekit.Node + (*NodeStats)(nil), // 1: livekit.NodeStats + (*RTCNodeMessage)(nil), // 2: livekit.RTCNodeMessage + (*SignalNodeMessage)(nil), // 3: livekit.SignalNodeMessage + (*StartSession)(nil), // 4: livekit.StartSession + (*EndSession)(nil), // 5: livekit.EndSession + (*SignalRequest)(nil), // 6: livekit.SignalRequest + (*SignalResponse)(nil), // 7: livekit.SignalResponse } var file_internal_proto_depIdxs = []int32{ 1, // 0: livekit.Node.stats:type_name -> livekit.NodeStats - 3, // 1: livekit.RouterMessage.start_session:type_name -> livekit.StartSession - 5, // 2: livekit.RouterMessage.request:type_name -> livekit.SignalRequest - 6, // 3: livekit.RouterMessage.response:type_name -> livekit.SignalResponse - 4, // 4: livekit.RouterMessage.end_session:type_name -> livekit.EndSession + 4, // 1: livekit.RTCNodeMessage.start_session:type_name -> livekit.StartSession + 6, // 2: livekit.RTCNodeMessage.request:type_name -> livekit.SignalRequest + 7, // 3: livekit.SignalNodeMessage.response:type_name -> livekit.SignalResponse + 5, // 4: livekit.SignalNodeMessage.end_session:type_name -> livekit.EndSession 5, // [5:5] is the sub-list for method output_type 5, // [5:5] is the sub-list for method input_type 5, // [5:5] is the sub-list for extension type_name @@ -518,7 +594,7 @@ func file_internal_proto_init() { } } file_internal_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouterMessage); i { + switch v := v.(*RTCNodeMessage); i { case 0: return &v.state case 1: @@ -530,7 +606,7 @@ func file_internal_proto_init() { } } file_internal_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*StartSession); i { + switch v := v.(*SignalNodeMessage); i { case 0: return &v.state case 1: @@ -542,6 +618,18 @@ func file_internal_proto_init() { } } file_internal_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StartSession); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_internal_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*EndSession); i { case 0: return &v.state @@ -555,10 +643,12 @@ func file_internal_proto_init() { } } file_internal_proto_msgTypes[2].OneofWrappers = []interface{}{ - (*RouterMessage_StartSession)(nil), - (*RouterMessage_Request)(nil), - (*RouterMessage_Response)(nil), - (*RouterMessage_EndSession)(nil), + (*RTCNodeMessage_StartSession)(nil), + (*RTCNodeMessage_Request)(nil), + } + file_internal_proto_msgTypes[3].OneofWrappers = []interface{}{ + (*SignalNodeMessage_Response)(nil), + (*SignalNodeMessage_EndSession)(nil), } type x struct{} out := protoimpl.TypeBuilder{ @@ -566,7 +656,7 @@ func file_internal_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_internal_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 6, NumExtensions: 0, NumServices: 0, },