diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 48c5c83bb..18e290bfd 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -684,6 +684,7 @@ runCommandProcessing c@AgentClient {subQ} server = do ri <- asks $ messageRetryInterval . config -- different retry interval? forever $ do atomically $ endAgentOperation c AOSndNetwork + atomically $ throwWhenInactive c cmdId <- atomically $ readTQueue cq atomically $ beginAgentOperation c AOSndNetwork E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case @@ -716,7 +717,9 @@ runCommandProcessing c@AgentClient {subQ} server = do retryCommand loop = do -- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent atomically $ endAgentOperation c AOSndNetwork - atomically $ beginAgentOperation c AOSndNetwork + atomically $ do + throwWhenInactive c + beginAgentOperation c AOSndNetwork loop notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd) withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m () @@ -789,6 +792,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh ri <- asks $ messageRetryInterval . config forever $ do atomically $ endAgentOperation c AOSndNetwork + atomically $ throwWhenInactive c msgId <- atomically $ readTQueue mq atomically $ do beginAgentOperation c AOSndNetwork @@ -883,7 +887,9 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh retrySending loop = do -- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent atomically $ endAgentOperation c AOSndNetwork - atomically $ beginAgentOperation c AOSndNetwork + atomically $ do + throwWhenInactive c + beginAgentOperation c AOSndNetwork loop ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> m () @@ -1199,7 +1205,7 @@ getNextSMPServer c usedSrvs = do subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do t <- atomically $ readTBQueue msgQ - agentOperationBracket c AORcvNetwork $ + agentOperationBracket c AORcvNetwork waitUntilActive $ withAgentLock c (runExceptT $ processSMPTransmission c t) >>= \case Left e -> liftIO $ print e Right _ -> return () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9a9b0e1c0..d20a344a9 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -58,6 +58,8 @@ module Simplex.Messaging.Agent.Client AgentState (..), agentOperations, agentOperationBracket, + waitUntilActive, + throwWhenInactive, beginAgentOperation, endAgentOperation, suspendSendingAndDatabase, @@ -72,7 +74,8 @@ where import Control.Concurrent (forkIO) import Control.Concurrent.Async (Async, uninterruptibleCancel) -import Control.Concurrent.STM (retry, stateTVar) +import Control.Concurrent.STM (retry, stateTVar, throwSTM) +import Control.Exception (AsyncException (..)) import Control.Logger.Simple import Control.Monad.Except import Control.Monad.IO.Unlift @@ -417,6 +420,12 @@ closeAgentClient c = liftIO $ do clear :: Monoid m => (AgentClient -> TVar m) -> IO () clear sel = atomically $ writeTVar (sel c) mempty +waitUntilActive :: AgentClient -> STM () +waitUntilActive c = unlessM (readTVar $ active c) retry + +throwWhenInactive :: AgentClient -> STM () +throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled + closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty) @@ -805,10 +814,10 @@ beginAgentOperation c op = do -- unsafeIOToSTM $ putStrLn $ "beginOperation! " <> show op <> " " <> show (opsInProgress s + 1) writeTVar opVar $! s {opsInProgress = opsInProgress s + 1} -agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> m a -> m a -agentOperationBracket c op action = +agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> (AgentClient -> STM ()) -> m a -> m a +agentOperationBracket c op check action = E.bracket - (atomically $ beginAgentOperation c op) + (atomically $ check c >> beginAgentOperation c op) (\_ -> atomically $ endAgentOperation c op) (const action) @@ -818,7 +827,7 @@ withStore' c action = withStore c $ fmap Right . action withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a withStore c action = do st <- asks store - liftEitherError storeError . agentOperationBracket c AODatabase $ + liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ withTransaction st action `E.catch` handleInternal where handleInternal :: E.SomeException -> IO (Either StoreError a) diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 0a45e38a3..57385984a 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -55,7 +55,7 @@ runNtfSupervisor c = do ns <- asks ntfSupervisor forever $ do cmd@(connId, _) <- atomically . readTBQueue $ ntfSubQ ns - handleError connId . agentOperationBracket c AONtfNetwork $ + handleError connId . agentOperationBracket c AONtfNetwork waitUntilActive $ runExceptT (processNtfSub c cmd) >>= \case Left e -> notifyErr connId e Right _ -> return () @@ -162,7 +162,7 @@ runNtfWorker c srv doWork = do delay <- asks $ ntfWorkerDelay . config forever $ do void . atomically $ readTMVar doWork - agentOperationBracket c AONtfNetwork runNtfOperation + agentOperationBracket c AONtfNetwork throwWhenInactive runNtfOperation threadDelay delay where runNtfOperation :: m () @@ -238,7 +238,7 @@ runNtfSMPWorker c srv doWork = do delay <- asks $ ntfSMPWorkerDelay . config forever $ do void . atomically $ readTMVar doWork - agentOperationBracket c AONtfNetwork runNtfSMPOperation + agentOperationBracket c AONtfNetwork throwWhenInactive runNtfSMPOperation threadDelay delay where runNtfSMPOperation = do