From 6adbc56021ca0600f8d9864c477e0668a67822a6 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 9 Sep 2022 12:30:27 +0100 Subject: [PATCH 1/2] try async commands without servers on different servers (#516) * refactor * retry commands with different servers * refactor * remove comment Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> --- src/Simplex/Messaging/Agent.hs | 130 ++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 52 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f705f6666..3e4967e05 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -85,6 +85,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import Data.Composition ((.:), (.:.)) import Data.Functor (($>)) +import Data.List (deleteFirstsBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -359,8 +360,11 @@ ackMessageAsync' c connId msgId = enqueueCommand c connId (Just server) $ ACK msgId newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c) -newConn c connId asyncMode enableNtfs cMode = do - srv <- getSMPServer c +newConn c connId asyncMode enableNtfs cMode = + getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode + +newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c) +newConnSrv c connId asyncMode enableNtfs cMode srv = do clientVRange <- asks $ smpClientVRange . config (rq, qUri) <- newRcvQueue c srv clientVRange connId' <- setUpConn asyncMode rq @@ -387,7 +391,11 @@ newConn c connId asyncMode enableNtfs cMode = do withStore c $ \db -> createRcvConn db g cData rq cMode joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = do +joinConn c connId asyncMode enableNtfs connReq cInfo = + getSMPServer c >>= joinConnSrv c connId asyncMode enableNtfs connReq cInfo + +joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServer -> m ConnId +joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, @@ -403,7 +411,7 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS} connId' <- setUpConn asyncMode cData sq rc let cData' = (cData :: ConnData) {connId = connId'} - tryError (confirmQueue aVersion c cData' sq cInfo $ Just e2eSndParams) >>= \case + tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case Right _ -> do unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO pure connId' @@ -424,23 +432,22 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV liftIO $ createRatchet db connId' rc pure connId' _ -> throwError $ AGENT A_VERSION -joinConn c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = do +joinConnSrv c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, agentVRange `compatibleVersion` aVRange ) of (Just qInfo, Just vrsn) -> do - (connId', cReq) <- newConn c connId False enableNtfs SCMInvitation + (connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation srv sendInvitation c qInfo vrsn cReq cInfo pure connId' _ -> throwError $ AGENT A_VERSION -joinConn _c _connId True _enableNtfs (CRContactUri _) _cInfo = do +joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do throwError $ CMD PROHIBITED -createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo -createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} = do - srv <- getSMPServer c +createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo +createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do (rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion let qInfo = toVersionT qUri smpClientVersion addSubscription c rq connId @@ -490,21 +497,21 @@ processConfirmation c rq@RcvQueue {e2ePrivKey, smpClientVersion = v} SMPConfirma -- | Subscribe to receive connection messages (SUB command) in Reader monad subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () subscribeConnection' c connId = - withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection cData rq sq) -> do - resumeMsgDelivery c cData sq - subscribe rq - resumeConnCmds c connId - SomeConn _ (SndConnection cData sq) -> do - resumeMsgDelivery c cData sq - case status (sq :: SndQueue) of - Confirmed -> pure () - Active -> throwError $ CONN SIMPLEX - _ -> throwError $ INTERNAL "unexpected queue status" - resumeConnCmds c connId - SomeConn _ (RcvConnection _ rq) -> subscribe rq >> resumeConnCmds c connId - SomeConn _ (ContactConnection _ rq) -> subscribe rq >> resumeConnCmds c connId - SomeConn _ (NewConnection _) -> resumeConnCmds c connId + withStore c (`getConn` connId) >>= \conn -> do + resumeConnCmds c connId + case conn of + SomeConn _ (DuplexConnection cData rq sq) -> do + resumeMsgDelivery c cData sq + subscribe rq + SomeConn _ (SndConnection cData sq) -> do + resumeMsgDelivery c cData sq + case status (sq :: SndQueue) of + Confirmed -> pure () + Active -> throwError $ CONN SIMPLEX + _ -> throwError $ INTERNAL "unexpected queue status" + SomeConn _ (RcvConnection _ rq) -> subscribe rq + SomeConn _ (ContactConnection _ rq) -> subscribe rq + SomeConn _ (NewConnection _) -> pure () where subscribe :: RcvQueue -> m () subscribe rq = do @@ -521,7 +528,7 @@ subscribeConnections' c connIds = do (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs - forM_ (M.keys cs) $ resumeConnCmds c + mapM_ (resumeConnCmds c) $ M.keys cs rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) @@ -679,26 +686,35 @@ runCommandProcessing c@AgentClient {subQ} server = do E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case Left (e :: E.SomeException) -> notify "" $ ERR (INTERNAL $ show e) - Right (connId, ACmd _ cmd) -> + Right (connId, ACmd _ cmd) -> do + usedSrvs <- newTVarIO ([] :: [SMPServer]) withRetryInterval ri $ \loop -> do resp <- tryError $ case cmd of - NEW enableNtfs (ACM cMode) -> do - (_, cReq) <- newConn c connId True enableNtfs cMode - notify connId $ INV (ACR cMode cReq) - JOIN enableNtfs (ACR _ cReq) connInfo -> void $ joinConn c connId True enableNtfs cReq connInfo + NEW enableNtfs (ACM cMode) -> + withNextSrv usedSrvs $ \srv -> do + (_, cReq) <- newConnSrv c connId True enableNtfs cMode srv + notify connId $ INV (ACR cMode cReq) + JOIN enableNtfs (ACR _ cReq) connInfo -> + withNextSrv usedSrvs $ \srv -> + void $ joinConnSrv c connId True enableNtfs cReq connInfo srv LET confId ownCInfo -> allowConnection' c connId confId ownCInfo ACK msgId -> ackMessage' c connId msgId - _ -> notify "" $ ERR (INTERNAL "") + _ -> notify connId $ ERR $ INTERNAL $ "unsupported async command " <> show cmd case resp of - Left _ -> - -- TODO retry NEW and JOIN on different server - -- TODO depending on command, some errors shouldn't be retried - retryCommand loop - Right () -> do - delCmd cmdId + Left e + | temporaryAgentError e || e == BROKER HOST -> retryCommand loop + | otherwise -> notify connId $ ERR e + Right () -> withStore' c (`deleteCommand` cmdId) where - delCmd :: AsyncCmdId -> m () - delCmd cmdId = withStore' c $ \db -> deleteCommand db cmdId + withNextSrv :: TVar [SMPServer] -> (SMPServer -> m ()) -> m () + withNextSrv usedSrvs action = do + used <- readTVarIO usedSrvs + srv <- getNextSMPServer c used + atomically $ do + srvs <- readTVar $ smpServers c + let used' = if length used + 1 >= L.length srvs then [] else srv : used + writeTVar usedSrvs used' + action srv notify :: ConnId -> ACommand 'Agent -> m () notify connId cmd = atomically $ writeTBQueue subQ ("", connId, cmd) retryCommand loop = do @@ -844,7 +860,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- and this branch should never be reached as receive is created before the confirmation, -- so the condition is not necessary here, strictly speaking. _ -> unless (duplexHandshake == Just True) $ do - qInfo <- createReplyQueue c cData sq + srv <- getSMPServer c + qInfo <- createReplyQueue c cData sq srv void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo] AM_A_MSG_ -> notify $ SENT mId _ -> pure () @@ -1155,14 +1172,23 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do suspendSendingAndDatabase c getSMPServer :: AgentMonad m => AgentClient -> m SMPServer -getSMPServer c = do - smpServers <- readTVarIO $ smpServers c - case smpServers of - srv :| [] -> pure srv - servers -> do - gen <- asks randomServer - atomically . stateTVar gen $ - first (servers L.!!) . randomR (0, L.length servers - 1) +getSMPServer c = readTVarIO (smpServers c) >>= pickServer + +pickServer :: AgentMonad m => NonEmpty SMPServer -> m SMPServer +pickServer = \case + srv :| [] -> pure srv + servers -> do + gen <- asks randomServer + atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1)) + +getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer +getNextSMPServer c usedSrvs = do + srvs <- readTVarIO $ smpServers c + case L.nonEmpty $ deleteFirstsBy different (L.toList srvs) usedSrvs of + Just srvs' -> pickServer srvs' + _ -> pickServer srvs + where + different (SMPServer host port _) (SMPServer host' port' _) = host /= host' || port /= port' subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do @@ -1400,8 +1426,8 @@ connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do withStore c $ \db -> upgradeRcvConnToDuplex db connId sq enqueueConfirmation c cData sq ownConnInfo Nothing -confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2eEncryption = do +confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () +confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do aMessage <- mkAgentMessage agentVersion msg <- mkConfirmation aMessage sendConfirmation c sq msg @@ -1415,7 +1441,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2e mkAgentMessage :: Version -> m AgentMessage mkAgentMessage 1 = pure $ AgentConnInfo connInfo mkAgentMessage _ = do - qInfo <- createReplyQueue c cData sq + qInfo <- createReplyQueue c cData sq srv pure $ AgentConnInfoReply (qInfo :| []) connInfo enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () From 42a96d6d0034d8e10b46c6dbddd937a71dc9fb12 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 9 Sep 2022 16:31:57 +0100 Subject: [PATCH 2/2] refactor agent subscriptions with TMap2 (#517) * refactor agent subscriptions with TMap2 * refactor * refactor * comment --- simplexmq.cabal | 1 + src/Simplex/Messaging/Agent/Client.hs | 85 +++++++++------------------ src/Simplex/Messaging/TMap.hs | 5 ++ src/Simplex/Messaging/TMap2.hs | 82 ++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 56 deletions(-) create mode 100644 src/Simplex/Messaging/TMap2.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index 397cb4018..88a1e1d02 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -82,6 +82,7 @@ library Simplex.Messaging.Server.Stats Simplex.Messaging.Server.StoreLog Simplex.Messaging.TMap + Simplex.Messaging.TMap2 Simplex.Messaging.Transport Simplex.Messaging.Transport.Client Simplex.Messaging.Transport.HTTP2 diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 596457d3e..4d02a204e 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -90,7 +90,6 @@ import Data.Maybe (listToMaybe) import Data.Set (Set) import qualified Data.Set as S import Data.Text.Encoding -import Data.Tuple (swap) import Data.Word (Word16) import qualified Database.SQLite.Simple as DB import Simplex.Messaging.Agent.Env.SQLite @@ -131,6 +130,8 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.TMap2 (TMap2) +import qualified Simplex.Messaging.TMap2 as TM2 import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version @@ -155,10 +156,9 @@ data AgentClient = AgentClient ntfServers :: TVar [NtfServer], ntfClients :: TMap NtfServer NtfClientVar, useNetworkConfig :: TVar NetworkConfig, - subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), - pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), subscrConns :: TVar (Set ConnId), - activeSubscrConns :: TMap ConnId SMPServer, + activeSubs :: TMap2 SMPServer ConnId RcvQueue, + pendingSubs :: TMap2 SMPServer ConnId RcvQueue, connMsgsQueued :: TMap ConnId Bool, smpQueueMsgQueues :: TMap (SMPServer, SMP.SenderId) (TQueue InternalId), smpQueueMsgDeliveries :: TMap (SMPServer, SMP.SenderId) (Async ()), @@ -210,10 +210,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do ntfServers <- newTVar ntf ntfClients <- TM.empty useNetworkConfig <- newTVar netCfg - subscrSrvrs <- TM.empty - pendingSubscrSrvrs <- TM.empty subscrConns <- newTVar S.empty - activeSubscrConns <- TM.empty + activeSubs <- TM2.empty + pendingSubs <- TM2.empty connMsgsQueued <- TM.empty smpQueueMsgQueues <- TM.empty smpQueueMsgDeliveries <- TM.empty @@ -231,7 +230,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i') lock <- newTMVar () - return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} + return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} agentDbPath :: AgentClient -> FilePath agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath @@ -270,26 +269,18 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) removeClientAndSubs = atomically $ do TM.delete srv smpClients - TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs + TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs where updateSubs cVar = do - cs <- readTVar cVar - modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs) - addPendingSubs cVar cs - pure cs - - addPendingSubs cVar cs = do - let ps = pendingSubscrSrvrs c - TM.lookup srv ps >>= \case - Just v -> TM.union cs v - _ -> TM.insert srv cVar ps + TM2.insert1 srv cVar $ pendingSubs c + readTVar cVar serverDown :: Map ConnId RcvQueue -> IO () - serverDown cs = unless (M.null cs) $ - whenM (readTVarIO active) $ do - let conns = M.keys cs - notifySub "" $ hostEvent DISCONNECT client - unless (null conns) . notifySub "" $ DOWN srv conns + serverDown cs = whenM (readTVarIO active) $ do + notifySub "" $ hostEvent DISCONNECT client + let conns = M.keys cs + unless (null conns) $ do + notifySub "" $ DOWN srv conns atomically $ mapM_ (releaseGetLock c) cs unliftIO u reconnectServer @@ -307,7 +298,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do reconnectClient :: m () reconnectClient = withAgentLock c $ - atomically (TM.lookup srv (pendingSubscrSrvrs c) >>= mapM readTVar) + atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar) >>= mapM_ resubscribe where resubscribe :: Map ConnId RcvQueue -> m () @@ -413,10 +404,9 @@ closeAgentClient c = liftIO $ do cancelActions $ reconnections c cancelActions $ asyncClients c cancelActions $ smpQueueMsgDeliveries c - clear subscrSrvrs - clear pendingSubscrSrvrs + atomically . TM2.clear $ activeSubs c + atomically . TM2.clear $ pendingSubs c clear subscrConns - clear activeSubscrConns clear connMsgsQueued clear smpQueueMsgQueues clear getMsgLocks @@ -520,17 +510,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED atomically $ do modifyTVar (subscrConns c) $ S.insert connId - addPendingSubscription c rq connId + TM2.insert server connId rq $ pendingSubs c withLogClient c server rcvId "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId) >>= either throwError pure processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) -processSubResult c rq@RcvQueue {server} connId r = do +processSubResult c rq connId r = do case r of Left e -> atomically . unless (temporaryClientError e) $ - removePendingSubscription c server connId + TM2.delete connId (pendingSubs c) _ -> addSubscription c rq connId pure r @@ -550,9 +540,9 @@ temporaryAgentError = \case subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ())) subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) - forM_ qs_ $ \q -> atomically $ do - modifyTVar (subscrConns c) . S.insert $ fst q - uncurry (addPendingSubscription c) $ swap q + forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do + modifyTVar (subscrConns c) $ S.insert connId + TM2.insert server connId rq $ pendingSubs c case L.nonEmpty qs_ of Just qs' -> do smp_ <- tryError (getSMPServerClient c srv) @@ -574,35 +564,18 @@ subscribeQueues c srv qs = do addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do - TM.insert connId server $ activeSubscrConns c modifyTVar (subscrConns c) $ S.insert connId - addSubs_ (subscrSrvrs c) rq connId - removePendingSubscription c server connId + TM2.insert server connId rq $ activeSubs c + TM2.delete connId $ pendingSubs c hasActiveSubscription :: AgentClient -> ConnId -> STM Bool -hasActiveSubscription c connId = TM.member connId (activeSubscrConns c) - -addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM () -addPendingSubscription = addSubs_ . pendingSubscrSrvrs - -addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM () -addSubs_ ss rq@RcvQueue {server} connId = - TM.lookup server ss >>= \case - Just m -> TM.insert connId rq m - _ -> TM.singleton connId rq >>= \m -> TM.insert server m ss +hasActiveSubscription c connId = TM2.member connId $ activeSubs c removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do modifyTVar (subscrConns c) $ S.delete connId - server_ <- TM.lookupDelete connId $ activeSubscrConns c - mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_ - -removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM () -removePendingSubscription = removeSubs_ . pendingSubscrSrvrs - -removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM () -removeSubs_ ss server connId = - TM.lookup server ss >>= mapM_ (TM.delete connId) + TM2.delete connId $ activeSubs c + TM2.delete connId $ pendingSubs c getSubscriptions :: AgentClient -> STM (Set ConnId) getSubscriptions = readTVar . subscrConns diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index 761a41c93..2f6e0cf8a 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -2,6 +2,7 @@ module Simplex.Messaging.TMap ( TMap, empty, singleton, + clear, Simplex.Messaging.TMap.null, Simplex.Messaging.TMap.lookup, member, @@ -31,6 +32,10 @@ singleton :: k -> a -> STM (TMap k a) singleton k v = newTVar $ M.singleton k v {-# INLINE singleton #-} +clear :: TMap k a -> STM () +clear m = writeTVar m M.empty +{-# INLINE clear #-} + null :: TMap k a -> STM Bool null m = M.null <$> readTVar m {-# INLINE null #-} diff --git a/src/Simplex/Messaging/TMap2.hs b/src/Simplex/Messaging/TMap2.hs new file mode 100644 index 000000000..69d42e47d --- /dev/null +++ b/src/Simplex/Messaging/TMap2.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} + +module Simplex.Messaging.TMap2 + ( TMap2, + empty, + clear, + Simplex.Messaging.TMap2.lookup, + lookup1, + member, + insert, + insert1, + delete, + lookupDelete1, + ) +where + +import Control.Concurrent.STM +import Control.Monad (forM_, (>=>)) +import qualified Data.Map.Strict as M +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util (whenM, ($>>=)) + +-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys. +-- It allows direct access via k1 to a group of k2 values and via k2 to one value +data TMap2 k1 k2 a = TMap2 + { _m1 :: TMap k1 (TMap k2 a), + _m2 :: TMap k2 k1 + } + +empty :: STM (TMap2 k1 k2 a) +empty = TMap2 <$> TM.empty <*> TM.empty + +clear :: TMap2 k1 k2 a -> STM () +clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2 + +lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a) +lookup k2 TMap2 {_m1, _m2} = do + TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2 + +lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a)) +lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1 +{-# INLINE lookup1 #-} + +member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool +member k2 TMap2 {_m2} = TM.member k2 _m2 +{-# INLINE member #-} + +insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM () +insert k1 k2 v TMap2 {_m1, _m2} = + TM.lookup k2 _m2 >>= \case + Just k1' + | k1 == k1' -> _insert1 + | otherwise -> _delete1 k1' k2 _m1 >> _insert2 + _ -> _insert2 + where + _insert1 = + TM.lookup k1 _m1 >>= \case + Just m -> TM.insert k2 v m + _ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1 + _insert2 = TM.insert k2 k1 _m2 >> _insert1 + +insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM () +insert1 k1 m' TMap2 {_m1, _m2} = + TM.lookup k1 _m1 >>= \case + Just m -> readTVar m' >>= (`TM.union` m) + _ -> TM.insert k1 m' _m1 + +delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM () +delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1) + +_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM () +_delete1 k1 k2 m1 = + TM.lookup k1 m1 + >>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1)) + +lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a)) +lookupDelete1 k1 TMap2 {_m1, _m2} = do + m_ <- TM.lookupDelete k1 _m1 + forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet + pure m_