diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 176d33403..755734fde 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -128,11 +128,10 @@ import Control.Monad.Reader import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) import qualified Data.Aeson as J -import Data.Bifunctor (bimap, first) +import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) -import Data.Containers.ListUtils (nubOrd) import Data.Either (isRight, rights) import Data.Foldable (foldl', toList) import Data.Functor (($>)) @@ -1134,37 +1133,39 @@ sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do enqueueCommand :: AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> AM () enqueueCommand c corrId connId server aCommand = do withStore c $ \db -> createCommand db corrId connId server aCommand - lift . void $ getAsyncCmdWorker True c server + lift . void $ getAsyncCmdWorker True c connId server -resumeSrvCmds :: AgentClient -> Maybe SMPServer -> AM' () -resumeSrvCmds = void .: getAsyncCmdWorker False +resumeSrvCmds :: AgentClient -> ConnId -> Maybe SMPServer -> AM' () +resumeSrvCmds = void .:. getAsyncCmdWorker False {-# INLINE resumeSrvCmds #-} resumeConnCmds :: AgentClient -> [ConnId] -> AM' () resumeConnCmds c connIds = do - srvs <- nubOrd . concat . rights <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds) - mapM_ (resumeSrvCmds c) srvs + connSrvs <- rights . zipWith (second . (,)) connIds <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds) + mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs -getAsyncCmdWorker :: Bool -> AgentClient -> Maybe SMPServer -> AM' Worker -getAsyncCmdWorker hasWork c server = - getAgentWorker "async_cmd" hasWork c server (asyncCmdWorkers c) (runCommandProcessing c server) +getAsyncCmdWorker :: Bool -> AgentClient -> ConnId -> Maybe SMPServer -> AM' Worker +getAsyncCmdWorker hasWork c connId server = + getAgentWorker "async_cmd" hasWork c (connId, server) (asyncCmdWorkers c) (runCommandProcessing c connId server) -runCommandProcessing :: AgentClient -> Maybe SMPServer -> Worker -> AM () -runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do +data CommandCompletion = CCMoved | CCCompleted + +runCommandProcessing :: AgentClient -> ConnId -> Maybe SMPServer -> Worker -> AM () +runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do ri <- asks $ messageRetryInterval . config -- different retry interval? forever $ do atomically $ endAgentOperation c AOSndNetwork lift $ waitForWork doWork liftIO $ throwWhenInactive c atomically $ beginAgentOperation c AOSndNetwork - withWork c doWork (`getPendingServerCommand` server_) $ runProcessCmd (riFast ri) + withWork c doWork (\db -> getPendingServerCommand db connId server_) $ runProcessCmd (riFast ri) where runProcessCmd ri cmd = do pending <- newTVarIO [] processCmd ri cmd pending mapM_ (atomically . writeTBQueue subQ) . reverse =<< readTVarIO pending processCmd :: RetryInterval -> PendingCommand -> TVar [ATransmission] -> AM () - processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} pendingCmds = case command of + processCmd ri PendingCommand {cmdId, corrId, userId, command} pendingCmds = case command of AClientCommand cmd -> case cmd of NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) @@ -1190,16 +1191,27 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do AInternalCommand cmd -> case cmd of ICAckDel rId srvMsgId msgId -> withServer $ \srv -> tryWithLock "ICAckDel" $ ack srv rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId) ICAck rId srvMsgId -> withServer $ \srv -> tryWithLock "ICAck" $ ack srv rId srvMsgId - ICAllowSecure _rId senderKey -> withServer' . tryWithLock "ICAllowSecure" $ do + ICAllowSecure _rId senderKey -> withServer' . tryMoveableWithLock "ICAllowSecure" $ do (SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <- withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId) case conn of RcvConnection cData rq -> do mapM_ (secure rq) senderKey mapM_ (connectReplyQueues c cData ownConnInfo Nothing) (L.nonEmpty $ smpReplyQueues senderConf) + pure CCCompleted -- duplex connection is matched to handle SKEY retries - DuplexConnection cData _ (sq :| _) -> - mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf) + DuplexConnection cData _ (sq :| _) -> do + tryAgentError (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case + Right () -> pure CCCompleted + Left e + | temporaryOrHostError e && Just server /= server_ -> do + -- In case the server is different we update server to remove command from this (connId, srv) queue + withStore c $ \db -> updateCommandServer db cmdId server + lift . void $ getAsyncCmdWorker True c connId (Just server) + pure CCMoved + | otherwise -> throwE e + where + server = qServer sq _ -> throwE $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd) ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do secure rq senderKey @@ -1272,15 +1284,18 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do withStore c (`getConn` connId) >>= \case SomeConn _ conn@DuplexConnection {} -> a conn _ -> internalErr "command requires duplex connection" - tryCommand action = withRetryInterval ri $ \_ loop -> do + tryCommand action = tryMoveableCommand (action $> CCCompleted) + tryMoveableCommand action = withRetryInterval ri $ \_ loop -> do liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c - tryError action >>= \case + tryAgentError action >>= \case Left e | temporaryOrHostError e -> retrySndOp c loop | otherwise -> cmdError e - Right () -> withStore' c (`deleteCommand` cmdId) + Right CCCompleted -> withStore' c (`deleteCommand` cmdId) + Right CCMoved -> pure () -- command processing moved to another command queue tryWithLock name = tryCommand . withConnLock c connId name + tryMoveableWithLock name = tryMoveableCommand . withConnLock c connId name internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command) cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId) notify :: forall e. AEntityI e => AEvent e -> AM () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index fbdb53548..8f339cab7 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -313,7 +313,7 @@ data AgentClient = AgentClient removedSubs :: TMap (UserId, SMPServer, SMP.RecipientId) SMPClientError, workerSeq :: TVar Int, smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()), - asyncCmdWorkers :: TMap (Maybe SMPServer) Worker, + asyncCmdWorkers :: TMap (ConnId, Maybe SMPServer) Worker, ntfNetworkOp :: TVar AgentOpState, rcvNetworkOp :: TVar AgentOpState, msgDeliveryOp :: TVar AgentOpState, diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 69b1b07e9..97b32eca8 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -137,6 +137,7 @@ module Simplex.Messaging.Agent.Store.SQLite createCommand, getPendingCommandServers, getPendingServerCommand, + updateCommandServer, deleteCommand, -- Notification device token persistence createNtfToken, @@ -1323,38 +1324,39 @@ getPendingCommandServers db connId = do where smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash -getPendingServerCommand :: DB.Connection -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) -getPendingServerCommand db srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed +getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) +getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed where getCmdId :: IO (Maybe Int64) getCmdId = maybeFirstRow fromOnly $ case srv_ of Nothing -> - DB.query_ + DB.query db [sql| SELECT command_id FROM commands - WHERE host IS NULL AND port IS NULL AND failed = 0 + WHERE conn_id = ? AND host IS NULL AND port IS NULL AND failed = 0 ORDER BY created_at ASC, command_id ASC LIMIT 1 |] + (Only connId) Just (SMPServer host port _) -> DB.query db [sql| SELECT command_id FROM commands - WHERE host = ? AND port = ? AND failed = 0 + WHERE conn_id = ? AND host = ? AND port = ? AND failed = 0 ORDER BY created_at ASC, command_id ASC LIMIT 1 |] - (host, port) + (connId, host, port) getCommand :: Int64 -> IO (Either StoreError PendingCommand) getCommand cmdId = firstRow pendingCommand err $ DB.query db [sql| - SELECT c.corr_id, cs.user_id, c.conn_id, c.command + SELECT c.corr_id, cs.user_id, c.command FROM commands c JOIN connections cs USING (conn_id) WHERE c.command_id = ? @@ -1362,9 +1364,22 @@ getPendingServerCommand db srv_ = getWorkItem "command" getCmdId getCommand mark (Only cmdId) where err = SEInternal $ "command " <> bshow cmdId <> " returned []" - pendingCommand (corrId, userId, connId, command) = PendingCommand {cmdId, corrId, userId, connId, command} + pendingCommand (corrId, userId, command) = PendingCommand {cmdId, corrId, userId, connId, command} markCommandFailed cmdId = DB.execute db "UPDATE commands SET failed = 1 WHERE command_id = ?" (Only cmdId) +updateCommandServer :: DB.Connection -> AsyncCmdId -> SMPServer -> IO (Either StoreError ()) +updateCommandServer db cmdId srv@(SMPServer host port _) = runExceptT $ do + serverKeyHash_ <- ExceptT $ getServerKeyHash_ db srv + liftIO $ + DB.execute + db + [sql| + UPDATE commands + SET host = ?, port = ?, server_key_hash = ? + WHERE command_id = ? + |] + (host, port, serverKeyHash_, cmdId) + deleteCommand :: DB.Connection -> AsyncCmdId -> IO () deleteCommand db cmdId = DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index f0a25b758..29725f96a 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1063,13 +1063,12 @@ testAllowConnectionClientRestart t = do threadDelay 250000 alice2 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB + runRight_ $ subscribeConnection alice2 bobId + threadDelay 500000 withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \_ -> do runRight $ do ("", "", UP _ _) <- nGet bob - - subscribeConnection alice2 bobId - get alice2 ##> ("", bobId, CON) get bob ##> ("", aliceId, INFO "alice's connInfo") get bob ##> ("", aliceId, CON) diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 95096e800..f876603f5 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -661,30 +661,30 @@ testGetPendingServerCommand :: SQLiteStore -> Expectation testGetPendingServerCommand st = do g <- C.newRandom withTransaction st $ \db -> do - Right Nothing <- getPendingServerCommand db Nothing + Right Nothing <- getPendingServerCommand db "" Nothing Right connId <- createNewConn db g cData1 {connId = ""} SCMInvitation Right () <- createCommand db "1" connId Nothing command corruptCmd db "1" connId Right () <- createCommand db "2" connId Nothing command - Left e <- getPendingServerCommand db Nothing + Left e <- getPendingServerCommand db connId Nothing show e `shouldContain` "bad AgentCmdType" DB.query_ db "SELECT conn_id, corr_id FROM commands WHERE failed = 1" `shouldReturn` [(connId, "1" :: ByteString)] - Right (Just PendingCommand {corrId}) <- getPendingServerCommand db Nothing + Right (Just PendingCommand {corrId}) <- getPendingServerCommand db connId Nothing corrId `shouldBe` "2" Right _ <- updateNewConnRcv db connId rcvQueue1 - Right Nothing <- getPendingServerCommand db $ Just smpServer1 + Right Nothing <- getPendingServerCommand db connId $ Just smpServer1 Right () <- createCommand db "3" connId (Just smpServer1) command corruptCmd db "3" connId Right () <- createCommand db "4" connId (Just smpServer1) command - Left e' <- getPendingServerCommand db (Just smpServer1) + Left e' <- getPendingServerCommand db connId (Just smpServer1) show e' `shouldContain` "bad AgentCmdType" DB.query_ db "SELECT conn_id, corr_id FROM commands WHERE failed = 1" `shouldReturn` [(connId, "1" :: ByteString), (connId, "3" :: ByteString)] - Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1) + Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db connId (Just smpServer1) corrId' `shouldBe` "4" where command = AClientCommand $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe