switch to a single redis subscriber, close properly

This commit is contained in:
David Zhao
2021-02-04 00:25:09 -08:00
parent 5dec5b1ae2
commit c015e267b0
7 changed files with 125 additions and 131 deletions

View File

@@ -27,7 +27,6 @@ import (
type RTCClient struct {
id string
identity string
conn *websocket.Conn
PeerConn *webrtc.PeerConnection
// sid => track
@@ -144,7 +143,7 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) {
peerConn.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
logger.Debugw("ICE state has changed", "state", connectionState.String(),
"participant", c.localParticipant.Sid)
"participant", c.localParticipant.Identity)
if connectionState == webrtc.ICEConnectionStateConnected {
// flush peers
c.lock.Lock()
@@ -195,8 +194,8 @@ func (c *RTCClient) Run() error {
}
switch msg := res.Message.(type) {
case *livekit.SignalResponse_Join:
c.localParticipant = msg.Join.Participant
c.id = msg.Join.Participant.Sid
c.identity = msg.Join.Participant.Identity
c.lock.Lock()
for _, p := range msg.Join.OtherParticipants {
@@ -205,7 +204,6 @@ func (c *RTCClient) Run() error {
c.lock.Unlock()
logger.Debugw("join accepted, sending offer..", "participant", msg.Join.Participant.Identity)
c.localParticipant = msg.Join.Participant
logger.Debugw("other participants", "count", len(msg.Join.OtherParticipants))
// Create an offer to send to the other process
@@ -214,7 +212,7 @@ func (c *RTCClient) Run() error {
return err
}
logger.Debugw("created offer", "offer", offer.SDP)
logger.Debugw("created offer", "participant", c.localParticipant.Identity)
// Sets the LocalDescription, and starts our UDP listeners
// Note: this will start the gathering of ICE candidates
@@ -228,7 +226,6 @@ func (c *RTCClient) Run() error {
Offer: rtc.ToProtoSessionDescription(offer),
},
}
logger.Debugw("connecting to remote...")
if err = c.SendRequest(req); err != nil {
return err
}
@@ -254,7 +251,6 @@ func (c *RTCClient) Run() error {
c.lock.Lock()
for _, p := range msg.Update.Participants {
c.remoteParticipants[p.Sid] = p
logger.Debugw("participant update", "id", p.Identity, "state", p.State.String())
}
c.lock.Unlock()
case *livekit.SignalResponse_TrackPublished:
@@ -277,7 +273,11 @@ func (c *RTCClient) WaitUntilConnected() error {
for {
select {
case <-ctx.Done():
return errors.New("could not connect after timeout")
id := c.ID()
if c.localParticipant != nil {
id = c.localParticipant.Identity
}
return fmt.Errorf("%s could not connect after timeout", id)
case <-time.After(10 * time.Millisecond):
if c.iceConnected.Get() {
return nil
@@ -476,7 +476,7 @@ func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) error {
}
func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription) error {
logger.Debugw("handling server answer")
logger.Debugw("handling server answer", "participant", c.localParticipant.Identity)
// remote answered the offer, establish connection
err := c.PeerConn.SetRemoteDescription(desc)
if err != nil {

View File

@@ -2,7 +2,6 @@ package routing
import (
"context"
"sync"
"time"
"github.com/go-redis/redis/v8"
@@ -25,23 +24,24 @@ const (
// Because
type RedisRouter struct {
LocalRouter
rc *redis.Client
cr *utils.CachedRedis
ctx context.Context
once sync.Once
rc *redis.Client
cr *utils.CachedRedis
ctx context.Context
isStarted utils.AtomicFlag
// map of participantKey => RTCNodeSink
rtcSinks map[string]*RTCNodeSink
// map of connectionId => SignalNodeSink
signalSinks map[string]*SignalNodeSink
cancel func()
pubsub *redis.PubSub
cancel func()
}
func NewRedisRouter(currentNode LocalNode, rc *redis.Client) *RedisRouter {
rr := &RedisRouter{
LocalRouter: *NewLocalRouter(currentNode),
rc: rc,
once: sync.Once{},
rtcSinks: make(map[string]*RTCNodeSink),
signalSinks: make(map[string]*SignalNodeSink),
}
@@ -203,15 +203,20 @@ func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantK
}
func (r *RedisRouter) Start() error {
r.once.Do(func() {
go r.statsWorker()
go r.rtcWorker()
go r.signalWorker()
})
if !r.isStarted.TrySet(true) {
return nil
}
go r.statsWorker()
go r.redisWorker()
return nil
}
func (r *RedisRouter) Stop() {
if !r.isStarted.TrySet(false) {
return
}
logger.Debugw("stopping RedisRouter")
r.pubsub.Close()
r.cancel()
}
@@ -282,115 +287,91 @@ func (r *RedisRouter) getParticipantSignalNode(connectionId string) (nodeId stri
func (r *RedisRouter) statsWorker() {
for r.ctx.Err() == nil {
// update every 10 seconds
<-time.After(statsUpdateInterval)
r.currentNode.Stats.UpdatedAt = time.Now().Unix()
if err := r.RegisterNode(); err != nil {
logger.Errorw("could not update node", "error", err)
select {
case <-time.After(statsUpdateInterval):
r.currentNode.Stats.UpdatedAt = time.Now().Unix()
if err := r.RegisterNode(); err != nil {
logger.Errorw("could not update node", "error", err)
}
case <-r.ctx.Done():
return
}
}
}
// worker that consumes signal channel and processes
func (r *RedisRouter) signalWorker() {
sub := r.rc.Subscribe(redisCtx, signalNodeChannel(r.currentNode.Id))
// worker that consumes redis messages intended for this node
func (r *RedisRouter) redisWorker() {
defer func() {
logger.Debugw("finishing redis signalWorker", "node", r.currentNode.Id)
logger.Debugw("finishing redisWorker", "node", r.currentNode.Id)
}()
logger.Debugw("starting redis signalWorker", "node", r.currentNode.Id)
for r.ctx.Err() == nil {
obj, err := sub.Receive(r.ctx)
if err != nil {
logger.Warnw("error receiving redis message", "error", err)
// TODO: retry? ignore? at a minimum need to sleep here to retry
time.Sleep(time.Second)
continue
}
if obj == nil {
logger.Debugw("starting redisWorker", "node", r.currentNode.Id)
sigChannel := signalNodeChannel(r.currentNode.Id)
rtcChannel := rtcNodeChannel(r.currentNode.Id)
r.pubsub = r.rc.Subscribe(r.ctx, sigChannel, rtcChannel)
for msg := range r.pubsub.Channel() {
if msg == nil {
return
}
msg, ok := obj.(*redis.Message)
if !ok {
continue
}
rm := livekit.SignalNodeMessage{}
err = proto.Unmarshal([]byte(msg.Payload), &rm)
connectionId := rm.ConnectionId
switch rmb := rm.Message.(type) {
case *livekit.SignalNodeMessage_Response:
// in the event the current node is an Signal node, push to response channels
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
err = resSink.WriteMessage(rmb.Response)
if err != nil {
logger.Errorw("could not write to response channel",
"connectionId", connectionId,
"error", err)
}
case *livekit.SignalNodeMessage_EndSession:
//signalNode, err := r.getParticipantSignalNode(connectionId)
if err != nil {
logger.Errorw("could not get participant RTC node",
"error", err)
if msg.Channel == sigChannel {
sm := livekit.SignalNodeMessage{}
if err := proto.Unmarshal([]byte(msg.Payload), &sm); err != nil {
logger.Errorw("could not unmarshal signal message on sigchan", "error", err)
continue
}
// EndSession can only be initiated on an RTC node, is handled on the signal node
//if signalNode == r.currentNode.Id {
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
resSink.Close()
//}
}
}
}
// worker that consumes RTC channel and processes
func (r *RedisRouter) rtcWorker() {
sub := r.rc.Subscribe(redisCtx, rtcNodeChannel(r.currentNode.Id))
defer func() {
logger.Debugw("finishing redis rtcWorker", "node", r.currentNode.Id)
}()
logger.Debugw("starting redis rtcWorker", "node", r.currentNode.Id)
for r.ctx.Err() == nil {
obj, err := sub.Receive(r.ctx)
if err != nil {
logger.Warnw("error receiving redis message", "error", err)
// TODO: retry? ignore? at a minimum need to sleep here to retry
time.Sleep(time.Second)
continue
}
if obj == nil {
return
}
msg, ok := obj.(*redis.Message)
if !ok {
continue
}
rm := livekit.RTCNodeMessage{}
err = proto.Unmarshal([]byte(msg.Payload), &rm)
pKey := rm.ParticipantKey
switch rmb := rm.Message.(type) {
case *livekit.RTCNodeMessage_StartSession:
// RTC session should start on this node
err = r.startParticipantRTC(rmb.StartSession, pKey)
if err != nil {
logger.Errorw("could not start participant", "error", err)
if err := r.handleSignalMessage(&sm); err != nil {
logger.Errorw("error processing signal message", "error", err)
continue
}
case *livekit.RTCNodeMessage_Request:
requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey)
err = requestChan.WriteMessage(rmb.Request)
if err != nil {
logger.Errorw("could not write to request channel",
"participant", pKey,
"error", err)
} else if msg.Channel == rtcChannel {
rm := livekit.RTCNodeMessage{}
if err := proto.Unmarshal([]byte(msg.Payload), &rm); err != nil {
logger.Errorw("could not unmarshal RTC message on rtcchan", "error", err)
continue
}
if err := r.handleRTCMessage(&rm); err != nil {
logger.Errorw("error processing RTC message", "error", err)
continue
}
}
}
}
func (r *RedisRouter) handleSignalMessage(sm *livekit.SignalNodeMessage) error {
connectionId := sm.ConnectionId
switch rmb := sm.Message.(type) {
case *livekit.SignalNodeMessage_Response:
// in the event the current node is an Signal node, push to response channels
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
if err := resSink.WriteMessage(rmb.Response); err != nil {
return err
}
case *livekit.SignalNodeMessage_EndSession:
logger.Debugw("received EndSession, closing signal connection")
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
resSink.Close()
}
return nil
}
func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error {
pKey := rm.ParticipantKey
switch rmb := rm.Message.(type) {
case *livekit.RTCNodeMessage_StartSession:
// RTC session should start on this node
if err := r.startParticipantRTC(rmb.StartSession, pKey); err != nil {
return errors.Wrap(err, "could not start participant")
}
case *livekit.RTCNodeMessage_Request:
requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey)
if err := requestChan.WriteMessage(rmb.Request); err != nil {
return err
}
}
return nil
}

View File

@@ -102,7 +102,7 @@ func (r *Room) Join(participant types.Participant) error {
"oldState", oldState)
r.broadcastParticipantState(p)
if oldState == livekit.ParticipantInfo_JOINING && p.State() == livekit.ParticipantInfo_JOINED {
if p.State() == livekit.ParticipantInfo_ACTIVE {
// subscribe participant to existing publishedTracks
for _, op := range r.GetParticipants() {
if p.ID() == op.ID() {
@@ -215,7 +215,7 @@ func (r *Room) onTrackAdded(participant types.Participant, track types.Published
// skip publishing participant
continue
}
if !existingParticipant.IsReady() {
if existingParticipant.State() != livekit.ParticipantInfo_ACTIVE {
// not fully joined. don't subscribe yet
continue
}

View File

@@ -76,8 +76,8 @@ func TestRoomJoin(t *testing.T) {
stateChangeCB := p.OnStateChangeArgsForCall(0)
assert.NotNil(t, stateChangeCB)
p.StateReturns(livekit.ParticipantInfo_JOINED)
stateChangeCB(p, livekit.ParticipantInfo_JOINING)
p.StateReturns(livekit.ParticipantInfo_ACTIVE)
stateChangeCB(p, livekit.ParticipantInfo_JOINED)
// it should become a subscriber when connectivity changes
for _, op := range rm.GetParticipants() {
@@ -176,9 +176,9 @@ func TestNewTrack(t *testing.T) {
rm := newRoomWithParticipants(t, 3)
participants := rm.GetParticipants()
p0 := participants[0].(*typesfakes.FakeParticipant)
p0.IsReadyReturns(false)
p0.StateReturns(livekit.ParticipantInfo_JOINED)
p1 := participants[1].(*typesfakes.FakeParticipant)
p1.IsReadyReturns(true)
p1.StateReturns(livekit.ParticipantInfo_ACTIVE)
pub := participants[2].(*typesfakes.FakeParticipant)

View File

@@ -118,6 +118,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
case msg := <-resSource.ReadChan():
if msg == nil {
logger.Errorw("source closed connection", "participant", identity)
return
}
res, ok := msg.(*livekit.SignalResponse)

View File

@@ -9,7 +9,6 @@ import (
"time"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
"github.com/twitchtv/twirp"
"github.com/livekit/livekit-server/cmd/cli/client"
@@ -122,17 +121,28 @@ func withTimeout(t *testing.T, description string, f func() bool) bool {
func waitUntilConnected(t *testing.T, clients ...*client.RTCClient) {
logger.Infow("waiting for clients to become connected")
wg := sync.WaitGroup{}
errChan := make(chan error, len(clients))
for i := range clients {
c := clients[i]
wg.Add(1)
go func() {
defer wg.Done()
if !assert.NoError(t, c.WaitUntilConnected()) {
t.Fatal("client could not connect", c.ID())
err := c.WaitUntilConnected()
if err != nil {
errChan <- err
}
}()
}
wg.Wait()
close(errChan)
hasError := false
for err := range errChan {
t.Fatal(err)
hasError = true
}
if hasError {
t.FailNow()
}
}
func createSingleNodeServer() *service.LivekitServer {

View File

@@ -1,7 +1,6 @@
package test
import (
"math/rand"
"testing"
"time"
@@ -13,9 +12,11 @@ import (
// a scenario with lots of clients connecting, publishing, and leaving at random periods
func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
c1 := createRTCClient("puj_1", ports[rand.Intn(len(ports))])
c2 := createRTCClient("puj_2", ports[rand.Intn(len(ports))])
c3 := createRTCClient("puj_3", ports[rand.Intn(len(ports))])
firstPort := ports[0]
lastPort := ports[len(ports)-1]
c1 := createRTCClient("puj_1", firstPort)
c2 := createRTCClient("puj_2", lastPort)
c3 := createRTCClient("puj_3", firstPort)
defer stopClients(c1, c2, c3)
waitUntilConnected(t, c1, c2, c3)
@@ -63,7 +64,8 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
}
logger.Infow("c2 reconnecting")
c2 = createRTCClient("puj_2", ports[rand.Intn(len(ports))])
// connect to a diff port
c2 = createRTCClient("puj_2", firstPort)
defer c2.Stop()
waitUntilConnected(t, c2)
writers = publishTracksForClients(t, c2)