From 8fde8e1344699cdcdc67709595c9285cd06bbef3 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 10 Mar 2023 09:10:52 +0000 Subject: [PATCH] xftp: agent command entities (#676) * xftp: agent command entities * progress event * parameterize agent command by entity * refactor * Eq instance for APartyCmdTag --- src/Simplex/FileTransfer/Agent.hs | 9 +- src/Simplex/Messaging/Agent.hs | 88 ++-- src/Simplex/Messaging/Agent/Client.hs | 18 +- .../Messaging/Agent/NtfSubSupervisor.hs | 8 +- src/Simplex/Messaging/Agent/Protocol.hs | 415 +++++++++++------- src/Simplex/Messaging/Agent/Server.hs | 8 +- src/Simplex/Messaging/Agent/Store.hs | 10 +- tests/AgentTests.hs | 42 +- tests/AgentTests/FunctionalAPITests.hs | 46 +- tests/SMPAgentClient.hs | 4 +- tests/XFTPAgent.hs | 8 +- 11 files changed, 385 insertions(+), 271 deletions(-) diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 334ace1fd..f86bc27db 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -7,6 +7,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Simplex.FileTransfer.Agent ( -- Receiving files @@ -134,7 +135,7 @@ workerInternalError c rcvFileId internalErrStr = do notifyInternalError c rcvFileId internalErrStr notifyInternalError :: (MonadUnliftIO m) => AgentClient -> RcvFileId -> String -> m () -notifyInternalError AgentClient {subQ} rcvFileId internalErrStr = atomically $ writeTBQueue subQ ("", "", FRCVERR rcvFileId $ INTERNAL internalErrStr) +notifyInternalError AgentClient {subQ} rcvFileId internalErrStr = atomically $ writeTBQueue subQ ("", "", APC SAERcvFile $ RFERR rcvFileId $ INTERNAL internalErrStr) runXFTPLocalWorker :: forall m. AgentMonad m => AgentClient -> TMVar () -> m () runXFTPLocalWorker c@AgentClient {subQ} doWork = do @@ -159,10 +160,10 @@ runXFTPLocalWorker c@AgentClient {subQ} doWork = do path <- decrypt encSize chunkPaths whenM (doesPathExist tmpPath) $ removeDirectoryRecursive tmpPath withStore' c $ \db -> updateRcvFileComplete db rcvFileId path - notify $ FRCVD rcvFileId path + notify $ RFDONE rcvFileId path where - notify :: ACommand 'Agent -> m () - notify cmd = atomically $ writeTBQueue subQ ("", "", cmd) + notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify cmd = atomically $ writeTBQueue subQ ("", "", APC (sAEntity @e) cmd) getChunkPaths :: [RcvFileChunk] -> m [FilePath] getChunkPaths [] = pure [] getChunkPaths (RcvFileChunk {chunkTmpPath = Just path} : cs) = do diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 523425be0..91e9579e2 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -135,7 +135,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, protoServer, sameSrvAddr') +import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, protoServer, sameSrvAddr') import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util @@ -379,29 +379,31 @@ runAgentClient c = race_ (subscriber c) (client c) client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () client c@AgentClient {rcvQ, subQ} = forever $ do - (corrId, connId, cmd) <- atomically $ readTBQueue rcvQ - runExceptT (processCommand c (connId, cmd)) + (corrId, entId, cmd) <- atomically $ readTBQueue rcvQ + runExceptT (processCommand c (entId, cmd)) >>= atomically . writeTBQueue subQ . \case - Left e -> (corrId, connId, ERR e) - Right (connId', resp) -> (corrId, connId', resp) + Left e -> (corrId, entId, APC SAEConn $ ERR e) + Right (entId', resp) -> (corrId, entId', resp) -- | execute any SMP agent command -processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent) -processCommand c (connId, cmd) = case cmd of - NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing - JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo - LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) - ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo - RJCT invId -> rejectContact' c connId invId $> (connId, OK) - SUB -> subscribeConnection' c connId $> (connId, OK) - SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody - ACK msgId -> ackMessage' c connId msgId $> (connId, OK) - SWCH -> switchConnection' c connId $> (connId, OK) - OFF -> suspendConnection' c connId $> (connId, OK) - DEL -> deleteConnection' c connId $> (connId, OK) - CHK -> (connId,) . STAT <$> getConnectionServers' c connId +processCommand :: forall m. AgentMonad m => AgentClient -> (EntityId, APartyCmd 'Client) -> m (EntityId, APartyCmd 'Agent) +processCommand c (connId, APC e cmd) = + second (APC e) <$> case cmd of + NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing + JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo + LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) + ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo + RJCT invId -> rejectContact' c connId invId $> (connId, OK) + SUB -> subscribeConnection' c connId $> (connId, OK) + SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody + ACK msgId -> ackMessage' c connId msgId $> (connId, OK) + SWCH -> switchConnection' c connId $> (connId, OK) + OFF -> suspendConnection' c connId $> (connId, OK) + DEL -> deleteConnection' c connId $> (connId, OK) + CHK -> (connId,) . STAT <$> getConnectionServers' c connId where -- command interface does not support different users + userId :: UserId userId = 1 createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId @@ -419,12 +421,12 @@ deleteUser' c userId delSMPQueues = do where delUser = whenM (withStore' c (`deleteUserWithoutConns` userId)) $ - atomically $ writeTBQueue (subQ c) ("", "", DEL_USER userId) + atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId newConnAsync c userId corrId enableNtfs cMode = do connId <- newConnNoQueues c userId "" enableNtfs cMode - enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode) + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pure connId newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId @@ -443,7 +445,7 @@ joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData let duplexHS = connAgentVersion /= 1 cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo pure connId _ -> throwError $ AGENT A_VERSION joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _cInfo = @@ -453,7 +455,7 @@ allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Con allowConnectionAsync' c corrId connId confId ownConnInfo = withStore c (`getConn` connId) >>= \case SomeConn _ (RcvConnection _ RcvQueue {server}) -> - enqueueCommand c corrId connId (Just server) $ AClientCommand $ LET confId ownConnInfo + enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> m ConnId @@ -480,7 +482,7 @@ ackMessageAsync' c corrId connId msgId = do enqueueAck :: m () enqueueAck = do (RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId - enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId + enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () deleteConnectionAsync' c connId = deleteConnectionsAsync' c [connId] @@ -502,7 +504,7 @@ deleteConnectionsAsync_ onSuccess c connIds = case connIds of switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m () switchConnectionAsync' c corrId connId = withStore c (`getConn` connId) >>= \case - SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH + SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH _ -> throwError $ CMD PROHIBITED newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c) @@ -708,7 +710,7 @@ subscribeConnections' c connIds = do let actual = M.size rs expected = length connIds when (actual /= expected) . atomically $ - writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected) + writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected) resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m () resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId] @@ -823,12 +825,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do cmdId <- atomically $ readTQueue cq atomically $ beginAgentOperation c AOSndNetwork E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case - Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e) + Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", APC SAEConn $ ERR $ INTERNAL $ show e) Right cmd -> processCmd (riFast ri) cmdId cmd where processCmd :: RetryInterval -> AsyncCmdId -> PendingCommand -> m () processCmd ri cmdId PendingCommand {corrId, userId, connId, command} = case command of - AClientCommand cmd -> case cmd of + AClientCommand (APC _ cmd) -> case cmd of NEW enableNtfs (ACM cMode) -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) tryCommand . withNextSrv usedSrvs [] $ \srv -> do @@ -915,7 +917,8 @@ runCommandProcessing c@AgentClient {subQ} server_ = do tryWithLock name = tryCommand . withConnLock c connId name internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command) cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId) - notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd) + notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd) withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m () withNextSrv usedSrvs initUsed action = do used <- readTVarIO usedSrvs @@ -1124,9 +1127,9 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, dupl where delMsg :: InternalId -> m () delMsg msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId - notify :: ACommand 'Agent -> m () - notify cmd = atomically $ writeTBQueue subQ ("", connId, cmd) - notifyDel :: InternalId -> ACommand 'Agent -> m () + notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify cmd = atomically $ writeTBQueue subQ ("", connId, APC (sAEntity @e) cmd) + notifyDel :: AEntityI e => InternalId -> ACommand 'Agent e -> m () notifyDel msgId cmd = notify cmd >> delMsg msgId connError msgId = notifyDel msgId . ERR . CONN qError msgId = notifyDel msgId . ERR . AGENT . A_QUEUE @@ -1245,7 +1248,7 @@ deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> [RcvQueue] deleteConnQueues c ntf rqs = do rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs) forM_ (M.assocs rs) $ \case - (connId, Right _) -> withStore' c (`deleteConn` connId) >> notify ("", connId, DEL_CONN) + (connId, Right _) -> withStore' c (`deleteConn` connId) >> notify ("", connId, APC SAEConn DEL_CONN) _ -> pure () pure rs where @@ -1259,7 +1262,7 @@ deleteConnQueues c ntf rqs = do | temporaryOrHostError e && deleteErrors rq + 1 < maxErrs -> withStore' c (`incRcvDeleteErrors` rq) $> r | otherwise -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq (Just e) $> Right () pure (rq, r') - notifyRQ rq e_ = notify ("", qConnId rq, DEL_RCVQ (qServer rq) (queueId rq) e_) + notifyRQ rq e_ = notify ("", qConnId rq, APC SAEConn $ DEL_RCVQ (qServer rq) (queueId rq) e_) notify = when ntf . atomically . writeTBQueue (subQ c) connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ()) connResults = M.map snd . foldl' addResult M.empty @@ -1297,7 +1300,7 @@ deleteConnections_ getConnections ntf c connIds = do let actual = M.size rs expected = length connIds when (actual /= expected) . atomically $ - writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected) + writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected) getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats getConnectionServers' c connId = do @@ -1513,7 +1516,7 @@ sendNtfConnCommands c cmd = do Just (ConnData {enableNtfs}, _) -> when enableNtfs . atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd) _ -> - atomically $ writeTBQueue (subQ c) ("", connId, ERR $ INTERNAL "no connection data") + atomically $ writeTBQueue (subQ c) ("", connId, APC SAEConn $ ERR $ INTERNAL "no connection data") setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m () setNtfServers' c = atomically . writeTVar (ntfServers c) @@ -1600,7 +1603,7 @@ cleanupManager c = do withStore' c deleteUsersWithoutConns >>= mapM_ notifyUserDeleted threadDelay int where - notifyUserDeleted userId = atomically $ writeTBQueue (subQ c) ("", "", DEL_USER userId) + notifyUserDeleted userId = atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, sessId, rId, cmd) = do @@ -1714,7 +1717,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s Just (Right clnt) | sessId == sessionId clnt -> do removeSubscription c connId - writeTBQueue subQ ("", connId, END) + notify' END pure "END" | otherwise -> ignored _ -> ignored @@ -1723,8 +1726,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s logServer "<--" c srv rId $ "unexpected: " <> bshow cmd notify . ERR $ BROKER (B.unpack $ strEncode srv) UNEXPECTED where - notify :: ACommand 'Agent -> m () - notify msg = atomically $ writeTBQueue subQ ("", connId, msg) + notify :: forall e. AEntityI e => ACommand 'Agent e -> m () + notify = atomically . notify' + + notify' :: forall e. AEntityI e => ACommand 'Agent e -> STM () + notify' msg = writeTBQueue subQ ("", connId, APC (sAEntity @e) msg) prohibited :: m () prohibited = notify . ERR $ AGENT A_PROHIBITED @@ -1805,7 +1811,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s -- this branch is executed by the accepting party in duplexHandshake mode (v2) -- and by the initiating party in v1 -- Also see comment where HELLO is sent. - | sndStatus == Active -> atomically $ writeTBQueue subQ ("", connId, CON) + | sndStatus == Active -> notify CON | duplexHandshake == Just True -> enqueueDuplexHello sq | otherwise -> pure () _ -> pure () diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 1ee467578..7ceaedcf4 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -413,8 +413,8 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, atomically $ mapM_ (releaseGetLock c) qs unliftIO u $ reconnectServer c tSess - notifySub :: ConnId -> ACommand 'Agent -> IO () - notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) + notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO () + notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd) reconnectServer :: AgentMonad m => AgentClient -> SMPTransportSession -> m () reconnectServer c tSess = newAsyncAction tryReconnectSMPClient $ reconnections c @@ -441,8 +441,8 @@ reconnectSMPClient c tSess@(_, srv, _) = let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs mapM_ (throwError . snd) $ listToMaybe tempErrs - notifySub :: ConnId -> ACommand 'Agent -> IO () - notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) + notifySub :: ConnId -> ACommand 'Agent 'AEConn -> IO () + notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC SAEConn cmd) getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do @@ -461,7 +461,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d clientDisconnected client = do atomically $ TM.delete tSess ntfClients incClientStat c userId client "DISCONNECT" "" - atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client) + atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv getXFTPServerClient :: forall m. AgentMonad m => AgentClient -> XFTPTransportSession -> m XFTPClient @@ -482,7 +482,7 @@ getXFTPServerClient c@AgentClient {active, xftpClients, useNetworkConfig} tSess@ clientDisconnected client = do atomically $ TM.delete tSess xftpClients incClientStat c userId client "DISCONNECT" "" - atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client) + atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv getClientVar :: forall a s. TransportSession s -> TMap (TransportSession s) (TMVar a) -> STM (Either (TMVar a) (TMVar a)) @@ -522,7 +522,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconne logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")" atomically $ putTMVar clientVar r liftIO $ incClientStat c userId client "CLIENT" "OK" - atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client) + atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent CONNECT client) successAction client Left e -> do liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e @@ -540,7 +540,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconne withRetryInterval ri $ \_ loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop atomically . removeAsyncAction aId $ asyncClients c -hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> Client msg -> ACommand 'Agent +hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost getClientConfig :: AgentMonad m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig @@ -1059,7 +1059,7 @@ suspendOperation c op endedAction = do notifySuspended :: AgentClient -> STM () notifySuspended c = do -- unsafeIOToSTM $ putStrLn "notifySuspended" - writeTBQueue (subQ c) ("", "", SUSPENDED) + writeTBQueue (subQ c) ("", "", APC SAENone SUSPENDED) writeTVar (agentState c) ASSuspended endOperation :: AgentClient -> AgentOperation -> STM () -> STM () diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 1a96fff16..459dc1206 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -31,8 +31,7 @@ import Data.Text (Text) import Data.Time (UTCTime, addUTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds) import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite -import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..)) -import qualified Simplex.Messaging.Agent.Protocol as AP +import Simplex.Messaging.Agent.Protocol (ACommand (..), APartyCmd (..), AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..), SAEntity (..)) import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite @@ -40,7 +39,7 @@ import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol (NtfSubStatus (..), NtfTknStatus (..), SMPQueueNtf (..)) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Protocol +import Simplex.Messaging.Protocol (NtfServer, ProtocolServer, SMPServer, sameSrvAddr) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (tshow, unlessM) @@ -325,8 +324,9 @@ workerInternalError c connId internalErrStr = do withStore' c $ \db -> setNullNtfSubscriptionAction db connId notifyInternalError c connId internalErrStr +-- TODO change error notifyInternalError :: (MonadUnliftIO m) => AgentClient -> ConnId -> String -> m () -notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, AP.ERR $ AP.INTERNAL internalErrStr) +notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, APC SAEConn $ ERR $ INTERNAL internalErrStr) getNtfToken :: AgentMonad m => m (Maybe NtfToken) getNtfToken = do diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index c18e432fb..b73a04d06 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -40,13 +40,19 @@ module Simplex.Messaging.Agent.Protocol -- * SMP agent protocol types ConnInfo, ACommand (..), + APartyCmd (..), ACommandTag (..), aCommandTag, + aPartyCmdTag, ACmd (..), + APartyCmdTag (..), ACmdTag (..), AParty (..), + AEntity (..), SAParty (..), + SAEntity (..), APartyI (..), + AEntityI (..), MsgHash, MsgMeta (..), ConnectionStats (..), @@ -165,6 +171,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol ( AProtocolType, + EntityId, ErrorType, MsgBody, MsgFlags, @@ -211,10 +218,10 @@ e2eEncUserMsgLength = 15856 type ARawTransmission = (ByteString, ByteString, ByteString) -- | Parsed SMP agent protocol transmission. -type ATransmission p = (ACorrId, ConnId, ACommand p) +type ATransmission p = (ACorrId, EntityId, APartyCmd p) -- | SMP agent protocol transmission or transmission error. -type ATransmissionOrError p = (ACorrId, ConnId, Either AgentErrorType (ACommand p)) +type ATransmissionOrError p = (ACorrId, EntityId, Either AgentErrorType (APartyCmd p)) type ACorrId = ByteString @@ -242,100 +249,151 @@ instance APartyI Agent where sAParty = SAgent instance APartyI Client where sAParty = SClient -data ACmd = forall p. APartyI p => ACmd (SAParty p) (ACommand p) +data AEntity = AEConn | AERcvFile | AENone + deriving (Eq, Show) + +data SAEntity :: AEntity -> Type where + SAEConn :: SAEntity AEConn + SAERcvFile :: SAEntity AERcvFile + SAENone :: SAEntity AENone + +deriving instance Show (SAEntity e) + +deriving instance Eq (SAEntity e) + +instance TestEquality SAEntity where + testEquality SAEConn SAEConn = Just Refl + testEquality SAERcvFile SAERcvFile = Just Refl + testEquality SAENone SAENone = Just Refl + testEquality _ _ = Nothing + +class AEntityI (e :: AEntity) where sAEntity :: SAEntity e + +instance AEntityI AEConn where sAEntity = SAEConn + +instance AEntityI AERcvFile where sAEntity = SAERcvFile + +instance AEntityI AENone where sAEntity = SAENone + +data ACmd = forall p e. (APartyI p, AEntityI e) => ACmd (SAParty p) (SAEntity e) (ACommand p e) deriving instance Show ACmd +data APartyCmd p = forall e. AEntityI e => APC (SAEntity e) (ACommand p e) + +instance Eq (APartyCmd p) where + APC e cmd == APC e' cmd' = case testEquality e e' of + Just Refl -> cmd == cmd' + Nothing -> False + +deriving instance Show (APartyCmd p) + type ConnInfo = ByteString -- | Parameterized type for SMP agent protocol commands and responses from all participants. -data ACommand (p :: AParty) where - NEW :: Bool -> AConnectionMode -> ACommand Client -- response INV - INV :: AConnectionRequestUri -> ACommand Agent - JOIN :: Bool -> AConnectionRequestUri -> ConnInfo -> ACommand Client -- response OK - CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake - LET :: ConfirmationId -> ConnInfo -> ACommand Client -- ConnInfo is from client - REQ :: InvitationId -> L.NonEmpty SMPServer -> ConnInfo -> ACommand Agent -- ConnInfo is from sender - ACPT :: InvitationId -> ConnInfo -> ACommand Client -- ConnInfo is from client - RJCT :: InvitationId -> ACommand Client - INFO :: ConnInfo -> ACommand Agent - CON :: ACommand Agent -- notification that connection is established - SUB :: ACommand Client - END :: ACommand Agent - CONNECT :: AProtocolType -> TransportHost -> ACommand Agent - DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent - DOWN :: SMPServer -> [ConnId] -> ACommand Agent - UP :: SMPServer -> [ConnId] -> ACommand Agent - SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent - SEND :: MsgFlags -> MsgBody -> ACommand Client - MID :: AgentMsgId -> ACommand Agent - SENT :: AgentMsgId -> ACommand Agent - MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent - MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent - ACK :: AgentMsgId -> ACommand Client - SWCH :: ACommand Client - OFF :: ACommand Client - DEL :: ACommand Client - DEL_RCVQ :: SMPServer -> SMP.RecipientId -> Maybe AgentErrorType -> ACommand Agent - DEL_CONN :: ACommand Agent - DEL_USER :: Int64 -> ACommand Agent - CHK :: ACommand Client - STAT :: ConnectionStats -> ACommand Agent - OK :: ACommand Agent - ERR :: AgentErrorType -> ACommand Agent - SUSPENDED :: ACommand Agent - FRCVD :: RcvFileId -> FilePath -> ACommand Agent - FRCVERR :: RcvFileId -> AgentErrorType -> ACommand Agent +data ACommand (p :: AParty) (e :: AEntity) where + NEW :: Bool -> AConnectionMode -> ACommand Client AEConn -- response INV + INV :: AConnectionRequestUri -> ACommand Agent AEConn + JOIN :: Bool -> AConnectionRequestUri -> ConnInfo -> ACommand Client AEConn -- response OK + CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake + LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client + REQ :: InvitationId -> L.NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender + ACPT :: InvitationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client + RJCT :: InvitationId -> ACommand Client AEConn + INFO :: ConnInfo -> ACommand Agent AEConn + CON :: ACommand Agent AEConn -- notification that connection is established + SUB :: ACommand Client AEConn + END :: ACommand Agent AEConn + CONNECT :: AProtocolType -> TransportHost -> ACommand Agent AENone + DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent AENone + DOWN :: SMPServer -> [ConnId] -> ACommand Agent AEConn + UP :: SMPServer -> [ConnId] -> ACommand Agent AEConn + SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent AEConn + SEND :: MsgFlags -> MsgBody -> ACommand Client AEConn + MID :: AgentMsgId -> ACommand Agent AEConn + SENT :: AgentMsgId -> ACommand Agent AEConn + MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent AEConn + MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent AEConn + ACK :: AgentMsgId -> ACommand Client AEConn + SWCH :: ACommand Client AEConn + OFF :: ACommand Client AEConn + DEL :: ACommand Client AEConn + DEL_RCVQ :: SMPServer -> SMP.RecipientId -> Maybe AgentErrorType -> ACommand Agent AEConn + DEL_CONN :: ACommand Agent AEConn + DEL_USER :: Int64 -> ACommand Agent AENone + CHK :: ACommand Client AEConn + STAT :: ConnectionStats -> ACommand Agent AEConn + OK :: ACommand Agent AEConn + ERR :: AgentErrorType -> ACommand Agent AEConn + SUSPENDED :: ACommand Agent AENone + -- XFTP commands and responses + RFPROG :: RcvFileId -> Int -> Int -> ACommand Agent AERcvFile + RFDONE :: RcvFileId -> FilePath -> ACommand Agent AERcvFile + RFERR :: RcvFileId -> AgentErrorType -> ACommand Agent AERcvFile -deriving instance Eq (ACommand p) +deriving instance Eq (ACommand p e) -deriving instance Show (ACommand p) +deriving instance Show (ACommand p e) -data ACmdTag = forall p. APartyI p => ACmdTag (SAParty p) (ACommandTag p) +data ACmdTag = forall p e. (APartyI p, AEntityI e) => ACmdTag (SAParty p) (SAEntity e) (ACommandTag p e) -data ACommandTag (p :: AParty) where - NEW_ :: ACommandTag Client - INV_ :: ACommandTag Agent - JOIN_ :: ACommandTag Client - CONF_ :: ACommandTag Agent - LET_ :: ACommandTag Client - REQ_ :: ACommandTag Agent - ACPT_ :: ACommandTag Client - RJCT_ :: ACommandTag Client - INFO_ :: ACommandTag Agent - CON_ :: ACommandTag Agent - SUB_ :: ACommandTag Client - END_ :: ACommandTag Agent - CONNECT_ :: ACommandTag Agent - DISCONNECT_ :: ACommandTag Agent - DOWN_ :: ACommandTag Agent - UP_ :: ACommandTag Agent - SWITCH_ :: ACommandTag Agent - SEND_ :: ACommandTag Client - MID_ :: ACommandTag Agent - SENT_ :: ACommandTag Agent - MERR_ :: ACommandTag Agent - MSG_ :: ACommandTag Agent - ACK_ :: ACommandTag Client - SWCH_ :: ACommandTag Client - OFF_ :: ACommandTag Client - DEL_ :: ACommandTag Client - DEL_RCVQ_ :: ACommandTag Agent - DEL_CONN_ :: ACommandTag Agent - DEL_USER_ :: ACommandTag Agent - CHK_ :: ACommandTag Client - STAT_ :: ACommandTag Agent - OK_ :: ACommandTag Agent - ERR_ :: ACommandTag Agent - SUSPENDED_ :: ACommandTag Agent - FRCVD_ :: ACommandTag Agent - FRCVERR_ :: ACommandTag Agent +data APartyCmdTag p = forall e. AEntityI e => APCT (SAEntity e) (ACommandTag p e) -deriving instance Eq (ACommandTag p) +instance Eq (APartyCmdTag p) where + APCT e cmd == APCT e' cmd' = case testEquality e e' of + Just Refl -> cmd == cmd' + Nothing -> False -deriving instance Show (ACommandTag p) +deriving instance Show (APartyCmdTag p) -aCommandTag :: ACommand p -> ACommandTag p +data ACommandTag (p :: AParty) (e :: AEntity) where + NEW_ :: ACommandTag Client AEConn + INV_ :: ACommandTag Agent AEConn + JOIN_ :: ACommandTag Client AEConn + CONF_ :: ACommandTag Agent AEConn + LET_ :: ACommandTag Client AEConn + REQ_ :: ACommandTag Agent AEConn + ACPT_ :: ACommandTag Client AEConn + RJCT_ :: ACommandTag Client AEConn + INFO_ :: ACommandTag Agent AEConn + CON_ :: ACommandTag Agent AEConn + SUB_ :: ACommandTag Client AEConn + END_ :: ACommandTag Agent AEConn + CONNECT_ :: ACommandTag Agent AENone + DISCONNECT_ :: ACommandTag Agent AENone + DOWN_ :: ACommandTag Agent AEConn + UP_ :: ACommandTag Agent AEConn + SWITCH_ :: ACommandTag Agent AEConn + SEND_ :: ACommandTag Client AEConn + MID_ :: ACommandTag Agent AEConn + SENT_ :: ACommandTag Agent AEConn + MERR_ :: ACommandTag Agent AEConn + MSG_ :: ACommandTag Agent AEConn + ACK_ :: ACommandTag Client AEConn + SWCH_ :: ACommandTag Client AEConn + OFF_ :: ACommandTag Client AEConn + DEL_ :: ACommandTag Client AEConn + DEL_RCVQ_ :: ACommandTag Agent AEConn + DEL_CONN_ :: ACommandTag Agent AEConn + DEL_USER_ :: ACommandTag Agent AENone + CHK_ :: ACommandTag Client AEConn + STAT_ :: ACommandTag Agent AEConn + OK_ :: ACommandTag Agent AEConn + ERR_ :: ACommandTag Agent AEConn + SUSPENDED_ :: ACommandTag Agent AENone + -- XFTP commands and responses + RFDONE_ :: ACommandTag Agent AERcvFile + RFPROG_ :: ACommandTag Agent AERcvFile + RFERR_ :: ACommandTag Agent AERcvFile + +deriving instance Eq (ACommandTag p e) + +deriving instance Show (ACommandTag p e) + +aPartyCmdTag :: APartyCmd p -> APartyCmdTag p +aPartyCmdTag (APC e cmd) = APCT e $ aCommandTag cmd + +aCommandTag :: ACommand p e -> ACommandTag p e aCommandTag = \case NEW {} -> NEW_ INV _ -> INV_ @@ -371,8 +429,9 @@ aCommandTag = \case OK -> OK_ ERR _ -> ERR_ SUSPENDED -> SUSPENDED_ - FRCVD {} -> FRCVD_ - FRCVERR {} -> FRCVERR_ + RFPROG {} -> RFPROG_ + RFDONE {} -> RFDONE_ + RFERR {} -> RFERR_ data QueueDirection = QDRcv | QDSnd deriving (Eq, Show) @@ -1219,46 +1278,57 @@ dbCommandP :: Parser ACmd dbCommandP = commandP $ A.take =<< (A.decimal <* "\n") instance StrEncoding ACmdTag where - strEncode (ACmdTag _ cmd) = strEncode cmd + strEncode (ACmdTag _ _ cmd) = strEncode cmd strP = A.takeTill (== ' ') >>= \case - "NEW" -> pure $ ACmdTag SClient NEW_ - "INV" -> pure $ ACmdTag SAgent INV_ - "JOIN" -> pure $ ACmdTag SClient JOIN_ - "CONF" -> pure $ ACmdTag SAgent CONF_ - "LET" -> pure $ ACmdTag SClient LET_ - "REQ" -> pure $ ACmdTag SAgent REQ_ - "ACPT" -> pure $ ACmdTag SClient ACPT_ - "RJCT" -> pure $ ACmdTag SClient RJCT_ - "INFO" -> pure $ ACmdTag SAgent INFO_ - "CON" -> pure $ ACmdTag SAgent CON_ - "SUB" -> pure $ ACmdTag SClient SUB_ - "END" -> pure $ ACmdTag SAgent END_ - "CONNECT" -> pure $ ACmdTag SAgent CONNECT_ - "DISCONNECT" -> pure $ ACmdTag SAgent DISCONNECT_ - "DOWN" -> pure $ ACmdTag SAgent DOWN_ - "UP" -> pure $ ACmdTag SAgent UP_ - "SWITCH" -> pure $ ACmdTag SAgent SWITCH_ - "SEND" -> pure $ ACmdTag SClient SEND_ - "MID" -> pure $ ACmdTag SAgent MID_ - "SENT" -> pure $ ACmdTag SAgent SENT_ - "MERR" -> pure $ ACmdTag SAgent MERR_ - "MSG" -> pure $ ACmdTag SAgent MSG_ - "ACK" -> pure $ ACmdTag SClient ACK_ - "SWCH" -> pure $ ACmdTag SClient SWCH_ - "OFF" -> pure $ ACmdTag SClient OFF_ - "DEL" -> pure $ ACmdTag SClient DEL_ - "DEL_RCVQ" -> pure $ ACmdTag SAgent DEL_RCVQ_ - "DEL_CONN" -> pure $ ACmdTag SAgent DEL_CONN_ - "DEL_USER" -> pure $ ACmdTag SAgent DEL_USER_ - "CHK" -> pure $ ACmdTag SClient CHK_ - "STAT" -> pure $ ACmdTag SAgent STAT_ - "OK" -> pure $ ACmdTag SAgent OK_ - "ERR" -> pure $ ACmdTag SAgent ERR_ - "SUSPENDED" -> pure $ ACmdTag SAgent SUSPENDED_ + "NEW" -> t NEW_ + "INV" -> ct INV_ + "JOIN" -> t JOIN_ + "CONF" -> ct CONF_ + "LET" -> t LET_ + "REQ" -> ct REQ_ + "ACPT" -> t ACPT_ + "RJCT" -> t RJCT_ + "INFO" -> ct INFO_ + "CON" -> ct CON_ + "SUB" -> t SUB_ + "END" -> ct END_ + "CONNECT" -> at SAENone CONNECT_ + "DISCONNECT" -> at SAENone DISCONNECT_ + "DOWN" -> ct DOWN_ + "UP" -> ct UP_ + "SWITCH" -> ct SWITCH_ + "SEND" -> t SEND_ + "MID" -> ct MID_ + "SENT" -> ct SENT_ + "MERR" -> ct MERR_ + "MSG" -> ct MSG_ + "ACK" -> t ACK_ + "SWCH" -> t SWCH_ + "OFF" -> t OFF_ + "DEL" -> t DEL_ + "DEL_RCVQ" -> ct DEL_RCVQ_ + "DEL_CONN" -> ct DEL_CONN_ + "DEL_USER" -> at SAENone DEL_USER_ + "CHK" -> t CHK_ + "STAT" -> ct STAT_ + "OK" -> ct OK_ + "ERR" -> ct ERR_ + "SUSPENDED" -> at SAENone SUSPENDED_ + "RFPROG" -> at SAERcvFile RFPROG_ + "RFDONE" -> at SAERcvFile RFDONE_ + "RFERR" -> at SAERcvFile RFERR_ _ -> fail "bad ACmdTag" + where + t = pure . ACmdTag SClient SAEConn + at e = pure . ACmdTag SAgent e + ct = at SAEConn -instance APartyI p => StrEncoding (ACommandTag p) where +instance APartyI p => StrEncoding (APartyCmdTag p) where + strEncode (APCT _ cmd) = strEncode cmd + strP = (\(ACmdTag _ e t) -> checkParty $ APCT e t) <$?> strP + +instance (APartyI p, AEntityI e) => StrEncoding (ACommandTag p e) where strEncode = \case NEW_ -> "NEW" INV_ -> "INV" @@ -1294,22 +1364,28 @@ instance APartyI p => StrEncoding (ACommandTag p) where OK_ -> "OK" ERR_ -> "ERR" SUSPENDED_ -> "SUSPENDED" - FRCVD_ -> "FRCVD" - FRCVERR_ -> "FRCVERR" - strP = (\(ACmdTag _ t) -> checkParty t) <$?> strP + RFPROG_ -> "RFPROG" + RFDONE_ -> "RFDONE" + RFERR_ -> "RFERR" + strP = (\(APCT _ t) -> checkEntity t) <$?> strP checkParty :: forall t p p'. (APartyI p, APartyI p') => t p' -> Either String (t p) checkParty x = case testEquality (sAParty @p) (sAParty @p') of Just Refl -> Right x Nothing -> Left "bad party" +checkEntity :: forall t e e'. (AEntityI e, AEntityI e') => t e' -> Either String (t e) +checkEntity x = case testEquality (sAEntity @e) (sAEntity @e') of + Just Refl -> Right x + Nothing -> Left "bad entity" + -- | SMP agent command and response parser commandP :: Parser ByteString -> Parser ACmd commandP binaryP = strP >>= \case - ACmdTag SClient cmd -> - ACmd SClient <$> case cmd of + ACmdTag SClient e cmd -> + ACmd SClient e <$> case cmd of NEW_ -> s (NEW <$> strP_ <*> strP) JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> binaryP) LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP) @@ -1322,8 +1398,8 @@ commandP binaryP = OFF_ -> pure OFF DEL_ -> pure DEL CHK_ -> pure CHK - ACmdTag SAgent cmd -> - ACmd SAgent <$> case cmd of + ACmdTag SAgent e cmd -> + ACmd SAgent e <$> case cmd of INV_ -> s (INV <$> strP) CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP) REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP) @@ -1346,8 +1422,9 @@ commandP binaryP = OK_ -> pure OK ERR_ -> s (ERR <$> strP) SUSPENDED_ -> pure SUSPENDED - FRCVD_ -> s (FRCVD <$> A.decimal <* A.space <*> strP) - FRCVERR_ -> s (FRCVERR <$> A.decimal <* A.space <*> strP) + RFPROG_ -> s (RFPROG <$> A.decimal <* A.space <*> A.decimal <* A.space <*> A.decimal) + RFDONE_ -> s (RFDONE <$> A.decimal <* A.space <*> strP) + RFERR_ -> s (RFERR <$> A.decimal <* A.space <*> strP) where s :: Parser a -> Parser a s p = A.space *> p @@ -1365,7 +1442,7 @@ parseCommand :: ByteString -> Either AgentErrorType ACmd parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX -- | Serialize SMP agent command. -serializeCommand :: ACommand p -> ByteString +serializeCommand :: ACommand p e -> ByteString serializeCommand = \case NEW ntfs cMode -> s (NEW_, ntfs, cMode) INV cReq -> s (INV_, cReq) @@ -1401,8 +1478,9 @@ serializeCommand = \case ERR e -> s (ERR_, e) OK -> s OK_ SUSPENDED -> s SUSPENDED_ - FRCVD fId fPath -> s (FRCVD_, Str $ bshow fId, fPath) - FRCVERR fId e -> s (FRCVERR_, Str $ bshow fId, e) + RFPROG fId rcvd total -> s (RFPROG_, Str $ bshow fId, rcvd, total) + RFDONE fId fPath -> s (RFDONE_, Str $ bshow fId, fPath) + RFERR fId e -> s (RFERR_, Str $ bshow fId, e) where s :: StrEncoding a => a -> ByteString s = strEncode @@ -1435,52 +1513,55 @@ tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h -- | Send SMP agent protocol command (or response) to TCP connection. tPut :: (Transport c, MonadIO m) => c -> ATransmission p -> m () -tPut h (corrId, connId, command) = - liftIO $ tPutRaw h (corrId, connId, serializeCommand command) +tPut h (corrId, connId, APC _ cmd) = + liftIO $ tPutRaw h (corrId, connId, serializeCommand cmd) -- | Receive client and agent transmissions from TCP connection. tGet :: forall c m p. (Transport c, MonadIO m) => SAParty p -> c -> m (ATransmissionOrError p) tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody where tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p) - tParseLoadBody t@(corrId, connId, command) = do + tParseLoadBody t@(corrId, entId, command) = do let cmd = parseCommand command >>= fromParty >>= tConnId t fullCmd <- either (return . Left) cmdWithMsgBody cmd - return (corrId, connId, fullCmd) + return (corrId, entId, fullCmd) - fromParty :: ACmd -> Either AgentErrorType (ACommand p) - fromParty (ACmd (p :: p1) cmd) = case testEquality party p of - Just Refl -> Right cmd + fromParty :: ACmd -> Either AgentErrorType (APartyCmd p) + fromParty (ACmd (p :: p1) e cmd) = case testEquality party p of + Just Refl -> Right $ APC e cmd _ -> Left $ CMD PROHIBITED - tConnId :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p) - tConnId (_, connId, _) cmd = case cmd of - -- NEW, JOIN and ACPT have optional connId - NEW {} -> Right cmd - JOIN {} -> Right cmd - ACPT {} -> Right cmd - -- ERROR response does not always have connId - ERR _ -> Right cmd - CONNECT {} -> Right cmd - DISCONNECT {} -> Right cmd - DOWN {} -> Right cmd - UP {} -> Right cmd - -- other responses must have connId - _ - | B.null connId -> Left $ CMD NO_CONN - | otherwise -> Right cmd + tConnId :: ARawTransmission -> APartyCmd p -> Either AgentErrorType (APartyCmd p) + tConnId (_, entId, _) (APC e cmd) = + APC e <$> case cmd of + -- NEW, JOIN and ACPT have optional connection ID + NEW {} -> Right cmd + JOIN {} -> Right cmd + ACPT {} -> Right cmd + -- ERROR response does not always have connection ID + ERR _ -> Right cmd + CONNECT {} -> Right cmd + DISCONNECT {} -> Right cmd + DOWN {} -> Right cmd + UP {} -> Right cmd + SUSPENDED {} -> Right cmd + -- other responses must have connection ID + _ + | B.null entId -> Left $ CMD NO_CONN + | otherwise -> Right cmd - cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) - cmdWithMsgBody = \case - SEND msgFlags body -> SEND msgFlags <$$> getBody body - MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body - JOIN ntfs qUri cInfo -> JOIN ntfs qUri <$$> getBody cInfo - CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo - LET confId cInfo -> LET confId <$$> getBody cInfo - REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo - ACPT invId cInfo -> ACPT invId <$$> getBody cInfo - INFO cInfo -> INFO <$$> getBody cInfo - cmd -> pure $ Right cmd + cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p)) + cmdWithMsgBody (APC e cmd) = + APC e <$$> case cmd of + SEND msgFlags body -> SEND msgFlags <$$> getBody body + MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body + JOIN ntfs qUri cInfo -> JOIN ntfs qUri <$$> getBody cInfo + CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo + LET confId cInfo -> LET confId <$$> getBody cInfo + REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo + ACPT invId cInfo -> ACPT invId <$$> getBody cInfo + INFO cInfo -> INFO <$$> getBody cInfo + _ -> pure $ Right cmd getBody :: ByteString -> m (Either AgentErrorType ByteString) getBody binary = diff --git a/src/Simplex/Messaging/Agent/Server.hs b/src/Simplex/Messaging/Agent/Server.hs index 45b6be121..6ef9701f8 100644 --- a/src/Simplex/Messaging/Agent/Server.hs +++ b/src/Simplex/Messaging/Agent/Server.hs @@ -60,10 +60,10 @@ connectClient h c = race_ (send h c) (receive h c) receive :: forall c m. (Transport c, MonadUnliftIO m) => c -> AgentClient -> m () receive h c@AgentClient {rcvQ, subQ} = forever $ do - (corrId, connId, cmdOrErr) <- tGet SClient h + (corrId, entId, cmdOrErr) <- tGet SClient h case cmdOrErr of - Right cmd -> write rcvQ (corrId, connId, cmd) - Left e -> write subQ (corrId, connId, ERR e) + Right cmd -> write rcvQ (corrId, entId, cmd) + Left e -> write subQ (corrId, entId, APC SAEConn $ ERR e) where write :: TBQueue (ATransmission p) -> ATransmission p -> m () write q t = do @@ -77,5 +77,5 @@ send h c@AgentClient {subQ} = forever $ do logClient c "<--" t logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m () -logClient AgentClient {clientId} dir (corrId, connId, cmd) = do +logClient AgentClient {clientId} dir (corrId, connId, APC _ cmd) = do logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd] diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 42c69398a..ab6e66820 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -277,20 +277,20 @@ instance StrEncoding AgentCmdType where _ -> fail "bad AgentCmdType" data AgentCommand - = AClientCommand (ACommand 'Client) + = AClientCommand (APartyCmd 'Client) | AInternalCommand InternalCommand instance StrEncoding AgentCommand where strEncode = \case - AClientCommand cmd -> strEncode (ACClient, Str $ serializeCommand cmd) + AClientCommand (APC _ cmd) -> strEncode (ACClient, Str $ serializeCommand cmd) AInternalCommand cmd -> strEncode (ACInternal, cmd) strP = strP_ >>= \case - ACClient -> AClientCommand <$> ((\(ACmd _ cmd) -> checkParty cmd) <$?> dbCommandP) + ACClient -> AClientCommand <$> ((\(ACmd _ e cmd) -> checkParty $ APC e cmd) <$?> dbCommandP) ACInternal -> AInternalCommand <$> strP data AgentCommandTag - = AClientCommandTag (ACommandTag 'Client) + = AClientCommandTag (APartyCmdTag 'Client) | AInternalCommandTag InternalCommandTag deriving (Show) @@ -363,7 +363,7 @@ instance StrEncoding InternalCommandTag where agentCommandTag :: AgentCommand -> AgentCommandTag agentCommandTag = \case - AClientCommand cmd -> AClientCommandTag $ aCommandTag cmd + AClientCommand cmd -> AClientCommandTag $ aPartyCmdTag cmd AInternalCommand cmd -> AInternalCommandTag $ internalCmdTag cmd internalCmdTag :: InternalCommand -> InternalCommandTag diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 77c61004e..43eb0c132 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -5,6 +5,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PostfixOperators #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} module AgentTests (agentTests) where @@ -19,6 +20,7 @@ import Control.Concurrent import Control.Monad (forM_) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Type.Equality import Network.HTTP.Types (urlEncode) import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) @@ -81,43 +83,53 @@ agentTests (ATransport t) = do it "should resume delivering messages after exceeding quota once all messages are received" $ smpAgentTest2_2_1 $ testResumeDeliveryQuotaExceeded t -tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent) -tGetAgent h = do - t@(_, _, cmd) <- tGet SAgent h - case cmd of - Right CONNECT {} -> tGetAgent h - Right DISCONNECT {} -> tGetAgent h - _ -> pure t +type AEntityTransmission p e = (ACorrId, ConnId, ACommand p e) + +type AEntityTransmissionOrError p e = (ACorrId, ConnId, Either AgentErrorType (ACommand p e)) + +tGetAgent :: Transport c => c -> IO (AEntityTransmissionOrError 'Agent 'AEConn) +tGetAgent = tGetAgent' + +tGetAgent' :: forall c e. (Transport c, AEntityI e) => c -> IO (AEntityTransmissionOrError 'Agent e) +tGetAgent' h = do + (corrId, connId, cmdOrErr) <- tGet SAgent h + case cmdOrErr of + Right (APC _ CONNECT {}) -> tGetAgent' h + Right (APC _ DISCONNECT {}) -> tGetAgent' h + Right (APC e cmd) -> case testEquality e (sAEntity @e) of + Just Refl -> pure (corrId, connId, Right cmd) + _ -> error $ "unexpected command " <> show cmd + Left err -> pure (corrId, connId, Left err) -- | receive message to handle `h` -(<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent) +(<#:) :: Transport c => c -> IO (AEntityTransmissionOrError 'Agent 'AEConn) (<#:) = tGetAgent -- | send transmission `t` to handle `h` and get response -(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent) +(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (AEntityTransmissionOrError 'Agent 'AEConn) h #: t = tPutRaw h t >> (<#:) h -- | action and expected response -- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r` -(#>) :: IO (ATransmissionOrError 'Agent) -> ATransmission 'Agent -> Expectation +(#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> AEntityTransmission 'Agent 'AEConn -> Expectation action #> (corrId, connId, cmd) = action `shouldReturn` (corrId, connId, Right cmd) -- | action and predicate for the response -- `h #:t =#> p` is the test that sends `t` to `h` and validates the response using `p` -(=#>) :: IO (ATransmissionOrError 'Agent) -> (ATransmission 'Agent -> Bool) -> Expectation +(=#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation action =#> p = action >>= (`shouldSatisfy` p . correctTransmission) -correctTransmission :: ATransmissionOrError a -> ATransmission a +correctTransmission :: AEntityTransmissionOrError p e -> AEntityTransmission p e correctTransmission (corrId, connId, cmdOrErr) = case cmdOrErr of Right cmd -> (corrId, connId, cmd) Left e -> error $ show e -- | receive message to handle `h` and validate that it is the expected one -(<#) :: Transport c => c -> ATransmission 'Agent -> Expectation +(<#) :: Transport c => c -> AEntityTransmission 'Agent 'AEConn -> Expectation h <# (corrId, connId, cmd) = (h <#:) `shouldReturn` (corrId, connId, Right cmd) -- | receive message to handle `h` and validate it using predicate `p` -(<#=) :: Transport c => c -> (ATransmission 'Agent -> Bool) -> Expectation +(<#=) :: Transport c => c -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission) -- | test that nothing is delivered to handle `h` during 10ms @@ -129,7 +141,7 @@ h #:# err = tryGet `shouldReturn` () Just _ -> error err _ -> return () -pattern Msg :: MsgBody -> ACommand 'Agent +pattern Msg :: MsgBody -> ACommand 'Agent 'AEConn pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO () diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index efb877de3..2dfc9753b 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -8,6 +8,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} module AgentTests.FunctionalAPITests @@ -19,6 +20,8 @@ module AgentTests.FunctionalAPITests runRight, runRight_, get, + get', + rfGet, (##>), (=##>), pattern Msg, @@ -38,6 +41,7 @@ import qualified Data.Map as M import Data.Maybe (isNothing) import qualified Data.Set as S import Data.Time.Clock.System (SystemTime (..), getSystemTime) +import Data.Type.Equality import SMPAgentClient import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) import Simplex.Messaging.Agent @@ -58,21 +62,31 @@ import Simplex.Messaging.Version import Test.Hspec import UnliftIO -(##>) :: (HasCallStack, MonadIO m) => m (ATransmission 'Agent) -> ATransmission 'Agent -> m () +type AEntityTransmission e = (ACorrId, ConnId, ACommand 'Agent e) + +(##>) :: (HasCallStack, MonadIO m) => m (AEntityTransmission e) -> AEntityTransmission e -> m () a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t) -(=##>) :: (HasCallStack, MonadIO m) => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m () +(=##>) :: (HasCallStack, MonadIO m) => m (AEntityTransmission e) -> (AEntityTransmission e -> Bool) -> m () a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p) -get :: MonadIO m => AgentClient -> m (ATransmission 'Agent) -get c = do - t@(_, _, cmd) <- atomically (readTBQueue $ subQ c) - case cmd of - CONNECT {} -> get c - DISCONNECT {} -> get c - _ -> pure t +get :: MonadIO m => AgentClient -> m (AEntityTransmission 'AEConn) +get = get' @'AEConn -pattern Msg :: MsgBody -> ACommand 'Agent +rfGet :: MonadIO m => AgentClient -> m (AEntityTransmission 'AERcvFile) +rfGet = get' @'AERcvFile + +get' :: forall e m. (MonadIO m, AEntityI e) => AgentClient -> m (AEntityTransmission e) +get' c = do + (corrId, connId, APC e cmd) <- atomically (readTBQueue $ subQ c) + case cmd of + CONNECT {} -> get' c + DISCONNECT {} -> get' c + _ -> case testEquality e (sAEntity @e) of + Just Refl -> pure (corrId, connId, cmd) + _ -> error $ "unexpected command " <> show cmd + +pattern Msg :: MsgBody -> ACommand 'Agent 'AEConn pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody smpCfgV1 :: ProtocolClientConfig @@ -524,7 +538,7 @@ testSuspendingAgent = do get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False ackMessage b aId 4 suspendAgent b 1000000 - get b ##> ("", "", SUSPENDED) + get' b ##> ("", "", SUSPENDED) 5 <- sendMessage a bId SMP.noMsgFlags "hello 2" get a ##> ("", bId, SENT 5) Nothing <- 100000 `timeout` get b @@ -551,11 +565,11 @@ testSuspendingAgentCompleteSending t = do liftIO $ threadDelay 100000 suspendAgent b 5000000 - withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ @AgentErrorType $ do get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False - ("", "", SUSPENDED) <- get b + ("", "", SUSPENDED) <- get' @'AENone b get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False @@ -581,7 +595,7 @@ testSuspendingAgentTimeout t = do 5 <- sendMessage b aId SMP.noMsgFlags "hello too" 6 <- sendMessage b aId SMP.noMsgFlags "how are you?" suspendAgent b 100000 - ("", "", SUSPENDED) <- get b + ("", "", SUSPENDED) <- get' @'AENone b pure () testBatchedSubscriptions :: ATransport -> IO () @@ -777,7 +791,7 @@ testUsers = do deleteUser a auId True get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId'; _ -> False get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False - get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False + get' @'AENone a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False exchangeGreetingsMsgId 6 a bId b aId liftIO $ noMessages a "nothing else should be delivered to alice" @@ -813,7 +827,7 @@ testUsersNoServer t = do deleteUser a auId True get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False - get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False + get' @'AENone a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False liftIO $ noMessages a "nothing else should be delivered to alice" withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do get a =##> \case ("", "", UP _ [c]) -> c == bId; _ -> False diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 2687594d7..cc80cb098 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -66,8 +66,8 @@ smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h get h = do t@(_, _, cmdStr) <- tGetRaw h case parseAll networkCommandP cmdStr of - Right (ACmd SAgent CONNECT {}) -> get h - Right (ACmd SAgent DISCONNECT {}) -> get h + Right (ACmd SAgent _ CONNECT {}) -> get h + Right (ACmd SAgent _ DISCONNECT {}) -> get h _ -> pure t runSmpAgentTest :: forall c a. Transport c => (c -> IO a) -> IO a diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 59f9fcf8e..57b1e2777 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -5,7 +5,7 @@ module XFTPAgent where -import AgentTests.FunctionalAPITests (get, runRight_) +import AgentTests.FunctionalAPITests (get, rfGet, runRight_) import Control.Monad.Except import Data.Bifunctor (first) import qualified Data.ByteString as LB @@ -13,7 +13,7 @@ import SMPAgentClient (agentCfg, initAgentServers) import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), checkParty) import Simplex.Messaging.Agent (disconnectAgentClient, getSMPAgentClient, xftpReceiveFile) -import Simplex.Messaging.Agent.Protocol (ACommand (FRCVD), AgentErrorType (..)) +import Simplex.Messaging.Agent.Protocol (ACommand (..), AgentErrorType (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) import System.Directory (getFileSize) import System.FilePath (()) @@ -48,7 +48,7 @@ testXFTPAgentReceive = withXFTPServer $ do runRight_ $ do fd :: ValidFileDescription 'FPRecipient <- getFileDescription fdRcv fId <- xftpReceiveFile rcp 1 fd recipientFiles - ("", "", FRCVD fId' path) <- get rcp + ("", "", RFDONE fId' path) <- rfGet rcp liftIO $ do fId' `shouldBe` fId LB.readFile path `shouldReturn` file @@ -89,7 +89,7 @@ testXFTPAgentReceiveRestore = do rcp' <- getSMPAgentClient agentCfg initAgentServers withXFTPServerStoreLogOn $ \_ -> do -- receive file using agent - should succeed with server up - ("", "", FRCVD fId' path) <- get rcp' + ("", "", RFDONE fId' path) <- rfGet rcp' liftIO $ do fId' `shouldBe` 1 file <- LB.readFile filePath