coalesce requests to connect to SMP servers, to have 1 connection per server (#305)

* coalesce requests to connect to SMP servers

* fix possible race condition when creating new SMP client

* one more race condition

* close pending SMP clients
This commit is contained in:
Evgeny Poberezkin
2022-01-20 18:33:02 +00:00
committed by GitHub
parent 305ae94cce
commit 670b3b7974
+36 -11
View File
@@ -36,6 +36,7 @@ module Simplex.Messaging.Agent.Client
)
where
import Control.Concurrent (forkIO)
import Control.Concurrent.Async (Async, async, uninterruptibleCancel)
import Control.Concurrent.STM (stateTVar)
import Control.Logger.Simple
@@ -61,17 +62,19 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Util (bshow, liftEitherError, liftError)
import Simplex.Messaging.Util (bshow, liftEitherError, liftError, liftIOEither, tryError)
import Simplex.Messaging.Version
import UnliftIO.Exception (Exception, IOException)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
data AgentClient = AgentClient
{ rcvQ :: TBQueue (ATransmission 'Client),
subQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue SMPServerTransmission,
smpClients :: TVar (Map SMPServer SMPClient),
smpClients :: TVar (Map SMPServer SMPClientVar),
subscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)),
subscrConns :: TVar (Map ConnId SMPServer),
connMsgsQueued :: TVar (Map ConnId Bool),
@@ -118,15 +121,32 @@ instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
readTVarIO smpClients
>>= maybe newSMPClient return . M.lookup srv
atomically getClientVar >>= either newSMPClient waitForSMPClient
where
newSMPClient :: m SMPClient
newSMPClient = do
smp <- connectClient
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically . modifyTVar smpClients $ M.insert srv smp
return smp
getClientVar :: STM (Either SMPClientVar SMPClientVar)
getClientVar = maybe (Left <$> newClientVar) (pure . Right) . M.lookup srv =<< readTVar smpClients
newClientVar :: STM SMPClientVar
newClientVar = do
smpVar <- newEmptyTMVar
modifyTVar smpClients $ M.insert srv smpVar
pure smpVar
waitForSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient
waitForSMPClient = liftIOEither . atomically . readTMVar
newSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient
newSMPClient smpVar =
tryError connectClient >>= \r -> case r of
Right smp -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar smpVar r
pure smp
Left e -> do
atomically $ do
putTMVar smpVar r
modifyTVar smpClients $ M.delete srv
throwError e
connectClient :: m SMPClient
connectClient = do
@@ -189,7 +209,12 @@ closeAgentClient c = liftIO $ do
cancelActions $ smpQueueMsgDeliveries c
closeSMPServerClients :: AgentClient -> IO ()
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ closeSMPClient
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
where
closeClient smpVar =
atomically (readTMVar smpVar) >>= \case
Right smp -> closeSMPClient smp `E.catch` \(_ :: E.SomeException) -> pure ()
_ -> pure ()
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel