diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 08e8d513b..12a1ad97b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -177,7 +177,7 @@ jobs: if: matrix.should_run == true shell: docker exec -t builder sh -eu {0} run: | - chmod -R 777 dist-newstyle ~/.cabal && git config --global --add safe.directory '*' + chmod -fR 777 ~/.cabal ./dist-newstyle || :; git config --global --add safe.directory '*' cabal clean cabal update cabal build --jobs=$(nproc) --enable-tests -fserver_postgres diff --git a/scripts/simplexmq-reproduce-builds.sh b/scripts/simplexmq-reproduce-builds.sh index 54fbfad8e..ec480b369 100755 --- a/scripts/simplexmq-reproduce-builds.sh +++ b/scripts/simplexmq-reproduce-builds.sh @@ -47,7 +47,7 @@ for os in 22.04 24.04; do docker exec \ -t \ builder \ - sh -c 'cabal update && cabal build --jobs=$(nproc) --enable-tests -fserver_postgres && mkdir -p /out && for i in smp-server simplexmq-test; do bin=$(find /project/dist-newstyle -name "$i" -type f -executable) && chmod +x "$bin" && mv "$bin" /out/; done && strip /out/smp-server' + sh -c 'git config --global --add safe.directory \*; cabal update && cabal build --jobs=$(nproc) --enable-tests -fserver_postgres && mkdir -p /out && for i in smp-server simplexmq-test; do bin=$(find /project/dist-newstyle -name "$i" -type f -executable) && chmod +x "$bin" && mv "$bin" /out/; done && strip /out/smp-server' # Copy smp-server postgresql binary and prepare it docker cp \ diff --git a/simplexmq.cabal b/simplexmq.cabal index 476b0a4be..f6ab07e08 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -1,7 +1,7 @@ cabal-version: 1.12 name: simplexmq -version: 6.5.0.0.1 +version: 6.5.0.3 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and @@ -103,12 +103,13 @@ library Simplex.Messaging.Agent.Store.AgentStore Simplex.Messaging.Agent.Store.Common Simplex.Messaging.Agent.Store.DB + Simplex.Messaging.Agent.Store.Entity Simplex.Messaging.Agent.Store.Interface Simplex.Messaging.Agent.Store.Migrations Simplex.Messaging.Agent.Store.Migrations.App Simplex.Messaging.Agent.Store.Postgres.Options Simplex.Messaging.Agent.Store.Shared - Simplex.Messaging.Agent.TRcvQueues + Simplex.Messaging.Agent.TSessionSubs Simplex.Messaging.Client Simplex.Messaging.Client.Agent Simplex.Messaging.Compression @@ -130,12 +131,13 @@ library Simplex.Messaging.Notifications.Types Simplex.Messaging.Parsers Simplex.Messaging.Protocol + Simplex.Messaging.Protocol.Types Simplex.Messaging.Server.Expiration Simplex.Messaging.Server.QueueStore.Postgres.Config Simplex.Messaging.Server.QueueStore.QueueInfo Simplex.Messaging.ServiceScheme Simplex.Messaging.Session - Simplex.Messaging.Agent.Store.Entity + Simplex.Messaging.SystemTime Simplex.Messaging.TMap Simplex.Messaging.Transport Simplex.Messaging.Transport.Buffer @@ -163,6 +165,8 @@ library Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250203_msg_bodies Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250322_short_links Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete + Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe + Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices else exposed-modules: Simplex.Messaging.Agent.Store.SQLite @@ -210,6 +214,8 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250203_msg_bodies Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250322_short_links Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250702_conn_invitations_remove_cascade_delete + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251009_queue_to_subscribe + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251010_client_notices if flag(client_postgres) || flag(server_postgres) exposed-modules: Simplex.Messaging.Agent.Store.Postgres @@ -475,7 +481,7 @@ test-suite simplexmq-test CoreTests.RetryIntervalTests CoreTests.SOCKSSettings CoreTests.StoreLogTests - CoreTests.TRcvQueuesTests + CoreTests.TSessionSubs CoreTests.UtilTests CoreTests.VersionRangeTests FileDescriptionTests diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index c169e69b0..25de49afc 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -58,8 +58,9 @@ import Simplex.Messaging.Protocol (BlockingInfo, EntityId (..), RcvPublicAuthKey import Simplex.Messaging.Server (controlPortAuth, dummyVerifyCmd, verifyCmdAuthorization) import Simplex.Messaging.Server.Control (CPClientRole (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, ServerEntityStatus (..), getRoundedSystemTime) +import Simplex.Messaging.Server.QueueStore (ServerEntityStatus (..)) import Simplex.Messaging.Server.Stats +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (CertChainPubKey (..), SessionId, THandleAuth (..), THandleParams (..), TransportPeer (..), defaultSupportedParams) @@ -451,7 +452,7 @@ processXFTPRequest HTTP2Body {bodyPart} = \case let rIds = L.map (\(FileRecipient rId _) -> rId) rcps pure $ FRSndIds sId rIds pure $ either FRErr id r - addFileRetry :: FileStore -> FileInfo -> Int -> RoundedSystemTime -> M (Either XFTPErrorType XFTPFileId) + addFileRetry :: FileStore -> FileInfo -> Int -> RoundedFileTime -> M (Either XFTPErrorType XFTPFileId) addFileRetry st file n ts = retryAdd n $ \sId -> runExceptT $ do ExceptT $ addFile st sId file ts EntityActive @@ -579,8 +580,8 @@ deleteOrBlockServerFile_ FileRec {filePath, fileInfo} stat storeAction = runExce liftIO $ atomicModifyIORef'_ (filesCount stats) (subtract 1) liftIO $ atomicModifyIORef'_ (filesSize stats) (subtract $ fromIntegral $ size fileInfo) -getFileTime :: IO RoundedSystemTime -getFileTime = getRoundedSystemTime fileTimePrecision +getFileTime :: IO RoundedFileTime +getFileTime = getRoundedSystemTime expireServerFiles :: Maybe Int -> ExpirationConfig -> M () expireServerFiles itemDelay expCfg = do diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index f59712fc0..eec481a21 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -8,6 +9,7 @@ module Simplex.FileTransfer.Server.Store ( FileStore (..), FileRec (..), FileRecipient (..), + RoundedFileTime, newFileStore, addFile, setFilePath, @@ -33,7 +35,8 @@ import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BlockingInfo, RcvPublicAuthKey, RecipientId, SenderId) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime (..), ServerEntityStatus (..)) +import Simplex.Messaging.Server.QueueStore (ServerEntityStatus (..)) +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (ifM, ($>>=)) @@ -49,10 +52,12 @@ data FileRec = FileRec fileInfo :: FileInfo, filePath :: TVar (Maybe FilePath), recipientIds :: TVar (Set RecipientId), - createdAt :: RoundedSystemTime, + createdAt :: RoundedFileTime, fileStatus :: TVar ServerEntityStatus } +type RoundedFileTime = RoundedSystemTime 3600 + fileTimePrecision :: Int64 fileTimePrecision = 3600 -- truncate creation time to 1 hour @@ -70,14 +75,14 @@ newFileStore = do usedStorage <- newTVarIO 0 pure FileStore {files, recipients, usedStorage} -addFile :: FileStore -> SenderId -> FileInfo -> RoundedSystemTime -> ServerEntityStatus -> STM (Either XFTPErrorType ()) +addFile :: FileStore -> SenderId -> FileInfo -> RoundedFileTime -> ServerEntityStatus -> STM (Either XFTPErrorType ()) addFile FileStore {files} sId fileInfo createdAt status = ifM (TM.member sId files) (pure $ Left DUPLICATE_) $ do f <- newFileRec sId fileInfo createdAt status TM.insert sId f files pure $ Right () -newFileRec :: SenderId -> FileInfo -> RoundedSystemTime -> ServerEntityStatus -> STM FileRec +newFileRec :: SenderId -> FileInfo -> RoundedFileTime -> ServerEntityStatus -> STM FileRec newFileRec senderId fileInfo createdAt status = do recipientIds <- newTVar S.empty filePath <- newTVar Nothing diff --git a/src/Simplex/FileTransfer/Server/StoreLog.hs b/src/Simplex/FileTransfer/Server/StoreLog.hs index c972da281..c82beda29 100644 --- a/src/Simplex/FileTransfer/Server/StoreLog.hs +++ b/src/Simplex/FileTransfer/Server/StoreLog.hs @@ -34,13 +34,13 @@ import Simplex.FileTransfer.Protocol (FileInfo (..)) import Simplex.FileTransfer.Server.Store import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BlockingInfo, RcvPublicAuthKey, RecipientId, SenderId) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, ServerEntityStatus (..)) +import Simplex.Messaging.Server.QueueStore (ServerEntityStatus (..)) import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.Util (bshow) import System.IO data FileStoreLogRecord - = AddFile SenderId FileInfo RoundedSystemTime ServerEntityStatus + = AddFile SenderId FileInfo RoundedFileTime ServerEntityStatus | PutFile SenderId FilePath | AddRecipients SenderId (NonEmpty FileRecipient) | DeleteFile SenderId @@ -69,7 +69,7 @@ instance StrEncoding FileStoreLogRecord where logFileStoreRecord :: StoreLog 'WriteMode -> FileStoreLogRecord -> IO () logFileStoreRecord = writeStoreLogRecord -logAddFile :: StoreLog 'WriteMode -> SenderId -> FileInfo -> RoundedSystemTime -> ServerEntityStatus -> IO () +logAddFile :: StoreLog 'WriteMode -> SenderId -> FileInfo -> RoundedFileTime -> ServerEntityStatus -> IO () logAddFile s = logFileStoreRecord s .:: AddFile logPutFile :: StoreLog 'WriteMode -> SenderId -> FilePath -> IO () diff --git a/src/Simplex/FileTransfer/Types.hs b/src/Simplex/FileTransfer/Types.hs index 953080480..aa465a12e 100644 --- a/src/Simplex/FileTransfer/Types.hs +++ b/src/Simplex/FileTransfer/Types.hs @@ -15,6 +15,7 @@ import Data.Text.Encoding (encodeUtf8) import Data.Word (Word32) import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Description +import Simplex.Messaging.Agent.Store.DB (FromField (..), ToField (..), fromTextField_) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding @@ -22,7 +23,6 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol (XFTPServer) import System.FilePath (()) -import Simplex.Messaging.Agent.Store.DB (FromField (..), ToField (..), fromTextField_) type RcvFileId = ByteString -- Agent entity ID diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 27967bfd6..c19d4aeea 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -13,6 +13,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -67,8 +68,12 @@ module Simplex.Messaging.Agent allowConnection, acceptContact, rejectContact, + DatabaseDiff (..), + compareConnections, + syncConnections, subscribeConnection, subscribeConnections, + subscribeAllConnections, getConnectionMessages, getNotificationConns, resubscribeConnection, @@ -132,17 +137,21 @@ module Simplex.Messaging.Agent ) where +import Control.Applicative ((<|>)) +import Control.Concurrent.STM (retry) import Control.Logger.Simple import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) +import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson as J +import qualified Data.Aeson.TH as JQ import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Composition ((.:), (.:.), (.::), (.::.)) +import Data.Composition import Data.Either (isRight, partitionEithers, rights) import Data.Foldable (foldl', toList) import Data.Functor (($>)) @@ -150,7 +159,7 @@ import Data.Functor.Identity import Data.Int (Int64) import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IM -import Data.List (find) +import Data.List (find, sortOn) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) @@ -180,9 +189,11 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.AgentStore import Simplex.Messaging.Agent.Store.Common (DBStore) import qualified Simplex.Messaging.Agent.Store.DB as DB +import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Interface (closeDBStore, execSQL, getCurrentMigrations) import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration) -import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, nonBlockingWriteTBQueue, temporaryClientError, unexpectedResponse) +import qualified Simplex.Messaging.Agent.TSessionSubs as SS +import Simplex.Messaging.Client (NetworkRequestMode (..), SMPClientError, ServerTransmission (..), ServerTransmissionBatch, nonBlockingWriteTBQueue, smpErrorClientNotice, temporaryClientError, unexpectedResponse) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -192,7 +203,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..), NtfTokenId, PNMessageData (..), pnMessagesP) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Parsers (parse) +import Simplex.Messaging.Parsers (defaultJSON, parse) import Simplex.Messaging.Protocol ( BrokerMsg, Cmd (..), @@ -217,7 +228,7 @@ import Simplex.Messaging.Protocol ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) -import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.SystemTime import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.Util @@ -226,6 +237,7 @@ import Simplex.RemoteControl.Client import Simplex.RemoteControl.Invitation import Simplex.RemoteControl.Types import System.Mem.Weak (deRefWeak) +import UnliftIO.Async (mapConcurrently) import UnliftIO.Concurrent (forkFinally, forkIO, killThread, mkWeakThreadId, threadDelay) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -240,13 +252,14 @@ getSMPAgentClient = getSMPAgentClient_ 1 {-# INLINE getSMPAgentClient #-} getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> DBStore -> Bool -> IO AgentClient -getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp} store backgroundMode = +getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp, presetServers} store backgroundMode = newSMPAgentEnv cfg store >>= runReaderT runAgent where runAgent = do liftIO $ checkServers "SMP" smp >> checkServers "XFTP" xftp currentTs <- liftIO getCurrentTime - c@AgentClient {acThread} <- liftIO . newAgentClient clientId initServers currentTs =<< ask + notices <- liftIO $ withTransaction store (`getClientNotices` presetServers) `catchAll_` pure [] + c@AgentClient {acThread} <- liftIO . newAgentClient clientId initServers currentTs notices =<< ask t <- runAgentThreads c `forkFinally` const (liftIO $ disconnectAgentClient c) atomically . writeTVar acThread . Just =<< mkWeakThreadId t pure c @@ -368,12 +381,12 @@ deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' {-# INLINE deleteConnectionsAsync #-} -- | Create SMP agent connection (NEW command) -createConnection :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> SConnectionMode c -> Maybe UserLinkData -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) -createConnection c nm userId enableNtfs = withAgentEnv c .::. newConn c nm userId enableNtfs +createConnection :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AE (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) +createConnection c nm userId enableNtfs checkNotices = withAgentEnv c .::. newConn c nm userId enableNtfs checkNotices {-# INLINE createConnection #-} -- | Create or update user's contact connection short link -setConnShortLink :: AgentClient -> NetworkRequestMode -> ConnId -> SConnectionMode c -> UserLinkData -> Maybe CRClientData -> AE (ConnShortLink c) +setConnShortLink :: AgentClient -> NetworkRequestMode -> ConnId -> SConnectionMode c -> UserConnLinkData c -> Maybe CRClientData -> AE (ConnShortLink c) setConnShortLink c = withAgentEnv c .::. setConnShortLink' c {-# INLINE setConnShortLink #-} @@ -430,6 +443,24 @@ rejectContact :: AgentClient -> ConfirmationId -> AE () rejectContact c = withAgentEnv c . rejectContact' c {-# INLINE rejectContact #-} +data DatabaseDiff a = DatabaseDiff + { missingIds :: [a], + extraIds :: [a] + } + deriving (Show) + +instance Functor DatabaseDiff where + fmap f DatabaseDiff {missingIds, extraIds} = + DatabaseDiff {missingIds = map f missingIds, extraIds = map f extraIds} + +compareConnections :: AgentClient -> [UserId] -> [ConnId] -> AE (DatabaseDiff UserId, DatabaseDiff ConnId) +compareConnections c = withAgentEnv c .: compareConnections' c +{-# INLINE compareConnections #-} + +syncConnections :: AgentClient -> [UserId] -> [ConnId] -> AE (DatabaseDiff UserId, DatabaseDiff ConnId) +syncConnections c = withAgentEnv c .: syncConnections' c +{-# INLINE syncConnections #-} + -- | Subscribe to receive connection messages (SUB command) subscribeConnection :: AgentClient -> ConnId -> AE (Maybe ClientServiceId) subscribeConnection c = withAgentEnv c . subscribeConnection' c @@ -440,6 +471,10 @@ subscribeConnections :: AgentClient -> [ConnId] -> AE (Map ConnId (Either AgentE subscribeConnections c = withAgentEnv c . subscribeConnections' c {-# INLINE subscribeConnections #-} +-- | Subscribe to all connections +subscribeAllConnections :: AgentClient -> Bool -> Maybe UserId -> AE () +subscribeAllConnections c = withAgentEnv c .: subscribeAllConnections' c + -- | Get messages for connections (GET commands) getConnectionMessages :: AgentClient -> NonEmpty ConnMsgReq -> IO (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) getConnectionMessages c = withAgentEnv' c . getConnectionMessages' c @@ -555,13 +590,14 @@ testProtocolServer c nm userId srv = withAgentEnv' c $ case protocolTypeI @p of -- | set SOCKS5 proxy on/off and optionally set TCP timeouts for fast network setNetworkConfig :: AgentClient -> NetworkConfig -> IO () setNetworkConfig c@AgentClient {useNetworkConfig, proxySessTs} cfg' = do - (spChanged, changed) <- atomically $ do + ts <- getCurrentTime + changed <- atomically $ do (_, cfg) <- readTVar useNetworkConfig let changed = cfg /= cfg' !cfgSlow = slowNetworkConfig cfg' when changed $ writeTVar useNetworkConfig (cfgSlow, cfg') - pure (socksProxy cfg /= socksProxy cfg', changed) - when spChanged $ getCurrentTime >>= atomically . writeTVar proxySessTs + when (socksProxy cfg /= socksProxy cfg') $ writeTVar proxySessTs ts + pure changed when changed $ reconnectAllServers c setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO () @@ -826,29 +862,43 @@ switchConnectionAsync' c corrId connId = rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted enqueueCommand c corrId connId Nothing $ AClientCommand SWCH let rqs' = updatedQs rq1 rqs - pure . connectionStats $ DuplexConnection cData rqs' sqs + connectionStats c $ DuplexConnection cData rqs' sqs _ -> throwE $ CMD PROHIBITED "switchConnectionAsync: not duplex" -newConn :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> SConnectionMode c -> Maybe UserLinkData -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) -newConn c nm userId enableNtfs cMode userData_ clientData pqInitKeys subMode = do +newConn :: ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> Bool -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, (CreatedConnLink c, Maybe ClientServiceId)) +newConn c nm userId enableNtfs checkNotices cMode linkData_ clientData pqInitKeys subMode = do srv <- getSMPServer c userId + when (checkNotices && connMode cMode == CMContact) $ checkClientNotices c srv connId <- newConnNoQueues c userId enableNtfs cMode (CR.connPQEncryption pqInitKeys) - (connId,) <$> newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srv + (connId,) <$> newRcvConnSrv c nm userId connId enableNtfs cMode linkData_ clientData pqInitKeys subMode srv `catchE` \e -> withStore' c (`deleteConnRecord` connId) >> throwE e -setConnShortLink' :: AgentClient -> NetworkRequestMode -> ConnId -> SConnectionMode c -> UserLinkData -> Maybe CRClientData -> AM (ConnShortLink c) -setConnShortLink' c nm connId cMode userData clientData = +checkClientNotices :: AgentClient -> SMPServerWithAuth -> AM () +checkClientNotices AgentClient {clientNotices, presetServers} (ProtoServerWithAuth srv@(ProtocolServer {host}) _) = do + notices <- readTVarIO clientNotices + unless (M.null notices) $ checkNotices notices =<< liftIO getSystemSeconds + where + srvKey + | isPresetServer srv presetServers = Nothing -- Nothing is used as key for preset servers + | otherwise = Just srv + checkNotices notices ts = + forM_ (M.lookup srvKey notices) $ \expires_ -> + when (maybe True (ts <) expires_) $ + throwError NOTICE {server = safeDecodeUtf8 $ strEncode $ L.head host, preset = isNothing srvKey, expiresAt = roundedToUTCTime <$> expires_} + +setConnShortLink' :: AgentClient -> NetworkRequestMode -> ConnId -> SConnectionMode c -> UserConnLinkData c -> Maybe CRClientData -> AM (ConnShortLink c) +setConnShortLink' c nm connId cMode userLinkData clientData = withConnLock c connId "setConnShortLink" $ do SomeConn _ conn <- withStore c (`getConn` connId) - (rq, lnkId, sl, d) <- case (conn, cMode) of - (ContactConnection _ rq, SCMContact) -> prepareContactLinkData rq - (RcvConnection _ rq, SCMInvitation) -> prepareInvLinkData rq + (rq, lnkId, sl, d) <- case (conn, cMode, userLinkData) of + (ContactConnection _ rq, SCMContact, d@UserContactLinkData {}) -> prepareContactLinkData rq d + (RcvConnection _ rq, SCMInvitation, d@UserInvLinkData {}) -> prepareInvLinkData rq d _ -> throwE $ CMD PROHIBITED "setConnShortLink: invalid connection or mode" addQueueLink c nm rq lnkId d pure sl where - prepareContactLinkData :: RcvQueue -> AM (RcvQueue, SMP.LinkId, ConnShortLink 'CMContact, QueueLinkData) - prepareContactLinkData rq@RcvQueue {shortLink} = do + prepareContactLinkData :: RcvQueue -> UserConnLinkData 'CMContact -> AM (RcvQueue, SMP.LinkId, ConnShortLink 'CMContact, QueueLinkData) + prepareContactLinkData rq@RcvQueue {shortLink} ud = do g <- asks random AgentConfig {smpClientVRange = vr, smpAgentVRange} <- asks config let cslContact = CSLContact SLSServer CCTContact (qServer rq) @@ -856,25 +906,25 @@ setConnShortLink' c nm connId cMode userData clientData = Just ShortLinkCreds {shortLinkId, shortLinkKey, linkPrivSigKey, linkEncFixedData} -> do let (linkId, k) = SL.contactShortLinkKdf shortLinkKey unless (shortLinkId == linkId) $ throwE $ INTERNAL "setConnShortLink: link ID is not derived from link" - d <- liftError id $ SL.encryptUserData g k $ SL.encodeSignUserData SCMContact linkPrivSigKey smpAgentVRange userData + d <- liftError id $ SL.encryptUserData g k $ SL.encodeSignUserData SCMContact linkPrivSigKey smpAgentVRange ud pure (rq, linkId, cslContact shortLinkKey, (linkEncFixedData, d)) Nothing -> do sigKeys@(_, privSigKey) <- atomically $ C.generateKeyPair @'C.Ed25519 g let qUri = SMPQueueUri vr $ (rcvSMPQueueAddress rq) {queueMode = Just QMContact} connReq = CRContactUri $ ConnReqUriData SSSimplex smpAgentVRange [qUri] clientData - (linkKey, linkData) = SL.encodeSignLinkData sigKeys smpAgentVRange connReq userData + (linkKey, linkData) = SL.encodeSignLinkData sigKeys smpAgentVRange connReq ud (linkId, k) = SL.contactShortLinkKdf linkKey srvData <- liftError id $ SL.encryptLinkData g k linkData let slCreds = ShortLinkCreds linkId linkKey privSigKey (fst srvData) withStore' c $ \db -> updateShortLinkCreds db rq slCreds pure (rq, linkId, cslContact linkKey, srvData) - prepareInvLinkData :: RcvQueue -> AM (RcvQueue, SMP.LinkId, ConnShortLink 'CMInvitation, QueueLinkData) - prepareInvLinkData rq@RcvQueue {shortLink} = case shortLink of + prepareInvLinkData :: RcvQueue -> UserConnLinkData 'CMInvitation -> AM (RcvQueue, SMP.LinkId, ConnShortLink 'CMInvitation, QueueLinkData) + prepareInvLinkData rq@RcvQueue {shortLink} ud = case shortLink of Just ShortLinkCreds {shortLinkId, shortLinkKey, linkPrivSigKey, linkEncFixedData} -> do g <- asks random AgentConfig {smpAgentVRange} <- asks config let k = SL.invShortLinkKdf shortLinkKey - d <- liftError id $ SL.encryptUserData g k $ SL.encodeSignUserData SCMInvitation linkPrivSigKey smpAgentVRange userData + d <- liftError id $ SL.encryptUserData g k $ SL.encodeSignUserData SCMInvitation linkPrivSigKey smpAgentVRange ud let sl = CSLInvitation SLSServer (qServer rq) shortLinkId shortLinkKey pure (rq, shortLinkId, sl, (linkEncFixedData, d)) Nothing -> throwE $ CMD PROHIBITED "setConnShortLink: no ShortLinkCreds in invitation" @@ -897,8 +947,8 @@ getConnShortLink' c nm userId = \case getInvShortLink db srv linkId >>= \case Just sl@InvShortLink {linkKey = lk} | linkKey == lk -> pure sl _ -> do - (sndPublicKey, sndPrivateKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g - let sl = InvShortLink {server = srv, linkId, linkKey, sndPrivateKey, sndPublicKey, sndId = Nothing} + sndPrivateKey <- atomically $ C.generatePrivateAuthKey C.SEd25519 g + let sl = InvShortLink {server = srv, linkId, linkKey, sndPrivateKey, sndId = Nothing} createInvShortLink db sl pure sl let k = SL.invShortLinkKdf linkKey @@ -939,13 +989,13 @@ changeConnectionUser' c oldUserId connId newUserId = do where updateConn = withStore' c $ \db -> setConnUserId db oldUserId connId newUserId -newRcvConnSrv :: forall c. ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe UserLinkData -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (CreatedConnLink c, Maybe ClientServiceId) -newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKeys subMode srvWithAuth@(ProtoServerWithAuth srv _) = do +newRcvConnSrv :: forall c. ConnectionModeI c => AgentClient -> NetworkRequestMode -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe (UserConnLinkData c) -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (CreatedConnLink c, Maybe ClientServiceId) +newRcvConnSrv c nm userId connId enableNtfs cMode userLinkData_ clientData pqInitKeys subMode srvWithAuth@(ProtoServerWithAuth srv _) = do case (cMode, pqInitKeys) of (SCMContact, CR.IKUsePQ) -> throwE $ CMD PROHIBITED "newRcvConnSrv" _ -> pure () e2eKeys <- atomically . C.generateKeyPair =<< asks random - case userData_ of + case userLinkData_ of Just d -> do (nonce, qUri, cReq, qd) <- prepareLinkData d $ fst e2eKeys (rq, qUri') <- createRcvQueue (Just nonce) qd e2eKeys @@ -963,7 +1013,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing (rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e atomically $ incSMPServerStat c userId srv connCreated - rq' <- withStore c $ \db -> updateNewConnRcv db connId rq + rq' <- withStore c $ \db -> updateNewConnRcv db connId rq subMode lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId mapM_ (newQueueNtfSubscription c rq') ntfServer_ pure (rq', qUri) @@ -975,12 +1025,12 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey SCMContact -> pure $ CRContactUri crData SCMInvitation -> do g <- asks random - let pqEnc = CR.initialPQEncryption (isJust userData_) pqInitKeys + let pqEnc = CR.initialPQEncryption (isJust userLinkData_) pqInitKeys (pk1, pk2, pKem, e2eRcvParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eEncryptVRange) pqEnc withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem pure $ CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange - prepareLinkData :: UserLinkData -> C.PublicKeyX25519 -> AM (C.CbNonce, SMPQueueUri, ConnectionRequestUri c, ClntQueueReqData) - prepareLinkData userData e2eDhKey = do + prepareLinkData :: UserConnLinkData c -> C.PublicKeyX25519 -> AM (C.CbNonce, SMPQueueUri, ConnectionRequestUri c, ClntQueueReqData) + prepareLinkData userLinkData e2eDhKey = do g <- asks random nonce@(C.CbNonce corrId) <- atomically $ C.randomCbNonce g sigKeys@(_, privSigKey) <- atomically $ C.generateKeyPair @'C.Ed25519 g @@ -990,7 +1040,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey qm = case cMode of SCMContact -> QMContact; SCMInvitation -> QMMessaging qUri = SMPQueueUri vr $ SMPQueueAddress srv sndId e2eDhKey (Just qm) connReq <- createConnReq qUri - let (linkKey, linkData) = SL.encodeSignLinkData sigKeys smpAgentVRange connReq userData + let (linkKey, linkData) = SL.encodeSignLinkData sigKeys smpAgentVRange connReq userLinkData qd <- case cMode of SCMContact -> do let (linkId, k) = SL.contactShortLinkKdf linkKey @@ -1090,8 +1140,8 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup = let Compatible SMPQueueInfo {queueAddress = SMPQueueAddress {smpServer, senderId}} = qInfo invLink_ <- withStore' c $ \db -> getInvShortLinkKeys db smpServer senderId let lnkId_ = fst <$> invLink_ - sndKeys_ = snd <$> invLink_ - (q, _) <- lift $ newSndQueue userId "" qInfo sndKeys_ + sndKey_ = snd <$> invLink_ + (q, _) <- lift $ newSndQueue userId "" qInfo sndKey_ withStore c $ \db -> runExceptT $ do e2eSndParams <- createRatchet_ db g maxSupported pqSupport e2eRcvParams sq' <- maybe (ExceptT $ updateNewConnSnd db connId q) pure sq_ @@ -1215,7 +1265,7 @@ createReplyQueue c nm ConnData {userId, connId, enableNtfs} SndQueue {smpClientV (rq, qUri, tSess, sessId) <- newRcvQueue c nm userId connId srv (versionToRange smpClientVersion) SCMInvitation (isJust ntfServer_) subMode atomically $ incSMPServerStat c userId (qServer rq) connCreated let qInfo = toVersionT qUri smpClientVersion - rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq + rq' <- withStore c $ \db -> upgradeSndConnToDuplex db connId rq subMode lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId mapM_ (newQueueNtfSubscription c rq') ntfServer_ pure (qInfo, clientServiceId rq') @@ -1244,6 +1294,27 @@ rejectContact' c invId = withStore' c $ \db -> deleteInvitation db invId {-# INLINE rejectContact' #-} +syncConnections' :: AgentClient -> [UserId] -> [ConnId] -> AM (DatabaseDiff UserId, DatabaseDiff ConnId) +syncConnections' c userIds connIds = do + r@(DatabaseDiff {extraIds = uIds}, DatabaseDiff {extraIds = cIds}) <- compareConnections' c userIds connIds + forM_ uIds $ \uid -> deleteUser' c uid False + deleteConnectionsAsync' c False cIds + pure r + +compareConnections' :: AgentClient -> [UserId] -> [ConnId] -> AM (DatabaseDiff UserId, DatabaseDiff ConnId) +compareConnections' c userIds connIds = do + knownUserIds <- withStore' c getUserIds + knownConnIds <- withStore' c getConnIds + pure (databaseDiff userIds knownUserIds, databaseDiff connIds knownConnIds) + +databaseDiff :: Ord a => [a] -> [a] -> DatabaseDiff a +databaseDiff passed known = + let passedSet = S.fromList passed + knownSet = S.fromList known + missingIds = S.toList $ passedSet `S.difference` knownSet + extraIds = S.toList $ knownSet `S.difference` passedSet + in DatabaseDiff {missingIds, extraIds} + -- | Subscribe to receive connection messages (SUB command) in Reader monad subscribeConnection' :: AgentClient -> ConnId -> AM (Maybe ClientServiceId) subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId] @@ -1263,39 +1334,48 @@ type QSubResult = QCmdResult (Maybe SMP.ServiceId) subscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) subscribeConnections' _ [] = pure M.empty -subscribeConnections' c connIds = do - conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConns` connIds) - let (errs, cs) = M.mapEither id conns - errs' = M.map (Left . storeError) errs - (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs +subscribeConnections' c connIds = subscribeConnections_ c . zip connIds =<< withStore' c (`getConnSubs` connIds) + +subscribeConnections_ :: AgentClient -> [(ConnId, Either StoreError SomeConnSub)] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) +subscribeConnections_ c conns = do + let (subRs, cs) = foldr partitionResultsConns ([], []) conns resumeDelivery cs - resumeConnCmds c $ M.keys cs - rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) + resumeConnCmds c $ map fst cs + rcvRs <- lift $ connResults <$> subscribeQueues c False (concatMap rcvQueues cs) rcvRs' <- storeClientServiceAssocs rcvRs ns <- asks ntfSupervisor lift $ whenM (liftIO $ hasInstantNotifications ns) . void . forkIO . void $ sendNtfCreate ns rcvRs' cs - let rs = M.unions ([errs', subRs, rcvRs'] :: [Map ConnId (Either AgentErrorType (Maybe ClientServiceId))]) + -- union is left-biased + let rs = rcvRs' `M.union` subRs notifyResultError rs pure rs where - rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType (Maybe ClientServiceId)) [RcvQueue] - rcvQueueOrResult (SomeConn _ conn) = case conn of - DuplexConnection _ rqs _ -> Right $ L.toList rqs - SndConnection _ sq -> Left $ sndSubResult sq - RcvConnection _ rq -> Right [rq] - ContactConnection _ rq -> Right [rq] - NewConnection _ -> Left (Right Nothing) + partitionResultsConns :: (ConnId, Either StoreError SomeConnSub) -> + (Map ConnId (Either AgentErrorType (Maybe ClientServiceId)), [(ConnId, SomeConnSub)]) -> + (Map ConnId (Either AgentErrorType (Maybe ClientServiceId)), [(ConnId, SomeConnSub)]) + partitionResultsConns (connId, conn_) (rs, cs) = case conn_ of + Left e -> (M.insert connId (Left $ storeError e) rs, cs) + Right c'@(SomeConn _ conn) -> case conn of + DuplexConnection {} -> (rs, cs') + SndConnection _ sq -> (M.insert connId (sndSubResult sq) rs, cs') + RcvConnection _ _ -> (rs, cs') + ContactConnection _ _ -> (rs, cs') + NewConnection _ -> (M.insert connId (Right Nothing) rs, cs') + where + cs' = (connId, c') : cs sndSubResult :: SndQueue -> Either AgentErrorType (Maybe ClientServiceId) sndSubResult SndQueue {status} = case status of Confirmed -> Right Nothing Active -> Left $ CONN SIMPLEX "subscribeConnections" _ -> Left $ INTERNAL "unexpected queue status" - connResults :: [(RcvQueue, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) + rcvQueues :: (ConnId, SomeConnSub) -> [RcvQueueSub] + rcvQueues (_, SomeConn _ conn) = connRcvQueues conn + connResults :: [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) connResults = M.map snd . foldl' addResult M.empty where -- collects results by connection ID - addResult :: Map ConnId QSubResult -> (RcvQueue, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QSubResult - addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs + addResult :: Map ConnId QSubResult -> (RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId)) -> Map ConnId QSubResult + addResult rs (RcvQueueSub {connId, status}, r) = M.alter (combineRes (status, r)) connId rs -- combines two results for one connection, by using only Active queues (if there is at least one Active queue) combineRes :: QSubResult -> Maybe QSubResult -> Maybe QSubResult combineRes r' (Just r) = Just $ if order r <= order r' then r else r' @@ -1308,31 +1388,90 @@ subscribeConnections' c connIds = do -- TODO [certs rcv] store associations of queues with client service ID storeClientServiceAssocs :: Map ConnId (Either AgentErrorType (Maybe SMP.ServiceId)) -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) storeClientServiceAssocs = pure . M.map (Nothing <$) - sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> Map ConnId SomeConn -> AM' () + sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> [(ConnId, SomeConnSub)] -> AM' () sendNtfCreate ns rcvRs cs = do let oks = M.keysSet $ M.filter (either temporaryAgentError $ const True) rcvRs - cs' = M.restrictKeys cs oks - (csCreate, csDelete) = M.partition (\(SomeConn _ conn) -> enableNtfs $ toConnData conn) cs' + (csCreate, csDelete) = foldr (groupConnIds oks) ([], []) cs sendNtfCmd NSCCreate csCreate sendNtfCmd NSCSmpDelete csDelete where - sendNtfCmd cmd cs' = forM_ (L.nonEmpty $ M.keys cs') $ \cids -> atomically $ writeTBQueue (ntfSubQ ns) (cmd, cids) - resumeDelivery :: Map ConnId SomeConn -> AM () - resumeDelivery conns = do - conns' <- M.restrictKeys conns . S.fromList <$> withStore' c getConnectionsForDelivery - lift $ mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) conns' - sndQueue :: SomeConn -> Maybe (ConnData, NonEmpty SndQueue) - sndQueue (SomeConn _ conn) = case conn of - DuplexConnection cData _ sqs -> Just (cData, sqs) - SndConnection cData sq -> Just (cData, [sq]) - _ -> Nothing + groupConnIds oks (connId, SomeConn _ conn) acc@(csCreate, csDelete) + | connId `S.notMember` oks = acc + | enableNtfs = (connId : csCreate, csDelete) + | otherwise = (csCreate, connId : csDelete) + where + ConnData {enableNtfs} = toConnData conn + sendNtfCmd cmd = mapM_ (\cids -> atomically $ writeTBQueue (ntfSubQ ns) (cmd, cids)) . L.nonEmpty + resumeDelivery :: [(ConnId, SomeConnSub)] -> AM () + resumeDelivery conns' = do + deliverTo <- S.fromList <$> withStore' c getConnectionsForDelivery + let conns'' = filter ((`S.member` deliverTo) . fst) conns' + lift $ mapM_ (mapM_ (resumeMsgDelivery c) . sndQueues) conns'' + sndQueues :: (ConnId, SomeConnSub) -> [SndQueue] + sndQueues (_, SomeConn _ conn) = case conn of + DuplexConnection _ _ sqs -> L.toList sqs + SndConnection _ sq -> [sq] + _ -> [] notifyResultError :: Map ConnId (Either AgentErrorType (Maybe ClientServiceId)) -> AM () notifyResultError rs = do let actual = M.size rs - expected = length connIds + expected = length conns when (actual /= expected) . atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected) +subscribeAllConnections' :: AgentClient -> Bool -> Maybe UserId -> AM () +subscribeAllConnections' c onlyNeeded activeUserId_ = handleErr $ do + userSrvs <- withStore' c (`getSubscriptionServers` onlyNeeded) + unless (null userSrvs) $ do + maxPending <- asks $ maxPendingSubscriptions . config + currPending <- newTVarIO 0 + let userSrvs' = case activeUserId_ of + Just activeUserId -> sortOn (\(uId, _) -> if uId == activeUserId then 0 else 1 :: Int) userSrvs + Nothing -> userSrvs + rs <- lift $ mapConcurrently (subscribeUserServer maxPending currPending) userSrvs' + let (errs, oks) = partitionEithers rs + logInfo $ "subscribed " <> tshow (sum oks) <> " queues" + forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map ("",) + withStore' c unsetQueuesToSubscribe + resumeAllDelivery + resumeAllCommands c + where + handleErr = (`catchAllErrors` \e -> notifySub' c "" (ERR e) >> throwE e) + subscribeUserServer :: Int -> TVar Int -> (UserId, SMPServer) -> AM' (Either AgentErrorType Int) + subscribeUserServer maxPending currPending (userId, srv) = do + atomically $ whenM ((maxPending <=) <$> readTVar currPending) retry + tryAllErrors' $ do + qs <- withStore' c $ \db -> do + qs <- getUserServerRcvQueueSubs db userId srv onlyNeeded + atomically $ modifyTVar' currPending (+ length qs) -- update before leaving transaction + pure qs + let n = length qs + lift $ subscribe qs `E.finally` atomically (modifyTVar' currPending $ subtract n) + pure n + where + subscribe qs = do + rs <- subscribeUserServerQueues c userId srv qs + -- TODO [certs rcv] storeClientServiceAssocs store associations of queues with client service ID + ns <- asks ntfSupervisor + whenM (liftIO $ hasInstantNotifications ns) $ sendNtfCreate ns rs + sendNtfCreate :: NtfSupervisor -> [(RcvQueueSub, Either AgentErrorType (Maybe SMP.ServiceId))] -> AM' () + sendNtfCreate ns rs = do + let (csCreate, csDelete) = foldl' groupConnIds (S.empty, S.empty) rs + sendNtfCmd NSCCreate csCreate + sendNtfCmd NSCSmpDelete csDelete + where + groupConnIds acc@(!csCreate, !csDelete) (RcvQueueSub {connId, enableNtfs}, r) = case r of + Left e + | not (temporaryAgentError e) -> acc + _ + | enableNtfs -> (S.insert connId csCreate, csDelete) + | otherwise -> (csCreate, S.insert connId csDelete) + sendNtfCmd cmd = mapM_ (\cIds -> atomically $ writeTBQueue (ntfSubQ ns) (cmd, cIds)) . L.nonEmpty . S.toList + resumeAllDelivery :: AM () + resumeAllDelivery = do + sqs <- withStore' c getAllSndQueuesForDelivery + lift $ mapM_ (resumeMsgDelivery c) sqs + resubscribeConnection' :: AgentClient -> ConnId -> AM (Maybe ClientServiceId) resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId] {-# INLINE resubscribeConnection' #-} @@ -1340,10 +1479,17 @@ resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections resubscribeConnections' :: AgentClient -> [ConnId] -> AM (Map ConnId (Either AgentErrorType (Maybe ClientServiceId))) resubscribeConnections' _ [] = pure M.empty resubscribeConnections' c connIds = do - let r = M.fromList . zip connIds . repeat $ Right Nothing - connIds' <- filterM (fmap not . atomically . hasActiveSubscription c) connIds + conns <- zip connIds <$> withStore' c (`getConnSubs` connIds) + let r = M.fromList $ map (,Right Nothing) connIds -- TODO [certs rcv] + conns' <- filterM (fmap not . isActiveConn . snd) conns -- union is left-biased, so results returned by subscribeConnections' take precedence - (`M.union` r) <$> subscribeConnections' c connIds' + (`M.union` r) <$> subscribeConnections_ c conns' + where + isActiveConn :: Either StoreError SomeConnSub -> AM Bool + isActiveConn (Left _) = pure True -- to have results processed by subscribeConnections_ + isActiveConn (Right (SomeConn _ conn)) = case connRcvQueues conn of + [] -> pure True + rqs' -> anyM $ map (atomically . hasActiveSubscription c) rqs' -- TODO [certs rcv] subscribeClientService' :: AgentClient -> ClientServiceId -> AM Int @@ -1355,7 +1501,6 @@ getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage where getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta) getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do - whenM (atomically $ hasActiveSubscription c connId) . throwE $ CMD PROHIBITED "getConnectionMessage: subscribed" SomeConn _ conn <- withStore c (`getConn` connId) rq <- case conn of DuplexConnection _ (rq :| _) _ -> pure rq @@ -1363,6 +1508,7 @@ getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage ContactConnection _ rq -> pure rq SndConnection _ _ -> throwE $ CONN SIMPLEX "getConnectionMessage" NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection" + whenM (atomically $ hasActiveSubscription c rq) . throwE $ CMD PROHIBITED "getConnectionMessage: subscribed" msg_ <- getQueueMessage c rq `catchAllErrors` \e -> atomically (releaseGetLock c rq) >> throwError e when (isNothing msg_) $ do atomically $ releaseGetLock c rq @@ -1478,6 +1624,11 @@ resumeConnCmds c connIds = do connSrvs <- withStore' c (`getPendingCommandServers` connIds) lift $ mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs +resumeAllCommands :: AgentClient -> AM () +resumeAllCommands c = do + connSrvs <- withStore' c getAllPendingCommandConns `catchAllErrors` (\e -> liftIO (print e) >> throwE e) + lift $ mapM_ (uncurry $ resumeSrvCmds c) connSrvs + getAsyncCmdWorker :: Bool -> AgentClient -> ConnId -> Maybe SMPServer -> AM' Worker getAsyncCmdWorker hasWork c connId server = getAgentWorker "async_cmd" hasWork c (connId, server) (asyncCmdWorkers c) (runCommandProcessing c connId server) @@ -1570,11 +1721,12 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do rq1' <- withStore' c $ \db -> setRcvSwitchStatus db rq1 $ Just RSSendingQUSE let rqs' = updatedQs rq1' rqs conn' = DuplexConnection cData rqs' sqs - notify . SWITCH QDRcv SPSecured $ connectionStats conn' + cStats <- connectionStats c conn' + notify $ SWITCH QDRcv SPSecured cStats _ -> internalErr "ICQSecure: no switching queue found" _ -> internalErr "ICQSecure: queue address not found in connection" ICQDelete rId -> do - withServer $ \srv -> tryWithLock "ICQDelete" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> do + withServer $ \srv -> tryWithLock "ICQDelete" . withDuplexConn $ \(DuplexConnection cData@ConnData {enableNtfs} rqs sqs) -> do case removeQ (srv, rId) rqs of Nothing -> internalErr "ICQDelete: queue address not found in connection" Just (rq'@RcvQueue {primary}, rq'' : rqs') @@ -1589,11 +1741,12 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do where finalizeSwitch = do withStore' c $ \db -> deleteConnRcvQueue db rq' - when (enableNtfs cData) $ do + when enableNtfs $ do ns <- asks ntfSupervisor liftIO $ sendNtfSubCommand ns (NSCCreate, [connId]) let conn' = DuplexConnection cData (rq'' :| rqs') sqs - notify $ SWITCH QDRcv SPCompleted $ connectionStats conn' + cStats <- connectionStats c conn' + notify $ SWITCH QDRcv SPCompleted cStats _ -> internalErr "ICQDelete: cannot delete the only queue in connection" where ack srv rId srvMsgId = do @@ -1663,15 +1816,15 @@ enqueueMessage c cData sq msgFlags aMessage = {-# INLINE enqueueMessage #-} -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType (Either AgentErrorType (ConnData, NonEmpty SndQueue), Maybe PQEncryption, MsgFlags, ValueOrRef AMessage)) -> AM' (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType (Either AgentErrorType (ConnData, NonEmpty SndQueue), Maybe PQEncryption, MsgFlags, ValueOrRef AMessage)) -> AM' (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe ([SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do cfg <- asks config (_, reqMids) <- unsafeWithStore c $ \db -> do mapAccumLM (\ids r -> storeSentMsg db cfg ids r `E.catchAny` \e -> (ids,) <$> handleInternal e) IM.empty reqs - forME reqMids $ \((csqs_, _, _, _), InternalId msgId, pqSecr) -> forM csqs_ $ \(cData, sq :| sqs) -> do - submitPendingMsg c cData sq + forME reqMids $ \((csqs_, _, _, _), InternalId msgId, pqSecr) -> forM csqs_ $ \(_, sq :| sqs) -> do + submitPendingMsg c sq let sqs' = filter isActiveSndQ sqs - pure ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) + pure ((msgId, pqSecr), if null sqs' then Nothing else Just (sqs', msgId)) where storeSentMsg :: DB.Connection -> @@ -1719,7 +1872,7 @@ enqueueMessageB c reqs = do -- msgBody is empty, because snd_messages record is linked to snd_message_bodies msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody = "", pqEncryption = pqEnc, internalHash, prevMsgHash, sndMsgPrepData_ = Just SndMsgPrepData {encryptKey = mek, paddedLen, sndMsgBodyId}} liftIO $ createSndMsg db connId msgData - liftIO $ createSndMsgDelivery db connId sq internalId + liftIO $ createSndMsgDelivery db sq internalId pure (req, internalId, pqEnc) handleInternal :: E.SomeException -> IO (Either AgentErrorType b) handleInternal = pure . Left . INTERNAL . show @@ -1731,41 +1884,44 @@ encodeAgentMsgStr aMessage internalSndId prevMsgHash = do agentMsg = AgentMessage privHeader aMessage in smpEncode agentMsg -enqueueSavedMessage :: AgentClient -> ConnData -> AgentMsgId -> SndQueue -> AM' () -enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) +enqueueSavedMessage :: AgentClient -> AgentMsgId -> SndQueue -> AM' () +enqueueSavedMessage c msgId sq = enqueueSavedMessageB c $ Identity ([sq], msgId) {-# INLINE enqueueSavedMessage #-} -enqueueSavedMessageB :: Foldable t => AgentClient -> t (ConnData, [SndQueue], AgentMsgId) -> AM' () +enqueueSavedMessageB :: Foldable t => AgentClient -> t ([SndQueue], AgentMsgId) -> AM' () enqueueSavedMessageB c reqs = do -- saving to the database is in the start to avoid race conditions when delivery is read from queue before it is saved void $ withStoreBatch' c $ \db -> concatMap (storeDeliveries db) reqs - forM_ reqs $ \(cData, sqs, _) -> - forM sqs $ submitPendingMsg c cData + -- TODO this needs to be optimized to insert them in one query + forM_ reqs $ \(sqs, _) -> forM sqs $ submitPendingMsg c where - storeDeliveries :: DB.Connection -> (ConnData, [SndQueue], AgentMsgId) -> [IO ()] - storeDeliveries db (ConnData {connId}, sqs, msgId) = do + storeDeliveries :: DB.Connection -> ([SndQueue], AgentMsgId) -> [IO ()] + storeDeliveries db (sqs, msgId) = do let mId = InternalId msgId - in map (\sq -> createSndMsgDelivery db connId sq mId) sqs + in map (\sq -> createSndMsgDelivery db sq mId) sqs -resumeMsgDelivery :: AgentClient -> ConnData -> SndQueue -> AM' () -resumeMsgDelivery = void .:. getDeliveryWorker False +resumeMsgDelivery :: AgentClient -> SndQueue -> AM' () +-- hasWork is passed as False to avoid unnecessary write to TMVar: +-- - new worker is always created by "some work to do". +-- - if the worker already exists, there is no need to "push" it again. +resumeMsgDelivery = void .: getDeliveryWorker False {-# INLINE resumeMsgDelivery #-} -getDeliveryWorker :: Bool -> AgentClient -> ConnData -> SndQueue -> AM' (Worker, TMVar ()) -getDeliveryWorker hasWork c cData sq = - getAgentWorker' fst mkLock "msg_delivery" hasWork c (qAddress sq) (smpDeliveryWorkers c) (runSmpQueueMsgDelivery c cData sq) +getDeliveryWorker :: Bool -> AgentClient -> SndQueue -> AM' (Worker, TMVar ()) +getDeliveryWorker hasWork c sq = + getAgentWorker' fst mkLock "msg_delivery" hasWork c (qAddress sq) (smpDeliveryWorkers c) (runSmpQueueMsgDelivery c sq) where mkLock w = do retryLock <- newEmptyTMVar pure (w, retryLock) -submitPendingMsg :: AgentClient -> ConnData -> SndQueue -> AM' () -submitPendingMsg c cData sq = do +submitPendingMsg :: AgentClient -> SndQueue -> AM' () +submitPendingMsg c sq = do atomically $ modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + 1} - void $ getDeliveryWorker True c cData sq + void $ getDeliveryWorker True c sq -runSmpQueueMsgDelivery :: AgentClient -> ConnData -> SndQueue -> (Worker, TMVar ()) -> AM () -runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userId, server, queueMode} (Worker {doWork}, qLock) = do +runSmpQueueMsgDelivery :: AgentClient -> SndQueue -> (Worker, TMVar ()) -> AM () +runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server, queueMode} (Worker {doWork}, qLock) = do AgentConfig {messageRetryInterval = ri, messageTimeout, helloTimeout, quotaExceededTimeout} <- asks config forever $ do atomically $ endAgentOperation c AOSndNetwork @@ -1879,7 +2035,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI AM_QADD_ -> pure () AM_QKEY_ -> do SomeConn _ conn <- withStore c (`getConn` connId) - notify . SWITCH QDSnd SPConfirmed $ connectionStats conn + cStats <- connectionStats c conn + notify $ SWITCH QDSnd SPConfirmed cStats AM_QUSE_ -> pure () AM_QTEST_ -> withConnLock c connId "runSmpQueueMsgDelivery AM_QTEST_" $ do withStore' c $ \db -> setSndQueueStatus db sq Active @@ -1904,7 +2061,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq@SndQueue {userI deleteConnSndQueue db connId sq' let sqs'' = sq'' :| sqs' conn' = DuplexConnection cData' rqs sqs'' - notify . SWITCH QDSnd SPCompleted $ connectionStats conn' + cStats <- connectionStats c conn' + notify $ SWITCH QDSnd SPCompleted cStats _ -> internalErr msgId "sent QTEST: there is only one queue in connection" _ -> internalErr msgId "sent QTEST: queue not in connection or not replacing another queue" _ -> internalErr msgId "QTEST sent not in duplex connection" @@ -2010,12 +2168,12 @@ switchDuplexConnection c nm (DuplexConnection cData@ConnData {connId, userId} rq -- The problem is that currently subscription already exists, and we do not support queues with credentials but without subscriptions. (q, qUri, tSess, sessId) <- newRcvQueue c nm userId connId srv' clientVRange SCMInvitation False SMSubscribe let rq' = (q :: NewRcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} - rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' + rq'' <- withStore c $ \db -> addConnRcvQueue db connId rq' SMSubscribe lift $ addNewQueueSubscription c rq'' tSess sessId void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))] rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD let rqs' = updatedQs rq1 rqs <> [rq''] - pure . connectionStats $ DuplexConnection cData rqs' sqs + connectionStats c $ DuplexConnection cData rqs' sqs abortConnectionSwitch' :: AgentClient -> ConnId -> AM ConnectionStats abortConnectionSwitch' c connId = @@ -2035,7 +2193,7 @@ abortConnectionSwitch' c connId = forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId let rqs'' = updatedQs rq' rqs' conn' = DuplexConnection cData rqs'' sqs - pure $ connectionStats conn' + connectionStats c conn' _ -> throwE $ INTERNAL "won't delete all rcv queues in connection" | otherwise -> throwE $ CMD PROHIBITED "abortConnectionSwitch: no rcv queues left" _ -> throwE $ CMD PROHIBITED "abortConnectionSwitch: not allowed" @@ -2052,13 +2210,13 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni AgentConfig {e2eEncryptVRange} <- asks config g <- asks random (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eEncryptVRange) pqSupport' - enqueueRatchetKeyMsgs c cData' sqs e2eParams + enqueueRatchetKeyMsgs c sqs e2eParams withStore' c $ \db -> do setConnRatchetSync db connId RSStarted setRatchetX3dhKeys db connId pk1 pk2 pKem let cData'' = cData' {ratchetSyncState = RSStarted} :: ConnData conn' = DuplexConnection cData'' rqs sqs - pure $ connectionStats conn' + connectionStats c conn' | otherwise -> throwE $ CMD PROHIBITED "synchronizeRatchet: not allowed" _ -> throwE $ CMD PROHIBITED "synchronizeRatchet: not duplex" @@ -2097,7 +2255,7 @@ deleteConnection' :: AgentClient -> NetworkRequestMode -> ConnId -> AM () deleteConnection' c nm connId = toConnResult connId =<< deleteConnections' c nm [connId] {-# INLINE deleteConnection' #-} -connRcvQueues :: Connection d -> [RcvQueue] +connRcvQueues :: Connection' d rq sq -> [rq] connRcvQueues = \case DuplexConnection _ rqs _ -> L.toList rqs RcvConnection _ rq -> [rq] @@ -2125,13 +2283,10 @@ prepareDeleteConnections_ :: [ConnId] -> AM (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) prepareDeleteConnections_ getConnections c waitDelivery connIds = do - conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConnections` connIds) - let (errs, cs) = M.mapEither id conns - errs' = M.map (Left . storeError) errs - (delRs, rcvQs) = M.mapEither rcvQueues cs - rqs = concat $ M.elems rcvQs - connIds' = M.keys rcvQs - lift $ forM_ (L.nonEmpty connIds') unsubConnIds + conns <- withStore' c (`getConnections` connIds) + let res@(delRs, rqs, connIds') = foldr partitionResultsConns (M.empty, [], []) $ zip connIds conns + atomically $ removeSubscriptions c connIds' rqs + lift $ forM_ (L.nonEmpty connIds') unsubNtfConnIds -- ! delRs is not used to notify about the result in any of the calling functions, -- ! it is only used to check results count in deleteConnections_; -- ! if it was used to notify about the result, it might be necessary to differentiate @@ -2139,16 +2294,18 @@ prepareDeleteConnections_ getConnections c waitDelivery connIds = do deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing cIds_ <- lift $ L.nonEmpty . catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) (M.keys delRs)) forM_ cIds_ $ \cIds -> notify ("", "", AEvt SAEConn $ DEL_CONNS cIds) - pure (errs' <> delRs, rqs, connIds') + pure res where - rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue] - rcvQueues (SomeConn _ conn) = case connRcvQueues conn of - [] -> Left $ Right () - rqs -> Right rqs - unsubConnIds :: NonEmpty ConnId -> AM' () - unsubConnIds connIds' = do - forM_ connIds' $ \connId -> - atomically $ removeSubscription c connId + partitionResultsConns :: (ConnId, Either StoreError SomeConn) -> + (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) -> + (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) + partitionResultsConns (connId, conn_) (rs, rqs, cIds) = case conn_ of + Left e -> (M.insert connId (Left $ storeError e) rs, rqs, cIds) + Right (SomeConn _ conn) -> case connRcvQueues conn of + [] -> (M.insert connId (Right ()) rs, rqs, cIds) + rqs' -> (rs, rqs' ++ rqs, connId : cIds) + unsubNtfConnIds :: NonEmpty ConnId -> AM' () + unsubNtfConnIds connIds' = do ns <- asks ntfSupervisor atomically $ writeTBQueue (ntfSubQ ns) (NSCDeleteSub, connIds') notify = atomically . writeTBQueue (subQ c) @@ -2227,34 +2384,62 @@ deleteConnections_ getConnections ntf waitDelivery c nm connIds = do getConnectionServers' :: AgentClient -> ConnId -> AM ConnectionStats getConnectionServers' c connId = do SomeConn _ conn <- withStore c (`getConn` connId) - pure $ connectionStats conn + connectionStats c conn getConnectionRatchetAdHash' :: AgentClient -> ConnId -> AM ByteString getConnectionRatchetAdHash' c connId = do CR.Ratchet {rcAD = Str rcAD} <- withStore c (`getRatchet` connId) pure $ C.sha256Hash rcAD -connectionStats :: Connection c -> ConnectionStats -connectionStats = \case - RcvConnection cData rq -> - (stats cData) {rcvQueuesInfo = [rcvQueueInfo rq]} - SndConnection cData sq -> - (stats cData) {sndQueuesInfo = [sndQueueInfo sq]} - DuplexConnection cData rqs sqs -> - (stats cData) {rcvQueuesInfo = map rcvQueueInfo $ L.toList rqs, sndQueuesInfo = map sndQueueInfo $ L.toList sqs} - ContactConnection cData rq -> - (stats cData) {rcvQueuesInfo = [rcvQueueInfo rq]} +connectionStats :: AgentClient -> Connection c -> AM ConnectionStats +connectionStats c = \case + RcvConnection cData rq -> do + rcvQueuesInfo <- (: []) <$> rcvQueueInfo rq + pure (stats cData) {rcvQueuesInfo, subStatus = connSubStatus rcvQueuesInfo} + SndConnection cData sq -> do + pure (stats cData) {sndQueuesInfo = [sndQueueInfo sq]} + DuplexConnection cData rqs sqs -> do + rcvQueuesInfo <- mapM rcvQueueInfo (L.toList rqs) + pure + (stats cData) + { rcvQueuesInfo, + sndQueuesInfo = map sndQueueInfo $ L.toList sqs, + subStatus = connSubStatus rcvQueuesInfo + } + ContactConnection cData rq -> do + rcvQueuesInfo <- (: []) <$> rcvQueueInfo rq + pure (stats cData) {rcvQueuesInfo, subStatus = connSubStatus rcvQueuesInfo} NewConnection cData -> - stats cData + pure $ stats cData where + stats :: ConnData -> ConnectionStats stats ConnData {connAgentVersion, ratchetSyncState} = ConnectionStats { connAgentVersion, rcvQueuesInfo = [], sndQueuesInfo = [], ratchetSyncState, - ratchetSyncSupported = connAgentVersion >= ratchetSyncSMPAgentVersion + ratchetSyncSupported = connAgentVersion >= ratchetSyncSMPAgentVersion, + subStatus = Nothing } + rcvQueueInfo :: RcvQueue -> AM RcvQueueInfo + rcvQueueInfo rq@RcvQueue {server, status, rcvSwchStatus} = do + subStatus <- atomically checkQueueSubStatus + pure $ RcvQueueInfo {rcvServer = server, status, rcvSwitchStatus = rcvSwchStatus, canAbortSwitch = canAbortRcvSwitch rq, subStatus} + where + checkQueueSubStatus :: STM SubscriptionStatus + checkQueueSubStatus = + ifM (hasActiveSubscription c rq) (pure SSActive) $ + ifM (hasPendingSubscription c rq) (pure SSPending) $ + maybe SSNoSub (SSRemoved . show) <$> hasRemovedSubscription c rq + sndQueueInfo :: SndQueue -> SndQueueInfo + sndQueueInfo SndQueue {server, status, sndSwchStatus} = + SndQueueInfo {sndServer = server, status, sndSwitchStatus = sndSwchStatus} + connSubStatus :: [RcvQueueInfo] -> Maybe SubscriptionStatus + connSubStatus rqs = + let isActive RcvQueueInfo {status} = status == Active + subStatus' RcvQueueInfo {subStatus} = subStatus + in minimum . L.map subStatus' <$> (L.nonEmpty (filter isActive rqs) <|> L.nonEmpty rqs) -- | Change servers to be used for creating new queues. -- This function will set all servers as enabled in case all passed servers are disabled. @@ -2410,8 +2595,8 @@ toggleConnectionNtfs' c connId enable = do _ -> throwE $ CONN SIMPLEX "toggleConnectionNtfs" where toggle :: ConnData -> AM () - toggle cData - | enableNtfs cData == enable = pure () + toggle ConnData {enableNtfs} + | enableNtfs == enable = pure () | otherwise = do withStore' c $ \db -> setConnectionNtfs db connId enable ns <- asks ntfSupervisor @@ -2452,10 +2637,10 @@ sendNtfConnCommands c cmd = do ns <- asks ntfSupervisor connIds <- liftIO $ S.toList <$> getSubscriptions c rs <- withStore' c (`getConnsData` connIds) - let (connIds', cErrs) = enabledNtfConns (zip connIds rs) + let (connIds', errs) = enabledNtfConns (zip connIds rs) forM_ (L.nonEmpty connIds') $ \connIds'' -> atomically $ writeTBQueue (ntfSubQ ns) (cmd, connIds'') - unless (null cErrs) $ atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS cErrs) + forM_ (L.nonEmpty errs) $ notifySub c . ERRS where enabledNtfConns :: [(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)]) enabledNtfConns = foldr addEnabledConn ([], []) @@ -2516,7 +2701,7 @@ execAgentStoreSQL :: AgentClient -> Text -> AE [Text] execAgentStoreSQL c sql = withAgentEnv c $ withStore' c (`execSQL` sql) getAgentMigrations :: AgentClient -> AE [UpMigration] -getAgentMigrations c = withAgentEnv c $ map upMigration <$> withStore' c getCurrentMigrations +getAgentMigrations c = withAgentEnv c $ map upMigration <$> withStore' c (getCurrentMigrations Nothing) debugAgentLocks :: AgentClient -> IO AgentLocks debugAgentLocks AgentClient {connLocks = cs, invLocks = is, deleteLock = d} = do @@ -2625,18 +2810,20 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId STEvent msgOrErr -> withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of Right msg -> runProcessSMP rq conn (toConnData conn) msg - Left e -> lift $ notifyErr connId e + Left e -> lift $ do + processClientNotice rq e + notifyErr connId e STResponse (Cmd SRecipient cmd) respOrErr -> withRcvConn entId $ \rq conn -> case cmd of SMP.SUB -> case respOrErr of - Right SMP.OK -> processSubOk rq upConnIds + Right SMP.OK -> liftIO $ processSubOk rq upConnIds -- TODO [certs rcv] associate queue with the service - Right (SMP.SOK serviceId_) -> processSubOk rq upConnIds + Right (SMP.SOK serviceId_) -> liftIO $ processSubOk rq upConnIds Right msg@SMP.MSG {} -> do - processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails + liftIO $ processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails runProcessSMP rq conn (toConnData conn) msg - Right r -> processSubErr rq $ unexpectedResponse r - Left e -> unless (temporaryClientError e) $ processSubErr rq e -- timeout/network was already reported + Right r -> lift $ processSubErr rq $ unexpectedResponse r + Left e -> lift $ unless (temporaryClientError e) $ processSubErr rq e -- timeout/network was already reported SMP.ACK _ -> case respOrErr of Right msg@SMP.MSG {} -> runProcessSMP rq conn (toConnData conn) msg _ -> pure () -- TODO process OK response to ACK @@ -2658,20 +2845,28 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId tryAllErrors' (a rq conn) >>= \case Left e -> notify' connId (ERR e) Right () -> pure () - processSubOk :: RcvQueue -> TVar [ConnId] -> AM () + processSubOk :: RcvQueue -> TVar [ConnId] -> IO () processSubOk rq@RcvQueue {connId} upConnIds = - atomically . whenM (isPendingSub connId) $ do - addSubscription c sessId rq + atomically . whenM (isPendingSub rq) $ do + SS.addActiveSub tSess sessId (rcvQueueSub rq) $ currentSubs c modifyTVar' upConnIds (connId :) - processSubErr :: RcvQueue -> SMPClientError -> AM () + processSubErr :: RcvQueue -> SMPClientError -> AM' () processSubErr rq@RcvQueue {connId} e = do - atomically . whenM (isPendingSub connId) $ - failSubscription c rq e >> incSMPServerStat c userId srv connSubErrs - lift $ notifyErr connId e - isPendingSub connId = do - pending <- (&&) <$> hasPendingSubscription c connId <*> activeClientSession c tSess sessId + atomically . whenM (isPendingSub rq) $ + failSubscription c tSess rq e >> incSMPServerStat c userId srv connSubErrs + processClientNotice rq e + notifyErr connId e + isPendingSub :: RcvQueue -> STM Bool + isPendingSub rq = do + pending <- (&&) <$> SS.hasPendingSub tSess (queueId rq) (currentSubs c) <*> activeClientSession c tSess sessId unless pending $ incSMPServerStat c userId srv connSubIgnored pure pending + processClientNotice rq e = + forM_ (smpErrorClientNotice e) $ \notice_ -> + E.bracket_ + (atomically $ takeTMVar $ clientNoticesLock c) + (atomically $ putTMVar (clientNoticesLock c) ()) + (processClientNotices c tSess [(rcvQueueSub rq, notice_)]) notify' :: forall e m. (AEntityI e, MonadIO m) => ConnId -> AEvent e -> m () notify' connId msg = atomically $ writeTBQueue subQ ("", connId, AEvt (sAEntity @e) msg) notifyErr :: ConnId -> SMPClientError -> AM' () @@ -2766,7 +2961,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId | rss `notElem` ([RSOk, RSStarted] :: [RatchetSyncState]) = do let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData conn'' = updateConnection cData'' conn' - notify . RSYNC RSOk Nothing $ connectionStats conn'' + cStats <- connectionStats c conn'' + notify $ RSYNC RSOk Nothing cStats withStore' c $ \db -> setConnRatchetSync db connId RSOk pure conn'' | otherwise = pure conn' @@ -2796,7 +2992,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId when (rss `elem` ([RSOk, RSAllowed, RSRequired] :: [RatchetSyncState])) $ do let cData'' = (toConnData conn') {ratchetSyncState = rss'} :: ConnData conn'' = updateConnection cData'' connDuplex - notify . RSYNC rss' (Just e) $ connectionStats conn'' + cStats <- connectionStats c conn'' + notify $ RSYNC rss' (Just e) cStats withStore' c $ \db -> setConnRatchetSync db connId rss' Left e -> do atomically $ incSMPServerStat c userId srv recvErrs @@ -2850,14 +3047,14 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId handleNotifyAck :: AM ACKd -> AM ACKd handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack SMP.END -> - atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False)) + atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c tSess connId rq $> True) (pure False)) >>= notifyEnd where notifyEnd removed | removed = notify END >> logServer "<--" c srv rId "END" | otherwise = logServer "<--" c srv rId "END from disconnected client - ignored" -- Possibly, we need to add some flag to connection that it was deleted - SMP.DELD -> atomically (removeSubscription c connId) >> notify DELD + SMP.DELD -> atomically (removeSubscription c tSess connId rq) >> notify DELD SMP.ERR e -> notify $ ERR $ SMP (B.unpack $ strEncode srv) e r -> unexpected r where @@ -3041,17 +3238,18 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId let (delSqs, keepSqs) = L.partition ((Just dbQueueId ==) . dbReplaceQId) sqs case L.nonEmpty keepSqs of Just sqs' -> do - (sq_@SndQueue {sndPublicKey}, dhPublicKey) <- lift $ newSndQueue userId connId qInfo Nothing + (sq_@SndQueue {sndPrivateKey}, dhPublicKey) <- lift $ newSndQueue userId connId qInfo Nothing sq2 <- withStore c $ \db -> do liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} logServer "<--" c srv rId $ "MSG :" <> logSecret' srvMsgId <> " " <> logSecret (senderId queueAddress) let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}} - void . enqueueMessages c cData' sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPublicKey)] + void . enqueueMessages c cData' sqs SMP.noMsgFlags $ QKEY [(sqInfo', C.toPublic sndPrivateKey)] sq1 <- withStore' c $ \db -> setSndSwitchStatus db sq $ Just SSSendingQKEY let sqs'' = updatedQs sq1 sqs' <> [sq2] conn' = DuplexConnection cData' rqs sqs'' - notify . SWITCH QDSnd SPStarted $ connectionStats conn' + cStats <- connectionStats c conn' + notify $ SWITCH QDSnd SPStarted cStats _ -> qError "QADD: won't delete all snd queues in connection" _ -> qError "QADD: replaced queue address is not found in connection" _ -> throwE $ AGENT A_VERSION @@ -3070,7 +3268,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId let dhSecret = C.dh' dhPublicKey dhPrivKey withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer' enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey - notify . SWITCH QDRcv SPConfirmed $ connectionStats conn' + cStats <- connectionStats c conn' + notify $ SWITCH QDRcv SPConfirmed cStats | otherwise -> qError "QKEY: queue already secured" _ -> qError "QKEY: queue address not found in connection" where @@ -3095,7 +3294,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId sq1' <- withStore' c $ \db -> setSndSwitchStatus db sq1 $ Just SSSendingQTEST let sqs' = updatedQs sq1' sqs conn' = DuplexConnection cData' rqs sqs' - notify . SWITCH QDSnd SPSecured $ connectionStats conn' + cStats <- connectionStats c conn' + notify $ SWITCH QDSnd SPSecured cStats _ -> qError "QUSE: switching SndQueue not found in connection" _ -> qError "QUSE: switched queue address not found in connection" @@ -3166,17 +3366,19 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId sendReplyKey = do g <- asks random (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion pqSupport - enqueueRatchetKeyMsgs c cData' sqs e2eParams + enqueueRatchetKeyMsgs c sqs e2eParams pure (pk1, pk2, pKem) notifyRatchetSyncError = do let cData'' = cData' {ratchetSyncState = RSRequired} :: ConnData conn'' = updateConnection cData'' conn' - notify $ RSYNC RSRequired (Just RATCHET_SYNC) (connectionStats conn'') + cStats <- connectionStats c conn'' + notify $ RSYNC RSRequired (Just RATCHET_SYNC) cStats notifyAgreed :: AM () notifyAgreed = do let cData'' = cData' {ratchetSyncState = RSAgreed} :: ConnData conn'' = updateConnection cData'' conn' - notify . RSYNC RSAgreed Nothing $ connectionStats conn'' + cStats <- connectionStats c conn'' + notify $ RSYNC RSAgreed Nothing cStats recreateRatchet :: CR.Ratchet 'C.X448 -> AM () recreateRatchet rc = withStore' c $ \db -> do setConnRatchetSync db connId RSAgreed @@ -3242,7 +3444,7 @@ secureConfirmQueueAsync c cData rq_ sq srv connInfo e2eEncryption_ subMode = do sqSecured <- agentSecureSndQueue c NRMBackground cData sq (qInfo, service) <- mkAgentConfirmation c NRMBackground cData rq_ sq srv connInfo subMode storeConfirmation c cData sq e2eEncryption_ qInfo - lift $ submitPendingMsg c cData sq + lift $ submitPendingMsg c sq pure (sqSecured, service) secureConfirmQueue :: AgentClient -> NetworkRequestMode -> ConnData -> Maybe RcvQueue -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM (SndQueueSecured, Maybe ClientServiceId) @@ -3288,7 +3490,7 @@ mkAgentConfirmation c nm cData rq_ sq srv connInfo subMode = do enqueueConfirmation :: AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AM () enqueueConfirmation c cData sq connInfo e2eEncryption_ = do storeConfirmation c cData sq e2eEncryption_ $ AgentConnInfo connInfo - lift $ submitPendingMsg c cData sq + lift $ submitPendingMsg c sq storeConfirmation :: AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AgentMessage -> AM () storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq e2eEncryption_ agentMsg = do @@ -3304,18 +3506,18 @@ storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq msgType = agentMessageType agentMsg msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash, sndMsgPrepData_ = Nothing} liftIO $ createSndMsg db connId msgData - liftIO $ createSndMsgDelivery db connId sq internalId + liftIO $ createSndMsgDelivery db sq internalId -enqueueRatchetKeyMsgs :: AgentClient -> ConnData -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> AM () -enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do - msgId <- enqueueRatchetKey c cData sq e2eEncryption - mapM_ (lift . enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs +enqueueRatchetKeyMsgs :: AgentClient -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> AM () +enqueueRatchetKeyMsgs c (sq :| sqs) e2eEncryption = do + msgId <- enqueueRatchetKey c sq e2eEncryption + mapM_ (lift . enqueueSavedMessage c msgId) $ filter isActiveSndQ sqs -enqueueRatchetKey :: AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> AM AgentMsgId -enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do +enqueueRatchetKey :: AgentClient -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> AM AgentMsgId +enqueueRatchetKey c sq@SndQueue {connId} e2eEncryption = do aVRange <- asks $ smpAgentVRange . config msgId <- storeRatchetKey $ maxVersion aVRange - lift $ submitPendingMsg c cData sq + lift $ submitPendingMsg c sq pure $ unId msgId where storeRatchetKey :: VersionSMPA -> AM InternalId @@ -3330,7 +3532,7 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do -- this message is e2e encrypted with queue key, not with double ratchet msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption = PQEncOff, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash, sndMsgPrepData_ = Nothing} liftIO $ createSndMsg db connId msgData - liftIO $ createSndMsgDelivery db connId sq internalId + liftIO $ createSndMsgDelivery db sq internalId pure internalId -- encoded AgentMessage -> encoded EncAgentMessage @@ -3361,11 +3563,11 @@ agentRatchetDecrypt' g db connId rc encAgentMsg = do liftIO $ updateRatchet db connId rc' skippedDiff liftEither $ bimap (SEAgentError . cryptoError) (,CR.rcRcvKEM rc') agentMsgBody_ -newSndQueue :: UserId -> ConnId -> Compatible SMPQueueInfo -> Maybe (C.AAuthKeyPair) -> AM' (NewSndQueue, C.PublicKeyX25519) -newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, queueMode, dhPublicKey = rcvE2ePubDhKey})) sndKeys_ = do +newSndQueue :: UserId -> ConnId -> Compatible SMPQueueInfo -> Maybe (C.APrivateAuthKey) -> AM' (NewSndQueue, C.PublicKeyX25519) +newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, queueMode, dhPublicKey = rcvE2ePubDhKey})) sndKey_ = do C.AuthAlg a <- asks $ sndAuthAlg . config g <- asks random - (sndPublicKey, sndPrivateKey) <- maybe (atomically $ C.generateAuthKeyPair a g) pure sndKeys_ + sndPrivateKey <- maybe (atomically $ C.generatePrivateAuthKey a g) pure sndKey_ (e2ePubKey, e2ePrivKey) <- atomically $ C.generateKeyPair g let sq = SndQueue @@ -3374,12 +3576,11 @@ newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAdd server = smpServer, sndId = senderId, queueMode, - sndPublicKey, sndPrivateKey, e2eDhSecret = C.dh' rcvE2ePubDhKey e2ePrivKey, e2ePubKey = Just e2ePubKey, -- setting status to Secured prevents SKEY when queue was already secured with LKEY - status = if isJust sndKeys_ then Secured else New, + status = if isJust sndKey_ then Secured else New, dbQueueId = DBNewEntity, primary = True, dbReplaceQueueId = Nothing, @@ -3387,3 +3588,12 @@ newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAdd smpClientVersion } pure (sq, e2ePubKey) + +$(pure []) + +instance FromJSON a => FromJSON (DatabaseDiff a) where + parseJSON = $(JQ.mkParseJSON defaultJSON ''DatabaseDiff) + +instance ToJSON a => ToJSON (DatabaseDiff a) where + toEncoding = $(JQ.mkToEncoding defaultJSON ''DatabaseDiff) + toJSON = $(JQ.mkToJSON defaultJSON ''DatabaseDiff) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 21c0436ee..217a1682a 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -48,9 +48,10 @@ module Simplex.Messaging.Agent.Client newRcvQueue, newRcvQueue_, subscribeQueues, + subscribeUserServerQueues, + processClientNotices, getQueueMessage, decryptSMPMessage, - addSubscription, failSubscription, addNewQueueSubscription, getSubscriptions, @@ -99,8 +100,10 @@ module Simplex.Messaging.Agent.Client logSecret, logSecret', removeSubscription, + removeSubscriptions, hasActiveSubscription, hasPendingSubscription, + hasRemovedSubscription, hasGetLock, releaseGetLock, activeClientSession, @@ -157,6 +160,8 @@ module Simplex.Messaging.Agent.Client withStoreBatch', unsafeWithStore, storeError, + notifySub, + notifySub', userServers, pickServer, getNextServer, @@ -200,15 +205,17 @@ import Data.Bifunctor (bimap, first, second) import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Composition ((.:), (.:.)) +import Data.Containers.ListUtils (nubOrd) import Data.Either (isRight, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (find, foldl', partition) +import Data.List (find, foldl') import Data.List.NonEmpty (NonEmpty (..), (<|)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe, mapMaybe) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.Set (Set) import qualified Data.Set as S import Data.Text (Text) @@ -231,10 +238,12 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.AgentStore (getClientNotices, updateClientNotices) import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction) import qualified Simplex.Messaging.Agent.Store.DB as DB -import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues (getRcvQueues)) -import qualified Simplex.Messaging.Agent.TRcvQueues as RQ +import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.Agent.TSessionSubs (TSessionSubs) +import qualified Simplex.Messaging.Agent.TSessionSubs as SS import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -281,9 +290,10 @@ import Simplex.Messaging.Protocol senderCanSecure, ) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Protocol.Types import Simplex.Messaging.Server.QueueStore.QueueInfo import Simplex.Messaging.Session -import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion, SessionId, THandleParams (sessionId, thVersion), TransportError (..), TransportPeer (..), sndAuthKeySMPVersion, shortLinksSMPVersion, newNtfCredsSMPVersion) @@ -310,8 +320,6 @@ type NtfClientVar = ClientVar NtfResponse type XFTPClientVar = ClientVar FileResponse -type SMPTransportSession = TransportSession SMP.BrokerMsg - type NtfTransportSession = TransportSession NtfResponse type XFTPTransportSession = TransportSession FileResponse @@ -333,12 +341,14 @@ data AgentClient = AgentClient xftpClients :: TMap XFTPTransportSession XFTPClientVar, useNetworkConfig :: TVar (NetworkConfig, NetworkConfig), -- (slow, fast) networks presetDomains :: [HostName], + presetServers :: [SMPServer], userNetworkInfo :: TVar UserNetworkInfo, userNetworkUpdated :: TVar (Maybe UTCTime), subscrConns :: TVar (Set ConnId), - activeSubs :: TRcvQueues (SessionId, RcvQueue), - pendingSubs :: TRcvQueues RcvQueue, - removedSubs :: TMap (UserId, SMPServer, SMP.RecipientId) SMPClientError, + currentSubs :: TSessionSubs, + removedSubs :: TMap (UserId, SMPServer) (TMap SMP.RecipientId SMPClientError), + clientNotices :: TMap (Maybe SMPServer) (Maybe SystemSeconds), + clientNoticesLock :: TMVar (), workerSeq :: TVar Int, smpDeliveryWorkers :: TMap SndQAddr (Worker, TMVar ()), asyncCmdWorkers :: TMap (ConnId, Maybe SMPServer) Worker, @@ -428,7 +438,7 @@ getAgentWorker' toW fromW name hasWork c@AgentClient {agentEnv} key ws work = do newWorker :: AgentClient -> STM Worker newWorker c = do workerId <- stateTVar (workerSeq c) $ \next -> (next, next + 1) - doWork <- newTMVar () + doWork <- newTMVar () -- new worker is created with "some work to do" (indicated by () in TMVar) action <- newTMVar Nothing restarts <- newTVar $ RestartCount 0 0 pure Worker {workerId, doWork, action, restarts} @@ -484,8 +494,8 @@ data UserNetworkType = UNNone | UNCellular | UNWifi | UNEthernet | UNOther deriving (Eq, Show) -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. -newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Env -> IO AgentClient -newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomains} currentTs agentEnv = do +newAgentClient :: Int -> InitialAgentServers -> UTCTime -> Map (Maybe SMPServer) (Maybe SystemSeconds) -> Env -> IO AgentClient +newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomains, presetServers} currentTs notices agentEnv = do let cfg = config agentEnv qSize = tbqSize cfg proxySessTs <- newTVarIO =<< getCurrentTime @@ -504,9 +514,10 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomai userNetworkInfo <- newTVarIO $ UserNetworkInfo UNOther True userNetworkUpdated <- newTVarIO Nothing subscrConns <- newTVarIO S.empty - activeSubs <- RQ.empty - pendingSubs <- RQ.empty + currentSubs <- SS.emptyIO removedSubs <- TM.emptyIO + clientNotices <- newTVarIO notices + clientNoticesLock <- newTMVarIO () workerSeq <- newTVarIO 0 smpDeliveryWorkers <- TM.emptyIO asyncCmdWorkers <- TM.emptyIO @@ -540,12 +551,14 @@ newAgentClient clientId InitialAgentServers {smp, ntf, xftp, netCfg, presetDomai xftpClients, useNetworkConfig, presetDomains, + presetServers, userNetworkInfo, userNetworkUpdated, subscrConns, - activeSubs, - pendingSubs, + currentSubs, removedSubs, + clientNotices, + clientNoticesLock, workerSeq, smpDeliveryWorkers, asyncCmdWorkers, @@ -700,40 +713,44 @@ smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm liftError (protocolClientError SMP $ B.unpack $ strEncode srv) $ do ts <- readTVarIO proxySessTs smp <- ExceptT $ getProtocolClient g nm tSess cfg presetDomains (Just msgQ) ts $ smpClientDisconnected c tSess env v' prs + atomically $ SS.setSessionId tSess (sessionId $ thParams smp) $ currentSubs c pure SMPConnectedClient {connectedClient = smp, proxiedRelays = prs} smpClientDisconnected :: AgentClient -> SMPTransportSession -> Env -> SMPClientVar -> TMap SMPServer ProxiedRelayVar -> SMPClient -> IO () -smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess@(userId, srv, qId) env v prs client = do +smpClientDisconnected c@AgentClient {active, smpClients, smpProxiedRelays} tSess@(userId, srv, cId) env v prs client = do removeClientAndSubs >>= serverDown logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv where -- we make active subscriptions pending only if the client for tSess was current (in the map) and active, -- because we can have a race condition when a new current client could have already -- made subscriptions active, and the old client would be processing diconnection later. - removeClientAndSubs :: IO ([RcvQueue], [ConnId]) + removeClientAndSubs :: IO ([RcvQueueSub], [ConnId]) removeClientAndSubs = atomically $ do removeSessVar v tSess smpClients ifM (readTVar active) removeSubs (pure ([], [])) where sessId = sessionId $ thParams client removeSubs = do - (qs, cs) <- RQ.getDelSessQueues tSess sessId $ activeSubs c - RQ.batchAddQueues (pendingSubs c) qs + mode <- getSessionMode c + subs <- SS.setSubsPending mode tSess sessId $ currentSubs c + let qs = M.elems subs + cs = nubOrd $ map qConnId qs -- this removes proxied relays that this client created sessions to destSrvs <- M.keys <$> readTVar prs - forM_ destSrvs $ \destSrv -> TM.delete (userId, destSrv, qId) smpProxiedRelays + forM_ destSrvs $ \destSrv -> TM.delete (userId, destSrv, cId) smpProxiedRelays pure (qs, cs) - serverDown :: ([RcvQueue], [ConnId]) -> IO () + serverDown :: ([RcvQueueSub], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do - notifySub "" $ hostEvent' DISCONNECT client - unless (null conns) $ notifySub "" $ DOWN srv conns + notifySub c $ hostEvent' DISCONNECT client + unless (null conns) $ notifySub c $ DOWN srv conns unless (null qs) $ do - atomically $ mapM_ (releaseGetLock c) qs - runReaderT (resubscribeSMPSession c tSess) env - - notifySub :: forall e. AEntityI e => ConnId -> AEvent e -> IO () - notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, AEvt (sAEntity @e) cmd) + releaseGetLocksIO c qs + mode <- getSessionModeIO c + let resubscribe + | (mode == TSMEntity) == isJust cId = resubscribeSMPSession c tSess + | otherwise = void $ subscribeQueues c True qs + runReaderT resubscribe env resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' () resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do @@ -742,7 +759,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do where getWorkerVar ts = ifM - (not <$> RQ.hasSessQueues tSess (pendingSubs c)) + (not <$> SS.hasPendingSubs tSess (currentSubs c)) (pure Nothing) -- prevent race with cleanup and adding pending queues in another call (Just <$> getSessVar workerSeq tSess smpSubWorkers ts) newSubWorker v = do @@ -751,11 +768,11 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do runSubWorker = do ri <- asks $ reconnectInterval . config withRetryForeground ri isForeground (isNetworkOnline c) $ \_ loop -> do - pending <- liftIO $ RQ.getSessQueues tSess $ pendingSubs c - forM_ (L.nonEmpty pending) $ \qs -> do + pending <- atomically $ SS.getPendingSubs tSess $ currentSubs c + unless (M.null pending) $ do liftIO $ waitUntilForeground c liftIO $ waitForUserNetwork c - reconnectSMPClient c tSess qs + handleNotify $ resubscribeSessQueues c tSess $ M.elems pending loop isForeground = (ASForeground ==) <$> readTVar (agentState c) cleanup :: SessionVar (Async ()) -> STM () @@ -764,28 +781,16 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess = do -- Not waiting may result in terminated worker remaining in the map. whenM (isEmptyTMVar $ sessionVar v) retry removeSessVar v tSess smpSubWorkers - -reconnectSMPClient :: AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM' () -reconnectSMPClient c tSess@(_, srv, _) qs = handleNotify $ do - cs <- readTVarIO $ RQ.getConnections $ activeSubs c - (rs, sessId_) <- subscribeQueues c $ L.toList qs - let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs - conns = filter (`M.notMember` cs) okConns - unless (null conns) $ notifySub "" $ UP srv conns - let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs - mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs - forM_ (listToMaybe tempErrs) $ \(connId, e) -> do - when (null okConns && M.null cs && null finalErrs) . liftIO $ - forM_ sessId_ $ \sessId -> do - -- We only close the client session that was used to subscribe. - v_ <- atomically $ ifM (activeClientSession c tSess sessId) (TM.lookupDelete tSess $ smpClients c) (pure Nothing) - mapM_ (closeClient_ c) v_ - notifySub connId $ ERR e - where handleNotify :: AM' () -> AM' () - handleNotify = E.handleAny $ notifySub "" . ERR . INTERNAL . show - notifySub :: forall e. AEntityI e => ConnId -> AEvent e -> AM' () - notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, AEvt (sAEntity @e) cmd) + handleNotify = E.handleAny $ notifySub' c "" . ERR . INTERNAL . show + +notifySub' :: forall e m. (AEntityI e, MonadIO m) => AgentClient -> ConnId -> AEvent e -> m () +notifySub' c connId cmd = liftIO $ nonBlockingWriteTBQueue (subQ c) (B.empty, connId, AEvt (sAEntity @e) cmd) +{-# INLINE notifySub' #-} + +notifySub :: MonadIO m => AgentClient -> AEvent 'AENone -> m () +notifySub c = notifySub' c "" +{-# INLINE notifySub #-} getNtfServerClient :: AgentClient -> NetworkRequestMode -> NtfTransportSession -> AM NtfClient getNtfServerClient c@AgentClient {active, ntfClients, workerSeq, proxySessTs, presetDomains} nm tSess@(_, srv, _) = do @@ -928,8 +933,7 @@ closeAgentClient c = do atomically (swapTVar (smpSubWorkers c) M.empty) >>= mapM_ cancelReconnect clearWorkers smpDeliveryWorkers >>= mapM_ (cancelWorker . fst) clearWorkers asyncCmdWorkers >>= mapM_ cancelWorker - atomically . RQ.clear $ activeSubs c - atomically . RQ.clear $ pendingSubs c + atomically $ SS.clear $ currentSubs c clear subscrConns clear getMsgLocks where @@ -1070,7 +1074,7 @@ withLogClient c nm tSess entId cmdStr action = withLogClient_ c nm tSess entId c withSMPClient :: SMPQueueRec q => AgentClient -> NetworkRequestMode -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> AM a withSMPClient c nm q cmdStr action = do - tSess <- mkSMPTransportSession c q + tSess <- mkSMPTransportSessionIO c q withLogClient c nm tSess (unEntityId $ queueId q) cmdStr $ action . connectedClient sendOrProxySMPMessage :: AgentClient -> NetworkRequestMode -> UserId -> SMPServer -> ConnId -> ByteString -> Maybe SMP.SndPrivateAuthKey -> SMP.SenderId -> MsgFlags -> SMP.MsgBody -> AM (Maybe SMPServer) @@ -1246,7 +1250,7 @@ runSMPServerTest c@AgentClient {presetDomains} nm userId (ProtoServerWithAuth sr SMP.QIK {rcvId, sndId, queueMode} <- liftError (testErr TSCreateQueue) $ createSMPQueue smp nm Nothing rKeys dhKey auth SMSubscribe (QRMessaging Nothing) Nothing liftError (testErr TSSecureQueue) $ case queueMode of - Just QMMessaging -> secureSndSMPQueue smp nm spKey sndId sKey + Just QMMessaging -> secureSndSMPQueue smp nm spKey sndId _ -> secureSMPQueue smp nm rpKey rcvId sKey liftError (testErr TSDeleteQueue) $ deleteSMPQueue smp nm rpKey rcvId ok <- netTimeoutInt (tcpTimeout $ networkConfig cfg) nm `timeout` closeProtocolClient smp @@ -1335,14 +1339,18 @@ getXFTPWorkPath = do maybe getTemporaryDirectory pure workDir mkTransportSession :: MonadIO m => AgentClient -> UserId -> ProtoServer msg -> ByteString -> m (TransportSession msg) -mkTransportSession c userId srv sessEntId = mkTSession userId srv sessEntId <$> getSessionMode c +mkTransportSession c userId srv sessEntId = mkTSession userId srv sessEntId <$> getSessionModeIO c {-# INLINE mkTransportSession #-} mkTSession :: UserId -> ProtoServer msg -> ByteString -> TransportSessionMode -> TransportSession msg mkTSession userId srv sessEntId mode = (userId, srv, if mode == TSMEntity then Just sessEntId else Nothing) {-# INLINE mkTSession #-} -mkSMPTransportSession :: (SMPQueueRec q, MonadIO m) => AgentClient -> q -> m SMPTransportSession +mkSMPTransportSessionIO :: (SMPQueueRec q, MonadIO m) => AgentClient -> q -> m SMPTransportSession +mkSMPTransportSessionIO c q = mkSMPTSession q <$> getSessionModeIO c +{-# INLINE mkSMPTransportSessionIO #-} + +mkSMPTransportSession :: SMPQueueRec q => AgentClient -> q -> STM SMPTransportSession mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c {-# INLINE mkSMPTransportSession #-} @@ -1350,8 +1358,12 @@ mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSessi mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) {-# INLINE mkSMPTSession #-} -getSessionMode :: MonadIO m => AgentClient -> m TransportSessionMode -getSessionMode = fmap sessionMode . getNetworkConfig +getSessionModeIO :: MonadIO m => AgentClient -> m TransportSessionMode +getSessionModeIO = fmap (sessionMode . snd) . readTVarIO . useNetworkConfig +{-# INLINE getSessionModeIO #-} + +getSessionMode :: AgentClient -> STM TransportSessionMode +getSessionMode = fmap (sessionMode . snd) . readTVar . useNetworkConfig {-# INLINE getSessionMode #-} newRcvQueue :: AgentClient -> NetworkRequestMode -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SConnectionMode c -> Bool -> SubscriptionMode -> AM (NewRcvQueue, SMPQueueUri, SMPTransportSession, SessionId) @@ -1405,6 +1417,8 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl shortLink, clientService = ClientService DBNewEntity <$> serviceId, status = New, + enableNtfs, + clientNoticeId = Nothing, dbQueueId = DBNewEntity, primary = True, dbReplaceQueueId = Nothing, @@ -1455,17 +1469,39 @@ newRcvQueue_ c nm userId connId (ProtoServerWithAuth srv auth) vRange cqrd enabl newErr :: String -> AM (Maybe ShortLinkCreds) newErr = throwE . BROKER (B.unpack $ strEncode srv) . UNEXPECTED . ("Create queue: " <>) -processSubResult :: AgentClient -> SessionId -> RcvQueue -> Either SMPClientError (Maybe ServiceId) -> STM () -processSubResult c sessId rq@RcvQueue {userId, server, connId} = \case - Left e -> - unless (temporaryClientError e) $ do - incSMPServerStat c userId server connSubErrs - failSubscription c rq e - Right _serviceId -> -- TODO [certs rcv] store association with the service - ifM - (hasPendingSubscription c connId) - (incSMPServerStat c userId server connSubscribed >> addSubscription c sessId rq) - (incSMPServerStat c userId server connSubIgnored) +processSubResults :: AgentClient -> SMPTransportSession -> SessionId -> NonEmpty (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> STM [(RcvQueueSub, Maybe ClientNotice)] +processSubResults c tSess@(userId, srv, _) sessId rs = do + pendingSubs <- SS.getPendingSubs tSess $ currentSubs c + let (failed, subscribed, notices, ignored) = foldr (partitionResults pendingSubs) (M.empty, [], [], 0) rs + unless (M.null failed) $ do + incSMPServerStat' c userId srv connSubErrs $ M.size failed + failSubscriptions c tSess failed + unless (null subscribed) $ do + incSMPServerStat' c userId srv connSubscribed $ length subscribed + SS.batchAddActiveSubs tSess sessId subscribed $ currentSubs c + unless (ignored == 0) $ incSMPServerStat' c userId srv connSubIgnored ignored + pure notices + where + partitionResults :: + Map SMP.RecipientId RcvQueueSub -> + (RcvQueueSub, Either SMPClientError (Maybe ServiceId)) -> + (Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int) -> + (Map SMP.RecipientId SMPClientError, [RcvQueueSub], [(RcvQueueSub, Maybe ClientNotice)], Int) + partitionResults pendingSubs (rq@RcvQueueSub {rcvId, clientNoticeId}, r) acc@(failed, subscribed, notices, ignored) = case r of + Left e -> case smpErrorClientNotice e of + Just notice_ -> (failed', subscribed, (rq, notice_) : notices, ignored) + where + notices' = if isJust notice_ || isJust clientNoticeId then (rq, notice_) : notices else notices + Nothing + | temporaryClientError e -> acc + | otherwise -> (failed', subscribed, notices, ignored) + where + failed' = M.insert rcvId e failed + Right _serviceId -- TODO [certs rcv] store association with the service + | rcvId `M.member` pendingSubs -> (failed, rq : subscribed, notices', ignored) + | otherwise -> (failed, subscribed, notices', ignored + 1) + where + notices' = if isJust clientNoticeId then (rq, Nothing) : notices else notices temporaryAgentError :: AgentErrorType -> Bool temporaryAgentError = \case @@ -1499,46 +1535,120 @@ serverHostError = \case SMP.TRANSPORT TEVersion -> True _ -> False --- | Subscribe to queues. The list of results can have a different order. -subscribeQueues :: AgentClient -> [RcvQueue] -> AM' ([(RcvQueue, Either AgentErrorType (Maybe ServiceId))], Maybe SessionId) -subscribeQueues c qs = do - (errs, qs') <- partitionEithers <$> mapM checkQueue qs - atomically $ do - modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs')) - RQ.batchAddQueues (pendingSubs c) qs' - env <- ask - -- only "checked" queues are subscribed - session <- newTVarIO Nothing - rs <- sendTSessionBatches "SUB" id (subscribeQueues_ env session) c NRMBackground qs' - (errs <> rs,) <$> readTVarIO session +-- | Batch by transport session and subscribe queues. The list of results can have a different order. +subscribeQueues :: AgentClient -> Bool -> [RcvQueueSub] -> AM' [(RcvQueueSub, Either AgentErrorType (Maybe ServiceId))] +subscribeQueues c withEvents qs = do + (errs, qs') <- checkQueues c qs + atomically $ modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs')) + qss <- batchQueues mkSMPTSession qs' <$> getSessionModeIO c + mapM_ (addPendingSubs c) qss + rs <- mapConcurrently (subscribeQueues_ c withEvents) qss + when withEvents $ forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (first qConnId) + pure $ map (second Left) errs <> concatMap L.toList rs + +addPendingSubs :: AgentClient -> (SMPTransportSession, NonEmpty RcvQueueSub) -> AM' () +addPendingSubs c (tSess, qs') = atomically $ SS.batchAddPendingSubs tSess (L.toList qs') $ currentSubs c + +subscribeQueues_ :: AgentClient -> Bool -> (SMPTransportSession, NonEmpty RcvQueueSub) -> AM' (BatchResponses RcvQueueSub AgentErrorType (Maybe ServiceId)) +subscribeQueues_ c withEvents qs'@(tSess@(_, srv, _), _) = do + (rs, active) <- subscribeSessQueues_ c withEvents qs' + if active + then when (hasTempErrors rs) resubscribe $> rs + else do + logWarn "subcription batch result for replaced SMP client, resubscribing" + -- we use BROKER NETWORK error here instead of the original error, so it becomes temporary. + resubscribe $> L.map (second $ Left . toNESubscribeError) rs + where + -- treating host errors as temporary here as well + hasTempErrors = any (either temporaryOrHostError (const False) . snd) + toNESubscribeError = BROKER (B.unpack $ strEncode srv) . NETWORK . NESubscribeError . show + resubscribe = resubscribeSMPSession c tSess + +subscribeUserServerQueues :: AgentClient -> UserId -> SMPServer -> [RcvQueueSub] -> AM' [(RcvQueueSub, Either AgentErrorType (Maybe ServiceId))] +subscribeUserServerQueues c userId srv qs = do + mode <- getSessionModeIO c + if mode == TSMEntity + then subscribeQueues c True qs + else do + let tSess = (userId, srv, Nothing) + (errs, qs_) <- checkQueues c qs + forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (first qConnId) + let errs' = map (second Left) errs + case L.nonEmpty qs_ of + Just qs' -> do + atomically $ modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId $ L.toList qs')) + addPendingSubs c (tSess, qs') + rs <- subscribeQueues_ c True (tSess, qs') + pure $ errs' <> L.toList rs + Nothing -> pure errs' + +-- only "checked" queues are subscribed +checkQueues :: AgentClient -> [RcvQueueSub] -> AM' ([(RcvQueueSub, AgentErrorType)], [RcvQueueSub]) +checkQueues c = fmap partitionEithers . mapM checkQueue where checkQueue rq = do prohibited <- liftIO $ hasGetLock c rq - pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED "subscribeQueues") else Right rq - subscribeQueues_ :: Env -> TVar (Maybe SessionId) -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError (Maybe ServiceId)) - subscribeQueues_ env session smp qs' = do - let (userId, srv, _) = transportSession' smp + pure $ if prohibited then Left (rq, CMD PROHIBITED "checkQueues") else Right rq + +-- This function expects that all queues belong to one transport session, +-- and that they are already added to pending subscriptions. +resubscribeSessQueues :: AgentClient -> SMPTransportSession -> [RcvQueueSub] -> AM' () +resubscribeSessQueues c tSess qs = do + (errs, qs_) <- checkQueues c qs + forM_ (L.nonEmpty qs_) $ \qs' -> void $ subscribeSessQueues_ c True (tSess, qs') + forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (first qConnId) + +subscribeSessQueues_ :: AgentClient -> Bool -> (SMPTransportSession, NonEmpty RcvQueueSub) -> AM' (BatchResponses RcvQueueSub AgentErrorType (Maybe ServiceId), Bool) +subscribeSessQueues_ c withEvents qs = sendClientBatch_ "SUB" False subscribe_ c NRMBackground qs + where + subscribe_ :: SMPClient -> NonEmpty RcvQueueSub -> IO (BatchResponses RcvQueueSub SMPClientError (Maybe ServiceId), Bool) + subscribe_ smp qs' = do + let (userId, srv, _) = tSess atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs' rs <- sendBatch (\smp' _ -> subscribeSMPQueues smp') smp NRMBackground qs' - active <- - atomically $ - ifM + cs_ <- + if withEvents + then Just . S.fromList . map qConnId . M.elems <$> atomically (SS.getActiveSubs tSess $ currentSubs c) + else pure Nothing + active <- E.uninterruptibleMask_ $ do + (active, notices) <- atomically $ do + r@(_, notices) <- ifM (activeClientSession c tSess sessId) - (writeTVar session (Just sessId) >> processSubResults rs $> True) - (incSMPServerStat' c userId srv connSubIgnored (length rs) $> False) - if active - then when (hasTempErrors rs) resubscribe $> rs - else do - logWarn "subcription batch result for replaced SMP client, resubscribing" - -- TODO we probably use PCENetworkError here instead of the original error, so it becomes temporary. - resubscribe $> L.map (second $ Left . PCENetworkError . NESubscribeError . show) rs + ((True,) <$> processSubResults c tSess sessId rs) + ((False, []) <$ incSMPServerStat' c userId srv connSubIgnored (length rs)) + unless (null notices) $ takeTMVar $ clientNoticesLock c + pure r + unless (null notices) $ void $ + (processClientNotices c tSess notices `runReaderT` agentEnv c) + `E.finally` atomically (putTMVar (clientNoticesLock c) ()) + pure active + forM_ cs_ $ \cs -> do + let (errs, okConns) = partitionEithers $ map (\(RcvQueueSub {connId}, r) -> bimap (connId,) (const connId) r) $ L.toList rs + conns = filter (`S.notMember` cs) okConns + unless (null conns) $ notifySub c $ UP srv conns + forM_ (L.nonEmpty errs) $ \errs' -> do + let noFinalErrs = all (temporaryClientError . snd) errs' + addr = B.unpack $ strEncode srv + notifySub c $ ERRS $ L.map (second $ protocolClientError SMP addr) errs' + when (null okConns && S.null cs && noFinalErrs && active) $ liftIO $ do + -- We only close the client session that was used to subscribe. + v_ <- atomically $ ifM (activeClientSession c tSess sessId) (TM.lookupDelete tSess $ smpClients c) (pure Nothing) + mapM_ (closeClient_ c) v_ + pure (rs, active) where tSess = transportSession' smp sessId = sessionId $ thParams smp - hasTempErrors = any (either temporaryClientError (const False) . snd) - processSubResults :: NonEmpty (RcvQueue, Either SMPClientError (Maybe ServiceId)) -> STM () - processSubResults = mapM_ $ uncurry $ processSubResult c sessId - resubscribe = resubscribeSMPSession c tSess `runReaderT` env + +processClientNotices :: AgentClient -> SMPTransportSession -> [(RcvQueueSub, Maybe ClientNotice)] -> AM' () +processClientNotices c@AgentClient {presetServers} tSess notices = do + now <- liftIO getSystemSeconds + tryAllErrors' (withStore' c $ \db -> (,) <$> updateClientNotices db tSess now notices <*> getClientNotices db presetServers) >>= \case + Right (noticeIds, clntNotices) -> atomically $ do + SS.updateClientNotices tSess noticeIds $ currentSubs c + writeTVar (clientNotices c) clntNotices + Left e -> do + logError $ "processClientNotices error: " <> tshow e + notifySub' c "" $ ERR e activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c) @@ -1551,72 +1661,94 @@ type BatchResponses q e r = NonEmpty (q, Either e r) -- Please note: this function does not preserve order of results to be the same as the order of arguments, -- it includes arguments in the results instead. -sendTSessionBatches :: forall q r. ByteString -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses q SMPClientError r)) -> AgentClient -> NetworkRequestMode -> [q] -> AM' [(q, Either AgentErrorType r)] -sendTSessionBatches statCmd toRQ action c nm qs = - concatMap L.toList <$> (mapConcurrently sendClientBatch =<< batchQueues) - where - batchQueues :: AM' [(SMPTransportSession, NonEmpty q)] - batchQueues = do - mode <- getSessionMode c - pure . M.assocs $ foldr (batch mode) M.empty qs - where - batch mode q m = - let tSess = mkSMPTSession (toRQ q) mode - in M.alter (Just . maybe [q] (q <|)) tSess m - sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r) - sendClientBatch (tSess@(_, srv, _), qs') = - tryAllErrors' (getSMPServerClient c nm tSess) >>= \case - Left e -> pure $ L.map (,Left e) qs' - Right (SMPConnectedClient smp _) -> liftIO $ do - logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd - L.map agentError <$> action smp qs' - where - agentError = second . first $ protocolClientError SMP $ clientServer smp +sendTSessionBatches :: forall q r. ByteString -> (q -> TransportSessionMode -> SMPTransportSession) -> (SMPClient -> NonEmpty q -> IO (BatchResponses q SMPClientError r)) -> AgentClient -> NetworkRequestMode -> [q] -> AM' [(q, Either AgentErrorType r)] +sendTSessionBatches statCmd mkSession action c nm qs = do + qs' <- batchQueues mkSession qs <$> getSessionModeIO c + concatMap L.toList <$> mapConcurrently (sendClientBatch statCmd action c nm) qs' -sendBatch :: (SMPClient -> NetworkRequestMode -> NonEmpty (SMP.RecipientId, SMP.RcvPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError a))) -> SMPClient -> NetworkRequestMode -> NonEmpty RcvQueue -> IO (BatchResponses RcvQueue SMPClientError a) +batchQueues :: (q -> TransportSessionMode -> SMPTransportSession) -> [q] -> TransportSessionMode -> [(SMPTransportSession, NonEmpty q)] +batchQueues mkSession qs mode = M.assocs $ foldr batch M.empty qs + where + batch q m = + let tSess = mkSession q mode + in M.alter (Just . maybe [q] (q <|)) tSess m + +sendClientBatch :: ByteString -> (SMPClient -> NonEmpty q -> IO (BatchResponses q SMPClientError r)) -> AgentClient -> NetworkRequestMode -> (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r) +sendClientBatch statCmd action = fmap fst .:. sendClientBatch_ statCmd () (fmap (,()) .: action) +{-# INLINE sendClientBatch #-} + +sendClientBatch_ :: ByteString -> res -> (SMPClient -> NonEmpty q -> IO (BatchResponses q SMPClientError r, res)) -> AgentClient -> NetworkRequestMode -> (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r, res) +sendClientBatch_ statCmd errRes action c nm (tSess@(_, srv, _), qs') = + tryAllErrors' (getSMPServerClient c nm tSess) >>= \case + Left e -> pure (L.map (,Left e) qs', errRes) + Right (SMPConnectedClient smp _) -> liftIO $ do + logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd + first (L.map agentError) <$> action smp qs' + where + agentError = second . first $ protocolClientError SMP $ B.unpack (strEncode srv) + +sendBatch :: SomeRcvQueue q => (SMPClient -> NetworkRequestMode -> NonEmpty (SMP.RecipientId, SMP.RcvPrivateAuthKey) -> IO (NonEmpty (Either SMPClientError a))) -> SMPClient -> NetworkRequestMode -> NonEmpty q -> IO (BatchResponses q SMPClientError a) sendBatch smpCmdFunc smp nm qs = L.zip qs <$> smpCmdFunc smp nm (L.map queueCreds qs) where - queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvId, rcvPrivateKey) + queueCreds q = (queueId q, rcvAuthKey q) -addSubscription :: AgentClient -> SessionId -> RcvQueue -> STM () -addSubscription c sessId rq@RcvQueue {connId} = do - modifyTVar' (subscrConns c) $ S.insert connId - RQ.addQueue (sessId, rq) $ activeSubs c - RQ.deleteQueue rq $ pendingSubs c +failSubscription :: SomeRcvQueue q => AgentClient -> SMPTransportSession -> q -> SMPClientError -> STM () +failSubscription c tSess rq e = do + let rId = queueId rq + TM.insert rId e =<< getRemovedSubs c (qUserId rq, qServer rq) + SS.deletePendingSub tSess rId $ currentSubs c -failSubscription :: AgentClient -> RcvQueue -> SMPClientError -> STM () -failSubscription c rq e = do - RQ.deleteQueue rq (pendingSubs c) - TM.insert (RQ.qKey rq) e (removedSubs c) +failSubscriptions :: AgentClient -> SMPTransportSession -> Map SMP.RecipientId SMPClientError -> STM () +failSubscriptions c tSess@(uId, srv, _) qs = do + TM.union qs =<< getRemovedSubs c (uId, srv) + SS.batchDeletePendingSubs tSess (M.keysSet qs) $ currentSubs c -addPendingSubscription :: AgentClient -> RcvQueue -> STM () -addPendingSubscription c rq@RcvQueue {connId} = do - modifyTVar' (subscrConns c) $ S.insert connId - RQ.addQueue rq $ pendingSubs c +getRemovedSubs :: AgentClient -> (UserId, SMPServer) -> STM (TMap SMP.RecipientId SMPClientError) +getRemovedSubs AgentClient {removedSubs} k = TM.lookup k removedSubs >>= maybe new pure + where + new = do + s <- newTVar M.empty + TM.insert k s removedSubs + pure s addNewQueueSubscription :: AgentClient -> RcvQueue -> SMPTransportSession -> SessionId -> AM' () -addNewQueueSubscription c rq tSess sessId = do - same <- - atomically $ - ifM - (activeClientSession c tSess sessId) - (True <$ addSubscription c sessId rq) - (False <$ addPendingSubscription c rq) +addNewQueueSubscription c rq' tSess sessId = do + let rq = rcvQueueSub rq' + same <- atomically $ do + modifyTVar' (subscrConns c) $ S.insert $ qConnId rq + active <- activeClientSession c tSess sessId + if active + then SS.addActiveSub tSess sessId rq $ currentSubs c + else SS.addPendingSub tSess rq $ currentSubs c + pure active unless same $ resubscribeSMPSession c tSess -hasActiveSubscription :: AgentClient -> ConnId -> STM Bool -hasActiveSubscription c connId = RQ.hasConn connId $ activeSubs c +hasActiveSubscription :: SomeRcvQueue q => AgentClient -> q -> STM Bool +hasActiveSubscription c rq = do + tSess <- mkSMPTransportSession c rq + SS.hasActiveSub tSess (queueId rq) $ currentSubs c {-# INLINE hasActiveSubscription #-} -hasPendingSubscription :: AgentClient -> ConnId -> STM Bool -hasPendingSubscription c connId = RQ.hasConn connId $ pendingSubs c +hasPendingSubscription :: SomeRcvQueue q => AgentClient -> q -> STM Bool +hasPendingSubscription c rq = do + tSess <- mkSMPTransportSession c rq + SS.hasPendingSub tSess (queueId rq) $ currentSubs c {-# INLINE hasPendingSubscription #-} -removeSubscription :: AgentClient -> ConnId -> STM () -removeSubscription c connId = do +hasRemovedSubscription :: SomeRcvQueue q => AgentClient -> q -> STM (Maybe SMPClientError) +hasRemovedSubscription c rq = do + TM.lookup (qUserId rq, qServer rq) (removedSubs c) $>>= TM.lookup (queueId rq) + +removeSubscription :: SomeRcvQueue q => AgentClient -> SMPTransportSession -> ConnId -> q -> STM () +removeSubscription c tSess connId rq = do modifyTVar' (subscrConns c) $ S.delete connId - RQ.deleteConn connId $ activeSubs c - RQ.deleteConn connId $ pendingSubs c + SS.deleteSub tSess (queueId rq) $ currentSubs c + +removeSubscriptions :: SomeRcvQueue q => AgentClient -> [ConnId] -> [q] -> STM () +removeSubscriptions c connIds qs = do + unless (null connIds) $ modifyTVar' (subscrConns c) (`S.difference` (S.fromList connIds)) + qss <- batchQueues mkSMPTSession qs <$> getSessionMode c + forM_ qss $ \(tSess, qs') -> SS.batchDeleteSubs tSess (L.toList qs') $ currentSubs c getSubscriptions :: AgentClient -> IO (Set ConnId) getSubscriptions = readTVarIO . subscrConns @@ -1644,8 +1776,8 @@ logSecret' = B64.encode . B.take 3 {-# INLINE logSecret' #-} sendConfirmation :: AgentClient -> NetworkRequestMode -> SndQueue -> ByteString -> AM (Maybe SMPServer) -sendConfirmation c nm sq@SndQueue {userId, server, connId, sndId, queueMode, sndPublicKey, sndPrivateKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation = do - let (privHdr, spKey) = if senderCanSecure queueMode then (SMP.PHEmpty, Just sndPrivateKey) else (SMP.PHConfirmation sndPublicKey, Nothing) +sendConfirmation c nm sq@SndQueue {userId, server, connId, sndId, queueMode, sndPrivateKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation = do + let (privHdr, spKey) = if senderCanSecure queueMode then (SMP.PHEmpty, Just sndPrivateKey) else (SMP.PHConfirmation (C.toPublic sndPrivateKey), Nothing) clientMsg = SMP.ClientMessage privHdr agentConfirmation msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg sendOrProxySMPMessage c nm userId server connId "" spKey sndId (MsgFlags {notification = True}) msg @@ -1691,12 +1823,12 @@ secureQueue c nm rq@RcvQueue {rcvId, rcvPrivateKey} senderKey = secureSMPQueue smp nm rcvPrivateKey rcvId senderKey secureSndQueue :: AgentClient -> NetworkRequestMode -> SndQueue -> AM () -secureSndQueue c nm SndQueue {userId, connId, server, sndId, sndPrivateKey, sndPublicKey} = +secureSndQueue c nm SndQueue {userId, connId, server, sndId, sndPrivateKey} = void $ sendOrProxySMPCommand c nm userId server connId "SKEY " sndId secureViaProxy secureDirectly where -- TODO track statistics - secureViaProxy smp proxySess = proxySecureSndSMPQueue smp nm proxySess sndPrivateKey sndId sndPublicKey - secureDirectly smp = secureSndSMPQueue smp nm sndPrivateKey sndId sndPublicKey + secureViaProxy smp proxySess = proxySecureSndSMPQueue smp nm proxySess sndPrivateKey sndId + secureDirectly smp = secureSndSMPQueue smp nm sndPrivateKey sndId addQueueLink :: AgentClient -> NetworkRequestMode -> RcvQueue -> SMP.LinkId -> QueueLinkData -> AM () addQueueLink c nm rq@RcvQueue {rcvId, rcvPrivateKey} lnkId d = @@ -1707,11 +1839,11 @@ deleteQueueLink c nm rq@RcvQueue {rcvId, rcvPrivateKey} = withSMPClient c nm rq "LDEL" $ \smp -> deleteSMPQueueLink smp nm rcvPrivateKey rcvId secureGetQueueLink :: AgentClient -> NetworkRequestMode -> UserId -> InvShortLink -> AM (SMP.SenderId, QueueLinkData) -secureGetQueueLink c nm userId InvShortLink {server, linkId, sndPrivateKey, sndPublicKey} = +secureGetQueueLink c nm userId InvShortLink {server, linkId, sndPrivateKey} = snd <$> sendOrProxySMPCommand c nm userId server (unEntityId linkId) "LKEY " linkId secureGetViaProxy secureGetDirectly where - secureGetViaProxy smp proxySess = proxySecureGetSMPQueueLink smp nm proxySess sndPrivateKey linkId sndPublicKey - secureGetDirectly smp = secureGetSMPQueueLink smp nm sndPrivateKey linkId sndPublicKey + secureGetViaProxy smp proxySess = proxySecureGetSMPQueueLink smp nm proxySess sndPrivateKey linkId + secureGetDirectly smp = secureGetSMPQueueLink smp nm sndPrivateKey linkId getQueueLink :: AgentClient -> NetworkRequestMode -> UserId -> SMPServer -> SMP.LinkId -> AM (SMP.SenderId, QueueLinkData) getQueueLink c nm userId server lnkId = @@ -1733,7 +1865,7 @@ data EnableQueueNtfReq = EnableQueueNtfReq } enableQueuesNtfs :: AgentClient -> [EnableQueueNtfReq] -> AM' [(EnableQueueNtfReq, Either AgentErrorType (SMP.NotifierId, SMP.RcvNtfPublicDhKey))] -enableQueuesNtfs c = sendTSessionBatches "NKEY" eqnrRq enableQueues_ c NRMBackground +enableQueuesNtfs c = sendTSessionBatches "NKEY" (mkSMPTSession . eqnrRq) enableQueues_ c NRMBackground where enableQueues_ :: SMPClient -> NonEmpty EnableQueueNtfReq -> IO (NonEmpty (EnableQueueNtfReq, Either (ProtocolClientError ErrorType) (SMP.NotifierId, RcvNtfPublicDhKey))) enableQueues_ smp qs' = L.zip qs' <$> enableSMPQueuesNtfs smp (L.map queueCreds qs') @@ -1752,7 +1884,7 @@ disableQueueNotifications c rq@RcvQueue {rcvId, rcvPrivateKey} = type DisableQueueNtfReq = (NtfSubscription, RcvQueue) disableQueuesNtfs :: AgentClient -> [DisableQueueNtfReq] -> AM' [(DisableQueueNtfReq, Either AgentErrorType ())] -disableQueuesNtfs c = sendTSessionBatches "NDEL" snd disableQueues_ c NRMBackground +disableQueuesNtfs c = sendTSessionBatches "NDEL" (mkSMPTSession . snd) disableQueues_ c NRMBackground where disableQueues_ :: SMPClient -> NonEmpty DisableQueueNtfReq -> IO (NonEmpty (DisableQueueNtfReq, Either (ProtocolClientError ErrorType) ())) disableQueues_ smp qs' = L.zip qs' <$> disableSMPQueuesNtfs smp (L.map queueCreds qs') @@ -1764,16 +1896,23 @@ sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = withSMPClient c NRMBackground rq ("ACK:" <> logSecret' msgId) $ \smp -> ackSMPMessage smp rcvPrivateKey rcvId msgId -hasGetLock :: AgentClient -> RcvQueue -> IO Bool -hasGetLock c RcvQueue {server, rcvId} = - TM.memberIO (server, rcvId) $ getMsgLocks c +hasGetLock :: SomeRcvQueue q => AgentClient -> q -> IO Bool +hasGetLock c rq = + TM.memberIO (qServer rq, queueId rq) $ getMsgLocks c {-# INLINE hasGetLock #-} -releaseGetLock :: AgentClient -> RcvQueue -> STM () -releaseGetLock c RcvQueue {server, rcvId} = - TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ()) +releaseGetLock :: SomeRcvQueue q => AgentClient -> q -> STM () +releaseGetLock c rq = + TM.lookup (qServer rq, queueId rq) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ()) {-# INLINE releaseGetLock #-} +releaseGetLocksIO :: SomeRcvQueue q => AgentClient -> [q] -> IO () +releaseGetLocksIO c rqs = do + locks <- readTVarIO $ getMsgLocks c + forM_ rqs $ \rq -> + forM_ (M.lookup ((qServer rq, queueId rq)) locks) $ \lock -> + atomically $ tryPutTMVar lock () + suspendQueue :: AgentClient -> NetworkRequestMode -> RcvQueue -> AM () suspendQueue c nm rq@RcvQueue {rcvId, rcvPrivateKey} = withSMPClient c nm rq "OFF" $ \smp -> @@ -1785,7 +1924,7 @@ deleteQueue c nm rq@RcvQueue {rcvId, rcvPrivateKey} = do deleteSMPQueue smp nm rcvPrivateKey rcvId deleteQueues :: AgentClient -> NetworkRequestMode -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())] -deleteQueues c nm = sendTSessionBatches "DEL" id deleteQueues_ c nm +deleteQueues c nm = sendTSessionBatches "DEL" mkSMPTSession deleteQueues_ c nm where deleteQueues_ smp rqs = do let (userId, srv, _) = transportSession' smp @@ -1876,7 +2015,7 @@ withNtfBatch cmdStr action c NtfToken {ntfServer, ntfPrivKey} subs = do logServer' "-->" c ntfServer (bshow (length subs) <> " subscriptions") cmdStr L.map agentError <$> action ntf ntfPrivKey subs where - agentError = first $ protocolClientError NTF $ clientServer ntf + agentError = first $ protocolClientError NTF $ B.unpack (strEncode ntfServer) agentNtfDeleteSubscription :: AgentClient -> NtfSubscriptionId -> NtfToken -> AM () agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} = @@ -2000,9 +2139,7 @@ withWorkItems c doWork getWork action = do forM_ criticalErr $ \err -> do notifyErr (CRITICAL False) err when (all isWorkItemError errs) noWork - unless (null errs) $ - atomically $ - writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS $ map (\e -> ("", INTERNAL $ show e)) errs) + forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (\e -> ("", INTERNAL $ show e)) Left e | isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e | otherwise -> notifyErr INTERNAL e @@ -2313,15 +2450,16 @@ data ServerSessions = ServerSessions getAgentSubsTotal :: AgentClient -> [UserId] -> IO (SMPServerSubs, Bool) getAgentSubsTotal c userIds = do - ssActive <- getSubsCount activeSubs - ssPending <- getSubsCount pendingSubs + (ssActive, ssPending) <- SS.foldSessionSubs addSub (0, 0) $ currentSubs c sess <- hasSession . M.toList =<< readTVarIO (smpClients c) pure (SMPServerSubs {ssActive, ssPending}, sess) where - getSubsCount :: (AgentClient -> TRcvQueues q) -> IO Int - getSubsCount subs = M.foldrWithKey' addSub 0 <$> readTVarIO (getRcvQueues $ subs c) - addSub :: (UserId, SMPServer, SMP.RecipientId) -> q -> Int -> Int - addSub (userId, _, _) _ cnt = if userId `elem` userIds then cnt + 1 else cnt + addSub :: (Int, Int) -> (SMPTransportSession, SS.SessSubs) -> IO (Int, Int) + addSub acc@(!ssActive, !ssPending) ((userId, _, _), s) + | userId `elem` userIds = do + (active, pending) <- SS.mapSubs M.size s + pure (ssActive + active, ssPending + pending) + | otherwise = pure acc hasSession :: [(SMPTransportSession, SMPClientVar)] -> IO Bool hasSession = \case [] -> pure False @@ -2358,13 +2496,12 @@ getAgentServersSummary c@AgentClient {smpServersStats, xftpServersStats, ntfServ ntfServersSessions } where - getServerSubs = do - subs <- M.foldrWithKey' (addSub incActive) M.empty <$> readTVarIO (getRcvQueues $ activeSubs c) - M.foldrWithKey' (addSub incPending) subs <$> readTVarIO (getRcvQueues $ pendingSubs c) + getServerSubs = SS.foldSessionSubs addSub M.empty $ currentSubs c where - addSub f (userId, srv, _) _ = M.alter (Just . f . fromMaybe SMPServerSubs {ssActive = 0, ssPending = 0}) (userId, srv) - incActive ss = ss {ssActive = ssActive ss + 1} - incPending ss = ss {ssPending = ssPending ss + 1} + addSub subs ((userId, srv, _), s) = do + (active, pending) <- SS.mapSubs M.size s + let add ss = ss {ssActive = ssActive ss + active, ssPending = ssPending ss + pending} + pure $ M.alter (Just . add . fromMaybe (SMPServerSubs 0 0)) (userId, srv) subs Env {xftpAgent = XFTPAgent {xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers}} = agentEnv getXFTPWorkerSrvs workers = foldM addSrv [] . M.toList =<< readTVarIO workers where @@ -2396,14 +2533,21 @@ data SubscriptionsInfo = SubscriptionsInfo getAgentSubscriptions :: AgentClient -> IO SubscriptionsInfo getAgentSubscriptions c = do - activeSubscriptions <- getSubs activeSubs - pendingSubscriptions <- getSubs pendingSubs - removedSubscriptions <- getRemovedSubs + (activeSubscriptions, pendingSubscriptions) <- SS.foldSessionSubs addSubs ([], []) $ currentSubs c + removedSubscriptions <- getRemoved pure $ SubscriptionsInfo {activeSubscriptions, pendingSubscriptions, removedSubscriptions} where - getSubs :: (AgentClient -> TRcvQueues q) -> IO [SubInfo] - getSubs sel = map (`subInfo` Nothing) . M.keys <$> readTVarIO (getRcvQueues $ sel c) - getRemovedSubs = map (uncurry subInfo . second Just) . M.assocs <$> readTVarIO (removedSubs c) + addSubs :: ([SubInfo], [SubInfo]) -> (SMPTransportSession, SS.SessSubs) -> IO ([SubInfo], [SubInfo]) + addSubs (active, pending) ((userId, srv, _), s) = do + (active', pending') <- SS.mapSubs (map (\rId -> subInfo (userId, srv, rId) Nothing) . M.keys) s + pure (active' ++ active, pending' ++ pending) + getRemoved :: IO [SubInfo] + getRemoved = foldM addSubInfo [] . M.assocs =<< readTVarIO (removedSubs c) + where + addSubInfo :: [SubInfo] -> ((UserId, SMPServer), TMap SMP.RecipientId SMPClientError) -> IO [SubInfo] + addSubInfo ss ((uId, srv), errs) = do + ss' <- map (\(rId, e) -> subInfo (uId, srv, rId) (Just e)) . M.assocs <$> readTVarIO errs + pure $ ss' ++ ss subInfo :: (UserId, SMPServer, SMP.RecipientId) -> Maybe SMPClientError -> SubInfo subInfo (uId, srv, rId) err = SubInfo {userId = uId, server = enc srv, rcvId = enc rId, subError = show <$> err} enc :: StrEncoding a => a -> Text diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index e15ffa48c..57bc11e3c 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -90,7 +90,8 @@ data InitialAgentServers = InitialAgentServers ntf :: [NtfServer], xftp :: Map UserId (NonEmpty (ServerCfg 'PXFTP)), netCfg :: NetworkConfig, - presetDomains :: [HostName] + presetDomains :: [HostName], + presetServers :: [SMPServer] } data ServerCfg p = ServerCfg @@ -166,6 +167,7 @@ data AgentConfig = AgentConfig ntfBatchSize :: Int, ntfSubFirstCheckInterval :: NominalDiffTime, ntfSubCheckInterval :: NominalDiffTime, + maxPendingSubscriptions :: Int, caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, @@ -237,6 +239,7 @@ defaultAgentConfig = ntfBatchSize = 150, ntfSubFirstCheckInterval = nominalDay, ntfSubCheckInterval = 3 * nominalDay, + maxPendingSubscriptions = 35000, -- CA certificate private key is not needed for initialization -- ! we do not generate these caCertificateFile = "/etc/opt/simplex-agent/ca.crt", diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 85fa45c49..fe852ac64 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -513,7 +513,7 @@ notifyInternalError' AgentClient {subQ} internalErrStr = atomically $ writeTBQue {-# INLINE notifyInternalError' #-} notifyErrs :: MonadIO m => AgentClient -> [(ConnId, AgentErrorType)] -> m () -notifyErrs AgentClient {subQ} connErrs = unless (null connErrs) $ atomically $ writeTBQueue subQ ("", "", AEvt SAENone $ ERRS connErrs) +notifyErrs c = mapM_ (notifySub c . ERRS) . L.nonEmpty {-# INLINE notifyErrs #-} getNtfToken :: AM' (Maybe NtfToken) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index d4d302df7..05ebc1b27 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -70,6 +70,7 @@ module Simplex.Messaging.Agent.Protocol MsgMeta (..), RcvQueueInfo (..), SndQueueInfo (..), + SubscriptionStatus (..), ConnectionStats (..), SwitchPhase (..), RcvSwitchStatus (..), @@ -111,6 +112,8 @@ module Simplex.Messaging.Agent.Protocol ServiceScheme, FixedLinkData (..), ConnLinkData (..), + UserConnLinkData (..), + UserContactData (..), UserLinkData (..), OwnerAuth (..), OwnerId, @@ -167,6 +170,7 @@ module Simplex.Messaging.Agent.Protocol updateSMPServerHosts, shortenShortLink, restoreShortLink, + isPresetServer, linkUserData, linkUserData', ) @@ -405,7 +409,7 @@ data AEvent (e :: AEntity) where OK :: AEvent AEConn JOINED :: SndQueueSecured -> Maybe ClientServiceId -> AEvent AEConn ERR :: AgentErrorType -> AEvent AEConn - ERRS :: [(ConnId, AgentErrorType)] -> AEvent AENone + ERRS :: NonEmpty (ConnId, AgentErrorType) -> AEvent AENone SUSPENDED :: AEvent AENone RFPROG :: Int64 -> Int64 -> AEvent AERcvFile RFDONE :: FilePath -> AEvent AERcvFile @@ -643,23 +647,34 @@ instance FromJSON RatchetSyncState where data RcvQueueInfo = RcvQueueInfo { rcvServer :: SMPServer, + status :: QueueStatus, rcvSwitchStatus :: Maybe RcvSwitchStatus, - canAbortSwitch :: Bool + canAbortSwitch :: Bool, + subStatus :: SubscriptionStatus } deriving (Eq, Show) data SndQueueInfo = SndQueueInfo { sndServer :: SMPServer, + status :: QueueStatus, sndSwitchStatus :: Maybe SndSwitchStatus } deriving (Eq, Show) +data SubscriptionStatus + = SSActive + | SSPending + | SSRemoved {subError :: String} + | SSNoSub + deriving (Eq, Ord, Show) + data ConnectionStats = ConnectionStats { connAgentVersion :: VersionSMPA, rcvQueuesInfo :: [RcvQueueInfo], sndQueuesInfo :: [SndQueueInfo], ratchetSyncState :: RatchetSyncState, - ratchetSyncSupported :: Bool + ratchetSyncSupported :: Bool, + subStatus :: Maybe SubscriptionStatus } deriving (Eq, Show) @@ -1612,15 +1627,16 @@ shortenShortLink presetSrvs = \case CSLInvitation sch srv lnkId linkKey -> CSLInvitation sch (shortServer srv) lnkId linkKey CSLContact sch ct srv linkKey -> CSLContact sch ct (shortServer srv) linkKey where - shortServer srv@(SMPServer hs@(h :| _) p kh) = - if isPresetServer then SMPServerOnlyHost h else srv - where - isPresetServer = case findPresetServer srv presetSrvs of - Just (SMPServer hs' p' kh') -> - all (`elem` hs') hs - && (p == p' || (null p' && (p == "443" || p == "5223"))) - && kh == kh' - Nothing -> False + shortServer srv@(SMPServer (h :| _) _ _) = + if isPresetServer srv presetSrvs then SMPServerOnlyHost h else srv + +isPresetServer :: Foldable t => SMPServer -> t SMPServer -> Bool +isPresetServer srv@(SMPServer hs p kh) presetSrvs = case findPresetServer srv presetSrvs of + Just (SMPServer hs' p' kh') -> + all (`elem` hs') hs + && (p == p' || (null p' && (p == "443" || p == "5223"))) + && kh == kh' + Nothing -> False -- explicit bidirectional is used for ghc 8.10.7 compatibility, [h]/[] patterns are not reversible. pattern SMPServerOnlyHost :: TransportHost -> SMPServer @@ -1638,7 +1654,7 @@ restoreShortLink presetSrvs = \case s@(SMPServerOnlyHost _) -> fromMaybe s $ findPresetServer s presetSrvs s -> s -findPresetServer :: SMPServer -> NonEmpty SMPServer -> Maybe SMPServer +findPresetServer :: Foldable t => SMPServer -> t SMPServer -> Maybe SMPServer findPresetServer ProtocolServer {host = h :| _} = find (\ProtocolServer {host = h' :| _} -> h == h') {-# INLINE findPresetServer #-} @@ -1676,25 +1692,30 @@ data FixedLinkData c = FixedLinkData data ConnLinkData c where InvitationLinkData :: VersionRangeSMPA -> UserLinkData -> ConnLinkData 'CMInvitation - ContactLinkData :: - { agentVRange :: VersionRangeSMPA, - -- direct connection via connReq in fixed data is allowed. - direct :: Bool, - -- additional owner keys to sign changes of mutable data. - owners :: [OwnerAuth], - -- alternative addresses of chat relays that receive requests for this contact address. - relays :: [ConnShortLink 'CMContact], - userData :: UserLinkData - } -> ConnLinkData 'CMContact + ContactLinkData :: VersionRangeSMPA -> UserContactData -> ConnLinkData 'CMContact + +data UserContactData = UserContactData + { -- direct connection via connReq in fixed data is allowed. + direct :: Bool, + -- additional owner keys to sign changes of mutable data. + owners :: [OwnerAuth], + -- alternative addresses of chat relays that receive requests for this contact address. + relays :: [ConnShortLink 'CMContact], + userData :: UserLinkData + } newtype UserLinkData = UserLinkData ByteString data AConnLinkData = forall m. ConnectionModeI m => ACLD (SConnectionMode m) (ConnLinkData m) +data UserConnLinkData c where + UserInvLinkData :: UserLinkData -> UserConnLinkData 'CMInvitation + UserContactLinkData :: UserContactData -> UserConnLinkData 'CMContact + linkUserData :: ConnLinkData c -> UserLinkData linkUserData = \case InvitationLinkData _ d -> d - ContactLinkData {userData} -> userData + ContactLinkData _ UserContactData {userData} -> userData {-# INLINE linkUserData #-} linkUserData' :: ConnLinkData c -> ByteString @@ -1735,8 +1756,8 @@ instance ConnectionModeI c => Encoding (FixedLinkData c) where instance ConnectionModeI c => Encoding (ConnLinkData c) where smpEncode = \case InvitationLinkData vr userData -> smpEncode (CMInvitation, vr, userData) - ContactLinkData {agentVRange, direct, owners, relays, userData} -> - B.concat [smpEncode (CMContact, agentVRange, direct), smpEncodeList owners, smpEncodeList relays, smpEncode userData] + ContactLinkData vr UserContactData {direct, owners, relays, userData} -> + B.concat [smpEncode (CMContact, vr, direct), smpEncodeList owners, smpEncodeList relays, smpEncode userData] smpP = (\(ACLD _ d) -> checkConnMode d) <$?> smpP {-# INLINE smpP #-} @@ -1749,11 +1770,12 @@ instance Encoding AConnLinkData where (vr, userData) <- smpP <* A.takeByteString -- ignoring tail for forward compatibility with the future link data encoding pure $ ACLD SCMInvitation $ InvitationLinkData vr userData CMContact -> do - (agentVRange, direct) <- smpP + (vr, direct) <- smpP owners <- smpListP relays <- smpListP userData <- smpP <* A.takeByteString -- ignoring tail for forward compatibility with the future link data encoding - pure $ ACLD SCMContact ContactLinkData {agentVRange, direct, owners, relays, userData} + let cd = UserContactData {direct, owners, relays, userData} + pure $ ACLD SCMContact $ ContactLinkData vr cd instance Encoding UserLinkData where smpEncode (UserLinkData s) = if B.length s <= 254 then smpEncode s else smpEncode ('\255', Large s) @@ -1859,6 +1881,8 @@ data AgentErrorType BROKER {brokerAddress :: String, brokerErr :: BrokerErrorType} | -- | errors of other agents AGENT {agentErr :: SMPAgentError} + | -- | client notice + NOTICE {server :: Text, preset :: Bool, expiresAt :: Maybe UTCTime} | -- | agent implementation or dependency errors INTERNAL {internalErr :: String} | -- | critical agent errors that should be shown to the user, optionally with restart button @@ -2000,6 +2024,10 @@ serializeCommand = \case serializeBinary :: ByteString -> ByteString serializeBinary body = bshow (B.length body) <> "\n" <> body +$(J.deriveJSON (enumJSON fstToLower) ''QueueStatus) + +$(J.deriveJSON (sumTypeJSON $ dropPrefix "SS") ''SubscriptionStatus) + $(J.deriveJSON defaultJSON ''RcvQueueInfo) $(J.deriveJSON defaultJSON ''SndQueueInfo) diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 6b866cee6..c054cb267 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -49,7 +49,6 @@ import Simplex.Messaging.Protocol RcvNtfDhSecret, RcvPrivateAuthKey, SndPrivateAuthKey, - SndPublicAuthKey, VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP @@ -89,6 +88,10 @@ data StoredRcvQueue (q :: DBStored) = RcvQueue clientService :: Maybe (StoredClientService q), -- | queue status status :: QueueStatus, + -- | to enable notifications for this queue - this field is duplicated from ConnData + enableNtfs :: Bool, + -- | client notice + clientNoticeId :: Maybe NoticeId, -- | database queue ID (within connection) dbQueueId :: DBEntityId' q, -- | True for a primary or a next primary queue of the connection (next if dbReplaceQueueId is set) @@ -104,6 +107,25 @@ data StoredRcvQueue (q :: DBStored) = RcvQueue } deriving (Show) +data RcvQueueSub = RcvQueueSub + { userId :: UserId, + connId :: ConnId, + server :: SMPServer, + rcvId :: SMP.RecipientId, + rcvPrivateKey :: RcvPrivateAuthKey, + status :: QueueStatus, + enableNtfs :: Bool, + clientNoticeId :: Maybe NoticeId, + dbQueueId :: Int64, + primary :: Bool, + dbReplaceQueueId :: Maybe Int64 + } + deriving (Show) + +rcvQueueSub :: RcvQueue -> RcvQueueSub +rcvQueueSub RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, status, enableNtfs, clientNoticeId, dbQueueId = DBEntityId dbQueueId, primary, dbReplaceQueueId} = + RcvQueueSub {userId, connId, server, rcvId, rcvPrivateKey, status, enableNtfs, clientNoticeId, dbQueueId, primary, dbReplaceQueueId} + data ShortLinkCreds = ShortLinkCreds { shortLinkId :: SMP.LinkId, shortLinkKey :: LinkKey, @@ -116,10 +138,6 @@ clientServiceId :: RcvQueue -> Maybe ClientServiceId clientServiceId = fmap dbServiceId . clientService {-# INLINE clientServiceId #-} -rcvQueueInfo :: RcvQueue -> RcvQueueInfo -rcvQueueInfo rq@RcvQueue {server, rcvSwchStatus} = - RcvQueueInfo {rcvServer = server, rcvSwitchStatus = rcvSwchStatus, canAbortSwitch = canAbortRcvSwitch rq} - rcvSMPQueueAddress :: RcvQueue -> SMPQueueAddress rcvSMPQueueAddress RcvQueue {server, sndId, e2ePrivKey, queueMode} = SMPQueueAddress server sndId (C.publicKey e2ePrivKey) queueMode @@ -155,7 +173,6 @@ data InvShortLink = InvShortLink linkId :: SMP.LinkId, linkKey :: LinkKey, sndPrivateKey :: SndPrivateAuthKey, -- stored to allow retries - sndPublicKey :: SndPublicAuthKey, sndId :: Maybe SMP.SenderId } deriving (Show) @@ -173,9 +190,7 @@ data StoredSndQueue (q :: DBStored) = SndQueue sndId :: SMP.SenderId, -- | sender can secure the queue queueMode :: Maybe QueueMode, - -- | key pair used by the sender to authorize transmissions - -- TODO combine keys to key pair so that types match - sndPublicKey :: SndPublicAuthKey, + -- | sender key used to authorize transmissions sndPrivateKey :: SndPrivateAuthKey, -- | DH public key used to negotiate per-queue e2e encryption e2ePubKey :: Maybe C.PublicKeyX25519, @@ -195,10 +210,6 @@ data StoredSndQueue (q :: DBStored) = SndQueue } deriving (Show) -sndQueueInfo :: SndQueue -> SndQueueInfo -sndQueueInfo SndQueue {server, sndSwchStatus} = - SndQueueInfo {sndServer = server, sndSwitchStatus = sndSwchStatus} - instance SMPQueue RcvQueue where qServer RcvQueue {server} = server {-# INLINE qServer #-} @@ -211,6 +222,12 @@ instance SMPQueue NewRcvQueue where queueId RcvQueue {rcvId} = rcvId {-# INLINE queueId #-} +instance SMPQueue RcvQueueSub where + qServer RcvQueueSub {server} = server + {-# INLINE qServer #-} + queueId RcvQueueSub {rcvId} = rcvId + {-# INLINE queueId #-} + instance SMPQueue SndQueue where qServer SndQueue {server} = server {-# INLINE qServer #-} @@ -250,6 +267,7 @@ class SMPQueue q => SMPQueueRec q where qUserId :: q -> UserId qConnId :: q -> ConnId dbQId :: q -> Int64 + qPrimary :: q -> Bool dbReplaceQId :: q -> Maybe Int64 instance SMPQueueRec RcvQueue where @@ -259,9 +277,23 @@ instance SMPQueueRec RcvQueue where {-# INLINE qConnId #-} dbQId RcvQueue {dbQueueId = DBEntityId qId} = qId {-# INLINE dbQId #-} + qPrimary RcvQueue {primary} = primary + {-# INLINE qPrimary #-} dbReplaceQId RcvQueue {dbReplaceQueueId} = dbReplaceQueueId {-# INLINE dbReplaceQId #-} +instance SMPQueueRec RcvQueueSub where + qUserId RcvQueueSub {userId} = userId + {-# INLINE qUserId #-} + qConnId RcvQueueSub {connId} = connId + {-# INLINE qConnId #-} + dbQId RcvQueueSub {dbQueueId} = dbQueueId + {-# INLINE dbQId #-} + qPrimary RcvQueueSub {primary} = primary + {-# INLINE qPrimary #-} + dbReplaceQId RcvQueueSub {dbReplaceQueueId} = dbReplaceQueueId + {-# INLINE dbReplaceQId #-} + instance SMPQueueRec SndQueue where qUserId SndQueue {userId} = userId {-# INLINE qUserId #-} @@ -269,9 +301,22 @@ instance SMPQueueRec SndQueue where {-# INLINE qConnId #-} dbQId SndQueue {dbQueueId = DBEntityId qId} = qId {-# INLINE dbQId #-} + qPrimary SndQueue {primary} = primary + {-# INLINE qPrimary #-} dbReplaceQId SndQueue {dbReplaceQueueId} = dbReplaceQueueId {-# INLINE dbReplaceQId #-} +class SMPQueueRec q => SomeRcvQueue q where + rcvAuthKey :: q -> RcvPrivateAuthKey + +instance SomeRcvQueue RcvQueue where + rcvAuthKey RcvQueue {rcvPrivateKey} = rcvPrivateKey + {-# INLINE rcvAuthKey #-} + +instance SomeRcvQueue RcvQueueSub where + rcvAuthKey RcvQueueSub {rcvPrivateKey} = rcvPrivateKey + {-# INLINE rcvAuthKey #-} + -- * Connection types -- | Type of a connection. @@ -287,16 +332,18 @@ data ConnType = CNew | CRcv | CSnd | CDuplex | CContact deriving (Eq, Show) -- -- - DuplexConnection is a connection that has both receive and send queues set up, -- typically created by upgrading a receive or a send connection with a missing queue. -data Connection (d :: ConnType) where - NewConnection :: ConnData -> Connection CNew - RcvConnection :: ConnData -> RcvQueue -> Connection CRcv - SndConnection :: ConnData -> SndQueue -> Connection CSnd - DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex - ContactConnection :: ConnData -> RcvQueue -> Connection CContact +data Connection' (d :: ConnType) rq sq where + NewConnection :: ConnData -> Connection' CNew rq sq + RcvConnection :: ConnData -> rq -> Connection' CRcv rq sq + SndConnection :: ConnData -> sq -> Connection' CSnd rq sq + DuplexConnection :: ConnData -> NonEmpty rq -> NonEmpty sq -> Connection' CDuplex rq sq + ContactConnection :: ConnData -> rq -> Connection' CContact rq sq -deriving instance Show (Connection d) +deriving instance (Show rq, Show sq) => Show (Connection' d rq sq) -toConnData :: Connection d -> ConnData +type Connection d = Connection' d RcvQueue SndQueue + +toConnData :: Connection' d rq sq -> ConnData toConnData = \case NewConnection cData -> cData RcvConnection cData _ -> cData @@ -304,7 +351,7 @@ toConnData = \case DuplexConnection cData _ _ -> cData ContactConnection cData _ -> cData -updateConnection :: ConnData -> Connection d -> Connection d +updateConnection :: ConnData -> Connection' d rq sq -> Connection' d rq sq updateConnection cData = \case NewConnection _ -> NewConnection cData RcvConnection _ rq -> RcvConnection cData rq @@ -337,9 +384,13 @@ instance TestEquality SConnType where -- | Connection of an unknown type. -- Used to refer to an arbitrary connection when retrieving from store. -data SomeConn = forall d. SomeConn (SConnType d) (Connection d) +data SomeConn' rq sq = forall d. SomeConn (SConnType d) (Connection' d rq sq) -deriving instance Show SomeConn +deriving instance (Show rq, Show sq) => Show (SomeConn' rq sq) + +type SomeConn = SomeConn' RcvQueue SndQueue + +type SomeConnSub = SomeConn' RcvQueueSub SndQueue data ConnData = ConnData { connId :: ConnId, @@ -353,6 +404,8 @@ data ConnData = ConnData } deriving (Eq, Show) +type NoticeId = Int64 + -- this function should be mirrored in the clients ratchetSyncAllowed :: ConnData -> Bool ratchetSyncAllowed ConnData {ratchetSyncState, connAgentVersion} = diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index 350d3bfe7..ef66eca38 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -28,6 +28,7 @@ module Simplex.Messaging.Agent.Store.AgentStore ( -- * Users createUserRecord, + getUserIds, deleteUserRecord, setUserDeleted, deleteUserWithoutConns, @@ -39,9 +40,16 @@ module Simplex.Messaging.Agent.Store.AgentStore updateNewConnRcv, updateNewConnSnd, createSndConn, + getClientNotices, + updateClientNotices, + getSubscriptionServers, + getUserServerRcvQueueSubs, + unsetQueuesToSubscribe, + getConnIds, getConn, getDeletedConn, getConns, + getConnSubs, getDeletedConns, getConnsData, setConnDeleted, @@ -109,6 +117,7 @@ module Simplex.Messaging.Agent.Store.AgentStore updateSndMsgRcpt, getPendingQueueMsg, getConnectionsForDelivery, + getAllSndQueuesForDelivery, updatePendingMsgRIState, deletePendingMsgs, getExpiredSndMessages, @@ -136,6 +145,7 @@ module Simplex.Messaging.Agent.Store.AgentStore -- Async commands createCommand, getPendingCommandServers, + getAllPendingCommandConns, getPendingServerCommand, updateCommandServer, deleteCommand, @@ -256,6 +266,7 @@ import Data.Int (Int64) import Data.List (foldl', sortBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L +import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.Ord (Down (..)) @@ -275,6 +286,8 @@ import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.Common import qualified Simplex.Messaging.Agent.Store.DB as DB import Simplex.Messaging.Agent.Store.DB (Binary (..), BoolInt (..), FromField (..), ToField (..), blobFieldDecoder, fromTextField_) +import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.Client (SMPTransportSession) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) import Simplex.Messaging.Crypto.Ratchet (PQEncryption (..), PQSupport (..), RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) @@ -286,7 +299,8 @@ import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.Protocol.Types +import Simplex.Messaging.SystemTime import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version.Internal @@ -294,7 +308,6 @@ import qualified UnliftIO.Exception as E import UnliftIO.STM #if defined(dbPostgres) import Data.List (sortOn) -import Data.Map.Strict (Map) import Database.PostgreSQL.Simple (In (..), Only (..), Query, SqlError, (:.) (..)) import Database.PostgreSQL.Simple.Errors (constraintViolation) import Database.PostgreSQL.Simple.SqlQQ (sql) @@ -324,6 +337,10 @@ createUserRecord db = do DB.execute_ db "INSERT INTO users DEFAULT VALUES" insertedRowId db +getUserIds :: DB.Connection -> IO [UserId] +getUserIds db = + map fromOnly <$> DB.query_ db "SELECT user_id FROM users WHERE deleted = 0" + checkUser :: DB.Connection -> UserId -> IO (Either StoreError ()) checkUser db userId = firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $ @@ -387,15 +404,15 @@ createNewConn db gVar cData cMode = do fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) -- TODO [certs rcv] store clientServiceId from NewRcvQueue -updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) -updateNewConnRcv db connId rq = +updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) +updateNewConnRcv db connId rq subMode = getConn db connId $>>= \case (SomeConn _ NewConnection {}) -> updateConn (SomeConn _ RcvConnection {}) -> updateConn -- to allow retries (SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnRcv" $ connType c where updateConn :: IO (Either StoreError RcvQueue) - updateConn = Right <$> addConnRcvQueue_ db connId rq + updateConn = Right <$> addConnRcvQueue_ db connId rq subMode updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) updateNewConnSnd db connId sq = @@ -477,25 +494,25 @@ upgradeRcvConnToDuplex db connId sq = (SomeConn c _) -> pure . Left . SEBadConnType "upgradeRcvConnToDuplex" $ connType c -- TODO [certs rcv] store clientServiceId from NewRcvQueue -upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) -upgradeSndConnToDuplex db connId rq = +upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) +upgradeSndConnToDuplex db connId rq subMode = getConn db connId >>= \case - Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq + Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode Right (SomeConn c _) -> pure . Left . SEBadConnType "upgradeSndConnToDuplex" $ connType c _ -> pure $ Left SEConnNotFound -- TODO [certs rcv] store clientServiceId from NewRcvQueue -addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) -addConnRcvQueue db connId rq = +addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) +addConnRcvQueue db connId rq subMode = getConn db connId >>= \case - Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq + Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnRcvQueue" $ connType c _ -> pure $ Left SEConnNotFound -addConnRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> IO RcvQueue -addConnRcvQueue_ db connId rq@RcvQueue {server} = do +addConnRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO RcvQueue +addConnRcvQueue_ db connId rq@RcvQueue {server} subMode = do serverKeyHash_ <- createServer_ db server - insertRcvQueue_ db connId rq serverKeyHash_ + insertRcvQueue_ db connId rq subMode serverKeyHash_ addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) addConnSndQueue db connId sq = @@ -783,13 +800,12 @@ getInvShortLink db server linkId = (host server, port server, linkId) where toInvShortLink :: (LinkKey, C.APrivateAuthKey, Maybe SenderId) -> InvShortLink - toInvShortLink (linkKey, sndPrivateKey@(C.APrivateAuthKey a pk), sndId) = - let sndPublicKey = C.APublicAuthKey a $ C.publicKey pk - in InvShortLink {server, linkId, linkKey, sndPrivateKey, sndPublicKey, sndId} + toInvShortLink (linkKey, sndPrivateKey, sndId) = + InvShortLink {server, linkId, linkKey, sndPrivateKey, sndId} -getInvShortLinkKeys :: DB.Connection -> SMPServer -> SenderId -> IO (Maybe (LinkId, C.AAuthKeyPair)) +getInvShortLinkKeys :: DB.Connection -> SMPServer -> SenderId -> IO (Maybe (LinkId, C.APrivateAuthKey)) getInvShortLinkKeys db srv sndId = - maybeFirstRow toSndKeys $ + maybeFirstRow id $ DB.query db [sql| @@ -798,9 +814,6 @@ getInvShortLinkKeys db srv sndId = WHERE host = ? AND port = ? AND snd_id = ? |] (host srv, port srv, sndId) - where - toSndKeys :: (LinkId, C.APrivateAuthKey) -> (LinkId, C.AAuthKeyPair) - toSndKeys (linkId, privKey@(C.APrivateAuthKey a pk)) = (linkId, (C.APublicAuthKey a $ C.publicKey pk, privKey)) deleteInvShortLink :: DB.Connection -> SMPServer -> LinkId -> IO () deleteInvShortLink db srv lnkId = @@ -887,8 +900,8 @@ createSndMsg db connId sndMsgData@SndMsgData {internalSndId, internalHash} = do insertSndMsgDetails_ db connId sndMsgData updateSndMsgHash db connId internalSndId internalHash -createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO () -createSndMsgDelivery db connId SndQueue {dbQueueId} msgId = +createSndMsgDelivery :: DB.Connection -> SndQueue -> InternalId -> IO () +createSndMsgDelivery db SndQueue {connId, dbQueueId} msgId = DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId) getSndMsgViaRcpt :: DB.Connection -> ConnId -> InternalSndId -> IO (Either StoreError SndMsg) @@ -920,6 +933,15 @@ getConnectionsForDelivery :: DB.Connection -> IO [ConnId] getConnectionsForDelivery db = map fromOnly <$> DB.query_ db "SELECT DISTINCT conn_id FROM snd_message_deliveries WHERE failed = 0" +getAllSndQueuesForDelivery :: DB.Connection -> IO [SndQueue] +getAllSndQueuesForDelivery db = map toSndQueue <$> DB.query_ db (sndQueueQuery <> " " <> delivery) + where + delivery = [sql| + JOIN (SELECT DISTINCT conn_id, snd_queue_id FROM snd_message_deliveries WHERE failed = 0) d + ON d.conn_id = q.conn_id AND d.snd_queue_id = q.snd_queue_id + WHERE c.deleted = 0 + |] + getPendingQueueMsg :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError (Maybe (Maybe RcvQueue, PendingMsgData))) getPendingQueueMsg db connId SndQueue {dbQueueId} = getWorkItem "message" getMsgId getMsgData markMsgFailed @@ -1322,6 +1344,21 @@ getPendingCommandServers db connIds = smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash conns = S.fromList connIds +getAllPendingCommandConns :: DB.Connection -> IO [(ConnId, Maybe SMPServer)] +getAllPendingCommandConns db = + map toResult + <$> DB.query_ + db + [sql| + SELECT DISTINCT c.conn_id, c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) + FROM commands c + JOIN connections cs ON c.conn_id = cs.conn_id + LEFT JOIN servers s ON s.host = c.host AND s.port = c.port + WHERE cs.deleted = 0 + |] + where + toResult (connId, host, port, keyHash) = (connId, SMPServer <$> host <*> port <*> keyHash) + getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed where @@ -1958,8 +1995,8 @@ upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do -- * createRcvConn helpers -insertRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> Maybe C.KeyHash -> IO RcvQueue -insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do +insertRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> Maybe C.KeyHash -> IO RcvQueue +insertRcvQueue_ db connId' rq@RcvQueue {..} subMode serverKeyHash_ = do -- to preserve ID if the queue already exists. -- possibly, it can be done in one query. currQId_ <- maybeFirstRow fromOnly $ DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? AND host = ? AND port = ? AND snd_id = ?" (connId', host server, port server, sndId) @@ -1969,19 +2006,20 @@ insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do [sql| INSERT INTO rcv_queues ( host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, - snd_id, queue_mode, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash, + snd_id, queue_mode, status, to_subscribe, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash, link_id, link_key, link_priv_sig_key, link_enc_fixed_data, ntf_public_key, ntf_private_key, ntf_id, rcv_ntf_dh_secret - ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); |] ( (host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) - :. (sndId, queueMode, status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_) + :. (sndId, queueMode, status, BI toSubscribe, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_) :. (shortLinkId <$> shortLink, shortLinkKey <$> shortLink, linkPrivSigKey <$> shortLink, linkEncFixedData <$> shortLink) :. ntfCredsFields ) -- TODO [certs rcv] save client service pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId, clientService = Nothing} where + toSubscribe = subMode == SMOnlyCreate ntfCredsFields = case clientNtfCreds of Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) @@ -1999,16 +2037,15 @@ insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do db [sql| INSERT INTO snd_queues - (host, port, snd_id, queue_mode, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, + (host, port, snd_id, queue_mode, conn_id, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?) ON CONFLICT (host, port, snd_id) DO UPDATE SET host=EXCLUDED.host, port=EXCLUDED.port, snd_id=EXCLUDED.snd_id, queue_mode=EXCLUDED.queue_mode, conn_id=EXCLUDED.conn_id, - snd_public_key=EXCLUDED.snd_public_key, snd_private_key=EXCLUDED.snd_private_key, e2e_pub_key=EXCLUDED.e2e_pub_key, e2e_dh_secret=EXCLUDED.e2e_dh_secret, @@ -2019,7 +2056,7 @@ insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do smp_client_version=EXCLUDED.smp_client_version, server_key_hash=EXCLUDED.server_key_hash |] - ((host server, port server, sndId, queueMode, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) + ((host server, port server, sndId, queueMode, connId', sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) pure (sq :: NewSndQueue) {connId = connId', dbQueueId = qId} @@ -2027,8 +2064,108 @@ newQueueId_ :: [Only Int64] -> DBEntityId newQueueId_ [] = DBEntityId 1 newQueueId_ (Only maxId : _) = DBEntityId (maxId + 1) +-- * subscribe all connections + +getClientNotices :: DB.Connection -> [SMPServer] -> IO (Map (Maybe SMPServer) (Maybe SystemSeconds)) +getClientNotices db presetSrvs = + M.map expiresAt . foldl' addNotice M.empty + <$> DB.query_ + db + [sql| + SELECT n.host, n.port, n.entity_id, COALESCE(n.server_key_hash, s.key_hash), n.created_at, n.notice_ttl + FROM client_notices n + JOIN servers s ON n.host = s.host AND n.port = s.port + WHERE n.protocol = 'smp' + |] + where + expiresAt (createdAt, ttl) = RoundedSystemTime . (createdAt +) <$> ttl + addNotice :: + Map (Maybe SMPServer) (Int64, Maybe Int64) -> + (NonEmpty TransportHost, ServiceName, RecipientId, C.KeyHash, Int64, Maybe Int64) -> + Map (Maybe SMPServer) (Int64, Maybe Int64) + addNotice m (host, port, _, keyHash, createdAt', ttl') = + let srv = SMPServer host port keyHash + srvKey + | isPresetServer srv presetSrvs = Nothing + | otherwise = Just srv + in M.alter (Just . addNoticeHost) srvKey m + where + -- sum of ttls starting from the latest createdAt + addNoticeHost :: Maybe (Int64, Maybe Int64) -> (Int64, Maybe Int64) + addNoticeHost = \case + Just (createdAt, ttl) -> (max createdAt createdAt', (+) <$> ttl <*> ttl') + Nothing -> (createdAt', ttl') + +updateClientNotices :: DB.Connection -> SMPTransportSession -> SystemSeconds -> [(RcvQueueSub, Maybe ClientNotice)] -> IO [(RecipientId, Maybe NoticeId)] +updateClientNotices db (_, srv, _) now = + mapM $ \(rq, notice_) -> maybe (deleteNotice rq) (upsertNotice rq) notice_ + where + deleteNotice RcvQueueSub {rcvId, clientNoticeId} = do + mapM_ (DB.execute db "DELETE FROM client_notices WHERE client_notice_id = ?" . Only) clientNoticeId + pure (rcvId, Nothing) + upsertNotice RcvQueueSub {rcvId, server} ClientNotice {ttl} = + getServerKeyHash_ db server >>= \case + Left _ -> pure (rcvId, Nothing) + Right keyHash_ -> do + noticeId_ <- + maybeFirstRow fromOnly $ + DB.query + db + [sql| + INSERT INTO client_notices(protocol, host, port, entity_id, server_key_hash, notice_ttl, created_at, updated_at) + VALUES ('smp',?,?,?,?,?,?,?) + ON CONFLICT (protocol, host, port, entity_id) + DO UPDATE SET + server_key_hash = EXCLUDED.server_key_hash, + notice_ttl = EXCLUDED.notice_ttl, + updated_at = EXCLUDED.updated_at + RETURNING client_notice_id + |] + (host srv, port srv, rcvId, keyHash_, ttl, now, now) + forM_ noticeId_ $ \noticeId -> do + DB.execute + db + "UPDATE rcv_queues SET client_notice_id = ? WHERE host = ? AND port = ?AND rcv_id = ?" + (noticeId, host srv, port srv, rcvId) + pure (rcvId, noticeId_) + +getSubscriptionServers :: DB.Connection -> Bool -> IO [(UserId, SMPServer)] +getSubscriptionServers db onlyNeeded = + map toUserServer <$> DB.query_ db (select <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0") + where + select = + [sql| + SELECT DISTINCT c.user_id, q.host, q.port, COALESCE(q.server_key_hash, s.key_hash) + FROM rcv_queues q + JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN connections c ON q.conn_id = c.conn_id + |] + toSubscribe + | onlyNeeded = " WHERE q.to_subscribe = 1 AND " + | otherwise = " WHERE " + toUserServer :: (UserId, NonEmpty TransportHost, ServiceName, C.KeyHash) -> (UserId, SMPServer) + toUserServer (userId, host, port, keyHash) = (userId, SMPServer host port keyHash) + +getUserServerRcvQueueSubs :: DB.Connection -> UserId -> SMPServer -> Bool -> IO [RcvQueueSub] +getUserServerRcvQueueSubs db userId srv onlyNeeded = + map toRcvQueueSub + <$> DB.query + db + (rcvQueueSubQuery <> toSubscribe <> " c.deleted = 0 AND q.deleted = 0 AND c.user_id = ? AND q.host = ? AND q.port = ?") + (userId, host srv, port srv) + where + toSubscribe + | onlyNeeded = " WHERE q.to_subscribe = 1 AND " + | otherwise = " WHERE " + +unsetQueuesToSubscribe :: DB.Connection -> IO () +unsetQueuesToSubscribe db = DB.execute_ db "UPDATE rcv_queues SET to_subscribe = 0 WHERE to_subscribe = 1" + -- * getConn helpers +getConnIds :: DB.Connection -> IO [ConnId] +getConnIds db = map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted = 0" + getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getConn = getAnyConn False {-# INLINE getConn #-} @@ -2038,11 +2175,18 @@ getDeletedConn = getAnyConn True {-# INLINE getDeletedConn #-} getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getAnyConn deleted' db connId = +getAnyConn = getAnyConn_ getRcvQueuesByConnId_ getSndQueuesByConnId_ +{-# INLINE getAnyConn #-} + +getAnyConn_ :: + (DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) -> + (DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) -> + (Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq))) +getAnyConn_ getRQs getSQs deleted' db connId = getConnData deleted' db connId >>= \case Just (cData, cMode) -> do - rQ <- getRcvQueuesByConnId_ db connId - sQ <- getSndQueuesByConnId_ db connId + rQ <- getRQs db connId + sQ <- getSQs db connId pure $ case (rQ, sQ, cMode) of (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) @@ -2053,36 +2197,34 @@ getAnyConn deleted' db connId = Nothing -> pure $ Left SEConnNotFound getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getConns = getAnyConns_ False +getConns = getAnyConns False {-# INLINE getConns #-} getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getDeletedConns = getAnyConns_ True +getDeletedConns = getAnyConns True {-# INLINE getDeletedConns #-} #if defined(dbPostgres) -getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getAnyConns_ deleted' db connIds = do +getAnyConns :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError (SomeConn)] +getAnyConns = getAnyConns_ getRcvQueuesByConnIds_ getSndQueuesByConnIds_ +{-# INLINE getAnyConns #-} + +getConnSubs :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConnSub] +getConnSubs = getAnyConns_ getRcvQueueSubsByConnIds_ getSndQueuesByConnIds_ False +{-# INLINE getConnSubs #-} + +getAnyConns_ :: + forall rq sq. + (DB.Connection -> [ConnId] -> IO (Map ConnId (NonEmpty rq))) -> + (DB.Connection -> [ConnId] -> IO (Map ConnId (NonEmpty sq))) -> + (Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError (SomeConn' rq sq)]) +getAnyConns_ getRQs getSQs deleted' db connIds = do cs <- getConnsData_ deleted' db connIds let connIds' = M.keys cs - rQs :: Map ConnId (NonEmpty RcvQueue) <- getRcvQueuesByConnIds_ connIds' - sQs :: Map ConnId (NonEmpty SndQueue) <- getSndQueuesByConnIds_ connIds' + rQs :: Map ConnId (NonEmpty rq) <- getRQs db connIds' + sQs :: Map ConnId (NonEmpty sq) <- getSQs db connIds' pure $ map (result cs rQs sQs) connIds where - getRcvQueuesByConnIds_ connIds' = - toQueueMap primaryFirst toRcvQueue - <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds')) - where - primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = - compare (Down p) (Down p') <> compare i i' - getSndQueuesByConnIds_ connIds' = - toQueueMap primaryFirst toSndQueue - <$> DB.query db (sndQueueQuery <> " WHERE q.conn_id IN ?") (Only (In connIds')) - where - primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = - compare (Down p) (Down p') <> compare i i' - toQueueMap primaryFst toQueue = - M.fromList . map (\qs@(q :| _) -> (qConnId q, L.sortBy primaryFst qs)) . groupOn' qConnId . sortOn qConnId . map toQueue result cs rQs sQs connId = case M.lookup connId cs of Just (cData, cMode) -> case (M.lookup connId rQs, M.lookup connId sQs, cMode) of (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) @@ -2093,6 +2235,22 @@ getAnyConns_ deleted' db connIds = do _ -> Left SEConnNotFound Nothing -> Left SEConnNotFound +getRcvQueuesByConnIds_ :: DB.Connection -> [ConnId] -> IO (Map ConnId (NonEmpty RcvQueue)) +getRcvQueuesByConnIds_ db connIds' = + toQueueMap toRcvQueue <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds')) + +getSndQueuesByConnIds_ :: DB.Connection -> [ConnId] -> IO (Map ConnId (NonEmpty SndQueue)) +getSndQueuesByConnIds_ db connIds' = + toQueueMap toSndQueue <$> DB.query db (sndQueueQuery <> " WHERE q.conn_id IN ?") (Only (In connIds')) + +getRcvQueueSubsByConnIds_ :: DB.Connection -> [ConnId] -> IO (Map ConnId (NonEmpty RcvQueueSub)) +getRcvQueueSubsByConnIds_ db connIds' = + toQueueMap toRcvQueueSub <$> DB.query db (rcvQueueSubQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds')) + +toQueueMap :: SMPQueueRec q => (a -> q) -> [a] -> Map ConnId (NonEmpty q) +toQueueMap toQueue = + M.fromList . map (\qs@(q :| _) -> (qConnId q, L.sortBy primaryFirst qs)) . groupOn' qConnId . sortOn qConnId . map toQueue + getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] getConnsData db connIds = do cs <- getConnsData_ False db connIds @@ -2112,8 +2270,19 @@ getConnsData_ deleted' db connIds = (In connIds, BI deleted') #else -getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db +getAnyConns :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] +getAnyConns = getAnyConns_ getRcvQueuesByConnId_ getSndQueuesByConnId_ +{-# INLINE getAnyConns #-} + +getConnSubs :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConnSub] +getConnSubs = getAnyConns_ getRcvQueueSubsByConnId_ getSndQueuesByConnId_ False +{-# INLINE getConnSubs #-} + +getAnyConns_ :: + (DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) -> + (DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) -> + (Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError (SomeConn' rq sq)]) +getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' db getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db @@ -2192,16 +2361,16 @@ getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue getRcvQueuesByConnId_ db connId = L.nonEmpty . sortBy primaryFirst . map toRcvQueue <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.deleted = 0") (Only connId) - where - primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = - -- the current primary queue is ordered first, the next primary - second - compare (Down p) (Down p') <> compare i i' + +-- the current primary queue is ordered first, the next primary - second +primaryFirst :: SMPQueueRec q => q -> q -> Ordering +primaryFirst q q' = compare (Down (qPrimary q)) (Down (qPrimary q')) <> compare (dbReplaceQId q) (dbReplaceQId q') rcvQueueQuery :: Query rcvQueueQuery = [sql| SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, - q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.queue_mode, q.status, + q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.queue_mode, q.status, c.enable_ntfs, q.client_notice_id, q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.switch_status, q.smp_client_version, q.delete_errors, q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret, q.link_id, q.link_key, q.link_priv_sig_key, q.link_enc_fixed_data @@ -2212,13 +2381,13 @@ rcvQueueQuery = toRcvQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, Maybe QueueMode) - :. (QueueStatus, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) + :. (QueueStatus, Maybe BoolInt, Maybe NoticeId, DBEntityId, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) :. (Maybe SMP.LinkId, Maybe LinkKey, Maybe C.PrivateKeyEd25519, Maybe EncDataBytes) -> RcvQueue toRcvQueue ( (userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode) - :. (status, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) + :. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) :. (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_) ) = @@ -2230,8 +2399,30 @@ toRcvQueue shortLink = case (shortLinkId_, shortLinkKey_, linkPrivSigKey_, linkEncFixedData_) of (Just shortLinkId, Just shortLinkKey, Just linkPrivSigKey, Just linkEncFixedData) -> Just ShortLinkCreds {shortLinkId, shortLinkKey, linkPrivSigKey, linkEncFixedData} _ -> Nothing + enableNtfs = maybe True unBI enableNtfs_ -- TODO [certs rcv] read client service - in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, clientService = Nothing, status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} + in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, queueMode, shortLink, clientService = Nothing, status, enableNtfs, clientNoticeId, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} + +-- | returns all connection queue credentials, the first queue is the primary one +getRcvQueueSubsByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueueSub)) +getRcvQueueSubsByConnId_ db connId = + L.nonEmpty . sortBy primaryFirst . map toRcvQueueSub + <$> DB.query db (rcvQueueSubQuery <> " WHERE q.conn_id = ? AND q.deleted = 0") (Only connId) + +rcvQueueSubQuery :: Query +rcvQueueSubQuery = + [sql| + SELECT c.user_id, q.conn_id, q.host, q.port, COALESCE(q.server_key_hash, s.key_hash), q.rcv_id, q.rcv_private_key, q.status, c.enable_ntfs, q.client_notice_id, + q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id + FROM rcv_queues q + JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN connections c ON q.conn_id = c.conn_id + |] + +toRcvQueueSub :: (UserId, ConnId, NonEmpty TransportHost, ServiceName, C.KeyHash, SMP.RecipientId, SMP.RcvPrivateAuthKey) :. (QueueStatus, Maybe BoolInt, Maybe NoticeId, Int64, BoolInt, Maybe Int64) -> RcvQueueSub +toRcvQueueSub ((userId, connId, host, port, keyHash, rcvId, rcvPrivateKey) :. (status, enableNtfs_, clientNoticeId, dbQueueId, BI primary, dbReplaceQueueId)) = + let enableNtfs = maybe True unBI enableNtfs_ + in RcvQueueSub {userId, connId, server = SMPServer host port keyHash, rcvId, rcvPrivateKey, status, enableNtfs, clientNoticeId, dbQueueId, primary, dbReplaceQueueId} getRcvQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue) getRcvQueueById db connId dbRcvId = @@ -2243,17 +2434,13 @@ getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue getSndQueuesByConnId_ dbConn connId = L.nonEmpty . sortBy primaryFirst . map toSndQueue <$> DB.query dbConn (sndQueueQuery <> " WHERE q.conn_id = ?") (Only connId) - where - primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = - -- the current primary queue is ordered first, the next primary - second - compare (Down p) (Down p') <> compare i i' sndQueueQuery :: Query sndQueueQuery = [sql| SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.snd_id, q.queue_mode, - q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, + q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.switch_status, q.smp_client_version FROM snd_queues q JOIN servers s ON q.host = s.host AND q.port = s.port @@ -2262,17 +2449,16 @@ sndQueueQuery = toSndQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId, Maybe QueueMode) - :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) + :. (SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) :. (DBEntityId, BoolInt, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> SndQueue toSndQueue ( (userId, keyHash, connId, host, port, sndId, queueMode) - :. (sndPubKey, sndPrivateKey@(C.APrivateAuthKey a pk), e2ePubKey, e2eDhSecret, status) + :. (sndPrivateKey, e2ePubKey, e2eDhSecret, status) :. (dbQueueId, BI primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion) ) = let server = SMPServer host port keyHash - sndPublicKey = fromMaybe (C.APublicAuthKey a (C.publicKey pk)) sndPubKey - in SndQueue {userId, connId, server, sndId, queueMode, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion} + in SndQueue {userId, connId, server, sndId, queueMode, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion} getSndQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError SndQueue) getSndQueueById db connId dbSndId = @@ -2581,6 +2767,7 @@ getRcvFile db rcvFileId = runExceptT $ do SELECT rcv_file_chunk_id, chunk_no, chunk_size, digest, tmp_path FROM rcv_file_chunks WHERE rcv_file_id = ? + ORDER BY chunk_no ASC |] (Only rcvFileId) forM chunks $ \chunk@RcvFileChunk {rcvChunkId} -> do diff --git a/src/Simplex/Messaging/Agent/Store/Postgres.hs b/src/Simplex/Messaging/Agent/Store/Postgres.hs index 18b1a7a2d..807b24a29 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres.hs @@ -8,6 +8,7 @@ module Simplex.Messaging.Agent.Store.Postgres ( DBOpts (..), Migrations.getCurrentMigrations, checkSchemaExists, + migrateDBSchema, createDBStore, closeDBStore, reopenDBStore, @@ -38,18 +39,20 @@ import System.Exit (exitFailure) -- If passed schema does not exist in connectInfo database, it will be created. -- Applies necessary migrations to schema. createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore) -createDBStore opts migrations MigrationConfig {confirm} = do +createDBStore opts migrations migrationConfig = do st <- connectPostgresStore opts - r <- migrateSchema st `onException` closeDBStore st + r <- migrateDBSchema st opts Nothing migrations migrationConfig `onException` closeDBStore st case r of Right () -> pure $ Right st Left e -> closeDBStore st $> Left e - where - migrateSchema st = - let initialize = Migrations.initialize st - getCurrent = withTransaction st Migrations.getCurrentMigrations - dbm = DBMigrate {initialize, getCurrent, run = Migrations.run st, backup = Nothing} - in sharedMigrateSchema dbm (dbNew st) migrations confirm + +migrateDBSchema :: DBStore -> DBOpts -> Maybe Query -> [Migration] -> MigrationConfig -> IO (Either MigrationError ()) +migrateDBSchema st _opts migrationsTable migrations MigrationConfig {confirm} = + let initialize = Migrations.initialize st migrationsTable + getCurrent = withTransaction st $ Migrations.getCurrentMigrations migrationsTable + run = Migrations.run st migrationsTable + dbm = DBMigrate {initialize, getCurrent, run, backup = Nothing} + in sharedMigrateSchema dbm (dbNew st) migrations confirm connectPostgresStore :: DBOpts -> IO DBStore connectPostgresStore DBOpts {connstr, schema, poolSize, createSchema} = do diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs index d6f552937..a258c9a46 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs @@ -14,55 +14,50 @@ where import Control.Exception (throwIO) import Control.Monad (void) import qualified Data.ByteString.Char8 as B +import Data.Maybe (fromMaybe) import qualified Data.Text as T import qualified Data.Text.Encoding as TE import Data.Time.Clock (getCurrentTime) import qualified Database.PostgreSQL.LibPQ as LibPQ -import Database.PostgreSQL.Simple (Only (..)) +import Database.PostgreSQL.Simple (Only (..), Query) import qualified Database.PostgreSQL.Simple as PSQL import Database.PostgreSQL.Simple.Internal (Connection (..)) -import Database.PostgreSQL.Simple.SqlQQ (sql) import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Shared import Simplex.Messaging.Util (($>>=)) import UnliftIO.MVar -initialize :: DBStore -> IO () -initialize st = withTransaction' st $ \db -> - void $ - PSQL.execute_ - db - [sql| - CREATE TABLE IF NOT EXISTS migrations ( - name TEXT NOT NULL, - ts TIMESTAMP NOT NULL, - down TEXT, - PRIMARY KEY (name) - ) - |] +initialize :: DBStore -> Maybe Query -> IO () +initialize st migrationsTable = withTransaction' st $ \db -> + void $ PSQL.execute_ db $ + "CREATE TABLE IF NOT EXISTS " + <> fromMaybe "migrations" migrationsTable + <> " (name TEXT NOT NULL PRIMARY KEY, ts TIMESTAMP NOT NULL, down TEXT)" -run :: DBStore -> MigrationsToRun -> IO () -run st = \case +run :: DBStore -> Maybe Query -> MigrationsToRun -> IO () +run st migrationsTable = \case MTRUp [] -> pure () MTRUp ms -> mapM_ runUp ms MTRDown ms -> mapM_ runDown $ reverse ms MTRNone -> pure () where + table = fromMaybe "migrations" migrationsTable runUp Migration {name, up, down} = withTransaction' st $ \db -> do insert db execSQL db up where - insert db = void $ PSQL.execute db "INSERT INTO migrations (name, down, ts) VALUES (?,?,?)" . (name,down,) =<< getCurrentTime + insert db = void $ PSQL.execute db ("INSERT INTO " <> table <> " (name, down, ts) VALUES (?,?,?)") . (name,down,) =<< getCurrentTime runDown DownMigration {downName, downQuery} = withTransaction' st $ \db -> do execSQL db downQuery - void $ PSQL.execute db "DELETE FROM migrations WHERE name = ?" (Only downName) + void $ PSQL.execute db ("DELETE FROM " <> table <> " WHERE name = ?") (Only downName) execSQL db query = withMVar (connectionHandle db) $ \pqConn -> LibPQ.exec pqConn (TE.encodeUtf8 query) $>>= LibPQ.resultErrorMessage >>= \case Just e | not (B.null e) -> throwIO $ userError $ B.unpack e _ -> pure () -getCurrentMigrations :: PSQL.Connection -> IO [Migration] -getCurrentMigrations db = map toMigration <$> PSQL.query_ db "SELECT name, down FROM migrations ORDER BY name ASC;" +getCurrentMigrations :: Maybe Query -> PSQL.Connection -> IO [Migration] +getCurrentMigrations migrationsTable db = map toMigration <$> PSQL.query_ db ("SELECT name, down FROM " <> table <> " ORDER BY name ASC;") where + table = fromMaybe "migrations" migrationsTable toMigration (name, down) = Migration {name, up = T.pack "", down} diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs index 565a06760..011d89031 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/App.hs @@ -8,6 +8,8 @@ import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250203_msg_bodies import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250322_short_links import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete +import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe +import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices import Simplex.Messaging.Agent.Store.Shared (Migration (..)) schemaMigrations :: [(String, Text, Maybe Text)] @@ -15,7 +17,9 @@ schemaMigrations = [ ("20241210_initial", m20241210_initial, Nothing), ("20250203_msg_bodies", m20250203_msg_bodies, Just down_m20250203_msg_bodies), ("20250322_short_links", m20250322_short_links, Just down_m20250322_short_links), - ("20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete) + ("20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete), + ("20251009_queue_to_subscribe", m20251009_queue_to_subscribe, Just down_m20251009_queue_to_subscribe), + ("20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs index 6b760342f..6f6b5f834 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs @@ -1,15 +1,14 @@ +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial where import Data.Text (Text) -import qualified Data.Text as T import Text.RawString.QQ (r) m20241210_initial :: Text m20241210_initial = - T.pack - [r| + [r| CREATE TABLE users( user_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, deleted SMALLINT NOT NULL DEFAULT 0 diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250203_msg_bodies.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250203_msg_bodies.hs index 848566b77..6a0d85e45 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250203_msg_bodies.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250203_msg_bodies.hs @@ -1,15 +1,14 @@ +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250203_msg_bodies where import Data.Text (Text) -import qualified Data.Text as T import Text.RawString.QQ (r) m20250203_msg_bodies :: Text m20250203_msg_bodies = - T.pack - [r| + [r| ALTER TABLE snd_messages ADD COLUMN msg_encrypt_key BYTEA; ALTER TABLE snd_messages ADD COLUMN padded_msg_len BIGINT; @@ -25,8 +24,7 @@ CREATE INDEX idx_snd_messages_snd_message_body_id ON snd_messages(snd_message_bo down_m20250203_msg_bodies :: Text down_m20250203_msg_bodies = - T.pack - [r| + [r| DROP INDEX idx_snd_messages_snd_message_body_id; ALTER TABLE snd_messages DROP COLUMN snd_message_body_id; DROP TABLE snd_message_bodies; diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250322_short_links.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250322_short_links.hs index be627ea0a..28394d51b 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250322_short_links.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250322_short_links.hs @@ -1,15 +1,14 @@ +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250322_short_links where import Data.Text (Text) -import qualified Data.Text as T import Text.RawString.QQ (r) m20250322_short_links :: Text m20250322_short_links = - T.pack - [r| + [r| ALTER TABLE rcv_queues ADD COLUMN link_id BYTEA; ALTER TABLE rcv_queues ADD COLUMN link_key BYTEA; ALTER TABLE rcv_queues ADD COLUMN link_priv_sig_key BYTEA; @@ -42,8 +41,7 @@ CREATE UNIQUE INDEX idx_inv_short_links_link_id ON inv_short_links(host, port, l down_m20250322_short_links :: Text down_m20250322_short_links = - T.pack - [r| + [r| DROP INDEX idx_rcv_queues_link_id; ALTER TABLE rcv_queues DROP COLUMN link_id; ALTER TABLE rcv_queues DROP COLUMN link_key; diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250702_conn_invitations_remove_cascade_delete.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250702_conn_invitations_remove_cascade_delete.hs index a61a60d5d..8d4628673 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250702_conn_invitations_remove_cascade_delete.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20250702_conn_invitations_remove_cascade_delete.hs @@ -1,15 +1,14 @@ +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20250702_conn_invitations_remove_cascade_delete where import Data.Text (Text) -import qualified Data.Text as T import Text.RawString.QQ (r) m20250702_conn_invitations_remove_cascade_delete :: Text m20250702_conn_invitations_remove_cascade_delete = - T.pack - [r| + [r| ALTER TABLE conn_invitations DROP CONSTRAINT conn_invitations_contact_conn_id_fkey; ALTER TABLE conn_invitations ALTER COLUMN contact_conn_id DROP NOT NULL; @@ -23,8 +22,7 @@ ALTER TABLE conn_invitations down_m20250702_conn_invitations_remove_cascade_delete :: Text down_m20250702_conn_invitations_remove_cascade_delete = - T.pack - [r| + [r| ALTER TABLE conn_invitations DROP CONSTRAINT conn_invitations_contact_conn_id_fkey; ALTER TABLE conn_invitations ALTER COLUMN contact_conn_id SET NOT NULL; diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251009_queue_to_subscribe.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251009_queue_to_subscribe.hs new file mode 100644 index 000000000..cac622586 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251009_queue_to_subscribe.hs @@ -0,0 +1,21 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251009_queue_to_subscribe where + +import Data.Text (Text) +import Text.RawString.QQ (r) + +m20251009_queue_to_subscribe :: Text +m20251009_queue_to_subscribe = + [r| +ALTER TABLE rcv_queues ADD COLUMN to_subscribe SMALLINT NOT NULL DEFAULT 0; +CREATE INDEX idx_rcv_queues_to_subscribe ON rcv_queues(to_subscribe); +|] + +down_m20251009_queue_to_subscribe :: Text +down_m20251009_queue_to_subscribe = + [r| +DROP INDEX idx_rcv_queues_to_subscribe; +ALTER TABLE rcv_queues DROP COLUMN to_subscribe; +|] diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251010_client_notices.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251010_client_notices.hs new file mode 100644 index 000000000..ba34c8299 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20251010_client_notices.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20251010_client_notices where + +import Data.Text (Text) +import Text.RawString.QQ (r) + +m20251010_client_notices :: Text +m20251010_client_notices = + [r| +CREATE TABLE client_notices( + client_notice_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + protocol TEXT NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + entity_id BYTEA NOT NULL, + server_key_hash BYTEA, + notice_ttl BIGINT, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL +); + +CREATE UNIQUE INDEX idx_client_notices_entity ON client_notices(protocol, host, port, entity_id); + +ALTER TABLE rcv_queues ADD COLUMN client_notice_id BIGINT +REFERENCES client_notices ON UPDATE RESTRICT ON DELETE SET NULL; + +CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id); +|] + +down_m20251010_client_notices :: Text +down_m20251010_client_notices = + [r| +DROP INDEX idx_rcv_queues_client_notice_id; +ALTER TABLE rcv_queues DROP COLUMN client_notice_id; + +DROP INDEX idx_client_notices_entity; +DROP TABLE client_notices; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 6203357fc..688eae0d2 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -27,6 +27,7 @@ module Simplex.Messaging.Agent.Store.SQLite ( DBOpts (..), Migrations.getCurrentMigrations, + migrateDBSchema, createDBStore, closeDBStore, reopenDBStore, @@ -68,25 +69,27 @@ import UnliftIO.STM -- * SQLite Store implementation createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore) -createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations MigrationConfig {confirm, backupPath} = do +createDBStore opts@DBOpts {dbFilePath, dbKey, keepKey, track} migrations migrationConfig = do let dbDir = takeDirectory dbFilePath createDirectoryIfMissing True dbDir st <- connectSQLiteStore dbFilePath dbKey keepKey track - r <- migrateSchema st `onException` closeDBStore st + r <- migrateDBSchema st opts Nothing migrations migrationConfig `onException` closeDBStore st case r of Right () -> pure $ Right st Left e -> closeDBStore st $> Left e where - migrateSchema st = - let initialize = Migrations.initialize st - getCurrent = withTransaction st Migrations.getCurrentMigrations - run = Migrations.run st vacuum - backup = mkBackup <$> backupPath - mkBackup bp = - let f = if null bp then dbFilePath else bp takeFileName dbFilePath - in copyFile dbFilePath $ f <> ".bak" - dbm = DBMigrate {initialize, getCurrent, run, backup} - in sharedMigrateSchema dbm (dbNew st) migrations confirm + +migrateDBSchema :: DBStore -> DBOpts -> Maybe Query -> [Migration] -> MigrationConfig -> IO (Either MigrationError ()) +migrateDBSchema st DBOpts {dbFilePath, vacuum} migrationsTable migrations MigrationConfig {confirm, backupPath} = + let initialize = Migrations.initialize st migrationsTable + getCurrent = withTransaction st $ Migrations.getCurrentMigrations migrationsTable + run = Migrations.run st migrationsTable vacuum + backup = mkBackup <$> backupPath + mkBackup bp = + let f = if null bp then dbFilePath else bp takeFileName dbFilePath + in copyFile dbFilePath $ f <> ".bak" + dbm = DBMigrate {initialize, getCurrent, run, backup} + in sharedMigrateSchema dbm (dbNew st) migrations confirm connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> DB.TrackQueries -> IO DBStore connectSQLiteStore dbFilePath key keepKey track = do diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index fb0523d08..3f3091ee1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -17,12 +17,12 @@ where import Control.Monad (forM_, when) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map.Strict as M +import Data.Maybe (fromMaybe) import Data.Text (Text) import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (getCurrentTime) import Database.SQLite.Simple (Only (..), Query (..)) import qualified Database.SQLite.Simple as SQL -import Database.SQLite.Simple.QQ (sql) import qualified Database.SQLite3 as SQLite3 import Simplex.Messaging.Agent.Protocol (extraSMPServerHosts) import qualified Simplex.Messaging.Agent.Store.DB as DB @@ -32,13 +32,16 @@ import Simplex.Messaging.Agent.Store.Shared import Simplex.Messaging.Encoding.String import Simplex.Messaging.Transport.Client (TransportHost) -getCurrentMigrations :: DB.Connection -> IO [Migration] -getCurrentMigrations DB.Connection {DB.conn} = map toMigration <$> SQL.query_ conn "SELECT name, down FROM migrations ORDER BY name ASC;" +getCurrentMigrations :: Maybe Query -> DB.Connection -> IO [Migration] +getCurrentMigrations migrationsTable DB.Connection {DB.conn} = + map toMigration + <$> SQL.query_ conn ("SELECT name, down FROM " <> table <> " ORDER BY name ASC;") where + table = fromMaybe "migrations" migrationsTable toMigration (name, down) = Migration {name, up = "", down} -run :: DBStore -> Bool -> MigrationsToRun -> IO () -run st vacuum = \case +run :: DBStore -> Maybe Query -> Bool -> MigrationsToRun -> IO () +run st migrationsTable vacuum = \case MTRUp [] -> pure () MTRUp ms -> do mapM_ runUp ms @@ -46,11 +49,12 @@ run st vacuum = \case MTRDown ms -> mapM_ runDown $ reverse ms MTRNone -> pure () where + table = fromMaybe "migrations" migrationsTable runUp Migration {name, up, down} = withTransaction' st $ \db -> do when (name == "m20220811_onion_hosts") $ updateServers db insert db >> execSQL db up' where - insert db = SQL.execute db "INSERT INTO migrations (name, down, ts) VALUES (?,?,?)" . (name,down,) =<< getCurrentTime + insert db = SQL.execute db ("INSERT INTO " <> table <> " (name, down, ts) VALUES (?,?,?)") . (name,down,) =<< getCurrentTime up' | dbNew st && name == "m20230110_users" = fromQuery new_m20230110_users | otherwise = up @@ -59,24 +63,19 @@ run st vacuum = \case in SQL.execute db "UPDATE servers SET host = ? WHERE host = ?" (hs, decodeLatin1 $ strEncode h) runDown DownMigration {downName, downQuery} = withTransaction' st $ \db -> do execSQL db downQuery - SQL.execute db "DELETE FROM migrations WHERE name = ?" (Only downName) + SQL.execute db ("DELETE FROM " <> table <> " WHERE name = ?") (Only downName) execSQL db = SQLite3.exec $ SQL.connectionHandle db -initialize :: DBStore -> IO () -initialize st = withTransaction' st $ \db -> do - cs :: [Text] <- map fromOnly <$> SQL.query_ db "SELECT name FROM pragma_table_info('migrations')" +initialize :: DBStore -> Maybe Query -> IO () +initialize st migrationsTable = withTransaction' st $ \db -> do + cs :: [Text] <- map fromOnly <$> SQL.query_ db ("SELECT name FROM pragma_table_info('" <> table <> "')") case cs of [] -> createMigrations db - _ -> when ("down" `notElem` cs) $ SQL.execute_ db "ALTER TABLE migrations ADD COLUMN down TEXT" + _ -> when ("down" `notElem` cs) $ SQL.execute_ db $ "ALTER TABLE " <> table <> " ADD COLUMN down TEXT" where + table = fromMaybe "migrations" migrationsTable createMigrations db = - SQL.execute_ - db - [sql| - CREATE TABLE IF NOT EXISTS migrations ( - name TEXT NOT NULL, - ts TEXT NOT NULL, - down TEXT, - PRIMARY KEY (name) - ); - |] + SQL.execute_ db $ + "CREATE TABLE IF NOT EXISTS " + <> table + <> " (name TEXT NOT NULL PRIMARY KEY, ts TEXT NOT NULL, down TEXT)" diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs index 9d5d65ea7..7371d9584 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/App.hs @@ -44,6 +44,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241224_ratchet_e2e_snd import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250203_msg_bodies import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250322_short_links import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20250702_conn_invitations_remove_cascade_delete +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251009_queue_to_subscribe +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251010_client_notices import Simplex.Messaging.Agent.Store.Shared (Migration (..)) schemaMigrations :: [(String, Query, Maybe Query)] @@ -87,7 +89,9 @@ schemaMigrations = ("m20241224_ratchet_e2e_snd_params", m20241224_ratchet_e2e_snd_params, Just down_m20241224_ratchet_e2e_snd_params), ("m20250203_msg_bodies", m20250203_msg_bodies, Just down_m20250203_msg_bodies), ("m20250322_short_links", m20250322_short_links, Just down_m20250322_short_links), - ("m20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete) + ("m20250702_conn_invitations_remove_cascade_delete", m20250702_conn_invitations_remove_cascade_delete, Just down_m20250702_conn_invitations_remove_cascade_delete), + ("m20251009_queue_to_subscribe", m20251009_queue_to_subscribe, Just down_m20251009_queue_to_subscribe), + ("m20251010_client_notices", m20251010_client_notices, Just down_m20251010_client_notices) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251009_queue_to_subscribe.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251009_queue_to_subscribe.hs new file mode 100644 index 000000000..76ee208df --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251009_queue_to_subscribe.hs @@ -0,0 +1,20 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251009_queue_to_subscribe where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20251009_queue_to_subscribe :: Query +m20251009_queue_to_subscribe = + [sql| +ALTER TABLE rcv_queues ADD COLUMN to_subscribe INTEGER NOT NULL DEFAULT 0; +CREATE INDEX idx_rcv_queues_to_subscribe ON rcv_queues(to_subscribe); +|] + +down_m20251009_queue_to_subscribe :: Query +down_m20251009_queue_to_subscribe = + [sql| +DROP INDEX idx_rcv_queues_to_subscribe; +ALTER TABLE rcv_queues DROP COLUMN to_subscribe; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251010_client_notices.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251010_client_notices.hs new file mode 100644 index 000000000..dddd92781 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20251010_client_notices.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20251010_client_notices where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20251010_client_notices :: Query +m20251010_client_notices = + [sql| +CREATE TABLE client_notices( + client_notice_id INTEGER PRIMARY KEY AUTOINCREMENT, + protocol TEXT NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + entity_id BLOB NOT NULL, + server_key_hash BLOB, + notice_ttl INTEGER, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL +); + +CREATE UNIQUE INDEX idx_client_notices_entity ON client_notices(protocol, host, port, entity_id); + +ALTER TABLE rcv_queues ADD COLUMN client_notice_id INTEGER +REFERENCES client_notices ON UPDATE RESTRICT ON DELETE SET NULL; + +CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id); +|] + +down_m20251010_client_notices :: Query +down_m20251010_client_notices = + [sql| +DROP INDEX idx_rcv_queues_client_notice_id; +ALTER TABLE rcv_queues DROP COLUMN client_notice_id; + +DROP INDEX idx_client_notices_entity; +DROP TABLE client_notices; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index ad39937cd..d2838a7b0 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -1,8 +1,7 @@ CREATE TABLE migrations( - name TEXT NOT NULL, + name TEXT NOT NULL PRIMARY KEY, ts TEXT NOT NULL, - down TEXT, - PRIMARY KEY(name) + down TEXT ); CREATE TABLE servers( host TEXT NOT NULL, @@ -61,6 +60,9 @@ CREATE TABLE rcv_queues( link_priv_sig_key BLOB, link_enc_fixed_data BLOB, queue_mode TEXT, + to_subscribe INTEGER NOT NULL DEFAULT 0, + client_notice_id INTEGER + REFERENCES client_notices ON UPDATE RESTRICT ON DELETE SET NULL, PRIMARY KEY(host, port, rcv_id), FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE, @@ -437,6 +439,17 @@ CREATE TABLE inv_short_links( snd_id BLOB, FOREIGN KEY(host, port) REFERENCES servers ON DELETE RESTRICT ON UPDATE CASCADE ); +CREATE TABLE client_notices( + client_notice_id INTEGER PRIMARY KEY AUTOINCREMENT, + protocol TEXT NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + entity_id BLOB NOT NULL, + server_key_hash BLOB, + notice_ttl INTEGER, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL +); CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id); CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id); CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id); @@ -572,3 +585,11 @@ CREATE UNIQUE INDEX idx_inv_short_links_link_id ON inv_short_links( port, link_id ); +CREATE INDEX idx_rcv_queues_to_subscribe ON rcv_queues(to_subscribe); +CREATE UNIQUE INDEX idx_client_notices_entity ON client_notices( + protocol, + host, + port, + entity_id +); +CREATE INDEX idx_rcv_queues_client_notice_id ON rcv_queues(client_notice_id); diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs deleted file mode 100644 index 52d67be70..000000000 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ /dev/null @@ -1,120 +0,0 @@ -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE LambdaCase #-} - -module Simplex.Messaging.Agent.TRcvQueues - ( TRcvQueues (getRcvQueues, getConnections), - Queue (..), - empty, - clear, - deleteConn, - hasConn, - addQueue, - batchAddQueues, - deleteQueue, - hasSessQueues, - getSessQueues, - getDelSessQueues, - ) -where - -import Control.Concurrent.STM -import Data.Foldable (foldl') -import Data.List.NonEmpty (NonEmpty (..), (<|)) -import qualified Data.List.NonEmpty as L -import qualified Data.Map.Strict as M -import Simplex.Messaging.Agent.Protocol (ConnId, UserId) -import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..)) -import Simplex.Messaging.Protocol (RecipientId, SMPServer) -import Simplex.Messaging.TMap (TMap) -import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport - -class Queue q where - connId' :: q -> ConnId - qKey :: q -> (UserId, SMPServer, RecipientId) - --- the fields in this record have the same data with swapped keys for lookup efficiency, --- and all methods must maintain this invariant. -data TRcvQueues q = TRcvQueues - { getRcvQueues :: TMap (UserId, SMPServer, RecipientId) q, - getConnections :: TMap ConnId (NonEmpty (UserId, SMPServer, RecipientId)) - } - -empty :: IO (TRcvQueues q) -empty = TRcvQueues <$> TM.emptyIO <*> TM.emptyIO - -clear :: TRcvQueues q -> STM () -clear (TRcvQueues qs cs) = TM.clear qs >> TM.clear cs - -deleteConn :: ConnId -> TRcvQueues q -> STM () -deleteConn cId (TRcvQueues qs cs) = - TM.lookupDelete cId cs >>= \case - Just ks -> modifyTVar' qs $ \qs' -> foldl' (flip M.delete) qs' ks - Nothing -> pure () - -hasConn :: ConnId -> TRcvQueues q -> STM Bool -hasConn cId (TRcvQueues _ cs) = TM.member cId cs - -addQueue :: Queue q => q -> TRcvQueues q -> STM () -addQueue rq (TRcvQueues qs cs) = do - TM.insert k rq qs - TM.alter addQ (connId' rq) cs - where - addQ = Just . maybe (k :| []) (k <|) - k = qKey rq - --- Save time by aggregating modifyTVar' -batchAddQueues :: (Foldable t, Queue q) => TRcvQueues q -> t q -> STM () -batchAddQueues (TRcvQueues qs cs) rqs = do - modifyTVar' qs $ \now -> foldl' (\rqs' rq -> M.insert (qKey rq) rq rqs') now rqs - modifyTVar' cs $ \now -> foldl' (\cs' rq -> M.alter (addQ $ qKey rq) (connId' rq) cs') now rqs - where - addQ k = Just . maybe (k :| []) (k <|) - -deleteQueue :: RcvQueue -> TRcvQueues RcvQueue -> STM () -deleteQueue rq (TRcvQueues qs cs) = do - TM.delete k qs - TM.update delQ (connId rq) cs - where - delQ = L.nonEmpty . L.filter (/= k) - k = qKey rq - -hasSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> STM Bool -hasSessQueues tSess (TRcvQueues qs _) = any (`isSession` tSess) <$> readTVar qs - -getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues RcvQueue -> IO [RcvQueue] -getSessQueues tSess (TRcvQueues qs _) = M.foldl' addQ [] <$> readTVarIO qs - where - addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs' - -getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> SessionId -> TRcvQueues (SessionId, RcvQueue) -> STM ([RcvQueue], [ConnId]) -getDelSessQueues tSess sessId' (TRcvQueues qs cs) = do - (removedQs, qs'') <- (\qs' -> M.foldl' delQ ([], qs') qs') <$> readTVar qs - writeTVar qs $! qs'' - removedConns <- stateTVar cs $ \cs' -> foldl' delConn ([], cs') removedQs - pure (removedQs, removedConns) - where - delQ acc@(removed, qs') (sessId, rq) - | rq `isSession` tSess && sessId == sessId' = (rq : removed, M.delete (qKey rq) qs') - | otherwise = acc - delConn :: ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, RecipientId))) -> RcvQueue -> ([ConnId], M.Map ConnId (NonEmpty (UserId, SMPServer, RecipientId))) - delConn (removed, cs') rq = M.alterF f cId cs' - where - cId = connId rq - f = \case - Just ks -> case L.nonEmpty $ L.filter (qKey rq /=) ks of - Just ks' -> (removed, Just ks') - Nothing -> (cId : removed, Nothing) - Nothing -> (removed, Nothing) -- "impossible" in invariant holds, because we get keys from the known queues - -isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool -isSession rq (uId, srv, connId_) = - userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_ - -instance Queue RcvQueue where - connId' = connId - qKey rq = (userId rq, server rq, rcvId rq) - -instance Queue (SessionId, RcvQueue) where - connId' = connId . snd - qKey = qKey . snd diff --git a/src/Simplex/Messaging/Agent/TSessionSubs.hs b/src/Simplex/Messaging/Agent/TSessionSubs.hs new file mode 100644 index 000000000..cce103fe6 --- /dev/null +++ b/src/Simplex/Messaging/Agent/TSessionSubs.hs @@ -0,0 +1,201 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} + +module Simplex.Messaging.Agent.TSessionSubs + ( TSessionSubs (sessionSubs), + SessSubs (..), + emptyIO, + clear, + hasActiveSub, + hasPendingSub, + addPendingSub, + setSessionId, + addActiveSub, + batchAddActiveSubs, + batchAddPendingSubs, + deletePendingSub, + batchDeletePendingSubs, + deleteSub, + batchDeleteSubs, + hasPendingSubs, + getPendingSubs, + getActiveSubs, + setSubsPending, + updateClientNotices, + foldSessionSubs, + mapSubs, + ) +where + +import Control.Concurrent.STM +import Control.Monad +import Data.Int (Int64) +import Data.List (foldl') +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import Data.Maybe (isJust) +import qualified Data.Set as S +import Simplex.Messaging.Agent.Protocol (SMPQueue (..)) +import Simplex.Messaging.Agent.Store (RcvQueueSub (..), SomeRcvQueue) +import Simplex.Messaging.Client (SMPTransportSession, TransportSessionMode (..)) +import Simplex.Messaging.Protocol (RecipientId) +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Transport +import Simplex.Messaging.Util (($>>=)) + +data TSessionSubs = TSessionSubs + { sessionSubs :: TMap SMPTransportSession SessSubs + } + +data SessSubs = SessSubs + { subsSessId :: TVar (Maybe SessionId), + activeSubs :: TMap RecipientId RcvQueueSub, + pendingSubs :: TMap RecipientId RcvQueueSub + } + +emptyIO :: IO TSessionSubs +emptyIO = TSessionSubs <$> TM.emptyIO +{-# INLINE emptyIO #-} + +clear :: TSessionSubs -> STM () +clear = TM.clear . sessionSubs +{-# INLINE clear #-} + +lookupSubs :: SMPTransportSession -> TSessionSubs -> STM (Maybe SessSubs) +lookupSubs tSess = TM.lookup tSess . sessionSubs +{-# INLINE lookupSubs #-} + +getSessSubs :: SMPTransportSession -> TSessionSubs -> STM SessSubs +getSessSubs tSess ss = lookupSubs tSess ss >>= maybe new pure + where + new = do + s <- SessSubs <$> newTVar Nothing <*> newTVar M.empty <*> newTVar M.empty + TM.insert tSess s $ sessionSubs ss + pure s + +hasActiveSub :: SMPTransportSession -> RecipientId -> TSessionSubs -> STM Bool +hasActiveSub = hasQueue_ activeSubs +{-# INLINE hasActiveSub #-} + +hasPendingSub :: SMPTransportSession -> RecipientId -> TSessionSubs -> STM Bool +hasPendingSub = hasQueue_ pendingSubs +{-# INLINE hasPendingSub #-} + +hasQueue_ :: (SessSubs -> TMap RecipientId RcvQueueSub) -> SMPTransportSession -> RecipientId -> TSessionSubs -> STM Bool +hasQueue_ subs tSess rId ss = isJust <$> (lookupSubs tSess ss $>>= TM.lookup rId . subs) +{-# INLINE hasQueue_ #-} + +addPendingSub :: SMPTransportSession -> RcvQueueSub -> TSessionSubs -> STM () +addPendingSub tSess rq ss = getSessSubs tSess ss >>= TM.insert (rcvId rq) rq . pendingSubs + +setSessionId :: SMPTransportSession -> SessionId -> TSessionSubs -> STM () +setSessionId tSess sessId ss = do + s <- getSessSubs tSess ss + readTVar (subsSessId s) >>= \case + Nothing -> writeTVar (subsSessId s) (Just sessId) + Just sessId' -> unless (sessId == sessId') $ void $ setSubsPending_ s $ Just sessId + +addActiveSub :: SMPTransportSession -> SessionId -> RcvQueueSub -> TSessionSubs -> STM () +addActiveSub tSess sessId rq ss = do + s <- getSessSubs tSess ss + sessId' <- readTVar $ subsSessId s + let rId = rcvId rq + if Just sessId == sessId' + then do + TM.insert rId rq $ activeSubs s + TM.delete rId $ pendingSubs s + else TM.insert rId rq $ pendingSubs s + +batchAddActiveSubs :: SMPTransportSession -> SessionId -> [RcvQueueSub] -> TSessionSubs -> STM () +batchAddActiveSubs tSess sessId rqs ss = do + s <- getSessSubs tSess ss + sessId' <- readTVar $ subsSessId s + let qs = M.fromList $ map (\rq -> (rcvId rq, rq)) rqs + if Just sessId == sessId' + then do + TM.union qs $ activeSubs s + modifyTVar' (pendingSubs s) (`M.difference` qs) + else TM.union qs $ pendingSubs s + +batchAddPendingSubs :: SMPTransportSession -> [RcvQueueSub] -> TSessionSubs -> STM () +batchAddPendingSubs tSess rqs ss = do + s <- getSessSubs tSess ss + modifyTVar' (pendingSubs s) $ M.union $ M.fromList $ map (\rq -> (rcvId rq, rq)) rqs + +deletePendingSub :: SMPTransportSession -> RecipientId -> TSessionSubs -> STM () +deletePendingSub tSess rId = lookupSubs tSess >=> mapM_ (TM.delete rId . pendingSubs) + +batchDeletePendingSubs :: SMPTransportSession -> S.Set RecipientId -> TSessionSubs -> STM () +batchDeletePendingSubs tSess rIds = lookupSubs tSess >=> mapM_ (delete . pendingSubs) + where + delete = (`modifyTVar'` (`M.withoutKeys` rIds)) + +deleteSub :: SMPTransportSession -> RecipientId -> TSessionSubs -> STM () +deleteSub tSess rId = lookupSubs tSess >=> mapM_ (\s -> TM.delete rId (activeSubs s) >> TM.delete rId (pendingSubs s)) + +batchDeleteSubs :: SomeRcvQueue q => SMPTransportSession -> [q] -> TSessionSubs -> STM () +batchDeleteSubs tSess rqs = lookupSubs tSess >=> mapM_ (\s -> delete (activeSubs s) >> delete (pendingSubs s)) + where + rIds = S.fromList $ map queueId rqs + delete = (`modifyTVar'` (`M.withoutKeys` rIds)) + +hasPendingSubs :: SMPTransportSession -> TSessionSubs -> STM Bool +hasPendingSubs tSess = lookupSubs tSess >=> maybe (pure False) (fmap (not . null) . readTVar . pendingSubs) + +getPendingSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) +getPendingSubs = getSubs_ pendingSubs +{-# INLINE getPendingSubs #-} + +getActiveSubs :: SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) +getActiveSubs = getSubs_ activeSubs +{-# INLINE getActiveSubs #-} + +getSubs_ :: (SessSubs -> TMap RecipientId RcvQueueSub) -> SMPTransportSession -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) +getSubs_ subs tSess = lookupSubs tSess >=> maybe (pure M.empty) (readTVar . subs) + +setSubsPending :: TransportSessionMode -> SMPTransportSession -> SessionId -> TSessionSubs -> STM (Map RecipientId RcvQueueSub) +setSubsPending mode tSess@(uId, srv, connId_) sessId tss@(TSessionSubs ss) + | entitySession == isJust connId_ = + TM.lookup tSess ss >>= withSessSubs (`setSubsPending_` Nothing) + | otherwise = + TM.lookupDelete tSess ss >>= withSessSubs setPendingChangeMode + where + entitySession = mode == TSMEntity + sessEntId = if entitySession then Just else const Nothing + withSessSubs run = \case + Nothing -> pure M.empty + Just s -> do + sessId' <- readTVar $ subsSessId s + if Just sessId == sessId' then run s else pure M.empty + setPendingChangeMode s = do + subs <- M.union <$> readTVar (activeSubs s) <*> readTVar (pendingSubs s) + unless (null subs) $ + forM_ subs $ \rq -> addPendingSub (uId, srv, sessEntId (connId rq)) rq tss + pure subs + +setSubsPending_ :: SessSubs -> Maybe SessionId -> STM (Map RecipientId RcvQueueSub) +setSubsPending_ s sessId_ = do + writeTVar (subsSessId s) sessId_ + let as = activeSubs s + subs <- readTVar as + unless (null subs) $ do + writeTVar as M.empty + modifyTVar' (pendingSubs s) $ M.union subs + pure subs + +updateClientNotices :: SMPTransportSession -> [(RecipientId, Maybe Int64)] -> TSessionSubs -> STM () +updateClientNotices tSess noticeIds ss = do + s <- getSessSubs tSess ss + modifyTVar' (pendingSubs s) $ \m -> foldl' (\m' (rcvId, clientNoticeId) -> M.adjust (\rq -> rq {clientNoticeId}) rcvId m') m noticeIds + +foldSessionSubs :: (a -> (SMPTransportSession, SessSubs) -> IO a) -> a -> TSessionSubs -> IO a +foldSessionSubs f a = foldM f a . M.assocs <=< readTVarIO . sessionSubs + +mapSubs :: (Map RecipientId RcvQueueSub -> a) -> SessSubs -> IO (a, a) +mapSubs f s = do + active <- readTVarIO $ activeSubs s + pending <- readTVarIO $ pendingSubs s + pure (f active, f pending) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 32e52e3aa..27840b092 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -29,6 +29,7 @@ module Simplex.Messaging.Client ( -- * Connect (disconnect) client to (from) SMP server TransportSession, + SMPTransportSession, ProtocolClient (thParams, sessionTs), SMPClient, ProxiedRelay (..), @@ -39,6 +40,7 @@ module Simplex.Messaging.Client transportHost', transportSession', useWebPort, + isPresetDomain, -- * SMP protocol command functions createSMPQueue, @@ -102,6 +104,7 @@ module Simplex.Messaging.Client temporaryClientError, smpClientServiceError, smpProxyError, + smpErrorClientNotice, textToHostMode, ServerTransmissionBatch, ServerTransmission (..), @@ -156,6 +159,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, sumTypeJSON) import Simplex.Messaging.Protocol +import Simplex.Messaging.Protocol.Types import Simplex.Messaging.Server.QueueStore.QueueInfo import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -549,6 +553,8 @@ type UserId = Int64 -- Please note that for SMP connection ID is used as entity ID, not queue ID. type TransportSession msg = (UserId, ProtoServer msg, Maybe ByteString) +type SMPTransportSession = TransportSession BrokerMsg + -- | Connects to 'ProtocolServer' using passed client configuration -- and queue for messages and notifications. -- @@ -712,13 +718,16 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS Right _ -> logWarn "SMP client unprocessed event" useWebPort :: NetworkConfig -> [HostName] -> ProtocolServer p -> Bool -useWebPort cfg presetDomains srv = case smpWebPortServers cfg of +useWebPort cfg presetDomains ProtocolServer {host = h :| _} = case smpWebPortServers cfg of SWPAll -> True - SWPPreset -> case srv of - ProtocolServer {host = THDomainName h :| _} -> any (`isSuffixOf` h) presetDomains - _ -> False + SWPPreset -> isPresetDomain presetDomains h SWPOff -> False +isPresetDomain :: [HostName] -> TransportHost -> Bool +isPresetDomain presetDomains = \case + THDomainName h -> any (`isSuffixOf` h) presetDomains + _ -> False + unexpectedResponse :: Show r => r -> ProtocolClientError err unexpectedResponse = PCEUnexpectedResponse . B.pack . take 32 . show @@ -791,6 +800,12 @@ smpProxyError = \case PCECryptoError _ -> CRYPTO PCEIOError _ -> INTERNAL +smpErrorClientNotice :: SMPClientError -> Maybe (Maybe ClientNotice) +smpErrorClientNotice = \case + PCEProtocolError (BLOCKED BlockingInfo {notice}) -> Just notice + _ -> Nothing +{-# INLINE smpErrorClientNotice #-} + -- | Create a new SMP queue. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#create-queue-command @@ -924,12 +939,12 @@ secureSMPQueue c nm rpKey rId senderKey = okSMPCommand (KEY senderKey) c nm rpKe {-# INLINE secureSMPQueue #-} -- | Secure the SMP queue via sender queue ID. -secureSndSMPQueue :: SMPClient -> NetworkRequestMode -> SndPrivateAuthKey -> SenderId -> SndPublicAuthKey -> ExceptT SMPClientError IO () -secureSndSMPQueue c nm spKey sId senderKey = okSMPCommand (SKEY senderKey) c nm spKey sId +secureSndSMPQueue :: SMPClient -> NetworkRequestMode -> SndPrivateAuthKey -> SenderId -> ExceptT SMPClientError IO () +secureSndSMPQueue c nm spKey sId = okSMPCommand (SKEY $ C.toPublic spKey) c nm spKey sId {-# INLINE secureSndSMPQueue #-} -proxySecureSndSMPQueue :: SMPClient -> NetworkRequestMode -> ProxiedRelay -> SndPrivateAuthKey -> SenderId -> SndPublicAuthKey -> ExceptT SMPClientError IO (Either ProxyClientError ()) -proxySecureSndSMPQueue c nm proxiedRelay spKey sId senderKey = proxyOKSMPCommand c nm proxiedRelay (Just spKey) sId (SKEY senderKey) +proxySecureSndSMPQueue :: SMPClient -> NetworkRequestMode -> ProxiedRelay -> SndPrivateAuthKey -> SenderId -> ExceptT SMPClientError IO (Either ProxyClientError ()) +proxySecureSndSMPQueue c nm proxiedRelay spKey sId = proxyOKSMPCommand c nm proxiedRelay (Just spKey) sId (SKEY $ C.toPublic spKey) {-# INLINE proxySecureSndSMPQueue #-} -- | Add or update date for queue link @@ -943,15 +958,15 @@ deleteSMPQueueLink = okSMPCommand LDEL {-# INLINE deleteSMPQueueLink #-} -- | Get 1-time inviation SMP queue link data and secure the queue via queue link ID. -secureGetSMPQueueLink :: SMPClient -> NetworkRequestMode -> SndPrivateAuthKey -> LinkId -> SndPublicAuthKey -> ExceptT SMPClientError IO (SenderId, QueueLinkData) -secureGetSMPQueueLink c nm spKey lnkId senderKey = - sendSMPCommand c nm (Just spKey) lnkId (LKEY senderKey) >>= \case +secureGetSMPQueueLink :: SMPClient -> NetworkRequestMode -> SndPrivateAuthKey -> LinkId -> ExceptT SMPClientError IO (SenderId, QueueLinkData) +secureGetSMPQueueLink c nm spKey lnkId = + sendSMPCommand c nm (Just spKey) lnkId (LKEY $ C.toPublic spKey) >>= \case LNK sId d -> pure (sId, d) r -> throwE $ unexpectedResponse r -proxySecureGetSMPQueueLink :: SMPClient -> NetworkRequestMode -> ProxiedRelay -> SndPrivateAuthKey -> LinkId -> SndPublicAuthKey -> ExceptT SMPClientError IO (Either ProxyClientError (SenderId, QueueLinkData)) -proxySecureGetSMPQueueLink c nm proxiedRelay spKey lnkId senderKey = - proxySMPCommand c nm proxiedRelay (Just spKey) lnkId (LKEY senderKey) >>= \case +proxySecureGetSMPQueueLink :: SMPClient -> NetworkRequestMode -> ProxiedRelay -> SndPrivateAuthKey -> LinkId -> ExceptT SMPClientError IO (Either ProxyClientError (SenderId, QueueLinkData)) +proxySecureGetSMPQueueLink c nm proxiedRelay spKey lnkId = + proxySMPCommand c nm proxiedRelay (Just spKey) lnkId (LKEY $ C.toPublic spKey) >>= \case Right (LNK sId d) -> pure $ Right (sId, d) Right r -> throwE $ unexpectedResponse r Left e -> pure $ Left e diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index ed1363b46..9cc78acb3 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -76,6 +76,7 @@ module Simplex.Messaging.Crypto generateKeyPair, generateSignatureKeyPair, generateAuthKeyPair, + generatePrivateAuthKey, generateDhKeyPair, privateToX509, x509ToPublic, @@ -329,10 +330,10 @@ type PublicKeyX448 = PublicKey X448 -- | GADT for private keys. data PrivateKey (a :: Algorithm) where - PrivateKeyEd25519 :: Ed25519.SecretKey -> Ed25519.PublicKey -> PrivateKey Ed25519 - PrivateKeyEd448 :: Ed448.SecretKey -> Ed448.PublicKey -> PrivateKey Ed448 - PrivateKeyX25519 :: X25519.SecretKey -> X25519.PublicKey -> PrivateKey X25519 - PrivateKeyX448 :: X448.SecretKey -> X448.PublicKey -> PrivateKey X448 + PrivateKeyEd25519 :: Ed25519.SecretKey -> PrivateKey Ed25519 + PrivateKeyEd448 :: Ed448.SecretKey -> PrivateKey Ed448 + PrivateKeyX25519 :: X25519.SecretKey -> PrivateKey X25519 + PrivateKeyX448 :: X448.SecretKey -> PrivateKey X448 deriving instance Eq (PrivateKey a) @@ -616,48 +617,66 @@ class CryptoPrivateKey pk where type PublicKeyType pk toPrivKey :: (forall a. AlgorithmI a => PrivateKey a -> b) -> pk -> b privKey :: APrivateKey -> Either String pk + toPublic :: pk -> PublicKeyType pk instance CryptoPrivateKey APrivateKey where type PublicKeyType APrivateKey = APublicKey toPrivKey f (APrivateKey _ k) = f k + {-# INLINE toPrivKey #-} privKey = Right + {-# INLINE privKey #-} + toPublic (APrivateKey a k) = APublicKey a (toPublic k) + {-# INLINE toPublic #-} instance CryptoPrivateKey APrivateSignKey where type PublicKeyType APrivateSignKey = APublicVerifyKey toPrivKey f (APrivateSignKey _ k) = f k + {-# INLINE toPrivKey #-} privKey (APrivateKey a k) = case signatureAlgorithm a of Just Dict -> Right $ APrivateSignKey a k _ -> Left "key does not support signature algorithms" + toPublic (APrivateSignKey a k) = APublicVerifyKey a (toPublic k) + {-# INLINE toPublic #-} instance CryptoPrivateKey APrivateAuthKey where type PublicKeyType APrivateAuthKey = APublicAuthKey toPrivKey f (APrivateAuthKey _ k) = f k + {-# INLINE toPrivKey #-} privKey (APrivateKey a k) = case authAlgorithm a of Just Dict -> Right $ APrivateAuthKey a k _ -> Left "key does not support auth algorithms" + toPublic (APrivateAuthKey a k) = APublicAuthKey a (toPublic k) + {-# INLINE toPublic #-} instance CryptoPrivateKey APrivateDhKey where type PublicKeyType APrivateDhKey = APublicDhKey toPrivKey f (APrivateDhKey _ k) = f k + {-# INLINE toPrivKey #-} privKey (APrivateKey a k) = case dhAlgorithm a of Just Dict -> Right $ APrivateDhKey a k _ -> Left "key does not support DH algorithm" + toPublic (APrivateDhKey a k) = APublicDhKey a (toPublic k) + {-# INLINE toPublic #-} instance AlgorithmI a => CryptoPrivateKey (PrivateKey a) where type PublicKeyType (PrivateKey a) = PublicKey a toPrivKey = id + {-# INLINE toPrivKey #-} privKey (APrivateKey _ k) = checkAlgorithm k + {-# INLINE privKey #-} + toPublic = publicKey + {-# INLINE toPublic #-} publicKey :: PrivateKey a -> PublicKey a publicKey = \case - PrivateKeyEd25519 _ k -> PublicKeyEd25519 k - PrivateKeyEd448 _ k -> PublicKeyEd448 k - PrivateKeyX25519 _ k -> PublicKeyX25519 k - PrivateKeyX448 _ k -> PublicKeyX448 k + PrivateKeyEd25519 pk -> PublicKeyEd25519 (Ed25519.toPublic pk) + PrivateKeyEd448 pk -> PublicKeyEd448 (Ed448.toPublic pk) + PrivateKeyX25519 pk -> PublicKeyX25519 (X25519.toPublic pk) + PrivateKeyX448 pk -> PublicKeyX448 (X448.toPublic pk) -- | Expand signature private key to a key pair. signatureKeyPair :: APrivateSignKey -> ASignatureKeyPair -signatureKeyPair ak@(APrivateSignKey a k) = (APublicVerifyKey a (publicKey k), ak) +signatureKeyPair ak@(APrivateSignKey a k) = (APublicVerifyKey a (toPublic k), ak) encodePrivKey :: CryptoPrivateKey pk => pk -> ByteString encodePrivKey = toPrivKey $ encodeASNObj . privateToX509 @@ -707,6 +726,9 @@ generateSignatureKeyPair a g = bimap (APublicVerifyKey a) (APrivateSignKey a) <$ generateAuthKeyPair :: (AlgorithmI a, AuthAlgorithm a) => SAlgorithm a -> TVar ChaChaDRG -> STM AAuthKeyPair generateAuthKeyPair a g = bimap (APublicAuthKey a) (APrivateAuthKey a) <$> generateKeyPair g +generatePrivateAuthKey :: (AlgorithmI a, AuthAlgorithm a) => SAlgorithm a -> TVar ChaChaDRG -> STM APrivateAuthKey +generatePrivateAuthKey a g = APrivateAuthKey a <$> generatePrivateKey g + generateDhKeyPair :: (AlgorithmI a, DhAlgorithm a) => SAlgorithm a -> TVar ChaChaDRG -> STM ADhKeyPair generateDhKeyPair a g = bimap (APublicDhKey a) (APrivateDhKey a) <$> generateKeyPair g @@ -714,23 +736,19 @@ generateKeyPair :: forall a. AlgorithmI a => TVar ChaChaDRG -> STM (KeyPair a) generateKeyPair g = stateTVar g (`withDRG` generateKeyPair_) generateKeyPair_ :: forall a. AlgorithmI a => MonadPseudoRandom ChaChaDRG (KeyPair a) -generateKeyPair_ = case sAlgorithm @a of - SEd25519 -> - Ed25519.generateSecretKey >>= \pk -> - let k = Ed25519.toPublic pk - in pure (PublicKeyEd25519 k, PrivateKeyEd25519 pk k) - SEd448 -> - Ed448.generateSecretKey >>= \pk -> - let k = Ed448.toPublic pk - in pure (PublicKeyEd448 k, PrivateKeyEd448 pk k) - SX25519 -> - X25519.generateSecretKey >>= \pk -> - let k = X25519.toPublic pk - in pure (PublicKeyX25519 k, PrivateKeyX25519 pk k) - SX448 -> - X448.generateSecretKey >>= \pk -> - let k = X448.toPublic pk - in pure (PublicKeyX448 k, PrivateKeyX448 pk k) +generateKeyPair_ = do + pk <- generatePrivateKey_ + pure (toPublic pk, pk) + +generatePrivateKey :: forall a. AlgorithmI a => TVar ChaChaDRG -> STM (PrivateKey a) +generatePrivateKey g = stateTVar g (`withDRG` generatePrivateKey_) + +generatePrivateKey_ :: forall a. AlgorithmI a => MonadPseudoRandom ChaChaDRG (PrivateKey a) +generatePrivateKey_ = case sAlgorithm @a of + SEd25519 -> PrivateKeyEd25519 <$> Ed25519.generateSecretKey + SEd448 -> PrivateKeyEd448 <$> Ed448.generateSecretKey + SX25519 -> PrivateKeyX25519 <$> X25519.generateSecretKey + SX448 -> PrivateKeyX448 <$> X448.generateSecretKey instance ToField APrivateSignKey where toField = toField . Binary . encodePrivKey @@ -854,8 +872,8 @@ instance SignatureSize APublicVerifyKey where instance SignatureAlgorithm a => SignatureSize (PrivateKey a) where signatureSize = \case - PrivateKeyEd25519 _ _ -> Ed25519.signatureSize - PrivateKeyEd448 _ _ -> Ed448.signatureSize + PrivateKeyEd25519 _ -> Ed25519.signatureSize + PrivateKeyEd448 _ -> Ed448.signatureSize {-# INLINE signatureSize #-} instance SignatureAlgorithm a => SignatureSize (PublicKey a) where @@ -1155,8 +1173,8 @@ cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError -- -- Used by SMP clients to sign SMP commands and by SMP agents to sign messages. sign' :: SignatureAlgorithm a => PrivateKey a -> ByteString -> Signature a -sign' (PrivateKeyEd25519 pk k) msg = SignatureEd25519 $ Ed25519.sign pk k msg -sign' (PrivateKeyEd448 pk k) msg = SignatureEd448 $ Ed448.sign pk k msg +sign' (PrivateKeyEd25519 pk) msg = SignatureEd25519 $ Ed25519.sign pk (Ed25519.toPublic pk) msg +sign' (PrivateKeyEd448 pk) msg = SignatureEd448 $ Ed448.sign pk (Ed448.toPublic pk) msg {-# INLINE sign' #-} sign :: APrivateSignKey -> ByteString -> ASignature @@ -1260,8 +1278,8 @@ verify (APublicVerifyKey a k) (ASignature a' sig) msg = case testEquality a a' o _ -> False dh' :: DhAlgorithm a => PublicKey a -> PrivateKey a -> DhSecret a -dh' (PublicKeyX25519 k) (PrivateKeyX25519 pk _) = DhSecretX25519 $ X25519.dh k pk -dh' (PublicKeyX448 k) (PrivateKeyX448 pk _) = DhSecretX448 $ X448.dh k pk +dh' (PublicKeyX25519 k) (PrivateKeyX25519 pk) = DhSecretX25519 $ X25519.dh k pk +dh' (PublicKeyX448 k) (PrivateKeyX448 pk) = DhSecretX448 $ X448.dh k pk {-# INLINE dh' #-} -- | NaCl @crypto_box@ encrypt with padding with a shared DH secret and 192-bit nonce. @@ -1465,10 +1483,10 @@ publicToX509 = \case privateToX509 :: PrivateKey a -> X.PrivKey privateToX509 = \case - PrivateKeyEd25519 k _ -> X.PrivKeyEd25519 k - PrivateKeyEd448 k _ -> X.PrivKeyEd448 k - PrivateKeyX25519 k _ -> X.PrivKeyX25519 k - PrivateKeyX448 k _ -> X.PrivKeyX448 k + PrivateKeyEd25519 k -> X.PrivKeyEd25519 k + PrivateKeyEd448 k -> X.PrivKeyEd448 k + PrivateKeyX25519 k -> X.PrivKeyX25519 k + PrivateKeyX448 k -> X.PrivKeyX448 k encodeASNObj :: ASN1Object a => a -> ByteString encodeASNObj k = toStrict . encodeASN1 DER $ toASN1 k [] @@ -1495,10 +1513,10 @@ x509ToPublic' k = x509ToPublic (k, []) >>= pubKey x509ToPrivate :: (X.PrivKey, [ASN1]) -> Either String APrivateKey x509ToPrivate = \case - (X.PrivKeyEd25519 k, []) -> Right . APrivateKey SEd25519 . PrivateKeyEd25519 k $ Ed25519.toPublic k - (X.PrivKeyEd448 k, []) -> Right . APrivateKey SEd448 . PrivateKeyEd448 k $ Ed448.toPublic k - (X.PrivKeyX25519 k, []) -> Right . APrivateKey SX25519 . PrivateKeyX25519 k $ X25519.toPublic k - (X.PrivKeyX448 k, []) -> Right . APrivateKey SX448 . PrivateKeyX448 k $ X448.toPublic k + (X.PrivKeyEd25519 k, []) -> Right $ APrivateKey SEd25519 $ PrivateKeyEd25519 k + (X.PrivKeyEd448 k, []) -> Right $ APrivateKey SEd448 $ PrivateKeyEd448 k + (X.PrivKeyX25519 k, []) -> Right $ APrivateKey SX25519 $ PrivateKeyX25519 k + (X.PrivKeyX448 k, []) -> Right $ APrivateKey SX448 $ PrivateKeyX448 k r -> keyError r x509ToPrivate' :: CryptoPrivateKey k => X.PrivKey -> Either String k diff --git a/src/Simplex/Messaging/Crypto/ShortLink.hs b/src/Simplex/Messaging/Crypto/ShortLink.hs index 34f292c8c..f7c65c1e6 100644 --- a/src/Simplex/Messaging/Crypto/ShortLink.hs +++ b/src/Simplex/Messaging/Crypto/ShortLink.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} @@ -48,20 +49,20 @@ contactShortLinkKdf (LinkKey k) = invShortLinkKdf :: LinkKey -> C.SbKey invShortLinkKdf (LinkKey k) = C.unsafeSbKey $ C.hkdf "" k "SimpleXInvLink" 32 -encodeSignLinkData :: forall c. ConnectionModeI c => C.KeyPairEd25519 -> VersionRangeSMPA -> ConnectionRequestUri c -> UserLinkData -> (LinkKey, (ByteString, ByteString)) +encodeSignLinkData :: ConnectionModeI c => C.KeyPairEd25519 -> VersionRangeSMPA -> ConnectionRequestUri c -> UserConnLinkData c -> (LinkKey, (ByteString, ByteString)) encodeSignLinkData (rootKey, pk) agentVRange connReq userData = let fd = smpEncode FixedLinkData {agentVRange, rootKey, connReq} - md = smpEncode $ connLinkData @c agentVRange userData + md = smpEncode $ connLinkData agentVRange userData in (LinkKey (C.sha3_256 fd), (encodeSign pk fd, encodeSign pk md)) -encodeSignUserData :: forall c. ConnectionModeI c => SConnectionMode c -> C.PrivateKeyEd25519 -> VersionRangeSMPA -> UserLinkData -> ByteString -encodeSignUserData _ pk agentVRange userData = - encodeSign pk $ smpEncode $ connLinkData @c agentVRange userData +encodeSignUserData :: ConnectionModeI c => SConnectionMode c -> C.PrivateKeyEd25519 -> VersionRangeSMPA -> UserConnLinkData c -> ByteString +encodeSignUserData _ pk agentVRange userLinkData = + encodeSign pk $ smpEncode $ connLinkData agentVRange userLinkData -connLinkData :: forall c. ConnectionModeI c => VersionRangeSMPA -> UserLinkData -> ConnLinkData c -connLinkData agentVRange userData = case sConnectionMode @c of - SCMInvitation -> InvitationLinkData agentVRange userData - SCMContact -> ContactLinkData {agentVRange, direct = True, owners = [], relays = [], userData} +connLinkData :: VersionRangeSMPA -> UserConnLinkData c -> ConnLinkData c +connLinkData vr = \case + UserInvLinkData d -> InvitationLinkData vr d + UserContactLinkData d -> ContactLinkData vr d encodeSign :: C.PrivateKeyEd25519 -> ByteString -> ByteString encodeSign pk s = smpEncode (C.sign' pk s) <> s diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index d254aaaa6..a97d86c33 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -10,6 +10,9 @@ module Simplex.Messaging.Encoding.String strToJSON, strToJEncoding, strParseJSON, + textToJSON, + textToEncoding, + textParseJSON, base64urlP, strEncodeList, strListP, @@ -225,9 +228,22 @@ _strP = A.space *> strP strToJSON :: StrEncoding a => a -> J.Value strToJSON = J.String . decodeLatin1 . strEncode +{-# INLINE strToJSON #-} strToJEncoding :: StrEncoding a => a -> J.Encoding strToJEncoding = JE.text . decodeLatin1 . strEncode +{-# INLINE strToJEncoding #-} strParseJSON :: StrEncoding a => String -> J.Value -> JT.Parser a strParseJSON name = J.withText name $ either fail pure . parseAll strP . encodeUtf8 + +textToJSON :: TextEncoding a => a -> J.Value +textToJSON = J.String . textEncode +{-# INLINE textToJSON #-} + +textToEncoding :: TextEncoding a => a -> J.Encoding +textToEncoding = JE.text . textEncode +{-# INLINE textToEncoding #-} + +textParseJSON :: TextEncoding a => String -> J.Value -> JT.Parser a +textParseJSON name = J.withText name $ maybe (fail name) pure . textDecode diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index fe2574eab..7e8acac81 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -67,9 +67,9 @@ import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server import Simplex.Messaging.Server.Control (CPClientRole (..)) import Simplex.Messaging.Server.Env.STM (StartOptions (..)) -import Simplex.Messaging.Server.QueueStore (getSystemDate) import Simplex.Messaging.Server.Stats (PeriodStats (..), PeriodStatCounts (..), periodStatCounts, periodStatDataCounts, updatePeriodStats) import Simplex.Messaging.Session +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import Simplex.Messaging.Transport (ASrvTransport, ATransport (..), THandle (..), THandleAuth (..), THandleParams (..), TProxy, Transport (..), TransportPeer (..), defaultSupportedParams) import Simplex.Messaging.Transport.Buffer (trimCR) diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 63a81ac0f..cb22af000 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -25,7 +25,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Protocol (NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (whenM, ($>>=)) @@ -61,10 +61,10 @@ data NtfTknData = NtfTknData tknDhSecret :: C.DhSecretX25519, tknRegCode :: NtfRegCode, tknCronInterval :: TVar Word16, - tknUpdatedAt :: TVar (Maybe RoundedSystemTime) + tknUpdatedAt :: TVar (Maybe SystemDate) } -mkNtfTknData :: NtfTokenId -> NewNtfEntity 'Token -> C.KeyPairX25519 -> C.DhSecretX25519 -> NtfRegCode -> RoundedSystemTime -> IO NtfTknData +mkNtfTknData :: NtfTokenId -> NewNtfEntity 'Token -> C.KeyPairX25519 -> C.DhSecretX25519 -> NtfRegCode -> SystemDate -> IO NtfTknData mkNtfTknData ntfTknId (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode ts = do tknStatus <- newTVarIO NTRegistered tknCronInterval <- newTVarIO 0 diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs index 226a02dc6..6a53ff4a2 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -1,11 +1,11 @@ {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Simplex.Messaging.Notifications.Server.Store.Migrations where import Data.List (sortOn) import Data.Text (Text) -import qualified Data.Text as T import Simplex.Messaging.Agent.Store.Shared import Text.RawString.QQ (r) @@ -23,8 +23,7 @@ ntfServerMigrations = sortOn name $ map migration ntfServerSchemaMigrations m20250417_initial :: Text m20250417_initial = - T.pack - [r| + [r| CREATE TABLE tokens( token_id BYTEA NOT NULL, push_provider TEXT NOT NULL, @@ -83,8 +82,7 @@ CREATE UNIQUE INDEX idx_last_notifications_token_subscription ON last_notificati m20250517_service_cert :: Text m20250517_service_cert = - T.pack - [r| + [r| ALTER TABLE smp_servers ADD COLUMN ntf_service_id BYTEA; ALTER TABLE subscriptions ADD COLUMN ntf_service_assoc BOOLEAN NOT NULL DEFAULT FALSE; @@ -95,8 +93,7 @@ CREATE INDEX idx_subscriptions_smp_server_id_ntf_service_status ON subscriptions down_m20250517_service_cert :: Text down_m20250517_service_cert = - T.pack - [r| + [r| DROP INDEX idx_subscriptions_smp_server_id_ntf_service_status; CREATE INDEX idx_subscriptions_smp_server_id_status ON subscriptions(smp_server_id, status); diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 78891796f..80d946c8b 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -65,10 +65,10 @@ import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Server.StoreLog import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate) import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) import Simplex.Messaging.Server.StoreLog (openWriteStoreLog) +import Simplex.Messaging.SystemTime import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (anyM, firstRow, maybeFirstRow, toChunks, tshow) import System.Exit (exitFailure) @@ -87,7 +87,7 @@ data NtfPostgresStore = NtfPostgresStore deletedTTL :: Int64 } -mkNtfTknRec :: NtfTokenId -> NewNtfEntity 'Token -> C.PrivateKeyX25519 -> C.DhSecretX25519 -> NtfRegCode -> RoundedSystemTime -> NtfTknRec +mkNtfTknRec :: NtfTokenId -> NewNtfEntity 'Token -> C.PrivateKeyX25519 -> C.DhSecretX25519 -> NtfRegCode -> SystemDate -> NtfTknRec mkNtfTknRec ntfTknId (NewNtfTkn token tknVerifyKey _) tknDhPrivKey tknDhSecret tknRegCode ts = NtfTknRec {ntfTknId, token, tknStatus = NTRegistered, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval = 0, tknUpdatedAt = Just ts} @@ -170,7 +170,7 @@ updateTokenDate st db NtfTknRec {ntfTknId, tknUpdatedAt} = do void $ DB.execute db "UPDATE tokens SET updated_at = ? WHERE token_id = ?" (ts, ntfTknId) withLog "updateTokenDate" st $ \sl -> logUpdateTokenTime sl ntfTknId ts -type NtfTknRow = (NtfTokenId, PushProvider, Binary ByteString, NtfTknStatus, NtfPublicAuthKey, C.PrivateKeyX25519, C.DhSecretX25519, Binary ByteString, Word16, Maybe RoundedSystemTime) +type NtfTknRow = (NtfTokenId, PushProvider, Binary ByteString, NtfTknStatus, NtfPublicAuthKey, C.PrivateKeyX25519, C.DhSecretX25519, Binary ByteString, Word16, Maybe SystemDate) ntfTknQuery :: Query ntfTknQuery = diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Types.hs b/src/Simplex/Messaging/Notifications/Server/Store/Types.hs index 39e303340..abac8d14e 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Types.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Types.hs @@ -16,7 +16,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfSubStatus, NtfSubscriptionId, NtfTokenId, NtfTknStatus, SMPQueueNtf) import Simplex.Messaging.Notifications.Server.Store (NtfSubData (..), NtfTknData (..)) import Simplex.Messaging.Protocol (NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) +import Simplex.Messaging.SystemTime data NtfTknRec = NtfTknRec { ntfTknId :: NtfTokenId, @@ -27,7 +27,7 @@ data NtfTknRec = NtfTknRec tknDhSecret :: C.DhSecretX25519, tknRegCode :: NtfRegCode, tknCronInterval :: Word16, - tknUpdatedAt :: Maybe RoundedSystemTime + tknUpdatedAt :: Maybe SystemDate } deriving (Show) diff --git a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs index e71ebaf57..7c71ddb08 100644 --- a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs @@ -39,8 +39,8 @@ import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Store import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Protocol (EntityId (..), SMPServer, ServiceId) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.SystemTime import System.IO data NtfStoreLogRecord @@ -49,7 +49,7 @@ data NtfStoreLogRecord | UpdateToken NtfTokenId DeviceToken NtfRegCode | TokenCron NtfTokenId Word16 | DeleteToken NtfTokenId - | UpdateTokenTime NtfTokenId RoundedSystemTime + | UpdateTokenTime NtfTokenId SystemDate | CreateSubscription NtfSubRec | SubscriptionStatus NtfSubscriptionId NtfSubStatus NtfAssociatedService | DeleteSubscription NtfSubscriptionId @@ -103,7 +103,7 @@ logTokenCron s tknId cronInt = logNtfStoreRecord s $ TokenCron tknId cronInt logDeleteToken :: StoreLog 'WriteMode -> NtfTokenId -> IO () logDeleteToken s tknId = logNtfStoreRecord s $ DeleteToken tknId -logUpdateTokenTime :: StoreLog 'WriteMode -> NtfTokenId -> RoundedSystemTime -> IO () +logUpdateTokenTime :: StoreLog 'WriteMode -> NtfTokenId -> SystemDate -> IO () logUpdateTokenTime s tknId t = logNtfStoreRecord s $ UpdateTokenTime tknId t logCreateSubscription :: StoreLog 'WriteMode -> NtfSubRec -> IO () diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 40314ad2a..13ac3f182 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -217,6 +217,7 @@ import Control.Applicative (optional, (<|>)) import Control.Exception (Exception, SomeException, displayException, fromException) import Control.Monad.Except import Data.Aeson (FromJSON (..), ToJSON (..)) +import qualified Data.Aeson as J import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A @@ -224,6 +225,7 @@ import Data.Bifunctor (bimap, first) import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy as LB import Data.Char (isPrint, isSpace) import Data.Constraint (Dict (..)) import Data.Functor (($>)) @@ -249,6 +251,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers +import Simplex.Messaging.Protocol.Types import Simplex.Messaging.Server.QueueStore.QueueInfo import Simplex.Messaging.ServiceScheme import Simplex.Messaging.Transport @@ -1588,7 +1591,8 @@ toNetworkError e = maybe (NEConnectError err) fromTLSError (fromException e) _ -> NETLSError err data BlockingInfo = BlockingInfo - { reason :: BlockingReason + { reason :: BlockingReason, + notice :: Maybe ClientNotice } deriving (Eq, Show) @@ -1596,10 +1600,12 @@ data BlockingReason = BRSpam | BRContent deriving (Eq, Show) instance StrEncoding BlockingInfo where - strEncode BlockingInfo {reason} = "reason=" <> strEncode reason + strEncode BlockingInfo {reason, notice} = + "reason=" <> strEncode reason <> maybe "" ((",notice=" <>) . LB.toStrict . J.encode) notice strP = do reason <- "reason=" *> strP - pure BlockingInfo {reason} + notice <- optional $ ",notice=" *> (J.eitherDecodeStrict <$?> A.takeByteString) + pure BlockingInfo {reason, notice} instance Encoding BlockingInfo where smpEncode = strEncode @@ -1843,9 +1849,13 @@ instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where | otherwise -> e END_ INFO info -> e (INFO_, ' ', info) OK -> e OK_ - ERR err -> case err of - BLOCKED _ | v < blockedEntitySMPVersion -> e (ERR_, ' ', AUTH) - _ -> e (ERR_, ' ', err) + ERR err -> e (ERR_, ' ', err') + where + err' = case err of + BLOCKED info + | v < blockedEntitySMPVersion -> AUTH + | v < clientNoticesSMPVersion -> BLOCKED info {notice = Nothing} + _ -> err PONG -> e PONG_ where e :: Encoding a => a -> ByteString diff --git a/src/Simplex/Messaging/Protocol/Types.hs b/src/Simplex/Messaging/Protocol/Types.hs new file mode 100644 index 000000000..0cfd660e3 --- /dev/null +++ b/src/Simplex/Messaging/Protocol/Types.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} + +module Simplex.Messaging.Protocol.Types where + +import qualified Data.Aeson.TH as J +import Data.Int (Int64) +import Simplex.Messaging.Parsers + +data ClientNotice = ClientNotice + { ttl :: Maybe Int64 -- seconds, Nothing - indefinite + } + deriving (Eq, Show) + +$(J.deriveJSON defaultJSON ''ClientNotice) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 7d6e00ab0..ec75a07d4 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -115,6 +115,7 @@ import Simplex.Messaging.Server.QueueStore.QueueInfo import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog (foldLogLines) +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport @@ -992,14 +993,20 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt else do r <- liftIO $ runExceptT $ do (q, QueueRec {status}) <- ExceptT $ getSenderQueue st qId - when (status == EntityActive) $ ExceptT $ blockQueue (queueStore st) q info - pure status + let rId = recipientId q + when (status /= EntityBlocked info) $ do + ExceptT $ blockQueue (queueStore st) q info + liftIO $ + getSubscribedClient rId (queueSubscribers $ subscribers srv) + $>>= readTVarIO + >>= mapM_ (\c -> atomically (writeTBQueue (sndQ c) ([(NoCorrId, rId, ERR $ BLOCKED info)] , []))) + pure (status, EntityBlocked info) case r of Left e -> liftIO $ hPutStrLn h $ "error: " <> show e - Right EntityActive -> do + Right (EntityActive, status') -> do incStat $ qBlocked stats - liftIO $ hPutStrLn h "ok, queue blocked" - Right status -> liftIO $ hPutStrLn h $ "ok, already inactive: " <> show status + liftIO $ hPutStrLn h $ "ok, queue blocked: " <> show status' + Right (_, status') -> liftIO $ hPutStrLn h $ "ok, already inactive: " <> show status' CPUnblock qId -> withUserRole $ unliftIO u $ do st <- asks msgStore r <- liftIO $ runExceptT $ do @@ -1679,7 +1686,7 @@ client -- This is tracked as "subscription" in the client to prevent these -- clients from being able to subscribe. pure s - getMessage_ :: Sub -> Maybe (MsgId, RoundedSystemTime) -> M s (Transmission BrokerMsg) + getMessage_ :: Sub -> Maybe (MsgId, SystemSeconds) -> M s (Transmission BrokerMsg) getMessage_ s delivered_ = do stats <- asks serverStats fmap (either err id) $ liftIO $ runExceptT $ @@ -1805,13 +1812,13 @@ client pure (corrId, entId, maybe OK (MSG . encryptMsg qr) msg_) _ -> pure $ err NO_MSG where - getDelivered :: Sub -> STM (Maybe (ServerSub, RoundedSystemTime)) + getDelivered :: Sub -> STM (Maybe (ServerSub, SystemSeconds)) getDelivered Sub {delivered, subThread} = do readTVar delivered $>>= \(msgId', ts) -> if msgId == msgId' || B.null msgId then writeTVar delivered Nothing $> Just (subThread, ts) else pure Nothing - updateStats :: ServerStats -> Bool -> RoundedSystemTime -> Message -> IO () + updateStats :: ServerStats -> Bool -> SystemSeconds -> Message -> IO () updateStats stats isGet deliveryTime = \case MessageQuota {} -> pure () Message {msgFlags} -> do @@ -2030,7 +2037,7 @@ client msgId' = messageId msg msgTs' = messageTs msg - setDelivered :: Sub -> Message -> RoundedSystemTime -> STM () + setDelivered :: Sub -> Message -> SystemSeconds -> STM () setDelivered Sub {delivered} msg !ts = do let !msgId = messageId msg writeTVar delivered $ Just (msgId, ts) diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index b72922f04..24cd6dfcc 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -123,6 +123,7 @@ import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.Server.StoreLog.ReadWrite +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ASrvTransport, SMPVersion, THandleParams, TransportPeer (..), VersionRangeSMP) @@ -464,7 +465,7 @@ data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) data Sub = Sub { subThread :: ServerSub, -- Nothing value indicates that sub - delivered :: TVar (Maybe (MsgId, RoundedSystemTime)) + delivered :: TVar (Maybe (MsgId, SystemSeconds)) } newServer :: IO (Server s) @@ -567,6 +568,10 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp forM_ storePaths_ $ \StorePaths {storeLogFile = f} -> loadStoreLog (mkQueue ms True) f $ queueStore ms pure $ StoreMemory ms SSCMemoryJournal {storeLogFile, storeMsgsPath} -> do + logWarn $ + "Journal message store is deprecated and will be removed soon.\n" + <> "Please migrate to in-memory storage using `journal export` command.\n" + <> "After that you can migrate to PostgreSQL using `database import` command." let qsCfg = MQStoreCfg cfg = mkJournalStoreConfig qsCfg storeMsgsPath msgQueueQuota maxJournalMsgCount maxJournalStateLines idleQueueInterval ms <- newMsgStore cfg diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 8e5fd55ee..64d18088d 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -186,6 +186,26 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = storeLogExists <- doesFileExist storeLogFilePath msgsFileExists <- doesFileExist storeMsgsFilePath case (cmd, tables) of + (SCImport, DTAll) + | not schemaExists && storeLogExists && msgsFileExists -> do + storeLogFile <- getRequiredStoreLogFile ini + confirmOrExit + ("WARNING: store log file " <> storeLogFile <> " and message log file " <> storeMsgsFilePath <> " will be imported to PostrgreSQL database: " <> B.unpack connstr <> ", schema: " <> B.unpack schema) + "Store logs not imported" + (sCnt, qCnt) <- importStoreLogToDatabase logPath storeLogFile dbOpts + putStrLn $ "Imported: " <> show sCnt <> " services, " <> show qCnt <> " queues" + putStrLn "Importing messages..." + mCnt <- importMessagesToDatabase storeMsgsFilePath dbOpts + putStrLn $ "Import completed: " <> show mCnt <> " messages" + putStrLn $ case readStoreType ini of + Right (ASType SQSPostgres SMSPostgres) -> "store_queues and store_messages set to `database`, start the server." + Right _ -> "set store_queues and store_messages to `database` in INI file" + Left e -> e <> ", configure storage correctly" + | otherwise -> do + when schemaExists $ putStrLn $ "Schema " <> B.unpack schema <> " already exists in PostrgreSQL database: " <> B.unpack connstr + unless storeLogExists $ putStrLn $ storeLogFilePath <> " file does not exist." + unless msgsFileExists $ putStrLn $ storeMsgsFilePath <> " file does not exist." + exitFailure (SCImport, DTQueues) | schemaExists && storeLogExists -> exitConfigureQueueStore connstr schema | schemaExists -> do @@ -224,8 +244,27 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = putStrLn $ "Import completed: " <> show mCnt <> " messages" putStrLn $ case readStoreType ini of Right (ASType SQSPostgres SMSPostgres) -> "store_queues and store_messages set to `database`, start the server." - Right _ -> "set store_queues and store_messages set to `database` in INI file" + Right _ -> "set store_queues and store_messages to `database` in INI file" Left e -> e <> ", configure storage correctly" + (SCExport, DTAll) + | schemaExists && not storeLogExists && not msgsFileExists -> do + confirmOrExit + ("WARNING: PostrgreSQL schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath <> " and to message log file " <> storeMsgsFilePath) + "Database store not exported" + (sCnt, qCnt) <- exportDatabaseToStoreLog logPath dbOpts storeLogFilePath + putStrLn $ "Exported: " <> show sCnt <> " services, " <> show qCnt <> " queues" + putStrLn "Exporting messages..." + let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL} + ms <- newMsgStore $ PostgresMsgStoreCfg storeCfg defaultMsgQueueQuota + withFile storeMsgsFilePath WriteMode (try . exportDbMessages True ms) >>= \case + Right mCnt -> putStrLn $ "Export completed: " <> show mCnt <> " messages" + Left (e :: SomeException) -> putStrLn $ "Error exporting messages: " <> show e + closeMsgStore ms + | otherwise -> do + unless schemaExists $ putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr + when storeLogExists $ putStrLn $ storeLogFilePath <> " file already exists." + when msgsFileExists $ putStrLn $ storeMsgsFilePath <> " file already exists." + exitFailure (SCExport, DTQueues) | schemaExists && storeLogExists -> exitConfigureQueueStore connstr schema | not schemaExists -> do @@ -262,6 +301,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = putStrLn $ "Export completed: " <> show mCnt <> " messages" putStrLn "Export queues with `smp-server database export queues`" Left (e :: SomeException) -> putStrLn $ "Error exporting messages: " <> show e + closeMsgStore ms (SCDelete, _) | not schemaExists -> do putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr @@ -760,16 +800,18 @@ data CliCommand data StoreCmd = SCImport | SCExport | SCDelete -data DatabaseTable = DTQueues | DTMessages +data DatabaseTable = DTQueues | DTMessages | DTAll instance StrEncoding DatabaseTable where strEncode = \case DTQueues -> "queues" DTMessages -> "messages" + DTAll -> "all" strP = A.takeTill (== ' ') >>= \case "queues" -> pure DTQueues "messages" -> pure DTMessages + "all" -> pure DTAll _ -> fail "DatabaseTable" cliCommandP :: FilePath -> FilePath -> FilePath -> Parser CliCommand @@ -940,6 +982,7 @@ cliCommandP cfgPath logPath iniFile = ( long "table" <> help "Database tables: queues/messages" <> metavar "TABLE" + <> value DTAll ) parseBasicAuth :: ReadM ServerPassword parseBasicAuth = eitherReader $ fmap ServerPassword . strDecode . B.pack diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index e81c153de..5038c8826 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -84,6 +84,7 @@ import Simplex.Messaging.Server.QueueStore.Postgres #endif import Simplex.Messaging.Server.QueueStore.STM import Simplex.Messaging.Server.QueueStore.Types +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (ifM, tshow, whenM, ($>>=), (<$$>)) diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index 9395f5bac..e05719cf6 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -1,8 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -14,14 +12,13 @@ module Simplex.Messaging.Server.QueueStore where import Control.Applicative (optional, (<|>)) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) -import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) -import Data.Time.Clock.System (SystemTime (..), getSystemTime) import qualified Data.X509 as X import qualified Data.X509.Validation as XV import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol +import Simplex.Messaging.SystemTime import Simplex.Messaging.Transport (SMPServiceRole) #if defined(dbServerPostgres) import Data.Text.Encoding (decodeLatin1, encodeUtf8) @@ -40,7 +37,7 @@ data QueueRec = QueueRec queueData :: Maybe (LinkId, QueueLinkData), notifier :: Maybe NtfCreds, status :: ServerEntityStatus, - updatedAt :: Maybe RoundedSystemTime, + updatedAt :: Maybe SystemDate, rcvServiceId :: Maybe ServiceId } deriving (Show) @@ -67,7 +64,7 @@ data ServiceRec = ServiceRec serviceRole :: SMPServiceRole, serviceCert :: X.CertificateChain, serviceCertHash :: XV.Fingerprint, -- SHA512 hash of long-term service client certificate. See comment for ClientHandshake. - serviceCreatedAt :: RoundedSystemTime + serviceCreatedAt :: SystemDate } deriving (Show) @@ -111,22 +108,3 @@ instance FromField ServerEntityStatus where fromField = fromTextField_ $ eitherT instance ToField ServerEntityStatus where toField = toField . decodeLatin1 . strEncode #endif - -newtype RoundedSystemTime = RoundedSystemTime Int64 - deriving (Eq, Ord, Show) -#if defined(dbServerPostgres) - deriving newtype (FromField, ToField) -#endif - -instance StrEncoding RoundedSystemTime where - strEncode (RoundedSystemTime t) = strEncode t - strP = RoundedSystemTime <$> strP - -getRoundedSystemTime :: Int64 -> IO RoundedSystemTime -getRoundedSystemTime prec = (\t -> RoundedSystemTime $ (systemSeconds t `div` prec) * prec) <$> getSystemTime - -getSystemDate :: IO RoundedSystemTime -getSystemDate = getRoundedSystemTime 86400 - -getSystemSeconds :: IO RoundedSystemTime -getSystemSeconds = RoundedSystemTime . systemSeconds <$> getSystemTime diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index 4a53dcdd4..e86bec07b 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -85,6 +85,7 @@ import Simplex.Messaging.Server.QueueStore.Postgres.Migrations (serverMigrations import Simplex.Messaging.Server.QueueStore.STM (STMService (..), readQueueRecIO) import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPServiceRole (..)) @@ -429,7 +430,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where setStatusDB "unblockQueue" st sq EntityActive $ withLog "unblockQueue" st (`logUnblockQueue` recipientId sq) - updateQueueTime :: PostgresQueueStore q -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec) + updateQueueTime :: PostgresQueueStore q -> q -> SystemDate -> IO (Either ErrorType QueueRec) updateQueueTime st sq t = withQueueRec sq "updateQueueTime" $ \q@QueueRec {updatedAt} -> if updatedAt == Just t @@ -641,7 +642,7 @@ type QueueRecRow = ( RecipientId, NonEmpty RcvPublicAuthKey, RcvDhSecret, SenderId, Maybe SndPublicAuthKey, Maybe QueueMode, Maybe NotifierId, Maybe NtfPublicAuthKey, Maybe RcvNtfDhSecret, Maybe ServiceId, - ServerEntityStatus, Maybe RoundedSystemTime, Maybe LinkId, Maybe ServiceId + ServerEntityStatus, Maybe SystemDate, Maybe LinkId, Maybe ServiceId ) queueRecToRow :: (RecipientId, QueueRec) -> QueueRecRow :. (Maybe EncDataBytes, Maybe EncDataBytes) @@ -709,11 +710,11 @@ mkNotifier (Just notifierId, Just notifierKey, Just rcvNtfDhSecret) ntfServiceId Just NtfCreds {notifierId, notifierKey, rcvNtfDhSecret, ntfServiceId} mkNotifier _ _ = Nothing -serviceRecToRow :: ServiceRec -> (ServiceId, SMPServiceRole, X.CertificateChain, Binary ByteString, RoundedSystemTime) +serviceRecToRow :: ServiceRec -> (ServiceId, SMPServiceRole, X.CertificateChain, Binary ByteString, SystemDate) serviceRecToRow ServiceRec {serviceId, serviceRole, serviceCert, serviceCertHash = XV.Fingerprint fp, serviceCreatedAt} = (serviceId, serviceRole, serviceCert, Binary fp, serviceCreatedAt) -rowToServiceRec :: (ServiceId, SMPServiceRole, X.CertificateChain, Binary ByteString, RoundedSystemTime) -> ServiceRec +rowToServiceRec :: (ServiceId, SMPServiceRole, X.CertificateChain, Binary ByteString, SystemDate) -> ServiceRec rowToServiceRec (serviceId, serviceRole, serviceCert, Binary fp, serviceCreatedAt) = ServiceRec {serviceId, serviceRole, serviceCert, serviceCertHash = XV.Fingerprint fp, serviceCreatedAt} @@ -792,4 +793,8 @@ instance FromField C.APublicAuthKey where fromField = blobFieldDecoder C.decodeP instance ToField EncDataBytes where toField (EncDataBytes s) = toField (Binary s) deriving newtype instance FromField EncDataBytes + +deriving newtype instance ToField (RoundedSystemTime t) + +deriving newtype instance FromField (RoundedSystemTime t) #endif diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs index be14202c6..7ff8b9862 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs @@ -1,11 +1,11 @@ {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} module Simplex.Messaging.Server.QueueStore.Postgres.Migrations where import Data.List (sortOn) import Data.Text (Text) -import qualified Data.Text as T import Simplex.Messaging.Agent.Store.Shared import Text.RawString.QQ (r) @@ -26,8 +26,7 @@ serverMigrations = sortOn name $ map migration serverSchemaMigrations m20250207_initial :: Text m20250207_initial = - T.pack - [r| + [r| CREATE TABLE msg_queues( recipient_id BYTEA NOT NULL, recipient_key BYTEA NOT NULL, @@ -51,24 +50,21 @@ CREATE INDEX idx_msg_queues_deleted_at ON msg_queues (deleted_at); m20250319_updated_index :: Text m20250319_updated_index = - T.pack - [r| + [r| DROP INDEX idx_msg_queues_deleted_at; CREATE INDEX idx_msg_queues_updated_at ON msg_queues (deleted_at, updated_at); |] down_m20250319_updated_index :: Text down_m20250319_updated_index = - T.pack - [r| + [r| DROP INDEX idx_msg_queues_updated_at; CREATE INDEX idx_msg_queues_deleted_at ON msg_queues (deleted_at); |] m20250320_short_links :: Text m20250320_short_links = - T.pack - [r| + [r| ALTER TABLE msg_queues ADD COLUMN queue_mode TEXT, ADD COLUMN link_id BYTEA, @@ -88,8 +84,7 @@ CREATE UNIQUE INDEX idx_msg_queues_link_id ON msg_queues(link_id); down_m20250320_short_links :: Text down_m20250320_short_links = - T.pack - [r| + [r| ALTER TABLE msg_queues ADD COLUMN snd_secure BOOLEAN NOT NULL DEFAULT FALSE; UPDATE msg_queues SET snd_secure = TRUE WHERE queue_mode = 'M'; @@ -124,8 +119,7 @@ ALTER TABLE msg_queues RENAME COLUMN recipient_keys TO recipient_key; m20250514_service_certs :: Text m20250514_service_certs = - T.pack - [r| + [r| CREATE TABLE services( service_id BYTEA NOT NULL, service_role TEXT NOT NULL, @@ -147,8 +141,7 @@ CREATE INDEX idx_msg_queues_ntf_service_id ON msg_queues(ntf_service_id, deleted down_m20250514_service_certs :: Text down_m20250514_service_certs = - T.pack - [r| + [r| DROP INDEX idx_msg_queues_rcv_service_id; DROP INDEX idx_msg_queues_ntf_service_id; @@ -163,8 +156,7 @@ DROP TABLE services; m20250903_store_messages :: Text m20250903_store_messages = - T.pack - [r| + [r| CREATE TABLE messages( message_id BIGINT NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, recipient_id BYTEA NOT NULL REFERENCES msg_queues ON DELETE CASCADE ON UPDATE RESTRICT, @@ -434,8 +426,7 @@ $$; down_m20250903_store_messages :: Text down_m20250903_store_messages = - T.pack - [r| + [r| DROP FUNCTION write_message; DROP FUNCTION try_del_msg; DROP FUNCTION try_del_peek_msg; diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 515a0ee77..ad98698db 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -41,6 +41,7 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPServiceRole (..)) @@ -251,7 +252,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where setStatus (queueRec sq) EntityActive $>> withLog "unblockQueue" st (`logUnblockQueue` recipientId sq) - updateQueueTime :: STMQueueStore q -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec) + updateQueueTime :: STMQueueStore q -> q -> SystemDate -> IO (Either ErrorType QueueRec) updateQueueTime st sq t = withQueueRec qr update $>>= log' where qr = queueRec sq diff --git a/src/Simplex/Messaging/Server/QueueStore/Types.hs b/src/Simplex/Messaging/Server/QueueStore/Types.hs index ee155cf91..8de015421 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Types.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Types.hs @@ -14,6 +14,7 @@ import Data.List.NonEmpty (NonEmpty) import Data.Text (Text) import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore +import Simplex.Messaging.SystemTime import Simplex.Messaging.TMap (TMap) class StoreQueueClass q where @@ -41,7 +42,7 @@ class StoreQueueClass q => QueueStoreClass q s where suspendQueue :: s -> q -> IO (Either ErrorType ()) blockQueue :: s -> q -> BlockingInfo -> IO (Either ErrorType ()) unblockQueue :: s -> q -> IO (Either ErrorType ()) - updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec) + updateQueueTime :: s -> q -> SystemDate -> IO (Either ErrorType QueueRec) deleteStoreQueue :: s -> q -> IO (Either ErrorType QueueRec) getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId) setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index 4af195295..e60f87815 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -27,7 +27,7 @@ import Data.Time.Clock (UTCTime (..)) import GHC.IORef (atomicSwapIORef) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (EntityId (..)) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime (..)) +import Simplex.Messaging.SystemTime import Simplex.Messaging.Util (atomicModifyIORef'_, tshow, unlessM) data ServerStats = ServerStats @@ -976,7 +976,7 @@ data TimeBuckets = TimeBuckets emptyTimeBuckets :: TimeBuckets emptyTimeBuckets = TimeBuckets 0 0 IM.empty -updateTimeBuckets :: RoundedSystemTime -> RoundedSystemTime -> TimeBuckets -> TimeBuckets +updateTimeBuckets :: SystemSeconds -> SystemSeconds -> TimeBuckets -> TimeBuckets updateTimeBuckets (RoundedSystemTime deliveryTime) (RoundedSystemTime currTime) diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs index 6ea015066..4ceb3cddd 100644 --- a/src/Simplex/Messaging/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -55,9 +55,9 @@ import GHC.IO (catchAny) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol --- import Simplex.Messaging.Server.MsgStore.Types import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.StoreLog.Types +import Simplex.Messaging.SystemTime import Simplex.Messaging.Util (ifM, tshow, unlessM, whenM) import System.Directory (doesFileExist, listDirectory, removeFile, renameFile) import System.IO @@ -75,7 +75,7 @@ data StoreLogRecord | UnblockQueue QueueId | DeleteQueue QueueId | DeleteNotifier QueueId - | UpdateTime QueueId RoundedSystemTime + | UpdateTime QueueId SystemDate | NewService ServiceRec | QueueService RecipientId ASubscriberParty (Maybe ServiceId) deriving (Show) @@ -280,7 +280,7 @@ logDeleteQueue s = writeStoreLogRecord s . DeleteQueue logDeleteNotifier :: StoreLog 'WriteMode -> QueueId -> IO () logDeleteNotifier s = writeStoreLogRecord s . DeleteNotifier -logUpdateQueueTime :: StoreLog 'WriteMode -> QueueId -> RoundedSystemTime -> IO () +logUpdateQueueTime :: StoreLog 'WriteMode -> QueueId -> SystemDate -> IO () logUpdateQueueTime s qId t = writeStoreLogRecord s $ UpdateTime qId t logNewService :: StoreLog 'WriteMode -> ServiceRec -> IO () diff --git a/src/Simplex/Messaging/SystemTime.hs b/src/Simplex/Messaging/SystemTime.hs new file mode 100644 index 000000000..7435a694d --- /dev/null +++ b/src/Simplex/Messaging/SystemTime.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Simplex.Messaging.SystemTime where + +import Data.Aeson (FromJSON, ToJSON) +import Data.Int (Int64) +import Data.Time.Clock (UTCTime) +import Data.Time.Clock.System (SystemTime (..), getSystemTime, systemToUTCTime) +import Data.Typeable (Proxy (..)) +import GHC.TypeLits (KnownNat, Nat, natVal) +import Simplex.Messaging.Agent.Store.DB (FromField (..), ToField (..)) +import Simplex.Messaging.Encoding.String + +newtype RoundedSystemTime (t :: Nat) = RoundedSystemTime {roundedSeconds :: Int64} + deriving (Eq, Ord, Show) + deriving newtype (FromJSON, ToJSON, FromField, ToField) + +type SystemDate = RoundedSystemTime 86400 + +type SystemSeconds = RoundedSystemTime 1 + +instance StrEncoding (RoundedSystemTime t) where + strEncode (RoundedSystemTime t) = strEncode t + strP = RoundedSystemTime <$> strP + +getRoundedSystemTime :: forall t. KnownNat t => IO (RoundedSystemTime t) +getRoundedSystemTime = (\t -> RoundedSystemTime $ (systemSeconds t `div` prec) * prec) <$> getSystemTime + where + prec = fromIntegral $ natVal $ Proxy @t + +getSystemDate :: IO SystemDate +getSystemDate = getRoundedSystemTime +{-# INLINE getSystemDate #-} + +getSystemSeconds :: IO SystemSeconds +getSystemSeconds = RoundedSystemTime . systemSeconds <$> getSystemTime +{-# INLINE getSystemSeconds #-} + +roundedToUTCTime :: RoundedSystemTime t -> UTCTime +roundedToUTCTime = systemToUTCTime . (`MkSystemTime` 0) . roundedSeconds +{-# INLINE roundedToUTCTime #-} diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 0b2eb3b75..e2e912875 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -55,6 +55,7 @@ module Simplex.Messaging.Transport shortLinksSMPVersion, serviceCertsSMPVersion, newNtfCredsSMPVersion, + clientNoticesSMPVersion, simplexMQVersion, smpBlockSize, TransportConfig (..), @@ -168,6 +169,7 @@ smpBlockSize = 16384 -- 15 - short links, with associated data passed in NEW of LSET command (3/30/2025) -- 16 - service certificates (5/31/2025) -- 17 - create notification credentials with NEW (7/12/2025) +-- 18 - support client notices (10/10/2025) data SMPVersion @@ -213,6 +215,9 @@ serviceCertsSMPVersion = VersionSMP 16 newNtfCredsSMPVersion :: VersionSMP newNtfCredsSMPVersion = VersionSMP 17 +clientNoticesSMPVersion :: VersionSMP +clientNoticesSMPVersion = VersionSMP 18 + minClientSMPRelayVersion :: VersionSMP minClientSMPRelayVersion = VersionSMP 6 @@ -220,13 +225,13 @@ minServerSMPRelayVersion :: VersionSMP minServerSMPRelayVersion = VersionSMP 6 currentClientSMPRelayVersion :: VersionSMP -currentClientSMPRelayVersion = VersionSMP 17 +currentClientSMPRelayVersion = VersionSMP 18 legacyServerSMPRelayVersion :: VersionSMP legacyServerSMPRelayVersion = VersionSMP 6 currentServerSMPRelayVersion :: VersionSMP -currentServerSMPRelayVersion = VersionSMP 17 +currentServerSMPRelayVersion = VersionSMP 18 -- Max SMP protocol version to be used in e2e encrypted -- connection between client and server, as defined by SMP proxy. @@ -234,7 +239,7 @@ currentServerSMPRelayVersion = VersionSMP 17 -- to prevent client version fingerprinting by the -- destination relays when clients upgrade at different times. proxiedSMPRelayVersion :: VersionSMP -proxiedSMPRelayVersion = VersionSMP 16 +proxiedSMPRelayVersion = VersionSMP 17 -- minimal supported protocol version is 6 -- TODO remove code that supports sending commands without batching diff --git a/src/Simplex/Messaging/Transport/Credentials.hs b/src/Simplex/Messaging/Transport/Credentials.hs index f610ab943..8e3efe795 100644 --- a/src/Simplex/Messaging/Transport/Credentials.hs +++ b/src/Simplex/Messaging/Transport/Credentials.hs @@ -35,8 +35,8 @@ tlsCredentials credentials = (C.KeyHash rootFP, (X509.CertificateChain certs, pr privateToTls :: C.APrivateSignKey -> TLS.PrivKey privateToTls (C.APrivateSignKey _ k) = case k of - C.PrivateKeyEd25519 secret _ -> TLS.PrivKeyEd25519 secret - C.PrivateKeyEd448 secret _ -> TLS.PrivKeyEd448 secret + C.PrivateKeyEd25519 pk -> TLS.PrivKeyEd25519 pk + C.PrivateKeyEd448 pk -> TLS.PrivKeyEd448 pk type Credentials = (C.ASignatureKeyPair, X509.SignedCertificate) diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 57fb11c21..e9f37b1ae 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -245,6 +245,7 @@ safeDecodeUtf8 :: ByteString -> Text safeDecodeUtf8 = decodeUtf8With onError where onError _ _ = Just '?' +{-# INLINE safeDecodeUtf8 #-} timeoutThrow :: MonadUnliftIO m => e -> Int -> ExceptT e m a -> ExceptT e m a timeoutThrow e ms action = ExceptT (sequence <$> (ms `timeout` runExceptT action)) >>= maybe (throwE e) pure diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs index 52e88fdcc..63c493861 100644 --- a/tests/AgentTests/EqInstances.hs +++ b/tests/AgentTests/EqInstances.hs @@ -5,16 +5,16 @@ module AgentTests.EqInstances where import Data.Type.Equality -import Simplex.Messaging.Agent.Protocol (ConnLinkData (..), OwnerAuth (..), UserLinkData (..)) +import Simplex.Messaging.Agent.Protocol (ConnLinkData (..), OwnerAuth (..), UserContactData (..), UserLinkData (..)) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client (ProxiedRelay (..)) -instance Eq SomeConn where +instance (Eq rq, Eq sq) => Eq (SomeConn' rq sq) where SomeConn d c == SomeConn d' c' = case testEquality d d' of Just Refl -> c == c' _ -> False -deriving instance Eq (Connection d) +deriving instance (Eq rq, Eq sq) => Eq (Connection' d rq sq) deriving instance Eq (SConnType d) @@ -22,6 +22,8 @@ deriving instance Eq (StoredRcvQueue s) deriving instance Eq (StoredSndQueue q) +deriving instance Eq RcvQueueSub + deriving instance Eq ClientNtfCreds deriving instance Eq ShortLinkCreds @@ -30,6 +32,10 @@ deriving instance Show (ConnLinkData c) deriving instance Eq (ConnLinkData c) +deriving instance Show UserContactData + +deriving instance Eq UserContactData + deriving instance Show UserLinkData deriving instance Eq UserLinkData diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index a6ee6d7f2..fcdd5be29 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -87,6 +87,8 @@ import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestSte import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), Env (..), InitialAgentServers (..), createAgentStore) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT, INV, JOINED) import qualified Simplex.Messaging.Agent.Protocol as A +import Simplex.Messaging.Agent.Store (Connection' (..), SomeConn' (..), StoredRcvQueue (..)) +import Simplex.Messaging.Agent.Store.AgentStore (getConn) import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction) import Simplex.Messaging.Agent.Store.Interface import qualified Simplex.Messaging.Agent.Store.DB as DB @@ -100,10 +102,12 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF) import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, NetworkError (..), ProtocolServer (..), SubscriptionMode (..), initialSMPClientVersion, srvHostnamesSMPClientVersion, supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Protocol.Types import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..), ServerStoreCfg (..), StorePaths (..)) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.MsgStore.Types (SMSType (..), SQSType (..)) import Simplex.Messaging.Server.QueueStore.QueueInfo +import Simplex.Messaging.Server.StoreLog (StoreLogRecord (..)) import Simplex.Messaging.Transport (ASrvTransport, SMPVersion, VersionSMP, authCmdsSMPVersion, currentServerSMPRelayVersion, minClientSMPRelayVersion, minServerSMPRelayVersion, sendingProxySMPVersion, sndAuthKeySMPVersion, alpnSupportedSMPHandshakes, supportedServerSMPRelayVRange) import Simplex.Messaging.Util (bshow, diffToMicroseconds) import Simplex.Messaging.Version (VersionRange (..)) @@ -120,7 +124,7 @@ import Fixtures #endif #if defined(dbServerPostgres) import qualified Database.PostgreSQL.Simple as PSQL -import Simplex.Messaging.Agent.Store (Connection (..), StoredRcvQueue (..), SomeConn (..)) +import Simplex.Messaging.Agent.Store (Connection' (..), StoredRcvQueue (..), SomeConn' (..)) import Simplex.Messaging.Agent.Store.AgentStore (getConn) import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) import Simplex.Messaging.Server.MsgStore.Postgres (PostgresQueue) @@ -278,7 +282,7 @@ inAnyOrder g rs = withFrozenCallStack $ do createConnection :: ConnectionModeI c => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> AE (ConnId, ConnectionRequestUri c) createConnection c userId enableNtfs cMode clientData subMode = do - (connId, (CCLink cReq _, Nothing)) <- A.createConnection c NRMInteractive userId enableNtfs cMode Nothing clientData IKPQOn subMode + (connId, (CCLink cReq _, Nothing)) <- A.createConnection c NRMInteractive userId enableNtfs True cMode Nothing clientData IKPQOn subMode pure (connId, cReq) joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE (ConnId, SndQueueSecured) @@ -308,7 +312,7 @@ deleteConnections c = A.deleteConnections c NRMInteractive getConnShortLink :: AgentClient -> UserId -> ConnShortLink c -> AE (ConnectionRequestUri c, ConnLinkData c) getConnShortLink c = A.getConnShortLink c NRMInteractive -setConnShortLink :: AgentClient -> ConnId -> SConnectionMode c -> UserLinkData -> Maybe CRClientData -> AE (ConnShortLink c) +setConnShortLink :: AgentClient -> ConnId -> SConnectionMode c -> UserConnLinkData c -> Maybe CRClientData -> AE (ConnShortLink c) setConnShortLink c = A.setConnShortLink c NRMInteractive suspendConnection :: AgentClient -> ConnId -> AE () @@ -365,13 +369,13 @@ functionalAPITests ps = do it "should connect after errors" $ testContactErrors ps False it "should connect after errors with client restarts" $ testContactErrors ps True describe "Short connection links" $ do - describe "should connect via 1-time short link" $ testProxyMatrix ps testInviationShortLink - describe "should connect via 1-time short link with async join" $ testProxyMatrix ps testInviationShortLinkAsync + describe "should connect via 1-time short link" $ testProxyMatrix ps testInvitationShortLink + describe "should connect via 1-time short link with async join" $ testProxyMatrix ps testInvitationShortLinkAsync describe "should connect via contact short link" $ testProxyMatrix ps testContactShortLink describe "should add short link to existing contact and connect" $ testProxyMatrix ps testAddContactShortLink - xdescribe "try to create 1-time short link with prev versions" $ testProxyMatrixWithPrev ps testInviationShortLinkPrev + xdescribe "try to create 1-time short link with prev versions" $ testProxyMatrixWithPrev ps testInvitationShortLinkPrev describe "server restart" $ do - it "should get 1-time link data after restart" $ testInviationShortLinkRestart ps + it "should get 1-time link data after restart" $ testInvitationShortLinkRestart ps it "should connect via contact short link after restart" $ testContactShortLinkRestart ps it "should connect via added contact short link after restart" $ testAddContactShortLinkRestart ps it "should create and get short links with the old contact queues" $ testOldContactQueueShortLink ps @@ -435,7 +439,7 @@ functionalAPITests ps = do describe "Batching SMP commands" $ do -- disable this and enable the following test to run tests with coverage it "should subscribe to multiple (200) subscriptions with batching" $ - testBatchedSubscriptions 200 10 ps + testBatchedSubscriptions 200 20 ps skip "faster version of the previous test (200 subscriptions gets very slow with test coverage)" $ it "should subscribe to multiple (6) subscriptions with batching" $ testBatchedSubscriptions 6 3 ps @@ -540,6 +544,10 @@ functionalAPITests ps = do describe "SMP queue info" $ do it "server should respond with queue and subscription information" $ withSmpServer ps testServerQueueInfo +#if !defined(dbServerPostgres) + describe "Client notices" $ do + it "should create client notice" $ testClientNotice ps +#endif testBasicAuth :: (ASrvTransport, AStoreType) -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> SndQueueSecured -> AgentMsgId -> IO Int testBasicAuth (t, msType) allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 sqSecured baseId = do @@ -700,7 +708,7 @@ runAgentClientTest pqSupport sqSecured viaProxy alice bob baseId = runAgentClientTestPQ :: HasCallStack => SndQueueSecured -> Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientTestPQ sqSecured viaProxy (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMInvitation Nothing Nothing aPQ SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ (sqSecured', Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured @@ -902,7 +910,7 @@ runAgentClientContactTest pqSupport sqSecured viaProxy alice bob baseId = runAgentClientContactTestPQ :: HasCallStack => SndQueueSecured -> Bool -> PQSupport -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, bPQ) baseId = runRight_ $ do - (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMContact Nothing Nothing aPQ SMSubscribe + (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing aPQ SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo bPQ (sqSecuredJoin, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" bPQ SMSubscribe liftIO $ sqSecuredJoin `shouldBe` False -- joining via contact address connection @@ -946,7 +954,7 @@ runAgentClientContactTestPQ sqSecured viaProxy reqPQSupport (alice, aPQ) (bob, b runAgentClientContactTestPQ3 :: HasCallStack => Bool -> (AgentClient, InitialKeys) -> (AgentClient, PQSupport) -> (AgentClient, PQSupport) -> AgentMsgId -> IO () runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId = runRight_ $ do - (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMContact Nothing Nothing aPQ SMSubscribe + (_, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing aPQ SMSubscribe (bAliceId, bobId, abPQEnc) <- connectViaContact bob bPQ qInfo sentMessages abPQEnc alice bobId bob bAliceId (tAliceId, tomId, atPQEnc) <- connectViaContact tom tPQ qInfo @@ -999,7 +1007,7 @@ noMessages_ ingoreQCONT c err = tryGet `shouldReturn` () testRejectContactRequest :: HasCallStack => IO () testRejectContactRequest = withAgentClients2 $ \alice bob -> runRight_ $ do - (_addrConnId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMContact Nothing Nothing IKPQOn SMSubscribe + (_addrConnId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn (sqSecured, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` False -- joining via contact address connection @@ -1321,11 +1329,12 @@ withServer1 ps = withSmpServerStoreLogOn ps testPort . const withServer2 :: (ASrvTransport, AStoreType) -> IO a -> IO a withServer2 (t, ASType qsType _) = withSmpServerConfigOn t (cfgJ2QS qsType) testPort2 . const -testInviationShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () -testInviationShortLink viaProxy a b = +testInvitationShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () +testInvitationShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do let userData = UserLinkData "some user data" - (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMSubscribe + newLinkData = UserInvLinkData userData + (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMSubscribe (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq @@ -1356,17 +1365,19 @@ testJoinConn_ viaProxy sndSecure a bId b connReq = do get b ##> ("", aId, CON) exchangeGreetingsViaProxy viaProxy a bId b aId -testInviationShortLinkPrev :: HasCallStack => Bool -> Bool -> AgentClient -> AgentClient -> IO () -testInviationShortLinkPrev viaProxy sndSecure a b = runRight_ $ do +testInvitationShortLinkPrev :: HasCallStack => Bool -> Bool -> AgentClient -> AgentClient -> IO () +testInvitationShortLinkPrev viaProxy sndSecure a b = runRight_ $ do let userData = UserLinkData "some user data" + newLinkData = UserInvLinkData userData -- can't create short link with previous version - (bId, (CCLink connReq Nothing, Nothing)) <- A.createConnection a NRMInteractive 1 True SCMInvitation (Just userData) Nothing CR.IKPQOn SMSubscribe + (bId, (CCLink connReq Nothing, Nothing)) <- A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKPQOn SMSubscribe testJoinConn_ viaProxy sndSecure a bId b connReq -testInviationShortLinkAsync :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () -testInviationShortLinkAsync viaProxy a b = do +testInvitationShortLinkAsync :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () +testInvitationShortLinkAsync viaProxy a b = do let userData = UserLinkData "some user data" - (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMSubscribe + newLinkData = UserInvLinkData userData + (bId, (CCLink connReq (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMSubscribe (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq @@ -1381,24 +1392,32 @@ testInviationShortLinkAsync viaProxy a b = do get b ##> ("", aId, CON) exchangeGreetingsViaProxy viaProxy a bId b aId +relayLink1 :: ConnShortLink 'CMContact +relayLink1 = either error id $ strDecode "https://localhost/a#4AkRDmhf64tdRlN406g8lJRg5OCmhD6ynIhi6glOcCM?p=7001&c=LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI" + +relayLink2 :: ConnShortLink 'CMContact +relayLink2 = either error id $ strDecode "https://localhost/a#4AkRDmhf64tdRlN406g8lJRg5OCmhD6ynIhi6glOcCM" + testContactShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () testContactShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do let userData = UserLinkData "some user data" - (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True SCMContact (Just userData) Nothing CR.IKPQOn SMSubscribe + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + newLinkData = UserContactLinkData userCtData + (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMContact (Just newLinkData) Nothing CR.IKPQOn SMSubscribe Right connReq <- pure $ smpDecode (smpEncode connReq0) - (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink + (connReq', ContactLinkData _ userCtData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq - linkUserData connData' `shouldBe` userData + userCtData' `shouldBe` userCtData -- same user can get contact link again - (connReq2, connData2) <- runRight $ getConnShortLink b 1 shortLink + (connReq2, ContactLinkData _ userCtData2) <- runRight $ getConnShortLink b 1 shortLink connReq2 `shouldBe` connReq - linkUserData connData2 `shouldBe` userData + userCtData2 `shouldBe` userCtData -- another user can get the same contact link - (connReq3, connData3) <- runRight $ getConnShortLink c 1 shortLink + (connReq3, ContactLinkData _ userCtData3) <- runRight $ getConnShortLink c 1 shortLink connReq3 `shouldBe` connReq - linkUserData connData3 `shouldBe` userData + userCtData3 `shouldBe` userCtData runRight $ do (aId, sndSecure) <- joinConnection b 1 True connReq "bob's connInfo" SMSubscribe liftIO $ sndSecure `shouldBe` False @@ -1414,13 +1433,15 @@ testContactShortLink viaProxy a b = exchangeGreetingsViaProxy viaProxy a bId b aId -- update user data let updatedData = UserLinkData "updated user data" - shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedData Nothing + updatedCtData = UserContactData {direct = False, owners = [], relays = [relayLink1, relayLink2], userData = updatedData} + userLinkData' = UserContactLinkData updatedCtData + shortLink' <- runRight $ setConnShortLink a contactId SCMContact userLinkData' Nothing shortLink' `shouldBe` shortLink - (connReq4, updatedConnData') <- runRight $ getConnShortLink c 1 shortLink + (connReq4, ContactLinkData _ updatedCtData') <- runRight $ getConnShortLink c 1 shortLink connReq4 `shouldBe` connReq - linkUserData updatedConnData' `shouldBe` updatedData + updatedCtData' `shouldBe` updatedCtData -- one more time - shortLink2 <- runRight $ setConnShortLink a contactId SCMContact updatedData Nothing + shortLink2 <- runRight $ setConnShortLink a contactId SCMContact userLinkData' Nothing shortLink2 `shouldBe` shortLink -- delete short link runRight_ $ deleteConnShortLink a NRMInteractive contactId SCMContact @@ -1430,22 +1451,24 @@ testContactShortLink viaProxy a b = testAddContactShortLink :: HasCallStack => Bool -> AgentClient -> AgentClient -> IO () testAddContactShortLink viaProxy a b = withAgent 3 agentCfg initAgentServers testDB3 $ \c -> do - (contactId, (CCLink connReq0 Nothing, Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True SCMContact Nothing Nothing CR.IKPQOn SMSubscribe + (contactId, (CCLink connReq0 Nothing, Nothing)) <- runRight $ A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMSubscribe Right connReq <- pure $ smpDecode (smpEncode connReq0) -- let userData = UserLinkData "some user data" - shortLink <- runRight $ setConnShortLink a contactId SCMContact userData Nothing - (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + newLinkData = UserContactLinkData userCtData + shortLink <- runRight $ setConnShortLink a contactId SCMContact newLinkData Nothing + (connReq', ContactLinkData _ userCtData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq - linkUserData connData' `shouldBe` userData + userCtData' `shouldBe` userCtData -- same user can get contact link again - (connReq2, connData2) <- runRight $ getConnShortLink b 1 shortLink + (connReq2, ContactLinkData _ userCtData2) <- runRight $ getConnShortLink b 1 shortLink connReq2 `shouldBe` connReq - linkUserData connData2 `shouldBe` userData + userCtData2 `shouldBe` userCtData -- another user can get the same contact link - (connReq3, connData3) <- runRight $ getConnShortLink c 1 shortLink + (connReq3, ContactLinkData _ userCtData3) <- runRight $ getConnShortLink c 1 shortLink connReq3 `shouldBe` connReq - linkUserData connData3 `shouldBe` userData + userCtData3 `shouldBe` userCtData runRight $ do (aId, sndSecure) <- joinConnection b 1 True connReq "bob's connInfo" SMSubscribe liftIO $ sndSecure `shouldBe` False @@ -1461,17 +1484,20 @@ testAddContactShortLink viaProxy a b = exchangeGreetingsViaProxy viaProxy a bId b aId -- update user data let updatedData = UserLinkData "updated user data" - shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedData Nothing + updatedCtData = UserContactData {direct = False, owners = [], relays = [relayLink1, relayLink2], userData = updatedData} + userLinkData' = UserContactLinkData updatedCtData + shortLink' <- runRight $ setConnShortLink a contactId SCMContact userLinkData' Nothing shortLink' `shouldBe` shortLink - (connReq4, updatedConnData') <- runRight $ getConnShortLink c 1 shortLink + (connReq4, ContactLinkData _ updatedCtData') <- runRight $ getConnShortLink c 1 shortLink connReq4 `shouldBe` connReq - linkUserData updatedConnData' `shouldBe` updatedData + updatedCtData' `shouldBe` updatedCtData -testInviationShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () -testInviationShortLinkRestart ps = withAgentClients2 $ \a b -> do +testInvitationShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testInvitationShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = UserLinkData "some user data" + newLinkData = UserInvLinkData userData (bId, (CCLink connReq (Just shortLink), Nothing)) <- withSmpServer ps $ - runRight $ A.createConnection a NRMInteractive 1 True SCMInvitation (Just userData) Nothing CR.IKUsePQ SMOnlyCreate + runRight $ A.createConnection a NRMInteractive 1 True True SCMInvitation (Just newLinkData) Nothing CR.IKUsePQ SMOnlyCreate withSmpServer ps $ do runRight_ $ subscribeConnection a bId (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink @@ -1482,48 +1508,56 @@ testInviationShortLinkRestart ps = withAgentClients2 $ \a b -> do testContactShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testContactShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = UserLinkData "some user data" + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + newLinkData = UserContactLinkData userCtData (contactId, (CCLink connReq0 (Just shortLink), Nothing)) <- withSmpServer ps $ - runRight $ A.createConnection a NRMInteractive 1 True SCMContact (Just userData) Nothing CR.IKPQOn SMOnlyCreate + runRight $ A.createConnection a NRMInteractive 1 True True SCMContact (Just newLinkData) Nothing CR.IKPQOn SMOnlyCreate Right connReq <- pure $ smpDecode (smpEncode connReq0) let updatedData = UserLinkData "updated user data" + updatedCtData = UserContactData {direct = False, owners = [], relays = [relayLink1, relayLink2], userData = updatedData} + updatedLinkData = UserContactLinkData updatedCtData withSmpServer ps $ do - (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink + (connReq', ContactLinkData _ userCtData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq - linkUserData connData' `shouldBe` userData + userCtData' `shouldBe` userCtData -- update user data - shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedData Nothing + shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedLinkData Nothing shortLink' `shouldBe` shortLink withSmpServer ps $ do - (connReq4, updatedConnData') <- runRight $ getConnShortLink b 1 shortLink + (connReq4, ContactLinkData _ updatedCtData') <- runRight $ getConnShortLink b 1 shortLink connReq4 `shouldBe` connReq - linkUserData updatedConnData' `shouldBe` updatedData + updatedCtData' `shouldBe` updatedCtData testAddContactShortLinkRestart :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testAddContactShortLinkRestart ps = withAgentClients2 $ \a b -> do let userData = UserLinkData "some user data" + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + newLinkData = UserContactLinkData userCtData ((contactId, (CCLink connReq0 Nothing, Nothing)), shortLink) <- withSmpServer ps $ runRight $ do - r@(contactId, _) <- A.createConnection a NRMInteractive 1 True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate - (r,) <$> setConnShortLink a contactId SCMContact userData Nothing + r@(contactId, _) <- A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate + (r,) <$> setConnShortLink a contactId SCMContact newLinkData Nothing Right connReq <- pure $ smpDecode (smpEncode connReq0) let updatedData = UserLinkData "updated user data" + updatedCtData = UserContactData {direct = False, owners = [], relays = [relayLink1, relayLink2], userData = updatedData} + updatedLinkData = UserContactLinkData updatedCtData withSmpServer ps $ do - (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink + (connReq', ContactLinkData _ userCtData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq - linkUserData connData' `shouldBe` userData + userCtData' `shouldBe` userCtData -- update user data - shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedData Nothing + shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedLinkData Nothing shortLink' `shouldBe` shortLink withSmpServer ps $ do - (connReq4, updatedConnData') <- runRight $ getConnShortLink b 1 shortLink + (connReq4, ContactLinkData _ updatedCtData') <- runRight $ getConnShortLink b 1 shortLink connReq4 `shouldBe` connReq - linkUserData updatedConnData' `shouldBe` updatedData + updatedCtData' `shouldBe` updatedCtData testOldContactQueueShortLink :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testOldContactQueueShortLink ps@(_, msType) = withAgentClients2 $ \a b -> do (contactId, (CCLink connReq Nothing, Nothing)) <- withSmpServer ps $ runRight $ - A.createConnection a NRMInteractive 1 True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate + A.createConnection a NRMInteractive 1 True True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate -- make it an "old" queue let updateStoreLog f = replaceSubstringInFile f " queue_mode=C" "" #if defined(dbServerPostgres) @@ -1552,19 +1586,23 @@ testOldContactQueueShortLink ps@(_, msType) = withAgentClients2 $ \a b -> do withSmpServer ps $ do let userData = UserLinkData "some user data" - shortLink <- runRight $ setConnShortLink a contactId SCMContact userData Nothing - (connReq', connData') <- runRight $ getConnShortLink b 1 shortLink + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + userLinkData = UserContactLinkData userCtData + shortLink <- runRight $ setConnShortLink a contactId SCMContact userLinkData Nothing + (connReq', ContactLinkData _ userCtData') <- runRight $ getConnShortLink b 1 shortLink strDecode (strEncode shortLink) `shouldBe` Right shortLink connReq' `shouldBe` connReq - linkUserData connData' `shouldBe` userData + userCtData' `shouldBe` userCtData -- update user data let updatedData = UserLinkData "updated user data" - shortLink' <- runRight $ setConnShortLink a contactId SCMContact updatedData Nothing + updatedCtData = UserContactData {direct = False, owners = [], relays = [relayLink1, relayLink2], userData = updatedData} + userLinkData' = UserContactLinkData updatedCtData + shortLink' <- runRight $ setConnShortLink a contactId SCMContact userLinkData' Nothing shortLink' `shouldBe` shortLink -- check updated - (connReq'', updatedConnData') <- runRight $ getConnShortLink b 1 shortLink + (connReq'', ContactLinkData _ updatedCtData') <- runRight $ getConnShortLink b 1 shortLink connReq'' `shouldBe` connReq - linkUserData updatedConnData' `shouldBe` updatedData + updatedCtData' `shouldBe` updatedCtData replaceSubstringInFile :: FilePath -> T.Text -> T.Text -> IO () replaceSubstringInFile filePath oldText newText = do @@ -2263,7 +2301,7 @@ makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn True makeConnectionForUsers_ :: HasCallStack => PQSupport -> SndQueueSecured -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers_ pqSupport sqSecured alice aliceUserId bob bobUserId = do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive aliceUserId True SCMInvitation Nothing Nothing (IKLinkPQ pqSupport) SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive aliceUserId True True SCMInvitation Nothing Nothing (IKLinkPQ pqSupport) SMSubscribe aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport (sqSecured', Nothing) <- A.joinConnection bob NRMInteractive bobUserId aliceId True qInfo "bob's connInfo" pqSupport SMSubscribe liftIO $ sqSecured' `shouldBe` sqSecured @@ -2391,8 +2429,8 @@ testSuspendingAgentTimeout ps = withAgentClients2 $ \a b -> do pure () testBatchedSubscriptions :: Int -> Int -> (ASrvTransport, AStoreType) -> IO () -testBatchedSubscriptions nCreate nDel ps@(t, ASType qsType _) = - withAgentClientsCfgServers2 agentCfg agentCfg initAgentServers2 $ \a b -> do +testBatchedSubscriptions nCreate nDel ps@(t, ASType qsType _) = do + (conns, conns') <- withAgentClientsCfgServers2 agentCfg agentCfg initAgentServers2 $ \a b -> do conns <- runServers $ do conns <- replicateM nCreate $ makeConnection_ PQSupportOff True a b forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId @@ -2401,21 +2439,23 @@ testBatchedSubscriptions nCreate nDel ps@(t, ASType qsType _) = delete b aIds' liftIO $ threadDelay 1000000 pure conns - ("", "", DOWN {}) <- nGet a - ("", "", DOWN {}) <- nGet a - ("", "", DOWN {}) <- nGet b - ("", "", DOWN {}) <- nGet b + let conns' = drop nDel conns + (aIds', bIds') = unzip conns' + down a bIds' + down b aIds' + runServers $ do + up a bIds' + up b aIds' + down a bIds' + down b aIds' + pure (conns, conns') + withAgentClientsCfgServers2 agentCfg agentCfg initAgentServers2 $ \a b -> do runServers $ do - ("", "", UP {}) <- nGet a - ("", "", UP {}) <- nGet a - ("", "", UP {}) <- nGet b - ("", "", UP {}) <- nGet b liftIO $ threadDelay 1000000 let (aIds, bIds) = unzip conns - conns' = drop nDel conns (aIds', bIds') = unzip conns' - subscribe a bIds - subscribe b aIds + subscribe a bIds' + subscribe b aIds' forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId_ PQEncOff 4 a bId b aId void $ resubscribeConnections a bIds void $ resubscribeConnections b aIds @@ -2425,14 +2465,18 @@ testBatchedSubscriptions nCreate nDel ps@(t, ASType qsType _) = deleteFail a bIds' deleteFail b aIds' where + down c cs = do + ("", "", DOWN _ cs1) <- nGet c + ("", "", DOWN _ cs2) <- nGet c + liftIO $ S.fromList (cs1 ++ cs2) `shouldBe` S.fromList cs + up c cs = do + ("", "", UP _ cs1) <- nGet c + ("", "", UP _ cs2) <- nGet c + liftIO $ S.fromList (cs1 ++ cs2) `shouldBe` S.fromList cs subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO () subscribe c cs = do - r <- subscribeConnections c cs - liftIO $ do - let dc = S.fromList $ take nDel cs - all isRight (M.withoutKeys r dc) `shouldBe` True - all (== Left (CONN NOT_FOUND "")) (M.restrictKeys r dc) `shouldBe` True - M.keys r `shouldMatchList` cs + subscribeAllConnections c False Nothing + liftIO $ up c cs delete :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO () delete c cs = do r <- deleteConnections c cs @@ -2462,8 +2506,10 @@ testBatchedPendingMessages nCreate nMsgs = runRight_ $ forM_ msgConns $ \(_, bId) -> sendMessage a bId SMP.noMsgFlags "hello" replicateM_ nMsgs $ get a =##> \case ("", cId, SENT _) -> isJust $ find ((cId ==) . snd) msgConns; _ -> False withB $ \b -> runRight_ $ do - r <- subscribeConnections b $ map fst conns - liftIO $ all isRight r `shouldBe` True + let aIds = map fst conns + subscribeAllConnections b False Nothing + ("", "", UP _ aIds') <- nGet b + liftIO $ S.fromList aIds' `shouldBe` S.fromList aIds replicateM_ nMsgs $ do ("", cId, Msg' msgId _ "hello") <- get b liftIO $ isJust (find ((cId ==) . fst) msgConns) `shouldBe` True @@ -3569,6 +3615,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do liftIO $ threadDelay 250000 ("", "", DOWN _ _) <- nGet a ("", "", UP _ _) <- nGet a + ("", "", UP _ _) <- nGet a a `hasClients` 2 exchangeGreetingsMsgId 4 a bId1 b aId1 @@ -3595,6 +3642,8 @@ testTwoUsers = withAgentClients2 $ \a b -> do ("", "", DOWN _ _) <- nGet a ("", "", UP _ _) <- nGet a ("", "", UP _ _) <- nGet a + ("", "", UP _ _) <- nGet a + ("", "", UP _ _) <- nGet a a `hasClients` 4 exchangeGreetingsMsgId 6 a bId1 b aId1 exchangeGreetingsMsgId 6 a bId1' b aId1' @@ -3826,6 +3875,76 @@ testServerQueueInfo = do qDelivered <$> qiSub `shouldBe` Just msgId_ pure msgId_ +testClientNotice :: HasCallStack => (ASrvTransport, AStoreType) -> IO () +testClientNotice ps = do + withAgent 1 agentCfg initAgentServers testDB $ \c -> do + (cId, _) <- withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ + A.createConnection c NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + ("", "", DOWN _ [_]) <- nGet c + + addNotice c cId $ Just 1 + + (cId', _) <- withSmpServerStoreLogOn ps testPort $ \_ -> do + subscribedWithErrors c 1 + testNotice c True + threadDelay 1000000 + runRight $ A.createConnection c NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + ("", "", DOWN _ [_]) <- nGet c + + addNotice c cId' $ Just 1 + + (cId'', _) <- withSmpServerStoreLogOn ps testPort $ \_ -> do + subscribedWithErrors c 1 + testNotice c True + threadDelay 1000000 + testNotice c True + threadDelay 1000000 + runRight $ A.createConnection c NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + + addNotice c cId'' $ Just 1 + + withAgent 1 agentCfg initAgentServers testDB $ \c -> do + (cId3, _) <- withSmpServerStoreLogOn ps testPort $ \_ -> do + runRight_ $ subscribeAllConnections c False Nothing + subscribedWithErrors c 3 + testNotice c True + threadDelay 2000000 + testNotice c True + threadDelay 1000000 + runRight $ A.createConnection c NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + ("", "", DOWN _ [_]) <- nGet c + + addNotice c cId3 Nothing + + withSmpServerStoreLogOn ps testPort $ \_ -> do + subscribedWithErrors c 1 + testNotice c False + + removeNotice c cId3 + + withAgent 1 agentCfg initAgentServers testDB $ \c -> do + withSmpServerStoreLogOn ps testPort $ \_ -> do + runRight_ $ subscribeAllConnections c False Nothing + subscribedWithErrors c 4 + void $ runRight $ A.createConnection c NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + where + addNotice c cId ttl = logNotice c cId $ Just ClientNotice {ttl} + removeNotice c cId = logNotice c cId Nothing + logNotice :: AgentClient -> ConnId -> Maybe ClientNotice -> IO () + logNotice c cId notice = do + Right (SomeConn _ (ContactConnection _ RcvQueue {rcvId})) <- withTransaction (store $ agentEnv c) (`getConn` cId) + withFile testStoreLogFile AppendMode $ \h -> B.hPutStrLn h $ strEncode $ BlockQueue rcvId $ SMP.BlockingInfo SMP.BRContent notice + subscribedWithErrors c n = do + ("", "", ERRS errs) <- nGet c + length errs `shouldBe` n + forM_ errs $ \case + (_, SMP _ (BLOCKED _)) -> pure () + r -> expectationFailure $ "unexpected event: " <> show r + testNotice :: HasCallStack => AgentClient -> Bool -> IO () + testNotice c willExpire = do + NOTICE "localhost" False expiresAt_ <- runLeft $ A.createConnection c NRMInteractive 1 True True SCMContact Nothing Nothing IKPQOn SMSubscribe + isJust expiresAt_ `shouldBe` willExpire + noNetworkDelay :: AgentClient -> IO () noNetworkDelay a = do d <- waitNetwork a diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 257c3f90f..dff79c861 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -204,9 +204,6 @@ cData1 = testPrivateAuthKey :: C.APrivateAuthKey testPrivateAuthKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe" -testPublicAuthKey :: C.APublicAuthKey -testPublicAuthKey = C.APublicAuthKey C.SEd25519 (C.publicKey "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe") - testPrivDhKey :: C.PrivateKeyX25519 testPrivDhKey = "MC4CAQAwBQYDK2VuBCIEINCzbVFaCiYHoYncxNY8tSIfn0pXcIAhLBfFc0m+gOpk" @@ -232,6 +229,8 @@ rcvQueue1 = shortLink = Nothing, clientService = Nothing, status = New, + enableNtfs = True, + clientNoticeId = Nothing, dbQueueId = DBNewEntity, primary = True, dbReplaceQueueId = Nothing, @@ -249,7 +248,6 @@ sndQueue1 = server = smpServer1, sndId = EntityId "3456", queueMode = Just QMMessaging, - sndPublicKey = testPublicAuthKey, sndPrivateKey = testPrivateAuthKey, e2ePubKey = Nothing, e2eDhSecret = testDhSecret, @@ -264,7 +262,7 @@ sndQueue1 = createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) createRcvConn db g cData rq cMode = runExceptT $ do connId <- ExceptT $ createNewConn db g cData cMode - rq' <- ExceptT $ updateNewConnRcv db connId rq + rq' <- ExceptT $ updateNewConnRcv db connId rq SMSubscribe pure (connId, rq') testCreateRcvConn :: SpecWith DBStore @@ -310,7 +308,7 @@ testCreateSndConn = dbQueueId `shouldBe` DBEntityId 1 getConn db "conn1" `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq)) - Right rq@RcvQueue {dbQueueId = dbQueueId'} <- upgradeSndConnToDuplex db "conn1" rcvQueue1 + Right rq@RcvQueue {dbQueueId = dbQueueId'} <- upgradeSndConnToDuplex db "conn1" rcvQueue1 SMSubscribe dbQueueId' `shouldBe` DBEntityId 1 getConn db "conn1" `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq])) @@ -322,7 +320,7 @@ testCreateSndConnRandomID = Right (connId, sq) <- createSndConn db g cData1 {connId = ""} sndQueue1 getConn db connId `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sq)) - Right (rq@RcvQueue {dbQueueId = dbQueueId'}) <- upgradeSndConnToDuplex db connId rcvQueue1 + Right (rq@RcvQueue {dbQueueId = dbQueueId'}) <- upgradeSndConnToDuplex db connId rcvQueue1 SMSubscribe dbQueueId' `shouldBe` DBEntityId 1 getConn db connId `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq] [sq])) @@ -409,7 +407,6 @@ testUpgradeRcvConnToDuplex = server = SMPServer "smp.simplex.im" "5223" testKeyHash, sndId = EntityId "2345", queueMode = Just QMMessaging, - sndPublicKey = testPublicAuthKey, sndPrivateKey = testPrivateAuthKey, e2ePubKey = Nothing, e2eDhSecret = testDhSecret, @@ -422,7 +419,7 @@ testUpgradeRcvConnToDuplex = } upgradeRcvConnToDuplex db "conn1" anotherSndQueue `shouldReturn` Left (SEBadConnType "upgradeRcvConnToDuplex" CSnd) - _ <- upgradeSndConnToDuplex db "conn1" rcvQueue1 + _ <- upgradeSndConnToDuplex db "conn1" rcvQueue1 SMSubscribe upgradeRcvConnToDuplex db "conn1" anotherSndQueue `shouldReturn` Left (SEBadConnType "upgradeRcvConnToDuplex" CDuplex) @@ -446,6 +443,8 @@ testUpgradeSndConnToDuplex = shortLink = Nothing, clientService = Nothing, status = New, + enableNtfs = True, + clientNoticeId = Nothing, dbQueueId = DBNewEntity, rcvSwchStatus = Nothing, primary = True, @@ -454,10 +453,10 @@ testUpgradeSndConnToDuplex = clientNtfCreds = Nothing, deleteErrors = 0 } - upgradeSndConnToDuplex db "conn1" anotherRcvQueue + upgradeSndConnToDuplex db "conn1" anotherRcvQueue SMSubscribe `shouldReturn` Left (SEBadConnType "upgradeSndConnToDuplex" CRcv) _ <- upgradeRcvConnToDuplex db "conn1" sndQueue1 - upgradeSndConnToDuplex db "conn1" anotherRcvQueue + upgradeSndConnToDuplex db "conn1" anotherRcvQueue SMSubscribe `shouldReturn` Left (SEBadConnType "upgradeSndConnToDuplex" CDuplex) testSetRcvQueueStatus :: SpecWith DBStore @@ -470,7 +469,7 @@ testSetRcvQueueStatus = setRcvQueueStatus db rq Confirmed `shouldReturn` () getConn db "conn1" - `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq {status = Confirmed})) + `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 (rq {status = Confirmed} :: RcvQueue))) testSetSndQueueStatus :: SpecWith DBStore testSetSndQueueStatus = @@ -482,7 +481,7 @@ testSetSndQueueStatus = setSndQueueStatus db sq Confirmed `shouldReturn` () getConn db "conn1" - `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq {status = Confirmed})) + `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 (sq {status = Confirmed} :: SndQueue))) testSetQueueStatusDuplex :: SpecWith DBStore testSetQueueStatusDuplex = @@ -569,7 +568,7 @@ testCreateSndMsg_ db expectedPrevHash connId sq sndMsgData@SndMsgData {..} = do `shouldReturn` Right (internalId, internalSndId, expectedPrevHash) createSndMsg db connId sndMsgData `shouldReturn` () - createSndMsgDelivery db connId sq internalId + createSndMsgDelivery db sq internalId `shouldReturn` () testCreateSndMsg :: SpecWith DBStore @@ -642,7 +641,7 @@ testReopenEncryptedStoreKeepKey = do hasMigrations st getMigrations :: DBStore -> IO Bool -getMigrations st = not . null <$> withTransaction st getCurrentMigrations +getMigrations st = not . null <$> withTransaction st (getCurrentMigrations Nothing) hasMigrations :: DBStore -> Expectation hasMigrations st = getMigrations st `shouldReturn` True @@ -684,7 +683,7 @@ testGetPendingServerCommand st = do Right (Just PendingCommand {corrId}) <- getPendingServerCommand db connId Nothing corrId `shouldBe` "2" - Right _ <- updateNewConnRcv db connId rcvQueue1 + Right _ <- updateNewConnRcv db connId rcvQueue1 SMSubscribe Right Nothing <- getPendingServerCommand db connId $ Just smpServer1 Right () <- createCommand db "3" connId (Just smpServer1) command corruptCmd db "3" connId diff --git a/tests/AgentTests/SchemaDump.hs b/tests/AgentTests/SchemaDump.hs index fdb172883..1f83973e6 100644 --- a/tests/AgentTests/SchemaDump.hs +++ b/tests/AgentTests/SchemaDump.hs @@ -76,14 +76,14 @@ testSchemaMigrations = do putStrLn $ "down migration " <> name m let downMigr = fromJust $ toDownMigration m schema <- getSchema testDB testSchema - Migrations.run st True $ MTRUp [m] + Migrations.run st Nothing True $ MTRUp [m] schema' <- getSchema testDB testSchema schema' `shouldNotBe` schema - Migrations.run st True $ MTRDown [downMigr] + Migrations.run st Nothing True $ MTRDown [downMigr] unless (name m `elem` skipComparisonForDownMigrations) $ do schema'' <- getSchema testDB testSchema schema'' `shouldBe` schema - Migrations.run st True $ MTRUp [m] + Migrations.run st Nothing True $ MTRUp [m] schema''' <- getSchema testDB testSchema schema''' `shouldBe` schema' diff --git a/tests/AgentTests/ServerChoice.hs b/tests/AgentTests/ServerChoice.hs index a3d6337f2..a27678cb6 100644 --- a/tests/AgentTests/ServerChoice.hs +++ b/tests/AgentTests/ServerChoice.hs @@ -64,7 +64,8 @@ initServers = ntf = [testNtfServer], xftp = userServers [testXFTPServer], netCfg = defaultNetworkConfig, - presetDomains = [] + presetDomains = [], + presetServers = [] } testChooseDifferentOperator :: IO () diff --git a/tests/AgentTests/ShortLinkTests.hs b/tests/AgentTests/ShortLinkTests.hs index b38daa7d8..9a56cc655 100644 --- a/tests/AgentTests/ShortLinkTests.hs +++ b/tests/AgentTests/ShortLinkTests.hs @@ -1,4 +1,7 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -8,7 +11,8 @@ import AgentTests.ConnectionRequestTests (contactConnRequest, invConnRequest) import AgentTests.EqInstances () import Control.Concurrent.STM import Control.Monad.Except -import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), ConnectionMode (..), LinkKey (..), SConnectionMode (..), SMPAgentError (..), UserLinkData (..), linkUserData, supportedSMPAgentVRange) +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), ConnLinkData (..), ConnectionMode (..), ConnShortLink (..), LinkKey (..), UserConnLinkData (..), SConnectionMode (..), SMPAgentError (..), UserContactData (..), UserLinkData (..), linkUserData, supportedSMPAgentVRange) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.ShortLink as SL import Test.Hspec hiding (fit, it) @@ -31,7 +35,8 @@ testInvShortLink = do g <- C.newRandom sigKeys <- atomically $ C.generateKeyPair @'C.Ed25519 g let userData = UserLinkData "some user data" - (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange invConnRequest userData + userLinkData = UserInvLinkData userData + (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange invConnRequest userLinkData k = SL.invShortLinkKdf linkKey Right srvData <- runExceptT $ SL.encryptLinkData g k linkData -- decrypt @@ -45,7 +50,8 @@ testInvShortLinkBadDataHash = do g <- C.newRandom sigKeys <- atomically $ C.generateKeyPair @'C.Ed25519 g let userData = UserLinkData "some user data" - (_linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange invConnRequest userData + userLinkData = UserInvLinkData userData + (_linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange invConnRequest userLinkData -- different key linkKey <- LinkKey <$> atomically (C.randomBytes 32 g) let k = SL.invShortLinkKdf linkKey @@ -54,19 +60,27 @@ testInvShortLinkBadDataHash = do SL.decryptLinkData @'CMInvitation linkKey k srvData `shouldBe` Left (AGENT (A_LINK "link data hash")) +relayLink1 :: ConnShortLink 'CMContact +relayLink1 = either error id $ strDecode "https://localhost/a#4AkRDmhf64tdRlN406g8lJRg5OCmhD6ynIhi6glOcCM?p=7001&c=LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI" + +relayLink2 :: ConnShortLink 'CMContact +relayLink2 = either error id $ strDecode "https://localhost/a#4AkRDmhf64tdRlN406g8lJRg5OCmhD6ynIhi6glOcCM" + testContactShortLink :: IO () testContactShortLink = do -- encrypt g <- C.newRandom sigKeys <- atomically $ C.generateKeyPair @'C.Ed25519 g let userData = UserLinkData "some user data" - (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userData + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + userLinkData = UserContactLinkData userCtData + (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userLinkData (_linkId, k) = SL.contactShortLinkKdf linkKey Right srvData <- runExceptT $ SL.encryptLinkData g k linkData -- decrypt - Right (connReq, connData') <- pure $ SL.decryptLinkData linkKey k srvData + Right (connReq, ContactLinkData _ userCtData') <- pure $ SL.decryptLinkData @'CMContact linkKey k srvData connReq `shouldBe` contactConnRequest - linkUserData connData' `shouldBe` userData + userCtData' `shouldBe` userCtData testUpdateContactShortLink :: IO () testUpdateContactShortLink = do @@ -74,17 +88,21 @@ testUpdateContactShortLink = do g <- C.newRandom sigKeys <- atomically $ C.generateKeyPair @'C.Ed25519 g let userData = UserLinkData "some user data" - (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userData + userCtData = UserContactData {direct = True, owners = [], relays = [], userData} + userLinkData = UserContactLinkData userCtData + (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userLinkData (_linkId, k) = SL.contactShortLinkKdf linkKey Right (fd, _ud) <- runExceptT $ SL.encryptLinkData g k linkData -- encrypt updated user data let updatedUserData = UserLinkData "updated user data" - signed = SL.encodeSignUserData SCMContact (snd sigKeys) supportedSMPAgentVRange updatedUserData + userCtData' = UserContactData {direct = False, owners = [], relays = [relayLink1, relayLink2], userData = updatedUserData} + userLinkData' = UserContactLinkData userCtData' + signed = SL.encodeSignUserData SCMContact (snd sigKeys) supportedSMPAgentVRange userLinkData' Right ud' <- runExceptT $ SL.encryptUserData g k signed -- decrypt - Right (connReq, connData') <- pure $ SL.decryptLinkData linkKey k (fd, ud') + Right (connReq, ContactLinkData _ userCtData'') <- pure $ SL.decryptLinkData @'CMContact linkKey k (fd, ud') connReq `shouldBe` contactConnRequest - linkUserData connData' `shouldBe` updatedUserData + userCtData'' `shouldBe` userCtData' testContactShortLinkBadDataHash :: IO () testContactShortLinkBadDataHash = do @@ -92,7 +110,8 @@ testContactShortLinkBadDataHash = do g <- C.newRandom sigKeys <- atomically $ C.generateKeyPair @'C.Ed25519 g let userData = UserLinkData "some user data" - (_linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userData + userLinkData = UserContactLinkData UserContactData {direct = True, owners = [], relays = [], userData} + (_linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userLinkData -- different key linkKey <- LinkKey <$> atomically (C.randomBytes 32 g) let (_linkId, k) = SL.contactShortLinkKdf linkKey @@ -107,14 +126,16 @@ testContactShortLinkBadSignature = do g <- C.newRandom sigKeys <- atomically $ C.generateKeyPair @'C.Ed25519 g let userData = UserLinkData "some user data" - (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userData + userLinkData = UserContactLinkData UserContactData {direct = True, owners = [], relays = [], userData} + (linkKey, linkData) = SL.encodeSignLinkData sigKeys supportedSMPAgentVRange contactConnRequest userLinkData (_linkId, k) = SL.contactShortLinkKdf linkKey Right (fd, _ud) <- runExceptT $ SL.encryptLinkData g k linkData -- encrypt updated user data let updatedUserData = UserLinkData "updated user data" + userLinkData' = UserContactLinkData UserContactData {direct = True, owners = [], relays = [], userData = updatedUserData} -- another signature key (_, pk) <- atomically $ C.generateKeyPair @'C.Ed25519 g - let signed = SL.encodeSignUserData SCMContact pk supportedSMPAgentVRange updatedUserData + let signed = SL.encodeSignUserData SCMContact pk supportedSMPAgentVRange userLinkData' Right ud' <- runExceptT $ SL.encryptUserData g k signed -- decryption fails SL.decryptLinkData @'CMContact linkKey k (fd, ud') diff --git a/tests/CoreTests/MsgStoreTests.hs b/tests/CoreTests/MsgStoreTests.hs index 3961a9ce0..48fc7810a 100644 --- a/tests/CoreTests/MsgStoreTests.hs +++ b/tests/CoreTests/MsgStoreTests.hs @@ -43,7 +43,6 @@ import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.MsgStore.Types import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.QueueInfo -import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog (closeStoreLog, logCreateQueue) import System.Directory (copyFile, createDirectoryIfMissing, listDirectory, removeFile, renameFile) import System.FilePath (()) @@ -58,6 +57,7 @@ import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Server.MsgStore.Postgres import Simplex.Messaging.Server.QueueStore.Postgres +import Simplex.Messaging.Server.QueueStore.Types import SMPClient (postgressBracket, testServerDBConnectInfo, testStoreDBOpts) #endif diff --git a/tests/CoreTests/StoreLogTests.hs b/tests/CoreTests/StoreLogTests.hs index 3a898ef6a..01966ba05 100644 --- a/tests/CoreTests/StoreLogTests.hs +++ b/tests/CoreTests/StoreLogTests.hs @@ -24,18 +24,22 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol import Simplex.Messaging.Server.Env.STM (readWriteQueueStore) -import Simplex.Messaging.Server.Main import Simplex.Messaging.Server.MsgStore.Journal import Simplex.Messaging.Server.MsgStore.Types import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..)) import Simplex.Messaging.Server.QueueStore.Types import Simplex.Messaging.Server.StoreLog +import Simplex.Messaging.SystemTime import Simplex.Messaging.Transport (SMPServiceRole (..)) import Simplex.Messaging.Transport.Credentials (genCredentials) import Test.Hspec hiding (fit, it) import Util +#if defined(dbServerPostgres) +import Simplex.Messaging.Server.Main +#endif + testPublicAuthKey :: C.APublicAuthKey testPublicAuthKey = C.APublicAuthKey C.SEd25519 (C.publicKey "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe") diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs deleted file mode 100644 index 1bc0b9f3f..000000000 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ /dev/null @@ -1,213 +0,0 @@ -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# OPTIONS_GHC -Wno-orphans #-} - -module CoreTests.TRcvQueuesTests where - -import AgentTests.EqInstances () -import qualified Data.ByteString.Char8 as B -import qualified Data.List.NonEmpty as L -import qualified Data.Map as M -import qualified Data.Set as S -import Data.String (IsString (..)) -import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId) -import Simplex.Messaging.Agent.Store (RcvQueue, StoredRcvQueue (..)) -import Simplex.Messaging.Agent.Store.Entity -import qualified Simplex.Messaging.Agent.TRcvQueues as RQ -import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Protocol (EntityId (..), QueueMode (..), RecipientId, SMPServer, pattern NoEntity, pattern VersionSMPC) -import Test.Hspec hiding (fit, it) -import UnliftIO -import Util - -tRcvQueuesTests :: Spec -tRcvQueuesTests = do - describe "connection API" $ do - it "hasConn" hasConnTest - it "hasConn, batch add" hasConnTestBatch - it "hasConn, batch idempotent" batchIdempotentTest - it "deleteConn" deleteConnTest - describe "session API" $ do - it "getSessQueues" getSessQueuesTest - it "getDelSessQueues" getDelSessQueuesTest - describe "queue transfer" $ do - it "getDelSessQueues-batchAddQueues preserves total length" removeSubsTest - -instance IsString EntityId where fromString = EntityId . B.pack - -checkDataInvariant :: RQ.Queue q => RQ.TRcvQueues q -> IO Bool -checkDataInvariant trq = atomically $ do - conns <- readTVar $ RQ.getConnections trq - qs <- readTVar $ RQ.getRcvQueues trq - -- three invariant checks - let inv1 = all (\cId -> (S.fromList . L.toList <$> M.lookup cId conns) == Just (M.keysSet (M.filter (\q -> RQ.connId' q == cId) qs))) (M.keys conns) - inv2 = all (\(k, q) -> maybe False ((k `elem`) . L.toList) (M.lookup (RQ.connId' q) conns)) (M.assocs qs) - inv3 = all (\(k, q) -> RQ.qKey q == k) (M.assocs qs) - pure $ inv1 && inv2 && inv3 - -hasConnTest :: IO () -hasConnTest = do - trq <- RQ.empty - atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1") trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2") trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3" "r3") trq - checkDataInvariant trq `shouldReturn` True - atomically (RQ.hasConn "c1" trq) `shouldReturn` True - atomically (RQ.hasConn "c2" trq) `shouldReturn` True - atomically (RQ.hasConn "c3" trq) `shouldReturn` True - atomically (RQ.hasConn "nope" trq) `shouldReturn` False - -hasConnTestBatch :: IO () -hasConnTestBatch = do - trq <- RQ.empty - let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1", dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2", dummyRQ 0 "smp://1234-w==@beta" "c3" "r3"] - atomically $ RQ.batchAddQueues trq qs - checkDataInvariant trq `shouldReturn` True - atomically (RQ.hasConn "c1" trq) `shouldReturn` True - atomically (RQ.hasConn "c2" trq) `shouldReturn` True - atomically (RQ.hasConn "c3" trq) `shouldReturn` True - atomically (RQ.hasConn "nope" trq) `shouldReturn` False - -batchIdempotentTest :: IO () -batchIdempotentTest = do - trq <- RQ.empty - let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1", dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2", dummyRQ 0 "smp://1234-w==@beta" "c3" "r3"] - atomically $ RQ.batchAddQueues trq qs - checkDataInvariant trq `shouldReturn` True - qs' <- readTVarIO $ RQ.getRcvQueues trq - cs' <- readTVarIO $ RQ.getConnections trq - atomically $ RQ.batchAddQueues trq qs - checkDataInvariant trq `shouldReturn` True - readTVarIO (RQ.getRcvQueues trq) `shouldReturn` qs' - fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn` cs' -- connections get duplicated, but that doesn't appear to affect anybody - -deleteConnTest :: IO () -deleteConnTest = do - trq <- RQ.empty - atomically $ do - RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1") trq - RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2") trq - RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3" "r3") trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.deleteConn "c1" trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.deleteConn "nope" trq - checkDataInvariant trq `shouldReturn` True - M.keys <$> readTVarIO (RQ.getConnections trq) `shouldReturn` ["c2", "c3"] - -getSessQueuesTest :: IO () -getSessQueuesTest = do - trq <- RQ.empty - atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1") trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2") trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.addQueue (dummyRQ 0 "smp://1234-w==@beta" "c3" "r3") trq - checkDataInvariant trq `shouldReturn` True - atomically $ RQ.addQueue (dummyRQ 1 "smp://1234-w==@beta" "c4" "r4") trq - checkDataInvariant trq `shouldReturn` True - let tSess1 = (0, "smp://1234-w==@alpha", Just "c1") - RQ.getSessQueues tSess1 trq `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1"] - atomically (RQ.hasSessQueues tSess1 trq) `shouldReturn` True - let tSess2 = (1, "smp://1234-w==@alpha", Just "c1") - RQ.getSessQueues tSess2 trq `shouldReturn` [] - atomically (RQ.hasSessQueues tSess2 trq) `shouldReturn` False - let tSess3 = (0, "smp://1234-w==@alpha", Just "nope") - RQ.getSessQueues tSess3 trq `shouldReturn` [] - atomically (RQ.hasSessQueues tSess3 trq) `shouldReturn` False - let tSess4 = (0, "smp://1234-w==@alpha", Nothing) - RQ.getSessQueues tSess4 trq `shouldReturn` [dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2", dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1"] - atomically (RQ.hasSessQueues tSess4 trq) `shouldReturn` True - -getDelSessQueuesTest :: IO () -getDelSessQueuesTest = do - trq <- RQ.empty - let qs = - [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1"), - ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2"), - ("1", dummyRQ 0 "smp://1234-w==@beta" "c3" "r3"), - ("1", dummyRQ 1 "smp://1234-w==@beta" "c4" "r4") - ] - atomically $ RQ.batchAddQueues trq qs - checkDataInvariant trq `shouldReturn` True - -- no user - atomically (RQ.getDelSessQueues (2, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) - checkDataInvariant trq `shouldReturn` True - -- wrong user - atomically (RQ.getDelSessQueues (1, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([], []) - checkDataInvariant trq `shouldReturn` True - -- connections intact - atomically (RQ.hasConn "c1" trq) `shouldReturn` True - atomically (RQ.hasConn "c2" trq) `shouldReturn` True - atomically (RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" trq) `shouldReturn` ([dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2", dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1"], ["c1", "c2"]) - checkDataInvariant trq `shouldReturn` True - -- connections gone - atomically (RQ.hasConn "c1" trq) `shouldReturn` False - atomically (RQ.hasConn "c2" trq) `shouldReturn` False - -- non-matched connections intact - atomically (RQ.hasConn "c3" trq) `shouldReturn` True - atomically (RQ.hasConn "c4" trq) `shouldReturn` True - -removeSubsTest :: IO () -removeSubsTest = do - aq <- RQ.empty - let qs = - [ ("1", dummyRQ 0 "smp://1234-w==@alpha" "c1" "r1"), - ("1", dummyRQ 0 "smp://1234-w==@alpha" "c2" "r2"), - ("1", dummyRQ 0 "smp://1234-w==@beta" "c3" "r3"), - ("1", dummyRQ 1 "smp://1234-w==@beta" "c4" "r4") - ] - atomically $ RQ.batchAddQueues aq qs - - pq <- RQ.empty - atomically (totalSize aq pq) `shouldReturn` (4, 4) - - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst - atomically (totalSize aq pq) `shouldReturn` (4, 4) - - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst - atomically (totalSize aq pq) `shouldReturn` (4, 4) - - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst - atomically (totalSize aq pq) `shouldReturn` (4, 4) - - atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") "1" aq >>= RQ.batchAddQueues pq . map ("1",) . fst - atomically (totalSize aq pq) `shouldReturn` (4, 4) - -totalSize :: RQ.TRcvQueues q -> RQ.TRcvQueues q -> STM (Int, Int) -totalSize a b = do - qsizeA <- M.size <$> readTVar (RQ.getRcvQueues a) - qsizeB <- M.size <$> readTVar (RQ.getRcvQueues b) - csizeA <- M.size <$> readTVar (RQ.getConnections a) - csizeB <- M.size <$> readTVar (RQ.getConnections b) - pure (qsizeA + qsizeB, csizeA + csizeB) - -dummyRQ :: UserId -> SMPServer -> ConnId -> RecipientId -> RcvQueue -dummyRQ userId server connId rcvId = - RcvQueue - { userId, - connId, - server, - rcvId, - rcvPrivateKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe", - rcvDhSecret = "01234567890123456789012345678901", - e2ePrivKey = "MC4CAQAwBQYDK2VuBCIEINCzbVFaCiYHoYncxNY8tSIfn0pXcIAhLBfFc0m+gOpk", - e2eDhSecret = Nothing, - sndId = NoEntity, - queueMode = Just QMMessaging, - shortLink = Nothing, - clientService = Nothing, - status = New, - dbQueueId = DBEntityId 0, - primary = True, - dbReplaceQueueId = Nothing, - rcvSwchStatus = Nothing, - smpClientVersion = VersionSMPC 123, - clientNtfCreds = Nothing, - deleteErrors = 0 - } diff --git a/tests/CoreTests/TSessionSubs.hs b/tests/CoreTests/TSessionSubs.hs new file mode 100644 index 000000000..e3f819332 --- /dev/null +++ b/tests/CoreTests/TSessionSubs.hs @@ -0,0 +1,133 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} + +module CoreTests.TSessionSubs where + +import AgentTests.EqInstances () +import Control.Monad +import qualified Data.ByteString.Char8 as B +import Data.List (foldl') +import qualified Data.Map as M +import Data.String (IsString (..)) +import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId) +import Simplex.Messaging.Agent.Store (RcvQueueSub (..)) +import qualified Simplex.Messaging.Agent.TSessionSubs as SS +import Simplex.Messaging.Client (SMPTransportSession, TransportSessionMode (..)) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Protocol (EntityId (..), RecipientId, SMPServer) +import Simplex.Messaging.Transport (SessionId) +import Test.Hspec hiding (fit, it) +import UnliftIO +import Util + +tSessionSubsTests :: Spec +tSessionSubsTests = it "subscription lifecycle" $ testSessionSubs + +instance IsString EntityId where fromString = EntityId . B.pack + +dumpSessionSubs :: SS.TSessionSubs -> IO (M.Map SMPTransportSession (Maybe SessionId, (M.Map RecipientId RcvQueueSub, M.Map RecipientId RcvQueueSub))) +dumpSessionSubs = + readTVarIO . SS.sessionSubs + >=> mapM (\s -> (,) <$> readTVarIO (SS.subsSessId s) <*> SS.mapSubs id s) + +srv1 :: SMPServer +srv1 = "smp://1234-w==@alpha" + +srv2 :: SMPServer +srv2 = "smp://1234-w==@beta" + +testSessionSubs :: IO () +testSessionSubs = do + ss <- SS.emptyIO + ss' <- SS.emptyIO + let q1 = dummyRQ 1 srv1 "c1" "r1" + q2 = dummyRQ 1 srv1 "c2" "r2" + q3 = dummyRQ 1 srv2 "c3" "r3" + q4 = dummyRQ 1 srv2 "c4" "r4" + tSess1 = (1, srv1, Nothing) + tSess2 = (1, srv2, Nothing) + atomically (SS.addPendingSub tSess1 q1 ss) + atomically (SS.addPendingSub tSess1 q2 ss) + atomically (SS.hasPendingSubs tSess1 ss) `shouldReturn` True + atomically (SS.hasPendingSubs tSess2 ss) `shouldReturn` False + atomically (SS.addPendingSub tSess2 q3 ss) + atomically (SS.hasPendingSubs tSess2 ss) `shouldReturn` True + atomically (SS.batchAddPendingSubs tSess1 [q1, q2] ss') + atomically (SS.batchAddPendingSubs tSess2 [q3] ss') + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1), ("r2", q2)] + atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [] + atomically (SS.getPendingSubs tSess2 ss) `shouldReturn` M.fromList [("r3", q3)] + st <- dumpSessionSubs ss + dumpSessionSubs ss' `shouldReturn` st + countSubs ss `shouldReturn` (0, 3) + atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` True + atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` False + atomically (SS.hasPendingSub tSess1 (rcvId q4) ss) `shouldReturn` False + atomically (SS.hasActiveSub tSess1 (rcvId q4) ss) `shouldReturn` False + -- setting active queue without setting session ID would keep it as pending + atomically $ SS.addActiveSub tSess1 "123" q1 ss + atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` True + atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` False + dumpSessionSubs ss `shouldReturn` st + countSubs ss `shouldReturn` (0, 3) + -- setting active queues + atomically $ SS.setSessionId tSess1 "123" ss + atomically $ SS.addActiveSub tSess1 "123" q1 ss + atomically (SS.hasPendingSub tSess1 (rcvId q1) ss) `shouldReturn` False + atomically (SS.hasActiveSub tSess1 (rcvId q1) ss) `shouldReturn` True + atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1)] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r2", q2)] + countSubs ss `shouldReturn` (1, 2) + atomically $ SS.setSessionId tSess2 "456" ss + atomically $ SS.addActiveSub tSess2 "456" q4 ss + atomically (SS.hasPendingSub tSess2 (rcvId q4) ss) `shouldReturn` False + atomically (SS.hasActiveSub tSess2 (rcvId q4) ss) `shouldReturn` True + atomically (SS.hasActiveSub tSess1 (rcvId q4) ss) `shouldReturn` False -- wrong transport session + atomically (SS.getActiveSubs tSess2 ss) `shouldReturn` M.fromList [("r4", q4)] + atomically (SS.getPendingSubs tSess2 ss) `shouldReturn` M.fromList [("r3", q3)] + countSubs ss `shouldReturn` (2, 2) + -- setting pending queues + st' <- dumpSessionSubs ss + atomically (SS.setSubsPending TSMUser tSess1 "abc" ss) `shouldReturn` M.empty -- wrong session + dumpSessionSubs ss `shouldReturn` st' + atomically (SS.setSubsPending TSMUser tSess1 "123" ss) `shouldReturn` M.fromList [("r1", q1)] + atomically (SS.getActiveSubs tSess1 ss) `shouldReturn` M.fromList [] + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r1", q1), ("r2", q2)] + countSubs ss `shouldReturn` (1, 3) + -- delete subs + atomically $ SS.deletePendingSub tSess1 (rcvId q1) ss + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [("r2", q2)] + countSubs ss `shouldReturn` (1, 2) + atomically $ SS.deleteSub tSess1 (rcvId q2) ss + atomically (SS.getPendingSubs tSess1 ss) `shouldReturn` M.fromList [] + countSubs ss `shouldReturn` (1, 1) + atomically (SS.getActiveSubs tSess2 ss) `shouldReturn` M.fromList [("r4", q4)] + atomically $ SS.deleteSub tSess2 (rcvId q4) ss + atomically (SS.getActiveSubs tSess2 ss) `shouldReturn` M.fromList [] + countSubs ss `shouldReturn` (0, 1) + countSubs ss' `shouldReturn` (0, 3) + atomically $ SS.batchDeleteSubs tSess1 [q1, q2] ss' + countSubs ss' `shouldReturn` (0, 1) + +countSubs :: SS.TSessionSubs -> IO (Int, Int) +countSubs = fmap (foldl' (\(n1, n2) (_, (m1, m2)) -> (n1 + M.size m1, n2 + M.size m2)) (0, 0)) . dumpSessionSubs + +dummyRQ :: UserId -> SMPServer -> ConnId -> RecipientId -> RcvQueueSub +dummyRQ userId server connId rcvId = + RcvQueueSub + { userId, + connId, + server, + rcvId, + rcvPrivateKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe", + status = New, + enableNtfs = False, + clientNoticeId = Nothing, + dbQueueId = 0, + primary = True, + dbReplaceQueueId = Nothing + } diff --git a/tests/PostgresSchemaDump.hs b/tests/PostgresSchemaDump.hs index e9b54d540..3aa0a18c6 100644 --- a/tests/PostgresSchemaDump.hs +++ b/tests/PostgresSchemaDump.hs @@ -44,14 +44,14 @@ postgresSchemaDumpTest migrations skipComparisonForDownMigrations testDBOpts@DBO putStrLn $ "down migration " <> name m let downMigr = fromJust $ toDownMigration m schema <- getSchema testSchemaPath - Migrations.run st $ MTRUp [m] + Migrations.run st Nothing $ MTRUp [m] schema' <- getSchema testSchemaPath schema' `shouldNotBe` schema - Migrations.run st $ MTRDown [downMigr] + Migrations.run st Nothing $ MTRDown [downMigr] unless (name m `elem` skipComparisonForDownMigrations) $ do schema'' <- getSchema testSchemaPath schema'' `shouldBe` schema - Migrations.run st $ MTRUp [m] + Migrations.run st Nothing $ MTRUp [m] schema''' <- getSchema testSchemaPath schema''' `shouldBe` schema' diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index b6b9fd26f..02bee9ae7 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -65,7 +65,8 @@ initAgentServers = ntf = [testNtfServer], xftp = userServers [testXFTPServer], netCfg = defaultNetworkConfig {tcpTimeout = NetworkTimeout 500000 500000, tcpConnectTimeout = NetworkTimeout 500000 500000}, - presetDomains = [] + presetDomains = [], + presetServers = [] } initAgentServers2 :: InitialAgentServers diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 5f1a59fd0..b756ce7c9 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -224,7 +224,7 @@ agentDeliverMessageViaProxy :: (C.AlgorithmI a, C.AuthAlgorithm a) => (NonEmpty agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, bViaProxy) alg msg1 msg2 baseId = withAgent 1 aCfg (servers aTestCfg) testDB $ \alice -> withAgent 2 aCfg (servers bTestCfg) testDB2 $ \bob -> runRight_ $ do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn (sqSecured, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True @@ -280,7 +280,7 @@ agentDeliverMessagesViaProxyConc agentServers msgs = -- agent connections have to be set up in advance -- otherwise the CONF messages would get mixed with MSG prePair alice bob = do - (bobId, (CCLink qInfo Nothing, Nothing)) <- runExceptT' $ A.createConnection alice NRMInteractive 1 True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- runExceptT' $ A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- runExceptT' $ A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn (sqSecured, Nothing) <- runExceptT' $ A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True @@ -331,7 +331,7 @@ agentViaProxyVersionError = withAgent 1 agentCfg (servers [SMPServer testHost testPort testKeyHash]) testDB $ \alice -> do Left (A.BROKER _ (TRANSPORT TEVersion)) <- withAgent 2 agentCfg (servers [SMPServer testHost2 testPort2 testKeyHash]) testDB2 $ \bob -> runExceptT $ do - (_bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (_bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe pure () @@ -351,7 +351,7 @@ agentViaProxyRetryOffline = do let pqEnc = CR.PQEncOn withServer $ \_ -> do (aliceId, bobId) <- withServer2 $ \_ -> runRight $ do - (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + (bobId, (CCLink qInfo Nothing, Nothing)) <- A.createConnection alice NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe aliceId <- A.prepareConnectionToJoin bob 1 True qInfo PQSupportOn (sqSecured, Nothing) <- A.joinConnection bob NRMInteractive 1 aliceId True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ sqSecured `shouldBe` True diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 53269d6f6..b2c2d997c 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1290,11 +1290,11 @@ testBlockMessageQueue = pure (rId, sId) -- TODO [postgres] block via control port - withFile testStoreLogFile AppendMode $ \h -> B.hPutStrLn h $ strEncode $ BlockQueue rId $ BlockingInfo BRContent + withFile testStoreLogFile AppendMode $ \h -> B.hPutStrLn h $ strEncode $ BlockQueue rId $ BlockingInfo BRContent Nothing withSmpServerStoreLogOn ps testPort $ runTest t $ \h -> do (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g - Resp "dabc" sId2 (ERR (BLOCKED (BlockingInfo BRContent))) <- signSendRecv h sKey ("dabc", sId, SKEY sPub) + Resp "dabc" sId2 (ERR (BLOCKED (BlockingInfo BRContent Nothing))) <- signSendRecv h sKey ("dabc", sId, SKEY sPub) (sId2, sId) #== "same queue ID in response" where runTest :: Transport c => TProxy c 'TServer -> (THandleSMP c 'TClient -> IO a) -> ThreadId -> IO a diff --git a/tests/Test.hs b/tests/Test.hs index 364080e0c..3e36e192d 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -15,7 +15,7 @@ import CoreTests.MsgStoreTests import CoreTests.RetryIntervalTests import CoreTests.SOCKSSettings import CoreTests.StoreLogTests -import CoreTests.TRcvQueuesTests +import CoreTests.TSessionSubs import CoreTests.UtilTests import CoreTests.VersionRangeTests import FileDescriptionTests (fileDescriptionTests) @@ -90,7 +90,7 @@ main = do #else describe "Store log tests" storeLogTests #endif - describe "TRcvQueues tests" tRcvQueuesTests + describe "TSessionSubs tests" tSessionSubsTests describe "Util tests" utilTests describe "Agent core tests" agentCoreTests #if defined(dbServerPostgres) @@ -103,8 +103,8 @@ main = do testStoreDBOpts "src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql" around_ (postgressBracket testServerDBConnectInfo) $ do - describe "SMP server via TLS, postgres+jornal message store" $ - before (pure (transport @TLS, ASType SQSPostgres SMSJournal)) serverTests + -- xdescribe "SMP server via TLS, postgres+jornal message store" $ + -- before (pure (transport @TLS, ASType SQSPostgres SMSJournal)) serverTests describe "SMP server via TLS, postgres-only message store" $ before (pure (transport @TLS, ASType SQSPostgres SMSPostgres)) serverTests #endif @@ -128,19 +128,20 @@ main = do describe "Notifications server (SMP server: jornal store)" $ ntfServerTests (transport @TLS, ASType SQSMemory SMSJournal) around_ (postgressBracket testServerDBConnectInfo) $ do - describe "Notifications server (SMP server: postgres+jornal store)" $ - ntfServerTests (transport @TLS, ASType SQSPostgres SMSJournal) + -- xdescribe "Notifications server (SMP server: postgres+jornal store)" $ + -- ntfServerTests (transport @TLS, ASType SQSPostgres SMSJournal) describe "Notifications server (SMP server: postgres-only store)" $ ntfServerTests (transport @TLS, ASType SQSPostgres SMSPostgres) around_ (postgressBracket testServerDBConnectInfo) $ do - describe "SMP client agent, postgres+jornal message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSJournal) - describe "SMP client agent, postgres-only message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSPostgres) - describe "SMP proxy, postgres+jornal message store" $ - before (pure $ ASType SQSPostgres SMSJournal) smpProxyTests + -- xdescribe "SMP client agent, postgres+jornal message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSJournal) + describe "SMP client agent, server postgres-only message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSPostgres) + -- xdescribe "SMP proxy, postgres+jornal message store" $ + -- before (pure $ ASType SQSPostgres SMSJournal) smpProxyTests describe "SMP proxy, postgres-only message store" $ before (pure $ ASType SQSPostgres SMSPostgres) smpProxyTests #endif - describe "SMP client agent, jornal message store" $ agentTests (transport @TLS, ASType SQSMemory SMSJournal) + -- xdescribe "SMP client agent, server jornal message store" $ agentTests (transport @TLS, ASType SQSMemory SMSJournal) + describe "SMP client agent, server memory message store" $ agentTests (transport @TLS, ASType SQSMemory SMSMemory) describe "SMP proxy, jornal message store" $ before (pure $ ASType SQSMemory SMSJournal) smpProxyTests describe "XFTP" $ do