From c015e267b0767a457b2f7c5ad8f7dbffb4e28408 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Thu, 4 Feb 2021 00:25:09 -0800 Subject: [PATCH] switch to a single redis subscriber, close properly --- cmd/cli/client/client.go | 18 ++-- pkg/routing/redisrouter.go | 197 ++++++++++++++++-------------------- pkg/rtc/room.go | 4 +- pkg/rtc/room_test.go | 8 +- pkg/service/rtcservice.go | 1 + test/integration_helpers.go | 16 ++- test/scenarios.go | 12 ++- 7 files changed, 125 insertions(+), 131 deletions(-) diff --git a/cmd/cli/client/client.go b/cmd/cli/client/client.go index fb13d50a1..e04a8b7ee 100644 --- a/cmd/cli/client/client.go +++ b/cmd/cli/client/client.go @@ -27,7 +27,6 @@ import ( type RTCClient struct { id string - identity string conn *websocket.Conn PeerConn *webrtc.PeerConnection // sid => track @@ -144,7 +143,7 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) { peerConn.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { logger.Debugw("ICE state has changed", "state", connectionState.String(), - "participant", c.localParticipant.Sid) + "participant", c.localParticipant.Identity) if connectionState == webrtc.ICEConnectionStateConnected { // flush peers c.lock.Lock() @@ -195,8 +194,8 @@ func (c *RTCClient) Run() error { } switch msg := res.Message.(type) { case *livekit.SignalResponse_Join: + c.localParticipant = msg.Join.Participant c.id = msg.Join.Participant.Sid - c.identity = msg.Join.Participant.Identity c.lock.Lock() for _, p := range msg.Join.OtherParticipants { @@ -205,7 +204,6 @@ func (c *RTCClient) Run() error { c.lock.Unlock() logger.Debugw("join accepted, sending offer..", "participant", msg.Join.Participant.Identity) - c.localParticipant = msg.Join.Participant logger.Debugw("other participants", "count", len(msg.Join.OtherParticipants)) // Create an offer to send to the other process @@ -214,7 +212,7 @@ func (c *RTCClient) Run() error { return err } - logger.Debugw("created offer", "offer", offer.SDP) + logger.Debugw("created offer", "participant", c.localParticipant.Identity) // Sets the LocalDescription, and starts our UDP listeners // Note: this will start the gathering of ICE candidates @@ -228,7 +226,6 @@ func (c *RTCClient) Run() error { Offer: rtc.ToProtoSessionDescription(offer), }, } - logger.Debugw("connecting to remote...") if err = c.SendRequest(req); err != nil { return err } @@ -254,7 +251,6 @@ func (c *RTCClient) Run() error { c.lock.Lock() for _, p := range msg.Update.Participants { c.remoteParticipants[p.Sid] = p - logger.Debugw("participant update", "id", p.Identity, "state", p.State.String()) } c.lock.Unlock() case *livekit.SignalResponse_TrackPublished: @@ -277,7 +273,11 @@ func (c *RTCClient) WaitUntilConnected() error { for { select { case <-ctx.Done(): - return errors.New("could not connect after timeout") + id := c.ID() + if c.localParticipant != nil { + id = c.localParticipant.Identity + } + return fmt.Errorf("%s could not connect after timeout", id) case <-time.After(10 * time.Millisecond): if c.iceConnected.Get() { return nil @@ -476,7 +476,7 @@ func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) error { } func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription) error { - logger.Debugw("handling server answer") + logger.Debugw("handling server answer", "participant", c.localParticipant.Identity) // remote answered the offer, establish connection err := c.PeerConn.SetRemoteDescription(desc) if err != nil { diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 3c9d78291..26fc726a0 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -2,7 +2,6 @@ package routing import ( "context" - "sync" "time" "github.com/go-redis/redis/v8" @@ -25,23 +24,24 @@ const ( // Because type RedisRouter struct { LocalRouter - rc *redis.Client - cr *utils.CachedRedis - ctx context.Context - once sync.Once + rc *redis.Client + cr *utils.CachedRedis + ctx context.Context + isStarted utils.AtomicFlag // map of participantKey => RTCNodeSink rtcSinks map[string]*RTCNodeSink // map of connectionId => SignalNodeSink signalSinks map[string]*SignalNodeSink - cancel func() + + pubsub *redis.PubSub + cancel func() } func NewRedisRouter(currentNode LocalNode, rc *redis.Client) *RedisRouter { rr := &RedisRouter{ LocalRouter: *NewLocalRouter(currentNode), rc: rc, - once: sync.Once{}, rtcSinks: make(map[string]*RTCNodeSink), signalSinks: make(map[string]*SignalNodeSink), } @@ -203,15 +203,20 @@ func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantK } func (r *RedisRouter) Start() error { - r.once.Do(func() { - go r.statsWorker() - go r.rtcWorker() - go r.signalWorker() - }) + if !r.isStarted.TrySet(true) { + return nil + } + go r.statsWorker() + go r.redisWorker() return nil } func (r *RedisRouter) Stop() { + if !r.isStarted.TrySet(false) { + return + } + logger.Debugw("stopping RedisRouter") + r.pubsub.Close() r.cancel() } @@ -282,115 +287,91 @@ func (r *RedisRouter) getParticipantSignalNode(connectionId string) (nodeId stri func (r *RedisRouter) statsWorker() { for r.ctx.Err() == nil { // update every 10 seconds - <-time.After(statsUpdateInterval) - r.currentNode.Stats.UpdatedAt = time.Now().Unix() - if err := r.RegisterNode(); err != nil { - logger.Errorw("could not update node", "error", err) + select { + case <-time.After(statsUpdateInterval): + r.currentNode.Stats.UpdatedAt = time.Now().Unix() + if err := r.RegisterNode(); err != nil { + logger.Errorw("could not update node", "error", err) + } + case <-r.ctx.Done(): + return } } } -// worker that consumes signal channel and processes -func (r *RedisRouter) signalWorker() { - sub := r.rc.Subscribe(redisCtx, signalNodeChannel(r.currentNode.Id)) +// worker that consumes redis messages intended for this node +func (r *RedisRouter) redisWorker() { defer func() { - logger.Debugw("finishing redis signalWorker", "node", r.currentNode.Id) + logger.Debugw("finishing redisWorker", "node", r.currentNode.Id) }() - logger.Debugw("starting redis signalWorker", "node", r.currentNode.Id) - for r.ctx.Err() == nil { - obj, err := sub.Receive(r.ctx) - if err != nil { - logger.Warnw("error receiving redis message", "error", err) - // TODO: retry? ignore? at a minimum need to sleep here to retry - time.Sleep(time.Second) - continue - } - if obj == nil { + logger.Debugw("starting redisWorker", "node", r.currentNode.Id) + + sigChannel := signalNodeChannel(r.currentNode.Id) + rtcChannel := rtcNodeChannel(r.currentNode.Id) + r.pubsub = r.rc.Subscribe(r.ctx, sigChannel, rtcChannel) + for msg := range r.pubsub.Channel() { + if msg == nil { return } - msg, ok := obj.(*redis.Message) - if !ok { - continue - } - - rm := livekit.SignalNodeMessage{} - err = proto.Unmarshal([]byte(msg.Payload), &rm) - connectionId := rm.ConnectionId - - switch rmb := rm.Message.(type) { - - case *livekit.SignalNodeMessage_Response: - // in the event the current node is an Signal node, push to response channels - resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) - err = resSink.WriteMessage(rmb.Response) - if err != nil { - logger.Errorw("could not write to response channel", - "connectionId", connectionId, - "error", err) - } - - case *livekit.SignalNodeMessage_EndSession: - //signalNode, err := r.getParticipantSignalNode(connectionId) - if err != nil { - logger.Errorw("could not get participant RTC node", - "error", err) + if msg.Channel == sigChannel { + sm := livekit.SignalNodeMessage{} + if err := proto.Unmarshal([]byte(msg.Payload), &sm); err != nil { + logger.Errorw("could not unmarshal signal message on sigchan", "error", err) continue } - // EndSession can only be initiated on an RTC node, is handled on the signal node - //if signalNode == r.currentNode.Id { - resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) - resSink.Close() - //} - } - } -} - -// worker that consumes RTC channel and processes -func (r *RedisRouter) rtcWorker() { - sub := r.rc.Subscribe(redisCtx, rtcNodeChannel(r.currentNode.Id)) - - defer func() { - logger.Debugw("finishing redis rtcWorker", "node", r.currentNode.Id) - }() - logger.Debugw("starting redis rtcWorker", "node", r.currentNode.Id) - for r.ctx.Err() == nil { - obj, err := sub.Receive(r.ctx) - if err != nil { - logger.Warnw("error receiving redis message", "error", err) - // TODO: retry? ignore? at a minimum need to sleep here to retry - time.Sleep(time.Second) - continue - } - if obj == nil { - return - } - - msg, ok := obj.(*redis.Message) - if !ok { - continue - } - - rm := livekit.RTCNodeMessage{} - err = proto.Unmarshal([]byte(msg.Payload), &rm) - pKey := rm.ParticipantKey - - switch rmb := rm.Message.(type) { - case *livekit.RTCNodeMessage_StartSession: - // RTC session should start on this node - err = r.startParticipantRTC(rmb.StartSession, pKey) - if err != nil { - logger.Errorw("could not start participant", "error", err) + if err := r.handleSignalMessage(&sm); err != nil { + logger.Errorw("error processing signal message", "error", err) + continue } - - case *livekit.RTCNodeMessage_Request: - requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey) - err = requestChan.WriteMessage(rmb.Request) - if err != nil { - logger.Errorw("could not write to request channel", - "participant", pKey, - "error", err) + } else if msg.Channel == rtcChannel { + rm := livekit.RTCNodeMessage{} + if err := proto.Unmarshal([]byte(msg.Payload), &rm); err != nil { + logger.Errorw("could not unmarshal RTC message on rtcchan", "error", err) + continue + } + if err := r.handleRTCMessage(&rm); err != nil { + logger.Errorw("error processing RTC message", "error", err) + continue } } } } + +func (r *RedisRouter) handleSignalMessage(sm *livekit.SignalNodeMessage) error { + connectionId := sm.ConnectionId + + switch rmb := sm.Message.(type) { + case *livekit.SignalNodeMessage_Response: + // in the event the current node is an Signal node, push to response channels + resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) + if err := resSink.WriteMessage(rmb.Response); err != nil { + return err + } + + case *livekit.SignalNodeMessage_EndSession: + logger.Debugw("received EndSession, closing signal connection") + resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) + resSink.Close() + } + return nil +} + +func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { + pKey := rm.ParticipantKey + + switch rmb := rm.Message.(type) { + case *livekit.RTCNodeMessage_StartSession: + // RTC session should start on this node + if err := r.startParticipantRTC(rmb.StartSession, pKey); err != nil { + return errors.Wrap(err, "could not start participant") + } + + case *livekit.RTCNodeMessage_Request: + requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey) + if err := requestChan.WriteMessage(rmb.Request); err != nil { + return err + } + } + return nil +} diff --git a/pkg/rtc/room.go b/pkg/rtc/room.go index cda3774c0..f036a2e4e 100644 --- a/pkg/rtc/room.go +++ b/pkg/rtc/room.go @@ -102,7 +102,7 @@ func (r *Room) Join(participant types.Participant) error { "oldState", oldState) r.broadcastParticipantState(p) - if oldState == livekit.ParticipantInfo_JOINING && p.State() == livekit.ParticipantInfo_JOINED { + if p.State() == livekit.ParticipantInfo_ACTIVE { // subscribe participant to existing publishedTracks for _, op := range r.GetParticipants() { if p.ID() == op.ID() { @@ -215,7 +215,7 @@ func (r *Room) onTrackAdded(participant types.Participant, track types.Published // skip publishing participant continue } - if !existingParticipant.IsReady() { + if existingParticipant.State() != livekit.ParticipantInfo_ACTIVE { // not fully joined. don't subscribe yet continue } diff --git a/pkg/rtc/room_test.go b/pkg/rtc/room_test.go index bab8119a5..382a06d70 100644 --- a/pkg/rtc/room_test.go +++ b/pkg/rtc/room_test.go @@ -76,8 +76,8 @@ func TestRoomJoin(t *testing.T) { stateChangeCB := p.OnStateChangeArgsForCall(0) assert.NotNil(t, stateChangeCB) - p.StateReturns(livekit.ParticipantInfo_JOINED) - stateChangeCB(p, livekit.ParticipantInfo_JOINING) + p.StateReturns(livekit.ParticipantInfo_ACTIVE) + stateChangeCB(p, livekit.ParticipantInfo_JOINED) // it should become a subscriber when connectivity changes for _, op := range rm.GetParticipants() { @@ -176,9 +176,9 @@ func TestNewTrack(t *testing.T) { rm := newRoomWithParticipants(t, 3) participants := rm.GetParticipants() p0 := participants[0].(*typesfakes.FakeParticipant) - p0.IsReadyReturns(false) + p0.StateReturns(livekit.ParticipantInfo_JOINED) p1 := participants[1].(*typesfakes.FakeParticipant) - p1.IsReadyReturns(true) + p1.StateReturns(livekit.ParticipantInfo_ACTIVE) pub := participants[2].(*typesfakes.FakeParticipant) diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index dcc708fc6..b2cc367d6 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -118,6 +118,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) { return case msg := <-resSource.ReadChan(): if msg == nil { + logger.Errorw("source closed connection", "participant", identity) return } res, ok := msg.(*livekit.SignalResponse) diff --git a/test/integration_helpers.go b/test/integration_helpers.go index 1ad5e9f8a..c358ce688 100644 --- a/test/integration_helpers.go +++ b/test/integration_helpers.go @@ -9,7 +9,6 @@ import ( "time" "github.com/go-redis/redis/v8" - "github.com/stretchr/testify/assert" "github.com/twitchtv/twirp" "github.com/livekit/livekit-server/cmd/cli/client" @@ -122,17 +121,28 @@ func withTimeout(t *testing.T, description string, f func() bool) bool { func waitUntilConnected(t *testing.T, clients ...*client.RTCClient) { logger.Infow("waiting for clients to become connected") wg := sync.WaitGroup{} + errChan := make(chan error, len(clients)) for i := range clients { c := clients[i] wg.Add(1) go func() { defer wg.Done() - if !assert.NoError(t, c.WaitUntilConnected()) { - t.Fatal("client could not connect", c.ID()) + err := c.WaitUntilConnected() + if err != nil { + errChan <- err } }() } wg.Wait() + close(errChan) + hasError := false + for err := range errChan { + t.Fatal(err) + hasError = true + } + if hasError { + t.FailNow() + } } func createSingleNodeServer() *service.LivekitServer { diff --git a/test/scenarios.go b/test/scenarios.go index e821081dc..5a7b7f3ee 100644 --- a/test/scenarios.go +++ b/test/scenarios.go @@ -1,7 +1,6 @@ package test import ( - "math/rand" "testing" "time" @@ -13,9 +12,11 @@ import ( // a scenario with lots of clients connecting, publishing, and leaving at random periods func scenarioPublishingUponJoining(t *testing.T, ports ...int) { - c1 := createRTCClient("puj_1", ports[rand.Intn(len(ports))]) - c2 := createRTCClient("puj_2", ports[rand.Intn(len(ports))]) - c3 := createRTCClient("puj_3", ports[rand.Intn(len(ports))]) + firstPort := ports[0] + lastPort := ports[len(ports)-1] + c1 := createRTCClient("puj_1", firstPort) + c2 := createRTCClient("puj_2", lastPort) + c3 := createRTCClient("puj_3", firstPort) defer stopClients(c1, c2, c3) waitUntilConnected(t, c1, c2, c3) @@ -63,7 +64,8 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) { } logger.Infow("c2 reconnecting") - c2 = createRTCClient("puj_2", ports[rand.Intn(len(ports))]) + // connect to a diff port + c2 = createRTCClient("puj_2", firstPort) defer c2.Stop() waitUntilConnected(t, c2) writers = publishTracksForClients(t, c2)