mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 16:26:02 +00:00
support users in agent to isolate traffic of different users (#598)
* users table, isolate traffic sessions by users or by queues * remove extra indices * corrections Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b328492dc9
commit
f4ad3a983e
@@ -7,6 +7,7 @@ module Main where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgent)
|
||||
import Simplex.Messaging.Client (defaultNetworkConfig)
|
||||
@@ -18,7 +19,7 @@ cfg = defaultAgentConfig
|
||||
servers :: InitialAgentServers
|
||||
servers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
|
||||
{ smp = M.fromList [(1, L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"])],
|
||||
ntf = [],
|
||||
netCfg = defaultNetworkConfig
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
packages: .
|
||||
-- packages: . ../direct-sqlcipher ../sqlcipher-simple
|
||||
-- packages: . ../hs-socks
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/aeson.git
|
||||
tag: 3eb66f9a68f103b5f1489382aad89f5712a64db7
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/hs-socks.git
|
||||
tag: a30cc7a79a08d8108316094f8f2f82a0c5e1ac51
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/direct-sqlcipher.git
|
||||
|
||||
@@ -55,6 +55,7 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
|
||||
Simplex.Messaging.Agent.TRcvQueues
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
|
||||
@@ -38,6 +38,8 @@ module Simplex.Messaging.Agent
|
||||
disconnectAgentClient,
|
||||
resumeAgentClient,
|
||||
withConnLock,
|
||||
createUser,
|
||||
deleteUser,
|
||||
createConnectionAsync,
|
||||
joinConnectionAsync,
|
||||
allowConnectionAsync,
|
||||
@@ -158,13 +160,19 @@ resumeAgentClient c = atomically $ writeTVar (active c) True
|
||||
-- |
|
||||
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
|
||||
|
||||
createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
|
||||
createUser c = withAgentEnv c . createUser' c
|
||||
|
||||
deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> m ()
|
||||
deleteUser c = withAgentEnv c . deleteUser' c
|
||||
|
||||
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
|
||||
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode
|
||||
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
createConnectionAsync c userId corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c userId corrId enableNtfs cMode
|
||||
|
||||
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
|
||||
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs
|
||||
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnectionAsync c userId corrId enableNtfs = withAgentEnv c .: joinConnAsync c userId corrId enableNtfs
|
||||
|
||||
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
|
||||
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -187,12 +195,12 @@ deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -
|
||||
deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c
|
||||
|
||||
-- | Create SMP agent connection (NEW command)
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
createConnection c enableNtfs cMode clientData = withAgentEnv c $ newConn c "" False enableNtfs cMode clientData
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
createConnection c userId enableNtfs cMode clientData = withAgentEnv c $ newConn c userId "" False enableNtfs cMode clientData
|
||||
|
||||
-- | Join SMP agent connection (JOIN command)
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnection c userId enableNtfs = withAgentEnv c .: joinConn c userId "" False enableNtfs
|
||||
|
||||
-- | Allow connection to continue after CONF notification (LET command)
|
||||
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -256,12 +264,12 @@ getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m By
|
||||
getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c
|
||||
|
||||
-- | Change servers to be used for creating new queues
|
||||
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers c = withAgentEnv c . setSMPServers' c
|
||||
setSMPServers :: AgentErrorMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers c = withAgentEnv c .: setSMPServers' c
|
||||
|
||||
-- | Test SMP server
|
||||
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
testSMPServerConnection c = withAgentEnv c . runSMPServerTest c
|
||||
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
testSMPServerConnection c = withAgentEnv c .: runSMPServerTest c
|
||||
|
||||
setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m ()
|
||||
setNtfServers c = withAgentEnv c . setNtfServers' c
|
||||
@@ -349,8 +357,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
-- | execute any SMP agent command
|
||||
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
|
||||
processCommand c (connId, cmd) = case cmd of
|
||||
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode Nothing
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo
|
||||
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId False enableNtfs cMode Nothing
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo
|
||||
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
|
||||
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo
|
||||
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
|
||||
@@ -361,29 +369,43 @@ processCommand c (connId, cmd) = case cmd of
|
||||
OFF -> suspendConnection' c connId $> (connId, OK)
|
||||
DEL -> deleteConnection' c connId $> (connId, OK)
|
||||
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
|
||||
where
|
||||
-- command interface does not support different users
|
||||
userId = 1
|
||||
|
||||
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
newConnAsync c corrId enableNtfs cMode = do
|
||||
createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
|
||||
createUser' c srvs = do
|
||||
userId <- withStore' c createUserRecord
|
||||
atomically $ TM.insert userId srvs $ smpServers c
|
||||
pure userId
|
||||
|
||||
deleteUser' :: AgentMonad m => AgentClient -> UserId -> m ()
|
||||
deleteUser' c userId = do
|
||||
withStore c (`deleteUserRecord` userId)
|
||||
atomically $ TM.delete userId $ smpServers c
|
||||
|
||||
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
newConnAsync c userId corrId enableNtfs cMode = do
|
||||
g <- asks idsDrg
|
||||
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
|
||||
let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
connId <- withStore c $ \db -> createNewConn db g cData cMode
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode)
|
||||
pure connId
|
||||
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case crAgentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo =
|
||||
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _cInfo =
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -397,9 +419,9 @@ acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> Invitat
|
||||
acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do
|
||||
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
|
||||
withStore c (`getConn` contactConnId) >>= \case
|
||||
SomeConn _ ContactConnection {} -> do
|
||||
SomeConn _ (ContactConnection ConnData {userId} _) -> do
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConnAsync c corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
@@ -444,14 +466,14 @@ switchConnectionAsync' c corrId connId =
|
||||
SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c connId asyncMode enableNtfs cMode clientData =
|
||||
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode clientData
|
||||
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c userId connId asyncMode enableNtfs cMode clientData =
|
||||
getSMPServer c userId >>= newConnSrv c userId connId asyncMode enableNtfs cMode clientData
|
||||
|
||||
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do
|
||||
newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do
|
||||
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
|
||||
(q, qUri) <- newRcvQueue c "" srv smpClientVRange
|
||||
(q, qUri) <- newRcvQueue c userId "" srv smpClientVRange
|
||||
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
|
||||
let rq = (q :: RcvQueue) {connId = connId'}
|
||||
addSubscription c rq
|
||||
@@ -471,19 +493,19 @@ newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do
|
||||
pure connId
|
||||
setUpConn False rq connAgentVersion = do
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
withStore c $ \db -> createRcvConn db g cData rq cMode
|
||||
|
||||
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c connId asyncMode enableNtfs cReq cInfo = do
|
||||
joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c userId connId asyncMode enableNtfs cReq cInfo = do
|
||||
srv <- case cReq of
|
||||
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
|
||||
getNextSMPServer c [qServer q]
|
||||
_ -> getSMPServer c
|
||||
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
|
||||
getNextSMPServer c userId [qServer q]
|
||||
_ -> getSMPServer c userId
|
||||
joinConnSrv c userId connId asyncMode enableNtfs cReq cInfo srv
|
||||
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
|
||||
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
|
||||
joinConnSrv c userId connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
|
||||
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
|
||||
case ( qUri `compatibleVersion` smpClientVRange,
|
||||
e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange,
|
||||
@@ -493,9 +515,9 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
|
||||
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
|
||||
(_, rcDHRs) <- liftIO C.generateKeyPair'
|
||||
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
|
||||
q <- newSndQueue "" qInfo
|
||||
q <- newSndQueue userId "" qInfo
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
connId' <- setUpConn asyncMode cData q rc
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
@@ -520,23 +542,23 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnSrv c connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
|
||||
joinConnSrv c userId connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case ( qUri `compatibleVersion` clientVRange,
|
||||
crAgentVRange `compatibleVersion` aVRange
|
||||
) of
|
||||
(Just qInfo, Just vrsn) -> do
|
||||
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation Nothing srv
|
||||
sendInvitation c qInfo vrsn cReq cInfo
|
||||
(connId', cReq) <- newConnSrv c userId connId False enableNtfs SCMInvitation Nothing srv
|
||||
sendInvitation c userId qInfo vrsn cReq cInfo
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
joinConnSrv _c _userId _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion
|
||||
createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c userId connId srv $ versionToRange smpClientVersion
|
||||
let qInfo = toVersionT qUri smpClientVersion
|
||||
addSubscription c rq
|
||||
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
|
||||
@@ -564,9 +586,9 @@ acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId
|
||||
acceptContact' c connId enableNtfs invId ownConnInfo = withConnLock c connId "acceptContact" $ do
|
||||
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
|
||||
withStore c (`getConn` contactConnId) >>= \case
|
||||
SomeConn _ ContactConnection {} -> do
|
||||
SomeConn _ (ContactConnection ConnData {userId} _) -> do
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
joinConn c userId connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
@@ -799,13 +821,15 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
NEW enableNtfs (ACM cMode) -> noServer $ do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
|
||||
(_, cReq) <- newConnSrv c connId True enableNtfs cMode Nothing srv
|
||||
let userId = 1 -- TODO
|
||||
(_, cReq) <- newConnSrv c userId connId True enableNtfs cMode Nothing srv
|
||||
notify $ INV (ACR cMode cReq)
|
||||
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do
|
||||
let initUsed = [qServer q]
|
||||
usedSrvs <- newTVarIO initUsed
|
||||
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
|
||||
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
|
||||
let userId = 1 -- TODO
|
||||
void $ joinConnSrv c userId connId True enableNtfs cReq connInfo srv
|
||||
notify OK
|
||||
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
|
||||
ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK
|
||||
@@ -901,10 +925,12 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m ()
|
||||
withNextSrv usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c used
|
||||
let userId = 1 -- TODO
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId used
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
|
||||
srvs_ <- TM.lookup userId $ smpServers c
|
||||
-- TODO this condition does not account for servers change, it has to be changed to see if there are any remaining unused servers configured for the user
|
||||
let used' = if length used + 1 >= maybe 0 L.length srvs_ then initUsed else srv : used
|
||||
writeTVar usedSrvs used'
|
||||
action srvAuth
|
||||
-- ^ ^ ^ async command processing /
|
||||
@@ -978,7 +1004,7 @@ getPendingMsgQ c SndQueue {server, sndId} = do
|
||||
pure q
|
||||
|
||||
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
|
||||
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
|
||||
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, duplexHandshake} sq = do
|
||||
(mq, qLock) <- atomically $ getPendingMsgQ c sq
|
||||
ri <- asks $ messageRetryInterval . config
|
||||
forever $ do
|
||||
@@ -1067,7 +1093,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
-- and this branch should never be reached as receive is created before the confirmation,
|
||||
-- so the condition is not necessary here, strictly speaking.
|
||||
_ -> unless (duplexHandshake == Just True) $ do
|
||||
srv <- getSMPServer c
|
||||
srv <- getSMPServer c userId
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
|
||||
AM_A_MSG_ -> notify $ SENT mId
|
||||
@@ -1142,12 +1168,12 @@ switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
switchConnection' c connId = withConnLock c connId "switchConnection" $ do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
|
||||
DuplexConnection cData@ConnData {userId} rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
-- try to get the server that is different from all queues, or at least from the primary rcv queue
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
|
||||
srv' <- if srv == server then getNextSMPServer c [server] else pure srvAuth
|
||||
(q, qUri) <- newRcvQueue c connId srv' clientVRange
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
|
||||
srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth
|
||||
(q, qUri) <- newRcvQueue c userId connId srv' clientVRange
|
||||
let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
void . withStore c $ \db -> addConnRcvQueue db connId rq'
|
||||
addSubscription c rq'
|
||||
@@ -1217,8 +1243,8 @@ connectionStats = \case
|
||||
NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []}
|
||||
|
||||
-- | Change servers to be used for creating new queues, in Reader monad
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers' c = atomically . writeTVar (smpServers c)
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers' c userId srvs = atomically $ TM.insert userId srvs $ smpServers c
|
||||
|
||||
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus
|
||||
registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
|
||||
@@ -1457,8 +1483,8 @@ debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs} = do
|
||||
where
|
||||
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
|
||||
|
||||
getSMPServer :: AgentMonad m => AgentClient -> m SMPServerWithAuth
|
||||
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
|
||||
getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth
|
||||
getSMPServer c userId = withUserServers c userId pickServer
|
||||
|
||||
pickServer :: AgentMonad m => NonEmpty SMPServerWithAuth -> m SMPServerWithAuth
|
||||
pickServer = \case
|
||||
@@ -1467,13 +1493,18 @@ pickServer = \case
|
||||
gen <- asks randomServer
|
||||
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
|
||||
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServerWithAuth
|
||||
getNextSMPServer c usedSrvs = do
|
||||
srvs <- readTVarIO $ smpServers c
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> UserId -> [SMPServer] -> m SMPServerWithAuth
|
||||
getNextSMPServer c userId usedSrvs = withUserServers c userId $ \srvs ->
|
||||
case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of
|
||||
Just srvs' -> pickServer srvs'
|
||||
_ -> pickServer srvs
|
||||
|
||||
withUserServers :: AgentMonad m => AgentClient -> UserId -> (NonEmpty SMPServerWithAuth -> m a) -> m a
|
||||
withUserServers c userId action =
|
||||
atomically (TM.lookup userId $ smpServers c) >>= \case
|
||||
Just srvs -> action srvs
|
||||
_ -> throwError $ INTERNAL "unknown userId - no SMP servers"
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
subscriber c@AgentClient {msgQ} = forever $ do
|
||||
t <- atomically $ readTBQueue msgQ
|
||||
@@ -1488,7 +1519,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
processSMP rq conn $ connData conn
|
||||
where
|
||||
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
|
||||
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
|
||||
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {userId, connId, duplexHandshake} = withConnLock c connId "processSMP" $
|
||||
case cmd of
|
||||
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
|
||||
handleNotifyAck $
|
||||
@@ -1586,8 +1617,10 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
ackDel = enqueueCmd . ICAckDel rId srvMsgId
|
||||
handleNotifyAck :: m () -> m ()
|
||||
handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack
|
||||
SMP.END ->
|
||||
atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND)
|
||||
SMP.END -> do
|
||||
-- TODO is race condition possible here on session mode change?
|
||||
tSess <- mkSMPTransportSession c rq
|
||||
atomically (TM.lookup tSess smpClients $>>= tryReadTMVar >>= processEND)
|
||||
>>= logServer "<--" c srv rId
|
||||
where
|
||||
processEND = \case
|
||||
@@ -1724,7 +1757,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
|
||||
(Just _, _) -> qError "QADD: queue address is already used in connection"
|
||||
(_, Just _replaced@SndQueue {dbQueueId}) -> do
|
||||
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo
|
||||
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo
|
||||
let sq' = (sq_ :: SndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
void . withStore c $ \db -> addConnSndQueue db connId sq'
|
||||
case (sndPublicKey, e2ePubKey) of
|
||||
@@ -1796,12 +1829,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
| otherwise = MsgError MsgDuplicate -- this case is not possible
|
||||
|
||||
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
|
||||
connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
|
||||
connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case qInfo `proveCompatible` clientVRange of
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Just qInfo' -> do
|
||||
sq <- newSndQueue connId qInfo'
|
||||
sq <- newSndQueue userId connId qInfo'
|
||||
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
|
||||
|
||||
@@ -1860,14 +1893,15 @@ agentRatchetDecrypt db connId encAgentMsg = do
|
||||
liftIO $ updateRatchet db connId rc' skippedDiff
|
||||
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
|
||||
|
||||
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue
|
||||
newSndQueue connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
|
||||
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m SndQueue
|
||||
newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
|
||||
C.SignAlg a <- asks $ cmdSignAlg . config
|
||||
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
pure
|
||||
SndQueue
|
||||
{ connId,
|
||||
{ userId,
|
||||
connId,
|
||||
server = smpServer,
|
||||
sndId = senderId,
|
||||
sndPublicKey = Just sndPublicKey,
|
||||
|
||||
@@ -48,6 +48,9 @@ module Simplex.Messaging.Agent.Client
|
||||
agentNtfCreateSubscription,
|
||||
agentNtfCheckSubscription,
|
||||
agentNtfDeleteSubscription,
|
||||
-- TODO possibly, this should not be exported?
|
||||
-- this is used in END processing
|
||||
mkSMPTransportSession,
|
||||
agentCbEncrypt,
|
||||
agentCbDecrypt,
|
||||
cryptoError,
|
||||
@@ -131,6 +134,7 @@ import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse)
|
||||
import Simplex.Messaging.Protocol
|
||||
( AProtocolType (..),
|
||||
BrokerMsg,
|
||||
EntityId,
|
||||
ErrorType,
|
||||
MsgFlags (..),
|
||||
MsgId,
|
||||
@@ -166,15 +170,22 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
|
||||
|
||||
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
|
||||
|
||||
-- | Transport session key - includes entity ID if `sessionMode = TSMEntity`.
|
||||
type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
|
||||
|
||||
type SMPTransportSession = TransportSession SMP.BrokerMsg
|
||||
|
||||
type NtfTransportSession = TransportSession NtfResponse
|
||||
|
||||
data AgentClient = AgentClient
|
||||
{ active :: TVar Bool,
|
||||
rcvQ :: TBQueue (ATransmission 'Client),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
smpServers :: TVar (NonEmpty SMPServerWithAuth),
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
|
||||
smpClients :: TMap SMPTransportSession SMPClientVar,
|
||||
ntfServers :: TVar [NtfServer],
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
ntfClients :: TMap NtfTransportSession NtfClientVar,
|
||||
useNetworkConfig :: TVar NetworkConfig,
|
||||
subscrConns :: TVar (Set ConnId),
|
||||
activeSubs :: TRcvQueues,
|
||||
@@ -195,7 +206,7 @@ data AgentClient = AgentClient
|
||||
-- locks to prevent concurrent operations with connection
|
||||
connLocks :: TMap ConnId Lock,
|
||||
-- locks to prevent concurrent reconnections to SMP servers
|
||||
reconnectLocks :: TMap SMPServer Lock,
|
||||
reconnectLocks :: TMap SMPTransportSession Lock,
|
||||
reconnections :: TVar [Async ()],
|
||||
asyncClients :: TVar [Async ()],
|
||||
agentStats :: TMap AgentStatsKey (TVar Int),
|
||||
@@ -227,7 +238,13 @@ data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map St
|
||||
|
||||
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
data AgentStatsKey = AgentStatsKey {host :: ByteString, clientTs :: ByteString, cmd :: ByteString, res :: ByteString}
|
||||
data AgentStatsKey = AgentStatsKey
|
||||
{ userId :: UserId,
|
||||
host :: ByteString,
|
||||
clientTs :: ByteString,
|
||||
cmd :: ByteString,
|
||||
res :: ByteString
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
|
||||
@@ -270,7 +287,7 @@ agentClientStore :: AgentClient -> SQLiteStore
|
||||
agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
|
||||
class ProtocolServerClient msg where
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg)
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
|
||||
clientProtocolError :: ErrorType -> AgentErrorType
|
||||
|
||||
instance ProtocolServerClient BrokerMsg where
|
||||
@@ -281,19 +298,20 @@ instance ProtocolServerClient NtfResponse where
|
||||
getProtocolServerClient = getNtfServerClient
|
||||
clientProtocolError = NTF
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, entityId_) = do
|
||||
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
|
||||
atomically (getClientVar srv smpClients)
|
||||
atomically (getClientVar tSess smpClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv smpClients connectClient reconnectClient)
|
||||
(waitForProtocolClient c srv)
|
||||
(newProtocolClient c tSess smpClients connectClient reconnectClient)
|
||||
(waitForProtocolClient c tSess)
|
||||
where
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
u <- askUnliftIO
|
||||
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
|
||||
let proxyUsername = Just $ bshow userId <> maybe "" (":" <>) entityId_
|
||||
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg proxyUsername (Just msgQ) $ clientDisconnected u)
|
||||
|
||||
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
|
||||
clientDisconnected u client = do
|
||||
@@ -302,14 +320,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
where
|
||||
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.delete tSess smpClients
|
||||
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
|
||||
mapM_ (`RQ.addQueue` pendingSubs c) qs
|
||||
pure (qs, S.toList conns)
|
||||
|
||||
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
|
||||
serverDown (qs, conns) = whenM (readTVarIO active) $ do
|
||||
incClientStat c client "DISCONNECT" ""
|
||||
incClientStat c userId client "DISCONNECT" ""
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
unless (null conns) $ notifySub "" $ DOWN srv conns
|
||||
unless (null qs) $ do
|
||||
@@ -329,18 +347,20 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withLockMap_ (reconnectLocks c) srv "reconnect" $
|
||||
withLockMap_ (reconnectLocks c) tSess "reconnect" $
|
||||
atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe
|
||||
where
|
||||
resubscribe :: [RcvQueue] -> m ()
|
||||
resubscribe qs = do
|
||||
connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar)
|
||||
connected <- maybe False isRight <$> atomically (TM.lookup tSess smpClients $>>= tryReadTMVar)
|
||||
cs <- atomically . RQ.getConns $ activeSubs c
|
||||
-- TODO probably tSess should be passed here or nothing
|
||||
-- maybe this should be another function, not subscribeQueues
|
||||
(client_, rs) <- subscribeQueues c srv qs
|
||||
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
|
||||
liftIO $ do
|
||||
unless connected . forM_ client_ $ \cl -> do
|
||||
incClientStat c cl "CONNECT" ""
|
||||
incClientStat c userId cl "CONNECT" ""
|
||||
notifySub "" $ hostEvent CONNECT cl
|
||||
let conns = S.toList $ S.fromList okConns `S.difference` cs
|
||||
unless (null conns) $ notifySub "" $ UP srv conns
|
||||
@@ -351,37 +371,37 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
notifySub :: ConnId -> ACommand 'Agent -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients} srv = do
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
|
||||
atomically (getClientVar srv ntfClients)
|
||||
atomically (getClientVar tSess ntfClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv ntfClients connectClient $ pure ())
|
||||
(waitForProtocolClient c srv)
|
||||
(newProtocolClient c tSess ntfClients connectClient $ pure ())
|
||||
(waitForProtocolClient c tSess)
|
||||
where
|
||||
connectClient :: m NtfClient
|
||||
connectClient = do
|
||||
cfg <- getClientConfig c ntfCfg
|
||||
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing clientDisconnected)
|
||||
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing Nothing clientDisconnected)
|
||||
|
||||
clientDisconnected :: NtfClient -> IO ()
|
||||
clientDisconnected client = do
|
||||
atomically $ TM.delete srv ntfClients
|
||||
incClientStat c client "DISCONNECT" ""
|
||||
atomically $ TM.delete tSess ntfClients
|
||||
incClientStat c userId client "DISCONNECT" ""
|
||||
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients
|
||||
getClientVar :: forall a s. TransportSession s -> TMap (TransportSession s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup tSess clients
|
||||
where
|
||||
newClientVar :: STM (TMVar a)
|
||||
newClientVar = do
|
||||
var <- newEmptyTMVar
|
||||
TM.insert srv var clients
|
||||
TM.insert tSess var clients
|
||||
pure var
|
||||
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient c srv clientVar = do
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient c (_, srv, _) clientVar = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
|
||||
liftEither $ case client_ of
|
||||
@@ -393,30 +413,30 @@ newProtocolClient ::
|
||||
forall msg m.
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
|
||||
AgentClient ->
|
||||
ProtoServer msg ->
|
||||
TMap (ProtoServer msg) (ClientVar msg) ->
|
||||
TransportSession msg ->
|
||||
TMap (TransportSession msg) (ClientVar msg) ->
|
||||
m (ProtocolClient msg) ->
|
||||
m () ->
|
||||
ClientVar msg ->
|
||||
m (ProtocolClient msg)
|
||||
newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
|
||||
atomically $ putTMVar clientVar r
|
||||
liftIO $ incClientStat c client "CLIENT" "OK"
|
||||
liftIO $ incClientStat c userId client "CLIENT" "OK"
|
||||
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
|
||||
successAction client
|
||||
Left e -> do
|
||||
liftIO $ incServerStat c srv "CLIENT" $ strEncode e
|
||||
liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e
|
||||
if temporaryAgentError e
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar clientVar (Left e)
|
||||
TM.delete srv clients
|
||||
TM.delete tSess clients
|
||||
throwError e
|
||||
tryConnectAsync :: m ()
|
||||
tryConnectAsync = do
|
||||
@@ -471,7 +491,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
|
||||
where
|
||||
k = (server, sndId)
|
||||
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients c clientsSel =
|
||||
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
|
||||
where
|
||||
@@ -494,30 +514,43 @@ withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pur
|
||||
where
|
||||
newLock = newEmptyTMVar >>= \l -> TM.insert key l locks $> l
|
||||
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ c srv statCmd action = do
|
||||
cl <- getProtocolServerClient c srv
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
cl <- getProtocolServerClient c tSess
|
||||
(action cl <* stat cl "OK") `catchError` logServerError cl
|
||||
where
|
||||
stat cl = liftIO . incClientStat c cl statCmd
|
||||
stat cl = liftIO . incClientStat c userId cl statCmd
|
||||
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
|
||||
logServerError cl e = do
|
||||
logServer "<--" c srv "" $ strEncode e
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ c srv qId cmdStr action = do
|
||||
logServer "-->" c srv qId cmdStr
|
||||
res <- withClient_ c srv cmdStr action
|
||||
logServer "<--" c srv qId "OK"
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
|
||||
logServer "-->" c srv entId cmdStr
|
||||
res <- withClient_ c tSess cmdStr action
|
||||
logServer "<--" c srv entId "OK"
|
||||
return res
|
||||
|
||||
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c srv statKey action = withClient_ c srv statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
|
||||
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
|
||||
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withSMPClient c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
withLogClient c tSess (queueId q) cmdStr action
|
||||
|
||||
withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a
|
||||
withSMPClient_ c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
withLogClient_ c tSess (queueId q) cmdStr action
|
||||
|
||||
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withNtfClient c srv = withLogClient c (0, srv, Nothing)
|
||||
|
||||
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
|
||||
liftClient protocolError_ = liftError . protocolClientError protocolError_
|
||||
@@ -551,12 +584,13 @@ instance ToJSON SMPTestFailure where
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
toJSON = J.genericToJSON J.defaultOptions
|
||||
|
||||
runSMPServerTest :: AgentMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
runSMPServerTest c (ProtoServerWithAuth srv auth) = do
|
||||
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
C.SignAlg a <- asks $ cmdSignAlg . config
|
||||
liftIO $ do
|
||||
getProtocolClient srv cfg Nothing (\_ -> pure ()) >>= \case
|
||||
let proxyUsername = bshow userId
|
||||
getProtocolClient srv cfg (Just proxyUsername) Nothing (\_ -> pure ()) >>= \case
|
||||
Right smp -> do
|
||||
(rKey, rpKey) <- C.generateSignatureKeyPair a
|
||||
(sKey, _) <- C.generateSignatureKeyPair a
|
||||
@@ -566,7 +600,7 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
|
||||
liftError (testErr TSSecureQueue) $ secureSMPQueue smp rpKey rcvId sKey
|
||||
liftError (testErr TSDeleteQueue) $ deleteSMPQueue smp rpKey rcvId
|
||||
ok <- tcpTimeout (networkConfig cfg) `timeout` closeProtocolClient smp
|
||||
incClientStat c smp "TEST" "OK"
|
||||
incClientStat c userId smp "TEST" "OK"
|
||||
pure $ either Just (const Nothing) r <|> maybe (Just (SMPTestFailure TSDisconnect $ BROKER addr TIMEOUT)) (const Nothing) ok
|
||||
Left e -> pure (Just $ testErr TSConnect e)
|
||||
where
|
||||
@@ -574,19 +608,29 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
|
||||
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
|
||||
testErr step = SMPTestFailure step . protocolClientError SMP addr
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do
|
||||
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
|
||||
mkTransportSession c userId srv entityId = do
|
||||
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
|
||||
pure (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
|
||||
|
||||
mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
|
||||
mkSMPTransportSession c q = mkTransportSession c (qUserId q) (qServer q) (queueId q)
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
|
||||
C.SignAlg a <- asks (cmdSignAlg . config)
|
||||
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
logServer "-->" c srv "" "NEW"
|
||||
tSess <- mkTransportSession c userId srv connId
|
||||
QIK {rcvId, sndId, rcvPublicDhKey} <-
|
||||
withClient c srv "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
|
||||
withClient c tSess "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
let rq =
|
||||
RcvQueue
|
||||
{ connId,
|
||||
{ userId,
|
||||
connId,
|
||||
server = srv,
|
||||
rcvId,
|
||||
rcvPrivateKey,
|
||||
@@ -609,7 +653,7 @@ subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
|
||||
atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
withSMPClient c rq "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
|
||||
>>= either throwError pure
|
||||
|
||||
@@ -648,14 +692,19 @@ subscribeQueues c srv qs = do
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
|
||||
-- TODO these subscriptions should happen in different sessions if mode is TSMEntity
|
||||
-- it is also a question whether it is needed to group by server outside if grouping by sessions should happen here anyway
|
||||
let userId = 1
|
||||
tSess = (userId, srv, Nothing)
|
||||
smp_ <- tryError $ getSMPServerClient c tSess
|
||||
(eitherToMaybe smp_,) . (errs <>) <$> case smp_ of
|
||||
Left e -> pure $ map (,Left e) qs_
|
||||
Right smp -> do
|
||||
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
|
||||
let qs2 = L.map queueCreds qs'
|
||||
n = (length qs2 - 1) `div` 90 + 1
|
||||
liftIO $ incClientStatN c smp n "SUBS" "OK"
|
||||
liftIO $ incClientStatN c userId smp n "SUBS" "OK"
|
||||
liftIO $ do
|
||||
rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
|
||||
mapM_ (uncurry $ processSubResult c) rs
|
||||
@@ -699,16 +748,17 @@ logSecret :: ByteString -> ByteString
|
||||
logSecret bs = encode $ B.take 3 bs
|
||||
|
||||
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
|
||||
sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withSMPClient_ c sq "SEND <CONF>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
|
||||
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo =
|
||||
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
|
||||
tSess <- mkTransportSession c userId smpServer senderId
|
||||
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
|
||||
msg <- mkInvitation
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
|
||||
where
|
||||
@@ -722,7 +772,7 @@ sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderI
|
||||
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
|
||||
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
|
||||
atomically createTakeGetLock
|
||||
(v, msg_) <- withLogClient c server rcvId "GET" $ \smp ->
|
||||
(v, msg_) <- withSMPClient c rq "GET" $ \smp ->
|
||||
(thVersion smp,) <$> getSMPMessage smp rcvPrivateKey rcvId
|
||||
mapM (decryptMeta v) msg_
|
||||
where
|
||||
@@ -742,23 +792,23 @@ decryptSMPMessage v rq SMP.RcvMessage {msgId, msgTs, msgFlags, msgBody = SMP.Enc
|
||||
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
|
||||
|
||||
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
|
||||
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
|
||||
withLogClient c server rcvId "KEY <key>" $ \smp ->
|
||||
secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey =
|
||||
withSMPClient c rq "KEY <key>" $ \smp ->
|
||||
secureSMPQueue smp rcvPrivateKey rcvId senderKey
|
||||
|
||||
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (NotifierId, RcvNtfPublicDhKey)
|
||||
enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
|
||||
withLogClient c server rcvId "NKEY <nkey>" $ \smp ->
|
||||
enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
|
||||
withSMPClient c rq "NKEY <nkey>" $ \smp ->
|
||||
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
|
||||
|
||||
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
disableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "NDEL" $ \smp ->
|
||||
disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "NDEL" $ \smp ->
|
||||
disableSMPQueueNotifications smp rcvPrivateKey rcvId
|
||||
|
||||
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
|
||||
sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do
|
||||
withLogClient c server rcvId "ACK" $ \smp ->
|
||||
sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do
|
||||
withSMPClient c rq "ACK" $ \smp ->
|
||||
ackSMPMessage smp rcvPrivateKey rcvId msgId
|
||||
atomically $ releaseGetLock c rq
|
||||
|
||||
@@ -767,57 +817,57 @@ releaseGetLock c RcvQueue {server, rcvId} =
|
||||
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
|
||||
|
||||
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "OFF" $ \smp ->
|
||||
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "OFF" $ \smp ->
|
||||
suspendSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "DEL" $ \smp ->
|
||||
deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msgFlags agentMsg =
|
||||
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg =
|
||||
withSMPClient_ c sq "SEND <MSG>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
|
||||
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
|
||||
|
||||
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
|
||||
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
|
||||
withClient c ntfServer "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
|
||||
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
|
||||
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
|
||||
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
|
||||
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
|
||||
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
|
||||
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
|
||||
withLogClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
|
||||
withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
|
||||
|
||||
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
|
||||
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
|
||||
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
|
||||
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
|
||||
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> NtfPrivateSignKey -> m NtfSubscriptionId
|
||||
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
|
||||
withLogClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
|
||||
withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
|
||||
|
||||
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
|
||||
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
|
||||
withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
|
||||
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
|
||||
withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
|
||||
@@ -948,18 +998,18 @@ incStat AgentClient {agentStats} n k = do
|
||||
Just v -> modifyTVar v (+ n)
|
||||
_ -> newTVar n >>= \v -> TM.insert k v agentStats
|
||||
|
||||
incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c pc = incClientStatN c pc 1
|
||||
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c userId pc = incClientStatN c userId pc 1
|
||||
|
||||
incServerStat :: AgentClient -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
incServerStat c ProtocolServer {host} cmd res = do
|
||||
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
incServerStat c userId ProtocolServer {host} cmd res = do
|
||||
threadDelay 100000
|
||||
atomically $ incStat c 1 statsKey
|
||||
where
|
||||
statsKey = AgentStatsKey {host = strEncode $ L.head host, clientTs = "", cmd, res}
|
||||
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
|
||||
|
||||
incClientStatN :: AgentClient -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN c pc n cmd res = do
|
||||
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN c userId pc n cmd res = do
|
||||
atomically $ incStat c n statsKey
|
||||
where
|
||||
statsKey = AgentStatsKey {host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
|
||||
statsKey = AgentStatsKey {userId, host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
|
||||
|
||||
@@ -32,12 +32,14 @@ import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Map (Map)
|
||||
import Data.Time.Clock (NominalDiffTime, nominalDay)
|
||||
import Data.Word (Word16)
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store (UserId)
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Client
|
||||
@@ -59,7 +61,7 @@ import UnliftIO.STM
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
data InitialAgentServers = InitialAgentServers
|
||||
{ smp :: NonEmpty SMPServerWithAuth,
|
||||
{ smp :: Map UserId (NonEmpty SMPServerWithAuth),
|
||||
ntf :: [NtfServer],
|
||||
netCfg :: NetworkConfig
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ import Simplex.Messaging.Protocol
|
||||
NotifierId,
|
||||
NtfPrivateSignKey,
|
||||
NtfPublicVerifyKey,
|
||||
QueueId,
|
||||
RcvDhSecret,
|
||||
RcvNtfDhSecret,
|
||||
RcvPrivateSignKey,
|
||||
@@ -47,7 +48,8 @@ import Simplex.Messaging.Version
|
||||
|
||||
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
|
||||
data RcvQueue = RcvQueue
|
||||
{ connId :: ConnId,
|
||||
{ userId :: UserId,
|
||||
connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
-- | recipient queue ID
|
||||
rcvId :: SMP.RecipientId,
|
||||
@@ -89,7 +91,8 @@ data ClientNtfCreds = ClientNtfCreds
|
||||
|
||||
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
|
||||
data SndQueue = SndQueue
|
||||
{ connId :: ConnId,
|
||||
{ userId :: UserId,
|
||||
connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
-- | sender queue ID
|
||||
sndId :: SMP.SenderId,
|
||||
@@ -150,6 +153,27 @@ findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue
|
||||
findRQ sAddr = find $ sameQAddress sAddr . sndAddress
|
||||
{-# INLINE findRQ #-}
|
||||
|
||||
class SMPQueue q => SMPQueueRec q where
|
||||
qUserId :: q -> UserId
|
||||
qConnId :: q -> ConnId
|
||||
queueId :: q -> QueueId
|
||||
|
||||
instance SMPQueueRec RcvQueue where
|
||||
qUserId = userId
|
||||
{-# INLINE qUserId #-}
|
||||
qConnId = connId
|
||||
{-# INLINE qConnId #-}
|
||||
queueId = rcvId
|
||||
{-# INLINE queueId #-}
|
||||
|
||||
instance SMPQueueRec SndQueue where
|
||||
qUserId = userId
|
||||
{-# INLINE qUserId #-}
|
||||
qConnId = connId
|
||||
{-# INLINE qConnId #-}
|
||||
queueId = sndId
|
||||
{-# INLINE queueId #-}
|
||||
|
||||
-- * Connection types
|
||||
|
||||
-- | Type of a connection.
|
||||
@@ -222,6 +246,7 @@ deriving instance Show SomeConn
|
||||
|
||||
data ConnData = ConnData
|
||||
{ connId :: ConnId,
|
||||
userId :: UserId,
|
||||
connAgentVersion :: Version,
|
||||
enableNtfs :: Bool,
|
||||
duplexHandshake :: Maybe Bool, -- added in agent protocol v2
|
||||
@@ -231,6 +256,8 @@ data ConnData = ConnData
|
||||
|
||||
data AgentCmdType = ACClient | ACInternal
|
||||
|
||||
type UserId = Int64
|
||||
|
||||
instance StrEncoding AgentCmdType where
|
||||
strEncode = \case
|
||||
ACClient -> "CLIENT"
|
||||
@@ -471,6 +498,8 @@ data StoreError
|
||||
SEInternal ByteString
|
||||
| -- | Failed to generate unique random ID
|
||||
SEUniqueID
|
||||
| -- | User ID not found
|
||||
SEUserNotFound
|
||||
| -- | Connection not found (or both queues absent).
|
||||
SEConnNotFound
|
||||
| -- | Connection already used.
|
||||
|
||||
@@ -27,6 +27,10 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
sqlString,
|
||||
execSQL,
|
||||
|
||||
-- * Users
|
||||
createUserRecord,
|
||||
deleteUserRecord,
|
||||
|
||||
-- * Queues and connections
|
||||
createNewConn,
|
||||
updateNewConnRcv,
|
||||
@@ -297,6 +301,18 @@ withTransaction st action = withConnection st $ loop 500 2_000_000
|
||||
loop (t * 9 `div` 8) (tLim - t) db
|
||||
else E.throwIO e
|
||||
|
||||
createUserRecord :: DB.Connection -> IO UserId
|
||||
createUserRecord db = do
|
||||
DB.execute_ db "INSERT INTO users () VALUES ()"
|
||||
insertedRowId db
|
||||
|
||||
deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ())
|
||||
deleteUserRecord db userId = runExceptT $ do
|
||||
_ :: Only Int64 <-
|
||||
ExceptT . firstRow id SEUserNotFound $
|
||||
DB.query db "SELECT user_id FROM users WHERE user_id = ?" (Only userId)
|
||||
liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId)
|
||||
|
||||
createConn_ ::
|
||||
TVar ChaChaDRG ->
|
||||
ConnData ->
|
||||
@@ -307,9 +323,9 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
|
||||
ConnData {connId} -> create connId $> Right connId
|
||||
|
||||
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode =
|
||||
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
|
||||
updateNewConnRcv db connId rq =
|
||||
@@ -332,17 +348,17 @@ updateNewConnSnd db connId sq =
|
||||
updateConn = Right <$> addConnSndQueue_ db connId sq
|
||||
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertRcvQueue_ db connId q
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
||||
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertSndQueue_ db connId q
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
@@ -1328,9 +1344,9 @@ getAnyConn dbConn connId deleted' =
|
||||
|
||||
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData dbConn connId' =
|
||||
maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
|
||||
maybeFirstRow cData $ DB.query dbConn "SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
|
||||
where
|
||||
cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
|
||||
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
|
||||
|
||||
setConnDeleted :: DB.Connection -> ConnId -> IO ()
|
||||
setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
|
||||
@@ -1348,31 +1364,32 @@ getRcvQueuesByConnId_ db connId =
|
||||
rcvQueueQuery :: Query
|
||||
rcvQueueQuery =
|
||||
[sql|
|
||||
SELECT s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
|
||||
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version,
|
||||
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN connections c ON q.conn_id = c.conn_id
|
||||
|]
|
||||
|
||||
toRcvQueue ::
|
||||
(C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
||||
:. (Int64, Bool, Maybe Int64, Maybe Version)
|
||||
:. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
|
||||
RcvQueue
|
||||
toRcvQueue ((keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
let server = SMPServer host port keyHash
|
||||
smpClientVersion = fromMaybe 1 smpClientVersion_
|
||||
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
|
||||
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
||||
_ -> Nothing
|
||||
in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
|
||||
in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
|
||||
|
||||
getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
|
||||
getRcvQueueById_ db connId dbRcvId =
|
||||
firstRow toRcvQueue SEConnNotFound $
|
||||
DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId)
|
||||
DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ?") (connId, dbRcvId)
|
||||
|
||||
-- | returns all connection queues, the first queue is the primary one
|
||||
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
|
||||
@@ -1381,16 +1398,17 @@ getSndQueuesByConnId_ dbConn connId =
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
FROM snd_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN connections c ON q.conn_id = c.conn_id
|
||||
WHERE q.conn_id = ?;
|
||||
|]
|
||||
(Only connId)
|
||||
where
|
||||
sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
|
||||
sndQueue ((userId, keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
|
||||
let server = SMPServer host port keyHash
|
||||
in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
|
||||
in SndQueue {userId, connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
|
||||
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
|
||||
-- the current primary queue is ordered first, the next primary - second
|
||||
compare (Down p) (Down p') <> compare i i'
|
||||
|
||||
@@ -37,6 +37,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
|
||||
@@ -53,7 +54,8 @@ schemaMigrations =
|
||||
("m20220811_onion_hosts", m20220811_onion_hosts),
|
||||
("m20220817_connection_ntfs", m20220817_connection_ntfs),
|
||||
("m20220905_commands", m20220905_commands),
|
||||
("m20220915_connection_queues", m20220915_connection_queues)
|
||||
("m20220915_connection_queues", m20220915_connection_queues),
|
||||
("m20230110_users", m20230110_users)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20230110_users :: Query
|
||||
m20230110_users =
|
||||
[sql|
|
||||
PRAGMA ignore_check_constraints=ON;
|
||||
|
||||
CREATE TABLE users (
|
||||
user_id INTEGER PRIMARY KEY AUTOINCREMENT
|
||||
);
|
||||
|
||||
INSERT INTO users (user_id) VALUES (1);
|
||||
|
||||
ALTER TABLE connections ADD COLUMN user_id INTEGER CHECK (user_id NOT NULL)
|
||||
REFERENCES users ON DELETE CASCADE;
|
||||
|
||||
CREATE INDEX idx_connections_user ON connections(user_id);
|
||||
|
||||
UPDATE connections SET user_id = 1;
|
||||
|
||||
PRAGMA ignore_check_constraints=OFF;
|
||||
|]
|
||||
@@ -22,7 +22,9 @@ CREATE TABLE connections(
|
||||
,
|
||||
duplex_handshake INTEGER NULL DEFAULT 0,
|
||||
enable_ntfs INTEGER,
|
||||
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL)
|
||||
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL),
|
||||
user_id INTEGER CHECK(user_id NOT NULL)
|
||||
REFERENCES users ON DELETE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE rcv_queues(
|
||||
host TEXT NOT NULL,
|
||||
@@ -228,3 +230,5 @@ CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries(
|
||||
conn_id,
|
||||
snd_queue_id
|
||||
);
|
||||
CREATE TABLE users(user_id INTEGER PRIMARY KEY AUTOINCREMENT);
|
||||
CREATE INDEX idx_connections_user ON connections(user_id);
|
||||
|
||||
@@ -54,6 +54,7 @@ module Simplex.Messaging.Client
|
||||
ProtocolClientError (..),
|
||||
ProtocolClientConfig (..),
|
||||
NetworkConfig (..),
|
||||
TransportSessionMode (..),
|
||||
defaultClientConfig,
|
||||
defaultNetworkConfig,
|
||||
transportClientConfig,
|
||||
@@ -152,6 +153,8 @@ data NetworkConfig = NetworkConfig
|
||||
hostMode :: HostMode,
|
||||
-- | if above criteria is not met, if the below setting is True return error, otherwise use the first host
|
||||
requiredHostMode :: Bool,
|
||||
-- | transport sessions are created per user or per entity
|
||||
sessionMode :: TransportSessionMode,
|
||||
-- | timeout for the initial client TCP/TLS connection (microseconds)
|
||||
tcpConnectTimeout :: Int,
|
||||
-- | timeout of protocol commands (microseconds)
|
||||
@@ -168,12 +171,23 @@ instance ToJSON NetworkConfig where
|
||||
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
data TransportSessionMode = TSMUser | TSMEntity
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON TransportSessionMode where
|
||||
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM"
|
||||
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM"
|
||||
|
||||
instance FromJSON TransportSessionMode where
|
||||
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM"
|
||||
|
||||
defaultNetworkConfig :: NetworkConfig
|
||||
defaultNetworkConfig =
|
||||
NetworkConfig
|
||||
{ socksProxy = Nothing,
|
||||
hostMode = HMOnionViaSocks,
|
||||
requiredHostMode = False,
|
||||
sessionMode = TSMUser,
|
||||
tcpConnectTimeout = 7_500_000,
|
||||
tcpTimeout = 5_000_000,
|
||||
tcpKeepAlive = Just defaultKeepAliveOpts,
|
||||
@@ -239,8 +253,8 @@ transportHost' = transportHost . client_
|
||||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe ByteString -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} proxyUsername msgQ disconnected = do
|
||||
case chooseTransportHost networkConfig (host protocolServer) of
|
||||
Right useHost ->
|
||||
(atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
|
||||
@@ -274,7 +288,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
||||
let tcConfig = transportClientConfig networkConfig
|
||||
action <-
|
||||
async $
|
||||
runTransportClient tcConfig useHost port' (Just $ keyHash protocolServer) (client t c cVar)
|
||||
runTransportClient tcConfig proxyUsername useHost port' (Just $ keyHash protocolServer) (client t c cVar)
|
||||
`finally` atomically (putTMVar cVar $ Left PCENetworkError)
|
||||
c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar)
|
||||
pure $ case c_ of
|
||||
|
||||
@@ -160,7 +160,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: ExceptT ProtocolClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) Nothing (Just msgQ) clientDisconnected
|
||||
|
||||
clientDisconnected :: SMPClient -> IO ()
|
||||
clientDisconnected _ = do
|
||||
|
||||
@@ -77,6 +77,7 @@ module Simplex.Messaging.Protocol
|
||||
BasicAuth (..),
|
||||
SrvLoc (..),
|
||||
CorrId (..),
|
||||
EntityId,
|
||||
QueueId,
|
||||
RecipientId,
|
||||
SenderId,
|
||||
|
||||
@@ -107,15 +107,15 @@ defaultTransportClientConfig :: TransportClientConfig
|
||||
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True
|
||||
|
||||
-- | Connect to passed TCP host:port and pass handle to the client.
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} host port keyHash client = do
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do
|
||||
let hostName = B.unpack $ strEncode host
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash
|
||||
connectTCP = case socksProxy of
|
||||
Just proxy -> connectSocksClient proxy $ hostAddr host
|
||||
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
|
||||
_ -> connectTCPClient hostName
|
||||
c <- liftIO $ do
|
||||
sock <- connectTCP port
|
||||
@@ -153,10 +153,12 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
|
||||
defaultSMPPort :: PortNumber
|
||||
defaultSMPPort = 5223
|
||||
|
||||
connectSocksClient :: SocksProxy -> SocksHostAddress -> ServiceName -> IO Socket
|
||||
connectSocksClient (SocksProxy addr) hostAddr _port = do
|
||||
connectSocksClient :: SocksProxy -> Maybe ByteString -> SocksHostAddress -> ServiceName -> IO Socket
|
||||
connectSocksClient (SocksProxy addr) proxyUsername hostAddr _port = do
|
||||
let port = if null _port then defaultSMPPort else fromMaybe defaultSMPPort $ readMaybe _port
|
||||
fst <$> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
|
||||
fst <$> case proxyUsername of
|
||||
Just username -> socksConnectAuth (defaultSocksConf addr) (SocksAddress hostAddr port) (SocksCredentials username "")
|
||||
_ -> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
|
||||
|
||||
defaultSocksHost :: HostAddress
|
||||
defaultSocksHost = tupleToHostAddress (127, 0, 0, 1)
|
||||
|
||||
@@ -120,7 +120,7 @@ sendRequest HTTP2Client {reqQ, config} req = do
|
||||
|
||||
runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> HostName -> ServiceName -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
|
||||
runHTTP2Client tlsParams caStore tcConfig host port client =
|
||||
runTLSTransportClient tlsParams caStore tcConfig (THDomainName host) port Nothing $ \c ->
|
||||
runTLSTransportClient tlsParams caStore tcConfig Nothing (THDomainName host) port Nothing $ \c ->
|
||||
withTlsConfig c 16384 (`run` client)
|
||||
where
|
||||
run = H.run $ ClientConfig "https" (B.pack host) 20
|
||||
|
||||
@@ -218,8 +218,8 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do
|
||||
runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientTest alice bob baseId = do
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
@@ -254,8 +254,8 @@ runAgentClientTest alice bob baseId = do
|
||||
runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientContactTest alice bob baseId = do
|
||||
Right () <- runExceptT $ do
|
||||
(_, qInfo) <- createConnection alice True SCMContact Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, REQ invId _ "bob's connInfo") <- get alice
|
||||
bobId <- acceptContact alice True invId "alice's connInfo"
|
||||
("", _, CONF confId _ "alice's connInfo") <- get bob
|
||||
@@ -302,9 +302,9 @@ testAsyncInitiatingOffline = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice'
|
||||
@@ -320,8 +320,8 @@ testAsyncJoiningOfflineBeforeActivation = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
@@ -338,9 +338,9 @@ testAsyncBothOffline = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
@@ -360,9 +360,9 @@ testAsyncServerOffline t = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
-- create connection and shutdown the server
|
||||
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
|
||||
runExceptT $ createConnection alice True SCMInvitation Nothing
|
||||
runExceptT $ createConnection alice 1 True SCMInvitation Nothing
|
||||
-- connection fails
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo"
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo"
|
||||
("", "", DOWN srv conns) <- get alice
|
||||
srv `shouldBe` testSMPServer
|
||||
conns `shouldBe` [bobId]
|
||||
@@ -372,7 +372,7 @@ testAsyncServerOffline t = do
|
||||
liftIO $ do
|
||||
srv1 `shouldBe` testSMPServer
|
||||
conns1 `shouldBe` [bobId]
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
@@ -387,9 +387,9 @@ testAsyncHelloTimeout = do
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(_, cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(_, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
|
||||
pure ()
|
||||
|
||||
@@ -444,8 +444,8 @@ testDuplicateMessage t = do
|
||||
|
||||
makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
|
||||
makeConnection alice bob = do
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
@@ -459,7 +459,7 @@ testInactiveClientDisconnected t = do
|
||||
withSmpServerConfigOn t cfg' testPort $ \_ -> do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
get alice ##> ("", "", DOWN testSMPServer [connId])
|
||||
pure ()
|
||||
|
||||
@@ -470,7 +470,7 @@ testActiveClientNotDisconnected t = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
ts <- getSystemTime
|
||||
Right () <- runExceptT $ do
|
||||
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
keepSubscribing alice connId ts
|
||||
pure ()
|
||||
where
|
||||
@@ -619,10 +619,10 @@ testAsyncCommands = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
bobId <- createConnectionAsync alice "1" True SCMInvitation
|
||||
bobId <- createConnectionAsync alice 1 "1" True SCMInvitation
|
||||
("1", bobId', INV (ACR _ qInfo)) <- get alice
|
||||
liftIO $ bobId' `shouldBe` bobId
|
||||
aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo"
|
||||
aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo"
|
||||
("2", aliceId', OK) <- get bob
|
||||
liftIO $ aliceId' `shouldBe` aliceId
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
@@ -663,7 +663,7 @@ testAsyncCommands = do
|
||||
testAsyncCommandsRestore :: ATransport -> IO ()
|
||||
testAsyncCommandsRestore t = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation
|
||||
Right bobId <- runExceptT $ createConnectionAsync alice 1 "1" True SCMInvitation
|
||||
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
|
||||
disconnectAgentClient alice
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
@@ -679,8 +679,8 @@ testAcceptContactAsync = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(_, qInfo) <- createConnection alice True SCMContact Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, REQ invId _ "bob's connInfo") <- get alice
|
||||
bobId <- acceptContactAsync alice "1" True invId "alice's connInfo"
|
||||
("1", bobId', OK) <- get alice
|
||||
@@ -808,11 +808,11 @@ testCreateQueueAuth clnt1 clnt2 = do
|
||||
a <- getClient clnt1
|
||||
b <- getClient clnt2
|
||||
Right created <- runExceptT $ do
|
||||
tryError (createConnection a True SCMInvitation Nothing) >>= \case
|
||||
tryError (createConnection a 1 True SCMInvitation Nothing) >>= \case
|
||||
Left (SMP AUTH) -> pure 0
|
||||
Left e -> throwError e
|
||||
Right (bId, qInfo) ->
|
||||
tryError (joinConnection b True qInfo "bob's connInfo") >>= \case
|
||||
tryError (joinConnection b 1 True qInfo "bob's connInfo") >>= \case
|
||||
Left (SMP AUTH) -> pure 1
|
||||
Left e -> throwError e
|
||||
Right aId -> do
|
||||
@@ -826,7 +826,7 @@ testCreateQueueAuth clnt1 clnt2 = do
|
||||
pure created
|
||||
where
|
||||
getClient (clntAuth, clntVersion) =
|
||||
let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]}
|
||||
let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]}
|
||||
smpCfg = (defaultClientConfig :: ProtocolClientConfig) {smpServerVRange = mkVersionRange 4 clntVersion}
|
||||
in getSMPAgentClient agentCfg {smpCfg} servers
|
||||
|
||||
@@ -834,7 +834,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
|
||||
testSMPServerConnectionTest t newQueueBasicAuth srv =
|
||||
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running
|
||||
Right r <- runExceptT $ testSMPServerConnection a srv
|
||||
Right r <- runExceptT $ testSMPServerConnection a 1 srv
|
||||
pure r
|
||||
|
||||
testRatchetAdHash :: IO ()
|
||||
|
||||
@@ -213,8 +213,8 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right (bobId, aliceId, nonce, message) <- runExceptT $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
@@ -275,9 +275,9 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
|
||||
DeviceToken {} <- registerTestToken bob "bcde" NMInstant apnsQ
|
||||
-- establish connection
|
||||
liftIO $ threadDelay 50000
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
liftIO $ threadDelay 1000000
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
liftIO $ threadDelay 750000
|
||||
void $ messageNotification apnsQ
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
@@ -327,8 +327,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
@@ -392,8 +392,8 @@ testChangeToken APNSMockServer {apnsQ} = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
Right (aliceId, bobId) <- runExceptT $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
|
||||
@@ -140,7 +140,7 @@ testForeignKeysEnabled =
|
||||
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
|
||||
|
||||
cData1 :: ConnData
|
||||
cData1 = ConnData {connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
|
||||
cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
|
||||
|
||||
testPrivateSignKey :: C.APrivateSignKey
|
||||
testPrivateSignKey = C.APrivateSignKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
|
||||
@@ -154,7 +154,8 @@ testDhSecret = "01234567890123456789012345678901"
|
||||
rcvQueue1 :: RcvQueue
|
||||
rcvQueue1 =
|
||||
RcvQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
rcvId = "1234",
|
||||
rcvPrivateKey = testPrivateSignKey,
|
||||
@@ -173,7 +174,8 @@ rcvQueue1 =
|
||||
sndQueue1 :: SndQueue
|
||||
sndQueue1 =
|
||||
SndQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
sndId = "3456",
|
||||
sndPublicKey = Nothing,
|
||||
@@ -314,7 +316,8 @@ testUpgradeRcvConnToDuplex =
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
sndId = "2345",
|
||||
sndPublicKey = Nothing,
|
||||
@@ -340,7 +343,8 @@ testUpgradeSndConnToDuplex =
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
rcvId = "3456",
|
||||
rcvPrivateKey = testPrivateSignKey,
|
||||
|
||||
@@ -69,7 +69,7 @@ ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
|
||||
testNtfClient client = do
|
||||
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig host ntfTestPort (Just testKeyHash) $ \h ->
|
||||
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h ->
|
||||
liftIO (runExceptT $ ntfClientHandshake h testKeyHash supportedNTFServerVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
@@ -11,7 +12,8 @@ import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Network.Socket (ServiceName)
|
||||
import NtfClient (ntfTestPort)
|
||||
import SMPClient
|
||||
@@ -27,6 +29,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Agent.Store (UserId)
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Transport
|
||||
@@ -173,13 +176,13 @@ testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5
|
||||
initAgentServers :: InitialAgentServers
|
||||
initAgentServers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList [noAuthSrv testSMPServer],
|
||||
{ smp = userServers [noAuthSrv testSMPServer],
|
||||
ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"],
|
||||
netCfg = defaultNetworkConfig {tcpTimeout = 500_000}
|
||||
}
|
||||
|
||||
initAgentServers2 :: InitialAgentServers
|
||||
initAgentServers2 = initAgentServers {smp = L.fromList [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
|
||||
initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
|
||||
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
@@ -209,11 +212,14 @@ agentCfg =
|
||||
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> m () -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
|
||||
let cfg' = agentCfg {tcpPort = port', database = db'}
|
||||
initServers' = initAgentServers {smp = L.fromList [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
|
||||
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg' initServers')
|
||||
afterProcess
|
||||
|
||||
userServers :: NonEmpty SMPServerWithAuth -> Map UserId (NonEmpty SMPServerWithAuth)
|
||||
userServers srvs = M.fromList [(1, srvs)]
|
||||
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a $ removeFile (dbFile db')
|
||||
|
||||
@@ -226,7 +232,7 @@ withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
||||
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a
|
||||
testSMPAgentClientOn port' client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
|
||||
runTransportClient defaultTransportClientConfig useHost port' (Just testKeyHash) $ \h -> do
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
|
||||
line <- liftIO $ getLn h
|
||||
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
then client h
|
||||
|
||||
@@ -57,7 +57,7 @@ testServerStatsBackupFile = "tests/tmp/smp-server-stats.log"
|
||||
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
|
||||
testSMPClient client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig useHost testPort (Just testKeyHash) $ \h ->
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h ->
|
||||
liftIO (runExceptT $ smpClientHandshake h testKeyHash supportedSMPServerVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
Reference in New Issue
Block a user