mirror of
https://github.com/livekit/livekit.git
synced 2026-03-30 22:05:39 +00:00
switch to a single redis subscriber, close properly
This commit is contained in:
@@ -27,7 +27,6 @@ import (
|
||||
|
||||
type RTCClient struct {
|
||||
id string
|
||||
identity string
|
||||
conn *websocket.Conn
|
||||
PeerConn *webrtc.PeerConnection
|
||||
// sid => track
|
||||
@@ -144,7 +143,7 @@ func NewRTCClient(conn *websocket.Conn) (*RTCClient, error) {
|
||||
|
||||
peerConn.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
|
||||
logger.Debugw("ICE state has changed", "state", connectionState.String(),
|
||||
"participant", c.localParticipant.Sid)
|
||||
"participant", c.localParticipant.Identity)
|
||||
if connectionState == webrtc.ICEConnectionStateConnected {
|
||||
// flush peers
|
||||
c.lock.Lock()
|
||||
@@ -195,8 +194,8 @@ func (c *RTCClient) Run() error {
|
||||
}
|
||||
switch msg := res.Message.(type) {
|
||||
case *livekit.SignalResponse_Join:
|
||||
c.localParticipant = msg.Join.Participant
|
||||
c.id = msg.Join.Participant.Sid
|
||||
c.identity = msg.Join.Participant.Identity
|
||||
|
||||
c.lock.Lock()
|
||||
for _, p := range msg.Join.OtherParticipants {
|
||||
@@ -205,7 +204,6 @@ func (c *RTCClient) Run() error {
|
||||
c.lock.Unlock()
|
||||
|
||||
logger.Debugw("join accepted, sending offer..", "participant", msg.Join.Participant.Identity)
|
||||
c.localParticipant = msg.Join.Participant
|
||||
logger.Debugw("other participants", "count", len(msg.Join.OtherParticipants))
|
||||
|
||||
// Create an offer to send to the other process
|
||||
@@ -214,7 +212,7 @@ func (c *RTCClient) Run() error {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Debugw("created offer", "offer", offer.SDP)
|
||||
logger.Debugw("created offer", "participant", c.localParticipant.Identity)
|
||||
|
||||
// Sets the LocalDescription, and starts our UDP listeners
|
||||
// Note: this will start the gathering of ICE candidates
|
||||
@@ -228,7 +226,6 @@ func (c *RTCClient) Run() error {
|
||||
Offer: rtc.ToProtoSessionDescription(offer),
|
||||
},
|
||||
}
|
||||
logger.Debugw("connecting to remote...")
|
||||
if err = c.SendRequest(req); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -254,7 +251,6 @@ func (c *RTCClient) Run() error {
|
||||
c.lock.Lock()
|
||||
for _, p := range msg.Update.Participants {
|
||||
c.remoteParticipants[p.Sid] = p
|
||||
logger.Debugw("participant update", "id", p.Identity, "state", p.State.String())
|
||||
}
|
||||
c.lock.Unlock()
|
||||
case *livekit.SignalResponse_TrackPublished:
|
||||
@@ -277,7 +273,11 @@ func (c *RTCClient) WaitUntilConnected() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.New("could not connect after timeout")
|
||||
id := c.ID()
|
||||
if c.localParticipant != nil {
|
||||
id = c.localParticipant.Identity
|
||||
}
|
||||
return fmt.Errorf("%s could not connect after timeout", id)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if c.iceConnected.Get() {
|
||||
return nil
|
||||
@@ -476,7 +476,7 @@ func (c *RTCClient) handleOffer(desc webrtc.SessionDescription) error {
|
||||
}
|
||||
|
||||
func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription) error {
|
||||
logger.Debugw("handling server answer")
|
||||
logger.Debugw("handling server answer", "participant", c.localParticipant.Identity)
|
||||
// remote answered the offer, establish connection
|
||||
err := c.PeerConn.SetRemoteDescription(desc)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,6 @@ package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
@@ -25,23 +24,24 @@ const (
|
||||
// Because
|
||||
type RedisRouter struct {
|
||||
LocalRouter
|
||||
rc *redis.Client
|
||||
cr *utils.CachedRedis
|
||||
ctx context.Context
|
||||
once sync.Once
|
||||
rc *redis.Client
|
||||
cr *utils.CachedRedis
|
||||
ctx context.Context
|
||||
isStarted utils.AtomicFlag
|
||||
|
||||
// map of participantKey => RTCNodeSink
|
||||
rtcSinks map[string]*RTCNodeSink
|
||||
// map of connectionId => SignalNodeSink
|
||||
signalSinks map[string]*SignalNodeSink
|
||||
cancel func()
|
||||
|
||||
pubsub *redis.PubSub
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func NewRedisRouter(currentNode LocalNode, rc *redis.Client) *RedisRouter {
|
||||
rr := &RedisRouter{
|
||||
LocalRouter: *NewLocalRouter(currentNode),
|
||||
rc: rc,
|
||||
once: sync.Once{},
|
||||
rtcSinks: make(map[string]*RTCNodeSink),
|
||||
signalSinks: make(map[string]*SignalNodeSink),
|
||||
}
|
||||
@@ -203,15 +203,20 @@ func (r *RedisRouter) startParticipantRTC(ss *livekit.StartSession, participantK
|
||||
}
|
||||
|
||||
func (r *RedisRouter) Start() error {
|
||||
r.once.Do(func() {
|
||||
go r.statsWorker()
|
||||
go r.rtcWorker()
|
||||
go r.signalWorker()
|
||||
})
|
||||
if !r.isStarted.TrySet(true) {
|
||||
return nil
|
||||
}
|
||||
go r.statsWorker()
|
||||
go r.redisWorker()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRouter) Stop() {
|
||||
if !r.isStarted.TrySet(false) {
|
||||
return
|
||||
}
|
||||
logger.Debugw("stopping RedisRouter")
|
||||
r.pubsub.Close()
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
@@ -282,115 +287,91 @@ func (r *RedisRouter) getParticipantSignalNode(connectionId string) (nodeId stri
|
||||
func (r *RedisRouter) statsWorker() {
|
||||
for r.ctx.Err() == nil {
|
||||
// update every 10 seconds
|
||||
<-time.After(statsUpdateInterval)
|
||||
r.currentNode.Stats.UpdatedAt = time.Now().Unix()
|
||||
if err := r.RegisterNode(); err != nil {
|
||||
logger.Errorw("could not update node", "error", err)
|
||||
select {
|
||||
case <-time.After(statsUpdateInterval):
|
||||
r.currentNode.Stats.UpdatedAt = time.Now().Unix()
|
||||
if err := r.RegisterNode(); err != nil {
|
||||
logger.Errorw("could not update node", "error", err)
|
||||
}
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker that consumes signal channel and processes
|
||||
func (r *RedisRouter) signalWorker() {
|
||||
sub := r.rc.Subscribe(redisCtx, signalNodeChannel(r.currentNode.Id))
|
||||
// worker that consumes redis messages intended for this node
|
||||
func (r *RedisRouter) redisWorker() {
|
||||
defer func() {
|
||||
logger.Debugw("finishing redis signalWorker", "node", r.currentNode.Id)
|
||||
logger.Debugw("finishing redisWorker", "node", r.currentNode.Id)
|
||||
}()
|
||||
logger.Debugw("starting redis signalWorker", "node", r.currentNode.Id)
|
||||
for r.ctx.Err() == nil {
|
||||
obj, err := sub.Receive(r.ctx)
|
||||
if err != nil {
|
||||
logger.Warnw("error receiving redis message", "error", err)
|
||||
// TODO: retry? ignore? at a minimum need to sleep here to retry
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if obj == nil {
|
||||
logger.Debugw("starting redisWorker", "node", r.currentNode.Id)
|
||||
|
||||
sigChannel := signalNodeChannel(r.currentNode.Id)
|
||||
rtcChannel := rtcNodeChannel(r.currentNode.Id)
|
||||
r.pubsub = r.rc.Subscribe(r.ctx, sigChannel, rtcChannel)
|
||||
for msg := range r.pubsub.Channel() {
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, ok := obj.(*redis.Message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
rm := livekit.SignalNodeMessage{}
|
||||
err = proto.Unmarshal([]byte(msg.Payload), &rm)
|
||||
connectionId := rm.ConnectionId
|
||||
|
||||
switch rmb := rm.Message.(type) {
|
||||
|
||||
case *livekit.SignalNodeMessage_Response:
|
||||
// in the event the current node is an Signal node, push to response channels
|
||||
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
|
||||
err = resSink.WriteMessage(rmb.Response)
|
||||
if err != nil {
|
||||
logger.Errorw("could not write to response channel",
|
||||
"connectionId", connectionId,
|
||||
"error", err)
|
||||
}
|
||||
|
||||
case *livekit.SignalNodeMessage_EndSession:
|
||||
//signalNode, err := r.getParticipantSignalNode(connectionId)
|
||||
if err != nil {
|
||||
logger.Errorw("could not get participant RTC node",
|
||||
"error", err)
|
||||
if msg.Channel == sigChannel {
|
||||
sm := livekit.SignalNodeMessage{}
|
||||
if err := proto.Unmarshal([]byte(msg.Payload), &sm); err != nil {
|
||||
logger.Errorw("could not unmarshal signal message on sigchan", "error", err)
|
||||
continue
|
||||
}
|
||||
// EndSession can only be initiated on an RTC node, is handled on the signal node
|
||||
//if signalNode == r.currentNode.Id {
|
||||
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
|
||||
resSink.Close()
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker that consumes RTC channel and processes
|
||||
func (r *RedisRouter) rtcWorker() {
|
||||
sub := r.rc.Subscribe(redisCtx, rtcNodeChannel(r.currentNode.Id))
|
||||
|
||||
defer func() {
|
||||
logger.Debugw("finishing redis rtcWorker", "node", r.currentNode.Id)
|
||||
}()
|
||||
logger.Debugw("starting redis rtcWorker", "node", r.currentNode.Id)
|
||||
for r.ctx.Err() == nil {
|
||||
obj, err := sub.Receive(r.ctx)
|
||||
if err != nil {
|
||||
logger.Warnw("error receiving redis message", "error", err)
|
||||
// TODO: retry? ignore? at a minimum need to sleep here to retry
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if obj == nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, ok := obj.(*redis.Message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
rm := livekit.RTCNodeMessage{}
|
||||
err = proto.Unmarshal([]byte(msg.Payload), &rm)
|
||||
pKey := rm.ParticipantKey
|
||||
|
||||
switch rmb := rm.Message.(type) {
|
||||
case *livekit.RTCNodeMessage_StartSession:
|
||||
// RTC session should start on this node
|
||||
err = r.startParticipantRTC(rmb.StartSession, pKey)
|
||||
if err != nil {
|
||||
logger.Errorw("could not start participant", "error", err)
|
||||
if err := r.handleSignalMessage(&sm); err != nil {
|
||||
logger.Errorw("error processing signal message", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
case *livekit.RTCNodeMessage_Request:
|
||||
requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey)
|
||||
err = requestChan.WriteMessage(rmb.Request)
|
||||
if err != nil {
|
||||
logger.Errorw("could not write to request channel",
|
||||
"participant", pKey,
|
||||
"error", err)
|
||||
} else if msg.Channel == rtcChannel {
|
||||
rm := livekit.RTCNodeMessage{}
|
||||
if err := proto.Unmarshal([]byte(msg.Payload), &rm); err != nil {
|
||||
logger.Errorw("could not unmarshal RTC message on rtcchan", "error", err)
|
||||
continue
|
||||
}
|
||||
if err := r.handleRTCMessage(&rm); err != nil {
|
||||
logger.Errorw("error processing RTC message", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisRouter) handleSignalMessage(sm *livekit.SignalNodeMessage) error {
|
||||
connectionId := sm.ConnectionId
|
||||
|
||||
switch rmb := sm.Message.(type) {
|
||||
case *livekit.SignalNodeMessage_Response:
|
||||
// in the event the current node is an Signal node, push to response channels
|
||||
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
|
||||
if err := resSink.WriteMessage(rmb.Response); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *livekit.SignalNodeMessage_EndSession:
|
||||
logger.Debugw("received EndSession, closing signal connection")
|
||||
resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId)
|
||||
resSink.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error {
|
||||
pKey := rm.ParticipantKey
|
||||
|
||||
switch rmb := rm.Message.(type) {
|
||||
case *livekit.RTCNodeMessage_StartSession:
|
||||
// RTC session should start on this node
|
||||
if err := r.startParticipantRTC(rmb.StartSession, pKey); err != nil {
|
||||
return errors.Wrap(err, "could not start participant")
|
||||
}
|
||||
|
||||
case *livekit.RTCNodeMessage_Request:
|
||||
requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey)
|
||||
if err := requestChan.WriteMessage(rmb.Request); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func (r *Room) Join(participant types.Participant) error {
|
||||
"oldState", oldState)
|
||||
r.broadcastParticipantState(p)
|
||||
|
||||
if oldState == livekit.ParticipantInfo_JOINING && p.State() == livekit.ParticipantInfo_JOINED {
|
||||
if p.State() == livekit.ParticipantInfo_ACTIVE {
|
||||
// subscribe participant to existing publishedTracks
|
||||
for _, op := range r.GetParticipants() {
|
||||
if p.ID() == op.ID() {
|
||||
@@ -215,7 +215,7 @@ func (r *Room) onTrackAdded(participant types.Participant, track types.Published
|
||||
// skip publishing participant
|
||||
continue
|
||||
}
|
||||
if !existingParticipant.IsReady() {
|
||||
if existingParticipant.State() != livekit.ParticipantInfo_ACTIVE {
|
||||
// not fully joined. don't subscribe yet
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -76,8 +76,8 @@ func TestRoomJoin(t *testing.T) {
|
||||
|
||||
stateChangeCB := p.OnStateChangeArgsForCall(0)
|
||||
assert.NotNil(t, stateChangeCB)
|
||||
p.StateReturns(livekit.ParticipantInfo_JOINED)
|
||||
stateChangeCB(p, livekit.ParticipantInfo_JOINING)
|
||||
p.StateReturns(livekit.ParticipantInfo_ACTIVE)
|
||||
stateChangeCB(p, livekit.ParticipantInfo_JOINED)
|
||||
|
||||
// it should become a subscriber when connectivity changes
|
||||
for _, op := range rm.GetParticipants() {
|
||||
@@ -176,9 +176,9 @@ func TestNewTrack(t *testing.T) {
|
||||
rm := newRoomWithParticipants(t, 3)
|
||||
participants := rm.GetParticipants()
|
||||
p0 := participants[0].(*typesfakes.FakeParticipant)
|
||||
p0.IsReadyReturns(false)
|
||||
p0.StateReturns(livekit.ParticipantInfo_JOINED)
|
||||
p1 := participants[1].(*typesfakes.FakeParticipant)
|
||||
p1.IsReadyReturns(true)
|
||||
p1.StateReturns(livekit.ParticipantInfo_ACTIVE)
|
||||
|
||||
pub := participants[2].(*typesfakes.FakeParticipant)
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ func (s *RTCService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case msg := <-resSource.ReadChan():
|
||||
if msg == nil {
|
||||
logger.Errorw("source closed connection", "participant", identity)
|
||||
return
|
||||
}
|
||||
res, ok := msg.(*livekit.SignalResponse)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/twitchtv/twirp"
|
||||
|
||||
"github.com/livekit/livekit-server/cmd/cli/client"
|
||||
@@ -122,17 +121,28 @@ func withTimeout(t *testing.T, description string, f func() bool) bool {
|
||||
func waitUntilConnected(t *testing.T, clients ...*client.RTCClient) {
|
||||
logger.Infow("waiting for clients to become connected")
|
||||
wg := sync.WaitGroup{}
|
||||
errChan := make(chan error, len(clients))
|
||||
for i := range clients {
|
||||
c := clients[i]
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if !assert.NoError(t, c.WaitUntilConnected()) {
|
||||
t.Fatal("client could not connect", c.ID())
|
||||
err := c.WaitUntilConnected()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
hasError := false
|
||||
for err := range errChan {
|
||||
t.Fatal(err)
|
||||
hasError = true
|
||||
}
|
||||
if hasError {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func createSingleNodeServer() *service.LivekitServer {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,9 +12,11 @@ import (
|
||||
|
||||
// a scenario with lots of clients connecting, publishing, and leaving at random periods
|
||||
func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
|
||||
c1 := createRTCClient("puj_1", ports[rand.Intn(len(ports))])
|
||||
c2 := createRTCClient("puj_2", ports[rand.Intn(len(ports))])
|
||||
c3 := createRTCClient("puj_3", ports[rand.Intn(len(ports))])
|
||||
firstPort := ports[0]
|
||||
lastPort := ports[len(ports)-1]
|
||||
c1 := createRTCClient("puj_1", firstPort)
|
||||
c2 := createRTCClient("puj_2", lastPort)
|
||||
c3 := createRTCClient("puj_3", firstPort)
|
||||
defer stopClients(c1, c2, c3)
|
||||
|
||||
waitUntilConnected(t, c1, c2, c3)
|
||||
@@ -63,7 +64,8 @@ func scenarioPublishingUponJoining(t *testing.T, ports ...int) {
|
||||
}
|
||||
|
||||
logger.Infow("c2 reconnecting")
|
||||
c2 = createRTCClient("puj_2", ports[rand.Intn(len(ports))])
|
||||
// connect to a diff port
|
||||
c2 = createRTCClient("puj_2", firstPort)
|
||||
defer c2.Stop()
|
||||
waitUntilConnected(t, c2)
|
||||
writers = publishTracksForClients(t, c2)
|
||||
|
||||
Reference in New Issue
Block a user