support users in agent to isolate traffic of different users (#598)

* users table, isolate traffic sessions by users or by queues

* remove extra indices

* corrections

Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>

Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>
This commit is contained in:
Evgeny Poberezkin
2023-01-11 13:47:20 +00:00
committed by GitHub
parent b328492dc9
commit f4ad3a983e
22 changed files with 459 additions and 258 deletions

View File

@@ -7,6 +7,7 @@ module Main where
import Control.Logger.Simple
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Server (runSMPAgent)
import Simplex.Messaging.Client (defaultNetworkConfig)
@@ -18,7 +19,7 @@ cfg = defaultAgentConfig
servers :: InitialAgentServers
servers =
InitialAgentServers
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
{ smp = M.fromList [(1, L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"])],
ntf = [],
netCfg = defaultNetworkConfig
}

View File

@@ -1,11 +1,17 @@
packages: .
-- packages: . ../direct-sqlcipher ../sqlcipher-simple
-- packages: . ../hs-socks
source-repository-package
type: git
location: https://github.com/simplex-chat/aeson.git
tag: 3eb66f9a68f103b5f1489382aad89f5712a64db7
source-repository-package
type: git
location: https://github.com/simplex-chat/hs-socks.git
tag: a30cc7a79a08d8108316094f8f2f82a0c5e1ac51
source-repository-package
type: git
location: https://github.com/simplex-chat/direct-sqlcipher.git

View File

@@ -55,6 +55,7 @@ library
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
Simplex.Messaging.Agent.TRcvQueues
Simplex.Messaging.Client
Simplex.Messaging.Client.Agent

View File

@@ -38,6 +38,8 @@ module Simplex.Messaging.Agent
disconnectAgentClient,
resumeAgentClient,
withConnLock,
createUser,
deleteUser,
createConnectionAsync,
joinConnectionAsync,
allowConnectionAsync,
@@ -158,13 +160,19 @@ resumeAgentClient c = atomically $ writeTVar (active c) True
-- |
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
createUser c = withAgentEnv c . createUser' c
deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> m ()
deleteUser c = withAgentEnv c . deleteUser' c
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
createConnectionAsync c userId corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c userId corrId enableNtfs cMode
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnectionAsync c userId corrId enableNtfs = withAgentEnv c .: joinConnAsync c userId corrId enableNtfs
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -187,12 +195,12 @@ deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -
deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c
-- | Create SMP agent connection (NEW command)
createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
createConnection c enableNtfs cMode clientData = withAgentEnv c $ newConn c "" False enableNtfs cMode clientData
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
createConnection c userId enableNtfs cMode clientData = withAgentEnv c $ newConn c userId "" False enableNtfs cMode clientData
-- | Join SMP agent connection (JOIN command)
joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnection c userId enableNtfs = withAgentEnv c .: joinConn c userId "" False enableNtfs
-- | Allow connection to continue after CONF notification (LET command)
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -256,12 +264,12 @@ getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m By
getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c
-- | Change servers to be used for creating new queues
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers c = withAgentEnv c . setSMPServers' c
setSMPServers :: AgentErrorMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers c = withAgentEnv c .: setSMPServers' c
-- | Test SMP server
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
testSMPServerConnection c = withAgentEnv c . runSMPServerTest c
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
testSMPServerConnection c = withAgentEnv c .: runSMPServerTest c
setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m ()
setNtfServers c = withAgentEnv c . setNtfServers' c
@@ -349,8 +357,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do
-- | execute any SMP agent command
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
processCommand c (connId, cmd) = case cmd of
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode Nothing
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId False enableNtfs cMode Nothing
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
@@ -361,29 +369,43 @@ processCommand c (connId, cmd) = case cmd of
OFF -> suspendConnection' c connId $> (connId, OK)
DEL -> deleteConnection' c connId $> (connId, OK)
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
where
-- command interface does not support different users
userId = 1
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
newConnAsync c corrId enableNtfs cMode = do
createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
createUser' c srvs = do
userId <- withStore' c createUserRecord
atomically $ TM.insert userId srvs $ smpServers c
pure userId
deleteUser' :: AgentMonad m => AgentClient -> UserId -> m ()
deleteUser' c userId = do
withStore c (`deleteUserRecord` userId)
atomically $ TM.delete userId $ smpServers c
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
newConnAsync c userId corrId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
connId <- withStore c $ \db -> createNewConn db g cData cMode
enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode)
pure connId
joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo =
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _cInfo =
throwError $ CMD PROHIBITED
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -397,9 +419,9 @@ acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> Invitat
acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ ContactConnection {} -> do
SomeConn _ (ContactConnection ConnData {userId} _) -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConnAsync c corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
@@ -444,14 +466,14 @@ switchConnectionAsync' c corrId connId =
SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
_ -> throwError $ CMD PROHIBITED
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
newConn c connId asyncMode enableNtfs cMode clientData =
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode clientData
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
newConn c userId connId asyncMode enableNtfs cMode clientData =
getSMPServer c userId >>= newConnSrv c userId connId asyncMode enableNtfs cMode clientData
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do
newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
(q, qUri) <- newRcvQueue c "" srv smpClientVRange
(q, qUri) <- newRcvQueue c userId "" srv smpClientVRange
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
let rq = (q :: RcvQueue) {connId = connId'}
addSubscription c rq
@@ -471,19 +493,19 @@ newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do
pure connId
setUpConn False rq connAgentVersion = do
g <- asks idsDrg
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
withStore c $ \db -> createRcvConn db g cData rq cMode
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c connId asyncMode enableNtfs cReq cInfo = do
joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c userId connId asyncMode enableNtfs cReq cInfo = do
srv <- case cReq of
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
getNextSMPServer c [qServer q]
_ -> getSMPServer c
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
getNextSMPServer c userId [qServer q]
_ -> getSMPServer c userId
joinConnSrv c userId connId asyncMode enableNtfs cReq cInfo srv
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
joinConnSrv c userId connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
case ( qUri `compatibleVersion` smpClientVRange,
e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange,
@@ -493,9 +515,9 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
(_, rcDHRs) <- liftIO C.generateKeyPair'
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
q <- newSndQueue "" qInfo
q <- newSndQueue userId "" qInfo
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
connId' <- setUpConn asyncMode cData q rc
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
@@ -520,23 +542,23 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
liftIO $ createRatchet db connId' rc
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConnSrv c connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
joinConnSrv c userId connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
crAgentVRange `compatibleVersion` aVRange
) of
(Just qInfo, Just vrsn) -> do
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation Nothing srv
sendInvitation c qInfo vrsn cReq cInfo
(connId', cReq) <- newConnSrv c userId connId False enableNtfs SCMInvitation Nothing srv
sendInvitation c userId qInfo vrsn cReq cInfo
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
joinConnSrv _c _userId _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
throwError $ CMD PROHIBITED
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion
createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c userId connId srv $ versionToRange smpClientVersion
let qInfo = toVersionT qUri smpClientVersion
addSubscription c rq
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
@@ -564,9 +586,9 @@ acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId
acceptContact' c connId enableNtfs invId ownConnInfo = withConnLock c connId "acceptContact" $ do
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ ContactConnection {} -> do
SomeConn _ (ContactConnection ConnData {userId} _) -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
joinConn c userId connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
@@ -799,13 +821,15 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
NEW enableNtfs (ACM cMode) -> noServer $ do
usedSrvs <- newTVarIO ([] :: [SMPServer])
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
(_, cReq) <- newConnSrv c connId True enableNtfs cMode Nothing srv
let userId = 1 -- TODO
(_, cReq) <- newConnSrv c userId connId True enableNtfs cMode Nothing srv
notify $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do
let initUsed = [qServer q]
usedSrvs <- newTVarIO initUsed
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
let userId = 1 -- TODO
void $ joinConnSrv c userId connId True enableNtfs cReq connInfo srv
notify OK
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK
@@ -901,10 +925,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m ()
withNextSrv usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c used
let userId = 1 -- TODO
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId used
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
srvs_ <- TM.lookup userId $ smpServers c
-- TODO this condition does not account for servers change, it has to be changed to see if there are any remaining unused servers configured for the user
let used' = if length used + 1 >= maybe 0 L.length srvs_ then initUsed else srv : used
writeTVar usedSrvs used'
action srvAuth
-- ^ ^ ^ async command processing /
@@ -978,7 +1004,7 @@ getPendingMsgQ c SndQueue {server, sndId} = do
pure q
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, duplexHandshake} sq = do
(mq, qLock) <- atomically $ getPendingMsgQ c sq
ri <- asks $ messageRetryInterval . config
forever $ do
@@ -1067,7 +1093,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- and this branch should never be reached as receive is created before the confirmation,
-- so the condition is not necessary here, strictly speaking.
_ -> unless (duplexHandshake == Just True) $ do
srv <- getSMPServer c
srv <- getSMPServer c userId
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
@@ -1142,12 +1168,12 @@ switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
switchConnection' c connId = withConnLock c connId "switchConnection" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
DuplexConnection cData@ConnData {userId} rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
clientVRange <- asks $ smpClientVRange . config
-- try to get the server that is different from all queues, or at least from the primary rcv queue
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
srv' <- if srv == server then getNextSMPServer c [server] else pure srvAuth
(q, qUri) <- newRcvQueue c connId srv' clientVRange
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth
(q, qUri) <- newRcvQueue c userId connId srv' clientVRange
let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnRcvQueue db connId rq'
addSubscription c rq'
@@ -1217,8 +1243,8 @@ connectionStats = \case
NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []}
-- | Change servers to be used for creating new queues, in Reader monad
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers' c = atomically . writeTVar (smpServers c)
setSMPServers' :: AgentMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers' c userId srvs = atomically $ TM.insert userId srvs $ smpServers c
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus
registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
@@ -1457,8 +1483,8 @@ debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs} = do
where
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
getSMPServer :: AgentMonad m => AgentClient -> m SMPServerWithAuth
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth
getSMPServer c userId = withUserServers c userId pickServer
pickServer :: AgentMonad m => NonEmpty SMPServerWithAuth -> m SMPServerWithAuth
pickServer = \case
@@ -1467,13 +1493,18 @@ pickServer = \case
gen <- asks randomServer
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServerWithAuth
getNextSMPServer c usedSrvs = do
srvs <- readTVarIO $ smpServers c
getNextSMPServer :: AgentMonad m => AgentClient -> UserId -> [SMPServer] -> m SMPServerWithAuth
getNextSMPServer c userId usedSrvs = withUserServers c userId $ \srvs ->
case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of
Just srvs' -> pickServer srvs'
_ -> pickServer srvs
withUserServers :: AgentMonad m => AgentClient -> UserId -> (NonEmpty SMPServerWithAuth -> m a) -> m a
withUserServers c userId action =
atomically (TM.lookup userId $ smpServers c) >>= \case
Just srvs -> action srvs
_ -> throwError $ INTERNAL "unknown userId - no SMP servers"
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
t <- atomically $ readTBQueue msgQ
@@ -1488,7 +1519,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
processSMP rq conn $ connData conn
where
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {userId, connId, duplexHandshake} = withConnLock c connId "processSMP" $
case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $
@@ -1586,8 +1617,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
ackDel = enqueueCmd . ICAckDel rId srvMsgId
handleNotifyAck :: m () -> m ()
handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack
SMP.END ->
atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND)
SMP.END -> do
-- TODO is race condition possible here on session mode change?
tSess <- mkSMPTransportSession c rq
atomically (TM.lookup tSess smpClients $>>= tryReadTMVar >>= processEND)
>>= logServer "<--" c srv rId
where
processEND = \case
@@ -1724,7 +1757,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
(Just _, _) -> qError "QADD: queue address is already used in connection"
(_, Just _replaced@SndQueue {dbQueueId}) -> do
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo
let sq' = (sq_ :: SndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnSndQueue db connId sq'
case (sndPublicKey, e2ePubKey) of
@@ -1796,12 +1829,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
| otherwise = MsgError MsgDuplicate -- this case is not possible
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do
clientVRange <- asks $ smpClientVRange . config
case qInfo `proveCompatible` clientVRange of
Nothing -> throwError $ AGENT A_VERSION
Just qInfo' -> do
sq <- newSndQueue connId qInfo'
sq <- newSndQueue userId connId qInfo'
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
@@ -1860,14 +1893,15 @@ agentRatchetDecrypt db connId encAgentMsg = do
liftIO $ updateRatchet db connId rc' skippedDiff
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue
newSndQueue connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m SndQueue
newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
C.SignAlg a <- asks $ cmdSignAlg . config
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
pure
SndQueue
{ connId,
{ userId,
connId,
server = smpServer,
sndId = senderId,
sndPublicKey = Just sndPublicKey,

View File

@@ -48,6 +48,9 @@ module Simplex.Messaging.Agent.Client
agentNtfCreateSubscription,
agentNtfCheckSubscription,
agentNtfDeleteSubscription,
-- TODO possibly, this should not be exported?
-- this is used in END processing
mkSMPTransportSession,
agentCbEncrypt,
agentCbDecrypt,
cryptoError,
@@ -131,6 +134,7 @@ import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse)
import Simplex.Messaging.Protocol
( AProtocolType (..),
BrokerMsg,
EntityId,
ErrorType,
MsgFlags (..),
MsgId,
@@ -166,15 +170,22 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
-- | Transport session key - includes entity ID if `sessionMode = TSMEntity`.
type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
type SMPTransportSession = TransportSession SMP.BrokerMsg
type NtfTransportSession = TransportSession NtfResponse
data AgentClient = AgentClient
{ active :: TVar Bool,
rcvQ :: TBQueue (ATransmission 'Client),
subQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue (ServerTransmission BrokerMsg),
smpServers :: TVar (NonEmpty SMPServerWithAuth),
smpClients :: TMap SMPServer SMPClientVar,
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
smpClients :: TMap SMPTransportSession SMPClientVar,
ntfServers :: TVar [NtfServer],
ntfClients :: TMap NtfServer NtfClientVar,
ntfClients :: TMap NtfTransportSession NtfClientVar,
useNetworkConfig :: TVar NetworkConfig,
subscrConns :: TVar (Set ConnId),
activeSubs :: TRcvQueues,
@@ -195,7 +206,7 @@ data AgentClient = AgentClient
-- locks to prevent concurrent operations with connection
connLocks :: TMap ConnId Lock,
-- locks to prevent concurrent reconnections to SMP servers
reconnectLocks :: TMap SMPServer Lock,
reconnectLocks :: TMap SMPTransportSession Lock,
reconnections :: TVar [Async ()],
asyncClients :: TVar [Async ()],
agentStats :: TMap AgentStatsKey (TVar Int),
@@ -227,7 +238,13 @@ data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map St
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
data AgentStatsKey = AgentStatsKey {host :: ByteString, clientTs :: ByteString, cmd :: ByteString, res :: ByteString}
data AgentStatsKey = AgentStatsKey
{ userId :: UserId,
host :: ByteString,
clientTs :: ByteString,
cmd :: ByteString,
res :: ByteString
}
deriving (Eq, Ord, Show)
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
@@ -270,7 +287,7 @@ agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
class ProtocolServerClient msg where
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg)
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
clientProtocolError :: ErrorType -> AgentErrorType
instance ProtocolServerClient BrokerMsg where
@@ -281,19 +298,20 @@ instance ProtocolServerClient NtfResponse where
getProtocolServerClient = getNtfServerClient
clientProtocolError = NTF
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, entityId_) = do
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
atomically (getClientVar srv smpClients)
atomically (getClientVar tSess smpClients)
>>= either
(newProtocolClient c srv smpClients connectClient reconnectClient)
(waitForProtocolClient c srv)
(newProtocolClient c tSess smpClients connectClient reconnectClient)
(waitForProtocolClient c tSess)
where
connectClient :: m SMPClient
connectClient = do
cfg <- getClientConfig c smpCfg
u <- askUnliftIO
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
let proxyUsername = Just $ bshow userId <> maybe "" (":" <>) entityId_
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg proxyUsername (Just msgQ) $ clientDisconnected u)
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
clientDisconnected u client = do
@@ -302,14 +320,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
where
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM.delete tSess smpClients
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
mapM_ (`RQ.addQueue` pendingSubs c) qs
pure (qs, S.toList conns)
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
serverDown (qs, conns) = whenM (readTVarIO active) $ do
incClientStat c client "DISCONNECT" ""
incClientStat c userId client "DISCONNECT" ""
notifySub "" $ hostEvent DISCONNECT client
unless (null conns) $ notifySub "" $ DOWN srv conns
unless (null qs) $ do
@@ -329,18 +347,20 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
reconnectClient :: m ()
reconnectClient =
withLockMap_ (reconnectLocks c) srv "reconnect" $
withLockMap_ (reconnectLocks c) tSess "reconnect" $
atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe
where
resubscribe :: [RcvQueue] -> m ()
resubscribe qs = do
connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar)
connected <- maybe False isRight <$> atomically (TM.lookup tSess smpClients $>>= tryReadTMVar)
cs <- atomically . RQ.getConns $ activeSubs c
-- TODO probably tSess should be passed here or nothing
-- maybe this should be another function, not subscribeQueues
(client_, rs) <- subscribeQueues c srv qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
liftIO $ do
unless connected . forM_ client_ $ \cl -> do
incClientStat c cl "CONNECT" ""
incClientStat c userId cl "CONNECT" ""
notifySub "" $ hostEvent CONNECT cl
let conns = S.toList $ S.fromList okConns `S.difference` cs
unless (null conns) $ notifySub "" $ UP srv conns
@@ -351,37 +371,37 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
notifySub :: ConnId -> ACommand 'Agent -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
getNtfServerClient c@AgentClient {active, ntfClients} srv = do
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
atomically (getClientVar srv ntfClients)
atomically (getClientVar tSess ntfClients)
>>= either
(newProtocolClient c srv ntfClients connectClient $ pure ())
(waitForProtocolClient c srv)
(newProtocolClient c tSess ntfClients connectClient $ pure ())
(waitForProtocolClient c tSess)
where
connectClient :: m NtfClient
connectClient = do
cfg <- getClientConfig c ntfCfg
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing clientDisconnected)
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing Nothing clientDisconnected)
clientDisconnected :: NtfClient -> IO ()
clientDisconnected client = do
atomically $ TM.delete srv ntfClients
incClientStat c client "DISCONNECT" ""
atomically $ TM.delete tSess ntfClients
incClientStat c userId client "DISCONNECT" ""
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients
getClientVar :: forall a s. TransportSession s -> TMap (TransportSession s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup tSess clients
where
newClientVar :: STM (TMVar a)
newClientVar = do
var <- newEmptyTMVar
TM.insert srv var clients
TM.insert tSess var clients
pure var
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ClientVar msg -> m (ProtocolClient msg)
waitForProtocolClient c srv clientVar = do
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
waitForProtocolClient c (_, srv, _) clientVar = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
liftEither $ case client_ of
@@ -393,30 +413,30 @@ newProtocolClient ::
forall msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
AgentClient ->
ProtoServer msg ->
TMap (ProtoServer msg) (ClientVar msg) ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
m (ProtocolClient msg) ->
m () ->
ClientVar msg ->
m (ProtocolClient msg)
newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
where
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
tryConnectClient successAction retryAction =
tryError connectClient >>= \r -> case r of
Right client -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
atomically $ putTMVar clientVar r
liftIO $ incClientStat c client "CLIENT" "OK"
liftIO $ incClientStat c userId client "CLIENT" "OK"
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
successAction client
Left e -> do
liftIO $ incServerStat c srv "CLIENT" $ strEncode e
liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e
if temporaryAgentError e
then retryAction
else atomically $ do
putTMVar clientVar (Left e)
TM.delete srv clients
TM.delete tSess clients
throwError e
tryConnectAsync :: m ()
tryConnectAsync = do
@@ -471,7 +491,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
where
k = (server, sndId)
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
where
@@ -494,30 +514,43 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur
where
newLock = newEmptyTMVar >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
withClient_ c srv statCmd action = do
cl <- getProtocolServerClient c srv
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchError` logServerError cl
where
stat cl = liftIO . incClientStat c cl statCmd
stat cl = liftIO . incClientStat c userId cl statCmd
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
logServerError cl e = do
logServer "<--" c srv "" $ strEncode e
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a
withLogClient_ c srv qId cmdStr action = do
logServer "-->" c srv qId cmdStr
res <- withClient_ c srv cmdStr action
logServer "<--" c srv qId "OK"
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c srv statKey action = withClient_ c srv statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
withSMPClient c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient c tSess (queueId q) cmdStr action
withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a
withSMPClient_ c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient_ c tSess (queueId q) cmdStr action
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
withNtfClient c srv = withLogClient c (0, srv, Nothing)
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
liftClient protocolError_ = liftError . protocolClientError protocolError_
@@ -551,12 +584,13 @@ instance ToJSON SMPTestFailure where
toEncoding = J.genericToEncoding J.defaultOptions
toJSON = J.genericToJSON J.defaultOptions
runSMPServerTest :: AgentMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
runSMPServerTest c (ProtoServerWithAuth srv auth) = do
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
cfg <- getClientConfig c smpCfg
C.SignAlg a <- asks $ cmdSignAlg . config
liftIO $ do
getProtocolClient srv cfg Nothing (\_ -> pure ()) >>= \case
let proxyUsername = bshow userId
getProtocolClient srv cfg (Just proxyUsername) Nothing (\_ -> pure ()) >>= \case
Right smp -> do
(rKey, rpKey) <- C.generateSignatureKeyPair a
(sKey, _) <- C.generateSignatureKeyPair a
@@ -566,7 +600,7 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
liftError (testErr TSSecureQueue) $ secureSMPQueue smp rpKey rcvId sKey
liftError (testErr TSDeleteQueue) $ deleteSMPQueue smp rpKey rcvId
ok <- tcpTimeout (networkConfig cfg) `timeout` closeProtocolClient smp
incClientStat c smp "TEST" "OK"
incClientStat c userId smp "TEST" "OK"
pure $ either Just (const Nothing) r <|> maybe (Just (SMPTestFailure TSDisconnect $ BROKER addr TIMEOUT)) (const Nothing) ok
Left e -> pure (Just $ testErr TSConnect e)
where
@@ -574,19 +608,29 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
testErr step = SMPTestFailure step . protocolClientError SMP addr
newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
mkTransportSession c userId srv entityId = do
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
pure (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
mkSMPTransportSession c q = mkTransportSession c (qUserId q) (qServer q) (queueId q)
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
C.SignAlg a <- asks (cmdSignAlg . config)
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
logServer "-->" c srv "" "NEW"
tSess <- mkTransportSession c userId srv connId
QIK {rcvId, sndId, rcvPublicDhKey} <-
withClient c srv "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
withClient c tSess "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
let rq =
RcvQueue
{ connId,
{ userId,
connId,
server = srv,
rcvId,
rcvPrivateKey,
@@ -609,7 +653,7 @@ subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
RQ.addQueue rq $ pendingSubs c
withLogClient c server rcvId "SUB" $ \smp ->
withSMPClient c rq "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
>>= either throwError pure
@@ -648,14 +692,19 @@ subscribeQueues c srv qs = do
RQ.addQueue rq $ pendingSubs c
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
-- TODO these subscriptions should happen in different sessions if mode is TSMEntity
-- it is also a question whether it is needed to group by server outside if grouping by sessions should happen here anyway
let userId = 1
tSess = (userId, srv, Nothing)
smp_ <- tryError $ getSMPServerClient c tSess
(eitherToMaybe smp_,) . (errs <>) <$> case smp_ of
Left e -> pure $ map (,Left e) qs_
Right smp -> do
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
let qs2 = L.map queueCreds qs'
n = (length qs2 - 1) `div` 90 + 1
liftIO $ incClientStatN c smp n "SUBS" "OK"
liftIO $ incClientStatN c userId smp n "SUBS" "OK"
liftIO $ do
rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
mapM_ (uncurry $ processSubResult c) rs
@@ -699,16 +748,17 @@ logSecret :: ByteString -> ByteString
logSecret bs = encode $ B.take 3 bs
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
withSMPClient_ c sq "SEND <CONF>" $ \smp -> do
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo =
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
tSess <- mkTransportSession c userId smpServer senderId
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
msg <- mkInvitation
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
where
@@ -722,7 +772,7 @@ sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderI
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
atomically createTakeGetLock
(v, msg_) <- withLogClient c server rcvId "GET" $ \smp ->
(v, msg_) <- withSMPClient c rq "GET" $ \smp ->
(thVersion smp,) <$> getSMPMessage smp rcvPrivateKey rcvId
mapM (decryptMeta v) msg_
where
@@ -742,23 +792,23 @@ decryptSMPMessage v rq SMP.RcvMessage {msgId, msgTs, msgFlags, msgBody = SMP.Enc
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
withLogClient c server rcvId "KEY <key>" $ \smp ->
secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey =
withSMPClient c rq "KEY <key>" $ \smp ->
secureSMPQueue smp rcvPrivateKey rcvId senderKey
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (NotifierId, RcvNtfPublicDhKey)
enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
withLogClient c server rcvId "NKEY <nkey>" $ \smp ->
enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
withSMPClient c rq "NKEY <nkey>" $ \smp ->
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
disableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogClient c server rcvId "NDEL" $ \smp ->
disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "NDEL" $ \smp ->
disableSMPQueueNotifications smp rcvPrivateKey rcvId
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do
withLogClient c server rcvId "ACK" $ \smp ->
sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do
withSMPClient c rq "ACK" $ \smp ->
ackSMPMessage smp rcvPrivateKey rcvId msgId
atomically $ releaseGetLock c rq
@@ -767,57 +817,57 @@ releaseGetLock c RcvQueue {server, rcvId} =
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogClient c server rcvId "OFF" $ \smp ->
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "OFF" $ \smp ->
suspendSMPQueue smp rcvPrivateKey rcvId
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogClient c server rcvId "DEL" $ \smp ->
deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "DEL" $ \smp ->
deleteSMPQueue smp rcvPrivateKey rcvId
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msgFlags agentMsg =
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg =
withSMPClient_ c sq "SEND <MSG>" $ \smp -> do
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
withClient c ntfServer "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
withLogClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> NtfPrivateSignKey -> m NtfSubscriptionId
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
withLogClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
@@ -948,18 +998,18 @@ incStat AgentClient {agentStats} n k = do
Just v -> modifyTVar v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
incClientStat c pc = incClientStatN c pc 1
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
incServerStat :: AgentClient -> ProtocolServer p -> ByteString -> ByteString -> IO ()
incServerStat c ProtocolServer {host} cmd res = do
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
incServerStat c userId ProtocolServer {host} cmd res = do
threadDelay 100000
atomically $ incStat c 1 statsKey
where
statsKey = AgentStatsKey {host = strEncode $ L.head host, clientTs = "", cmd, res}
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
incClientStatN :: AgentClient -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c pc n cmd res = do
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c userId pc n cmd res = do
atomically $ incStat c n statsKey
where
statsKey = AgentStatsKey {host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
statsKey = AgentStatsKey {userId, host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}

View File

@@ -32,12 +32,14 @@ import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
import Data.List.NonEmpty (NonEmpty)
import Data.Map (Map)
import Data.Time.Clock (NominalDiffTime, nominalDay)
import Data.Word (Word16)
import Network.Socket
import Numeric.Natural
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Store (UserId)
import Simplex.Messaging.Agent.Store.SQLite
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client
@@ -59,7 +61,7 @@ import UnliftIO.STM
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
data InitialAgentServers = InitialAgentServers
{ smp :: NonEmpty SMPServerWithAuth,
{ smp :: Map UserId (NonEmpty SMPServerWithAuth),
ntf :: [NtfServer],
netCfg :: NetworkConfig
}

View File

@@ -34,6 +34,7 @@ import Simplex.Messaging.Protocol
NotifierId,
NtfPrivateSignKey,
NtfPublicVerifyKey,
QueueId,
RcvDhSecret,
RcvNtfDhSecret,
RcvPrivateSignKey,
@@ -47,7 +48,8 @@ import Simplex.Messaging.Version
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
data RcvQueue = RcvQueue
{ connId :: ConnId,
{ userId :: UserId,
connId :: ConnId,
server :: SMPServer,
-- | recipient queue ID
rcvId :: SMP.RecipientId,
@@ -89,7 +91,8 @@ data ClientNtfCreds = ClientNtfCreds
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
data SndQueue = SndQueue
{ connId :: ConnId,
{ userId :: UserId,
connId :: ConnId,
server :: SMPServer,
-- | sender queue ID
sndId :: SMP.SenderId,
@@ -150,6 +153,27 @@ findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue
findRQ sAddr = find $ sameQAddress sAddr . sndAddress
{-# INLINE findRQ #-}
class SMPQueue q => SMPQueueRec q where
qUserId :: q -> UserId
qConnId :: q -> ConnId
queueId :: q -> QueueId
instance SMPQueueRec RcvQueue where
qUserId = userId
{-# INLINE qUserId #-}
qConnId = connId
{-# INLINE qConnId #-}
queueId = rcvId
{-# INLINE queueId #-}
instance SMPQueueRec SndQueue where
qUserId = userId
{-# INLINE qUserId #-}
qConnId = connId
{-# INLINE qConnId #-}
queueId = sndId
{-# INLINE queueId #-}
-- * Connection types
-- | Type of a connection.
@@ -222,6 +246,7 @@ deriving instance Show SomeConn
data ConnData = ConnData
{ connId :: ConnId,
userId :: UserId,
connAgentVersion :: Version,
enableNtfs :: Bool,
duplexHandshake :: Maybe Bool, -- added in agent protocol v2
@@ -231,6 +256,8 @@ data ConnData = ConnData
data AgentCmdType = ACClient | ACInternal
type UserId = Int64
instance StrEncoding AgentCmdType where
strEncode = \case
ACClient -> "CLIENT"
@@ -471,6 +498,8 @@ data StoreError
SEInternal ByteString
| -- | Failed to generate unique random ID
SEUniqueID
| -- | User ID not found
SEUserNotFound
| -- | Connection not found (or both queues absent).
SEConnNotFound
| -- | Connection already used.

View File

@@ -27,6 +27,10 @@ module Simplex.Messaging.Agent.Store.SQLite
sqlString,
execSQL,
-- * Users
createUserRecord,
deleteUserRecord,
-- * Queues and connections
createNewConn,
updateNewConnRcv,
@@ -297,6 +301,18 @@ withTransaction st action = withConnection st $ loop 500 2_000_000
loop (t * 9 `div` 8) (tLim - t) db
else E.throwIO e
createUserRecord :: DB.Connection -> IO UserId
createUserRecord db = do
DB.execute_ db "INSERT INTO users () VALUES ()"
insertedRowId db
deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ())
deleteUserRecord db userId = runExceptT $ do
_ :: Only Int64 <-
ExceptT . firstRow id SEUserNotFound $
DB.query db "SELECT user_id FROM users WHERE user_id = ?" (Only userId)
liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId)
createConn_ ::
TVar ChaChaDRG ->
ConnData ->
@@ -307,9 +323,9 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
ConnData {connId} -> create connId $> Right connId
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode =
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} cMode =
createConn_ gVar cData $ \connId -> do
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
updateNewConnRcv db connId rq =
@@ -332,17 +348,17 @@ updateNewConnSnd db connId sq =
updateConn = Right <$> addConnSndQueue_ db connId sq
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertRcvQueue_ db connId q
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertSndQueue_ db connId q
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
@@ -1328,9 +1344,9 @@ getAnyConn dbConn connId deleted' =
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData dbConn connId' =
maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
maybeFirstRow cData $ DB.query dbConn "SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
where
cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
setConnDeleted :: DB.Connection -> ConnId -> IO ()
setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
@@ -1348,31 +1364,32 @@ getRcvQueuesByConnId_ db connId =
rcvQueueQuery :: Query
rcvQueueQuery =
[sql|
SELECT s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN connections c ON q.conn_id = c.conn_id
|]
toRcvQueue ::
(C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
:. (Int64, Bool, Maybe Int64, Maybe Version)
:. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
RcvQueue
toRcvQueue ((keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
let server = SMPServer host port keyHash
smpClientVersion = fromMaybe 1 smpClientVersion_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
getRcvQueueById_ db connId dbRcvId =
firstRow toRcvQueue SEConnNotFound $
DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId)
DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ?") (connId, dbRcvId)
-- | returns all connection queues, the first queue is the primary one
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
@@ -1381,16 +1398,17 @@ getSndQueuesByConnId_ dbConn connId =
<$> DB.query
dbConn
[sql|
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
FROM snd_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN connections c ON q.conn_id = c.conn_id
WHERE q.conn_id = ?;
|]
(Only connId)
where
sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
sndQueue ((userId, keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
let server = SMPServer host port keyHash
in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
in SndQueue {userId, connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
-- the current primary queue is ordered first, the next primary - second
compare (Down p) (Down p') <> compare i i'

View File

@@ -37,6 +37,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -53,7 +54,8 @@ schemaMigrations =
("m20220811_onion_hosts", m20220811_onion_hosts),
("m20220817_connection_ntfs", m20220817_connection_ntfs),
("m20220905_commands", m20220905_commands),
("m20220915_connection_queues", m20220915_connection_queues)
("m20220915_connection_queues", m20220915_connection_queues),
("m20230110_users", m20230110_users)
]
-- | The list of migrations in ascending order by date

View File

@@ -0,0 +1,27 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20230110_users :: Query
m20230110_users =
[sql|
PRAGMA ignore_check_constraints=ON;
CREATE TABLE users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT
);
INSERT INTO users (user_id) VALUES (1);
ALTER TABLE connections ADD COLUMN user_id INTEGER CHECK (user_id NOT NULL)
REFERENCES users ON DELETE CASCADE;
CREATE INDEX idx_connections_user ON connections(user_id);
UPDATE connections SET user_id = 1;
PRAGMA ignore_check_constraints=OFF;
|]

View File

@@ -22,7 +22,9 @@ CREATE TABLE connections(
,
duplex_handshake INTEGER NULL DEFAULT 0,
enable_ntfs INTEGER,
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL)
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL),
user_id INTEGER CHECK(user_id NOT NULL)
REFERENCES users ON DELETE CASCADE
) WITHOUT ROWID;
CREATE TABLE rcv_queues(
host TEXT NOT NULL,
@@ -228,3 +230,5 @@ CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries(
conn_id,
snd_queue_id
);
CREATE TABLE users(user_id INTEGER PRIMARY KEY AUTOINCREMENT);
CREATE INDEX idx_connections_user ON connections(user_id);

View File

@@ -54,6 +54,7 @@ module Simplex.Messaging.Client
ProtocolClientError (..),
ProtocolClientConfig (..),
NetworkConfig (..),
TransportSessionMode (..),
defaultClientConfig,
defaultNetworkConfig,
transportClientConfig,
@@ -152,6 +153,8 @@ data NetworkConfig = NetworkConfig
hostMode :: HostMode,
-- | if above criteria is not met, if the below setting is True return error, otherwise use the first host
requiredHostMode :: Bool,
-- | transport sessions are created per user or per entity
sessionMode :: TransportSessionMode,
-- | timeout for the initial client TCP/TLS connection (microseconds)
tcpConnectTimeout :: Int,
-- | timeout of protocol commands (microseconds)
@@ -168,12 +171,23 @@ instance ToJSON NetworkConfig where
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
data TransportSessionMode = TSMUser | TSMEntity
deriving (Eq, Show, Generic)
instance ToJSON TransportSessionMode where
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM"
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM"
instance FromJSON TransportSessionMode where
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM"
defaultNetworkConfig :: NetworkConfig
defaultNetworkConfig =
NetworkConfig
{ socksProxy = Nothing,
hostMode = HMOnionViaSocks,
requiredHostMode = False,
sessionMode = TSMUser,
tcpConnectTimeout = 7_500_000,
tcpTimeout = 5_000_000,
tcpKeepAlive = Just defaultKeepAliveOpts,
@@ -239,8 +253,8 @@ transportHost' = transportHost . client_
--
-- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe ByteString -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} proxyUsername msgQ disconnected = do
case chooseTransportHost networkConfig (host protocolServer) of
Right useHost ->
(atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
@@ -274,7 +288,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
let tcConfig = transportClientConfig networkConfig
action <-
async $
runTransportClient tcConfig useHost port' (Just $ keyHash protocolServer) (client t c cVar)
runTransportClient tcConfig proxyUsername useHost port' (Just $ keyHash protocolServer) (client t c cVar)
`finally` atomically (putTMVar cVar $ Left PCENetworkError)
c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar)
pure $ case c_ of

View File

@@ -160,7 +160,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
void $ tryConnectClient (const reconnectClient) loop
connectClient :: ExceptT ProtocolClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) Nothing (Just msgQ) clientDisconnected
clientDisconnected :: SMPClient -> IO ()
clientDisconnected _ = do

View File

@@ -77,6 +77,7 @@ module Simplex.Messaging.Protocol
BasicAuth (..),
SrvLoc (..),
CorrId (..),
EntityId,
QueueId,
RecipientId,
SenderId,

View File

@@ -107,15 +107,15 @@ defaultTransportClientConfig :: TransportClientConfig
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True
-- | Connect to passed TCP host:port and pass handle to the client.
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTransportClient = runTLSTransportClient supportedParameters Nothing
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} host port keyHash client = do
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do
let hostName = B.unpack $ strEncode host
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash
connectTCP = case socksProxy of
Just proxy -> connectSocksClient proxy $ hostAddr host
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
_ -> connectTCPClient hostName
c <- liftIO $ do
sock <- connectTCP port
@@ -153,10 +153,12 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
defaultSMPPort :: PortNumber
defaultSMPPort = 5223
connectSocksClient :: SocksProxy -> SocksHostAddress -> ServiceName -> IO Socket
connectSocksClient (SocksProxy addr) hostAddr _port = do
connectSocksClient :: SocksProxy -> Maybe ByteString -> SocksHostAddress -> ServiceName -> IO Socket
connectSocksClient (SocksProxy addr) proxyUsername hostAddr _port = do
let port = if null _port then defaultSMPPort else fromMaybe defaultSMPPort $ readMaybe _port
fst <$> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
fst <$> case proxyUsername of
Just username -> socksConnectAuth (defaultSocksConf addr) (SocksAddress hostAddr port) (SocksCredentials username "")
_ -> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
defaultSocksHost :: HostAddress
defaultSocksHost = tupleToHostAddress (127, 0, 0, 1)

View File

@@ -120,7 +120,7 @@ sendRequest HTTP2Client {reqQ, config} req = do
runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> HostName -> ServiceName -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
runHTTP2Client tlsParams caStore tcConfig host port client =
runTLSTransportClient tlsParams caStore tcConfig (THDomainName host) port Nothing $ \c ->
runTLSTransportClient tlsParams caStore tcConfig Nothing (THDomainName host) port Nothing $ \c ->
withTlsConfig c 16384 (`run` client)
where
run = H.run $ ClientConfig "https" (B.pack host) 20

View File

@@ -218,8 +218,8 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do
runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientTest alice bob baseId = do
Right () <- runExceptT $ do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
@@ -254,8 +254,8 @@ runAgentClientTest alice bob baseId = do
runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientContactTest alice bob baseId = do
Right () <- runExceptT $ do
(_, qInfo) <- createConnection alice True SCMContact Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, REQ invId _ "bob's connInfo") <- get alice
bobId <- acceptContact alice True invId "alice's connInfo"
("", _, CONF confId _ "alice's connInfo") <- get bob
@@ -302,9 +302,9 @@ testAsyncInitiatingOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
subscribeConnection alice' bobId
("", _, CONF confId _ "bob's connInfo") <- get alice'
@@ -320,8 +320,8 @@ testAsyncJoiningOfflineBeforeActivation = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
disconnectAgentClient bob
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
@@ -338,9 +338,9 @@ testAsyncBothOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
disconnectAgentClient bob
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
subscribeConnection alice' bobId
@@ -360,9 +360,9 @@ testAsyncServerOffline t = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
-- create connection and shutdown the server
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
runExceptT $ createConnection alice True SCMInvitation Nothing
runExceptT $ createConnection alice 1 True SCMInvitation Nothing
-- connection fails
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo"
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo"
("", "", DOWN srv conns) <- get alice
srv `shouldBe` testSMPServer
conns `shouldBe` [bobId]
@@ -372,7 +372,7 @@ testAsyncServerOffline t = do
liftIO $ do
srv1 `shouldBe` testSMPServer
conns1 `shouldBe` [bobId]
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
@@ -387,9 +387,9 @@ testAsyncHelloTimeout = do
alice <- getSMPAgentClient agentCfgV1 initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers
Right () <- runExceptT $ do
(_, cReq) <- createConnection alice True SCMInvitation Nothing
(_, cReq) <- createConnection alice 1 True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
pure ()
@@ -444,8 +444,8 @@ testDuplicateMessage t = do
makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
makeConnection alice bob = do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
@@ -459,7 +459,7 @@ testInactiveClientDisconnected t = do
withSmpServerConfigOn t cfg' testPort $ \_ -> do
alice <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
get alice ##> ("", "", DOWN testSMPServer [connId])
pure ()
@@ -470,7 +470,7 @@ testActiveClientNotDisconnected t = do
alice <- getSMPAgentClient agentCfg initAgentServers
ts <- getSystemTime
Right () <- runExceptT $ do
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
keepSubscribing alice connId ts
pure ()
where
@@ -619,10 +619,10 @@ testAsyncCommands = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
bobId <- createConnectionAsync alice "1" True SCMInvitation
bobId <- createConnectionAsync alice 1 "1" True SCMInvitation
("1", bobId', INV (ACR _ qInfo)) <- get alice
liftIO $ bobId' `shouldBe` bobId
aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo"
aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo"
("2", aliceId', OK) <- get bob
liftIO $ aliceId' `shouldBe` aliceId
("", _, CONF confId _ "bob's connInfo") <- get alice
@@ -663,7 +663,7 @@ testAsyncCommands = do
testAsyncCommandsRestore :: ATransport -> IO ()
testAsyncCommandsRestore t = do
alice <- getSMPAgentClient agentCfg initAgentServers
Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation
Right bobId <- runExceptT $ createConnectionAsync alice 1 "1" True SCMInvitation
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
disconnectAgentClient alice
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
@@ -679,8 +679,8 @@ testAcceptContactAsync = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
(_, qInfo) <- createConnection alice True SCMContact Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, REQ invId _ "bob's connInfo") <- get alice
bobId <- acceptContactAsync alice "1" True invId "alice's connInfo"
("1", bobId', OK) <- get alice
@@ -808,11 +808,11 @@ testCreateQueueAuth clnt1 clnt2 = do
a <- getClient clnt1
b <- getClient clnt2
Right created <- runExceptT $ do
tryError (createConnection a True SCMInvitation Nothing) >>= \case
tryError (createConnection a 1 True SCMInvitation Nothing) >>= \case
Left (SMP AUTH) -> pure 0
Left e -> throwError e
Right (bId, qInfo) ->
tryError (joinConnection b True qInfo "bob's connInfo") >>= \case
tryError (joinConnection b 1 True qInfo "bob's connInfo") >>= \case
Left (SMP AUTH) -> pure 1
Left e -> throwError e
Right aId -> do
@@ -826,7 +826,7 @@ testCreateQueueAuth clnt1 clnt2 = do
pure created
where
getClient (clntAuth, clntVersion) =
let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]}
let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]}
smpCfg = (defaultClientConfig :: ProtocolClientConfig) {smpServerVRange = mkVersionRange 4 clntVersion}
in getSMPAgentClient agentCfg {smpCfg} servers
@@ -834,7 +834,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
testSMPServerConnectionTest t newQueueBasicAuth srv =
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running
Right r <- runExceptT $ testSMPServerConnection a srv
Right r <- runExceptT $ testSMPServerConnection a 1 srv
pure r
testRatchetAdHash :: IO ()

View File

@@ -213,8 +213,8 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (bobId, aliceId, nonce, message) <- runExceptT $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get bob ##> ("", aliceId, INFO "alice's connInfo")
@@ -275,9 +275,9 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
DeviceToken {} <- registerTestToken bob "bcde" NMInstant apnsQ
-- establish connection
liftIO $ threadDelay 50000
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
liftIO $ threadDelay 1000000
aliceId <- joinConnection bob True qInfo "bob's connInfo"
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
liftIO $ threadDelay 750000
void $ messageNotification apnsQ
("", _, CONF confId _ "bob's connInfo") <- get alice
@@ -327,8 +327,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right () <- runExceptT $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get bob ##> ("", aliceId, INFO "alice's connInfo")
@@ -392,8 +392,8 @@ testChangeToken APNSMockServer {apnsQ} = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Right (aliceId, bobId) <- runExceptT $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get bob ##> ("", aliceId, INFO "alice's connInfo")

View File

@@ -140,7 +140,7 @@ testForeignKeysEnabled =
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
cData1 :: ConnData
cData1 = ConnData {connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
testPrivateSignKey :: C.APrivateSignKey
testPrivateSignKey = C.APrivateSignKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
@@ -154,7 +154,8 @@ testDhSecret = "01234567890123456789012345678901"
rcvQueue1 :: RcvQueue
rcvQueue1 =
RcvQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
rcvId = "1234",
rcvPrivateKey = testPrivateSignKey,
@@ -173,7 +174,8 @@ rcvQueue1 =
sndQueue1 :: SndQueue
sndQueue1 =
SndQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
sndId = "3456",
sndPublicKey = Nothing,
@@ -314,7 +316,8 @@ testUpgradeRcvConnToDuplex =
_ <- createSndConn db g cData1 sndQueue1
let anotherSndQueue =
SndQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
sndId = "2345",
sndPublicKey = Nothing,
@@ -340,7 +343,8 @@ testUpgradeSndConnToDuplex =
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
let anotherRcvQueue =
RcvQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
rcvId = "3456",
rcvPrivateKey = testPrivateSignKey,

View File

@@ -69,7 +69,7 @@ ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
testNtfClient client = do
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig host ntfTestPort (Just testKeyHash) $ \h ->
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h ->
liftIO (runExceptT $ ntfClientHandshake h testKeyHash supportedNTFServerVRange) >>= \case
Right th -> client th
Left e -> error $ show e

View File

@@ -1,6 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
@@ -11,7 +12,8 @@ import Control.Monad.IO.Unlift
import Crypto.Random
import qualified Data.ByteString.Char8 as B
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Network.Socket (ServiceName)
import NtfClient (ntfTestPort)
import SMPClient
@@ -27,6 +29,7 @@ import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
import Simplex.Messaging.Agent.Store (UserId)
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Transport
@@ -173,13 +176,13 @@ testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5
initAgentServers :: InitialAgentServers
initAgentServers =
InitialAgentServers
{ smp = L.fromList [noAuthSrv testSMPServer],
{ smp = userServers [noAuthSrv testSMPServer],
ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"],
netCfg = defaultNetworkConfig {tcpTimeout = 500_000}
}
initAgentServers2 :: InitialAgentServers
initAgentServers2 = initAgentServers {smp = L.fromList [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
agentCfg :: AgentConfig
agentCfg =
@@ -209,11 +212,14 @@ agentCfg =
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> m () -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
let cfg' = agentCfg {tcpPort = port', database = db'}
initServers' = initAgentServers {smp = L.fromList [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
in serverBracket
(\started -> runSMPAgentBlocking t started cfg' initServers')
afterProcess
userServers :: NonEmpty SMPServerWithAuth -> Map UserId (NonEmpty SMPServerWithAuth)
userServers srvs = M.fromList [(1, srvs)]
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a $ removeFile (dbFile db')
@@ -226,7 +232,7 @@ withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a
testSMPAgentClientOn port' client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
runTransportClient defaultTransportClientConfig useHost port' (Just testKeyHash) $ \h -> do
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
line <- liftIO $ getLn h
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
then client h

View File

@@ -57,7 +57,7 @@ testServerStatsBackupFile = "tests/tmp/smp-server-stats.log"
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
testSMPClient client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig useHost testPort (Just testKeyHash) $ \h ->
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h ->
liftIO (runExceptT $ smpClientHandshake h testKeyHash supportedSMPServerVRange) >>= \case
Right th -> client th
Left e -> error $ show e