diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index dfe538844..96bcd4529 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -63,7 +63,7 @@ import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKe import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError) +import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, whenM) import Simplex.Messaging.Version import System.Timeout (timeout) import UnliftIO (async, forConcurrently_) @@ -79,8 +79,8 @@ data AgentClient = AgentClient msgQ :: TBQueue SMPServerTransmission, smpServers :: TVar (NonEmpty SMPServer), smpClients :: TMap SMPServer SMPClientVar, - subscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue), - pendingSubscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue), + subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), + pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), subscrConns :: TMap ConnId SMPServer, connMsgsQueued :: TMap ConnId Bool, smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId), @@ -188,21 +188,24 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = clientDisconnected :: UnliftIO m -> IO () clientDisconnected u = do - removeClientSubs >>= (`forM_` serverDown u) + removeClientAndSubs >>= (`forM_` serverDown u) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeClientSubs :: IO (Maybe (Map ConnId RcvQueue)) - removeClientSubs = atomically $ do + removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) + removeClientAndSubs = atomically $ do TM.delete srv smpClients - cs_ <- TM.lookupDelete srv $ subscrSrvrs c - forM_ cs_ $ \cs -> do - modifyTVar' (TM.tVar $ subscrConns c) (`M.withoutKeys` M.keysSet cs) - modifyTVar' (TM.tVar $ pendingSubscrSrvrs c) $ addPendingSubs cs - return cs_ + cVar_ <- TM.lookupDelete srv $ subscrSrvrs c + forM cVar_ $ \cVar -> do + cs <- readTVar cVar + modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) + addPendingSubs cVar cs + pure cs where - addPendingSubs :: Map ConnId RcvQueue -> Map SMPServer (Map ConnId RcvQueue) -> Map SMPServer (Map ConnId RcvQueue) - addPendingSubs cs = M.alter (Just . addSubs cs) srv - addSubs cs = maybe cs (M.union cs) + addPendingSubs cVar cs = do + let ps = pendingSubscrSrvrs c + TM.lookup srv ps >>= \case + Just v -> TM.union cs v + _ -> TM.insert srv cVar ps serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () serverDown u cs = unless (M.null cs) $ do @@ -221,19 +224,26 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = reconnectClient `catchError` const loop reconnectClient :: m () - reconnectClient = do + reconnectClient = withAgentLock c . withSMP c srv $ \smp -> do - subs <- readTVarIO . TM.tVar $ subscrConns c - cs <- atomically . TM.lookup srv $ pendingSubscrSrvrs c - forConcurrently_ (maybe [] M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) -> - when (isNothing $ M.lookup connId subs) $ do - subscribeSMPQueue smp rcvPrivateKey rcvId - `catchError` \case - e@SMPResponseTimeout -> throwError e - e@SMPNetworkError -> throwError e - e -> liftIO $ notifySub (ERR $ smpClientError e) connId - addSubscription c rq connId - liftIO $ notifySub UP connId + cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c) + forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) -> + whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $ + subscribe_ smp sub `catchError` handleError connId + where + subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT SMPClientError IO () + subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do + subscribeSMPQueue smp rcvPrivateKey rcvId + addSubscription c rq connId + liftIO $ notifySub UP connId + + handleError :: ConnId -> SMPClientError -> ExceptT SMPClientError IO () + handleError connId = \case + e@SMPResponseTimeout -> throwError e + e@SMPNetworkError -> throwError e + e -> do + liftIO $ notifySub (ERR $ smpClientError e) connId + atomically $ removePendingSubscription c srv connId notifySub :: ACommand 'Agent -> ConnId -> IO () notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd) @@ -243,10 +253,10 @@ closeAgentClient c = liftIO $ do closeSMPServerClients c cancelActions $ reconnections c cancelActions $ asyncClients c - cancelActions . TM.tVar $ smpQueueMsgDeliveries c + cancelActions $ smpQueueMsgDeliveries c closeSMPServerClients :: AgentClient -> IO () -closeSMPServerClients c = readTVarIO (TM.tVar $ smpClients c) >>= mapM_ (forkIO . closeClient) +closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient) where closeClient smpVar = atomically (readTMVar smpVar) >>= \case @@ -331,7 +341,7 @@ newRcvQueue_ a c srv = do subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do - addPendingSubscription c rq connId + atomically $ addPendingSubscription c rq connId withLogSMP c server rcvId "SUB" $ \smp -> do liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case Left e -> do @@ -343,17 +353,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do TM.insert connId server $ subscrConns c - addSubs_ rq connId $ subscrSrvrs c + addSubs_ (subscrSrvrs c) rq connId removePendingSubscription c server connId -addPendingSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () -addPendingSubscription c rq connId = - atomically . addSubs_ rq connId $ pendingSubscrSrvrs c +addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM () +addPendingSubscription = addSubs_ . pendingSubscrSrvrs -addSubs_ :: RcvQueue -> ConnId -> TMap SMPServer (Map ConnId RcvQueue) -> STM () -addSubs_ rq@RcvQueue {server} connId = TM.alter (Just . addSub) server - where - addSub = maybe (M.singleton connId rq) (M.insert connId rq) +addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM () +addSubs_ ss rq@RcvQueue {server} connId = + TM.lookup server ss >>= \case + Just m -> TM.insert connId rq m + _ -> TM.singleton connId rq >>= \m -> TM.insert server m ss removeSubscription :: MonadUnliftIO m => AgentClient -> ConnId -> m () removeSubscription c@AgentClient {subscrConns} connId = atomically $ do @@ -363,13 +373,9 @@ removeSubscription c@AgentClient {subscrConns} connId = atomically $ do removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM () removePendingSubscription = removeSubs_ . pendingSubscrSrvrs -removeSubs_ :: TMap SMPServer (Map ConnId RcvQueue) -> SMPServer -> ConnId -> STM () -removeSubs_ ss server connId = TM.update delSub server ss - where - delSub :: Map ConnId RcvQueue -> Maybe (Map ConnId RcvQueue) - delSub cs = - let cs' = M.delete connId cs - in if M.null cs' then Nothing else Just cs' +removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM () +removeSubs_ ss server connId = + TM.lookup server ss >>= mapM_ (TM.delete connId) logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 102830353..938ba9aa1 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -137,7 +137,7 @@ runClientTransport th@THandle {sessionId} = do clientDisconnected :: (MonadUnliftIO m, MonadReader Env m) => Client -> m () clientDisconnected c@Client {subscriptions, connected} = do atomically $ writeTVar connected False - subs <- readTVarIO $ TM.tVar subscriptions + subs <- readTVarIO subscriptions mapM_ cancelSub subs cs <- asks $ subscribers . server atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 5187c0d73..3c4599a97 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -114,9 +114,9 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile} restoreQueues QueueStore {queues, senders, notifiers} s = do (qs, s') <- liftIO $ readWriteStoreLog s atomically $ do - writeTVar (TM.tVar queues) =<< mapM newTVar qs - writeTVar (TM.tVar senders) $ M.foldr' addSender M.empty qs - writeTVar (TM.tVar notifiers) $ M.foldr' addNotifier M.empty qs + writeTVar queues =<< mapM newTVar qs + writeTVar senders $ M.foldr' addSender M.empty qs + writeTVar notifiers $ M.foldr' addNotifier M.empty qs pure s' addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index 012adde4b..a6584903e 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -1,6 +1,7 @@ module Simplex.Messaging.TMap - ( TMap (..), + ( TMap, empty, + singleton, Simplex.Messaging.TMap.lookup, member, insert, @@ -10,6 +11,7 @@ module Simplex.Messaging.TMap adjust, update, alter, + union, ) where @@ -17,44 +19,52 @@ import Control.Concurrent.STM import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -newtype TMap k a = TMap {tVar :: TVar (Map k a)} +type TMap k a = TVar (Map k a) empty :: STM (TMap k a) -empty = TMap <$> newTVar M.empty +empty = newTVar M.empty {-# INLINE empty #-} +singleton :: k -> a -> STM (TMap k a) +singleton k v = newTVar $ M.singleton k v +{-# INLINE singleton #-} + lookup :: Ord k => k -> TMap k a -> STM (Maybe a) -lookup k (TMap m) = M.lookup k <$> readTVar m +lookup k m = M.lookup k <$> readTVar m {-# INLINE lookup #-} member :: Ord k => k -> TMap k a -> STM Bool -member k (TMap m) = M.member k <$> readTVar m +member k m = M.member k <$> readTVar m {-# INLINE member #-} insert :: Ord k => k -> a -> TMap k a -> STM () -insert k v (TMap m) = modifyTVar' m $ M.insert k v +insert k v m = modifyTVar' m $ M.insert k v {-# INLINE insert #-} delete :: Ord k => k -> TMap k a -> STM () -delete k (TMap m) = modifyTVar' m $ M.delete k +delete k m = modifyTVar' m $ M.delete k {-# INLINE delete #-} lookupInsert :: Ord k => k -> a -> TMap k a -> STM (Maybe a) -lookupInsert k v (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv) +lookupInsert k v m = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv) {-# INLINE lookupInsert #-} lookupDelete :: Ord k => k -> TMap k a -> STM (Maybe a) -lookupDelete k (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv) +lookupDelete k m = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv) {-# INLINE lookupDelete #-} adjust :: Ord k => (a -> a) -> k -> TMap k a -> STM () -adjust f k (TMap m) = modifyTVar' m $ M.adjust f k +adjust f k m = modifyTVar' m $ M.adjust f k {-# INLINE adjust #-} update :: Ord k => (a -> Maybe a) -> k -> TMap k a -> STM () -update f k (TMap m) = modifyTVar' m $ M.update f k +update f k m = modifyTVar' m $ M.update f k {-# INLINE update #-} alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM () -alter f k (TMap m) = modifyTVar' m $ M.alter f k +alter f k m = modifyTVar' m $ M.alter f k {-# INLINE alter #-} + +union :: Ord k => Map k a -> TMap k a -> STM () +union m' m = modifyTVar' m $ M.union m' +{-# INLINE union #-}