mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 14:16:00 +00:00
connection queue redundancy and rotation (#521)
* rfc: queue rotation * update rfc * messages for queue rotation * allow multiple subscribed queues per connection in Agent/Client.hs * refactor * fix module name * allow multiple queues in duplex connection type * update commands * add indices * addConnectionRcvQueue * switch connection to another queue (WIP) * update schema/protocol * switching queue works, but sending messages after the switch fails * messages are delivered after rotation * use connection-scoped queue ID * rename queue records fields * refactor using SMPQueue class/instances * simplify queries * QKEY: check queue is not secured, refactor * update rfc * mark queue as primary in QUSE * queue rotation errors * fix async ack * fix async ACK to send OK * correction Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> * use SWCH command * rename * take into account only active queue subscription when determining connection result if at least one queue is active * remove comment * only enable notifications for connections with enableNtfs = True * async test (WIP) * async queue rotation test * simplify combining results * test with 2 servers * fix unused subscribeConnection * switch to cabal build * increase build timeout * increase delay in async test * skip queue rotation tests * build matrix * step name * use ubuntu-18.04 in build matrix * enable rotation tests Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
19aef52135
commit
eb5c1c78cb
57
.github/workflows/build.yml
vendored
57
.github/workflows/build.yml
vendored
@@ -11,35 +11,51 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-20.04
|
||||
name: build-${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-18.04
|
||||
- os: ubuntu-20.04
|
||||
steps:
|
||||
- name: Clone project
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Setup Stack
|
||||
- name: Setup Haskell
|
||||
uses: haskell/actions/setup@v1
|
||||
with:
|
||||
ghc-version: "8.10.7"
|
||||
enable-stack: true
|
||||
stack-version: "latest"
|
||||
cabal-version: "latest"
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.stack
|
||||
key: ${{ hashFiles('stack.yaml') }}
|
||||
path: |
|
||||
~/.cabal/store
|
||||
dist-newstyle
|
||||
key: ${{ matrix.os }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }}
|
||||
|
||||
- name: Build & test
|
||||
id: build_test
|
||||
- name: Build
|
||||
shell: bash
|
||||
run: cabal build --enable-tests
|
||||
|
||||
- name: Test
|
||||
if: matrix.os == 'ubuntu-18.04'
|
||||
timeout-minutes: 30
|
||||
shell: bash
|
||||
run: cabal test --test-show-details=direct
|
||||
|
||||
- name: Prepare binaries
|
||||
if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04'
|
||||
shell: bash
|
||||
run: |
|
||||
stack build --test --force-dirty
|
||||
install_root=$(stack path --local-install-root)
|
||||
mv ${install_root}/bin/smp-server smp-server-ubuntu-20_04-x86-64
|
||||
mv ${install_root}/bin/ntf-server ntf-server-ubuntu-20_04-x86-64
|
||||
mv $(cabal list-bin smp-server) smp-server-ubuntu-20_04-x86-64
|
||||
mv $(cabal list-bin ntf-server) ntf-server-ubuntu-20_04-x86-64
|
||||
|
||||
- name: Build changelog
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04'
|
||||
id: build_changelog
|
||||
uses: mikepenz/release-changelog-builder-action@v1
|
||||
with:
|
||||
@@ -50,19 +66,8 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract release candidate
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
id: extract_release_candidate
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ ${GITHUB_REF} == *rc* ]]; then
|
||||
echo "::set-output name=release_candidate::true"
|
||||
else
|
||||
echo "::set-output name=release_candidate::false"
|
||||
fi
|
||||
|
||||
- name: Create release
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04'
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
body: |
|
||||
@@ -70,7 +75,7 @@ jobs:
|
||||
|
||||
Commits:
|
||||
${{ steps.build_changelog.outputs.changelog }}
|
||||
prerelease: ${{ steps.extract_release_candidate.outputs.release_candidate }}
|
||||
prerelease: true
|
||||
files: |
|
||||
LICENSE
|
||||
smp-server-ubuntu-20_04-x86-64
|
||||
|
||||
62
rfcs/2022-08-14-queue-rotation.md
Normal file
62
rfcs/2022-08-14-queue-rotation.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# SMP queue rotation and redundancy
|
||||
|
||||
## Problem
|
||||
|
||||
1. Long term usage of the queue allows the servers to analyze long term traffic patterns.
|
||||
|
||||
2. If the user changes the configured SMP server(s), the previously created contacts do not migrate to the new server(s).
|
||||
|
||||
3. Server can lose messages.
|
||||
|
||||
## Solution
|
||||
|
||||
Additional messages exchanged by SMP agents to negotiate addition and removal of queues to the connection.
|
||||
|
||||
This approach allows for both rotation and redundancy, as changing queue will be done be adding a queue and then removing the existing queue.
|
||||
|
||||
The reason for this approach is that otherwise it's non-trivial to switch from one queue to another without losing messages or delivering them out of order, it's easier to have messages delivered via both queues during the switch, however short or long time it is.
|
||||
|
||||
### Messages
|
||||
|
||||
Additional agent messages required for the protocol:
|
||||
|
||||
QADD_ -> "QA"
|
||||
QKEY_ -> "QK"
|
||||
QUSE_ -> "QU"
|
||||
QTEST_ -> "QT"
|
||||
QDEL_ -> "QD"
|
||||
QEND_ -> "QE"
|
||||
|
||||
`QADD`: add the new queue address(es), the same format as `REPLY` message, encoded as `QA`.
|
||||
|
||||
`QKEY`: pass sender's key via existing connection (SMP confirmation message will not be used, to avoid the same "race" of the initial key exchange that would create the risk of intercepting the queue for the attacker), encoded as `QK`.
|
||||
|
||||
`QUSE`: instruct the sender to use the new queue with sender's queue ID as parameter, encoded as `QU`.
|
||||
|
||||
`QTEST`: send test message to the new connection, encoded as `QT`. Any other message can be sent if available to continue rotation, the absence of this message is not an error.
|
||||
|
||||
`QDEL`: instruct the sender to stop using the previous queue, encoded as `QD`
|
||||
|
||||
`QEND`: notify the recipient that no messages will be sent to this queue, encoded as `QE`. The recipient will delete this queue.
|
||||
|
||||
### Protocol
|
||||
|
||||
```
|
||||
participant A as Alice
|
||||
participant B as Bob
|
||||
participant R as Server that has A's receive queue
|
||||
participant S as Server that has A's send queue (B's receive queue)
|
||||
participant R' as Server that hosts the new A's receive queue
|
||||
|
||||
A ->> R': create new queue
|
||||
A ->> S ->> B: QADD (R'): snd address of the new queue(s)
|
||||
B ->> A(R) ->> A: QKEY (R'): sender's key for the new queue(s) (to avoid the race of SMP confirmation for the initial exchange)
|
||||
A ->> S(R'): secure new queue
|
||||
A ->> S ->> B: QUSE (R'): instruction to use new queue(s)
|
||||
B ->> A(R,R') ->> A: QTEST
|
||||
B ->> A(R,R') ->> A: send all new messages to both queues
|
||||
A ->> S ->> B: QDEL (R): instruction to delete the old queue
|
||||
B ->> A(R') -> A: QEND (R): notification that no messages will be sent to the old queue
|
||||
B ->> R' ->> A: send all new messages to the new queue only
|
||||
A ->> S(R): DEL: delete the previous queue
|
||||
```
|
||||
@@ -54,6 +54,8 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
|
||||
Simplex.Messaging.Agent.TRcvQueues
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
Simplex.Messaging.Crypto
|
||||
@@ -83,7 +85,6 @@ library
|
||||
Simplex.Messaging.Server.Stats
|
||||
Simplex.Messaging.Server.StoreLog
|
||||
Simplex.Messaging.TMap
|
||||
Simplex.Messaging.TMap2
|
||||
Simplex.Messaging.Transport
|
||||
Simplex.Messaging.Transport.Client
|
||||
Simplex.Messaging.Transport.HTTP2
|
||||
@@ -350,6 +351,7 @@ test-suite smp-server-test
|
||||
AgentTests.NotificationTests
|
||||
AgentTests.SchemaDump
|
||||
AgentTests.SQLiteTests
|
||||
CoreTests.CryptoTests
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
CoreTests.VersionRangeTests
|
||||
|
||||
@@ -43,6 +43,7 @@ module Simplex.Messaging.Agent
|
||||
allowConnectionAsync,
|
||||
acceptContactAsync,
|
||||
ackMessageAsync,
|
||||
switchConnectionAsync,
|
||||
deleteConnectionAsync,
|
||||
createConnection,
|
||||
joinConnection,
|
||||
@@ -57,6 +58,7 @@ module Simplex.Messaging.Agent
|
||||
resubscribeConnections,
|
||||
sendMessage,
|
||||
ackMessage,
|
||||
switchConnection,
|
||||
suspendConnection,
|
||||
deleteConnection,
|
||||
getConnectionServers,
|
||||
@@ -89,9 +91,10 @@ import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:), (.:.), (.::))
|
||||
import Data.Foldable (foldl')
|
||||
import Data.Functor (($>))
|
||||
import Data.List (deleteFirstsBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.List (deleteFirstsBy, find)
|
||||
import Data.List.NonEmpty (NonEmpty (..), (<|))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
@@ -117,7 +120,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, sameSrvAddr)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util
|
||||
@@ -171,6 +174,10 @@ acceptContactAsync c corrId enableNtfs = withAgentEnv c .: acceptContactAsync' c
|
||||
ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
|
||||
ackMessageAsync c = withAgentEnv c .:. ackMessageAsync' c
|
||||
|
||||
-- | Switch connection to the new receive queue
|
||||
switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c
|
||||
|
||||
-- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response
|
||||
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c
|
||||
@@ -224,6 +231,10 @@ sendMessage c = withAgentEnv c .:. sendMessage' c
|
||||
ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> m ()
|
||||
ackMessage c = withAgentEnv c .: ackMessage' c
|
||||
|
||||
-- | Switch connection to the new receive queue
|
||||
switchConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
switchConnection c = withAgentEnv c . switchConnection' c
|
||||
|
||||
-- | Suspend SMP agent connection (OFF command)
|
||||
suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
suspendConnection c = withAgentEnv c . suspendConnection' c
|
||||
@@ -328,6 +339,7 @@ processCommand c (connId, cmd) = case cmd of
|
||||
SUB -> subscribeConnection' c connId $> (connId, OK)
|
||||
SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody
|
||||
ACK msgId -> ackMessage' c connId msgId $> (connId, OK)
|
||||
SWCH -> switchConnection' c connId $> (connId, OK)
|
||||
OFF -> suspendConnection' c connId $> (connId, OK)
|
||||
DEL -> deleteConnection' c connId $> (connId, OK)
|
||||
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
|
||||
@@ -375,22 +387,25 @@ acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
|
||||
ackMessageAsync' c corrId connId msgId =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> enqueueAck rq
|
||||
SomeConn _ (RcvConnection _ rq) -> enqueueAck rq
|
||||
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
|
||||
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
|
||||
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
|
||||
ackMessageAsync' c corrId connId msgId = do
|
||||
SomeConn cType _ <- withStore c (`getConn` connId)
|
||||
case cType of
|
||||
SCDuplex -> enqueueAck
|
||||
SCRcv -> enqueueAck
|
||||
SCSnd -> throwError $ CONN SIMPLEX
|
||||
SCContact -> throwError $ CMD PROHIBITED
|
||||
SCNew -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
enqueueAck :: RcvQueue -> m ()
|
||||
enqueueAck RcvQueue {server} =
|
||||
enqueueCommand c corrId connId (Just server) $ AClientCommand $ ACK msgId
|
||||
enqueueAck :: m ()
|
||||
enqueueAck = do
|
||||
(RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId
|
||||
enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId
|
||||
|
||||
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
deleteConnectionAsync' c@AgentClient {subQ} corrId connId =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> enqueueDelete rq
|
||||
-- TODO *** delete all queues
|
||||
SomeConn _ (DuplexConnection _ (rq :| _) _) -> enqueueDelete rq
|
||||
SomeConn _ (RcvConnection _ rq) -> enqueueDelete rq
|
||||
SomeConn _ (ContactConnection _ rq) -> enqueueDelete rq
|
||||
SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId) >> notifyDeleted
|
||||
@@ -402,6 +417,13 @@ deleteConnectionAsync' c@AgentClient {subQ} corrId connId =
|
||||
notifyDeleted :: m ()
|
||||
notifyDeleted = atomically $ writeTBQueue subQ (corrId, connId, OK)
|
||||
|
||||
-- | Add connection to the new receive queue
|
||||
switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ()
|
||||
switchConnectionAsync' c corrId connId =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c connId asyncMode enableNtfs cMode =
|
||||
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode
|
||||
@@ -409,9 +431,10 @@ newConn c connId asyncMode enableNtfs cMode =
|
||||
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c connId asyncMode enableNtfs cMode srv = do
|
||||
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
|
||||
(rq, qUri) <- newRcvQueue c srv smpClientVRange
|
||||
connId' <- setUpConn asyncMode rq $ maxVersion smpAgentVRange
|
||||
addSubscription c rq connId'
|
||||
(q, qUri) <- newRcvQueue c "" srv smpClientVRange
|
||||
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
|
||||
let rq = (q :: RcvQueue) {connId = connId'}
|
||||
addSubscription c rq
|
||||
when enableNtfs $ do
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId', NSCCreate)
|
||||
@@ -424,7 +447,7 @@ newConnSrv c connId asyncMode enableNtfs cMode srv = do
|
||||
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
|
||||
where
|
||||
setUpConn True rq _ = do
|
||||
withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
void . withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
pure connId
|
||||
setUpConn False rq connAgentVersion = do
|
||||
g <- asks idsDrg
|
||||
@@ -434,8 +457,8 @@ newConnSrv c connId asyncMode enableNtfs cMode srv = do
|
||||
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c connId asyncMode enableNtfs cReq cInfo = do
|
||||
srv <- case cReq of
|
||||
CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ ->
|
||||
getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)]
|
||||
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
|
||||
getNextSMPServer c [qServer q]
|
||||
_ -> getSMPServer c
|
||||
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
|
||||
|
||||
@@ -450,11 +473,12 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ age
|
||||
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
|
||||
(_, rcDHRs) <- liftIO C.generateKeyPair'
|
||||
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
|
||||
sq <- newSndQueue qInfo
|
||||
q <- newSndQueue "" qInfo
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
|
||||
connId' <- setUpConn asyncMode cData sq rc
|
||||
let cData' = (cData :: ConnData) {connId = connId'}
|
||||
connId' <- setUpConn asyncMode cData q rc
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
@@ -466,7 +490,7 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ age
|
||||
where
|
||||
setUpConn True _ sq rc =
|
||||
withStore c $ \db -> runExceptT $ do
|
||||
ExceptT $ updateNewConnSnd db connId sq
|
||||
void . ExceptT $ updateNewConnSnd db connId sq
|
||||
liftIO $ createRatchet db connId rc
|
||||
pure connId
|
||||
setUpConn False cData sq rc = do
|
||||
@@ -492,10 +516,10 @@ joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion
|
||||
(rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion
|
||||
let qInfo = toVersionT qUri smpClientVersion
|
||||
addSubscription c rq connId
|
||||
withStore c $ \db -> upgradeSndConnToDuplex db connId rq
|
||||
addSubscription c rq
|
||||
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
|
||||
when enableNtfs $ do
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
|
||||
@@ -538,24 +562,27 @@ subscribeConnection' c connId =
|
||||
withStore c (`getConn` connId) >>= \conn -> do
|
||||
resumeConnCmds c connId
|
||||
case conn of
|
||||
SomeConn _ (DuplexConnection cData rq sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
subscribe rq
|
||||
SomeConn _ (DuplexConnection cData (rq :| rqs) sqs) -> do
|
||||
mapM_ (resumeMsgDelivery c cData) sqs
|
||||
subscribe cData rq
|
||||
mapM_ (\q -> subscribeQueue c q `catchError` \_ -> pure ()) rqs
|
||||
SomeConn _ (SndConnection cData sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
case status (sq :: SndQueue) of
|
||||
Confirmed -> pure ()
|
||||
Active -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ INTERNAL "unexpected queue status"
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (ContactConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (RcvConnection cData rq) -> subscribe cData rq
|
||||
SomeConn _ (ContactConnection cData rq) -> subscribe cData rq
|
||||
SomeConn _ (NewConnection _) -> pure ()
|
||||
where
|
||||
subscribe :: RcvQueue -> m ()
|
||||
subscribe rq = do
|
||||
subscribeQueue c rq connId
|
||||
subscribe :: ConnData -> RcvQueue -> m ()
|
||||
subscribe ConnData {enableNtfs} rq = do
|
||||
subscribeQueue c rq
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
|
||||
atomically $ sendNtfSubCommand ns (connId, if enableNtfs then NSCCreate else NSCDelete)
|
||||
|
||||
type QSubResult = (QueueStatus, Either AgentErrorType ())
|
||||
|
||||
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribeConnections' _ [] = pure M.empty
|
||||
@@ -564,42 +591,61 @@ subscribeConnections' c connIds = do
|
||||
let (errs, cs) = M.mapEither id conns
|
||||
errs' = M.map (Left . storeError) errs
|
||||
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
|
||||
srvRcvQs :: Map SMPServer (Map ConnId RcvQueue) = M.foldlWithKey' addRcvQueue M.empty rcvQs
|
||||
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
|
||||
srvRcvQs :: Map SMPServer [RcvQueue] = M.foldl' (foldl' addRcvQueue) M.empty rcvQs
|
||||
mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs
|
||||
mapM_ (resumeConnCmds c) $ M.keys cs
|
||||
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
|
||||
rcvRs <- connResults . concat <$> mapConcurrently subscribe (M.assocs srvRcvQs)
|
||||
ns <- asks ntfSupervisor
|
||||
tkn <- readTVarIO (ntfTkn ns)
|
||||
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs
|
||||
let rs = M.unions $ errs' : subRs : rcvRs
|
||||
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns
|
||||
let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())])
|
||||
notifyResultError rs
|
||||
pure rs
|
||||
where
|
||||
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) RcvQueue
|
||||
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) (NonEmpty RcvQueue)
|
||||
rcvQueueOrResult = \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> Right rq
|
||||
SomeConn _ (DuplexConnection _ rqs _) -> Right rqs
|
||||
SomeConn _ (SndConnection _ sq) -> Left $ sndSubResult sq
|
||||
SomeConn _ (RcvConnection _ rq) -> Right rq
|
||||
SomeConn _ (ContactConnection _ rq) -> Right rq
|
||||
SomeConn _ (RcvConnection _ rq) -> Right [rq]
|
||||
SomeConn _ (ContactConnection _ rq) -> Right [rq]
|
||||
SomeConn _ (NewConnection _) -> Left (Right ())
|
||||
sndSubResult :: SndQueue -> Either AgentErrorType ()
|
||||
sndSubResult sq = case status (sq :: SndQueue) of
|
||||
Confirmed -> Right ()
|
||||
Active -> Left $ CONN SIMPLEX
|
||||
_ -> Left $ INTERNAL "unexpected queue status"
|
||||
addRcvQueue :: Map SMPServer (Map ConnId RcvQueue) -> ConnId -> RcvQueue -> Map SMPServer (Map ConnId RcvQueue)
|
||||
addRcvQueue m connId rq@RcvQueue {server} = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
|
||||
subscribe :: (SMPServer, Map ConnId RcvQueue) -> m (Map ConnId (Either AgentErrorType ()))
|
||||
addRcvQueue :: Map SMPServer [RcvQueue] -> RcvQueue -> Map SMPServer [RcvQueue]
|
||||
addRcvQueue m rq@RcvQueue {server} = M.alter (Just . maybe [rq] (rq :)) server m
|
||||
subscribe :: (SMPServer, [RcvQueue]) -> m [(RcvQueue, Either AgentErrorType ())]
|
||||
subscribe (srv, qs) = snd <$> subscribeQueues c srv qs
|
||||
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
|
||||
sendNtfCreate ns rcvRs =
|
||||
forM_ (concatMap M.assocs rcvRs) $ \case
|
||||
(connId, Right _) -> atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate)
|
||||
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
|
||||
connResults = M.map snd . foldl' addResult M.empty
|
||||
where
|
||||
-- collects results by connection ID
|
||||
addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QSubResult
|
||||
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
|
||||
-- combines two results for one connection, by using only Active queues (if there is at least one Active queue)
|
||||
combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult
|
||||
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
|
||||
combineRes r' _ = Just r'
|
||||
order :: QSubResult -> Int
|
||||
order (Active, Right _) = 1
|
||||
order (Active, _) = 2
|
||||
order (_, Right _) = 3
|
||||
order _ = 4
|
||||
sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> m ()
|
||||
sendNtfCreate ns rcvRs conns =
|
||||
forM_ (M.assocs rcvRs) $ \case
|
||||
(connId, Right _) -> forM_ (M.lookup connId conns) $ \case
|
||||
Right (SomeConn _ conn) -> do
|
||||
let cmd = if enableNtfs $ connData conn then NSCCreate else NSCDelete
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd)
|
||||
_ -> pure ()
|
||||
_ -> pure ()
|
||||
sndQueue :: SomeConn -> Maybe (ConnData, SndQueue)
|
||||
sndQueue :: SomeConn -> Maybe (ConnData, NonEmpty SndQueue)
|
||||
sndQueue = \case
|
||||
SomeConn _ (DuplexConnection cData _ sq) -> Just (cData, sq)
|
||||
SomeConn _ (SndConnection cData sq) -> Just (cData, sq)
|
||||
SomeConn _ (DuplexConnection cData _ sqs) -> Just (cData, sqs)
|
||||
SomeConn _ (SndConnection cData sq) -> Just (cData, [sq])
|
||||
_ -> Nothing
|
||||
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
|
||||
notifyResultError rs = do
|
||||
@@ -626,7 +672,7 @@ getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMs
|
||||
getConnectionMessage' c connId = do
|
||||
whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> getQueueMessage c rq
|
||||
SomeConn _ (DuplexConnection _ (rq :| _) _) -> getQueueMessage c rq
|
||||
SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq
|
||||
SomeConn _ (ContactConnection _ rq) -> getQueueMessage c rq
|
||||
SomeConn _ SndConnection {} -> throwError $ CONN SIMPLEX
|
||||
@@ -663,12 +709,12 @@ getNotificationMessage' c nonce encNtfInfo = do
|
||||
sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
|
||||
sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection cData _ sq) -> enqueueMsg cData sq
|
||||
SomeConn _ (SndConnection cData sq) -> enqueueMsg cData sq
|
||||
SomeConn _ (DuplexConnection cData _ sqs) -> enqueueMsgs cData sqs
|
||||
SomeConn _ (SndConnection cData sq) -> enqueueMsgs cData [sq]
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
enqueueMsg :: ConnData -> SndQueue -> m AgentMsgId
|
||||
enqueueMsg cData sq = enqueueMessage c cData sq msgFlags $ A_MSG msg
|
||||
enqueueMsgs :: ConnData -> NonEmpty SndQueue -> m AgentMsgId
|
||||
enqueueMsgs cData sqs = enqueueMessages c cData sqs msgFlags $ A_MSG msg
|
||||
|
||||
-- / async command processing v v v
|
||||
|
||||
@@ -713,8 +759,8 @@ getPendingCommandQ c server = do
|
||||
pure cq
|
||||
|
||||
runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
|
||||
runCommandProcessing c@AgentClient {subQ} server = do
|
||||
cq <- atomically $ getPendingCommandQ c server
|
||||
runCommandProcessing c@AgentClient {subQ} server_ = do
|
||||
cq <- atomically $ getPendingCommandQ c server_
|
||||
ri <- asks $ messageRetryInterval . config -- different retry interval?
|
||||
forever $ do
|
||||
atomically $ endAgentOperation c AOSndNetwork
|
||||
@@ -726,73 +772,108 @@ runCommandProcessing c@AgentClient {subQ} server = do
|
||||
Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd
|
||||
where
|
||||
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m ()
|
||||
processCmd ri corrId connId cmdId = \case
|
||||
processCmd ri corrId connId cmdId command = case command of
|
||||
AClientCommand cmd -> case cmd of
|
||||
NEW enableNtfs (ACM cMode) -> do
|
||||
NEW enableNtfs (ACM cMode) -> noServer $ do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
|
||||
(_, cReq) <- newConnSrv c connId True enableNtfs cMode srv
|
||||
notify $ INV (ACR cMode cReq)
|
||||
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do
|
||||
let initUsed = [smpServer (queueAddress :: SMPQueueAddress)]
|
||||
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do
|
||||
let initUsed = [qServer q]
|
||||
usedSrvs <- newTVarIO initUsed
|
||||
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
|
||||
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
|
||||
notify OK
|
||||
LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
|
||||
ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK
|
||||
DEL -> tryCommand $ deleteConnection' c connId >> notify OK
|
||||
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
|
||||
ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK
|
||||
SWCH -> noServer $ tryCommand $ switchConnection' c connId >>= notify . SWITCH SPStarted
|
||||
DEL -> withServer' . tryCommand $ deleteConnection' c connId >> notify OK
|
||||
_ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd)
|
||||
AInternalCommand cmd -> case server of
|
||||
Just _srv -> case cmd of
|
||||
ICAckDel _rId srvMsgId msgId -> tryWithLock "ICAckDel" $ ack _rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId)
|
||||
ICAck _rId srvMsgId -> tryWithLock "ICAck" $ ack _rId srvMsgId
|
||||
ICAllowSecure _rId senderKey -> tryWithLock "ICAllowSecure" $ do
|
||||
(SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <-
|
||||
withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId)
|
||||
case conn of
|
||||
RcvConnection cData rq -> do
|
||||
secure rq senderKey
|
||||
mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf)
|
||||
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
|
||||
ICDuplexSecure _rId senderKey -> tryWithLock "ICDuplexSecure" $ do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection cData rq sq -> do
|
||||
secure rq senderKey
|
||||
when (duplexHandshake cData == Just True) . void $
|
||||
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
|
||||
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
|
||||
_ -> throwError $ INTERNAL $ "command requires server " <> show (internalCmdTag cmd)
|
||||
AInternalCommand cmd -> case cmd of
|
||||
ICAckDel rId srvMsgId msgId -> withServer $ \srv -> tryWithLock "ICAckDel" $ ack srv rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId)
|
||||
ICAck rId srvMsgId -> withServer $ \srv -> tryWithLock "ICAck" $ ack srv rId srvMsgId
|
||||
ICAllowSecure _rId senderKey -> withServer' . tryWithLock "ICAllowSecure" $ do
|
||||
(SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <-
|
||||
withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId)
|
||||
case conn of
|
||||
RcvConnection cData rq -> do
|
||||
secure rq senderKey
|
||||
mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf)
|
||||
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
|
||||
ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do
|
||||
secure rq senderKey
|
||||
when (duplexHandshake cData == Just True) . void $
|
||||
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
|
||||
ICQSecure rId senderKey ->
|
||||
withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) ->
|
||||
case find (sameQueue (srv, rId)) rqs of
|
||||
Just rq'@RcvQueue {server, sndId, status} -> when (status == Confirmed) $ do
|
||||
secureQueue c rq' senderKey
|
||||
withStore' c $ \db -> setRcvQueueStatus db rq' Secured
|
||||
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QUSE [((server, sndId), True)]
|
||||
_ -> internalErr "ICQSecure: queue address not found in connection"
|
||||
ICQDelete rId -> do
|
||||
withServer $ \srv -> tryWithLock "ICQDelete" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> do
|
||||
case removeQ (srv, rId) rqs of
|
||||
Nothing -> internalErr "ICQDelete: queue address not found in connection"
|
||||
Just (rq'@RcvQueue {primary}, rq'' : rqs')
|
||||
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
|
||||
| otherwise -> do
|
||||
deleteQueue c rq'
|
||||
withStore' c $ \db -> deleteConnRcvQueue db connId rq'
|
||||
let conn' = DuplexConnection cData (rq'' :| rqs') sqs
|
||||
notify $ SWITCH SPCompleted $ connectionStats conn'
|
||||
_ -> internalErr "ICQDelete: cannot delete the only queue in connection"
|
||||
where
|
||||
ack _rId srvMsgId = do
|
||||
-- TODO get particular queue
|
||||
rq <- withStore c (`getRcvQueue` connId)
|
||||
ack srv rId srvMsgId = do
|
||||
rq <- withStore c $ \db -> getRcvQueue db connId srv rId
|
||||
ackQueueMessage c rq srvMsgId
|
||||
secure :: RcvQueue -> SMP.SndPublicVerifyKey -> m ()
|
||||
secure rq senderKey = do
|
||||
secureQueue c rq senderKey
|
||||
withStore' c $ \db -> setRcvQueueStatus db rq Secured
|
||||
where
|
||||
withServer a = case server_ of
|
||||
Just srv -> a srv
|
||||
_ -> internalErr "command requires server"
|
||||
withServer' = withServer . const
|
||||
noServer a = case server_ of
|
||||
Nothing -> a
|
||||
_ -> internalErr "command requires no server"
|
||||
withDuplexConn :: (Connection 'CDuplex -> m ()) -> m ()
|
||||
withDuplexConn a =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ conn@DuplexConnection {} -> a conn
|
||||
_ -> internalErr "command requires duplex connection"
|
||||
tryCommand action = withRetryInterval ri $ \loop ->
|
||||
tryError action >>= \case
|
||||
Left e
|
||||
| temporaryAgentError e || e == BROKER HOST -> retrySndOp c loop
|
||||
| otherwise -> notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
|
||||
| otherwise -> cmdError e
|
||||
Right () -> withStore' c (`deleteCommand` cmdId)
|
||||
tryWithLock name = tryCommand . withConnLock c connId name
|
||||
internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command)
|
||||
cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
|
||||
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd)
|
||||
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m ()
|
||||
withNextSrv usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srv <- getNextSMPServer c used
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
|
||||
writeTVar usedSrvs used'
|
||||
action srv
|
||||
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m ()
|
||||
withNextSrv usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srv <- getNextSMPServer c used
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
|
||||
writeTVar usedSrvs used'
|
||||
action srv
|
||||
-- ^ ^ ^ async command processing /
|
||||
|
||||
enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
|
||||
enqueueMessages c cData (sq :| sqs) msgFlags aMessage = do
|
||||
msgId <- enqueueMessage c cData sq msgFlags aMessage
|
||||
mapM_ (enqueueSavedMessage c cData msgId) $
|
||||
filter (\SndQueue {status} -> status == Secured || status == Active) sqs
|
||||
pure msgId
|
||||
|
||||
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
|
||||
enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do
|
||||
resumeMsgDelivery c cData sq
|
||||
@@ -813,20 +894,28 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage
|
||||
msgType = agentMessageType agentMsg
|
||||
msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash}
|
||||
liftIO $ createSndMsg db connId msgData
|
||||
liftIO $ createSndMsgDelivery db connId sq internalId
|
||||
pure internalId
|
||||
|
||||
enqueueSavedMessage :: AgentMonad m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m ()
|
||||
enqueueSavedMessage c cData@ConnData {connId} msgId sq = do
|
||||
resumeMsgDelivery c cData sq
|
||||
let mId = InternalId msgId
|
||||
queuePendingMsgs c sq [mId]
|
||||
withStore' c $ \db -> createSndMsgDelivery db connId sq mId
|
||||
|
||||
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
|
||||
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
|
||||
let qKey = (server, sndId)
|
||||
unlessM (queueDelivering qKey) $
|
||||
async (runSmpQueueMsgDelivery c cData sq)
|
||||
>>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c)
|
||||
unlessM connQueued $
|
||||
withStore' c (`getPendingMsgs` connId)
|
||||
unlessM msgsQueued $
|
||||
withStore' c (\db -> getPendingMsgs db connId sq)
|
||||
>>= queuePendingMsgs c sq
|
||||
where
|
||||
queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c)
|
||||
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c)
|
||||
msgsQueued = atomically $ isJust <$> TM.lookupInsert (server, sndId) True (pendingMsgsQueued c)
|
||||
|
||||
queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m ()
|
||||
queuePendingMsgs c sq msgIds = atomically $ do
|
||||
@@ -853,6 +942,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
forever $ do
|
||||
atomically $ endAgentOperation c AOSndNetwork
|
||||
atomically $ throwWhenInactive c
|
||||
atomically $ throwWhenNoDelivery c sq
|
||||
msgId <- atomically $ readTQueue mq
|
||||
atomically $ beginAgentOperation c AOSndNetwork
|
||||
atomically $ endAgentOperation c AOMsgDelivery
|
||||
@@ -890,6 +980,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
_ -> connError msgId NOT_ACCEPTED
|
||||
AM_REPLY_ -> notifyDel msgId $ ERR e
|
||||
AM_A_MSG_ -> notifyDel msgId $ MERR mId e
|
||||
AM_QADD_ -> pure ()
|
||||
AM_QKEY_ -> pure ()
|
||||
AM_QUSE_ -> pure ()
|
||||
AM_QTEST_ -> pure ()
|
||||
AM_QDEL_ -> pure ()
|
||||
AM_QEND_ -> pure ()
|
||||
_
|
||||
-- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server),
|
||||
-- the message sending would be retried
|
||||
@@ -910,6 +1006,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
when (isJust rq_) $ removeConfirmations db connId
|
||||
-- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification
|
||||
unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO
|
||||
AM_CONN_INFO_REPLY -> pure ()
|
||||
AM_REPLY_ -> pure ()
|
||||
AM_HELLO_ -> do
|
||||
withStore' c $ \db -> setSndQueueStatus db sq Active
|
||||
case rq_ of
|
||||
@@ -933,11 +1031,18 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
|
||||
AM_A_MSG_ -> notify $ SENT mId
|
||||
_ -> pure ()
|
||||
AM_QADD_ -> pure ()
|
||||
AM_QKEY_ -> pure ()
|
||||
AM_QUSE_ -> pure ()
|
||||
AM_QTEST_ ->
|
||||
withStore' c $ \db -> setSndQueueStatus db sq Active
|
||||
AM_QDEL_ -> pure ()
|
||||
AM_QEND_ ->
|
||||
getConnectionServers' c connId >>= notify . SWITCH SPCompleted
|
||||
delMsg msgId
|
||||
where
|
||||
delMsg :: InternalId -> m ()
|
||||
delMsg msgId = withStore' c $ \db -> deleteMsg db connId msgId
|
||||
delMsg msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId
|
||||
notify :: ACommand 'Agent -> m ()
|
||||
notify cmd = atomically $ writeTBQueue subQ ("", connId, cmd)
|
||||
notifyDel :: InternalId -> ACommand 'Agent -> m ()
|
||||
@@ -954,20 +1059,38 @@ retrySndOp c loop = do
|
||||
|
||||
ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> m ()
|
||||
ackMessage' c connId msgId = withConnLock c connId "ackMessage" $ do
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> ack rq
|
||||
SomeConn _ (RcvConnection _ rq) -> ack rq
|
||||
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
|
||||
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
|
||||
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection {} -> ack
|
||||
RcvConnection {} -> ack
|
||||
SndConnection {} -> throwError $ CONN SIMPLEX
|
||||
ContactConnection {} -> throwError $ CMD PROHIBITED
|
||||
NewConnection _ -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
ack :: RcvQueue -> m ()
|
||||
ack rq = do
|
||||
ack :: m ()
|
||||
ack = do
|
||||
let mId = InternalId msgId
|
||||
srvMsgId <- withStore c $ \db -> setMsgUserAck db connId mId
|
||||
(rq, srvMsgId) <- withStore c $ \db -> setMsgUserAck db connId mId
|
||||
ackQueueMessage c rq srvMsgId
|
||||
withStore' c $ \db -> deleteMsg db connId mId
|
||||
|
||||
switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
switchConnection' c connId = withConnLock c connId "switchConnection" $ do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
case conn of
|
||||
DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
-- try to get the server that is different from all queues, or at least from the primary rcv queue
|
||||
srv <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
|
||||
srv' <- if srv == server then getNextSMPServer c [server] else pure srv
|
||||
(q, qUri) <- newRcvQueue c connId srv' clientVRange
|
||||
let rq' = (q :: RcvQueue) {primary = False, nextPrimary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
void . withStore c $ \db -> addConnRcvQueue db connId rq'
|
||||
addSubscription c rq'
|
||||
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))]
|
||||
pure . connectionStats $ DuplexConnection cData (rq <| rq' :| rqs_) sqs
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
ackQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> SMP.MsgId -> m ()
|
||||
ackQueueMessage c rq srvMsgId =
|
||||
sendAck c rq srvMsgId `catchError` \case
|
||||
@@ -978,7 +1101,7 @@ ackQueueMessage c rq srvMsgId =
|
||||
suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> suspendQueue c rq
|
||||
SomeConn _ (DuplexConnection _ rqs _) -> mapM_ (suspendQueue c) rqs
|
||||
SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq
|
||||
SomeConn _ (ContactConnection _ rq) -> suspendQueue c rq
|
||||
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
|
||||
@@ -988,7 +1111,7 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
|
||||
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> delete rq
|
||||
SomeConn _ (DuplexConnection _ rqs _) -> mapM_ delete rqs
|
||||
SomeConn _ (RcvConnection _ rq) -> delete rq
|
||||
SomeConn _ (ContactConnection _ rq) -> delete rq
|
||||
SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId)
|
||||
@@ -1003,15 +1126,17 @@ deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete)
|
||||
|
||||
getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
|
||||
getConnectionServers' c connId = connServers <$> withStore c (`getConn` connId)
|
||||
where
|
||||
connServers :: SomeConn -> ConnectionStats
|
||||
connServers = \case
|
||||
SomeConn _ (RcvConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []}
|
||||
SomeConn _ (SndConnection _ SndQueue {server}) -> ConnectionStats {rcvServers = [], sndServers = [server]}
|
||||
SomeConn _ (DuplexConnection _ RcvQueue {server = s1} SndQueue {server = s2}) -> ConnectionStats {rcvServers = [s1], sndServers = [s2]}
|
||||
SomeConn _ (ContactConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []}
|
||||
SomeConn _ (NewConnection _) -> ConnectionStats {rcvServers = [], sndServers = []}
|
||||
getConnectionServers' c connId = do
|
||||
SomeConn _ conn <- withStore c (`getConn` connId)
|
||||
pure $ connectionStats conn
|
||||
|
||||
connectionStats :: Connection c -> ConnectionStats
|
||||
connectionStats = \case
|
||||
RcvConnection _ rq -> ConnectionStats {rcvServers = [qServer rq], sndServers = []}
|
||||
SndConnection _ sq -> ConnectionStats {rcvServers = [], sndServers = [qServer sq]}
|
||||
DuplexConnection _ rqs sqs -> ConnectionStats {rcvServers = map qServer $ L.toList rqs, sndServers = map qServer $ L.toList sqs}
|
||||
ContactConnection _ rq -> ConnectionStats {rcvServers = [qServer rq], sndServers = []}
|
||||
NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []}
|
||||
|
||||
-- | Change servers to be used for creating new queues, in Reader monad
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
@@ -1271,11 +1396,9 @@ pickServer = \case
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer
|
||||
getNextSMPServer c usedSrvs = do
|
||||
srvs <- readTVarIO $ smpServers c
|
||||
case L.nonEmpty $ deleteFirstsBy sameAddr (L.toList srvs) usedSrvs of
|
||||
case L.nonEmpty $ deleteFirstsBy sameSrvAddr (L.toList srvs) usedSrvs of
|
||||
Just srvs' -> pickServer srvs'
|
||||
_ -> pickServer srvs
|
||||
where
|
||||
sameAddr (SMPServer host port _) (SMPServer host' port' _) = host == host' && port == port'
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
subscriber c@AgentClient {msgQ} = forever $ do
|
||||
@@ -1288,7 +1411,10 @@ subscriber c@AgentClient {msgQ} = forever $ do
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
|
||||
processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) =
|
||||
withStore c (\db -> getRcvConn db srv rId) >>= \case
|
||||
SomeConn _ conn@(DuplexConnection cData rq _) -> processSMP conn cData rq
|
||||
-- TODO *** get queue separately?
|
||||
SomeConn _ conn@(DuplexConnection cData rqs _) -> case find (sameQueue (srv, rId)) rqs of
|
||||
Just rq -> processSMP conn cData rq
|
||||
_ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND)
|
||||
SomeConn _ conn@(RcvConnection cData rq) -> processSMP conn cData rq
|
||||
SomeConn _ conn@(ContactConnection cData rq) -> processSMP conn cData rq
|
||||
_ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND)
|
||||
@@ -1313,7 +1439,20 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
_ -> prohibited >> ack
|
||||
(Just e2eDh, Nothing) -> do
|
||||
decryptClientMessage e2eDh clientMsg >>= \case
|
||||
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) ->
|
||||
(SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do
|
||||
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
|
||||
let RcvQueue {primary, nextPrimary, dbReplaceQueueId} = rq
|
||||
unless primary . withStore' c $ \db -> do
|
||||
unless (status == Active) $ setRcvQueueStatus db rq Active
|
||||
when nextPrimary $ setRcvQueuePrimary db connId rq
|
||||
case (conn, dbReplaceQueueId) of
|
||||
(DuplexConnection _ rqs sqs, Just dbRcvId) ->
|
||||
case find (\RcvQueue {dbQueueId} -> dbQueueId == dbRcvId) rqs of
|
||||
Just RcvQueue {server, sndId} -> do
|
||||
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QDEL [(server, sndId)]
|
||||
notify . SWITCH SPTested $ connectionStats conn
|
||||
_ -> throwError $ INTERNAL "replaced RcvQueue not found in connection"
|
||||
_ -> pure ()
|
||||
tryError agentClientMsg >>= \case
|
||||
Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of
|
||||
HELLO -> helloMsg >> ackDel msgId
|
||||
@@ -1322,6 +1461,19 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
A_MSG body -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
notify $ MSG msgMeta msgFlags body
|
||||
QADD qs -> qDuplex "QADD" $ qAddMsg qs
|
||||
QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs
|
||||
QUSE qs -> qDuplex "QUSE" $ qUseMsg qs
|
||||
-- no action needed for QTEST
|
||||
-- any message in the new queue will mark it active and trigger deletion of the old queue
|
||||
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
|
||||
QDEL qs -> qDuplex "QDEL" $ qDelMsg qs
|
||||
QEND qs -> qDuplex "QEND" $ qEndMsg qs
|
||||
where
|
||||
qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m ()
|
||||
qDuplex name a = case conn of
|
||||
DuplexConnection {} -> a conn >> ackDel msgId
|
||||
_ -> qError $ name <> ": message must be sent to duplex connection"
|
||||
Right _ -> prohibited >> ack
|
||||
Left e@(AGENT A_DUPLICATE) -> do
|
||||
withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case
|
||||
@@ -1350,7 +1502,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
broker = (srvMsgId, systemToUTCTime srvTs)
|
||||
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
|
||||
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash}
|
||||
liftIO $ createRcvMsg db connId rcvMsg
|
||||
liftIO $ createRcvMsg db connId rq rcvMsg
|
||||
pure $ Just (internalId, msgMeta, aMessage)
|
||||
_ -> pure Nothing
|
||||
_ -> prohibited >> ack
|
||||
@@ -1434,12 +1586,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
confId <- withStore c $ \db -> do
|
||||
setHandshakeVersion db connId agentVersion duplexHS
|
||||
createConfirmation db g newConfirmation
|
||||
let srvs = map queueServer $ smpReplyQueues senderConf
|
||||
let srvs = map qServer $ smpReplyQueues senderConf
|
||||
notify $ CONF confId srvs connInfo
|
||||
queueServer (SMPQueueInfo _ SMPQueueAddress {smpServer}) = smpServer
|
||||
_ -> prohibited
|
||||
-- party accepting connection
|
||||
(DuplexConnection _ RcvQueue {smpClientVersion = v'} _, Nothing) -> do
|
||||
(DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do
|
||||
withStore c (\db -> runExceptT $ agentRatchetDecrypt db connId encConnInfo) >>= parseMessage >>= \case
|
||||
AgentConnInfo connInfo -> do
|
||||
notify $ INFO connInfo
|
||||
@@ -1458,7 +1609,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
_ -> do
|
||||
withStore' c $ \db -> setRcvQueueStatus db rq Active
|
||||
case conn of
|
||||
DuplexConnection _ _ sq@SndQueue {status = sndStatus}
|
||||
DuplexConnection _ _ (sq@SndQueue {status = sndStatus} :| _)
|
||||
-- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO
|
||||
-- this branch is executed by the accepting party in duplexHandshake mode (v2)
|
||||
-- and by the initiating party in v1
|
||||
@@ -1471,7 +1622,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
enqueueDuplexHello :: SndQueue -> m ()
|
||||
enqueueDuplexHello sq = void $ enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
|
||||
|
||||
replyMsg :: L.NonEmpty SMPQueueInfo -> m ()
|
||||
replyMsg :: NonEmpty SMPQueueInfo -> m ()
|
||||
replyMsg smpQueues = do
|
||||
logServer "<--" c srv rId "MSG <REPLY>"
|
||||
case duplexHandshake of
|
||||
@@ -1482,6 +1633,95 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR)
|
||||
_ -> prohibited
|
||||
|
||||
-- processed by queue sender
|
||||
qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m ()
|
||||
qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
|
||||
qAddMsg ((qUri, Just addr) :| _) (DuplexConnection _ rqs sqs@(sq :| sqs_)) = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case qUri `compatibleVersion` clientVRange of
|
||||
Just qInfo@(Compatible sqInfo@SMPQueueInfo {queueAddress}) ->
|
||||
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
|
||||
(Just _, _) -> qError "QADD: queue address is already used in connection"
|
||||
(_, Just _replaced@SndQueue {dbQueueId}) -> do
|
||||
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo
|
||||
let sq' = (sq_ :: SndQueue) {nextPrimary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
void . withStore c $ \db -> addConnSndQueue db connId sq'
|
||||
case (sndPublicKey, e2ePubKey) of
|
||||
(Just sndPubKey, Just dhPublicKey) -> do
|
||||
logServer "<--" c srv rId $ "MSG <QADD> " <> logSecret (senderId queueAddress)
|
||||
let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}}
|
||||
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)]
|
||||
let conn' = DuplexConnection cData rqs (sq <| sq' :| sqs_)
|
||||
notify . SWITCH SPStarted $ connectionStats conn'
|
||||
_ -> qError "absent sender keys"
|
||||
_ -> qError "QADD: replaced queue address is not found in connection"
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
|
||||
-- processed by queue recipient
|
||||
qKeyMsg :: NonEmpty (SMPQueueInfo, SndPublicVerifyKey) -> Connection 'CDuplex -> m ()
|
||||
qKeyMsg ((qInfo, senderKey) :| _) (DuplexConnection _ rqs _) = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
|
||||
case findRQ (smpServer, senderId) rqs of
|
||||
Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'}
|
||||
| status' == New || status' == Confirmed -> do
|
||||
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
|
||||
let dhSecret = C.dh' dhPublicKey dhPrivKey
|
||||
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
|
||||
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
|
||||
notify . SWITCH SPConfirmed $ connectionStats conn
|
||||
| otherwise -> qError "QKEY: queue already secured"
|
||||
_ -> qError "QKEY: queue address not found in connection"
|
||||
where
|
||||
SMPQueueInfo cVer' SMPQueueAddress {smpServer, senderId, dhPublicKey} = qInfo
|
||||
|
||||
-- processed by queue sender
|
||||
-- mark queue as Secured and to start sending messages to it
|
||||
qUseMsg :: NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> m ()
|
||||
qUseMsg ((addr, primary) :| _) (DuplexConnection _ _ sqs) =
|
||||
case removeQ addr sqs of
|
||||
Just (sq', sqs') -> do
|
||||
logServer "<--" c srv rId $ "MSG <QUSE> " <> logSecret (snd addr)
|
||||
withStore' c $ \db -> do
|
||||
setSndQueueStatus db sq' Secured
|
||||
when primary $ setSndQueuePrimary db connId sq'
|
||||
let sq'' = (sq' :: SndQueue) {status = Secured, primary}
|
||||
void $ enqueueMessages c cData (sq'' :| sqs') SMP.noMsgFlags $ QTEST [addr]
|
||||
notify . SWITCH SPConfirmed $ connectionStats conn
|
||||
_ -> qError "QUSE: queue address not found in connection"
|
||||
|
||||
-- processed by queue sender
|
||||
-- remove snd queue from connection and enqueue QEND message
|
||||
qDelMsg :: NonEmpty (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
|
||||
qDelMsg (addr :| _) (DuplexConnection _ rqs sqs) =
|
||||
case removeQ addr sqs of
|
||||
Nothing -> logServer "<--" c srv rId "MSG <QDEL>: queue not found (already deleted?)"
|
||||
Just (sq, sq' : sqs') -> do
|
||||
logServer "<--" c srv rId $ "MSG <QDEL> " <> logSecret (snd addr)
|
||||
-- remove the delivery from the map to stop the thread when the delivery loop is complete
|
||||
atomically $ TM.delete addr $ smpQueueMsgQueues c
|
||||
withStore' c $ \db -> do
|
||||
deletePendingMsgs db connId sq
|
||||
deleteConnSndQueue db connId sq
|
||||
let sqs'' = sq' :| sqs'
|
||||
conn' = DuplexConnection cData rqs sqs''
|
||||
void $ enqueueMessages c cData sqs'' SMP.noMsgFlags $ QEND [addr]
|
||||
notify . SWITCH SPTested $ connectionStats conn'
|
||||
_ -> qError "QDEL received to the only queue in connection"
|
||||
|
||||
-- received by party initiating switch
|
||||
-- TODO *** check that the received address matches expectations
|
||||
qEndMsg :: NonEmpty (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
|
||||
qEndMsg (addr@(smpServer, senderId) :| _) (DuplexConnection _ rqs _) =
|
||||
case findRQ addr rqs of
|
||||
Just RcvQueue {rcvId} -> do
|
||||
logServer "<--" c srv rId $ "MSG <QEND> " <> logSecret senderId
|
||||
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQDelete rcvId
|
||||
_ -> qError "QEND: queue address not found in connection"
|
||||
|
||||
qError :: String -> m ()
|
||||
qError = throwError . AGENT . A_QUEUE
|
||||
|
||||
smpInvitation :: ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
smpInvitation connReq@(CRInvitationUri crData _) cInfo = do
|
||||
logServer "<--" c srv rId "MSG <KEY>"
|
||||
@@ -1490,11 +1730,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
g <- asks idsDrg
|
||||
let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo}
|
||||
invId <- withStore c $ \db -> createInvitation db g newInv
|
||||
let srvs = L.map queueServer $ crSmpQueues crData
|
||||
let srvs = L.map qServer $ crSmpQueues crData
|
||||
notify $ REQ invId srvs cInfo
|
||||
_ -> prohibited
|
||||
where
|
||||
queueServer (SMPQueueUri _ SMPQueueAddress {smpServer}) = smpServer
|
||||
|
||||
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity
|
||||
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash
|
||||
@@ -1505,15 +1743,15 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm
|
||||
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
|
||||
| otherwise = MsgError MsgDuplicate -- this case is not possible
|
||||
|
||||
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> L.NonEmpty SMPQueueInfo -> m ()
|
||||
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
|
||||
connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case qInfo `proveCompatible` clientVRange of
|
||||
Nothing -> throwError $ AGENT A_VERSION
|
||||
Just qInfo' -> do
|
||||
sq <- newSndQueue qInfo'
|
||||
withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
enqueueConfirmation c cData sq ownConnInfo Nothing
|
||||
sq <- newSndQueue connId qInfo'
|
||||
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
|
||||
|
||||
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do
|
||||
@@ -1570,28 +1808,34 @@ agentRatchetDecrypt db connId encAgentMsg = do
|
||||
liftIO $ updateRatchet db connId rc' skippedDiff
|
||||
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
|
||||
|
||||
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => Compatible SMPQueueInfo -> m SndQueue
|
||||
newSndQueue qInfo =
|
||||
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue
|
||||
newSndQueue connId qInfo =
|
||||
asks (cmdSignAlg . config) >>= \case
|
||||
C.SignAlg a -> newSndQueue_ a qInfo
|
||||
C.SignAlg a -> newSndQueue_ a connId qInfo
|
||||
|
||||
newSndQueue_ ::
|
||||
(C.SignatureAlgorithm a, C.AlgorithmI a, MonadUnliftIO m) =>
|
||||
C.SAlgorithm a ->
|
||||
ConnId ->
|
||||
Compatible SMPQueueInfo ->
|
||||
m SndQueue
|
||||
newSndQueue_ a (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
|
||||
newSndQueue_ a connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
|
||||
-- this function assumes clientVersion is compatible - it was tested before
|
||||
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
pure
|
||||
SndQueue
|
||||
{ server = smpServer,
|
||||
{ connId,
|
||||
server = smpServer,
|
||||
sndId = senderId,
|
||||
sndPublicKey = Just sndPublicKey,
|
||||
sndPrivateKey,
|
||||
e2eDhSecret = C.dh' rcvE2ePubDhKey e2ePrivKey,
|
||||
e2ePubKey = Just e2ePubKey,
|
||||
status = New,
|
||||
dbQueueId = 0,
|
||||
primary = True,
|
||||
nextPrimary = False,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ module Simplex.Messaging.Agent.Client
|
||||
suspendQueue,
|
||||
deleteQueue,
|
||||
logServer,
|
||||
logSecret,
|
||||
removeSubscription,
|
||||
hasActiveSubscription,
|
||||
agentClientStore,
|
||||
@@ -62,6 +63,7 @@ module Simplex.Messaging.Agent.Client
|
||||
agentOperationBracket,
|
||||
waitUntilActive,
|
||||
throwWhenInactive,
|
||||
throwWhenNoDelivery,
|
||||
beginAgentOperation,
|
||||
endAgentOperation,
|
||||
suspendSendingAndDatabase,
|
||||
@@ -90,11 +92,12 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (isRight, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.List (partition, (\\))
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (listToMaybe)
|
||||
import Data.Maybe (isJust, listToMaybe)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text.Encoding
|
||||
@@ -107,6 +110,8 @@ import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction)
|
||||
import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues)
|
||||
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -140,8 +145,6 @@ import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.TMap2 (TMap2)
|
||||
import qualified Simplex.Messaging.TMap2 as TM2
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
@@ -167,11 +170,11 @@ data AgentClient = AgentClient
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
useNetworkConfig :: TVar NetworkConfig,
|
||||
subscrConns :: TVar (Set ConnId),
|
||||
activeSubs :: TMap2 SMPServer ConnId RcvQueue,
|
||||
pendingSubs :: TMap2 SMPServer ConnId RcvQueue,
|
||||
connMsgsQueued :: TMap ConnId Bool,
|
||||
smpQueueMsgQueues :: TMap (SMPServer, SMP.SenderId) (TQueue InternalId),
|
||||
smpQueueMsgDeliveries :: TMap (SMPServer, SMP.SenderId) (Async ()),
|
||||
activeSubs :: TRcvQueues,
|
||||
pendingSubs :: TRcvQueues,
|
||||
pendingMsgsQueued :: TMap SndQAddr Bool,
|
||||
smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId),
|
||||
smpQueueMsgDeliveries :: TMap SndQAddr (Async ()),
|
||||
connCmdsQueued :: TMap ConnId Bool,
|
||||
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
|
||||
asyncCmdProcesses :: TMap (Maybe SMPServer) (Async ()),
|
||||
@@ -229,9 +232,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
ntfClients <- TM.empty
|
||||
useNetworkConfig <- newTVar netCfg
|
||||
subscrConns <- newTVar S.empty
|
||||
activeSubs <- TM2.empty
|
||||
pendingSubs <- TM2.empty
|
||||
connMsgsQueued <- TM.empty
|
||||
activeSubs <- RQ.empty
|
||||
pendingSubs <- RQ.empty
|
||||
pendingMsgsQueued <- TM.empty
|
||||
smpQueueMsgQueues <- TM.empty
|
||||
smpQueueMsgDeliveries <- TM.empty
|
||||
connCmdsQueued <- TM.empty
|
||||
@@ -249,7 +252,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
reconnections <- newTVar []
|
||||
asyncClients <- newTVar []
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, clientId, agentEnv}
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, pendingMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, clientId, agentEnv}
|
||||
|
||||
agentClientStore :: AgentClient -> SQLiteStore
|
||||
agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
@@ -282,25 +285,25 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
|
||||
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
|
||||
clientDisconnected u client = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown)
|
||||
removeClientAndSubs >>= serverDown
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
where
|
||||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
TM2.insert1 srv cVar $ pendingSubs c
|
||||
readTVar cVar
|
||||
qs <- RQ.getDelSrvQueues srv $ activeSubs c
|
||||
mapM_ (`RQ.addQueue` pendingSubs c) qs
|
||||
cs <- RQ.getConns (activeSubs c)
|
||||
-- TODO deduplicate conns
|
||||
let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs
|
||||
pure (qs, conns)
|
||||
|
||||
serverDown :: Map ConnId RcvQueue -> IO ()
|
||||
serverDown cs = whenM (readTVarIO active) $ do
|
||||
serverDown :: ([RcvQueue], [ConnId]) -> IO ()
|
||||
serverDown (qs, conns) = whenM (readTVarIO active) $ do
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
let conns = M.keys cs
|
||||
unless (null conns) $ do
|
||||
notifySub "" $ DOWN srv conns
|
||||
atomically $ mapM_ (releaseGetLock c) cs
|
||||
unless (null conns) $ notifySub "" $ DOWN srv conns
|
||||
unless (null qs) $ do
|
||||
atomically $ mapM_ (releaseGetLock c) qs
|
||||
unliftIO u reconnectServer
|
||||
|
||||
reconnectServer :: m ()
|
||||
@@ -317,20 +320,22 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withLockMap_ (reconnectLocks c) srv "reconnect" $
|
||||
atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar)
|
||||
>>= mapM_ resubscribe
|
||||
atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe
|
||||
where
|
||||
resubscribe :: Map ConnId RcvQueue -> m ()
|
||||
resubscribe :: [RcvQueue] -> m ()
|
||||
resubscribe qs = do
|
||||
connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar)
|
||||
(client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs
|
||||
cs <- atomically . RQ.getConns $ activeSubs c
|
||||
(client_, rs) <- subscribeQueues c srv qs
|
||||
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
|
||||
liftIO $ do
|
||||
unless connected $ mapM_ (notifySub "" . hostEvent CONNECT) client_
|
||||
unless (M.null oks) $ do
|
||||
notifySub "" . UP srv $ M.keys oks
|
||||
let (tempErrs, finalErrs) = M.partition temporaryAgentError errs
|
||||
liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs
|
||||
mapM_ throwError . listToMaybe $ M.elems tempErrs
|
||||
-- TODO deduplicate okConns
|
||||
let conns = okConns \\ S.toList cs
|
||||
unless (null conns) $ notifySub "" $ UP srv conns
|
||||
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
|
||||
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
|
||||
mapM_ (throwError . snd) $ listToMaybe tempErrs
|
||||
|
||||
notifySub :: ConnId -> ACommand 'Agent -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
@@ -425,10 +430,10 @@ closeAgentClient c = liftIO $ do
|
||||
cancelActions $ asyncClients c
|
||||
cancelActions $ smpQueueMsgDeliveries c
|
||||
cancelActions $ asyncCmdProcesses c
|
||||
atomically . TM2.clear $ activeSubs c
|
||||
atomically . TM2.clear $ pendingSubs c
|
||||
atomically . RQ.clear $ activeSubs c
|
||||
atomically . RQ.clear $ pendingSubs c
|
||||
clear subscrConns
|
||||
clear connMsgsQueued
|
||||
clear pendingMsgsQueued
|
||||
clear smpQueueMsgQueues
|
||||
clear connCmdsQueued
|
||||
clear asyncCmdQueues
|
||||
@@ -443,6 +448,14 @@ waitUntilActive c = unlessM (readTVar $ active c) retry
|
||||
throwWhenInactive :: AgentClient -> STM ()
|
||||
throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled
|
||||
|
||||
throwWhenNoDelivery :: AgentClient -> SndQueue -> STM ()
|
||||
throwWhenNoDelivery c SndQueue {server, sndId} =
|
||||
unlessM (isJust <$> TM.lookup k (smpQueueMsgQueues c)) $ do
|
||||
TM.delete k $ smpQueueMsgDeliveries c
|
||||
throwSTM ThreadKilled
|
||||
where
|
||||
k = (server, sndId)
|
||||
|
||||
closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients c clientsSel =
|
||||
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
|
||||
@@ -502,19 +515,20 @@ protocolClientError protocolError_ = \case
|
||||
e@PCESignatureError {} -> INTERNAL $ show e
|
||||
e@PCEIOError {} -> INTERNAL $ show e
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c srv vRange =
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c connId srv vRange =
|
||||
asks (cmdSignAlg . config) >>= \case
|
||||
C.SignAlg a -> newRcvQueue_ a c srv vRange
|
||||
C.SignAlg a -> newRcvQueue_ a c connId srv vRange
|
||||
|
||||
newRcvQueue_ ::
|
||||
(C.SignatureAlgorithm a, C.AlgorithmI a, AgentMonad m) =>
|
||||
C.SAlgorithm a ->
|
||||
AgentClient ->
|
||||
ConnId ->
|
||||
SMPServer ->
|
||||
VersionRange ->
|
||||
m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue_ a c srv vRange = do
|
||||
newRcvQueue_ a c connId srv vRange = do
|
||||
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
@@ -524,36 +538,41 @@ newRcvQueue_ a c srv vRange = do
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
let rq =
|
||||
RcvQueue
|
||||
{ server = srv,
|
||||
{ connId,
|
||||
server = srv,
|
||||
rcvId,
|
||||
rcvPrivateKey,
|
||||
rcvDhSecret = C.dh' rcvPublicDhKey privDhKey,
|
||||
e2ePrivKey,
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = Just sndId,
|
||||
sndId,
|
||||
status = New,
|
||||
dbQueueId = 0,
|
||||
primary = True,
|
||||
nextPrimary = False,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = maxVersion vRange,
|
||||
clientNtfCreds = Nothing
|
||||
}
|
||||
pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey)
|
||||
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
TM2.insert server connId rq $ pendingSubs c
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq)
|
||||
>>= either throwError pure
|
||||
|
||||
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult c rq connId r = do
|
||||
processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult c rq r = do
|
||||
case r of
|
||||
Left e ->
|
||||
atomically . unless (temporaryClientError e) $
|
||||
TM2.delete connId (pendingSubs c)
|
||||
_ -> addSubscription c rq connId
|
||||
RQ.deleteQueue rq (pendingSubs c)
|
||||
_ -> addSubscription c rq
|
||||
pure r
|
||||
|
||||
temporaryClientError :: ProtocolClientError -> Bool
|
||||
@@ -569,45 +588,46 @@ temporaryAgentError = \case
|
||||
_ -> False
|
||||
|
||||
-- | subscribe multiple queues - all passed queues should be on the same server
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())])
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
|
||||
forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue qs
|
||||
forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do
|
||||
-- TODO check server is correct
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
TM2.insert server connId rq $ pendingSubs c
|
||||
RQ.addQueue rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
(eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of
|
||||
Left e -> pure $ map (second . const $ Left e) qs_
|
||||
(eitherToMaybe smp_,) . (errs <>) <$> case smp_ of
|
||||
Left e -> pure $ map (,Left e) qs_
|
||||
Right smp -> do
|
||||
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
|
||||
let qs2 = L.map (queueCreds . snd) qs'
|
||||
rs' :: [((ConnId, RcvQueue), Either ProtocolClientError ())] <-
|
||||
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
|
||||
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
|
||||
pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
|
||||
_ -> pure (Nothing, M.fromList errs)
|
||||
let qs2 = L.map queueCreds qs'
|
||||
liftIO $ do
|
||||
rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
|
||||
mapM_ (uncurry $ processSubResult c) rs
|
||||
pure $ map (second . first $ protocolClientError SMP) rs
|
||||
_ -> pure (Nothing, errs)
|
||||
where
|
||||
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
|
||||
checkQueue rq@RcvQueue {rcvId, server} = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
pure $ if prohibited || srv /= server then Left (connId, Left $ CMD PROHIBITED) else Right rq
|
||||
pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq
|
||||
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m ()
|
||||
addSubscription c rq@RcvQueue {connId} = atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
TM2.insert server connId rq $ activeSubs c
|
||||
TM2.delete connId $ pendingSubs c
|
||||
RQ.addQueue rq $ activeSubs c
|
||||
RQ.deleteQueue rq $ pendingSubs c
|
||||
|
||||
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
|
||||
hasActiveSubscription c connId = TM2.member connId $ activeSubs c
|
||||
hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c
|
||||
|
||||
removeSubscription :: AgentClient -> ConnId -> STM ()
|
||||
removeSubscription c connId = do
|
||||
modifyTVar (subscrConns c) $ S.delete connId
|
||||
TM2.delete connId $ activeSubs c
|
||||
TM2.delete connId $ pendingSubs c
|
||||
RQ.deleteConn connId $ activeSubs c
|
||||
RQ.deleteConn connId $ pendingSubs c
|
||||
|
||||
getSubscriptions :: AgentClient -> STM (Set ConnId)
|
||||
getSubscriptions = readTVar . subscrConns
|
||||
|
||||
@@ -96,7 +96,8 @@ data AgentConfig = AgentConfig
|
||||
certificateFile :: FilePath,
|
||||
e2eEncryptVRange :: VersionRange,
|
||||
smpAgentVRange :: VersionRange,
|
||||
smpClientVRange :: VersionRange
|
||||
smpClientVRange :: VersionRange,
|
||||
initialClientId :: Int
|
||||
}
|
||||
|
||||
defaultReconnectInterval :: RetryInterval
|
||||
@@ -142,7 +143,8 @@ defaultAgentConfig =
|
||||
certificateFile = "/etc/opt/simplex-agent/agent.crt",
|
||||
e2eEncryptVRange = supportedE2EEncryptVRange,
|
||||
smpAgentVRange = supportedSMPAgentVRange,
|
||||
smpClientVRange = supportedSMPClientVRange
|
||||
smpClientVRange = supportedSMPClientVRange,
|
||||
initialClientId = 0
|
||||
}
|
||||
|
||||
data Env = Env
|
||||
@@ -155,12 +157,12 @@ data Env = Env
|
||||
}
|
||||
|
||||
newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
|
||||
newSMPAgentEnv config@AgentConfig {database, yesToMigrations} = do
|
||||
newSMPAgentEnv config@AgentConfig {database, yesToMigrations, initialClientId} = do
|
||||
idsDrg <- newTVarIO =<< drgNew
|
||||
store <- case database of
|
||||
AgentDB st -> pure st
|
||||
AgentDBFile {dbFile, dbKey} -> liftIO $ createAgentStore dbFile dbKey yesToMigrations
|
||||
clientCounter <- newTVarIO 0
|
||||
clientCounter <- newTVarIO initialClientId
|
||||
randomServer <- newTVarIO =<< liftIO newStdGen
|
||||
ntfSupervisor <- atomically . newNtfSubSupervisor $ tbqSize config
|
||||
return Env {config, store, idsDrg, clientCounter, randomServer, ntfSupervisor}
|
||||
|
||||
@@ -73,7 +73,7 @@ processNtfSub c (connId, cmd) = do
|
||||
NSCCreate -> do
|
||||
(a, RcvQueue {server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do
|
||||
a <- liftIO $ getNtfSubscription db connId
|
||||
q <- ExceptT $ getRcvQueue db connId
|
||||
q <- ExceptT $ getPrimaryRcvQueue db connId
|
||||
pure (a, q)
|
||||
logInfo $ "processNtfSub, NSCCreate - a = " <> tshow a
|
||||
case a of
|
||||
@@ -125,7 +125,7 @@ processNtfSub c (connId, cmd) = do
|
||||
(Just (NtfSubscription {ntfServer}, _)) -> addNtfNTFWorker ntfServer
|
||||
_ -> pure () -- err "NSCDelete - no subscription"
|
||||
NSCSmpDelete -> do
|
||||
withStore' c (`getRcvQueue` connId) >>= \case
|
||||
withStore' c (`getPrimaryRcvQueue` connId) >>= \case
|
||||
Right rq@RcvQueue {server = smpServer} -> do
|
||||
logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq
|
||||
ts <- liftIO getCurrentTime
|
||||
@@ -185,7 +185,7 @@ runNtfWorker c srv doWork = do
|
||||
NSACreate ->
|
||||
getNtfToken >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
|
||||
RcvQueue {clientNtfCreds} <- withStore c (`getRcvQueue` connId)
|
||||
RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId)
|
||||
case clientNtfCreds of
|
||||
Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do
|
||||
nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey
|
||||
@@ -260,7 +260,7 @@ runNtfSMPWorker c srv doWork = do
|
||||
NSASmpKey ->
|
||||
getNtfToken >>= \case
|
||||
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
|
||||
rq <- withStore c (`getRcvQueue` connId)
|
||||
rq <- withStore c (`getPrimaryRcvQueue` connId)
|
||||
C.SignAlg a <- asks (cmdSignAlg . config)
|
||||
(ntfPublicKey, ntfPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(rcvNtfPubDhKey, rcvNtfPrivDhKey) <- liftIO C.generateKeyPair'
|
||||
@@ -275,7 +275,7 @@ runNtfSMPWorker c srv doWork = do
|
||||
NSASmpDelete -> do
|
||||
rq_ <- withStore' c $ \db -> do
|
||||
setRcvQueueNtfCreds db connId Nothing
|
||||
getRcvQueue db connId
|
||||
getPrimaryRcvQueue db connId
|
||||
forM_ rq_ $ \rq -> disableQueueNotifications c rq
|
||||
withStore' c $ \db -> deleteNtfSubscription db connId
|
||||
|
||||
|
||||
@@ -50,15 +50,19 @@ module Simplex.Messaging.Agent.Protocol
|
||||
MsgHash,
|
||||
MsgMeta (..),
|
||||
ConnectionStats (..),
|
||||
SwitchPhase (..),
|
||||
SMPConfirmation (..),
|
||||
AgentMsgEnvelope (..),
|
||||
AgentMessage (..),
|
||||
AgentMessageType (..),
|
||||
APrivHeader (..),
|
||||
AMessage (..),
|
||||
SndQAddr,
|
||||
SMPServer,
|
||||
pattern SMPServer,
|
||||
SrvLoc (..),
|
||||
SMPQueue (..),
|
||||
sameQAddress,
|
||||
SMPQueueUri (..),
|
||||
SMPQueueInfo (..),
|
||||
SMPQueueAddress (..),
|
||||
@@ -162,6 +166,7 @@ import Simplex.Messaging.Protocol
|
||||
legacyEncodeServer,
|
||||
legacyServerP,
|
||||
legacyStrEncodeServer,
|
||||
sameSrvAddr,
|
||||
pattern SMPServer,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
@@ -247,12 +252,14 @@ data ACommand (p :: AParty) where
|
||||
DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent
|
||||
DOWN :: SMPServer -> [ConnId] -> ACommand Agent
|
||||
UP :: SMPServer -> [ConnId] -> ACommand Agent
|
||||
SWITCH :: SwitchPhase -> ConnectionStats -> ACommand Agent
|
||||
SEND :: MsgFlags -> MsgBody -> ACommand Client
|
||||
MID :: AgentMsgId -> ACommand Agent
|
||||
SENT :: AgentMsgId -> ACommand Agent
|
||||
MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent
|
||||
MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent
|
||||
ACK :: AgentMsgId -> ACommand Client
|
||||
SWCH :: ACommand Client
|
||||
OFF :: ACommand Client
|
||||
DEL :: ACommand Client
|
||||
CHK :: ACommand Client
|
||||
@@ -284,12 +291,14 @@ data ACommandTag (p :: AParty) where
|
||||
DISCONNECT_ :: ACommandTag Agent
|
||||
DOWN_ :: ACommandTag Agent
|
||||
UP_ :: ACommandTag Agent
|
||||
SWITCH_ :: ACommandTag Agent
|
||||
SEND_ :: ACommandTag Client
|
||||
MID_ :: ACommandTag Agent
|
||||
SENT_ :: ACommandTag Agent
|
||||
MERR_ :: ACommandTag Agent
|
||||
MSG_ :: ACommandTag Agent
|
||||
ACK_ :: ACommandTag Client
|
||||
SWCH_ :: ACommandTag Client
|
||||
OFF_ :: ACommandTag Client
|
||||
DEL_ :: ACommandTag Client
|
||||
CHK_ :: ACommandTag Client
|
||||
@@ -320,12 +329,14 @@ aCommandTag = \case
|
||||
DISCONNECT {} -> DISCONNECT_
|
||||
DOWN {} -> DOWN_
|
||||
UP {} -> UP_
|
||||
SWITCH {} -> SWITCH_
|
||||
SEND {} -> SEND_
|
||||
MID _ -> MID_
|
||||
SENT _ -> SENT_
|
||||
MERR {} -> MERR_
|
||||
MSG {} -> MSG_
|
||||
ACK _ -> ACK_
|
||||
SWCH -> SWCH_
|
||||
OFF -> OFF_
|
||||
DEL -> DEL_
|
||||
CHK -> CHK_
|
||||
@@ -334,6 +345,23 @@ aCommandTag = \case
|
||||
ERR _ -> ERR_
|
||||
SUSPENDED -> SUSPENDED_
|
||||
|
||||
data SwitchPhase = SPStarted | SPConfirmed | SPTested | SPCompleted
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding SwitchPhase where
|
||||
strEncode = \case
|
||||
SPStarted -> "started"
|
||||
SPConfirmed -> "confirmed"
|
||||
SPTested -> "tested"
|
||||
SPCompleted -> "completed"
|
||||
strP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"started" -> pure SPStarted
|
||||
"confirmed" -> pure SPConfirmed
|
||||
"tested" -> pure SPTested
|
||||
"completed" -> pure SPCompleted
|
||||
_ -> fail "bad SwitchPhase"
|
||||
|
||||
data ConnectionStats = ConnectionStats
|
||||
{ rcvServers :: [SMPServer],
|
||||
sndServers :: [SMPServer]
|
||||
@@ -508,7 +536,18 @@ instance Encoding AgentMessage where
|
||||
'M' -> AgentMessage <$> smpP <*> smpP
|
||||
_ -> fail "bad AgentMessage"
|
||||
|
||||
data AgentMessageType = AM_CONN_INFO | AM_CONN_INFO_REPLY | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_
|
||||
data AgentMessageType
|
||||
= AM_CONN_INFO
|
||||
| AM_CONN_INFO_REPLY
|
||||
| AM_HELLO_
|
||||
| AM_REPLY_
|
||||
| AM_A_MSG_
|
||||
| AM_QADD_
|
||||
| AM_QKEY_
|
||||
| AM_QUSE_
|
||||
| AM_QTEST_
|
||||
| AM_QDEL_
|
||||
| AM_QEND_
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding AgentMessageType where
|
||||
@@ -518,6 +557,12 @@ instance Encoding AgentMessageType where
|
||||
AM_HELLO_ -> "H"
|
||||
AM_REPLY_ -> "R"
|
||||
AM_A_MSG_ -> "M"
|
||||
AM_QADD_ -> "QA"
|
||||
AM_QKEY_ -> "QK"
|
||||
AM_QUSE_ -> "QU"
|
||||
AM_QTEST_ -> "QT"
|
||||
AM_QDEL_ -> "QD"
|
||||
AM_QEND_ -> "QE"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'C' -> pure AM_CONN_INFO
|
||||
@@ -525,6 +570,15 @@ instance Encoding AgentMessageType where
|
||||
'H' -> pure AM_HELLO_
|
||||
'R' -> pure AM_REPLY_
|
||||
'M' -> pure AM_A_MSG_
|
||||
'Q' ->
|
||||
A.anyChar >>= \case
|
||||
'A' -> pure AM_QADD_
|
||||
'K' -> pure AM_QKEY_
|
||||
'U' -> pure AM_QUSE_
|
||||
'T' -> pure AM_QTEST_
|
||||
'D' -> pure AM_QDEL_
|
||||
'E' -> pure AM_QEND_
|
||||
_ -> fail "bad AgentMessageType"
|
||||
_ -> fail "bad AgentMessageType"
|
||||
|
||||
agentMessageType :: AgentMessage -> AgentMessageType
|
||||
@@ -540,6 +594,12 @@ agentMessageType = \case
|
||||
-- REPLY is only used in v1
|
||||
REPLY _ -> AM_REPLY_
|
||||
A_MSG _ -> AM_A_MSG_
|
||||
QADD _ -> AM_QADD_
|
||||
QKEY _ -> AM_QKEY_
|
||||
QUSE _ -> AM_QUSE_
|
||||
QTEST _ -> AM_QTEST_
|
||||
QDEL _ -> AM_QDEL_
|
||||
QEND _ -> AM_QEND_
|
||||
|
||||
data APrivHeader = APrivHeader
|
||||
{ -- | sequential ID assigned by the sending agent
|
||||
@@ -554,7 +614,16 @@ instance Encoding APrivHeader where
|
||||
smpEncode (sndMsgId, prevMsgHash)
|
||||
smpP = APrivHeader <$> smpP <*> smpP
|
||||
|
||||
data AMsgType = HELLO_ | REPLY_ | A_MSG_
|
||||
data AMsgType
|
||||
= HELLO_
|
||||
| REPLY_
|
||||
| A_MSG_
|
||||
| QADD_
|
||||
| QKEY_
|
||||
| QUSE_
|
||||
| QTEST_
|
||||
| QDEL_
|
||||
| QEND_
|
||||
deriving (Eq)
|
||||
|
||||
instance Encoding AMsgType where
|
||||
@@ -562,11 +631,26 @@ instance Encoding AMsgType where
|
||||
HELLO_ -> "H"
|
||||
REPLY_ -> "R"
|
||||
A_MSG_ -> "M"
|
||||
QADD_ -> "QA"
|
||||
QKEY_ -> "QK"
|
||||
QUSE_ -> "QU"
|
||||
QTEST_ -> "QT"
|
||||
QDEL_ -> "QD"
|
||||
QEND_ -> "QE"
|
||||
smpP =
|
||||
smpP >>= \case
|
||||
A.anyChar >>= \case
|
||||
'H' -> pure HELLO_
|
||||
'R' -> pure REPLY_
|
||||
'M' -> pure A_MSG_
|
||||
'Q' ->
|
||||
A.anyChar >>= \case
|
||||
'A' -> pure QADD_
|
||||
'K' -> pure QKEY_
|
||||
'U' -> pure QUSE_
|
||||
'T' -> pure QTEST_
|
||||
'D' -> pure QDEL_
|
||||
'E' -> pure QEND_
|
||||
_ -> fail "bad AMsgType"
|
||||
_ -> fail "bad AMsgType"
|
||||
|
||||
-- | Messages sent between SMP agents once SMP queue is secured.
|
||||
@@ -579,19 +663,45 @@ data AMessage
|
||||
REPLY (L.NonEmpty SMPQueueInfo)
|
||||
| -- | agent envelope for the client message
|
||||
A_MSG MsgBody
|
||||
| -- add queue to connection (sent by recipient), with optional address of the replaced queue
|
||||
QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr))
|
||||
| -- key to secure the added queues and agree e2e encryption key (sent by sender)
|
||||
QKEY (L.NonEmpty (SMPQueueInfo, SndPublicVerifyKey))
|
||||
| -- inform that the queues are ready to use (sent by recipient)
|
||||
QUSE (L.NonEmpty (SndQAddr, Bool))
|
||||
| -- sent by the sender to test new queues
|
||||
QTEST (L.NonEmpty SndQAddr)
|
||||
| -- inform that the queues will be deleted (sent recipient once message received via the new queue)
|
||||
QDEL (L.NonEmpty SndQAddr)
|
||||
| -- sent by sender to confirm that no more messages will be sent to the queue
|
||||
QEND (L.NonEmpty SndQAddr)
|
||||
deriving (Show)
|
||||
|
||||
type SndQAddr = (SMPServer, SMP.SenderId)
|
||||
|
||||
instance Encoding AMessage where
|
||||
smpEncode = \case
|
||||
HELLO -> smpEncode HELLO_
|
||||
REPLY smpQueues -> smpEncode (REPLY_, smpQueues)
|
||||
A_MSG body -> smpEncode (A_MSG_, Tail body)
|
||||
QADD qs -> smpEncode (QADD_, qs)
|
||||
QKEY qs -> smpEncode (QKEY_, qs)
|
||||
QUSE qs -> smpEncode (QUSE_, qs)
|
||||
QTEST qs -> smpEncode (QTEST_, qs)
|
||||
QDEL qs -> smpEncode (QDEL_, qs)
|
||||
QEND qs -> smpEncode (QEND_, qs)
|
||||
smpP =
|
||||
smpP
|
||||
>>= \case
|
||||
HELLO_ -> pure HELLO
|
||||
REPLY_ -> REPLY <$> smpP
|
||||
A_MSG_ -> A_MSG . unTail <$> smpP
|
||||
QADD_ -> QADD <$> smpP
|
||||
QKEY_ -> QKEY <$> smpP
|
||||
QUSE_ -> QUSE <$> smpP
|
||||
QTEST_ -> QTEST <$> smpP
|
||||
QDEL_ -> QDEL <$> smpP
|
||||
QEND_ -> QEND <$> smpP
|
||||
|
||||
instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) where
|
||||
strEncode = \case
|
||||
@@ -688,6 +798,11 @@ updateSMPServerHosts srv@ProtocolServer {host} = case host of
|
||||
_ -> srv
|
||||
_ -> srv
|
||||
|
||||
class SMPQueue q where
|
||||
qServer :: q -> SMPServer
|
||||
qAddress :: q -> (SMPServer, SMP.QueueId)
|
||||
sameQueue :: (SMPServer, SMP.QueueId) -> q -> Bool
|
||||
|
||||
data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -728,6 +843,34 @@ data SMPQueueAddress = SMPQueueAddress
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance SMPQueue SMPQueueUri where
|
||||
qServer SMPQueueUri {queueAddress} = qServer queueAddress
|
||||
{-# INLINE qServer #-}
|
||||
qAddress SMPQueueUri {queueAddress} = qAddress queueAddress
|
||||
{-# INLINE qAddress #-}
|
||||
sameQueue addr q = sameQAddress addr (qAddress q)
|
||||
{-# INLINE sameQueue #-}
|
||||
|
||||
instance SMPQueue SMPQueueInfo where
|
||||
qServer SMPQueueInfo {queueAddress} = qServer queueAddress
|
||||
{-# INLINE qServer #-}
|
||||
qAddress SMPQueueInfo {queueAddress} = qAddress queueAddress
|
||||
{-# INLINE qAddress #-}
|
||||
sameQueue addr q = sameQAddress addr (qAddress q)
|
||||
{-# INLINE sameQueue #-}
|
||||
|
||||
instance SMPQueue SMPQueueAddress where
|
||||
qServer SMPQueueAddress {smpServer} = smpServer
|
||||
{-# INLINE qServer #-}
|
||||
qAddress SMPQueueAddress {smpServer, senderId} = (smpServer, senderId)
|
||||
{-# INLINE qAddress #-}
|
||||
sameQueue addr q = sameQAddress addr (qAddress q)
|
||||
{-# INLINE sameQueue #-}
|
||||
|
||||
sameQAddress :: (SMPServer, SMP.QueueId) -> (SMPServer, SMP.QueueId) -> Bool
|
||||
sameQAddress (srv, qId) (srv', qId') = sameSrvAddr srv srv' && qId == qId'
|
||||
{-# INLINE sameQAddress #-}
|
||||
|
||||
instance StrEncoding SMPQueueUri where
|
||||
strEncode (SMPQueueUri vr SMPQueueAddress {smpServer = srv, senderId = qId, dhPublicKey})
|
||||
| minVersion vr > 1 = strEncode srv <> "/" <> strEncode qId <> "#/?" <> query queryParams
|
||||
@@ -754,6 +897,13 @@ instance StrEncoding SMPQueueUri where
|
||||
hs_ <- queryParam_ "srv" query
|
||||
pure (vr, maybe [] thList_ hs_, dhKey)
|
||||
|
||||
instance Encoding SMPQueueUri where
|
||||
smpEncode (SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}) =
|
||||
smpEncode (clientVRange, smpServer, senderId, dhPublicKey)
|
||||
smpP = do
|
||||
(clientVRange, smpServer, senderId, dhPublicKey) <- smpP
|
||||
pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}
|
||||
|
||||
data ConnectionRequestUri (m :: ConnectionMode) where
|
||||
CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation
|
||||
-- contact connection request does NOT contain E2E encryption parameters -
|
||||
@@ -963,6 +1113,8 @@ data SMPAgentError
|
||||
A_ENCRYPTION
|
||||
| -- | duplicate message - this error is detected by ratchet decryption - this message will be ignored and not shown
|
||||
A_DUPLICATE
|
||||
| -- | error in the message to add/delete/etc queue in connection
|
||||
A_QUEUE {queueErr :: String}
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON SMPAgentError where
|
||||
@@ -978,6 +1130,7 @@ instance StrEncoding AgentErrorType where
|
||||
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> strP)
|
||||
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
|
||||
<|> "BROKER " *> (BROKER <$> parseRead1)
|
||||
<|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString)
|
||||
<|> "AGENT " *> (AGENT <$> parseRead1)
|
||||
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
|
||||
strEncode = \case
|
||||
@@ -988,6 +1141,7 @@ instance StrEncoding AgentErrorType where
|
||||
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> strEncode e
|
||||
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
|
||||
BROKER e -> "BROKER " <> bshow e
|
||||
AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e
|
||||
AGENT e -> "AGENT " <> bshow e
|
||||
INTERNAL e -> "INTERNAL " <> bshow e
|
||||
|
||||
@@ -1029,12 +1183,14 @@ instance StrEncoding ACmdTag where
|
||||
"DISCONNECT" -> pure $ ACmdTag SAgent DISCONNECT_
|
||||
"DOWN" -> pure $ ACmdTag SAgent DOWN_
|
||||
"UP" -> pure $ ACmdTag SAgent UP_
|
||||
"SWITCH" -> pure $ ACmdTag SAgent SWITCH_
|
||||
"SEND" -> pure $ ACmdTag SClient SEND_
|
||||
"MID" -> pure $ ACmdTag SAgent MID_
|
||||
"SENT" -> pure $ ACmdTag SAgent SENT_
|
||||
"MERR" -> pure $ ACmdTag SAgent MERR_
|
||||
"MSG" -> pure $ ACmdTag SAgent MSG_
|
||||
"ACK" -> pure $ ACmdTag SClient ACK_
|
||||
"SWCH" -> pure $ ACmdTag SClient SWCH_
|
||||
"OFF" -> pure $ ACmdTag SClient OFF_
|
||||
"DEL" -> pure $ ACmdTag SClient DEL_
|
||||
"CHK" -> pure $ ACmdTag SClient CHK_
|
||||
@@ -1062,12 +1218,14 @@ instance APartyI p => StrEncoding (ACommandTag p) where
|
||||
DISCONNECT_ -> "DISCONNECT"
|
||||
DOWN_ -> "DOWN"
|
||||
UP_ -> "UP"
|
||||
SWITCH_ -> "SWITCH"
|
||||
SEND_ -> "SEND"
|
||||
MID_ -> "MID"
|
||||
SENT_ -> "SENT"
|
||||
MERR_ -> "MERR"
|
||||
MSG_ -> "MSG"
|
||||
ACK_ -> "ACK"
|
||||
SWCH_ -> "SWCH"
|
||||
OFF_ -> "OFF"
|
||||
DEL_ -> "DEL"
|
||||
CHK_ -> "CHK"
|
||||
@@ -1097,6 +1255,7 @@ commandP binaryP =
|
||||
SUB_ -> pure SUB
|
||||
SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP)
|
||||
ACK_ -> s (ACK <$> A.decimal)
|
||||
SWCH_ -> pure SWCH
|
||||
OFF_ -> pure OFF
|
||||
DEL_ -> pure DEL
|
||||
CHK_ -> pure CHK
|
||||
@@ -1112,6 +1271,7 @@ commandP binaryP =
|
||||
DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP)
|
||||
DOWN_ -> s (DOWN <$> strP_ <*> connections)
|
||||
UP_ -> s (UP <$> strP_ <*> connections)
|
||||
SWITCH_ -> s (SWITCH <$> strP_ <*> strP)
|
||||
MID_ -> s (MID <$> A.decimal)
|
||||
SENT_ -> s (SENT <$> A.decimal)
|
||||
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
|
||||
@@ -1139,36 +1299,40 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX
|
||||
-- | Serialize SMP agent command.
|
||||
serializeCommand :: ACommand p -> ByteString
|
||||
serializeCommand = \case
|
||||
NEW ntfs cMode -> B.unwords ["NEW", strEncode ntfs, strEncode cMode]
|
||||
INV cReq -> "INV " <> strEncode cReq
|
||||
JOIN ntfs cReq cInfo -> B.unwords ["JOIN", strEncode ntfs, strEncode cReq, serializeBinary cInfo]
|
||||
CONF confId srvs cInfo -> B.unwords ["CONF", confId, strEncodeList srvs, serializeBinary cInfo]
|
||||
LET confId cInfo -> B.unwords ["LET", confId, serializeBinary cInfo]
|
||||
REQ invId srvs cInfo -> B.unwords ["REQ", invId, strEncode srvs, serializeBinary cInfo]
|
||||
ACPT invId cInfo -> B.unwords ["ACPT", invId, serializeBinary cInfo]
|
||||
RJCT invId -> "RJCT " <> invId
|
||||
INFO cInfo -> "INFO " <> serializeBinary cInfo
|
||||
SUB -> "SUB"
|
||||
END -> "END"
|
||||
CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h]
|
||||
DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h]
|
||||
DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns]
|
||||
UP srv conns -> B.unwords ["UP", strEncode srv, connections conns]
|
||||
SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody
|
||||
MID mId -> "MID " <> bshow mId
|
||||
SENT mId -> "SENT " <> bshow mId
|
||||
MERR mId e -> B.unwords ["MERR", bshow mId, strEncode e]
|
||||
MSG msgMeta msgFlags msgBody -> B.unwords ["MSG", serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody]
|
||||
ACK mId -> "ACK " <> bshow mId
|
||||
OFF -> "OFF"
|
||||
DEL -> "DEL"
|
||||
CHK -> "CHK"
|
||||
STAT srvs -> "STAT " <> strEncode srvs
|
||||
CON -> "CON"
|
||||
ERR e -> "ERR " <> strEncode e
|
||||
OK -> "OK"
|
||||
SUSPENDED -> "SUSPENDED"
|
||||
NEW ntfs cMode -> s (NEW_, ntfs, cMode)
|
||||
INV cReq -> s (INV_, cReq)
|
||||
JOIN ntfs cReq cInfo -> s (JOIN_, ntfs, cReq, Str $ serializeBinary cInfo)
|
||||
CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo]
|
||||
LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo]
|
||||
REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo]
|
||||
ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo]
|
||||
RJCT invId -> B.unwords [s RJCT_, invId]
|
||||
INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo]
|
||||
SUB -> s SUB_
|
||||
END -> s END_
|
||||
CONNECT p h -> s (CONNECT_, p, h)
|
||||
DISCONNECT p h -> s (DISCONNECT_, p, h)
|
||||
DOWN srv conns -> B.unwords [s DOWN_, s srv, connections conns]
|
||||
UP srv conns -> B.unwords [s UP_, s srv, connections conns]
|
||||
SWITCH phase srvs -> s (SWITCH_, phase, srvs)
|
||||
SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody]
|
||||
MID mId -> s (MID_, Str $ bshow mId)
|
||||
SENT mId -> s (SENT_, Str $ bshow mId)
|
||||
MERR mId e -> s (MERR_, Str $ bshow mId, e)
|
||||
MSG msgMeta msgFlags msgBody -> B.unwords [s MSG_, serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody]
|
||||
ACK mId -> s (ACK_, Str $ bshow mId)
|
||||
SWCH -> s SWCH_
|
||||
OFF -> s OFF_
|
||||
DEL -> s DEL_
|
||||
CHK -> s CHK_
|
||||
STAT srvs -> s (STAT_, srvs)
|
||||
CON -> s CON_
|
||||
ERR e -> s (ERR_, e)
|
||||
OK -> s OK_
|
||||
SUSPENDED -> s SUSPENDED_
|
||||
where
|
||||
s :: StrEncoding a => a -> ByteString
|
||||
s = strEncode
|
||||
showTs :: UTCTime -> ByteString
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
connections :: [ConnId] -> ByteString
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
@@ -17,6 +18,9 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
import Data.List (find)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Time (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
@@ -43,7 +47,8 @@ import Simplex.Messaging.Version
|
||||
|
||||
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
|
||||
data RcvQueue = RcvQueue
|
||||
{ server :: SMPServer,
|
||||
{ connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
-- | recipient queue ID
|
||||
rcvId :: SMP.RecipientId,
|
||||
-- | key used by the recipient to sign transmissions
|
||||
@@ -55,9 +60,17 @@ data RcvQueue = RcvQueue
|
||||
-- | public sender's DH key and agreed shared DH secret for simple per-queue e2e
|
||||
e2eDhSecret :: Maybe C.DhSecretX25519,
|
||||
-- | sender queue ID
|
||||
sndId :: Maybe SMP.SenderId,
|
||||
sndId :: SMP.SenderId,
|
||||
-- | queue status
|
||||
status :: QueueStatus,
|
||||
-- | database queue ID (within connection), can be Nothing for old queues
|
||||
dbQueueId :: Int64,
|
||||
-- | True for a primary queue of the connection
|
||||
primary :: Bool,
|
||||
-- | True for the next primary queue
|
||||
nextPrimary :: Bool,
|
||||
-- | database queue ID to replace, Nothing if this queue is not replacing another, `Just Nothing` is used for replacing old queues
|
||||
dbReplaceQueueId :: Maybe Int64,
|
||||
-- | SMP client version
|
||||
smpClientVersion :: Version,
|
||||
-- | credentials used in context of notifications
|
||||
@@ -78,7 +91,8 @@ data ClientNtfCreds = ClientNtfCreds
|
||||
|
||||
-- | A send queue. SMP queue through which the agent sends messages to a recipient.
|
||||
data SndQueue = SndQueue
|
||||
{ server :: SMPServer,
|
||||
{ connId :: ConnId,
|
||||
server :: SMPServer,
|
||||
-- | sender queue ID
|
||||
sndId :: SMP.SenderId,
|
||||
-- | key pair used by the sender to sign transmissions
|
||||
@@ -90,11 +104,52 @@ data SndQueue = SndQueue
|
||||
e2eDhSecret :: C.DhSecretX25519,
|
||||
-- | queue status
|
||||
status :: QueueStatus,
|
||||
-- | database queue ID (within connection), can be Nothing for old queues
|
||||
dbQueueId :: Int64,
|
||||
-- | True for a primary queue of the connection
|
||||
primary :: Bool,
|
||||
-- | True for the next primary queue
|
||||
nextPrimary :: Bool,
|
||||
-- | ID of the queue this one is replacing
|
||||
dbReplaceQueueId :: Maybe Int64,
|
||||
-- | SMP client version
|
||||
smpClientVersion :: Version
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance SMPQueue RcvQueue where
|
||||
qServer RcvQueue {server} = server
|
||||
{-# INLINE qServer #-}
|
||||
qAddress RcvQueue {server, rcvId} = (server, rcvId)
|
||||
{-# INLINE qAddress #-}
|
||||
sameQueue addr q = sameQAddress addr (qAddress q)
|
||||
{-# INLINE sameQueue #-}
|
||||
|
||||
instance SMPQueue SndQueue where
|
||||
qServer SndQueue {server} = server
|
||||
{-# INLINE qServer #-}
|
||||
qAddress SndQueue {server, sndId} = (server, sndId)
|
||||
{-# INLINE qAddress #-}
|
||||
sameQueue addr q = sameQAddress addr (qAddress q)
|
||||
{-# INLINE sameQueue #-}
|
||||
|
||||
findQ :: SMPQueue q => (SMPServer, SMP.QueueId) -> NonEmpty q -> Maybe q
|
||||
findQ = find . sameQueue
|
||||
{-# INLINE findQ #-}
|
||||
|
||||
removeQ :: SMPQueue q => (SMPServer, SMP.QueueId) -> NonEmpty q -> Maybe (q, [q])
|
||||
removeQ addr qs = case L.break (sameQueue addr) qs of
|
||||
(_, []) -> Nothing
|
||||
(qs1, q : qs2) -> Just (q, qs1 <> qs2)
|
||||
|
||||
sndAddress :: RcvQueue -> (SMPServer, SMP.SenderId)
|
||||
sndAddress RcvQueue {server, sndId} = (server, sndId)
|
||||
{-# INLINE sndAddress #-}
|
||||
|
||||
findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue
|
||||
findRQ sAddr = find $ sameQAddress sAddr . sndAddress
|
||||
{-# INLINE findRQ #-}
|
||||
|
||||
-- * Connection types
|
||||
|
||||
-- | Type of a connection.
|
||||
@@ -114,13 +169,21 @@ data Connection (d :: ConnType) where
|
||||
NewConnection :: ConnData -> Connection CNew
|
||||
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
|
||||
SndConnection :: ConnData -> SndQueue -> Connection CSnd
|
||||
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
|
||||
DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex
|
||||
ContactConnection :: ConnData -> RcvQueue -> Connection CContact
|
||||
|
||||
deriving instance Eq (Connection d)
|
||||
|
||||
deriving instance Show (Connection d)
|
||||
|
||||
connData :: Connection d -> ConnData
|
||||
connData = \case
|
||||
NewConnection cData -> cData
|
||||
RcvConnection cData _ -> cData
|
||||
SndConnection cData _ -> cData
|
||||
DuplexConnection cData _ _ -> cData
|
||||
ContactConnection cData _ -> cData
|
||||
|
||||
data SConnType :: ConnType -> Type where
|
||||
SCNew :: SConnType CNew
|
||||
SCRcv :: SConnType CRcv
|
||||
@@ -193,6 +256,7 @@ instance StrEncoding AgentCommand where
|
||||
data AgentCommandTag
|
||||
= AClientCommandTag (ACommandTag 'Client)
|
||||
| AInternalCommandTag InternalCommandTag
|
||||
deriving (Show)
|
||||
|
||||
instance StrEncoding AgentCommandTag where
|
||||
strEncode = \case
|
||||
@@ -208,12 +272,16 @@ data InternalCommand
|
||||
| ICAckDel SMP.RecipientId MsgId InternalId
|
||||
| ICAllowSecure SMP.RecipientId SMP.SndPublicVerifyKey
|
||||
| ICDuplexSecure SMP.RecipientId SMP.SndPublicVerifyKey
|
||||
| ICQSecure SMP.RecipientId SMP.SndPublicVerifyKey
|
||||
| ICQDelete SMP.RecipientId
|
||||
|
||||
data InternalCommandTag
|
||||
= ICAck_
|
||||
| ICAckDel_
|
||||
| ICAllowSecure_
|
||||
| ICDuplexSecure_
|
||||
| ICQSecure_
|
||||
| ICQDelete_
|
||||
deriving (Show)
|
||||
|
||||
instance StrEncoding InternalCommand where
|
||||
@@ -222,12 +290,16 @@ instance StrEncoding InternalCommand where
|
||||
ICAckDel rId srvMsgId mId -> strEncode (ICAckDel_, rId, srvMsgId, mId)
|
||||
ICAllowSecure rId sndKey -> strEncode (ICAllowSecure_, rId, sndKey)
|
||||
ICDuplexSecure rId sndKey -> strEncode (ICDuplexSecure_, rId, sndKey)
|
||||
ICQSecure rId senderKey -> strEncode (ICQSecure_, rId, senderKey)
|
||||
ICQDelete rId -> strEncode (ICQDelete_, rId)
|
||||
strP =
|
||||
strP_ >>= \case
|
||||
ICAck_ -> ICAck <$> strP_ <*> strP
|
||||
ICAckDel_ -> ICAckDel <$> strP_ <*> strP_ <*> strP
|
||||
ICAllowSecure_ -> ICAllowSecure <$> strP_ <*> strP
|
||||
ICDuplexSecure_ -> ICDuplexSecure <$> strP_ <*> strP
|
||||
strP >>= \case
|
||||
ICAck_ -> ICAck <$> _strP <*> _strP
|
||||
ICAckDel_ -> ICAckDel <$> _strP <*> _strP <*> _strP
|
||||
ICAllowSecure_ -> ICAllowSecure <$> _strP <*> _strP
|
||||
ICDuplexSecure_ -> ICDuplexSecure <$> _strP <*> _strP
|
||||
ICQSecure_ -> ICQSecure <$> _strP <*> _strP
|
||||
ICQDelete_ -> ICQDelete <$> _strP
|
||||
|
||||
instance StrEncoding InternalCommandTag where
|
||||
strEncode = \case
|
||||
@@ -235,12 +307,16 @@ instance StrEncoding InternalCommandTag where
|
||||
ICAckDel_ -> "ACK_DEL"
|
||||
ICAllowSecure_ -> "ALLOW_SECURE"
|
||||
ICDuplexSecure_ -> "DUPLEX_SECURE"
|
||||
ICQSecure_ -> "QSECURE"
|
||||
ICQDelete_ -> "QDELETE"
|
||||
strP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"ACK" -> pure ICAck_
|
||||
"ACK_DEL" -> pure ICAckDel_
|
||||
"ALLOW_SECURE" -> pure ICAllowSecure_
|
||||
"DUPLEX_SECURE" -> pure ICDuplexSecure_
|
||||
"QSECURE" -> pure ICQSecure_
|
||||
"QDELETE" -> pure ICQDelete_
|
||||
_ -> fail "bad InternalCommandTag"
|
||||
|
||||
agentCommandTag :: AgentCommand -> AgentCommandTag
|
||||
@@ -254,6 +330,8 @@ internalCmdTag = \case
|
||||
ICAckDel {} -> ICAckDel_
|
||||
ICAllowSecure {} -> ICAllowSecure_
|
||||
ICDuplexSecure {} -> ICDuplexSecure_
|
||||
ICQSecure {} -> ICQSecure_
|
||||
ICQDelete _ -> ICQDelete_
|
||||
|
||||
-- * Confirmation types
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
@@ -38,9 +39,16 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
deleteConn,
|
||||
upgradeRcvConnToDuplex,
|
||||
upgradeSndConnToDuplex,
|
||||
addConnRcvQueue,
|
||||
addConnSndQueue,
|
||||
setRcvQueueStatus,
|
||||
setRcvQueueConfirmedE2E,
|
||||
setSndQueueStatus,
|
||||
setRcvQueuePrimary,
|
||||
setSndQueuePrimary,
|
||||
deleteConnRcvQueue,
|
||||
deleteConnSndQueue,
|
||||
getPrimaryRcvQueue,
|
||||
getRcvQueue,
|
||||
setRcvQueueNtfCreds,
|
||||
-- Confirmations
|
||||
@@ -60,11 +68,14 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
createRcvMsg,
|
||||
updateSndIds,
|
||||
createSndMsg,
|
||||
createSndMsgDelivery,
|
||||
getPendingMsgData,
|
||||
getPendingMsgs,
|
||||
deletePendingMsgs,
|
||||
setMsgUserAck,
|
||||
getLastMsg,
|
||||
deleteMsg,
|
||||
deleteSndMsgDelivery,
|
||||
-- Double ratchet persistence
|
||||
createRatchetX3dhKeys,
|
||||
getRatchetX3dhKeys,
|
||||
@@ -120,10 +131,12 @@ import Data.Function (on)
|
||||
import Data.Functor (($>))
|
||||
import Data.IORef
|
||||
import Data.Int (Int64)
|
||||
import Data.List (foldl', groupBy)
|
||||
import Data.List (foldl', groupBy, sortBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromMaybe, listToMaybe)
|
||||
import Data.Ord (Down (..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
@@ -134,6 +147,7 @@ import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import qualified Database.SQLite3 as SQLite3
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations (Migration)
|
||||
@@ -295,45 +309,39 @@ createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ())
|
||||
updateNewConnRcv db connId rq@RcvQueue {server} =
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
|
||||
updateNewConnRcv db connId rq =
|
||||
getConn db connId $>>= \case
|
||||
(SomeConn _ NewConnection {}) -> updateConn
|
||||
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
where
|
||||
updateConn :: IO (Either StoreError ())
|
||||
updateConn = do
|
||||
upsertServer_ db server
|
||||
insertRcvQueue_ db connId rq
|
||||
pure $ Right ()
|
||||
updateConn :: IO (Either StoreError Int64)
|
||||
updateConn = Right <$> addConnRcvQueue_ db connId rq
|
||||
|
||||
updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ())
|
||||
updateNewConnSnd db connId sq@SndQueue {server} =
|
||||
updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
|
||||
updateNewConnSnd db connId sq =
|
||||
getConn db connId $>>= \case
|
||||
(SomeConn _ NewConnection {}) -> updateConn
|
||||
(SomeConn _ SndConnection {}) -> updateConn -- to allow retries
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
where
|
||||
updateConn :: IO (Either StoreError ())
|
||||
updateConn = do
|
||||
upsertServer_ db server
|
||||
insertSndQueue_ db connId sq
|
||||
pure $ Right ()
|
||||
updateConn :: IO (Either StoreError Int64)
|
||||
updateConn = Right <$> addConnSndQueue_ db connId sq
|
||||
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
insertRcvQueue_ db connId q
|
||||
void $ insertRcvQueue_ db connId q
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId)
|
||||
createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
upsertServer_ db server
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
insertSndQueue_ db connId q
|
||||
void $ insertSndQueue_ db connId q
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError SomeConn)
|
||||
getRcvConn db ProtocolServer {host, port} rcvId =
|
||||
@@ -356,25 +364,43 @@ deleteConn db connId =
|
||||
"DELETE FROM connections WHERE conn_id = :conn_id;"
|
||||
[":conn_id" := connId]
|
||||
|
||||
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ())
|
||||
upgradeRcvConnToDuplex db connId sq@SndQueue {server} =
|
||||
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
|
||||
upgradeRcvConnToDuplex db connId sq =
|
||||
getConn db connId $>>= \case
|
||||
(SomeConn _ RcvConnection {}) -> do
|
||||
upsertServer_ db server
|
||||
insertSndQueue_ db connId sq
|
||||
pure $ Right ()
|
||||
(SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
|
||||
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ())
|
||||
upgradeSndConnToDuplex db connId rq@RcvQueue {server} =
|
||||
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
|
||||
upgradeSndConnToDuplex db connId rq =
|
||||
getConn db connId >>= \case
|
||||
Right (SomeConn _ SndConnection {}) -> do
|
||||
upsertServer_ db server
|
||||
insertRcvQueue_ db connId rq
|
||||
pure $ Right ()
|
||||
Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
addConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64)
|
||||
addConnRcvQueue db connId rq =
|
||||
getConn db connId >>= \case
|
||||
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
addConnRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
||||
addConnRcvQueue_ db connId rq@RcvQueue {server} = do
|
||||
upsertServer_ db server
|
||||
insertRcvQueue_ db connId rq
|
||||
|
||||
addConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64)
|
||||
addConnSndQueue db connId sq =
|
||||
getConn db connId >>= \case
|
||||
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
addConnSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
||||
addConnSndQueue_ db connId sq@SndQueue {server} = do
|
||||
upsertServer_ db server
|
||||
insertSndQueue_ db connId sq
|
||||
|
||||
setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
|
||||
setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status =
|
||||
-- ? return error if queue does not exist?
|
||||
@@ -418,9 +444,39 @@ setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} stat
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId]
|
||||
|
||||
getRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue)
|
||||
getRcvQueue db connId =
|
||||
maybe (Left SEConnNotFound) Right <$> getRcvQueueByConnId_ db connId
|
||||
setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do
|
||||
DB.execute db "UPDATE rcv_queues SET rcv_primary = ?, next_rcv_primary = ? WHERE conn_id = ?" (False, False, connId)
|
||||
DB.execute
|
||||
db
|
||||
"UPDATE rcv_queues SET rcv_primary = ?, next_rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?"
|
||||
(True, False, Nothing :: Maybe Int64, connId, dbQueueId)
|
||||
|
||||
setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
setSndQueuePrimary db connId SndQueue {dbQueueId} = do
|
||||
DB.execute db "UPDATE snd_queues SET snd_primary = ?, next_snd_primary = ? WHERE conn_id = ?" (False, False, connId)
|
||||
DB.execute
|
||||
db
|
||||
"UPDATE snd_queues SET snd_primary = ?, next_snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?"
|
||||
(True, False, Nothing :: Maybe Int64, connId, dbQueueId)
|
||||
|
||||
deleteConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
deleteConnRcvQueue db connId RcvQueue {dbQueueId} =
|
||||
DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId)
|
||||
|
||||
deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
deleteConnSndQueue db connId SndQueue {dbQueueId} = do
|
||||
DB.execute db "DELETE FROM snd_queues WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
||||
DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
||||
|
||||
getPrimaryRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue)
|
||||
getPrimaryRcvQueue db connId =
|
||||
maybe (Left SEConnNotFound) (Right . L.head) <$> getRcvQueuesByConnId_ db connId
|
||||
|
||||
getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue)
|
||||
getRcvQueue db connId (SMPServer host port _) rcvId =
|
||||
firstRow (toRcvQueue connId) SEConnNotFound $
|
||||
DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ?") (connId, host, port, rcvId)
|
||||
|
||||
setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO ()
|
||||
setRcvQueueNtfCreds db connId clientNtfCreds =
|
||||
@@ -587,10 +643,10 @@ updateRcvIds db connId = do
|
||||
updateLastIdsRcv_ db connId internalId internalRcvId
|
||||
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
createRcvMsg :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
createRcvMsg db connId rcvMsgData = do
|
||||
createRcvMsg :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
|
||||
createRcvMsg db connId rq rcvMsgData = do
|
||||
insertRcvMsgBase_ db connId rcvMsgData
|
||||
insertRcvMsgDetails_ db connId rcvMsgData
|
||||
insertRcvMsgDetails_ db connId rq rcvMsgData
|
||||
updateHashRcv_ db connId rcvMsgData
|
||||
|
||||
updateSndIds :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
@@ -607,9 +663,13 @@ createSndMsg db connId sndMsgData = do
|
||||
insertSndMsgDetails_ db connId sndMsgData
|
||||
updateHashSnd_ db connId sndMsgData
|
||||
|
||||
createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO ()
|
||||
createSndMsgDelivery db connId SndQueue {dbQueueId} msgId =
|
||||
DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId)
|
||||
|
||||
getPendingMsgData :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (Maybe RcvQueue, PendingMsgData))
|
||||
getPendingMsgData db connId msgId = do
|
||||
rq_ <- getRcvQueueByConnId_ db connId
|
||||
rq_ <- L.head <$$> getRcvQueuesByConnId_ db connId
|
||||
(rq_,) <$$> firstRow pendingMsgData SEMsgNotFound getMsgData_
|
||||
where
|
||||
getMsgData_ =
|
||||
@@ -627,16 +687,23 @@ getPendingMsgData db connId msgId = do
|
||||
let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_
|
||||
in PendingMsgData {msgId, msgType, msgFlags, msgBody, internalTs}
|
||||
|
||||
getPendingMsgs :: DB.Connection -> ConnId -> IO [InternalId]
|
||||
getPendingMsgs db connId =
|
||||
getPendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO [InternalId]
|
||||
getPendingMsgs db connId SndQueue {dbQueueId} =
|
||||
map fromOnly
|
||||
<$> DB.query db "SELECT internal_id FROM snd_messages WHERE conn_id = ?" (Only connId)
|
||||
<$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
||||
|
||||
setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError SMP.MsgId)
|
||||
setMsgUserAck db connId agentMsgId = do
|
||||
DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId)
|
||||
firstRow fromOnly SEMsgNotFound $
|
||||
DB.query db "SELECT broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
|
||||
deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
deletePendingMsgs db connId SndQueue {dbQueueId} =
|
||||
DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId)
|
||||
|
||||
setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (RcvQueue, SMP.MsgId))
|
||||
setMsgUserAck db connId agentMsgId = runExceptT $ do
|
||||
liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId)
|
||||
(dbRcvId, srvMsgId) <-
|
||||
ExceptT . firstRow id SEMsgNotFound $
|
||||
DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
|
||||
rq <- ExceptT $ getRcvQueueById_ db connId dbRcvId
|
||||
pure (rq, srvMsgId)
|
||||
|
||||
getLastMsg :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Maybe RcvMsg)
|
||||
getLastMsg db connId msgId =
|
||||
@@ -662,6 +729,15 @@ deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO ()
|
||||
deleteMsg db connId msgId =
|
||||
DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId)
|
||||
|
||||
deleteSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO ()
|
||||
deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId = do
|
||||
DB.execute
|
||||
db
|
||||
"DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? AND internal_id = ?"
|
||||
(connId, dbQueueId, msgId)
|
||||
(Only (cnt :: Int) : _) <- DB.query db "SELECT count(*) FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ?" (connId, msgId)
|
||||
when (cnt == 0) $ deleteMsg db connId msgId
|
||||
|
||||
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO ()
|
||||
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 =
|
||||
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2)
|
||||
@@ -1202,27 +1278,35 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn connId RcvQueue {..} = do
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64
|
||||
insertRcvQueue_ db connId' RcvQueue {..} = do
|
||||
qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')
|
||||
DB.execute
|
||||
dbConn
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO rcv_queues
|
||||
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?);
|
||||
(host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, next_rcv_primary, replace_rcv_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
((host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion))
|
||||
((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion))
|
||||
pure qId
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn connId SndQueue {..} = do
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64
|
||||
insertSndQueue_ db connId' SndQueue {..} = do
|
||||
qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')
|
||||
DB.execute
|
||||
dbConn
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO snd_queues
|
||||
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?);
|
||||
(host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, next_snd_primary, replace_snd_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
(host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion)
|
||||
((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion))
|
||||
pure qId
|
||||
|
||||
newQueueId_ :: [Only Int64] -> Int64
|
||||
newQueueId_ [] = 1
|
||||
newQueueId_ (Only maxId : _) = maxId + 1
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
@@ -1230,65 +1314,79 @@ getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getConn dbConn connId =
|
||||
getConnData dbConn connId >>= \case
|
||||
Nothing -> pure $ Left SEConnNotFound
|
||||
Just (connData, cMode) -> do
|
||||
rQ <- getRcvQueueByConnId_ dbConn connId
|
||||
sQ <- getSndQueueByConnId_ dbConn connId
|
||||
Just (cData, cMode) -> do
|
||||
rQ <- getRcvQueuesByConnId_ dbConn connId
|
||||
sQ <- getSndQueuesByConnId_ dbConn connId
|
||||
pure $ case (rQ, sQ, cMode) of
|
||||
(Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
|
||||
(Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
|
||||
(Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
|
||||
(Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection connData)
|
||||
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
||||
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
||||
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
||||
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData dbConn connId' =
|
||||
connData
|
||||
<$> DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId')
|
||||
maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId')
|
||||
where
|
||||
connData [(connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake)] = Just (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode)
|
||||
connData _ = Nothing
|
||||
cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode)
|
||||
|
||||
getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
|
||||
getRcvQueueByConnId_ dbConn connId =
|
||||
listToMaybe . map rcvQueue
|
||||
-- | returns all connection queues, the first queue is the primary one
|
||||
getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue))
|
||||
getRcvQueuesByConnId_ db connId =
|
||||
L.nonEmpty . sortBy primaryFirst . map (toRcvQueue connId)
|
||||
<$> DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ?") (Only connId)
|
||||
where
|
||||
primaryFirst RcvQueue {primary = p} RcvQueue {primary = p'} = compare (Down p) (Down p')
|
||||
|
||||
rcvQueueQuery :: Query
|
||||
rcvQueueQuery =
|
||||
[sql|
|
||||
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status,
|
||||
q.rcv_queue_id, q.rcv_primary, q.next_rcv_primary, q.replace_rcv_queue_id, q.smp_client_version,
|
||||
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
|]
|
||||
|
||||
toRcvQueue ::
|
||||
ConnId ->
|
||||
(C.KeyHash, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
||||
:. (Int64, Bool, Bool, Maybe Int64, Maybe Version)
|
||||
:. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
|
||||
RcvQueue
|
||||
toRcvQueue connId ((keyHash, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
let server = SMPServer host port keyHash
|
||||
smpClientVersion = fromMaybe 1 smpClientVersion_
|
||||
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
|
||||
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
||||
_ -> Nothing
|
||||
in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion, clientNtfCreds}
|
||||
|
||||
getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue)
|
||||
getRcvQueueById_ db connId dbRcvId =
|
||||
firstRow (toRcvQueue connId) SEConnNotFound $
|
||||
DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId)
|
||||
|
||||
-- | returns all connection queues, the first queue is the primary one
|
||||
getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue))
|
||||
getSndQueuesByConnId_ dbConn connId =
|
||||
L.nonEmpty . sortBy primaryFirst . map sndQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret,
|
||||
q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status, q.smp_client_version,
|
||||
q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.conn_id = ?;
|
||||
|]
|
||||
(Only connId)
|
||||
where
|
||||
rcvQueue ((keyHash, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
let server = SMPServer host port keyHash
|
||||
smpClientVersion = fromMaybe 1 smpClientVersion_
|
||||
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
|
||||
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
||||
_ -> Nothing
|
||||
in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion, clientNtfCreds}
|
||||
|
||||
getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
|
||||
getSndQueueByConnId_ dbConn connId =
|
||||
sndQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.smp_client_version
|
||||
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.next_snd_primary, q.replace_snd_queue_id, q.smp_client_version
|
||||
FROM snd_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.conn_id = ?;
|
||||
|]
|
||||
(Only connId)
|
||||
where
|
||||
sndQueue [(keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion)] =
|
||||
sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion)) =
|
||||
let server = SMPServer host port keyHash
|
||||
in Just SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion}
|
||||
sndQueue _ = Nothing
|
||||
in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion}
|
||||
primaryFirst SndQueue {primary = p} SndQueue {primary = p'} = compare (Down p) (Down p')
|
||||
|
||||
-- * updateRcvIds helpers
|
||||
|
||||
@@ -1342,22 +1440,23 @@ insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody,
|
||||
":msg_body" := msgBody
|
||||
]
|
||||
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connId RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do
|
||||
let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO rcv_messages
|
||||
( conn_id, internal_rcv_id, internal_id, external_snd_id,
|
||||
( conn_id, rcv_queue_id, internal_rcv_id, internal_id, external_snd_id,
|
||||
broker_id, broker_ts,
|
||||
internal_hash, external_prev_snd_hash, integrity)
|
||||
VALUES
|
||||
(:conn_id,:internal_rcv_id,:internal_id,:external_snd_id,
|
||||
(:conn_id,:rcv_queue_id,:internal_rcv_id,:internal_id,:external_snd_id,
|
||||
:broker_id,:broker_ts,
|
||||
:internal_hash,:external_prev_snd_hash,:integrity);
|
||||
|]
|
||||
[ ":conn_id" := connId,
|
||||
":rcv_queue_id" := dbQueueId,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":internal_id" := fst recipient,
|
||||
":external_snd_id" := sndMsgId,
|
||||
|
||||
@@ -36,6 +36,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
|
||||
@@ -51,7 +52,8 @@ schemaMigrations =
|
||||
("m20220625_v2_ntf_mode", m20220625_v2_ntf_mode),
|
||||
("m20220811_onion_hosts", m20220811_onion_hosts),
|
||||
("m20220817_connection_ntfs", m20220817_connection_ntfs),
|
||||
("m20220905_commands", m20220905_commands)
|
||||
("m20220905_commands", m20220905_commands),
|
||||
("m20220915_connection_queues", m20220915_connection_queues)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20220915_connection_queues :: Query
|
||||
m20220915_connection_queues =
|
||||
[sql|
|
||||
PRAGMA ignore_check_constraints=ON;
|
||||
|
||||
-- rcv_queues
|
||||
ALTER TABLE rcv_queues ADD COLUMN rcv_queue_id INTEGER CHECK (rcv_queue_id NOT NULL);
|
||||
UPDATE rcv_queues SET rcv_queue_id = 1;
|
||||
CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues (conn_id, rcv_queue_id);
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN rcv_primary INTEGER CHECK (rcv_primary NOT NULL);
|
||||
UPDATE rcv_queues SET rcv_primary = 1;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN next_rcv_primary INTEGER CHECK (next_rcv_primary NOT NULL);
|
||||
UPDATE rcv_queues SET next_rcv_primary = 0;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN replace_rcv_queue_id INTEGER NULL;
|
||||
|
||||
-- snd_queues
|
||||
ALTER TABLE snd_queues ADD COLUMN snd_queue_id INTEGER CHECK (snd_queue_id NOT NULL);
|
||||
UPDATE snd_queues SET snd_queue_id = 1;
|
||||
CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues (conn_id, snd_queue_id);
|
||||
|
||||
ALTER TABLE snd_queues ADD COLUMN snd_primary INTEGER CHECK (snd_primary NOT NULL);
|
||||
UPDATE snd_queues SET snd_primary = 1;
|
||||
|
||||
ALTER TABLE snd_queues ADD COLUMN next_snd_primary INTEGER CHECK (next_snd_primary NOT NULL);
|
||||
UPDATE snd_queues SET next_snd_primary = 0;
|
||||
|
||||
ALTER TABLE snd_queues ADD COLUMN replace_snd_queue_id INTEGER NULL;
|
||||
|
||||
-- messages
|
||||
CREATE TABLE snd_message_deliveries (
|
||||
snd_message_delivery_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
snd_queue_id INTEGER NOT NULL,
|
||||
internal_id INTEGER NOT NULL,
|
||||
FOREIGN KEY (conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED
|
||||
);
|
||||
|
||||
CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries (conn_id, snd_queue_id);
|
||||
|
||||
ALTER TABLE rcv_messages ADD COLUMN rcv_queue_id INTEGER CHECK (rcv_queue_id NOT NULL);
|
||||
UPDATE rcv_messages SET rcv_queue_id = 1;
|
||||
|
||||
PRAGMA ignore_check_constraints=OFF;
|
||||
|]
|
||||
@@ -41,6 +41,10 @@ CREATE TABLE rcv_queues(
|
||||
ntf_private_key BLOB,
|
||||
ntf_id BLOB,
|
||||
rcv_ntf_dh_secret BLOB,
|
||||
rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL),
|
||||
rcv_primary INTEGER CHECK(rcv_primary NOT NULL),
|
||||
next_rcv_primary INTEGER CHECK(next_rcv_primary NOT NULL),
|
||||
replace_rcv_queue_id INTEGER NULL,
|
||||
PRIMARY KEY(host, port, rcv_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
@@ -58,6 +62,10 @@ CREATE TABLE snd_queues(
|
||||
smp_client_version INTEGER NOT NULL DEFAULT 1,
|
||||
snd_public_key BLOB,
|
||||
e2e_pub_key BLOB,
|
||||
snd_queue_id INTEGER CHECK(snd_queue_id NOT NULL),
|
||||
snd_primary INTEGER CHECK(snd_primary NOT NULL),
|
||||
next_snd_primary INTEGER CHECK(next_snd_primary NOT NULL),
|
||||
replace_snd_queue_id INTEGER NULL,
|
||||
PRIMARY KEY(host, port, snd_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
@@ -89,6 +97,7 @@ CREATE TABLE rcv_messages(
|
||||
external_prev_snd_hash BLOB NOT NULL,
|
||||
integrity BLOB NOT NULL,
|
||||
user_ack INTEGER NULL DEFAULT 0,
|
||||
rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL),
|
||||
PRIMARY KEY(conn_id, internal_rcv_id),
|
||||
FOREIGN KEY(conn_id, internal_id) REFERENCES messages
|
||||
ON DELETE CASCADE
|
||||
@@ -206,3 +215,17 @@ CREATE TABLE commands(
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id);
|
||||
CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id);
|
||||
CREATE TABLE snd_message_deliveries(
|
||||
snd_message_delivery_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
snd_queue_id INTEGER NOT NULL,
|
||||
internal_id INTEGER NOT NULL,
|
||||
FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED
|
||||
);
|
||||
CREATE TABLE sqlite_sequence(name,seq);
|
||||
CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries(
|
||||
conn_id,
|
||||
snd_queue_id
|
||||
);
|
||||
|
||||
50
src/Simplex/Messaging/Agent/TRcvQueues.hs
Normal file
50
src/Simplex/Messaging/Agent/TRcvQueues.hs
Normal file
@@ -0,0 +1,50 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Agent.TRcvQueues where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Simplex.Messaging.Agent.Protocol (ConnId)
|
||||
import Simplex.Messaging.Agent.Store (RcvQueue (..))
|
||||
import Simplex.Messaging.Protocol (RecipientId, SMPServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
|
||||
newtype TRcvQueues = TRcvQueues (TMap (SMPServer, RecipientId) RcvQueue)
|
||||
|
||||
empty :: STM TRcvQueues
|
||||
empty = TRcvQueues <$> TM.empty
|
||||
|
||||
clear :: TRcvQueues -> STM ()
|
||||
clear (TRcvQueues qs) = TM.clear qs
|
||||
|
||||
deleteConn :: ConnId -> TRcvQueues -> STM ()
|
||||
deleteConn cId (TRcvQueues qs) = modifyTVar' qs $ M.filter (\rq -> cId /= connId rq)
|
||||
|
||||
-- TODO possibly, should be limited to active queues
|
||||
hasConn :: ConnId -> TRcvQueues -> STM Bool
|
||||
hasConn cId (TRcvQueues qs) = any (\rq -> cId == connId rq) <$> readTVar qs
|
||||
|
||||
-- TODO possibly, should be limited to active queues
|
||||
getConns :: TRcvQueues -> STM (Set ConnId)
|
||||
getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs
|
||||
|
||||
addQueue :: RcvQueue -> TRcvQueues -> STM ()
|
||||
addQueue rq@RcvQueue {server, rcvId} (TRcvQueues qs) = TM.insert (server, rcvId) rq qs
|
||||
|
||||
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
|
||||
deleteQueue RcvQueue {server, rcvId} (TRcvQueues qs) = TM.delete (server, rcvId) qs
|
||||
|
||||
getSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
|
||||
getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
|
||||
where
|
||||
addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs'
|
||||
|
||||
getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue]
|
||||
getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty)
|
||||
where
|
||||
addQ (removed, qs') rq@RcvQueue {server, rcvId}
|
||||
| srv == server = (rq : removed, qs')
|
||||
| otherwise = (removed, M.insert (server, rcvId) rq qs')
|
||||
@@ -6,6 +6,7 @@ module Simplex.Messaging.Encoding.String
|
||||
StrEncoding (..),
|
||||
Str (..),
|
||||
strP_,
|
||||
_strP,
|
||||
strToJSON,
|
||||
strToJEncoding,
|
||||
strParseJSON,
|
||||
@@ -171,6 +172,9 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin
|
||||
strP_ :: StrEncoding a => Parser a
|
||||
strP_ = strP <* A.space
|
||||
|
||||
_strP :: StrEncoding a => Parser a
|
||||
_strP = A.space *> strP
|
||||
|
||||
strToJSON :: StrEncoding a => a -> J.Value
|
||||
strToJSON = J.String . decodeLatin1 . strEncode
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ module Simplex.Messaging.Protocol
|
||||
legacyEncodeServer,
|
||||
legacyServerP,
|
||||
legacyStrEncodeServer,
|
||||
sameSrvAddr,
|
||||
|
||||
-- * TCP transport functions
|
||||
tPut,
|
||||
@@ -572,6 +573,10 @@ pattern NtfServer host port keyHash = ProtocolServer SPNTF host port keyHash
|
||||
|
||||
{-# COMPLETE NtfServer #-}
|
||||
|
||||
sameSrvAddr :: ProtocolServer p -> ProtocolServer p -> Bool
|
||||
sameSrvAddr ProtocolServer {host, port} ProtocolServer {host = h', port = p'} = host == h' && port == p'
|
||||
{-# INLINE sameSrvAddr #-}
|
||||
|
||||
data ProtocolType = PSMP | PNTF
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.TMap2
|
||||
( TMap2,
|
||||
empty,
|
||||
clear,
|
||||
Simplex.Messaging.TMap2.lookup,
|
||||
lookup1,
|
||||
member,
|
||||
insert,
|
||||
insert1,
|
||||
delete,
|
||||
lookupDelete1,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (forM_, (>=>))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (whenM, ($>>=))
|
||||
|
||||
-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys.
|
||||
-- It allows direct access via k1 to a group of k2 values and via k2 to one value
|
||||
data TMap2 k1 k2 a = TMap2
|
||||
{ _m1 :: TMap k1 (TMap k2 a),
|
||||
_m2 :: TMap k2 k1
|
||||
}
|
||||
|
||||
empty :: STM (TMap2 k1 k2 a)
|
||||
empty = TMap2 <$> TM.empty <*> TM.empty
|
||||
|
||||
clear :: TMap2 k1 k2 a -> STM ()
|
||||
clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2
|
||||
|
||||
lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a)
|
||||
lookup k2 TMap2 {_m1, _m2} = do
|
||||
TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2
|
||||
|
||||
lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
|
||||
lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1
|
||||
{-# INLINE lookup1 #-}
|
||||
|
||||
member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool
|
||||
member k2 TMap2 {_m2} = TM.member k2 _m2
|
||||
{-# INLINE member #-}
|
||||
|
||||
insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM ()
|
||||
insert k1 k2 v TMap2 {_m1, _m2} =
|
||||
TM.lookup k2 _m2 >>= \case
|
||||
Just k1'
|
||||
| k1 == k1' -> _insert1
|
||||
| otherwise -> _delete1 k1' k2 _m1 >> _insert2
|
||||
_ -> _insert2
|
||||
where
|
||||
_insert1 =
|
||||
TM.lookup k1 _m1 >>= \case
|
||||
Just m -> TM.insert k2 v m
|
||||
_ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1
|
||||
_insert2 = TM.insert k2 k1 _m2 >> _insert1
|
||||
|
||||
insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM ()
|
||||
insert1 k1 m' TMap2 {_m1, _m2} =
|
||||
TM.lookup k1 _m1 >>= \case
|
||||
Just m -> readTVar m' >>= (`TM.union` m)
|
||||
_ -> TM.insert k1 m' _m1
|
||||
|
||||
delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM ()
|
||||
delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1)
|
||||
|
||||
_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM ()
|
||||
_delete1 k1 k2 m1 =
|
||||
TM.lookup k1 m1
|
||||
>>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1))
|
||||
|
||||
lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
|
||||
lookupDelete1 k1 TMap2 {_m1, _m2} = do
|
||||
m_ <- TM.lookupDelete k1 _m1
|
||||
forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet
|
||||
pure m_
|
||||
@@ -21,14 +21,15 @@ import Control.Concurrent (killThread, threadDelay)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import SMPAgentClient
|
||||
import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
|
||||
import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
@@ -115,6 +116,15 @@ functionalAPITests t = do
|
||||
testAsyncCommandsRestore t
|
||||
it "should accept connection using async command" $
|
||||
withSmpServer t testAcceptContactAsync
|
||||
describe "Queue rotation" $ do
|
||||
it "should switch delivery to the new queue (1 server)" $
|
||||
withSmpServer t $ testSwitchConnection initAgentServers
|
||||
it "should switch delivery to the new queue (2 servers)" $
|
||||
withSmpServer t . withSmpServerOn t testPort2 $ testSwitchConnection initAgentServers2
|
||||
it "should switch to new queue asynchronously (1 server)" $
|
||||
withSmpServer t $ testSwitchAsync initAgentServers
|
||||
it "should switch to new queue asynchronously (2 servers)" $
|
||||
withSmpServer t . withSmpServerOn t testPort2 $ testSwitchAsync initAgentServers2
|
||||
|
||||
testMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testMatrix2 t runTest = do
|
||||
@@ -638,6 +648,76 @@ testAcceptContactAsync = do
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
testSwitchConnection :: InitialAgentServers -> IO ()
|
||||
testSwitchConnection servers = do
|
||||
a <- getSMPAgentClient agentCfg servers
|
||||
b <- getSMPAgentClient agentCfg {database = testDB2} servers
|
||||
Right () <- runExceptT $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
switchConnectionAsync a "" bId
|
||||
phase a bId SPStarted
|
||||
phase b aId SPStarted
|
||||
phase a bId SPConfirmed
|
||||
phase b aId SPConfirmed
|
||||
phase a bId SPTested
|
||||
phase b aId SPTested
|
||||
phase b aId SPCompleted
|
||||
phase a bId SPCompleted
|
||||
exchangeGreetingsMsgId 12 a bId b aId
|
||||
pure ()
|
||||
|
||||
phase :: AgentClient -> ByteString -> SwitchPhase -> ExceptT AgentErrorType IO ()
|
||||
phase c connId p =
|
||||
get c >>= \(_, connId', msg) -> do
|
||||
liftIO $ connId `shouldBe` connId'
|
||||
case msg of
|
||||
SWITCH p' _ -> liftIO $ p `shouldBe` p'
|
||||
ERR (AGENT A_DUPLICATE) -> phase c connId p
|
||||
r -> do
|
||||
liftIO . putStrLn $ "expected: " <> show p <> ", received: " <> show r
|
||||
SWITCH _ _ <- pure r
|
||||
pure ()
|
||||
|
||||
testSwitchAsync :: InitialAgentServers -> IO ()
|
||||
testSwitchAsync servers = do
|
||||
Right (aId, bId) <- withA $ \a -> withB $ \b -> runExceptT $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
exchangeGreetingsMsgId 4 a bId b aId
|
||||
pure (aId, bId)
|
||||
let phaseA = withA . phase' bId
|
||||
phaseB = withB . phase' aId
|
||||
Right () <- withA $ \a -> runExceptT $ do
|
||||
subscribeConnection a bId
|
||||
switchConnectionAsync a "" bId
|
||||
phase a bId SPStarted
|
||||
liftIO $ threadDelay 500000
|
||||
phaseB SPStarted
|
||||
phaseA SPConfirmed
|
||||
phaseB SPConfirmed
|
||||
phaseA SPTested
|
||||
Right () <- withB $ \b -> runExceptT $ do
|
||||
subscribeConnection b aId
|
||||
phase b aId SPTested
|
||||
phase b aId SPCompleted
|
||||
phaseA SPCompleted
|
||||
Right () <- withA $ \a -> withB $ \b -> runExceptT $ do
|
||||
subscribeConnection a bId
|
||||
subscribeConnection b aId
|
||||
exchangeGreetingsMsgId 12 a bId b aId
|
||||
pure ()
|
||||
where
|
||||
withAgent :: AgentConfig -> (AgentClient -> IO a) -> IO a
|
||||
withAgent cfg' = bracket (getSMPAgentClient cfg' servers) disconnectAgentClient
|
||||
phase' connId p c = do
|
||||
Right () <- runExceptT $ do
|
||||
subscribeConnection c connId
|
||||
phase c connId p
|
||||
liftIO $ threadDelay 500000
|
||||
pure ()
|
||||
withA = withAgent agentCfg
|
||||
withB = withAgent agentCfg {database = testDB2, initialClientId = 1}
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings = exchangeGreetingsMsgId 4
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
@@ -114,7 +115,7 @@ testConcurrentWrites =
|
||||
runTest st connId = replicateM_ 100 . withTransaction st $ \db -> do
|
||||
(internalId, internalRcvId, _, _) <- updateRcvIds db connId
|
||||
let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy"
|
||||
createRcvMsg db connId rcvMsgData
|
||||
createRcvMsg db connId rcvQueue1 rcvMsgData
|
||||
|
||||
testCompiledThreadsafe :: SpecWith SQLiteStore
|
||||
testCompiledThreadsafe =
|
||||
@@ -153,14 +154,19 @@ testDhSecret = "01234567890123456789012345678901"
|
||||
rcvQueue1 :: RcvQueue
|
||||
rcvQueue1 =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
{ connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
rcvId = "1234",
|
||||
rcvPrivateKey = testPrivateSignKey,
|
||||
rcvDhSecret = testDhSecret,
|
||||
e2ePrivKey = testPrivDhKey,
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = Just "2345",
|
||||
sndId = "2345",
|
||||
status = New,
|
||||
dbQueueId = 1,
|
||||
primary = True,
|
||||
nextPrimary = False,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1,
|
||||
clientNtfCreds = Nothing
|
||||
}
|
||||
@@ -168,13 +174,18 @@ rcvQueue1 =
|
||||
sndQueue1 :: SndQueue
|
||||
sndQueue1 =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
{ connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
sndId = "3456",
|
||||
sndPublicKey = Nothing,
|
||||
sndPrivateKey = testPrivateSignKey,
|
||||
e2ePubKey = Nothing,
|
||||
e2eDhSecret = testDhSecret,
|
||||
status = New,
|
||||
dbQueueId = 1,
|
||||
primary = True,
|
||||
nextPrimary = False,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1
|
||||
}
|
||||
|
||||
@@ -187,21 +198,23 @@ testCreateRcvConn =
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1))
|
||||
upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
`shouldReturn` Right ()
|
||||
`shouldReturn` Right 1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
|
||||
|
||||
testCreateRcvConnRandomId :: SpecWith SQLiteStore
|
||||
testCreateRcvConnRandomId =
|
||||
it "should create RcvConnection and add SndQueue with random ID" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
Right connId <- createRcvConn db g cData1 {connId = ""} rcvQueue1 SCMInvitation
|
||||
let rq' = (rcvQueue1 :: RcvQueue) {connId}
|
||||
sq' = (sndQueue1 :: SndQueue) {connId}
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1))
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rq'))
|
||||
upgradeRcvConnToDuplex db connId sndQueue1
|
||||
`shouldReturn` Right ()
|
||||
`shouldReturn` Right 1
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq'] [sq']))
|
||||
|
||||
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateRcvConnDuplicate =
|
||||
@@ -220,21 +233,23 @@ testCreateSndConn =
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1))
|
||||
upgradeSndConnToDuplex db "conn1" rcvQueue1
|
||||
`shouldReturn` Right ()
|
||||
`shouldReturn` Right 1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
|
||||
|
||||
testCreateSndConnRandomID :: SpecWith SQLiteStore
|
||||
testCreateSndConnRandomID =
|
||||
it "should create SndConnection and add RcvQueue with random ID" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
Right connId <- createSndConn db g cData1 {connId = ""} sndQueue1
|
||||
let rq' = (rcvQueue1 :: RcvQueue) {connId}
|
||||
sq' = (sndQueue1 :: SndQueue) {connId}
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sq'))
|
||||
upgradeSndConnToDuplex db connId rcvQueue1
|
||||
`shouldReturn` Right ()
|
||||
`shouldReturn` Right 1
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq'] [sq']))
|
||||
|
||||
testCreateSndConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateSndConnDuplicate =
|
||||
@@ -287,12 +302,12 @@ testDeleteDuplexConn =
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left (SEConnNotFound)
|
||||
`shouldReturn` Left SEConnNotFound
|
||||
|
||||
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeRcvConnToDuplex =
|
||||
@@ -301,13 +316,18 @@ testUpgradeRcvConnToDuplex =
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
{ connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
sndId = "2345",
|
||||
sndPublicKey = Nothing,
|
||||
sndPrivateKey = testPrivateSignKey,
|
||||
e2ePubKey = Nothing,
|
||||
e2eDhSecret = testDhSecret,
|
||||
status = New,
|
||||
dbQueueId = 1,
|
||||
primary = True,
|
||||
nextPrimary = False,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1
|
||||
}
|
||||
upgradeRcvConnToDuplex db "conn1" anotherSndQueue
|
||||
@@ -323,14 +343,19 @@ testUpgradeSndConnToDuplex =
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
{ connId = "conn1",
|
||||
server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
rcvId = "3456",
|
||||
rcvPrivateKey = testPrivateSignKey,
|
||||
rcvDhSecret = testDhSecret,
|
||||
e2ePrivKey = testPrivDhKey,
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = Just "4567",
|
||||
sndId = "4567",
|
||||
status = New,
|
||||
dbQueueId = 1,
|
||||
primary = True,
|
||||
nextPrimary = False,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1,
|
||||
clientNtfCreds = Nothing
|
||||
}
|
||||
@@ -371,15 +396,17 @@ testSetQueueStatusDuplex =
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1]))
|
||||
setRcvQueueStatus db rcvQueue1 Secured
|
||||
`shouldReturn` ()
|
||||
let rq' = (rcvQueue1 :: RcvQueue) {status = Secured}
|
||||
sq' = (sndQueue1 :: SndQueue) {status = Confirmed}
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq'] [sndQueue1]))
|
||||
setSndQueueStatus db sndQueue1 Confirmed
|
||||
`shouldReturn` ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}))
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq'] [sq']))
|
||||
|
||||
hw :: ByteString
|
||||
hw = encodeUtf8 "Hello world!"
|
||||
@@ -405,12 +432,12 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
|
||||
externalPrevSndHash = "hash_from_sender"
|
||||
}
|
||||
|
||||
testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do
|
||||
testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvQueue -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rq rcvMsgData@RcvMsgData {..} = do
|
||||
let MsgMeta {recipient = (internalId, _)} = msgMeta
|
||||
updateRcvIds db connId
|
||||
`shouldReturn` (InternalId internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
|
||||
createRcvMsg db connId rcvMsgData
|
||||
createRcvMsg db connId rq rcvMsgData
|
||||
`shouldReturn` ()
|
||||
|
||||
testCreateRcvMsg :: SpecWith SQLiteStore
|
||||
@@ -421,8 +448,8 @@ testCreateRcvMsg =
|
||||
_ <- withTransaction st $ \db -> do
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
withTransaction st $ \db -> do
|
||||
testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg_ db 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
testCreateRcvMsg_ db 0 "" connId rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg_ db 1 "hash_dummy" connId rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
|
||||
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
|
||||
mkSndMsgData internalId internalSndId internalHash =
|
||||
@@ -464,9 +491,9 @@ testCreateRcvAndSndMsgs =
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
withTransaction st $ \db -> do
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg_ db 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateRcvMsg_ db 0 "" connId rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg_ db 1 "rcv_hash_1" connId rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg_ db "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg_ db 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateRcvMsg_ db 2 "rcv_hash_2" connId rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg_ db "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg_ db "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
module CoreTests.ProtocolErrorTests where
|
||||
|
||||
import Simplex.Messaging.Agent.Protocol (AgentErrorType)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol (ErrorType)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Test.Hspec
|
||||
import Test.Hspec.QuickCheck (modifyMaxSuccess)
|
||||
import Test.QuickCheck
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
import AgentTests (agentTests)
|
||||
-- import Control.Logger.Simple
|
||||
import CoreTests.CryptoTests
|
||||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
@@ -13,8 +14,13 @@ import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
|
||||
import System.Environment (setEnv)
|
||||
import Test.Hspec
|
||||
|
||||
-- logCfg :: LogConfig
|
||||
-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
-- setLogLevel LogInfo -- LogError
|
||||
-- withGlobalLogging logCfg $ do
|
||||
createDirectoryIfMissing False "tests/tmp"
|
||||
setEnv "APNS_KEY_ID" "H82WD9K9AQ"
|
||||
setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"
|
||||
|
||||
Reference in New Issue
Block a user