From d4962daf119c28acae7ad7ea71c1f56a0a55cff2 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 18 Sep 2022 14:02:20 +0100 Subject: [PATCH] internal async commands (#530) * internal async commands * rename * remove GADT from AgentCommand --- src/Simplex/Messaging/Agent.hs | 82 ++++++++++++++++------------ src/Simplex/Messaging/Agent/Store.hs | 76 +++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 43 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 4ebfcdbed..7b3752fbc 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -326,7 +326,7 @@ newConnAsync c corrId enableNtfs cMode = do connAgentVersion <- asks $ maxVersion . smpAgentVRange . config let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent connId <- withStore c $ \db -> createNewConn db g cData cMode - enqueueCommand c corrId connId Nothing $ NEW enableNtfs (ACM cMode) + enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode) pure connId joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId @@ -338,7 +338,7 @@ joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri (ConnReqUriData _ age let duplexHS = connAgentVersion /= 1 cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo + enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo pure connId _ -> throwError $ AGENT A_VERSION joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo = @@ -348,7 +348,7 @@ allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Con allowConnectionAsync' c corrId connId confId ownConnInfo = withStore c (`getConn` connId) >>= \case SomeConn _ (RcvConnection _ RcvQueue {server}) -> - enqueueCommand c corrId connId (Just server) $ LET confId ownConnInfo + enqueueCommand c corrId connId (Just server) $ AClientCommand $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m () @@ -362,7 +362,7 @@ ackMessageAsync' c corrId connId msgId = where enqueueAck :: RcvQueue -> m () enqueueAck RcvQueue {server} = do - enqueueCommand c corrId connId (Just server) $ ACK msgId + enqueueCommand c corrId connId (Just server) $ AClientCommand $ ACK msgId newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c) newConn c connId asyncMode enableNtfs cMode = @@ -643,10 +643,10 @@ sendMessage' c connId msgFlags msg = -- / async command processing v v v -enqueueCommand :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> ACommand 'Client -> m () +enqueueCommand :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> m () enqueueCommand c corrId connId server aCommand = do resumeSrvCmds c server - commandId <- withStore' c $ \db -> createCommand db corrId connId server $ AClientCommand aCommand + commandId <- withStore' c $ \db -> createCommand db corrId connId server aCommand queuePendingCommands c server [commandId] resumeSrvCmds :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m () @@ -694,24 +694,35 @@ runCommandProcessing c@AgentClient {subQ} server = do atomically $ beginAgentOperation c AOSndNetwork E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e) - Right (corrId, connId, AClientCommand cmd) -> processCmd ri corrId connId cmdId cmd + Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd where - processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> ACommand 'Client -> m () + processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m () processCmd ri corrId connId cmdId = \case - NEW enableNtfs (ACM cMode) -> do - usedSrvs <- newTVarIO ([] :: [SMPServer]) - tryCommand . withNextSrv usedSrvs [] $ \srv -> do - (_, cReq) <- newConnSrv c connId True enableNtfs cMode srv - notify $ INV (ACR cMode cReq) - JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do - let initUsed = [smpServer (queueAddress :: SMPQueueAddress)] - usedSrvs <- newTVarIO initUsed - tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do - void $ joinConnSrv c connId True enableNtfs cReq connInfo srv - notify OK - LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK - ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK - cmd -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd) + AClientCommand cmd -> case cmd of + NEW enableNtfs (ACM cMode) -> do + usedSrvs <- newTVarIO ([] :: [SMPServer]) + tryCommand . withNextSrv usedSrvs [] $ \srv -> do + (_, cReq) <- newConnSrv c connId True enableNtfs cMode srv + notify $ INV (ACR cMode cReq) + JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do + let initUsed = [smpServer (queueAddress :: SMPQueueAddress)] + usedSrvs <- newTVarIO initUsed + tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do + void $ joinConnSrv c connId True enableNtfs cReq connInfo srv + notify OK + LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK + ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK + _ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd) + AInternalCommand cmd -> case server of + Just _srv -> case cmd of + ICAckDel _rId srvMsgId msgId -> tryCommand $ ack _rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId) + ICAck _rId srvMsgId -> tryCommand $ ack _rId srvMsgId + _ -> notify $ ERR $ INTERNAL $ "command requires server " <> show (internalCmdTag cmd) + where + ack _rId srvMsgId = do + -- TODO get particular queue + rq <- withStore c (`getRcvQueue` connId) + ackQueueMessage c rq srvMsgId where tryCommand action = withRetryInterval ri $ \loop -> tryError action >>= \case @@ -907,11 +918,15 @@ ackMessage' c connId msgId = do ack rq = do let mId = InternalId msgId srvMsgId <- withStore c $ \db -> setMsgUserAck db connId mId - sendAck c rq srvMsgId `catchError` \case - SMP SMP.NO_MSG -> pure () - e -> throwError e + ackQueueMessage c rq srvMsgId withStore' c $ \db -> deleteMsg db connId mId +ackQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> SMP.MsgId -> m () +ackQueueMessage c rq srvMsgId = + sendAck c rq srvMsgId `catchError` \case + SMP SMP.NO_MSG -> pure () + e -> throwError e + -- | Suspend SMP agent connection (OFF command) in Reader monad suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m () suspendConnection' c connId = @@ -1246,8 +1261,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> tryError agentClientMsg >>= \case Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of - HELLO -> helloMsg >> ack >> withStore' c (\db -> deleteMsg db connId msgId) - REPLY cReq -> replyMsg cReq >> ack >> withStore' c (\db -> deleteMsg db connId msgId) + HELLO -> helloMsg >> ackDel msgId + REPLY cReq -> replyMsg cReq >> ackDel msgId -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command A_MSG body -> do logServer "<--" c srv rId "MSG " @@ -1256,9 +1271,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm Left e@(AGENT A_DUPLICATE) -> do withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} - | userAck -> do - ack - withStore' c $ \db -> deleteMsg db connId internalId + | userAck -> ackDel internalId | otherwise -> do liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case AgentMessage _ (A_MSG body) -> do @@ -1289,10 +1302,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm _ -> prohibited >> ack where ack :: m () - ack = - sendAck c rq srvMsgId `catchError` \case - SMP SMP.NO_MSG -> pure () - e -> throwError e + ack = enqueueCmd $ ICAck rId srvMsgId + ackDel :: InternalId -> m () + ackDel = enqueueCmd . ICAckDel rId srvMsgId + enqueueCmd :: InternalCommand -> m () + enqueueCmd = enqueueCommand c "" connId (Just srv) . AInternalCommand handleNotifyAck :: m () -> m () handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack SMP.END -> diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index fb13fac06..91f5916fd 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -6,7 +6,9 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} module Simplex.Messaging.Agent.Store where @@ -164,29 +166,81 @@ data ConnData = ConnData } deriving (Eq, Show) -data AgentCommand = AClientCommand (ACommand 'Client) +data AgentCmdType = ACClient | ACInternal + +instance StrEncoding AgentCmdType where + strEncode = \case + ACClient -> "CLIENT" + ACInternal -> "INTERNAL" + strP = + A.takeTill (== ' ') >>= \case + "CLIENT" -> pure ACClient + "INTERNAL" -> pure ACInternal + _ -> fail "bad AgentCmdType" + +data AgentCommand + = AClientCommand (ACommand 'Client) + | AInternalCommand InternalCommand instance StrEncoding AgentCommand where strEncode = \case - AClientCommand cmd -> "CLIENT " <> serializeCommand cmd + AClientCommand cmd -> strEncode (ACClient, Str $ serializeCommand cmd) + AInternalCommand cmd -> strEncode (ACInternal, cmd) strP = - A.takeTill (== ' ') >>= \case - "CLIENT" -> AClientCommand <$> (A.space *> ((\(ACmd _ cmd) -> checkParty cmd) <$?> dbCommandP)) - _ -> fail "bad AgentCommand" + strP_ >>= \case + ACClient -> AClientCommand <$> ((\(ACmd _ cmd) -> checkParty cmd) <$?> dbCommandP) + ACInternal -> AInternalCommand <$> strP -data AgentCommandTag = AClientCommandTag (ACommandTag 'Client) +data AgentCommandTag + = AClientCommandTag (ACommandTag 'Client) + | AInternalCommandTag InternalCommandTag instance StrEncoding AgentCommandTag where strEncode = \case - AClientCommandTag t -> "CLIENT " <> strEncode t + AClientCommandTag t -> strEncode (ACClient, t) + AInternalCommandTag t -> strEncode (ACInternal, t) + strP = + strP_ >>= \case + ACClient -> AClientCommandTag <$> strP + ACInternal -> AInternalCommandTag <$> strP + +data InternalCommand + = ICAck SMP.RecipientId MsgId + | ICAckDel SMP.RecipientId MsgId InternalId + +data InternalCommandTag + = ICAck_ + | ICAckDel_ + deriving (Show) + +instance StrEncoding InternalCommand where + strEncode = \case + ICAck rId srvMsgId -> strEncode (ICAck_, rId, srvMsgId) + ICAckDel rId srvMsgId mId -> strEncode (ICAckDel_, rId, srvMsgId, mId) + strP = + strP_ >>= \case + ICAck_ -> ICAck <$> strP_ <*> strP + ICAckDel_ -> ICAckDel <$> strP_ <*> strP_ <*> strP + +instance StrEncoding InternalCommandTag where + strEncode = \case + ICAck_ -> "ACK" + ICAckDel_ -> "ACK_DEL" strP = A.takeTill (== ' ') >>= \case - "CLIENT" -> AClientCommandTag <$> (A.space *> strP) - _ -> fail "bad AgentCommandTag" + "ACK" -> pure ICAck_ + "ACK_DEL" -> pure ICAckDel_ + _ -> fail "bad InternalCommandTag" agentCommandTag :: AgentCommand -> AgentCommandTag agentCommandTag = \case AClientCommand cmd -> AClientCommandTag $ aCommandTag cmd + AInternalCommand cmd -> AInternalCommandTag $ internalCmdTag cmd + +internalCmdTag :: InternalCommand -> InternalCommandTag +internalCmdTag = \case + ICAck {} -> ICAck_ + ICAckDel {} -> ICAckDel_ -- * Confirmation types @@ -302,6 +356,10 @@ data MsgBase = MsgBase newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show) +instance StrEncoding InternalId where + strEncode = strEncode . unId + strP = InternalId <$> strP + type InternalTs = UTCTime type AsyncCmdId = Int64