mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-25 03:45:23 +00:00
Merge branch 'master' into xftp
This commit is contained in:
@@ -7,6 +7,7 @@ module Main where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgent)
|
||||
import Simplex.Messaging.Client (defaultNetworkConfig)
|
||||
@@ -18,7 +19,7 @@ cfg = defaultAgentConfig
|
||||
servers :: InitialAgentServers
|
||||
servers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
|
||||
{ smp = M.fromList [(1, L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"])],
|
||||
ntf = [],
|
||||
netCfg = defaultNetworkConfig
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
packages: .
|
||||
-- packages: . ../direct-sqlcipher ../sqlcipher-simple
|
||||
-- packages: . ../hs-socks
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/aeson.git
|
||||
tag: 3eb66f9a68f103b5f1489382aad89f5712a64db7
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/hs-socks.git
|
||||
tag: a30cc7a79a08d8108316094f8f2f82a0c5e1ac51
|
||||
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/simplex-chat/direct-sqlcipher.git
|
||||
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
name: simplexmq
|
||||
version: 4.3.0
|
||||
version: 4.3.1
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: |
|
||||
This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
|
||||
@@ -95,10 +95,12 @@ To send the file, the sender will:
|
||||
|
||||
- compute SHA512 digest
|
||||
- pad the file to match the whole number of chunks in size,
|
||||
- encrypt it with a randomly chosen symmetric key and IV (e.g., using NaCL cryptobox),
|
||||
- encrypt it with a randomly chosen symmetric key and IV (e.g., using NaCL crypto_secretbox),
|
||||
- split into fixed size chunks
|
||||
- upload each chunk to a randomly chosen server.
|
||||
|
||||
The sending client should generate more per-recipient keys than the actual number of recipients, possibly rounding up to a power of 2, to conceal the actual number of intended recipients.
|
||||
|
||||
Then the sending client will combine addresses of all chunks and other information into "file description", different for each file recipient, that will include:
|
||||
|
||||
- an encryption key that was used to encrypt the file (the same for all recipients).
|
||||
@@ -106,22 +108,24 @@ Then the sending client will combine addresses of all chunks and other informati
|
||||
- list of chunk descriptions; information for each chunk:
|
||||
- private Ed25519 key to sign commands for file transfer server.
|
||||
- chunk address (server host and chunk ID).
|
||||
- chunk sha512 digest
|
||||
|
||||
To reduce the size, chunk descriptions will be grouped by the server host.
|
||||
|
||||
This "file description" itself will be sent as a small file. To estimate its size:
|
||||
This "file description" itself will be sent as a small file over an authenticated channel, to prevent file description modification. To estimate its size:
|
||||
|
||||
- each chunk \* redundancy per chunk, assuming chunks are grouped per server:
|
||||
- chunk number in the file - 8 bytes (including any overhead)
|
||||
- 1-based chunk number in the file - 8 bytes (including any overhead)
|
||||
- Ed25519 key (different for each recipient / chunk combination) - 32 bytes \* 4/3 (base64, assuming text encoding)
|
||||
- chunk ID (different for each recipient) - 64 bytes \* 4/3
|
||||
- optional (only in the first chunk occurence) chunk sha512 digests - 64 bytes \* 4/3
|
||||
- server addresses - say, 128 bytes per server
|
||||
- sha512 digest - 64 bytes \* 4/3
|
||||
- encryption key - 32 bytes \* 4/3
|
||||
- IV - 32 bytes \* 4/3
|
||||
- encoding overhead - say, 256 bytes
|
||||
|
||||
For 1gb file, sent via 4 different servers, in 8Mb chunks, with redundancy 2, the size of "file description", assuming text encoding, will be ~34kb (`128 * (8 + 32 + 64) * 2 * 4/3 + 128 * 4 + (64 + 32 + 32) * 4/3 + 256`).
|
||||
For 1gb file, sent via 4 different servers, in 8Mb chunks, with redundancy 2, the size of "file description", assuming text encoding, will be ~45kb (`128 * (8 + 32 + 64) * 2 * 4/3 + 128 * 64 * 4/3 + 128 * 4 + (64 + 32 + 32) * 4/3 + 256`).
|
||||
|
||||
File description format (yml):
|
||||
|
||||
@@ -132,11 +136,12 @@ chunk: 8Mb
|
||||
hash: abc=
|
||||
key: abc=
|
||||
iv: abc=
|
||||
part_hashes: [def=, def=, def=, def=]
|
||||
parts:
|
||||
- server: xftp://abc=@example1.com
|
||||
chunks: [1:abc=:def=, 3:abc=:def=]
|
||||
chunks: [1:abc=:def=:ghi=, 3:abc=:def=:ghi=]
|
||||
- server: xftp://abc=@example2.com
|
||||
chunks: [2:abc=:def=, 4:abc=:def=]
|
||||
chunks: [2:abc=:def=:ghi=, 4:abc=:def=:ghi=]
|
||||
- server: xftp://abc=@example3.com
|
||||
chunks: [1:abc=:def=, 4:abc=:def=]
|
||||
- server: xftp://abc=@example4.com
|
||||
|
||||
+5
-1
@@ -5,7 +5,7 @@ cabal-version: 1.12
|
||||
-- see: https://github.com/sol/hpack
|
||||
|
||||
name: simplexmq
|
||||
version: 4.3.0
|
||||
version: 4.3.1
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
<./docs/Simplex-Messaging-Client.html client> and
|
||||
@@ -64,6 +64,10 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
|
||||
Simplex.Messaging.Agent.TAsyncs
|
||||
Simplex.Messaging.Agent.TRcvQueues
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
|
||||
+295
-194
@@ -38,6 +38,8 @@ module Simplex.Messaging.Agent
|
||||
disconnectAgentClient,
|
||||
resumeAgentClient,
|
||||
withConnLock,
|
||||
createUser,
|
||||
deleteUser,
|
||||
createConnectionAsync,
|
||||
joinConnectionAsync,
|
||||
allowConnectionAsync,
|
||||
@@ -45,6 +47,7 @@ module Simplex.Messaging.Agent
|
||||
ackMessageAsync,
|
||||
switchConnectionAsync,
|
||||
deleteConnectionAsync,
|
||||
deleteConnectionsAsync,
|
||||
createConnection,
|
||||
joinConnection,
|
||||
allowConnection,
|
||||
@@ -61,6 +64,7 @@ module Simplex.Messaging.Agent
|
||||
switchConnection,
|
||||
suspendConnection,
|
||||
deleteConnection,
|
||||
deleteConnections,
|
||||
getConnectionServers,
|
||||
getConnectionRatchetAdHash,
|
||||
setSMPServers,
|
||||
@@ -86,7 +90,7 @@ module Simplex.Messaging.Agent
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Logger.Simple (logInfo, showText)
|
||||
import Control.Logger.Simple (logError, logInfo, showText)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
@@ -97,7 +101,7 @@ import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:), (.:.), (.::))
|
||||
import Data.Foldable (foldl')
|
||||
import Data.Functor (($>))
|
||||
import Data.List (deleteFirstsBy, find)
|
||||
import Data.List (deleteFirstsBy, find, (\\))
|
||||
import Data.List.NonEmpty (NonEmpty (..), (<|))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
@@ -110,6 +114,7 @@ import Data.Time.Clock.System (systemToUTCTime)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Lock (withLock)
|
||||
import Simplex.Messaging.Agent.NtfSubSupervisor
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
@@ -124,13 +129,13 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, sameSrvAddr, sameSrvAddr')
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, protoServer, sameSrvAddr')
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import System.Random (randomR)
|
||||
import UnliftIO.Async (async, mapConcurrently, race_)
|
||||
import UnliftIO.Async (async, race_)
|
||||
import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
@@ -143,7 +148,7 @@ getSMPAgentClient cfg initServers = newSMPAgentEnv cfg >>= runReaderT runAgent
|
||||
where
|
||||
runAgent = do
|
||||
c <- getAgentClient initServers
|
||||
void $ race_ (subscriber c) (runNtfSupervisor c) `forkFinally` const (disconnectAgentClient c)
|
||||
void $ raceAny_ [subscriber c, runNtfSupervisor c, cleanupManager c] `forkFinally` const (disconnectAgentClient c)
|
||||
pure c
|
||||
|
||||
disconnectAgentClient :: MonadUnliftIO m => AgentClient -> m ()
|
||||
@@ -155,16 +160,22 @@ disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns}} = do
|
||||
resumeAgentClient :: MonadIO m => AgentClient -> m ()
|
||||
resumeAgentClient c = atomically $ writeTVar (active c) True
|
||||
|
||||
-- |
|
||||
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
|
||||
|
||||
createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
|
||||
createUser c = withAgentEnv c . createUser' c
|
||||
|
||||
-- | Delete user record optionally deleting all user's connections on SMP servers
|
||||
deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m ()
|
||||
deleteUser c = withAgentEnv c .: deleteUser' c
|
||||
|
||||
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
|
||||
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode
|
||||
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
createConnectionAsync c userId corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c userId corrId enableNtfs cMode
|
||||
|
||||
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
|
||||
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs
|
||||
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnectionAsync c userId corrId enableNtfs = withAgentEnv c .: joinConnAsync c userId corrId enableNtfs
|
||||
|
||||
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
|
||||
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -183,16 +194,20 @@ switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -
|
||||
switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c
|
||||
|
||||
-- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response
|
||||
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c
|
||||
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnectionAsync c = withAgentEnv c . deleteConnectionAsync' c
|
||||
|
||||
-- -- | Delete SMP agent connections using batch commands asynchronously, no synchronous response
|
||||
deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> [ConnId] -> m ()
|
||||
deleteConnectionsAsync c = withAgentEnv c . deleteConnectionsAsync' c
|
||||
|
||||
-- | Create SMP agent connection (NEW command)
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
createConnection c enableNtfs cMode clientData = withAgentEnv c $ newConn c "" False enableNtfs cMode clientData
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
createConnection c userId enableNtfs cMode clientData = withAgentEnv c $ newConn c userId "" enableNtfs cMode clientData
|
||||
|
||||
-- | Join SMP agent connection (JOIN command)
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnection c userId enableNtfs = withAgentEnv c .: joinConn c userId "" False enableNtfs
|
||||
|
||||
-- | Allow connection to continue after CONF notification (LET command)
|
||||
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -247,6 +262,10 @@ suspendConnection c = withAgentEnv c . suspendConnection' c
|
||||
deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnection c = withAgentEnv c . deleteConnection' c
|
||||
|
||||
-- | Delete multiple connections, batching commands when possible
|
||||
deleteConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
deleteConnections c = withAgentEnv c . deleteConnections' c
|
||||
|
||||
-- | get servers used for connection
|
||||
getConnectionServers :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
getConnectionServers c = withAgentEnv c . getConnectionServers' c
|
||||
@@ -256,12 +275,12 @@ getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m By
|
||||
getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c
|
||||
|
||||
-- | Change servers to be used for creating new queues
|
||||
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers c = withAgentEnv c . setSMPServers' c
|
||||
setSMPServers :: AgentErrorMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers c = withAgentEnv c .: setSMPServers' c
|
||||
|
||||
-- | Test SMP server
|
||||
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
testSMPServerConnection c = withAgentEnv c . runSMPServerTest c
|
||||
testSMPServerConnection :: AgentErrorMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
testSMPServerConnection c = withAgentEnv c .: runSMPServerTest c
|
||||
|
||||
setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m ()
|
||||
setNtfServers c = withAgentEnv c . setNtfServers' c
|
||||
@@ -349,8 +368,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
-- | execute any SMP agent command
|
||||
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
|
||||
processCommand c (connId, cmd) = case cmd of
|
||||
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode Nothing
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo
|
||||
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo
|
||||
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
|
||||
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo
|
||||
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
|
||||
@@ -361,29 +380,53 @@ processCommand c (connId, cmd) = case cmd of
|
||||
OFF -> suspendConnection' c connId $> (connId, OK)
|
||||
DEL -> deleteConnection' c connId $> (connId, OK)
|
||||
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
|
||||
where
|
||||
-- command interface does not support different users
|
||||
userId = 1
|
||||
|
||||
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
newConnAsync c corrId enableNtfs cMode = do
|
||||
g <- asks idsDrg
|
||||
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
|
||||
let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
connId <- withStore c $ \db -> createNewConn db g cData cMode
|
||||
createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m UserId
|
||||
createUser' c srvs = do
|
||||
userId <- withStore' c createUserRecord
|
||||
atomically $ TM.insert userId srvs $ smpServers c
|
||||
pure userId
|
||||
|
||||
deleteUser' :: AgentMonad m => AgentClient -> UserId -> Bool -> m ()
|
||||
deleteUser' c userId delSMPQueues = do
|
||||
if delSMPQueues
|
||||
then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c
|
||||
else withStore c (`deleteUserRecord` userId)
|
||||
atomically $ TM.delete userId $ smpServers c
|
||||
where
|
||||
delUser =
|
||||
whenM (withStore' c (`deleteUserWithoutConns` userId)) $
|
||||
atomically $ writeTBQueue (subQ c) ("", "", DEL_USER userId)
|
||||
|
||||
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
newConnAsync c userId corrId enableNtfs cMode = do
|
||||
connId <- newConnNoQueues c userId "" enableNtfs cMode
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode)
|
||||
pure connId
|
||||
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
|
||||
newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId
|
||||
newConnNoQueues c userId connId enableNtfs cMode = do
|
||||
g <- asks idsDrg
|
||||
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
|
||||
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
withStore c $ \db -> createNewConn db g cData cMode
|
||||
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case crAgentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo =
|
||||
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _cInfo =
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -397,9 +440,9 @@ acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> Invitat
|
||||
acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do
|
||||
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
|
||||
withStore c (`getConn` contactConnId) >>= \case
|
||||
SomeConn _ ContactConnection {} -> do
|
||||
SomeConn _ (ContactConnection ConnData {userId} _) -> do
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConnAsync c corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
@@ -419,23 +462,21 @@ ackMessageAsync' c corrId connId msgId = do
|
||||
(RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId
|
||||
enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId
|
||||
|
||||
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
deleteConnectionAsync' c@AgentClient {subQ} corrId connId = withConnLock c connId "deleteConnectionAsync" $ do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection _ (rq :| _) _ -> enqueueDelete rq
|
||||
RcvConnection _ rq -> enqueueDelete rq
|
||||
ContactConnection _ rq -> enqueueDelete rq
|
||||
SndConnection _ _ -> delete
|
||||
NewConnection _ -> delete
|
||||
where
|
||||
enqueueDelete :: RcvQueue -> m ()
|
||||
enqueueDelete RcvQueue {server} = do
|
||||
withStore' c $ \db -> setConnDeleted db connId
|
||||
disableConn c connId
|
||||
enqueueCommand c corrId connId (Just server) $ AInternalCommand ICDeleteConn
|
||||
delete :: m ()
|
||||
delete = withStore' c (`deleteConn` connId) >> atomically (writeTBQueue subQ (corrId, connId, OK))
|
||||
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnectionAsync' c connId = deleteConnectionsAsync' c [connId]
|
||||
|
||||
deleteConnectionsAsync' :: AgentMonad m => AgentClient -> [ConnId] -> m ()
|
||||
deleteConnectionsAsync' = deleteConnectionsAsync_ $ pure ()
|
||||
|
||||
deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> [ConnId] -> m ()
|
||||
deleteConnectionsAsync_ onSuccess c connIds = case connIds of
|
||||
[] -> onSuccess
|
||||
_ -> do
|
||||
(_, rqs, connIds') <- prepareDeleteConnections_ getConn c connIds
|
||||
withStore' c $ forM_ connIds' . setConnDeleted
|
||||
void . forkIO $
|
||||
withLock (deleteLock c) "deleteConnectionsAsync" $
|
||||
deleteConnQueues c True rqs >> onSuccess
|
||||
|
||||
-- | Add connection to the new receive queue
|
||||
switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
@@ -444,46 +485,42 @@ switchConnectionAsync' c corrId connId =
|
||||
SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c connId asyncMode enableNtfs cMode clientData =
|
||||
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode clientData
|
||||
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c userId connId enableNtfs cMode clientData =
|
||||
getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData
|
||||
|
||||
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c connId asyncMode enableNtfs cMode clientData srv = do
|
||||
newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c userId connId enableNtfs cMode clientData srv = do
|
||||
connId' <- newConnNoQueues c userId connId enableNtfs cMode
|
||||
newRcvConnSrv c userId connId' enableNtfs cMode clientData srv
|
||||
|
||||
newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
|
||||
newRcvConnSrv c userId connId enableNtfs cMode clientData srv = do
|
||||
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
|
||||
(q, qUri) <- newRcvQueue c "" srv smpClientVRange
|
||||
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
|
||||
let rq = (q :: RcvQueue) {connId = connId'}
|
||||
(rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange `catchError` \e -> liftIO (print e) >> throwError e
|
||||
void . withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
addSubscription c rq
|
||||
when enableNtfs $ do
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId', NSCCreate)
|
||||
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
|
||||
let crData = ConnReqUriData simplexChat smpAgentVRange [qUri] clientData
|
||||
case cMode of
|
||||
SCMContact -> pure (connId', CRContactUri crData)
|
||||
SCMContact -> pure (connId, CRContactUri crData)
|
||||
SCMInvitation -> do
|
||||
(pk1, pk2, e2eRcvParams) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
|
||||
withStore' c $ \db -> createRatchetX3dhKeys db connId' pk1 pk2
|
||||
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
|
||||
where
|
||||
setUpConn True rq _ = do
|
||||
void . withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
pure connId
|
||||
setUpConn False rq connAgentVersion = do
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
|
||||
withStore c $ \db -> createRcvConn db g cData rq cMode
|
||||
withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2
|
||||
pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
|
||||
|
||||
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c connId asyncMode enableNtfs cReq cInfo = do
|
||||
joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c userId connId asyncMode enableNtfs cReq cInfo = do
|
||||
srv <- case cReq of
|
||||
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
|
||||
getNextSMPServer c [qServer q]
|
||||
_ -> getSMPServer c
|
||||
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
|
||||
getNextSMPServer c userId [qServer q]
|
||||
_ -> getSMPServer c userId
|
||||
joinConnSrv c userId connId asyncMode enableNtfs cReq cInfo srv
|
||||
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
|
||||
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServerWithAuth -> m ConnId
|
||||
joinConnSrv c userId connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) cInfo srv = do
|
||||
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
|
||||
case ( qUri `compatibleVersion` smpClientVRange,
|
||||
e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange,
|
||||
@@ -493,9 +530,9 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
|
||||
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
|
||||
(_, rcDHRs) <- liftIO C.generateKeyPair'
|
||||
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
|
||||
q <- newSndQueue "" qInfo
|
||||
q <- newSndQueue userId "" qInfo
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, deleted = False}
|
||||
connId' <- setUpConn asyncMode cData q rc
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
@@ -520,23 +557,23 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri ConnReqUriData {crAge
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnSrv c connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
|
||||
joinConnSrv c userId connId False enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case ( qUri `compatibleVersion` clientVRange,
|
||||
crAgentVRange `compatibleVersion` aVRange
|
||||
) of
|
||||
(Just qInfo, Just vrsn) -> do
|
||||
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation Nothing srv
|
||||
sendInvitation c qInfo vrsn cReq cInfo
|
||||
(connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing srv
|
||||
sendInvitation c userId qInfo vrsn cReq cInfo
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
joinConnSrv _c _userId _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion
|
||||
createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c userId connId srv $ versionToRange smpClientVersion
|
||||
let qInfo = toVersionT qUri smpClientVersion
|
||||
addSubscription c rq
|
||||
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
|
||||
@@ -564,9 +601,9 @@ acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId
|
||||
acceptContact' c connId enableNtfs invId ownConnInfo = withConnLock c connId "acceptContact" $ do
|
||||
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
|
||||
withStore c (`getConn` contactConnId) >>= \case
|
||||
SomeConn _ ContactConnection {} -> do
|
||||
SomeConn _ (ContactConnection ConnData {userId} _) -> do
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
joinConn c userId connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
@@ -577,32 +614,16 @@ rejectContact' c contactConnId invId =
|
||||
withStore c $ \db -> deleteInvitation db contactConnId invId
|
||||
|
||||
-- | Subscribe to receive connection messages (SUB command) in Reader monad
|
||||
subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection' c connId = do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
resumeConnCmds c connId
|
||||
case conn of
|
||||
DuplexConnection cData (rq :| rqs) sqs -> do
|
||||
mapM_ (resumeMsgDelivery c cData) sqs
|
||||
subscribe cData rq
|
||||
mapM_ (\q -> subscribeQueue c q `catchError` \_ -> pure ()) rqs
|
||||
SndConnection cData sq -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
case status (sq :: SndQueue) of
|
||||
Confirmed -> pure ()
|
||||
Active -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ INTERNAL "unexpected queue status"
|
||||
RcvConnection cData rq -> subscribe cData rq
|
||||
ContactConnection cData rq -> subscribe cData rq
|
||||
NewConnection _ -> pure ()
|
||||
where
|
||||
subscribe :: ConnData -> RcvQueue -> m ()
|
||||
subscribe ConnData {enableNtfs} rq = do
|
||||
subscribeQueue c rq
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId, if enableNtfs then NSCCreate else NSCDelete)
|
||||
subscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId]
|
||||
|
||||
type QSubResult = (QueueStatus, Either AgentErrorType ())
|
||||
toConnResult :: AgentMonad m => ConnId -> Map ConnId (Either AgentErrorType ()) -> m ()
|
||||
toConnResult connId rs = case M.lookup connId rs of
|
||||
Just (Right ()) -> when (M.size rs > 1) $ logError $ T.pack $ "too many results " <> show (M.size rs)
|
||||
Just (Left e) -> throwError e
|
||||
_ -> throwError $ INTERNAL $ "no result for connection " <> B.unpack connId
|
||||
|
||||
type QCmdResult = (QueueStatus, Either AgentErrorType ())
|
||||
|
||||
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribeConnections' _ [] = pure M.empty
|
||||
@@ -611,10 +632,9 @@ subscribeConnections' c connIds = do
|
||||
let (errs, cs) = M.mapEither id conns
|
||||
errs' = M.map (Left . storeError) errs
|
||||
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
|
||||
srvRcvQs :: Map SMPServer [RcvQueue] = M.foldl' (foldl' addRcvQueue) M.empty rcvQs
|
||||
mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs
|
||||
mapM_ (resumeConnCmds c) $ M.keys cs
|
||||
rcvRs <- connResults . concat <$> mapConcurrently subscribe (M.assocs srvRcvQs)
|
||||
rcvRs <- connResults <$> subscribeQueues c (concat $ M.elems rcvQs)
|
||||
ns <- asks ntfSupervisor
|
||||
tkn <- readTVarIO (ntfTkn ns)
|
||||
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns
|
||||
@@ -622,9 +642,9 @@ subscribeConnections' c connIds = do
|
||||
notifyResultError rs
|
||||
pure rs
|
||||
where
|
||||
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) (NonEmpty RcvQueue)
|
||||
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
|
||||
rcvQueueOrResult (SomeConn _ conn) = case conn of
|
||||
DuplexConnection _ rqs _ -> Right rqs
|
||||
DuplexConnection _ rqs _ -> Right $ L.toList rqs
|
||||
SndConnection _ sq -> Left $ sndSubResult sq
|
||||
RcvConnection _ rq -> Right [rq]
|
||||
ContactConnection _ rq -> Right [rq]
|
||||
@@ -634,21 +654,17 @@ subscribeConnections' c connIds = do
|
||||
Confirmed -> Right ()
|
||||
Active -> Left $ CONN SIMPLEX
|
||||
_ -> Left $ INTERNAL "unexpected queue status"
|
||||
addRcvQueue :: Map SMPServer [RcvQueue] -> RcvQueue -> Map SMPServer [RcvQueue]
|
||||
addRcvQueue m rq@RcvQueue {server} = M.alter (Just . maybe [rq] (rq :)) server m
|
||||
subscribe :: (SMPServer, [RcvQueue]) -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
subscribe (srv, qs) = snd <$> subscribeQueues c srv qs
|
||||
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
|
||||
connResults = M.map snd . foldl' addResult M.empty
|
||||
where
|
||||
-- collects results by connection ID
|
||||
addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QSubResult
|
||||
addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult
|
||||
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
|
||||
-- combines two results for one connection, by using only Active queues (if there is at least one Active queue)
|
||||
combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult
|
||||
combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult
|
||||
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
|
||||
combineRes r' _ = Just r'
|
||||
order :: QSubResult -> Int
|
||||
order :: QCmdResult -> Int
|
||||
order (Active, Right _) = 1
|
||||
order (Active, _) = 2
|
||||
order (_, Right _) = 3
|
||||
@@ -675,10 +691,7 @@ subscribeConnections' c connIds = do
|
||||
writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
|
||||
resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
resubscribeConnection' c connId =
|
||||
unlessM
|
||||
(atomically $ hasActiveSubscription c connId)
|
||||
(subscribeConnection' c connId)
|
||||
resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId]
|
||||
|
||||
resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
resubscribeConnections' _ [] = pure M.empty
|
||||
@@ -791,21 +804,21 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
atomically $ beginAgentOperation c AOSndNetwork
|
||||
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
|
||||
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
|
||||
Right (corrId, connId, cmd) -> processCmd (riFast ri) corrId connId cmdId cmd
|
||||
Right cmd -> processCmd (riFast ri) cmdId cmd
|
||||
where
|
||||
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
|
||||
processCmd ri corrId connId cmdId command = case command of
|
||||
processCmd :: RetryInterval -> AsyncCmdId -> PendingCommand -> m ()
|
||||
processCmd ri cmdId PendingCommand {corrId, userId, connId, command} = case command of
|
||||
AClientCommand cmd -> case cmd of
|
||||
NEW enableNtfs (ACM cMode) -> noServer $ do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
|
||||
(_, cReq) <- newConnSrv c connId True enableNtfs cMode Nothing srv
|
||||
(_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing srv
|
||||
notify $ INV (ACR cMode cReq)
|
||||
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do
|
||||
let initUsed = [qServer q]
|
||||
usedSrvs <- newTVarIO initUsed
|
||||
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
|
||||
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
|
||||
void $ joinConnSrv c userId connId True enableNtfs cReq connInfo srv
|
||||
notify OK
|
||||
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
|
||||
ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK
|
||||
@@ -827,23 +840,8 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
secure rq senderKey
|
||||
when (duplexHandshake cData == Just True) . void $
|
||||
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
|
||||
ICDeleteConn ->
|
||||
withServer $ \srv -> tryWithLock "ICDeleteConn" $ do
|
||||
SomeConn _ conn <- withStore c $ \db -> getAnyConn db connId True
|
||||
case conn of
|
||||
DuplexConnection _ (rq :| rqs) _ -> delete srv rq $ case rqs of
|
||||
[] -> notify OK
|
||||
RcvQueue {server = srv'} : _ -> enqueue srv'
|
||||
RcvConnection _ rq -> delete srv rq $ notify OK
|
||||
ContactConnection _ rq -> delete srv rq $ notify OK
|
||||
_ -> internalErr "command requires connection with rcv queue"
|
||||
where
|
||||
delete :: SMPServer -> RcvQueue -> m () -> m ()
|
||||
delete srv rq@RcvQueue {server} next
|
||||
| sameSrvAddr srv server = deleteConnQueue c rq >> next
|
||||
| otherwise = enqueue server
|
||||
enqueue :: SMPServer -> m ()
|
||||
enqueue srv = enqueueCommand c corrId connId (Just srv) $ AInternalCommand ICDeleteConn
|
||||
-- ICDeleteConn is no longer used, but it can be present in old client databases
|
||||
ICDeleteConn -> withStore' c (`deleteCommand` cmdId)
|
||||
ICQSecure rId senderKey ->
|
||||
withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) ->
|
||||
case find (sameQueue (srv, rId)) rqs of
|
||||
@@ -860,7 +858,7 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
|
||||
| otherwise -> do
|
||||
deleteQueue c rq'
|
||||
withStore' c $ \db -> deleteConnRcvQueue db connId rq'
|
||||
withStore' c $ \db -> deleteConnRcvQueue db rq'
|
||||
when (enableNtfs cData) $ do
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
|
||||
@@ -901,10 +899,11 @@ runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServerWithAuth -> m ()) -> m ()
|
||||
withNextSrv usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c used
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId used
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
|
||||
srvs_ <- TM.lookup userId $ smpServers c
|
||||
let unused = maybe [] ((\\ used) . map protoServer . L.toList) srvs_
|
||||
used' = if null unused then initUsed else srv : used
|
||||
writeTVar usedSrvs $! used'
|
||||
action srvAuth
|
||||
-- ^ ^ ^ async command processing /
|
||||
@@ -978,7 +977,7 @@ getPendingMsgQ c SndQueue {server, sndId} = do
|
||||
pure q
|
||||
|
||||
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
|
||||
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
|
||||
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, duplexHandshake} sq = do
|
||||
(mq, qLock) <- atomically $ getPendingMsgQ c sq
|
||||
ri <- asks $ messageRetryInterval . config
|
||||
forever $ do
|
||||
@@ -1067,7 +1066,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
-- and this branch should never be reached as receive is created before the confirmation,
|
||||
-- so the condition is not necessary here, strictly speaking.
|
||||
_ -> unless (duplexHandshake == Just True) $ do
|
||||
srv <- getSMPServer c
|
||||
srv <- getSMPServer c userId
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
|
||||
AM_A_MSG_ -> notify $ SENT mId
|
||||
@@ -1142,12 +1141,12 @@ switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
switchConnection' c connId = withConnLock c connId "switchConnection" $ do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
|
||||
DuplexConnection cData@ConnData {userId} rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
-- try to get the server that is different from all queues, or at least from the primary rcv queue
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
|
||||
srv' <- if srv == server then getNextSMPServer c [server] else pure srvAuth
|
||||
(q, qUri) <- newRcvQueue c connId srv' clientVRange
|
||||
srvAuth@(ProtoServerWithAuth srv _) <- getNextSMPServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
|
||||
srv' <- if srv == server then getNextSMPServer c userId [server] else pure srvAuth
|
||||
(q, qUri) <- newRcvQueue c userId connId srv' clientVRange
|
||||
let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
void . withStore c $ \db -> addConnRcvQueue db connId rq'
|
||||
addSubscription c rq'
|
||||
@@ -1173,24 +1172,18 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
|
||||
NewConnection _ -> throwError $ CMD PROHIBITED
|
||||
|
||||
-- | Delete SMP agent connection (DEL command) in Reader monad
|
||||
-- unlike deleteConnectionAsync, this function does not mark connection as deleted in case of deletion failure
|
||||
-- currently it is used only in tests
|
||||
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection _ rqs _ -> mapM_ (deleteConnQueue c) rqs >> disableConn c connId >> deleteConn'
|
||||
RcvConnection _ rq -> delete rq
|
||||
ContactConnection _ rq -> delete rq
|
||||
SndConnection _ _ -> deleteConn'
|
||||
NewConnection _ -> deleteConn'
|
||||
where
|
||||
delete :: RcvQueue -> m ()
|
||||
delete rq = deleteConnQueue c rq >> disableConn c connId >> deleteConn'
|
||||
deleteConn' = withStore' c (`deleteConn` connId)
|
||||
deleteConnection' c connId = toConnResult connId =<< deleteConnections' c [connId]
|
||||
|
||||
deleteConnQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
deleteConnQueue c rq@RcvQueue {connId} = do
|
||||
deleteQueue c rq
|
||||
withStore' c $ \db -> deleteConnRcvQueue db connId rq
|
||||
connRcvQueues :: Connection d -> [RcvQueue]
|
||||
connRcvQueues = \case
|
||||
DuplexConnection _ rqs _ -> L.toList rqs
|
||||
RcvConnection _ rq -> [rq]
|
||||
ContactConnection _ rq -> [rq]
|
||||
SndConnection _ _ -> []
|
||||
NewConnection _ -> []
|
||||
|
||||
disableConn :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
disableConn c connId = do
|
||||
@@ -1198,6 +1191,94 @@ disableConn c connId = do
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete)
|
||||
|
||||
-- Unlike deleteConnectionsAsync, this function does not mark connections as deleted in case of deletion failure.
|
||||
deleteConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
deleteConnections' = deleteConnections_ getConn False
|
||||
|
||||
deleteDeletedConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
deleteDeletedConns = deleteConnections_ getDeletedConn True
|
||||
|
||||
prepareDeleteConnections_ ::
|
||||
forall m.
|
||||
AgentMonad m =>
|
||||
(DB.Connection -> ConnId -> IO (Either StoreError SomeConn)) ->
|
||||
AgentClient ->
|
||||
[ConnId] ->
|
||||
m (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId])
|
||||
prepareDeleteConnections_ getConnection c connIds = do
|
||||
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConnection)
|
||||
let (errs, cs) = M.mapEither id conns
|
||||
errs' = M.map (Left . storeError) errs
|
||||
(delRs, rcvQs) = M.mapEither rcvQueues cs
|
||||
rqs = concat $ M.elems rcvQs
|
||||
connIds' = M.keys rcvQs
|
||||
forM_ connIds' $ disableConn c
|
||||
withStore' c $ forM_ (M.keys delRs) . deleteConn
|
||||
pure (errs' <> delRs, rqs, connIds')
|
||||
where
|
||||
rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
|
||||
rcvQueues (SomeConn _ conn) = case connRcvQueues conn of
|
||||
[] -> Left $ Right ()
|
||||
rqs -> Right rqs
|
||||
|
||||
deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
deleteConnQueues c ntf rqs = do
|
||||
rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs)
|
||||
forM_ (M.assocs rs) $ \case
|
||||
(connId, Right _) -> withStore' c (`deleteConn` connId) >> notify ("", connId, DEL_CONN)
|
||||
_ -> pure ()
|
||||
pure rs
|
||||
where
|
||||
deleteQueueRecs :: [(RcvQueue, Either AgentErrorType ())] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
deleteQueueRecs rs = do
|
||||
maxErrs <- asks $ deleteErrorCount . config
|
||||
forM rs $ \(rq, r) -> do
|
||||
r' <- case r of
|
||||
Right _ -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq Nothing $> r
|
||||
Left e
|
||||
| temporaryOrHostError e && deleteErrors rq + 1 < maxErrs -> withStore' c (`incRcvDeleteErrors` rq) $> r
|
||||
| otherwise -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq (Just e) $> Right ()
|
||||
pure (rq, r')
|
||||
notifyRQ rq e_ = notify ("", qConnId rq, DEL_RCVQ (qServer rq) (queueId rq) e_)
|
||||
notify = when ntf . atomically . writeTBQueue (subQ c)
|
||||
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
|
||||
connResults = M.map snd . foldl' addResult M.empty
|
||||
where
|
||||
-- collects results by connection ID
|
||||
addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult
|
||||
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
|
||||
-- combines two results for one connection, by prioritizing errors in Active queues
|
||||
combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult
|
||||
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
|
||||
combineRes r' _ = Just r'
|
||||
order :: QCmdResult -> Int
|
||||
order (Active, Left _) = 1
|
||||
order (_, Left _) = 2
|
||||
order _ = 3
|
||||
|
||||
deleteConnections_ ::
|
||||
forall m.
|
||||
AgentMonad m =>
|
||||
(DB.Connection -> ConnId -> IO (Either StoreError SomeConn)) ->
|
||||
Bool ->
|
||||
AgentClient ->
|
||||
[ConnId] ->
|
||||
m (Map ConnId (Either AgentErrorType ()))
|
||||
deleteConnections_ _ _ _ [] = pure M.empty
|
||||
deleteConnections_ getConnection ntf c connIds = do
|
||||
(rs, rqs, _) <- prepareDeleteConnections_ getConnection c connIds
|
||||
rcvRs <- deleteConnQueues c ntf rqs
|
||||
let rs' = M.union rs rcvRs
|
||||
notifyResultError rs'
|
||||
pure rs'
|
||||
where
|
||||
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
|
||||
notifyResultError rs = do
|
||||
let actual = M.size rs
|
||||
expected = length connIds
|
||||
when (actual /= expected) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
|
||||
getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
getConnectionServers' c connId = do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
@@ -1217,8 +1298,8 @@ connectionStats = \case
|
||||
NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []}
|
||||
|
||||
-- | Change servers to be used for creating new queues, in Reader monad
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers' c = atomically . writeTVar (smpServers c)
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> UserId -> NonEmpty SMPServerWithAuth -> m ()
|
||||
setSMPServers' c userId srvs = atomically $ TM.insert userId srvs $ smpServers c
|
||||
|
||||
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus
|
||||
registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
|
||||
@@ -1450,15 +1531,16 @@ execAgentStoreSQL' :: AgentMonad m => AgentClient -> Text -> m [Text]
|
||||
execAgentStoreSQL' c sql = withStore' c (`execSQL` sql)
|
||||
|
||||
debugAgentLocks' :: AgentMonad m => AgentClient -> m AgentLocks
|
||||
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs} = do
|
||||
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do
|
||||
connLocks <- getLocks cs
|
||||
srvLocks <- getLocks rs
|
||||
pure AgentLocks {connLocks, srvLocks}
|
||||
delLock <- atomically $ tryReadTMVar d
|
||||
pure AgentLocks {connLocks, srvLocks, delLock}
|
||||
where
|
||||
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
|
||||
|
||||
getSMPServer :: AgentMonad m => AgentClient -> m SMPServerWithAuth
|
||||
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
|
||||
getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth
|
||||
getSMPServer c userId = withUserServers c userId pickServer
|
||||
|
||||
pickServer :: AgentMonad m => NonEmpty SMPServerWithAuth -> m SMPServerWithAuth
|
||||
pickServer = \case
|
||||
@@ -1467,13 +1549,18 @@ pickServer = \case
|
||||
gen <- asks randomServer
|
||||
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
|
||||
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServerWithAuth
|
||||
getNextSMPServer c usedSrvs = do
|
||||
srvs <- readTVarIO $ smpServers c
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> UserId -> [SMPServer] -> m SMPServerWithAuth
|
||||
getNextSMPServer c userId usedSrvs = withUserServers c userId $ \srvs ->
|
||||
case L.nonEmpty $ deleteFirstsBy sameSrvAddr' (L.toList srvs) (map noAuthSrv usedSrvs) of
|
||||
Just srvs' -> pickServer srvs'
|
||||
_ -> pickServer srvs
|
||||
|
||||
withUserServers :: AgentMonad m => AgentClient -> UserId -> (NonEmpty SMPServerWithAuth -> m a) -> m a
|
||||
withUserServers c userId action =
|
||||
atomically (TM.lookup userId $ smpServers c) >>= \case
|
||||
Just srvs -> action srvs
|
||||
_ -> throwError $ INTERNAL "unknown userId - no SMP servers"
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
subscriber c@AgentClient {msgQ} = forever $ do
|
||||
t <- atomically $ readTBQueue msgQ
|
||||
@@ -1482,13 +1569,26 @@ subscriber c@AgentClient {msgQ} = forever $ do
|
||||
Left e -> liftIO $ print e
|
||||
Right _ -> return ()
|
||||
|
||||
cleanupManager :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
cleanupManager c = do
|
||||
threadDelay =<< asks (initialCleanupDelay . config)
|
||||
int <- asks (cleanupInterval . config)
|
||||
forever $ do
|
||||
void . runExceptT $
|
||||
withLock (deleteLock c) "cleanupManager" $ do
|
||||
void $ withStore' c getDeletedConns >>= deleteDeletedConns c
|
||||
withStore' c deleteUsersWithoutConns >>= mapM_ notifyUserDeleted
|
||||
threadDelay int
|
||||
where
|
||||
notifyUserDeleted userId = atomically $ writeTBQueue (subQ c) ("", "", DEL_USER userId)
|
||||
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
|
||||
processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) = do
|
||||
processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, sessId, rId, cmd) = do
|
||||
(rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId)
|
||||
processSMP rq conn $ connData conn
|
||||
where
|
||||
processSMP :: RcvQueue -> Connection c -> ConnData -> m ()
|
||||
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {connId, duplexHandshake} = withConnLock c connId "processSMP" $
|
||||
processSMP rq@RcvQueue {e2ePrivKey, e2eDhSecret, status} conn cData@ConnData {userId, connId, duplexHandshake} = withConnLock c connId "processSMP" $
|
||||
case cmd of
|
||||
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
|
||||
handleNotifyAck $
|
||||
@@ -1587,7 +1687,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
handleNotifyAck :: m () -> m ()
|
||||
handleNotifyAck m = m `catchError` \e -> notify (ERR e) >> ack
|
||||
SMP.END ->
|
||||
atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND)
|
||||
atomically (TM.lookup tSess smpClients $>>= tryReadTMVar >>= processEND)
|
||||
>>= logServer "<--" c srv rId
|
||||
where
|
||||
processEND = \case
|
||||
@@ -1724,7 +1824,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
|
||||
(Just _, _) -> qError "QADD: queue address is already used in connection"
|
||||
(_, Just _replaced@SndQueue {dbQueueId}) -> do
|
||||
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo
|
||||
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo
|
||||
let sq' = (sq_ :: SndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
void . withStore c $ \db -> addConnSndQueue db connId sq'
|
||||
case (sndPublicKey, e2ePubKey) of
|
||||
@@ -1796,12 +1896,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
| otherwise = MsgError MsgDuplicate -- this case is not possible
|
||||
|
||||
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
|
||||
connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
|
||||
connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case qInfo `proveCompatible` clientVRange of
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Just qInfo' -> do
|
||||
sq <- newSndQueue connId qInfo'
|
||||
sq <- newSndQueue userId connId qInfo'
|
||||
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
|
||||
|
||||
@@ -1860,14 +1960,15 @@ agentRatchetDecrypt db connId encAgentMsg = do
|
||||
liftIO $ updateRatchet db connId rc' skippedDiff
|
||||
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
|
||||
|
||||
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue
|
||||
newSndQueue connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
|
||||
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m SndQueue
|
||||
newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
|
||||
C.SignAlg a <- asks $ cmdSignAlg . config
|
||||
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
pure
|
||||
SndQueue
|
||||
{ connId,
|
||||
{ userId,
|
||||
connId,
|
||||
server = smpServer,
|
||||
sndId = senderId,
|
||||
sndPublicKey = Just sndPublicKey,
|
||||
|
||||
@@ -25,7 +25,6 @@ module Simplex.Messaging.Agent.Client
|
||||
closeProtocolServerClients,
|
||||
runSMPServerTest,
|
||||
newRcvQueue,
|
||||
subscribeQueue,
|
||||
subscribeQueues,
|
||||
getQueueMessage,
|
||||
decryptSMPMessage,
|
||||
@@ -54,6 +53,7 @@ module Simplex.Messaging.Agent.Client
|
||||
sendAck,
|
||||
suspendQueue,
|
||||
deleteQueue,
|
||||
deleteQueues,
|
||||
logServer,
|
||||
logSecret,
|
||||
removeSubscription,
|
||||
@@ -96,10 +96,10 @@ import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (isRight, partitionEithers)
|
||||
import Data.Either (lefts, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (partition)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.List (foldl', partition)
|
||||
import Data.List.NonEmpty (NonEmpty (..), (<|))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
@@ -117,6 +117,7 @@ import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction)
|
||||
import Simplex.Messaging.Agent.TAsyncs
|
||||
import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues)
|
||||
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
|
||||
import Simplex.Messaging.Client
|
||||
@@ -131,6 +132,7 @@ import Simplex.Messaging.Parsers (dropPrefix, enumJSON, parse)
|
||||
import Simplex.Messaging.Protocol
|
||||
( AProtocolType (..),
|
||||
BrokerMsg,
|
||||
EntityId,
|
||||
ErrorType,
|
||||
MsgFlags (..),
|
||||
MsgId,
|
||||
@@ -156,7 +158,7 @@ import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async)
|
||||
import UnliftIO (mapConcurrently)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -166,15 +168,19 @@ type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
|
||||
|
||||
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
|
||||
|
||||
type SMPTransportSession = TransportSession SMP.BrokerMsg
|
||||
|
||||
type NtfTransportSession = TransportSession NtfResponse
|
||||
|
||||
data AgentClient = AgentClient
|
||||
{ active :: TVar Bool,
|
||||
rcvQ :: TBQueue (ATransmission 'Client),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
smpServers :: TVar (NonEmpty SMPServerWithAuth),
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
|
||||
smpClients :: TMap SMPTransportSession SMPClientVar,
|
||||
ntfServers :: TVar [NtfServer],
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
ntfClients :: TMap NtfTransportSession NtfClientVar,
|
||||
useNetworkConfig :: TVar NetworkConfig,
|
||||
subscrConns :: TVar (Set ConnId),
|
||||
activeSubs :: TRcvQueues,
|
||||
@@ -194,10 +200,12 @@ data AgentClient = AgentClient
|
||||
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
|
||||
-- locks to prevent concurrent operations with connection
|
||||
connLocks :: TMap ConnId Lock,
|
||||
-- lock to prevent concurrency between periodic and async connection deletions
|
||||
deleteLock :: Lock,
|
||||
-- locks to prevent concurrent reconnections to SMP servers
|
||||
reconnectLocks :: TMap SMPServer Lock,
|
||||
reconnections :: TVar [Async ()],
|
||||
asyncClients :: TVar [Async ()],
|
||||
reconnectLocks :: TMap SMPTransportSession Lock,
|
||||
reconnections :: TAsyncs,
|
||||
asyncClients :: TAsyncs,
|
||||
agentStats :: TMap AgentStatsKey (TVar Int),
|
||||
clientId :: Int,
|
||||
agentEnv :: Env
|
||||
@@ -222,12 +230,18 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
|
||||
data AgentState = ASActive | ASSuspending | ASSuspended
|
||||
deriving (Eq, Show)
|
||||
|
||||
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String}
|
||||
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
data AgentStatsKey = AgentStatsKey {host :: ByteString, clientTs :: ByteString, cmd :: ByteString, res :: ByteString}
|
||||
data AgentStatsKey = AgentStatsKey
|
||||
{ userId :: UserId,
|
||||
host :: ByteString,
|
||||
clientTs :: ByteString,
|
||||
cmd :: ByteString,
|
||||
res :: ByteString
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
|
||||
@@ -259,18 +273,54 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
agentState <- newTVar ASActive
|
||||
getMsgLocks <- TM.empty
|
||||
connLocks <- TM.empty
|
||||
deleteLock <- createLock
|
||||
reconnectLocks <- TM.empty
|
||||
reconnections <- newTVar []
|
||||
asyncClients <- newTVar []
|
||||
reconnections <- newTAsyncs
|
||||
asyncClients <- newTAsyncs
|
||||
agentStats <- TM.empty
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, pendingMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, agentStats, clientId, agentEnv}
|
||||
return
|
||||
AgentClient
|
||||
{ active,
|
||||
rcvQ,
|
||||
subQ,
|
||||
msgQ,
|
||||
smpServers,
|
||||
smpClients,
|
||||
ntfServers,
|
||||
ntfClients,
|
||||
useNetworkConfig,
|
||||
subscrConns,
|
||||
activeSubs,
|
||||
pendingSubs,
|
||||
pendingMsgsQueued,
|
||||
smpQueueMsgQueues,
|
||||
smpQueueMsgDeliveries,
|
||||
connCmdsQueued,
|
||||
asyncCmdQueues,
|
||||
asyncCmdProcesses,
|
||||
ntfNetworkOp,
|
||||
rcvNetworkOp,
|
||||
msgDeliveryOp,
|
||||
sndNetworkOp,
|
||||
databaseOp,
|
||||
agentState,
|
||||
getMsgLocks,
|
||||
connLocks,
|
||||
deleteLock,
|
||||
reconnectLocks,
|
||||
reconnections,
|
||||
asyncClients,
|
||||
agentStats,
|
||||
clientId,
|
||||
agentEnv
|
||||
}
|
||||
|
||||
agentClientStore :: AgentClient -> SQLiteStore
|
||||
agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
|
||||
class ProtocolServerClient msg where
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg)
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (ProtocolClient msg)
|
||||
clientProtocolError :: ErrorType -> AgentErrorType
|
||||
|
||||
instance ProtocolServerClient BrokerMsg where
|
||||
@@ -281,19 +331,19 @@ instance ProtocolServerClient NtfResponse where
|
||||
getProtocolServerClient = getNtfServerClient
|
||||
clientProtocolError = NTF
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
|
||||
atomically (getClientVar srv smpClients)
|
||||
atomically (getClientVar tSess smpClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv smpClients connectClient reconnectClient)
|
||||
(waitForProtocolClient c srv)
|
||||
(newProtocolClient c tSess smpClients connectClient reconnectSMPClient)
|
||||
(waitForProtocolClient c tSess)
|
||||
where
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
u <- askUnliftIO
|
||||
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
|
||||
liftEitherError (protocolClientError SMP $ B.unpack $ strEncode srv) (getProtocolClient tSess cfg (Just msgQ) $ clientDisconnected u)
|
||||
|
||||
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
|
||||
clientDisconnected u client = do
|
||||
@@ -302,86 +352,84 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
where
|
||||
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
(qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c
|
||||
TM.delete tSess smpClients
|
||||
qs <- RQ.getDelSessQueues tSess $ activeSubs c
|
||||
mapM_ (`RQ.addQueue` pendingSubs c) qs
|
||||
pure (qs, S.toList conns)
|
||||
let cs = S.fromList $ map qConnId qs
|
||||
cs' <- RQ.getConns $ activeSubs c
|
||||
pure (qs, S.toList $ cs `S.difference` cs')
|
||||
|
||||
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
|
||||
serverDown (qs, conns) = whenM (readTVarIO active) $ do
|
||||
incClientStat c client "DISCONNECT" ""
|
||||
incClientStat c userId client "DISCONNECT" ""
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
unless (null conns) $ notifySub "" $ DOWN srv conns
|
||||
unless (null qs) $ do
|
||||
atomically $ mapM_ (releaseGetLock c) qs
|
||||
unliftIO u reconnectServer
|
||||
unliftIO u $ reconnectServer c tSess
|
||||
|
||||
reconnectServer :: m ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections c) (a :)
|
||||
notifySub :: ConnId -> ACommand 'Agent -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
|
||||
tryReconnectClient :: m ()
|
||||
tryReconnectClient = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop ->
|
||||
reconnectClient `catchError` const loop
|
||||
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withLockMap_ (reconnectLocks c) srv "reconnect" $
|
||||
atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe
|
||||
where
|
||||
resubscribe :: [RcvQueue] -> m ()
|
||||
resubscribe qs = do
|
||||
connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar)
|
||||
cs <- atomically . RQ.getConns $ activeSubs c
|
||||
(client_, rs) <- subscribeQueues c srv qs
|
||||
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
|
||||
liftIO $ do
|
||||
unless connected . forM_ client_ $ \cl -> do
|
||||
incClientStat c cl "CONNECT" ""
|
||||
notifySub "" $ hostEvent CONNECT cl
|
||||
let conns = S.toList $ S.fromList okConns `S.difference` cs
|
||||
unless (null conns) $ notifySub "" $ UP srv conns
|
||||
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
|
||||
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
|
||||
mapM_ (throwError . snd) $ listToMaybe tempErrs
|
||||
reconnectServer :: AgentMonad m => AgentClient -> SMPTransportSession -> m ()
|
||||
reconnectServer c tSess = newAsyncAction tryReconnectSMPClient $ reconnections c
|
||||
where
|
||||
tryReconnectSMPClient aId = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop ->
|
||||
reconnectSMPClient c tSess `catchError` const loop
|
||||
atomically . removeAsyncAction aId $ reconnections c
|
||||
|
||||
reconnectSMPClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m ()
|
||||
reconnectSMPClient c tSess@(_, srv, _) =
|
||||
withLockMap_ (reconnectLocks c) tSess "reconnect" $
|
||||
atomically (RQ.getSessQueues tSess $ pendingSubs c) >>= mapM_ resubscribe . L.nonEmpty
|
||||
where
|
||||
resubscribe :: NonEmpty RcvQueue -> m ()
|
||||
resubscribe qs = do
|
||||
cs <- atomically . RQ.getConns $ activeSubs c
|
||||
rs <- subscribeQueues c $ L.toList qs
|
||||
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
|
||||
liftIO $ do
|
||||
let conns = S.toList $ S.fromList okConns `S.difference` cs
|
||||
unless (null conns) $ notifySub "" $ UP srv conns
|
||||
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
|
||||
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
|
||||
mapM_ (throwError . snd) $ listToMaybe tempErrs
|
||||
notifySub :: ConnId -> ACommand 'Agent -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients} srv = do
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfTransportSession -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
|
||||
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
|
||||
atomically (getClientVar srv ntfClients)
|
||||
atomically (getClientVar tSess ntfClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv ntfClients connectClient $ pure ())
|
||||
(waitForProtocolClient c srv)
|
||||
(newProtocolClient c tSess ntfClients connectClient $ \_ _ -> pure ())
|
||||
(waitForProtocolClient c tSess)
|
||||
where
|
||||
connectClient :: m NtfClient
|
||||
connectClient = do
|
||||
cfg <- getClientConfig c ntfCfg
|
||||
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient srv cfg Nothing clientDisconnected)
|
||||
liftEitherError (protocolClientError NTF $ B.unpack $ strEncode srv) (getProtocolClient tSess cfg Nothing clientDisconnected)
|
||||
|
||||
clientDisconnected :: NtfClient -> IO ()
|
||||
clientDisconnected client = do
|
||||
atomically $ TM.delete srv ntfClients
|
||||
incClientStat c client "DISCONNECT" ""
|
||||
atomically $ TM.delete tSess ntfClients
|
||||
incClientStat c userId client "DISCONNECT" ""
|
||||
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients
|
||||
getClientVar :: forall a s. TransportSession s -> TMap (TransportSession s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
getClientVar tSess clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup tSess clients
|
||||
where
|
||||
newClientVar :: STM (TMVar a)
|
||||
newClientVar = do
|
||||
var <- newEmptyTMVar
|
||||
TM.insert srv var clients
|
||||
TM.insert tSess var clients
|
||||
pure var
|
||||
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient c srv clientVar = do
|
||||
waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient c (_, srv, _) clientVar = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
client_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar clientVar)
|
||||
liftEither $ case client_ of
|
||||
@@ -393,39 +441,38 @@ newProtocolClient ::
|
||||
forall msg m.
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
|
||||
AgentClient ->
|
||||
ProtoServer msg ->
|
||||
TMap (ProtoServer msg) (ClientVar msg) ->
|
||||
TransportSession msg ->
|
||||
TMap (TransportSession msg) (ClientVar msg) ->
|
||||
m (ProtocolClient msg) ->
|
||||
m () ->
|
||||
(AgentClient -> TransportSession msg -> m ()) ->
|
||||
ClientVar msg ->
|
||||
m (ProtocolClient msg)
|
||||
newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
|
||||
atomically $ putTMVar clientVar r
|
||||
liftIO $ incClientStat c client "CLIENT" "OK"
|
||||
liftIO $ incClientStat c userId client "CLIENT" "OK"
|
||||
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
|
||||
successAction client
|
||||
Left e -> do
|
||||
liftIO $ incServerStat c srv "CLIENT" $ strEncode e
|
||||
liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e
|
||||
if temporaryAgentError e
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar clientVar (Left e)
|
||||
TM.delete srv clients
|
||||
TM.delete tSess clients
|
||||
throwError e
|
||||
tryConnectAsync :: m ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients c) (a :)
|
||||
connectAsync :: m ()
|
||||
connectAsync = do
|
||||
tryConnectAsync = newAsyncAction connectAsync $ asyncClients c
|
||||
connectAsync :: Int -> m ()
|
||||
connectAsync aId = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c tSess) loop
|
||||
atomically . removeAsyncAction aId $ asyncClients c
|
||||
|
||||
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
|
||||
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client
|
||||
@@ -441,8 +488,8 @@ closeAgentClient c = liftIO $ do
|
||||
atomically $ writeTVar (active c) False
|
||||
closeProtocolServerClients c smpClients
|
||||
closeProtocolServerClients c ntfClients
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
cancelActions . actions $ reconnections c
|
||||
cancelActions . actions $ asyncClients c
|
||||
cancelActions $ smpQueueMsgDeliveries c
|
||||
cancelActions $ asyncCmdProcesses c
|
||||
atomically . RQ.clear $ activeSubs c
|
||||
@@ -471,7 +518,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =
|
||||
where
|
||||
k = (server, sndId)
|
||||
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients c clientsSel =
|
||||
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
|
||||
where
|
||||
@@ -492,32 +539,45 @@ withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId
|
||||
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
|
||||
where
|
||||
newLock = newEmptyTMVar >>= \l -> TM.insert key l locks $> l
|
||||
newLock = createLock >>= \l -> TM.insert key l locks $> l
|
||||
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ c srv statCmd action = do
|
||||
cl <- getProtocolServerClient c srv
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
cl <- getProtocolServerClient c tSess
|
||||
(action cl <* stat cl "OK") `catchError` logServerError cl
|
||||
where
|
||||
stat cl = liftIO . incClientStat c cl statCmd
|
||||
stat cl = liftIO . incClientStat c userId cl statCmd
|
||||
logServerError :: ProtocolClient msg -> AgentErrorType -> m a
|
||||
logServerError cl e = do
|
||||
logServer "<--" c srv "" $ strEncode e
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ c srv qId cmdStr action = do
|
||||
logServer "-->" c srv qId cmdStr
|
||||
res <- withClient_ c srv cmdStr action
|
||||
logServer "<--" c srv qId "OK"
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
|
||||
logServer "-->" c srv entId cmdStr
|
||||
res <- withClient_ c tSess cmdStr action
|
||||
logServer "<--" c srv entId "OK"
|
||||
return res
|
||||
|
||||
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c srv statKey action = withClient_ c srv statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
|
||||
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @msg) (clientServer client) $ action client
|
||||
|
||||
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withSMPClient c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
withLogClient c tSess (queueId q) cmdStr action
|
||||
|
||||
withSMPClient_ :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> m a) -> m a
|
||||
withSMPClient_ c q cmdStr action = do
|
||||
tSess <- mkSMPTransportSession c q
|
||||
withLogClient_ c tSess (queueId q) cmdStr action
|
||||
|
||||
withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityId -> ByteString -> (NtfClient -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withNtfClient c srv = withLogClient c (0, srv, Nothing)
|
||||
|
||||
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> HostName -> ExceptT ProtocolClientError IO a -> m a
|
||||
liftClient protocolError_ = liftError . protocolClientError protocolError_
|
||||
@@ -551,12 +611,13 @@ instance ToJSON SMPTestFailure where
|
||||
toEncoding = J.genericToEncoding J.defaultOptions
|
||||
toJSON = J.genericToJSON J.defaultOptions
|
||||
|
||||
runSMPServerTest :: AgentMonad m => AgentClient -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
runSMPServerTest c (ProtoServerWithAuth srv auth) = do
|
||||
runSMPServerTest :: AgentMonad m => AgentClient -> UserId -> SMPServerWithAuth -> m (Maybe SMPTestFailure)
|
||||
runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
|
||||
cfg <- getClientConfig c smpCfg
|
||||
C.SignAlg a <- asks $ cmdSignAlg . config
|
||||
liftIO $ do
|
||||
getProtocolClient srv cfg Nothing (\_ -> pure ()) >>= \case
|
||||
let tSess = (userId, srv, Nothing)
|
||||
getProtocolClient tSess cfg Nothing (\_ -> pure ()) >>= \case
|
||||
Right smp -> do
|
||||
(rKey, rpKey) <- C.generateSignatureKeyPair a
|
||||
(sKey, _) <- C.generateSignatureKeyPair a
|
||||
@@ -566,7 +627,7 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
|
||||
liftError (testErr TSSecureQueue) $ secureSMPQueue smp rpKey rcvId sKey
|
||||
liftError (testErr TSDeleteQueue) $ deleteSMPQueue smp rpKey rcvId
|
||||
ok <- tcpTimeout (networkConfig cfg) `timeout` closeProtocolClient smp
|
||||
incClientStat c smp "TEST" "OK"
|
||||
incClientStat c userId smp "TEST" "OK"
|
||||
pure $ either Just (const Nothing) r <|> maybe (Just (SMPTestFailure TSDisconnect $ BROKER addr TIMEOUT)) (const Nothing) ok
|
||||
Left e -> pure (Just $ testErr TSConnect e)
|
||||
where
|
||||
@@ -574,19 +635,36 @@ runSMPServerTest c (ProtoServerWithAuth srv auth) = do
|
||||
testErr :: SMPTestStep -> ProtocolClientError -> SMPTestFailure
|
||||
testErr step = SMPTestFailure step . protocolClientError SMP addr
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do
|
||||
mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
|
||||
mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c
|
||||
|
||||
mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg
|
||||
mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
|
||||
|
||||
mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
|
||||
mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c
|
||||
|
||||
mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession
|
||||
mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
|
||||
|
||||
getSessionMode :: AgentMonad m => AgentClient -> m TransportSessionMode
|
||||
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
|
||||
C.SignAlg a <- asks (cmdSignAlg . config)
|
||||
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
logServer "-->" c srv "" "NEW"
|
||||
tSess <- mkTransportSession c userId srv connId
|
||||
QIK {rcvId, sndId, rcvPublicDhKey} <-
|
||||
withClient c srv "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
|
||||
withClient c tSess "NEW" $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey auth
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
let rq =
|
||||
RcvQueue
|
||||
{ connId,
|
||||
{ userId,
|
||||
connId,
|
||||
server = srv,
|
||||
rcvId,
|
||||
rcvPrivateKey,
|
||||
@@ -599,20 +677,11 @@ newRcvQueue c connId (ProtoServerWithAuth srv auth) vRange = do
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = maxVersion vRange,
|
||||
clientNtfCreds = Nothing
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
|
||||
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ do
|
||||
modifyTVar' (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
|
||||
>>= either throwError pure
|
||||
|
||||
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult c rq r = do
|
||||
case r of
|
||||
@@ -639,34 +708,59 @@ temporaryOrHostError = \case
|
||||
BROKER _ HOST -> True
|
||||
e -> temporaryAgentError e
|
||||
|
||||
-- | subscribe multiple queues - all passed queues should be on the same server
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
|
||||
forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do
|
||||
modifyTVar' (subscrConns c) $ S.insert connId
|
||||
-- | Subscribe to queues. The list of results can have a different order.
|
||||
subscribeQueues :: forall m. AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
subscribeQueues c qs = do
|
||||
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
|
||||
forM_ qs' $ \rq@RcvQueue {connId} -> atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
(eitherToMaybe smp_,) . (errs <>) <$> case smp_ of
|
||||
Left e -> pure $ map (,Left e) qs_
|
||||
Right smp -> do
|
||||
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
|
||||
let qs2 = L.map queueCreds qs'
|
||||
n = (length qs2 - 1) `div` 90 + 1
|
||||
liftIO $ incClientStatN c smp n "SUBS" "OK"
|
||||
liftIO $ do
|
||||
rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
|
||||
mapM_ (uncurry $ processSubResult c) rs
|
||||
pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs
|
||||
_ -> pure (Nothing, errs)
|
||||
u <- askUnliftIO
|
||||
(errs <>) <$> sendTSessionBatches "SUB" 90 (subscribeQueues_ u) c qs
|
||||
where
|
||||
checkQueue rq@RcvQueue {rcvId, server} = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED) else Right rq
|
||||
subscribeQueues_ u smp qs' = do
|
||||
rs <- sendBatch subscribeSMPQueues smp qs'
|
||||
mapM_ (uncurry $ processSubResult c) rs
|
||||
when (any temporaryClientError . lefts . map snd $ L.toList rs) $
|
||||
unliftIO u $ reconnectServer c $ transportSession' smp
|
||||
pure rs
|
||||
|
||||
type BatchResponses e = (NonEmpty (RcvQueue, Either e ()))
|
||||
|
||||
-- statBatchSize is not used to batch the commands, only for traffic statistics
|
||||
sendTSessionBatches :: forall m. AgentMonad m => ByteString -> Int -> (SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)) -> AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
sendTSessionBatches statCmd statBatchSize action c qs =
|
||||
concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues)
|
||||
where
|
||||
batchQueues :: m [(SMPTransportSession, NonEmpty RcvQueue)]
|
||||
batchQueues = do
|
||||
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
|
||||
pure . M.assocs $ foldl' (batch mode) M.empty qs
|
||||
where
|
||||
batch mode m rq =
|
||||
let tSess = mkSMPTSession rq mode
|
||||
in M.alter (Just . maybe [rq] (rq <|)) tSess m
|
||||
sendClientBatch :: (SMPTransportSession, NonEmpty RcvQueue) -> m (BatchResponses AgentErrorType)
|
||||
sendClientBatch (tSess@(userId, srv, _), qs') =
|
||||
tryError (getSMPServerClient c tSess) >>= \case
|
||||
Left e -> pure $ L.map (,Left e) qs'
|
||||
Right smp -> liftIO $ do
|
||||
logServer "-->" c srv (bshow (length qs') <> " queues") statCmd
|
||||
rs <- L.map agentError <$> action smp qs'
|
||||
statBatch
|
||||
pure rs
|
||||
where
|
||||
agentError = second . first $ protocolClientError SMP $ clientServer smp
|
||||
statBatch =
|
||||
let n = (length qs - 1) `div` statBatchSize + 1
|
||||
in incClientStatN c userId smp n (statCmd <> "S") "OK"
|
||||
|
||||
sendBatch :: (SMPClient -> NonEmpty (SMP.RcvPrivateSignKey, SMP.RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses ProtocolClientError)
|
||||
sendBatch smpCmdFunc smp qs = L.zip qs <$> smpCmdFunc smp (L.map queueCreds qs)
|
||||
where
|
||||
checkQueue rq@RcvQueue {rcvId, server}
|
||||
| server == srv = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
|
||||
| otherwise = pure $ Left (rq, Left $ INTERNAL "queue server does not match parameter")
|
||||
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
|
||||
@@ -699,16 +793,17 @@ logSecret :: ByteString -> ByteString
|
||||
logSecret bs = encode $ B.take 3 bs
|
||||
|
||||
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
|
||||
sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withSMPClient_ c sq "SEND <CONF>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
|
||||
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo =
|
||||
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
|
||||
tSess <- mkTransportSession c userId smpServer senderId
|
||||
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
|
||||
msg <- mkInvitation
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
|
||||
where
|
||||
@@ -722,7 +817,7 @@ sendInvitation c (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderI
|
||||
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
|
||||
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
|
||||
atomically createTakeGetLock
|
||||
(v, msg_) <- withLogClient c server rcvId "GET" $ \smp ->
|
||||
(v, msg_) <- withSMPClient c rq "GET" $ \smp ->
|
||||
(thVersion smp,) <$> getSMPMessage smp rcvPrivateKey rcvId
|
||||
mapM (decryptMeta v) msg_
|
||||
where
|
||||
@@ -742,23 +837,23 @@ decryptSMPMessage v rq SMP.RcvMessage {msgId, msgTs, msgFlags, msgBody = SMP.Enc
|
||||
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
|
||||
|
||||
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
|
||||
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
|
||||
withLogClient c server rcvId "KEY <key>" $ \smp ->
|
||||
secureQueue c rq@RcvQueue {rcvId, rcvPrivateKey} senderKey =
|
||||
withSMPClient c rq "KEY <key>" $ \smp ->
|
||||
secureSMPQueue smp rcvPrivateKey rcvId senderKey
|
||||
|
||||
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (NotifierId, RcvNtfPublicDhKey)
|
||||
enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
|
||||
withLogClient c server rcvId "NKEY <nkey>" $ \smp ->
|
||||
enableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
|
||||
withSMPClient c rq "NKEY <nkey>" $ \smp ->
|
||||
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
|
||||
|
||||
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
disableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "NDEL" $ \smp ->
|
||||
disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "NDEL" $ \smp ->
|
||||
disableSMPQueueNotifications smp rcvPrivateKey rcvId
|
||||
|
||||
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
|
||||
sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do
|
||||
withLogClient c server rcvId "ACK" $ \smp ->
|
||||
sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = do
|
||||
withSMPClient c rq "ACK" $ \smp ->
|
||||
ackSMPMessage smp rcvPrivateKey rcvId msgId
|
||||
atomically $ releaseGetLock c rq
|
||||
|
||||
@@ -767,57 +862,60 @@ releaseGetLock c RcvQueue {server, rcvId} =
|
||||
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
|
||||
|
||||
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "OFF" $ \smp ->
|
||||
suspendQueue c rq@RcvQueue {rcvId, rcvPrivateKey} =
|
||||
withSMPClient c rq "OFF" $ \smp ->
|
||||
suspendSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "DEL" $ \smp ->
|
||||
deleteQueue c rq@RcvQueue {rcvId, rcvPrivateKey} = do
|
||||
withSMPClient c rq "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msgFlags agentMsg =
|
||||
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
|
||||
deleteQueues :: forall m. AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
deleteQueues = sendTSessionBatches "DEL" 90 $ sendBatch deleteSMPQueues
|
||||
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {sndId, sndPrivateKey} msgFlags agentMsg =
|
||||
withSMPClient_ c sq "SEND <MSG>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
|
||||
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
|
||||
|
||||
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
|
||||
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
|
||||
withClient c ntfServer "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
withClient c (0, ntfServer, Nothing) "TNEW" $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
|
||||
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
|
||||
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
|
||||
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
withNtfClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
|
||||
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
|
||||
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
withNtfClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
|
||||
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
|
||||
withLogClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
|
||||
withNtfClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
|
||||
|
||||
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
|
||||
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
withNtfClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
|
||||
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
|
||||
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
withNtfClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
|
||||
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> NtfPrivateSignKey -> m NtfSubscriptionId
|
||||
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
|
||||
withLogClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
|
||||
withNtfClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
|
||||
|
||||
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
|
||||
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
|
||||
withNtfClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
|
||||
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
|
||||
withNtfClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
|
||||
@@ -948,18 +1046,18 @@ incStat AgentClient {agentStats} n k = do
|
||||
Just v -> modifyTVar' v (+ n)
|
||||
_ -> newTVar n >>= \v -> TM.insert k v agentStats
|
||||
|
||||
incClientStat :: AgentClient -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c pc = incClientStatN c pc 1
|
||||
incClientStat :: AgentClient -> UserId -> ProtocolClient msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c userId pc = incClientStatN c userId pc 1
|
||||
|
||||
incServerStat :: AgentClient -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
incServerStat c ProtocolServer {host} cmd res = do
|
||||
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
incServerStat c userId ProtocolServer {host} cmd res = do
|
||||
threadDelay 100000
|
||||
atomically $ incStat c 1 statsKey
|
||||
where
|
||||
statsKey = AgentStatsKey {host = strEncode $ L.head host, clientTs = "", cmd, res}
|
||||
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
|
||||
|
||||
incClientStatN :: AgentClient -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN c pc n cmd res = do
|
||||
incClientStatN :: AgentClient -> UserId -> ProtocolClient msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN c userId pc n cmd res = do
|
||||
atomically $ incStat c n statsKey
|
||||
where
|
||||
statsKey = AgentStatsKey {host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
|
||||
statsKey = AgentStatsKey {userId, host = strEncode $ transportHost' pc, clientTs = strEncode $ sessionTs pc, cmd, res}
|
||||
|
||||
@@ -32,12 +32,14 @@ import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Map (Map)
|
||||
import Data.Time.Clock (NominalDiffTime, nominalDay)
|
||||
import Data.Word (Word16)
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store (UserId)
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Client
|
||||
@@ -59,7 +61,7 @@ import UnliftIO.STM
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
data InitialAgentServers = InitialAgentServers
|
||||
{ smp :: NonEmpty SMPServerWithAuth,
|
||||
{ smp :: Map UserId (NonEmpty SMPServerWithAuth),
|
||||
ntf :: [NtfServer],
|
||||
netCfg :: NetworkConfig
|
||||
}
|
||||
@@ -86,6 +88,9 @@ data AgentConfig = AgentConfig
|
||||
messageRetryInterval :: RetryInterval2,
|
||||
messageTimeout :: NominalDiffTime,
|
||||
helloTimeout :: NominalDiffTime,
|
||||
initialCleanupDelay :: Int,
|
||||
cleanupInterval :: Int,
|
||||
deleteErrorCount :: Int,
|
||||
ntfCron :: Word16,
|
||||
ntfWorkerDelay :: Int,
|
||||
ntfSMPWorkerDelay :: Int,
|
||||
@@ -143,6 +148,9 @@ defaultAgentConfig =
|
||||
messageRetryInterval = defaultMessageRetryInterval,
|
||||
messageTimeout = 2 * nominalDay,
|
||||
helloTimeout = 2 * nominalDay,
|
||||
initialCleanupDelay = 30 * 1000000, -- 30 seconds
|
||||
cleanupInterval = 30 * 60 * 1000000, -- 30 minutes
|
||||
deleteErrorCount = 10,
|
||||
ntfCron = 20, -- minutes
|
||||
ntfWorkerDelay = 100000, -- microseconds
|
||||
ntfSMPWorkerDelay = 500000, -- microseconds
|
||||
|
||||
@@ -10,6 +10,10 @@ import UnliftIO.STM
|
||||
|
||||
type Lock = TMVar String
|
||||
|
||||
createLock :: STM Lock
|
||||
createLock = newEmptyTMVar
|
||||
{-# INLINE createLock #-}
|
||||
|
||||
withLock :: MonadUnliftIO m => TMVar String -> String -> m a -> m a
|
||||
withLock lock name =
|
||||
E.bracket_
|
||||
|
||||
@@ -272,6 +272,9 @@ data ACommand (p :: AParty) where
|
||||
SWCH :: ACommand Client
|
||||
OFF :: ACommand Client
|
||||
DEL :: ACommand Client
|
||||
DEL_RCVQ :: SMPServer -> SMP.RecipientId -> Maybe AgentErrorType -> ACommand Agent
|
||||
DEL_CONN :: ACommand Agent
|
||||
DEL_USER :: Int64 -> ACommand Agent
|
||||
CHK :: ACommand Client
|
||||
STAT :: ConnectionStats -> ACommand Agent
|
||||
OK :: ACommand Agent
|
||||
@@ -311,6 +314,9 @@ data ACommandTag (p :: AParty) where
|
||||
SWCH_ :: ACommandTag Client
|
||||
OFF_ :: ACommandTag Client
|
||||
DEL_ :: ACommandTag Client
|
||||
DEL_RCVQ_ :: ACommandTag Agent
|
||||
DEL_CONN_ :: ACommandTag Agent
|
||||
DEL_USER_ :: ACommandTag Agent
|
||||
CHK_ :: ACommandTag Client
|
||||
STAT_ :: ACommandTag Agent
|
||||
OK_ :: ACommandTag Agent
|
||||
@@ -349,6 +355,9 @@ aCommandTag = \case
|
||||
SWCH -> SWCH_
|
||||
OFF -> OFF_
|
||||
DEL -> DEL_
|
||||
DEL_RCVQ {} -> DEL_RCVQ_
|
||||
DEL_CONN -> DEL_CONN_
|
||||
DEL_USER _ -> DEL_USER_
|
||||
CHK -> CHK_
|
||||
STAT _ -> STAT_
|
||||
OK -> OK_
|
||||
@@ -1225,6 +1234,9 @@ instance StrEncoding ACmdTag where
|
||||
"SWCH" -> pure $ ACmdTag SClient SWCH_
|
||||
"OFF" -> pure $ ACmdTag SClient OFF_
|
||||
"DEL" -> pure $ ACmdTag SClient DEL_
|
||||
"DEL_RCVQ" -> pure $ ACmdTag SAgent DEL_RCVQ_
|
||||
"DEL_CONN" -> pure $ ACmdTag SAgent DEL_CONN_
|
||||
"DEL_USER" -> pure $ ACmdTag SAgent DEL_USER_
|
||||
"CHK" -> pure $ ACmdTag SClient CHK_
|
||||
"STAT" -> pure $ ACmdTag SAgent STAT_
|
||||
"OK" -> pure $ ACmdTag SAgent OK_
|
||||
@@ -1260,6 +1272,9 @@ instance APartyI p => StrEncoding (ACommandTag p) where
|
||||
SWCH_ -> "SWCH"
|
||||
OFF_ -> "OFF"
|
||||
DEL_ -> "DEL"
|
||||
DEL_RCVQ_ -> "DEL_RCVQ"
|
||||
DEL_CONN_ -> "DEL_CONN"
|
||||
DEL_USER_ -> "DEL_USER"
|
||||
CHK_ -> "CHK"
|
||||
STAT_ -> "STAT"
|
||||
OK_ -> "OK"
|
||||
@@ -1308,6 +1323,9 @@ commandP binaryP =
|
||||
SENT_ -> s (SENT <$> A.decimal)
|
||||
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
|
||||
MSG_ -> s (MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> binaryP)
|
||||
DEL_RCVQ_ -> s (DEL_RCVQ <$> strP_ <*> strP_ <*> strP)
|
||||
DEL_CONN_ -> pure DEL_CONN
|
||||
DEL_USER_ -> s (DEL_USER <$> strP)
|
||||
STAT_ -> s (STAT <$> strP)
|
||||
OK_ -> pure OK
|
||||
ERR_ -> s (ERR <$> strP)
|
||||
@@ -1356,6 +1374,9 @@ serializeCommand = \case
|
||||
SWCH -> s SWCH_
|
||||
OFF -> s OFF_
|
||||
DEL -> s DEL_
|
||||
DEL_RCVQ srv rcvId err_ -> s (DEL_RCVQ_, srv, rcvId, err_)
|
||||
DEL_CONN -> s DEL_CONN_
|
||||
DEL_USER userId -> s (DEL_USER_, userId)
|
||||
CHK -> s CHK_
|
||||
STAT srvs -> s (STAT_, srvs)
|
||||
CON -> s CON_
|
||||
|
||||
@@ -34,6 +34,7 @@ import Simplex.Messaging.Protocol
|
||||
NotifierId,
|
||||
NtfPrivateSignKey,
|
||||
NtfPublicVerifyKey,
|
||||
QueueId,
|
||||
RcvDhSecret,
|
||||
RcvNtfDhSecret,
|
||||
RcvPrivateSignKey,
|
||||
@@ -47,7 +48,8 @@ import Simplex.Messaging.Version
|
||||
|
||||
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
|
||||
data RcvQueue = RcvQueue
|
||||
{ connId :: ConnId,
|
||||
{ userId :: UserId,
|
||||
connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
-- | recipient queue ID
|
||||
rcvId :: SMP.RecipientId,
|
||||
@@ -72,7 +74,8 @@ data RcvQueue = RcvQueue
|
||||
-- | SMP client version
|
||||
smpClientVersion :: Version,
|
||||
-- | credentials used in context of notifications
|
||||
clientNtfCreds :: Maybe ClientNtfCreds
|
||||
clientNtfCreds :: Maybe ClientNtfCreds,
|
||||
deleteErrors :: Int
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -89,7 +92,8 @@ data ClientNtfCreds = ClientNtfCreds
|
||||
|
||||
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
|
||||
data SndQueue = SndQueue
|
||||
{ connId :: ConnId,
|
||||
{ userId :: UserId,
|
||||
connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
-- | sender queue ID
|
||||
sndId :: SMP.SenderId,
|
||||
@@ -150,6 +154,27 @@ findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue
|
||||
findRQ sAddr = find $ sameQAddress sAddr . sndAddress
|
||||
{-# INLINE findRQ #-}
|
||||
|
||||
class SMPQueue q => SMPQueueRec q where
|
||||
qUserId :: q -> UserId
|
||||
qConnId :: q -> ConnId
|
||||
queueId :: q -> QueueId
|
||||
|
||||
instance SMPQueueRec RcvQueue where
|
||||
qUserId = userId
|
||||
{-# INLINE qUserId #-}
|
||||
qConnId = connId
|
||||
{-# INLINE qConnId #-}
|
||||
queueId = rcvId
|
||||
{-# INLINE queueId #-}
|
||||
|
||||
instance SMPQueueRec SndQueue where
|
||||
qUserId = userId
|
||||
{-# INLINE qUserId #-}
|
||||
qConnId = connId
|
||||
{-# INLINE qConnId #-}
|
||||
queueId = sndId
|
||||
{-# INLINE queueId #-}
|
||||
|
||||
-- * Connection types
|
||||
|
||||
-- | Type of a connection.
|
||||
@@ -222,6 +247,7 @@ deriving instance Show SomeConn
|
||||
|
||||
data ConnData = ConnData
|
||||
{ connId :: ConnId,
|
||||
userId :: UserId,
|
||||
connAgentVersion :: Version,
|
||||
enableNtfs :: Bool,
|
||||
duplexHandshake :: Maybe Bool, -- added in agent protocol v2
|
||||
@@ -229,8 +255,17 @@ data ConnData = ConnData
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data PendingCommand = PendingCommand
|
||||
{ corrId :: ACorrId,
|
||||
userId :: UserId,
|
||||
connId :: ConnId,
|
||||
command :: AgentCommand
|
||||
}
|
||||
|
||||
data AgentCmdType = ACClient | ACInternal
|
||||
|
||||
type UserId = Int64
|
||||
|
||||
instance StrEncoding AgentCmdType where
|
||||
strEncode = \case
|
||||
ACClient -> "CLIENT"
|
||||
@@ -471,6 +506,8 @@ data StoreError
|
||||
SEInternal ByteString
|
||||
| -- | Failed to generate unique random ID
|
||||
SEUniqueID
|
||||
| -- | User ID not found
|
||||
SEUserNotFound
|
||||
| -- | Connection not found (or both queues absent).
|
||||
SEConnNotFound
|
||||
| -- | Connection already used.
|
||||
|
||||
@@ -27,6 +27,14 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
sqlString,
|
||||
execSQL,
|
||||
|
||||
-- * Users
|
||||
createUserRecord,
|
||||
deleteUserRecord,
|
||||
setUserDeleted,
|
||||
deleteUserWithoutConns,
|
||||
deleteUsersWithoutConns,
|
||||
checkUser,
|
||||
|
||||
-- * Queues and connections
|
||||
createNewConn,
|
||||
updateNewConnRcv,
|
||||
@@ -34,9 +42,10 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
createRcvConn,
|
||||
createSndConn,
|
||||
getConn,
|
||||
getAnyConn,
|
||||
getDeletedConn,
|
||||
getConnData,
|
||||
setConnDeleted,
|
||||
getDeletedConns,
|
||||
getRcvConn,
|
||||
deleteConn,
|
||||
upgradeRcvConnToDuplex,
|
||||
@@ -49,6 +58,7 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
setRcvQueuePrimary,
|
||||
setSndQueuePrimary,
|
||||
deleteConnRcvQueue,
|
||||
incRcvDeleteErrors,
|
||||
deleteConnSndQueue,
|
||||
getPrimaryRcvQueue,
|
||||
getRcvQueue,
|
||||
@@ -286,7 +296,7 @@ withConnection SQLiteStore {dbConnection} =
|
||||
(atomically . putTMVar dbConnection)
|
||||
|
||||
withTransaction :: forall a. SQLiteStore -> (DB.Connection -> IO a) -> IO a
|
||||
withTransaction st action = withConnection st $ loop 500 2_000_000
|
||||
withTransaction st action = withConnection st $ loop 500 3_000_000
|
||||
where
|
||||
loop :: Int -> Int -> DB.Connection -> IO a
|
||||
loop t tLim db =
|
||||
@@ -297,6 +307,60 @@ withTransaction st action = withConnection st $ loop 500 2_000_000
|
||||
loop (t * 9 `div` 8) (tLim - t) db
|
||||
else E.throwIO e
|
||||
|
||||
createUserRecord :: DB.Connection -> IO UserId
|
||||
createUserRecord db = do
|
||||
DB.execute_ db "INSERT INTO users DEFAULT VALUES"
|
||||
insertedRowId db
|
||||
|
||||
checkUser :: DB.Connection -> UserId -> IO (Either StoreError ())
|
||||
checkUser db userId =
|
||||
firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $
|
||||
DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, False)
|
||||
|
||||
deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ())
|
||||
deleteUserRecord db userId = runExceptT $ do
|
||||
ExceptT $ checkUser db userId
|
||||
liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId)
|
||||
|
||||
setUserDeleted :: DB.Connection -> UserId -> IO (Either StoreError [ConnId])
|
||||
setUserDeleted db userId = runExceptT $ do
|
||||
ExceptT $ checkUser db userId
|
||||
liftIO $ do
|
||||
DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (True, userId)
|
||||
map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE user_id = ?" (Only userId)
|
||||
|
||||
deleteUserWithoutConns :: DB.Connection -> UserId -> IO Bool
|
||||
deleteUserWithoutConns db userId = do
|
||||
userId_ :: Maybe Int64 <-
|
||||
maybeFirstRow fromOnly $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT user_id FROM users u
|
||||
WHERE u.user_id = ?
|
||||
AND u.deleted = ?
|
||||
AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id)
|
||||
|]
|
||||
(userId, True)
|
||||
case userId_ of
|
||||
Just _ -> DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) $> True
|
||||
_ -> pure False
|
||||
|
||||
deleteUsersWithoutConns :: DB.Connection -> IO [Int64]
|
||||
deleteUsersWithoutConns db = do
|
||||
userIds <-
|
||||
map fromOnly
|
||||
<$> DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT user_id FROM users u
|
||||
WHERE u.deleted = ?
|
||||
AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id)
|
||||
|]
|
||||
(Only True)
|
||||
forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only
|
||||
pure userIds
|
||||
|
||||
createConn_ ::
|
||||
TVar ChaChaDRG ->
|
||||
ConnData ->
|
||||
@@ -307,9 +371,9 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
|
||||
ConnData {connId} -> create connId $> Right connId
|
||||
|
||||
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode =
|
||||
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
|
||||
updateNewConnRcv db connId rq =
|
||||
@@ -332,17 +396,17 @@ updateNewConnSnd db connId sq =
|
||||
updateConn = Right <$> addConnSndQueue_ db connId sq
|
||||
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertRcvQueue_ db connId q
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
||||
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
void $ insertSndQueue_ db connId q
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
@@ -455,8 +519,12 @@ setSndQueuePrimary db connId SndQueue {dbQueueId} = do
|
||||
"UPDATE snd_queues SET snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?"
|
||||
(True, Nothing :: Maybe Int64, connId, dbQueueId)
|
||||
|
||||
deleteConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
deleteConnRcvQueue db connId RcvQueue {dbQueueId} =
|
||||
incRcvDeleteErrors :: DB.Connection -> RcvQueue -> IO ()
|
||||
incRcvDeleteErrors db RcvQueue {connId, dbQueueId} =
|
||||
DB.execute db "UPDATE rcv_queues SET delete_errors = delete_errors + 1 WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
|
||||
|
||||
deleteConnRcvQueue :: DB.Connection -> RcvQueue -> IO ()
|
||||
deleteConnRcvQueue db RcvQueue {connId, dbQueueId} =
|
||||
DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
|
||||
|
||||
deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
@@ -820,13 +888,20 @@ getPendingCommands db connId = do
|
||||
where
|
||||
srvCmdId (host, port, keyHash, cmdId) = (SMPServer <$> host <*> port <*> keyHash, cmdId)
|
||||
|
||||
getPendingCommand :: DB.Connection -> AsyncCmdId -> IO (Either StoreError (ACorrId, ConnId, AgentCommand))
|
||||
getPendingCommand :: DB.Connection -> AsyncCmdId -> IO (Either StoreError PendingCommand)
|
||||
getPendingCommand db msgId = do
|
||||
firstRow id SECmdNotFound $
|
||||
firstRow pendingCommand SECmdNotFound $
|
||||
DB.query
|
||||
db
|
||||
"SELECT corr_id, conn_id, command FROM commands WHERE command_id = ?"
|
||||
[sql|
|
||||
SELECT c.corr_id, cs.user_id, c.conn_id, c.command
|
||||
FROM commands c
|
||||
JOIN connections cs USING (conn_id)
|
||||
WHERE c.command_id = ?
|
||||
|]
|
||||
(Only msgId)
|
||||
where
|
||||
pendingCommand (corrId, userId, connId, command) = PendingCommand {corrId, userId, connId, command}
|
||||
|
||||
deleteCommand :: DB.Connection -> AsyncCmdId -> IO ()
|
||||
deleteCommand db cmdId =
|
||||
@@ -1307,10 +1382,13 @@ newQueueId_ (Only maxId : _) = maxId + 1
|
||||
-- * getConn helpers
|
||||
|
||||
getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getConn db connId = getAnyConn db connId False
|
||||
getConn = getAnyConn False
|
||||
|
||||
getAnyConn :: DB.Connection -> ConnId -> Bool -> IO (Either StoreError SomeConn)
|
||||
getAnyConn dbConn connId deleted' =
|
||||
getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getDeletedConn = getAnyConn True
|
||||
|
||||
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getAnyConn deleted' dbConn connId =
|
||||
getConnData dbConn connId >>= \case
|
||||
Nothing -> pure $ Left SEConnNotFound
|
||||
Just (cData@ConnData {deleted}, cMode)
|
||||
@@ -1328,13 +1406,16 @@ getAnyConn dbConn connId deleted' =
|
||||
|
||||
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData dbConn connId' =
|
||||
maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
|
||||
maybeFirstRow cData $ DB.query dbConn "SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake, deleted FROM connections WHERE conn_id = ?;" (Only connId')
|
||||
where
|
||||
cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
|
||||
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake, deleted) = (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake, deleted}, cMode)
|
||||
|
||||
setConnDeleted :: DB.Connection -> ConnId -> IO ()
|
||||
setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
|
||||
|
||||
getDeletedConns :: DB.Connection -> IO [ConnId]
|
||||
getDeletedConns db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True)
|
||||
|
||||
-- | returns all connection queues, the first queue is the primary one
|
||||
getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue))
|
||||
getRcvQueuesByConnId_ db connId =
|
||||
@@ -1348,31 +1429,32 @@ getRcvQueuesByConnId_ db connId =
|
||||
rcvQueueQuery :: Query
|
||||
rcvQueueQuery =
|
||||
[sql|
|
||||
SELECT s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
SELECT c.user_id, s.key_hash, q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
|
||||
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version,
|
||||
q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, q.delete_errors,
|
||||
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN connections c ON q.conn_id = c.conn_id
|
||||
|]
|
||||
|
||||
toRcvQueue ::
|
||||
(C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
||||
:. (Int64, Bool, Maybe Int64, Maybe Version)
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
||||
:. (Int64, Bool, Maybe Int64, Maybe Version, Int)
|
||||
:. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
|
||||
RcvQueue
|
||||
toRcvQueue ((keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
let server = SMPServer host port keyHash
|
||||
smpClientVersion = fromMaybe 1 smpClientVersion_
|
||||
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
|
||||
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
||||
_ -> Nothing
|
||||
in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
|
||||
in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion, clientNtfCreds, deleteErrors}
|
||||
|
||||
getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
|
||||
getRcvQueueById_ db connId dbRcvId =
|
||||
firstRow toRcvQueue SEConnNotFound $
|
||||
DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId)
|
||||
DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ?") (connId, dbRcvId)
|
||||
|
||||
-- | returns all connection queues, the first queue is the primary one
|
||||
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
|
||||
@@ -1381,16 +1463,17 @@ getSndQueuesByConnId_ dbConn connId =
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
SELECT c.user_id, s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
FROM snd_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
JOIN connections c ON q.conn_id = c.conn_id
|
||||
WHERE q.conn_id = ?;
|
||||
|]
|
||||
(Only connId)
|
||||
where
|
||||
sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
|
||||
sndQueue ((userId, keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, dbReplaceQueueId, smpClientVersion)) =
|
||||
let server = SMPServer host port keyHash
|
||||
in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
|
||||
in SndQueue {userId, connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, smpClientVersion}
|
||||
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
|
||||
-- the current primary queue is ordered first, the next primary - second
|
||||
compare (Down p) (Down p') <> compare i i'
|
||||
|
||||
@@ -37,6 +37,9 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
|
||||
@@ -53,7 +56,10 @@ schemaMigrations =
|
||||
("m20220811_onion_hosts", m20220811_onion_hosts),
|
||||
("m20220817_connection_ntfs", m20220817_connection_ntfs),
|
||||
("m20220905_commands", m20220905_commands),
|
||||
("m20220915_connection_queues", m20220915_connection_queues)
|
||||
("m20220915_connection_queues", m20220915_connection_queues),
|
||||
("m20230110_users", m20230110_users),
|
||||
("m20230117_fkey_indexes", m20230117_fkey_indexes),
|
||||
("m20230120_delete_errors", m20230120_delete_errors)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20230110_users :: Query
|
||||
m20230110_users =
|
||||
[sql|
|
||||
PRAGMA ignore_check_constraints=ON;
|
||||
|
||||
CREATE TABLE users (
|
||||
user_id INTEGER PRIMARY KEY AUTOINCREMENT
|
||||
);
|
||||
|
||||
INSERT INTO users (user_id) VALUES (1);
|
||||
|
||||
ALTER TABLE connections ADD COLUMN user_id INTEGER CHECK (user_id NOT NULL)
|
||||
REFERENCES users ON DELETE CASCADE;
|
||||
|
||||
CREATE INDEX idx_connections_user ON connections(user_id);
|
||||
|
||||
CREATE INDEX idx_commands_conn_id ON commands(conn_id);
|
||||
|
||||
UPDATE connections SET user_id = 1;
|
||||
|
||||
PRAGMA ignore_check_constraints=OFF;
|
||||
|]
|
||||
@@ -0,0 +1,27 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
-- .lint fkey-indexes
|
||||
m20230117_fkey_indexes :: Query
|
||||
m20230117_fkey_indexes =
|
||||
[sql|
|
||||
CREATE INDEX idx_commands_host_port ON commands(host, port);
|
||||
CREATE INDEX idx_conn_confirmations_conn_id ON conn_confirmations(conn_id);
|
||||
CREATE INDEX idx_conn_invitations_contact_conn_id ON conn_invitations(contact_conn_id);
|
||||
CREATE INDEX idx_messages_conn_id_internal_snd_id ON messages(conn_id, internal_snd_id);
|
||||
CREATE INDEX idx_messages_conn_id_internal_rcv_id ON messages(conn_id, internal_rcv_id);
|
||||
CREATE INDEX idx_messages_conn_id ON messages(conn_id);
|
||||
CREATE INDEX idx_ntf_subscriptions_ntf_host_ntf_port ON ntf_subscriptions(ntf_host, ntf_port);
|
||||
CREATE INDEX idx_ntf_subscriptions_smp_host_smp_port ON ntf_subscriptions(smp_host, smp_port);
|
||||
CREATE INDEX idx_ntf_tokens_ntf_host_ntf_port ON ntf_tokens(ntf_host, ntf_port);
|
||||
CREATE INDEX idx_ratchets_conn_id ON ratchets(conn_id);
|
||||
CREATE INDEX idx_rcv_messages_conn_id_internal_id ON rcv_messages(conn_id, internal_id);
|
||||
CREATE INDEX idx_skipped_messages_conn_id ON skipped_messages(conn_id);
|
||||
CREATE INDEX idx_snd_message_deliveries_conn_id_internal_id ON snd_message_deliveries(conn_id, internal_id);
|
||||
CREATE INDEX idx_snd_messages_conn_id_internal_id ON snd_messages(conn_id, internal_id);
|
||||
CREATE INDEX idx_snd_queues_host_port ON snd_queues(host, port);
|
||||
|]
|
||||
@@ -0,0 +1,20 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20230120_delete_errors :: Query
|
||||
m20230120_delete_errors =
|
||||
[sql|
|
||||
PRAGMA ignore_check_constraints=ON;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN delete_errors INTEGER DEFAULT 0 CHECK (delete_errors NOT NULL);
|
||||
UPDATE rcv_queues SET delete_errors = 0;
|
||||
|
||||
ALTER TABLE users ADD COLUMN deleted INTEGER DEFAULT 0 CHECK (deleted NOT NULL);
|
||||
UPDATE users SET deleted = 0;
|
||||
|
||||
PRAGMA ignore_check_constraints=OFF;
|
||||
|]
|
||||
@@ -22,7 +22,9 @@ CREATE TABLE connections(
|
||||
,
|
||||
duplex_handshake INTEGER NULL DEFAULT 0,
|
||||
enable_ntfs INTEGER,
|
||||
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL)
|
||||
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL),
|
||||
user_id INTEGER CHECK(user_id NOT NULL)
|
||||
REFERENCES users ON DELETE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE rcv_queues(
|
||||
host TEXT NOT NULL,
|
||||
@@ -45,6 +47,7 @@ CREATE TABLE rcv_queues(
|
||||
rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL),
|
||||
rcv_primary INTEGER CHECK(rcv_primary NOT NULL),
|
||||
replace_rcv_queue_id INTEGER NULL,
|
||||
delete_errors INTEGER DEFAULT 0 CHECK(delete_errors NOT NULL),
|
||||
PRIMARY KEY(host, port, rcv_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
@@ -228,3 +231,51 @@ CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries(
|
||||
conn_id,
|
||||
snd_queue_id
|
||||
);
|
||||
CREATE TABLE users(
|
||||
user_id INTEGER PRIMARY KEY AUTOINCREMENT
|
||||
,
|
||||
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL)
|
||||
);
|
||||
CREATE INDEX idx_connections_user ON connections(user_id);
|
||||
CREATE INDEX idx_commands_conn_id ON commands(conn_id);
|
||||
CREATE INDEX idx_commands_host_port ON commands(host, port);
|
||||
CREATE INDEX idx_conn_confirmations_conn_id ON conn_confirmations(conn_id);
|
||||
CREATE INDEX idx_conn_invitations_contact_conn_id ON conn_invitations(
|
||||
contact_conn_id
|
||||
);
|
||||
CREATE INDEX idx_messages_conn_id_internal_snd_id ON messages(
|
||||
conn_id,
|
||||
internal_snd_id
|
||||
);
|
||||
CREATE INDEX idx_messages_conn_id_internal_rcv_id ON messages(
|
||||
conn_id,
|
||||
internal_rcv_id
|
||||
);
|
||||
CREATE INDEX idx_messages_conn_id ON messages(conn_id);
|
||||
CREATE INDEX idx_ntf_subscriptions_ntf_host_ntf_port ON ntf_subscriptions(
|
||||
ntf_host,
|
||||
ntf_port
|
||||
);
|
||||
CREATE INDEX idx_ntf_subscriptions_smp_host_smp_port ON ntf_subscriptions(
|
||||
smp_host,
|
||||
smp_port
|
||||
);
|
||||
CREATE INDEX idx_ntf_tokens_ntf_host_ntf_port ON ntf_tokens(
|
||||
ntf_host,
|
||||
ntf_port
|
||||
);
|
||||
CREATE INDEX idx_ratchets_conn_id ON ratchets(conn_id);
|
||||
CREATE INDEX idx_rcv_messages_conn_id_internal_id ON rcv_messages(
|
||||
conn_id,
|
||||
internal_id
|
||||
);
|
||||
CREATE INDEX idx_skipped_messages_conn_id ON skipped_messages(conn_id);
|
||||
CREATE INDEX idx_snd_message_deliveries_conn_id_internal_id ON snd_message_deliveries(
|
||||
conn_id,
|
||||
internal_id
|
||||
);
|
||||
CREATE INDEX idx_snd_messages_conn_id_internal_id ON snd_messages(
|
||||
conn_id,
|
||||
internal_id
|
||||
);
|
||||
CREATE INDEX idx_snd_queues_host_port ON snd_queues(host, port);
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
module Simplex.Messaging.Agent.TAsyncs where
|
||||
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import UnliftIO.Async (Async, async)
|
||||
import UnliftIO.STM
|
||||
|
||||
data TAsyncs = TAsyncs
|
||||
{ actionId :: TVar Int,
|
||||
actions :: TMap Int (Async ())
|
||||
}
|
||||
|
||||
newTAsyncs :: STM TAsyncs
|
||||
newTAsyncs = TAsyncs <$> newTVar 0 <*> TM.empty
|
||||
|
||||
newAsyncAction :: MonadUnliftIO m => (Int -> m ()) -> TAsyncs -> m ()
|
||||
newAsyncAction action as = do
|
||||
aId <- atomically $ stateTVar (actionId as) $ \i -> (i + 1, i + 1)
|
||||
a <- async $ action aId
|
||||
atomically $ TM.insert aId a $ actions as
|
||||
|
||||
removeAsyncAction :: Int -> TAsyncs -> STM ()
|
||||
removeAsyncAction aId = TM.delete aId . actions
|
||||
@@ -1,18 +1,28 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Agent.TRcvQueues where
|
||||
module Simplex.Messaging.Agent.TRcvQueues
|
||||
( TRcvQueues,
|
||||
empty,
|
||||
clear,
|
||||
deleteConn,
|
||||
hasConn,
|
||||
getConns,
|
||||
addQueue,
|
||||
deleteQueue,
|
||||
getSessQueues,
|
||||
getDelSessQueues,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Simplex.Messaging.Agent.Protocol (ConnId)
|
||||
import Simplex.Messaging.Agent.Store (RcvQueue (..))
|
||||
import Simplex.Messaging.Agent.Store (RcvQueue (..), UserId)
|
||||
import Simplex.Messaging.Protocol (RecipientId, SMPServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
|
||||
newtype TRcvQueues = TRcvQueues (TMap (SMPServer, RecipientId) RcvQueue)
|
||||
newtype TRcvQueues = TRcvQueues (TMap (UserId, SMPServer, RecipientId) RcvQueue)
|
||||
|
||||
empty :: STM TRcvQueues
|
||||
empty = TRcvQueues <$> TM.empty
|
||||
@@ -30,19 +40,26 @@ getConns :: TRcvQueues -> STM (Set ConnId)
|
||||
getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs
|
||||
|
||||
addQueue :: RcvQueue -> TRcvQueues -> STM ()
|
||||
addQueue rq@RcvQueue {server, rcvId} (TRcvQueues qs) = TM.insert (server, rcvId) rq qs
|
||||
addQueue rq (TRcvQueues qs) = TM.insert (qKey rq) rq qs
|
||||
|
||||
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
|
||||
deleteQueue RcvQueue {server, rcvId} (TRcvQueues qs) = TM.delete (server, rcvId) qs
|
||||
deleteQueue rq (TRcvQueues qs) = TM.delete (qKey rq) qs
|
||||
|
||||
getSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
|
||||
getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
|
||||
getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
|
||||
getSessQueues tSess (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
|
||||
where
|
||||
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
|
||||
addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs'
|
||||
|
||||
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
|
||||
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty)
|
||||
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
|
||||
getDelSessQueues tSess (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
|
||||
where
|
||||
addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId}
|
||||
| srv == server = ((rq : remQs, S.insert connId remConns), qs')
|
||||
| otherwise = (removed, M.insert (server, rcvId) rq qs')
|
||||
addQ (removed, qs') rq
|
||||
| rq `isSession` tSess = (rq : removed, qs')
|
||||
| otherwise = (removed, M.insert (qKey rq) rq qs')
|
||||
|
||||
isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool
|
||||
isSession rq (uId, srv, connId_) =
|
||||
userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_
|
||||
|
||||
qKey :: RcvQueue -> (UserId, SMPServer, ConnId)
|
||||
qKey rq = (userId rq, server rq, connId rq)
|
||||
|
||||
@@ -26,12 +26,14 @@
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Client
|
||||
( -- * Connect (disconnect) client to (from) SMP server
|
||||
TransportSession,
|
||||
ProtocolClient (thVersion, sessionId, sessionTs),
|
||||
SMPClient,
|
||||
getProtocolClient,
|
||||
closeProtocolClient,
|
||||
clientServer,
|
||||
transportHost',
|
||||
transportSession',
|
||||
|
||||
-- * SMP protocol command functions
|
||||
createSMPQueue,
|
||||
@@ -48,12 +50,14 @@ module Simplex.Messaging.Client
|
||||
ackSMPMessage,
|
||||
suspendSMPQueue,
|
||||
deleteSMPQueue,
|
||||
deleteSMPQueues,
|
||||
sendProtocolCommand,
|
||||
|
||||
-- * Supporting types and client configuration
|
||||
ProtocolClientError (..),
|
||||
ProtocolClientConfig (..),
|
||||
NetworkConfig (..),
|
||||
TransportSessionMode (..),
|
||||
defaultClientConfig,
|
||||
defaultNetworkConfig,
|
||||
transportClientConfig,
|
||||
@@ -75,6 +79,7 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (rights)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.List (find)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
@@ -110,9 +115,10 @@ data ProtocolClient msg = ProtocolClient
|
||||
|
||||
data PClient msg = PClient
|
||||
{ connected :: TVar Bool,
|
||||
protocolServer :: ProtoServer msg,
|
||||
transportSession :: TransportSession msg,
|
||||
transportHost :: TransportHost,
|
||||
tcpTimeout :: Int,
|
||||
pingErrorCount :: TVar Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TMap CorrId (Request msg),
|
||||
sndQ :: TBQueue (NonEmpty SentRawTransmission),
|
||||
@@ -126,7 +132,7 @@ type SMPClient = ProtocolClient SMP.BrokerMsg
|
||||
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
|
||||
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg)
|
||||
type ServerTransmission msg = (TransportSession msg, Version, SessionId, QueueId, msg)
|
||||
|
||||
data HostMode
|
||||
= -- | prefer (or require) onion hosts when connecting via SOCKS proxy
|
||||
@@ -152,14 +158,18 @@ data NetworkConfig = NetworkConfig
|
||||
hostMode :: HostMode,
|
||||
-- | if above criteria is not met, if the below setting is True return error, otherwise use the first host
|
||||
requiredHostMode :: Bool,
|
||||
-- | transport sessions are created per user or per entity
|
||||
sessionMode :: TransportSessionMode,
|
||||
-- | timeout for the initial client TCP/TLS connection (microseconds)
|
||||
tcpConnectTimeout :: Int,
|
||||
-- | timeout of protocol commands (microseconds)
|
||||
tcpTimeout :: Int,
|
||||
-- | TCP keep-alive options, Nothing to skip enabling keep-alive
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
-- | period for SMP ping commands (microseconds)
|
||||
-- | period for SMP ping commands (microseconds, 0 to disable)
|
||||
smpPingInterval :: Int,
|
||||
-- | the count of PING errors after which SMP client terminates (and will be reconnected), 0 to disable
|
||||
smpPingCount :: Int,
|
||||
logTLSErrors :: Bool
|
||||
}
|
||||
deriving (Eq, Show, Generic, FromJSON)
|
||||
@@ -168,16 +178,28 @@ instance ToJSON NetworkConfig where
|
||||
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
data TransportSessionMode = TSMUser | TSMEntity
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON TransportSessionMode where
|
||||
toJSON = J.genericToJSON . enumJSON $ dropPrefix "TSM"
|
||||
toEncoding = J.genericToEncoding . enumJSON $ dropPrefix "TSM"
|
||||
|
||||
instance FromJSON TransportSessionMode where
|
||||
parseJSON = J.genericParseJSON . enumJSON $ dropPrefix "TSM"
|
||||
|
||||
defaultNetworkConfig :: NetworkConfig
|
||||
defaultNetworkConfig =
|
||||
NetworkConfig
|
||||
{ socksProxy = Nothing,
|
||||
hostMode = HMOnionViaSocks,
|
||||
requiredHostMode = False,
|
||||
sessionMode = TSMUser,
|
||||
tcpConnectTimeout = 7_500_000,
|
||||
tcpTimeout = 5_000_000,
|
||||
tcpKeepAlive = Just defaultKeepAliveOpts,
|
||||
smpPingInterval = 600_000_000, -- 10min
|
||||
smpPingCount = 3,
|
||||
logTLSErrors = False
|
||||
}
|
||||
|
||||
@@ -229,19 +251,29 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
|
||||
publicHost = find (not . isOnionHost) hosts
|
||||
|
||||
clientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient msg -> String
|
||||
clientServer = B.unpack . strEncode . protocolServer . client_
|
||||
clientServer = B.unpack . strEncode . snd3 . transportSession . client_
|
||||
where
|
||||
snd3 (_, s, _) = s
|
||||
|
||||
transportHost' :: ProtocolClient msg -> TransportHost
|
||||
transportHost' = transportHost . client_
|
||||
|
||||
transportSession' :: ProtocolClient msg -> TransportSession msg
|
||||
transportSession' = transportSession . client_
|
||||
|
||||
type UserId = Int64
|
||||
|
||||
-- | Transport session key - includes entity ID if `sessionMode = TSMEntity`.
|
||||
type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
|
||||
|
||||
-- | Connects to 'ProtocolServer' using passed client configuration
|
||||
-- and queue for messages and notifications.
|
||||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
|
||||
case chooseTransportHost networkConfig (host protocolServer) of
|
||||
getProtocolClient :: forall msg. Protocol msg => TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
|
||||
case chooseTransportHost networkConfig (host srv) of
|
||||
Right useHost ->
|
||||
(atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
|
||||
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
|
||||
@@ -251,6 +283,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
||||
mkProtocolClient :: TransportHost -> STM (PClient msg)
|
||||
mkProtocolClient transportHost = do
|
||||
connected <- newTVar False
|
||||
pingErrorCount <- newTVar 0
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- TM.empty
|
||||
sndQ <- newTBQueue qSize
|
||||
@@ -258,9 +291,10 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
||||
return
|
||||
PClient
|
||||
{ connected,
|
||||
protocolServer,
|
||||
transportSession,
|
||||
transportHost,
|
||||
tcpTimeout,
|
||||
pingErrorCount,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
sndQ,
|
||||
@@ -272,9 +306,10 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
||||
runClient (port', ATransport t) useHost c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
let tcConfig = transportClientConfig networkConfig
|
||||
username = proxyUsername transportSession
|
||||
action <-
|
||||
async $
|
||||
runTransportClient tcConfig useHost port' (Just $ keyHash protocolServer) (client t c cVar)
|
||||
runTransportClient tcConfig (Just username) useHost port' (Just $ keyHash srv) (client t c cVar)
|
||||
`finally` atomically (putTMVar cVar $ Left PCENetworkError)
|
||||
c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar)
|
||||
pure $ case c_ of
|
||||
@@ -282,15 +317,18 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left PCENetworkError
|
||||
|
||||
proxyUsername :: TransportSession msg -> ByteString
|
||||
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
|
||||
|
||||
useTransport :: (ServiceName, ATransport)
|
||||
useTransport = case port protocolServer of
|
||||
useTransport = case port srv of
|
||||
"" -> defaultTransport cfg
|
||||
"80" -> ("80", transport @WS)
|
||||
p -> (p, transport @TLS)
|
||||
|
||||
client :: forall c. Transport c => TProxy c -> PClient msg -> TMVar (Either ProtocolClientError (ProtocolClient msg)) -> c -> IO ()
|
||||
client _ c cVar h =
|
||||
runExceptT (protocolClientHandshake @msg h (keyHash protocolServer) smpServerVRange) >>= \case
|
||||
runExceptT (protocolClientHandshake @msg h (keyHash srv) smpServerVRange) >>= \case
|
||||
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
|
||||
Right th@THandle {sessionId, thVersion} -> do
|
||||
sessionTs <- getCurrentTime
|
||||
@@ -308,9 +346,15 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
||||
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: ProtocolClient msg -> IO ()
|
||||
ping c = forever $ do
|
||||
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
|
||||
threadDelay smpPingInterval
|
||||
runExceptT $ sendProtocolCommand c Nothing "" protocolPing
|
||||
runExceptT (sendProtocolCommand c Nothing "" protocolPing) >>= \case
|
||||
Left PCEResponseTimeout -> do
|
||||
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
|
||||
when (maxCnt == 0 || cnt < maxCnt) $ ping c
|
||||
_ -> ping c -- sendProtocolCommand resets pingErrorCount
|
||||
where
|
||||
maxCnt = smpPingCount networkConfig
|
||||
|
||||
process :: ProtocolClient msg -> IO ()
|
||||
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
|
||||
@@ -412,8 +456,8 @@ writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
|
||||
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
|
||||
|
||||
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
|
||||
serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {protocolServer}} entityId message =
|
||||
(protocolServer, thVersion, sessionId, entityId, message)
|
||||
serverTransmission ProtocolClient {thVersion, sessionId, client_ = PClient {transportSession}} entityId message =
|
||||
(transportSession, thVersion, sessionId, entityId, message)
|
||||
|
||||
-- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue
|
||||
--
|
||||
@@ -464,13 +508,7 @@ disableSMPQueueNotifications = okSMPCommand NDEL
|
||||
|
||||
-- | Disable notifications for multiple queues for push notifications server.
|
||||
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
disableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
|
||||
where
|
||||
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient NDEL)) qs
|
||||
response = \case
|
||||
Right OK -> Right ()
|
||||
Right r -> Left . PCEUnexpectedResponse $ bshow r
|
||||
Left e -> Left e
|
||||
disableSMPQueuesNtfs = okSMPCommands NDEL
|
||||
|
||||
-- | Send SMP message.
|
||||
--
|
||||
@@ -501,41 +539,57 @@ suspendSMPQueue = okSMPCommand OFF
|
||||
-- | Irreversibly delete SMP queue and all messages in it.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand DEL
|
||||
|
||||
-- | Delete multiple SMP queues batching commands if supported.
|
||||
deleteSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
deleteSMPQueues = okSMPCommands DEL
|
||||
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
sendSMPCommand c (Just pKey) qId cmd >>= \case
|
||||
OK -> return ()
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
okSMPCommands :: PartyI p => Command p -> SMPClient -> NonEmpty (C.APrivateSignKey, QueueId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
okSMPCommands cmd c qs = L.map response <$> sendProtocolCommands c cs
|
||||
where
|
||||
aCmd = Cmd sParty cmd
|
||||
cs = L.map (\(pKey, qId) -> (Just pKey, qId, aCmd)) qs
|
||||
response = \case
|
||||
Right OK -> Right ()
|
||||
Right r -> Left . PCEUnexpectedResponse $ bshow r
|
||||
Left e -> Left e
|
||||
|
||||
-- | Send SMP command
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
|
||||
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
|
||||
|
||||
-- | Send multiple commands with batching and collect responses
|
||||
sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg))
|
||||
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}} cs = do
|
||||
sendProtocolCommands c@ProtocolClient {client_ = PClient {sndQ}} cs = do
|
||||
ts <- mapM (runExceptT . mkTransmission c) cs
|
||||
mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts
|
||||
forConcurrently ts $ \case
|
||||
Right (_, r) -> withTimeout . atomically $ takeTMVar r
|
||||
Right (_, r) -> withTimeout c $ atomically $ takeTMVar r
|
||||
Left e -> pure $ Left e
|
||||
where
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
-- | Send Protocol command
|
||||
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
|
||||
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ, tcpTimeout}} pKey qId cmd = do
|
||||
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}} pKey qId cmd = do
|
||||
(t, r) <- mkTransmission c (pKey, qId, cmd)
|
||||
ExceptT $ sendRecv t r
|
||||
where
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg)
|
||||
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout (atomically $ takeTMVar r)
|
||||
where
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout c (atomically $ takeTMVar r)
|
||||
|
||||
withTimeout :: ProtocolClient msg -> IO (Either ProtocolClientError msg) -> IO (Either ProtocolClientError msg)
|
||||
withTimeout ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} a =
|
||||
timeout tcpTimeout a >>= \case
|
||||
Just r -> atomically (writeTVar pingErrorCount 0) >> pure r
|
||||
_ -> pure $ Left PCEResponseTimeout
|
||||
|
||||
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
|
||||
mkTransmission ProtocolClient {sessionId, thVersion, client_ = PClient {clientCorrId, sentCommands}} (pKey, qId, cmd) = do
|
||||
|
||||
@@ -160,7 +160,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: ExceptT ProtocolClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
connectClient = ExceptT $ getProtocolClient (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
|
||||
clientDisconnected :: SMPClient -> IO ()
|
||||
clientDisconnected _ = do
|
||||
|
||||
@@ -190,7 +190,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
|
||||
receiveSMP :: M ()
|
||||
receiveSMP = forever $ do
|
||||
(srv, _, _, ntfId, msg) <- atomically $ readTBQueue msgQ
|
||||
((_, srv, _), _, _, ntfId, msg) <- atomically $ readTBQueue msgQ
|
||||
let smpQueue = SMPQueueNtf srv ntfId
|
||||
case msg of
|
||||
SMP.NMSG nmsgNonce encNMsgMeta -> do
|
||||
@@ -372,6 +372,7 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
|
||||
t_ <- atomically $ getActiveNtfToken st subTknId
|
||||
verifyToken' t_ $ verifiedSubCmd s c
|
||||
else pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
|
||||
NtfCmd SSubscription PING -> pure $ VRVerified $ NtfReqPing corrId entId
|
||||
NtfCmd SSubscription c -> do
|
||||
s_ <- atomically $ getNtfSubscription st entId
|
||||
case s_ of
|
||||
@@ -529,6 +530,7 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
|
||||
incNtfStat subDeleted
|
||||
pure NROk
|
||||
PING -> pure NRPong
|
||||
NtfReqPing corrId entId -> pure (corrId, entId, NRPong)
|
||||
getId :: M NtfEntityId
|
||||
getId = getRandomBytes =<< asks (subIdBytes . config)
|
||||
getRegCode :: M NtfRegCode
|
||||
|
||||
@@ -148,6 +148,7 @@ getPushClient s@NtfPushServer {pushClients} pp =
|
||||
data NtfRequest
|
||||
= NtfReqNew CorrId ANewNtfEntity
|
||||
| forall e. NtfEntityI e => NtfReqCmd (SNtfEntity e) (NtfEntityRec e) (Transmission (NtfCommand e))
|
||||
| NtfReqPing CorrId NtfEntityId
|
||||
|
||||
data NtfServerClient = NtfServerClient
|
||||
{ rcvQ :: TBQueue NtfRequest,
|
||||
|
||||
@@ -77,6 +77,7 @@ module Simplex.Messaging.Protocol
|
||||
BasicAuth (..),
|
||||
SrvLoc (..),
|
||||
CorrId (..),
|
||||
EntityId,
|
||||
QueueId,
|
||||
RecipientId,
|
||||
SenderId,
|
||||
@@ -753,7 +754,7 @@ basicAuth s
|
||||
where
|
||||
valid c = isPrint c && not (isSpace c) && c /= '@' && c /= ':' && c /= '/'
|
||||
|
||||
data ProtoServerWithAuth p = ProtoServerWithAuth (ProtocolServer p) (Maybe BasicAuth)
|
||||
data ProtoServerWithAuth p = ProtoServerWithAuth {protoServer :: ProtocolServer p, serverBasicAuth :: Maybe BasicAuth}
|
||||
deriving (Show)
|
||||
|
||||
instance ProtocolTypeI p => IsString (ProtoServerWithAuth p) where
|
||||
|
||||
@@ -47,6 +47,7 @@ data ServerStatsData = ServerStatsData
|
||||
_qCount :: Int,
|
||||
_msgCount :: Int
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newServerStats :: UTCTime -> STM ServerStats
|
||||
newServerStats ts = do
|
||||
@@ -88,7 +89,7 @@ setServerStats s d = do
|
||||
writeTVar (qDeleted s) $! _qDeleted d
|
||||
writeTVar (msgSent s) $! _msgSent d
|
||||
writeTVar (msgRecv s) $! _msgRecv d
|
||||
setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d)
|
||||
setPeriodStats (activeQueues s) (_activeQueues d)
|
||||
writeTVar (msgSentNtf s) $! _msgSentNtf d
|
||||
writeTVar (msgRecvNtf s) $! _msgRecvNtf d
|
||||
setPeriodStats (activeQueuesNtf s) (_activeQueuesNtf d)
|
||||
@@ -152,6 +153,7 @@ data PeriodStatsData a = PeriodStatsData
|
||||
_week :: Set a,
|
||||
_month :: Set a
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newPeriodStatsData :: PeriodStatsData a
|
||||
newPeriodStatsData = PeriodStatsData {_day = S.empty, _week = S.empty, _month = S.empty}
|
||||
|
||||
@@ -107,15 +107,15 @@ defaultTransportClientConfig :: TransportClientConfig
|
||||
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True
|
||||
|
||||
-- | Connect to passed TCP host:port and pass handle to the client.
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} host port keyHash client = do
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do
|
||||
let hostName = B.unpack $ strEncode host
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash
|
||||
connectTCP = case socksProxy of
|
||||
Just proxy -> connectSocksClient proxy $ hostAddr host
|
||||
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
|
||||
_ -> connectTCPClient hostName
|
||||
c <- liftIO $ do
|
||||
sock <- connectTCP port
|
||||
@@ -153,10 +153,12 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
|
||||
defaultSMPPort :: PortNumber
|
||||
defaultSMPPort = 5223
|
||||
|
||||
connectSocksClient :: SocksProxy -> SocksHostAddress -> ServiceName -> IO Socket
|
||||
connectSocksClient (SocksProxy addr) hostAddr _port = do
|
||||
connectSocksClient :: SocksProxy -> Maybe ByteString -> SocksHostAddress -> ServiceName -> IO Socket
|
||||
connectSocksClient (SocksProxy addr) proxyUsername hostAddr _port = do
|
||||
let port = if null _port then defaultSMPPort else fromMaybe defaultSMPPort $ readMaybe _port
|
||||
fst <$> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
|
||||
fst <$> case proxyUsername of
|
||||
Just username -> socksConnectAuth (defaultSocksConf addr) (SocksAddress hostAddr port) (SocksCredentials username "")
|
||||
_ -> socksConnect (defaultSocksConf addr) (SocksAddress hostAddr port)
|
||||
|
||||
defaultSocksHost :: HostAddress
|
||||
defaultSocksHost = tupleToHostAddress (127, 0, 0, 1)
|
||||
|
||||
@@ -120,7 +120,7 @@ sendRequest HTTP2Client {reqQ, config} req = do
|
||||
|
||||
runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> HostName -> ServiceName -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
|
||||
runHTTP2Client tlsParams caStore tcConfig host port client =
|
||||
runTLSTransportClient tlsParams caStore tcConfig (THDomainName host) port Nothing $ \c ->
|
||||
runTLSTransportClient tlsParams caStore tcConfig Nothing (THDomainName host) port Nothing $ \c ->
|
||||
withTlsConfig c 16384 (`run` client)
|
||||
where
|
||||
run = H.run $ ClientConfig "https" (B.pack host) 20
|
||||
|
||||
+4
-4
@@ -338,8 +338,8 @@ testServerConnectionAfterError t _ = do
|
||||
withServer $ do
|
||||
alice <#= \case ("", "bob", SENT 4) -> True; ("", "", UP s ["bob"]) -> s == server; _ -> False
|
||||
alice <#= \case ("", "bob", SENT 4) -> True; ("", "", UP s ["bob"]) -> s == server; _ -> False
|
||||
bob <# ("", "", UP server ["alice"])
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
|
||||
bob #: ("2", "alice", "ACK 4") #> ("2", "alice", OK)
|
||||
alice #: ("1", "bob", "SEND F 11\nhello again") #> ("1", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
@@ -381,8 +381,8 @@ testMsgDeliveryAgentRestart t bob = do
|
||||
(corrId == "3" && cmd == OK)
|
||||
|| (corrId == "" && cmd == SENT 5)
|
||||
_ -> False
|
||||
bob <# ("", "", UP server ["alice"])
|
||||
bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
|
||||
bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK)
|
||||
|
||||
removeFile testStoreLogFile
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
|
||||
module AgentTests.FunctionalAPITests
|
||||
@@ -26,10 +27,11 @@ where
|
||||
|
||||
import Control.Concurrent (killThread, threadDelay)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (ExceptT, MonadError (throwError), runExceptT)
|
||||
import Control.Monad.Except (ExceptT, runExceptT, throwError)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (isRight)
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.Map as M
|
||||
import Data.Maybe (isNothing)
|
||||
@@ -41,7 +43,8 @@ import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Client (SMPTestFailure (..), SMPTestStep (..))
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig)
|
||||
import Simplex.Messaging.Agent.Store (UserId)
|
||||
import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultClientConfig)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..))
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
@@ -137,6 +140,17 @@ functionalAPITests t = do
|
||||
testAsyncCommandsRestore t
|
||||
it "should accept connection using async command" $
|
||||
withSmpServer t testAcceptContactAsync
|
||||
it "should delete connections using async command when server connection fails" $
|
||||
testDeleteConnectionAsync t
|
||||
describe "Users" $ do
|
||||
it "should create and delete user with connections" $
|
||||
withSmpServer t testUsers
|
||||
it "should create and delete user without connections" $
|
||||
withSmpServer t testDeleteUserQuietly
|
||||
it "should create and delete user with connections when server connection fails" $
|
||||
testUsersNoServer t
|
||||
it "should connect two users and switch session mode" $
|
||||
withSmpServer t testTwoUsers
|
||||
describe "Queue rotation" $ do
|
||||
describe "should switch delivery to the new queue" $
|
||||
testServerMatrix2 t testSwitchConnection
|
||||
@@ -229,8 +243,8 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do
|
||||
runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientTest alice bob baseId = do
|
||||
runRight_ $ do
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
@@ -264,8 +278,8 @@ runAgentClientTest alice bob baseId = do
|
||||
runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientContactTest alice bob baseId = do
|
||||
runRight_ $ do
|
||||
(_, qInfo) <- createConnection alice True SCMContact Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, REQ invId _ "bob's connInfo") <- get alice
|
||||
bobId <- acceptContact alice True invId "alice's connInfo"
|
||||
("", _, CONF confId _ "alice's connInfo") <- get bob
|
||||
@@ -303,7 +317,7 @@ noMessages c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` get c >>= \case
|
||||
Just _ -> error err
|
||||
Just msg -> error $ err <> ": " <> show msg
|
||||
_ -> return ()
|
||||
|
||||
testAsyncInitiatingOffline :: IO ()
|
||||
@@ -311,9 +325,9 @@ testAsyncInitiatingOffline = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice'
|
||||
@@ -328,8 +342,8 @@ testAsyncJoiningOfflineBeforeActivation = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
@@ -345,9 +359,9 @@ testAsyncBothOffline = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(bobId, cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(bobId, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
@@ -366,9 +380,9 @@ testAsyncServerOffline t = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
-- create connection and shutdown the server
|
||||
(bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
|
||||
runRight $ createConnection alice True SCMInvitation Nothing
|
||||
runRight $ createConnection alice 1 True SCMInvitation Nothing
|
||||
-- connection fails
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob True cReq "bob's connInfo"
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo"
|
||||
("", "", DOWN srv conns) <- get alice
|
||||
srv `shouldBe` testSMPServer
|
||||
conns `shouldBe` [bobId]
|
||||
@@ -378,7 +392,7 @@ testAsyncServerOffline t = do
|
||||
liftIO $ do
|
||||
srv1 `shouldBe` testSMPServer
|
||||
conns1 `shouldBe` [bobId]
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
@@ -392,9 +406,9 @@ testAsyncHelloTimeout = do
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2, helloTimeout = 1} initAgentServers
|
||||
runRight_ $ do
|
||||
(_, cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(_, cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob True cReq "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
|
||||
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
|
||||
|
||||
testDuplicateMessage :: ATransport -> IO ()
|
||||
@@ -446,9 +460,12 @@ testDuplicateMessage t = do
|
||||
get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False
|
||||
|
||||
makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
|
||||
makeConnection alice bob = do
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
makeConnection alice bob = makeConnectionForUsers alice 1 bob 1
|
||||
|
||||
makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId)
|
||||
makeConnectionForUsers alice aliceUserId bob bobUserId = do
|
||||
(bobId, qInfo) <- createConnection alice aliceUserId True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob bobUserId True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
@@ -462,7 +479,7 @@ testInactiveClientDisconnected t = do
|
||||
withSmpServerConfigOn t cfg' testPort $ \_ -> do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
runRight_ $ do
|
||||
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
get alice ##> ("", "", DOWN testSMPServer [connId])
|
||||
|
||||
testActiveClientNotDisconnected :: ATransport -> IO ()
|
||||
@@ -472,7 +489,7 @@ testActiveClientNotDisconnected t = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
ts <- getSystemTime
|
||||
runRight_ $ do
|
||||
(connId, _cReq) <- createConnection alice True SCMInvitation Nothing
|
||||
(connId, _cReq) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
keepSubscribing alice connId ts
|
||||
where
|
||||
keepSubscribing :: AgentClient -> ConnId -> SystemTime -> ExceptT AgentErrorType IO ()
|
||||
@@ -536,10 +553,8 @@ testSuspendingAgentCompleteSending t = do
|
||||
get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False
|
||||
("", "", SUSPENDED) <- get b
|
||||
|
||||
r <- get a
|
||||
liftIO $ print r
|
||||
("", "", UP {}) <- pure r
|
||||
get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False
|
||||
get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False
|
||||
get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False
|
||||
ackMessage a bId 5
|
||||
get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False
|
||||
ackMessage a bId 6
|
||||
@@ -572,9 +587,9 @@ testBatchedSubscriptions t = do
|
||||
conns <- runServers $ do
|
||||
conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b
|
||||
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
|
||||
forM_ (take 10 conns) $ \(aId, bId) -> do
|
||||
deleteConnection a bId
|
||||
deleteConnection b aId
|
||||
let (aIds', bIds') = unzip $ take 10 conns
|
||||
delete a bIds'
|
||||
delete b aIds'
|
||||
liftIO $ threadDelay 1000000
|
||||
pure conns
|
||||
("", "", DOWN {}) <- get a
|
||||
@@ -587,18 +602,37 @@ testBatchedSubscriptions t = do
|
||||
("", "", UP {}) <- get b
|
||||
("", "", UP {}) <- get b
|
||||
liftIO $ threadDelay 1000000
|
||||
subscribe a $ map snd conns
|
||||
subscribe b $ map fst conns
|
||||
forM_ (drop 10 conns) $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId
|
||||
let (aIds, bIds) = unzip conns
|
||||
conns' = drop 10 conns
|
||||
(aIds', bIds') = unzip conns'
|
||||
subscribe a bIds
|
||||
subscribe b aIds
|
||||
forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId
|
||||
delete a bIds'
|
||||
delete b aIds'
|
||||
deleteFail a bIds'
|
||||
deleteFail b aIds'
|
||||
where
|
||||
subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
|
||||
subscribe c cs = do
|
||||
r <- subscribeConnections c cs
|
||||
liftIO $ do
|
||||
let dc = S.fromList $ take 10 cs
|
||||
all (== Right ()) (M.withoutKeys r dc) `shouldBe` True
|
||||
all isRight (M.withoutKeys r dc) `shouldBe` True
|
||||
all (== Left (CONN NOT_FOUND)) (M.restrictKeys r dc) `shouldBe` True
|
||||
M.keys r `shouldMatchList` cs
|
||||
delete :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
|
||||
delete c cs = do
|
||||
r <- deleteConnections c cs
|
||||
liftIO $ do
|
||||
all isRight r `shouldBe` True
|
||||
M.keys r `shouldMatchList` cs
|
||||
deleteFail :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
|
||||
deleteFail c cs = do
|
||||
r <- deleteConnections c cs
|
||||
liftIO $ do
|
||||
all (== Left (CONN NOT_FOUND)) r `shouldBe` True
|
||||
M.keys r `shouldMatchList` cs
|
||||
runServers :: ExceptT AgentErrorType IO a -> IO a
|
||||
runServers a = do
|
||||
withSmpServerStoreLogOn t testPort $ \t1 -> do
|
||||
@@ -612,10 +646,10 @@ testAsyncCommands = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
bobId <- createConnectionAsync alice "1" True SCMInvitation
|
||||
bobId <- createConnectionAsync alice 1 "1" True SCMInvitation
|
||||
("1", bobId', INV (ACR _ qInfo)) <- get alice
|
||||
liftIO $ bobId' `shouldBe` bobId
|
||||
aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo"
|
||||
aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo"
|
||||
("2", aliceId', OK) <- get bob
|
||||
liftIO $ aliceId' `shouldBe` aliceId
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
@@ -645,8 +679,9 @@ testAsyncCommands = do
|
||||
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
|
||||
ackMessageAsync alice "7" bobId $ baseId + 4
|
||||
("7", _, OK) <- get alice
|
||||
deleteConnectionAsync alice "8" bobId
|
||||
("8", _, OK) <- get alice
|
||||
deleteConnectionAsync alice bobId
|
||||
get alice =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bobId; _ -> False
|
||||
get alice =##> \case ("", c, DEL_CONN) -> c == bobId; _ -> False
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
where
|
||||
baseId = 3
|
||||
@@ -655,7 +690,7 @@ testAsyncCommands = do
|
||||
testAsyncCommandsRestore :: ATransport -> IO ()
|
||||
testAsyncCommandsRestore t = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bobId <- runRight $ createConnectionAsync alice "1" True SCMInvitation
|
||||
bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation
|
||||
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
|
||||
disconnectAgentClient alice
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
@@ -670,8 +705,8 @@ testAcceptContactAsync = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(_, qInfo) <- createConnection alice True SCMContact Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, REQ invId _ "bob's connInfo") <- get alice
|
||||
bobId <- acceptContactAsync alice "1" True invId "alice's connInfo"
|
||||
("1", bobId', OK) <- get alice
|
||||
@@ -707,6 +742,89 @@ testAcceptContactAsync = do
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
testDeleteConnectionAsync :: ATransport -> IO ()
|
||||
testDeleteConnectionAsync t = do
|
||||
a <- getSMPAgentClient agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers
|
||||
connIds <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
|
||||
(bId1, _inv) <- createConnection a 1 True SCMInvitation Nothing
|
||||
(bId2, _inv) <- createConnection a 1 True SCMInvitation Nothing
|
||||
(bId3, _inv) <- createConnection a 1 True SCMInvitation Nothing
|
||||
pure ([bId1, bId2, bId3] :: [ConnId])
|
||||
runRight_ $ do
|
||||
deleteConnectionsAsync a connIds
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get a =##> \case ("", c, DEL_CONN) -> c `elem` connIds; _ -> False
|
||||
get a =##> \case ("", c, DEL_CONN) -> c `elem` connIds; _ -> False
|
||||
get a =##> \case ("", c, DEL_CONN) -> c `elem` connIds; _ -> False
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
|
||||
testUsers :: IO ()
|
||||
testUsers = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
auId <- createUser a [noAuthSrv testSMPServer]
|
||||
(aId', bId') <- makeConnectionForUsers a auId b 1
|
||||
exchangeGreetingsMsgId 4 a bId' b aId'
|
||||
deleteUser a auId True
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId'; _ -> False
|
||||
get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False
|
||||
get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False
|
||||
exchangeGreetingsMsgId 6 a bId b aId
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
|
||||
testDeleteUserQuietly :: IO ()
|
||||
testDeleteUserQuietly = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
auId <- createUser a [noAuthSrv testSMPServer]
|
||||
(aId', bId') <- makeConnectionForUsers a auId b 1
|
||||
exchangeGreetingsMsgId 4 a bId' b aId'
|
||||
deleteUser a auId False
|
||||
exchangeGreetingsMsgId 6 a bId b aId
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
|
||||
testUsersNoServer :: ATransport -> IO ()
|
||||
testUsersNoServer t = do
|
||||
a <- getSMPAgentClient agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
liftIO $ print 1
|
||||
(aId, bId, auId, _aId', bId') <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
auId <- createUser a [noAuthSrv testSMPServer]
|
||||
(aId', bId') <- makeConnectionForUsers a auId b 1
|
||||
exchangeGreetingsMsgId 4 a bId' b aId'
|
||||
pure (aId, bId, auId, aId', bId')
|
||||
liftIO $ print 2
|
||||
get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False
|
||||
get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False
|
||||
get b =##> \case ("", "", DOWN _ cs) -> length cs == 2; _ -> False
|
||||
liftIO $ print 3
|
||||
runRight_ $ do
|
||||
deleteUser a auId True
|
||||
liftIO $ print 4
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
liftIO $ print 4.1
|
||||
get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False
|
||||
liftIO $ print 4.2
|
||||
get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False
|
||||
liftIO $ print 5
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
|
||||
liftIO $ print 6
|
||||
get a =##> \case ("", "", UP _ [c]) -> c == bId; _ -> False
|
||||
get b =##> \case ("", "", UP _ cs) -> length cs == 2; _ -> False
|
||||
liftIO $ print 7
|
||||
exchangeGreetingsMsgId 6 a bId b aId
|
||||
|
||||
testSwitchConnection :: InitialAgentServers -> IO ()
|
||||
testSwitchConnection servers = do
|
||||
a <- getSMPAgentClient agentCfg servers
|
||||
@@ -743,21 +861,26 @@ phase c connId d p =
|
||||
|
||||
testSwitchAsync :: InitialAgentServers -> IO ()
|
||||
testSwitchAsync servers = do
|
||||
liftIO $ print 1
|
||||
(aId, bId) <- withA $ \a -> withB $ \b -> runRight $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
pure (aId, bId)
|
||||
liftIO $ print 2
|
||||
let withA' = session withA bId
|
||||
withB' = session withB aId
|
||||
withA' $ \a -> do
|
||||
switchConnectionAsync a "" bId
|
||||
phase a bId QDRcv SPStarted
|
||||
liftIO $ print 3
|
||||
withB' $ \b -> phase b aId QDSnd SPStarted
|
||||
withA' $ \a -> phase a bId QDRcv SPConfirmed
|
||||
liftIO $ print 4
|
||||
withB' $ \b -> do
|
||||
phase b aId QDSnd SPConfirmed
|
||||
phase b aId QDSnd SPCompleted
|
||||
withA' $ \a -> phase a bId QDRcv SPCompleted
|
||||
liftIO $ print 5
|
||||
withA $ \a -> withB $ \b -> runRight_ $ do
|
||||
subscribeConnection a bId
|
||||
subscribeConnection b aId
|
||||
@@ -785,20 +908,22 @@ testSwitchDelete servers = do
|
||||
disconnectAgentClient b
|
||||
switchConnectionAsync a "" bId
|
||||
phase a bId QDRcv SPStarted
|
||||
deleteConnectionAsync a "1" bId
|
||||
("1", bId', OK) <- get a
|
||||
liftIO $ bId `shouldBe` bId'
|
||||
deleteConnectionAsync a bId
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False
|
||||
get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False
|
||||
get a =##> \case ("", c, DEL_CONN) -> c == bId; _ -> False
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
|
||||
testCreateQueueAuth :: (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int
|
||||
testCreateQueueAuth clnt1 clnt2 = do
|
||||
a <- getClient clnt1
|
||||
b <- getClient clnt2
|
||||
runRight $ do
|
||||
tryError (createConnection a True SCMInvitation Nothing) >>= \case
|
||||
tryError (createConnection a 1 True SCMInvitation Nothing) >>= \case
|
||||
Left (SMP AUTH) -> pure 0
|
||||
Left e -> throwError e
|
||||
Right (bId, qInfo) ->
|
||||
tryError (joinConnection b True qInfo "bob's connInfo") >>= \case
|
||||
tryError (joinConnection b 1 True qInfo "bob's connInfo") >>= \case
|
||||
Left (SMP AUTH) -> pure 1
|
||||
Left e -> throwError e
|
||||
Right aId -> do
|
||||
@@ -811,7 +936,7 @@ testCreateQueueAuth clnt1 clnt2 = do
|
||||
pure 2
|
||||
where
|
||||
getClient (clntAuth, clntVersion) =
|
||||
let servers = initAgentServers {smp = [ProtoServerWithAuth testSMPServer clntAuth]}
|
||||
let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]}
|
||||
smpCfg = (defaultClientConfig :: ProtocolClientConfig) {smpServerVRange = mkVersionRange 4 clntVersion}
|
||||
in getSMPAgentClient agentCfg {smpCfg} servers
|
||||
|
||||
@@ -819,7 +944,7 @@ testSMPServerConnectionTest :: ATransport -> Maybe BasicAuth -> SMPServerWithAut
|
||||
testSMPServerConnectionTest t newQueueBasicAuth srv =
|
||||
withSmpServerConfigOn t cfg {newQueueBasicAuth} testPort2 $ \_ -> do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers -- initially passed server is not running
|
||||
runRight $ testSMPServerConnection a srv
|
||||
runRight $ testSMPServerConnection a 1 srv
|
||||
|
||||
testRatchetAdHash :: IO ()
|
||||
testRatchetAdHash = do
|
||||
@@ -831,6 +956,73 @@ testRatchetAdHash = do
|
||||
ad2 <- getConnectionRatchetAdHash b aId
|
||||
liftIO $ ad1 `shouldBe` ad2
|
||||
|
||||
testTwoUsers :: IO ()
|
||||
testTwoUsers = do
|
||||
let nc = netCfg initAgentServers
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
sessionMode nc `shouldBe` TSMUser
|
||||
runRight_ $ do
|
||||
(aId1, bId1) <- makeConnectionForUsers a 1 b 1
|
||||
exchangeGreetings a bId1 b aId1
|
||||
(aId1', bId1') <- makeConnectionForUsers a 1 b 1
|
||||
exchangeGreetings a bId1' b aId1'
|
||||
a `hasClients` 1
|
||||
b `hasClients` 1
|
||||
setNetworkConfig a nc {sessionMode = TSMEntity}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
a `hasClients` 2
|
||||
|
||||
exchangeGreetingsMsgId 6 a bId1 b aId1
|
||||
exchangeGreetingsMsgId 6 a bId1' b aId1'
|
||||
liftIO $ threadDelay 250000
|
||||
setNetworkConfig a nc {sessionMode = TSMUser}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
a `hasClients` 1
|
||||
|
||||
aUserId2 <- createUser a [noAuthSrv testSMPServer]
|
||||
(aId2, bId2) <- makeConnectionForUsers a aUserId2 b 1
|
||||
exchangeGreetings a bId2 b aId2
|
||||
(aId2', bId2') <- makeConnectionForUsers a aUserId2 b 1
|
||||
exchangeGreetings a bId2' b aId2'
|
||||
a `hasClients` 2
|
||||
b `hasClients` 1
|
||||
setNetworkConfig a nc {sessionMode = TSMEntity}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
a `hasClients` 4
|
||||
exchangeGreetingsMsgId 8 a bId1 b aId1
|
||||
exchangeGreetingsMsgId 8 a bId1' b aId1'
|
||||
exchangeGreetingsMsgId 6 a bId2 b aId2
|
||||
exchangeGreetingsMsgId 6 a bId2' b aId2'
|
||||
liftIO $ threadDelay 250000
|
||||
setNetworkConfig a nc {sessionMode = TSMUser}
|
||||
liftIO $ threadDelay 250000
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", DOWN _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
("", "", UP _ _) <- get a
|
||||
a `hasClients` 2
|
||||
exchangeGreetingsMsgId 10 a bId1 b aId1
|
||||
exchangeGreetingsMsgId 10 a bId1' b aId1'
|
||||
exchangeGreetingsMsgId 8 a bId2 b aId2
|
||||
exchangeGreetingsMsgId 8 a bId2' b aId2'
|
||||
where
|
||||
hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings = exchangeGreetingsMsgId 4
|
||||
|
||||
|
||||
@@ -209,8 +209,8 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
(bobId, aliceId, nonce, message) <- runRight $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
@@ -270,9 +270,9 @@ testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
|
||||
DeviceToken {} <- registerTestToken bob "bcde" NMInstant apnsQ
|
||||
-- establish connection
|
||||
liftIO $ threadDelay 50000
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
liftIO $ threadDelay 1000000
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
liftIO $ threadDelay 750000
|
||||
void $ messageNotification apnsQ
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
@@ -321,8 +321,8 @@ testChangeNotificationsMode APNSMockServer {apnsQ} = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
runRight_ $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
@@ -385,8 +385,8 @@ testChangeToken APNSMockServer {apnsQ} = do
|
||||
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
|
||||
(aliceId, bobId) <- runRight $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob True qInfo "bob's connInfo"
|
||||
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
|
||||
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
|
||||
@@ -140,7 +140,7 @@ testForeignKeysEnabled =
|
||||
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
|
||||
|
||||
cData1 :: ConnData
|
||||
cData1 = ConnData {connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
|
||||
cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, duplexHandshake = Nothing, deleted = False}
|
||||
|
||||
testPrivateSignKey :: C.APrivateSignKey
|
||||
testPrivateSignKey = C.APrivateSignKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
|
||||
@@ -154,7 +154,8 @@ testDhSecret = "01234567890123456789012345678901"
|
||||
rcvQueue1 :: RcvQueue
|
||||
rcvQueue1 =
|
||||
RcvQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
rcvId = "1234",
|
||||
rcvPrivateKey = testPrivateSignKey,
|
||||
@@ -167,13 +168,15 @@ rcvQueue1 =
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1,
|
||||
clientNtfCreds = Nothing
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
|
||||
sndQueue1 :: SndQueue
|
||||
sndQueue1 =
|
||||
SndQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
sndId = "3456",
|
||||
sndPublicKey = Nothing,
|
||||
@@ -314,7 +317,8 @@ testUpgradeRcvConnToDuplex =
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
sndId = "2345",
|
||||
sndPublicKey = Nothing,
|
||||
@@ -340,7 +344,8 @@ testUpgradeSndConnToDuplex =
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ connId = "conn1",
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
rcvId = "3456",
|
||||
rcvPrivateKey = testPrivateSignKey,
|
||||
@@ -353,7 +358,8 @@ testUpgradeSndConnToDuplex =
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1,
|
||||
clientNtfCreds = Nothing
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
upgradeSndConnToDuplex db "conn1" anotherRcvQueue
|
||||
`shouldReturn` Left (SEBadConnType CRcv)
|
||||
|
||||
+1
-1
@@ -69,7 +69,7 @@ ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
|
||||
testNtfClient client = do
|
||||
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig host ntfTestPort (Just testKeyHash) $ \h ->
|
||||
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h ->
|
||||
liftIO (runExceptT $ ntfClientHandshake h testKeyHash supportedNTFServerVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
+14
-17
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
@@ -11,7 +12,8 @@ import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Network.Socket (ServiceName)
|
||||
import NtfClient (ntfTestPort)
|
||||
import SMPClient
|
||||
@@ -27,6 +29,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Agent.Store (UserId)
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Transport
|
||||
@@ -173,13 +176,13 @@ testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5
|
||||
initAgentServers :: InitialAgentServers
|
||||
initAgentServers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList [noAuthSrv testSMPServer],
|
||||
{ smp = userServers [noAuthSrv testSMPServer],
|
||||
ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"],
|
||||
netCfg = defaultNetworkConfig {tcpTimeout = 500_000}
|
||||
netCfg = defaultNetworkConfig {tcpTimeout = 500_000, tcpConnectTimeout = 500_000}
|
||||
}
|
||||
|
||||
initAgentServers2 :: InitialAgentServers
|
||||
initAgentServers2 = initAgentServers {smp = L.fromList [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
|
||||
initAgentServers2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer, noAuthSrv testSMPServer2]}
|
||||
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
@@ -187,17 +190,8 @@ agentCfg =
|
||||
{ tcpPort = agentTestPort,
|
||||
tbqSize = 4,
|
||||
database = testDB,
|
||||
smpCfg =
|
||||
defaultClientConfig
|
||||
{ qSize = 1,
|
||||
defaultTransport = (testPort, transport @TLS),
|
||||
networkConfig = defaultNetworkConfig {tcpTimeout = 500_000}
|
||||
},
|
||||
ntfCfg =
|
||||
defaultClientConfig
|
||||
{ qSize = 1,
|
||||
defaultTransport = (ntfTestPort, transport @TLS)
|
||||
},
|
||||
smpCfg = defaultClientConfig {qSize = 1, defaultTransport = (testPort, transport @TLS)},
|
||||
ntfCfg = defaultClientConfig {qSize = 1, defaultTransport = (ntfTestPort, transport @TLS)},
|
||||
reconnectInterval = defaultReconnectInterval {initialInterval = 50_000},
|
||||
ntfWorkerDelay = 1000,
|
||||
ntfSMPWorkerDelay = 1000,
|
||||
@@ -209,11 +203,14 @@ agentCfg =
|
||||
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> m () -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
|
||||
let cfg' = agentCfg {tcpPort = port', database = db'}
|
||||
initServers' = initAgentServers {smp = L.fromList [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
|
||||
initServers' = initAgentServers {smp = userServers [ProtoServerWithAuth (SMPServer "localhost" smpPort' testKeyHash) Nothing]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg' initServers')
|
||||
afterProcess
|
||||
|
||||
userServers :: NonEmpty SMPServerWithAuth -> Map UserId (NonEmpty SMPServerWithAuth)
|
||||
userServers srvs = M.fromList [(1, srvs)]
|
||||
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, AgentDatabase) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a $ removeFile (dbFile db')
|
||||
|
||||
@@ -226,7 +223,7 @@ withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
||||
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m, MonadFail m) => ServiceName -> (c -> m a) -> m a
|
||||
testSMPAgentClientOn port' client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig agentTestHost
|
||||
runTransportClient defaultTransportClientConfig useHost port' (Just testKeyHash) $ \h -> do
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost port' (Just testKeyHash) $ \h -> do
|
||||
line <- liftIO $ getLn h
|
||||
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
then client h
|
||||
|
||||
+1
-1
@@ -57,7 +57,7 @@ testServerStatsBackupFile = "tests/tmp/smp-server-stats.log"
|
||||
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
|
||||
testSMPClient client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig useHost testPort (Just testKeyHash) $ \h ->
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h ->
|
||||
liftIO (runExceptT $ smpClientHandshake h testKeyHash supportedSMPServerVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
+33
-2
@@ -11,6 +11,7 @@
|
||||
|
||||
module ServerTests where
|
||||
|
||||
import AgentTests.NotificationTests (removeFileIfExists)
|
||||
import Control.Concurrent (ThreadId, killThread, threadDelay)
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (SomeException, try)
|
||||
@@ -20,6 +21,7 @@ import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Set as S
|
||||
import SMPClient
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
@@ -28,6 +30,7 @@ import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..))
|
||||
import Simplex.Messaging.Transport
|
||||
import System.Directory (removeFile)
|
||||
import System.TimeIt (timeItT)
|
||||
@@ -605,6 +608,10 @@ logSize f =
|
||||
testRestoreMessages :: ATransport -> Spec
|
||||
testRestoreMessages at@(ATransport t) =
|
||||
it "should store messages on exit and restore on start" $ do
|
||||
removeFileIfExists testStoreLogFile
|
||||
removeFileIfExists testStoreMsgsFile
|
||||
removeFileIfExists testServerStatsBackupFile
|
||||
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
recipientId <- newTVarIO ""
|
||||
recipientKey <- newTVarIO Nothing
|
||||
@@ -632,11 +639,15 @@ testRestoreMessages at@(ATransport t) =
|
||||
Resp "6" _ (ERR QUOTA) <- signSendRecv h sKey ("6", sId, _SEND "hello 6")
|
||||
pure ()
|
||||
|
||||
rId <- readTVarIO recipientId
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 2
|
||||
logSize testStoreMsgsFile `shouldReturn` 5
|
||||
logSize testServerStatsBackupFile `shouldReturn` 16
|
||||
Right stats1 <- strDecode <$> B.readFile testServerStatsBackupFile
|
||||
checkStats stats1 [rId] 5 1
|
||||
|
||||
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
|
||||
rId <- readTVarIO recipientId
|
||||
Just rKey <- readTVarIO recipientKey
|
||||
Just dh <- readTVarIO dhShared
|
||||
let dec = decryptMsgV3 dh
|
||||
@@ -650,9 +661,11 @@ testRestoreMessages at@(ATransport t) =
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
-- the last message is not removed because it was not ACK'd
|
||||
logSize testStoreMsgsFile `shouldReturn` 3
|
||||
logSize testServerStatsBackupFile `shouldReturn` 16
|
||||
Right stats2 <- strDecode <$> B.readFile testServerStatsBackupFile
|
||||
checkStats stats2 [rId] 5 3
|
||||
|
||||
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
|
||||
rId <- readTVarIO recipientId
|
||||
Just rKey <- readTVarIO recipientKey
|
||||
Just dh <- readTVarIO dhShared
|
||||
let dec = decryptMsgV3 dh
|
||||
@@ -667,9 +680,13 @@ testRestoreMessages at@(ATransport t) =
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
logSize testStoreMsgsFile `shouldReturn` 0
|
||||
logSize testServerStatsBackupFile `shouldReturn` 16
|
||||
Right stats3 <- strDecode <$> B.readFile testServerStatsBackupFile
|
||||
checkStats stats3 [rId] 5 5
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testStoreMsgsFile
|
||||
removeFile testServerStatsBackupFile
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
@@ -679,6 +696,20 @@ testRestoreMessages at@(ATransport t) =
|
||||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation
|
||||
checkStats s qs sent received = do
|
||||
_qCreated s `shouldBe` length qs
|
||||
_qSecured s `shouldBe` length qs
|
||||
_qDeleted s `shouldBe` 0
|
||||
_msgSent s `shouldBe` sent
|
||||
_msgRecv s `shouldBe` received
|
||||
_msgSentNtf s `shouldBe` 0
|
||||
_msgRecvNtf s `shouldBe` 0
|
||||
let PeriodStatsData {_day, _week, _month} = _activeQueues s
|
||||
S.toList _day `shouldBe` qs
|
||||
S.toList _week `shouldBe` qs
|
||||
S.toList _month `shouldBe` qs
|
||||
|
||||
testRestoreMessagesV2 :: ATransport -> Spec
|
||||
testRestoreMessagesV2 at@(ATransport t) =
|
||||
it "should store messages on exit and restore on start" $ do
|
||||
|
||||
Reference in New Issue
Block a user