diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 44c843074..6e5112cef 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -72,7 +72,7 @@ import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, ProtocolServer (..), Qu import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, ifM, liftEitherError, liftError, tryError) +import Simplex.Messaging.Util (bshow, ifM, liftEitherError, liftError, tryError, unlessM, whenM) import Simplex.Messaging.Version import System.Timeout (timeout) import UnliftIO (async, forConcurrently) @@ -86,7 +86,8 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient) type NtfClientVar = TMVar (Either AgentErrorType NtfClient) data AgentClient = AgentClient - { rcvQ :: TBQueue (ATransmission 'Client), + { active :: TVar Bool, + rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue (ServerTransmission BrokerMsg), smpServers :: TVar (NonEmpty SMPServer), @@ -110,6 +111,7 @@ data AgentClient = AgentClient newAgentClient :: InitialAgentServers -> Env -> STM AgentClient newAgentClient InitialAgentServers {smp, ntf} agentEnv = do let qSize = tbqSize $ config agentEnv + active <- newTVar True rcvQ <- newTBQueue qSize subQ <- newTBQueue qSize msgQ <- newTBQueue qSize @@ -127,7 +129,7 @@ newAgentClient InitialAgentServers {smp, ntf} agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) lock <- newTMVar () - return AgentClient {rcvQ, subQ, msgQ, smpServers, ntfServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock} + return AgentClient {active, rcvQ, subQ, msgQ, smpServers, ntfServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock} -- | Agent monad with MonadReader Env and MonadError AgentErrorType type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) @@ -145,7 +147,8 @@ instance ProtocolServerClient NtfResponse where protocolError = NTF getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient -getSMPServerClient c@AgentClient {smpClients, msgQ} srv = +getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do + unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped" atomically (getClientVar srv smpClients) >>= either (newProtocolClient c srv smpClients connectClient reconnectClient) @@ -183,7 +186,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = serverDown u cs = unless (M.null cs) $ do let conns = M.keys cs unless (null conns) . notifySub "" $ DOWN srv conns - unliftIO u reconnectServer + whenM (readTVarIO active) $ unliftIO u reconnectServer reconnectServer :: m () reconnectServer = do @@ -226,7 +229,8 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient -getNtfServerClient c@AgentClient {ntfClients} srv = +getNtfServerClient c@AgentClient {active, ntfClients} srv = do + unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped" atomically (getClientVar srv ntfClients) >>= either (newProtocolClient c srv ntfClients connectClient $ pure ()) @@ -297,6 +301,7 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon closeAgentClient :: MonadIO m => AgentClient -> m () closeAgentClient c = liftIO $ do + atomically $ writeTVar (active c) False closeSMPServerClients c cancelActions $ reconnections c cancelActions $ asyncClients c