Merge branch 'master' into xftp

This commit is contained in:
Evgeny Poberezkin
2023-01-27 18:00:11 +00:00
35 changed files with 1401 additions and 570 deletions
+2 -1
View File
@@ -7,6 +7,7 @@ module Main where
import Control.Logger.Simple
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Server (runSMPAgent)
import Simplex.Messaging.Client (defaultNetworkConfig)
@@ -18,7 +19,7 @@ cfg = defaultAgentConfig
servers :: InitialAgentServers
servers =
InitialAgentServers
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
{ smp = M.fromList [(1, L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"])],
ntf = [],
netCfg = defaultNetworkConfig
}
+6
View File
@@ -1,11 +1,17 @@
packages: .
-- packages: . ../direct-sqlcipher ../sqlcipher-simple
-- packages: . ../hs-socks
source-repository-package
type: git
location: https://github.com/simplex-chat/aeson.git
tag: 3eb66f9a68f103b5f1489382aad89f5712a64db7
source-repository-package
type: git
location: https://github.com/simplex-chat/hs-socks.git
tag: a30cc7a79a08d8108316094f8f2f82a0c5e1ac51
source-repository-package
type: git
location: https://github.com/simplex-chat/direct-sqlcipher.git
+1 -1
View File
@@ -1,5 +1,5 @@
name: simplexmq
version: 4.3.0
version: 4.3.1
synopsis: SimpleXMQ message broker
description: |
This package includes <./docs/Simplex-Messaging-Server.html server>,
+11 -6
View File
@@ -95,10 +95,12 @@ To send the file, the sender will:
- compute SHA512 digest
- pad the file to match the whole number of chunks in size,
- encrypt it with a randomly chosen symmetric key and IV (e.g., using NaCL cryptobox),
- encrypt it with a randomly chosen symmetric key and IV (e.g., using NaCL crypto_secretbox),
- split into fixed size chunks
- upload each chunk to a randomly chosen server.
The sending client should generate more per-recipient keys than the actual number of recipients, possibly rounding up to a power of 2, to conceal the actual number of intended recipients.
Then the sending client will combine addresses of all chunks and other information into "file description", different for each file recipient, that will include:
- an encryption key that was used to encrypt the file (the same for all recipients).
@@ -106,22 +108,24 @@ Then the sending client will combine addresses of all chunks and other informati
- list of chunk descriptions; information for each chunk:
- private Ed25519 key to sign commands for file transfer server.
- chunk address (server host and chunk ID).
- chunk sha512 digest
To reduce the size, chunk descriptions will be grouped by the server host.
This "file description" itself will be sent as a small file. To estimate its size:
This "file description" itself will be sent as a small file over an authenticated channel, to prevent file description modification. To estimate its size:
- each chunk \* redundancy per chunk, assuming chunks are grouped per server:
- chunk number in the file - 8 bytes (including any overhead)
- 1-based chunk number in the file - 8 bytes (including any overhead)
- Ed25519 key (different for each recipient / chunk combination) - 32 bytes \* 4/3 (base64, assuming text encoding)
- chunk ID (different for each recipient) - 64 bytes \* 4/3
- optional (only in the first chunk occurence) chunk sha512 digests - 64 bytes \* 4/3
- server addresses - say, 128 bytes per server
- sha512 digest - 64 bytes \* 4/3
- encryption key - 32 bytes \* 4/3
- IV - 32 bytes \* 4/3
- encoding overhead - say, 256 bytes
For 1gb file, sent via 4 different servers, in 8Mb chunks, with redundancy 2, the size of "file description", assuming text encoding, will be ~34kb (`128 * (8 + 32 + 64) * 2 * 4/3 + 128 * 4 + (64 + 32 + 32) * 4/3 + 256`).
For 1gb file, sent via 4 different servers, in 8Mb chunks, with redundancy 2, the size of "file description", assuming text encoding, will be ~45kb (`128 * (8 + 32 + 64) * 2 * 4/3 + 128 * 64 * 4/3 + 128 * 4 + (64 + 32 + 32) * 4/3 + 256`).
File description format (yml):
@@ -132,11 +136,12 @@ chunk: 8Mb
hash: abc=
key: abc=
iv: abc=
part_hashes: [def=, def=, def=, def=]
parts:
- server: xftp://abc=@example1.com
chunks: [1:abc=:def=, 3:abc=:def=]
chunks: [1:abc=:def=:ghi=, 3:abc=:def=:ghi=]
- server: xftp://abc=@example2.com
chunks: [2:abc=:def=, 4:abc=:def=]
chunks: [2:abc=:def=:ghi=, 4:abc=:def=:ghi=]
- server: xftp://abc=@example3.com
chunks: [1:abc=:def=, 4:abc=:def=]
- server: xftp://abc=@example4.com
+5 -1
View File
@@ -5,7 +5,7 @@ cabal-version: 1.12
-- see: https://github.com/sol/hpack
name: simplexmq
version: 4.3.0
version: 4.3.1
synopsis: SimpleXMQ message broker
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
<./docs/Simplex-Messaging-Client.html client> and
@@ -64,6 +64,10 @@ library
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
Simplex.Messaging.Agent.TAsyncs
Simplex.Messaging.Agent.TRcvQueues
Simplex.Messaging.Client
Simplex.Messaging.Client.Agent
+295 -194
View File
@@ -38,6 +38,8 @@ module Simplex.Messaging.Agent
disconnectAgentClient,
resumeAgentClient,
withConnLock,
createUser,
deleteUser,
createConnectionAsync,
joinConnectionAsync,
allowConnectionAsync,
@@ -45,6 +47,7 @@ module Simplex.Messaging.Agent
ackMessageAsync,
switchConnectionAsync,
deleteConnectionAsync,
deleteConnectionsAsync,
createConnection,
joinConnection,
allowConnection,
@@ -61,6 +64,7 @@ module Simplex.Messaging.Agent
switchConnection,
suspendConnection,
deleteConnection,
deleteConnections,
getConnectionServers,
getConnectionRatchetAdHash,
setSMPServers,
@@ -86,7 +90,7 @@ module Simplex.Messaging.Agent
where
import Control.Concurrent.STM (stateTVar)
import Control.Logger.Simple (logInfo, showText)
import Control.Logger.Simple (logError, logInfo, showText)
import Control.Monad.Except
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader
@@ -97,7 +101,7 @@ import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.), (.::))
import Data.Foldable (foldl')
import Data.Functor (($>))
import Data.List (deleteFirstsBy, find)
import Data.List (deleteFirstsBy, find, (\\))
import Data.List.NonEmpty (NonEmpty (..), (<|))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
@@ -110,6 +114,7 @@ import Data.Time.Clock.System (systemToUTCTime)
import qualified Database.SQLite.Simple as DB
import Simplex.Messaging.Agent.Client
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Lock (withLock)
import Simplex.Messaging.Agent.NtfSubSupervisor
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
@@ -124,13 +129,13 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, sameSrvAddr, sameSrvAddr')
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, protoServer, sameSrvAddr')
import qualified Simplex.Messaging.Protocol as SMP
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import System.Random (randomR)
import UnliftIO.Async (async, mapConcurrently, race_)
import UnliftIO.Async (async, race_)
import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -143,7 +148,7 @@ getSMPAgentClient cfg initServers = newSMPAgentEnv cfg >>= runReaderT runAgent
where
runAgent = do
c <- getAgentClient initServers
void $ race_ (subscriber c) (runNtfSupervisor c) `forkFinally` const (disconnectAgentClient c)
void $ raceAny_ [subscriber c, runNtfSupervisor c, cleanupManager c] `forkFinally` const (disconnectAgentClient c)
pure c
disconnectAgentClient :: MonadUnliftIO m => AgentClient -> m ()
@@ -155,16 +160,22 @@ disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns}} = do
resumeAgentClient :: MonadIO m => AgentClient -> m ()
resumeAgentClient c = atomically $ writeTVar (active c) True
-- |
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
createUser c = withAgentEnv c . createUser' c
-- | Delete user record optionally deleting all user's connections on SMP servers
deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m ()
deleteUser c = withAgentEnv c .: deleteUser' c
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
createConnectionAsync c userId corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c userId corrId enableNtfs cMode
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnectionAsync c userId corrId enableNtfs = withAgentEnv c .: joinConnAsync c userId corrId enableNtfs
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -183,16 +194,20 @@ switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -
switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c
-- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ()
deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
deleteConnectionAsync c = withAgentEnv c . deleteConnectionAsync' c
-- -- | Delete SMP agent connections using batch commands asynchronously, no synchronous response
deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> [ConnId] -> m ()
deleteConnectionsAsync c = withAgentEnv c . deleteConnectionsAsync' c
-- | Create SMP agent connection (NEW command)
createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
createConnection c enableNtfs cMode clientData = withAgentEnv c $ newConn c "" False enableNtfs cMode clientData
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
createConnection c userId enableNtfs cMode clientData = withAgentEnv c $ newConn c userId "" enableNtfs cMode clientData
-- | Join SMP agent connection (JOIN command)
joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnection c userId enableNtfs = withAgentEnv c .: joinConn c userId "" False enableNtfs
-- | Allow connection to continue after CONF notification (LET command)
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -247,6 +262,10 @@ suspendConnection c = withAgentEnv c . suspendConnection' c
deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
deleteConnection c = withAgentEnv c . deleteConnection' c
-- | Delete multiple connections, batching commands when possible
deleteConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
deleteConnections c = withAgentEnv c . deleteConnections' c
-- | get servers used for connection
getConnectionServers :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
getConnectionServers c = withAgentEnv c . getConnectionServers' c
@@ -256,12 +275,12 @@ getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m By
getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c
-- | Change servers to be used for creating new queues
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers c = withAgentEnv c . setSMPServers' c
setSMPServers :: AgentErrorMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers c = withAgentEnv c .: setSMPServers' c
-- | Test SMP server
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
testSMPServerConnection c = withAgentEnv c . runSMPServerTest c
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
testSMPServerConnection c = withAgentEnv c .: runSMPServerTest c
setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m ()
setNtfServers c = withAgentEnv c . setNtfServers' c
@@ -349,8 +368,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do
-- | execute any SMP agent command
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
processCommand c (connId, cmd) = case cmd of
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode Nothing
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
@@ -361,29 +380,53 @@ processCommand c (connId, cmd) = case cmd of
OFF -> suspendConnection' c connId $> (connId, OK)
DEL -> deleteConnection' c connId $> (connId, OK)
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
where
-- command interface does not support different users
userId = 1
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
newConnAsync c corrId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
connId <- withStore c $ \db -> createNewConn db g cData cMode
createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
createUser' c srvs = do
userId <- withStore' c createUserRecord
atomically $ TM.insert userId srvs $ smpServers c
pure userId
deleteUser' :: AgentMonad m => AgentClient -> UserId -> Bool -> m ()
deleteUser' c userId delSMPQueues = do
if delSMPQueues
then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c
else withStore c (`deleteUserRecord` userId)
atomically $ TM.delete userId $ smpServers c
where
delUser =
whenM (withStore' c (`deleteUserWithoutConns` userId)) $
atomically $ writeTBQueue (subQ c) ("", "", DEL_USER userId)
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
newConnAsync c userId corrId enableNtfs cMode = do
connId <- newConnNoQueues c userId "" enableNtfs cMode
enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode)
pure connId
joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId
newConnNoQueues c userId connId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
withStore c $ \db -> createNewConn db g cData cMode
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo =
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _cInfo =
throwError $ CMD PROHIBITED
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -397,9 +440,9 @@ acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> Invitat
acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ ContactConnection {} -> do
SomeConn _ (ContactConnection ConnData {userId} _) -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConnAsync c corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
@@ -419,23 +462,21 @@ ackMessageAsync' c corrId connId msgId = do
(RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId
enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
deleteConnectionAsync' c@AgentClient {subQ} corrId connId = withConnLock c connId "deleteConnectionAsync" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection _ (rq :| _) _ -> enqueueDelete rq
RcvConnection _ rq -> enqueueDelete rq
ContactConnection _ rq -> enqueueDelete rq
SndConnection _ _ -> delete
NewConnection _ -> delete
where
enqueueDelete :: RcvQueue -> m ()
enqueueDelete RcvQueue {server} = do
withStore' c $ \db -> setConnDeleted db connId
disableConn c connId
enqueueCommand c corrId connId (Just server) $ AInternalCommand ICDeleteConn
delete :: m ()
delete = withStore' c (`deleteConn` connId) >> atomically (writeTBQueue subQ (corrId, connId, OK))
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
deleteConnectionAsync' c connId = deleteConnectionsAsync' c [connId]
deleteConnectionsAsync' :: AgentMonad m => AgentClient -> [ConnId] -> m ()
deleteConnectionsAsync' = deleteConnectionsAsync_ $ pure ()
deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> [ConnId] -> m ()
deleteConnectionsAsync_ onSuccess c connIds = case connIds of
[] -> onSuccess
_ -> do
(_, rqs, connIds') <- prepareDeleteConnections_ getConn c connIds
withStore' c $ forM_ connIds' . setConnDeleted
void . forkIO $
withLock (deleteLock c) "deleteConnectionsAsync" $
deleteConnQueues c True rqs >> onSuccess
-- | Add connection to the new receive queue
switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
@@ -444,46 +485,42 @@ switchConnectionAsync' c corrId connId =
SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
_ -> throwError $ CMD PROHIBITED
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
newConn c connId asyncMode enableNtfs cMode clientData =
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode clientData
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
newConn c userId connId enableNtfs cMode clientData =
getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do
newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newConnSrv c userId connId enableNtfs cMode clientData srv = do
connId' <- newConnNoQueues c userId connId enableNtfs cMode
newRcvConnSrv c userId connId' enableNtfs cMode clientData srv
newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newRcvConnSrv c userId connId enableNtfs cMode clientData srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
(q, qUri) <- newRcvQueue c "" srv smpClientVRange
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
let rq = (q :: RcvQueue) {connId = connId'}
(rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange `catchError` \e -> liftIO (print e) >> throwError e
void . withStore c $ \db -> updateNewConnRcv db connId rq
addSubscription c rq
when enableNtfs $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId', NSCCreate)
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
let crData = ConnReqUriData simplexChat smpAgentVRange [qUri] clientData
case cMode of
SCMContact -> pure (connId', CRContactUri crData)
SCMContact -> pure (connId, CRContactUri crData)
SCMInvitation -> do
(pk1, pk2, e2eRcvParams) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
withStore' c $ \db -> createRatchetX3dhKeys db connId' pk1 pk2
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
where
setUpConn True rq _ = do
void . withStore c $ \db -> updateNewConnRcv db connId rq
pure connId
setUpConn False rq connAgentVersion = do
g <- asks idsDrg
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
withStore c $ \db -> createRcvConn db g cData rq cMode
withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2
pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c connId asyncMode enableNtfs cReq cInfo = do
joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c userId connId asyncMode enableNtfs cReq cInfo = do
srv <- case cReq of
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
getNextSMPServer c [qServer q]
_ -> getSMPServer c
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
getNextSMPServer c userId [qServer q]
_ -> getSMPServer c userId
joinConnSrv c userId connId asyncMode enableNtfs cReq cInfo srv
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
joinConnSrv c userId connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
case ( qUri `compatibleVersion` smpClientVRange,
e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange,
@@ -493,9 +530,9 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
(_, rcDHRs) <- liftIO C.generateKeyPair'
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
q <- newSndQueue "" qInfo
q <- newSndQueue userId "" qInfo
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
connId' <- setUpConn asyncMode cData q rc
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
@@ -520,23 +557,23 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
liftIO $ createRatchet db connId' rc
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConnSrv c connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
joinConnSrv c userId connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
crAgentVRange `compatibleVersion` aVRange
) of
(Just qInfo, Just vrsn) -> do
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation Nothing srv
sendInvitation c qInfo vrsn cReq cInfo
(connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing srv
sendInvitation c userId qInfo vrsn cReq cInfo
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
joinConnSrv _c _userId _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
throwError $ CMD PROHIBITED
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion
createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c userId connId srv $ versionToRange smpClientVersion
let qInfo = toVersionT qUri smpClientVersion
addSubscription c rq
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
@@ -564,9 +601,9 @@ acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId
acceptContact' c connId enableNtfs invId ownConnInfo = withConnLock c connId "acceptContact" $ do
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ ContactConnection {} -> do
SomeConn _ (ContactConnection ConnData {userId} _) -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
joinConn c userId connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
@@ -577,32 +614,16 @@ rejectContact' c contactConnId invId =
withStore c $ \db -> deleteInvitation db contactConnId invId
-- | Subscribe to receive connection messages (SUB command) in Reader monad
subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
subscribeConnection' c connId = do
SomeConn _ conn <- withStore c (`getConn` connId)
resumeConnCmds c connId
case conn of
DuplexConnection cData (rq :| rqs) sqs -> do
mapM_ (resumeMsgDelivery c cData) sqs
subscribe cData rq
mapM_ (\q -> subscribeQueue c q `catchError` \_ -> pure ()) rqs
SndConnection cData sq -> do
resumeMsgDelivery c cData sq
case status (sq :: SndQueue) of
Confirmed -> pure ()
Active -> throwError $ CONN SIMPLEX
_ -> throwError $ INTERNAL "unexpected queue status"
RcvConnection cData rq -> subscribe cData rq
ContactConnection cData rq -> subscribe cData rq
NewConnection _ -> pure ()
where
subscribe :: ConnData -> RcvQueue -> m ()
subscribe ConnData {enableNtfs} rq = do
subscribeQueue c rq
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, if enableNtfs then NSCCreate else NSCDelete)
subscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId]
type QSubResult = (QueueStatus, Either AgentErrorType ())
toConnResult :: AgentMonad m => ConnId -> Map ConnId (Either AgentErrorType ()) -> m ()
toConnResult connId rs = case M.lookup connId rs of
Just (Right ()) -> when (M.size rs > 1) $ logError $ T.pack $ "too many results " <> show (M.size rs)
Just (Left e) -> throwError e
_ -> throwError $ INTERNAL $ "no result for connection " <> B.unpack connId
type QCmdResult = (QueueStatus, Either AgentErrorType ())
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
subscribeConnections' _ [] = pure M.empty
@@ -611,10 +632,9 @@ subscribeConnections' c connIds = do
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
srvRcvQs :: Map SMPServer [RcvQueue] = M.foldl' (foldl' addRcvQueue) M.empty rcvQs
mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs
mapM_ (resumeConnCmds c) $ M.keys cs
rcvRs <- connResults . concat <$> mapConcurrently subscribe (M.assocs srvRcvQs)
rcvRs <- connResults <$> subscribeQueues c (concat $ M.elems rcvQs)
ns <- asks ntfSupervisor
tkn <- readTVarIO (ntfTkn ns)
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns
@@ -622,9 +642,9 @@ subscribeConnections' c connIds = do
notifyResultError rs
pure rs
where
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) (NonEmpty RcvQueue)
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
rcvQueueOrResult (SomeConn _ conn) = case conn of
DuplexConnection _ rqs _ -> Right rqs
DuplexConnection _ rqs _ -> Right $ L.toList rqs
SndConnection _ sq -> Left $ sndSubResult sq
RcvConnection _ rq -> Right [rq]
ContactConnection _ rq -> Right [rq]
@@ -634,21 +654,17 @@ subscribeConnections' c connIds = do
Confirmed -> Right ()
Active -> Left $ CONN SIMPLEX
_ -> Left $ INTERNAL "unexpected queue status"
addRcvQueue :: Map SMPServer [RcvQueue] -> RcvQueue -> Map SMPServer [RcvQueue]
addRcvQueue m rq@RcvQueue {server} = M.alter (Just . maybe [rq] (rq :)) server m
subscribe :: (SMPServer, [RcvQueue]) -> m [(RcvQueue, Either AgentErrorType ())]
subscribe (srv, qs) = snd <$> subscribeQueues c srv qs
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
connResults = M.map snd . foldl' addResult M.empty
where
-- collects results by connection ID
addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QSubResult
addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
-- combines two results for one connection, by using only Active queues (if there is at least one Active queue)
combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult
combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
combineRes r' _ = Just r'
order :: QSubResult -> Int
order :: QCmdResult -> Int
order (Active, Right _) = 1
order (Active, _) = 2
order (_, Right _) = 3
@@ -675,10 +691,7 @@ subscribeConnections' c connIds = do
writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
resubscribeConnection' c connId =
unlessM
(atomically $ hasActiveSubscription c connId)
(subscribeConnection' c connId)
resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId]
resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
resubscribeConnections' _ [] = pure M.empty
@@ -791,21 +804,21 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
atomically $ beginAgentOperation c AOSndNetwork
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd
Right cmd -> processCmd (riFast ri) cmdId cmd
where
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
processCmd ri corrId connId cmdId command = case command of
processCmd :: RetryInterval -> AsyncCmdId -> PendingCommand -> m ()
processCmd ri cmdId PendingCommand {corrId, userId, connId, command} = case command of
AClientCommand cmd -> case cmd of
NEW enableNtfs (ACM cMode) -> noServer $ do
usedSrvs <- newTVarIO ([] :: [SMPServer])
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
(_, cReq) <- newConnSrv c connId True enableNtfs cMode Nothing srv
(_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing srv
notify $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do
let initUsed = [qServer q]
usedSrvs <- newTVarIO initUsed
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
void $ joinConnSrv c userId connId True enableNtfs cReq connInfo srv
notify OK
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK
@@ -827,23 +840,8 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
secure rq senderKey
when (duplexHandshake cData == Just True) . void $
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
ICDeleteConn ->
withServer $ \srv -> tryWithLock "ICDeleteConn" $ do
SomeConn _ conn <- withStore c $ \db -> getAnyConn db connId True
case conn of
DuplexConnection _ (rq :| rqs) _ -> delete srv rq $ case rqs of
[] -> notify OK
RcvQueue {server = srv'} : _ -> enqueue srv'
RcvConnection _ rq -> delete srv rq $ notify OK
ContactConnection _ rq -> delete srv rq $ notify OK
_ -> internalErr "command requires connection with rcv queue"
where
delete :: SMPServer -> RcvQueue -> m () -> m ()
delete srv rq@RcvQueue {server} next
| sameSrvAddr srv server = deleteConnQueue c rq >> next
| otherwise = enqueue server
enqueue :: SMPServer -> m ()
enqueue srv = enqueueCommand c corrId connId (Just srv) $ AInternalCommand ICDeleteConn
-- ICDeleteConn is no longer used, but it can be present in old client databases
ICDeleteConn -> withStore' c (`deleteCommand` cmdId)
ICQSecure rId senderKey ->
withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) ->
case find (sameQueue (srv, rId)) rqs of
@@ -860,7 +858,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
| otherwise -> do
deleteQueue c rq'
withStore' c $ \db -> deleteConnRcvQueue db connId rq'
withStore' c $ \db -> deleteConnRcvQueue db rq'
when (enableNtfs cData) $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
@@ -901,10 +899,11 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m ()
withNextSrv usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c used
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId used
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
srvs_ <- TM.lookup userId $ smpServers c
let unused = maybe [] ((\\ used) . map protoServer . L.toList) srvs_
used' = if null unused then initUsed else srv : used
writeTVar usedSrvs $! used'
action srvAuth
-- ^ ^ ^ async command processing /
@@ -978,7 +977,7 @@ getPendingMsgQ c SndQueue {server, sndId} = do
pure q
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, duplexHandshake} sq = do
(mq, qLock) <- atomically $ getPendingMsgQ c sq
ri <- asks $ messageRetryInterval . config
forever $ do
@@ -1067,7 +1066,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- and this branch should never be reached as receive is created before the confirmation,
-- so the condition is not necessary here, strictly speaking.
_ -> unless (duplexHandshake == Just True) $ do
srv <- getSMPServer c
srv <- getSMPServer c userId
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
@@ -1142,12 +1141,12 @@ switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
switchConnection' c connId = withConnLock c connId "switchConnection" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
DuplexConnection cData@ConnData {userId} rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
clientVRange <- asks $ smpClientVRange . config
-- try to get the server that is different from all queues, or at least from the primary rcv queue
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
srv' <- if srv == server then getNextSMPServer c [server] else pure srvAuth
(q, qUri) <- newRcvQueue c connId srv' clientVRange
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth
(q, qUri) <- newRcvQueue c userId connId srv' clientVRange
let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnRcvQueue db connId rq'
addSubscription c rq'
@@ -1173,24 +1172,18 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
NewConnection _ -> throwError $ CMD PROHIBITED
-- | Delete SMP agent connection (DEL command) in Reader monad
-- unlike deleteConnectionAsync, this function does not mark connection as deleted in case of deletion failure
-- currently it is used only in tests
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection _ rqs _ -> mapM_ (deleteConnQueue c) rqs >> disableConn c connId >> deleteConn'
RcvConnection _ rq -> delete rq
ContactConnection _ rq -> delete rq
SndConnection _ _ -> deleteConn'
NewConnection _ -> deleteConn'
where
delete :: RcvQueue -> m ()
delete rq = deleteConnQueue c rq >> disableConn c connId >> deleteConn'
deleteConn' = withStore' c (`deleteConn` connId)
deleteConnection' c connId = toConnResult connId =<< deleteConnections' c [connId]
deleteConnQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
deleteConnQueue c rq@RcvQueue {connId} = do
deleteQueue c rq
withStore' c $ \db -> deleteConnRcvQueue db connId rq
connRcvQueues :: Connection d -> [RcvQueue]
connRcvQueues = \case
DuplexConnection _ rqs _ -> L.toList rqs
RcvConnection _ rq -> [rq]
ContactConnection _ rq -> [rq]
SndConnection _ _ -> []
NewConnection _ -> []
disableConn :: AgentMonad m => AgentClient -> ConnId -> m ()
disableConn c connId = do
@@ -1198,6 +1191,94 @@ disableConn c connId = do
ns <- asks ntfSupervisor
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete)
-- Unlike deleteConnectionsAsync, this function does not mark connections as deleted in case of deletion failure.
deleteConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
deleteConnections' = deleteConnections_ getConn False
deleteDeletedConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
deleteDeletedConns = deleteConnections_ getDeletedConn True
prepareDeleteConnections_ ::
forall m.
AgentMonad m =>
(DB.Connection -> ConnId -> IO (Either StoreError SomeConn)) ->
AgentClient ->
[ConnId] ->
m (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId])
prepareDeleteConnections_ getConnection c connIds = do
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConnection)
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(delRs, rcvQs) = M.mapEither rcvQueues cs
rqs = concat $ M.elems rcvQs
connIds' = M.keys rcvQs
forM_ connIds' $ disableConn c
withStore' c $ forM_ (M.keys delRs) . deleteConn
pure (errs' <> delRs, rqs, connIds')
where
rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
rcvQueues (SomeConn _ conn) = case connRcvQueues conn of
[] -> Left $ Right ()
rqs -> Right rqs
deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ()))
deleteConnQueues c ntf rqs = do
rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs)
forM_ (M.assocs rs) $ \case
(connId, Right _) -> withStore' c (`deleteConn` connId) >> notify ("", connId, DEL_CONN)
_ -> pure ()
pure rs
where
deleteQueueRecs :: [(RcvQueue, Either AgentErrorType ())] -> m [(RcvQueue, Either AgentErrorType ())]
deleteQueueRecs rs = do
maxErrs <- asks $ deleteErrorCount . config
forM rs $ \(rq, r) -> do
r' <- case r of
Right _ -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq Nothing $> r
Left e
| temporaryOrHostError e && deleteErrors rq + 1 < maxErrs -> withStore' c (`incRcvDeleteErrors` rq) $> r
| otherwise -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq (Just e) $> Right ()
pure (rq, r')
notifyRQ rq e_ = notify ("", qConnId rq, DEL_RCVQ (qServer rq) (queueId rq) e_)
notify = when ntf . atomically . writeTBQueue (subQ c)
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
connResults = M.map snd . foldl' addResult M.empty
where
-- collects results by connection ID
addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
-- combines two results for one connection, by prioritizing errors in Active queues
combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
combineRes r' _ = Just r'
order :: QCmdResult -> Int
order (Active, Left _) = 1
order (_, Left _) = 2
order _ = 3
deleteConnections_ ::
forall m.
AgentMonad m =>
(DB.Connection -> ConnId -> IO (Either StoreError SomeConn)) ->
Bool ->
AgentClient ->
[ConnId] ->
m (Map ConnId (Either AgentErrorType ()))
deleteConnections_ _ _ _ [] = pure M.empty
deleteConnections_ getConnection ntf c connIds = do
(rs, rqs, _) <- prepareDeleteConnections_ getConnection c connIds
rcvRs <- deleteConnQueues c ntf rqs
let rs' = M.union rs rcvRs
notifyResultError rs'
pure rs'
where
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
notifyResultError rs = do
let actual = M.size rs
expected = length connIds
when (actual /= expected) . atomically $
writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected)
getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
getConnectionServers' c connId = do
SomeConn _ conn <- withStore c (`getConn` connId)
@@ -1217,8 +1298,8 @@ connectionStats = \case
NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []}
-- | Change servers to be used for creating new queues, in Reader monad
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers' c = atomically . writeTVar (smpServers c)
setSMPServers' :: AgentMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
setSMPServers' c userId srvs = atomically $ TM.insert userId srvs $ smpServers c
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus
registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
@@ -1450,15 +1531,16 @@ execAgentStoreSQL' :: AgentMonad m => AgentClient -> Text -> m [Text]
execAgentStoreSQL' c sql = withStore' c (`execSQL` sql)
debugAgentLocks' :: AgentMonad m => AgentClient -> m AgentLocks
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs} = do
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do
connLocks <- getLocks cs
srvLocks <- getLocks rs
pure AgentLocks {connLocks, srvLocks}
delLock <- atomically $ tryReadTMVar d
pure AgentLocks {connLocks, srvLocks, delLock}
where
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
getSMPServer :: AgentMonad m => AgentClient -> m SMPServerWithAuth
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth
getSMPServer c userId = withUserServers c userId pickServer
pickServer :: AgentMonad m => NonEmpty SMPServerWithAuth -> m SMPServerWithAuth
pickServer = \case
@@ -1467,13 +1549,18 @@ pickServer = \case
gen <- asks randomServer
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServerWithAuth
getNextSMPServer c usedSrvs = do
srvs <- readTVarIO $ smpServers c
getNextSMPServer :: AgentMonad m => AgentClient -> UserId -> [SMPServer] -> m SMPServerWithAuth
getNextSMPServer c userId usedSrvs = withUserServers c userId $ \srvs ->
case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of
Just srvs' -> pickServer srvs'
_ -> pickServer srvs
withUserServers :: AgentMonad m => AgentClient -> UserId -> (NonEmpty SMPServerWithAuth -> m a) -> m a
withUserServers c userId action =
atomically (TM.lookup userId $ smpServers c) >>= \case
Just srvs -> action srvs
_ -> throwError $ INTERNAL "unknown userId - no SMP servers"
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
t <- atomically $ readTBQueue msgQ
@@ -1482,13 +1569,26 @@ subscriber c@AgentClient {msgQ} = forever $ do
Left e -> liftIO $ print e
Right _ -> return ()
cleanupManager :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
cleanupManager c = do
threadDelay =<< asks (initialCleanupDelay . config)
int <- asks (cleanupInterval . config)
forever $ do
void . runExceptT $
withLock (deleteLock c) "cleanupManager" $ do
void $ withStore' c getDeletedConns >>= deleteDeletedConns c
withStore' c deleteUsersWithoutConns >>= mapM_ notifyUserDeleted
threadDelay int
where
notifyUserDeleted userId = atomically $ writeTBQueue (subQ c) ("", "", DEL_USER userId)
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) = do
processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, sessId, rId, cmd) = do
(rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId)
processSMP rq conn $ connData conn
where
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {userId, connId, duplexHandshake} = withConnLock c connId "processSMP" $
case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $
@@ -1587,7 +1687,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
handleNotifyAck :: m () -> m ()
handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack
SMP.END ->
atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND)
atomically (TM.lookup tSess smpClients $>>= tryReadTMVar >>= processEND)
>>= logServer "<--" c srv rId
where
processEND = \case
@@ -1724,7 +1824,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
(Just _, _) -> qError "QADD: queue address is already used in connection"
(_, Just _replaced@SndQueue {dbQueueId}) -> do
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo
let sq' = (sq_ :: SndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnSndQueue db connId sq'
case (sndPublicKey, e2ePubKey) of
@@ -1796,12 +1896,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
| otherwise = MsgError MsgDuplicate -- this case is not possible
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do
clientVRange <- asks $ smpClientVRange . config
case qInfo `proveCompatible` clientVRange of
Nothing -> throwError $ AGENT A_VERSION
Just qInfo' -> do
sq <- newSndQueue connId qInfo'
sq <- newSndQueue userId connId qInfo'
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
@@ -1860,14 +1960,15 @@ agentRatchetDecrypt db connId encAgentMsg = do
liftIO $ updateRatchet db connId rc' skippedDiff
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue
newSndQueue connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m SndQueue
newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
C.SignAlg a <- asks $ cmdSignAlg . config
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
pure
SndQueue
{ connId,
{ userId,
connId,
server = smpServer,
sndId = senderId,
sndPublicKey = Just sndPublicKey,
+283 -185
View File
@@ -25,7 +25,6 @@ module Simplex.Messaging.Agent.Client
closeProtocolServerClients,
runSMPServerTest,
newRcvQueue,
subscribeQueue,
subscribeQueues,
getQueueMessage,
decryptSMPMessage,
@@ -54,6 +53,7 @@ module Simplex.Messaging.Agent.Client
sendAck,
suspendQueue,
deleteQueue,
deleteQueues,
logServer,
logSecret,
removeSubscription,
@@ -96,10 +96,10 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight, partitionEithers)
import Data.Either (lefts, partitionEithers)
import Data.Functor (($>))
import Data.List (partition)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List (foldl', partition)
import Data.List.NonEmpty (NonEmpty (..), (<|))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
@@ -117,6 +117,7 @@ import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction)
import Simplex.Messaging.Agent.TAsyncs
import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues)
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
import Simplex.Messaging.Client
@@ -131,6 +132,7 @@ import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse)
import Simplex.Messaging.Protocol
( AProtocolType (..),
BrokerMsg,
EntityId,
ErrorType,
MsgFlags (..),
MsgId,
@@ -156,7 +158,7 @@ import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import System.Timeout (timeout)
import UnliftIO (async)
import UnliftIO (mapConcurrently)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -166,15 +168,19 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
type SMPTransportSession = TransportSession SMP.BrokerMsg
type NtfTransportSession = TransportSession NtfResponse
data AgentClient = AgentClient
{ active :: TVar Bool,
rcvQ :: TBQueue (ATransmission 'Client),
subQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue (ServerTransmission BrokerMsg),
smpServers :: TVar (NonEmpty SMPServerWithAuth),
smpClients :: TMap SMPServer SMPClientVar,
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
smpClients :: TMap SMPTransportSession SMPClientVar,
ntfServers :: TVar [NtfServer],
ntfClients :: TMap NtfServer NtfClientVar,
ntfClients :: TMap NtfTransportSession NtfClientVar,
useNetworkConfig :: TVar NetworkConfig,
subscrConns :: TVar (Set ConnId),
activeSubs :: TRcvQueues,
@@ -194,10 +200,12 @@ data AgentClient = AgentClient
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
-- locks to prevent concurrent operations with connection
connLocks :: TMap ConnId Lock,
-- lock to prevent concurrency between periodic and async connection deletions
deleteLock :: Lock,
-- locks to prevent concurrent reconnections to SMP servers
reconnectLocks :: TMap SMPServer Lock,
reconnections :: TVar [Async ()],
asyncClients :: TVar [Async ()],
reconnectLocks :: TMap SMPTransportSession Lock,
reconnections :: TAsyncs,
asyncClients :: TAsyncs,
agentStats :: TMap AgentStatsKey (TVar Int),
clientId :: Int,
agentEnv :: Env
@@ -222,12 +230,18 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
data AgentState = ASActive | ASSuspending | ASSuspended
deriving (Eq, Show)
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String}
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String}
deriving (Show, Generic)
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
data AgentStatsKey = AgentStatsKey {host :: ByteString, clientTs :: ByteString, cmd :: ByteString, res :: ByteString}
data AgentStatsKey = AgentStatsKey
{ userId :: UserId,
host :: ByteString,
clientTs :: ByteString,
cmd :: ByteString,
res :: ByteString
}
deriving (Eq, Ord, Show)
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
@@ -259,18 +273,54 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
agentState <- newTVar ASActive
getMsgLocks <- TM.empty
connLocks <- TM.empty
deleteLock <- createLock
reconnectLocks <- TM.empty
reconnections <- newTVar []
asyncClients <- newTVar []
reconnections <- newTAsyncs
asyncClients <- newTAsyncs
agentStats <- TM.empty
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, pendingMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, agentStats, clientId, agentEnv}
return
AgentClient
{ active,
rcvQ,
subQ,
msgQ,
smpServers,
smpClients,
ntfServers,
ntfClients,
useNetworkConfig,
subscrConns,
activeSubs,
pendingSubs,
pendingMsgsQueued,
smpQueueMsgQueues,
smpQueueMsgDeliveries,
connCmdsQueued,
asyncCmdQueues,
asyncCmdProcesses,
ntfNetworkOp,
rcvNetworkOp,
msgDeliveryOp,
sndNetworkOp,
databaseOp,
agentState,
getMsgLocks,
connLocks,
deleteLock,
reconnectLocks,
reconnections,
asyncClients,
agentStats,
clientId,
agentEnv
}
agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
class ProtocolServerClient msg where
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg)
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
clientProtocolError :: ErrorType -> AgentErrorType
instance ProtocolServerClient BrokerMsg where
@@ -281,19 +331,19 @@ instance ProtocolServerClient NtfResponse where
getProtocolServerClient = getNtfServerClient
clientProtocolError = NTF
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
atomically (getClientVar srv smpClients)
atomically (getClientVar tSess smpClients)
>>= either
(newProtocolClient c srv smpClients connectClient reconnectClient)
(waitForProtocolClient c srv)
(newProtocolClient c tSess smpClients connectClient reconnectSMPClient)
(waitForProtocolClient c tSess)
where
connectClient :: m SMPClient
connectClient = do
cfg <- getClientConfig c smpCfg
u <- askUnliftIO
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient tSess cfg (Just msgQ) $ clientDisconnected u)
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
clientDisconnected u client = do
@@ -302,86 +352,84 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
where
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
TM.delete tSess smpClients
qs <- RQ.getDelSessQueues tSess $ activeSubs c
mapM_ (`RQ.addQueue` pendingSubs c) qs
pure (qs, S.toList conns)
let cs = S.fromList $ map qConnId qs
cs' <- RQ.getConns $ activeSubs c
pure (qs, S.toList $ cs `S.difference` cs')
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
serverDown (qs, conns) = whenM (readTVarIO active) $ do
incClientStat c client "DISCONNECT" ""
incClientStat c userId client "DISCONNECT" ""
notifySub "" $ hostEvent DISCONNECT client
unless (null conns) $ notifySub "" $ DOWN srv conns
unless (null qs) $ do
atomically $ mapM_ (releaseGetLock c) qs
unliftIO u reconnectServer
unliftIO u $ reconnectServer c tSess
reconnectServer :: m ()
reconnectServer = do
a <- async tryReconnectClient
atomically $ modifyTVar' (reconnections c) (a :)
notifySub :: ConnId -> ACommand 'Agent -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
tryReconnectClient :: m ()
tryReconnectClient = do
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop ->
reconnectClient `catchError` const loop
reconnectClient :: m ()
reconnectClient =
withLockMap_ (reconnectLocks c) srv "reconnect" $
atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe
where
resubscribe :: [RcvQueue] -> m ()
resubscribe qs = do
connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar)
cs <- atomically . RQ.getConns $ activeSubs c
(client_, rs) <- subscribeQueues c srv qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
liftIO $ do
unless connected . forM_ client_ $ \cl -> do
incClientStat c cl "CONNECT" ""
notifySub "" $ hostEvent CONNECT cl
let conns = S.toList $ S.fromList okConns `S.difference` cs
unless (null conns) $ notifySub "" $ UP srv conns
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
mapM_ (throwError . snd) $ listToMaybe tempErrs
reconnectServer :: AgentMonad m => AgentClient -> SMPTransportSession -> m ()
reconnectServer c tSess = newAsyncAction tryReconnectSMPClient $ reconnections c
where
tryReconnectSMPClient aId = do
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop ->
reconnectSMPClient c tSess `catchError` const loop
atomically . removeAsyncAction aId $ reconnections c
reconnectSMPClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m ()
reconnectSMPClient c tSess@(_, srv, _) =
withLockMap_ (reconnectLocks c) tSess "reconnect" $
atomically (RQ.getSessQueues tSess $ pendingSubs c) >>= mapM_ resubscribe . L.nonEmpty
where
resubscribe :: NonEmpty RcvQueue -> m ()
resubscribe qs = do
cs <- atomically . RQ.getConns $ activeSubs c
rs <- subscribeQueues c $ L.toList qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
liftIO $ do
let conns = S.toList $ S.fromList okConns `S.difference` cs
unless (null conns) $ notifySub "" $ UP srv conns
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
mapM_ (throwError . snd) $ listToMaybe tempErrs
notifySub :: ConnId -> ACommand 'Agent -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
getNtfServerClient c@AgentClient {active, ntfClients} srv = do
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
atomically (getClientVar srv ntfClients)
atomically (getClientVar tSess ntfClients)
>>= either
(newProtocolClient c srv ntfClients connectClient $ pure ())
(waitForProtocolClient c srv)
(newProtocolClient c tSess ntfClients connectClient $ \_ _ -> pure ())
(waitForProtocolClient c tSess)
where
connectClient :: m NtfClient
connectClient = do
cfg <- getClientConfig c ntfCfg
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing clientDisconnected)
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient tSess cfg Nothing clientDisconnected)
clientDisconnected :: NtfClient -> IO ()
clientDisconnected client = do
atomically $ TM.delete srv ntfClients
incClientStat c client "DISCONNECT" ""
atomically $ TM.delete tSess ntfClients
incClientStat c userId client "DISCONNECT" ""
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients
getClientVar :: forall a s. TransportSession s -> TMap (TransportSession s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup tSess clients
where
newClientVar :: STM (TMVar a)
newClientVar = do
var <- newEmptyTMVar
TM.insert srv var clients
TM.insert tSess var clients
pure var
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ClientVar msg -> m (ProtocolClient msg)
waitForProtocolClient c srv clientVar = do
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
waitForProtocolClient c (_, srv, _) clientVar = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
liftEither $ case client_ of
@@ -393,39 +441,38 @@ newProtocolClient ::
forall msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
AgentClient ->
ProtoServer msg ->
TMap (ProtoServer msg) (ClientVar msg) ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
m (ProtocolClient msg) ->
m () ->
(AgentClient -> TransportSession msg -> m ()) ->
ClientVar msg ->
m (ProtocolClient msg)
newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
where
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
tryConnectClient successAction retryAction =
tryError connectClient >>= \r -> case r of
Right client -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
atomically $ putTMVar clientVar r
liftIO $ incClientStat c client "CLIENT" "OK"
liftIO $ incClientStat c userId client "CLIENT" "OK"
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
successAction client
Left e -> do
liftIO $ incServerStat c srv "CLIENT" $ strEncode e
liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e
if temporaryAgentError e
then retryAction
else atomically $ do
putTMVar clientVar (Left e)
TM.delete srv clients
TM.delete tSess clients
throwError e
tryConnectAsync :: m ()
tryConnectAsync = do
a <- async connectAsync
atomically $ modifyTVar' (asyncClients c) (a :)
connectAsync :: m ()
connectAsync = do
tryConnectAsync = newAsyncAction connectAsync $ asyncClients c
connectAsync :: Int -> m ()
connectAsync aId = do
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop
atomically . removeAsyncAction aId $ asyncClients c
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client
@@ -441,8 +488,8 @@ closeAgentClient c = liftIO $ do
atomically $ writeTVar (active c) False
closeProtocolServerClients c smpClients
closeProtocolServerClients c ntfClients
cancelActions $ reconnections c
cancelActions $ asyncClients c
cancelActions . actions $ reconnections c
cancelActions . actions $ asyncClients c
cancelActions $ smpQueueMsgDeliveries c
cancelActions $ asyncCmdProcesses c
atomically . RQ.clear $ activeSubs c
@@ -471,7 +518,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
where
k = (server, sndId)
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
where
@@ -492,32 +539,45 @@ withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
where
newLock = newEmptyTMVar >>= \l -> TM.insert key l locks $> l
newLock = createLock >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
withClient_ c srv statCmd action = do
cl <- getProtocolServerClient c srv
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchError` logServerError cl
where
stat cl = liftIO . incClientStat c cl statCmd
stat cl = liftIO . incClientStat c userId cl statCmd
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
logServerError cl e = do
logServer "<--" c srv "" $ strEncode e
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a
withLogClient_ c srv qId cmdStr action = do
logServer "-->" c srv qId cmdStr
res <- withClient_ c srv cmdStr action
logServer "<--" c srv qId "OK"
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c srv statKey action = withClient_ c srv statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
withSMPClient c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient c tSess (queueId q) cmdStr action
withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a
withSMPClient_ c q cmdStr action = do
tSess <- mkSMPTransportSession c q
withLogClient_ c tSess (queueId q) cmdStr action
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
withNtfClient c srv = withLogClient c (0, srv, Nothing)
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
liftClient protocolError_ = liftError . protocolClientError protocolError_
@@ -551,12 +611,13 @@ instance ToJSON SMPTestFailure where
toEncoding = J.genericToEncoding J.defaultOptions
toJSON = J.genericToJSON J.defaultOptions
runSMPServerTest :: AgentMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
runSMPServerTest c (ProtoServerWithAuth srv auth) = do
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
cfg <- getClientConfig c smpCfg
C.SignAlg a <- asks $ cmdSignAlg . config
liftIO $ do
getProtocolClient srv cfg Nothing (\_ -> pure ()) >>= \case
let tSess = (userId, srv, Nothing)
getProtocolClient tSess cfg Nothing (\_ -> pure ()) >>= \case
Right smp -> do
(rKey, rpKey) <- C.generateSignatureKeyPair a
(sKey, _) <- C.generateSignatureKeyPair a
@@ -566,7 +627,7 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
liftError (testErr TSSecureQueue) $ secureSMPQueue smp rpKey rcvId sKey
liftError (testErr TSDeleteQueue) $ deleteSMPQueue smp rpKey rcvId
ok <- tcpTimeout (networkConfig cfg) `timeout` closeProtocolClient smp
incClientStat c smp "TEST" "OK"
incClientStat c userId smp "TEST" "OK"
pure $ either Just (const Nothing) r <|> maybe (Just (SMPTestFailure TSDisconnect $ BROKER addr TIMEOUT)) (const Nothing) ok
Left e -> pure (Just $ testErr TSConnect e)
where
@@ -574,19 +635,36 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
testErr step = SMPTestFailure step . protocolClientError SMP addr
newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c
mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg
mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c
mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession
mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
getSessionMode :: AgentMonad m => AgentClient -> m TransportSessionMode
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
C.SignAlg a <- asks (cmdSignAlg . config)
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
logServer "-->" c srv "" "NEW"
tSess <- mkTransportSession c userId srv connId
QIK {rcvId, sndId, rcvPublicDhKey} <-
withClient c srv "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
withClient c tSess "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
let rq =
RcvQueue
{ connId,
{ userId,
connId,
server = srv,
rcvId,
rcvPrivateKey,
@@ -599,20 +677,11 @@ newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do
primary = True,
dbReplaceQueueId = Nothing,
smpClientVersion = maxVersion vRange,
clientNtfCreds = Nothing
clientNtfCreds = Nothing,
deleteErrors = 0
}
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ do
modifyTVar' (subscrConns c) $ S.insert connId
RQ.addQueue rq $ pendingSubs c
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
>>= either throwError pure
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult c rq r = do
case r of
@@ -639,34 +708,59 @@ temporaryOrHostError = \case
BROKER _ HOST -> True
e -> temporaryAgentError e
-- | subscribe multiple queues - all passed queues should be on the same server
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do
modifyTVar' (subscrConns c) $ S.insert connId
-- | Subscribe to queues. The list of results can have a different order.
subscribeQueues :: forall m. AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
subscribeQueues c qs = do
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
forM_ qs' $ \rq@RcvQueue {connId} -> atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
RQ.addQueue rq $ pendingSubs c
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
(eitherToMaybe smp_,) . (errs <>) <$> case smp_ of
Left e -> pure $ map (,Left e) qs_
Right smp -> do
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
let qs2 = L.map queueCreds qs'
n = (length qs2 - 1) `div` 90 + 1
liftIO $ incClientStatN c smp n "SUBS" "OK"
liftIO $ do
rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
mapM_ (uncurry $ processSubResult c) rs
pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs
_ -> pure (Nothing, errs)
u <- askUnliftIO
(errs <>) <$> sendTSessionBatches "SUB" 90 (subscribeQueues_ u) c qs
where
checkQueue rq@RcvQueue {rcvId, server} = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED) else Right rq
subscribeQueues_ u smp qs' = do
rs <- sendBatch subscribeSMPQueues smp qs'
mapM_ (uncurry $ processSubResult c) rs
when (any temporaryClientError . lefts . map snd $ L.toList rs) $
unliftIO u $ reconnectServer c $ transportSession' smp
pure rs
type BatchResponses e = (NonEmpty (RcvQueue, Either e ()))
-- statBatchSize is not used to batch the commands, only for traffic statistics
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
sendTSessionBatches statCmd statBatchSize action c qs =
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
where
batchQueues :: m [(SMPTransportSession, NonEmpty RcvQueue)]
batchQueues = do
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
pure . M.assocs $ foldl' (batch mode) M.empty qs
where
batch mode m rq =
let tSess = mkSMPTSession rq mode
in M.alter (Just . maybe [rq] (rq <|)) tSess m
sendClientBatch :: (SMPTransportSession, NonEmpty RcvQueue) -> m (BatchResponses AgentErrorType)
sendClientBatch (tSess@(userId, srv, _), qs') =
tryError (getSMPServerClient c tSess) >>= \case
Left e -> pure $ L.map (,Left e) qs'
Right smp -> liftIO $ do
logServer "-->" c srv (bshow (length qs') <> " queues") statCmd
rs <- L.map agentError <$> action smp qs'
statBatch
pure rs
where
agentError = second . first $ protocolClientError SMP $ clientServer smp
statBatch =
let n = (length qs - 1) `div` statBatchSize + 1
in incClientStatN c userId smp n (statCmd <> "S") "OK"
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)
sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
where
checkQueue rq@RcvQueue {rcvId, server}
| server == srv = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
| otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter")
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
@@ -699,16 +793,17 @@ logSecret :: ByteString -> ByteString
logSecret bs = encode $ B.take 3 bs
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
withSMPClient_ c sq "SEND <CONF>" $ \smp -> do
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo =
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
tSess <- mkTransportSession c userId smpServer senderId
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
msg <- mkInvitation
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
where
@@ -722,7 +817,7 @@ sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderI
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
atomically createTakeGetLock
(v, msg_) <- withLogClient c server rcvId "GET" $ \smp ->
(v, msg_) <- withSMPClient c rq "GET" $ \smp ->
(thVersion smp,) <$> getSMPMessage smp rcvPrivateKey rcvId
mapM (decryptMeta v) msg_
where
@@ -742,23 +837,23 @@ decryptSMPMessage v rq SMP.RcvMessage {msgId, msgTs, msgFlags, msgBody = SMP.Enc
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
withLogClient c server rcvId "KEY <key>" $ \smp ->
secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey =
withSMPClient c rq "KEY <key>" $ \smp ->
secureSMPQueue smp rcvPrivateKey rcvId senderKey
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (NotifierId, RcvNtfPublicDhKey)
enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
withLogClient c server rcvId "NKEY <nkey>" $ \smp ->
enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
withSMPClient c rq "NKEY <nkey>" $ \smp ->
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
disableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogClient c server rcvId "NDEL" $ \smp ->
disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "NDEL" $ \smp ->
disableSMPQueueNotifications smp rcvPrivateKey rcvId
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do
withLogClient c server rcvId "ACK" $ \smp ->
sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do
withSMPClient c rq "ACK" $ \smp ->
ackSMPMessage smp rcvPrivateKey rcvId msgId
atomically $ releaseGetLock c rq
@@ -767,57 +862,60 @@ releaseGetLock c RcvQueue {server, rcvId} =
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogClient c server rcvId "OFF" $ \smp ->
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
withSMPClient c rq "OFF" $ \smp ->
suspendSMPQueue smp rcvPrivateKey rcvId
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogClient c server rcvId "DEL" $ \smp ->
deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = do
withSMPClient c rq "DEL" $ \smp ->
deleteSMPQueue smp rcvPrivateKey rcvId
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msgFlags agentMsg =
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
deleteQueues :: forall m. AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
deleteQueues = sendTSessionBatches "DEL" 90 $ sendBatch deleteSMPQueues
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg =
withSMPClient_ c sq "SEND <MSG>" $ \smp -> do
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
withClient c ntfServer "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
withLogClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> NtfPrivateSignKey -> m NtfSubscriptionId
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
withLogClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
@@ -948,18 +1046,18 @@ incStat AgentClient {agentStats} n k = do
Just v -> modifyTVar' v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
incClientStat c pc = incClientStatN c pc 1
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
incServerStat :: AgentClient -> ProtocolServer p -> ByteString -> ByteString -> IO ()
incServerStat c ProtocolServer {host} cmd res = do
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
incServerStat c userId ProtocolServer {host} cmd res = do
threadDelay 100000
atomically $ incStat c 1 statsKey
where
statsKey = AgentStatsKey {host = strEncode $ L.head host, clientTs = "", cmd, res}
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
incClientStatN :: AgentClient -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c pc n cmd res = do
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c userId pc n cmd res = do
atomically $ incStat c n statsKey
where
statsKey = AgentStatsKey {host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
statsKey = AgentStatsKey {userId, host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
+9 -1
View File
@@ -32,12 +32,14 @@ import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
import Data.List.NonEmpty (NonEmpty)
import Data.Map (Map)
import Data.Time.Clock (NominalDiffTime, nominalDay)
import Data.Word (Word16)
import Network.Socket
import Numeric.Natural
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Store (UserId)
import Simplex.Messaging.Agent.Store.SQLite
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client
@@ -59,7 +61,7 @@ import UnliftIO.STM
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
data InitialAgentServers = InitialAgentServers
{ smp :: NonEmpty SMPServerWithAuth,
{ smp :: Map UserId (NonEmpty SMPServerWithAuth),
ntf :: [NtfServer],
netCfg :: NetworkConfig
}
@@ -86,6 +88,9 @@ data AgentConfig = AgentConfig
messageRetryInterval :: RetryInterval2,
messageTimeout :: NominalDiffTime,
helloTimeout :: NominalDiffTime,
initialCleanupDelay :: Int,
cleanupInterval :: Int,
deleteErrorCount :: Int,
ntfCron :: Word16,
ntfWorkerDelay :: Int,
ntfSMPWorkerDelay :: Int,
@@ -143,6 +148,9 @@ defaultAgentConfig =
messageRetryInterval = defaultMessageRetryInterval,
messageTimeout = 2 * nominalDay,
helloTimeout = 2 * nominalDay,
initialCleanupDelay = 30 * 1000000, -- 30 seconds
cleanupInterval = 30 * 60 * 1000000, -- 30 minutes
deleteErrorCount = 10,
ntfCron = 20, -- minutes
ntfWorkerDelay = 100000, -- microseconds
ntfSMPWorkerDelay = 500000, -- microseconds
+4
View File
@@ -10,6 +10,10 @@ import UnliftIO.STM
type Lock = TMVar String
createLock :: STM Lock
createLock = newEmptyTMVar
{-# INLINE createLock #-}
withLock :: MonadUnliftIO m => TMVar String -> String -> m a -> m a
withLock lock name =
E.bracket_
+21
View File
@@ -272,6 +272,9 @@ data ACommand (p :: AParty) where
SWCH :: ACommand Client
OFF :: ACommand Client
DEL :: ACommand Client
DEL_RCVQ :: SMPServer -> SMP.RecipientId -> Maybe AgentErrorType -> ACommand Agent
DEL_CONN :: ACommand Agent
DEL_USER :: Int64 -> ACommand Agent
CHK :: ACommand Client
STAT :: ConnectionStats -> ACommand Agent
OK :: ACommand Agent
@@ -311,6 +314,9 @@ data ACommandTag (p :: AParty) where
SWCH_ :: ACommandTag Client
OFF_ :: ACommandTag Client
DEL_ :: ACommandTag Client
DEL_RCVQ_ :: ACommandTag Agent
DEL_CONN_ :: ACommandTag Agent
DEL_USER_ :: ACommandTag Agent
CHK_ :: ACommandTag Client
STAT_ :: ACommandTag Agent
OK_ :: ACommandTag Agent
@@ -349,6 +355,9 @@ aCommandTag = \case
SWCH -> SWCH_
OFF -> OFF_
DEL -> DEL_
DEL_RCVQ {} -> DEL_RCVQ_
DEL_CONN -> DEL_CONN_
DEL_USER _ -> DEL_USER_
CHK -> CHK_
STAT _ -> STAT_
OK -> OK_
@@ -1225,6 +1234,9 @@ instance StrEncoding ACmdTag where
"SWCH" -> pure $ ACmdTag SClient SWCH_
"OFF" -> pure $ ACmdTag SClient OFF_
"DEL" -> pure $ ACmdTag SClient DEL_
"DEL_RCVQ" -> pure $ ACmdTag SAgent DEL_RCVQ_
"DEL_CONN" -> pure $ ACmdTag SAgent DEL_CONN_
"DEL_USER" -> pure $ ACmdTag SAgent DEL_USER_
"CHK" -> pure $ ACmdTag SClient CHK_
"STAT" -> pure $ ACmdTag SAgent STAT_
"OK" -> pure $ ACmdTag SAgent OK_
@@ -1260,6 +1272,9 @@ instance APartyI p => StrEncoding (ACommandTag p) where
SWCH_ -> "SWCH"
OFF_ -> "OFF"
DEL_ -> "DEL"
DEL_RCVQ_ -> "DEL_RCVQ"
DEL_CONN_ -> "DEL_CONN"
DEL_USER_ -> "DEL_USER"
CHK_ -> "CHK"
STAT_ -> "STAT"
OK_ -> "OK"
@@ -1308,6 +1323,9 @@ commandP binaryP =
SENT_ -> s (SENT <$> A.decimal)
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
MSG_ -> s (MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> binaryP)
DEL_RCVQ_ -> s (DEL_RCVQ <$> strP_ <*> strP_ <*> strP)
DEL_CONN_ -> pure DEL_CONN
DEL_USER_ -> s (DEL_USER <$> strP)
STAT_ -> s (STAT <$> strP)
OK_ -> pure OK
ERR_ -> s (ERR <$> strP)
@@ -1356,6 +1374,9 @@ serializeCommand = \case
SWCH -> s SWCH_
OFF -> s OFF_
DEL -> s DEL_
DEL_RCVQ srv rcvId err_ -> s (DEL_RCVQ_, srv, rcvId, err_)
DEL_CONN -> s DEL_CONN_
DEL_USER userId -> s (DEL_USER_, userId)
CHK -> s CHK_
STAT srvs -> s (STAT_, srvs)
CON -> s CON_
+40 -3
View File
@@ -34,6 +34,7 @@ import Simplex.Messaging.Protocol
NotifierId,
NtfPrivateSignKey,
NtfPublicVerifyKey,
QueueId,
RcvDhSecret,
RcvNtfDhSecret,
RcvPrivateSignKey,
@@ -47,7 +48,8 @@ import Simplex.Messaging.Version
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
data RcvQueue = RcvQueue
{ connId :: ConnId,
{ userId :: UserId,
connId :: ConnId,
server :: SMPServer,
-- | recipient queue ID
rcvId :: SMP.RecipientId,
@@ -72,7 +74,8 @@ data RcvQueue = RcvQueue
-- | SMP client version
smpClientVersion :: Version,
-- | credentials used in context of notifications
clientNtfCreds :: Maybe ClientNtfCreds
clientNtfCreds :: Maybe ClientNtfCreds,
deleteErrors :: Int
}
deriving (Eq, Show)
@@ -89,7 +92,8 @@ data ClientNtfCreds = ClientNtfCreds
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
data SndQueue = SndQueue
{ connId :: ConnId,
{ userId :: UserId,
connId :: ConnId,
server :: SMPServer,
-- | sender queue ID
sndId :: SMP.SenderId,
@@ -150,6 +154,27 @@ findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue
findRQ sAddr = find $ sameQAddress sAddr . sndAddress
{-# INLINE findRQ #-}
class SMPQueue q => SMPQueueRec q where
qUserId :: q -> UserId
qConnId :: q -> ConnId
queueId :: q -> QueueId
instance SMPQueueRec RcvQueue where
qUserId = userId
{-# INLINE qUserId #-}
qConnId = connId
{-# INLINE qConnId #-}
queueId = rcvId
{-# INLINE queueId #-}
instance SMPQueueRec SndQueue where
qUserId = userId
{-# INLINE qUserId #-}
qConnId = connId
{-# INLINE qConnId #-}
queueId = sndId
{-# INLINE queueId #-}
-- * Connection types
-- | Type of a connection.
@@ -222,6 +247,7 @@ deriving instance Show SomeConn
data ConnData = ConnData
{ connId :: ConnId,
userId :: UserId,
connAgentVersion :: Version,
enableNtfs :: Bool,
duplexHandshake :: Maybe Bool, -- added in agent protocol v2
@@ -229,8 +255,17 @@ data ConnData = ConnData
}
deriving (Eq, Show)
data PendingCommand = PendingCommand
{ corrId :: ACorrId,
userId :: UserId,
connId :: ConnId,
command :: AgentCommand
}
data AgentCmdType = ACClient | ACInternal
type UserId = Int64
instance StrEncoding AgentCmdType where
strEncode = \case
ACClient -> "CLIENT"
@@ -471,6 +506,8 @@ data StoreError
SEInternal ByteString
| -- | Failed to generate unique random ID
SEUniqueID
| -- | User ID not found
SEUserNotFound
| -- | Connection not found (or both queues absent).
SEConnNotFound
| -- | Connection already used.
+113 -30
View File
@@ -27,6 +27,14 @@ module Simplex.Messaging.Agent.Store.SQLite
sqlString,
execSQL,
-- * Users
createUserRecord,
deleteUserRecord,
setUserDeleted,
deleteUserWithoutConns,
deleteUsersWithoutConns,
checkUser,
-- * Queues and connections
createNewConn,
updateNewConnRcv,
@@ -34,9 +42,10 @@ module Simplex.Messaging.Agent.Store.SQLite
createRcvConn,
createSndConn,
getConn,
getAnyConn,
getDeletedConn,
getConnData,
setConnDeleted,
getDeletedConns,
getRcvConn,
deleteConn,
upgradeRcvConnToDuplex,
@@ -49,6 +58,7 @@ module Simplex.Messaging.Agent.Store.SQLite
setRcvQueuePrimary,
setSndQueuePrimary,
deleteConnRcvQueue,
incRcvDeleteErrors,
deleteConnSndQueue,
getPrimaryRcvQueue,
getRcvQueue,
@@ -286,7 +296,7 @@ withConnection SQLiteStore {dbConnection} =
(atomically . putTMVar dbConnection)
withTransaction :: forall a. SQLiteStore -> (DB.Connection -> IO a) -> IO a
withTransaction st action = withConnection st $ loop 500 2_000_000
withTransaction st action = withConnection st $ loop 500 3_000_000
where
loop :: Int -> Int -> DB.Connection -> IO a
loop t tLim db =
@@ -297,6 +307,60 @@ withTransaction st action = withConnection st $ loop 500 2_000_000
loop (t * 9 `div` 8) (tLim - t) db
else E.throwIO e
createUserRecord :: DB.Connection -> IO UserId
createUserRecord db = do
DB.execute_ db "INSERT INTO users DEFAULT VALUES"
insertedRowId db
checkUser :: DB.Connection -> UserId -> IO (Either StoreError ())
checkUser db userId =
firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $
DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, False)
deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ())
deleteUserRecord db userId = runExceptT $ do
ExceptT $ checkUser db userId
liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId)
setUserDeleted :: DB.Connection -> UserId -> IO (Either StoreError [ConnId])
setUserDeleted db userId = runExceptT $ do
ExceptT $ checkUser db userId
liftIO $ do
DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (True, userId)
map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE user_id = ?" (Only userId)
deleteUserWithoutConns :: DB.Connection -> UserId -> IO Bool
deleteUserWithoutConns db userId = do
userId_ :: Maybe Int64 <-
maybeFirstRow fromOnly $
DB.query
db
[sql|
SELECT user_id FROM users u
WHERE u.user_id = ?
AND u.deleted = ?
AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id)
|]
(userId, True)
case userId_ of
Just _ -> DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) $> True
_ -> pure False
deleteUsersWithoutConns :: DB.Connection -> IO [Int64]
deleteUsersWithoutConns db = do
userIds <-
map fromOnly
<$> DB.query
db
[sql|
SELECT user_id FROM users u
WHERE u.deleted = ?
AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id)
|]
(Only True)
forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only
pure userIds
createConn_ ::
TVar ChaChaDRG ->
ConnData ->
@@ -307,9 +371,9 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
ConnData {connId} -> create connId $> Right connId
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode =
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} cMode =
createConn_ gVar cData $ \connId -> do
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
updateNewConnRcv db connId rq =
@@ -332,17 +396,17 @@ updateNewConnSnd db connId sq =
updateConn = Right <$> addConnSndQueue_ db connId sq
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertRcvQueue_ db connId q
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
void $ insertSndQueue_ db connId q
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
@@ -455,8 +519,12 @@ setSndQueuePrimary db connId SndQueue {dbQueueId} = do
"UPDATE snd_queues SET snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?"
(True, Nothing :: Maybe Int64, connId, dbQueueId)
deleteConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO ()
deleteConnRcvQueue db connId RcvQueue {dbQueueId} =
incRcvDeleteErrors :: DB.Connection -> RcvQueue -> IO ()
incRcvDeleteErrors db RcvQueue {connId, dbQueueId} =
DB.execute db "UPDATE rcv_queues SET delete_errors = delete_errors + 1 WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
deleteConnRcvQueue :: DB.Connection -> RcvQueue -> IO ()
deleteConnRcvQueue db RcvQueue {connId, dbQueueId} =
DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO ()
@@ -820,13 +888,20 @@ getPendingCommands db connId = do
where
srvCmdId (host, port, keyHash, cmdId) = (SMPServer <$> host <*> port <*> keyHash, cmdId)
getPendingCommand :: DB.Connection -> AsyncCmdId -> IO (Either StoreError (ACorrId, ConnId, AgentCommand))
getPendingCommand :: DB.Connection -> AsyncCmdId -> IO (Either StoreError PendingCommand)
getPendingCommand db msgId = do
firstRow id SECmdNotFound $
firstRow pendingCommand SECmdNotFound $
DB.query
db
"SELECT corr_id, conn_id, command FROM commands WHERE command_id = ?"
[sql|
SELECT c.corr_id, cs.user_id, c.conn_id, c.command
FROM commands c
JOIN connections cs USING (conn_id)
WHERE c.command_id = ?
|]
(Only msgId)
where
pendingCommand (corrId, userId, connId, command) = PendingCommand {corrId, userId, connId, command}
deleteCommand :: DB.Connection -> AsyncCmdId -> IO ()
deleteCommand db cmdId =
@@ -1307,10 +1382,13 @@ newQueueId_ (Only maxId : _) = maxId + 1
-- * getConn helpers
getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConn db connId = getAnyConn db connId False
getConn = getAnyConn False
getAnyConn :: DB.Connection -> ConnId -> Bool -> IO (Either StoreError SomeConn)
getAnyConn dbConn connId deleted' =
getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getDeletedConn = getAnyConn True
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getAnyConn deleted' dbConn connId =
getConnData dbConn connId >>= \case
Nothing -> pure $ Left SEConnNotFound
Just (cData@ConnData {deleted}, cMode)
@@ -1328,13 +1406,16 @@ getAnyConn dbConn connId deleted' =
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData dbConn connId' =
maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
maybeFirstRow cData $ DB.query dbConn "SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
where
cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
setConnDeleted :: DB.Connection -> ConnId -> IO ()
setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
getDeletedConns :: DB.Connection -> IO [ConnId]
getDeletedConns db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True)
-- | returns all connection queues, the first queue is the primary one
getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue))
getRcvQueuesByConnId_ db connId =
@@ -1348,31 +1429,32 @@ getRcvQueuesByConnId_ db connId =
rcvQueueQuery :: Query
rcvQueueQuery =
[sql|
SELECT s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version,
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, q.delete_errors,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN connections c ON q.conn_id = c.conn_id
|]
toRcvQueue ::
(C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
:. (Int64, Bool, Maybe Int64, Maybe Version)
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
:. (Int64, Bool, Maybe Int64, Maybe Version, Int)
:. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
RcvQueue
toRcvQueue ((keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
let server = SMPServer host port keyHash
smpClientVersion = fromMaybe 1 smpClientVersion_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds, deleteErrors}
getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
getRcvQueueById_ db connId dbRcvId =
firstRow toRcvQueue SEConnNotFound $
DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId)
DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ?") (connId, dbRcvId)
-- | returns all connection queues, the first queue is the primary one
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
@@ -1381,16 +1463,17 @@ getSndQueuesByConnId_ dbConn connId =
<$> DB.query
dbConn
[sql|
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
FROM snd_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN servers s ON q.host = s.host AND q.port = s.port
JOIN connections c ON q.conn_id = c.conn_id
WHERE q.conn_id = ?;
|]
(Only connId)
where
sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
sndQueue ((userId, keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
let server = SMPServer host port keyHash
in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
in SndQueue {userId, connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
-- the current primary queue is ordered first, the next primary - second
compare (Down p) (Down p') <> compare i i'
@@ -37,6 +37,9 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -53,7 +56,10 @@ schemaMigrations =
("m20220811_onion_hosts", m20220811_onion_hosts),
("m20220817_connection_ntfs", m20220817_connection_ntfs),
("m20220905_commands", m20220905_commands),
("m20220915_connection_queues", m20220915_connection_queues)
("m20220915_connection_queues", m20220915_connection_queues),
("m20230110_users", m20230110_users),
("m20230117_fkey_indexes", m20230117_fkey_indexes),
("m20230120_delete_errors", m20230120_delete_errors)
]
-- | The list of migrations in ascending order by date
@@ -0,0 +1,29 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20230110_users :: Query
m20230110_users =
[sql|
PRAGMA ignore_check_constraints=ON;
CREATE TABLE users (
user_id INTEGER PRIMARY KEY AUTOINCREMENT
);
INSERT INTO users (user_id) VALUES (1);
ALTER TABLE connections ADD COLUMN user_id INTEGER CHECK (user_id NOT NULL)
REFERENCES users ON DELETE CASCADE;
CREATE INDEX idx_connections_user ON connections(user_id);
CREATE INDEX idx_commands_conn_id ON commands(conn_id);
UPDATE connections SET user_id = 1;
PRAGMA ignore_check_constraints=OFF;
|]
@@ -0,0 +1,27 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
-- .lint fkey-indexes
m20230117_fkey_indexes :: Query
m20230117_fkey_indexes =
[sql|
CREATE INDEX idx_commands_host_port ON commands(host, port);
CREATE INDEX idx_conn_confirmations_conn_id ON conn_confirmations(conn_id);
CREATE INDEX idx_conn_invitations_contact_conn_id ON conn_invitations(contact_conn_id);
CREATE INDEX idx_messages_conn_id_internal_snd_id ON messages(conn_id, internal_snd_id);
CREATE INDEX idx_messages_conn_id_internal_rcv_id ON messages(conn_id, internal_rcv_id);
CREATE INDEX idx_messages_conn_id ON messages(conn_id);
CREATE INDEX idx_ntf_subscriptions_ntf_host_ntf_port ON ntf_subscriptions(ntf_host, ntf_port);
CREATE INDEX idx_ntf_subscriptions_smp_host_smp_port ON ntf_subscriptions(smp_host, smp_port);
CREATE INDEX idx_ntf_tokens_ntf_host_ntf_port ON ntf_tokens(ntf_host, ntf_port);
CREATE INDEX idx_ratchets_conn_id ON ratchets(conn_id);
CREATE INDEX idx_rcv_messages_conn_id_internal_id ON rcv_messages(conn_id, internal_id);
CREATE INDEX idx_skipped_messages_conn_id ON skipped_messages(conn_id);
CREATE INDEX idx_snd_message_deliveries_conn_id_internal_id ON snd_message_deliveries(conn_id, internal_id);
CREATE INDEX idx_snd_messages_conn_id_internal_id ON snd_messages(conn_id, internal_id);
CREATE INDEX idx_snd_queues_host_port ON snd_queues(host, port);
|]
@@ -0,0 +1,20 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20230120_delete_errors :: Query
m20230120_delete_errors =
[sql|
PRAGMA ignore_check_constraints=ON;
ALTER TABLE rcv_queues ADD COLUMN delete_errors INTEGER DEFAULT 0 CHECK (delete_errors NOT NULL);
UPDATE rcv_queues SET delete_errors = 0;
ALTER TABLE users ADD COLUMN deleted INTEGER DEFAULT 0 CHECK (deleted NOT NULL);
UPDATE users SET deleted = 0;
PRAGMA ignore_check_constraints=OFF;
|]
@@ -22,7 +22,9 @@ CREATE TABLE connections(
,
duplex_handshake INTEGER NULL DEFAULT 0,
enable_ntfs INTEGER,
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL)
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL),
user_id INTEGER CHECK(user_id NOT NULL)
REFERENCES users ON DELETE CASCADE
) WITHOUT ROWID;
CREATE TABLE rcv_queues(
host TEXT NOT NULL,
@@ -45,6 +47,7 @@ CREATE TABLE rcv_queues(
rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL),
rcv_primary INTEGER CHECK(rcv_primary NOT NULL),
replace_rcv_queue_id INTEGER NULL,
delete_errors INTEGER DEFAULT 0 CHECK(delete_errors NOT NULL),
PRIMARY KEY(host, port, rcv_id),
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE,
@@ -228,3 +231,51 @@ CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries(
conn_id,
snd_queue_id
);
CREATE TABLE users(
user_id INTEGER PRIMARY KEY AUTOINCREMENT
,
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL)
);
CREATE INDEX idx_connections_user ON connections(user_id);
CREATE INDEX idx_commands_conn_id ON commands(conn_id);
CREATE INDEX idx_commands_host_port ON commands(host, port);
CREATE INDEX idx_conn_confirmations_conn_id ON conn_confirmations(conn_id);
CREATE INDEX idx_conn_invitations_contact_conn_id ON conn_invitations(
contact_conn_id
);
CREATE INDEX idx_messages_conn_id_internal_snd_id ON messages(
conn_id,
internal_snd_id
);
CREATE INDEX idx_messages_conn_id_internal_rcv_id ON messages(
conn_id,
internal_rcv_id
);
CREATE INDEX idx_messages_conn_id ON messages(conn_id);
CREATE INDEX idx_ntf_subscriptions_ntf_host_ntf_port ON ntf_subscriptions(
ntf_host,
ntf_port
);
CREATE INDEX idx_ntf_subscriptions_smp_host_smp_port ON ntf_subscriptions(
smp_host,
smp_port
);
CREATE INDEX idx_ntf_tokens_ntf_host_ntf_port ON ntf_tokens(
ntf_host,
ntf_port
);
CREATE INDEX idx_ratchets_conn_id ON ratchets(conn_id);
CREATE INDEX idx_rcv_messages_conn_id_internal_id ON rcv_messages(
conn_id,
internal_id
);
CREATE INDEX idx_skipped_messages_conn_id ON skipped_messages(conn_id);
CREATE INDEX idx_snd_message_deliveries_conn_id_internal_id ON snd_message_deliveries(
conn_id,
internal_id
);
CREATE INDEX idx_snd_messages_conn_id_internal_id ON snd_messages(
conn_id,
internal_id
);
CREATE INDEX idx_snd_queues_host_port ON snd_queues(host, port);
+25
View File
@@ -0,0 +1,25 @@
module Simplex.Messaging.Agent.TAsyncs where
import Control.Concurrent.STM (stateTVar)
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import UnliftIO.Async (Async, async)
import UnliftIO.STM
data TAsyncs = TAsyncs
{ actionId :: TVar Int,
actions :: TMap Int (Async ())
}
newTAsyncs :: STM TAsyncs
newTAsyncs = TAsyncs <$> newTVar 0 <*> TM.empty
newAsyncAction :: MonadUnliftIO m => (Int -> m ()) -> TAsyncs -> m ()
newAsyncAction action as = do
aId <- atomically $ stateTVar (actionId as) $ \i -> (i + 1, i + 1)
a <- async $ action aId
atomically $ TM.insert aId a $ actions as
removeAsyncAction :: Int -> TAsyncs -> STM ()
removeAsyncAction aId = TM.delete aId . actions
+32 -15
View File
@@ -1,18 +1,28 @@
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.Agent.TRcvQueues where
module Simplex.Messaging.Agent.TRcvQueues
( TRcvQueues,
empty,
clear,
deleteConn,
hasConn,
getConns,
addQueue,
deleteQueue,
getSessQueues,
getDelSessQueues,
)
where
import Control.Concurrent.STM
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
import Simplex.Messaging.Agent.Protocol (ConnId)
import Simplex.Messaging.Agent.Store (RcvQueue (..))
import Simplex.Messaging.Agent.Store (RcvQueue (..), UserId)
import Simplex.Messaging.Protocol (RecipientId, SMPServer)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
newtype TRcvQueues = TRcvQueues (TMap (SMPServer, RecipientId) RcvQueue)
newtype TRcvQueues = TRcvQueues (TMap (UserId, SMPServer, RecipientId) RcvQueue)
empty :: STM TRcvQueues
empty = TRcvQueues <$> TM.empty
@@ -30,19 +40,26 @@ getConns :: TRcvQueues -> STM (Set ConnId)
getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs
addQueue :: RcvQueue -> TRcvQueues -> STM ()
addQueue rq@RcvQueue {server, rcvId} (TRcvQueues qs) = TM.insert (server, rcvId) rq qs
addQueue rq (TRcvQueues qs) = TM.insert (qKey rq) rq qs
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
deleteQueue RcvQueue {server, rcvId} (TRcvQueues qs) = TM.delete (server, rcvId) qs
deleteQueue rq (TRcvQueues qs) = TM.delete (qKey rq) qs
getSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
getSessQueues tSess (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
where
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs'
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty)
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
getDelSessQueues tSess (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
where
addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId}
| srv == server = ((rq : remQs, S.insert connId remConns), qs')
| otherwise = (removed, M.insert (server, rcvId) rq qs')
addQ (removed, qs') rq
| rq `isSession` tSess = (rq : removed, qs')
| otherwise = (removed, M.insert (qKey rq) rq qs')
isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool
isSession rq (uId, srv, connId_) =
userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_
qKey :: RcvQueue -> (UserId, SMPServer, ConnId)
qKey rq = (userId rq, server rq, connId rq)
+85 -31
View File
@@ -26,12 +26,14 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
module Simplex.Messaging.Client
( -- * Connect (disconnect) client to (from) SMP server
TransportSession,
ProtocolClient (thVersion, sessionId, sessionTs),
SMPClient,
getProtocolClient,
closeProtocolClient,
clientServer,
transportHost',
transportSession',
-- * SMP protocol command functions
createSMPQueue,
@@ -48,12 +50,14 @@ module Simplex.Messaging.Client
ackSMPMessage,
suspendSMPQueue,
deleteSMPQueue,
deleteSMPQueues,
sendProtocolCommand,
-- * Supporting types and client configuration
ProtocolClientError (..),
ProtocolClientConfig (..),
NetworkConfig (..),
TransportSessionMode (..),
defaultClientConfig,
defaultNetworkConfig,
transportClientConfig,
@@ -75,6 +79,7 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (rights)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (find)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
@@ -110,9 +115,10 @@ data ProtocolClient msg = ProtocolClient
data PClient msg = PClient
{ connected :: TVar Bool,
protocolServer :: ProtoServer msg,
transportSession :: TransportSession msg,
transportHost :: TransportHost,
tcpTimeout :: Int,
pingErrorCount :: TVar Int,
clientCorrId :: TVar Natural,
sentCommands :: TMap CorrId (Request msg),
sndQ :: TBQueue (NonEmpty SentRawTransmission),
@@ -126,7 +132,7 @@ type SMPClient = ProtocolClient SMP.BrokerMsg
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
-- | Type synonym for transmission from some SPM server queue.
type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg)
type ServerTransmission msg = (TransportSession msg, Version, SessionId, QueueId, msg)
data HostMode
= -- | prefer (or require) onion hosts when connecting via SOCKS proxy
@@ -152,14 +158,18 @@ data NetworkConfig = NetworkConfig
hostMode :: HostMode,
-- | if above criteria is not met, if the below setting is True return error, otherwise use the first host
requiredHostMode :: Bool,
-- | transport sessions are created per user or per entity
sessionMode :: TransportSessionMode,
-- | timeout for the initial client TCP/TLS connection (microseconds)
tcpConnectTimeout :: Int,
-- | timeout of protocol commands (microseconds)
tcpTimeout :: Int,
-- | TCP keep-alive options, Nothing to skip enabling keep-alive
tcpKeepAlive :: Maybe KeepAliveOpts,
-- | period for SMP ping commands (microseconds)
-- | period for SMP ping commands (microseconds, 0 to disable)
smpPingInterval :: Int,
-- | the count of PING errors after which SMP client terminates (and will be reconnected), 0 to disable
smpPingCount :: Int,
logTLSErrors :: Bool
}
deriving (Eq, Show, Generic, FromJSON)
@@ -168,16 +178,28 @@ instance ToJSON NetworkConfig where
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
data TransportSessionMode = TSMUser | TSMEntity
deriving (Eq, Show, Generic)
instance ToJSON TransportSessionMode where
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM"
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM"
instance FromJSON TransportSessionMode where
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM"
defaultNetworkConfig :: NetworkConfig
defaultNetworkConfig =
NetworkConfig
{ socksProxy = Nothing,
hostMode = HMOnionViaSocks,
requiredHostMode = False,
sessionMode = TSMUser,
tcpConnectTimeout = 7_500_000,
tcpTimeout = 5_000_000,
tcpKeepAlive = Just defaultKeepAliveOpts,
smpPingInterval = 600_000_000, -- 10min
smpPingCount = 3,
logTLSErrors = False
}
@@ -229,19 +251,29 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
publicHost = find (not . isOnionHost) hosts
clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient msg -> String
clientServer = B.unpack . strEncode . protocolServer . client_
clientServer = B.unpack . strEncode . snd3 . transportSession . client_
where
snd3 (_, s, _) = s
transportHost' :: ProtocolClient msg -> TransportHost
transportHost' = transportHost . client_
transportSession' :: ProtocolClient msg -> TransportSession msg
transportSession' = transportSession . client_
type UserId = Int64
-- | Transport session key - includes entity ID if `sessionMode = TSMEntity`.
type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
-- | Connects to 'ProtocolServer' using passed client configuration
-- and queue for messages and notifications.
--
-- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
case chooseTransportHost networkConfig (host protocolServer) of
getProtocolClient :: forall msg. Protocol msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
case chooseTransportHost networkConfig (host srv) of
Right useHost ->
(atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
@@ -251,6 +283,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
mkProtocolClient :: TransportHost -> STM (PClient msg)
mkProtocolClient transportHost = do
connected <- newTVar False
pingErrorCount <- newTVar 0
clientCorrId <- newTVar 0
sentCommands <- TM.empty
sndQ <- newTBQueue qSize
@@ -258,9 +291,10 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
return
PClient
{ connected,
protocolServer,
transportSession,
transportHost,
tcpTimeout,
pingErrorCount,
clientCorrId,
sentCommands,
sndQ,
@@ -272,9 +306,10 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
runClient (port', ATransport t) useHost c = do
cVar <- newEmptyTMVarIO
let tcConfig = transportClientConfig networkConfig
username = proxyUsername transportSession
action <-
async $
runTransportClient tcConfig useHost port' (Just $ keyHash protocolServer) (client t c cVar)
runTransportClient tcConfig (Just username) useHost port' (Just $ keyHash srv) (client t c cVar)
`finally` atomically (putTMVar cVar $ Left PCENetworkError)
c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar)
pure $ case c_ of
@@ -282,15 +317,18 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
Just (Left e) -> Left e
Nothing -> Left PCENetworkError
proxyUsername :: TransportSession msg -> ByteString
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
useTransport :: (ServiceName, ATransport)
useTransport = case port protocolServer of
useTransport = case port srv of
"" -> defaultTransport cfg
"80" -> ("80", transport @WS)
p -> (p, transport @TLS)
client :: forall c. Transport c => TProxy c -> PClient msg -> TMVar (Either ProtocolClientError (ProtocolClient msg)) -> c -> IO ()
client _ c cVar h =
runExceptT (protocolClientHandshake @msg h (keyHash protocolServer) smpServerVRange) >>= \case
runExceptT (protocolClientHandshake @msg h (keyHash srv) smpServerVRange) >>= \case
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
Right th@THandle {sessionId, thVersion} -> do
sessionTs <- getCurrentTime
@@ -308,9 +346,15 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
ping :: ProtocolClient msg -> IO ()
ping c = forever $ do
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
threadDelay smpPingInterval
runExceptT $ sendProtocolCommand c Nothing "" protocolPing
runExceptT (sendProtocolCommand c Nothing "" protocolPing) >>= \case
Left PCEResponseTimeout -> do
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
when (maxCnt == 0 || cnt < maxCnt) $ ping c
_ -> ping c -- sendProtocolCommand resets pingErrorCount
where
maxCnt = smpPingCount networkConfig
process :: ProtocolClient msg -> IO ()
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
@@ -412,8 +456,8 @@ writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {protocolServer}} entityId message =
(protocolServer, thVersion, sessionId, entityId, message)
serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {transportSession}} entityId message =
(transportSession, thVersion, sessionId, entityId, message)
-- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue
--
@@ -464,13 +508,7 @@ disableSMPQueueNotifications = okSMPCommand NDEL
-- | Disable notifications for multiple queues for push notifications server.
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
disableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
where
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient NDEL)) qs
response = \case
Right OK -> Right ()
Right r -> Left . PCEUnexpectedResponse $ bshow r
Left e -> Left e
disableSMPQueuesNtfs = okSMPCommands NDEL
-- | Send SMP message.
--
@@ -501,41 +539,57 @@ suspendSMPQueue = okSMPCommand OFF
-- | Irreversibly delete SMP queue and all messages in it.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
deleteSMPQueue = okSMPCommand DEL
-- | Delete multiple SMP queues batching commands if supported.
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
deleteSMPQueues = okSMPCommands DEL
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
okSMPCommand cmd c pKey qId =
sendSMPCommand c (Just pKey) qId cmd >>= \case
OK -> return ()
r -> throwE . PCEUnexpectedResponse $ bshow r
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either ProtocolClientError ()))
okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
where
aCmd = Cmd sParty cmd
cs = L.map (\(pKey, qId) -> (Just pKey, qId, aCmd)) qs
response = \case
Right OK -> Right ()
Right r -> Left . PCEUnexpectedResponse $ bshow r
Left e -> Left e
-- | Send SMP command
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
-- | Send multiple commands with batching and collect responses
sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg))
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}} cs = do
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do
ts <- mapM (runExceptT . mkTransmission c) cs
mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts
forConcurrently ts $ \case
Right (_, r) -> withTimeout . atomically $ takeTMVar r
Right (_, r) -> withTimeout c $ atomically $ takeTMVar r
Left e -> pure $ Left e
where
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
-- | Send Protocol command
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}} pKey qId cmd = do
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}} pKey qId cmd = do
(t, r) <- mkTransmission c (pKey, qId, cmd)
ExceptT $ sendRecv t r
where
-- two separate "atomically" needed to avoid blocking
sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg)
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout (atomically $ takeTMVar r)
where
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout c (atomically $ takeTMVar r)
withTimeout :: ProtocolClient msg -> IO (Either ProtocolClientError msg) -> IO (Either ProtocolClientError msg)
withTimeout ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} a =
timeout tcpTimeout a >>= \case
Just r -> atomically (writeTVar pingErrorCount 0) >> pure r
_ -> pure $ Left PCEResponseTimeout
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do
+1 -1
View File
@@ -160,7 +160,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
void $ tryConnectClient (const reconnectClient) loop
connectClient :: ExceptT ProtocolClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
connectClient = ExceptT $ getProtocolClient (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
clientDisconnected :: SMPClient -> IO ()
clientDisconnected _ = do
@@ -190,7 +190,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
receiveSMP :: M ()
receiveSMP = forever $ do
(srv, _, _, ntfId, msg) <- atomically $ readTBQueue msgQ
((_, srv, _), _, _, ntfId, msg) <- atomically $ readTBQueue msgQ
let smpQueue = SMPQueueNtf srv ntfId
case msg of
SMP.NMSG nmsgNonce encNMsgMeta -> do
@@ -372,6 +372,7 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
t_ <- atomically $ getActiveNtfToken st subTknId
verifyToken' t_ $ verifiedSubCmd s c
else pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
NtfCmd SSubscription PING -> pure $ VRVerified $ NtfReqPing corrId entId
NtfCmd SSubscription c -> do
s_ <- atomically $ getNtfSubscription st entId
case s_ of
@@ -529,6 +530,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
incNtfStat subDeleted
pure NROk
PING -> pure NRPong
NtfReqPing corrId entId -> pure (corrId, entId, NRPong)
getId :: M NtfEntityId
getId = getRandomBytes =<< asks (subIdBytes . config)
getRegCode :: M NtfRegCode
@@ -148,6 +148,7 @@ getPushClient s@NtfPushServer {pushClients} pp =
data NtfRequest
= NtfReqNew CorrId ANewNtfEntity
| forall e. NtfEntityI e => NtfReqCmd (SNtfEntity e) (NtfEntityRec e) (Transmission (NtfCommand e))
| NtfReqPing CorrId NtfEntityId
data NtfServerClient = NtfServerClient
{ rcvQ :: TBQueue NtfRequest,
+2 -1
View File
@@ -77,6 +77,7 @@ module Simplex.Messaging.Protocol
BasicAuth (..),
SrvLoc (..),
CorrId (..),
EntityId,
QueueId,
RecipientId,
SenderId,
@@ -753,7 +754,7 @@ basicAuth s
where
valid c = isPrint c && not (isSpace c) && c /= '@' && c /= ':' && c /= '/'
data ProtoServerWithAuth p = ProtoServerWithAuth (ProtocolServer p) (Maybe BasicAuth)
data ProtoServerWithAuth p = ProtoServerWithAuth {protoServer :: ProtocolServer p, serverBasicAuth :: Maybe BasicAuth}
deriving (Show)
instance ProtocolTypeI p => IsString (ProtoServerWithAuth p) where
+3 -1
View File
@@ -47,6 +47,7 @@ data ServerStatsData = ServerStatsData
_qCount :: Int,
_msgCount :: Int
}
deriving (Show)
newServerStats :: UTCTime -> STM ServerStats
newServerStats ts = do
@@ -88,7 +89,7 @@ setServerStats s d = do
writeTVar (qDeleted s) $! _qDeleted d
writeTVar (msgSent s) $! _msgSent d
writeTVar (msgRecv s) $! _msgRecv d
setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d)
setPeriodStats (activeQueues s) (_activeQueues d)
writeTVar (msgSentNtf s) $! _msgSentNtf d
writeTVar (msgRecvNtf s) $! _msgRecvNtf d
setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d)
@@ -152,6 +153,7 @@ data PeriodStatsData a = PeriodStatsData
_week :: Set a,
_month :: Set a
}
deriving (Show)
newPeriodStatsData :: PeriodStatsData a
newPeriodStatsData = PeriodStatsData {_day = S.empty, _week = S.empty, _month = S.empty}
+9 -7
View File
@@ -107,15 +107,15 @@ defaultTransportClientConfig :: TransportClientConfig
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True
-- | Connect to passed TCP host:port and pass handle to the client.
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTransportClient = runTLSTransportClient supportedParameters Nothing
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} host port keyHash client = do
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do
let hostName = B.unpack $ strEncode host
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash
connectTCP = case socksProxy of
Just proxy -> connectSocksClient proxy $ hostAddr host
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
_ -> connectTCPClient hostName
c <- liftIO $ do
sock <- connectTCP port
@@ -153,10 +153,12 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
defaultSMPPort :: PortNumber
defaultSMPPort = 5223
connectSocksClient :: SocksProxy -> SocksHostAddress -> ServiceName -> IO Socket
connectSocksClient (SocksProxy addr) hostAddr _port = do
connectSocksClient :: SocksProxy -> Maybe ByteString -> SocksHostAddress -> ServiceName -> IO Socket
connectSocksClient (SocksProxy addr) proxyUsername hostAddr _port = do
let port = if null _port then defaultSMPPort else fromMaybe defaultSMPPort $ readMaybe _port
fst <$> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
fst <$> case proxyUsername of
Just username -> socksConnectAuth (defaultSocksConf addr) (SocksAddress hostAddr port) (SocksCredentials username "")
_ -> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
defaultSocksHost :: HostAddress
defaultSocksHost = tupleToHostAddress (127, 0, 0, 1)
@@ -120,7 +120,7 @@ sendRequest HTTP2Client {reqQ, config} req = do
runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> HostName -> ServiceName -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
runHTTP2Client tlsParams caStore tcConfig host port client =
runTLSTransportClient tlsParams caStore tcConfig (THDomainName host) port Nothing $ \c ->
runTLSTransportClient tlsParams caStore tcConfig Nothing (THDomainName host) port Nothing $ \c ->
withTlsConfig c 16384 (`run` client)
where
run = H.run $ ClientConfig "https" (B.pack host) 20
+4 -4
View File
@@ -338,8 +338,8 @@ testServerConnectionAfterError t _ = do
withServer $ do
alice <#= \case ("", "bob", SENT 4) -> True; ("", "", UP s ["bob"]) -> s == server; _ -> False
alice <#= \case ("", "bob", SENT 4) -> True; ("", "", UP s ["bob"]) -> s == server; _ -> False
bob <# ("", "", UP server ["alice"])
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
bob <#= \case ("", "alice", Msg "hello") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
bob <#= \case ("", "alice", Msg "hello") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
bob #: ("2", "alice", "ACK 4") #> ("2", "alice", OK)
alice #: ("1", "bob", "SEND F 11\nhello again") #> ("1", "bob", MID 5)
alice <# ("", "bob", SENT 5)
@@ -381,8 +381,8 @@ testMsgDeliveryAgentRestart t bob = do
(corrId == "3" && cmd == OK)
|| (corrId == "" && cmd == SENT 5)
_ -> False
bob <# ("", "", UP server ["alice"])
bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False
bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK)
removeFile testStoreLogFile
+240 -48
View File
@@ -7,6 +7,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module AgentTests.FunctionalAPITests
@@ -26,10 +27,11 @@ where
import Control.Concurrent (killThread, threadDelay)
import Control.Monad
import Control.Monad.Except (ExceptT, MonadError (throwError), runExceptT)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.IO.Unlift
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight)
import Data.Int (Int64)
import qualified Data.Map as M
import Data.Maybe (isNothing)
@@ -41,7 +43,8 @@ import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Client (SMPTestFailure (..), SMPTestStep (..))
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..))
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig)
import Simplex.Messaging.Agent.Store (UserId)
import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultClientConfig)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..))
import qualified Simplex.Messaging.Protocol as SMP
@@ -137,6 +140,17 @@ functionalAPITests t = do
testAsyncCommandsRestore t
it "should accept connection using async command" $
withSmpServer t testAcceptContactAsync
it "should delete connections using async command when server connection fails" $
testDeleteConnectionAsync t
describe "Users" $ do
it "should create and delete user with connections" $
withSmpServer t testUsers
it "should create and delete user without connections" $
withSmpServer t testDeleteUserQuietly
it "should create and delete user with connections when server connection fails" $
testUsersNoServer t
it "should connect two users and switch session mode" $
withSmpServer t testTwoUsers
describe "Queue rotation" $ do
describe "should switch delivery to the new queue" $
testServerMatrix2 t testSwitchConnection
@@ -229,8 +243,8 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do
runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientTest alice bob baseId = do
runRight_ $ do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
@@ -264,8 +278,8 @@ runAgentClientTest alice bob baseId = do
runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientContactTest alice bob baseId = do
runRight_ $ do
(_, qInfo) <- createConnection alice True SCMContact Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, REQ invId _ "bob's connInfo") <- get alice
bobId <- acceptContact alice True invId "alice's connInfo"
("", _, CONF confId _ "alice's connInfo") <- get bob
@@ -303,7 +317,7 @@ noMessages c err = tryGet `shouldReturn` ()
where
tryGet =
10000 `timeout` get c >>= \case
Just _ -> error err
Just msg -> error $ err <> ": " <> show msg
_ -> return ()
testAsyncInitiatingOffline :: IO ()
@@ -311,9 +325,9 @@ testAsyncInitiatingOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
subscribeConnection alice' bobId
("", _, CONF confId _ "bob's connInfo") <- get alice'
@@ -328,8 +342,8 @@ testAsyncJoiningOfflineBeforeActivation = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
disconnectAgentClient bob
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
@@ -345,9 +359,9 @@ testAsyncBothOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
disconnectAgentClient bob
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
subscribeConnection alice' bobId
@@ -366,9 +380,9 @@ testAsyncServerOffline t = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
-- create connection and shutdown the server
(bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
runRight $ createConnection alice True SCMInvitation Nothing
runRight $ createConnection alice 1 True SCMInvitation Nothing
-- connection fails
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo"
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo"
("", "", DOWN srv conns) <- get alice
srv `shouldBe` testSMPServer
conns `shouldBe` [bobId]
@@ -378,7 +392,7 @@ testAsyncServerOffline t = do
liftIO $ do
srv1 `shouldBe` testSMPServer
conns1 `shouldBe` [bobId]
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
@@ -392,9 +406,9 @@ testAsyncHelloTimeout = do
alice <- getSMPAgentClient agentCfgV1 initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers
runRight_ $ do
(_, cReq) <- createConnection alice True SCMInvitation Nothing
(_, cReq) <- createConnection alice 1 True SCMInvitation Nothing
disconnectAgentClient alice
aliceId <- joinConnection bob True cReq "bob's connInfo"
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
testDuplicateMessage :: ATransport -> IO ()
@@ -446,9 +460,12 @@ testDuplicateMessage t = do
get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False
makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
makeConnection alice bob = do
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
makeConnection alice bob = makeConnectionForUsers alice 1 bob 1
makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId)
makeConnectionForUsers alice aliceUserId bob bobUserId = do
(bobId, qInfo) <- createConnection alice aliceUserId True SCMInvitation Nothing
aliceId <- joinConnection bob bobUserId True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
@@ -462,7 +479,7 @@ testInactiveClientDisconnected t = do
withSmpServerConfigOn t cfg' testPort $ \_ -> do
alice <- getSMPAgentClient agentCfg initAgentServers
runRight_ $ do
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
get alice ##> ("", "", DOWN testSMPServer [connId])
testActiveClientNotDisconnected :: ATransport -> IO ()
@@ -472,7 +489,7 @@ testActiveClientNotDisconnected t = do
alice <- getSMPAgentClient agentCfg initAgentServers
ts <- getSystemTime
runRight_ $ do
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
keepSubscribing alice connId ts
where
keepSubscribing :: AgentClient -> ConnId -> SystemTime -> ExceptT AgentErrorType IO ()
@@ -536,10 +553,8 @@ testSuspendingAgentCompleteSending t = do
get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False
("", "", SUSPENDED) <- get b
r <- get a
liftIO $ print r
("", "", UP {}) <- pure r
get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False
get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False
get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False
ackMessage a bId 5
get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False
ackMessage a bId 6
@@ -572,9 +587,9 @@ testBatchedSubscriptions t = do
conns <- runServers $ do
conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
forM_ (take 10 conns) $ \(aId, bId) -> do
deleteConnection a bId
deleteConnection b aId
let (aIds', bIds') = unzip $ take 10 conns
delete a bIds'
delete b aIds'
liftIO $ threadDelay 1000000
pure conns
("", "", DOWN {}) <- get a
@@ -587,18 +602,37 @@ testBatchedSubscriptions t = do
("", "", UP {}) <- get b
("", "", UP {}) <- get b
liftIO $ threadDelay 1000000
subscribe a $ map snd conns
subscribe b $ map fst conns
forM_ (drop 10 conns) $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId
let (aIds, bIds) = unzip conns
conns' = drop 10 conns
(aIds', bIds') = unzip conns'
subscribe a bIds
subscribe b aIds
forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId
delete a bIds'
delete b aIds'
deleteFail a bIds'
deleteFail b aIds'
where
subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
subscribe c cs = do
r <- subscribeConnections c cs
liftIO $ do
let dc = S.fromList $ take 10 cs
all (== Right ()) (M.withoutKeys r dc) `shouldBe` True
all isRight (M.withoutKeys r dc) `shouldBe` True
all (== Left (CONN NOT_FOUND)) (M.restrictKeys r dc) `shouldBe` True
M.keys r `shouldMatchList` cs
delete :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
delete c cs = do
r <- deleteConnections c cs
liftIO $ do
all isRight r `shouldBe` True
M.keys r `shouldMatchList` cs
deleteFail :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
deleteFail c cs = do
r <- deleteConnections c cs
liftIO $ do
all (== Left (CONN NOT_FOUND)) r `shouldBe` True
M.keys r `shouldMatchList` cs
runServers :: ExceptT AgentErrorType IO a -> IO a
runServers a = do
withSmpServerStoreLogOn t testPort $ \t1 -> do
@@ -612,10 +646,10 @@ testAsyncCommands = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
bobId <- createConnectionAsync alice "1" True SCMInvitation
bobId <- createConnectionAsync alice 1 "1" True SCMInvitation
("1", bobId', INV (ACR _ qInfo)) <- get alice
liftIO $ bobId' `shouldBe` bobId
aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo"
aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo"
("2", aliceId', OK) <- get bob
liftIO $ aliceId' `shouldBe` aliceId
("", _, CONF confId _ "bob's connInfo") <- get alice
@@ -645,8 +679,9 @@ testAsyncCommands = do
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
ackMessageAsync alice "7" bobId $ baseId + 4
("7", _, OK) <- get alice
deleteConnectionAsync alice "8" bobId
("8", _, OK) <- get alice
deleteConnectionAsync alice bobId
get alice =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bobId; _ -> False
get alice =##> \case ("", c, DEL_CONN) -> c == bobId; _ -> False
liftIO $ noMessages alice "nothing else should be delivered to alice"
where
baseId = 3
@@ -655,7 +690,7 @@ testAsyncCommands = do
testAsyncCommandsRestore :: ATransport -> IO ()
testAsyncCommandsRestore t = do
alice <- getSMPAgentClient agentCfg initAgentServers
bobId <- runRight $ createConnectionAsync alice "1" True SCMInvitation
bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
disconnectAgentClient alice
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
@@ -670,8 +705,8 @@ testAcceptContactAsync = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
(_, qInfo) <- createConnection alice True SCMContact Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, REQ invId _ "bob's connInfo") <- get alice
bobId <- acceptContactAsync alice "1" True invId "alice's connInfo"
("1", bobId', OK) <- get alice
@@ -707,6 +742,89 @@ testAcceptContactAsync = do
baseId = 3
msgId = subtract baseId
testDeleteConnectionAsync :: ATransport -> IO ()
testDeleteConnectionAsync t = do
a <- getSMPAgentClient agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers
connIds <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
(bId1, _inv) <- createConnection a 1 True SCMInvitation Nothing
(bId2, _inv) <- createConnection a 1 True SCMInvitation Nothing
(bId3, _inv) <- createConnection a 1 True SCMInvitation Nothing
pure ([bId1, bId2, bId3] :: [ConnId])
runRight_ $ do
deleteConnectionsAsync a connIds
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
get a =##> \case ("", c, DEL_CONN) -> c `elem` connIds; _ -> False
get a =##> \case ("", c, DEL_CONN) -> c `elem` connIds; _ -> False
get a =##> \case ("", c, DEL_CONN) -> c `elem` connIds; _ -> False
liftIO $ noMessages a "nothing else should be delivered to alice"
testUsers :: IO ()
testUsers = do
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
auId <- createUser a [noAuthSrv testSMPServer]
(aId', bId') <- makeConnectionForUsers a auId b 1
exchangeGreetingsMsgId 4 a bId' b aId'
deleteUser a auId True
get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId'; _ -> False
get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False
get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False
exchangeGreetingsMsgId 6 a bId b aId
liftIO $ noMessages a "nothing else should be delivered to alice"
testDeleteUserQuietly :: IO ()
testDeleteUserQuietly = do
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
auId <- createUser a [noAuthSrv testSMPServer]
(aId', bId') <- makeConnectionForUsers a auId b 1
exchangeGreetingsMsgId 4 a bId' b aId'
deleteUser a auId False
exchangeGreetingsMsgId 6 a bId b aId
liftIO $ noMessages a "nothing else should be delivered to alice"
testUsersNoServer :: ATransport -> IO ()
testUsersNoServer t = do
a <- getSMPAgentClient agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
liftIO $ print 1
(aId, bId, auId, _aId', bId') <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
auId <- createUser a [noAuthSrv testSMPServer]
(aId', bId') <- makeConnectionForUsers a auId b 1
exchangeGreetingsMsgId 4 a bId' b aId'
pure (aId, bId, auId, aId', bId')
liftIO $ print 2
get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False
get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False
get b =##> \case ("", "", DOWN _ cs) -> length cs == 2; _ -> False
liftIO $ print 3
runRight_ $ do
deleteUser a auId True
liftIO $ print 4
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False
liftIO $ print 4.1
get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False
liftIO $ print 4.2
get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False
liftIO $ print 5
liftIO $ noMessages a "nothing else should be delivered to alice"
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
liftIO $ print 6
get a =##> \case ("", "", UP _ [c]) -> c == bId; _ -> False
get b =##> \case ("", "", UP _ cs) -> length cs == 2; _ -> False
liftIO $ print 7
exchangeGreetingsMsgId 6 a bId b aId
testSwitchConnection :: InitialAgentServers -> IO ()
testSwitchConnection servers = do
a <- getSMPAgentClient agentCfg servers
@@ -743,21 +861,26 @@ phase c connId d p =
testSwitchAsync :: InitialAgentServers -> IO ()
testSwitchAsync servers = do
liftIO $ print 1
(aId, bId) <- withA $ \a -> withB $ \b -> runRight $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
pure (aId, bId)
liftIO $ print 2
let withA' = session withA bId
withB' = session withB aId
withA' $ \a -> do
switchConnectionAsync a "" bId
phase a bId QDRcv SPStarted
liftIO $ print 3
withB' $ \b -> phase b aId QDSnd SPStarted
withA' $ \a -> phase a bId QDRcv SPConfirmed
liftIO $ print 4
withB' $ \b -> do
phase b aId QDSnd SPConfirmed
phase b aId QDSnd SPCompleted
withA' $ \a -> phase a bId QDRcv SPCompleted
liftIO $ print 5
withA $ \a -> withB $ \b -> runRight_ $ do
subscribeConnection a bId
subscribeConnection b aId
@@ -785,20 +908,22 @@ testSwitchDelete servers = do
disconnectAgentClient b
switchConnectionAsync a "" bId
phase a bId QDRcv SPStarted
deleteConnectionAsync a "1" bId
("1", bId', OK) <- get a
liftIO $ bId `shouldBe` bId'
deleteConnectionAsync a bId
get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False
get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False
get a =##> \case ("", c, DEL_CONN) -> c == bId; _ -> False
liftIO $ noMessages a "nothing else should be delivered to alice"
testCreateQueueAuth :: (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int
testCreateQueueAuth clnt1 clnt2 = do
a <- getClient clnt1
b <- getClient clnt2
runRight $ do
tryError (createConnection a True SCMInvitation Nothing) >>= \case
tryError (createConnection a 1 True SCMInvitation Nothing) >>= \case
Left (SMP AUTH) -> pure 0
Left e -> throwError e
Right (bId, qInfo) ->
tryError (joinConnection b True qInfo "bob's connInfo") >>= \case
tryError (joinConnection b 1 True qInfo "bob's connInfo") >>= \case
Left (SMP AUTH) -> pure 1
Left e -> throwError e
Right aId -> do
@@ -811,7 +936,7 @@ testCreateQueueAuth clnt1 clnt2 = do
pure 2
where
getClient (clntAuth, clntVersion) =
let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]}
let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]}
smpCfg = (defaultClientConfig :: ProtocolClientConfig) {smpServerVRange = mkVersionRange 4 clntVersion}
in getSMPAgentClient agentCfg {smpCfg} servers
@@ -819,7 +944,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
testSMPServerConnectionTest t newQueueBasicAuth srv =
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running
runRight $ testSMPServerConnection a srv
runRight $ testSMPServerConnection a 1 srv
testRatchetAdHash :: IO ()
testRatchetAdHash = do
@@ -831,6 +956,73 @@ testRatchetAdHash = do
ad2 <- getConnectionRatchetAdHash b aId
liftIO $ ad1 `shouldBe` ad2
testTwoUsers :: IO ()
testTwoUsers = do
let nc = netCfg initAgentServers
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
sessionMode nc `shouldBe` TSMUser
runRight_ $ do
(aId1, bId1) <- makeConnectionForUsers a 1 b 1
exchangeGreetings a bId1 b aId1
(aId1', bId1') <- makeConnectionForUsers a 1 b 1
exchangeGreetings a bId1' b aId1'
a `hasClients` 1
b `hasClients` 1
setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 2
exchangeGreetingsMsgId 6 a bId1 b aId1
exchangeGreetingsMsgId 6 a bId1' b aId1'
liftIO $ threadDelay 250000
setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 1
aUserId2 <- createUser a [noAuthSrv testSMPServer]
(aId2, bId2) <- makeConnectionForUsers a aUserId2 b 1
exchangeGreetings a bId2 b aId2
(aId2', bId2') <- makeConnectionForUsers a aUserId2 b 1
exchangeGreetings a bId2' b aId2'
a `hasClients` 2
b `hasClients` 1
setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 4
exchangeGreetingsMsgId 8 a bId1 b aId1
exchangeGreetingsMsgId 8 a bId1' b aId1'
exchangeGreetingsMsgId 6 a bId2 b aId2
exchangeGreetingsMsgId 6 a bId2' b aId2'
liftIO $ threadDelay 250000
setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 2
exchangeGreetingsMsgId 10 a bId1 b aId1
exchangeGreetingsMsgId 10 a bId1' b aId1'
exchangeGreetingsMsgId 8 a bId2 b aId2
exchangeGreetingsMsgId 8 a bId2' b aId2'
where
hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings = exchangeGreetingsMsgId 4
+8 -8
View File
@@ -209,8 +209,8 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
(bobId, aliceId, nonce, message) <- runRight $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get bob ##> ("", aliceId, INFO "alice's connInfo")
@@ -270,9 +270,9 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
DeviceToken {} <- registerTestToken bob "bcde" NMInstant apnsQ
-- establish connection
liftIO $ threadDelay 50000
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
liftIO $ threadDelay 1000000
aliceId <- joinConnection bob True qInfo "bob's connInfo"
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
liftIO $ threadDelay 750000
void $ messageNotification apnsQ
("", _, CONF confId _ "bob's connInfo") <- get alice
@@ -321,8 +321,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
runRight_ $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get bob ##> ("", aliceId, INFO "alice's connInfo")
@@ -385,8 +385,8 @@ testChangeToken APNSMockServer {apnsQ} = do
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
(aliceId, bobId) <- runRight $ do
-- establish connection
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
aliceId <- joinConnection bob True qInfo "bob's connInfo"
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get bob ##> ("", aliceId, INFO "alice's connInfo")
+13 -7
View File
@@ -140,7 +140,7 @@ testForeignKeysEnabled =
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
cData1 :: ConnData
cData1 = ConnData {connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
testPrivateSignKey :: C.APrivateSignKey
testPrivateSignKey = C.APrivateSignKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
@@ -154,7 +154,8 @@ testDhSecret = "01234567890123456789012345678901"
rcvQueue1 :: RcvQueue
rcvQueue1 =
RcvQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
rcvId = "1234",
rcvPrivateKey = testPrivateSignKey,
@@ -167,13 +168,15 @@ rcvQueue1 =
primary = True,
dbReplaceQueueId = Nothing,
smpClientVersion = 1,
clientNtfCreds = Nothing
clientNtfCreds = Nothing,
deleteErrors = 0
}
sndQueue1 :: SndQueue
sndQueue1 =
SndQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
sndId = "3456",
sndPublicKey = Nothing,
@@ -314,7 +317,8 @@ testUpgradeRcvConnToDuplex =
_ <- createSndConn db g cData1 sndQueue1
let anotherSndQueue =
SndQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
sndId = "2345",
sndPublicKey = Nothing,
@@ -340,7 +344,8 @@ testUpgradeSndConnToDuplex =
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
let anotherRcvQueue =
RcvQueue
{ connId = "conn1",
{ userId = 1,
connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
rcvId = "3456",
rcvPrivateKey = testPrivateSignKey,
@@ -353,7 +358,8 @@ testUpgradeSndConnToDuplex =
primary = True,
dbReplaceQueueId = Nothing,
smpClientVersion = 1,
clientNtfCreds = Nothing
clientNtfCreds = Nothing,
deleteErrors = 0
}
upgradeSndConnToDuplex db "conn1" anotherRcvQueue
`shouldReturn` Left (SEBadConnType CRcv)
+1 -1
View File
@@ -69,7 +69,7 @@ ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
testNtfClient client = do
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig host ntfTestPort (Just testKeyHash) $ \h ->
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h ->
liftIO (runExceptT $ ntfClientHandshake h testKeyHash supportedNTFServerVRange) >>= \case
Right th -> client th
Left e -> error $ show e
+14 -17
View File
@@ -1,6 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
@@ -11,7 +12,8 @@ import Control.Monad.IO.Unlift
import Crypto.Random
import qualified Data.ByteString.Char8 as B
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Network.Socket (ServiceName)
import NtfClient (ntfTestPort)
import SMPClient
@@ -27,6 +29,7 @@ import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
import Simplex.Messaging.Agent.Store (UserId)
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Transport
@@ -173,13 +176,13 @@ testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5
initAgentServers :: InitialAgentServers
initAgentServers =
InitialAgentServers
{ smp = L.fromList [noAuthSrv testSMPServer],
{ smp = userServers [noAuthSrv testSMPServer],
ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"],
netCfg = defaultNetworkConfig {tcpTimeout = 500_000}
netCfg = defaultNetworkConfig {tcpTimeout = 500_000, tcpConnectTimeout = 500_000}
}
initAgentServers2 :: InitialAgentServers
initAgentServers2 = initAgentServers {smp = L.fromList [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
agentCfg :: AgentConfig
agentCfg =
@@ -187,17 +190,8 @@ agentCfg =
{ tcpPort = agentTestPort,
tbqSize = 4,
database = testDB,
smpCfg =
defaultClientConfig
{ qSize = 1,
defaultTransport = (testPort, transport @TLS),
networkConfig = defaultNetworkConfig {tcpTimeout = 500_000}
},
ntfCfg =
defaultClientConfig
{ qSize = 1,
defaultTransport = (ntfTestPort, transport @TLS)
},
smpCfg = defaultClientConfig {qSize = 1, defaultTransport = (testPort, transport @TLS)},
ntfCfg = defaultClientConfig {qSize = 1, defaultTransport = (ntfTestPort, transport @TLS)},
reconnectInterval = defaultReconnectInterval {initialInterval = 50_000},
ntfWorkerDelay = 1000,
ntfSMPWorkerDelay = 1000,
@@ -209,11 +203,14 @@ agentCfg =
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> m () -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
let cfg' = agentCfg {tcpPort = port', database = db'}
initServers' = initAgentServers {smp = L.fromList [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
in serverBracket
(\started -> runSMPAgentBlocking t started cfg' initServers')
afterProcess
userServers :: NonEmpty SMPServerWithAuth -> Map UserId (NonEmpty SMPServerWithAuth)
userServers srvs = M.fromList [(1, srvs)]
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a $ removeFile (dbFile db')
@@ -226,7 +223,7 @@ withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a
testSMPAgentClientOn port' client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
runTransportClient defaultTransportClientConfig useHost port' (Just testKeyHash) $ \h -> do
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
line <- liftIO $ getLn h
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
then client h
+1 -1
View File
@@ -57,7 +57,7 @@ testServerStatsBackupFile = "tests/tmp/smp-server-stats.log"
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
testSMPClient client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig useHost testPort (Just testKeyHash) $ \h ->
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h ->
liftIO (runExceptT $ smpClientHandshake h testKeyHash supportedSMPServerVRange) >>= \case
Right th -> client th
Left e -> error $ show e
+33 -2
View File
@@ -11,6 +11,7 @@
module ServerTests where
import AgentTests.NotificationTests (removeFileIfExists)
import Control.Concurrent (ThreadId, killThread, threadDelay)
import Control.Concurrent.STM
import Control.Exception (SomeException, try)
@@ -20,6 +21,7 @@ import Data.Bifunctor (first)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Set as S
import SMPClient
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
@@ -28,6 +30,7 @@ import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.Env.STM (ServerConfig (..))
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..))
import Simplex.Messaging.Transport
import System.Directory (removeFile)
import System.TimeIt (timeItT)
@@ -605,6 +608,10 @@ logSize f =
testRestoreMessages :: ATransport -> Spec
testRestoreMessages at@(ATransport t) =
it "should store messages on exit and restore on start" $ do
removeFileIfExists testStoreLogFile
removeFileIfExists testStoreMsgsFile
removeFileIfExists testServerStatsBackupFile
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
recipientId <- newTVarIO ""
recipientKey <- newTVarIO Nothing
@@ -632,11 +639,15 @@ testRestoreMessages at@(ATransport t) =
Resp "6" _ (ERR QUOTA) <- signSendRecv h sKey ("6", sId, _SEND "hello 6")
pure ()
rId <- readTVarIO recipientId
logSize testStoreLogFile `shouldReturn` 2
logSize testStoreMsgsFile `shouldReturn` 5
logSize testServerStatsBackupFile `shouldReturn` 16
Right stats1 <- strDecode <$> B.readFile testServerStatsBackupFile
checkStats stats1 [rId] 5 1
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
rId <- readTVarIO recipientId
Just rKey <- readTVarIO recipientKey
Just dh <- readTVarIO dhShared
let dec = decryptMsgV3 dh
@@ -650,9 +661,11 @@ testRestoreMessages at@(ATransport t) =
logSize testStoreLogFile `shouldReturn` 1
-- the last message is not removed because it was not ACK'd
logSize testStoreMsgsFile `shouldReturn` 3
logSize testServerStatsBackupFile `shouldReturn` 16
Right stats2 <- strDecode <$> B.readFile testServerStatsBackupFile
checkStats stats2 [rId] 5 3
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
rId <- readTVarIO recipientId
Just rKey <- readTVarIO recipientKey
Just dh <- readTVarIO dhShared
let dec = decryptMsgV3 dh
@@ -667,9 +680,13 @@ testRestoreMessages at@(ATransport t) =
logSize testStoreLogFile `shouldReturn` 1
logSize testStoreMsgsFile `shouldReturn` 0
logSize testServerStatsBackupFile `shouldReturn` 16
Right stats3 <- strDecode <$> B.readFile testServerStatsBackupFile
checkStats stats3 [rId] 5 5
removeFile testStoreLogFile
removeFile testStoreMsgsFile
removeFile testServerStatsBackupFile
where
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
runTest _ test' server = do
@@ -679,6 +696,20 @@ testRestoreMessages at@(ATransport t) =
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
runClient _ test' = testSMPClient test' `shouldReturn` ()
checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation
checkStats s qs sent received = do
_qCreated s `shouldBe` length qs
_qSecured s `shouldBe` length qs
_qDeleted s `shouldBe` 0
_msgSent s `shouldBe` sent
_msgRecv s `shouldBe` received
_msgSentNtf s `shouldBe` 0
_msgRecvNtf s `shouldBe` 0
let PeriodStatsData {_day, _week, _month} = _activeQueues s
S.toList _day `shouldBe` qs
S.toList _week `shouldBe` qs
S.toList _month `shouldBe` qs
testRestoreMessagesV2 :: ATransport -> Spec
testRestoreMessagesV2 at@(ATransport t) =
it "should store messages on exit and restore on start" $ do