diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 65a353696..2e0107d74 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -42,6 +42,7 @@ module Simplex.Messaging.Agent joinConnectionAsync, allowConnectionAsync, ackMessageAsync, + deleteConnectionAsync, createConnection, joinConnection, allowConnection, @@ -76,7 +77,7 @@ module Simplex.Messaging.Agent ) where -import Control.Concurrent.STM (flushTBQueue, stateTVar) +import Control.Concurrent.STM (stateTVar) import Control.Logger.Simple (logInfo, showText) import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) @@ -163,6 +164,10 @@ allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m () ackMessageAsync c = withAgentEnv c .:. ackMessageAsync' c +-- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response +deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m () +deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c + -- | Create SMP agent connection (NEW command) createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c) createConnection c enableNtfs cMode = withAgentEnv c $ newConn c "" False enableNtfs cMode @@ -361,9 +366,24 @@ ackMessageAsync' c corrId connId msgId = SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED where enqueueAck :: RcvQueue -> m () - enqueueAck RcvQueue {server} = do + enqueueAck RcvQueue {server} = enqueueCommand c corrId connId (Just server) $ AClientCommand $ ACK msgId +deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> m () +deleteConnectionAsync' c@AgentClient {subQ} corrId connId = + withStore c (`getConn` connId) >>= \case + SomeConn _ (DuplexConnection _ rq _) -> enqueueDelete rq + SomeConn _ (RcvConnection _ rq) -> enqueueDelete rq + SomeConn _ (ContactConnection _ rq) -> enqueueDelete rq + SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId) >> notifyDeleted + SomeConn _ (NewConnection _) -> withStore' c (`deleteConn` connId) >> notifyDeleted + where + enqueueDelete :: RcvQueue -> m () + enqueueDelete RcvQueue {server} = + enqueueCommand c corrId connId (Just server) $ AClientCommand DEL + notifyDeleted :: m () + notifyDeleted = atomically $ writeTBQueue subQ (corrId, connId, OK) + newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c) newConn c connId asyncMode enableNtfs cMode = getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode @@ -706,6 +726,7 @@ runCommandProcessing c@AgentClient {subQ} server = do notify OK LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK + DEL -> tryCommand $ deleteConnection' c connId >> notify OK _ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd) AInternalCommand cmd -> case server of Just _srv -> case cmd of diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index b21c00493..fe4414d8d 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -429,7 +429,7 @@ testActiveClientNotDisconnected t = do where keepSubscribing :: AgentClient -> ConnId -> SystemTime -> ExceptT AgentErrorType IO () keepSubscribing alice connId ts = do - ts' <- liftIO $ getSystemTime + ts' <- liftIO getSystemTime if milliseconds ts' - milliseconds ts < 2200 then do -- keep sending SUB for 2.2 seconds @@ -603,7 +603,9 @@ testAsyncCommands = do get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False ackMessageAsync alice "7" bobId $ baseId + 4 ("7", _, OK) <- get alice - pure () + deleteConnectionAsync alice "8" bobId + ("8", _, OK) <- get alice + liftIO $ noMessages alice "nothing else should be delivered to alice" pure () where baseId = 3