From 9a7ea7a2fa5c4579b30368fc09216de31f8a528a Mon Sep 17 00:00:00 2001 From: David Zhao Date: Thu, 9 Feb 2023 17:27:33 -0800 Subject: [PATCH] Close previous request channels when during initial retry (#1409) So we don't leave abandoned requests hanging on the media instance --- pkg/routing/interfaces.go | 2 + pkg/routing/redis.go | 8 +++ pkg/routing/routingfakes/fake_message_sink.go | 65 +++++++++++++++++++ .../routingfakes/fake_message_source.go | 65 +++++++++++++++++++ pkg/service/rtcservice.go | 3 + 5 files changed, 143 insertions(+) diff --git a/pkg/routing/interfaces.go b/pkg/routing/interfaces.go index 9db0eb9ac..8363cc892 100644 --- a/pkg/routing/interfaces.go +++ b/pkg/routing/interfaces.go @@ -20,6 +20,7 @@ import ( //counterfeiter:generate . MessageSink type MessageSink interface { WriteMessage(msg proto.Message) error + IsClosed() bool Close() } @@ -27,6 +28,7 @@ type MessageSink interface { type MessageSource interface { // ReadChan exposes a one way channel to make it easier to use with select ReadChan() <-chan proto.Message + IsClosed() bool Close() } diff --git a/pkg/routing/redis.go b/pkg/routing/redis.go index b31548d96..6e7928261 100644 --- a/pkg/routing/redis.go +++ b/pkg/routing/redis.go @@ -129,6 +129,10 @@ func (s *RTCNodeSink) Close() { } } +func (s *RTCNodeSink) IsClosed() bool { + return s.isClosed.Load() +} + func (s *RTCNodeSink) OnClose(f func()) { s.onClose = f } @@ -166,6 +170,10 @@ func (s *SignalNodeSink) Close() { } } +func (s *SignalNodeSink) IsClosed() bool { + return s.isClosed.Load() +} + func (s *SignalNodeSink) OnClose(f func()) { s.onClose = f } diff --git a/pkg/routing/routingfakes/fake_message_sink.go b/pkg/routing/routingfakes/fake_message_sink.go index 00c5fba85..ec53c4309 100644 --- a/pkg/routing/routingfakes/fake_message_sink.go +++ b/pkg/routing/routingfakes/fake_message_sink.go @@ -13,6 +13,16 @@ type FakeMessageSink struct { closeMutex sync.RWMutex closeArgsForCall []struct { } + IsClosedStub func() bool + isClosedMutex sync.RWMutex + isClosedArgsForCall []struct { + } + isClosedReturns struct { + result1 bool + } + isClosedReturnsOnCall map[int]struct { + result1 bool + } WriteMessageStub func(protoreflect.ProtoMessage) error writeMessageMutex sync.RWMutex writeMessageArgsForCall []struct { @@ -52,6 +62,59 @@ func (fake *FakeMessageSink) CloseCalls(stub func()) { fake.CloseStub = stub } +func (fake *FakeMessageSink) IsClosed() bool { + fake.isClosedMutex.Lock() + ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] + fake.isClosedArgsForCall = append(fake.isClosedArgsForCall, struct { + }{}) + stub := fake.IsClosedStub + fakeReturns := fake.isClosedReturns + fake.recordInvocation("IsClosed", []interface{}{}) + fake.isClosedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSink) IsClosedCallCount() int { + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() + return len(fake.isClosedArgsForCall) +} + +func (fake *FakeMessageSink) IsClosedCalls(stub func() bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = stub +} + +func (fake *FakeMessageSink) IsClosedReturns(result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + fake.isClosedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMessageSink) IsClosedReturnsOnCall(i int, result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + if fake.isClosedReturnsOnCall == nil { + fake.isClosedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isClosedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeMessageSink) WriteMessage(arg1 protoreflect.ProtoMessage) error { fake.writeMessageMutex.Lock() ret, specificReturn := fake.writeMessageReturnsOnCall[len(fake.writeMessageArgsForCall)] @@ -118,6 +181,8 @@ func (fake *FakeMessageSink) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() fake.writeMessageMutex.RLock() defer fake.writeMessageMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/routing/routingfakes/fake_message_source.go b/pkg/routing/routingfakes/fake_message_source.go index 5a9e7b34a..acfe7606c 100644 --- a/pkg/routing/routingfakes/fake_message_source.go +++ b/pkg/routing/routingfakes/fake_message_source.go @@ -13,6 +13,16 @@ type FakeMessageSource struct { closeMutex sync.RWMutex closeArgsForCall []struct { } + IsClosedStub func() bool + isClosedMutex sync.RWMutex + isClosedArgsForCall []struct { + } + isClosedReturns struct { + result1 bool + } + isClosedReturnsOnCall map[int]struct { + result1 bool + } ReadChanStub func() <-chan protoreflect.ProtoMessage readChanMutex sync.RWMutex readChanArgsForCall []struct { @@ -51,6 +61,59 @@ func (fake *FakeMessageSource) CloseCalls(stub func()) { fake.CloseStub = stub } +func (fake *FakeMessageSource) IsClosed() bool { + fake.isClosedMutex.Lock() + ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] + fake.isClosedArgsForCall = append(fake.isClosedArgsForCall, struct { + }{}) + stub := fake.IsClosedStub + fakeReturns := fake.isClosedReturns + fake.recordInvocation("IsClosed", []interface{}{}) + fake.isClosedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSource) IsClosedCallCount() int { + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() + return len(fake.isClosedArgsForCall) +} + +func (fake *FakeMessageSource) IsClosedCalls(stub func() bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = stub +} + +func (fake *FakeMessageSource) IsClosedReturns(result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + fake.isClosedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMessageSource) IsClosedReturnsOnCall(i int, result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + if fake.isClosedReturnsOnCall == nil { + fake.isClosedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isClosedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeMessageSource) ReadChan() <-chan protoreflect.ProtoMessage { fake.readChanMutex.Lock() ret, specificReturn := fake.readChanReturnsOnCall[len(fake.readChanArgsForCall)] @@ -109,6 +172,8 @@ func (fake *FakeMessageSource) Invocations() map[string][][]interface{} { defer fake.invocationsMutex.RUnlock() fake.closeMutex.RLock() defer fake.closeMutex.RUnlock() + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() fake.readChanMutex.RLock() defer fake.readChanMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/pkg/service/rtcservice.go b/pkg/service/rtcservice.go index a78c952ee..6f744d81c 100644 --- a/pkg/service/rtcservice.go +++ b/pkg/service/rtcservice.go @@ -458,6 +458,9 @@ func (s *RTCService) startConnection(ctx context.Context, roomName livekit.RoomN // instead of waiting forever on the WebSocket cr.InitialResponse, err = readInitialResponse(cr.ResponseSource, timeout) if err != nil { + // close the connection to avoid leaking + cr.RequestSink.Close() + cr.ResponseSource.Close() return cr, err } return cr, nil