From f4ad3a983e06a694917decb4283dfee0f3d3dae7 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 11 Jan 2023 13:47:20 +0000 Subject: [PATCH] support users in agent to isolate traffic of different users (#598) * users table, isolate traffic sessions by users or by queues * remove extra indices * corrections Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> --- apps/smp-agent/Main.hs | 3 +- cabal.project | 6 + simplexmq.cabal | 1 + src/Simplex/Messaging/Agent.hs | 182 +++++++------ src/Simplex/Messaging/Agent/Client.hs | 250 +++++++++++------- src/Simplex/Messaging/Agent/Env/SQLite.hs | 4 +- src/Simplex/Messaging/Agent/Store.hs | 33 ++- src/Simplex/Messaging/Agent/Store/SQLite.hs | 54 ++-- .../Agent/Store/SQLite/Migrations.hs | 4 +- .../SQLite/Migrations/M20230110_users.hs | 27 ++ .../Store/SQLite/Migrations/agent_schema.sql | 6 +- src/Simplex/Messaging/Client.hs | 20 +- src/Simplex/Messaging/Client/Agent.hs | 2 +- src/Simplex/Messaging/Protocol.hs | 1 + src/Simplex/Messaging/Transport/Client.hs | 16 +- .../Messaging/Transport/HTTP2/Client.hs | 2 +- tests/AgentTests/FunctionalAPITests.hs | 56 ++-- tests/AgentTests/NotificationTests.hs | 16 +- tests/AgentTests/SQLiteTests.hs | 14 +- tests/NtfClient.hs | 2 +- tests/SMPAgentClient.hs | 16 +- tests/SMPClient.hs | 2 +- 22 files changed, 459 insertions(+), 258 deletions(-) create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20230110_users.hs diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index da43bdec7..0c51b5619 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -7,6 +7,7 @@ module Main where import Control.Logger.Simple import qualified Data.List.NonEmpty as L +import qualified Data.Map.Strict as M import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Server (runSMPAgent) import Simplex.Messaging.Client (defaultNetworkConfig) @@ -18,7 +19,7 @@ cfg = defaultAgentConfig servers :: InitialAgentServers servers = InitialAgentServers - { smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"], + { smp = M.fromList [(1, L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"])], ntf = [], netCfg = defaultNetworkConfig } diff --git a/cabal.project b/cabal.project index 96cf5badd..020cc61f7 100644 --- a/cabal.project +++ b/cabal.project @@ -1,11 +1,17 @@ packages: . -- packages: . ../direct-sqlcipher ../sqlcipher-simple +-- packages: . ../hs-socks source-repository-package type: git location: https://github.com/simplex-chat/aeson.git tag: 3eb66f9a68f103b5f1489382aad89f5712a64db7 +source-repository-package + type: git + location: https://github.com/simplex-chat/hs-socks.git + tag: a30cc7a79a08d8108316094f8f2f82a0c5e1ac51 + source-repository-package type: git location: https://github.com/simplex-chat/direct-sqlcipher.git diff --git a/simplexmq.cabal b/simplexmq.cabal index a2c25d487..b5a6e4549 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -55,6 +55,7 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 444a9a63b..042c25ab9 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -38,6 +38,8 @@ module Simplex.Messaging.Agent disconnectAgentClient, resumeAgentClient, withConnLock, + createUser, + deleteUser, createConnectionAsync, joinConnectionAsync, allowConnectionAsync, @@ -158,13 +160,19 @@ resumeAgentClient c = atomically $ writeTVar (active c) True -- | type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m) +createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId +createUser c = withAgentEnv c . createUser' c + +deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> m () +deleteUser c = withAgentEnv c . deleteUser' c + -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id -createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId -createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode +createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId +createConnectionAsync c userId corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c userId corrId enableNtfs cMode -- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id -joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs +joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId +joinConnectionAsync c userId corrId enableNtfs = withAgentEnv c .: joinConnAsync c userId corrId enableNtfs -- | Allow connection to continue after CONF notification (LET command), no synchronous response allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () @@ -187,12 +195,12 @@ deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId - deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c -- | Create SMP agent connection (NEW command) -createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c) -createConnection c enableNtfs cMode clientData = withAgentEnv c $ newConn c "" False enableNtfs cMode clientData +createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c) +createConnection c userId enableNtfs cMode clientData = withAgentEnv c $ newConn c userId "" False enableNtfs cMode clientData -- | Join SMP agent connection (JOIN command) -joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId +joinConnection c userId enableNtfs = withAgentEnv c .: joinConn c userId "" False enableNtfs -- | Allow connection to continue after CONF notification (LET command) allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () @@ -256,12 +264,12 @@ getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m By getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c -- | Change servers to be used for creating new queues -setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m () -setSMPServers c = withAgentEnv c . setSMPServers' c +setSMPServers :: AgentErrorMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m () +setSMPServers c = withAgentEnv c .: setSMPServers' c -- | Test SMP server -testSMPServerConnection :: AgentErrorMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure) -testSMPServerConnection c = withAgentEnv c . runSMPServerTest c +testSMPServerConnection :: AgentErrorMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure) +testSMPServerConnection c = withAgentEnv c .: runSMPServerTest c setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m () setNtfServers c = withAgentEnv c . setNtfServers' c @@ -349,8 +357,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do -- | execute any SMP agent command processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent) processCommand c (connId, cmd) = case cmd of - NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode Nothing - JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo + NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId False enableNtfs cMode Nothing + JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo RJCT invId -> rejectContact' c connId invId $> (connId, OK) @@ -361,29 +369,43 @@ processCommand c (connId, cmd) = case cmd of OFF -> suspendConnection' c connId $> (connId, OK) DEL -> deleteConnection' c connId $> (connId, OK) CHK -> (connId,) . STAT <$> getConnectionServers' c connId + where + -- command interface does not support different users + userId = 1 -newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId -newConnAsync c corrId enableNtfs cMode = do +createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId +createUser' c srvs = do + userId <- withStore' c createUserRecord + atomically $ TM.insert userId srvs $ smpServers c + pure userId + +deleteUser' :: AgentMonad m => AgentClient -> UserId -> m () +deleteUser' c userId = do + withStore c (`deleteUserRecord` userId) + atomically $ TM.delete userId $ smpServers c + +newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId +newConnAsync c userId corrId enableNtfs cMode = do g <- asks idsDrg connAgentVersion <- asks $ maxVersion . smpAgentVRange . config - let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent + let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent connId <- withStore c $ \db -> createNewConn db g cData cMode enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode) pure connId -joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do +joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId +joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do aVRange <- asks $ smpAgentVRange . config case crAgentVRange `compatibleVersion` aVRange of Just (Compatible connAgentVersion) -> do g <- asks idsDrg let duplexHS = connAgentVersion /= 1 - cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False} + cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo pure connId _ -> throwError $ AGENT A_VERSION -joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo = +joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _cInfo = throwError $ CMD PROHIBITED allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () @@ -397,9 +419,9 @@ acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> Invitat acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case - SomeConn _ ContactConnection {} -> do + SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConnAsync c corrId enableNtfs connReq ownConnInfo `catchError` \err -> do + joinConnAsync c userId corrId enableNtfs connReq ownConnInfo `catchError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -444,14 +466,14 @@ switchConnectionAsync' c corrId connId = SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH _ -> throwError $ CMD PROHIBITED -newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c) -newConn c connId asyncMode enableNtfs cMode clientData = - getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode clientData +newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c) +newConn c userId connId asyncMode enableNtfs cMode clientData = + getSMPServer c userId >>= newConnSrv c userId connId asyncMode enableNtfs cMode clientData -newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) -newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do +newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - (q, qUri) <- newRcvQueue c "" srv smpClientVRange + (q, qUri) <- newRcvQueue c userId "" srv smpClientVRange connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange let rq = (q :: RcvQueue) {connId = connId'} addSubscription c rq @@ -471,19 +493,19 @@ newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do pure connId setUpConn False rq connAgentVersion = do g <- asks idsDrg - let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent withStore c $ \db -> createRcvConn db g cData rq cMode -joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId -joinConn c connId asyncMode enableNtfs cReq cInfo = do +joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId +joinConn c userId connId asyncMode enableNtfs cReq cInfo = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> - getNextSMPServer c [qServer q] - _ -> getSMPServer c - joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv + getNextSMPServer c userId [qServer q] + _ -> getSMPServer c userId + joinConnSrv c userId connId asyncMode enableNtfs cReq cInfo srv -joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId -joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do +joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId +joinConnSrv c userId connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config case ( qUri `compatibleVersion` smpClientVRange, e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange, @@ -493,9 +515,9 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge (pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams (_, rcDHRs) <- liftIO C.generateKeyPair' let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams - q <- newSndQueue "" qInfo + q <- newSndQueue userId "" qInfo let duplexHS = connAgentVersion /= 1 - cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False} + cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False} connId' <- setUpConn asyncMode cData q rc let sq = (q :: SndQueue) {connId = connId'} cData' = (cData :: ConnData) {connId = connId'} @@ -520,23 +542,23 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge liftIO $ createRatchet db connId' rc pure connId' _ -> throwError $ AGENT A_VERSION -joinConnSrv c connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do +joinConnSrv c userId connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config case ( qUri `compatibleVersion` clientVRange, crAgentVRange `compatibleVersion` aVRange ) of (Just qInfo, Just vrsn) -> do - (connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation Nothing srv - sendInvitation c qInfo vrsn cReq cInfo + (connId', cReq) <- newConnSrv c userId connId False enableNtfs SCMInvitation Nothing srv + sendInvitation c userId qInfo vrsn cReq cInfo pure connId' _ -> throwError $ AGENT A_VERSION -joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do +joinConnSrv _c _userId _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do throwError $ CMD PROHIBITED createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> m SMPQueueInfo -createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do - (rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion +createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} srv = do + (rq, qUri) <- newRcvQueue c userId connId srv $ versionToRange smpClientVersion let qInfo = toVersionT qUri smpClientVersion addSubscription c rq void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq @@ -564,9 +586,9 @@ acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId acceptContact' c connId enableNtfs invId ownConnInfo = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case - SomeConn _ ContactConnection {} -> do + SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do + joinConn c userId connId False enableNtfs connReq ownConnInfo `catchError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -799,13 +821,15 @@ runCommandProcessing c@AgentClient {subQ} server_ = do NEW enableNtfs (ACM cMode) -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) tryCommand . withNextSrv usedSrvs [] $ \srv -> do - (_, cReq) <- newConnSrv c connId True enableNtfs cMode Nothing srv + let userId = 1 -- TODO + (_, cReq) <- newConnSrv c userId connId True enableNtfs cMode Nothing srv notify $ INV (ACR cMode cReq) JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do - void $ joinConnSrv c connId True enableNtfs cReq connInfo srv + let userId = 1 -- TODO + void $ joinConnSrv c userId connId True enableNtfs cReq connInfo srv notify OK LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK @@ -901,10 +925,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m () withNextSrv usedSrvs initUsed action = do used <- readTVarIO usedSrvs - srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c used + let userId = 1 -- TODO + srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId used atomically $ do - srvs <- readTVar $ smpServers c - let used' = if length used + 1 >= L.length srvs then initUsed else srv : used + srvs_ <- TM.lookup userId $ smpServers c + -- TODO this condition does not account for servers change, it has to be changed to see if there are any remaining unused servers configured for the user + let used' = if length used + 1 >= maybe 0 L.length srvs_ then initUsed else srv : used writeTVar usedSrvs used' action srvAuth -- ^ ^ ^ async command processing / @@ -978,7 +1004,7 @@ getPendingMsgQ c SndQueue {server, sndId} = do pure q runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () -runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do +runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, duplexHandshake} sq = do (mq, qLock) <- atomically $ getPendingMsgQ c sq ri <- asks $ messageRetryInterval . config forever $ do @@ -1067,7 +1093,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh -- and this branch should never be reached as receive is created before the confirmation, -- so the condition is not necessary here, strictly speaking. _ -> unless (duplexHandshake == Just True) $ do - srv <- getSMPServer c + srv <- getSMPServer c userId qInfo <- createReplyQueue c cData sq srv void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo] AM_A_MSG_ -> notify $ SENT mId @@ -1142,12 +1168,12 @@ switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats switchConnection' c connId = withConnLock c connId "switchConnection" $ do SomeConn _ conn <- withStore c (`getConn` connId) case conn of - DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do + DuplexConnection cData@ConnData {userId} rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do clientVRange <- asks $ smpClientVRange . config -- try to get the server that is different from all queues, or at least from the primary rcv queue - srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs) - srv' <- if srv == server then getNextSMPServer c [server] else pure srvAuth - (q, qUri) <- newRcvQueue c connId srv' clientVRange + srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs) + srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth + (q, qUri) <- newRcvQueue c userId connId srv' clientVRange let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} void . withStore c $ \db -> addConnRcvQueue db connId rq' addSubscription c rq' @@ -1217,8 +1243,8 @@ connectionStats = \case NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []} -- | Change servers to be used for creating new queues, in Reader monad -setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m () -setSMPServers' c = atomically . writeTVar (smpServers c) +setSMPServers' :: AgentMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m () +setSMPServers' c userId srvs = atomically $ TM.insert userId srvs $ smpServers c registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus registerNtfToken' c suppliedDeviceToken suppliedNtfMode = @@ -1457,8 +1483,8 @@ debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs} = do where getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls) -getSMPServer :: AgentMonad m => AgentClient -> m SMPServerWithAuth -getSMPServer c = readTVarIO (smpServers c) >>= pickServer +getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth +getSMPServer c userId = withUserServers c userId pickServer pickServer :: AgentMonad m => NonEmpty SMPServerWithAuth -> m SMPServerWithAuth pickServer = \case @@ -1467,13 +1493,18 @@ pickServer = \case gen <- asks randomServer atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1)) -getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServerWithAuth -getNextSMPServer c usedSrvs = do - srvs <- readTVarIO $ smpServers c +getNextSMPServer :: AgentMonad m => AgentClient -> UserId -> [SMPServer] -> m SMPServerWithAuth +getNextSMPServer c userId usedSrvs = withUserServers c userId $ \srvs -> case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of Just srvs' -> pickServer srvs' _ -> pickServer srvs +withUserServers :: AgentMonad m => AgentClient -> UserId -> (NonEmpty SMPServerWithAuth -> m a) -> m a +withUserServers c userId action = + atomically (TM.lookup userId $ smpServers c) >>= \case + Just srvs -> action srvs + _ -> throwError $ INTERNAL "unknown userId - no SMP servers" + subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do t <- atomically $ readTBQueue msgQ @@ -1488,7 +1519,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm processSMP rq conn $ connData conn where processSMP :: RcvQueue -> Connection c -> ConnData -> m () - processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $ + processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {userId, connId, duplexHandshake} = withConnLock c connId "processSMP" $ case cmd of SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ @@ -1586,8 +1617,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm ackDel = enqueueCmd . ICAckDel rId srvMsgId handleNotifyAck :: m () -> m () handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack - SMP.END -> - atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND) + SMP.END -> do + -- TODO is race condition possible here on session mode change? + tSess <- mkSMPTransportSession c rq + atomically (TM.lookup tSess smpClients $>>= tryReadTMVar >>= processEND) >>= logServer "<--" c srv rId where processEND = \case @@ -1724,7 +1757,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of (Just _, _) -> qError "QADD: queue address is already used in connection" (_, Just _replaced@SndQueue {dbQueueId}) -> do - sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo + sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo let sq' = (sq_ :: SndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} void . withStore c $ \db -> addConnSndQueue db connId sq' case (sndPublicKey, e2ePubKey) of @@ -1796,12 +1829,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm | otherwise = MsgError MsgDuplicate -- this case is not possible connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m () -connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do +connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do clientVRange <- asks $ smpClientVRange . config case qInfo `proveCompatible` clientVRange of Nothing -> throwError $ AGENT A_VERSION Just qInfo' -> do - sq <- newSndQueue connId qInfo' + sq <- newSndQueue userId connId qInfo' dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing @@ -1860,14 +1893,15 @@ agentRatchetDecrypt db connId encAgentMsg = do liftIO $ updateRatchet db connId rc' skippedDiff liftEither $ first (SEAgentError . cryptoError) agentMsgBody_ -newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue -newSndQueue connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do +newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m SndQueue +newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do C.SignAlg a <- asks $ cmdSignAlg . config (sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a (e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair' pure SndQueue - { connId, + { userId, + connId, server = smpServer, sndId = senderId, sndPublicKey = Just sndPublicKey, diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9a3642767..54c3e11ab 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -48,6 +48,9 @@ module Simplex.Messaging.Agent.Client agentNtfCreateSubscription, agentNtfCheckSubscription, agentNtfDeleteSubscription, + -- TODO possibly, this should not be exported? + -- this is used in END processing + mkSMPTransportSession, agentCbEncrypt, agentCbDecrypt, cryptoError, @@ -131,6 +134,7 @@ import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse) import Simplex.Messaging.Protocol ( AProtocolType (..), BrokerMsg, + EntityId, ErrorType, MsgFlags (..), MsgId, @@ -166,15 +170,22 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient) type NtfClientVar = TMVar (Either AgentErrorType NtfClient) +-- | Transport session key - includes entity ID if `sessionMode = TSMEntity`. +type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId) + +type SMPTransportSession = TransportSession SMP.BrokerMsg + +type NtfTransportSession = TransportSession NtfResponse + data AgentClient = AgentClient { active :: TVar Bool, rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue (ServerTransmission BrokerMsg), - smpServers :: TVar (NonEmpty SMPServerWithAuth), - smpClients :: TMap SMPServer SMPClientVar, + smpServers :: TMap UserId (NonEmpty SMPServerWithAuth), + smpClients :: TMap SMPTransportSession SMPClientVar, ntfServers :: TVar [NtfServer], - ntfClients :: TMap NtfServer NtfClientVar, + ntfClients :: TMap NtfTransportSession NtfClientVar, useNetworkConfig :: TVar NetworkConfig, subscrConns :: TVar (Set ConnId), activeSubs :: TRcvQueues, @@ -195,7 +206,7 @@ data AgentClient = AgentClient -- locks to prevent concurrent operations with connection connLocks :: TMap ConnId Lock, -- locks to prevent concurrent reconnections to SMP servers - reconnectLocks :: TMap SMPServer Lock, + reconnectLocks :: TMap SMPTransportSession Lock, reconnections :: TVar [Async ()], asyncClients :: TVar [Async ()], agentStats :: TMap AgentStatsKey (TVar Int), @@ -227,7 +238,13 @@ data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map St instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions -data AgentStatsKey = AgentStatsKey {host :: ByteString, clientTs :: ByteString, cmd :: ByteString, res :: ByteString} +data AgentStatsKey = AgentStatsKey + { userId :: UserId, + host :: ByteString, + clientTs :: ByteString, + cmd :: ByteString, + res :: ByteString + } deriving (Eq, Ord, Show) newAgentClient :: InitialAgentServers -> Env -> STM AgentClient @@ -270,7 +287,7 @@ agentClientStore :: AgentClient -> SQLiteStore agentClientStore AgentClient {agentEnv = Env {store}} = store class ProtocolServerClient msg where - getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg) + getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg) clientProtocolError :: ErrorType -> AgentErrorType instance ProtocolServerClient BrokerMsg where @@ -281,19 +298,20 @@ instance ProtocolServerClient NtfResponse where getProtocolServerClient = getNtfServerClient clientProtocolError = NTF -getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient -getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do +getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient +getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, entityId_) = do unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped" - atomically (getClientVar srv smpClients) + atomically (getClientVar tSess smpClients) >>= either - (newProtocolClient c srv smpClients connectClient reconnectClient) - (waitForProtocolClient c srv) + (newProtocolClient c tSess smpClients connectClient reconnectClient) + (waitForProtocolClient c tSess) where connectClient :: m SMPClient connectClient = do cfg <- getClientConfig c smpCfg u <- askUnliftIO - liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u) + let proxyUsername = Just $ bshow userId <> maybe "" (":" <>) entityId_ + liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg proxyUsername (Just msgQ) $ clientDisconnected u) clientDisconnected :: UnliftIO m -> SMPClient -> IO () clientDisconnected u client = do @@ -302,14 +320,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do where removeClientAndSubs :: IO ([RcvQueue], [ConnId]) removeClientAndSubs = atomically $ do - TM.delete srv smpClients + TM.delete tSess smpClients (qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c mapM_ (`RQ.addQueue` pendingSubs c) qs pure (qs, S.toList conns) serverDown :: ([RcvQueue], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do - incClientStat c client "DISCONNECT" "" + incClientStat c userId client "DISCONNECT" "" notifySub "" $ hostEvent DISCONNECT client unless (null conns) $ notifySub "" $ DOWN srv conns unless (null qs) $ do @@ -329,18 +347,20 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do reconnectClient :: m () reconnectClient = - withLockMap_ (reconnectLocks c) srv "reconnect" $ + withLockMap_ (reconnectLocks c) tSess "reconnect" $ atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe where resubscribe :: [RcvQueue] -> m () resubscribe qs = do - connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar) + connected <- maybe False isRight <$> atomically (TM.lookup tSess smpClients $>>= tryReadTMVar) cs <- atomically . RQ.getConns $ activeSubs c + -- TODO probably tSess should be passed here or nothing + -- maybe this should be another function, not subscribeQueues (client_, rs) <- subscribeQueues c srv qs let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs liftIO $ do unless connected . forM_ client_ $ \cl -> do - incClientStat c cl "CONNECT" "" + incClientStat c userId cl "CONNECT" "" notifySub "" $ hostEvent CONNECT cl let conns = S.toList $ S.fromList okConns `S.difference` cs unless (null conns) $ notifySub "" $ UP srv conns @@ -351,37 +371,37 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do notifySub :: ConnId -> ACommand 'Agent -> IO () notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) -getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient -getNtfServerClient c@AgentClient {active, ntfClients} srv = do +getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient +getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped" - atomically (getClientVar srv ntfClients) + atomically (getClientVar tSess ntfClients) >>= either - (newProtocolClient c srv ntfClients connectClient $ pure ()) - (waitForProtocolClient c srv) + (newProtocolClient c tSess ntfClients connectClient $ pure ()) + (waitForProtocolClient c tSess) where connectClient :: m NtfClient connectClient = do cfg <- getClientConfig c ntfCfg - liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing clientDisconnected) + liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing Nothing clientDisconnected) clientDisconnected :: NtfClient -> IO () clientDisconnected client = do - atomically $ TM.delete srv ntfClients - incClientStat c client "DISCONNECT" "" + atomically $ TM.delete tSess ntfClients + incClientStat c userId client "DISCONNECT" "" atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv -getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a)) -getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients +getClientVar :: forall a s. TransportSession s -> TMap (TransportSession s) (TMVar a) -> STM (Either (TMVar a) (TMVar a)) +getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup tSess clients where newClientVar :: STM (TMVar a) newClientVar = do var <- newEmptyTMVar - TM.insert srv var clients + TM.insert tSess var clients pure var -waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ClientVar msg -> m (ProtocolClient msg) -waitForProtocolClient c srv clientVar = do +waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg) +waitForProtocolClient c (_, srv, _) clientVar = do NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar) liftEither $ case client_ of @@ -393,30 +413,30 @@ newProtocolClient :: forall msg m. (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> - ProtoServer msg -> - TMap (ProtoServer msg) (ClientVar msg) -> + TransportSession msg -> + TMap (TransportSession msg) (ClientVar msg) -> m (ProtocolClient msg) -> m () -> ClientVar msg -> m (ProtocolClient msg) -newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync +newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync where tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a tryConnectClient successAction retryAction = tryError connectClient >>= \r -> case r of Right client -> do - logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv + logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")" atomically $ putTMVar clientVar r - liftIO $ incClientStat c client "CLIENT" "OK" + liftIO $ incClientStat c userId client "CLIENT" "OK" atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client) successAction client Left e -> do - liftIO $ incServerStat c srv "CLIENT" $ strEncode e + liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e if temporaryAgentError e then retryAction else atomically $ do putTMVar clientVar (Left e) - TM.delete srv clients + TM.delete tSess clients throwError e tryConnectAsync :: m () tryConnectAsync = do @@ -471,7 +491,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} = where k = (server, sndId) -closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO () +closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty) where @@ -494,30 +514,43 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur where newLock = newEmptyTMVar >>= \l -> TM.insert key l locks $> l -withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> m a) -> m a -withClient_ c srv statCmd action = do - cl <- getProtocolServerClient c srv +withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a +withClient_ c tSess@(userId, srv, _) statCmd action = do + cl <- getProtocolServerClient c tSess (action cl <* stat cl "OK") `catchError` logServerError cl where - stat cl = liftIO . incClientStat c cl statCmd + stat cl = liftIO . incClientStat c userId cl statCmd logServerError :: ProtocolClient msg -> AgentErrorType -> m a logServerError cl e = do logServer "<--" c srv "" $ strEncode e stat cl $ strEncode e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a -withLogClient_ c srv qId cmdStr action = do - logServer "-->" c srv qId cmdStr - res <- withClient_ c srv cmdStr action - logServer "<--" c srv qId "OK" +withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a +withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do + logServer "-->" c srv entId cmdStr + res <- withClient_ c tSess cmdStr action + logServer "<--" c srv entId "OK" return res -withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a -withClient c srv statKey action = withClient_ c srv statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client +withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client -withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a -withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client +withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client + +withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a +withSMPClient c q cmdStr action = do + tSess <- mkSMPTransportSession c q + withLogClient c tSess (queueId q) cmdStr action + +withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a +withSMPClient_ c q cmdStr action = do + tSess <- mkSMPTransportSession c q + withLogClient_ c tSess (queueId q) cmdStr action + +withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a +withNtfClient c srv = withLogClient c (0, srv, Nothing) liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a liftClient protocolError_ = liftError . protocolClientError protocolError_ @@ -551,12 +584,13 @@ instance ToJSON SMPTestFailure where toEncoding = J.genericToEncoding J.defaultOptions toJSON = J.genericToJSON J.defaultOptions -runSMPServerTest :: AgentMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure) -runSMPServerTest c (ProtoServerWithAuth srv auth) = do +runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure) +runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do cfg <- getClientConfig c smpCfg C.SignAlg a <- asks $ cmdSignAlg . config liftIO $ do - getProtocolClient srv cfg Nothing (\_ -> pure ()) >>= \case + let proxyUsername = bshow userId + getProtocolClient srv cfg (Just proxyUsername) Nothing (\_ -> pure ()) >>= \case Right smp -> do (rKey, rpKey) <- C.generateSignatureKeyPair a (sKey, _) <- C.generateSignatureKeyPair a @@ -566,7 +600,7 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do liftError (testErr TSSecureQueue) $ secureSMPQueue smp rpKey rcvId sKey liftError (testErr TSDeleteQueue) $ deleteSMPQueue smp rpKey rcvId ok <- tcpTimeout (networkConfig cfg) `timeout` closeProtocolClient smp - incClientStat c smp "TEST" "OK" + incClientStat c userId smp "TEST" "OK" pure $ either Just (const Nothing) r <|> maybe (Just (SMPTestFailure TSDisconnect $ BROKER addr TIMEOUT)) (const Nothing) ok Left e -> pure (Just $ testErr TSConnect e) where @@ -574,19 +608,29 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure testErr step = SMPTestFailure step . protocolClientError SMP addr -newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri) -newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do +mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg) +mkTransportSession c userId srv entityId = do + mode <- sessionMode <$> readTVarIO (useNetworkConfig c) + pure (userId, srv, if mode == TSMEntity then Just entityId else Nothing) + +mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession +mkSMPTransportSession c q = mkTransportSession c (qUserId q) (qServer q) (queueId q) + +newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri) +newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do C.SignAlg a <- asks (cmdSignAlg . config) (recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a (dhKey, privDhKey) <- liftIO C.generateKeyPair' (e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair' logServer "-->" c srv "" "NEW" + tSess <- mkTransportSession c userId srv connId QIK {rcvId, sndId, rcvPublicDhKey} <- - withClient c srv "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth + withClient c tSess "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] let rq = RcvQueue - { connId, + { userId, + connId, server = srv, rcvId, rcvPrivateKey, @@ -609,7 +653,7 @@ subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do atomically $ do modifyTVar (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c - withLogClient c server rcvId "SUB" $ \smp -> + withSMPClient c rq "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq) >>= either throwError pure @@ -648,14 +692,19 @@ subscribeQueues c srv qs = do RQ.addQueue rq $ pendingSubs c case L.nonEmpty qs_ of Just qs' -> do - smp_ <- tryError (getSMPServerClient c srv) + mode <- sessionMode <$> readTVarIO (useNetworkConfig c) + -- TODO these subscriptions should happen in different sessions if mode is TSMEntity + -- it is also a question whether it is needed to group by server outside if grouping by sessions should happen here anyway + let userId = 1 + tSess = (userId, srv, Nothing) + smp_ <- tryError $ getSMPServerClient c tSess (eitherToMaybe smp_,) . (errs <>) <$> case smp_ of Left e -> pure $ map (,Left e) qs_ Right smp -> do logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB" let qs2 = L.map queueCreds qs' n = (length qs2 - 1) `div` 90 + 1 - liftIO $ incClientStatN c smp n "SUBS" "OK" + liftIO $ incClientStatN c userId smp n "SUBS" "OK" liftIO $ do rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2 mapM_ (uncurry $ processSubResult c) rs @@ -699,16 +748,17 @@ logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () -sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation = - withLogClient_ c server sndId "SEND " $ \smp -> do +sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation = + withSMPClient_ c sq "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" -sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () -sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = - withLogClient_ c smpServer senderId "SEND " $ \smp -> do +sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () +sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do + tSess <- mkTransportSession c userId smpServer senderId + withLogClient_ c tSess senderId "SEND " $ \smp -> do msg <- mkInvitation liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg where @@ -722,7 +772,7 @@ sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderI getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta) getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do atomically createTakeGetLock - (v, msg_) <- withLogClient c server rcvId "GET" $ \smp -> + (v, msg_) <- withSMPClient c rq "GET" $ \smp -> (thVersion smp,) <$> getSMPMessage smp rcvPrivateKey rcvId mapM (decryptMeta v) msg_ where @@ -742,23 +792,23 @@ decryptSMPMessage v rq SMP.RcvMessage {msgId, msgTs, msgFlags, msgBody = SMP.Enc decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId) secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m () -secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey = - withLogClient c server rcvId "KEY " $ \smp -> +secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey = + withSMPClient c rq "KEY " $ \smp -> secureSMPQueue smp rcvPrivateKey rcvId senderKey enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (NotifierId, RcvNtfPublicDhKey) -enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey = - withLogClient c server rcvId "NKEY " $ \smp -> +enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey = + withSMPClient c rq "NKEY " $ \smp -> enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m () -disableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} = - withLogClient c server rcvId "NDEL" $ \smp -> +disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} = + withSMPClient c rq "NDEL" $ \smp -> disableSMPQueueNotifications smp rcvPrivateKey rcvId sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m () -sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do - withLogClient c server rcvId "ACK" $ \smp -> +sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do + withSMPClient c rq "ACK" $ \smp -> ackSMPMessage smp rcvPrivateKey rcvId msgId atomically $ releaseGetLock c rq @@ -767,57 +817,57 @@ releaseGetLock c RcvQueue {server, rcvId} = TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ()) suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () -suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} = - withLogClient c server rcvId "OFF" $ \smp -> +suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = + withSMPClient c rq "OFF" $ \smp -> suspendSMPQueue smp rcvPrivateKey rcvId deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () -deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} = - withLogClient c server rcvId "DEL" $ \smp -> +deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = + withSMPClient c rq "DEL" $ \smp -> deleteSMPQueue smp rcvPrivateKey rcvId -sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m () -sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msgFlags agentMsg = - withLogClient_ c server sndId "SEND " $ \smp -> do +sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m () +sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg = + withSMPClient_ c sq "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519) agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey = - withClient c ntfServer "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey) + withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey) agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m () agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code = - withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code + withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} = - withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId + withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m () agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token = - withLogClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token + withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m () agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} = - withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId + withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m () agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval = - withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval + withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> NtfPrivateSignKey -> m NtfSubscriptionId agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey = - withLogClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey) + withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey) agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} = - withLogClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId + withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m () agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} = - withLogClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId + withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do @@ -948,18 +998,18 @@ incStat AgentClient {agentStats} n k = do Just v -> modifyTVar v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats -incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO () -incClientStat c pc = incClientStatN c pc 1 +incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO () +incClientStat c userId pc = incClientStatN c userId pc 1 -incServerStat :: AgentClient -> ProtocolServer p -> ByteString -> ByteString -> IO () -incServerStat c ProtocolServer {host} cmd res = do +incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO () +incServerStat c userId ProtocolServer {host} cmd res = do threadDelay 100000 atomically $ incStat c 1 statsKey where - statsKey = AgentStatsKey {host = strEncode $ L.head host, clientTs = "", cmd, res} + statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res} -incClientStatN :: AgentClient -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO () -incClientStatN c pc n cmd res = do +incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO () +incClientStatN c userId pc n cmd res = do atomically $ incStat c n statsKey where - statsKey = AgentStatsKey {host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res} + statsKey = AgentStatsKey {userId, host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res} diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 60afe9bcf..f20ed2f00 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -32,12 +32,14 @@ import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random import Data.List.NonEmpty (NonEmpty) +import Data.Map (Map) import Data.Time.Clock (NominalDiffTime, nominalDay) import Data.Word (Word16) import Network.Socket import Numeric.Natural import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval +import Simplex.Messaging.Agent.Store (UserId) import Simplex.Messaging.Agent.Store.SQLite import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client @@ -59,7 +61,7 @@ import UnliftIO.STM type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) data InitialAgentServers = InitialAgentServers - { smp :: NonEmpty SMPServerWithAuth, + { smp :: Map UserId (NonEmpty SMPServerWithAuth), ntf :: [NtfServer], netCfg :: NetworkConfig } diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 4b45c07bd..cff4ca125 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -34,6 +34,7 @@ import Simplex.Messaging.Protocol NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, + QueueId, RcvDhSecret, RcvNtfDhSecret, RcvPrivateSignKey, @@ -47,7 +48,8 @@ import Simplex.Messaging.Version -- | A receive queue. SMP queue through which the agent receives messages from a sender. data RcvQueue = RcvQueue - { connId :: ConnId, + { userId :: UserId, + connId :: ConnId, server :: SMPServer, -- | recipient queue ID rcvId :: SMP.RecipientId, @@ -89,7 +91,8 @@ data ClientNtfCreds = ClientNtfCreds -- | A send queue. SMP queue through which the agent sends messages to a recipient. data SndQueue = SndQueue - { connId :: ConnId, + { userId :: UserId, + connId :: ConnId, server :: SMPServer, -- | sender queue ID sndId :: SMP.SenderId, @@ -150,6 +153,27 @@ findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue findRQ sAddr = find $ sameQAddress sAddr . sndAddress {-# INLINE findRQ #-} +class SMPQueue q => SMPQueueRec q where + qUserId :: q -> UserId + qConnId :: q -> ConnId + queueId :: q -> QueueId + +instance SMPQueueRec RcvQueue where + qUserId = userId + {-# INLINE qUserId #-} + qConnId = connId + {-# INLINE qConnId #-} + queueId = rcvId + {-# INLINE queueId #-} + +instance SMPQueueRec SndQueue where + qUserId = userId + {-# INLINE qUserId #-} + qConnId = connId + {-# INLINE qConnId #-} + queueId = sndId + {-# INLINE queueId #-} + -- * Connection types -- | Type of a connection. @@ -222,6 +246,7 @@ deriving instance Show SomeConn data ConnData = ConnData { connId :: ConnId, + userId :: UserId, connAgentVersion :: Version, enableNtfs :: Bool, duplexHandshake :: Maybe Bool, -- added in agent protocol v2 @@ -231,6 +256,8 @@ data ConnData = ConnData data AgentCmdType = ACClient | ACInternal +type UserId = Int64 + instance StrEncoding AgentCmdType where strEncode = \case ACClient -> "CLIENT" @@ -471,6 +498,8 @@ data StoreError SEInternal ByteString | -- | Failed to generate unique random ID SEUniqueID + | -- | User ID not found + SEUserNotFound | -- | Connection not found (or both queues absent). SEConnNotFound | -- | Connection already used. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 135e1b43c..d3c2b7acf 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -27,6 +27,10 @@ module Simplex.Messaging.Agent.Store.SQLite sqlString, execSQL, + -- * Users + createUserRecord, + deleteUserRecord, + -- * Queues and connections createNewConn, updateNewConnRcv, @@ -297,6 +301,18 @@ withTransaction st action = withConnection st $ loop 500 2_000_000 loop (t * 9 `div` 8) (tLim - t) db else E.throwIO e +createUserRecord :: DB.Connection -> IO UserId +createUserRecord db = do + DB.execute_ db "INSERT INTO users () VALUES ()" + insertedRowId db + +deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ()) +deleteUserRecord db userId = runExceptT $ do + _ :: Only Int64 <- + ExceptT . firstRow id SEUserNotFound $ + DB.query db "SELECT user_id FROM users WHERE user_id = ?" (Only userId) + liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) + createConn_ :: TVar ChaChaDRG -> ConnData -> @@ -307,9 +323,9 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of ConnData {connId} -> create connId $> Right connId createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId) -createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode = +createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} cMode = createConn_ gVar cData $ \connId -> do - DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) + DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64) updateNewConnRcv db connId rq = @@ -332,17 +348,17 @@ updateNewConnSnd db connId sq = updateConn = Right <$> addConnSndQueue_ db connId sq createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId) -createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode = +createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode = createConn_ gVar cData $ \connId -> do upsertServer_ db server - DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) + DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) void $ insertRcvQueue_ db connId q createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId) -createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = +createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = createConn_ gVar cData $ \connId -> do upsertServer_ db server - DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake) + DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake) void $ insertSndQueue_ db connId q getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) @@ -1328,9 +1344,9 @@ getAnyConn dbConn connId deleted' = getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) getConnData dbConn connId' = - maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId') + maybeFirstRow cData $ DB.query dbConn "SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId') where - cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode) + cData (userId, connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode) setConnDeleted :: DB.Connection -> ConnId -> IO () setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) @@ -1348,31 +1364,32 @@ getRcvQueuesByConnId_ db connId = rcvQueueQuery :: Query rcvQueueQuery = [sql| - SELECT s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, + SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status, q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret FROM rcv_queues q - INNER JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN connections c ON q.conn_id = c.conn_id |] toRcvQueue :: - (C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus) + (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus) :. (Int64, Bool, Maybe Int64, Maybe Version) :. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> RcvQueue -toRcvQueue ((keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = +toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = let server = SMPServer host port keyHash smpClientVersion = fromMaybe 1 smpClientVersion_ clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} _ -> Nothing - in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds} + in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds} getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue) getRcvQueueById_ db connId dbRcvId = firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId) + DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ?") (connId, dbRcvId) -- | returns all connection queues, the first queue is the primary one getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue)) @@ -1381,16 +1398,17 @@ getSndQueuesByConnId_ dbConn connId = <$> DB.query dbConn [sql| - SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version + SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version FROM snd_queues q - INNER JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN connections c ON q.conn_id = c.conn_id WHERE q.conn_id = ?; |] (Only connId) where - sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) = + sndQueue ((userId, keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) = let server = SMPServer host port keyHash - in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion} + in SndQueue {userId, connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion} primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = -- the current primary queue is ordered first, the next primary - second compare (Down p) (Down p') <> compare i i' diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 43f99cb94..386a0eab1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -37,6 +37,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users import Simplex.Messaging.Encoding.String import Simplex.Messaging.Transport.Client (TransportHost) @@ -53,7 +54,8 @@ schemaMigrations = ("m20220811_onion_hosts", m20220811_onion_hosts), ("m20220817_connection_ntfs", m20220817_connection_ntfs), ("m20220905_commands", m20220905_commands), - ("m20220915_connection_queues", m20220915_connection_queues) + ("m20220915_connection_queues", m20220915_connection_queues), + ("m20230110_users", m20230110_users) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20230110_users.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20230110_users.hs new file mode 100644 index 000000000..a56930ad0 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20230110_users.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20230110_users :: Query +m20230110_users = + [sql| +PRAGMA ignore_check_constraints=ON; + +CREATE TABLE users ( + user_id INTEGER PRIMARY KEY AUTOINCREMENT +); + +INSERT INTO users (user_id) VALUES (1); + +ALTER TABLE connections ADD COLUMN user_id INTEGER CHECK (user_id NOT NULL) +REFERENCES users ON DELETE CASCADE; + +CREATE INDEX idx_connections_user ON connections(user_id); + +UPDATE connections SET user_id = 1; + +PRAGMA ignore_check_constraints=OFF; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index 9c39ffa56..548aaa63c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -22,7 +22,9 @@ CREATE TABLE connections( , duplex_handshake INTEGER NULL DEFAULT 0, enable_ntfs INTEGER, - deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL) + deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL), + user_id INTEGER CHECK(user_id NOT NULL) + REFERENCES users ON DELETE CASCADE ) WITHOUT ROWID; CREATE TABLE rcv_queues( host TEXT NOT NULL, @@ -228,3 +230,5 @@ CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries( conn_id, snd_queue_id ); +CREATE TABLE users(user_id INTEGER PRIMARY KEY AUTOINCREMENT); +CREATE INDEX idx_connections_user ON connections(user_id); diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 611702b69..dcf1e3e10 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -54,6 +54,7 @@ module Simplex.Messaging.Client ProtocolClientError (..), ProtocolClientConfig (..), NetworkConfig (..), + TransportSessionMode (..), defaultClientConfig, defaultNetworkConfig, transportClientConfig, @@ -152,6 +153,8 @@ data NetworkConfig = NetworkConfig hostMode :: HostMode, -- | if above criteria is not met, if the below setting is True return error, otherwise use the first host requiredHostMode :: Bool, + -- | transport sessions are created per user or per entity + sessionMode :: TransportSessionMode, -- | timeout for the initial client TCP/TLS connection (microseconds) tcpConnectTimeout :: Int, -- | timeout of protocol commands (microseconds) @@ -168,12 +171,23 @@ instance ToJSON NetworkConfig where toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True} toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True} +data TransportSessionMode = TSMUser | TSMEntity + deriving (Eq, Show, Generic) + +instance ToJSON TransportSessionMode where + toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM" + toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM" + +instance FromJSON TransportSessionMode where + parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM" + defaultNetworkConfig :: NetworkConfig defaultNetworkConfig = NetworkConfig { socksProxy = Nothing, hostMode = HMOnionViaSocks, requiredHostMode = False, + sessionMode = TSMUser, tcpConnectTimeout = 7_500_000, tcpTimeout = 5_000_000, tcpKeepAlive = Just defaultKeepAliveOpts, @@ -239,8 +253,8 @@ transportHost' = transportHost . client_ -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg)) -getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do +getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe ByteString -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg)) +getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} proxyUsername msgQ disconnected = do case chooseTransportHost networkConfig (host protocolServer) of Right useHost -> (atomically (mkProtocolClient useHost) >>= runClient useTransport useHost) @@ -274,7 +288,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, let tcConfig = transportClientConfig networkConfig action <- async $ - runTransportClient tcConfig useHost port' (Just $ keyHash protocolServer) (client t c cVar) + runTransportClient tcConfig proxyUsername useHost port' (Just $ keyHash protocolServer) (client t c cVar) `finally` atomically (putTMVar cVar $ Left PCENetworkError) c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar) pure $ case c_ of diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index d5f12942c..f2fe97666 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -160,7 +160,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = void $ tryConnectClient (const reconnectClient) loop connectClient :: ExceptT ProtocolClientError IO SMPClient - connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected + connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) Nothing (Just msgQ) clientDisconnected clientDisconnected :: SMPClient -> IO () clientDisconnected _ = do diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 0263e374e..a6897074d 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -77,6 +77,7 @@ module Simplex.Messaging.Protocol BasicAuth (..), SrvLoc (..), CorrId (..), + EntityId, QueueId, RecipientId, SenderId, diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index d6f95f20a..f7af991a4 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -107,15 +107,15 @@ defaultTransportClientConfig :: TransportClientConfig defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True -- | Connect to passed TCP host:port and pass handle to the client. -runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a +runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a runTransportClient = runTLSTransportClient supportedParameters Nothing -runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a -runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} host port keyHash client = do +runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a +runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do let hostName = B.unpack $ strEncode host clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash connectTCP = case socksProxy of - Just proxy -> connectSocksClient proxy $ hostAddr host + Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host _ -> connectTCPClient hostName c <- liftIO $ do sock <- connectTCP port @@ -153,10 +153,12 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err defaultSMPPort :: PortNumber defaultSMPPort = 5223 -connectSocksClient :: SocksProxy -> SocksHostAddress -> ServiceName -> IO Socket -connectSocksClient (SocksProxy addr) hostAddr _port = do +connectSocksClient :: SocksProxy -> Maybe ByteString -> SocksHostAddress -> ServiceName -> IO Socket +connectSocksClient (SocksProxy addr) proxyUsername hostAddr _port = do let port = if null _port then defaultSMPPort else fromMaybe defaultSMPPort $ readMaybe _port - fst <$> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port) + fst <$> case proxyUsername of + Just username -> socksConnectAuth (defaultSocksConf addr) (SocksAddress hostAddr port) (SocksCredentials username "") + _ -> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port) defaultSocksHost :: HostAddress defaultSocksHost = tupleToHostAddress (127, 0, 0, 1) diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index 085af91a6..d0efe2060 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -120,7 +120,7 @@ sendRequest HTTP2Client {reqQ, config} req = do runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> HostName -> ServiceName -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO () runHTTP2Client tlsParams caStore tcConfig host port client = - runTLSTransportClient tlsParams caStore tcConfig (THDomainName host) port Nothing $ \c -> + runTLSTransportClient tlsParams caStore tcConfig Nothing (THDomainName host) port Nothing $ \c -> withTlsConfig c 16384 (`run` client) where run = H.run $ ClientConfig "https" (B.pack host) 20 diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 7f69070c6..d70d5705d 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -218,8 +218,8 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientTest alice bob baseId = do Right () <- runExceptT $ do - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -254,8 +254,8 @@ runAgentClientTest alice bob baseId = do runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientContactTest alice bob baseId = do Right () <- runExceptT $ do - (_, qInfo) <- createConnection alice True SCMContact Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (_, qInfo) <- createConnection alice 1 True SCMContact Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, REQ invId _ "bob's connInfo") <- get alice bobId <- acceptContact alice True invId "alice's connInfo" ("", _, CONF confId _ "alice's connInfo") <- get bob @@ -302,9 +302,9 @@ testAsyncInitiatingOffline = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right () <- runExceptT $ do - (bobId, cReq) <- createConnection alice True SCMInvitation Nothing + (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing disconnectAgentClient alice - aliceId <- joinConnection bob True cReq "bob's connInfo" + aliceId <- joinConnection bob 1 True cReq "bob's connInfo" alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers subscribeConnection alice' bobId ("", _, CONF confId _ "bob's connInfo") <- get alice' @@ -320,8 +320,8 @@ testAsyncJoiningOfflineBeforeActivation = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right () <- runExceptT $ do - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" disconnectAgentClient bob ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" @@ -338,9 +338,9 @@ testAsyncBothOffline = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right () <- runExceptT $ do - (bobId, cReq) <- createConnection alice True SCMInvitation Nothing + (bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing disconnectAgentClient alice - aliceId <- joinConnection bob True cReq "bob's connInfo" + aliceId <- joinConnection bob 1 True cReq "bob's connInfo" disconnectAgentClient bob alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers subscribeConnection alice' bobId @@ -360,9 +360,9 @@ testAsyncServerOffline t = do bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers -- create connection and shutdown the server Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ -> - runExceptT $ createConnection alice True SCMInvitation Nothing + runExceptT $ createConnection alice 1 True SCMInvitation Nothing -- connection fails - Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo" + Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo" ("", "", DOWN srv conns) <- get alice srv `shouldBe` testSMPServer conns `shouldBe` [bobId] @@ -372,7 +372,7 @@ testAsyncServerOffline t = do liftIO $ do srv1 `shouldBe` testSMPServer conns1 `shouldBe` [bobId] - aliceId <- joinConnection bob True cReq "bob's connInfo" + aliceId <- joinConnection bob 1 True cReq "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -387,9 +387,9 @@ testAsyncHelloTimeout = do alice <- getSMPAgentClient agentCfgV1 initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers Right () <- runExceptT $ do - (_, cReq) <- createConnection alice True SCMInvitation Nothing + (_, cReq) <- createConnection alice 1 True SCMInvitation Nothing disconnectAgentClient alice - aliceId <- joinConnection bob True cReq "bob's connInfo" + aliceId <- joinConnection bob 1 True cReq "bob's connInfo" get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED) pure () @@ -444,8 +444,8 @@ testDuplicateMessage t = do makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnection alice bob = do - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -459,7 +459,7 @@ testInactiveClientDisconnected t = do withSmpServerConfigOn t cfg' testPort $ \_ -> do alice <- getSMPAgentClient agentCfg initAgentServers Right () <- runExceptT $ do - (connId, _cReq) <- createConnection alice True SCMInvitation Nothing + (connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing get alice ##> ("", "", DOWN testSMPServer [connId]) pure () @@ -470,7 +470,7 @@ testActiveClientNotDisconnected t = do alice <- getSMPAgentClient agentCfg initAgentServers ts <- getSystemTime Right () <- runExceptT $ do - (connId, _cReq) <- createConnection alice True SCMInvitation Nothing + (connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing keepSubscribing alice connId ts pure () where @@ -619,10 +619,10 @@ testAsyncCommands = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right () <- runExceptT $ do - bobId <- createConnectionAsync alice "1" True SCMInvitation + bobId <- createConnectionAsync alice 1 "1" True SCMInvitation ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId - aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo" + aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" ("2", aliceId', OK) <- get bob liftIO $ aliceId' `shouldBe` aliceId ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -663,7 +663,7 @@ testAsyncCommands = do testAsyncCommandsRestore :: ATransport -> IO () testAsyncCommandsRestore t = do alice <- getSMPAgentClient agentCfg initAgentServers - Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation + Right bobId <- runExceptT $ createConnectionAsync alice 1 "1" True SCMInvitation liftIO $ noMessages alice "alice doesn't receive INV because server is down" disconnectAgentClient alice alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers @@ -679,8 +679,8 @@ testAcceptContactAsync = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right () <- runExceptT $ do - (_, qInfo) <- createConnection alice True SCMContact Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (_, qInfo) <- createConnection alice 1 True SCMContact Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, REQ invId _ "bob's connInfo") <- get alice bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" ("1", bobId', OK) <- get alice @@ -808,11 +808,11 @@ testCreateQueueAuth clnt1 clnt2 = do a <- getClient clnt1 b <- getClient clnt2 Right created <- runExceptT $ do - tryError (createConnection a True SCMInvitation Nothing) >>= \case + tryError (createConnection a 1 True SCMInvitation Nothing) >>= \case Left (SMP AUTH) -> pure 0 Left e -> throwError e Right (bId, qInfo) -> - tryError (joinConnection b True qInfo "bob's connInfo") >>= \case + tryError (joinConnection b 1 True qInfo "bob's connInfo") >>= \case Left (SMP AUTH) -> pure 1 Left e -> throwError e Right aId -> do @@ -826,7 +826,7 @@ testCreateQueueAuth clnt1 clnt2 = do pure created where getClient (clntAuth, clntVersion) = - let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]} + let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]} smpCfg = (defaultClientConfig :: ProtocolClientConfig) {smpServerVRange = mkVersionRange 4 clntVersion} in getSMPAgentClient agentCfg {smpCfg} servers @@ -834,7 +834,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut testSMPServerConnectionTest t newQueueBasicAuth srv = withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running - Right r <- runExceptT $ testSMPServerConnection a srv + Right r <- runExceptT $ testSMPServerConnection a 1 srv pure r testRatchetAdHash :: IO () diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index d340a3522..62a8363ef 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -213,8 +213,8 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right (bobId, aliceId, nonce, message) <- runExceptT $ do -- establish connection - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -275,9 +275,9 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do DeviceToken {} <- registerTestToken bob "bcde" NMInstant apnsQ -- establish connection liftIO $ threadDelay 50000 - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing liftIO $ threadDelay 1000000 - aliceId <- joinConnection bob True qInfo "bob's connInfo" + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" liftIO $ threadDelay 750000 void $ messageNotification apnsQ ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -327,8 +327,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right () <- runExceptT $ do -- establish connection - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") @@ -392,8 +392,8 @@ testChangeToken APNSMockServer {apnsQ} = do bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers Right (aliceId, bobId) <- runExceptT $ do -- establish connection - (bobId, qInfo) <- createConnection alice True SCMInvitation Nothing - aliceId <- joinConnection bob True qInfo "bob's connInfo" + (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing + aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get bob ##> ("", aliceId, INFO "alice's connInfo") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 072021461..1636571b6 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -140,7 +140,7 @@ testForeignKeysEnabled = `shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint) cData1 :: ConnData -cData1 = ConnData {connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False} +cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False} testPrivateSignKey :: C.APrivateSignKey testPrivateSignKey = C.APrivateSignKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe" @@ -154,7 +154,8 @@ testDhSecret = "01234567890123456789012345678901" rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue - { connId = "conn1", + { userId = 1, + connId = "conn1", server = SMPServer "smp.simplex.im" "5223" testKeyHash, rcvId = "1234", rcvPrivateKey = testPrivateSignKey, @@ -173,7 +174,8 @@ rcvQueue1 = sndQueue1 :: SndQueue sndQueue1 = SndQueue - { connId = "conn1", + { userId = 1, + connId = "conn1", server = SMPServer "smp.simplex.im" "5223" testKeyHash, sndId = "3456", sndPublicKey = Nothing, @@ -314,7 +316,8 @@ testUpgradeRcvConnToDuplex = _ <- createSndConn db g cData1 sndQueue1 let anotherSndQueue = SndQueue - { connId = "conn1", + { userId = 1, + connId = "conn1", server = SMPServer "smp.simplex.im" "5223" testKeyHash, sndId = "2345", sndPublicKey = Nothing, @@ -340,7 +343,8 @@ testUpgradeSndConnToDuplex = _ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation let anotherRcvQueue = RcvQueue - { connId = "conn1", + { userId = 1, + connId = "conn1", server = SMPServer "smp.simplex.im" "5223" testKeyHash, rcvId = "3456", rcvPrivateKey = testPrivateSignKey, diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index c4d614531..7c6c9fb2f 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -69,7 +69,7 @@ ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log" testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost - runTransportClient defaultTransportClientConfig host ntfTestPort (Just testKeyHash) $ \h -> + runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> liftIO (runExceptT $ ntfClientHandshake h testKeyHash supportedNTFServerVRange) >>= \case Right th -> client th Left e -> error $ show e diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 16efc39f0..39f8c86d8 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -11,7 +12,8 @@ import Control.Monad.IO.Unlift import Crypto.Random import qualified Data.ByteString.Char8 as B import Data.List.NonEmpty (NonEmpty) -import qualified Data.List.NonEmpty as L +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M import Network.Socket (ServiceName) import NtfClient (ntfTestPort) import SMPClient @@ -27,6 +29,7 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Server (runSMPAgentBlocking) +import Simplex.Messaging.Agent.Store (UserId) import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig) import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Transport @@ -173,13 +176,13 @@ testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5 initAgentServers :: InitialAgentServers initAgentServers = InitialAgentServers - { smp = L.fromList [noAuthSrv testSMPServer], + { smp = userServers [noAuthSrv testSMPServer], ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"], netCfg = defaultNetworkConfig {tcpTimeout = 500_000} } initAgentServers2 :: InitialAgentServers -initAgentServers2 = initAgentServers {smp = L.fromList [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]} +initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]} agentCfg :: AgentConfig agentCfg = @@ -209,11 +212,14 @@ agentCfg = withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> m () -> (ThreadId -> m a) -> m a withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess = let cfg' = agentCfg {tcpPort = port', database = db'} - initServers' = initAgentServers {smp = L.fromList [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]} + initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]} in serverBracket (\started -> runSMPAgentBlocking t started cfg' initServers') afterProcess +userServers :: NonEmpty SMPServerWithAuth -> Map UserId (NonEmpty SMPServerWithAuth) +userServers srvs = M.fromList [(1, srvs)] + withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> (ThreadId -> m a) -> m a withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a $ removeFile (dbFile db') @@ -226,7 +232,7 @@ withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB) testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a testSMPAgentClientOn port' client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost - runTransportClient defaultTransportClientConfig useHost port' (Just testKeyHash) $ \h -> do + runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do line <- liftIO $ getLn h if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion then client h diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 1bca4d060..b005699d3 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -57,7 +57,7 @@ testServerStatsBackupFile = "tests/tmp/smp-server-stats.log" testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a testSMPClient client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost - runTransportClient defaultTransportClientConfig useHost testPort (Just testKeyHash) $ \h -> + runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> liftIO (runExceptT $ smpClientHandshake h testKeyHash supportedSMPServerVRange) >>= \case Right th -> client th Left e -> error $ show e