From d13a962afd8400e08c4cd56dd8b19a6c7398e2ed Mon Sep 17 00:00:00 2001 From: David Zhao Date: Sun, 7 Feb 2021 21:58:20 -0800 Subject: [PATCH] fixed message channel deadlock --- pkg/routing/errors.go | 1 + pkg/routing/messagechannel.go | 13 ++++++++++--- pkg/routing/redis.go | 5 ++++- pkg/routing/redisrouter.go | 20 ++++++++++++++------ 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/pkg/routing/errors.go b/pkg/routing/errors.go index 35e45a654..5737ec6ac 100644 --- a/pkg/routing/errors.go +++ b/pkg/routing/errors.go @@ -9,4 +9,5 @@ var ( ErrIncorrectRTCNode = errors.New("current node isn't the RTC node for the room") errInvalidRouterMessage = errors.New("invalid router message") ErrChannelClosed = errors.New("channel closed") + ErrChannelFull = errors.New("channel is full") ) diff --git a/pkg/routing/messagechannel.go b/pkg/routing/messagechannel.go index e7ac6ed05..07bfb1d9e 100644 --- a/pkg/routing/messagechannel.go +++ b/pkg/routing/messagechannel.go @@ -15,7 +15,7 @@ type MessageChannel struct { func NewMessageChannel() *MessageChannel { return &MessageChannel{ // allow some buffer to avoid blocked writes - msgChan: make(chan proto.Message, 2), + msgChan: make(chan proto.Message, 10), } } @@ -27,8 +27,15 @@ func (m *MessageChannel) WriteMessage(msg proto.Message) error { if m.isClosed.Get() { return ErrChannelClosed } - m.msgChan <- msg - return nil + + select { + case m.msgChan <- msg: + // published + return nil + default: + // channel is full + return ErrChannelFull + } } func (m *MessageChannel) ReadChan() <-chan proto.Message { diff --git a/pkg/routing/redis.go b/pkg/routing/redis.go index 1025149a0..fca335db3 100644 --- a/pkg/routing/redis.go +++ b/pkg/routing/redis.go @@ -59,7 +59,7 @@ func publishRTCMessage(rc *redis.Client, nodeId string, participantKey string, m return err } - //logger.Debugw("publishing to", "rtcChannel", rtcNodeChannel(nodeId), + //logger.Debugw("publishing to rtc", "rtcChannel", rtcNodeChannel(nodeId), // "message", rm.Message) return rc.Publish(redisCtx, rtcNodeChannel(nodeId), data).Err() } @@ -84,6 +84,9 @@ func publishSignalMessage(rc *redis.Client, nodeId string, connectionId string, if err != nil { return err } + + //logger.Debugw("publishing to signal", "signalChannel", signalNodeChannel(nodeId), + // "message", rm.Message) return rc.Publish(redisCtx, signalNodeChannel(nodeId), data).Err() } diff --git a/pkg/routing/redisrouter.go b/pkg/routing/redisrouter.go index 5cba088e7..118703b1a 100644 --- a/pkg/routing/redisrouter.go +++ b/pkg/routing/redisrouter.go @@ -351,21 +351,27 @@ func (r *RedisRouter) redisWorker() { func (r *RedisRouter) handleSignalMessage(sm *livekit.SignalNodeMessage) error { connectionId := sm.ConnectionId + r.lock.Lock() + resSink := r.responseChannels[connectionId] + r.lock.Unlock() + + // if a client closed the channel, then sent more messages after that, + if resSink == nil { + return nil + } + switch rmb := sm.Message.(type) { case *livekit.SignalNodeMessage_Response: //logger.Debugw("forwarding signal message", // "connectionId", connectionId, // "type", fmt.Sprintf("%T", rmb.Response.Message)) - // in the event the current node is an Signal node, push to response channels - resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) if err := resSink.WriteMessage(rmb.Response); err != nil { return err } case *livekit.SignalNodeMessage_EndSession: - logger.Debugw("received EndSession, closing signal connection", - "connectionId", connectionId) - resSink := r.getOrCreateMessageChannel(r.responseChannels, connectionId) + //logger.Debugw("received EndSession, closing signal connection", + // "connectionId", connectionId) resSink.Close() } return nil @@ -382,7 +388,9 @@ func (r *RedisRouter) handleRTCMessage(rm *livekit.RTCNodeMessage) error { } case *livekit.RTCNodeMessage_Request: - requestChan := r.getOrCreateMessageChannel(r.requestChannels, pKey) + r.lock.Lock() + requestChan := r.requestChannels[pKey] + r.lock.Unlock() if err := requestChan.WriteMessage(rmb.Request); err != nil { return err }