diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index d66d2184e..e599d8074 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -122,31 +122,34 @@ type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m) -- | Create SMP agent connection (NEW command) createConnection :: AgentErrorMonad m => AgentClient -> m (ConnId, SMPQueueInfo) -createConnection c = (`runReaderT` agentEnv c) $ newConn c "" +createConnection c = withAgentClient c $ newConn c "" -- | Join SMP agent connection (JOIN command) joinConnection :: AgentErrorMonad m => AgentClient -> SMPQueueInfo -> ConnInfo -> m ConnId -joinConnection c = (`runReaderT` agentEnv c) .: joinConn c "" +joinConnection c = withAgentClient c .: joinConn c "" -- | Approve confirmation (LET command) acceptConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () -acceptConnection c = (`runReaderT` agentEnv c) .:. acceptConnection' c +acceptConnection c = withAgentClient c .:. acceptConnection' c -- | Subscribe to receive connection messages (SUB command) subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () -subscribeConnection c = (`runReaderT` agentEnv c) . subscribeConnection' c +subscribeConnection c = withAgentClient c . subscribeConnection' c -- | Send message to the connection (SEND command) sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgBody -> m InternalId -sendMessage c = (`runReaderT` agentEnv c) .: sendMessage' c +sendMessage c = withAgentClient c .: sendMessage' c -- | Suspend SMP agent connection (OFF command) suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () -suspendConnection c = (`runReaderT` agentEnv c) . suspendConnection' c +suspendConnection c = withAgentClient c . suspendConnection' c -- | Delete SMP agent connection (DEL command) deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () -deleteConnection c = (`runReaderT` agentEnv c) . deleteConnection' c +deleteConnection c = withAgentClient c . deleteConnection' c + +withAgentClient :: AgentErrorMonad m => AgentClient -> ReaderT Env m a -> m a +withAgentClient c = withAgentLock c . (`runReaderT` agentEnv c) -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient @@ -186,10 +189,16 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> logClient AgentClient {clientId} dir (corrId, connId, cmd) = do logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd] +withAgentLock :: MonadUnliftIO m => AgentClient -> m a -> m a +withAgentLock AgentClient {lock} = + E.bracket_ + (void . atomically $ takeTMVar lock) + (atomically $ putTMVar lock ()) + client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () client c@AgentClient {rcvQ, subQ} = forever $ do (corrId, connId, cmd) <- atomically $ readTBQueue rcvQ - runExceptT (processCommand c (connId, cmd)) + withAgentLock c (runExceptT $ processCommand c (connId, cmd)) >>= atomically . writeTBQueue subQ . \case Left e -> (corrId, connId, ERR e) Right (connId', resp) -> (corrId, connId', resp) @@ -380,7 +389,7 @@ sendControlMessage c sq agentMessage = do subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do t <- atomically $ readTBQueue msgQ - runExceptT (processSMPTransmission c t) >>= \case + withAgentLock c (runExceptT $ processSMPTransmission c t) >>= \case Left e -> liftIO $ print e Right _ -> return () @@ -467,17 +476,14 @@ processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m () agentClientMsg externalPrevSndHash sender broker msgBody internalHash = do logServer "<--" c srv rId "MSG " - case status of - Active -> do - internalTs <- liftIO getCurrentTime - (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore (`updateRcvIds` connId) - let integrity = checkMsgIntegrity prevExtSndId (fst sender) prevRcvMsgHash externalPrevSndHash - recipient = (unId internalId, internalTs) - msgMeta = MsgMeta {integrity, recipient, sender, broker} - rcvMsg = RcvMsgData {msgMeta, msgBody, internalRcvId, internalHash, externalPrevSndHash} - withStore $ \st -> createRcvMsg st connId rcvMsg - notify $ MSG msgMeta msgBody - _ -> prohibited + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore (`updateRcvIds` connId) + let integrity = checkMsgIntegrity prevExtSndId (fst sender) prevRcvMsgHash externalPrevSndHash + recipient = (unId internalId, internalTs) + msgMeta = MsgMeta {integrity, recipient, sender, broker} + rcvMsg = RcvMsgData {msgMeta, msgBody, internalRcvId, internalHash, externalPrevSndHash} + withStore $ \st -> createRcvMsg st connId rcvMsg + notify $ MSG msgMeta msgBody checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9ffbd44e0..3f72ebc21 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -73,7 +73,8 @@ data AgentClient = AgentClient activations :: TVar (Map ConnId (Async ())), -- activations of send queues in progress clientId :: Int, agentEnv :: Env, - smpSubscriber :: Async () + smpSubscriber :: Async (), + lock :: TMVar () } newAgentClient :: Env -> STM AgentClient @@ -87,7 +88,8 @@ newAgentClient agentEnv = do subscrConns <- newTVar M.empty activations <- newTVar M.empty clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) - return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, activations, clientId, agentEnv, smpSubscriber = undefined} + lock <- newTMVar () + return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, activations, clientId, agentEnv, smpSubscriber = undefined, lock} -- | Agent monad with MonadReader Env and MonadError AgentErrorType type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index f6d2e2fd9..714501a0c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -186,7 +186,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto getAllConnIds :: SQLiteStore -> m [ConnId] getAllConnIds st = - liftIO . withConnection st $ \db -> + liftIO . withTransaction st $ \db -> concat <$> (DB.query_ db "SELECT conn_alias FROM connections;" :: IO [[ConnId]]) getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn @@ -334,7 +334,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto getAcceptedConfirmation :: SQLiteStore -> ConnId -> m AcceptedConfirmation getAcceptedConfirmation st connId = - liftIOEither . withConnection st $ \db -> + liftIOEither . withTransaction st $ \db -> confirmation <$> DB.query db