From ebb75ced121353aa4ec2b64cf1582af86015f0e2 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin Date: Sat, 13 Apr 2024 18:33:12 +0100 Subject: [PATCH] extract SessionVar from AgentClient to reuse (#1099) --- simplexmq.cabal | 1 + src/Simplex/Messaging/Agent/Client.hs | 68 ++++++++---------------- src/Simplex/Messaging/Client/Agent.hs | 76 +++++++++++++++------------ src/Simplex/Messaging/Session.hs | 38 ++++++++++++++ 4 files changed, 104 insertions(+), 79 deletions(-) create mode 100644 src/Simplex/Messaging/Session.hs diff --git a/simplexmq.cabal b/simplexmq.cabal index 7c6433727..6d04f9775 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -147,6 +147,7 @@ library Simplex.Messaging.Server.Stats Simplex.Messaging.Server.StoreLog Simplex.Messaging.ServiceScheme + Simplex.Messaging.Session Simplex.Messaging.TMap Simplex.Messaging.Transport Simplex.Messaging.Transport.Buffer diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 7041f10c6..4fbcae425 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -152,7 +152,6 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Composition ((.:.)) import Data.Either (lefts, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) @@ -227,6 +226,7 @@ import Simplex.Messaging.Protocol sameSrvAddr', ) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion) @@ -241,11 +241,6 @@ import UnliftIO.Directory (doesFileExist, getTemporaryDirectory, removeFile) import qualified UnliftIO.Exception as E import UnliftIO.STM -data SessionVar a = SessionVar - { sessionVar :: TMVar a, - sessionVarId :: Int - } - type ClientVar msg = SessionVar (Either AgentErrorType (Client msg)) type SMPClientVar = ClientVar SMP.BrokerMsg @@ -550,9 +545,9 @@ instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where clientSessionTs = X.xftpSessionTs getSMPServerClient :: AgentClient -> SMPTransportSession -> AM SMPClient -getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do +getSMPServerClient c@AgentClient {active, smpClients, msgQ, workerSeq} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE - atomically (getTSessVar c tSess smpClients) + atomically (getSessVar workerSeq tSess smpClients) >>= either newClient (waitForProtocolClient c tSess) where -- we resubscribe only on newClient error, but not on waitForProtocolClient error, @@ -581,7 +576,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, removeClientAndSubs :: IO ([RcvQueue], [ConnId]) removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], []) where - currentActiveClient = (&&) <$> removeTSessVar' v tSess smpClients <*> readTVar active + currentActiveClient = (&&) <$> removeSessVar' v tSess smpClients <*> readTVar active removeSubs = do (qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c RQ.batchAddQueues (pendingSubs c) qs @@ -600,14 +595,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd) resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' () -resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess = +resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = atomically getWorkerVar >>= mapM_ (either newSubWorker (\_ -> pure ())) where getWorkerVar = ifM (null <$> getPending) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call - (Just <$> getTSessVar c tSess smpSubWorkers) + (Just <$> getSessVar workerSeq tSess smpSubWorkers) newSubWorker v = do a <- async $ void (E.tryAny runSubWorker) >> atomically (cleanup v) atomically $ putTMVar (sessionVar v) a @@ -626,7 +621,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess = -- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar. -- Not waiting may result in terminated worker remaining in the map. whenM (isEmptyTMVar $ sessionVar v) retry - removeTSessVar v tSess smpSubWorkers + removeSessVar v tSess smpSubWorkers reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM () reconnectSMPClient tc c tSess@(_, srv, _) qs = do @@ -660,9 +655,9 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd) getNtfServerClient :: AgentClient -> NtfTransportSession -> AM NtfClient -getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do +getNtfServerClient c@AgentClient {active, ntfClients, workerSeq} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE - atomically (getTSessVar c tSess ntfClients) + atomically (getSessVar workerSeq tSess ntfClients) >>= either (newProtocolClient c tSess ntfClients connectClient) (waitForProtocolClient c tSess) @@ -677,15 +672,15 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d clientDisconnected :: NtfClientVar -> NtfClient -> IO () clientDisconnected v client = do - atomically $ removeTSessVar v tSess ntfClients + atomically $ removeSessVar v tSess ntfClients incClientStat c userId client "DISCONNECT" "" atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient -getXFTPServerClient c@AgentClient {active, xftpClients} tSess@(userId, srv, _) = do +getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE - atomically (getTSessVar c tSess xftpClients) + atomically (getSessVar workerSeq tSess xftpClients) >>= either (newProtocolClient c tSess xftpClients connectClient) (waitForProtocolClient c tSess) @@ -701,32 +696,11 @@ getXFTPServerClient c@AgentClient {active, xftpClients} tSess@(userId, srv, _) = clientDisconnected :: XFTPClientVar -> XFTPClient -> IO () clientDisconnected v client = do - atomically $ removeTSessVar v tSess xftpClients + atomically $ removeSessVar v tSess xftpClients incClientStat c userId client "DISCONNECT" "" atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv -getTSessVar :: forall a s. AgentClient -> TransportSession s -> TMap (TransportSession s) (SessionVar a) -> STM (Either (SessionVar a) (SessionVar a)) -getTSessVar c tSess vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lookup tSess vs - where - newSessionVar :: STM (SessionVar a) - newSessionVar = do - sessionVar <- newEmptyTMVar - sessionVarId <- stateTVar (workerSeq c) $ \next -> (next, next + 1) - let v = SessionVar {sessionVar, sessionVarId} - TM.insert tSess v vs - pure v - -removeTSessVar :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM () -removeTSessVar = void .:. removeTSessVar' -{-# INLINE removeTSessVar #-} - -removeTSessVar' :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM Bool -removeTSessVar' v tSess vs = - TM.lookup tSess vs >>= \case - Just v' | sessionVarId v == sessionVarId v' -> TM.delete tSess vs $> True - _ -> pure False - waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg) waitForProtocolClient c (_, srv, _) v = do NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c @@ -757,7 +731,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = Left e -> do liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e atomically $ do - removeTSessVar v tSess clients + removeSessVar v tSess clients putTMVar (sessionVar v) (Left e) throwError e -- signal error to caller @@ -781,10 +755,11 @@ getNetworkConfig c = do waitForUserNetwork :: AgentClient -> AM' () waitForUserNetwork AgentClient {userNetworkState} = - (offline <$> readTVarIO userNetworkState) >>= mapM_ waitWhileOffline + readTVarIO userNetworkState >>= mapM_ waitWhileOffline . offline where waitWhileOffline UNSOffline {offlineDelay = d} = - unlessM (liftIO $ waitOnline d False) $ do -- network delay reached, increase delay + unlessM (liftIO $ waitOnline d False) $ do + -- network delay reached, increase delay ts' <- liftIO getCurrentTime ni <- asks $ userNetworkInterval . config atomically $ do @@ -794,7 +769,7 @@ waitForUserNetwork AgentClient {userNetworkState} = -- and to reset `offlineDelay` if network went `on` and `off` again. writeTVar userNetworkState $! let d'' = nextRetryDelay (diffToMicroseconds $ diffUTCTime ts' ts) (min d d') ni - in ns {offline = Just UNSOffline {offlineDelay = d'', offlineFrom = ts}} + in ns {offline = Just UNSOffline {offlineDelay = d'', offlineFrom = ts}} waitOnline :: Int64 -> Bool -> IO Bool waitOnline t online' | t <= 0 = pure online' @@ -866,9 +841,10 @@ closeClient c clientSel tSess = closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO () closeClient_ c v = do NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c - E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case - Just (Right client) -> closeProtocolServerClient client `catchAll_` pure () - _ -> pure () + E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ + tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case + Just (Right client) -> closeProtocolServerClient client `catchAll_` pure () + _ -> pure () closeXFTPServerClient :: AgentClient -> UserId -> XFTPServer -> FileDigest -> IO () closeXFTPServerClient c userId server (FileDigest chunkDigest) = diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 8e21aada1..4b925c6f6 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -40,6 +40,7 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer) +import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport @@ -50,7 +51,7 @@ import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM -type SMPClientVar = TMVar (Either SMPClientError SMPClient) +type SMPClientVar = SessionVar (Either SMPClientError SMPClient) data SMPClientAgentEvent = CAConnected SMPServer @@ -100,7 +101,8 @@ data SMPClientAgent = SMPClientAgent srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey), reconnections :: TVar [Async ()], - asyncClients :: TVar [Async ()] + asyncClients :: TVar [Async ()], + workerSeq :: TVar Int } newtype InternalException e = InternalException {unInternalException :: e} @@ -115,9 +117,10 @@ instance Exception e => MonadUnliftIO (ExceptT e IO) where ExceptT . fmap (first unInternalException) . E.try $ withRunInIO $ \run -> inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT) - -- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`, - -- the last two lines could be replaced with: - -- inner $ either (E.throwIO . InternalException) pure <=< runExceptT + +-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`, +-- the last two lines could be replaced with: +-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where {-# INLINE withRunInIO #-} @@ -136,46 +139,53 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg pendingSrvSubs <- TM.empty reconnections <- newTVar [] asyncClients <- newTVar [] - pure SMPClientAgent {agentCfg, msgQ, agentQ, randomDrg, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients} + workerSeq <- newTVar 0 + pure + SMPClientAgent + { agentCfg, + msgQ, + agentQ, + randomDrg, + smpClients, + srvSubs, + pendingSrvSubs, + reconnections, + asyncClients, + workerSeq + } getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient -getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} srv = +getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg, workerSeq} srv = atomically getClientVar >>= either newSMPClient waitForSMPClient where getClientVar :: STM (Either SMPClientVar SMPClientVar) - getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients - - newClientVar :: STM SMPClientVar - newClientVar = do - smpVar <- newEmptyTMVar - TM.insert srv smpVar smpClients - pure smpVar + getClientVar = getSessVar workerSeq srv smpClients waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient - waitForSMPClient smpVar = do + waitForSMPClient v = do let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg - smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar) + smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) liftEither $ case smpClient_ of Just (Right smpClient) -> Right smpClient Just (Left e) -> Left e Nothing -> Left PCEResponseTimeout newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient - newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync) + newSMPClient v = tryConnectClient pure (liftIO tryConnectAsync) where tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a tryConnectClient successAction retryAction = - tryE connectClient >>= \r -> case r of + tryE (connectClient v) >>= \r -> case r of Right smp -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - atomically $ putTMVar smpVar r + atomically $ putTMVar (sessionVar v) r successAction smp Left e -> do if e == PCENetworkError || e == PCEResponseTimeout then retryAction else atomically $ do - putTMVar smpVar (Left e) - TM.delete srv smpClients + putTMVar (sessionVar v) (Left e) + removeSessVar v srv smpClients throwE e tryConnectAsync :: IO () tryConnectAsync = do @@ -186,17 +196,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr withRetryInterval (reconnectInterval agentCfg) $ \_ loop -> void $ tryConnectClient (const reconnectClient) loop - connectClient :: ExceptT SMPClientError IO SMPClient - connectClient = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected + connectClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient + connectClient v = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) (clientDisconnected v) - clientDisconnected :: SMPClient -> IO () - clientDisconnected _ = do - removeClientAndSubs >>= (`forM_` serverDown) + clientDisconnected :: SMPClientVar -> SMPClient -> IO () + clientDisconnected v _ = do + removeClientAndSubs v >>= (`forM_` serverDown) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateAuthKey)) - removeClientAndSubs = atomically $ do - TM.delete srv smpClients + removeClientAndSubs :: SMPClientVar -> IO (Maybe (Map SMPSub C.APrivateAuthKey)) + removeClientAndSubs v = atomically $ do + removeSessVar v srv smpClients TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs where updateSubs sVar = do @@ -207,7 +217,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr addPendingSubs sVar ss = do let ps = pendingSrvSubs ca TM.lookup srv ps >>= \case - Just v -> TM.union ss v + Just ss' -> TM.union ss ss' _ -> TM.insert srv sVar ps serverDown :: Map SMPSub C.APrivateAuthKey -> IO () @@ -268,10 +278,10 @@ closeSMPClientAgent c = do cancelActions $ asyncClients c closeSMPServerClients :: SMPClientAgent -> IO () -closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient) +closeSMPServerClients c = atomically (smpClients c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient) where - closeClient smpVar = - atomically (readTMVar smpVar) >>= \case + closeClient v = + atomically (readTMVar $ sessionVar v) >>= \case Right smp -> closeProtocolClient smp `catchAll_` pure () _ -> pure () diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs new file mode 100644 index 000000000..7a219e106 --- /dev/null +++ b/src/Simplex/Messaging/Session.hs @@ -0,0 +1,38 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Session where + +import Control.Concurrent.STM +import Control.Monad +import Data.Composition ((.:.)) +import Data.Functor (($>)) +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM + +data SessionVar a = SessionVar + { sessionVar :: TMVar a, + sessionVarId :: Int + } + +getSessVar :: forall k a. Ord k => TVar Int -> k -> TMap k (SessionVar a) -> STM (Either (SessionVar a) (SessionVar a)) +getSessVar sessSeq sessKey vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lookup sessKey vs + where + newSessionVar :: STM (SessionVar a) + newSessionVar = do + sessionVar <- newEmptyTMVar + sessionVarId <- stateTVar sessSeq $ \next -> (next, next + 1) + let v = SessionVar {sessionVar, sessionVarId} + TM.insert sessKey v vs + pure v + +removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM () +removeSessVar = void .:. removeSessVar' +{-# INLINE removeSessVar #-} + +removeSessVar' :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM Bool +removeSessVar' v sessKey vs = + TM.lookup sessKey vs >>= \case + Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True + _ -> pure False