connection queue redundancy and rotation (#521)

* rfc: queue rotation

* update rfc

* messages for queue rotation

* allow multiple subscribed queues per connection in Agent/Client.hs

* refactor

* fix module name

* allow multiple queues in duplex connection type

* update commands

* add indices

* addConnectionRcvQueue

* switch connection to another queue (WIP)

* update schema/protocol

* switching queue works, but sending messages after the switch fails

* messages are delivered after rotation

* use connection-scoped queue ID

* rename queue records fields

* refactor using SMPQueue class/instances

* simplify queries

* QKEY: check queue is not secured, refactor

* update rfc

* mark queue as primary in QUSE

* queue rotation errors

* fix async ack

* fix async ACK to send OK

* correction

Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>

* use SWCH command

* rename

* take into account only active queue subscription when determining connection result if at least one queue is active

* remove comment

* only enable notifications for connections with enableNtfs = True

* async test (WIP)

* async queue rotation test

* simplify combining results

* test with 2 servers

* fix unused subscribeConnection

* switch to cabal build

* increase build timeout

* increase delay in async test

* skip queue rotation tests

* build matrix

* step name

* use ubuntu-18.04 in build matrix

* enable rotation tests

Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>
This commit is contained in:
Evgeny Poberezkin
2022-10-29 18:57:01 +01:00
committed by GitHub
parent 19aef52135
commit eb5c1c78cb
21 changed files with 1362 additions and 517 deletions

View File

@@ -11,35 +11,51 @@ on:
jobs:
build:
runs-on: ubuntu-20.04
name: build-${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- os: ubuntu-18.04
- os: ubuntu-20.04
steps:
- name: Clone project
uses: actions/checkout@v2
- name: Setup Stack
- name: Setup Haskell
uses: haskell/actions/setup@v1
with:
ghc-version: "8.10.7"
enable-stack: true
stack-version: "latest"
cabal-version: "latest"
- name: Cache dependencies
uses: actions/cache@v2
with:
path: ~/.stack
key: ${{ hashFiles('stack.yaml') }}
path: |
~/.cabal/store
dist-newstyle
key: ${{ matrix.os }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }}
- name: Build & test
id: build_test
- name: Build
shell: bash
run: cabal build --enable-tests
- name: Test
if: matrix.os == 'ubuntu-18.04'
timeout-minutes: 30
shell: bash
run: cabal test --test-show-details=direct
- name: Prepare binaries
if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04'
shell: bash
run: |
stack build --test --force-dirty
install_root=$(stack path --local-install-root)
mv ${install_root}/bin/smp-server smp-server-ubuntu-20_04-x86-64
mv ${install_root}/bin/ntf-server ntf-server-ubuntu-20_04-x86-64
mv $(cabal list-bin smp-server) smp-server-ubuntu-20_04-x86-64
mv $(cabal list-bin ntf-server) ntf-server-ubuntu-20_04-x86-64
- name: Build changelog
if: startsWith(github.ref, 'refs/tags/v')
if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04'
id: build_changelog
uses: mikepenz/release-changelog-builder-action@v1
with:
@@ -50,19 +66,8 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Extract release candidate
if: startsWith(github.ref, 'refs/tags/v')
id: extract_release_candidate
shell: bash
run: |
if [[ ${GITHUB_REF} == *rc* ]]; then
echo "::set-output name=release_candidate::true"
else
echo "::set-output name=release_candidate::false"
fi
- name: Create release
if: startsWith(github.ref, 'refs/tags/v')
if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04'
uses: softprops/action-gh-release@v1
with:
body: |
@@ -70,7 +75,7 @@ jobs:
Commits:
${{ steps.build_changelog.outputs.changelog }}
prerelease: ${{ steps.extract_release_candidate.outputs.release_candidate }}
prerelease: true
files: |
LICENSE
smp-server-ubuntu-20_04-x86-64

View File

@@ -0,0 +1,62 @@
# SMP queue rotation and redundancy
## Problem
1. Long term usage of the queue allows the servers to analyze long term traffic patterns.
2. If the user changes the configured SMP server(s), the previously created contacts do not migrate to the new server(s).
3. Server can lose messages.
## Solution
Additional messages exchanged by SMP agents to negotiate addition and removal of queues to the connection.
This approach allows for both rotation and redundancy, as changing queue will be done be adding a queue and then removing the existing queue.
The reason for this approach is that otherwise it's non-trivial to switch from one queue to another without losing messages or delivering them out of order, it's easier to have messages delivered via both queues during the switch, however short or long time it is.
### Messages
Additional agent messages required for the protocol:
QADD_ -> "QA"
QKEY_ -> "QK"
QUSE_ -> "QU"
QTEST_ -> "QT"
QDEL_ -> "QD"
QEND_ -> "QE"
`QADD`: add the new queue address(es), the same format as `REPLY` message, encoded as `QA`.
`QKEY`: pass sender's key via existing connection (SMP confirmation message will not be used, to avoid the same "race" of the initial key exchange that would create the risk of intercepting the queue for the attacker), encoded as `QK`.
`QUSE`: instruct the sender to use the new queue with sender's queue ID as parameter, encoded as `QU`.
`QTEST`: send test message to the new connection, encoded as `QT`. Any other message can be sent if available to continue rotation, the absence of this message is not an error.
`QDEL`: instruct the sender to stop using the previous queue, encoded as `QD`
`QEND`: notify the recipient that no messages will be sent to this queue, encoded as `QE`. The recipient will delete this queue.
### Protocol
```
participant A as Alice
participant B as Bob
participant R as Server that has A's receive queue
participant S as Server that has A's send queue (B's receive queue)
participant R' as Server that hosts the new A's receive queue
A ->> R': create new queue
A ->> S ->> B: QADD (R'): snd address of the new queue(s)
B ->> A(R) ->> A: QKEY (R'): sender's key for the new queue(s) (to avoid the race of SMP confirmation for the initial exchange)
A ->> S(R'): secure new queue
A ->> S ->> B: QUSE (R'): instruction to use new queue(s)
B ->> A(R,R') ->> A: QTEST
B ->> A(R,R') ->> A: send all new messages to both queues
A ->> S ->> B: QDEL (R): instruction to delete the old queue
B ->> A(R') -> A: QEND (R): notification that no messages will be sent to the old queue
B ->> R' ->> A: send all new messages to the new queue only
A ->> S(R): DEL: delete the previous queue
```

View File

@@ -54,6 +54,8 @@ library
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
Simplex.Messaging.Agent.TRcvQueues
Simplex.Messaging.Client
Simplex.Messaging.Client.Agent
Simplex.Messaging.Crypto
@@ -83,7 +85,6 @@ library
Simplex.Messaging.Server.Stats
Simplex.Messaging.Server.StoreLog
Simplex.Messaging.TMap
Simplex.Messaging.TMap2
Simplex.Messaging.Transport
Simplex.Messaging.Transport.Client
Simplex.Messaging.Transport.HTTP2
@@ -350,6 +351,7 @@ test-suite smp-server-test
AgentTests.NotificationTests
AgentTests.SchemaDump
AgentTests.SQLiteTests
CoreTests.CryptoTests
CoreTests.EncodingTests
CoreTests.ProtocolErrorTests
CoreTests.VersionRangeTests

View File

@@ -43,6 +43,7 @@ module Simplex.Messaging.Agent
allowConnectionAsync,
acceptContactAsync,
ackMessageAsync,
switchConnectionAsync,
deleteConnectionAsync,
createConnection,
joinConnection,
@@ -57,6 +58,7 @@ module Simplex.Messaging.Agent
resubscribeConnections,
sendMessage,
ackMessage,
switchConnection,
suspendConnection,
deleteConnection,
getConnectionServers,
@@ -89,9 +91,10 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.), (.::))
import Data.Foldable (foldl')
import Data.Functor (($>))
import Data.List (deleteFirstsBy)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List (deleteFirstsBy, find)
import Data.List.NonEmpty (NonEmpty (..), (<|))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
@@ -117,7 +120,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, sameSrvAddr)
import qualified Simplex.Messaging.Protocol as SMP
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util
@@ -171,6 +174,10 @@ acceptContactAsync c corrId enableNtfs = withAgentEnv c .: acceptContactAsync' c
ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
ackMessageAsync c = withAgentEnv c .:. ackMessageAsync' c
-- | Switch connection to the new receive queue
switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ()
switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c
-- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ()
deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c
@@ -224,6 +231,10 @@ sendMessage c = withAgentEnv c .:. sendMessage' c
ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> m ()
ackMessage c = withAgentEnv c .: ackMessage' c
-- | Switch connection to the new receive queue
switchConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
switchConnection c = withAgentEnv c . switchConnection' c
-- | Suspend SMP agent connection (OFF command)
suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
suspendConnection c = withAgentEnv c . suspendConnection' c
@@ -328,6 +339,7 @@ processCommand c (connId, cmd) = case cmd of
SUB -> subscribeConnection' c connId $> (connId, OK)
SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody
ACK msgId -> ackMessage' c connId msgId $> (connId, OK)
SWCH -> switchConnection' c connId $> (connId, OK)
OFF -> suspendConnection' c connId $> (connId, OK)
DEL -> deleteConnection' c connId $> (connId, OK)
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
@@ -375,22 +387,25 @@ acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do
_ -> throwError $ CMD PROHIBITED
ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
ackMessageAsync' c corrId connId msgId =
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> enqueueAck rq
SomeConn _ (RcvConnection _ rq) -> enqueueAck rq
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
ackMessageAsync' c corrId connId msgId = do
SomeConn cType _ <- withStore c (`getConn` connId)
case cType of
SCDuplex -> enqueueAck
SCRcv -> enqueueAck
SCSnd -> throwError $ CONN SIMPLEX
SCContact -> throwError $ CMD PROHIBITED
SCNew -> throwError $ CMD PROHIBITED
where
enqueueAck :: RcvQueue -> m ()
enqueueAck RcvQueue {server} =
enqueueCommand c corrId connId (Just server) $ AClientCommand $ ACK msgId
enqueueAck :: m ()
enqueueAck = do
(RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId
enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
deleteConnectionAsync' c@AgentClient {subQ} corrId connId =
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> enqueueDelete rq
-- TODO *** delete all queues
SomeConn _ (DuplexConnection _ (rq :| _) _) -> enqueueDelete rq
SomeConn _ (RcvConnection _ rq) -> enqueueDelete rq
SomeConn _ (ContactConnection _ rq) -> enqueueDelete rq
SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId) >> notifyDeleted
@@ -402,6 +417,13 @@ deleteConnectionAsync' c@AgentClient {subQ} corrId connId =
notifyDeleted :: m ()
notifyDeleted = atomically $ writeTBQueue subQ (corrId, connId, OK)
-- | Add connection to the new receive queue
switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
switchConnectionAsync' c corrId connId =
withStore c (`getConn` connId) >>= \case
SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
_ -> throwError $ CMD PROHIBITED
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
newConn c connId asyncMode enableNtfs cMode =
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode
@@ -409,9 +431,10 @@ newConn c connId asyncMode enableNtfs cMode =
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c)
newConnSrv c connId asyncMode enableNtfs cMode srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
(rq, qUri) <- newRcvQueue c srv smpClientVRange
connId' <- setUpConn asyncMode rq $ maxVersion smpAgentVRange
addSubscription c rq connId'
(q, qUri) <- newRcvQueue c "" srv smpClientVRange
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
let rq = (q :: RcvQueue) {connId = connId'}
addSubscription c rq
when enableNtfs $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId', NSCCreate)
@@ -424,7 +447,7 @@ newConnSrv c connId asyncMode enableNtfs cMode srv = do
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
where
setUpConn True rq _ = do
withStore c $ \db -> updateNewConnRcv db connId rq
void . withStore c $ \db -> updateNewConnRcv db connId rq
pure connId
setUpConn False rq connAgentVersion = do
g <- asks idsDrg
@@ -434,8 +457,8 @@ newConnSrv c connId asyncMode enableNtfs cMode srv = do
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c connId asyncMode enableNtfs cReq cInfo = do
srv <- case cReq of
CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ ->
getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)]
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
getNextSMPServer c [qServer q]
_ -> getSMPServer c
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
@@ -450,11 +473,12 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ age
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
(_, rcDHRs) <- liftIO C.generateKeyPair'
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
sq <- newSndQueue qInfo
q <- newSndQueue "" qInfo
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
connId' <- setUpConn asyncMode cData sq rc
let cData' = (cData :: ConnData) {connId = connId'}
connId' <- setUpConn asyncMode cData q rc
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
@@ -466,7 +490,7 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ age
where
setUpConn True _ sq rc =
withStore c $ \db -> runExceptT $ do
ExceptT $ updateNewConnSnd db connId sq
void . ExceptT $ updateNewConnSnd db connId sq
liftIO $ createRatchet db connId rc
pure connId
setUpConn False cData sq rc = do
@@ -492,10 +516,10 @@ joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion
(rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion
let qInfo = toVersionT qUri smpClientVersion
addSubscription c rq connId
withStore c $ \db -> upgradeSndConnToDuplex db connId rq
addSubscription c rq
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
when enableNtfs $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
@@ -538,24 +562,27 @@ subscribeConnection' c connId =
withStore c (`getConn` connId) >>= \conn -> do
resumeConnCmds c connId
case conn of
SomeConn _ (DuplexConnection cData rq sq) -> do
resumeMsgDelivery c cData sq
subscribe rq
SomeConn _ (DuplexConnection cData (rq :| rqs) sqs) -> do
mapM_ (resumeMsgDelivery c cData) sqs
subscribe cData rq
mapM_ (\q -> subscribeQueue c q `catchError` \_ -> pure ()) rqs
SomeConn _ (SndConnection cData sq) -> do
resumeMsgDelivery c cData sq
case status (sq :: SndQueue) of
Confirmed -> pure ()
Active -> throwError $ CONN SIMPLEX
_ -> throwError $ INTERNAL "unexpected queue status"
SomeConn _ (RcvConnection _ rq) -> subscribe rq
SomeConn _ (ContactConnection _ rq) -> subscribe rq
SomeConn _ (RcvConnection cData rq) -> subscribe cData rq
SomeConn _ (ContactConnection cData rq) -> subscribe cData rq
SomeConn _ (NewConnection _) -> pure ()
where
subscribe :: RcvQueue -> m ()
subscribe rq = do
subscribeQueue c rq connId
subscribe :: ConnData -> RcvQueue -> m ()
subscribe ConnData {enableNtfs} rq = do
subscribeQueue c rq
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
atomically $ sendNtfSubCommand ns (connId, if enableNtfs then NSCCreate else NSCDelete)
type QSubResult = (QueueStatus, Either AgentErrorType ())
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
subscribeConnections' _ [] = pure M.empty
@@ -564,42 +591,61 @@ subscribeConnections' c connIds = do
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
srvRcvQs :: Map SMPServer (Map ConnId RcvQueue) = M.foldlWithKey' addRcvQueue M.empty rcvQs
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
srvRcvQs :: Map SMPServer [RcvQueue] = M.foldl' (foldl' addRcvQueue) M.empty rcvQs
mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs
mapM_ (resumeConnCmds c) $ M.keys cs
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
rcvRs <- connResults . concat <$> mapConcurrently subscribe (M.assocs srvRcvQs)
ns <- asks ntfSupervisor
tkn <- readTVarIO (ntfTkn ns)
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs
let rs = M.unions $ errs' : subRs : rcvRs
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns
let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())])
notifyResultError rs
pure rs
where
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) RcvQueue
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) (NonEmpty RcvQueue)
rcvQueueOrResult = \case
SomeConn _ (DuplexConnection _ rq _) -> Right rq
SomeConn _ (DuplexConnection _ rqs _) -> Right rqs
SomeConn _ (SndConnection _ sq) -> Left $ sndSubResult sq
SomeConn _ (RcvConnection _ rq) -> Right rq
SomeConn _ (ContactConnection _ rq) -> Right rq
SomeConn _ (RcvConnection _ rq) -> Right [rq]
SomeConn _ (ContactConnection _ rq) -> Right [rq]
SomeConn _ (NewConnection _) -> Left (Right ())
sndSubResult :: SndQueue -> Either AgentErrorType ()
sndSubResult sq = case status (sq :: SndQueue) of
Confirmed -> Right ()
Active -> Left $ CONN SIMPLEX
_ -> Left $ INTERNAL "unexpected queue status"
addRcvQueue :: Map SMPServer (Map ConnId RcvQueue) -> ConnId -> RcvQueue -> Map SMPServer (Map ConnId RcvQueue)
addRcvQueue m connId rq@RcvQueue {server} = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
subscribe :: (SMPServer, Map ConnId RcvQueue) -> m (Map ConnId (Either AgentErrorType ()))
addRcvQueue :: Map SMPServer [RcvQueue] -> RcvQueue -> Map SMPServer [RcvQueue]
addRcvQueue m rq@RcvQueue {server} = M.alter (Just . maybe [rq] (rq :)) server m
subscribe :: (SMPServer, [RcvQueue]) -> m [(RcvQueue, Either AgentErrorType ())]
subscribe (srv, qs) = snd <$> subscribeQueues c srv qs
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
sendNtfCreate ns rcvRs =
forM_ (concatMap M.assocs rcvRs) $ \case
(connId, Right _) -> atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate)
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
connResults = M.map snd . foldl' addResult M.empty
where
-- collects results by connection ID
addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QSubResult
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
-- combines two results for one connection, by using only Active queues (if there is at least one Active queue)
combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
combineRes r' _ = Just r'
order :: QSubResult -> Int
order (Active, Right _) = 1
order (Active, _) = 2
order (_, Right _) = 3
order _ = 4
sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> m ()
sendNtfCreate ns rcvRs conns =
forM_ (M.assocs rcvRs) $ \case
(connId, Right _) -> forM_ (M.lookup connId conns) $ \case
Right (SomeConn _ conn) -> do
let cmd = if enableNtfs $ connData conn then NSCCreate else NSCDelete
atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd)
_ -> pure ()
_ -> pure ()
sndQueue :: SomeConn -> Maybe (ConnData, SndQueue)
sndQueue :: SomeConn -> Maybe (ConnData, NonEmpty SndQueue)
sndQueue = \case
SomeConn _ (DuplexConnection cData _ sq) -> Just (cData, sq)
SomeConn _ (SndConnection cData sq) -> Just (cData, sq)
SomeConn _ (DuplexConnection cData _ sqs) -> Just (cData, sqs)
SomeConn _ (SndConnection cData sq) -> Just (cData, [sq])
_ -> Nothing
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
notifyResultError rs = do
@@ -626,7 +672,7 @@ getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMs
getConnectionMessage' c connId = do
whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> getQueueMessage c rq
SomeConn _ (DuplexConnection _ (rq :| _) _) -> getQueueMessage c rq
SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq
SomeConn _ (ContactConnection _ rq) -> getQueueMessage c rq
SomeConn _ SndConnection {} -> throwError $ CONN SIMPLEX
@@ -663,12 +709,12 @@ getNotificationMessage' c nonce encNtfInfo = do
sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData _ sq) -> enqueueMsg cData sq
SomeConn _ (SndConnection cData sq) -> enqueueMsg cData sq
SomeConn _ (DuplexConnection cData _ sqs) -> enqueueMsgs cData sqs
SomeConn _ (SndConnection cData sq) -> enqueueMsgs cData [sq]
_ -> throwError $ CONN SIMPLEX
where
enqueueMsg :: ConnData -> SndQueue -> m AgentMsgId
enqueueMsg cData sq = enqueueMessage c cData sq msgFlags $ A_MSG msg
enqueueMsgs :: ConnData -> NonEmpty SndQueue -> m AgentMsgId
enqueueMsgs cData sqs = enqueueMessages c cData sqs msgFlags $ A_MSG msg
-- / async command processing v v v
@@ -713,8 +759,8 @@ getPendingCommandQ c server = do
pure cq
runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
runCommandProcessing c@AgentClient {subQ} server = do
cq <- atomically $ getPendingCommandQ c server
runCommandProcessing c@AgentClient {subQ} server_ = do
cq <- atomically $ getPendingCommandQ c server_
ri <- asks $ messageRetryInterval . config -- different retry interval?
forever $ do
atomically $ endAgentOperation c AOSndNetwork
@@ -726,73 +772,108 @@ runCommandProcessing c@AgentClient {subQ} server = do
Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd
where
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
processCmd ri corrId connId cmdId = \case
processCmd ri corrId connId cmdId command = case command of
AClientCommand cmd -> case cmd of
NEW enableNtfs (ACM cMode) -> do
NEW enableNtfs (ACM cMode) -> noServer $ do
usedSrvs <- newTVarIO ([] :: [SMPServer])
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
(_, cReq) <- newConnSrv c connId True enableNtfs cMode srv
notify $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do
let initUsed = [smpServer (queueAddress :: SMPQueueAddress)]
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do
let initUsed = [qServer q]
usedSrvs <- newTVarIO initUsed
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
notify OK
LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK
DEL -> tryCommand $ deleteConnection' c connId >> notify OK
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK
SWCH -> noServer $ tryCommand $ switchConnection' c connId >>= notify . SWITCH SPStarted
DEL -> withServer' . tryCommand $ deleteConnection' c connId >> notify OK
_ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd)
AInternalCommand cmd -> case server of
Just _srv -> case cmd of
ICAckDel _rId srvMsgId msgId -> tryWithLock "ICAckDel" $ ack _rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId)
ICAck _rId srvMsgId -> tryWithLock "ICAck" $ ack _rId srvMsgId
ICAllowSecure _rId senderKey -> tryWithLock "ICAllowSecure" $ do
(SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <-
withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId)
case conn of
RcvConnection cData rq -> do
secure rq senderKey
mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf)
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
ICDuplexSecure _rId senderKey -> tryWithLock "ICDuplexSecure" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData rq sq -> do
secure rq senderKey
when (duplexHandshake cData == Just True) . void $
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
_ -> throwError $ INTERNAL $ "command requires server " <> show (internalCmdTag cmd)
AInternalCommand cmd -> case cmd of
ICAckDel rId srvMsgId msgId -> withServer $ \srv -> tryWithLock "ICAckDel" $ ack srv rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId)
ICAck rId srvMsgId -> withServer $ \srv -> tryWithLock "ICAck" $ ack srv rId srvMsgId
ICAllowSecure _rId senderKey -> withServer' . tryWithLock "ICAllowSecure" $ do
(SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <-
withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId)
case conn of
RcvConnection cData rq -> do
secure rq senderKey
mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf)
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do
secure rq senderKey
when (duplexHandshake cData == Just True) . void $
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
ICQSecure rId senderKey ->
withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) ->
case find (sameQueue (srv, rId)) rqs of
Just rq'@RcvQueue {server, sndId, status} -> when (status == Confirmed) $ do
secureQueue c rq' senderKey
withStore' c $ \db -> setRcvQueueStatus db rq' Secured
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QUSE [((server, sndId), True)]
_ -> internalErr "ICQSecure: queue address not found in connection"
ICQDelete rId -> do
withServer $ \srv -> tryWithLock "ICQDelete" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> do
case removeQ (srv, rId) rqs of
Nothing -> internalErr "ICQDelete: queue address not found in connection"
Just (rq'@RcvQueue {primary}, rq'' : rqs')
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
| otherwise -> do
deleteQueue c rq'
withStore' c $ \db -> deleteConnRcvQueue db connId rq'
let conn' = DuplexConnection cData (rq'' :| rqs') sqs
notify $ SWITCH SPCompleted $ connectionStats conn'
_ -> internalErr "ICQDelete: cannot delete the only queue in connection"
where
ack _rId srvMsgId = do
-- TODO get particular queue
rq <- withStore c (`getRcvQueue` connId)
ack srv rId srvMsgId = do
rq <- withStore c $ \db -> getRcvQueue db connId srv rId
ackQueueMessage c rq srvMsgId
secure :: RcvQueue -> SMP.SndPublicVerifyKey -> m ()
secure rq senderKey = do
secureQueue c rq senderKey
withStore' c $ \db -> setRcvQueueStatus db rq Secured
where
withServer a = case server_ of
Just srv -> a srv
_ -> internalErr "command requires server"
withServer' = withServer . const
noServer a = case server_ of
Nothing -> a
_ -> internalErr "command requires no server"
withDuplexConn :: (Connection 'CDuplex -> m ()) -> m ()
withDuplexConn a =
withStore c (`getConn` connId) >>= \case
SomeConn _ conn@DuplexConnection {} -> a conn
_ -> internalErr "command requires duplex connection"
tryCommand action = withRetryInterval ri $ \loop ->
tryError action >>= \case
Left e
| temporaryAgentError e || e == BROKER HOST -> retrySndOp c loop
| otherwise -> notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
| otherwise -> cmdError e
Right () -> withStore' c (`deleteCommand` cmdId)
tryWithLock name = tryCommand . withConnLock c connId name
internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command)
cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd)
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m ()
withNextSrv usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srv <- getNextSMPServer c used
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
writeTVar usedSrvs used'
action srv
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m ()
withNextSrv usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srv <- getNextSMPServer c used
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
writeTVar usedSrvs used'
action srv
-- ^ ^ ^ async command processing /
enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessages c cData (sq :| sqs) msgFlags aMessage = do
msgId <- enqueueMessage c cData sq msgFlags aMessage
mapM_ (enqueueSavedMessage c cData msgId) $
filter (\SndQueue {status} -> status == Secured || status == Active) sqs
pure msgId
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do
resumeMsgDelivery c cData sq
@@ -813,20 +894,28 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage
msgType = agentMessageType agentMsg
msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash}
liftIO $ createSndMsg db connId msgData
liftIO $ createSndMsgDelivery db connId sq internalId
pure internalId
enqueueSavedMessage :: AgentMonad m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m ()
enqueueSavedMessage c cData@ConnData {connId} msgId sq = do
resumeMsgDelivery c cData sq
let mId = InternalId msgId
queuePendingMsgs c sq [mId]
withStore' c $ \db -> createSndMsgDelivery db connId sq mId
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
let qKey = (server, sndId)
unlessM (queueDelivering qKey) $
async (runSmpQueueMsgDelivery c cData sq)
>>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c)
unlessM connQueued $
withStore' c (`getPendingMsgs` connId)
unlessM msgsQueued $
withStore' c (\db -> getPendingMsgs db connId sq)
>>= queuePendingMsgs c sq
where
queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c)
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c)
msgsQueued = atomically $ isJust <$> TM.lookupInsert (server, sndId) True (pendingMsgsQueued c)
queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m ()
queuePendingMsgs c sq msgIds = atomically $ do
@@ -853,6 +942,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
forever $ do
atomically $ endAgentOperation c AOSndNetwork
atomically $ throwWhenInactive c
atomically $ throwWhenNoDelivery c sq
msgId <- atomically $ readTQueue mq
atomically $ beginAgentOperation c AOSndNetwork
atomically $ endAgentOperation c AOMsgDelivery
@@ -890,6 +980,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
_ -> connError msgId NOT_ACCEPTED
AM_REPLY_ -> notifyDel msgId $ ERR e
AM_A_MSG_ -> notifyDel msgId $ MERR mId e
AM_QADD_ -> pure ()
AM_QKEY_ -> pure ()
AM_QUSE_ -> pure ()
AM_QTEST_ -> pure ()
AM_QDEL_ -> pure ()
AM_QEND_ -> pure ()
_
-- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server),
-- the message sending would be retried
@@ -910,6 +1006,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
when (isJust rq_) $ removeConfirmations db connId
-- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification
unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO
AM_CONN_INFO_REPLY -> pure ()
AM_REPLY_ -> pure ()
AM_HELLO_ -> do
withStore' c $ \db -> setSndQueueStatus db sq Active
case rq_ of
@@ -933,11 +1031,18 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
_ -> pure ()
AM_QADD_ -> pure ()
AM_QKEY_ -> pure ()
AM_QUSE_ -> pure ()
AM_QTEST_ ->
withStore' c $ \db -> setSndQueueStatus db sq Active
AM_QDEL_ -> pure ()
AM_QEND_ ->
getConnectionServers' c connId >>= notify . SWITCH SPCompleted
delMsg msgId
where
delMsg :: InternalId -> m ()
delMsg msgId = withStore' c $ \db -> deleteMsg db connId msgId
delMsg msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId
notify :: ACommand 'Agent -> m ()
notify cmd = atomically $ writeTBQueue subQ ("", connId, cmd)
notifyDel :: InternalId -> ACommand 'Agent -> m ()
@@ -954,20 +1059,38 @@ retrySndOp c loop = do
ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> m ()
ackMessage' c connId msgId = withConnLock c connId "ackMessage" $ do
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> ack rq
SomeConn _ (RcvConnection _ rq) -> ack rq
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection {} -> ack
RcvConnection {} -> ack
SndConnection {} -> throwError $ CONN SIMPLEX
ContactConnection {} -> throwError $ CMD PROHIBITED
NewConnection _ -> throwError $ CMD PROHIBITED
where
ack :: RcvQueue -> m ()
ack rq = do
ack :: m ()
ack = do
let mId = InternalId msgId
srvMsgId <- withStore c $ \db -> setMsgUserAck db connId mId
(rq, srvMsgId) <- withStore c $ \db -> setMsgUserAck db connId mId
ackQueueMessage c rq srvMsgId
withStore' c $ \db -> deleteMsg db connId mId
switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
switchConnection' c connId = withConnLock c connId "switchConnection" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
clientVRange <- asks $ smpClientVRange . config
-- try to get the server that is different from all queues, or at least from the primary rcv queue
srv <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
srv' <- if srv == server then getNextSMPServer c [server] else pure srv
(q, qUri) <- newRcvQueue c connId srv' clientVRange
let rq' = (q :: RcvQueue) {primary = False, nextPrimary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnRcvQueue db connId rq'
addSubscription c rq'
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))]
pure . connectionStats $ DuplexConnection cData (rq <| rq' :| rqs_) sqs
_ -> throwError $ CMD PROHIBITED
ackQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> SMP.MsgId -> m ()
ackQueueMessage c rq srvMsgId =
sendAck c rq srvMsgId `catchError` \case
@@ -978,7 +1101,7 @@ ackQueueMessage c rq srvMsgId =
suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> suspendQueue c rq
SomeConn _ (DuplexConnection _ rqs _) -> mapM_ (suspendQueue c) rqs
SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq
SomeConn _ (ContactConnection _ rq) -> suspendQueue c rq
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
@@ -988,7 +1111,7 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> delete rq
SomeConn _ (DuplexConnection _ rqs _) -> mapM_ delete rqs
SomeConn _ (RcvConnection _ rq) -> delete rq
SomeConn _ (ContactConnection _ rq) -> delete rq
SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId)
@@ -1003,15 +1126,17 @@ deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete)
getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
getConnectionServers' c connId = connServers <$> withStore c (`getConn` connId)
where
connServers :: SomeConn -> ConnectionStats
connServers = \case
SomeConn _ (RcvConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []}
SomeConn _ (SndConnection _ SndQueue {server}) -> ConnectionStats {rcvServers = [], sndServers = [server]}
SomeConn _ (DuplexConnection _ RcvQueue {server = s1} SndQueue {server = s2}) -> ConnectionStats {rcvServers = [s1], sndServers = [s2]}
SomeConn _ (ContactConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []}
SomeConn _ (NewConnection _) -> ConnectionStats {rcvServers = [], sndServers = []}
getConnectionServers' c connId = do
SomeConn _ conn <- withStore c (`getConn` connId)
pure $ connectionStats conn
connectionStats :: Connection c -> ConnectionStats
connectionStats = \case
RcvConnection _ rq -> ConnectionStats {rcvServers = [qServer rq], sndServers = []}
SndConnection _ sq -> ConnectionStats {rcvServers = [], sndServers = [qServer sq]}
DuplexConnection _ rqs sqs -> ConnectionStats {rcvServers = map qServer $ L.toList rqs, sndServers = map qServer $ L.toList sqs}
ContactConnection _ rq -> ConnectionStats {rcvServers = [qServer rq], sndServers = []}
NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []}
-- | Change servers to be used for creating new queues, in Reader monad
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
@@ -1271,11 +1396,9 @@ pickServer = \case
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer
getNextSMPServer c usedSrvs = do
srvs <- readTVarIO $ smpServers c
case L.nonEmpty $ deleteFirstsBy sameAddr (L.toList srvs) usedSrvs of
case L.nonEmpty $ deleteFirstsBy sameSrvAddr (L.toList srvs) usedSrvs of
Just srvs' -> pickServer srvs'
_ -> pickServer srvs
where
sameAddr (SMPServer host port _) (SMPServer host' port' _) = host == host' && port == port'
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
@@ -1288,7 +1411,10 @@ subscriber c@AgentClient {msgQ} = forever $ do
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) =
withStore c (\db -> getRcvConn db srv rId) >>= \case
SomeConn _ conn@(DuplexConnection cData rq _) -> processSMP conn cData rq
-- TODO *** get queue separately?
SomeConn _ conn@(DuplexConnection cData rqs _) -> case find (sameQueue (srv, rId)) rqs of
Just rq -> processSMP conn cData rq
_ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND)
SomeConn _ conn@(RcvConnection cData rq) -> processSMP conn cData rq
SomeConn _ conn@(ContactConnection cData rq) -> processSMP conn cData rq
_ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND)
@@ -1313,7 +1439,20 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
_ -> prohibited >> ack
(Just e2eDh, Nothing) -> do
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) ->
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
let RcvQueue {primary, nextPrimary, dbReplaceQueueId} = rq
unless primary . withStore' c $ \db -> do
unless (status == Active) $ setRcvQueueStatus db rq Active
when nextPrimary $ setRcvQueuePrimary db connId rq
case (conn, dbReplaceQueueId) of
(DuplexConnection _ rqs sqs, Just dbRcvId) ->
case find (\RcvQueue {dbQueueId} -> dbQueueId == dbRcvId) rqs of
Just RcvQueue {server, sndId} -> do
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QDEL [(server, sndId)]
notify . SWITCH SPTested $ connectionStats conn
_ -> throwError $ INTERNAL "replaced RcvQueue not found in connection"
_ -> pure ()
tryError agentClientMsg >>= \case
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
HELLO -> helloMsg >> ackDel msgId
@@ -1322,6 +1461,19 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
QADD qs -> qDuplex "QADD" $ qAddMsg qs
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
-- no action needed for QTEST
-- any message in the new queue will mark it active and trigger deletion of the old queue
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
QDEL qs -> qDuplex "QDEL" $ qDelMsg qs
QEND qs -> qDuplex "QEND" $ qEndMsg qs
where
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplex name a = case conn of
DuplexConnection {} -> a conn >> ackDel msgId
_ -> qError $ name <> ": message must be sent to duplex connection"
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
@@ -1350,7 +1502,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
broker = (srvMsgId, systemToUTCTime srvTs)
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
liftIO $ createRcvMsg db connId rcvMsg
liftIO $ createRcvMsg db connId rq rcvMsg
pure $ Just (internalId, msgMeta, aMessage)
_ -> pure Nothing
_ -> prohibited >> ack
@@ -1434,12 +1586,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
confId <- withStore c $ \db -> do
setHandshakeVersion db connId agentVersion duplexHS
createConfirmation db g newConfirmation
let srvs = map queueServer $ smpReplyQueues senderConf
let srvs = map qServer $ smpReplyQueues senderConf
notify $ CONF confId srvs connInfo
queueServer (SMPQueueInfo _ SMPQueueAddress {smpServer}) = smpServer
_ -> prohibited
-- party accepting connection
(DuplexConnection _ RcvQueue {smpClientVersion = v'} _, Nothing) -> do
(DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do
withStore c (\db -> runExceptT $ agentRatchetDecrypt db connId encConnInfo) >>= parseMessage >>= \case
AgentConnInfo connInfo -> do
notify $ INFO connInfo
@@ -1458,7 +1609,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
_ -> do
withStore' c $ \db -> setRcvQueueStatus db rq Active
case conn of
DuplexConnection _ _ sq@SndQueue {status = sndStatus}
DuplexConnection _ _ (sq@SndQueue {status = sndStatus} :| _)
-- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO
-- this branch is executed by the accepting party in duplexHandshake mode (v2)
-- and by the initiating party in v1
@@ -1471,7 +1622,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
enqueueDuplexHello :: SndQueue -> m ()
enqueueDuplexHello sq = void $ enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
replyMsg :: L.NonEmpty SMPQueueInfo -> m ()
replyMsg :: NonEmpty SMPQueueInfo -> m ()
replyMsg smpQueues = do
logServer "<--" c srv rId "MSG <REPLY>"
case duplexHandshake of
@@ -1482,6 +1633,95 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR)
_ -> prohibited
-- processed by queue sender
qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m ()
qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
qAddMsg ((qUri, Just addr) :| _) (DuplexConnection _ rqs sqs@(sq :| sqs_)) = do
clientVRange <- asks $ smpClientVRange . config
case qUri `compatibleVersion` clientVRange of
Just qInfo@(Compatible sqInfo@SMPQueueInfo {queueAddress}) ->
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
(Just _, _) -> qError "QADD: queue address is already used in connection"
(_, Just _replaced@SndQueue {dbQueueId}) -> do
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo
let sq' = (sq_ :: SndQueue) {nextPrimary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnSndQueue db connId sq'
case (sndPublicKey, e2ePubKey) of
(Just sndPubKey, Just dhPublicKey) -> do
logServer "<--" c srv rId $ "MSG <QADD> " <> logSecret (senderId queueAddress)
let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}}
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)]
let conn' = DuplexConnection cData rqs (sq <| sq' :| sqs_)
notify . SWITCH SPStarted $ connectionStats conn'
_ -> qError "absent sender keys"
_ -> qError "QADD: replaced queue address is not found in connection"
_ -> throwError $ AGENT A_VERSION
-- processed by queue recipient
qKeyMsg :: NonEmpty (SMPQueueInfo, SndPublicVerifyKey) -> Connection 'CDuplex -> m ()
qKeyMsg ((qInfo, senderKey) :| _) (DuplexConnection _ rqs _) = do
clientVRange <- asks $ smpClientVRange . config
unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case findRQ (smpServer, senderId) rqs of
Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'}
| status' == New || status' == Confirmed -> do
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
let dhSecret = C.dh' dhPublicKey dhPrivKey
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
notify . SWITCH SPConfirmed $ connectionStats conn
| otherwise -> qError "QKEY: queue already secured"
_ -> qError "QKEY: queue address not found in connection"
where
SMPQueueInfo cVer' SMPQueueAddress {smpServer, senderId, dhPublicKey} = qInfo
-- processed by queue sender
-- mark queue as Secured and to start sending messages to it
qUseMsg :: NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> m ()
qUseMsg ((addr, primary) :| _) (DuplexConnection _ _ sqs) =
case removeQ addr sqs of
Just (sq', sqs') -> do
logServer "<--" c srv rId $ "MSG <QUSE> " <> logSecret (snd addr)
withStore' c $ \db -> do
setSndQueueStatus db sq' Secured
when primary $ setSndQueuePrimary db connId sq'
let sq'' = (sq' :: SndQueue) {status = Secured, primary}
void $ enqueueMessages c cData (sq'' :| sqs') SMP.noMsgFlags $ QTEST [addr]
notify . SWITCH SPConfirmed $ connectionStats conn
_ -> qError "QUSE: queue address not found in connection"
-- processed by queue sender
-- remove snd queue from connection and enqueue QEND message
qDelMsg :: NonEmpty (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
qDelMsg (addr :| _) (DuplexConnection _ rqs sqs) =
case removeQ addr sqs of
Nothing -> logServer "<--" c srv rId "MSG <QDEL>: queue not found (already deleted?)"
Just (sq, sq' : sqs') -> do
logServer "<--" c srv rId $ "MSG <QDEL> " <> logSecret (snd addr)
-- remove the delivery from the map to stop the thread when the delivery loop is complete
atomically $ TM.delete addr $ smpQueueMsgQueues c
withStore' c $ \db -> do
deletePendingMsgs db connId sq
deleteConnSndQueue db connId sq
let sqs'' = sq' :| sqs'
conn' = DuplexConnection cData rqs sqs''
void $ enqueueMessages c cData sqs'' SMP.noMsgFlags $ QEND [addr]
notify . SWITCH SPTested $ connectionStats conn'
_ -> qError "QDEL received to the only queue in connection"
-- received by party initiating switch
-- TODO *** check that the received address matches expectations
qEndMsg :: NonEmpty (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
qEndMsg (addr@(smpServer, senderId) :| _) (DuplexConnection _ rqs _) =
case findRQ addr rqs of
Just RcvQueue {rcvId} -> do
logServer "<--" c srv rId $ "MSG <QEND> " <> logSecret senderId
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQDelete rcvId
_ -> qError "QEND: queue address not found in connection"
qError :: String -> m ()
qError = throwError . AGENT . A_QUEUE
smpInvitation :: ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
smpInvitation connReq@(CRInvitationUri crData _) cInfo = do
logServer "<--" c srv rId "MSG <KEY>"
@@ -1490,11 +1730,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
g <- asks idsDrg
let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo}
invId <- withStore c $ \db -> createInvitation db g newInv
let srvs = L.map queueServer $ crSmpQueues crData
let srvs = L.map qServer $ crSmpQueues crData
notify $ REQ invId srvs cInfo
_ -> prohibited
where
queueServer (SMPQueueUri _ SMPQueueAddress {smpServer}) = smpServer
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash
@@ -1505,15 +1743,15 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
| otherwise = MsgError MsgDuplicate -- this case is not possible
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> L.NonEmpty SMPQueueInfo -> m ()
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
clientVRange <- asks $ smpClientVRange . config
case qInfo `proveCompatible` clientVRange of
Nothing -> throwError $ AGENT A_VERSION
Just qInfo' -> do
sq <- newSndQueue qInfo'
withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq ownConnInfo Nothing
sq <- newSndQueue connId qInfo'
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do
@@ -1570,28 +1808,34 @@ agentRatchetDecrypt db connId encAgentMsg = do
liftIO $ updateRatchet db connId rc' skippedDiff
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => Compatible SMPQueueInfo -> m SndQueue
newSndQueue qInfo =
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue
newSndQueue connId qInfo =
asks (cmdSignAlg . config) >>= \case
C.SignAlg a -> newSndQueue_ a qInfo
C.SignAlg a -> newSndQueue_ a connId qInfo
newSndQueue_ ::
(C.SignatureAlgorithm a, C.AlgorithmI a, MonadUnliftIO m) =>
C.SAlgorithm a ->
ConnId ->
Compatible SMPQueueInfo ->
m SndQueue
newSndQueue_ a (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
newSndQueue_ a connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
-- this function assumes clientVersion is compatible - it was tested before
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
pure
SndQueue
{ server = smpServer,
{ connId,
server = smpServer,
sndId = senderId,
sndPublicKey = Just sndPublicKey,
sndPrivateKey,
e2eDhSecret = C.dh' rcvE2ePubDhKey e2ePrivKey,
e2ePubKey = Just e2ePubKey,
status = New,
dbQueueId = 0,
primary = True,
nextPrimary = False,
dbReplaceQueueId = Nothing,
smpClientVersion
}

View File

@@ -51,6 +51,7 @@ module Simplex.Messaging.Agent.Client
suspendQueue,
deleteQueue,
logServer,
logSecret,
removeSubscription,
hasActiveSubscription,
agentClientStore,
@@ -62,6 +63,7 @@ module Simplex.Messaging.Agent.Client
agentOperationBracket,
waitUntilActive,
throwWhenInactive,
throwWhenNoDelivery,
beginAgentOperation,
endAgentOperation,
suspendSendingAndDatabase,
@@ -90,11 +92,12 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight, partitionEithers)
import Data.Functor (($>))
import Data.List.NonEmpty (NonEmpty)
import Data.List (partition, (\\))
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (listToMaybe)
import Data.Maybe (isJust, listToMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text.Encoding
@@ -107,6 +110,8 @@ import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction)
import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues)
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
@@ -140,8 +145,6 @@ import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.TMap2 (TMap2)
import qualified Simplex.Messaging.TMap2 as TM2
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
@@ -167,11 +170,11 @@ data AgentClient = AgentClient
ntfClients :: TMap NtfServer NtfClientVar,
useNetworkConfig :: TVar NetworkConfig,
subscrConns :: TVar (Set ConnId),
activeSubs :: TMap2 SMPServer ConnId RcvQueue,
pendingSubs :: TMap2 SMPServer ConnId RcvQueue,
connMsgsQueued :: TMap ConnId Bool,
smpQueueMsgQueues :: TMap (SMPServer, SMP.SenderId) (TQueue InternalId),
smpQueueMsgDeliveries :: TMap (SMPServer, SMP.SenderId) (Async ()),
activeSubs :: TRcvQueues,
pendingSubs :: TRcvQueues,
pendingMsgsQueued :: TMap SndQAddr Bool,
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId),
smpQueueMsgDeliveries :: TMap SndQAddr (Async ()),
connCmdsQueued :: TMap ConnId Bool,
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
asyncCmdProcesses :: TMap (Maybe SMPServer) (Async ()),
@@ -229,9 +232,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
ntfClients <- TM.empty
useNetworkConfig <- newTVar netCfg
subscrConns <- newTVar S.empty
activeSubs <- TM2.empty
pendingSubs <- TM2.empty
connMsgsQueued <- TM.empty
activeSubs <- RQ.empty
pendingSubs <- RQ.empty
pendingMsgsQueued <- TM.empty
smpQueueMsgQueues <- TM.empty
smpQueueMsgDeliveries <- TM.empty
connCmdsQueued <- TM.empty
@@ -249,7 +252,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
reconnections <- newTVar []
asyncClients <- newTVar []
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, clientId, agentEnv}
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, pendingMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, clientId, agentEnv}
agentClientStore :: AgentClient -> SQLiteStore
agentClientStore AgentClient {agentEnv = Env {store}} = store
@@ -282,25 +285,25 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
clientDisconnected u client = do
removeClientAndSubs >>= (`forM_` serverDown)
removeClientAndSubs >>= serverDown
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
where
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs
where
updateSubs cVar = do
TM2.insert1 srv cVar $ pendingSubs c
readTVar cVar
qs <- RQ.getDelSrvQueues srv $ activeSubs c
mapM_ (`RQ.addQueue` pendingSubs c) qs
cs <- RQ.getConns (activeSubs c)
-- TODO deduplicate conns
let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs
pure (qs, conns)
serverDown :: Map ConnId RcvQueue -> IO ()
serverDown cs = whenM (readTVarIO active) $ do
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
serverDown (qs, conns) = whenM (readTVarIO active) $ do
notifySub "" $ hostEvent DISCONNECT client
let conns = M.keys cs
unless (null conns) $ do
notifySub "" $ DOWN srv conns
atomically $ mapM_ (releaseGetLock c) cs
unless (null conns) $ notifySub "" $ DOWN srv conns
unless (null qs) $ do
atomically $ mapM_ (releaseGetLock c) qs
unliftIO u reconnectServer
reconnectServer :: m ()
@@ -317,20 +320,22 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
reconnectClient :: m ()
reconnectClient =
withLockMap_ (reconnectLocks c) srv "reconnect" $
atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar)
>>= mapM_ resubscribe
atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe
where
resubscribe :: Map ConnId RcvQueue -> m ()
resubscribe :: [RcvQueue] -> m ()
resubscribe qs = do
connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar)
(client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs
cs <- atomically . RQ.getConns $ activeSubs c
(client_, rs) <- subscribeQueues c srv qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
liftIO $ do
unless connected $ mapM_ (notifySub "" . hostEvent CONNECT) client_
unless (M.null oks) $ do
notifySub "" . UP srv $ M.keys oks
let (tempErrs, finalErrs) = M.partition temporaryAgentError errs
liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs
mapM_ throwError . listToMaybe $ M.elems tempErrs
-- TODO deduplicate okConns
let conns = okConns \\ S.toList cs
unless (null conns) $ notifySub "" $ UP srv conns
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
mapM_ (throwError . snd) $ listToMaybe tempErrs
notifySub :: ConnId -> ACommand 'Agent -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
@@ -425,10 +430,10 @@ closeAgentClient c = liftIO $ do
cancelActions $ asyncClients c
cancelActions $ smpQueueMsgDeliveries c
cancelActions $ asyncCmdProcesses c
atomically . TM2.clear $ activeSubs c
atomically . TM2.clear $ pendingSubs c
atomically . RQ.clear $ activeSubs c
atomically . RQ.clear $ pendingSubs c
clear subscrConns
clear connMsgsQueued
clear pendingMsgsQueued
clear smpQueueMsgQueues
clear connCmdsQueued
clear asyncCmdQueues
@@ -443,6 +448,14 @@ waitUntilActive c = unlessM (readTVar $ active c) retry
throwWhenInactive :: AgentClient -> STM ()
throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled
throwWhenNoDelivery :: AgentClient -> SndQueue -> STM ()
throwWhenNoDelivery c SndQueue {server, sndId} =
unlessM (isJust <$> TM.lookup k (smpQueueMsgQueues c)) $ do
TM.delete k $ smpQueueMsgDeliveries c
throwSTM ThreadKilled
where
k = (server, sndId)
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
@@ -502,19 +515,20 @@ protocolClientError protocolError_ = \case
e@PCESignatureError {} -> INTERNAL $ show e
e@PCEIOError {} -> INTERNAL $ show e
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c srv vRange =
newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c connId srv vRange =
asks (cmdSignAlg . config) >>= \case
C.SignAlg a -> newRcvQueue_ a c srv vRange
C.SignAlg a -> newRcvQueue_ a c connId srv vRange
newRcvQueue_ ::
(C.SignatureAlgorithm a, C.AlgorithmI a, AgentMonad m) =>
C.SAlgorithm a ->
AgentClient ->
ConnId ->
SMPServer ->
VersionRange ->
m (RcvQueue, SMPQueueUri)
newRcvQueue_ a c srv vRange = do
newRcvQueue_ a c connId srv vRange = do
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
@@ -524,36 +538,41 @@ newRcvQueue_ a c srv vRange = do
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
let rq =
RcvQueue
{ server = srv,
{ connId,
server = srv,
rcvId,
rcvPrivateKey,
rcvDhSecret = C.dh' rcvPublicDhKey privDhKey,
e2ePrivKey,
e2eDhSecret = Nothing,
sndId = Just sndId,
sndId,
status = New,
dbQueueId = 0,
primary = True,
nextPrimary = False,
dbReplaceQueueId = Nothing,
smpClientVersion = maxVersion vRange,
clientNtfCreds = Nothing
}
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
TM2.insert server connId rq $ pendingSubs c
RQ.addQueue rq $ pendingSubs c
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
>>= either throwError pure
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult c rq connId r = do
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult c rq r = do
case r of
Left e ->
atomically . unless (temporaryClientError e) $
TM2.delete connId (pendingSubs c)
_ -> addSubscription c rq connId
RQ.deleteQueue rq (pendingSubs c)
_ -> addSubscription c rq
pure r
temporaryClientError :: ProtocolClientError -> Bool
@@ -569,45 +588,46 @@ temporaryAgentError = \case
_ -> False
-- | subscribe multiple queues - all passed queues should be on the same server
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do
-- TODO check server is correct
modifyTVar (subscrConns c) $ S.insert connId
TM2.insert server connId rq $ pendingSubs c
RQ.addQueue rq $ pendingSubs c
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
(eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of
Left e -> pure $ map (second . const $ Left e) qs_
(eitherToMaybe smp_,) . (errs <>) <$> case smp_ of
Left e -> pure $ map (,Left e) qs_
Right smp -> do
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
let qs2 = L.map (queueCreds . snd) qs'
rs' :: [((ConnId, RcvQueue), Either ProtocolClientError ())] <-
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
_ -> pure (Nothing, M.fromList errs)
let qs2 = L.map queueCreds qs'
liftIO $ do
rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
mapM_ (uncurry $ processSubResult c) rs
pure $ map (second . first $ protocolClientError SMP) rs
_ -> pure (Nothing, errs)
where
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
checkQueue rq@RcvQueue {rcvId, server} = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (connId, Left $ CMD PROHIBITED) else Right rq
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c rq@RcvQueue {server} connId = atomically $ do
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
addSubscription c rq@RcvQueue {connId} = atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
TM2.insert server connId rq $ activeSubs c
TM2.delete connId $ pendingSubs c
RQ.addQueue rq $ activeSubs c
RQ.deleteQueue rq $ pendingSubs c
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
hasActiveSubscription c connId = TM2.member connId $ activeSubs c
hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c
removeSubscription :: AgentClient -> ConnId -> STM ()
removeSubscription c connId = do
modifyTVar (subscrConns c) $ S.delete connId
TM2.delete connId $ activeSubs c
TM2.delete connId $ pendingSubs c
RQ.deleteConn connId $ activeSubs c
RQ.deleteConn connId $ pendingSubs c
getSubscriptions :: AgentClient -> STM (Set ConnId)
getSubscriptions = readTVar . subscrConns

View File

@@ -96,7 +96,8 @@ data AgentConfig = AgentConfig
certificateFile :: FilePath,
e2eEncryptVRange :: VersionRange,
smpAgentVRange :: VersionRange,
smpClientVRange :: VersionRange
smpClientVRange :: VersionRange,
initialClientId :: Int
}
defaultReconnectInterval :: RetryInterval
@@ -142,7 +143,8 @@ defaultAgentConfig =
certificateFile = "/etc/opt/simplex-agent/agent.crt",
e2eEncryptVRange = supportedE2EEncryptVRange,
smpAgentVRange = supportedSMPAgentVRange,
smpClientVRange = supportedSMPClientVRange
smpClientVRange = supportedSMPClientVRange,
initialClientId = 0
}
data Env = Env
@@ -155,12 +157,12 @@ data Env = Env
}
newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
newSMPAgentEnv config@AgentConfig {database, yesToMigrations} = do
newSMPAgentEnv config@AgentConfig {database, yesToMigrations, initialClientId} = do
idsDrg <- newTVarIO =<< drgNew
store <- case database of
AgentDB st -> pure st
AgentDBFile {dbFile, dbKey} -> liftIO $ createAgentStore dbFile dbKey yesToMigrations
clientCounter <- newTVarIO 0
clientCounter <- newTVarIO initialClientId
randomServer <- newTVarIO =<< liftIO newStdGen
ntfSupervisor <- atomically . newNtfSubSupervisor $ tbqSize config
return Env {config, store, idsDrg, clientCounter, randomServer, ntfSupervisor}

View File

@@ -73,7 +73,7 @@ processNtfSub c (connId, cmd) = do
NSCCreate -> do
(a, RcvQueue {server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do
a <- liftIO $ getNtfSubscription db connId
q <- ExceptT $ getRcvQueue db connId
q <- ExceptT $ getPrimaryRcvQueue db connId
pure (a, q)
logInfo $ "processNtfSub, NSCCreate - a = " <> tshow a
case a of
@@ -125,7 +125,7 @@ processNtfSub c (connId, cmd) = do
(Just (NtfSubscription {ntfServer}, _)) -> addNtfNTFWorker ntfServer
_ -> pure () -- err "NSCDelete - no subscription"
NSCSmpDelete -> do
withStore' c (`getRcvQueue` connId) >>= \case
withStore' c (`getPrimaryRcvQueue` connId) >>= \case
Right rq@RcvQueue {server = smpServer} -> do
logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq
ts <- liftIO getCurrentTime
@@ -185,7 +185,7 @@ runNtfWorker c srv doWork = do
NSACreate ->
getNtfToken >>= \case
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
RcvQueue {clientNtfCreds} <- withStore c (`getRcvQueue` connId)
RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId)
case clientNtfCreds of
Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do
nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey
@@ -260,7 +260,7 @@ runNtfSMPWorker c srv doWork = do
NSASmpKey ->
getNtfToken >>= \case
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
rq <- withStore c (`getRcvQueue` connId)
rq <- withStore c (`getPrimaryRcvQueue` connId)
C.SignAlg a <- asks (cmdSignAlg . config)
(ntfPublicKey, ntfPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(rcvNtfPubDhKey, rcvNtfPrivDhKey) <- liftIO C.generateKeyPair'
@@ -275,7 +275,7 @@ runNtfSMPWorker c srv doWork = do
NSASmpDelete -> do
rq_ <- withStore' c $ \db -> do
setRcvQueueNtfCreds db connId Nothing
getRcvQueue db connId
getPrimaryRcvQueue db connId
forM_ rq_ $ \rq -> disableQueueNotifications c rq
withStore' c $ \db -> deleteNtfSubscription db connId

View File

@@ -50,15 +50,19 @@ module Simplex.Messaging.Agent.Protocol
MsgHash,
MsgMeta (..),
ConnectionStats (..),
SwitchPhase (..),
SMPConfirmation (..),
AgentMsgEnvelope (..),
AgentMessage (..),
AgentMessageType (..),
APrivHeader (..),
AMessage (..),
SndQAddr,
SMPServer,
pattern SMPServer,
SrvLoc (..),
SMPQueue (..),
sameQAddress,
SMPQueueUri (..),
SMPQueueInfo (..),
SMPQueueAddress (..),
@@ -162,6 +166,7 @@ import Simplex.Messaging.Protocol
legacyEncodeServer,
legacyServerP,
legacyStrEncodeServer,
sameSrvAddr,
pattern SMPServer,
)
import qualified Simplex.Messaging.Protocol as SMP
@@ -247,12 +252,14 @@ data ACommand (p :: AParty) where
DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent
DOWN :: SMPServer -> [ConnId] -> ACommand Agent
UP :: SMPServer -> [ConnId] -> ACommand Agent
SWITCH :: SwitchPhase -> ConnectionStats -> ACommand Agent
SEND :: MsgFlags -> MsgBody -> ACommand Client
MID :: AgentMsgId -> ACommand Agent
SENT :: AgentMsgId -> ACommand Agent
MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent
MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent
ACK :: AgentMsgId -> ACommand Client
SWCH :: ACommand Client
OFF :: ACommand Client
DEL :: ACommand Client
CHK :: ACommand Client
@@ -284,12 +291,14 @@ data ACommandTag (p :: AParty) where
DISCONNECT_ :: ACommandTag Agent
DOWN_ :: ACommandTag Agent
UP_ :: ACommandTag Agent
SWITCH_ :: ACommandTag Agent
SEND_ :: ACommandTag Client
MID_ :: ACommandTag Agent
SENT_ :: ACommandTag Agent
MERR_ :: ACommandTag Agent
MSG_ :: ACommandTag Agent
ACK_ :: ACommandTag Client
SWCH_ :: ACommandTag Client
OFF_ :: ACommandTag Client
DEL_ :: ACommandTag Client
CHK_ :: ACommandTag Client
@@ -320,12 +329,14 @@ aCommandTag = \case
DISCONNECT {} -> DISCONNECT_
DOWN {} -> DOWN_
UP {} -> UP_
SWITCH {} -> SWITCH_
SEND {} -> SEND_
MID _ -> MID_
SENT _ -> SENT_
MERR {} -> MERR_
MSG {} -> MSG_
ACK _ -> ACK_
SWCH -> SWCH_
OFF -> OFF_
DEL -> DEL_
CHK -> CHK_
@@ -334,6 +345,23 @@ aCommandTag = \case
ERR _ -> ERR_
SUSPENDED -> SUSPENDED_
data SwitchPhase = SPStarted | SPConfirmed | SPTested | SPCompleted
deriving (Eq, Show)
instance StrEncoding SwitchPhase where
strEncode = \case
SPStarted -> "started"
SPConfirmed -> "confirmed"
SPTested -> "tested"
SPCompleted -> "completed"
strP =
A.takeTill (== ' ') >>= \case
"started" -> pure SPStarted
"confirmed" -> pure SPConfirmed
"tested" -> pure SPTested
"completed" -> pure SPCompleted
_ -> fail "bad SwitchPhase"
data ConnectionStats = ConnectionStats
{ rcvServers :: [SMPServer],
sndServers :: [SMPServer]
@@ -508,7 +536,18 @@ instance Encoding AgentMessage where
'M' -> AgentMessage <$> smpP <*> smpP
_ -> fail "bad AgentMessage"
data AgentMessageType = AM_CONN_INFO | AM_CONN_INFO_REPLY | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_
data AgentMessageType
= AM_CONN_INFO
| AM_CONN_INFO_REPLY
| AM_HELLO_
| AM_REPLY_
| AM_A_MSG_
| AM_QADD_
| AM_QKEY_
| AM_QUSE_
| AM_QTEST_
| AM_QDEL_
| AM_QEND_
deriving (Eq, Show)
instance Encoding AgentMessageType where
@@ -518,6 +557,12 @@ instance Encoding AgentMessageType where
AM_HELLO_ -> "H"
AM_REPLY_ -> "R"
AM_A_MSG_ -> "M"
AM_QADD_ -> "QA"
AM_QKEY_ -> "QK"
AM_QUSE_ -> "QU"
AM_QTEST_ -> "QT"
AM_QDEL_ -> "QD"
AM_QEND_ -> "QE"
smpP =
A.anyChar >>= \case
'C' -> pure AM_CONN_INFO
@@ -525,6 +570,15 @@ instance Encoding AgentMessageType where
'H' -> pure AM_HELLO_
'R' -> pure AM_REPLY_
'M' -> pure AM_A_MSG_
'Q' ->
A.anyChar >>= \case
'A' -> pure AM_QADD_
'K' -> pure AM_QKEY_
'U' -> pure AM_QUSE_
'T' -> pure AM_QTEST_
'D' -> pure AM_QDEL_
'E' -> pure AM_QEND_
_ -> fail "bad AgentMessageType"
_ -> fail "bad AgentMessageType"
agentMessageType :: AgentMessage -> AgentMessageType
@@ -540,6 +594,12 @@ agentMessageType = \case
-- REPLY is only used in v1
REPLY _ -> AM_REPLY_
A_MSG _ -> AM_A_MSG_
QADD _ -> AM_QADD_
QKEY _ -> AM_QKEY_
QUSE _ -> AM_QUSE_
QTEST _ -> AM_QTEST_
QDEL _ -> AM_QDEL_
QEND _ -> AM_QEND_
data APrivHeader = APrivHeader
{ -- | sequential ID assigned by the sending agent
@@ -554,7 +614,16 @@ instance Encoding APrivHeader where
smpEncode (sndMsgId, prevMsgHash)
smpP = APrivHeader <$> smpP <*> smpP
data AMsgType = HELLO_ | REPLY_ | A_MSG_
data AMsgType
= HELLO_
| REPLY_
| A_MSG_
| QADD_
| QKEY_
| QUSE_
| QTEST_
| QDEL_
| QEND_
deriving (Eq)
instance Encoding AMsgType where
@@ -562,11 +631,26 @@ instance Encoding AMsgType where
HELLO_ -> "H"
REPLY_ -> "R"
A_MSG_ -> "M"
QADD_ -> "QA"
QKEY_ -> "QK"
QUSE_ -> "QU"
QTEST_ -> "QT"
QDEL_ -> "QD"
QEND_ -> "QE"
smpP =
smpP >>= \case
A.anyChar >>= \case
'H' -> pure HELLO_
'R' -> pure REPLY_
'M' -> pure A_MSG_
'Q' ->
A.anyChar >>= \case
'A' -> pure QADD_
'K' -> pure QKEY_
'U' -> pure QUSE_
'T' -> pure QTEST_
'D' -> pure QDEL_
'E' -> pure QEND_
_ -> fail "bad AMsgType"
_ -> fail "bad AMsgType"
-- | Messages sent between SMP agents once SMP queue is secured.
@@ -579,19 +663,45 @@ data AMessage
REPLY (L.NonEmpty SMPQueueInfo)
| -- | agent envelope for the client message
A_MSG MsgBody
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr))
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
QKEY (L.NonEmpty (SMPQueueInfo, SndPublicVerifyKey))
| -- inform that the queues are ready to use (sent by recipient)
QUSE (L.NonEmpty (SndQAddr, Bool))
| -- sent by the sender to test new queues
QTEST (L.NonEmpty SndQAddr)
| -- inform that the queues will be deleted (sent recipient once message received via the new queue)
QDEL (L.NonEmpty SndQAddr)
| -- sent by sender to confirm that no more messages will be sent to the queue
QEND (L.NonEmpty SndQAddr)
deriving (Show)
type SndQAddr = (SMPServer, SMP.SenderId)
instance Encoding AMessage where
smpEncode = \case
HELLO -> smpEncode HELLO_
REPLY smpQueues -> smpEncode (REPLY_, smpQueues)
A_MSG body -> smpEncode (A_MSG_, Tail body)
QADD qs -> smpEncode (QADD_, qs)
QKEY qs -> smpEncode (QKEY_, qs)
QUSE qs -> smpEncode (QUSE_, qs)
QTEST qs -> smpEncode (QTEST_, qs)
QDEL qs -> smpEncode (QDEL_, qs)
QEND qs -> smpEncode (QEND_, qs)
smpP =
smpP
>>= \case
HELLO_ -> pure HELLO
REPLY_ -> REPLY <$> smpP
A_MSG_ -> A_MSG . unTail <$> smpP
QADD_ -> QADD <$> smpP
QKEY_ -> QKEY <$> smpP
QUSE_ -> QUSE <$> smpP
QTEST_ -> QTEST <$> smpP
QDEL_ -> QDEL <$> smpP
QEND_ -> QEND <$> smpP
instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) where
strEncode = \case
@@ -688,6 +798,11 @@ updateSMPServerHosts srv@ProtocolServer {host} = case host of
_ -> srv
_ -> srv
class SMPQueue q where
qServer :: q -> SMPServer
qAddress :: q -> (SMPServer, SMP.QueueId)
sameQueue :: (SMPServer, SMP.QueueId) -> q -> Bool
data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress}
deriving (Eq, Show)
@@ -728,6 +843,34 @@ data SMPQueueAddress = SMPQueueAddress
}
deriving (Eq, Show)
instance SMPQueue SMPQueueUri where
qServer SMPQueueUri {queueAddress} = qServer queueAddress
{-# INLINE qServer #-}
qAddress SMPQueueUri {queueAddress} = qAddress queueAddress
{-# INLINE qAddress #-}
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
instance SMPQueue SMPQueueInfo where
qServer SMPQueueInfo {queueAddress} = qServer queueAddress
{-# INLINE qServer #-}
qAddress SMPQueueInfo {queueAddress} = qAddress queueAddress
{-# INLINE qAddress #-}
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
instance SMPQueue SMPQueueAddress where
qServer SMPQueueAddress {smpServer} = smpServer
{-# INLINE qServer #-}
qAddress SMPQueueAddress {smpServer, senderId} = (smpServer, senderId)
{-# INLINE qAddress #-}
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
sameQAddress :: (SMPServer, SMP.QueueId) -> (SMPServer, SMP.QueueId) -> Bool
sameQAddress (srv, qId) (srv', qId') = sameSrvAddr srv srv' && qId == qId'
{-# INLINE sameQAddress #-}
instance StrEncoding SMPQueueUri where
strEncode (SMPQueueUri vr SMPQueueAddress {smpServer = srv, senderId = qId, dhPublicKey})
| minVersion vr > 1 = strEncode srv <> "/" <> strEncode qId <> "#/?" <> query queryParams
@@ -754,6 +897,13 @@ instance StrEncoding SMPQueueUri where
hs_ <- queryParam_ "srv" query
pure (vr, maybe [] thList_ hs_, dhKey)
instance Encoding SMPQueueUri where
smpEncode (SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}) =
smpEncode (clientVRange, smpServer, senderId, dhPublicKey)
smpP = do
(clientVRange, smpServer, senderId, dhPublicKey) <- smpP
pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}
data ConnectionRequestUri (m :: ConnectionMode) where
CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation
-- contact connection request does NOT contain E2E encryption parameters -
@@ -963,6 +1113,8 @@ data SMPAgentError
A_ENCRYPTION
| -- | duplicate message - this error is detected by ratchet decryption - this message will be ignored and not shown
A_DUPLICATE
| -- | error in the message to add/delete/etc queue in connection
A_QUEUE {queueErr :: String}
deriving (Eq, Generic, Read, Show, Exception)
instance ToJSON SMPAgentError where
@@ -978,6 +1130,7 @@ instance StrEncoding AgentErrorType where
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> strP)
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
<|> "BROKER " *> (BROKER <$> parseRead1)
<|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString)
<|> "AGENT " *> (AGENT <$> parseRead1)
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
strEncode = \case
@@ -988,6 +1141,7 @@ instance StrEncoding AgentErrorType where
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> strEncode e
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
BROKER e -> "BROKER " <> bshow e
AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e
AGENT e -> "AGENT " <> bshow e
INTERNAL e -> "INTERNAL " <> bshow e
@@ -1029,12 +1183,14 @@ instance StrEncoding ACmdTag where
"DISCONNECT" -> pure $ ACmdTag SAgent DISCONNECT_
"DOWN" -> pure $ ACmdTag SAgent DOWN_
"UP" -> pure $ ACmdTag SAgent UP_
"SWITCH" -> pure $ ACmdTag SAgent SWITCH_
"SEND" -> pure $ ACmdTag SClient SEND_
"MID" -> pure $ ACmdTag SAgent MID_
"SENT" -> pure $ ACmdTag SAgent SENT_
"MERR" -> pure $ ACmdTag SAgent MERR_
"MSG" -> pure $ ACmdTag SAgent MSG_
"ACK" -> pure $ ACmdTag SClient ACK_
"SWCH" -> pure $ ACmdTag SClient SWCH_
"OFF" -> pure $ ACmdTag SClient OFF_
"DEL" -> pure $ ACmdTag SClient DEL_
"CHK" -> pure $ ACmdTag SClient CHK_
@@ -1062,12 +1218,14 @@ instance APartyI p => StrEncoding (ACommandTag p) where
DISCONNECT_ -> "DISCONNECT"
DOWN_ -> "DOWN"
UP_ -> "UP"
SWITCH_ -> "SWITCH"
SEND_ -> "SEND"
MID_ -> "MID"
SENT_ -> "SENT"
MERR_ -> "MERR"
MSG_ -> "MSG"
ACK_ -> "ACK"
SWCH_ -> "SWCH"
OFF_ -> "OFF"
DEL_ -> "DEL"
CHK_ -> "CHK"
@@ -1097,6 +1255,7 @@ commandP binaryP =
SUB_ -> pure SUB
SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP)
ACK_ -> s (ACK <$> A.decimal)
SWCH_ -> pure SWCH
OFF_ -> pure OFF
DEL_ -> pure DEL
CHK_ -> pure CHK
@@ -1112,6 +1271,7 @@ commandP binaryP =
DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP)
DOWN_ -> s (DOWN <$> strP_ <*> connections)
UP_ -> s (UP <$> strP_ <*> connections)
SWITCH_ -> s (SWITCH <$> strP_ <*> strP)
MID_ -> s (MID <$> A.decimal)
SENT_ -> s (SENT <$> A.decimal)
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
@@ -1139,36 +1299,40 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX
-- | Serialize SMP agent command.
serializeCommand :: ACommand p -> ByteString
serializeCommand = \case
NEW ntfs cMode -> B.unwords ["NEW", strEncode ntfs, strEncode cMode]
INV cReq -> "INV " <> strEncode cReq
JOIN ntfs cReq cInfo -> B.unwords ["JOIN", strEncode ntfs, strEncode cReq, serializeBinary cInfo]
CONF confId srvs cInfo -> B.unwords ["CONF", confId, strEncodeList srvs, serializeBinary cInfo]
LET confId cInfo -> B.unwords ["LET", confId, serializeBinary cInfo]
REQ invId srvs cInfo -> B.unwords ["REQ", invId, strEncode srvs, serializeBinary cInfo]
ACPT invId cInfo -> B.unwords ["ACPT", invId, serializeBinary cInfo]
RJCT invId -> "RJCT " <> invId
INFO cInfo -> "INFO " <> serializeBinary cInfo
SUB -> "SUB"
END -> "END"
CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h]
DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h]
DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns]
UP srv conns -> B.unwords ["UP", strEncode srv, connections conns]
SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody
MID mId -> "MID " <> bshow mId
SENT mId -> "SENT " <> bshow mId
MERR mId e -> B.unwords ["MERR", bshow mId, strEncode e]
MSG msgMeta msgFlags msgBody -> B.unwords ["MSG", serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody]
ACK mId -> "ACK " <> bshow mId
OFF -> "OFF"
DEL -> "DEL"
CHK -> "CHK"
STAT srvs -> "STAT " <> strEncode srvs
CON -> "CON"
ERR e -> "ERR " <> strEncode e
OK -> "OK"
SUSPENDED -> "SUSPENDED"
NEW ntfs cMode -> s (NEW_, ntfs, cMode)
INV cReq -> s (INV_, cReq)
JOIN ntfs cReq cInfo -> s (JOIN_, ntfs, cReq, Str $ serializeBinary cInfo)
CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo]
LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo]
REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo]
ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo]
RJCT invId -> B.unwords [s RJCT_, invId]
INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo]
SUB -> s SUB_
END -> s END_
CONNECT p h -> s (CONNECT_, p, h)
DISCONNECT p h -> s (DISCONNECT_, p, h)
DOWN srv conns -> B.unwords [s DOWN_, s srv, connections conns]
UP srv conns -> B.unwords [s UP_, s srv, connections conns]
SWITCH phase srvs -> s (SWITCH_, phase, srvs)
SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody]
MID mId -> s (MID_, Str $ bshow mId)
SENT mId -> s (SENT_, Str $ bshow mId)
MERR mId e -> s (MERR_, Str $ bshow mId, e)
MSG msgMeta msgFlags msgBody -> B.unwords [s MSG_, serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody]
ACK mId -> s (ACK_, Str $ bshow mId)
SWCH -> s SWCH_
OFF -> s OFF_
DEL -> s DEL_
CHK -> s CHK_
STAT srvs -> s (STAT_, srvs)
CON -> s CON_
ERR e -> s (ERR_, e)
OK -> s OK_
SUSPENDED -> s SUSPENDED_
where
s :: StrEncoding a => a -> ByteString
s = strEncode
showTs :: UTCTime -> ByteString
showTs = B.pack . formatISO8601Millis
connections :: [ConnId] -> ByteString

View File

@@ -5,6 +5,7 @@
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -17,6 +18,9 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
import Data.Int (Int64)
import Data.Kind (Type)
import Data.List (find)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Time (UTCTime)
import Data.Type.Equality
import Simplex.Messaging.Agent.Protocol
@@ -43,7 +47,8 @@ import Simplex.Messaging.Version
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
data RcvQueue = RcvQueue
{ server :: SMPServer,
{ connId :: ConnId,
server :: SMPServer,
-- | recipient queue ID
rcvId :: SMP.RecipientId,
-- | key used by the recipient to sign transmissions
@@ -55,9 +60,17 @@ data RcvQueue = RcvQueue
-- | public sender's DH key and agreed shared DH secret for simple per-queue e2e
e2eDhSecret :: Maybe C.DhSecretX25519,
-- | sender queue ID
sndId :: Maybe SMP.SenderId,
sndId :: SMP.SenderId,
-- | queue status
status :: QueueStatus,
-- | database queue ID (within connection), can be Nothing for old queues
dbQueueId :: Int64,
-- | True for a primary queue of the connection
primary :: Bool,
-- | True for the next primary queue
nextPrimary :: Bool,
-- | database queue ID to replace, Nothing if this queue is not replacing another, `Just Nothing` is used for replacing old queues
dbReplaceQueueId :: Maybe Int64,
-- | SMP client version
smpClientVersion :: Version,
-- | credentials used in context of notifications
@@ -78,7 +91,8 @@ data ClientNtfCreds = ClientNtfCreds
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
data SndQueue = SndQueue
{ server :: SMPServer,
{ connId :: ConnId,
server :: SMPServer,
-- | sender queue ID
sndId :: SMP.SenderId,
-- | key pair used by the sender to sign transmissions
@@ -90,11 +104,52 @@ data SndQueue = SndQueue
e2eDhSecret :: C.DhSecretX25519,
-- | queue status
status :: QueueStatus,
-- | database queue ID (within connection), can be Nothing for old queues
dbQueueId :: Int64,
-- | True for a primary queue of the connection
primary :: Bool,
-- | True for the next primary queue
nextPrimary :: Bool,
-- | ID of the queue this one is replacing
dbReplaceQueueId :: Maybe Int64,
-- | SMP client version
smpClientVersion :: Version
}
deriving (Eq, Show)
instance SMPQueue RcvQueue where
qServer RcvQueue {server} = server
{-# INLINE qServer #-}
qAddress RcvQueue {server, rcvId} = (server, rcvId)
{-# INLINE qAddress #-}
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
instance SMPQueue SndQueue where
qServer SndQueue {server} = server
{-# INLINE qServer #-}
qAddress SndQueue {server, sndId} = (server, sndId)
{-# INLINE qAddress #-}
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
findQ :: SMPQueue q => (SMPServer, SMP.QueueId) -> NonEmpty q -> Maybe q
findQ = find . sameQueue
{-# INLINE findQ #-}
removeQ :: SMPQueue q => (SMPServer, SMP.QueueId) -> NonEmpty q -> Maybe (q, [q])
removeQ addr qs = case L.break (sameQueue addr) qs of
(_, []) -> Nothing
(qs1, q : qs2) -> Just (q, qs1 <> qs2)
sndAddress :: RcvQueue -> (SMPServer, SMP.SenderId)
sndAddress RcvQueue {server, sndId} = (server, sndId)
{-# INLINE sndAddress #-}
findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue
findRQ sAddr = find $ sameQAddress sAddr . sndAddress
{-# INLINE findRQ #-}
-- * Connection types
-- | Type of a connection.
@@ -114,13 +169,21 @@ data Connection (d :: ConnType) where
NewConnection :: ConnData -> Connection CNew
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
SndConnection :: ConnData -> SndQueue -> Connection CSnd
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex
ContactConnection :: ConnData -> RcvQueue -> Connection CContact
deriving instance Eq (Connection d)
deriving instance Show (Connection d)
connData :: Connection d -> ConnData
connData = \case
NewConnection cData -> cData
RcvConnection cData _ -> cData
SndConnection cData _ -> cData
DuplexConnection cData _ _ -> cData
ContactConnection cData _ -> cData
data SConnType :: ConnType -> Type where
SCNew :: SConnType CNew
SCRcv :: SConnType CRcv
@@ -193,6 +256,7 @@ instance StrEncoding AgentCommand where
data AgentCommandTag
= AClientCommandTag (ACommandTag 'Client)
| AInternalCommandTag InternalCommandTag
deriving (Show)
instance StrEncoding AgentCommandTag where
strEncode = \case
@@ -208,12 +272,16 @@ data InternalCommand
| ICAckDel SMP.RecipientId MsgId InternalId
| ICAllowSecure SMP.RecipientId SMP.SndPublicVerifyKey
| ICDuplexSecure SMP.RecipientId SMP.SndPublicVerifyKey
| ICQSecure SMP.RecipientId SMP.SndPublicVerifyKey
| ICQDelete SMP.RecipientId
data InternalCommandTag
= ICAck_
| ICAckDel_
| ICAllowSecure_
| ICDuplexSecure_
| ICQSecure_
| ICQDelete_
deriving (Show)
instance StrEncoding InternalCommand where
@@ -222,12 +290,16 @@ instance StrEncoding InternalCommand where
ICAckDel rId srvMsgId mId -> strEncode (ICAckDel_, rId, srvMsgId, mId)
ICAllowSecure rId sndKey -> strEncode (ICAllowSecure_, rId, sndKey)
ICDuplexSecure rId sndKey -> strEncode (ICDuplexSecure_, rId, sndKey)
ICQSecure rId senderKey -> strEncode (ICQSecure_, rId, senderKey)
ICQDelete rId -> strEncode (ICQDelete_, rId)
strP =
strP_ >>= \case
ICAck_ -> ICAck <$> strP_ <*> strP
ICAckDel_ -> ICAckDel <$> strP_ <*> strP_ <*> strP
ICAllowSecure_ -> ICAllowSecure <$> strP_ <*> strP
ICDuplexSecure_ -> ICDuplexSecure <$> strP_ <*> strP
strP >>= \case
ICAck_ -> ICAck <$> _strP <*> _strP
ICAckDel_ -> ICAckDel <$> _strP <*> _strP <*> _strP
ICAllowSecure_ -> ICAllowSecure <$> _strP <*> _strP
ICDuplexSecure_ -> ICDuplexSecure <$> _strP <*> _strP
ICQSecure_ -> ICQSecure <$> _strP <*> _strP
ICQDelete_ -> ICQDelete <$> _strP
instance StrEncoding InternalCommandTag where
strEncode = \case
@@ -235,12 +307,16 @@ instance StrEncoding InternalCommandTag where
ICAckDel_ -> "ACK_DEL"
ICAllowSecure_ -> "ALLOW_SECURE"
ICDuplexSecure_ -> "DUPLEX_SECURE"
ICQSecure_ -> "QSECURE"
ICQDelete_ -> "QDELETE"
strP =
A.takeTill (== ' ') >>= \case
"ACK" -> pure ICAck_
"ACK_DEL" -> pure ICAckDel_
"ALLOW_SECURE" -> pure ICAllowSecure_
"DUPLEX_SECURE" -> pure ICDuplexSecure_
"QSECURE" -> pure ICQSecure_
"QDELETE" -> pure ICQDelete_
_ -> fail "bad InternalCommandTag"
agentCommandTag :: AgentCommand -> AgentCommandTag
@@ -254,6 +330,8 @@ internalCmdTag = \case
ICAckDel {} -> ICAckDel_
ICAllowSecure {} -> ICAllowSecure_
ICDuplexSecure {} -> ICDuplexSecure_
ICQSecure {} -> ICQSecure_
ICQDelete _ -> ICQDelete_
-- * Confirmation types

View File

@@ -15,6 +15,7 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
@@ -38,9 +39,16 @@ module Simplex.Messaging.Agent.Store.SQLite
deleteConn,
upgradeRcvConnToDuplex,
upgradeSndConnToDuplex,
addConnRcvQueue,
addConnSndQueue,
setRcvQueueStatus,
setRcvQueueConfirmedE2E,
setSndQueueStatus,
setRcvQueuePrimary,
setSndQueuePrimary,
deleteConnRcvQueue,
deleteConnSndQueue,
getPrimaryRcvQueue,
getRcvQueue,
setRcvQueueNtfCreds,
-- Confirmations
@@ -60,11 +68,14 @@ module Simplex.Messaging.Agent.Store.SQLite
createRcvMsg,
updateSndIds,
createSndMsg,
createSndMsgDelivery,
getPendingMsgData,
getPendingMsgs,
deletePendingMsgs,
setMsgUserAck,
getLastMsg,
deleteMsg,
deleteSndMsgDelivery,
-- Double ratchet persistence
createRatchetX3dhKeys,
getRatchetX3dhKeys,
@@ -120,10 +131,12 @@ import Data.Function (on)
import Data.Functor (($>))
import Data.IORef
import Data.Int (Int64)
import Data.List (foldl', groupBy)
import Data.List (foldl', groupBy, sortBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, listToMaybe)
import Data.Ord (Down (..))
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
@@ -134,6 +147,7 @@ import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.QQ (sql)
import Database.SQLite.Simple.ToField (ToField (..))
import qualified Database.SQLite3 as SQLite3
import Network.Socket (ServiceName)
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite.Migrations (Migration)
@@ -295,45 +309,39 @@ createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh
createConn_ gVar cData $ \connId -> do
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ())
updateNewConnRcv db connId rq@RcvQueue {server} =
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
updateNewConnRcv db connId rq =
getConn db connId $>>= \case
(SomeConn _ NewConnection {}) -> updateConn
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
where
updateConn :: IO (Either StoreError ())
updateConn = do
upsertServer_ db server
insertRcvQueue_ db connId rq
pure $ Right ()
updateConn :: IO (Either StoreError Int64)
updateConn = Right <$> addConnRcvQueue_ db connId rq
updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ())
updateNewConnSnd db connId sq@SndQueue {server} =
updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
updateNewConnSnd db connId sq =
getConn db connId $>>= \case
(SomeConn _ NewConnection {}) -> updateConn
(SomeConn _ SndConnection {}) -> updateConn -- to allow retries
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
where
updateConn :: IO (Either StoreError ())
updateConn = do
upsertServer_ db server
insertSndQueue_ db connId sq
pure $ Right ()
updateConn :: IO (Either StoreError Int64)
updateConn = Right <$> addConnSndQueue_ db connId sq
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
insertRcvQueue_ db connId q
void $ insertRcvQueue_ db connId q
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
createConn_ gVar cData $ \connId -> do
upsertServer_ db server
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
insertSndQueue_ db connId q
void $ insertSndQueue_ db connId q
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError SomeConn)
getRcvConn db ProtocolServer {host, port} rcvId =
@@ -356,25 +364,43 @@ deleteConn db connId =
"DELETE FROM connections WHERE conn_id = :conn_id;"
[":conn_id" := connId]
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ())
upgradeRcvConnToDuplex db connId sq@SndQueue {server} =
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
upgradeRcvConnToDuplex db connId sq =
getConn db connId $>>= \case
(SomeConn _ RcvConnection {}) -> do
upsertServer_ db server
insertSndQueue_ db connId sq
pure $ Right ()
(SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ())
upgradeSndConnToDuplex db connId rq@RcvQueue {server} =
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
upgradeSndConnToDuplex db connId rq =
getConn db connId >>= \case
Right (SomeConn _ SndConnection {}) -> do
upsertServer_ db server
insertRcvQueue_ db connId rq
pure $ Right ()
Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
addConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
addConnRcvQueue db connId rq =
getConn db connId >>= \case
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
addConnRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
addConnRcvQueue_ db connId rq@RcvQueue {server} = do
upsertServer_ db server
insertRcvQueue_ db connId rq
addConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
addConnSndQueue db connId sq =
getConn db connId >>= \case
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
addConnSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
addConnSndQueue_ db connId sq@SndQueue {server} = do
upsertServer_ db server
insertSndQueue_ db connId sq
setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
-- ? return error if queue does not exist?
@@ -418,9 +444,39 @@ setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} stat
|]
[":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId]
getRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue)
getRcvQueue db connId =
maybe (Left SEConnNotFound) Right <$> getRcvQueueByConnId_ db connId
setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO ()
setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do
DB.execute db "UPDATE rcv_queues SET rcv_primary = ?, next_rcv_primary = ? WHERE conn_id = ?" (False, False, connId)
DB.execute
db
"UPDATE rcv_queues SET rcv_primary = ?, next_rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?"
(True, False, Nothing :: Maybe Int64, connId, dbQueueId)
setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO ()
setSndQueuePrimary db connId SndQueue {dbQueueId} = do
DB.execute db "UPDATE snd_queues SET snd_primary = ?, next_snd_primary = ? WHERE conn_id = ?" (False, False, connId)
DB.execute
db
"UPDATE snd_queues SET snd_primary = ?, next_snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?"
(True, False, Nothing :: Maybe Int64, connId, dbQueueId)
deleteConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO ()
deleteConnRcvQueue db connId RcvQueue {dbQueueId} =
DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO ()
deleteConnSndQueue db connId SndQueue {dbQueueId} = do
DB.execute db "DELETE FROM snd_queues WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
getPrimaryRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue)
getPrimaryRcvQueue db connId =
maybe (Left SEConnNotFound) (Right . L.head) <$> getRcvQueuesByConnId_ db connId
getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue)
getRcvQueue db connId (SMPServer host port _) rcvId =
firstRow (toRcvQueue connId) SEConnNotFound $
DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ?") (connId, host, port, rcvId)
setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO ()
setRcvQueueNtfCreds db connId clientNtfCreds =
@@ -587,10 +643,10 @@ updateRcvIds db connId = do
updateLastIdsRcv_ db connId internalId internalRcvId
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
createRcvMsg :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
createRcvMsg db connId rcvMsgData = do
createRcvMsg :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
createRcvMsg db connId rq rcvMsgData = do
insertRcvMsgBase_ db connId rcvMsgData
insertRcvMsgDetails_ db connId rcvMsgData
insertRcvMsgDetails_ db connId rq rcvMsgData
updateHashRcv_ db connId rcvMsgData
updateSndIds :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
@@ -607,9 +663,13 @@ createSndMsg db connId sndMsgData = do
insertSndMsgDetails_ db connId sndMsgData
updateHashSnd_ db connId sndMsgData
createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO ()
createSndMsgDelivery db connId SndQueue {dbQueueId} msgId =
DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId)
getPendingMsgData :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (Maybe RcvQueue, PendingMsgData))
getPendingMsgData db connId msgId = do
rq_ <- getRcvQueueByConnId_ db connId
rq_ <- L.head <$$> getRcvQueuesByConnId_ db connId
(rq_,) <$$> firstRow pendingMsgData SEMsgNotFound getMsgData_
where
getMsgData_ =
@@ -627,16 +687,23 @@ getPendingMsgData db connId msgId = do
let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_
in PendingMsgData {msgId, msgType, msgFlags, msgBody, internalTs}
getPendingMsgs :: DB.Connection -> ConnId -> IO [InternalId]
getPendingMsgs db connId =
getPendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO [InternalId]
getPendingMsgs db connId SndQueue {dbQueueId} =
map fromOnly
<$> DB.query db "SELECT internal_id FROM snd_messages WHERE conn_id = ?" (Only connId)
<$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError SMP.MsgId)
setMsgUserAck db connId agentMsgId = do
DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId)
firstRow fromOnly SEMsgNotFound $
DB.query db "SELECT broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO ()
deletePendingMsgs db connId SndQueue {dbQueueId} =
DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (RcvQueue, SMP.MsgId))
setMsgUserAck db connId agentMsgId = runExceptT $ do
liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId)
(dbRcvId, srvMsgId) <-
ExceptT . firstRow id SEMsgNotFound $
DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
rq <- ExceptT $ getRcvQueueById_ db connId dbRcvId
pure (rq, srvMsgId)
getLastMsg :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Maybe RcvMsg)
getLastMsg db connId msgId =
@@ -662,6 +729,15 @@ deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO ()
deleteMsg db connId msgId =
DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId)
deleteSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO ()
deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId = do
DB.execute
db
"DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? AND internal_id = ?"
(connId, dbQueueId, msgId)
(Only (cnt :: Int) : _) <- DB.query db "SELECT count(*) FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ?" (connId, msgId)
when (cnt == 0) $ deleteMsg db connId msgId
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO ()
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 =
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2)
@@ -1202,27 +1278,35 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
-- * createRcvConn helpers
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
insertRcvQueue_ dbConn connId RcvQueue {..} = do
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
insertRcvQueue_ db connId' RcvQueue {..} = do
qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')
DB.execute
dbConn
db
[sql|
INSERT INTO rcv_queues
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?);
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, next_rcv_primary, replace_rcv_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
((host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion))
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion))
pure qId
-- * createSndConn helpers
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
insertSndQueue_ dbConn connId SndQueue {..} = do
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
insertSndQueue_ db connId' SndQueue {..} = do
qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')
DB.execute
dbConn
db
[sql|
INSERT INTO snd_queues
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?);
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, next_snd_primary, replace_snd_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|]
(host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion)
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion))
pure qId
newQueueId_ :: [Only Int64] -> Int64
newQueueId_ [] = 1
newQueueId_ (Only maxId : _) = maxId + 1
-- * getConn helpers
@@ -1230,65 +1314,79 @@ getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConn dbConn connId =
getConnData dbConn connId >>= \case
Nothing -> pure $ Left SEConnNotFound
Just (connData, cMode) -> do
rQ <- getRcvQueueByConnId_ dbConn connId
sQ <- getSndQueueByConnId_ dbConn connId
Just (cData, cMode) -> do
rQ <- getRcvQueuesByConnId_ dbConn connId
sQ <- getSndQueuesByConnId_ dbConn connId
pure $ case (rQ, sQ, cMode) of
(Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
(Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
(Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
(Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection connData)
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
_ -> Left SEConnNotFound
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData dbConn connId' =
connData
<$> DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId')
maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId')
where
connData [(connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake)] = Just (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode)
connData _ = Nothing
cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode)
getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
getRcvQueueByConnId_ dbConn connId =
listToMaybe . map rcvQueue
-- | returns all connection queues, the first queue is the primary one
getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue))
getRcvQueuesByConnId_ db connId =
L.nonEmpty . sortBy primaryFirst . map (toRcvQueue connId)
<$> DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ?") (Only connId)
where
primaryFirst RcvQueue {primary = p} RcvQueue {primary = p'} = compare (Down p) (Down p')
rcvQueueQuery :: Query
rcvQueueQuery =
[sql|
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
q.rcv_queue_id, q.rcv_primary, q.next_rcv_primary, q.replace_rcv_queue_id, q.smp_client_version,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|]
toRcvQueue ::
ConnId ->
(C.KeyHash, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
:. (Int64, Bool, Bool, Maybe Int64, Maybe Version)
:. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
RcvQueue
toRcvQueue connId ((keyHash, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
let server = SMPServer host port keyHash
smpClientVersion = fromMaybe 1 smpClientVersion_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
getRcvQueueById_ db connId dbRcvId =
firstRow (toRcvQueue connId) SEConnNotFound $
DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId)
-- | returns all connection queues, the first queue is the primary one
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
getSndQueuesByConnId_ dbConn connId =
L.nonEmpty . sortBy primaryFirst . map sndQueue
<$> DB.query
dbConn
[sql|
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status, q.smp_client_version,
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.conn_id = ?;
|]
(Only connId)
where
rcvQueue ((keyHash, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
let server = SMPServer host port keyHash
smpClientVersion = fromMaybe 1 smpClientVersion_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion, clientNtfCreds}
getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
getSndQueueByConnId_ dbConn connId =
sndQueue
<$> DB.query
dbConn
[sql|
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.smp_client_version
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.next_snd_primary, q.replace_snd_queue_id, q.smp_client_version
FROM snd_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.conn_id = ?;
|]
(Only connId)
where
sndQueue [(keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion)] =
sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion)) =
let server = SMPServer host port keyHash
in Just SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion}
sndQueue _ = Nothing
in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion}
primaryFirst SndQueue {primary = p} SndQueue {primary = p'} = compare (Down p) (Down p')
-- * updateRcvIds helpers
@@ -1342,22 +1440,23 @@ insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody,
":msg_body" := msgBody
]
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connId RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do
let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta
DB.executeNamed
dbConn
[sql|
INSERT INTO rcv_messages
( conn_id, internal_rcv_id, internal_id, external_snd_id,
( conn_id, rcv_queue_id, internal_rcv_id, internal_id, external_snd_id,
broker_id, broker_ts,
internal_hash, external_prev_snd_hash, integrity)
VALUES
(:conn_id,:internal_rcv_id,:internal_id,:external_snd_id,
(:conn_id,:rcv_queue_id,:internal_rcv_id,:internal_id,:external_snd_id,
:broker_id,:broker_ts,
:internal_hash,:external_prev_snd_hash,:integrity);
|]
[ ":conn_id" := connId,
":rcv_queue_id" := dbQueueId,
":internal_rcv_id" := internalRcvId,
":internal_id" := fst recipient,
":external_snd_id" := sndMsgId,

View File

@@ -36,6 +36,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -51,7 +52,8 @@ schemaMigrations =
("m20220625_v2_ntf_mode", m20220625_v2_ntf_mode),
("m20220811_onion_hosts", m20220811_onion_hosts),
("m20220817_connection_ntfs", m20220817_connection_ntfs),
("m20220905_commands", m20220905_commands)
("m20220905_commands", m20220905_commands),
("m20220915_connection_queues", m20220915_connection_queues)
]
-- | The list of migrations in ascending order by date

View File

@@ -0,0 +1,54 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20220915_connection_queues :: Query
m20220915_connection_queues =
[sql|
PRAGMA ignore_check_constraints=ON;
-- rcv_queues
ALTER TABLE rcv_queues ADD COLUMN rcv_queue_id INTEGER CHECK (rcv_queue_id NOT NULL);
UPDATE rcv_queues SET rcv_queue_id = 1;
CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues (conn_id, rcv_queue_id);
ALTER TABLE rcv_queues ADD COLUMN rcv_primary INTEGER CHECK (rcv_primary NOT NULL);
UPDATE rcv_queues SET rcv_primary = 1;
ALTER TABLE rcv_queues ADD COLUMN next_rcv_primary INTEGER CHECK (next_rcv_primary NOT NULL);
UPDATE rcv_queues SET next_rcv_primary = 0;
ALTER TABLE rcv_queues ADD COLUMN replace_rcv_queue_id INTEGER NULL;
-- snd_queues
ALTER TABLE snd_queues ADD COLUMN snd_queue_id INTEGER CHECK (snd_queue_id NOT NULL);
UPDATE snd_queues SET snd_queue_id = 1;
CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues (conn_id, snd_queue_id);
ALTER TABLE snd_queues ADD COLUMN snd_primary INTEGER CHECK (snd_primary NOT NULL);
UPDATE snd_queues SET snd_primary = 1;
ALTER TABLE snd_queues ADD COLUMN next_snd_primary INTEGER CHECK (next_snd_primary NOT NULL);
UPDATE snd_queues SET next_snd_primary = 0;
ALTER TABLE snd_queues ADD COLUMN replace_snd_queue_id INTEGER NULL;
-- messages
CREATE TABLE snd_message_deliveries (
snd_message_delivery_id INTEGER PRIMARY KEY AUTOINCREMENT,
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
snd_queue_id INTEGER NOT NULL,
internal_id INTEGER NOT NULL,
FOREIGN KEY (conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED
);
CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries (conn_id, snd_queue_id);
ALTER TABLE rcv_messages ADD COLUMN rcv_queue_id INTEGER CHECK (rcv_queue_id NOT NULL);
UPDATE rcv_messages SET rcv_queue_id = 1;
PRAGMA ignore_check_constraints=OFF;
|]

View File

@@ -41,6 +41,10 @@ CREATE TABLE rcv_queues(
ntf_private_key BLOB,
ntf_id BLOB,
rcv_ntf_dh_secret BLOB,
rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL),
rcv_primary INTEGER CHECK(rcv_primary NOT NULL),
next_rcv_primary INTEGER CHECK(next_rcv_primary NOT NULL),
replace_rcv_queue_id INTEGER NULL,
PRIMARY KEY(host, port, rcv_id),
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE,
@@ -58,6 +62,10 @@ CREATE TABLE snd_queues(
smp_client_version INTEGER NOT NULL DEFAULT 1,
snd_public_key BLOB,
e2e_pub_key BLOB,
snd_queue_id INTEGER CHECK(snd_queue_id NOT NULL),
snd_primary INTEGER CHECK(snd_primary NOT NULL),
next_snd_primary INTEGER CHECK(next_snd_primary NOT NULL),
replace_snd_queue_id INTEGER NULL,
PRIMARY KEY(host, port, snd_id),
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE
@@ -89,6 +97,7 @@ CREATE TABLE rcv_messages(
external_prev_snd_hash BLOB NOT NULL,
integrity BLOB NOT NULL,
user_ack INTEGER NULL DEFAULT 0,
rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL),
PRIMARY KEY(conn_id, internal_rcv_id),
FOREIGN KEY(conn_id, internal_id) REFERENCES messages
ON DELETE CASCADE
@@ -206,3 +215,17 @@ CREATE TABLE commands(
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE
);
CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id);
CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id);
CREATE TABLE snd_message_deliveries(
snd_message_delivery_id INTEGER PRIMARY KEY AUTOINCREMENT,
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
snd_queue_id INTEGER NOT NULL,
internal_id INTEGER NOT NULL,
FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED
);
CREATE TABLE sqlite_sequence(name,seq);
CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries(
conn_id,
snd_queue_id
);

View File

@@ -0,0 +1,50 @@
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.Agent.TRcvQueues where
import Control.Concurrent.STM
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
import Simplex.Messaging.Agent.Protocol (ConnId)
import Simplex.Messaging.Agent.Store (RcvQueue (..))
import Simplex.Messaging.Protocol (RecipientId, SMPServer)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
newtype TRcvQueues = TRcvQueues (TMap (SMPServer, RecipientId) RcvQueue)
empty :: STM TRcvQueues
empty = TRcvQueues <$> TM.empty
clear :: TRcvQueues -> STM ()
clear (TRcvQueues qs) = TM.clear qs
deleteConn :: ConnId -> TRcvQueues -> STM ()
deleteConn cId (TRcvQueues qs) = modifyTVar' qs $ M.filter (\rq -> cId /= connId rq)
-- TODO possibly, should be limited to active queues
hasConn :: ConnId -> TRcvQueues -> STM Bool
hasConn cId (TRcvQueues qs) = any (\rq -> cId == connId rq) <$> readTVar qs
-- TODO possibly, should be limited to active queues
getConns :: TRcvQueues -> STM (Set ConnId)
getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs
addQueue :: RcvQueue -> TRcvQueues -> STM ()
addQueue rq@RcvQueue {server, rcvId} (TRcvQueues qs) = TM.insert (server, rcvId) rq qs
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
deleteQueue RcvQueue {server, rcvId} (TRcvQueues qs) = TM.delete (server, rcvId) qs
getSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
where
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
where
addQ (removed, qs') rq@RcvQueue {server, rcvId}
| srv == server = (rq : removed, qs')
| otherwise = (removed, M.insert (server, rcvId) rq qs')

View File

@@ -6,6 +6,7 @@ module Simplex.Messaging.Encoding.String
StrEncoding (..),
Str (..),
strP_,
_strP,
strToJSON,
strToJEncoding,
strParseJSON,
@@ -171,6 +172,9 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin
strP_ :: StrEncoding a => Parser a
strP_ = strP <* A.space
_strP :: StrEncoding a => Parser a
_strP = A.space *> strP
strToJSON :: StrEncoding a => a -> J.Value
strToJSON = J.String . decodeLatin1 . strEncode

View File

@@ -115,6 +115,7 @@ module Simplex.Messaging.Protocol
legacyEncodeServer,
legacyServerP,
legacyStrEncodeServer,
sameSrvAddr,
-- * TCP transport functions
tPut,
@@ -572,6 +573,10 @@ pattern NtfServer host port keyHash = ProtocolServer SPNTF host port keyHash
{-# COMPLETE NtfServer #-}
sameSrvAddr :: ProtocolServer p -> ProtocolServer p -> Bool
sameSrvAddr ProtocolServer {host, port} ProtocolServer {host = h', port = p'} = host == h' && port == p'
{-# INLINE sameSrvAddr #-}
data ProtocolType = PSMP | PNTF
deriving (Eq, Ord, Show)

View File

@@ -1,82 +0,0 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.TMap2
( TMap2,
empty,
clear,
Simplex.Messaging.TMap2.lookup,
lookup1,
member,
insert,
insert1,
delete,
lookupDelete1,
)
where
import Control.Concurrent.STM
import Control.Monad (forM_, (>=>))
import qualified Data.Map.Strict as M
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (whenM, ($>>=))
-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys.
-- It allows direct access via k1 to a group of k2 values and via k2 to one value
data TMap2 k1 k2 a = TMap2
{ _m1 :: TMap k1 (TMap k2 a),
_m2 :: TMap k2 k1
}
empty :: STM (TMap2 k1 k2 a)
empty = TMap2 <$> TM.empty <*> TM.empty
clear :: TMap2 k1 k2 a -> STM ()
clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2
lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a)
lookup k2 TMap2 {_m1, _m2} = do
TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2
lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1
{-# INLINE lookup1 #-}
member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool
member k2 TMap2 {_m2} = TM.member k2 _m2
{-# INLINE member #-}
insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM ()
insert k1 k2 v TMap2 {_m1, _m2} =
TM.lookup k2 _m2 >>= \case
Just k1'
| k1 == k1' -> _insert1
| otherwise -> _delete1 k1' k2 _m1 >> _insert2
_ -> _insert2
where
_insert1 =
TM.lookup k1 _m1 >>= \case
Just m -> TM.insert k2 v m
_ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1
_insert2 = TM.insert k2 k1 _m2 >> _insert1
insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM ()
insert1 k1 m' TMap2 {_m1, _m2} =
TM.lookup k1 _m1 >>= \case
Just m -> readTVar m' >>= (`TM.union` m)
_ -> TM.insert k1 m' _m1
delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM ()
delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1)
_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM ()
_delete1 k1 k2 m1 =
TM.lookup k1 m1
>>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1))
lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
lookupDelete1 k1 TMap2 {_m1, _m2} = do
m_ <- TM.lookupDelete k1 _m1
forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet
pure m_

View File

@@ -21,14 +21,15 @@ import Control.Concurrent (killThread, threadDelay)
import Control.Monad
import Control.Monad.Except (ExceptT, runExceptT)
import Control.Monad.IO.Unlift
import Data.ByteString.Char8 (ByteString)
import Data.Int (Int64)
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
import SMPAgentClient
import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers)
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Client (ProtocolClientConfig (..))
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
@@ -115,6 +116,15 @@ functionalAPITests t = do
testAsyncCommandsRestore t
it "should accept connection using async command" $
withSmpServer t testAcceptContactAsync
describe "Queue rotation" $ do
it "should switch delivery to the new queue (1 server)" $
withSmpServer t $ testSwitchConnection initAgentServers
it "should switch delivery to the new queue (2 servers)" $
withSmpServer t . withSmpServerOn t testPort2 $ testSwitchConnection initAgentServers2
it "should switch to new queue asynchronously (1 server)" $
withSmpServer t $ testSwitchAsync initAgentServers
it "should switch to new queue asynchronously (2 servers)" $
withSmpServer t . withSmpServerOn t testPort2 $ testSwitchAsync initAgentServers2
testMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
testMatrix2 t runTest = do
@@ -638,6 +648,76 @@ testAcceptContactAsync = do
baseId = 3
msgId = subtract baseId
testSwitchConnection :: InitialAgentServers -> IO ()
testSwitchConnection servers = do
a <- getSMPAgentClient agentCfg servers
b <- getSMPAgentClient agentCfg {database = testDB2} servers
Right () <- runExceptT $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
switchConnectionAsync a "" bId
phase a bId SPStarted
phase b aId SPStarted
phase a bId SPConfirmed
phase b aId SPConfirmed
phase a bId SPTested
phase b aId SPTested
phase b aId SPCompleted
phase a bId SPCompleted
exchangeGreetingsMsgId 12 a bId b aId
pure ()
phase :: AgentClient -> ByteString -> SwitchPhase -> ExceptT AgentErrorType IO ()
phase c connId p =
get c >>= \(_, connId', msg) -> do
liftIO $ connId `shouldBe` connId'
case msg of
SWITCH p' _ -> liftIO $ p `shouldBe` p'
ERR (AGENT A_DUPLICATE) -> phase c connId p
r -> do
liftIO . putStrLn $ "expected: " <> show p <> ", received: " <> show r
SWITCH _ _ <- pure r
pure ()
testSwitchAsync :: InitialAgentServers -> IO ()
testSwitchAsync servers = do
Right (aId, bId) <- withA $ \a -> withB $ \b -> runExceptT $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
pure (aId, bId)
let phaseA = withA . phase' bId
phaseB = withB . phase' aId
Right () <- withA $ \a -> runExceptT $ do
subscribeConnection a bId
switchConnectionAsync a "" bId
phase a bId SPStarted
liftIO $ threadDelay 500000
phaseB SPStarted
phaseA SPConfirmed
phaseB SPConfirmed
phaseA SPTested
Right () <- withB $ \b -> runExceptT $ do
subscribeConnection b aId
phase b aId SPTested
phase b aId SPCompleted
phaseA SPCompleted
Right () <- withA $ \a -> withB $ \b -> runExceptT $ do
subscribeConnection a bId
subscribeConnection b aId
exchangeGreetingsMsgId 12 a bId b aId
pure ()
where
withAgent :: AgentConfig -> (AgentClient -> IO a) -> IO a
withAgent cfg' = bracket (getSMPAgentClient cfg' servers) disconnectAgentClient
phase' connId p c = do
Right () <- runExceptT $ do
subscribeConnection c connId
phase c connId p
liftIO $ threadDelay 500000
pure ()
withA = withAgent agentCfg
withB = withAgent agentCfg {database = testDB2, initialClientId = 1}
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings = exchangeGreetingsMsgId 4

View File

@@ -2,6 +2,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
@@ -114,7 +115,7 @@ testConcurrentWrites =
runTest st connId = replicateM_ 100 . withTransaction st $ \db -> do
(internalId, internalRcvId, _, _) <- updateRcvIds db connId
let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy"
createRcvMsg db connId rcvMsgData
createRcvMsg db connId rcvQueue1 rcvMsgData
testCompiledThreadsafe :: SpecWith SQLiteStore
testCompiledThreadsafe =
@@ -153,14 +154,19 @@ testDhSecret = "01234567890123456789012345678901"
rcvQueue1 :: RcvQueue
rcvQueue1 =
RcvQueue
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
{ connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
rcvId = "1234",
rcvPrivateKey = testPrivateSignKey,
rcvDhSecret = testDhSecret,
e2ePrivKey = testPrivDhKey,
e2eDhSecret = Nothing,
sndId = Just "2345",
sndId = "2345",
status = New,
dbQueueId = 1,
primary = True,
nextPrimary = False,
dbReplaceQueueId = Nothing,
smpClientVersion = 1,
clientNtfCreds = Nothing
}
@@ -168,13 +174,18 @@ rcvQueue1 =
sndQueue1 :: SndQueue
sndQueue1 =
SndQueue
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
{ connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
sndId = "3456",
sndPublicKey = Nothing,
sndPrivateKey = testPrivateSignKey,
e2ePubKey = Nothing,
e2eDhSecret = testDhSecret,
status = New,
dbQueueId = 1,
primary = True,
nextPrimary = False,
dbReplaceQueueId = Nothing,
smpClientVersion = 1
}
@@ -187,21 +198,23 @@ testCreateRcvConn =
getConn db "conn1"
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1))
upgradeRcvConnToDuplex db "conn1" sndQueue1
`shouldReturn` Right ()
`shouldReturn` Right 1
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
testCreateRcvConnRandomId :: SpecWith SQLiteStore
testCreateRcvConnRandomId =
it "should create RcvConnection and add SndQueue with random ID" . withStoreTransaction $ \db -> do
g <- newTVarIO =<< drgNew
Right connId <- createRcvConn db g cData1 {connId = ""} rcvQueue1 SCMInvitation
let rq' = (rcvQueue1 :: RcvQueue) {connId}
sq' = (sndQueue1 :: SndQueue) {connId}
getConn db connId
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1))
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rq'))
upgradeRcvConnToDuplex db connId sndQueue1
`shouldReturn` Right ()
`shouldReturn` Right 1
getConn db connId
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq'] [sq']))
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
testCreateRcvConnDuplicate =
@@ -220,21 +233,23 @@ testCreateSndConn =
getConn db "conn1"
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1))
upgradeSndConnToDuplex db "conn1" rcvQueue1
`shouldReturn` Right ()
`shouldReturn` Right 1
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
testCreateSndConnRandomID :: SpecWith SQLiteStore
testCreateSndConnRandomID =
it "should create SndConnection and add RcvQueue with random ID" . withStoreTransaction $ \db -> do
g <- newTVarIO =<< drgNew
Right connId <- createSndConn db g cData1 {connId = ""} sndQueue1
let rq' = (rcvQueue1 :: RcvQueue) {connId}
sq' = (sndQueue1 :: SndQueue) {connId}
getConn db connId
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1))
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sq'))
upgradeSndConnToDuplex db connId rcvQueue1
`shouldReturn` Right ()
`shouldReturn` Right 1
getConn db connId
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq'] [sq']))
testCreateSndConnDuplicate :: SpecWith SQLiteStore
testCreateSndConnDuplicate =
@@ -287,12 +302,12 @@ testDeleteDuplexConn =
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
deleteConn db "conn1"
`shouldReturn` ()
-- TODO check queues are deleted as well
getConn db "conn1"
`shouldReturn` Left (SEConnNotFound)
`shouldReturn` Left SEConnNotFound
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
testUpgradeRcvConnToDuplex =
@@ -301,13 +316,18 @@ testUpgradeRcvConnToDuplex =
_ <- createSndConn db g cData1 sndQueue1
let anotherSndQueue =
SndQueue
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
{ connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
sndId = "2345",
sndPublicKey = Nothing,
sndPrivateKey = testPrivateSignKey,
e2ePubKey = Nothing,
e2eDhSecret = testDhSecret,
status = New,
dbQueueId = 1,
primary = True,
nextPrimary = False,
dbReplaceQueueId = Nothing,
smpClientVersion = 1
}
upgradeRcvConnToDuplex db "conn1" anotherSndQueue
@@ -323,14 +343,19 @@ testUpgradeSndConnToDuplex =
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
let anotherRcvQueue =
RcvQueue
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
{ connId = "conn1",
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
rcvId = "3456",
rcvPrivateKey = testPrivateSignKey,
rcvDhSecret = testDhSecret,
e2ePrivKey = testPrivDhKey,
e2eDhSecret = Nothing,
sndId = Just "4567",
sndId = "4567",
status = New,
dbQueueId = 1,
primary = True,
nextPrimary = False,
dbReplaceQueueId = Nothing,
smpClientVersion = 1,
clientNtfCreds = Nothing
}
@@ -371,15 +396,17 @@ testSetQueueStatusDuplex =
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
setRcvQueueStatus db rcvQueue1 Secured
`shouldReturn` ()
let rq' = (rcvQueue1 :: RcvQueue) {status = Secured}
sq' = (sndQueue1 :: SndQueue) {status = Confirmed}
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq'] [sndQueue1]))
setSndQueueStatus db sndQueue1 Confirmed
`shouldReturn` ()
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}))
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq'] [sq']))
hw :: ByteString
hw = encodeUtf8 "Hello world!"
@@ -405,12 +432,12 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
externalPrevSndHash = "hash_from_sender"
}
testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation
testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do
testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvQueue -> RcvMsgData -> Expectation
testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rq rcvMsgData@RcvMsgData {..} = do
let MsgMeta {recipient = (internalId, _)} = msgMeta
updateRcvIds db connId
`shouldReturn` (InternalId internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
createRcvMsg db connId rcvMsgData
createRcvMsg db connId rq rcvMsgData
`shouldReturn` ()
testCreateRcvMsg :: SpecWith SQLiteStore
@@ -421,8 +448,8 @@ testCreateRcvMsg =
_ <- withTransaction st $ \db -> do
createRcvConn db g cData1 rcvQueue1 SCMInvitation
withTransaction st $ \db -> do
testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
testCreateRcvMsg_ db 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
testCreateRcvMsg_ db 0 "" connId rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
testCreateRcvMsg_ db 1 "hash_dummy" connId rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
mkSndMsgData internalId internalSndId internalHash =
@@ -464,9 +491,9 @@ testCreateRcvAndSndMsgs =
createRcvConn db g cData1 rcvQueue1 SCMInvitation
withTransaction st $ \db -> do
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
testCreateRcvMsg_ db 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
testCreateRcvMsg_ db 0 "" connId rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
testCreateRcvMsg_ db 1 "rcv_hash_1" connId rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
testCreateSndMsg_ db "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
testCreateRcvMsg_ db 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
testCreateRcvMsg_ db 2 "rcv_hash_2" connId rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
testCreateSndMsg_ db "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
testCreateSndMsg_ db "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"

View File

@@ -1,9 +1,9 @@
module CoreTests.ProtocolErrorTests where
import Simplex.Messaging.Agent.Protocol (AgentErrorType)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Protocol (ErrorType)
import Simplex.Messaging.Encoding.String
import Test.Hspec
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck

View File

@@ -1,6 +1,7 @@
{-# LANGUAGE TypeApplications #-}
import AgentTests (agentTests)
-- import Control.Logger.Simple
import CoreTests.CryptoTests
import CoreTests.EncodingTests
import CoreTests.ProtocolErrorTests
@@ -13,8 +14,13 @@ import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
import System.Environment (setEnv)
import Test.Hspec
-- logCfg :: LogConfig
-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
main :: IO ()
main = do
-- setLogLevel LogInfo -- LogError
-- withGlobalLogging logCfg $ do
createDirectoryIfMissing False "tests/tmp"
setEnv "APNS_KEY_ID" "H82WD9K9AQ"
setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"