diff --git a/simplexmq.cabal b/simplexmq.cabal index a234b3baa..9d84b4cfc 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -83,6 +83,7 @@ library Simplex.Messaging.Server.Stats Simplex.Messaging.Server.StoreLog Simplex.Messaging.TMap + Simplex.Messaging.TMap2 Simplex.Messaging.Transport Simplex.Messaging.Transport.Client Simplex.Messaging.Transport.HTTP2 diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 474241bba..98809205a 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -86,6 +86,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import Data.Composition ((.:), (.:.)) import Data.Functor (($>)) +import Data.List (deleteFirstsBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -366,8 +367,11 @@ ackMessageAsync' c connId msgId = enqueueCommand c connId (Just server) $ ACK msgId newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c) -newConn c connId asyncMode enableNtfs cMode = do - srv <- getAnySMPServer c +newConn c connId asyncMode enableNtfs cMode = + getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode + +newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c) +newConnSrv c connId asyncMode enableNtfs cMode srv = do clientVRange <- asks $ smpClientVRange . config (rq, qUri) <- newRcvQueue c srv clientVRange True connId' <- setUpConn asyncMode rq @@ -394,7 +398,15 @@ newConn c connId asyncMode enableNtfs cMode = do withStore c $ \db -> createRcvConn db g cData rq cMode joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = do +joinConn c connId asyncMode enableNtfs cReq cInfo = do + srv <- case cReq of + CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ -> + getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)] + _ -> getSMPServer c + joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv + +joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServer -> m ConnId +joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, @@ -410,7 +422,7 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS} connId' <- setUpConn asyncMode cData sq rc let cData' = (cData :: ConnData) {connId = connId'} - tryError (confirmQueue aVersion c cData' sq cInfo $ Just e2eSndParams) >>= \case + tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case Right _ -> do unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO pure connId' @@ -431,23 +443,22 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV liftIO $ createRatchet db connId' rc pure connId' _ -> throwError $ AGENT A_VERSION -joinConn c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = do +joinConnSrv c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, agentVRange `compatibleVersion` aVRange ) of (Just qInfo, Just vrsn) -> do - (connId', cReq) <- newConn c connId False enableNtfs SCMInvitation + (connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation srv sendInvitation c qInfo vrsn cReq cInfo pure connId' _ -> throwError $ AGENT A_VERSION -joinConn _c _connId True _enableNtfs (CRContactUri _) _cInfo = do +joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do throwError $ CMD PROHIBITED -createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo -createReplyQueue c ConnData {connId, enableNtfs} SndQueue {server, smpClientVersion} = do - srv <- getSMPServer c server +createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo +createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do (rq, qUri) <- newRcvQueue c srv (versionToRange smpClientVersion) True let qInfo = toVersionT qUri smpClientVersion addSubscription c rq connId @@ -499,22 +510,22 @@ processConfirmation c rq@RcvQueue {e2ePrivKey, smpClientVersion = v} SMPConfirma -- | Subscribe to receive connection messages (SUB command) in Reader monad subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () subscribeConnection' c connId = - withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection cData rq sq _ _) -> do - resumeMsgDelivery c cData sq - subscribe rq - void . forkIO $ doRcvQueueAction c cData rq sq - resumeConnCmds c connId - SomeConn _ (SndConnection cData sq) -> do - resumeMsgDelivery c cData sq - case status (sq :: SndQueue) of - Confirmed -> pure () -- TODO secure queue if this is a new server version - Active -> throwError $ CONN SIMPLEX - _ -> throwError $ INTERNAL "unexpected queue status" - resumeConnCmds c connId - SomeConn _ (RcvConnection _ rq) -> subscribe rq >> resumeConnCmds c connId - SomeConn _ (ContactConnection _ rq) -> subscribe rq >> resumeConnCmds c connId - SomeConn _ (NewConnection _) -> resumeConnCmds c connId + withStore c (`getConn` connId) >>= \conn -> do + resumeConnCmds c connId + case conn of + SomeConn _ (DuplexConnection cData rq sq _ _) -> do + resumeMsgDelivery c cData sq + subscribe rq + void . forkIO $ doRcvQueueAction c cData rq sq + SomeConn _ (SndConnection cData sq) -> do + resumeMsgDelivery c cData sq + case status (sq :: SndQueue) of + Confirmed -> pure () -- TODO secure queue if this is a new server version + Active -> throwError $ CONN SIMPLEX + _ -> throwError $ INTERNAL "unexpected queue status" + SomeConn _ (RcvConnection _ rq) -> subscribe rq + SomeConn _ (ContactConnection _ rq) -> subscribe rq + SomeConn _ (NewConnection _) -> pure () where -- TODO sndQueueAction? subscribe :: RcvQueue -> m () @@ -549,7 +560,7 @@ createNextRcvQueue c cData@ConnData {connId} rq@RcvQueue {server, sndId} sq = do let queueAddress = SMPQueueAddress {smpServer, senderId, dhPublicKey = C.publicKey e2ePrivKey} pure SMPQueueUri {clientVRange, queueAddress} _ -> do - srv <- getSMPServer c server + srv <- getNextSMPServer c [server] (rq', qUri) <- newRcvQueue c srv clientVRange False withStore' c $ \db -> dbCreateNextRcvQueue db connId rq rq' pure qUri @@ -599,7 +610,7 @@ subscribeConnections' c connIds = do (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs - forM_ (M.keys cs) $ resumeConnCmds c + mapM_ (resumeConnCmds c) $ M.keys cs rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) @@ -761,26 +772,35 @@ runCommandProcessing c@AgentClient {subQ} server = do E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case Left (e :: E.SomeException) -> notify "" $ ERR (INTERNAL $ show e) - Right (connId, ACmd _ cmd) -> + Right (connId, ACmd _ cmd) -> do + usedSrvs <- newTVarIO ([] :: [SMPServer]) withRetryInterval ri $ \loop -> do resp <- tryError $ case cmd of - NEW enableNtfs (ACM cMode) -> do - (_, cReq) <- newConn c connId True enableNtfs cMode - notify connId $ INV (ACR cMode cReq) - JOIN enableNtfs (ACR _ cReq) connInfo -> void $ joinConn c connId True enableNtfs cReq connInfo + NEW enableNtfs (ACM cMode) -> + withNextSrv usedSrvs $ \srv -> do + (_, cReq) <- newConnSrv c connId True enableNtfs cMode srv + notify connId $ INV (ACR cMode cReq) + JOIN enableNtfs (ACR _ cReq) connInfo -> + withNextSrv usedSrvs $ \srv -> + void $ joinConnSrv c connId True enableNtfs cReq connInfo srv LET confId ownCInfo -> allowConnection' c connId confId ownCInfo ACK msgId -> ackMessage' c connId msgId - _ -> notify "" $ ERR (INTERNAL "") + _ -> notify connId $ ERR $ INTERNAL $ "unsupported async command " <> show cmd case resp of - Left _ -> - -- TODO retry NEW and JOIN on different server - -- TODO depending on command, some errors shouldn't be retried - retryCommand loop - Right () -> do - delCmd cmdId + Left e + | temporaryAgentError e || e == BROKER HOST -> retryCommand loop + | otherwise -> notify connId $ ERR e + Right () -> withStore' c (`deleteCommand` cmdId) where - delCmd :: AsyncCmdId -> m () - delCmd cmdId = withStore' c $ \db -> deleteCommand db cmdId + withNextSrv :: TVar [SMPServer] -> (SMPServer -> m ()) -> m () + withNextSrv usedSrvs action = do + used <- readTVarIO usedSrvs + srv <- getNextSMPServer c used + atomically $ do + srvs <- readTVar $ smpServers c + let used' = if length used + 1 >= L.length srvs then [] else srv : used + writeTVar usedSrvs used' + action srv notify :: ConnId -> ACommand 'Agent -> m () notify connId cmd = atomically $ writeTBQueue subQ ("", connId, cmd) retryCommand loop = do @@ -940,7 +960,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- and this branch should never be reached as receive is created before the confirmation, -- so the condition is not necessary here, strictly speaking. _ -> unless (duplexHandshake == Just True) $ do - qInfo <- createReplyQueue c cData sq + srv <- getSMPServer c + qInfo <- createReplyQueue c cData sq srv void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo] AM_A_MSG_ -> notify $ SENT mId AM_QHELLO_ -> do @@ -1270,8 +1291,8 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do -- unsafeIOToSTM $ putStrLn $ "in timeout: suspendSendingAndDatabase" suspendSendingAndDatabase c -getAnySMPServer :: AgentMonad m => AgentClient -> m SMPServer -getAnySMPServer c = readTVarIO (smpServers c) >>= pickServer +getSMPServer :: AgentMonad m => AgentClient -> m SMPServer +getSMPServer c = readTVarIO (smpServers c) >>= pickServer pickServer :: AgentMonad m => NonEmpty SMPServer -> m SMPServer pickServer = \case @@ -1280,14 +1301,14 @@ pickServer = \case gen <- asks randomServer atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1)) -getSMPServer :: AgentMonad m => AgentClient -> SMPServer -> m SMPServer -getSMPServer c (SMPServer host port _) = do +getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer +getNextSMPServer c usedSrvs = do srvs <- readTVarIO $ smpServers c - case L.nonEmpty $ L.filter different srvs of + case L.nonEmpty $ deleteFirstsBy different (L.toList srvs) usedSrvs of Just srvs' -> pickServer srvs' - _ -> pure $ L.head srvs + _ -> pickServer srvs where - different (SMPServer host' port' _) = host /= host' || port /= port' + different (SMPServer host port _) (SMPServer host' port' _) = host /= host' || port /= port' subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do @@ -1666,8 +1687,8 @@ connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do withStore c $ \db -> upgradeRcvConnToDuplex db connId sq enqueueConfirmation c cData sq ownConnInfo Nothing -confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2eEncryption = do +confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () +confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do aMessage <- mkAgentMessage agentVersion msg <- mkConfirmation aMessage sendConfirmation c sq msg @@ -1681,7 +1702,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2e mkAgentMessage :: Version -> m AgentMessage mkAgentMessage 1 = pure $ AgentConnInfo connInfo mkAgentMessage _ = do - qInfo <- createReplyQueue c cData sq + qInfo <- createReplyQueue c cData sq srv pure $ AgentConnInfoReply (qInfo :| []) connInfo enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index b8f3ed954..4936f71dc 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -92,7 +92,6 @@ import Data.Set (Set) import qualified Data.Set as S import Data.Text.Encoding import Data.Time.Clock (getCurrentTime) -import Data.Tuple (swap) import Data.Word (Word16) import qualified Database.SQLite.Simple as DB import Simplex.Messaging.Agent.Env.SQLite @@ -133,6 +132,8 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.TMap2 (TMap2) +import qualified Simplex.Messaging.TMap2 as TM2 import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version @@ -159,10 +160,9 @@ data AgentClient = AgentClient ntfServers :: TVar [NtfServer], ntfClients :: TMap NtfServer NtfClientVar, useNetworkConfig :: TVar NetworkConfig, - subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), - pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), subscrConns :: TVar (Set ConnId), - activeSubscrConns :: TMap ConnId SMPServer, + activeSubs :: TMap2 SMPServer ConnId RcvQueue, + pendingSubs :: TMap2 SMPServer ConnId RcvQueue, connMsgsQueued :: TMap ConnId Bool, smpQueueMsgQueues :: TMap MsgDeliveryKey (TQueue InternalId), smpQueueMsgDeliveries :: TMap MsgDeliveryKey (Async ()), @@ -215,10 +215,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do ntfServers <- newTVar ntf ntfClients <- TM.empty useNetworkConfig <- newTVar netCfg - subscrSrvrs <- TM.empty - pendingSubscrSrvrs <- TM.empty subscrConns <- newTVar S.empty - activeSubscrConns <- TM.empty + activeSubs <- TM2.empty + pendingSubs <- TM2.empty connMsgsQueued <- TM.empty smpQueueMsgQueues <- TM.empty smpQueueMsgDeliveries <- TM.empty @@ -237,7 +236,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i') lock <- newTMVar () - return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} + return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock} agentDbPath :: AgentClient -> FilePath agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath @@ -276,26 +275,18 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) removeClientAndSubs = atomically $ do TM.delete srv smpClients - TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs + TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs where updateSubs cVar = do - cs <- readTVar cVar - modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs) - addPendingSubs cVar cs - pure cs - - addPendingSubs cVar cs = do - let ps = pendingSubscrSrvrs c - TM.lookup srv ps >>= \case - Just v -> TM.union cs v - _ -> TM.insert srv cVar ps + TM2.insert1 srv cVar $ pendingSubs c + readTVar cVar serverDown :: Map ConnId RcvQueue -> IO () - serverDown cs = unless (M.null cs) $ - whenM (readTVarIO active) $ do - let conns = M.keys cs - notifySub "" $ hostEvent DISCONNECT client - unless (null conns) . notifySub "" $ DOWN srv conns + serverDown cs = whenM (readTVarIO active) $ do + notifySub "" $ hostEvent DISCONNECT client + let conns = M.keys cs + unless (null conns) $ do + notifySub "" $ DOWN srv conns atomically $ mapM_ (releaseGetLock c) cs unliftIO u reconnectServer @@ -313,7 +304,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do reconnectClient :: m () reconnectClient = withAgentLock c $ - atomically (TM.lookup srv (pendingSubscrSrvrs c) >>= mapM readTVar) + atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar) >>= mapM_ resubscribe where resubscribe :: Map ConnId RcvQueue -> m () @@ -419,10 +410,9 @@ closeAgentClient c = liftIO $ do cancelActions $ reconnections c cancelActions $ asyncClients c cancelActions $ smpQueueMsgDeliveries c - clear subscrSrvrs - clear pendingSubscrSrvrs + atomically . TM2.clear $ activeSubs c + atomically . TM2.clear $ pendingSubs c clear subscrConns - clear activeSubscrConns clear connMsgsQueued clear smpQueueMsgQueues clear getMsgLocks @@ -534,17 +524,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED atomically $ do modifyTVar (subscrConns c) $ S.insert connId - addPendingSubscription c rq connId + TM2.insert server connId rq $ pendingSubs c withLogClient c server rcvId "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId) >>= either throwError pure processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) -processSubResult c rq@RcvQueue {server} connId r = do +processSubResult c rq connId r = do case r of Left e -> atomically . unless (temporaryClientError e) $ - removePendingSubscription c server connId + TM2.delete connId (pendingSubs c) _ -> addSubscription c rq connId pure r @@ -564,9 +554,9 @@ temporaryAgentError = \case subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ())) subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) - forM_ qs_ $ \q -> atomically $ do - modifyTVar (subscrConns c) . S.insert $ fst q - uncurry (addPendingSubscription c) $ swap q + forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do + modifyTVar (subscrConns c) $ S.insert connId + TM2.insert server connId rq $ pendingSubs c case L.nonEmpty qs_ of Just qs' -> do smp_ <- tryError (getSMPServerClient c srv) @@ -588,35 +578,18 @@ subscribeQueues c srv qs = do addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do - TM.insert connId server $ activeSubscrConns c modifyTVar (subscrConns c) $ S.insert connId - addSubs_ (subscrSrvrs c) rq connId - removePendingSubscription c server connId + TM2.insert server connId rq $ activeSubs c + TM2.delete connId $ pendingSubs c hasActiveSubscription :: AgentClient -> ConnId -> STM Bool -hasActiveSubscription c connId = TM.member connId (activeSubscrConns c) - -addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM () -addPendingSubscription = addSubs_ . pendingSubscrSrvrs - -addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM () -addSubs_ ss rq@RcvQueue {server} connId = - TM.lookup server ss >>= \case - Just m -> TM.insert connId rq m - _ -> TM.singleton connId rq >>= \m -> TM.insert server m ss +hasActiveSubscription c connId = TM2.member connId $ activeSubs c removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do modifyTVar (subscrConns c) $ S.delete connId - server_ <- TM.lookupDelete connId $ activeSubscrConns c - mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_ - -removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM () -removePendingSubscription = removeSubs_ . pendingSubscrSrvrs - -removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM () -removeSubs_ ss server connId = - TM.lookup server ss >>= mapM_ (TM.delete connId) + TM2.delete connId $ activeSubs c + TM2.delete connId $ pendingSubs c getSubscriptions :: AgentClient -> STM (Set ConnId) getSubscriptions = readTVar . subscrConns diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index 761a41c93..2f6e0cf8a 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -2,6 +2,7 @@ module Simplex.Messaging.TMap ( TMap, empty, singleton, + clear, Simplex.Messaging.TMap.null, Simplex.Messaging.TMap.lookup, member, @@ -31,6 +32,10 @@ singleton :: k -> a -> STM (TMap k a) singleton k v = newTVar $ M.singleton k v {-# INLINE singleton #-} +clear :: TMap k a -> STM () +clear m = writeTVar m M.empty +{-# INLINE clear #-} + null :: TMap k a -> STM Bool null m = M.null <$> readTVar m {-# INLINE null #-} diff --git a/src/Simplex/Messaging/TMap2.hs b/src/Simplex/Messaging/TMap2.hs new file mode 100644 index 000000000..69d42e47d --- /dev/null +++ b/src/Simplex/Messaging/TMap2.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} + +module Simplex.Messaging.TMap2 + ( TMap2, + empty, + clear, + Simplex.Messaging.TMap2.lookup, + lookup1, + member, + insert, + insert1, + delete, + lookupDelete1, + ) +where + +import Control.Concurrent.STM +import Control.Monad (forM_, (>=>)) +import qualified Data.Map.Strict as M +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util (whenM, ($>>=)) + +-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys. +-- It allows direct access via k1 to a group of k2 values and via k2 to one value +data TMap2 k1 k2 a = TMap2 + { _m1 :: TMap k1 (TMap k2 a), + _m2 :: TMap k2 k1 + } + +empty :: STM (TMap2 k1 k2 a) +empty = TMap2 <$> TM.empty <*> TM.empty + +clear :: TMap2 k1 k2 a -> STM () +clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2 + +lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a) +lookup k2 TMap2 {_m1, _m2} = do + TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2 + +lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a)) +lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1 +{-# INLINE lookup1 #-} + +member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool +member k2 TMap2 {_m2} = TM.member k2 _m2 +{-# INLINE member #-} + +insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM () +insert k1 k2 v TMap2 {_m1, _m2} = + TM.lookup k2 _m2 >>= \case + Just k1' + | k1 == k1' -> _insert1 + | otherwise -> _delete1 k1' k2 _m1 >> _insert2 + _ -> _insert2 + where + _insert1 = + TM.lookup k1 _m1 >>= \case + Just m -> TM.insert k2 v m + _ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1 + _insert2 = TM.insert k2 k1 _m2 >> _insert1 + +insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM () +insert1 k1 m' TMap2 {_m1, _m2} = + TM.lookup k1 _m1 >>= \case + Just m -> readTVar m' >>= (`TM.union` m) + _ -> TM.insert k1 m' _m1 + +delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM () +delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1) + +_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM () +_delete1 k1 k2 m1 = + TM.lookup k1 m1 + >>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1)) + +lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a)) +lookupDelete1 k1 TMap2 {_m1, _m2} = do + m_ <- TM.lookupDelete k1 _m1 + forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet + pure m_