From c380431b94424d47ac3029ad88756f6dda544664 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 28 Mar 2022 07:30:29 +0100 Subject: [PATCH 1/3] resubscribe concurrently when subscription is resumed (#339) * resubscribe concurrently when subscription is resumed * use strict modifyTVar, refactor with TMap * add inline * refactor --- simplexmq.cabal | 1 + src/Simplex/Messaging/Agent.hs | 16 ++-- src/Simplex/Messaging/Agent/Client.hs | 86 ++++++++++---------- src/Simplex/Messaging/Client.hs | 15 ++-- src/Simplex/Messaging/Server.hs | 31 ++++--- src/Simplex/Messaging/Server/Env/STM.hs | 20 +++-- src/Simplex/Messaging/Server/MsgStore/STM.hs | 2 +- src/Simplex/Messaging/TMap.hs | 55 +++++++++++++ src/Simplex/Messaging/Transport/Server.hs | 2 +- src/Simplex/Messaging/Util.hs | 4 + 10 files changed, 143 insertions(+), 89 deletions(-) create mode 100644 src/Simplex/Messaging/TMap.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index 5762c6271..b863614d6 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -55,6 +55,7 @@ library Simplex.Messaging.Server.QueueStore Simplex.Messaging.Server.QueueStore.STM Simplex.Messaging.Server.StoreLog + Simplex.Messaging.TMap Simplex.Messaging.Transport Simplex.Messaging.Transport.Client Simplex.Messaging.Transport.KeepAlive diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f96e8a723..f1ecad909 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -83,6 +83,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parse) import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM) import Simplex.Messaging.Version import System.Random (randomR) @@ -361,18 +362,13 @@ resumeMsgDelivery c connId sq@SndQueue {server, sndId} = do let qKey = (connId, server, sndId) unlessM (queueDelivering qKey) $ async (runSmpQueueMsgDelivery c connId sq) - >>= atomically . modifyTVar (smpQueueMsgDeliveries c) . M.insert qKey + >>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c) unlessM connQueued $ withStore (`getPendingMsgs` connId) >>= queuePendingMsgs c connId sq where - queueDelivering qKey = isJust . M.lookup qKey <$> readTVarIO (smpQueueMsgDeliveries c) - connQueued = - atomically $ - isJust - <$> stateTVar - (connMsgsQueued c) - (\m -> (M.lookup connId m, M.insert connId True m)) + queueDelivering qKey = atomically $ isJust <$> TM.lookup qKey (smpQueueMsgDeliveries c) + connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c) queuePendingMsgs :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> [InternalId] -> m () queuePendingMsgs c connId sq msgIds = atomically $ do @@ -382,11 +378,11 @@ queuePendingMsgs c connId sq msgIds = atomically $ do getPendingMsgQ :: AgentClient -> ConnId -> SndQueue -> STM (TQueue InternalId) getPendingMsgQ c connId SndQueue {server, sndId} = do let qKey = (connId, server, sndId) - maybe (newMsgQueue qKey) pure . M.lookup qKey =<< readTVar (smpQueueMsgQueues c) + maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c) where newMsgQueue qKey = do mq <- newTQueue - modifyTVar (smpQueueMsgQueues c) $ M.insert qKey mq + TM.insert qKey mq $ smpQueueMsgQueues c pure mq runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnId -> SndQueue -> m () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 140bcb20b..dfe538844 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -51,8 +51,6 @@ import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (isNothing) -import Data.Set (Set) -import qualified Data.Set as S import Data.Text.Encoding import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol @@ -63,10 +61,12 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError) import Simplex.Messaging.Version import System.Timeout (timeout) -import UnliftIO (async) +import UnliftIO (async, forConcurrently_) import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -78,13 +78,13 @@ data AgentClient = AgentClient subQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue SMPServerTransmission, smpServers :: TVar (NonEmpty SMPServer), - smpClients :: TVar (Map SMPServer SMPClientVar), - subscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)), - pendingSubscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)), - subscrConns :: TVar (Map ConnId SMPServer), - connMsgsQueued :: TVar (Map ConnId Bool), - smpQueueMsgQueues :: TVar (Map (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId)), - smpQueueMsgDeliveries :: TVar (Map (ConnId, SMPServer, SMP.SenderId) (Async ())), + smpClients :: TMap SMPServer SMPClientVar, + subscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue), + pendingSubscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue), + subscrConns :: TMap ConnId SMPServer, + connMsgsQueued :: TMap ConnId Bool, + smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId), + smpQueueMsgDeliveries :: TMap (ConnId, SMPServer, SMP.SenderId) (Async ()), reconnections :: TVar [Async ()], asyncClients :: TVar [Async ()], clientId :: Int, @@ -100,13 +100,13 @@ newAgentClient agentEnv = do subQ <- newTBQueue qSize msgQ <- newTBQueue qSize smpServers <- newTVar $ initialSMPServers (config agentEnv) - smpClients <- newTVar M.empty - subscrSrvrs <- newTVar M.empty - pendingSubscrSrvrs <- newTVar M.empty - subscrConns <- newTVar M.empty - connMsgsQueued <- newTVar M.empty - smpQueueMsgQueues <- newTVar M.empty - smpQueueMsgDeliveries <- newTVar M.empty + smpClients <- TM.empty + subscrSrvrs <- TM.empty + pendingSubscrSrvrs <- TM.empty + subscrConns <- TM.empty + connMsgsQueued <- TM.empty + smpQueueMsgQueues <- TM.empty + smpQueueMsgDeliveries <- TM.empty reconnections <- newTVar [] asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) @@ -133,12 +133,12 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = atomically getClientVar >>= either newSMPClient waitForSMPClient where getClientVar :: STM (Either SMPClientVar SMPClientVar) - getClientVar = maybe (Left <$> newClientVar) (pure . Right) . M.lookup srv =<< readTVar smpClients + getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients newClientVar :: STM SMPClientVar newClientVar = do smpVar <- newEmptyTMVar - modifyTVar smpClients $ M.insert srv smpVar + TM.insert srv smpVar smpClients pure smpVar waitForSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient @@ -165,12 +165,12 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = then retryAction else atomically $ do putTMVar smpVar (Left e) - modifyTVar smpClients $ M.delete srv + TM.delete srv smpClients throwError e tryConnectAsync :: m () tryConnectAsync = do a <- async connectAsync - atomically $ modifyTVar (asyncClients c) (a :) + atomically $ modifyTVar' (asyncClients c) (a :) connectAsync :: m () connectAsync = do ri <- asks $ reconnectInterval . config @@ -193,18 +193,16 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = removeClientSubs :: IO (Maybe (Map ConnId RcvQueue)) removeClientSubs = atomically $ do - modifyTVar smpClients $ M.delete srv - cs <- M.lookup srv <$> readTVar (subscrSrvrs c) - modifyTVar (subscrSrvrs c) $ M.delete srv - modifyTVar (subscrConns c) $ maybe id (deleteKeys . M.keysSet) cs - mapM_ (modifyTVar (pendingSubscrSrvrs c) . addPendingSubs) cs - return cs + TM.delete srv smpClients + cs_ <- TM.lookupDelete srv $ subscrSrvrs c + forM_ cs_ $ \cs -> do + modifyTVar' (TM.tVar $ subscrConns c) (`M.withoutKeys` M.keysSet cs) + modifyTVar' (TM.tVar $ pendingSubscrSrvrs c) $ addPendingSubs cs + return cs_ where addPendingSubs :: Map ConnId RcvQueue -> Map SMPServer (Map ConnId RcvQueue) -> Map SMPServer (Map ConnId RcvQueue) addPendingSubs cs = M.alter (Just . addSubs cs) srv addSubs cs = maybe cs (M.union cs) - deleteKeys :: Ord k => Set k -> Map k a -> Map k a - deleteKeys ks m = S.foldr' M.delete m ks serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () serverDown u cs = unless (M.null cs) $ do @@ -214,7 +212,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = reconnectServer :: m () reconnectServer = do a <- async tryReconnectClient - atomically $ modifyTVar (reconnections c) (a :) + atomically $ modifyTVar' (reconnections c) (a :) tryReconnectClient :: m () tryReconnectClient = do @@ -225,9 +223,9 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = reconnectClient :: m () reconnectClient = do withAgentLock c . withSMP c srv $ \smp -> do - subs <- readTVarIO $ subscrConns c - cs <- M.lookup srv <$> readTVarIO (pendingSubscrSrvrs c) - forM_ (maybe [] M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) -> + subs <- readTVarIO . TM.tVar $ subscrConns c + cs <- atomically . TM.lookup srv $ pendingSubscrSrvrs c + forConcurrently_ (maybe [] M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) -> when (isNothing $ M.lookup connId subs) $ do subscribeSMPQueue smp rcvPrivateKey rcvId `catchError` \case @@ -245,10 +243,10 @@ closeAgentClient c = liftIO $ do closeSMPServerClients c cancelActions $ reconnections c cancelActions $ asyncClients c - cancelActions $ smpQueueMsgDeliveries c + cancelActions . TM.tVar $ smpQueueMsgDeliveries c closeSMPServerClients :: AgentClient -> IO () -closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient) +closeSMPServerClients c = readTVarIO (TM.tVar $ smpClients c) >>= mapM_ (forkIO . closeClient) where closeClient smpVar = atomically (readTMVar smpVar) >>= \case @@ -344,29 +342,29 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do - modifyTVar (subscrConns c) $ M.insert connId server - addSubs_ (subscrSrvrs c) rq connId + TM.insert connId server $ subscrConns c + addSubs_ rq connId $ subscrSrvrs c removePendingSubscription c server connId addPendingSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () addPendingSubscription c rq connId = - atomically $ addSubs_ (pendingSubscrSrvrs c) rq connId + atomically . addSubs_ rq connId $ pendingSubscrSrvrs c -addSubs_ :: TVar (Map SMPServer (Map ConnId RcvQueue)) -> RcvQueue -> ConnId -> STM () -addSubs_ ss rq@RcvQueue {server} connId = modifyTVar ss $ M.alter (Just . addSub) server +addSubs_ :: RcvQueue -> ConnId -> TMap SMPServer (Map ConnId RcvQueue) -> STM () +addSubs_ rq@RcvQueue {server} connId = TM.alter (Just . addSub) server where addSub = maybe (M.singleton connId rq) (M.insert connId rq) removeSubscription :: MonadUnliftIO m => AgentClient -> ConnId -> m () removeSubscription c@AgentClient {subscrConns} connId = atomically $ do - server_ <- stateTVar subscrConns $ \cs -> (M.lookup connId cs, M.delete connId cs) + server_ <- TM.lookupDelete connId subscrConns mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_ removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM () -removePendingSubscription c = removeSubs_ (pendingSubscrSrvrs c) +removePendingSubscription = removeSubs_ . pendingSubscrSrvrs -removeSubs_ :: TVar (Map SMPServer (Map ConnId RcvQueue)) -> SMPServer -> ConnId -> STM () -removeSubs_ ss server connId = modifyTVar ss $ M.update delSub server +removeSubs_ :: TMap SMPServer (Map ConnId RcvQueue) -> SMPServer -> ConnId -> STM () +removeSubs_ ss server connId = TM.update delSub server ss where delSub :: Map ConnId RcvQueue -> Maybe (Map ConnId RcvQueue) delSub cs = diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index e6e87337f..151b77463 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -56,13 +56,13 @@ import Control.Monad.Trans.Class import Control.Monad.Trans.Except import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Map.Strict (Map) -import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe) import Network.Socket (ServiceName) import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError, clientHandshake) import Simplex.Messaging.Transport.Client (runTransportClient) import Simplex.Messaging.Transport.KeepAlive @@ -83,7 +83,7 @@ data SMPClient = SMPClient smpServer :: SMPServer, tcpTimeout :: Int, clientCorrId :: TVar Natural, - sentCommands :: TVar (Map CorrId Request), + sentCommands :: TMap CorrId Request, sndQ :: TBQueue SentRawTransmission, rcvQ :: TBQueue (SignedTransmission BrokerMsg), msgQ :: TBQueue SMPServerTransmission @@ -137,7 +137,7 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp mkSMPClient = do connected <- newTVar False clientCorrId <- newTVar 0 - sentCommands <- newTVar M.empty + sentCommands <- TM.empty sndQ <- newTBQueue qSize rcvQ <- newTBQueue qSize return @@ -202,11 +202,10 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp if B.null $ bs corrId then sendMsg qId respOrErr else do - cs <- readTVarIO sentCommands - case M.lookup corrId cs of + atomically (TM.lookup corrId sentCommands) >>= \case Nothing -> sendMsg qId respOrErr Just Request {queueId, responseVar} -> atomically $ do - modifyTVar sentCommands $ M.delete corrId + TM.delete corrId sentCommands putTMVar responseVar $ if queueId == qId then case respOrErr of @@ -368,6 +367,6 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeou send :: CorrId -> SentRawTransmission -> STM (TMVar Response) send corrId t = do r <- newEmptyTMVar - modifyTVar sentCommands . M.insert corrId $ Request qId r + TM.insert corrId (Request qId r) sentCommands writeTBQueue sndQ t return r diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 73b5780ef..102830353 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -25,7 +25,6 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking) where -import Control.Concurrent.STM (stateTVar) import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift @@ -47,6 +46,8 @@ import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (QueueStore) import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Server import Simplex.Messaging.Util @@ -92,8 +93,8 @@ runSMPServerBlocking started cfg@ServerConfig {transports} = do MonadUnliftIO m' => Server -> (Server -> TBQueue (QueueId, Client)) -> - (Server -> TVar (M.Map QueueId Client)) -> - (Client -> TVar (M.Map QueueId s)) -> + (Server -> TMap QueueId Client) -> + (Client -> TMap QueueId s) -> (s -> m' ()) -> m' () serverThread s subQ subs clientSubs unsub = forever $ do @@ -110,13 +111,13 @@ runSMPServerBlocking started cfg@ServerConfig {transports} = do else do yes <- readTVar $ connected c' pure $ if yes then Just (qId, c') else Nothing - stateTVar (subs s) (\cs -> (M.lookup qId cs, M.insert qId clnt cs)) + TM.lookupInsert qId clnt (subs s) >>= fmap join . mapM clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> m' (Maybe s) endPreviousSubscriptions (qId, c) = do void . forkIO . atomically $ writeTBQueue (sndQ c) (CorrId "", qId, END) - atomically . stateTVar (clientSubs c) $ \ss -> (M.lookup qId ss, M.delete qId ss) + atomically $ TM.lookupDelete qId (clientSubs c) runClient :: (Transport c, MonadUnliftIO m, MonadReader Env m) => TProxy c -> c -> m () runClient _ h = do @@ -136,10 +137,10 @@ runClientTransport th@THandle {sessionId} = do clientDisconnected :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () clientDisconnected c@Client {subscriptions, connected} = do atomically $ writeTVar connected False - subs <- readTVarIO subscriptions + subs <- readTVarIO $ TM.tVar subscriptions mapM_ cancelSub subs cs <- asks $ subscribers . server - atomically . mapM_ (modifyTVar cs . M.update deleteCurrentClient) $ M.keys subs + atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs where deleteCurrentClient :: Client -> Maybe Client deleteCurrentClient c' @@ -309,21 +310,19 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri getSubscription :: RecipientId -> STM Sub getSubscription rId = do - subs <- readTVar subscriptions - case M.lookup rId subs of + TM.lookup rId subscriptions >>= \case Just s -> tryTakeTMVar (delivered s) $> s Nothing -> do writeTBQueue subscribedQ (rId, clnt) s <- newSubscription - writeTVar subscriptions $ M.insert rId s subs + TM.insert rId s subscriptions return s subscribeNotifications :: m (Transmission BrokerMsg) subscribeNotifications = atomically $ do - subs <- readTVar ntfSubscriptions - when (isNothing $ M.lookup queueId subs) $ do + whenM (isNothing <$> TM.lookup queueId ntfSubscriptions) $ do writeTBQueue ntfSubscribedQ (queueId, clnt) - writeTVar ntfSubscriptions $ M.insert queueId () subs + TM.insert queueId () ntfSubscriptions pure ok acknowledgeMsg :: m (Transmission BrokerMsg) @@ -334,7 +333,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri _ -> return $ err NO_MSG withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a) - withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId + withSub rId f = mapM f =<< TM.lookup rId subscriptions sendMessage :: QueueStore -> MsgBody -> m (Transmission BrokerMsg) sendMessage st msgBody @@ -369,7 +368,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri trySendNotification :: STM () trySendNotification = forM_ (notifier qr) $ \(nId, _) -> - mapM_ (writeNtf nId) . M.lookup nId =<< readTVar notifiers + mapM_ (writeNtf nId) =<< TM.lookup nId notifiers writeNtf :: NotifierId -> Client -> STM () writeNtf nId Client {sndQ = q} = @@ -403,7 +402,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri void setDelivered setSub :: (Sub -> Sub) -> STM () - setSub f = modifyTVar subscriptions $ M.adjust f rId + setSub f = TM.adjust f rId subscriptions setDelivered :: STM (Maybe Bool) setDelivered = withSub rId $ \s -> tryPutTMVar (delivered s) () diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index d1b4c51fc..5fe48baaf 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -21,6 +21,8 @@ import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore (QueueRec (..)) import Simplex.Messaging.Server.QueueStore.STM import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport) import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams) import System.IO (IOMode (..)) @@ -53,14 +55,14 @@ data Env = Env data Server = Server { subscribedQ :: TBQueue (RecipientId, Client), - subscribers :: TVar (Map RecipientId Client), + subscribers :: TMap RecipientId Client, ntfSubscribedQ :: TBQueue (NotifierId, Client), - notifiers :: TVar (Map NotifierId Client) + notifiers :: TMap NotifierId Client } data Client = Client - { subscriptions :: TVar (Map RecipientId Sub), - ntfSubscriptions :: TVar (Map NotifierId ()), + { subscriptions :: TMap RecipientId Sub, + ntfSubscriptions :: TMap NotifierId (), rcvQ :: TBQueue (Transmission Cmd), sndQ :: TBQueue (Transmission BrokerMsg), sessionId :: ByteString, @@ -77,15 +79,15 @@ data Sub = Sub newServer :: Natural -> STM Server newServer qSize = do subscribedQ <- newTBQueue qSize - subscribers <- newTVar M.empty + subscribers <- TM.empty ntfSubscribedQ <- newTBQueue qSize - notifiers <- newTVar M.empty + notifiers <- TM.empty return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers} newClient :: Natural -> ByteString -> STM Client newClient qSize sessionId = do - subscriptions <- newTVar M.empty - ntfSubscriptions <- newTVar M.empty + subscriptions <- TM.empty + ntfSubscriptions <- TM.empty rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize connected <- newTVar True @@ -112,7 +114,7 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile} restoreQueues queueStore s = do (queues, s') <- liftIO $ readWriteStoreLog s atomically $ - modifyTVar queueStore $ \d -> + modifyTVar' queueStore $ \d -> d { queues, senders = M.foldr' addSender M.empty queues, diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 6d0fb63a0..9ebe55bbf 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -35,7 +35,7 @@ instance MonadMsgStore STMMsgStore MsgQueue STM where delMsgQueue :: STMMsgStore -> RecipientId -> STM () delMsgQueue store rId = - modifyTVar store $ MsgStoreData . M.delete rId . messages + modifyTVar' store $ MsgStoreData . M.delete rId . messages instance MonadMsgQueue MsgQueue STM where isFull :: MsgQueue -> STM Bool diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs new file mode 100644 index 000000000..de9b293f0 --- /dev/null +++ b/src/Simplex/Messaging/TMap.hs @@ -0,0 +1,55 @@ +module Simplex.Messaging.TMap + ( TMap (..), + empty, + Simplex.Messaging.TMap.lookup, + insert, + delete, + lookupInsert, + lookupDelete, + adjust, + update, + alter, + ) +where + +import Control.Concurrent.STM +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M + +newtype TMap k a = TMap {tVar :: TVar (Map k a)} + +empty :: STM (TMap k a) +empty = TMap <$> newTVar M.empty +{-# INLINE empty #-} + +lookup :: Ord k => k -> TMap k a -> STM (Maybe a) +lookup k (TMap m) = M.lookup k <$> readTVar m +{-# INLINE lookup #-} + +insert :: Ord k => k -> a -> TMap k a -> STM () +insert k v (TMap m) = modifyTVar' m $ M.insert k v +{-# INLINE insert #-} + +delete :: Ord k => k -> TMap k a -> STM () +delete k (TMap m) = modifyTVar' m $ M.delete k +{-# INLINE delete #-} + +lookupInsert :: Ord k => k -> a -> TMap k a -> STM (Maybe a) +lookupInsert k v (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv) +{-# INLINE lookupInsert #-} + +lookupDelete :: Ord k => k -> TMap k a -> STM (Maybe a) +lookupDelete k (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv) +{-# INLINE lookupDelete #-} + +adjust :: Ord k => (a -> a) -> k -> TMap k a -> STM () +adjust f k (TMap m) = modifyTVar' m $ M.adjust f k +{-# INLINE adjust #-} + +update :: Ord k => (a -> Maybe a) -> k -> TMap k a -> STM () +update f k (TMap m) = modifyTVar' m $ M.update f k +{-# INLINE update #-} + +alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM () +alter f k (TMap m) = modifyTVar' m $ M.alter f k +{-# INLINE alter #-} diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index d91c596b9..0e68fd83c 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -41,7 +41,7 @@ runTransportServer started port serverParams server = do $ \sock -> forever $ do (connSock, _) <- accept sock tid <- forkIO $ connectClient u connSock `E.catch` \(_ :: E.SomeException) -> pure () - atomically . modifyTVar clients $ S.insert tid + atomically . modifyTVar' clients $ S.insert tid where connectClient :: UnliftIO m -> Socket -> IO () connectClient u connSock = diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 7ea2a523b..ea53bc60e 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -58,6 +58,10 @@ ifM :: Monad m => m Bool -> m a -> m a -> m a ifM ba t f = ba >>= \b -> if b then t else f {-# INLINE ifM #-} +whenM :: Monad m => m Bool -> m () -> m () +whenM b a = ifM b a $ pure () +{-# INLINE whenM #-} + unlessM :: Monad m => m Bool -> m () -> m () unlessM b = ifM b $ pure () {-# INLINE unlessM #-} From 6ef6bedc039856099a552a24f60350e838ae8a6a Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 28 Mar 2022 10:29:21 +0100 Subject: [PATCH 2/3] refactor/optimize server queue/message store (#340) * refactor/optimize server queue/message store * change fst to pattern match * server store - wrap QueueRec into TVar --- src/Simplex/Messaging/Server/Env/STM.hs | 15 +- src/Simplex/Messaging/Server/MsgStore/STM.hs | 21 +-- src/Simplex/Messaging/Server/QueueStore.hs | 1 + .../Messaging/Server/QueueStore/STM.hs | 137 ++++++++---------- src/Simplex/Messaging/TMap.hs | 5 + 5 files changed, 77 insertions(+), 102 deletions(-) diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 5fe48baaf..5187c0d73 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -111,15 +111,12 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile} return Env {config, server, serverIdentity, queueStore, msgStore, idsDrg, storeLog = s', tlsServerParams} where restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode) - restoreQueues queueStore s = do - (queues, s') <- liftIO $ readWriteStoreLog s - atomically $ - modifyTVar' queueStore $ \d -> - d - { queues, - senders = M.foldr' addSender M.empty queues, - notifiers = M.foldr' addNotifier M.empty queues - } + restoreQueues QueueStore {queues, senders, notifiers} s = do + (qs, s') <- liftIO $ readWriteStoreLog s + atomically $ do + writeTVar (TM.tVar queues) =<< mapM newTVar qs + writeTVar (TM.tVar senders) $ M.foldr' addSender M.empty qs + writeTVar (TM.tVar notifiers) $ M.foldr' addNotifier M.empty qs pure s' addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 9ebe55bbf..86d6db996 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -6,36 +6,31 @@ module Simplex.Messaging.Server.MsgStore.STM where -import Data.Map.Strict (Map) -import qualified Data.Map.Strict as M import Numeric.Natural import Simplex.Messaging.Protocol (RecipientId) import Simplex.Messaging.Server.MsgStore +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import UnliftIO.STM newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message} -newtype MsgStoreData = MsgStoreData {messages :: Map RecipientId MsgQueue} - -type STMMsgStore = TVar MsgStoreData +type STMMsgStore = TMap RecipientId MsgQueue newMsgStore :: STM STMMsgStore -newMsgStore = newTVar $ MsgStoreData M.empty +newMsgStore = TM.empty instance MonadMsgStore STMMsgStore MsgQueue STM where getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue - getMsgQueue store rId quota = do - m <- messages <$> readTVar store - maybe (newQ m) return $ M.lookup rId m + getMsgQueue st rId quota = maybe newQ pure =<< TM.lookup rId st where - newQ m' = do + newQ = do q <- MsgQueue <$> newTBQueue quota - writeTVar store . MsgStoreData $ M.insert rId q m' + TM.insert rId q st return q delMsgQueue :: STMMsgStore -> RecipientId -> STM () - delMsgQueue store rId = - modifyTVar' store $ MsgStoreData . M.delete rId . messages + delMsgQueue st rId = TM.delete rId st instance MonadMsgQueue MsgQueue STM where isFull :: MsgQueue -> STM Bool diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index ed859422a..544bf35b9 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -15,6 +15,7 @@ data QueueRec = QueueRec notifier :: Maybe (NotifierId, NtfPublicVerifyKey), status :: QueueStatus } + deriving (Eq, Show) data QueueStatus = QueueActive | QueueOff deriving (Eq, Show) diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index b3424f6e8..401e7ee30 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} @@ -11,107 +12,83 @@ module Simplex.Messaging.Server.QueueStore.STM where -import Data.Map.Strict (Map) -import qualified Data.Map.Strict as M +import Control.Monad +import Data.Functor (($>)) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util (ifM) import UnliftIO.STM -data QueueStoreData = QueueStoreData - { queues :: Map RecipientId QueueRec, - senders :: Map SenderId RecipientId, - notifiers :: Map NotifierId RecipientId +data QueueStore = QueueStore + { queues :: TMap RecipientId (TVar QueueRec), + senders :: TMap SenderId RecipientId, + notifiers :: TMap NotifierId RecipientId } -type QueueStore = TVar QueueStoreData - newQueueStore :: STM QueueStore -newQueueStore = newTVar QueueStoreData {queues = M.empty, senders = M.empty, notifiers = M.empty} +newQueueStore = do + queues <- TM.empty + senders <- TM.empty + notifiers <- TM.empty + pure QueueStore {queues, senders, notifiers} instance MonadQueueStore QueueStore STM where addQueue :: QueueStore -> QueueRec -> STM (Either ErrorType ()) - addQueue store qRec@QueueRec {recipientId = rId, senderId = sId} = do - cs@QueueStoreData {queues, senders} <- readTVar store - if M.member rId queues || M.member sId senders - then return $ Left DUPLICATE_ - else do - writeTVar store $ - cs - { queues = M.insert rId qRec queues, - senders = M.insert sId rId senders - } - return $ Right () + addQueue QueueStore {queues, senders} q@QueueRec {recipientId = rId, senderId = sId} = do + ifM hasId (pure $ Left DUPLICATE_) $ do + qVar <- newTVar q + TM.insert rId qVar queues + TM.insert sId rId senders + pure $ Right () + where + hasId = (||) <$> TM.member rId queues <*> TM.member sId senders getQueue :: QueueStore -> SParty p -> QueueId -> STM (Either ErrorType QueueRec) - getQueue st party qId = do - cs <- readTVar st - pure $ case party of - SRecipient -> getRcpQueue cs qId - SSender -> getPartyQueue cs senders - SNotifier -> getPartyQueue cs notifiers + getQueue QueueStore {queues, senders, notifiers} party qId = + toResult <$> (mapM readTVar =<< getVar) where - getPartyQueue :: - QueueStoreData -> - (QueueStoreData -> Map QueueId RecipientId) -> - Either ErrorType QueueRec - getPartyQueue cs recipientIds = - case M.lookup qId $ recipientIds cs of - Just rId -> getRcpQueue cs rId - Nothing -> Left AUTH + getVar = case party of + SRecipient -> TM.lookup qId queues + SSender -> TM.lookup qId senders >>= get + SNotifier -> TM.lookup qId notifiers >>= get + get = fmap join . mapM (`TM.lookup` queues) secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) - secureQueue store rId sKey = - updateQueues store rId $ \cs c -> - case senderKey c of - Just _ -> (Left AUTH, cs) - _ -> (Right c, cs {queues = M.insert rId c {senderKey = Just sKey} (queues cs)}) + secureQueue QueueStore {queues} rId sKey = + withQueue rId queues $ \qVar -> + readTVar qVar >>= \q -> case senderKey q of + Just _ -> pure Nothing + _ -> writeTVar qVar q {senderKey = Just sKey} $> Just q addQueueNotifier :: QueueStore -> RecipientId -> NotifierId -> NtfPublicVerifyKey -> STM (Either ErrorType QueueRec) - addQueueNotifier store rId nId nKey = do - cs@QueueStoreData {queues, notifiers} <- readTVar store - if M.member nId notifiers - then pure $ Left DUPLICATE_ - else case M.lookup rId queues of - Nothing -> pure $ Left AUTH - Just q -> case notifier q of - Just _ -> pure $ Left AUTH + addQueueNotifier QueueStore {queues, notifiers} rId nId nKey = do + ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $ + withQueue rId queues $ \qVar -> + readTVar qVar >>= \q -> case notifier q of + Just _ -> pure Nothing _ -> do - writeTVar store $ - cs - { queues = M.insert rId q {notifier = Just (nId, nKey)} queues, - notifiers = M.insert nId rId notifiers - } - pure $ Right q + writeTVar qVar q {notifier = Just (nId, nKey)} + TM.insert nId rId notifiers + pure $ Just q suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - suspendQueue store rId = - updateQueues store rId $ \cs c -> - (Right (), cs {queues = M.insert rId c {status = QueueOff} (queues cs)}) + suspendQueue QueueStore {queues} rId = + withQueue rId queues $ \qVar -> modifyTVar' qVar (\q -> q {status = QueueOff}) $> Just () deleteQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ()) - deleteQueue store rId = - updateQueues store rId $ \cs c -> - ( Right (), - cs - { queues = M.delete rId (queues cs), - senders = M.delete (senderId c) (senders cs) - } - ) + deleteQueue QueueStore {queues, senders, notifiers} rId = do + TM.lookupDelete rId queues >>= \case + Just qVar -> + readTVar qVar >>= \q -> do + TM.delete (senderId q) senders + forM_ (notifier q) $ \(nId, _) -> TM.delete nId notifiers + pure $ Right () + _ -> pure $ Left AUTH -updateQueues :: - QueueStore -> - RecipientId -> - (QueueStoreData -> QueueRec -> (Either ErrorType a, QueueStoreData)) -> - STM (Either ErrorType a) -updateQueues store rId update = do - cs <- readTVar store - let conn = getRcpQueue cs rId - either (return . Left) (_update cs) conn - where - _update cs c = do - let (res, cs') = update cs c - writeTVar store cs' - return res +toResult :: Maybe a -> Either ErrorType a +toResult = maybe (Left AUTH) Right -getRcpQueue :: QueueStoreData -> RecipientId -> Either ErrorType QueueRec -getRcpQueue cs rId = maybe (Left AUTH) Right . M.lookup rId $ queues cs +withQueue :: RecipientId -> TMap RecipientId (TVar QueueRec) -> (TVar QueueRec -> STM (Maybe a)) -> STM (Either ErrorType a) +withQueue rId queues f = toResult <$> (TM.lookup rId queues >>= fmap join . mapM f) diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index de9b293f0..012adde4b 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -2,6 +2,7 @@ module Simplex.Messaging.TMap ( TMap (..), empty, Simplex.Messaging.TMap.lookup, + member, insert, delete, lookupInsert, @@ -26,6 +27,10 @@ lookup :: Ord k => k -> TMap k a -> STM (Maybe a) lookup k (TMap m) = M.lookup k <$> readTVar m {-# INLINE lookup #-} +member :: Ord k => k -> TMap k a -> STM Bool +member k (TMap m) = M.member k <$> readTVar m +{-# INLINE member #-} + insert :: Ord k => k -> a -> TMap k a -> STM () insert k v (TMap m) = modifyTVar' m $ M.insert k v {-# INLINE insert #-} From cd22e06b3a4e8a85d09e136b963040981232834d Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 28 Mar 2022 18:49:17 +0100 Subject: [PATCH 3/3] use TMap for subscription maps (#341) * use TMap for subscription maps * refactor * correction --- src/Simplex/Messaging/Agent/Client.hs | 94 +++++++++++++------------ src/Simplex/Messaging/Server.hs | 2 +- src/Simplex/Messaging/Server/Env/STM.hs | 6 +- src/Simplex/Messaging/TMap.hs | 34 +++++---- 4 files changed, 76 insertions(+), 60 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index dfe538844..96bcd4529 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -63,7 +63,7 @@ import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKe import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError) +import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, whenM) import Simplex.Messaging.Version import System.Timeout (timeout) import UnliftIO (async, forConcurrently_) @@ -79,8 +79,8 @@ data AgentClient = AgentClient msgQ :: TBQueue SMPServerTransmission, smpServers :: TVar (NonEmpty SMPServer), smpClients :: TMap SMPServer SMPClientVar, - subscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue), - pendingSubscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue), + subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), + pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), subscrConns :: TMap ConnId SMPServer, connMsgsQueued :: TMap ConnId Bool, smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId), @@ -188,21 +188,24 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = clientDisconnected :: UnliftIO m -> IO () clientDisconnected u = do - removeClientSubs >>= (`forM_` serverDown u) + removeClientAndSubs >>= (`forM_` serverDown u) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeClientSubs :: IO (Maybe (Map ConnId RcvQueue)) - removeClientSubs = atomically $ do + removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) + removeClientAndSubs = atomically $ do TM.delete srv smpClients - cs_ <- TM.lookupDelete srv $ subscrSrvrs c - forM_ cs_ $ \cs -> do - modifyTVar' (TM.tVar $ subscrConns c) (`M.withoutKeys` M.keysSet cs) - modifyTVar' (TM.tVar $ pendingSubscrSrvrs c) $ addPendingSubs cs - return cs_ + cVar_ <- TM.lookupDelete srv $ subscrSrvrs c + forM cVar_ $ \cVar -> do + cs <- readTVar cVar + modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) + addPendingSubs cVar cs + pure cs where - addPendingSubs :: Map ConnId RcvQueue -> Map SMPServer (Map ConnId RcvQueue) -> Map SMPServer (Map ConnId RcvQueue) - addPendingSubs cs = M.alter (Just . addSubs cs) srv - addSubs cs = maybe cs (M.union cs) + addPendingSubs cVar cs = do + let ps = pendingSubscrSrvrs c + TM.lookup srv ps >>= \case + Just v -> TM.union cs v + _ -> TM.insert srv cVar ps serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () serverDown u cs = unless (M.null cs) $ do @@ -221,19 +224,26 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = reconnectClient `catchError` const loop reconnectClient :: m () - reconnectClient = do + reconnectClient = withAgentLock c . withSMP c srv $ \smp -> do - subs <- readTVarIO . TM.tVar $ subscrConns c - cs <- atomically . TM.lookup srv $ pendingSubscrSrvrs c - forConcurrently_ (maybe [] M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) -> - when (isNothing $ M.lookup connId subs) $ do - subscribeSMPQueue smp rcvPrivateKey rcvId - `catchError` \case - e@SMPResponseTimeout -> throwError e - e@SMPNetworkError -> throwError e - e -> liftIO $ notifySub (ERR $ smpClientError e) connId - addSubscription c rq connId - liftIO $ notifySub UP connId + cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c) + forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) -> + whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $ + subscribe_ smp sub `catchError` handleError connId + where + subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT SMPClientError IO () + subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do + subscribeSMPQueue smp rcvPrivateKey rcvId + addSubscription c rq connId + liftIO $ notifySub UP connId + + handleError :: ConnId -> SMPClientError -> ExceptT SMPClientError IO () + handleError connId = \case + e@SMPResponseTimeout -> throwError e + e@SMPNetworkError -> throwError e + e -> do + liftIO $ notifySub (ERR $ smpClientError e) connId + atomically $ removePendingSubscription c srv connId notifySub :: ACommand 'Agent -> ConnId -> IO () notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd) @@ -243,10 +253,10 @@ closeAgentClient c = liftIO $ do closeSMPServerClients c cancelActions $ reconnections c cancelActions $ asyncClients c - cancelActions . TM.tVar $ smpQueueMsgDeliveries c + cancelActions $ smpQueueMsgDeliveries c closeSMPServerClients :: AgentClient -> IO () -closeSMPServerClients c = readTVarIO (TM.tVar $ smpClients c) >>= mapM_ (forkIO . closeClient) +closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient) where closeClient smpVar = atomically (readTMVar smpVar) >>= \case @@ -331,7 +341,7 @@ newRcvQueue_ a c srv = do subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do - addPendingSubscription c rq connId + atomically $ addPendingSubscription c rq connId withLogSMP c server rcvId "SUB" $ \smp -> do liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case Left e -> do @@ -343,17 +353,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do TM.insert connId server $ subscrConns c - addSubs_ rq connId $ subscrSrvrs c + addSubs_ (subscrSrvrs c) rq connId removePendingSubscription c server connId -addPendingSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () -addPendingSubscription c rq connId = - atomically . addSubs_ rq connId $ pendingSubscrSrvrs c +addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM () +addPendingSubscription = addSubs_ . pendingSubscrSrvrs -addSubs_ :: RcvQueue -> ConnId -> TMap SMPServer (Map ConnId RcvQueue) -> STM () -addSubs_ rq@RcvQueue {server} connId = TM.alter (Just . addSub) server - where - addSub = maybe (M.singleton connId rq) (M.insert connId rq) +addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM () +addSubs_ ss rq@RcvQueue {server} connId = + TM.lookup server ss >>= \case + Just m -> TM.insert connId rq m + _ -> TM.singleton connId rq >>= \m -> TM.insert server m ss removeSubscription :: MonadUnliftIO m => AgentClient -> ConnId -> m () removeSubscription c@AgentClient {subscrConns} connId = atomically $ do @@ -363,13 +373,9 @@ removeSubscription c@AgentClient {subscrConns} connId = atomically $ do removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM () removePendingSubscription = removeSubs_ . pendingSubscrSrvrs -removeSubs_ :: TMap SMPServer (Map ConnId RcvQueue) -> SMPServer -> ConnId -> STM () -removeSubs_ ss server connId = TM.update delSub server ss - where - delSub :: Map ConnId RcvQueue -> Maybe (Map ConnId RcvQueue) - delSub cs = - let cs' = M.delete connId cs - in if M.null cs' then Nothing else Just cs' +removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM () +removeSubs_ ss server connId = + TM.lookup server ss >>= mapM_ (TM.delete connId) logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 102830353..938ba9aa1 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -137,7 +137,7 @@ runClientTransport th@THandle {sessionId} = do clientDisconnected :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () clientDisconnected c@Client {subscriptions, connected} = do atomically $ writeTVar connected False - subs <- readTVarIO $ TM.tVar subscriptions + subs <- readTVarIO subscriptions mapM_ cancelSub subs cs <- asks $ subscribers . server atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 5187c0d73..3c4599a97 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -114,9 +114,9 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile} restoreQueues QueueStore {queues, senders, notifiers} s = do (qs, s') <- liftIO $ readWriteStoreLog s atomically $ do - writeTVar (TM.tVar queues) =<< mapM newTVar qs - writeTVar (TM.tVar senders) $ M.foldr' addSender M.empty qs - writeTVar (TM.tVar notifiers) $ M.foldr' addNotifier M.empty qs + writeTVar queues =<< mapM newTVar qs + writeTVar senders $ M.foldr' addSender M.empty qs + writeTVar notifiers $ M.foldr' addNotifier M.empty qs pure s' addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index 012adde4b..a6584903e 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -1,6 +1,7 @@ module Simplex.Messaging.TMap - ( TMap (..), + ( TMap, empty, + singleton, Simplex.Messaging.TMap.lookup, member, insert, @@ -10,6 +11,7 @@ module Simplex.Messaging.TMap adjust, update, alter, + union, ) where @@ -17,44 +19,52 @@ import Control.Concurrent.STM import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -newtype TMap k a = TMap {tVar :: TVar (Map k a)} +type TMap k a = TVar (Map k a) empty :: STM (TMap k a) -empty = TMap <$> newTVar M.empty +empty = newTVar M.empty {-# INLINE empty #-} +singleton :: k -> a -> STM (TMap k a) +singleton k v = newTVar $ M.singleton k v +{-# INLINE singleton #-} + lookup :: Ord k => k -> TMap k a -> STM (Maybe a) -lookup k (TMap m) = M.lookup k <$> readTVar m +lookup k m = M.lookup k <$> readTVar m {-# INLINE lookup #-} member :: Ord k => k -> TMap k a -> STM Bool -member k (TMap m) = M.member k <$> readTVar m +member k m = M.member k <$> readTVar m {-# INLINE member #-} insert :: Ord k => k -> a -> TMap k a -> STM () -insert k v (TMap m) = modifyTVar' m $ M.insert k v +insert k v m = modifyTVar' m $ M.insert k v {-# INLINE insert #-} delete :: Ord k => k -> TMap k a -> STM () -delete k (TMap m) = modifyTVar' m $ M.delete k +delete k m = modifyTVar' m $ M.delete k {-# INLINE delete #-} lookupInsert :: Ord k => k -> a -> TMap k a -> STM (Maybe a) -lookupInsert k v (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv) +lookupInsert k v m = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv) {-# INLINE lookupInsert #-} lookupDelete :: Ord k => k -> TMap k a -> STM (Maybe a) -lookupDelete k (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv) +lookupDelete k m = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv) {-# INLINE lookupDelete #-} adjust :: Ord k => (a -> a) -> k -> TMap k a -> STM () -adjust f k (TMap m) = modifyTVar' m $ M.adjust f k +adjust f k m = modifyTVar' m $ M.adjust f k {-# INLINE adjust #-} update :: Ord k => (a -> Maybe a) -> k -> TMap k a -> STM () -update f k (TMap m) = modifyTVar' m $ M.update f k +update f k m = modifyTVar' m $ M.update f k {-# INLINE update #-} alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM () -alter f k (TMap m) = modifyTVar' m $ M.alter f k +alter f k m = modifyTVar' m $ M.alter f k {-# INLINE alter #-} + +union :: Ord k => Map k a -> TMap k a -> STM () +union m' m = modifyTVar' m $ M.union m' +{-# INLINE union #-}