diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 66d66d8a5..536d67b85 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,35 +11,51 @@ on: jobs: build: - runs-on: ubuntu-20.04 + name: build-${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-18.04 + - os: ubuntu-20.04 steps: - name: Clone project uses: actions/checkout@v2 - - name: Setup Stack + - name: Setup Haskell uses: haskell/actions/setup@v1 with: ghc-version: "8.10.7" - enable-stack: true - stack-version: "latest" + cabal-version: "latest" - name: Cache dependencies uses: actions/cache@v2 with: - path: ~/.stack - key: ${{ hashFiles('stack.yaml') }} + path: | + ~/.cabal/store + dist-newstyle + key: ${{ matrix.os }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }} - - name: Build & test - id: build_test + - name: Build + shell: bash + run: cabal build --enable-tests + + - name: Test + if: matrix.os == 'ubuntu-18.04' + timeout-minutes: 30 + shell: bash + run: cabal test --test-show-details=direct + + - name: Prepare binaries + if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04' shell: bash run: | - stack build --test --force-dirty - install_root=$(stack path --local-install-root) - mv ${install_root}/bin/smp-server smp-server-ubuntu-20_04-x86-64 - mv ${install_root}/bin/ntf-server ntf-server-ubuntu-20_04-x86-64 + mv $(cabal list-bin smp-server) smp-server-ubuntu-20_04-x86-64 + mv $(cabal list-bin ntf-server) ntf-server-ubuntu-20_04-x86-64 - name: Build changelog - if: startsWith(github.ref, 'refs/tags/v') + if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04' id: build_changelog uses: mikepenz/release-changelog-builder-action@v1 with: @@ -50,19 +66,8 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Extract release candidate - if: startsWith(github.ref, 'refs/tags/v') - id: extract_release_candidate - shell: bash - run: | - if [[ ${GITHUB_REF} == *rc* ]]; then - echo "::set-output name=release_candidate::true" - else - echo "::set-output name=release_candidate::false" - fi - - name: Create release - if: startsWith(github.ref, 'refs/tags/v') + if: startsWith(github.ref, 'refs/tags/v') && matrix.os == 'ubuntu-20.04' uses: softprops/action-gh-release@v1 with: body: | @@ -70,7 +75,7 @@ jobs: Commits: ${{ steps.build_changelog.outputs.changelog }} - prerelease: ${{ steps.extract_release_candidate.outputs.release_candidate }} + prerelease: true files: | LICENSE smp-server-ubuntu-20_04-x86-64 diff --git a/rfcs/2022-08-14-queue-rotation.md b/rfcs/2022-08-14-queue-rotation.md new file mode 100644 index 000000000..d31e8ec11 --- /dev/null +++ b/rfcs/2022-08-14-queue-rotation.md @@ -0,0 +1,62 @@ +# SMP queue rotation and redundancy + +## Problem + +1. Long term usage of the queue allows the servers to analyze long term traffic patterns. + +2. If the user changes the configured SMP server(s), the previously created contacts do not migrate to the new server(s). + +3. Server can lose messages. + +## Solution + +Additional messages exchanged by SMP agents to negotiate addition and removal of queues to the connection. + +This approach allows for both rotation and redundancy, as changing queue will be done be adding a queue and then removing the existing queue. + +The reason for this approach is that otherwise it's non-trivial to switch from one queue to another without losing messages or delivering them out of order, it's easier to have messages delivered via both queues during the switch, however short or long time it is. + +### Messages + +Additional agent messages required for the protocol: + + QADD_ -> "QA" + QKEY_ -> "QK" + QUSE_ -> "QU" + QTEST_ -> "QT" + QDEL_ -> "QD" + QEND_ -> "QE" + +`QADD`: add the new queue address(es), the same format as `REPLY` message, encoded as `QA`. + +`QKEY`: pass sender's key via existing connection (SMP confirmation message will not be used, to avoid the same "race" of the initial key exchange that would create the risk of intercepting the queue for the attacker), encoded as `QK`. + +`QUSE`: instruct the sender to use the new queue with sender's queue ID as parameter, encoded as `QU`. + +`QTEST`: send test message to the new connection, encoded as `QT`. Any other message can be sent if available to continue rotation, the absence of this message is not an error. + +`QDEL`: instruct the sender to stop using the previous queue, encoded as `QD` + +`QEND`: notify the recipient that no messages will be sent to this queue, encoded as `QE`. The recipient will delete this queue. + +### Protocol + +``` +participant A as Alice +participant B as Bob +participant R as Server that has A's receive queue +participant S as Server that has A's send queue (B's receive queue) +participant R' as Server that hosts the new A's receive queue + +A ->> R': create new queue +A ->> S ->> B: QADD (R'): snd address of the new queue(s) +B ->> A(R) ->> A: QKEY (R'): sender's key for the new queue(s) (to avoid the race of SMP confirmation for the initial exchange) +A ->> S(R'): secure new queue +A ->> S ->> B: QUSE (R'): instruction to use new queue(s) +B ->> A(R,R') ->> A: QTEST +B ->> A(R,R') ->> A: send all new messages to both queues +A ->> S ->> B: QDEL (R): instruction to delete the old queue +B ->> A(R') -> A: QEND (R): notification that no messages will be sent to the old queue +B ->> R' ->> A: send all new messages to the new queue only +A ->> S(R): DEL: delete the previous queue +``` diff --git a/simplexmq.cabal b/simplexmq.cabal index 4fddcc1e3..c371fbd23 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -54,6 +54,8 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues + Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent Simplex.Messaging.Crypto @@ -83,7 +85,6 @@ library Simplex.Messaging.Server.Stats Simplex.Messaging.Server.StoreLog Simplex.Messaging.TMap - Simplex.Messaging.TMap2 Simplex.Messaging.Transport Simplex.Messaging.Transport.Client Simplex.Messaging.Transport.HTTP2 @@ -350,6 +351,7 @@ test-suite smp-server-test AgentTests.NotificationTests AgentTests.SchemaDump AgentTests.SQLiteTests + CoreTests.CryptoTests CoreTests.EncodingTests CoreTests.ProtocolErrorTests CoreTests.VersionRangeTests diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 886b4ccaf..20dc8ffcd 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -43,6 +43,7 @@ module Simplex.Messaging.Agent allowConnectionAsync, acceptContactAsync, ackMessageAsync, + switchConnectionAsync, deleteConnectionAsync, createConnection, joinConnection, @@ -57,6 +58,7 @@ module Simplex.Messaging.Agent resubscribeConnections, sendMessage, ackMessage, + switchConnection, suspendConnection, deleteConnection, getConnectionServers, @@ -89,9 +91,10 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::)) +import Data.Foldable (foldl') import Data.Functor (($>)) -import Data.List (deleteFirstsBy) -import Data.List.NonEmpty (NonEmpty (..)) +import Data.List (deleteFirstsBy, find) +import Data.List.NonEmpty (NonEmpty (..), (<|)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M @@ -117,7 +120,7 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta) +import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta, SndPublicVerifyKey, sameSrvAddr) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util @@ -171,6 +174,10 @@ acceptContactAsync c corrId enableNtfs = withAgentEnv c .: acceptContactAsync' c ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m () ackMessageAsync c = withAgentEnv c .:. ackMessageAsync' c +-- | Switch connection to the new receive queue +switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m () +switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c + -- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m () deleteConnectionAsync c = withAgentEnv c .: deleteConnectionAsync' c @@ -224,6 +231,10 @@ sendMessage c = withAgentEnv c .:. sendMessage' c ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> m () ackMessage c = withAgentEnv c .: ackMessage' c +-- | Switch connection to the new receive queue +switchConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats +switchConnection c = withAgentEnv c . switchConnection' c + -- | Suspend SMP agent connection (OFF command) suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () suspendConnection c = withAgentEnv c . suspendConnection' c @@ -328,6 +339,7 @@ processCommand c (connId, cmd) = case cmd of SUB -> subscribeConnection' c connId $> (connId, OK) SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody ACK msgId -> ackMessage' c connId msgId $> (connId, OK) + SWCH -> switchConnection' c connId $> (connId, OK) OFF -> suspendConnection' c connId $> (connId, OK) DEL -> deleteConnection' c connId $> (connId, OK) CHK -> (connId,) . STAT <$> getConnectionServers' c connId @@ -375,22 +387,25 @@ acceptContactAsync' c corrId enableNtfs invId ownConnInfo = do _ -> throwError $ CMD PROHIBITED ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m () -ackMessageAsync' c corrId connId msgId = - withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> enqueueAck rq - SomeConn _ (RcvConnection _ rq) -> enqueueAck rq - SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX - SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED - SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED +ackMessageAsync' c corrId connId msgId = do + SomeConn cType _ <- withStore c (`getConn` connId) + case cType of + SCDuplex -> enqueueAck + SCRcv -> enqueueAck + SCSnd -> throwError $ CONN SIMPLEX + SCContact -> throwError $ CMD PROHIBITED + SCNew -> throwError $ CMD PROHIBITED where - enqueueAck :: RcvQueue -> m () - enqueueAck RcvQueue {server} = - enqueueCommand c corrId connId (Just server) $ AClientCommand $ ACK msgId + enqueueAck :: m () + enqueueAck = do + (RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId + enqueueCommand c corrId connId (Just server) . AClientCommand $ ACK msgId deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> m () deleteConnectionAsync' c@AgentClient {subQ} corrId connId = withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> enqueueDelete rq + -- TODO *** delete all queues + SomeConn _ (DuplexConnection _ (rq :| _) _) -> enqueueDelete rq SomeConn _ (RcvConnection _ rq) -> enqueueDelete rq SomeConn _ (ContactConnection _ rq) -> enqueueDelete rq SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId) >> notifyDeleted @@ -402,6 +417,13 @@ deleteConnectionAsync' c@AgentClient {subQ} corrId connId = notifyDeleted :: m () notifyDeleted = atomically $ writeTBQueue subQ (corrId, connId, OK) +-- | Add connection to the new receive queue +switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m () +switchConnectionAsync' c corrId connId = + withStore c (`getConn` connId) >>= \case + SomeConn _ DuplexConnection {} -> enqueueCommand c corrId connId Nothing $ AClientCommand SWCH + _ -> throwError $ CMD PROHIBITED + newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c) newConn c connId asyncMode enableNtfs cMode = getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode @@ -409,9 +431,10 @@ newConn c connId asyncMode enableNtfs cMode = newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c) newConnSrv c connId asyncMode enableNtfs cMode srv = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - (rq, qUri) <- newRcvQueue c srv smpClientVRange - connId' <- setUpConn asyncMode rq $ maxVersion smpAgentVRange - addSubscription c rq connId' + (q, qUri) <- newRcvQueue c "" srv smpClientVRange + connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange + let rq = (q :: RcvQueue) {connId = connId'} + addSubscription c rq when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId', NSCCreate) @@ -424,7 +447,7 @@ newConnSrv c connId asyncMode enableNtfs cMode srv = do pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange) where setUpConn True rq _ = do - withStore c $ \db -> updateNewConnRcv db connId rq + void . withStore c $ \db -> updateNewConnRcv db connId rq pure connId setUpConn False rq connAgentVersion = do g <- asks idsDrg @@ -434,8 +457,8 @@ newConnSrv c connId asyncMode enableNtfs cMode srv = do joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId joinConn c connId asyncMode enableNtfs cReq cInfo = do srv <- case cReq of - CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ -> - getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)] + CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> + getNextSMPServer c [qServer q] _ -> getSMPServer c joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv @@ -450,11 +473,12 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ age (pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams (_, rcDHRs) <- liftIO C.generateKeyPair' let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams - sq <- newSndQueue qInfo + q <- newSndQueue "" qInfo let duplexHS = connAgentVersion /= 1 cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS} - connId' <- setUpConn asyncMode cData sq rc - let cData' = (cData :: ConnData) {connId = connId'} + connId' <- setUpConn asyncMode cData q rc + let sq = (q :: SndQueue) {connId = connId'} + cData' = (cData :: ConnData) {connId = connId'} tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case Right _ -> do unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO @@ -466,7 +490,7 @@ joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ age where setUpConn True _ sq rc = withStore c $ \db -> runExceptT $ do - ExceptT $ updateNewConnSnd db connId sq + void . ExceptT $ updateNewConnSnd db connId sq liftIO $ createRatchet db connId rc pure connId setUpConn False cData sq rc = do @@ -492,10 +516,10 @@ joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do - (rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion + (rq, qUri) <- newRcvQueue c connId srv $ versionToRange smpClientVersion let qInfo = toVersionT qUri smpClientVersion - addSubscription c rq connId - withStore c $ \db -> upgradeSndConnToDuplex db connId rq + addSubscription c rq + void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) @@ -538,24 +562,27 @@ subscribeConnection' c connId = withStore c (`getConn` connId) >>= \conn -> do resumeConnCmds c connId case conn of - SomeConn _ (DuplexConnection cData rq sq) -> do - resumeMsgDelivery c cData sq - subscribe rq + SomeConn _ (DuplexConnection cData (rq :| rqs) sqs) -> do + mapM_ (resumeMsgDelivery c cData) sqs + subscribe cData rq + mapM_ (\q -> subscribeQueue c q `catchError` \_ -> pure ()) rqs SomeConn _ (SndConnection cData sq) -> do resumeMsgDelivery c cData sq case status (sq :: SndQueue) of Confirmed -> pure () Active -> throwError $ CONN SIMPLEX _ -> throwError $ INTERNAL "unexpected queue status" - SomeConn _ (RcvConnection _ rq) -> subscribe rq - SomeConn _ (ContactConnection _ rq) -> subscribe rq + SomeConn _ (RcvConnection cData rq) -> subscribe cData rq + SomeConn _ (ContactConnection cData rq) -> subscribe cData rq SomeConn _ (NewConnection _) -> pure () where - subscribe :: RcvQueue -> m () - subscribe rq = do - subscribeQueue c rq connId + subscribe :: ConnData -> RcvQueue -> m () + subscribe ConnData {enableNtfs} rq = do + subscribeQueue c rq ns <- asks ntfSupervisor - atomically $ sendNtfSubCommand ns (connId, NSCCreate) + atomically $ sendNtfSubCommand ns (connId, if enableNtfs then NSCCreate else NSCDelete) + +type QSubResult = (QueueStatus, Either AgentErrorType ()) subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) subscribeConnections' _ [] = pure M.empty @@ -564,42 +591,61 @@ subscribeConnections' c connIds = do let (errs, cs) = M.mapEither id conns errs' = M.map (Left . storeError) errs (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs - srvRcvQs :: Map SMPServer (Map ConnId RcvQueue) = M.foldlWithKey' addRcvQueue M.empty rcvQs - mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs + srvRcvQs :: Map SMPServer [RcvQueue] = M.foldl' (foldl' addRcvQueue) M.empty rcvQs + mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs mapM_ (resumeConnCmds c) $ M.keys cs - rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs) + rcvRs <- connResults . concat <$> mapConcurrently subscribe (M.assocs srvRcvQs) ns <- asks ntfSupervisor tkn <- readTVarIO (ntfTkn ns) - when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs - let rs = M.unions $ errs' : subRs : rcvRs + when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns + let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())]) notifyResultError rs pure rs where - rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) RcvQueue + rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) (NonEmpty RcvQueue) rcvQueueOrResult = \case - SomeConn _ (DuplexConnection _ rq _) -> Right rq + SomeConn _ (DuplexConnection _ rqs _) -> Right rqs SomeConn _ (SndConnection _ sq) -> Left $ sndSubResult sq - SomeConn _ (RcvConnection _ rq) -> Right rq - SomeConn _ (ContactConnection _ rq) -> Right rq + SomeConn _ (RcvConnection _ rq) -> Right [rq] + SomeConn _ (ContactConnection _ rq) -> Right [rq] SomeConn _ (NewConnection _) -> Left (Right ()) sndSubResult :: SndQueue -> Either AgentErrorType () sndSubResult sq = case status (sq :: SndQueue) of Confirmed -> Right () Active -> Left $ CONN SIMPLEX _ -> Left $ INTERNAL "unexpected queue status" - addRcvQueue :: Map SMPServer (Map ConnId RcvQueue) -> ConnId -> RcvQueue -> Map SMPServer (Map ConnId RcvQueue) - addRcvQueue m connId rq@RcvQueue {server} = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m - subscribe :: (SMPServer, Map ConnId RcvQueue) -> m (Map ConnId (Either AgentErrorType ())) + addRcvQueue :: Map SMPServer [RcvQueue] -> RcvQueue -> Map SMPServer [RcvQueue] + addRcvQueue m rq@RcvQueue {server} = M.alter (Just . maybe [rq] (rq :)) server m + subscribe :: (SMPServer, [RcvQueue]) -> m [(RcvQueue, Either AgentErrorType ())] subscribe (srv, qs) = snd <$> subscribeQueues c srv qs - sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m () - sendNtfCreate ns rcvRs = - forM_ (concatMap M.assocs rcvRs) $ \case - (connId, Right _) -> atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCCreate) + connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ()) + connResults = M.map snd . foldl' addResult M.empty + where + -- collects results by connection ID + addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QSubResult + addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs + -- combines two results for one connection, by using only Active queues (if there is at least one Active queue) + combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult + combineRes r' (Just r) = Just $ if order r <= order r' then r else r' + combineRes r' _ = Just r' + order :: QSubResult -> Int + order (Active, Right _) = 1 + order (Active, _) = 2 + order (_, Right _) = 3 + order _ = 4 + sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> m () + sendNtfCreate ns rcvRs conns = + forM_ (M.assocs rcvRs) $ \case + (connId, Right _) -> forM_ (M.lookup connId conns) $ \case + Right (SomeConn _ conn) -> do + let cmd = if enableNtfs $ connData conn then NSCCreate else NSCDelete + atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd) + _ -> pure () _ -> pure () - sndQueue :: SomeConn -> Maybe (ConnData, SndQueue) + sndQueue :: SomeConn -> Maybe (ConnData, NonEmpty SndQueue) sndQueue = \case - SomeConn _ (DuplexConnection cData _ sq) -> Just (cData, sq) - SomeConn _ (SndConnection cData sq) -> Just (cData, sq) + SomeConn _ (DuplexConnection cData _ sqs) -> Just (cData, sqs) + SomeConn _ (SndConnection cData sq) -> Just (cData, [sq]) _ -> Nothing notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m () notifyResultError rs = do @@ -626,7 +672,7 @@ getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMs getConnectionMessage' c connId = do whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> getQueueMessage c rq + SomeConn _ (DuplexConnection _ (rq :| _) _) -> getQueueMessage c rq SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq SomeConn _ (ContactConnection _ rq) -> getQueueMessage c rq SomeConn _ SndConnection {} -> throwError $ CONN SIMPLEX @@ -663,12 +709,12 @@ getNotificationMessage' c nonce encNtfInfo = do sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection cData _ sq) -> enqueueMsg cData sq - SomeConn _ (SndConnection cData sq) -> enqueueMsg cData sq + SomeConn _ (DuplexConnection cData _ sqs) -> enqueueMsgs cData sqs + SomeConn _ (SndConnection cData sq) -> enqueueMsgs cData [sq] _ -> throwError $ CONN SIMPLEX where - enqueueMsg :: ConnData -> SndQueue -> m AgentMsgId - enqueueMsg cData sq = enqueueMessage c cData sq msgFlags $ A_MSG msg + enqueueMsgs :: ConnData -> NonEmpty SndQueue -> m AgentMsgId + enqueueMsgs cData sqs = enqueueMessages c cData sqs msgFlags $ A_MSG msg -- / async command processing v v v @@ -713,8 +759,8 @@ getPendingCommandQ c server = do pure cq runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m () -runCommandProcessing c@AgentClient {subQ} server = do - cq <- atomically $ getPendingCommandQ c server +runCommandProcessing c@AgentClient {subQ} server_ = do + cq <- atomically $ getPendingCommandQ c server_ ri <- asks $ messageRetryInterval . config -- different retry interval? forever $ do atomically $ endAgentOperation c AOSndNetwork @@ -726,73 +772,108 @@ runCommandProcessing c@AgentClient {subQ} server = do Right (corrId, connId, cmd) -> processCmd ri corrId connId cmdId cmd where processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> AgentCommand -> m () - processCmd ri corrId connId cmdId = \case + processCmd ri corrId connId cmdId command = case command of AClientCommand cmd -> case cmd of - NEW enableNtfs (ACM cMode) -> do + NEW enableNtfs (ACM cMode) -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) tryCommand . withNextSrv usedSrvs [] $ \srv -> do (_, cReq) <- newConnSrv c connId True enableNtfs cMode srv notify $ INV (ACR cMode cReq) - JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do - let initUsed = [smpServer (queueAddress :: SMPQueueAddress)] + JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) connInfo -> noServer $ do + let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do void $ joinConnSrv c connId True enableNtfs cReq connInfo srv notify OK - LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK - ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK - DEL -> tryCommand $ deleteConnection' c connId >> notify OK + LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK + ACK msgId -> withServer' . tryCommand $ ackMessage' c connId msgId >> notify OK + SWCH -> noServer $ tryCommand $ switchConnection' c connId >>= notify . SWITCH SPStarted + DEL -> withServer' . tryCommand $ deleteConnection' c connId >> notify OK _ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd) - AInternalCommand cmd -> case server of - Just _srv -> case cmd of - ICAckDel _rId srvMsgId msgId -> tryWithLock "ICAckDel" $ ack _rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId) - ICAck _rId srvMsgId -> tryWithLock "ICAck" $ ack _rId srvMsgId - ICAllowSecure _rId senderKey -> tryWithLock "ICAllowSecure" $ do - (SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <- - withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId) - case conn of - RcvConnection cData rq -> do - secure rq senderKey - mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf) - _ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd) - ICDuplexSecure _rId senderKey -> tryWithLock "ICDuplexSecure" $ do - SomeConn _ conn <- withStore c (`getConn` connId) - case conn of - DuplexConnection cData rq sq -> do - secure rq senderKey - when (duplexHandshake cData == Just True) . void $ - enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO - _ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd) - _ -> throwError $ INTERNAL $ "command requires server " <> show (internalCmdTag cmd) + AInternalCommand cmd -> case cmd of + ICAckDel rId srvMsgId msgId -> withServer $ \srv -> tryWithLock "ICAckDel" $ ack srv rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId) + ICAck rId srvMsgId -> withServer $ \srv -> tryWithLock "ICAck" $ ack srv rId srvMsgId + ICAllowSecure _rId senderKey -> withServer' . tryWithLock "ICAllowSecure" $ do + (SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <- + withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId) + case conn of + RcvConnection cData rq -> do + secure rq senderKey + mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf) + _ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd) + ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do + secure rq senderKey + when (duplexHandshake cData == Just True) . void $ + enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO + ICQSecure rId senderKey -> + withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> + case find (sameQueue (srv, rId)) rqs of + Just rq'@RcvQueue {server, sndId, status} -> when (status == Confirmed) $ do + secureQueue c rq' senderKey + withStore' c $ \db -> setRcvQueueStatus db rq' Secured + void . enqueueMessages c cData sqs SMP.noMsgFlags $ QUSE [((server, sndId), True)] + _ -> internalErr "ICQSecure: queue address not found in connection" + ICQDelete rId -> do + withServer $ \srv -> tryWithLock "ICQDelete" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> do + case removeQ (srv, rId) rqs of + Nothing -> internalErr "ICQDelete: queue address not found in connection" + Just (rq'@RcvQueue {primary}, rq'' : rqs') + | primary -> internalErr "ICQDelete: cannot delete primary rcv queue" + | otherwise -> do + deleteQueue c rq' + withStore' c $ \db -> deleteConnRcvQueue db connId rq' + let conn' = DuplexConnection cData (rq'' :| rqs') sqs + notify $ SWITCH SPCompleted $ connectionStats conn' + _ -> internalErr "ICQDelete: cannot delete the only queue in connection" where - ack _rId srvMsgId = do - -- TODO get particular queue - rq <- withStore c (`getRcvQueue` connId) + ack srv rId srvMsgId = do + rq <- withStore c $ \db -> getRcvQueue db connId srv rId ackQueueMessage c rq srvMsgId secure :: RcvQueue -> SMP.SndPublicVerifyKey -> m () secure rq senderKey = do secureQueue c rq senderKey withStore' c $ \db -> setRcvQueueStatus db rq Secured where + withServer a = case server_ of + Just srv -> a srv + _ -> internalErr "command requires server" + withServer' = withServer . const + noServer a = case server_ of + Nothing -> a + _ -> internalErr "command requires no server" + withDuplexConn :: (Connection 'CDuplex -> m ()) -> m () + withDuplexConn a = + withStore c (`getConn` connId) >>= \case + SomeConn _ conn@DuplexConnection {} -> a conn + _ -> internalErr "command requires duplex connection" tryCommand action = withRetryInterval ri $ \loop -> tryError action >>= \case Left e | temporaryAgentError e || e == BROKER HOST -> retrySndOp c loop - | otherwise -> notify (ERR e) >> withStore' c (`deleteCommand` cmdId) + | otherwise -> cmdError e Right () -> withStore' c (`deleteCommand` cmdId) tryWithLock name = tryCommand . withConnLock c connId name + internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command) + cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId) notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd) - withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m () - withNextSrv usedSrvs initUsed action = do - used <- readTVarIO usedSrvs - srv <- getNextSMPServer c used - atomically $ do - srvs <- readTVar $ smpServers c - let used' = if length used + 1 >= L.length srvs then initUsed else srv : used - writeTVar usedSrvs used' - action srv + withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m () + withNextSrv usedSrvs initUsed action = do + used <- readTVarIO usedSrvs + srv <- getNextSMPServer c used + atomically $ do + srvs <- readTVar $ smpServers c + let used' = if length used + 1 >= L.length srvs then initUsed else srv : used + writeTVar usedSrvs used' + action srv -- ^ ^ ^ async command processing / +enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId +enqueueMessages c cData (sq :| sqs) msgFlags aMessage = do + msgId <- enqueueMessage c cData sq msgFlags aMessage + mapM_ (enqueueSavedMessage c cData msgId) $ + filter (\SndQueue {status} -> status == Secured || status == Active) sqs + pure msgId + enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do resumeMsgDelivery c cData sq @@ -813,20 +894,28 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData + liftIO $ createSndMsgDelivery db connId sq internalId pure internalId +enqueueSavedMessage :: AgentMonad m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () +enqueueSavedMessage c cData@ConnData {connId} msgId sq = do + resumeMsgDelivery c cData sq + let mId = InternalId msgId + queuePendingMsgs c sq [mId] + withStore' c $ \db -> createSndMsgDelivery db connId sq mId + resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do let qKey = (server, sndId) unlessM (queueDelivering qKey) $ async (runSmpQueueMsgDelivery c cData sq) >>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c) - unlessM connQueued $ - withStore' c (`getPendingMsgs` connId) + unlessM msgsQueued $ + withStore' c (\db -> getPendingMsgs db connId sq) >>= queuePendingMsgs c sq where queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c) - connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c) + msgsQueued = atomically $ isJust <$> TM.lookupInsert (server, sndId) True (pendingMsgsQueued c) queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m () queuePendingMsgs c sq msgIds = atomically $ do @@ -853,6 +942,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh forever $ do atomically $ endAgentOperation c AOSndNetwork atomically $ throwWhenInactive c + atomically $ throwWhenNoDelivery c sq msgId <- atomically $ readTQueue mq atomically $ beginAgentOperation c AOSndNetwork atomically $ endAgentOperation c AOMsgDelivery @@ -890,6 +980,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh _ -> connError msgId NOT_ACCEPTED AM_REPLY_ -> notifyDel msgId $ ERR e AM_A_MSG_ -> notifyDel msgId $ MERR mId e + AM_QADD_ -> pure () + AM_QKEY_ -> pure () + AM_QUSE_ -> pure () + AM_QTEST_ -> pure () + AM_QDEL_ -> pure () + AM_QEND_ -> pure () _ -- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server), -- the message sending would be retried @@ -910,6 +1006,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh when (isJust rq_) $ removeConfirmations db connId -- TODO possibly notification flag should be ON for one of the parties, to result in contact connected notification unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO + AM_CONN_INFO_REPLY -> pure () + AM_REPLY_ -> pure () AM_HELLO_ -> do withStore' c $ \db -> setSndQueueStatus db sq Active case rq_ of @@ -933,11 +1031,18 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh qInfo <- createReplyQueue c cData sq srv void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo] AM_A_MSG_ -> notify $ SENT mId - _ -> pure () + AM_QADD_ -> pure () + AM_QKEY_ -> pure () + AM_QUSE_ -> pure () + AM_QTEST_ -> + withStore' c $ \db -> setSndQueueStatus db sq Active + AM_QDEL_ -> pure () + AM_QEND_ -> + getConnectionServers' c connId >>= notify . SWITCH SPCompleted delMsg msgId where delMsg :: InternalId -> m () - delMsg msgId = withStore' c $ \db -> deleteMsg db connId msgId + delMsg msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId notify :: ACommand 'Agent -> m () notify cmd = atomically $ writeTBQueue subQ ("", connId, cmd) notifyDel :: InternalId -> ACommand 'Agent -> m () @@ -954,20 +1059,38 @@ retrySndOp c loop = do ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> m () ackMessage' c connId msgId = withConnLock c connId "ackMessage" $ do - withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> ack rq - SomeConn _ (RcvConnection _ rq) -> ack rq - SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX - SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED - SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED + SomeConn _ conn <- withStore c (`getConn` connId) + case conn of + DuplexConnection {} -> ack + RcvConnection {} -> ack + SndConnection {} -> throwError $ CONN SIMPLEX + ContactConnection {} -> throwError $ CMD PROHIBITED + NewConnection _ -> throwError $ CMD PROHIBITED where - ack :: RcvQueue -> m () - ack rq = do + ack :: m () + ack = do let mId = InternalId msgId - srvMsgId <- withStore c $ \db -> setMsgUserAck db connId mId + (rq, srvMsgId) <- withStore c $ \db -> setMsgUserAck db connId mId ackQueueMessage c rq srvMsgId withStore' c $ \db -> deleteMsg db connId mId +switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats +switchConnection' c connId = withConnLock c connId "switchConnection" $ do + SomeConn _ conn <- withStore c (`getConn` connId) + case conn of + DuplexConnection cData rqs@(rq@RcvQueue {server, dbQueueId, sndId} :| rqs_) sqs -> do + clientVRange <- asks $ smpClientVRange . config + -- try to get the server that is different from all queues, or at least from the primary rcv queue + srv <- getNextSMPServer c $ map qServer (L.toList rqs) <> map qServer (L.toList sqs) + srv' <- if srv == server then getNextSMPServer c [server] else pure srv + (q, qUri) <- newRcvQueue c connId srv' clientVRange + let rq' = (q :: RcvQueue) {primary = False, nextPrimary = True, dbReplaceQueueId = Just dbQueueId} + void . withStore c $ \db -> addConnRcvQueue db connId rq' + addSubscription c rq' + void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] + pure . connectionStats $ DuplexConnection cData (rq <| rq' :| rqs_) sqs + _ -> throwError $ CMD PROHIBITED + ackQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> SMP.MsgId -> m () ackQueueMessage c rq srvMsgId = sendAck c rq srvMsgId `catchError` \case @@ -978,7 +1101,7 @@ ackQueueMessage c rq srvMsgId = suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m () suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> suspendQueue c rq + SomeConn _ (DuplexConnection _ rqs _) -> mapM_ (suspendQueue c) rqs SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq SomeConn _ (ContactConnection _ rq) -> suspendQueue c rq SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX @@ -988,7 +1111,7 @@ suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> delete rq + SomeConn _ (DuplexConnection _ rqs _) -> mapM_ delete rqs SomeConn _ (RcvConnection _ rq) -> delete rq SomeConn _ (ContactConnection _ rq) -> delete rq SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId) @@ -1003,15 +1126,17 @@ deleteConnection' c connId = withConnLock c connId "deleteConnection" $ do atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete) getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats -getConnectionServers' c connId = connServers <$> withStore c (`getConn` connId) - where - connServers :: SomeConn -> ConnectionStats - connServers = \case - SomeConn _ (RcvConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []} - SomeConn _ (SndConnection _ SndQueue {server}) -> ConnectionStats {rcvServers = [], sndServers = [server]} - SomeConn _ (DuplexConnection _ RcvQueue {server = s1} SndQueue {server = s2}) -> ConnectionStats {rcvServers = [s1], sndServers = [s2]} - SomeConn _ (ContactConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []} - SomeConn _ (NewConnection _) -> ConnectionStats {rcvServers = [], sndServers = []} +getConnectionServers' c connId = do + SomeConn _ conn <- withStore c (`getConn` connId) + pure $ connectionStats conn + +connectionStats :: Connection c -> ConnectionStats +connectionStats = \case + RcvConnection _ rq -> ConnectionStats {rcvServers = [qServer rq], sndServers = []} + SndConnection _ sq -> ConnectionStats {rcvServers = [], sndServers = [qServer sq]} + DuplexConnection _ rqs sqs -> ConnectionStats {rcvServers = map qServer $ L.toList rqs, sndServers = map qServer $ L.toList sqs} + ContactConnection _ rq -> ConnectionStats {rcvServers = [qServer rq], sndServers = []} + NewConnection _ -> ConnectionStats {rcvServers = [], sndServers = []} -- | Change servers to be used for creating new queues, in Reader monad setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m () @@ -1271,11 +1396,9 @@ pickServer = \case getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer getNextSMPServer c usedSrvs = do srvs <- readTVarIO $ smpServers c - case L.nonEmpty $ deleteFirstsBy sameAddr (L.toList srvs) usedSrvs of + case L.nonEmpty $ deleteFirstsBy sameSrvAddr (L.toList srvs) usedSrvs of Just srvs' -> pickServer srvs' _ -> pickServer srvs - where - sameAddr (SMPServer host port _) (SMPServer host' port' _) = host == host' && port == port' subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do @@ -1288,7 +1411,10 @@ subscriber c@AgentClient {msgQ} = forever $ do processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cmd) = withStore c (\db -> getRcvConn db srv rId) >>= \case - SomeConn _ conn@(DuplexConnection cData rq _) -> processSMP conn cData rq + -- TODO *** get queue separately? + SomeConn _ conn@(DuplexConnection cData rqs _) -> case find (sameQueue (srv, rId)) rqs of + Just rq -> processSMP conn cData rq + _ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND) SomeConn _ conn@(RcvConnection cData rq) -> processSMP conn cData rq SomeConn _ conn@(ContactConnection cData rq) -> processSMP conn cData rq _ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND) @@ -1313,7 +1439,20 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm _ -> prohibited >> ack (Just e2eDh, Nothing) -> do decryptClientMessage e2eDh clientMsg >>= \case - (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> + (SMP.PHEmpty, AgentMsgEnvelope _ encAgentMsg) -> do + -- primary queue is set as Active in helloMsg, below is to set additional queues Active + let RcvQueue {primary, nextPrimary, dbReplaceQueueId} = rq + unless primary . withStore' c $ \db -> do + unless (status == Active) $ setRcvQueueStatus db rq Active + when nextPrimary $ setRcvQueuePrimary db connId rq + case (conn, dbReplaceQueueId) of + (DuplexConnection _ rqs sqs, Just dbRcvId) -> + case find (\RcvQueue {dbQueueId} -> dbQueueId == dbRcvId) rqs of + Just RcvQueue {server, sndId} -> do + void . enqueueMessages c cData sqs SMP.noMsgFlags $ QDEL [(server, sndId)] + notify . SWITCH SPTested $ connectionStats conn + _ -> throwError $ INTERNAL "replaced RcvQueue not found in connection" + _ -> pure () tryError agentClientMsg >>= \case Right (Just (msgId, msgMeta, aMessage)) -> case aMessage of HELLO -> helloMsg >> ackDel msgId @@ -1322,6 +1461,19 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm A_MSG body -> do logServer "<--" c srv rId "MSG " notify $ MSG msgMeta msgFlags body + QADD qs -> qDuplex "QADD" $ qAddMsg qs + QKEY qs -> qDuplex "QKEY" $ qKeyMsg qs + QUSE qs -> qDuplex "QUSE" $ qUseMsg qs + -- no action needed for QTEST + -- any message in the new queue will mark it active and trigger deletion of the old queue + QTEST _ -> logServer "<--" c srv rId "MSG " >> ackDel msgId + QDEL qs -> qDuplex "QDEL" $ qDelMsg qs + QEND qs -> qDuplex "QEND" $ qEndMsg qs + where + qDuplex :: String -> (Connection 'CDuplex -> m ()) -> m () + qDuplex name a = case conn of + DuplexConnection {} -> a conn >> ackDel msgId + _ -> qError $ name <> ": message must be sent to duplex connection" Right _ -> prohibited >> ack Left e@(AGENT A_DUPLICATE) -> do withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case @@ -1350,7 +1502,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm broker = (srvMsgId, systemToUTCTime srvTs) msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash} - liftIO $ createRcvMsg db connId rcvMsg + liftIO $ createRcvMsg db connId rq rcvMsg pure $ Just (internalId, msgMeta, aMessage) _ -> pure Nothing _ -> prohibited >> ack @@ -1434,12 +1586,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm confId <- withStore c $ \db -> do setHandshakeVersion db connId agentVersion duplexHS createConfirmation db g newConfirmation - let srvs = map queueServer $ smpReplyQueues senderConf + let srvs = map qServer $ smpReplyQueues senderConf notify $ CONF confId srvs connInfo - queueServer (SMPQueueInfo _ SMPQueueAddress {smpServer}) = smpServer _ -> prohibited -- party accepting connection - (DuplexConnection _ RcvQueue {smpClientVersion = v'} _, Nothing) -> do + (DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do withStore c (\db -> runExceptT $ agentRatchetDecrypt db connId encConnInfo) >>= parseMessage >>= \case AgentConnInfo connInfo -> do notify $ INFO connInfo @@ -1458,7 +1609,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm _ -> do withStore' c $ \db -> setRcvQueueStatus db rq Active case conn of - DuplexConnection _ _ sq@SndQueue {status = sndStatus} + DuplexConnection _ _ (sq@SndQueue {status = sndStatus} :| _) -- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO -- this branch is executed by the accepting party in duplexHandshake mode (v2) -- and by the initiating party in v1 @@ -1471,7 +1622,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm enqueueDuplexHello :: SndQueue -> m () enqueueDuplexHello sq = void $ enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO - replyMsg :: L.NonEmpty SMPQueueInfo -> m () + replyMsg :: NonEmpty SMPQueueInfo -> m () replyMsg smpQueues = do logServer "<--" c srv rId "MSG " case duplexHandshake of @@ -1482,6 +1633,95 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm connectReplyQueues c cData ownConnInfo smpQueues `catchError` (notify . ERR) _ -> prohibited + -- processed by queue sender + qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m () + qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported" + qAddMsg ((qUri, Just addr) :| _) (DuplexConnection _ rqs sqs@(sq :| sqs_)) = do + clientVRange <- asks $ smpClientVRange . config + case qUri `compatibleVersion` clientVRange of + Just qInfo@(Compatible sqInfo@SMPQueueInfo {queueAddress}) -> + case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of + (Just _, _) -> qError "QADD: queue address is already used in connection" + (_, Just _replaced@SndQueue {dbQueueId}) -> do + sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue connId qInfo + let sq' = (sq_ :: SndQueue) {nextPrimary = True, dbReplaceQueueId = Just dbQueueId} + void . withStore c $ \db -> addConnSndQueue db connId sq' + case (sndPublicKey, e2ePubKey) of + (Just sndPubKey, Just dhPublicKey) -> do + logServer "<--" c srv rId $ "MSG " <> logSecret (senderId queueAddress) + let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}} + void . enqueueMessages c cData sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)] + let conn' = DuplexConnection cData rqs (sq <| sq' :| sqs_) + notify . SWITCH SPStarted $ connectionStats conn' + _ -> qError "absent sender keys" + _ -> qError "QADD: replaced queue address is not found in connection" + _ -> throwError $ AGENT A_VERSION + + -- processed by queue recipient + qKeyMsg :: NonEmpty (SMPQueueInfo, SndPublicVerifyKey) -> Connection 'CDuplex -> m () + qKeyMsg ((qInfo, senderKey) :| _) (DuplexConnection _ rqs _) = do + clientVRange <- asks $ smpClientVRange . config + unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION + case findRQ (smpServer, senderId) rqs of + Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'} + | status' == New || status' == Confirmed -> do + logServer "<--" c srv rId $ "MSG " <> logSecret senderId + let dhSecret = C.dh' dhPublicKey dhPrivKey + withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer' + enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey + notify . SWITCH SPConfirmed $ connectionStats conn + | otherwise -> qError "QKEY: queue already secured" + _ -> qError "QKEY: queue address not found in connection" + where + SMPQueueInfo cVer' SMPQueueAddress {smpServer, senderId, dhPublicKey} = qInfo + + -- processed by queue sender + -- mark queue as Secured and to start sending messages to it + qUseMsg :: NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> m () + qUseMsg ((addr, primary) :| _) (DuplexConnection _ _ sqs) = + case removeQ addr sqs of + Just (sq', sqs') -> do + logServer "<--" c srv rId $ "MSG " <> logSecret (snd addr) + withStore' c $ \db -> do + setSndQueueStatus db sq' Secured + when primary $ setSndQueuePrimary db connId sq' + let sq'' = (sq' :: SndQueue) {status = Secured, primary} + void $ enqueueMessages c cData (sq'' :| sqs') SMP.noMsgFlags $ QTEST [addr] + notify . SWITCH SPConfirmed $ connectionStats conn + _ -> qError "QUSE: queue address not found in connection" + + -- processed by queue sender + -- remove snd queue from connection and enqueue QEND message + qDelMsg :: NonEmpty (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () + qDelMsg (addr :| _) (DuplexConnection _ rqs sqs) = + case removeQ addr sqs of + Nothing -> logServer "<--" c srv rId "MSG : queue not found (already deleted?)" + Just (sq, sq' : sqs') -> do + logServer "<--" c srv rId $ "MSG " <> logSecret (snd addr) + -- remove the delivery from the map to stop the thread when the delivery loop is complete + atomically $ TM.delete addr $ smpQueueMsgQueues c + withStore' c $ \db -> do + deletePendingMsgs db connId sq + deleteConnSndQueue db connId sq + let sqs'' = sq' :| sqs' + conn' = DuplexConnection cData rqs sqs'' + void $ enqueueMessages c cData sqs'' SMP.noMsgFlags $ QEND [addr] + notify . SWITCH SPTested $ connectionStats conn' + _ -> qError "QDEL received to the only queue in connection" + + -- received by party initiating switch + -- TODO *** check that the received address matches expectations + qEndMsg :: NonEmpty (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m () + qEndMsg (addr@(smpServer, senderId) :| _) (DuplexConnection _ rqs _) = + case findRQ addr rqs of + Just RcvQueue {rcvId} -> do + logServer "<--" c srv rId $ "MSG " <> logSecret senderId + enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQDelete rcvId + _ -> qError "QEND: queue address not found in connection" + + qError :: String -> m () + qError = throwError . AGENT . A_QUEUE + smpInvitation :: ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () smpInvitation connReq@(CRInvitationUri crData _) cInfo = do logServer "<--" c srv rId "MSG " @@ -1490,11 +1730,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm g <- asks idsDrg let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo} invId <- withStore c $ \db -> createInvitation db g newInv - let srvs = L.map queueServer $ crSmpQueues crData + let srvs = L.map qServer $ crSmpQueues crData notify $ REQ invId srvs cInfo _ -> prohibited - where - queueServer (SMPQueueUri _ SMPQueueAddress {smpServer}) = smpServer checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash @@ -1505,15 +1743,15 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, v, sessId, rId, cm | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash | otherwise = MsgError MsgDuplicate -- this case is not possible -connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> L.NonEmpty SMPQueueInfo -> m () +connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m () connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do clientVRange <- asks $ smpClientVRange . config case qInfo `proveCompatible` clientVRange of Nothing -> throwError $ AGENT A_VERSION Just qInfo' -> do - sq <- newSndQueue qInfo' - withStore c $ \db -> upgradeRcvConnToDuplex db connId sq - enqueueConfirmation c cData sq ownConnInfo Nothing + sq <- newSndQueue connId qInfo' + dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq + enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do @@ -1570,28 +1808,34 @@ agentRatchetDecrypt db connId encAgentMsg = do liftIO $ updateRatchet db connId rc' skippedDiff liftEither $ first (SEAgentError . cryptoError) agentMsgBody_ -newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => Compatible SMPQueueInfo -> m SndQueue -newSndQueue qInfo = +newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => ConnId -> Compatible SMPQueueInfo -> m SndQueue +newSndQueue connId qInfo = asks (cmdSignAlg . config) >>= \case - C.SignAlg a -> newSndQueue_ a qInfo + C.SignAlg a -> newSndQueue_ a connId qInfo newSndQueue_ :: (C.SignatureAlgorithm a, C.AlgorithmI a, MonadUnliftIO m) => C.SAlgorithm a -> + ConnId -> Compatible SMPQueueInfo -> m SndQueue -newSndQueue_ a (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do +newSndQueue_ a connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do -- this function assumes clientVersion is compatible - it was tested before (sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a (e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair' pure SndQueue - { server = smpServer, + { connId, + server = smpServer, sndId = senderId, sndPublicKey = Just sndPublicKey, sndPrivateKey, e2eDhSecret = C.dh' rcvE2ePubDhKey e2ePrivKey, e2ePubKey = Just e2ePubKey, status = New, + dbQueueId = 0, + primary = True, + nextPrimary = False, + dbReplaceQueueId = Nothing, smpClientVersion } diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 552a545e2..d336ea41c 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -51,6 +51,7 @@ module Simplex.Messaging.Agent.Client suspendQueue, deleteQueue, logServer, + logSecret, removeSubscription, hasActiveSubscription, agentClientStore, @@ -62,6 +63,7 @@ module Simplex.Messaging.Agent.Client agentOperationBracket, waitUntilActive, throwWhenInactive, + throwWhenNoDelivery, beginAgentOperation, endAgentOperation, suspendSendingAndDatabase, @@ -90,11 +92,12 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (isRight, partitionEithers) import Data.Functor (($>)) -import Data.List.NonEmpty (NonEmpty) +import Data.List (partition, (\\)) +import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (listToMaybe) +import Data.Maybe (isJust, listToMaybe) import Data.Set (Set) import qualified Data.Set as S import Data.Text.Encoding @@ -107,6 +110,8 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction) +import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues) +import qualified Simplex.Messaging.Agent.TRcvQueues as RQ import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C @@ -140,8 +145,6 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.TMap2 (TMap2) -import qualified Simplex.Messaging.TMap2 as TM2 import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version @@ -167,11 +170,11 @@ data AgentClient = AgentClient ntfClients :: TMap NtfServer NtfClientVar, useNetworkConfig :: TVar NetworkConfig, subscrConns :: TVar (Set ConnId), - activeSubs :: TMap2 SMPServer ConnId RcvQueue, - pendingSubs :: TMap2 SMPServer ConnId RcvQueue, - connMsgsQueued :: TMap ConnId Bool, - smpQueueMsgQueues :: TMap (SMPServer, SMP.SenderId) (TQueue InternalId), - smpQueueMsgDeliveries :: TMap (SMPServer, SMP.SenderId) (Async ()), + activeSubs :: TRcvQueues, + pendingSubs :: TRcvQueues, + pendingMsgsQueued :: TMap SndQAddr Bool, + smpQueueMsgQueues :: TMap SndQAddr (TQueue InternalId), + smpQueueMsgDeliveries :: TMap SndQAddr (Async ()), connCmdsQueued :: TMap ConnId Bool, asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId), asyncCmdProcesses :: TMap (Maybe SMPServer) (Async ()), @@ -229,9 +232,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do ntfClients <- TM.empty useNetworkConfig <- newTVar netCfg subscrConns <- newTVar S.empty - activeSubs <- TM2.empty - pendingSubs <- TM2.empty - connMsgsQueued <- TM.empty + activeSubs <- RQ.empty + pendingSubs <- RQ.empty + pendingMsgsQueued <- TM.empty smpQueueMsgQueues <- TM.empty smpQueueMsgDeliveries <- TM.empty connCmdsQueued <- TM.empty @@ -249,7 +252,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do reconnections <- newTVar [] asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i') - return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, clientId, agentEnv} + return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, pendingMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, connLocks, reconnectLocks, reconnections, asyncClients, clientId, agentEnv} agentClientStore :: AgentClient -> SQLiteStore agentClientStore AgentClient {agentEnv = Env {store}} = store @@ -282,25 +285,25 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do clientDisconnected :: UnliftIO m -> SMPClient -> IO () clientDisconnected u client = do - removeClientAndSubs >>= (`forM_` serverDown) + removeClientAndSubs >>= serverDown logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv where - removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) + removeClientAndSubs :: IO ([RcvQueue], [ConnId]) removeClientAndSubs = atomically $ do TM.delete srv smpClients - TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs - where - updateSubs cVar = do - TM2.insert1 srv cVar $ pendingSubs c - readTVar cVar + qs <- RQ.getDelSrvQueues srv $ activeSubs c + mapM_ (`RQ.addQueue` pendingSubs c) qs + cs <- RQ.getConns (activeSubs c) + -- TODO deduplicate conns + let conns = map (connId :: RcvQueue -> ConnId) qs \\ S.toList cs + pure (qs, conns) - serverDown :: Map ConnId RcvQueue -> IO () - serverDown cs = whenM (readTVarIO active) $ do + serverDown :: ([RcvQueue], [ConnId]) -> IO () + serverDown (qs, conns) = whenM (readTVarIO active) $ do notifySub "" $ hostEvent DISCONNECT client - let conns = M.keys cs - unless (null conns) $ do - notifySub "" $ DOWN srv conns - atomically $ mapM_ (releaseGetLock c) cs + unless (null conns) $ notifySub "" $ DOWN srv conns + unless (null qs) $ do + atomically $ mapM_ (releaseGetLock c) qs unliftIO u reconnectServer reconnectServer :: m () @@ -317,20 +320,22 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do reconnectClient :: m () reconnectClient = withLockMap_ (reconnectLocks c) srv "reconnect" $ - atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar) - >>= mapM_ resubscribe + atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe where - resubscribe :: Map ConnId RcvQueue -> m () + resubscribe :: [RcvQueue] -> m () resubscribe qs = do connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar) - (client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs + cs <- atomically . RQ.getConns $ activeSubs c + (client_, rs) <- subscribeQueues c srv qs + let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs liftIO $ do unless connected $ mapM_ (notifySub "" . hostEvent CONNECT) client_ - unless (M.null oks) $ do - notifySub "" . UP srv $ M.keys oks - let (tempErrs, finalErrs) = M.partition temporaryAgentError errs - liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs - mapM_ throwError . listToMaybe $ M.elems tempErrs + -- TODO deduplicate okConns + let conns = okConns \\ S.toList cs + unless (null conns) $ notifySub "" $ UP srv conns + let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs + liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs + mapM_ (throwError . snd) $ listToMaybe tempErrs notifySub :: ConnId -> ACommand 'Agent -> IO () notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) @@ -425,10 +430,10 @@ closeAgentClient c = liftIO $ do cancelActions $ asyncClients c cancelActions $ smpQueueMsgDeliveries c cancelActions $ asyncCmdProcesses c - atomically . TM2.clear $ activeSubs c - atomically . TM2.clear $ pendingSubs c + atomically . RQ.clear $ activeSubs c + atomically . RQ.clear $ pendingSubs c clear subscrConns - clear connMsgsQueued + clear pendingMsgsQueued clear smpQueueMsgQueues clear connCmdsQueued clear asyncCmdQueues @@ -443,6 +448,14 @@ waitUntilActive c = unlessM (readTVar $ active c) retry throwWhenInactive :: AgentClient -> STM () throwWhenInactive c = unlessM (readTVar $ active c) $ throwSTM ThreadKilled +throwWhenNoDelivery :: AgentClient -> SndQueue -> STM () +throwWhenNoDelivery c SndQueue {server, sndId} = + unlessM (isJust <$> TM.lookup k (smpQueueMsgQueues c)) $ do + TM.delete k $ smpQueueMsgDeliveries c + throwSTM ThreadKilled + where + k = (server, sndId) + closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (ProtoServer msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty) @@ -502,19 +515,20 @@ protocolClientError protocolError_ = \case e@PCESignatureError {} -> INTERNAL $ show e e@PCEIOError {} -> INTERNAL $ show e -newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri) -newRcvQueue c srv vRange = +newRcvQueue :: AgentMonad m => AgentClient -> ConnId -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri) +newRcvQueue c connId srv vRange = asks (cmdSignAlg . config) >>= \case - C.SignAlg a -> newRcvQueue_ a c srv vRange + C.SignAlg a -> newRcvQueue_ a c connId srv vRange newRcvQueue_ :: (C.SignatureAlgorithm a, C.AlgorithmI a, AgentMonad m) => C.SAlgorithm a -> AgentClient -> + ConnId -> SMPServer -> VersionRange -> m (RcvQueue, SMPQueueUri) -newRcvQueue_ a c srv vRange = do +newRcvQueue_ a c connId srv vRange = do (recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a (dhKey, privDhKey) <- liftIO C.generateKeyPair' (e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair' @@ -524,36 +538,41 @@ newRcvQueue_ a c srv vRange = do logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] let rq = RcvQueue - { server = srv, + { connId, + server = srv, rcvId, rcvPrivateKey, rcvDhSecret = C.dh' rcvPublicDhKey privDhKey, e2ePrivKey, e2eDhSecret = Nothing, - sndId = Just sndId, + sndId, status = New, + dbQueueId = 0, + primary = True, + nextPrimary = False, + dbReplaceQueueId = Nothing, smpClientVersion = maxVersion vRange, clientNtfCreds = Nothing } pure (rq, SMPQueueUri vRange $ SMPQueueAddress srv sndId e2eDhKey) -subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () -subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do +subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () +subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED atomically $ do modifyTVar (subscrConns c) $ S.insert connId - TM2.insert server connId rq $ pendingSubs c + RQ.addQueue rq $ pendingSubs c withLogClient c server rcvId "SUB" $ \smp -> - liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId) + liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq) >>= either throwError pure -processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) -processSubResult c rq connId r = do +processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) +processSubResult c rq r = do case r of Left e -> atomically . unless (temporaryClientError e) $ - TM2.delete connId (pendingSubs c) - _ -> addSubscription c rq connId + RQ.deleteQueue rq (pendingSubs c) + _ -> addSubscription c rq pure r temporaryClientError :: ProtocolClientError -> Bool @@ -569,45 +588,46 @@ temporaryAgentError = \case _ -> False -- | subscribe multiple queues - all passed queues should be on the same server -subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ())) +subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> [RcvQueue] -> m (Maybe SMPClient, [(RcvQueue, Either AgentErrorType ())]) subscribeQueues c srv qs = do - (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) - forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do + (errs, qs_) <- partitionEithers <$> mapM checkQueue qs + forM_ qs_ $ \rq@RcvQueue {connId, server = _server} -> atomically $ do + -- TODO check server is correct modifyTVar (subscrConns c) $ S.insert connId - TM2.insert server connId rq $ pendingSubs c + RQ.addQueue rq $ pendingSubs c case L.nonEmpty qs_ of Just qs' -> do smp_ <- tryError (getSMPServerClient c srv) - (eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of - Left e -> pure $ map (second . const $ Left e) qs_ + (eitherToMaybe smp_,) . (errs <>) <$> case smp_ of + Left e -> pure $ map (,Left e) qs_ Right smp -> do logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB" - let qs2 = L.map (queueCreds . snd) qs' - rs' :: [((ConnId, RcvQueue), Either ProtocolClientError ())] <- - liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2 - forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r - pure $ map (bimap fst (first $ protocolClientError SMP)) rs' - _ -> pure (Nothing, M.fromList errs) + let qs2 = L.map queueCreds qs' + liftIO $ do + rs <- zip qs_ . L.toList <$> subscribeSMPQueues smp qs2 + mapM_ (uncurry $ processSubResult c) rs + pure $ map (second . first $ protocolClientError SMP) rs + _ -> pure (Nothing, errs) where - checkQueue rq@(connId, RcvQueue {rcvId, server}) = do + checkQueue rq@RcvQueue {rcvId, server} = do prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c - pure $ if prohibited || srv /= server then Left (connId, Left $ CMD PROHIBITED) else Right rq + pure $ if prohibited || srv /= server then Left (rq, Left $ CMD PROHIBITED) else Right rq queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) -addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m () -addSubscription c rq@RcvQueue {server} connId = atomically $ do +addSubscription :: MonadIO m => AgentClient -> RcvQueue -> m () +addSubscription c rq@RcvQueue {connId} = atomically $ do modifyTVar (subscrConns c) $ S.insert connId - TM2.insert server connId rq $ activeSubs c - TM2.delete connId $ pendingSubs c + RQ.addQueue rq $ activeSubs c + RQ.deleteQueue rq $ pendingSubs c hasActiveSubscription :: AgentClient -> ConnId -> STM Bool -hasActiveSubscription c connId = TM2.member connId $ activeSubs c +hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c removeSubscription :: AgentClient -> ConnId -> STM () removeSubscription c connId = do modifyTVar (subscrConns c) $ S.delete connId - TM2.delete connId $ activeSubs c - TM2.delete connId $ pendingSubs c + RQ.deleteConn connId $ activeSubs c + RQ.deleteConn connId $ pendingSubs c getSubscriptions :: AgentClient -> STM (Set ConnId) getSubscriptions = readTVar . subscrConns diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 15027d4c6..a75531ac0 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -96,7 +96,8 @@ data AgentConfig = AgentConfig certificateFile :: FilePath, e2eEncryptVRange :: VersionRange, smpAgentVRange :: VersionRange, - smpClientVRange :: VersionRange + smpClientVRange :: VersionRange, + initialClientId :: Int } defaultReconnectInterval :: RetryInterval @@ -142,7 +143,8 @@ defaultAgentConfig = certificateFile = "/etc/opt/simplex-agent/agent.crt", e2eEncryptVRange = supportedE2EEncryptVRange, smpAgentVRange = supportedSMPAgentVRange, - smpClientVRange = supportedSMPClientVRange + smpClientVRange = supportedSMPClientVRange, + initialClientId = 0 } data Env = Env @@ -155,12 +157,12 @@ data Env = Env } newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env -newSMPAgentEnv config@AgentConfig {database, yesToMigrations} = do +newSMPAgentEnv config@AgentConfig {database, yesToMigrations, initialClientId} = do idsDrg <- newTVarIO =<< drgNew store <- case database of AgentDB st -> pure st AgentDBFile {dbFile, dbKey} -> liftIO $ createAgentStore dbFile dbKey yesToMigrations - clientCounter <- newTVarIO 0 + clientCounter <- newTVarIO initialClientId randomServer <- newTVarIO =<< liftIO newStdGen ntfSupervisor <- atomically . newNtfSubSupervisor $ tbqSize config return Env {config, store, idsDrg, clientCounter, randomServer, ntfSupervisor} diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 7f3346f4c..7476e85c7 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -73,7 +73,7 @@ processNtfSub c (connId, cmd) = do NSCCreate -> do (a, RcvQueue {server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do a <- liftIO $ getNtfSubscription db connId - q <- ExceptT $ getRcvQueue db connId + q <- ExceptT $ getPrimaryRcvQueue db connId pure (a, q) logInfo $ "processNtfSub, NSCCreate - a = " <> tshow a case a of @@ -125,7 +125,7 @@ processNtfSub c (connId, cmd) = do (Just (NtfSubscription {ntfServer}, _)) -> addNtfNTFWorker ntfServer _ -> pure () -- err "NSCDelete - no subscription" NSCSmpDelete -> do - withStore' c (`getRcvQueue` connId) >>= \case + withStore' c (`getPrimaryRcvQueue` connId) >>= \case Right rq@RcvQueue {server = smpServer} -> do logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq ts <- liftIO getCurrentTime @@ -185,7 +185,7 @@ runNtfWorker c srv doWork = do NSACreate -> getNtfToken >>= \case Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do - RcvQueue {clientNtfCreds} <- withStore c (`getRcvQueue` connId) + RcvQueue {clientNtfCreds} <- withStore c (`getPrimaryRcvQueue` connId) case clientNtfCreds of Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey @@ -260,7 +260,7 @@ runNtfSMPWorker c srv doWork = do NSASmpKey -> getNtfToken >>= \case Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do - rq <- withStore c (`getRcvQueue` connId) + rq <- withStore c (`getPrimaryRcvQueue` connId) C.SignAlg a <- asks (cmdSignAlg . config) (ntfPublicKey, ntfPrivateKey) <- liftIO $ C.generateSignatureKeyPair a (rcvNtfPubDhKey, rcvNtfPrivDhKey) <- liftIO C.generateKeyPair' @@ -275,7 +275,7 @@ runNtfSMPWorker c srv doWork = do NSASmpDelete -> do rq_ <- withStore' c $ \db -> do setRcvQueueNtfCreds db connId Nothing - getRcvQueue db connId + getPrimaryRcvQueue db connId forM_ rq_ $ \rq -> disableQueueNotifications c rq withStore' c $ \db -> deleteNtfSubscription db connId diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index c23eb9f98..727ae8403 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -50,15 +50,19 @@ module Simplex.Messaging.Agent.Protocol MsgHash, MsgMeta (..), ConnectionStats (..), + SwitchPhase (..), SMPConfirmation (..), AgentMsgEnvelope (..), AgentMessage (..), AgentMessageType (..), APrivHeader (..), AMessage (..), + SndQAddr, SMPServer, pattern SMPServer, SrvLoc (..), + SMPQueue (..), + sameQAddress, SMPQueueUri (..), SMPQueueInfo (..), SMPQueueAddress (..), @@ -162,6 +166,7 @@ import Simplex.Messaging.Protocol legacyEncodeServer, legacyServerP, legacyStrEncodeServer, + sameSrvAddr, pattern SMPServer, ) import qualified Simplex.Messaging.Protocol as SMP @@ -247,12 +252,14 @@ data ACommand (p :: AParty) where DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent DOWN :: SMPServer -> [ConnId] -> ACommand Agent UP :: SMPServer -> [ConnId] -> ACommand Agent + SWITCH :: SwitchPhase -> ConnectionStats -> ACommand Agent SEND :: MsgFlags -> MsgBody -> ACommand Client MID :: AgentMsgId -> ACommand Agent SENT :: AgentMsgId -> ACommand Agent MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent ACK :: AgentMsgId -> ACommand Client + SWCH :: ACommand Client OFF :: ACommand Client DEL :: ACommand Client CHK :: ACommand Client @@ -284,12 +291,14 @@ data ACommandTag (p :: AParty) where DISCONNECT_ :: ACommandTag Agent DOWN_ :: ACommandTag Agent UP_ :: ACommandTag Agent + SWITCH_ :: ACommandTag Agent SEND_ :: ACommandTag Client MID_ :: ACommandTag Agent SENT_ :: ACommandTag Agent MERR_ :: ACommandTag Agent MSG_ :: ACommandTag Agent ACK_ :: ACommandTag Client + SWCH_ :: ACommandTag Client OFF_ :: ACommandTag Client DEL_ :: ACommandTag Client CHK_ :: ACommandTag Client @@ -320,12 +329,14 @@ aCommandTag = \case DISCONNECT {} -> DISCONNECT_ DOWN {} -> DOWN_ UP {} -> UP_ + SWITCH {} -> SWITCH_ SEND {} -> SEND_ MID _ -> MID_ SENT _ -> SENT_ MERR {} -> MERR_ MSG {} -> MSG_ ACK _ -> ACK_ + SWCH -> SWCH_ OFF -> OFF_ DEL -> DEL_ CHK -> CHK_ @@ -334,6 +345,23 @@ aCommandTag = \case ERR _ -> ERR_ SUSPENDED -> SUSPENDED_ +data SwitchPhase = SPStarted | SPConfirmed | SPTested | SPCompleted + deriving (Eq, Show) + +instance StrEncoding SwitchPhase where + strEncode = \case + SPStarted -> "started" + SPConfirmed -> "confirmed" + SPTested -> "tested" + SPCompleted -> "completed" + strP = + A.takeTill (== ' ') >>= \case + "started" -> pure SPStarted + "confirmed" -> pure SPConfirmed + "tested" -> pure SPTested + "completed" -> pure SPCompleted + _ -> fail "bad SwitchPhase" + data ConnectionStats = ConnectionStats { rcvServers :: [SMPServer], sndServers :: [SMPServer] @@ -508,7 +536,18 @@ instance Encoding AgentMessage where 'M' -> AgentMessage <$> smpP <*> smpP _ -> fail "bad AgentMessage" -data AgentMessageType = AM_CONN_INFO | AM_CONN_INFO_REPLY | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_ +data AgentMessageType + = AM_CONN_INFO + | AM_CONN_INFO_REPLY + | AM_HELLO_ + | AM_REPLY_ + | AM_A_MSG_ + | AM_QADD_ + | AM_QKEY_ + | AM_QUSE_ + | AM_QTEST_ + | AM_QDEL_ + | AM_QEND_ deriving (Eq, Show) instance Encoding AgentMessageType where @@ -518,6 +557,12 @@ instance Encoding AgentMessageType where AM_HELLO_ -> "H" AM_REPLY_ -> "R" AM_A_MSG_ -> "M" + AM_QADD_ -> "QA" + AM_QKEY_ -> "QK" + AM_QUSE_ -> "QU" + AM_QTEST_ -> "QT" + AM_QDEL_ -> "QD" + AM_QEND_ -> "QE" smpP = A.anyChar >>= \case 'C' -> pure AM_CONN_INFO @@ -525,6 +570,15 @@ instance Encoding AgentMessageType where 'H' -> pure AM_HELLO_ 'R' -> pure AM_REPLY_ 'M' -> pure AM_A_MSG_ + 'Q' -> + A.anyChar >>= \case + 'A' -> pure AM_QADD_ + 'K' -> pure AM_QKEY_ + 'U' -> pure AM_QUSE_ + 'T' -> pure AM_QTEST_ + 'D' -> pure AM_QDEL_ + 'E' -> pure AM_QEND_ + _ -> fail "bad AgentMessageType" _ -> fail "bad AgentMessageType" agentMessageType :: AgentMessage -> AgentMessageType @@ -540,6 +594,12 @@ agentMessageType = \case -- REPLY is only used in v1 REPLY _ -> AM_REPLY_ A_MSG _ -> AM_A_MSG_ + QADD _ -> AM_QADD_ + QKEY _ -> AM_QKEY_ + QUSE _ -> AM_QUSE_ + QTEST _ -> AM_QTEST_ + QDEL _ -> AM_QDEL_ + QEND _ -> AM_QEND_ data APrivHeader = APrivHeader { -- | sequential ID assigned by the sending agent @@ -554,7 +614,16 @@ instance Encoding APrivHeader where smpEncode (sndMsgId, prevMsgHash) smpP = APrivHeader <$> smpP <*> smpP -data AMsgType = HELLO_ | REPLY_ | A_MSG_ +data AMsgType + = HELLO_ + | REPLY_ + | A_MSG_ + | QADD_ + | QKEY_ + | QUSE_ + | QTEST_ + | QDEL_ + | QEND_ deriving (Eq) instance Encoding AMsgType where @@ -562,11 +631,26 @@ instance Encoding AMsgType where HELLO_ -> "H" REPLY_ -> "R" A_MSG_ -> "M" + QADD_ -> "QA" + QKEY_ -> "QK" + QUSE_ -> "QU" + QTEST_ -> "QT" + QDEL_ -> "QD" + QEND_ -> "QE" smpP = - smpP >>= \case + A.anyChar >>= \case 'H' -> pure HELLO_ 'R' -> pure REPLY_ 'M' -> pure A_MSG_ + 'Q' -> + A.anyChar >>= \case + 'A' -> pure QADD_ + 'K' -> pure QKEY_ + 'U' -> pure QUSE_ + 'T' -> pure QTEST_ + 'D' -> pure QDEL_ + 'E' -> pure QEND_ + _ -> fail "bad AMsgType" _ -> fail "bad AMsgType" -- | Messages sent between SMP agents once SMP queue is secured. @@ -579,19 +663,45 @@ data AMessage REPLY (L.NonEmpty SMPQueueInfo) | -- | agent envelope for the client message A_MSG MsgBody + | -- add queue to connection (sent by recipient), with optional address of the replaced queue + QADD (L.NonEmpty (SMPQueueUri, Maybe SndQAddr)) + | -- key to secure the added queues and agree e2e encryption key (sent by sender) + QKEY (L.NonEmpty (SMPQueueInfo, SndPublicVerifyKey)) + | -- inform that the queues are ready to use (sent by recipient) + QUSE (L.NonEmpty (SndQAddr, Bool)) + | -- sent by the sender to test new queues + QTEST (L.NonEmpty SndQAddr) + | -- inform that the queues will be deleted (sent recipient once message received via the new queue) + QDEL (L.NonEmpty SndQAddr) + | -- sent by sender to confirm that no more messages will be sent to the queue + QEND (L.NonEmpty SndQAddr) deriving (Show) +type SndQAddr = (SMPServer, SMP.SenderId) + instance Encoding AMessage where smpEncode = \case HELLO -> smpEncode HELLO_ REPLY smpQueues -> smpEncode (REPLY_, smpQueues) A_MSG body -> smpEncode (A_MSG_, Tail body) + QADD qs -> smpEncode (QADD_, qs) + QKEY qs -> smpEncode (QKEY_, qs) + QUSE qs -> smpEncode (QUSE_, qs) + QTEST qs -> smpEncode (QTEST_, qs) + QDEL qs -> smpEncode (QDEL_, qs) + QEND qs -> smpEncode (QEND_, qs) smpP = smpP >>= \case HELLO_ -> pure HELLO REPLY_ -> REPLY <$> smpP A_MSG_ -> A_MSG . unTail <$> smpP + QADD_ -> QADD <$> smpP + QKEY_ -> QKEY <$> smpP + QUSE_ -> QUSE <$> smpP + QTEST_ -> QTEST <$> smpP + QDEL_ -> QDEL <$> smpP + QEND_ -> QEND <$> smpP instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) where strEncode = \case @@ -688,6 +798,11 @@ updateSMPServerHosts srv@ProtocolServer {host} = case host of _ -> srv _ -> srv +class SMPQueue q where + qServer :: q -> SMPServer + qAddress :: q -> (SMPServer, SMP.QueueId) + sameQueue :: (SMPServer, SMP.QueueId) -> q -> Bool + data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress} deriving (Eq, Show) @@ -728,6 +843,34 @@ data SMPQueueAddress = SMPQueueAddress } deriving (Eq, Show) +instance SMPQueue SMPQueueUri where + qServer SMPQueueUri {queueAddress} = qServer queueAddress + {-# INLINE qServer #-} + qAddress SMPQueueUri {queueAddress} = qAddress queueAddress + {-# INLINE qAddress #-} + sameQueue addr q = sameQAddress addr (qAddress q) + {-# INLINE sameQueue #-} + +instance SMPQueue SMPQueueInfo where + qServer SMPQueueInfo {queueAddress} = qServer queueAddress + {-# INLINE qServer #-} + qAddress SMPQueueInfo {queueAddress} = qAddress queueAddress + {-# INLINE qAddress #-} + sameQueue addr q = sameQAddress addr (qAddress q) + {-# INLINE sameQueue #-} + +instance SMPQueue SMPQueueAddress where + qServer SMPQueueAddress {smpServer} = smpServer + {-# INLINE qServer #-} + qAddress SMPQueueAddress {smpServer, senderId} = (smpServer, senderId) + {-# INLINE qAddress #-} + sameQueue addr q = sameQAddress addr (qAddress q) + {-# INLINE sameQueue #-} + +sameQAddress :: (SMPServer, SMP.QueueId) -> (SMPServer, SMP.QueueId) -> Bool +sameQAddress (srv, qId) (srv', qId') = sameSrvAddr srv srv' && qId == qId' +{-# INLINE sameQAddress #-} + instance StrEncoding SMPQueueUri where strEncode (SMPQueueUri vr SMPQueueAddress {smpServer = srv, senderId = qId, dhPublicKey}) | minVersion vr > 1 = strEncode srv <> "/" <> strEncode qId <> "#/?" <> query queryParams @@ -754,6 +897,13 @@ instance StrEncoding SMPQueueUri where hs_ <- queryParam_ "srv" query pure (vr, maybe [] thList_ hs_, dhKey) +instance Encoding SMPQueueUri where + smpEncode (SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}) = + smpEncode (clientVRange, smpServer, senderId, dhPublicKey) + smpP = do + (clientVRange, smpServer, senderId, dhPublicKey) <- smpP + pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey} + data ConnectionRequestUri (m :: ConnectionMode) where CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation -- contact connection request does NOT contain E2E encryption parameters - @@ -963,6 +1113,8 @@ data SMPAgentError A_ENCRYPTION | -- | duplicate message - this error is detected by ratchet decryption - this message will be ignored and not shown A_DUPLICATE + | -- | error in the message to add/delete/etc queue in connection + A_QUEUE {queueErr :: String} deriving (Eq, Generic, Read, Show, Exception) instance ToJSON SMPAgentError where @@ -978,6 +1130,7 @@ instance StrEncoding AgentErrorType where <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> strP) <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) <|> "BROKER " *> (BROKER <$> parseRead1) + <|> "AGENT QUEUE " *> (AGENT . A_QUEUE <$> parseRead A.takeByteString) <|> "AGENT " *> (AGENT <$> parseRead1) <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) strEncode = \case @@ -988,6 +1141,7 @@ instance StrEncoding AgentErrorType where BROKER (RESPONSE e) -> "BROKER RESPONSE " <> strEncode e BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e BROKER e -> "BROKER " <> bshow e + AGENT (A_QUEUE e) -> "AGENT QUEUE " <> bshow e AGENT e -> "AGENT " <> bshow e INTERNAL e -> "INTERNAL " <> bshow e @@ -1029,12 +1183,14 @@ instance StrEncoding ACmdTag where "DISCONNECT" -> pure $ ACmdTag SAgent DISCONNECT_ "DOWN" -> pure $ ACmdTag SAgent DOWN_ "UP" -> pure $ ACmdTag SAgent UP_ + "SWITCH" -> pure $ ACmdTag SAgent SWITCH_ "SEND" -> pure $ ACmdTag SClient SEND_ "MID" -> pure $ ACmdTag SAgent MID_ "SENT" -> pure $ ACmdTag SAgent SENT_ "MERR" -> pure $ ACmdTag SAgent MERR_ "MSG" -> pure $ ACmdTag SAgent MSG_ "ACK" -> pure $ ACmdTag SClient ACK_ + "SWCH" -> pure $ ACmdTag SClient SWCH_ "OFF" -> pure $ ACmdTag SClient OFF_ "DEL" -> pure $ ACmdTag SClient DEL_ "CHK" -> pure $ ACmdTag SClient CHK_ @@ -1062,12 +1218,14 @@ instance APartyI p => StrEncoding (ACommandTag p) where DISCONNECT_ -> "DISCONNECT" DOWN_ -> "DOWN" UP_ -> "UP" + SWITCH_ -> "SWITCH" SEND_ -> "SEND" MID_ -> "MID" SENT_ -> "SENT" MERR_ -> "MERR" MSG_ -> "MSG" ACK_ -> "ACK" + SWCH_ -> "SWCH" OFF_ -> "OFF" DEL_ -> "DEL" CHK_ -> "CHK" @@ -1097,6 +1255,7 @@ commandP binaryP = SUB_ -> pure SUB SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP) ACK_ -> s (ACK <$> A.decimal) + SWCH_ -> pure SWCH OFF_ -> pure OFF DEL_ -> pure DEL CHK_ -> pure CHK @@ -1112,6 +1271,7 @@ commandP binaryP = DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP) DOWN_ -> s (DOWN <$> strP_ <*> connections) UP_ -> s (UP <$> strP_ <*> connections) + SWITCH_ -> s (SWITCH <$> strP_ <*> strP) MID_ -> s (MID <$> A.decimal) SENT_ -> s (SENT <$> A.decimal) MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP) @@ -1139,36 +1299,40 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX -- | Serialize SMP agent command. serializeCommand :: ACommand p -> ByteString serializeCommand = \case - NEW ntfs cMode -> B.unwords ["NEW", strEncode ntfs, strEncode cMode] - INV cReq -> "INV " <> strEncode cReq - JOIN ntfs cReq cInfo -> B.unwords ["JOIN", strEncode ntfs, strEncode cReq, serializeBinary cInfo] - CONF confId srvs cInfo -> B.unwords ["CONF", confId, strEncodeList srvs, serializeBinary cInfo] - LET confId cInfo -> B.unwords ["LET", confId, serializeBinary cInfo] - REQ invId srvs cInfo -> B.unwords ["REQ", invId, strEncode srvs, serializeBinary cInfo] - ACPT invId cInfo -> B.unwords ["ACPT", invId, serializeBinary cInfo] - RJCT invId -> "RJCT " <> invId - INFO cInfo -> "INFO " <> serializeBinary cInfo - SUB -> "SUB" - END -> "END" - CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h] - DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h] - DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns] - UP srv conns -> B.unwords ["UP", strEncode srv, connections conns] - SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody - MID mId -> "MID " <> bshow mId - SENT mId -> "SENT " <> bshow mId - MERR mId e -> B.unwords ["MERR", bshow mId, strEncode e] - MSG msgMeta msgFlags msgBody -> B.unwords ["MSG", serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody] - ACK mId -> "ACK " <> bshow mId - OFF -> "OFF" - DEL -> "DEL" - CHK -> "CHK" - STAT srvs -> "STAT " <> strEncode srvs - CON -> "CON" - ERR e -> "ERR " <> strEncode e - OK -> "OK" - SUSPENDED -> "SUSPENDED" + NEW ntfs cMode -> s (NEW_, ntfs, cMode) + INV cReq -> s (INV_, cReq) + JOIN ntfs cReq cInfo -> s (JOIN_, ntfs, cReq, Str $ serializeBinary cInfo) + CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo] + LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo] + REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo] + ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo] + RJCT invId -> B.unwords [s RJCT_, invId] + INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo] + SUB -> s SUB_ + END -> s END_ + CONNECT p h -> s (CONNECT_, p, h) + DISCONNECT p h -> s (DISCONNECT_, p, h) + DOWN srv conns -> B.unwords [s DOWN_, s srv, connections conns] + UP srv conns -> B.unwords [s UP_, s srv, connections conns] + SWITCH phase srvs -> s (SWITCH_, phase, srvs) + SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody] + MID mId -> s (MID_, Str $ bshow mId) + SENT mId -> s (SENT_, Str $ bshow mId) + MERR mId e -> s (MERR_, Str $ bshow mId, e) + MSG msgMeta msgFlags msgBody -> B.unwords [s MSG_, serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody] + ACK mId -> s (ACK_, Str $ bshow mId) + SWCH -> s SWCH_ + OFF -> s OFF_ + DEL -> s DEL_ + CHK -> s CHK_ + STAT srvs -> s (STAT_, srvs) + CON -> s CON_ + ERR e -> s (ERR_, e) + OK -> s OK_ + SUSPENDED -> s SUSPENDED_ where + s :: StrEncoding a => a -> ByteString + s = strEncode showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis connections :: [ConnId] -> ByteString diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 246953340..24fbbcd63 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -5,6 +5,7 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -17,6 +18,9 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Kind (Type) +import Data.List (find) +import Data.List.NonEmpty (NonEmpty) +import qualified Data.List.NonEmpty as L import Data.Time (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Protocol @@ -43,7 +47,8 @@ import Simplex.Messaging.Version -- | A receive queue. SMP queue through which the agent receives messages from a sender. data RcvQueue = RcvQueue - { server :: SMPServer, + { connId :: ConnId, + server :: SMPServer, -- | recipient queue ID rcvId :: SMP.RecipientId, -- | key used by the recipient to sign transmissions @@ -55,9 +60,17 @@ data RcvQueue = RcvQueue -- | public sender's DH key and agreed shared DH secret for simple per-queue e2e e2eDhSecret :: Maybe C.DhSecretX25519, -- | sender queue ID - sndId :: Maybe SMP.SenderId, + sndId :: SMP.SenderId, -- | queue status status :: QueueStatus, + -- | database queue ID (within connection), can be Nothing for old queues + dbQueueId :: Int64, + -- | True for a primary queue of the connection + primary :: Bool, + -- | True for the next primary queue + nextPrimary :: Bool, + -- | database queue ID to replace, Nothing if this queue is not replacing another, `Just Nothing` is used for replacing old queues + dbReplaceQueueId :: Maybe Int64, -- | SMP client version smpClientVersion :: Version, -- | credentials used in context of notifications @@ -78,7 +91,8 @@ data ClientNtfCreds = ClientNtfCreds -- | A send queue. SMP queue through which the agent sends messages to a recipient. data SndQueue = SndQueue - { server :: SMPServer, + { connId :: ConnId, + server :: SMPServer, -- | sender queue ID sndId :: SMP.SenderId, -- | key pair used by the sender to sign transmissions @@ -90,11 +104,52 @@ data SndQueue = SndQueue e2eDhSecret :: C.DhSecretX25519, -- | queue status status :: QueueStatus, + -- | database queue ID (within connection), can be Nothing for old queues + dbQueueId :: Int64, + -- | True for a primary queue of the connection + primary :: Bool, + -- | True for the next primary queue + nextPrimary :: Bool, + -- | ID of the queue this one is replacing + dbReplaceQueueId :: Maybe Int64, -- | SMP client version smpClientVersion :: Version } deriving (Eq, Show) +instance SMPQueue RcvQueue where + qServer RcvQueue {server} = server + {-# INLINE qServer #-} + qAddress RcvQueue {server, rcvId} = (server, rcvId) + {-# INLINE qAddress #-} + sameQueue addr q = sameQAddress addr (qAddress q) + {-# INLINE sameQueue #-} + +instance SMPQueue SndQueue where + qServer SndQueue {server} = server + {-# INLINE qServer #-} + qAddress SndQueue {server, sndId} = (server, sndId) + {-# INLINE qAddress #-} + sameQueue addr q = sameQAddress addr (qAddress q) + {-# INLINE sameQueue #-} + +findQ :: SMPQueue q => (SMPServer, SMP.QueueId) -> NonEmpty q -> Maybe q +findQ = find . sameQueue +{-# INLINE findQ #-} + +removeQ :: SMPQueue q => (SMPServer, SMP.QueueId) -> NonEmpty q -> Maybe (q, [q]) +removeQ addr qs = case L.break (sameQueue addr) qs of + (_, []) -> Nothing + (qs1, q : qs2) -> Just (q, qs1 <> qs2) + +sndAddress :: RcvQueue -> (SMPServer, SMP.SenderId) +sndAddress RcvQueue {server, sndId} = (server, sndId) +{-# INLINE sndAddress #-} + +findRQ :: (SMPServer, SMP.SenderId) -> NonEmpty RcvQueue -> Maybe RcvQueue +findRQ sAddr = find $ sameQAddress sAddr . sndAddress +{-# INLINE findRQ #-} + -- * Connection types -- | Type of a connection. @@ -114,13 +169,21 @@ data Connection (d :: ConnType) where NewConnection :: ConnData -> Connection CNew RcvConnection :: ConnData -> RcvQueue -> Connection CRcv SndConnection :: ConnData -> SndQueue -> Connection CSnd - DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex + DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex ContactConnection :: ConnData -> RcvQueue -> Connection CContact deriving instance Eq (Connection d) deriving instance Show (Connection d) +connData :: Connection d -> ConnData +connData = \case + NewConnection cData -> cData + RcvConnection cData _ -> cData + SndConnection cData _ -> cData + DuplexConnection cData _ _ -> cData + ContactConnection cData _ -> cData + data SConnType :: ConnType -> Type where SCNew :: SConnType CNew SCRcv :: SConnType CRcv @@ -193,6 +256,7 @@ instance StrEncoding AgentCommand where data AgentCommandTag = AClientCommandTag (ACommandTag 'Client) | AInternalCommandTag InternalCommandTag + deriving (Show) instance StrEncoding AgentCommandTag where strEncode = \case @@ -208,12 +272,16 @@ data InternalCommand | ICAckDel SMP.RecipientId MsgId InternalId | ICAllowSecure SMP.RecipientId SMP.SndPublicVerifyKey | ICDuplexSecure SMP.RecipientId SMP.SndPublicVerifyKey + | ICQSecure SMP.RecipientId SMP.SndPublicVerifyKey + | ICQDelete SMP.RecipientId data InternalCommandTag = ICAck_ | ICAckDel_ | ICAllowSecure_ | ICDuplexSecure_ + | ICQSecure_ + | ICQDelete_ deriving (Show) instance StrEncoding InternalCommand where @@ -222,12 +290,16 @@ instance StrEncoding InternalCommand where ICAckDel rId srvMsgId mId -> strEncode (ICAckDel_, rId, srvMsgId, mId) ICAllowSecure rId sndKey -> strEncode (ICAllowSecure_, rId, sndKey) ICDuplexSecure rId sndKey -> strEncode (ICDuplexSecure_, rId, sndKey) + ICQSecure rId senderKey -> strEncode (ICQSecure_, rId, senderKey) + ICQDelete rId -> strEncode (ICQDelete_, rId) strP = - strP_ >>= \case - ICAck_ -> ICAck <$> strP_ <*> strP - ICAckDel_ -> ICAckDel <$> strP_ <*> strP_ <*> strP - ICAllowSecure_ -> ICAllowSecure <$> strP_ <*> strP - ICDuplexSecure_ -> ICDuplexSecure <$> strP_ <*> strP + strP >>= \case + ICAck_ -> ICAck <$> _strP <*> _strP + ICAckDel_ -> ICAckDel <$> _strP <*> _strP <*> _strP + ICAllowSecure_ -> ICAllowSecure <$> _strP <*> _strP + ICDuplexSecure_ -> ICDuplexSecure <$> _strP <*> _strP + ICQSecure_ -> ICQSecure <$> _strP <*> _strP + ICQDelete_ -> ICQDelete <$> _strP instance StrEncoding InternalCommandTag where strEncode = \case @@ -235,12 +307,16 @@ instance StrEncoding InternalCommandTag where ICAckDel_ -> "ACK_DEL" ICAllowSecure_ -> "ALLOW_SECURE" ICDuplexSecure_ -> "DUPLEX_SECURE" + ICQSecure_ -> "QSECURE" + ICQDelete_ -> "QDELETE" strP = A.takeTill (== ' ') >>= \case "ACK" -> pure ICAck_ "ACK_DEL" -> pure ICAckDel_ "ALLOW_SECURE" -> pure ICAllowSecure_ "DUPLEX_SECURE" -> pure ICDuplexSecure_ + "QSECURE" -> pure ICQSecure_ + "QDELETE" -> pure ICQDelete_ _ -> fail "bad InternalCommandTag" agentCommandTag :: AgentCommand -> AgentCommandTag @@ -254,6 +330,8 @@ internalCmdTag = \case ICAckDel {} -> ICAckDel_ ICAllowSecure {} -> ICAllowSecure_ ICDuplexSecure {} -> ICDuplexSecure_ + ICQSecure {} -> ICQSecure_ + ICQDelete _ -> ICQDelete_ -- * Confirmation types diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index ba93d7041..20ec06ede 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -15,6 +15,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -38,9 +39,16 @@ module Simplex.Messaging.Agent.Store.SQLite deleteConn, upgradeRcvConnToDuplex, upgradeSndConnToDuplex, + addConnRcvQueue, + addConnSndQueue, setRcvQueueStatus, setRcvQueueConfirmedE2E, setSndQueueStatus, + setRcvQueuePrimary, + setSndQueuePrimary, + deleteConnRcvQueue, + deleteConnSndQueue, + getPrimaryRcvQueue, getRcvQueue, setRcvQueueNtfCreds, -- Confirmations @@ -60,11 +68,14 @@ module Simplex.Messaging.Agent.Store.SQLite createRcvMsg, updateSndIds, createSndMsg, + createSndMsgDelivery, getPendingMsgData, getPendingMsgs, + deletePendingMsgs, setMsgUserAck, getLastMsg, deleteMsg, + deleteSndMsgDelivery, -- Double ratchet persistence createRatchetX3dhKeys, getRatchetX3dhKeys, @@ -120,10 +131,12 @@ import Data.Function (on) import Data.Functor (($>)) import Data.IORef import Data.Int (Int64) -import Data.List (foldl', groupBy) +import Data.List (foldl', groupBy, sortBy) import Data.List.NonEmpty (NonEmpty (..)) +import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Data.Maybe (fromMaybe, listToMaybe) +import Data.Ord (Down (..)) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -134,6 +147,7 @@ import Database.SQLite.Simple.FromField import Database.SQLite.Simple.QQ (sql) import Database.SQLite.Simple.ToField (ToField (..)) import qualified Database.SQLite3 as SQLite3 +import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Migrations (Migration) @@ -295,45 +309,39 @@ createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandsh createConn_ gVar cData $ \connId -> do DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) -updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ()) -updateNewConnRcv db connId rq@RcvQueue {server} = +updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64) +updateNewConnRcv db connId rq = getConn db connId $>>= \case (SomeConn _ NewConnection {}) -> updateConn (SomeConn _ RcvConnection {}) -> updateConn -- to allow retries (SomeConn c _) -> pure . Left . SEBadConnType $ connType c where - updateConn :: IO (Either StoreError ()) - updateConn = do - upsertServer_ db server - insertRcvQueue_ db connId rq - pure $ Right () + updateConn :: IO (Either StoreError Int64) + updateConn = Right <$> addConnRcvQueue_ db connId rq -updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ()) -updateNewConnSnd db connId sq@SndQueue {server} = +updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64) +updateNewConnSnd db connId sq = getConn db connId $>>= \case (SomeConn _ NewConnection {}) -> updateConn (SomeConn _ SndConnection {}) -> updateConn -- to allow retries (SomeConn c _) -> pure . Left . SEBadConnType $ connType c where - updateConn :: IO (Either StoreError ()) - updateConn = do - upsertServer_ db server - insertSndQueue_ db connId sq - pure $ Right () + updateConn :: IO (Either StoreError Int64) + updateConn = Right <$> addConnSndQueue_ db connId sq createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId) createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode = createConn_ gVar cData $ \connId -> do upsertServer_ db server DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake) - insertRcvQueue_ db connId q + void $ insertRcvQueue_ db connId q createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SndQueue -> IO (Either StoreError ConnId) createSndConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@SndQueue {server} = createConn_ gVar cData $ \connId -> do upsertServer_ db server DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, SCMInvitation, connAgentVersion, enableNtfs, duplexHandshake) - insertSndQueue_ db connId q + void $ insertSndQueue_ db connId q getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError SomeConn) getRcvConn db ProtocolServer {host, port} rcvId = @@ -356,25 +364,43 @@ deleteConn db connId = "DELETE FROM connections WHERE conn_id = :conn_id;" [":conn_id" := connId] -upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ()) -upgradeRcvConnToDuplex db connId sq@SndQueue {server} = +upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64) +upgradeRcvConnToDuplex db connId sq = getConn db connId $>>= \case - (SomeConn _ RcvConnection {}) -> do - upsertServer_ db server - insertSndQueue_ db connId sq - pure $ Right () + (SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq (SomeConn c _) -> pure . Left . SEBadConnType $ connType c -upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ()) -upgradeSndConnToDuplex db connId rq@RcvQueue {server} = +upgradeSndConnToDuplex :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64) +upgradeSndConnToDuplex db connId rq = getConn db connId >>= \case - Right (SomeConn _ SndConnection {}) -> do - upsertServer_ db server - insertRcvQueue_ db connId rq - pure $ Right () + Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c _ -> pure $ Left SEConnNotFound +addConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError Int64) +addConnRcvQueue db connId rq = + getConn db connId >>= \case + Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound + +addConnRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64 +addConnRcvQueue_ db connId rq@RcvQueue {server} = do + upsertServer_ db server + insertRcvQueue_ db connId rq + +addConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError Int64) +addConnSndQueue db connId sq = + getConn db connId >>= \case + Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound + +addConnSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64 +addConnSndQueue_ db connId sq@SndQueue {server} = do + upsertServer_ db server + insertSndQueue_ db connId sq + setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status = -- ? return error if queue does not exist? @@ -418,9 +444,39 @@ setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} stat |] [":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId] -getRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue) -getRcvQueue db connId = - maybe (Left SEConnNotFound) Right <$> getRcvQueueByConnId_ db connId +setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO () +setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do + DB.execute db "UPDATE rcv_queues SET rcv_primary = ?, next_rcv_primary = ? WHERE conn_id = ?" (False, False, connId) + DB.execute + db + "UPDATE rcv_queues SET rcv_primary = ?, next_rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?" + (True, False, Nothing :: Maybe Int64, connId, dbQueueId) + +setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO () +setSndQueuePrimary db connId SndQueue {dbQueueId} = do + DB.execute db "UPDATE snd_queues SET snd_primary = ?, next_snd_primary = ? WHERE conn_id = ?" (False, False, connId) + DB.execute + db + "UPDATE snd_queues SET snd_primary = ?, next_snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?" + (True, False, Nothing :: Maybe Int64, connId, dbQueueId) + +deleteConnRcvQueue :: DB.Connection -> ConnId -> RcvQueue -> IO () +deleteConnRcvQueue db connId RcvQueue {dbQueueId} = + DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId) + +deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO () +deleteConnSndQueue db connId SndQueue {dbQueueId} = do + DB.execute db "DELETE FROM snd_queues WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + +getPrimaryRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue) +getPrimaryRcvQueue db connId = + maybe (Left SEConnNotFound) (Right . L.head) <$> getRcvQueuesByConnId_ db connId + +getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) +getRcvQueue db connId (SMPServer host port _) rcvId = + firstRow (toRcvQueue connId) SEConnNotFound $ + DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ?") (connId, host, port, rcvId) setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO () setRcvQueueNtfCreds db connId clientNtfCreds = @@ -587,10 +643,10 @@ updateRcvIds db connId = do updateLastIdsRcv_ db connId internalId internalRcvId pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) -createRcvMsg :: DB.Connection -> ConnId -> RcvMsgData -> IO () -createRcvMsg db connId rcvMsgData = do +createRcvMsg :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () +createRcvMsg db connId rq rcvMsgData = do insertRcvMsgBase_ db connId rcvMsgData - insertRcvMsgDetails_ db connId rcvMsgData + insertRcvMsgDetails_ db connId rq rcvMsgData updateHashRcv_ db connId rcvMsgData updateSndIds :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash) @@ -607,9 +663,13 @@ createSndMsg db connId sndMsgData = do insertSndMsgDetails_ db connId sndMsgData updateHashSnd_ db connId sndMsgData +createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO () +createSndMsgDelivery db connId SndQueue {dbQueueId} msgId = + DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId) + getPendingMsgData :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (Maybe RcvQueue, PendingMsgData)) getPendingMsgData db connId msgId = do - rq_ <- getRcvQueueByConnId_ db connId + rq_ <- L.head <$$> getRcvQueuesByConnId_ db connId (rq_,) <$$> firstRow pendingMsgData SEMsgNotFound getMsgData_ where getMsgData_ = @@ -627,16 +687,23 @@ getPendingMsgData db connId msgId = do let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_ in PendingMsgData {msgId, msgType, msgFlags, msgBody, internalTs} -getPendingMsgs :: DB.Connection -> ConnId -> IO [InternalId] -getPendingMsgs db connId = +getPendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO [InternalId] +getPendingMsgs db connId SndQueue {dbQueueId} = map fromOnly - <$> DB.query db "SELECT internal_id FROM snd_messages WHERE conn_id = ?" (Only connId) + <$> DB.query db "SELECT internal_id FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) -setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError SMP.MsgId) -setMsgUserAck db connId agentMsgId = do - DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId) - firstRow fromOnly SEMsgNotFound $ - DB.query db "SELECT broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId) +deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO () +deletePendingMsgs db connId SndQueue {dbQueueId} = + DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + +setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (RcvQueue, SMP.MsgId)) +setMsgUserAck db connId agentMsgId = runExceptT $ do + liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId) + (dbRcvId, srvMsgId) <- + ExceptT . firstRow id SEMsgNotFound $ + DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId) + rq <- ExceptT $ getRcvQueueById_ db connId dbRcvId + pure (rq, srvMsgId) getLastMsg :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Maybe RcvMsg) getLastMsg db connId msgId = @@ -662,6 +729,15 @@ deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO () deleteMsg db connId msgId = DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId) +deleteSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO () +deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId = do + DB.execute + db + "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? AND internal_id = ?" + (connId, dbQueueId, msgId) + (Only (cnt :: Int) : _) <- DB.query db "SELECT count(*) FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ?" (connId, msgId) + when (cnt == 0) $ deleteMsg db connId msgId + createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO () createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 = DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2) @@ -1202,27 +1278,35 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do -- * createRcvConn helpers -insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () -insertRcvQueue_ dbConn connId RcvQueue {..} = do +insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO Int64 +insertRcvQueue_ db connId' RcvQueue {..} = do + qId <- newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId') DB.execute - dbConn + db [sql| INSERT INTO rcv_queues - (host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?); + (host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, status, rcv_queue_id, rcv_primary, next_rcv_primary, replace_rcv_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - ((host server, port server, rcvId, connId) :. (rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion)) + ((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, status, qId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion)) + pure qId -- * createSndConn helpers -insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO () -insertSndQueue_ dbConn connId SndQueue {..} = do +insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO Int64 +insertSndQueue_ db connId' SndQueue {..} = do + qId <- newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId') DB.execute - dbConn + db [sql| INSERT INTO snd_queues - (host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?); + (host, port, snd_id, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, next_snd_primary, replace_snd_queue_id, smp_client_version) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] - (host server, port server, sndId, connId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion) + ((host server, port server, sndId, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion)) + pure qId + +newQueueId_ :: [Only Int64] -> Int64 +newQueueId_ [] = 1 +newQueueId_ (Only maxId : _) = maxId + 1 -- * getConn helpers @@ -1230,65 +1314,79 @@ getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getConn dbConn connId = getConnData dbConn connId >>= \case Nothing -> pure $ Left SEConnNotFound - Just (connData, cMode) -> do - rQ <- getRcvQueueByConnId_ dbConn connId - sQ <- getSndQueueByConnId_ dbConn connId + Just (cData, cMode) -> do + rQ <- getRcvQueuesByConnId_ dbConn connId + sQ <- getSndQueuesByConnId_ dbConn connId pure $ case (rQ, sQ, cMode) of - (Just rcvQ, Just sndQ, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ) - (Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ) - (Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ) - (Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ) - (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection connData) + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) _ -> Left SEConnNotFound getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) getConnData dbConn connId' = - connData - <$> DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId') + maybeFirstRow cData $ DB.query dbConn "SELECT conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake FROM connections WHERE conn_id = ?;" (Only connId') where - connData [(connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake)] = Just (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode) - connData _ = Nothing + cData (connId, cMode, connAgentVersion, enableNtfs_, duplexHandshake) = (ConnData {connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, duplexHandshake}, cMode) -getRcvQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue) -getRcvQueueByConnId_ dbConn connId = - listToMaybe . map rcvQueue +-- | returns all connection queues, the first queue is the primary one +getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue)) +getRcvQueuesByConnId_ db connId = + L.nonEmpty . sortBy primaryFirst . map (toRcvQueue connId) + <$> DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ?") (Only connId) + where + primaryFirst RcvQueue {primary = p} RcvQueue {primary = p'} = compare (Down p) (Down p') + +rcvQueueQuery :: Query +rcvQueueQuery = + [sql| + SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, + q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status, + q.rcv_queue_id, q.rcv_primary, q.next_rcv_primary, q.replace_rcv_queue_id, q.smp_client_version, + q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret + FROM rcv_queues q + INNER JOIN servers s ON q.host = s.host AND q.port = s.port + |] + +toRcvQueue :: + ConnId -> + (C.KeyHash, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateSignKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus) + :. (Int64, Bool, Bool, Maybe Int64, Maybe Version) + :. (Maybe SMP.NtfPublicVerifyKey, Maybe SMP.NtfPrivateSignKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> + RcvQueue +toRcvQueue connId ((keyHash, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = + let server = SMPServer host port keyHash + smpClientVersion = fromMaybe 1 smpClientVersion_ + clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of + (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} + _ -> Nothing + in RcvQueue {connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion, clientNtfCreds} + +getRcvQueueById_ :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue) +getRcvQueueById_ db connId dbRcvId = + firstRow (toRcvQueue connId) SEConnNotFound $ + DB.query db (rcvQueueQuery <> " WHERE conn_id = ? AND rcv_queue_id = ?") (connId, dbRcvId) + +-- | returns all connection queues, the first queue is the primary one +getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue)) +getSndQueuesByConnId_ dbConn connId = + L.nonEmpty . sortBy primaryFirst . map sndQueue <$> DB.query dbConn [sql| - SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, - q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.status, q.smp_client_version, - q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret - FROM rcv_queues q - INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.conn_id = ?; - |] - (Only connId) - where - rcvQueue ((keyHash, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion_) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = - let server = SMPServer host port keyHash - smpClientVersion = fromMaybe 1 smpClientVersion_ - clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of - (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} - _ -> Nothing - in RcvQueue {server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status, smpClientVersion, clientNtfCreds} - -getSndQueueByConnId_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue) -getSndQueueByConnId_ dbConn connId = - sndQueue - <$> DB.query - dbConn - [sql| - SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.smp_client_version + SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.next_snd_primary, q.replace_snd_queue_id, q.smp_client_version FROM snd_queues q INNER JOIN servers s ON q.host = s.host AND q.port = s.port WHERE q.conn_id = ?; |] (Only connId) where - sndQueue [(keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion)] = + sndQueue ((keyHash, host, port, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion)) = let server = SMPServer host port keyHash - in Just SndQueue {server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, smpClientVersion} - sndQueue _ = Nothing + in SndQueue {connId, server, sndId, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, nextPrimary, dbReplaceQueueId, smpClientVersion} + primaryFirst SndQueue {primary = p} SndQueue {primary = p'} = compare (Down p) (Down p') -- * updateRcvIds helpers @@ -1342,22 +1440,23 @@ insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, ":msg_body" := msgBody ] -insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () -insertRcvMsgDetails_ dbConn connId RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do +insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () +insertRcvMsgDetails_ dbConn connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta DB.executeNamed dbConn [sql| INSERT INTO rcv_messages - ( conn_id, internal_rcv_id, internal_id, external_snd_id, + ( conn_id, rcv_queue_id, internal_rcv_id, internal_id, external_snd_id, broker_id, broker_ts, internal_hash, external_prev_snd_hash, integrity) VALUES - (:conn_id,:internal_rcv_id,:internal_id,:external_snd_id, + (:conn_id,:rcv_queue_id,:internal_rcv_id,:internal_id,:external_snd_id, :broker_id,:broker_ts, :internal_hash,:external_prev_snd_hash,:integrity); |] [ ":conn_id" := connId, + ":rcv_queue_id" := dbQueueId, ":internal_rcv_id" := internalRcvId, ":internal_id" := fst recipient, ":external_snd_id" := sndMsgId, diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 7a66deb7f..43f99cb94 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -36,6 +36,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues import Simplex.Messaging.Encoding.String import Simplex.Messaging.Transport.Client (TransportHost) @@ -51,7 +52,8 @@ schemaMigrations = ("m20220625_v2_ntf_mode", m20220625_v2_ntf_mode), ("m20220811_onion_hosts", m20220811_onion_hosts), ("m20220817_connection_ntfs", m20220817_connection_ntfs), - ("m20220905_commands", m20220905_commands) + ("m20220905_commands", m20220905_commands), + ("m20220915_connection_queues", m20220915_connection_queues) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220915_connection_queues.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220915_connection_queues.hs new file mode 100644 index 000000000..f50295873 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220915_connection_queues.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20220915_connection_queues :: Query +m20220915_connection_queues = + [sql| +PRAGMA ignore_check_constraints=ON; + +-- rcv_queues +ALTER TABLE rcv_queues ADD COLUMN rcv_queue_id INTEGER CHECK (rcv_queue_id NOT NULL); +UPDATE rcv_queues SET rcv_queue_id = 1; +CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues (conn_id, rcv_queue_id); + +ALTER TABLE rcv_queues ADD COLUMN rcv_primary INTEGER CHECK (rcv_primary NOT NULL); +UPDATE rcv_queues SET rcv_primary = 1; + +ALTER TABLE rcv_queues ADD COLUMN next_rcv_primary INTEGER CHECK (next_rcv_primary NOT NULL); +UPDATE rcv_queues SET next_rcv_primary = 0; + +ALTER TABLE rcv_queues ADD COLUMN replace_rcv_queue_id INTEGER NULL; + +-- snd_queues +ALTER TABLE snd_queues ADD COLUMN snd_queue_id INTEGER CHECK (snd_queue_id NOT NULL); +UPDATE snd_queues SET snd_queue_id = 1; +CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues (conn_id, snd_queue_id); + +ALTER TABLE snd_queues ADD COLUMN snd_primary INTEGER CHECK (snd_primary NOT NULL); +UPDATE snd_queues SET snd_primary = 1; + +ALTER TABLE snd_queues ADD COLUMN next_snd_primary INTEGER CHECK (next_snd_primary NOT NULL); +UPDATE snd_queues SET next_snd_primary = 0; + +ALTER TABLE snd_queues ADD COLUMN replace_snd_queue_id INTEGER NULL; + +-- messages +CREATE TABLE snd_message_deliveries ( + snd_message_delivery_id INTEGER PRIMARY KEY AUTOINCREMENT, + conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE, + snd_queue_id INTEGER NOT NULL, + internal_id INTEGER NOT NULL, + FOREIGN KEY (conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED +); + +CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries (conn_id, snd_queue_id); + +ALTER TABLE rcv_messages ADD COLUMN rcv_queue_id INTEGER CHECK (rcv_queue_id NOT NULL); +UPDATE rcv_messages SET rcv_queue_id = 1; + +PRAGMA ignore_check_constraints=OFF; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index 5787978bb..1b833e7e1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -41,6 +41,10 @@ CREATE TABLE rcv_queues( ntf_private_key BLOB, ntf_id BLOB, rcv_ntf_dh_secret BLOB, + rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL), + rcv_primary INTEGER CHECK(rcv_primary NOT NULL), + next_rcv_primary INTEGER CHECK(next_rcv_primary NOT NULL), + replace_rcv_queue_id INTEGER NULL, PRIMARY KEY(host, port, rcv_id), FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE, @@ -58,6 +62,10 @@ CREATE TABLE snd_queues( smp_client_version INTEGER NOT NULL DEFAULT 1, snd_public_key BLOB, e2e_pub_key BLOB, + snd_queue_id INTEGER CHECK(snd_queue_id NOT NULL), + snd_primary INTEGER CHECK(snd_primary NOT NULL), + next_snd_primary INTEGER CHECK(next_snd_primary NOT NULL), + replace_snd_queue_id INTEGER NULL, PRIMARY KEY(host, port, snd_id), FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE @@ -89,6 +97,7 @@ CREATE TABLE rcv_messages( external_prev_snd_hash BLOB NOT NULL, integrity BLOB NOT NULL, user_ack INTEGER NULL DEFAULT 0, + rcv_queue_id INTEGER CHECK(rcv_queue_id NOT NULL), PRIMARY KEY(conn_id, internal_rcv_id), FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE @@ -206,3 +215,17 @@ CREATE TABLE commands( FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE ); +CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id); +CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id); +CREATE TABLE snd_message_deliveries( + snd_message_delivery_id INTEGER PRIMARY KEY AUTOINCREMENT, + conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE, + snd_queue_id INTEGER NOT NULL, + internal_id INTEGER NOT NULL, + FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED +); +CREATE TABLE sqlite_sequence(name,seq); +CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries( + conn_id, + snd_queue_id +); diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs new file mode 100644 index 000000000..1cc66690f --- /dev/null +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE NamedFieldPuns #-} + +module Simplex.Messaging.Agent.TRcvQueues where + +import Control.Concurrent.STM +import qualified Data.Map.Strict as M +import Data.Set (Set) +import qualified Data.Set as S +import Simplex.Messaging.Agent.Protocol (ConnId) +import Simplex.Messaging.Agent.Store (RcvQueue (..)) +import Simplex.Messaging.Protocol (RecipientId, SMPServer) +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM + +newtype TRcvQueues = TRcvQueues (TMap (SMPServer, RecipientId) RcvQueue) + +empty :: STM TRcvQueues +empty = TRcvQueues <$> TM.empty + +clear :: TRcvQueues -> STM () +clear (TRcvQueues qs) = TM.clear qs + +deleteConn :: ConnId -> TRcvQueues -> STM () +deleteConn cId (TRcvQueues qs) = modifyTVar' qs $ M.filter (\rq -> cId /= connId rq) + +-- TODO possibly, should be limited to active queues +hasConn :: ConnId -> TRcvQueues -> STM Bool +hasConn cId (TRcvQueues qs) = any (\rq -> cId == connId rq) <$> readTVar qs + +-- TODO possibly, should be limited to active queues +getConns :: TRcvQueues -> STM (Set ConnId) +getConns (TRcvQueues qs) = M.foldr' (S.insert . connId) S.empty <$> readTVar qs + +addQueue :: RcvQueue -> TRcvQueues -> STM () +addQueue rq@RcvQueue {server, rcvId} (TRcvQueues qs) = TM.insert (server, rcvId) rq qs + +deleteQueue :: RcvQueue -> TRcvQueues -> STM () +deleteQueue RcvQueue {server, rcvId} (TRcvQueues qs) = TM.delete (server, rcvId) qs + +getSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue] +getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs + where + addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs' + +getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue] +getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty) + where + addQ (removed, qs') rq@RcvQueue {server, rcvId} + | srv == server = (rq : removed, qs') + | otherwise = (removed, M.insert (server, rcvId) rq qs') diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index 2ff374e46..82bff3286 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -6,6 +6,7 @@ module Simplex.Messaging.Encoding.String StrEncoding (..), Str (..), strP_, + _strP, strToJSON, strToJEncoding, strParseJSON, @@ -171,6 +172,9 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin strP_ :: StrEncoding a => Parser a strP_ = strP <* A.space +_strP :: StrEncoding a => Parser a +_strP = A.space *> strP + strToJSON :: StrEncoding a => a -> J.Value strToJSON = J.String . decodeLatin1 . strEncode diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 94f7d581e..38b801872 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -115,6 +115,7 @@ module Simplex.Messaging.Protocol legacyEncodeServer, legacyServerP, legacyStrEncodeServer, + sameSrvAddr, -- * TCP transport functions tPut, @@ -572,6 +573,10 @@ pattern NtfServer host port keyHash = ProtocolServer SPNTF host port keyHash {-# COMPLETE NtfServer #-} +sameSrvAddr :: ProtocolServer p -> ProtocolServer p -> Bool +sameSrvAddr ProtocolServer {host, port} ProtocolServer {host = h', port = p'} = host == h' && port == p' +{-# INLINE sameSrvAddr #-} + data ProtocolType = PSMP | PNTF deriving (Eq, Ord, Show) diff --git a/src/Simplex/Messaging/TMap2.hs b/src/Simplex/Messaging/TMap2.hs deleted file mode 100644 index 69d42e47d..000000000 --- a/src/Simplex/Messaging/TMap2.hs +++ /dev/null @@ -1,82 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} - -module Simplex.Messaging.TMap2 - ( TMap2, - empty, - clear, - Simplex.Messaging.TMap2.lookup, - lookup1, - member, - insert, - insert1, - delete, - lookupDelete1, - ) -where - -import Control.Concurrent.STM -import Control.Monad (forM_, (>=>)) -import qualified Data.Map.Strict as M -import Simplex.Messaging.TMap (TMap) -import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (whenM, ($>>=)) - --- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys. --- It allows direct access via k1 to a group of k2 values and via k2 to one value -data TMap2 k1 k2 a = TMap2 - { _m1 :: TMap k1 (TMap k2 a), - _m2 :: TMap k2 k1 - } - -empty :: STM (TMap2 k1 k2 a) -empty = TMap2 <$> TM.empty <*> TM.empty - -clear :: TMap2 k1 k2 a -> STM () -clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2 - -lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a) -lookup k2 TMap2 {_m1, _m2} = do - TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2 - -lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a)) -lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1 -{-# INLINE lookup1 #-} - -member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool -member k2 TMap2 {_m2} = TM.member k2 _m2 -{-# INLINE member #-} - -insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM () -insert k1 k2 v TMap2 {_m1, _m2} = - TM.lookup k2 _m2 >>= \case - Just k1' - | k1 == k1' -> _insert1 - | otherwise -> _delete1 k1' k2 _m1 >> _insert2 - _ -> _insert2 - where - _insert1 = - TM.lookup k1 _m1 >>= \case - Just m -> TM.insert k2 v m - _ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1 - _insert2 = TM.insert k2 k1 _m2 >> _insert1 - -insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM () -insert1 k1 m' TMap2 {_m1, _m2} = - TM.lookup k1 _m1 >>= \case - Just m -> readTVar m' >>= (`TM.union` m) - _ -> TM.insert k1 m' _m1 - -delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM () -delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1) - -_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM () -_delete1 k1 k2 m1 = - TM.lookup k1 m1 - >>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1)) - -lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a)) -lookupDelete1 k1 TMap2 {_m1, _m2} = do - m_ <- TM.lookupDelete k1 _m1 - forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet - pure m_ diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 89c9a3769..b3a32985e 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -21,14 +21,15 @@ import Control.Concurrent (killThread, threadDelay) import Control.Monad import Control.Monad.Except (ExceptT, runExceptT) import Control.Monad.IO.Unlift +import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import qualified Data.Map as M import qualified Data.Set as S import Data.Time.Clock.System (SystemTime (..), getSystemTime) import SMPAgentClient -import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) +import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) import Simplex.Messaging.Agent -import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..)) +import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Client (ProtocolClientConfig (..)) import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) @@ -115,6 +116,15 @@ functionalAPITests t = do testAsyncCommandsRestore t it "should accept connection using async command" $ withSmpServer t testAcceptContactAsync + describe "Queue rotation" $ do + it "should switch delivery to the new queue (1 server)" $ + withSmpServer t $ testSwitchConnection initAgentServers + it "should switch delivery to the new queue (2 servers)" $ + withSmpServer t . withSmpServerOn t testPort2 $ testSwitchConnection initAgentServers2 + it "should switch to new queue asynchronously (1 server)" $ + withSmpServer t $ testSwitchAsync initAgentServers + it "should switch to new queue asynchronously (2 servers)" $ + withSmpServer t . withSmpServerOn t testPort2 $ testSwitchAsync initAgentServers2 testMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do @@ -638,6 +648,76 @@ testAcceptContactAsync = do baseId = 3 msgId = subtract baseId +testSwitchConnection :: InitialAgentServers -> IO () +testSwitchConnection servers = do + a <- getSMPAgentClient agentCfg servers + b <- getSMPAgentClient agentCfg {database = testDB2} servers + Right () <- runExceptT $ do + (aId, bId) <- makeConnection a b + exchangeGreetingsMsgId 4 a bId b aId + switchConnectionAsync a "" bId + phase a bId SPStarted + phase b aId SPStarted + phase a bId SPConfirmed + phase b aId SPConfirmed + phase a bId SPTested + phase b aId SPTested + phase b aId SPCompleted + phase a bId SPCompleted + exchangeGreetingsMsgId 12 a bId b aId + pure () + +phase :: AgentClient -> ByteString -> SwitchPhase -> ExceptT AgentErrorType IO () +phase c connId p = + get c >>= \(_, connId', msg) -> do + liftIO $ connId `shouldBe` connId' + case msg of + SWITCH p' _ -> liftIO $ p `shouldBe` p' + ERR (AGENT A_DUPLICATE) -> phase c connId p + r -> do + liftIO . putStrLn $ "expected: " <> show p <> ", received: " <> show r + SWITCH _ _ <- pure r + pure () + +testSwitchAsync :: InitialAgentServers -> IO () +testSwitchAsync servers = do + Right (aId, bId) <- withA $ \a -> withB $ \b -> runExceptT $ do + (aId, bId) <- makeConnection a b + exchangeGreetingsMsgId 4 a bId b aId + pure (aId, bId) + let phaseA = withA . phase' bId + phaseB = withB . phase' aId + Right () <- withA $ \a -> runExceptT $ do + subscribeConnection a bId + switchConnectionAsync a "" bId + phase a bId SPStarted + liftIO $ threadDelay 500000 + phaseB SPStarted + phaseA SPConfirmed + phaseB SPConfirmed + phaseA SPTested + Right () <- withB $ \b -> runExceptT $ do + subscribeConnection b aId + phase b aId SPTested + phase b aId SPCompleted + phaseA SPCompleted + Right () <- withA $ \a -> withB $ \b -> runExceptT $ do + subscribeConnection a bId + subscribeConnection b aId + exchangeGreetingsMsgId 12 a bId b aId + pure () + where + withAgent :: AgentConfig -> (AgentClient -> IO a) -> IO a + withAgent cfg' = bracket (getSMPAgentClient cfg' servers) disconnectAgentClient + phase' connId p c = do + Right () <- runExceptT $ do + subscribeConnection c connId + phase c connId p + liftIO $ threadDelay 500000 + pure () + withA = withAgent agentCfg + withB = withAgent agentCfg {database = testDB2, initialClientId = 1} + exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetings = exchangeGreetingsMsgId 4 diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 84ae3e1aa..80baeeb69 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} @@ -114,7 +115,7 @@ testConcurrentWrites = runTest st connId = replicateM_ 100 . withTransaction st $ \db -> do (internalId, internalRcvId, _, _) <- updateRcvIds db connId let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy" - createRcvMsg db connId rcvMsgData + createRcvMsg db connId rcvQueue1 rcvMsgData testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = @@ -153,14 +154,19 @@ testDhSecret = "01234567890123456789012345678901" rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue - { server = SMPServer "smp.simplex.im" "5223" testKeyHash, + { connId = "conn1", + server = SMPServer "smp.simplex.im" "5223" testKeyHash, rcvId = "1234", rcvPrivateKey = testPrivateSignKey, rcvDhSecret = testDhSecret, e2ePrivKey = testPrivDhKey, e2eDhSecret = Nothing, - sndId = Just "2345", + sndId = "2345", status = New, + dbQueueId = 1, + primary = True, + nextPrimary = False, + dbReplaceQueueId = Nothing, smpClientVersion = 1, clientNtfCreds = Nothing } @@ -168,13 +174,18 @@ rcvQueue1 = sndQueue1 :: SndQueue sndQueue1 = SndQueue - { server = SMPServer "smp.simplex.im" "5223" testKeyHash, + { connId = "conn1", + server = SMPServer "smp.simplex.im" "5223" testKeyHash, sndId = "3456", sndPublicKey = Nothing, sndPrivateKey = testPrivateSignKey, e2ePubKey = Nothing, e2eDhSecret = testDhSecret, status = New, + dbQueueId = 1, + primary = True, + nextPrimary = False, + dbReplaceQueueId = Nothing, smpClientVersion = 1 } @@ -187,21 +198,23 @@ testCreateRcvConn = getConn db "conn1" `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1)) upgradeRcvConnToDuplex db "conn1" sndQueue1 - `shouldReturn` Right () + `shouldReturn` Right 1 getConn db "conn1" - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1])) testCreateRcvConnRandomId :: SpecWith SQLiteStore testCreateRcvConnRandomId = it "should create RcvConnection and add SndQueue with random ID" . withStoreTransaction $ \db -> do g <- newTVarIO =<< drgNew Right connId <- createRcvConn db g cData1 {connId = ""} rcvQueue1 SCMInvitation + let rq' = (rcvQueue1 :: RcvQueue) {connId} + sq' = (sndQueue1 :: SndQueue) {connId} getConn db connId - `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1)) + `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rq')) upgradeRcvConnToDuplex db connId sndQueue1 - `shouldReturn` Right () + `shouldReturn` Right 1 getConn db connId - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq'] [sq'])) testCreateRcvConnDuplicate :: SpecWith SQLiteStore testCreateRcvConnDuplicate = @@ -220,21 +233,23 @@ testCreateSndConn = getConn db "conn1" `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1)) upgradeSndConnToDuplex db "conn1" rcvQueue1 - `shouldReturn` Right () + `shouldReturn` Right 1 getConn db "conn1" - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1])) testCreateSndConnRandomID :: SpecWith SQLiteStore testCreateSndConnRandomID = it "should create SndConnection and add RcvQueue with random ID" . withStoreTransaction $ \db -> do g <- newTVarIO =<< drgNew Right connId <- createSndConn db g cData1 {connId = ""} sndQueue1 + let rq' = (rcvQueue1 :: RcvQueue) {connId} + sq' = (sndQueue1 :: SndQueue) {connId} getConn db connId - `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1)) + `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sq')) upgradeSndConnToDuplex db connId rcvQueue1 - `shouldReturn` Right () + `shouldReturn` Right 1 getConn db connId - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq'] [sq'])) testCreateSndConnDuplicate :: SpecWith SQLiteStore testCreateSndConnDuplicate = @@ -287,12 +302,12 @@ testDeleteDuplexConn = _ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation _ <- upgradeRcvConnToDuplex db "conn1" sndQueue1 getConn db "conn1" - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1])) deleteConn db "conn1" `shouldReturn` () -- TODO check queues are deleted as well getConn db "conn1" - `shouldReturn` Left (SEConnNotFound) + `shouldReturn` Left SEConnNotFound testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = @@ -301,13 +316,18 @@ testUpgradeRcvConnToDuplex = _ <- createSndConn db g cData1 sndQueue1 let anotherSndQueue = SndQueue - { server = SMPServer "smp.simplex.im" "5223" testKeyHash, + { connId = "conn1", + server = SMPServer "smp.simplex.im" "5223" testKeyHash, sndId = "2345", sndPublicKey = Nothing, sndPrivateKey = testPrivateSignKey, e2ePubKey = Nothing, e2eDhSecret = testDhSecret, status = New, + dbQueueId = 1, + primary = True, + nextPrimary = False, + dbReplaceQueueId = Nothing, smpClientVersion = 1 } upgradeRcvConnToDuplex db "conn1" anotherSndQueue @@ -323,14 +343,19 @@ testUpgradeSndConnToDuplex = _ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation let anotherRcvQueue = RcvQueue - { server = SMPServer "smp.simplex.im" "5223" testKeyHash, + { connId = "conn1", + server = SMPServer "smp.simplex.im" "5223" testKeyHash, rcvId = "3456", rcvPrivateKey = testPrivateSignKey, rcvDhSecret = testDhSecret, e2ePrivKey = testPrivDhKey, e2eDhSecret = Nothing, - sndId = Just "4567", + sndId = "4567", status = New, + dbQueueId = 1, + primary = True, + nextPrimary = False, + dbReplaceQueueId = Nothing, smpClientVersion = 1, clientNtfCreds = Nothing } @@ -371,15 +396,17 @@ testSetQueueStatusDuplex = _ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation _ <- upgradeRcvConnToDuplex db "conn1" sndQueue1 getConn db "conn1" - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rcvQueue1] [sndQueue1])) setRcvQueueStatus db rcvQueue1 Secured `shouldReturn` () + let rq' = (rcvQueue1 :: RcvQueue) {status = Secured} + sq' = (sndQueue1 :: SndQueue) {status = Confirmed} getConn db "conn1" - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1)) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq'] [sndQueue1])) setSndQueueStatus db sndQueue1 Confirmed `shouldReturn` () getConn db "conn1" - `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})) + `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq'] [sq'])) hw :: ByteString hw = encodeUtf8 "Hello world!" @@ -405,12 +432,12 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = externalPrevSndHash = "hash_from_sender" } -testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation -testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do +testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvQueue -> RcvMsgData -> Expectation +testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rq rcvMsgData@RcvMsgData {..} = do let MsgMeta {recipient = (internalId, _)} = msgMeta updateRcvIds db connId `shouldReturn` (InternalId internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) - createRcvMsg db connId rcvMsgData + createRcvMsg db connId rq rcvMsgData `shouldReturn` () testCreateRcvMsg :: SpecWith SQLiteStore @@ -421,8 +448,8 @@ testCreateRcvMsg = _ <- withTransaction st $ \db -> do createRcvConn db g cData1 rcvQueue1 SCMInvitation withTransaction st $ \db -> do - testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" - testCreateRcvMsg_ db 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" + testCreateRcvMsg_ db 0 "" connId rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" + testCreateRcvMsg_ db 1 "hash_dummy" connId rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData mkSndMsgData internalId internalSndId internalHash = @@ -464,9 +491,9 @@ testCreateRcvAndSndMsgs = createRcvConn db g cData1 rcvQueue1 SCMInvitation withTransaction st $ \db -> do _ <- upgradeRcvConnToDuplex db "conn1" sndQueue1 - testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" - testCreateRcvMsg_ db 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" + testCreateRcvMsg_ db 0 "" connId rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" + testCreateRcvMsg_ db 1 "rcv_hash_1" connId rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" testCreateSndMsg_ db "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" - testCreateRcvMsg_ db 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" + testCreateRcvMsg_ db 2 "rcv_hash_2" connId rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" testCreateSndMsg_ db "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" testCreateSndMsg_ db "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index bb57529f5..867f82a24 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -1,9 +1,9 @@ module CoreTests.ProtocolErrorTests where import Simplex.Messaging.Agent.Protocol (AgentErrorType) +import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol (ErrorType) -import Simplex.Messaging.Encoding.String import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck diff --git a/tests/Test.hs b/tests/Test.hs index 61e09a2b7..d3c52283c 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,6 +1,7 @@ {-# LANGUAGE TypeApplications #-} import AgentTests (agentTests) +-- import Control.Logger.Simple import CoreTests.CryptoTests import CoreTests.EncodingTests import CoreTests.ProtocolErrorTests @@ -13,8 +14,13 @@ import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import System.Environment (setEnv) import Test.Hspec +-- logCfg :: LogConfig +-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} + main :: IO () main = do + -- setLogLevel LogInfo -- LogError + -- withGlobalLogging logCfg $ do createDirectoryIfMissing False "tests/tmp" setEnv "APNS_KEY_ID" "H82WD9K9AQ" setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"