From 670b3b79749bfb48a04ee40b8c441e9ca68ad41a Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Thu, 20 Jan 2022 18:33:02 +0000 Subject: [PATCH] coalesce requests to connect to SMP servers, to have 1 connection per server (#305) * coalesce requests to connect to SMP servers * fix possible race condition when creating new SMP client * one more race condition * close pending SMP clients --- src/Simplex/Messaging/Agent/Client.hs | 47 ++++++++++++++++++++------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 6bf99c0f5..e596602e0 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -36,6 +36,7 @@ module Simplex.Messaging.Agent.Client ) where +import Control.Concurrent (forkIO) import Control.Concurrent.Async (Async, async, uninterruptibleCancel) import Control.Concurrent.STM (stateTVar) import Control.Logger.Simple @@ -61,17 +62,19 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Util (bshow, liftEitherError, liftError) +import Simplex.Messaging.Util (bshow, liftEitherError, liftError, liftIOEither, tryError) import Simplex.Messaging.Version import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E import UnliftIO.STM +type SMPClientVar = TMVar (Either AgentErrorType SMPClient) + data AgentClient = AgentClient { rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue SMPServerTransmission, - smpClients :: TVar (Map SMPServer SMPClient), + smpClients :: TVar (Map SMPServer SMPClientVar), subscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)), subscrConns :: TVar (Map ConnId SMPServer), connMsgsQueued :: TVar (Map ConnId Bool), @@ -118,15 +121,32 @@ instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient getSMPServerClient c@AgentClient {smpClients, msgQ} srv = - readTVarIO smpClients - >>= maybe newSMPClient return . M.lookup srv + atomically getClientVar >>= either newSMPClient waitForSMPClient where - newSMPClient :: m SMPClient - newSMPClient = do - smp <- connectClient - logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - atomically . modifyTVar smpClients $ M.insert srv smp - return smp + getClientVar :: STM (Either SMPClientVar SMPClientVar) + getClientVar = maybe (Left <$> newClientVar) (pure . Right) . M.lookup srv =<< readTVar smpClients + + newClientVar :: STM SMPClientVar + newClientVar = do + smpVar <- newEmptyTMVar + modifyTVar smpClients $ M.insert srv smpVar + pure smpVar + + waitForSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient + waitForSMPClient = liftIOEither . atomically . readTMVar + + newSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient + newSMPClient smpVar = + tryError connectClient >>= \r -> case r of + Right smp -> do + logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv + atomically $ putTMVar smpVar r + pure smp + Left e -> do + atomically $ do + putTMVar smpVar r + modifyTVar smpClients $ M.delete srv + throwError e connectClient :: m SMPClient connectClient = do @@ -189,7 +209,12 @@ closeAgentClient c = liftIO $ do cancelActions $ smpQueueMsgDeliveries c closeSMPServerClients :: AgentClient -> IO () -closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ closeSMPClient +closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient) + where + closeClient smpVar = + atomically (readTMVar smpVar) >>= \case + Right smp -> closeSMPClient smp `E.catch` \(_ :: E.SomeException) -> pure () + _ -> pure () cancelActions :: Foldable f => TVar (f (Async ())) -> IO () cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel