diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 58ed41137..08e8d513b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -58,11 +58,11 @@ jobs: # ============================= build: - name: "ubuntu-${{ matrix.os }}, GHC: ${{ matrix.ghc }}" + name: "ubuntu-${{ matrix.os }}-${{ matrix.arch }}, GHC: ${{ matrix.ghc }}" needs: maybe-release env: apps: "smp-server xftp-server ntf-server xftp" - runs-on: ubuntu-${{ matrix.os }} + runs-on: ${{ matrix.runner }} services: postgres: image: postgres:15 @@ -81,16 +81,34 @@ jobs: matrix: include: - os: 22.04 + os_underscore: 22_04 + arch: x86-64 + runner: "ubuntu-22.04" ghc: "8.10.7" - platform_name: 22_04-8.10.7 should_run: ${{ !(github.ref == 'refs/heads/stable' || startsWith(github.ref, 'refs/tags/v')) }} - os: 22.04 + os_underscore: 22_04 + arch: x86-64 + runner: "ubuntu-22.04" ghc: "9.6.3" - platform_name: 22_04-x86-64 should_run: true - os: 24.04 + os_underscore: 24_04 + arch: x86-64 + runner: "ubuntu-24.04" + ghc: "9.6.3" + should_run: true + - os: 22.04 + os_underscore: 22_04 + arch: aarch64 + runner: "ubuntu-22.04-arm" + ghc: "9.6.3" + should_run: true + - os: 24.04 + os_underscore: 24_04 + arch: aarch64 + runner: "ubuntu-24.04-arm" ghc: "9.6.3" - platform_name: 24_04-x86-64 should_run: true steps: - name: Clone project @@ -127,11 +145,7 @@ jobs: context: . load: true file: Dockerfile.build - tags: build/${{ matrix.platform_name }}:latest - cache-from: | - type=gha - type=gha,scope=master - cache-to: type=gha,mode=max + tags: build/${{ matrix.os }}:latest build-args: | TAG=${{ matrix.os }} GHC=${{ matrix.ghc }} @@ -143,23 +157,28 @@ jobs: path: | ~/.cabal/store dist-newstyle - key: ${{ matrix.os }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }} + key: ubuntu-${{ matrix.os }}-${{ matrix.arch }}-ghc${{ matrix.ghc }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }} - name: Start container if: matrix.should_run == true shell: bash run: | docker run -t -d \ + --device /dev/fuse \ + --cap-add SYS_ADMIN \ + --security-opt apparmor:unconfined \ --name builder \ -v ~/.cabal:/root/.cabal \ -v /home/runner/work/_temp:/home/runner/work/_temp \ -v ${{ github.workspace }}:/project \ - build/${{ matrix.platform_name }}:latest + build/${{ matrix.os }}:latest - name: Build smp-server (postgresql) and tests if: matrix.should_run == true shell: docker exec -t builder sh -eu {0} run: | + chmod -R 777 dist-newstyle ~/.cabal && git config --global --add safe.directory '*' + cabal clean cabal update cabal build --jobs=$(nproc) --enable-tests -fserver_postgres mkdir -p /out @@ -181,7 +200,7 @@ jobs: id: prepare-postgres shell: bash run: | - name="smp-server-postgres-ubuntu-${{ matrix.platform_name }}" + name="smp-server-postgres-ubuntu-${{ matrix.os_underscore }}-${{ matrix.arch }}" docker cp builder:/out/smp-server $name path="${{ github.workspace }}/$name" @@ -213,9 +232,9 @@ jobs: printf 'bins< bins.output printf 'hashes< hashes.output for i in ${{ env.apps }}; do - mv ./out/$i ./$i-ubuntu-${{ matrix.platform_name }} + name="$i-ubuntu-${{ matrix.os_underscore }}-${{ matrix.arch }}" - name="$i-ubuntu-${{ matrix.platform_name }}" + mv ./out/$i ./$name path="${{ github.workspace }}/$name" hash="SHA2-256($name)= $(openssl sha256 $path | cut -d' ' -f 2)" @@ -246,7 +265,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Test - if: matrix.should_run == true + if: matrix.should_run == true && matrix.arch == 'x86-64' timeout-minutes: 120 shell: bash env: diff --git a/simplexmq.cabal b/simplexmq.cabal index fc282f30a..476b0a4be 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -1,7 +1,7 @@ cabal-version: 1.12 name: simplexmq -version: 6.4.4.1 +version: 6.5.0.0.1 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and @@ -267,6 +267,7 @@ library Simplex.Messaging.Notifications.Server.Store.Postgres Simplex.Messaging.Notifications.Server.Store.Types Simplex.Messaging.Notifications.Server.StoreLog + Simplex.Messaging.Server.MsgStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres.Migrations other-modules: diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index cfee308bc..c3723e7a9 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -75,7 +75,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String (strDecode, strEncode) import Simplex.Messaging.Protocol (ProtocolServer, ProtocolType (..), XFTPServer) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM) +import Simplex.Messaging.Util (allFinally, catchAll_, catchAllErrors, liftError, tshow, unlessM, whenM) import System.FilePath (takeFileName, ()) import UnliftIO import UnliftIO.Directory @@ -198,10 +198,10 @@ runXFTPRcvWorker c srv Worker {doWork} = do liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv downloadAttempts downloadFileChunk fc replica approvedRelays - `catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e + `catchAllErrors` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e where retryLoop loop e replicaDelay = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (serverHostError e) $ notify c (fromMaybe rcvFileEntityId redirectEntityId_) (RFWARN e) liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay @@ -280,7 +280,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do runXFTPOperation AgentConfig {rcvFilesTTL} = withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $ \f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath, redirect} -> - decryptFile f `catchAgentError` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath + decryptFile f `catchAllErrors` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath decryptFile :: RcvFile -> AM () decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do let CryptoFile savePath cfArgs = saveFile @@ -307,7 +307,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do liftIO $ waitUntilForeground c withStore' c (`updateRcvFileComplete` rcvFileId) -- proceed with redirect - yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath) + yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `allFinally` (lift $ toFSFilePath fsSavePath >>= removePath) next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of -- TODO switch to another error constructor Left _ -> throwE . FILE $ REDIRECT "decode error" @@ -399,7 +399,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do runXFTPOperation cfg@AgentConfig {sndFilesTTL} = withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $ \f@SndFile {sndFileId, sndFileEntityId, prefixPath} -> - prepareFile cfg f `catchAgentError` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath + prepareFile cfg f `catchAllErrors` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath prepareFile :: AgentConfig -> SndFile -> AM () prepareFile _ SndFile {prefixPath = Nothing} = throwE $ INTERNAL "no prefix path" @@ -468,11 +468,11 @@ runXFTPSndPrepareWorker c Worker {doWork} = do liftIO $ waitForUserNetwork c let triedAllSrvs = n > userSrvCount createWithNextSrv triedHosts - `catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e + `catchAllErrors` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e where -- we don't do closeXFTPServerClient here to not risk closing connection for concurrent chunk upload retryLoop loop triedAllSrvs e = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e liftIO $ assertAgentForeground c loop @@ -508,10 +508,10 @@ runXFTPSndWorker c srv Worker {doWork} = do liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv uploadAttempts uploadFileChunk cfg fc replica - `catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e + `catchAllErrors` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e where retryLoop loop e replicaDelay = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e liftIO $ closeXFTPServerClient c userId server digest withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay @@ -681,10 +681,10 @@ runXFTPDelWorker c srv Worker {doWork} = do liftIO $ waitForUserNetwork c atomically $ incXFTPServerStat c userId srv deleteAttempts deleteChunkReplica - `catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e + `catchAllErrors` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e where retryLoop loop e replicaDelay = do - flip catchAgentError (\_ -> pure ()) $ do + flip catchAllErrors (\_ -> pure ()) $ do when (serverHostError e) $ notify c "" $ SFWARN e liftIO $ closeXFTPServerClient c userId server chunkDigest withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index dac4cc1b3..62f06b7d3 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -59,6 +59,8 @@ import Simplex.Messaging.Protocol RecipientId, SenderId, pattern NoEntity, + NetworkError (..), + toNetworkError, ) import Simplex.Messaging.Transport (ALPN, CertChainPubKey (..), HandshakeError (..), THandleAuth (..), THandleParams (..), TransportError (..), TransportPeer (..), defaultSupportedParams) import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost) @@ -191,7 +193,7 @@ xftpHTTP2Config transportConfig XFTPClientConfig {xftpNetworkConfig = NetworkCon xftpClientError :: HTTP2ClientError -> XFTPClientError xftpClientError = \case HCResponseTimeout -> PCEResponseTimeout - HCNetworkError -> PCENetworkError + HCNetworkError e -> PCENetworkError e HCIOError e -> PCEIOError e sendXFTPCommand :: forall p. FilePartyI p => XFTPClient -> C.APrivateAuthKey -> XFTPFileId -> FileCommand p -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body) @@ -261,9 +263,9 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec { ExceptT (sequence <$> (t `timeout` (download cbState `catches` errors))) >>= maybe (throwE PCEResponseTimeout) pure where errors = - [ Handler $ \(_e :: H.HTTP2Error) -> pure $ Left PCENetworkError, - Handler $ \(e :: IOException) -> pure $ Left (PCEIOError e), - Handler $ \(_e :: SomeException) -> pure $ Left PCENetworkError + [ Handler $ \(e :: H.HTTP2Error) -> pure $ Left $ PCENetworkError $ NEConnectError $ displayException e, + Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError e, + Handler $ \(e :: SomeException) -> pure $ Left $ PCENetworkError $ toNetworkError e ] download cbState = runExceptT . withExceptT PCEResponseError $ diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 60ea5d69c..27967bfd6 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -139,7 +139,7 @@ import Control.Monad.Reader import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) import qualified Data.Aeson as J -import Data.Bifunctor (bimap, first, second) +import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) @@ -284,13 +284,13 @@ saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServ xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss} - tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case + tryAllErrors' (withStore' c (`updateServersStats` stats)) >>= \case Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right () -> pure () restoreServersStats :: AgentClient -> AM' () restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do - tryAgentError' (withStore c getServersStats) >>= \case + tryAllErrors' (withStore c getServersStats) >>= \case Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e) Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do @@ -774,7 +774,7 @@ acceptContactAsync' :: AgentClient -> UserId -> ACorrId -> Bool -> InvitationId acceptContactAsync' c userId corrId enableNtfs invId ownConnInfo pqSupport subMode = do Invitation {connReq} <- withStore c $ \db -> getInvitation db "acceptContactAsync'" invId withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do + joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAllErrors` \err -> do withStore' c (`unacceptInvitation` invId) throwE err @@ -961,7 +961,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey createRcvQueue nonce_ qd e2eKeys = do AgentConfig {smpClientVRange = vr} <- asks config ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing - (rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAgentError` \e -> liftIO (print e) >> throwE e + (rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e atomically $ incSMPServerStat c userId srv connCreated rq' <- withStore c $ \db -> updateNewConnRcv db connId rq lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId @@ -1269,7 +1269,7 @@ subscribeConnections' c connIds = do errs' = M.map (Left . storeError) errs (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs resumeDelivery cs - lift $ resumeConnCmds c $ M.keys cs + resumeConnCmds c $ M.keys cs rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) rcvRs' <- storeClientServiceAssocs rcvRs ns <- asks ntfSupervisor @@ -1351,7 +1351,7 @@ subscribeClientService' = undefined -- requesting messages sequentially, to reduce memory usage getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta))) -getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage +getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage where getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta) getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do @@ -1363,7 +1363,7 @@ getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage ContactConnection _ rq -> pure rq SndConnection _ _ -> throwE $ CONN SIMPLEX "getConnectionMessage" NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection" - msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e + msg_ <- getQueueMessage c rq `catchAllErrors` \e -> atomically (releaseGetLock c rq) >> throwError e when (isNothing msg_) $ do atomically $ releaseGetLock c rq forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBEntityId dbQueueId) msgTs @@ -1473,10 +1473,10 @@ resumeSrvCmds :: AgentClient -> ConnId -> Maybe SMPServer -> AM' () resumeSrvCmds = void .:. getAsyncCmdWorker False {-# INLINE resumeSrvCmds #-} -resumeConnCmds :: AgentClient -> [ConnId] -> AM' () +resumeConnCmds :: AgentClient -> [ConnId] -> AM () resumeConnCmds c connIds = do - connSrvs <- rights . zipWith (second . (,)) connIds <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds) - mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs + connSrvs <- withStore' c (`getPendingCommandServers` connIds) + lift $ mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs getAsyncCmdWorker :: Bool -> AgentClient -> ConnId -> Maybe SMPServer -> AM' Worker getAsyncCmdWorker hasWork c connId server = @@ -1534,7 +1534,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do pure CCCompleted -- duplex connection is matched to handle SKEY retries DuplexConnection cData _ (sq :| _) -> do - tryAgentError (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case + tryAllErrors (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case Right () -> pure CCCompleted Left e | temporaryOrHostError e && Just server /= server_ -> do @@ -1621,7 +1621,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do tryMoveableCommand action = withRetryInterval ri $ \_ loop -> do liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c - tryAgentError action >>= \case + tryAllErrors action >>= \case Left e | temporaryOrHostError e -> retrySndOp c loop | otherwise -> cmdError e @@ -2065,7 +2065,7 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM () ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do atomically $ incSMPServerStat c userId server ackAttempts - tryAgentError (sendAck c rq srvMsgId) >>= \case + tryAllErrors (sendAck c rq srvMsgId) >>= \case Right _ -> sendMsgNtf ackMsgs Left (SMP _ SMP.NO_MSG) -> sendMsgNtf ackNoMsgErrs Left e -> do @@ -2076,7 +2076,7 @@ ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do atomically $ incSMPServerStat c userId server stat whenM (liftIO $ hasGetLock c rq) $ do atomically $ releaseGetLock c rq - brokerTs_ <- eitherToMaybe <$> tryAgentError (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId) + brokerTs_ <- eitherToMaybe <$> tryAllErrors (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId) atomically $ writeTBQueue (subQ c) ("", connId, AEvt SAEConn $ MSGNTF srvMsgId brokerTs_) -- | Suspend SMP agent connection (OFF command) in Reader monad @@ -2307,7 +2307,7 @@ registerNtfToken' c nm suppliedDeviceToken suppliedNtfMode = replaceToken :: NtfTokenId -> AM NtfTknStatus replaceToken tknId = do ns <- asks ntfSupervisor - tryReplace ns `catchAgentError` \e -> + tryReplace ns `catchAllErrors` \e -> if temporaryOrHostError e then throwE e else do @@ -2451,23 +2451,23 @@ sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM () sendNtfConnCommands c cmd = do ns <- asks ntfSupervisor connIds <- liftIO $ S.toList <$> getSubscriptions c - rs <- lift $ withStoreBatch' c (\db -> map (getConnData db) connIds) + rs <- withStore' c (`getConnsData` connIds) let (connIds', cErrs) = enabledNtfConns (zip connIds rs) forM_ (L.nonEmpty connIds') $ \connIds'' -> atomically $ writeTBQueue (ntfSubQ ns) (cmd, connIds'') unless (null cErrs) $ atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS cErrs) where - enabledNtfConns :: [(ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)]) + enabledNtfConns :: [(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)]) enabledNtfConns = foldr addEnabledConn ([], []) where addEnabledConn :: - (ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode))) -> + (ConnId, Either StoreError (Maybe (ConnData, ConnectionMode))) -> ([ConnId], [(ConnId, AgentErrorType)]) -> ([ConnId], [(ConnId, AgentErrorType)]) addEnabledConn cData_ (cIds, errs) = case cData_ of (_, Right (Just (ConnData {connId, enableNtfs}, _))) -> if enableNtfs then (connId : cIds, errs) else (cIds, errs) (connId, Right Nothing) -> (cIds, (connId, INTERNAL "no connection data") : errs) - (connId, Left e) -> (cIds, (connId, e) : errs) + (connId, Left e) -> (cIds, (connId, INTERNAL (show e)) : errs) setNtfServers :: AgentClient -> [NtfServer] -> IO () setNtfServers c = atomically . writeTVar (ntfServers c) @@ -2564,7 +2564,7 @@ cleanupManager c@AgentClient {subQ} = do where run :: forall e. AEntityI e => (AgentErrorType -> AEvent e) -> AM () -> AM' () run err a = do - waitActive . runExceptT $ a `catchAgentError` (notify "" . err) + waitActive . runExceptT $ a `catchAllErrors` (notify "" . err) step <- asks $ cleanupStepInterval . config liftIO $ threadDelay step -- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active @@ -2578,33 +2578,33 @@ cleanupManager c@AgentClient {subQ} = do deleteRcvFilesExpired = do rcvFilesTTL <- asks $ rcvFilesTTL . config rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL) - forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do + forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`deleteRcvFile'` dbId) deleteRcvFilesDeleted = do rcvDeleted <- withStore' c getCleanupRcvFilesDeleted - forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do + forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`deleteRcvFile'` dbId) deleteRcvFilesTmpPaths = do rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths - forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do + forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`updateRcvFileNoTmpPath` dbId) deleteSndFilesExpired = do sndFilesTTL <- asks $ sndFilesTTL . config sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL) - forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do + forM_ sndExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do lift . forM_ p $ removePath <=< toFSFilePath withStore' c (`deleteSndFile'` dbId) deleteSndFilesDeleted = do sndDeleted <- withStore' c getCleanupSndFilesDeleted - forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do + forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do lift . forM_ p $ removePath <=< toFSFilePath withStore' c (`deleteSndFile'` dbId) deleteSndFilesPrefixPaths = do sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths - forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do + forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do lift $ removePath =<< toFSFilePath p withStore' c (`updateSndFileNoPrefixPath` dbId) deleteExpiredReplicasForDeletion = do @@ -2652,10 +2652,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId where withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' () withRcvConn rId a = do - tryAgentError' (withStore c $ \db -> getRcvConn db srv rId) >>= \case + tryAllErrors' (withStore c $ \db -> getRcvConn db srv rId) >>= \case Left e -> notify' "" (ERR e) Right (rq@RcvQueue {connId}, SomeConn _ conn) -> - tryAgentError' (a rq conn) >>= \case + tryAllErrors' (a rq conn) >>= \case Left e -> notify' connId (ERR e) Right () -> pure () processSubOk :: RcvQueue -> TVar [ConnId] -> AM () @@ -2739,7 +2739,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId _ -> pure () let encryptedMsgHash = C.sha256Hash encAgentMessage g <- asks random - tryAgentError (agentClientMsg g encryptedMsgHash) >>= \case + tryAllErrors (agentClientMsg g encryptedMsgHash) >>= \case Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do conn'' <- resetRatchetSync case aMessage of @@ -2848,7 +2848,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId ackDel :: InternalId -> AM ACKd ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd handleNotifyAck :: AM ACKd -> AM ACKd - handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack + handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack SMP.END -> atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False)) >>= notifyEnd @@ -3006,7 +3006,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do logServer "<--" c srv rId $ "MSG :" <> logSecret' srvMsgId - rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing + rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAllErrors` \e -> notify (ERR e) $> Nothing case L.nonEmpty . catMaybes $ L.toList rs of Just rs' -> notify (RCVD msgMeta rs') $> ACKPending Nothing -> ack diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 39b3534c0..21c0436ee 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -130,6 +130,7 @@ module Simplex.Messaging.Agent.Client hasWorkToDo, hasWorkToDo', withWork, + withWork_, withWorkItems, agentOperations, agentOperationBracket, @@ -249,6 +250,7 @@ import Simplex.Messaging.Protocol EntityId (..), ServiceId, ErrorType, + NetworkError (..), MsgFlags (..), MsgId, NtfServer, @@ -371,12 +373,12 @@ data SMPConnectedClient = SMPConnectedClient type ProxiedRelayVar = SessionVar (Either AgentErrorType ProxiedRelay) -getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker +getAgentWorker :: (Ord k, Show k, AnyError e, MonadUnliftIO m) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT e m ()) -> m Worker getAgentWorker = getAgentWorker' id pure {-# INLINE getAgentWorker #-} -getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a -getAgentWorker' toW fromW name hasWork c key ws work = do +getAgentWorker' :: forall a k e m. (Ord k, Show k, AnyError e, MonadUnliftIO m) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT e m ()) -> m a +getAgentWorker' toW fromW name hasWork c@AgentClient {agentEnv} key ws work = do atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w where getWorker = TM.lookup key ws @@ -389,12 +391,12 @@ getAgentWorker' toW fromW name hasWork c key ws work = do | otherwise = pure w runWorker w = runWorkerAsync (toW w) runWork where - runWork :: AM' () - runWork = tryAgentError' (work w) >>= restartOrDelete - restartOrDelete :: Either AgentErrorType () -> AM' () + runWork :: m () + runWork = tryAllErrors' (work w) >>= restartOrDelete + restartOrDelete :: Either e () -> m () restartOrDelete e_ = do t <- liftIO getSystemTime - maxRestarts <- asks $ maxWorkerRestartsPerMin . config + let maxRestarts = maxWorkerRestartsPerMin $ config agentEnv -- worker may terminate because it was deleted from the map (getWorker returns Nothing), then it won't restart restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts) when restart runWork @@ -431,7 +433,7 @@ newWorker c = do restarts <- newTVar $ RestartCount 0 0 pure Worker {workerId, doWork, action, restarts} -runWorkerAsync :: Worker -> AM' () -> AM' () +runWorkerAsync :: MonadUnliftIO m => Worker -> m () -> m () runWorkerAsync Worker {action} work = E.bracket (atomically $ takeTMVar action) -- get current action, locking to avoid race conditions @@ -665,7 +667,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq pure (clnt, sess) newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv = - tryAgentError (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case + tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case Right sess -> do atomically $ putTMVar (sessionVar rv) (Right sess) pure $ Right sess @@ -688,7 +690,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq smpConnectClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(_, srv, _) prs v = newProtocolClient c tSess smpClients connectClient v - `catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwE e + `catchAllErrors` \e -> lift (resubscribeSMPSession c tSess) >> throwE e where connectClient :: SMPClientVar -> AM SMPConnectedClient connectClient v' = do @@ -866,7 +868,7 @@ newProtocolClient :: ClientVar msg -> AM (Client msg) newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = - tryAgentError (connectClient v) >>= \case + tryAllErrors (connectClient v) >>= \case Right client -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")" atomically $ putTMVar (sessionVar v) (Right client) @@ -1027,7 +1029,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> NetworkRequestMode -> TransportSession msg -> (Client msg -> AM a) -> AM a withClient_ c nm tSess@(_, srv, _) action = do cl <- getProtocolServerClient c nm tSess - action cl `catchAgentError` logServerError + action cl `catchAllErrors` logServerError where logServerError :: AgentErrorType -> AM a logServerError e = do @@ -1040,7 +1042,7 @@ withProxySession c nm proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = d logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr case sess_ of Right sess -> do - r <- action (cl, sess) `catchAgentError` logServerError cl + r <- action (cl, sess) `catchAllErrors` logServerError cl logServer ("<-- " <> proxySrv cl <> " <") c destSrv entId "OK" pure r Left e -> logServerError cl e @@ -1117,7 +1119,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn unknownServer = liftIO $ maybe True (\srvs -> all (`S.notMember` knownHosts srvs) destHosts) <$> TM.lookupIO userId (smpServers c) sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer, a) sendViaProxy proxySrv_ destSess@(_, _, connId_) = do - r <- tryAgentError . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do + r <- tryAllErrors . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess let proxySrv = protocolClientServer' smp case r' of @@ -1164,7 +1166,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn | otherwise -> throwE e sendDirectly tSess = withLogClient_ c nm tSess (unEntityId entId) ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do - tryAgentError (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case + tryAllErrors (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case Right r -> r <$ atomically (incSMPServerStat c userId destSrv sentDirect) Left e -> throwE e @@ -1198,12 +1200,12 @@ protocolClientError protocolError_ host = \case PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e PCEUnexpectedResponse e -> BROKER host $ UNEXPECTED $ B.unpack e PCEResponseTimeout -> BROKER host TIMEOUT - PCENetworkError -> BROKER host NETWORK + PCENetworkError e -> BROKER host $ NETWORK e PCEIncompatibleHost -> BROKER host HOST PCETransportError e -> BROKER host $ TRANSPORT e e@PCECryptoError {} -> INTERNAL $ show e PCEServiceUnavailable {} -> BROKER host NO_SERVICE - PCEIOError {} -> BROKER host NETWORK + PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e data ProtocolTestStep = TSConnect @@ -1477,7 +1479,7 @@ temporaryAgentError = \case _ -> False where tempBrokerError = \case - NETWORK -> True + NETWORK _ -> True TIMEOUT -> True _ -> False @@ -1517,7 +1519,7 @@ subscribeQueues c qs = do subscribeQueues_ env session smp qs' = do let (userId, srv, _) = transportSession' smp atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs' - rs <- sendBatch (\smp' _ -> subscribeSMPQueues smp') smp NRMBackground qs' + rs <- sendBatch (\smp' _ -> subscribeSMPQueues smp') smp NRMBackground qs' active <- atomically $ ifM @@ -1528,7 +1530,8 @@ subscribeQueues c qs = do then when (hasTempErrors rs) resubscribe $> rs else do logWarn "subcription batch result for replaced SMP client, resubscribing" - resubscribe $> L.map (second $ \_ -> Left PCENetworkError) rs + -- TODO we probably use PCENetworkError here instead of the original error, so it becomes temporary. + resubscribe $> L.map (second $ Left . PCENetworkError . NESubscribeError . show) rs where tSess = transportSession' smp sessId = sessionId $ thParams smp @@ -1562,7 +1565,7 @@ sendTSessionBatches statCmd toRQ action c nm qs = in M.alter (Just . maybe [q] (q <|)) tSess m sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r) sendClientBatch (tSess@(_, srv, _), qs') = - tryAgentError' (getSMPServerClient c nm tSess) >>= \case + tryAllErrors' (getSMPServerClient c nm tSess) >>= \case Left e -> pure $ L.map (,Left e) qs' Right (SMPConnectedClient smp _) -> liftIO $ do logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd @@ -1867,7 +1870,7 @@ withNtfBatch :: AM' (NonEmpty (Either AgentErrorType r)) withNtfBatch cmdStr action c NtfToken {ntfServer, ntfPrivKey} subs = do let tSess = (0, ntfServer, Nothing) - tryAgentError' (getNtfServerClient c NRMBackground tSess) >>= \case + tryAllErrors' (getNtfServerClient c NRMBackground tSess) >>= \case Left e -> pure $ L.map (\_ -> Left e) subs Right ntf -> liftIO $ do logServer' "-->" c ntfServer (bshow (length subs) <> " subscriptions") cmdStr @@ -1968,40 +1971,42 @@ waitForWork = void . atomically . readTMVar {-# INLINE waitForWork #-} withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM () -withWork c doWork getWork action = - withStore' c getWork >>= \case +withWork c doWork = withWork_ c doWork . withStore' c +{-# INLINE withWork #-} + +withWork_ :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m () +withWork_ c doWork getWork action = + getWork >>= \case Right (Just r) -> action r Right Nothing -> noWork -- worker is stopped here (noWork) because the next iteration is likely to produce the same result - Left e@SEWorkItemError {} -> noWork >> notifyErr (CRITICAL False) e - Left e -> notifyErr INTERNAL e + Left e + | isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e + | otherwise -> notifyErr INTERNAL e where noWork = liftIO $ noWorkToDo doWork notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e) -withWorkItems :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError [Either StoreError a])) -> (NonEmpty a -> AM ()) -> AM () +withWorkItems :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' [Either e' a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m () withWorkItems c doWork getWork action = do - withStore' c getWork >>= \case + getWork >>= \case Right [] -> noWork Right rs -> do let (errs, items) = partitionEithers rs case L.nonEmpty items of Just items' -> action items' Nothing -> do - let criticalErr = find workItemError errs + let criticalErr = find isWorkItemError errs forM_ criticalErr $ \err -> do notifyErr (CRITICAL False) err - when (all workItemError errs) noWork + when (all isWorkItemError errs) noWork unless (null errs) $ atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS $ map (\e -> ("", INTERNAL $ show e)) errs) Left e - | workItemError e -> noWork >> notifyErr (CRITICAL False) e + | isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e | otherwise -> notifyErr INTERNAL e where - workItemError = \case - SEWorkItemError {} -> True - _ -> False noWork = liftIO $ noWorkToDo doWork notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e) diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 0c10d8cd4..e15ffa48c 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -27,11 +27,6 @@ module Simplex.Messaging.Agent.Env.SQLite serverHosts, defaultAgentConfig, defaultReconnectInterval, - tryAgentError, - tryAgentError', - catchAgentError, - catchAgentError', - agentFinally, Env (..), newSMPAgentEnv, createAgentStore, @@ -45,7 +40,6 @@ module Simplex.Messaging.Agent.Env.SQLite where import Control.Concurrent (ThreadId) -import Control.Exception (BlockedIndefinitelyOnSTM (..), SomeException, fromException) import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader @@ -70,7 +64,7 @@ import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store (createStore) import Simplex.Messaging.Agent.Store.Common (DBStore) import Simplex.Messaging.Agent.Store.Interface (DBOpts) -import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationError (..)) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange) @@ -83,7 +77,6 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors') import System.Mem.Weak (Weak) import System.Random (StdGen, newStdGen) import UnliftIO.STM @@ -273,7 +266,7 @@ newSMPAgentEnv config store = do multicastSubscribers <- newTMVarIO 0 pure Env {config, store, random, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers} -createAgentStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createAgentStore :: DBOpts -> MigrationConfig -> IO (Either MigrationError DBStore) createAgentStore = createStore data NtfSupervisor = NtfSupervisor @@ -312,33 +305,6 @@ newXFTPAgent = do xftpDelWorkers <- TM.emptyIO pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers} -tryAgentError :: AM a -> AM (Either AgentErrorType a) -tryAgentError = tryAllErrors mkInternal -{-# INLINE tryAgentError #-} - --- unlike runExceptT, this ensures we catch IO exceptions as well -tryAgentError' :: AM a -> AM' (Either AgentErrorType a) -tryAgentError' = tryAllErrors' mkInternal -{-# INLINE tryAgentError' #-} - -catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a -catchAgentError = catchAllErrors mkInternal -{-# INLINE catchAgentError #-} - -catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a -catchAgentError' = catchAllErrors' mkInternal -{-# INLINE catchAgentError' #-} - -agentFinally :: AM a -> AM b -> AM a -agentFinally = allFinally mkInternal -{-# INLINE agentFinally #-} - -mkInternal :: SomeException -> AgentErrorType -mkInternal e = case fromException e of - Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction" - _ -> INTERNAL $ show e -{-# INLINE mkInternal #-} - data Worker = Worker { workerId :: Int, doWork :: TMVar (), diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index 7546c03dd..85fa45c49 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -52,7 +52,7 @@ import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Protocol (NtfServer, sameSrvAddr) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, whenM) +import Simplex.Messaging.Util (catchAllErrors, diffToMicroseconds, threadDelay', tryAllErrors, tshow, whenM) import System.Random (randomR) import UnliftIO import UnliftIO.Concurrent (forkIO) @@ -217,7 +217,7 @@ runNtfWorker c srv Worker {doWork} = runNtfOperation :: AM () runNtfOperation = do ntfBatchSize <- asks $ ntfBatchSize . config - withWorkItems c doWork (\db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do + withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do logInfo $ "runNtfWorker - length nextSubs = " <> tshow (length nextSubs) currTs <- liftIO getCurrentTime let (creates, checks, deletes, rotates) = splitActions currTs nextSubs @@ -357,7 +357,7 @@ runNtfWorker c srv Worker {doWork} = runCatching :: (NtfSubscription -> AM (Maybe NtfSubscription)) -> NtfSubscription -> AM' (Maybe NtfSubscription) runCatching action sub@NtfSubscription {connId} = fromRight Nothing - <$> runExceptT (action sub `catchAgentError` \e -> workerInternalError c connId (show e) $> Nothing) + <$> runExceptT (action sub `catchAllErrors` \e -> workerInternalError c connId (show e) $> Nothing) -- deleteNtfSub is only used in NSADelete and NSARotate, so also deprecated deleteNtfSub :: NtfSubscription -> AM () -> AM (Maybe NtfSubscription) deleteNtfSub sub@NtfSubscription {userId, ntfSubId} continue = case ntfSubId of @@ -365,7 +365,7 @@ runNtfWorker c srv Worker {doWork} = lift getNtfToken >>= \case Just tkn@NtfToken {ntfServer} -> do atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts - tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case + tryAllErrors (agentNtfDeleteSubscription c nSubId tkn) >>= \case Right _ -> do atomically $ incNtfServerStat c userId ntfServer ntfDeleted continue' @@ -385,7 +385,7 @@ runNtfSMPWorker c srv Worker {doWork} = forever $ do runNtfSMPOperation :: AM () runNtfSMPOperation = do ntfBatchSize <- asks $ ntfBatchSize . config - withWorkItems c doWork (\db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do + withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do logInfo $ "runNtfSMPWorker - length nextSubs = " <> tshow (length nextSubs) let (creates, deletes) = splitActions nextSubs retrySubActions c creates createNotifierKeys @@ -567,7 +567,7 @@ runNtfTknDelWorker c srv Worker {doWork} = withRetryInterval ri $ \_ loop -> do liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c - processTknToDelete nextTknToDelete `catchAgentError` retryTmpError loop nextTknToDelete + processTknToDelete nextTknToDelete `catchAllErrors` retryTmpError loop nextTknToDelete retryTmpError :: AM () -> NtfTokenToDelete -> AgentErrorType -> AM () retryTmpError loop (tknDbId, _, _) e = do logError $ "ntf tkn del error: " <> tshow e diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index f1d5f2ec8..d4d302df7 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -173,6 +173,7 @@ module Simplex.Messaging.Agent.Protocol where import Control.Applicative (optional, (<|>)) +import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException) import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?)) import qualified Data.Aeson as J' import qualified Data.Aeson.Encoding as JE @@ -1866,6 +1867,12 @@ data AgentErrorType INACTIVE deriving (Eq, Show, Exception) +instance AnyError AgentErrorType where + fromSomeException e = case fromException e of + Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction" + _ -> INTERNAL $ show e + {-# INLINE fromSomeException #-} + -- | SMP agent protocol command or response error. data CommandErrorType = -- | command is prohibited in this context diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 71e35e7ec..6b866cee6 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -29,10 +29,11 @@ import Data.Time (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval (RI2State) +import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Agent.Store.Common import Simplex.Messaging.Agent.Store.Interface (createDBStore) import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations) -import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationError (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (MsgEncryptKeyX448, PQEncryption, PQSupport, RatchetX448) import Simplex.Messaging.Encoding.String @@ -52,9 +53,9 @@ import Simplex.Messaging.Protocol VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Agent.Store.Entity +import Simplex.Messaging.Util (AnyError (..), bshow) -createStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore :: DBOpts -> MigrationConfig -> IO (Either MigrationError DBStore) createStore dbOpts = createDBStore dbOpts appMigrations -- * Queue types @@ -692,7 +693,20 @@ data StoreError | -- | XFTP Deleted snd chunk replica not found. SEDeletedSndChunkReplicaNotFound | -- | Error when reading work item that suspends worker - do not use! - SEWorkItemError ByteString + SEWorkItemError {errContext :: String} | -- | Servers stats not found. SEServersStatsNotFound deriving (Eq, Show, Exception) + +instance AnyError StoreError where + fromSomeException = SEInternal . bshow + +class (Show e, AnyError e) => AnyStoreError e where + isWorkItemError :: e -> Bool + mkWorkItemError :: String -> e + +instance AnyStoreError StoreError where + isWorkItemError = \case + SEWorkItemError {} -> True + _ -> False + mkWorkItemError errContext = SEWorkItemError {errContext} diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index e10f48c8f..350d3bfe7 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -43,7 +43,7 @@ module Simplex.Messaging.Agent.Store.AgentStore getDeletedConn, getConns, getDeletedConns, - getConnData, + getConnsData, setConnDeleted, setConnUserId, setConnAgentVersion, @@ -237,6 +237,8 @@ module Simplex.Messaging.Agent.Store.AgentStore firstRow', maybeFirstRow, fromOnlyBI, + getWorkItem, + getWorkItems, ) where @@ -255,8 +257,9 @@ import Data.List (foldl', sortBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.Ord (Down (..)) +import qualified Data.Set as S import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Word (Word32) @@ -285,12 +288,14 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, firstRow, firstRow', ifM, maybeFirstRow, tshow, ($>>=), (<$$>)) +import Simplex.Messaging.Util import Simplex.Messaging.Version.Internal import qualified UnliftIO.Exception as E import UnliftIO.STM #if defined(dbPostgres) -import Database.PostgreSQL.Simple (Only (..), Query, SqlError, (:.) (..)) +import Data.List (sortOn) +import Data.Map.Strict (Map) +import Database.PostgreSQL.Simple (In (..), Only (..), Query, SqlError, (:.) (..)) import Database.PostgreSQL.Simple.Errors (constraintViolation) import Database.PostgreSQL.Simple.SqlQQ (sql) #else @@ -424,15 +429,12 @@ deleteConnRecord :: DB.Connection -> ConnId -> IO () deleteConnRecord db connId = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool -checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do - fromMaybe False - <$> maybeFirstRow - fromOnly - ( DB.query - db - "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" - (host server, port server, sndId, New) - ) +checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = + maybeFirstRow' False fromOnlyBI $ + DB.query + db + "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" + (host server, port server, sndId, New) getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do @@ -966,28 +968,25 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} = _ -> Left $ SEInternal "unexpected snd msg data" markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId) -getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a)) +getWorkItem :: (Show i, AnyStoreError e) => String -> IO (Maybe i) -> (i -> IO (Either e a)) -> (i -> IO ()) -> IO (Either e (Maybe a)) getWorkItem itemName getId getItem markFailed = runExceptT $ handleWrkErr itemName "getId" getId >>= mapM (tryGetItem itemName getItem markFailed) -getWorkItems :: Show i => ByteString -> IO [i] -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError [Either StoreError a]) +getWorkItems :: (Show i, AnyStoreError e) => String -> IO [i] -> (i -> IO (Either e a)) -> (i -> IO ()) -> IO (Either e [Either e a]) getWorkItems itemName getIds getItem markFailed = runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed) -tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a -tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e +tryGetItem :: (Show i, AnyStoreError e) => String -> (i -> IO (Either e a)) -> (i -> IO ()) -> i -> ExceptT e IO a +tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchAllErrors` \e -> mark >> throwE e where - mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId - -catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a -catchStoreError = catchAllErrors (SEInternal . bshow) + mark = handleWrkErr itemName ("markFailed ID " <> show itemId) $ markFailed itemId -- Errors caught by this function will suspend worker as if there is no more work, -handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a +handleWrkErr :: forall e a. AnyStoreError e => String -> String -> IO a -> ExceptT e IO a handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action where - mkError :: E.SomeException -> StoreError - mkError e = SEWorkItemError $ itemName <> " " <> opName <> " error: " <> bshow e + mkError :: E.SomeException -> e + mkError e = mkWorkItemError $ itemName <> " " <> opName <> " error: " <> show e updatePendingMsgRIState :: DB.Connection -> ConnId -> InternalId -> RI2State -> IO () updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} = @@ -1073,15 +1072,12 @@ toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck} checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool -checkRcvMsgHashExists db connId hash = do - fromMaybe False - <$> maybeFirstRow - fromOnly - ( DB.query - db - "SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" - (connId, Binary hash) - ) +checkRcvMsgHashExists db connId hash = + maybeFirstRow' False fromOnlyBI $ + DB.query + db + "SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" + (connId, Binary hash) getRcvMsgBrokerTs :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Either StoreError BrokerTs) getRcvMsgBrokerTs db connId msgId = @@ -1305,21 +1301,26 @@ insertedRowId db = fromOnly . head <$> DB.query_ db q q = "SELECT last_insert_rowid()" #endif -getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer] -getPendingCommandServers db connId = do +getPendingCommandServers :: DB.Connection -> [ConnId] -> IO [(ConnId, NonEmpty (Maybe SMPServer))] +getPendingCommandServers db connIds = -- TODO review whether this can break if, e.g., the server has another key hash. - map smpServer - <$> DB.query + mapMaybe connServers . groupOn' rowConnId + <$> DB.query_ db [sql| - SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) + SELECT DISTINCT c.conn_id, c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) FROM commands c LEFT JOIN servers s ON s.host = c.host AND s.port = c.port - WHERE conn_id = ? + ORDER BY c.conn_id |] - (Only connId) where + rowConnId (Only connId :. _) = connId + connServers rs = + let connId = rowConnId $ L.head rs + srvs = L.map (\(_ :. r) -> smpServer r) rs + in if connId `S.member` conns then Just (connId, srvs) else Nothing smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash + conns = S.fromList connIds getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed @@ -2037,21 +2038,19 @@ getDeletedConn = getAnyConn True {-# INLINE getDeletedConn #-} getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getAnyConn deleted' dbConn connId = - getConnData dbConn connId >>= \case +getAnyConn deleted' db connId = + getConnData deleted' db connId >>= \case + Just (cData, cMode) -> do + rQ <- getRcvQueuesByConnId_ db connId + sQ <- getSndQueuesByConnId_ db connId + pure $ case (rQ, sQ, cMode) of + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) + _ -> Left SEConnNotFound Nothing -> pure $ Left SEConnNotFound - Just (cData@ConnData {deleted}, cMode) - | deleted /= deleted' -> pure $ Left SEConnNotFound - | otherwise -> do - rQ <- getRcvQueuesByConnId_ dbConn connId - sQ <- getSndQueuesByConnId_ dbConn connId - pure $ case (rQ, sQ, cMode) of - (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) - (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) - (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) - (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) - (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) - _ -> Left SEConnNotFound getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getConns = getAnyConns_ False @@ -2061,28 +2060,84 @@ getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getDeletedConns = getAnyConns_ True {-# INLINE getDeletedConns #-} +#if defined(dbPostgres) getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db +getAnyConns_ deleted' db connIds = do + cs <- getConnsData_ deleted' db connIds + let connIds' = M.keys cs + rQs :: Map ConnId (NonEmpty RcvQueue) <- getRcvQueuesByConnIds_ connIds' + sQs :: Map ConnId (NonEmpty SndQueue) <- getSndQueuesByConnIds_ connIds' + pure $ map (result cs rQs sQs) connIds where - handleDBError :: E.SomeException -> IO (Either StoreError SomeConn) - handleDBError = pure . Left . SEInternal . bshow + getRcvQueuesByConnIds_ connIds' = + toQueueMap primaryFirst toRcvQueue + <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds')) + where + primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = + compare (Down p) (Down p') <> compare i i' + getSndQueuesByConnIds_ connIds' = + toQueueMap primaryFirst toSndQueue + <$> DB.query db (sndQueueQuery <> " WHERE q.conn_id IN ?") (Only (In connIds')) + where + primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = + compare (Down p) (Down p') <> compare i i' + toQueueMap primaryFst toQueue = + M.fromList . map (\qs@(q :| _) -> (qConnId q, L.sortBy primaryFst qs)) . groupOn' qConnId . sortOn qConnId . map toQueue + result cs rQs sQs connId = case M.lookup connId cs of + Just (cData, cMode) -> case (M.lookup connId rQs, M.lookup connId sQs, cMode) of + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) + _ -> Left SEConnNotFound + Nothing -> Left SEConnNotFound -getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) -getConnData db connId' = - maybeFirstRow cData $ +getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] +getConnsData db connIds = do + cs <- getConnsData_ False db connIds + pure $ map (Right . (`M.lookup` cs)) connIds + +getConnsData_ :: Bool -> DB.Connection -> [ConnId] -> IO (Map ConnId (ConnData, ConnectionMode)) +getConnsData_ deleted' db connIds = + M.fromList . map ((\c@(ConnData {connId}, _) -> (connId, c)) . rowToConnData) <$> DB.query db [sql| - SELECT - user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, + SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support FROM connections - WHERE conn_id = ? + WHERE conn_id IN ? AND deleted = ? |] - (Only connId') - where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) + (In connIds, BI deleted') + +#else +getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] +getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db + +getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] +getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db + +handleDBError :: E.SomeException -> IO (Either StoreError a) +handleDBError = pure . Left . SEInternal . bshow +#endif + +getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) +getConnData deleted' db connId' = + maybeFirstRow rowToConnData $ + DB.query + db + [sql| + SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support + FROM connections + WHERE conn_id = ? AND deleted = ? + |] + (connId', BI deleted') + +rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode) +rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () setConnDeleted db waitDelivery connId @@ -2120,15 +2175,12 @@ addProcessedRatchetKeyHash db connId hash = DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, Binary hash) checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool -checkRatchetKeyHashExists db connId hash = do - fromMaybe False - <$> maybeFirstRow - fromOnly - ( DB.query - db - "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" - (connId, Binary hash) - ) +checkRatchetKeyHashExists db connId hash = + maybeFirstRow' False fromOnlyBI $ + DB.query + db + "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" + (connId, Binary hash) deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO () deleteRatchetKeyHashesExpired db ttl = do @@ -2906,8 +2958,8 @@ deleteSndFile' db sndFileId = getSndFileDeleted :: DB.Connection -> DBSndFileId -> IO Bool getSndFileDeleted db sndFileId = - fromMaybe True - <$> maybeFirstRow fromOnlyBI (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)) + maybeFirstRow' True fromOnlyBI $ + DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId) createSndFileReplica :: DB.Connection -> SndFileChunk -> NewSndChunkReplica -> IO () createSndFileReplica db SndFileChunk {sndChunkId} = createSndFileReplica_ db sndChunkId diff --git a/src/Simplex/Messaging/Agent/Store/Migrations.hs b/src/Simplex/Messaging/Agent/Store/Migrations.hs index f6b6c2df3..27c35b790 100644 --- a/src/Simplex/Messaging/Agent/Store/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/Migrations.hs @@ -15,7 +15,7 @@ where import Control.Monad import Data.Char (toLower) import Data.Functor (($>)) -import Data.Maybe (isNothing, mapMaybe) +import Data.Maybe (isJust, isNothing, mapMaybe) import Simplex.Messaging.Agent.Store.Shared import System.Exit (exitFailure) import System.IO (hFlush, stdout) @@ -37,7 +37,7 @@ data DBMigrate = DBMigrate { initialize :: IO (), getCurrent :: IO [Migration], run :: MigrationsToRun -> IO (), - backup :: IO () + backup :: Maybe (IO ()) } sharedMigrateSchema :: DBMigrate -> Bool -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError ()) @@ -54,20 +54,20 @@ sharedMigrateSchema dbm dbNew' migrations confirmMigrations = do | otherwise -> case confirmMigrations of MCYesUp -> runWithBackup ms MCYesUpDown -> runWithBackup ms - MCConsole -> confirm err >> runWithBackup ms + MCConsole -> confirm' err >> runWithBackup ms MCError -> pure $ Left err where err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums) Right ms@(MTRDown dms) -> case confirmMigrations of MCYesUpDown -> runWithBackup ms - MCConsole -> confirm err >> runWithBackup ms + MCConsole -> confirm' err >> runWithBackup ms MCYesUp -> pure $ Left err MCError -> pure $ Left err where err = MEDowngrade $ map downName dms where - runWithBackup ms = backup dbm >> run dbm ms $> Right () - confirm err = confirmOrExit $ migrationErrorDescription err + runWithBackup ms = sequence (backup dbm) >> run dbm ms $> Right () + confirm' err = confirmOrExit $ migrationErrorDescription (isJust $ backup dbm) err confirmOrExit :: String -> IO () confirmOrExit s = do diff --git a/src/Simplex/Messaging/Agent/Store/Postgres.hs b/src/Simplex/Messaging/Agent/Store/Postgres.hs index 075e4be48..18b1a7a2d 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres.hs @@ -30,15 +30,15 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc import qualified Simplex.Messaging.Agent.Store.Postgres.Migrations as Migrations import Simplex.Messaging.Agent.Store.Postgres.Common import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB -import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..)) import Simplex.Messaging.Util (ifM, safeDecodeUtf8) import System.Exit (exitFailure) -- | Create a new Postgres DBStore with the given connection string, schema name and migrations. -- If passed schema does not exist in connectInfo database, it will be created. -- Applies necessary migrations to schema. -createDBStore :: DBOpts -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) -createDBStore opts migrations confirmMigrations = do +createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore) +createDBStore opts migrations MigrationConfig {confirm} = do st <- connectPostgresStore opts r <- migrateSchema st `onException` closeDBStore st case r of @@ -48,15 +48,16 @@ createDBStore opts migrations confirmMigrations = do migrateSchema st = let initialize = Migrations.initialize st getCurrent = withTransaction st Migrations.getCurrentMigrations - dbm = DBMigrate {initialize, getCurrent, run = Migrations.run st, backup = pure ()} - in sharedMigrateSchema dbm (dbNew st) migrations confirmMigrations + dbm = DBMigrate {initialize, getCurrent, run = Migrations.run st, backup = Nothing} + in sharedMigrateSchema dbm (dbNew st) migrations confirm connectPostgresStore :: DBOpts -> IO DBStore connectPostgresStore DBOpts {connstr, schema, poolSize, createSchema} = do dbPriorityPool <- newDBStorePool poolSize dbPool <- newDBStorePool poolSize dbClosed <- newTVarIO True - let st = DBStore {dbConnstr = connstr, dbSchema = schema, dbPoolSize = fromIntegral poolSize, dbPriorityPool, dbPool, dbNew = False, dbClosed} + let dbConnect = fst <$> connectDB connstr schema False + st = DBStore {dbConnstr = connstr, dbSchema = schema, dbPoolSize = fromIntegral poolSize, dbPriorityPool, dbPool, dbConnect, dbNew = False, dbClosed} dbNew <- connectStore st createSchema pure st {dbNew} diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs index 3ca0a755e..fac2c1c10 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs @@ -2,6 +2,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module Simplex.Messaging.Agent.Store.Postgres.Common @@ -19,7 +20,7 @@ where import Control.Concurrent.MVar import Control.Concurrent.STM -import Control.Exception (bracket) +import qualified Control.Exception as E import Data.ByteString (ByteString) import qualified Database.PostgreSQL.Simple as PSQL import Numeric.Natural (Natural) @@ -32,11 +33,7 @@ data DBStore = DBStore dbPoolSize :: Int, dbPriorityPool :: DBStorePool, dbPool :: DBStorePool, - -- dbPoolSize :: Int, - -- dbPool :: TBQueue PSQL.Connection, - -- -- MVar is needed for fair pool distribution, without STM retry contention. - -- -- Only one thread can be blocked on STM read. - -- dbSem :: MVar (), + dbConnect :: IO PSQL.Connection, dbClosed :: TVar Bool, dbNew :: Bool } @@ -55,15 +52,23 @@ data DBStorePool = DBStorePool } withConnectionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a -withConnectionPriority DBStore {dbPriorityPool, dbPool} priority = - withConnectionPool $ if priority then dbPriorityPool else dbPool +withConnectionPriority DBStore {dbPriorityPool, dbPool, dbConnect} priority = + withConnectionPool (if priority then dbPriorityPool else dbPool) dbConnect {-# INLINE withConnectionPriority #-} -withConnectionPool :: DBStorePool -> (PSQL.Connection -> IO a) -> IO a -withConnectionPool DBStorePool {dbPoolConns, dbSem} = - bracket - (withMVar dbSem $ \_ -> atomically $ readTBQueue dbPoolConns) - (atomically . writeTBQueue dbPoolConns) +withConnectionPool :: DBStorePool -> IO PSQL.Connection -> (PSQL.Connection -> IO a) -> IO a +withConnectionPool DBStorePool {dbPoolConns, dbSem} dbConnect action = + E.mask $ \restore -> do + conn <- withMVar dbSem $ \_ -> atomically $ readTBQueue dbPoolConns + r <- restore (action conn) `E.onException` reset conn + atomically $ writeTBQueue dbPoolConns conn + pure r + where + reset conn = do + conn' <- E.try dbConnect >>= \case + Right conn' -> PSQL.close conn >> pure conn' + Left (_ :: E.SomeException) -> pure conn + atomically $ writeTBQueue dbPoolConns conn' withConnection :: DBStore -> (PSQL.Connection -> IO a) -> IO a withConnection st = withConnectionPriority st False diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index c724c031b..6203357fc 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -57,18 +57,18 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Agent.Store.SQLite.Common import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..)) import Simplex.Messaging.Util (ifM, safeDecodeUtf8) import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) -import System.FilePath (takeDirectory) +import System.FilePath (takeDirectory, takeFileName, ()) import UnliftIO.Exception (bracketOnError, onException) import UnliftIO.MVar import UnliftIO.STM -- * SQLite Store implementation -createDBStore :: DBOpts -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) -createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations confirmMigrations = do +createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore) +createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations MigrationConfig {confirm, backupPath} = do let dbDir = takeDirectory dbFilePath createDirectoryIfMissing True dbDir st <- connectSQLiteStore dbFilePath dbKey keepKey track @@ -81,9 +81,12 @@ createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations conf let initialize = Migrations.initialize st getCurrent = withTransaction st Migrations.getCurrentMigrations run = Migrations.run st vacuum - backup = copyFile dbFilePath (dbFilePath <> ".bak") + backup = mkBackup <$> backupPath + mkBackup bp = + let f = if null bp then dbFilePath else bp takeFileName dbFilePath + in copyFile dbFilePath $ f <> ".bak" dbm = DBMigrate {initialize, getCurrent, run, backup} - in sharedMigrateSchema dbm (dbNew st) migrations confirmMigrations + in sharedMigrateSchema dbm (dbNew st) migrations confirm connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> DB.TrackQueries -> IO DBStore connectSQLiteStore dbFilePath key keepKey track = do diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs index 7da6b2ca2..2620e561b 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs @@ -52,7 +52,7 @@ import Simplex.Messaging.Util (diffToMicroseconds, tshow) newtype BoolInt = BI {unBI :: Bool} deriving newtype (FromField, ToField) -newtype Binary = Binary {fromBinary :: ByteString} +newtype Binary a = Binary {fromBinary :: a} deriving newtype (FromField, ToField) data Connection = Connection diff --git a/src/Simplex/Messaging/Agent/Store/Shared.hs b/src/Simplex/Messaging/Agent/Store/Shared.hs index 3921bf586..67edbb42b 100644 --- a/src/Simplex/Messaging/Agent/Store/Shared.hs +++ b/src/Simplex/Messaging/Agent/Store/Shared.hs @@ -9,6 +9,7 @@ module Simplex.Messaging.Agent.Store.Shared DownMigration (..), MTRError (..), mtrErrorDescription, + MigrationConfig (..), MigrationConfirmation (..), MigrationError (..), UpMigration (..), @@ -55,13 +56,15 @@ data MigrationError | MigrationError {mtrError :: MTRError} deriving (Eq, Show) -migrationErrorDescription :: MigrationError -> String -migrationErrorDescription = \case +migrationErrorDescription :: Bool -> MigrationError -> String +migrationErrorDescription withBackup = \case MEUpgrade ums -> - "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map upName ums) + "The app has a newer version than the database.\nConfirm to " <> backupStr <> "upgrade using these migrations: " <> intercalate ", " (map upName ums) MEDowngrade dms -> - "Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms + "Database version is newer than the app.\nConfirm to " <> backupStr <> "downgrade using these migrations: " <> intercalate ", " dms MigrationError err -> mtrErrorDescription err + where + backupStr = if withBackup then "back up and " else "" data UpMigration = UpMigration {upName :: String, withDown :: Bool} deriving (Eq, Show) @@ -69,6 +72,11 @@ data UpMigration = UpMigration {upName :: String, withDown :: Bool} upMigration :: Migration -> UpMigration upMigration Migration {name, down} = UpMigration name $ isJust down +data MigrationConfig = MigrationConfig + { confirm :: MigrationConfirmation, + backupPath :: Maybe FilePath -- Nothing - no backup, empty string - the same folder + } + data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError deriving (Eq, Show) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index afc19e8d6..32e52e3aa 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -597,12 +597,14 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS socksCreds = clientSocksCredentials networkConfig proxySessTs transportSession tId <- runTransportClient tcConfig socksCreds useHost port' (Just $ keyHash srv) (client t c cVar) - `forkFinally` \_ -> void (atomically . tryPutTMVar cVar $ Left PCENetworkError) + `forkFinally` \r -> + let err = either toNetworkError (const NEFailedError) r + in void $ atomically $ tryPutTMVar cVar $ Left $ PCENetworkError err c_ <- netTimeoutInt tcpConnectTimeout nm `timeout` atomically (takeTMVar cVar) case c_ of Just (Right c') -> mkWeakThreadId tId >>= \tId' -> pure $ Right c' {action = Just tId'} Just (Left e) -> pure $ Left e - Nothing -> killThread tId $> Left PCENetworkError + Nothing -> killThread tId $> Left (PCENetworkError NETimeoutError) useTransport :: (ServiceName, ATransport 'TClient) useTransport = case port srv of @@ -743,7 +745,7 @@ data ProtocolClientError err PCEResponseTimeout | -- | Failure to establish TCP connection. -- Forwarded to the agent client as `ERR BROKER NETWORK`. - PCENetworkError + PCENetworkError NetworkError | -- | No host compatible with network configuration PCEIncompatibleHost | -- | Service is unavailable for command that requires service connection @@ -761,7 +763,7 @@ type SMPClientError = ProtocolClientError ErrorType temporaryClientError :: ProtocolClientError err -> Bool temporaryClientError = \case - PCENetworkError -> True + PCENetworkError _ -> True PCEResponseTimeout -> True PCEIOError _ -> True _ -> False @@ -782,7 +784,7 @@ smpProxyError = \case PCEResponseError e -> PROXY $ BROKER $ RESPONSE $ B.unpack $ strEncode e PCEUnexpectedResponse e -> PROXY $ BROKER $ UNEXPECTED $ B.unpack e PCEResponseTimeout -> PROXY $ BROKER TIMEOUT - PCENetworkError -> PROXY $ BROKER NETWORK + PCENetworkError e -> PROXY $ BROKER $ NETWORK e PCEIncompatibleHost -> PROXY $ BROKER HOST PCEServiceUnavailable -> PROXY $ BROKER $ NO_SERVICE -- for completeness, it cannot happen. PCETransportError t -> PROXY $ BROKER $ TRANSPORT t diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 31b611c17..604960360 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -391,7 +391,7 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE where logSMPError :: SMPClientError -> ExceptT SMPClientError IO a logSMPError e = do - logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode $ host srv) <> "): " <> tshow e + logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode srv) <> "): " <> tshow e throwE e subscribeQueuesNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO () diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 3494eaaf5..fe2574eab 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -613,7 +613,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = PCEIncompatibleHost -> Just $ NSErr "IncompatibleHost" PCEServiceUnavailable -> Just NSService -- this error should not happen on individual subscriptions PCEResponseTimeout -> Nothing - PCENetworkError -> Nothing + PCENetworkError _ -> Nothing PCEIOError _ -> Nothing where -- Note on moving to PostgreSQL: the idea of logging errors without e is removed here diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index b6f23047f..78891796f 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -54,7 +54,8 @@ import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common -import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_) +import Simplex.Messaging.Agent.Store.Postgres.DB (fromTextField_) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..)) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import qualified Simplex.Messaging.Crypto as C @@ -63,7 +64,6 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Server.StoreLog -import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer) import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate) import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) @@ -76,6 +76,8 @@ import System.IO (IOMode (..), hFlush, stdout, withFile) import Text.Hex (decodeHex) #if !defined(dbPostgres) +import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (eitherToMaybe) #endif @@ -98,7 +100,7 @@ data NtfEntityRec (e :: NtfEntity) where newNtfDbStore :: PostgresStoreCfg -> IO NtfPostgresStore newNtfDbStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do - dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations confirmMigrations + dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations (MigrationConfig confirmMigrations Nothing) dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath pure NtfPostgresStore {dbStore, dbStoreLog, deletedTTL} where diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 1e778deac..40314ad2a 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -81,6 +81,7 @@ module Simplex.Messaging.Protocol CommandError (..), ProxyError (..), BrokerErrorType (..), + NetworkError (..), BlockingInfo (..), BlockingReason (..), RawTransmission, @@ -168,6 +169,7 @@ module Simplex.Messaging.Protocol noMsgFlags, messageId, messageTs, + toNetworkError, -- * Parse and serialize ProtocolMsgTag (..), @@ -212,7 +214,7 @@ module Simplex.Messaging.Protocol where import Control.Applicative (optional, (<|>)) -import Control.Exception (Exception) +import Control.Exception (Exception, SomeException, displayException, fromException) import Control.Monad.Except import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as J @@ -241,6 +243,7 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import qualified GHC.TypeLits as TE import qualified GHC.TypeLits as Type import Network.Socket (ServiceName) +import qualified Network.TLS as TLS import Simplex.Messaging.Agent.Store.DB (Binary (..), FromField (..), ToField (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -1555,7 +1558,7 @@ data BrokerErrorType | -- | unexpected response UNEXPECTED {respErr :: String} | -- | network error - NETWORK + NETWORK {networkError :: NetworkError} | -- | no compatible server host (e.g. onion when public is required, or vice versa) HOST | -- | service unavailable client-side - used in agent errors @@ -1566,6 +1569,24 @@ data BrokerErrorType TIMEOUT deriving (Eq, Read, Show, Exception) +data NetworkError + = NEConnectError {connectError :: String} + | NETLSError {tlsError :: String} + | NEUnknownCAError + | NEFailedError + | NETimeoutError + | NESubscribeError {subscribeError :: String} + deriving (Eq, Read, Show) + +toNetworkError :: SomeException -> NetworkError +toNetworkError e = maybe (NEConnectError err) fromTLSError (fromException e) + where + err = displayException e + fromTLSError :: TLS.TLSException -> NetworkError + fromTLSError = \case + TLS.HandshakeFailed (TLS.Error_Protocol _ TLS.UnknownCa) -> NEUnknownCAError + _ -> NETLSError err + data BlockingInfo = BlockingInfo { reason :: BlockingReason } @@ -2001,7 +2022,7 @@ instance Encoding BrokerErrorType where RESPONSE e -> "RESPONSE " <> smpEncode e UNEXPECTED e -> "UNEXPECTED " <> smpEncode e TRANSPORT e -> "TRANSPORT " <> smpEncode e - NETWORK -> "NETWORK" + NETWORK _e -> "NETWORK" -- TODO once all upgrade: "NETWORK " <> smpEncode e TIMEOUT -> "TIMEOUT" HOST -> "HOST" NO_SERVICE -> "NO_SERVICE" @@ -2010,7 +2031,7 @@ instance Encoding BrokerErrorType where "RESPONSE" -> RESPONSE <$> _smpP "UNEXPECTED" -> UNEXPECTED <$> _smpP "TRANSPORT" -> TRANSPORT <$> _smpP - "NETWORK" -> pure NETWORK + "NETWORK" -> NETWORK <$> (_smpP <|> pure NEFailedError) "TIMEOUT" -> pure TIMEOUT "HOST" -> pure HOST "NO_SERVICE" -> pure NO_SERVICE @@ -2021,7 +2042,7 @@ instance StrEncoding BrokerErrorType where RESPONSE e -> "RESPONSE " <> encodeUtf8 (T.pack e) UNEXPECTED e -> "UNEXPECTED " <> encodeUtf8 (T.pack e) TRANSPORT e -> "TRANSPORT " <> smpEncode e - NETWORK -> "NETWORK" + NETWORK _e -> "NETWORK" -- TODO once all upgrade: "NETWORK " <> strEncode e TIMEOUT -> "TIMEOUT" HOST -> "HOST" NO_SERVICE -> "NO_SERVICE" @@ -2030,13 +2051,50 @@ instance StrEncoding BrokerErrorType where "RESPONSE" -> RESPONSE <$> _textP "UNEXPECTED" -> UNEXPECTED <$> _textP "TRANSPORT" -> TRANSPORT <$> _smpP - "NETWORK" -> pure NETWORK + "NETWORK" -> NETWORK <$> (_strP <|> pure NEFailedError) "TIMEOUT" -> pure TIMEOUT "HOST" -> pure HOST "NO_SERVICE" -> pure NO_SERVICE _ -> fail "bad BrokerErrorType" - where - _textP = A.space *> (T.unpack . safeDecodeUtf8 <$> A.takeByteString) + +instance Encoding NetworkError where + smpEncode = \case + NEConnectError e -> "CONNECT " <> smpEncode e + NETLSError e -> "TLS " <> smpEncode e + NEUnknownCAError -> "UNKNOWNCA" + NEFailedError -> "FAILED" + NETimeoutError -> "TIMEOUT" + NESubscribeError e -> "SUBSCRIBE " <> smpEncode e + smpP = + A.takeTill (== ' ') >>= \case + "CONNECT" -> NEConnectError <$> _smpP + "TLS" -> NETLSError <$> _smpP + "UNKNOWNCA" -> pure NEUnknownCAError + "FAILED" -> pure NEFailedError + "TIMEOUT" -> pure NETimeoutError + "SUBSCRIBE" -> NESubscribeError <$> _smpP + _ -> fail "bad NetworkError" + +instance StrEncoding NetworkError where + strEncode = \case + NEConnectError e -> "CONNECT " <> encodeUtf8 (T.pack e) + NETLSError e -> "TLS " <> encodeUtf8 (T.pack e) + NEUnknownCAError -> "UNKNOWNCA" + NEFailedError -> "FAILED" + NETimeoutError -> "TIMEOUT" + NESubscribeError e -> "SUBSCRIBE " <> encodeUtf8 (T.pack e) + strP = + A.takeTill (== ' ') >>= \case + "CONNECT" -> NEConnectError <$> _textP + "TLS" -> NETLSError <$> _textP + "UNKNOWNCA" -> pure NEUnknownCAError + "FAILED" -> pure NEFailedError + "TIMEOUT" -> pure NETimeoutError + "SUBSCRIBE" -> NESubscribeError <$> _textP + _ -> fail "bad NetworkError" + +_textP :: Parser String +_textP = A.space *> (T.unpack . safeDecodeUtf8 <$> A.takeByteString) -- | Send signed SMP transmission to TCP transport. tPut :: Transport c => THandle v c p -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] @@ -2200,6 +2258,12 @@ $(J.deriveJSON defaultJSON ''MsgFlags) $(J.deriveJSON (sumTypeJSON id) ''CommandError) +$(J.deriveToJSON (sumTypeJSON $ dropPrefix "NE") ''NetworkError) + +instance FromJSON NetworkError where + parseJSON = $(J.mkParseJSON (sumTypeJSON $ dropPrefix "NE") ''NetworkError) + omittedField = Just NEFailedError + $(J.deriveJSON (sumTypeJSON id) ''BrokerErrorType) $(J.deriveJSON defaultJSON ''BlockingInfo) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 9ca6856ee..7d6e00ab0 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -105,7 +105,7 @@ import Simplex.Messaging.Server.Control import Simplex.Messaging.Server.Env.STM as Env import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.MsgStore -import Simplex.Messaging.Server.MsgStore.Journal (JournalMsgStore, JournalQueue) +import Simplex.Messaging.Server.MsgStore.Journal (JournalMsgStore, JournalQueue (..), getJournalQueueMessages) import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.MsgStore.Types import Simplex.Messaging.Server.NtfStore @@ -132,12 +132,17 @@ import UnliftIO.Directory (doesFileExist, renameFile) import UnliftIO.Exception import UnliftIO.IO import UnliftIO.STM + #if MIN_VERSION_base(4,18,0) import Data.List (sort) import GHC.Conc (listThreads, threadStatus) import GHC.Conc.Sync (threadLabel) #endif +#if defined(dbServerPostgres) +import Simplex.Messaging.Server.MsgStore.Postgres (exportDbMessages, getDbMessageStats) +#endif + -- | Runs an SMP server using passed configuration. -- -- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs @@ -477,7 +482,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt atomicWriteIORef (msgCount stats) stored atomicModifyIORef'_ (msgExpired stats) (+ expired) printMessageStats "STORE: messages" msgStats - Left e -> logError $ "STORE: withAllMsgQueues, error expiring messages, " <> tshow e + Left e -> logError $ "STORE: expireOldMessages, error expiring messages, " <> tshow e expireNtfsThread :: ServerConfig s -> M s () expireNtfsThread ServerConfig {notificationExpiration = expCfg} = do @@ -1848,10 +1853,10 @@ client Right body -> do when (isJust (queueData qr) && isSecuredMsgQueue qr) $ void $ liftIO $ deleteQueueLinkData (queueStore ms) q - ServerConfig {messageExpiration, msgIdBytes} <- asks config + ServerConfig {messageExpiration, expireMessagesOnSend, msgIdBytes} <- asks config msgId <- randomId' msgIdBytes msg_ <- liftIO $ runExceptT $ do - expireMessages messageExpiration stats + when expireMessagesOnSend $ mapM_ (expireMessages stats) messageExpiration msg <- liftIO $ mkMessage msgId body writeMsg ms q True msg case msg_ of @@ -1875,9 +1880,9 @@ client msgTs <- getSystemTime pure $! Message msgId msgTs msgFlags body - expireMessages :: Maybe ExpirationConfig -> ServerStats -> ExceptT ErrorType IO () - expireMessages msgExp stats = do - deleted <- maybe (pure 0) (deleteExpiredMsgs ms q <=< liftIO . expireBeforeEpoch) msgExp + expireMessages :: ServerStats -> ExpirationConfig -> ExceptT ErrorType IO () + expireMessages stats msgExp = do + deleted <- deleteExpiredMsgs ms q =<< liftIO (expireBeforeEpoch msgExp) liftIO $ when (deleted > 0) $ atomicModifyIORef'_ (msgExpired stats) (+ deleted) -- The condition for delivery of the message is: @@ -2104,27 +2109,42 @@ randomId = fmap EntityId . randomId' {-# INLINE randomId #-} saveServerMessages :: Bool -> MsgStore s -> IO () -saveServerMessages drainMsgs = \case - StoreMemory ms@STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of +saveServerMessages drainMsgs ms = case ms of + StoreMemory STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of Just f -> exportMessages False ms f drainMsgs Nothing -> logNote "undelivered messages are not saved" StoreJournal _ -> logNote "closed journal message storage" +#if defined(dbServerPostgres) + StoreDatabase _ -> logNote "closed postgres message storage" +#endif -exportMessages :: MsgStoreClass s => Bool -> s -> FilePath -> Bool -> IO () -exportMessages tty ms f drainMsgs = do +exportMessages :: forall s. MsgStoreClass s => Bool -> MsgStore s -> FilePath -> Bool -> IO () +exportMessages tty st f drainMsgs = do logNote $ "saving messages to file " <> T.pack f - liftIO $ withFile f WriteMode $ \h -> - tryAny (unsafeWithAllMsgQueues tty True ms $ saveQueueMsgs h) >>= \case - Right (Sum total) -> logNote $ "messages saved: " <> tshow total + run $ case st of + StoreMemory ms -> exportMessages_ ms $ getMsgs ms + StoreJournal ms -> exportMessages_ ms $ getJournalMsgs ms +#if defined(dbServerPostgres) + StoreDatabase ms -> exportDbMessages tty ms +#endif + where + exportMessages_ ms get = fmap (\(Sum n) -> n) . unsafeWithAllMsgQueues tty ms . saveQueueMsgs get + run :: (Handle -> IO Int) -> IO () + run a = liftIO $ withFile f WriteMode $ tryAny . a >=> \case + Right n -> logNote $ "messages saved: " <> tshow n Left e -> do logError $ "error exporting messages: " <> tshow e exitFailure - where - saveQueueMsgs h q = do - msgs <- - unsafeRunStore q "saveQueueMsgs" $ - getQueueMessages_ drainMsgs q =<< getMsgQueue ms q False - BLD.hPutBuilder h $ encodeMessages (recipientId q) msgs + getJournalMsgs ms q = + readTVarIO (msgQueue' q) >>= \case + Just _ -> getMsgs ms q + Nothing -> getJournalQueueMessages ms q + getMsgs :: MsgStoreClass s' => s' -> StoreQueue s' -> IO [Message] + getMsgs ms q = unsafeRunStore q "saveQueueMsgs" $ getQueueMessages_ drainMsgs q =<< getMsgQueue ms q False + saveQueueMsgs :: (StoreQueue s -> IO [Message]) -> Handle -> StoreQueue s -> IO (Sum Int) + saveQueueMsgs get h q = do + msgs <- get q + unless (null msgs) $ BLD.hPutBuilder h $ encodeMessages (recipientId q) msgs pure $ Sum $ length msgs encodeMessages rId = mconcat . map (\msg -> BLD.byteString (strEncode $ MLRv3 rId msg) <> BLD.char8 '\n') @@ -2140,6 +2160,9 @@ processServerMessages StartOptions {skipWarnings} = do Just f -> ifM (doesFileExist f) (Just <$> importMessages False ms f old_ skipWarnings) (pure Nothing) Nothing -> pure Nothing StoreJournal ms -> processJournalMessages old_ expire ms +#if defined(dbServerPostgres) + StoreDatabase ms -> processDbMessages old_ expire ms +#endif processJournalMessages :: forall s. Maybe Int64 -> Bool -> JournalMsgStore s -> IO (Maybe MessageStats) processJournalMessages old_ expire ms | expire = Just <$> case old_ of @@ -2151,7 +2174,7 @@ processServerMessages StartOptions {skipWarnings} = do run processValidateQueue | otherwise = logWarn "skipping message expiration" $> Nothing where - run a = unsafeWithAllMsgQueues False False ms a `catchAny` \_ -> exitFailure + run a = unsafeWithAllMsgQueues False ms a `catchAny` \_ -> exitFailure processExpireQueue :: Int64 -> JournalQueue s -> IO MessageStats processExpireQueue old q = unsafeRunStore q "processExpireQueue" $ do mq <- getMsgQueue ms q False @@ -2162,6 +2185,17 @@ processServerMessages StartOptions {skipWarnings} = do processValidateQueue q = unsafeRunStore q "processValidateQueue" $ do storedMsgsCount <- getQueueSize_ =<< getMsgQueue ms q False pure newMessageStats {storedMsgsCount, storedQueues = 1} +#if defined(dbServerPostgres) + processDbMessages old_ expire ms + | expire = Just <$> case old_ of + Just old -> do + -- TODO [messages] expire messages from all queues, not only recent + logNote "expiring database store messages..." + now <- systemSeconds <$> getSystemTime + expireOldMessages False ms now (now - old) + Nothing -> getDbMessageStats ms + | otherwise = logWarn "skipping message expiration" $> Nothing +#endif importMessages :: forall s. MsgStoreClass s => Bool -> s -> FilePath -> Maybe Int64 -> Bool -> IO MessageStats importMessages tty ms f old_ skipWarnings = do diff --git a/src/Simplex/Messaging/Server/CLI.hs b/src/Simplex/Messaging/Server/CLI.hs index 04db0231c..89077ea47 100644 --- a/src/Simplex/Messaging/Server/CLI.hs +++ b/src/Simplex/Messaging/Server/CLI.hs @@ -33,7 +33,7 @@ import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), ProtocolServer (..), ProtocolTypeI) -import Simplex.Messaging.Server.Env.STM (ServerStoreCfg (..), StartOptions (..), StorePaths (..)) +import Simplex.Messaging.Server.Env.STM (ServerStoreCfg (..), StartOptions (..), dbStoreCfg, storeLogFile') import Simplex.Messaging.Server.Main.GitCommit import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) import Simplex.Messaging.Transport (ASrvTransport, ATransport (..), TLS, Transport (..), simplexMQVersion) @@ -414,12 +414,13 @@ printServerTransports protocol ts = do \Set `port` in smp-server.ini section [TRANSPORT] to `5223,443`\n" printSMPServerConfig :: [(ServiceName, ASrvTransport, AddHTTP)] -> ServerStoreCfg s -> IO () -printSMPServerConfig transports = \case - SSCMemory sp_ -> printServerConfig "SMP" transports $ (\StorePaths {storeLogFile} -> storeLogFile) <$> sp_ - SSCMemoryJournal {storeLogFile} -> printServerConfig "SMP" transports $ Just storeLogFile - SSCDatabaseJournal {storeCfg = PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}}} -> do - B.putStrLn $ "PostgreSQL database: " <> connstr <> ", schema: " <> schema - printServerTransports "SMP" transports +printSMPServerConfig transports st = case dbStoreCfg st of + Just cfg -> printDBConfig cfg + Nothing -> printServerConfig "SMP" transports $ storeLogFile' st + where + printDBConfig PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}} = do + B.putStrLn $ "PostgreSQL database: " <> connstr <> ", schema: " <> schema + printServerTransports "SMP" transports deleteDirIfExists :: FilePath -> IO () deleteDirIfExists path = whenM (doesDirectoryExist path) $ removeDirectoryRecursive path diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index d0f1a84fd..b72922f04 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -72,7 +72,10 @@ module Simplex.Messaging.Server.Env.STM defaultIdleQueueInterval, journalMsgStoreDepth, readWriteQueueStore, + noPostgresExitStr, noPostgresExit, + dbStoreCfg, + storeLogFile', ) where @@ -131,6 +134,10 @@ import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM +#if defined(dbServerPostgres) +import Simplex.Messaging.Server.MsgStore.Postgres +#endif + data ServerConfig s = ServerConfig { transports :: [(ServiceName, ASrvTransport, AddHTTP)], smpHandshakeTimeout :: Int, @@ -153,6 +160,7 @@ data ServerConfig s = ServerConfig -- | time after which the messages can be removed from the queues and check interval, seconds messageExpiration :: Maybe ExpirationConfig, expireMessagesOnStart :: Bool, + expireMessagesOnSend :: Bool, -- | interval of inactivity after which journal queue is closed idleQueueInterval :: Int64, -- | notification expiration interval (seconds) @@ -274,14 +282,25 @@ fromMsgStore :: MsgStore s -> s fromMsgStore = \case StoreMemory s -> s StoreJournal s -> s +#if defined(dbServerPostgres) + StoreDatabase s -> s +#endif {-# INLINE fromMsgStore #-} type family SupportedStore (qs :: QSType) (ms :: MSType) :: Constraint where SupportedStore 'QSMemory 'MSMemory = () SupportedStore 'QSMemory 'MSJournal = () - SupportedStore 'QSPostgres 'MSJournal = () + SupportedStore 'QSMemory 'MSPostgres = + (Int ~ Bool, TypeError ('TE.Text "Storing messages in Postgres DB with queues in memory is not supported")) SupportedStore 'QSPostgres 'MSMemory = - (Int ~ Bool, TypeError ('TE.Text "Storing messages in memory with Postgres DB is not supported")) + (Int ~ Bool, TypeError ('TE.Text "Storing messages in memory with queues in Postgres DB is not supported")) + SupportedStore 'QSPostgres 'MSJournal = () +#if defined(dbServerPostgres) + SupportedStore 'QSPostgres 'MSPostgres = () +#else + SupportedStore 'QSPostgres 'MSPostgres = + (Int ~ Bool, TypeError ('TE.Text "Server compiled without server_postgres flag")) +#endif data AStoreType = forall qs ms. (SupportedStore qs ms, MsgStoreClass (MsgStoreType qs ms)) => @@ -291,16 +310,43 @@ data ServerStoreCfg s where SSCMemory :: Maybe StorePaths -> ServerStoreCfg STMMsgStore SSCMemoryJournal :: {storeLogFile :: FilePath, storeMsgsPath :: FilePath} -> ServerStoreCfg (JournalMsgStore 'QSMemory) SSCDatabaseJournal :: {storeCfg :: PostgresStoreCfg, storeMsgsPath' :: FilePath} -> ServerStoreCfg (JournalMsgStore 'QSPostgres) +#if defined(dbServerPostgres) + SSCDatabase :: PostgresStoreCfg -> ServerStoreCfg PostgresMsgStore +#endif + +dbStoreCfg :: ServerStoreCfg s -> Maybe PostgresStoreCfg +dbStoreCfg = \case + SSCMemory _ -> Nothing + SSCMemoryJournal {} -> Nothing + SSCDatabaseJournal {storeCfg} -> Just storeCfg +#if defined(dbServerPostgres) + SSCDatabase cfg -> Just cfg +#endif + +storeLogFile' :: ServerStoreCfg s -> Maybe FilePath +storeLogFile' = \case + SSCMemory sp_ -> (\StorePaths {storeLogFile} -> storeLogFile) <$> sp_ + SSCMemoryJournal {storeLogFile} -> Just storeLogFile + SSCDatabaseJournal {storeCfg = PostgresStoreCfg {dbStoreLogPath}} -> dbStoreLogPath +#if defined(dbServerPostgres) + SSCDatabase (PostgresStoreCfg {dbStoreLogPath}) -> dbStoreLogPath +#endif data StorePaths = StorePaths {storeLogFile :: FilePath, storeMsgsFile :: Maybe FilePath} type family MsgStoreType (qs :: QSType) (ms :: MSType) where MsgStoreType 'QSMemory 'MSMemory = STMMsgStore MsgStoreType qs 'MSJournal = JournalMsgStore qs +#if defined(dbServerPostgres) + MsgStoreType 'QSPostgres 'MSPostgres = PostgresMsgStore +#endif data MsgStore s where StoreMemory :: STMMsgStore -> MsgStore STMMsgStore StoreJournal :: JournalMsgStore qs -> MsgStore (JournalMsgStore qs) +#if defined(dbServerPostgres) + StoreDatabase :: PostgresMsgStore -> MsgStore PostgresMsgStore +#endif data Server s = Server { clients :: ServerClients s, @@ -532,8 +578,12 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp qsCfg = PQStoreCfg (storeCfg {confirmMigrations} :: PostgresStoreCfg) cfg = mkJournalStoreConfig qsCfg storeMsgsPath' msgQueueQuota maxJournalMsgCount maxJournalStateLines idleQueueInterval when compactLog $ compactDbStoreLog $ dbStoreLogPath storeCfg - ms <- newMsgStore cfg - pure $ StoreJournal ms + StoreJournal <$> newMsgStore cfg + SSCDatabase storeCfg -> do + let StartOptions {compactLog, confirmMigrations} = startOptions config + cfg = PostgresMsgStoreCfg storeCfg {confirmMigrations} msgQueueQuota + when compactLog $ compactDbStoreLog $ dbStoreLogPath storeCfg + StoreDatabase <$> newMsgStore cfg #else SSCDatabaseJournal {} -> noPostgresExit #endif @@ -627,10 +677,12 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp _ -> SPMMessages noPostgresExit :: IO a -noPostgresExit = do - putStrLn "Error: server binary is compiled without support for PostgreSQL database." - putStrLn "Please download `smp-server-postgres` or re-compile with `cabal build -fserver_postgres`." - exitFailure +noPostgresExit = putStrLn noPostgresExitStr >> exitFailure + +noPostgresExitStr :: String +noPostgresExitStr = + "Error: server binary is compiled without support for PostgreSQL database.\n" + <> "Please download `smp-server-postgres` or re-compile with `cabal build -fserver_postgres`." mkJournalStoreConfig :: QStoreCfg s -> FilePath -> Int -> Int -> Int -> Int64 -> JournalStoreConfig s mkJournalStoreConfig queueStoreCfg storePath msgQueueQuota maxJournalMsgCount maxJournalStateLines idleQueueInterval = diff --git a/src/Simplex/Messaging/Server/Information.hs b/src/Simplex/Messaging/Server/Information.hs index a94148dbe..ec832c48d 100644 --- a/src/Simplex/Messaging/Server/Information.hs +++ b/src/Simplex/Messaging/Server/Information.hs @@ -14,7 +14,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Int (Int64) import Data.Maybe (isJust) import Data.Text (Text) -import Simplex.Messaging.Agent.Protocol (ConnectionLink, ConnectionMode (..), ConnectionRequestUri) +import Simplex.Messaging.Agent.Protocol (ConnectionLink, ConnectionMode (..)) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON) diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index f94cd9682..8e5fd55ee 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -18,9 +18,10 @@ module Simplex.Messaging.Server.Main where import Control.Concurrent.STM -import Control.Exception (finally) +import Control.Exception (SomeException, finally, try) import Control.Logger.Simple import Control.Monad +import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isAlpha, isAscii, toUpper) @@ -60,11 +61,11 @@ import Simplex.Messaging.Transport (supportedProxyClientSMPRelayVRange, alpnSupp import Simplex.Messaging.Transport.Client (TransportHost (..), defaultSocksProxy) import Simplex.Messaging.Transport.HTTP2 (httpALPN) import Simplex.Messaging.Transport.Server (ServerCredentials (..), mkTransportServerConfig) -import Simplex.Messaging.Util (eitherToMaybe, ifM) +import Simplex.Messaging.Util (eitherToMaybe, ifM, unlessM) import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist) import System.Exit (exitFailure) import System.FilePath (combine) -import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) +import System.IO (BufferMode (..), IOMode (..), hSetBuffering, stderr, stdout, withFile) import Text.Read (readMaybe) #if defined(dbServerPostgres) @@ -73,6 +74,7 @@ import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) import Simplex.Messaging.Server.MsgStore.Types (QSType (..)) import Simplex.Messaging.Server.MsgStore.Journal (postgresQueueStore) +import Simplex.Messaging.Server.MsgStore.Postgres import Simplex.Messaging.Server.QueueStore.Postgres (batchInsertQueues, batchInsertServices, foldQueueRecs, foldServiceRecs) import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..)) import Simplex.Messaging.Server.QueueStore.Types @@ -129,6 +131,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = printMessageStats "Messages" msgStats putStrLn $ case readStoreType ini of Right (ASType SQSMemory SMSMemory) -> "store_messages set to `memory`, update it to `journal` in INI file" + Right (ASType SQSPostgres SMSPostgres) -> "store_messages set to `database`, update it to `journal` in INI file" Right (ASType _ SMSJournal) -> "store_messages set to `journal`" Left e -> e <> ", configure storage correctly" SCExport @@ -140,19 +143,31 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = confirmOrExit ("WARNING: journal directory " <> storeMsgsJournalDir <> " will be exported to message log file " <> storeMsgsFilePath) "Journal not exported" - ms <- newJournalMsgStore logPath MQStoreCfg - -- TODO [postgres] in case postgres configured, queues must be read from database - readQueueStore True (mkQueue ms False) storeLogFile $ stmQueueStore ms - exportMessages True ms storeMsgsFilePath False - putStrLn "Export completed" case readStoreType ini of - Right (ASType SQSMemory SMSMemory) -> putStrLn "store_messages set to `memory`, start the server." - Right (ASType SQSMemory SMSJournal) -> putStrLn "store_messages set to `journal`, update it to `memory` in INI file" - Right (ASType SQSPostgres SMSJournal) -> + Right (ASType SQSMemory msType) -> do + ms <- newJournalMsgStore logPath MQStoreCfg + readQueueStore True (mkQueue ms False) storeLogFile $ stmQueueStore ms + exportMessages True (StoreJournal ms) storeMsgsFilePath False + putStrLn "Export completed" + putStrLn $ case msType of + SMSMemory -> "store_messages set to `memory`, start the server." + SMSJournal -> "store_messages set to `journal`, update it to `memory` in INI file" #if defined(dbServerPostgres) + Right (ASType SQSPostgres SMSJournal) -> do + let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath + dbOpts@DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts + unlessM (checkSchemaExists connstr schema) $ do + putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr + exitFailure + ms <- newJournalMsgStore logPath $ PQStoreCfg PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini} + exportMessages True (StoreJournal ms) storeMsgsFilePath False + putStrLn "Export completed" putStrLn "store_messages set to `journal`, store_queues is set to `database`.\nExport queues to store log to use memory storage for messages (`smp-server database export`)." + Right (ASType SQSPostgres SMSPostgres) -> do + putStrLn $ "Messages can be exported with `dabatase export --table messages`." + exitFailure #else - noPostgresExit + Right (ASType SQSPostgres SMSJournal) -> noPostgresExit #endif Left e -> putStrLn $ e <> ", configure storage correctly" SCDelete @@ -166,11 +181,12 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = deleteDirIfExists storeMsgsJournalDir putStrLn $ "Deleted all messages in journal " <> storeMsgsJournalDir #if defined(dbServerPostgres) - Database cmd dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do + Database cmd tables dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do schemaExists <- checkSchemaExists connstr schema storeLogExists <- doesFileExist storeLogFilePath - case cmd of - SCImport + msgsFileExists <- doesFileExist storeMsgsFilePath + case (cmd, tables) of + (SCImport, DTQueues) | schemaExists && storeLogExists -> exitConfigureQueueStore connstr schema | schemaExists -> do putStrLn $ "Schema " <> B.unpack schema <> " already exists in PostrgreSQL database: " <> B.unpack connstr @@ -188,12 +204,29 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = putStrLn $ case readStoreType ini of Right (ASType SQSMemory SMSMemory) -> setToDbStr <> "\nstore_messages set to `memory`, import messages to journal to use PostgreSQL database for queues (`smp-server journal import`)" Right (ASType SQSMemory SMSJournal) -> setToDbStr - Right (ASType SQSPostgres SMSJournal) -> "store_queues set to `database`, start the server." + Right (ASType SQSPostgres _) -> "store_queues set to `database`, start the server." Left e -> e <> ", configure storage correctly" where setToDbStr :: String setToDbStr = "store_queues set to `memory`, update it to `database` in INI file" - SCExport + (SCImport, DTMessages) + | not schemaExists -> do + putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr + exitFailure + | not msgsFileExists -> do + putStrLn $ storeMsgsFilePath <> " file does not exist." + exitFailure + | otherwise -> do + confirmOrExit + ("WARNING: message log file " <> storeMsgsFilePath <> " will be imported to PostrgreSQL database " <> B.unpack connstr <> ", schema: " <> B.unpack schema) + "Message records not imported" + mCnt <- importMessagesToDatabase storeMsgsFilePath dbOpts + putStrLn $ "Import completed: " <> show mCnt <> " messages" + putStrLn $ case readStoreType ini of + Right (ASType SQSPostgres SMSPostgres) -> "store_queues and store_messages set to `database`, start the server." + Right _ -> "set store_queues and store_messages set to `database` in INI file" + Left e -> e <> ", configure storage correctly" + (SCExport, DTQueues) | schemaExists && storeLogExists -> exitConfigureQueueStore connstr schema | not schemaExists -> do putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr @@ -203,15 +236,33 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = exitFailure | otherwise -> do confirmOrExit - ("WARNING: PostrgreSQL database schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath) + ("WARNING: PostrgreSQL schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath) "Queue records not exported" (sCnt, qCnt) <- exportDatabaseToStoreLog logPath dbOpts storeLogFilePath putStrLn $ "Export completed: " <> show sCnt <> " services, " <> show qCnt <> " queues" putStrLn $ case readStoreType ini of - Right (ASType SQSPostgres SMSJournal) -> "store_queues set to `database`, update it to `memory` in INI file." + Right (ASType SQSPostgres _) -> "store_queues or store_messages set to `database`, update it to `memory` in INI file." Right (ASType SQSMemory _) -> "store_queues set to `memory`, start the server" Left e -> e <> ", configure storage correctly" - SCDelete + (SCExport, DTMessages) + | not schemaExists -> do + putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr + exitFailure + | msgsFileExists -> do + putStrLn $ storeMsgsFilePath <> " file already exists." + exitFailure + | otherwise -> do + confirmOrExit + ("WARNING: Messages from PostrgreSQL schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to message log file " <> storeMsgsFilePath) + "Message records not exported" + let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL} + ms <- newMsgStore $ PostgresMsgStoreCfg storeCfg defaultMsgQueueQuota + withFile storeMsgsFilePath WriteMode (try . exportDbMessages True ms) >>= \case + Right mCnt -> do + putStrLn $ "Export completed: " <> show mCnt <> " messages" + putStrLn "Export queues with `smp-server database export queues`" + Left (e :: SomeException) -> putStrLn $ "Error exporting messages: " <> show e + (SCDelete, _) | not schemaExists -> do putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr exitFailure @@ -245,8 +296,14 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = readStoreType ini = case (iniStoreQueues, iniStoreMessage) of ("memory", "memory") -> Right $ ASType SQSMemory SMSMemory ("memory", "journal") -> Right $ ASType SQSMemory SMSJournal + ("memory", "database") -> Left "Database and memory storage are not compatible." + ("database", "memory") -> Left "Database and memory storage are not compatible." ("database", "journal") -> Right $ ASType SQSPostgres SMSJournal - ("database", "memory") -> Left "Using PostgreSQL database requires journal memory storage." +#if defined(dbServerPostgres) + ("database", "database") -> Right $ ASType SQSPostgres SMSPostgres +#else + ("database", "database") -> Left noPostgresExitStr +#endif (q, m) -> Left $ T.unpack $ "Invalid storage settings: store_queues: " <> q <> ", store_messages: " <> m where iniStoreQueues = fromRight "memory" $ lookupValue "STORE_LOG" "store_queues" ini @@ -396,6 +453,12 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath storeCfg = PostgresStoreCfg {dbOpts = iniDBOptions ini defaultDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini} in SSCDatabaseJournal {storeCfg, storeMsgsPath' = storeMsgsJournalDir} +#if defined(dbServerPostgres) + iniStoreCfg SQSPostgres SMSPostgres = + let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath + storeCfg = PostgresStoreCfg {dbOpts = iniDBOptions ini defaultDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini} + in SSCDatabase storeCfg +#endif serverConfig :: ServerStoreCfg s -> ServerConfig s serverConfig serverStoreCfg = ServerConfig @@ -428,6 +491,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = { ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini }, expireMessagesOnStart = fromMaybe True $ iniOnOff "STORE_LOG" "expire_messages_on_start" ini, + expireMessagesOnSend = fromMaybe True $ iniOnOff "STORE_LOG" "expire_messages_on_send" ini, idleQueueInterval = defaultIdleQueueInterval, notificationExpiration = defaultNtfExpiration @@ -504,6 +568,14 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = msgsFileExists <- doesFileExist storeMsgsFilePath storeLogExists <- doesFileExist storeLogFilePath case mode of +#if defined(dbServerPostgres) + ASType SQSPostgres SMSPostgres + | msgsFileExists || msgsDirExists -> do + putStrLn $ "Error: " <> storeMsgsFilePath <> " file or " <> storeMsgsJournalDir <> " directory are present." + putStrLn "Configure memory storage." + exitFailure + | otherwise -> checkDbStorage ini storeLogExists +#endif ASType qs SMSJournal | msgsFileExists && msgsDirExists -> exitConfigureMsgStorage | msgsFileExists -> do @@ -516,28 +588,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = SQSMemory -> unless (storeLogExists) $ putStrLn $ "store_queues is `memory`, " <> storeLogFilePath <> " file will be created." #if defined(dbServerPostgres) - SQSPostgres -> do - let DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts - schemaExists <- checkSchemaExists connstr schema - case enableDbStoreLog' ini of - Just () - | not schemaExists -> noDatabaseSchema connstr schema - | not storeLogExists -> do - putStrLn $ "Error: db_store_log is `on`, " <> storeLogFilePath <> " does not exist" - exitFailure - | otherwise -> pure () - Nothing - | storeLogExists && schemaExists -> exitConfigureQueueStore connstr schema - | storeLogExists -> do - putStrLn $ "Error: store_queues is `database` with " <> storeLogFilePath <> " file present." - putStrLn "Set store_queues to `memory` or use `smp-server database import` to migrate." - exitFailure - | not schemaExists -> noDatabaseSchema connstr schema - | otherwise -> pure () - where - noDatabaseSchema connstr schema = do - putStrLn $ "Error: store_queues is `database`, create schema " <> B.unpack schema <> " in PostgreSQL database " <> B.unpack connstr - exitFailure + SQSPostgres -> checkDbStorage ini storeLogExists #else SQSPostgres -> noPostgresExit #endif @@ -555,6 +606,29 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = exitFailure #if defined(dbServerPostgres) + checkDbStorage ini storeLogExists = do + let DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts + schemaExists <- checkSchemaExists connstr schema + case enableDbStoreLog' ini of + Just () + | not schemaExists -> noDatabaseSchema connstr schema + | not storeLogExists -> do + putStrLn $ "Error: db_store_log is `on`, " <> storeLogFilePath <> " does not exist" + exitFailure + | otherwise -> pure () + Nothing + | storeLogExists && schemaExists -> exitConfigureQueueStore connstr schema + | storeLogExists -> do + putStrLn $ "Error: store_queues is `database` with " <> storeLogFilePath <> " file present." + putStrLn "Set store_queues to `memory` or use `smp-server database import` to migrate." + exitFailure + | not schemaExists -> noDatabaseSchema connstr schema + | otherwise -> pure () + where + noDatabaseSchema connstr schema = do + putStrLn $ "Error: store_queues is `database`, create schema " <> B.unpack schema <> " in PostgreSQL database " <> B.unpack connstr + exitFailure + exitConfigureQueueStore connstr schema = do putStrLn $ "Error: both " <> storeLogFilePath <> " file and " <> B.unpack schema <> " schema are present (database: " <> B.unpack connstr <> ")." putStrLn "Configure queue storage." @@ -575,13 +649,28 @@ importStoreLogToDatabase logPath storeLogFile dbOpts = do renameFile storeLogFile $ storeLogFile <> ".bak" pure (sCnt, qCnt) +importMessagesToDatabase :: FilePath -> DBOpts -> IO Int64 +importMessagesToDatabase msgsLogFile dbOpts = do + let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL} + ms <- newMsgStore $ PostgresMsgStoreCfg storeCfg defaultMsgQueueQuota + mCnt <- getDbMessageCount ms + when (mCnt > 0) $ do + confirmOrExit ("WARNING: the database contains messages, they will be deleted.") "Message records not imported" + deleteAllMessages ms + inserted <- batchInsertMessages True msgsLogFile $ queueStore ms + mCnt' <- getDbMessageCount ms + unless (inserted == mCnt') $ putStrLn $ "WARNING: inserted " <> show inserted <> " rows, table has " <> show mCnt' <> " messages." + updateQueueCounts ms + renameFile msgsLogFile $ msgsLogFile <> ".bak" + pure mCnt' + exportDatabaseToStoreLog :: FilePath -> DBOpts -> FilePath -> IO (Int, Int) exportDatabaseToStoreLog logPath dbOpts storeLogFilePath = do let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL} ps <- newJournalMsgStore logPath $ PQStoreCfg storeCfg sl <- openWriteStoreLog False storeLogFilePath Sum sCnt <- foldServiceRecs (postgresQueueStore ps) $ \sr -> logNewService sl sr $> Sum (1 :: Int) - Sum qCnt <- foldQueueRecs True True (postgresQueueStore ps) Nothing $ \(rId, qr) -> logCreateQueue sl rId qr $> Sum (1 :: Int) + Sum qCnt <- foldQueueRecs True True (postgresQueueStore ps) $ \(rId, qr) -> logCreateQueue sl rId qr $> Sum (1 :: Int) closeStoreLog sl pure (sCnt, qCnt) #endif @@ -667,10 +756,22 @@ data CliCommand | Start StartOptions | Delete | Journal StoreCmd - | Database StoreCmd DBOpts + | Database StoreCmd DatabaseTable DBOpts data StoreCmd = SCImport | SCExport | SCDelete +data DatabaseTable = DTQueues | DTMessages + +instance StrEncoding DatabaseTable where + strEncode = \case + DTQueues -> "queues" + DTMessages -> "messages" + strP = + A.takeTill (== ' ') >>= \case + "queues" -> pure DTQueues + "messages" -> pure DTMessages + _ -> fail "DatabaseTable" + cliCommandP :: FilePath -> FilePath -> FilePath -> Parser CliCommand cliCommandP cfgPath logPath iniFile = hsubparser @@ -679,7 +780,7 @@ cliCommandP cfgPath logPath iniFile = <> command "start" (info (Start <$> startOptionsP) (progDesc $ "Start server (configuration: " <> iniFile <> ")")) <> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files")) <> command "journal" (info (Journal <$> journalCmdP) (progDesc "Import/export messages to/from journal storage")) - <> command "database" (info (Database <$> databaseCmdP <*> dbOptsP defaultDBOpts) (progDesc "Import/export queues to/from PostgreSQL database storage")) + <> command "database" (info (Database <$> databaseCmdP <*> dbTableP <*> dbOptsP defaultDBOpts) (progDesc "Import/export queues to/from PostgreSQL database storage")) ) where initP :: Parser InitOptions @@ -833,6 +934,13 @@ cliCommandP cfgPath logPath iniFile = <> command "export" (info (pure SCExport) (progDesc $ "Export " <> dest <> " to " <> src)) <> command "delete" (info (pure SCDelete) (progDesc $ "Delete " <> dest)) ) + dbTableP = + option + strParse + ( long "table" + <> help "Database tables: queues/messages" + <> metavar "TABLE" + ) parseBasicAuth :: ReadM ServerPassword parseBasicAuth = eitherReader $ fmap ServerPassword . strDecode . B.pack entityP :: String -> String -> String -> Parser (Maybe Entity, Maybe Text) diff --git a/src/Simplex/Messaging/Server/Main/Init.hs b/src/Simplex/Messaging/Server/Main/Init.hs index 2823df74f..d17018de2 100644 --- a/src/Simplex/Messaging/Server/Main/Init.hs +++ b/src/Simplex/Messaging/Server/Main/Init.hs @@ -87,7 +87,8 @@ iniFileContent cfgPath logPath opts host basicAuth controlPortPwds = <> ("restore_messages: " <> onOff enableStoreLog <> "\n\n") <> "# Messages and notifications expiration periods.\n" <> ("expire_messages_days: " <> tshow defMsgExpirationDays <> "\n") - <> "expire_messages_on_start: on\n" + <> "expire_messages_on_start: on\n\ + \expire_messages_on_send: off\n" <> ("expire_ntfs_hours: " <> tshow defNtfExpirationHours <> "\n\n") <> "# Log daily server statistics to CSV file\n" <> ("log_stats: " <> onOff logStats <> "\n\n") diff --git a/src/Simplex/Messaging/Server/MsgStore/Journal.hs b/src/Simplex/Messaging/Server/MsgStore/Journal.hs index 78f9c1393..e81c153de 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Journal.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Journal.hs @@ -24,7 +24,7 @@ module Simplex.Messaging.Server.MsgStore.Journal ( JournalMsgStore (random, expireBackupsBefore), QStore (..), QStoreCfg (..), - JournalQueue, + JournalQueue (msgQueue'), -- msgQueue' is used in tests JournalMsgQueue (queue, state), JMQueue (queueDirectory, statePath), JournalStoreConfig (..), @@ -38,6 +38,7 @@ module Simplex.Messaging.Server.MsgStore.Journal msgQueueStatePath, readQueueState, newMsgQueueState, + getJournalQueueMessages, newJournalId, appendState, queueLogFileName, @@ -58,7 +59,7 @@ import Control.Monad.Trans.Except import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (fromRight) +import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) import Data.List (sort) @@ -290,13 +291,10 @@ newtype StoreIO (s :: QSType) a = StoreIO {unStoreIO :: IO a} deriving newtype (Functor, Applicative, Monad) instance StoreQueueClass (JournalQueue s) where - type MsgQueue (JournalQueue s) = JournalMsgQueue s recipientId = recipientId' {-# INLINE recipientId #-} queueRec = queueRec' {-# INLINE queueRec #-} - msgQueue = msgQueue' - {-# INLINE msgQueue #-} withQueueLock :: JournalQueue s -> Text -> IO a -> IO a withQueueLock JournalQueue {recipientId', queueLock, sharedLock} = withLockWaitShared recipientId' queueLock sharedLock @@ -309,7 +307,7 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where newQueueStore = \case MQStoreCfg -> MQStore <$> newQueueStore @(JournalQueue s) () #if defined(dbServerPostgres) - PQStoreCfg cfg -> PQStore <$> newQueueStore @(JournalQueue s) cfg + PQStoreCfg cfg -> PQStore <$> newQueueStore @(JournalQueue s) (cfg, True) #endif closeQueueStore = withQS (closeQueueStore @(JournalQueue s)) @@ -378,6 +376,7 @@ makeQueue_ JournalMsgStore {sharedLock} rId qr queueLock = do instance MsgStoreClass (JournalMsgStore s) where type StoreMonad (JournalMsgStore s) = StoreIO s + type MsgQueue (JournalMsgStore s) = JournalMsgQueue s type QueueStore (JournalMsgStore s) = QStore s type StoreQueue (JournalMsgStore s) = JournalQueue s type MsgStoreConfig (JournalMsgStore s) = JournalStoreConfig s @@ -405,11 +404,11 @@ instance MsgStoreClass (JournalMsgStore s) where -- This function can only be used in server CLI commands or before server is started. -- It does not cache queues and is NOT concurrency safe. - unsafeWithAllMsgQueues :: Monoid a => Bool -> Bool -> JournalMsgStore s -> (JournalQueue s -> IO a) -> IO a - unsafeWithAllMsgQueues tty withData ms action = case queueStore_ ms of + unsafeWithAllMsgQueues :: Monoid a => Bool -> JournalMsgStore s -> (JournalQueue s -> IO a) -> IO a + unsafeWithAllMsgQueues tty ms action = case queueStore_ ms of MQStore st -> withLoadedQueues st run #if defined(dbServerPostgres) - PQStore st -> foldQueueRecs tty withData st Nothing $ uncurry (mkQueue ms False) >=> run + PQStore st -> foldQueueRecs False tty st $ uncurry (mkQueue ms False) >=> run #endif where run q = do @@ -421,7 +420,7 @@ instance MsgStoreClass (JournalMsgStore s) where expireOldMessages :: Bool -> JournalMsgStore s -> Int64 -> Int64 -> IO MessageStats expireOldMessages tty ms now ttl = case queueStore_ ms of MQStore st -> - withLoadedQueues st $ \q -> run $ isolateQueue q "deleteExpiredMsgs" $ do + withLoadedQueues st $ \q -> run $ isolateQueue ms q "deleteExpiredMsgs" $ do StoreIO (readTVarIO $ queueRec q) >>= \case Just QueueRec {updatedAt = Just (RoundedSystemTime t)} | t > veryOld -> expireQueueMsgs ms now old q @@ -429,7 +428,7 @@ instance MsgStoreClass (JournalMsgStore s) where #if defined(dbServerPostgres) PQStore st -> do let JournalMsgStore {queueLocks, sharedLock} = ms - foldQueueRecs tty False st (Just veryOld) $ \(rId, qr) -> do + foldRecentQueueRecs veryOld tty st $ \(rId, qr) -> do q <- mkQueue ms False rId qr withSharedWaitLock rId queueLocks sharedLock $ run $ tryStore' "deleteExpiredMsgs" rId $ getLoadedQueue q >>= unStoreIO . expireQueueMsgs ms now old @@ -485,7 +484,7 @@ instance MsgStoreClass (JournalMsgStore s) where where newQ = do let dir = msgQueueDirectory ms rId - statePath = msgQueueStatePath dir $ B.unpack (strEncode rId) + statePath = msgQueueStatePath dir rId queue = JMQueue {queueDirectory = dir, statePath} q <- ifM (doesDirectoryExist dir) (openMsgQueue ms queue forWrite) (createQ queue) atomically $ writeTVar msgQueue' $ Just q @@ -563,8 +562,9 @@ instance MsgStoreClass (JournalMsgStore s) where where getSize = maybe (pure (-1)) (fmap size . readTVarIO . state) + -- drainMsgs is never True with Journal storage getQueueMessages_ :: Bool -> JournalQueue s -> JournalMsgQueue s -> StoreIO s [Message] - getQueueMessages_ drainMsgs q' q = StoreIO (run []) + getQueueMessages_ drainMsgs q' q = StoreIO $ if drainMsgs then run [] else readTVarIO (state q) >>= runFast where run msgs = readTVarIO (handles q) >>= maybe (pure []) (getMsg msgs) getMsg msgs hs = chooseReadJournal q' q drainMsgs hs >>= maybe (pure msgs) readMsg @@ -573,9 +573,19 @@ instance MsgStoreClass (JournalMsgStore s) where (msg, len) <- hGetMsgAt h $ bytePos rs updateReadPos q' q drainMsgs len hs (msg :) <$> run msgs + runFast MsgQueueState {writeState = ws, readState = rs, size} + | size > 0 = + readTVarIO (handles q) >>= \case + Just (MsgQueueHandles _ rh wh_) -> do + msgs <- getJournalRange rh (bytePos rs) (byteCount rs) + case wh_ of + Just wh -> (msgs ++) <$> getJournalRange wh 0 (bytePos ws) + Nothing -> pure msgs + Nothing -> pure [] + | otherwise = pure [] writeMsg :: JournalMsgStore s -> JournalQueue s -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool)) - writeMsg ms q' logState msg = isolateQueue q' "writeMsg" $ do + writeMsg ms q' logState msg = isolateQueue ms q' "writeMsg" $ do q <- getMsgQueue ms q' True StoreIO $ (`E.finally` updateActiveAt q') $ do st@MsgQueueState {canWrite, size} <- readTVarIO (state q) @@ -649,8 +659,8 @@ instance MsgStoreClass (JournalMsgStore s) where $>>= \len -> readTVarIO handles $>>= \hs -> updateReadPos q mq logState len hs $> Just () - isolateQueue :: JournalQueue s -> Text -> StoreIO s a -> ExceptT ErrorType IO a - isolateQueue sq op = tryStore' op (recipientId' sq) . withQueueLock sq op . unStoreIO + isolateQueue :: JournalMsgStore s -> JournalQueue s -> Text -> StoreIO s a -> ExceptT ErrorType IO a + isolateQueue _ sq op = tryStore' op (recipientId' sq) . withQueueLock sq op . unStoreIO unsafeRunStore :: JournalQueue s -> Text -> StoreIO s a -> IO a unsafeRunStore sq op a = @@ -795,8 +805,8 @@ msgQueueDirectory JournalMsgStore {config = JournalStoreConfig {storePath, pathP let (seg, s') = B.splitAt 2 s in seg : splitSegments (n - 1) s' -msgQueueStatePath :: FilePath -> String -> FilePath -msgQueueStatePath dir queueId = dir (queueLogFileName <> "." <> queueId <> logFileExt) +msgQueueStatePath :: FilePath -> RecipientId -> FilePath +msgQueueStatePath dir rId = dir (queueLogFileName <> "." <> B.unpack (strEncode rId) <> logFileExt) createNewJournal :: FilePath -> ByteString -> IO Handle createNewJournal dir journalId = do @@ -965,10 +975,11 @@ deleteQueue_ ms q = pure r where rId = recipientId q - remove r@(_, mq_) = do + remove qr = do + mq_ <- atomically $ swapTVar (msgQueue' q) Nothing mapM_ (closeMsgQueueHandles ms) mq_ removeQueueDirectory ms rId - pure r + pure (qr, mq_) closeMsgQueue :: JournalMsgStore s -> JournalQueue s -> IO () closeMsgQueue ms JournalQueue {msgQueue'} = atomically (swapTVar msgQueue' Nothing) >>= mapM_ (closeMsgQueueHandles ms) @@ -1019,3 +1030,33 @@ hClose h = closeOnException :: Handle -> IO a -> IO a closeOnException h a = a `E.onException` hClose h + +getJournalQueueMessages :: JournalMsgStore s -> JournalQueue s -> IO [Message] +getJournalQueueMessages ms q = + readQueueState ms (msgQueueStatePath dir rId) >>= \case + (Just MsgQueueState {readState = rs, writeState = ws, size}, _) | size > 0 -> do + msgs <- getMsgs (journalId rs) (bytePos rs) (byteCount rs) + if journalId rs == journalId ws + then pure msgs + else (msgs ++) <$> getMsgs (journalId ws) 0 (bytePos ws) + _ -> pure [] + where + rId = recipientId' q + dir = msgQueueDirectory ms rId + getMsgs jId from to = + IO.withFile (journalFilePath dir jId) ReadWriteMode $ \h' -> + getJournalRange h' from to + +getJournalRange :: Handle -> Int64 -> Int64 -> IO [Message] +getJournalRange h from to + | to > from = do + IO.hSeek h AbsoluteSeek $ fromIntegral from + parseMsgs =<< B.hGet h (fromIntegral $ to - from) + | otherwise = pure [] + where + parseMsgs s = do + let (errs, msgs) = partitionEithers $ map strDecode $ B.lines s + unless (null errs) $ do + f <- IO.hShow h + putStrLn $ "Error reading " <> show (length errs) <> " messages from " <> f + pure msgs diff --git a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs new file mode 100644 index 000000000..a0eb1d1ca --- /dev/null +++ b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs @@ -0,0 +1,386 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} + +module Simplex.Messaging.Server.MsgStore.Postgres + ( PostgresMsgStore, + PostgresMsgStoreCfg (..), + PostgresQueue, + exportDbMessages, + getDbMessageStats, + getDbMessageCount, + deleteAllMessages, + batchInsertMessages, + updateQueueCounts, + ) +where + +import Control.Concurrent.STM +import qualified Control.Exception as E +import Control.Monad +import Control.Monad.Reader +import Control.Monad.Trans.Except +import qualified Data.ByteString as B +import qualified Data.ByteString.Builder as BB +import qualified Data.ByteString.Lazy as LB +import Data.Functor (($>)) +import Data.IORef +import Data.Int (Int64) +import Data.List (intersperse) +import qualified Data.Map.Strict as M +import Data.Text (Text) +import Data.Time.Clock.System (SystemTime (..)) +import Database.PostgreSQL.Simple (Binary (..), Only (..), (:.) (..)) +import qualified Database.PostgreSQL.Simple as DB +import qualified Database.PostgreSQL.Simple.Copy as DB +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +import Simplex.Messaging.Agent.Store.Postgres.Common +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Protocol +import Simplex.Messaging.Server.MsgStore +import Simplex.Messaging.Server.MsgStore.Types +import Simplex.Messaging.Server.QueueStore +import Simplex.Messaging.Server.QueueStore.Postgres +import Simplex.Messaging.Server.QueueStore.Types +import Simplex.Messaging.Server.StoreLog (foldLogLines) +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Util (maybeFirstRow, maybeFirstRow', (<$$>)) +import System.IO (Handle, hFlush, stdout) + +data PostgresMsgStore = PostgresMsgStore + { config :: PostgresMsgStoreCfg, + queueStore_ :: PostgresQueueStore' + } + +data PostgresMsgStoreCfg = PostgresMsgStoreCfg + { queueStoreCfg :: PostgresStoreCfg, + quota :: Int + } + +type PostgresQueueStore' = PostgresQueueStore PostgresQueue + +data PostgresQueue = PostgresQueue + { recipientId' :: RecipientId, + queueRec' :: TVar (Maybe QueueRec) + } + +instance StoreQueueClass PostgresQueue where + recipientId = recipientId' + {-# INLINE recipientId #-} + queueRec = queueRec' + {-# INLINE queueRec #-} + withQueueLock PostgresQueue {} _ = id -- TODO [messages] maybe it's just transaction? + {-# INLINE withQueueLock #-} + +newtype DBTransaction = DBTransaction {dbConn :: DB.Connection} + +type DBStoreIO a = ReaderT DBTransaction IO a + +instance MsgStoreClass PostgresMsgStore where + type StoreMonad PostgresMsgStore = ReaderT DBTransaction IO + type MsgQueue PostgresMsgStore = () + type QueueStore PostgresMsgStore = PostgresQueueStore' + type StoreQueue PostgresMsgStore = PostgresQueue + type MsgStoreConfig PostgresMsgStore = PostgresMsgStoreCfg + + newMsgStore :: PostgresMsgStoreCfg -> IO PostgresMsgStore + newMsgStore config = do + queueStore_ <- newQueueStore @PostgresQueue (queueStoreCfg config, False) + pure PostgresMsgStore {config, queueStore_} + + closeMsgStore :: PostgresMsgStore -> IO () + closeMsgStore = closeQueueStore @PostgresQueue . queueStore_ + + withActiveMsgQueues _ _ = error "withActiveMsgQueues not used" + + unsafeWithAllMsgQueues _ _ _ = error "unsafeWithAllMsgQueues not used" + + expireOldMessages :: Bool -> PostgresMsgStore -> Int64 -> Int64 -> IO MessageStats + expireOldMessages _tty ms now ttl = + maybeFirstRow' newMessageStats toMessageStats $ withConnection st $ \db -> + DB.query db "CALL expire_old_messages(?,?,?,0,0,0)" (oldQueue, oldMsg, batchSize) + where + st = dbStore $ queueStore_ ms + oldQueue = 0 :: Int64 -- expire all queues + oldMsg = now - ttl + batchSize = 10000 :: Int + toMessageStats (expiredMsgsCount, storedMsgsCount, storedQueues) = + MessageStats {expiredMsgsCount, storedMsgsCount, storedQueues} + + logQueueStates _ = error "logQueueStates not used" + + logQueueState _ = error "logQueueState not used" + + queueStore = queueStore_ + {-# INLINE queueStore #-} + + loadedQueueCounts :: PostgresMsgStore -> IO LoadedQueueCounts + loadedQueueCounts ms = do + loadedQueueCount <- M.size <$> readTVarIO queues + loadedNotifierCount <- M.size <$> readTVarIO notifiers + notifierLockCount <- M.size <$> readTVarIO notifierLocks + pure LoadedQueueCounts {loadedQueueCount, loadedNotifierCount, openJournalCount = 0, queueLockCount = 0, notifierLockCount} + where + PostgresQueueStore {queues, notifiers, notifierLocks} = queueStore_ ms + + mkQueue :: PostgresMsgStore -> Bool -> RecipientId -> QueueRec -> IO PostgresQueue + mkQueue _ _keepLock rId qr = PostgresQueue rId <$> newTVarIO (Just qr) + {-# INLINE mkQueue #-} + + getMsgQueue _ _ _ = pure () + {-# INLINE getMsgQueue #-} + + getPeekMsgQueue :: PostgresMsgStore -> PostgresQueue -> DBStoreIO (Maybe ((), Message)) + getPeekMsgQueue _ q = ((),) <$$> tryPeekMsg_ q () + + withIdleMsgQueue :: Int64 -> PostgresMsgStore -> PostgresQueue -> (() -> DBStoreIO a) -> DBStoreIO (Maybe a, Int) + withIdleMsgQueue _ _ _ _ = error "withIdleMsgQueue not used" + + deleteQueue :: PostgresMsgStore -> PostgresQueue -> IO (Either ErrorType QueueRec) + deleteQueue ms q = deleteStoreQueue (queueStore_ ms) q + {-# INLINE deleteQueue #-} + + deleteQueueSize :: PostgresMsgStore -> PostgresQueue -> IO (Either ErrorType (QueueRec, Int)) + deleteQueueSize ms q = runExceptT $ do + size <- getQueueSize ms q + qr <- ExceptT $ deleteStoreQueue (queueStore_ ms) q + pure (qr, size) + + getQueueMessages_ _ _ _ = error "getQueueMessages_ not used" + + writeMsg :: PostgresMsgStore -> PostgresQueue -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool)) + writeMsg ms q _ msg = + uninterruptibleMask_ $ + withDB' "writeMsg" (queueStore_ ms) $ \db -> do + let (msgQuota, ntf, body) = case msg of + Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body') + MessageQuota {} -> (True, False, B.empty) + toResult <$> + DB.query + db + "SELECT quota_written, was_empty FROM write_message(?,?,?,?,?,?,?)" + (recipientId' q, Binary (messageId msg), systemSeconds (messageTs msg), msgQuota, ntf, Binary body, quota) + where + toResult = \case + ((msgQuota, wasEmpty) : _) -> if msgQuota then Nothing else Just (msg, wasEmpty) + [] -> Nothing + PostgresMsgStore {config = PostgresMsgStoreCfg {quota}} = ms + + setOverQuota_ :: PostgresQueue -> IO () -- can ONLY be used while restoring messages, not while server running + setOverQuota_ _ = error "TODO setOverQuota_" -- TODO [messages] + + getQueueSize_ :: () -> DBStoreIO Int + getQueueSize_ _ = error "getQueueSize_ not used" + + getQueueSize :: PostgresMsgStore -> PostgresQueue -> ExceptT ErrorType IO Int + getQueueSize ms q = + withDB' "getQueueSize" (queueStore_ ms) $ \db -> + maybeFirstRow' 0 fromOnly $ + DB.query db "SELECT msg_queue_size FROM msg_queues WHERE recipient_id = ? AND deleted_at IS NULL" (Only (recipientId' q)) + + tryPeekMsg_ :: PostgresQueue -> () -> DBStoreIO (Maybe Message) + tryPeekMsg_ q _ = do + db <- asks dbConn + liftIO $ maybeFirstRow toMessage $ + DB.query + db + [sql| + SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + FROM messages + WHERE recipient_id = ? + ORDER BY message_id ASC LIMIT 1 + |] + (Only (recipientId' q)) + + tryDeleteMsg_ :: PostgresQueue -> () -> Bool -> DBStoreIO () + tryDeleteMsg_ _q _ _ = error "tryDeleteMsg_ not used" -- do + + isolateQueue :: PostgresMsgStore -> PostgresQueue -> Text -> DBStoreIO a -> ExceptT ErrorType IO a + isolateQueue ms _q op a = uninterruptibleMask_ $ withDB' op (queueStore_ ms) $ runReaderT a . DBTransaction + + unsafeRunStore _ _ _ = error "unsafeRunStore not used" + + tryPeekMsg :: PostgresMsgStore -> PostgresQueue -> ExceptT ErrorType IO (Maybe Message) + tryPeekMsg ms q = isolateQueue ms q "tryPeekMsg" $ tryPeekMsg_ q () + {-# INLINE tryPeekMsg #-} + + tryDelMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message) + tryDelMsg ms q msgId = + uninterruptibleMask_ $ + withDB' "tryDelMsg" (queueStore_ ms) $ \db -> + maybeFirstRow toMessage $ + DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_msg(?, ?)" (recipientId' q, Binary msgId) + + tryDelPeekMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message) + tryDelPeekMsg ms q msgId = + uninterruptibleMask_ $ + withDB' "tryDelPeekMsg" (queueStore_ ms) $ \db -> + toResult . map toMessage + <$> DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_peek_msg(?, ?)" (recipientId' q, Binary msgId) + where + toResult = \case + [] -> (Nothing, Nothing) + [msg] + | messageId msg == msgId -> (Just msg, Nothing) + | otherwise -> (Nothing, Just msg) + deleted : next : _ -> (Just deleted, Just next) + + deleteExpiredMsgs :: PostgresMsgStore -> PostgresQueue -> Int64 -> ExceptT ErrorType IO Int + deleteExpiredMsgs ms q old = + uninterruptibleMask_ $ + maybeFirstRow' 0 (fromIntegral @Int64 . fromOnly) $ withDB' "deleteExpiredMsgs" (queueStore_ ms) $ \db -> + DB.query db "SELECT delete_expired_msgs(?, ?)" (recipientId' q, old) + +uninterruptibleMask_ :: ExceptT ErrorType IO a -> ExceptT ErrorType IO a +uninterruptibleMask_ = ExceptT . E.uninterruptibleMask_ . runExceptT +{-# INLINE uninterruptibleMask_ #-} + +toMessage :: (Binary MsgId, Int64, Bool, Bool, Binary MsgBody) -> Message +toMessage (Binary msgId, ts, msgQuota, ntf, Binary body) + | msgQuota = MessageQuota {msgId, msgTs} + | otherwise = Message {msgId, msgTs, msgFlags = MsgFlags ntf, msgBody = C.unsafeMaxLenBS body} -- TODO [messages] unsafeMaxLenBS? + where + msgTs = MkSystemTime ts 0 + +exportDbMessages :: Bool -> PostgresMsgStore -> Handle -> IO Int +exportDbMessages tty ms h = do + rows <- newIORef [] + n <- withConnection st $ \db -> DB.foldWithOptions_ opts db query 0 $ \i r -> do + let i' = i + 1 + if i' `mod` 1000 > 0 + then modifyIORef rows (r :) + else do + readIORef rows >>= writeMessages . (r :) + writeIORef rows [] + when tty $ putStr (progress i' <> "\r") >> hFlush stdout + pure i' + readIORef rows >>= \rs -> unless (null rs) $ writeMessages rs + when tty $ putStrLn $ progress n + pure n + where + st = dbStore $ queueStore_ ms + opts = DB.defaultFoldOptions {DB.fetchQuantity = DB.Fixed 1000} + query = + [sql| + SELECT recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + FROM messages + ORDER BY recipient_id, message_id ASC + |] + writeMessages = BB.hPutBuilder h . encodeMessages . reverse + encodeMessages = mconcat . map (\(Only rId :. msg) -> BB.byteString (strEncode $ MLRv3 rId $ toMessage msg) <> BB.char8 '\n') + progress i = "Processed: " <> show i <> " records" + +getDbMessageStats :: PostgresMsgStore -> IO MessageStats +getDbMessageStats ms = + maybeFirstRow' newMessageStats toMessageStats $ withConnection st $ \db -> + DB.query_ + db + [sql| + SELECT + (SELECT COUNT (1) FROM msg_queues WHERE deleted_at IS NULL), + (SELECT COUNT (1) FROM messages m JOIN msg_queues q USING recipient_id WHERE deleted_at IS NULL) + |] + where + st = dbStore $ queueStore_ ms + toMessageStats (storedQueues, storedMsgsCount) = + MessageStats {storedQueues, storedMsgsCount, expiredMsgsCount = 0} + +getDbMessageCount :: PostgresMsgStore -> IO Int64 +getDbMessageCount ms = + maybeFirstRow' 0 fromOnly $ + withConnection (dbStore $ queueStore_ ms) (`DB.query_` "SELECT COUNT(*) FROM messages") + +deleteAllMessages :: PostgresMsgStore -> IO () +deleteAllMessages ms = + withConnection (dbStore $ queueStore_ ms) $ \db -> do + void $ DB.execute_ db "TRUNCATE messages" + void $ DB.execute_ + db + [sql| + UPDATE msg_queues + SET msg_queue_size = 0, msg_can_write = TRUE, msg_queue_expire = FALSE + WHERE msg_queue_size != 0 OR msg_can_write = FALSE OR msg_queue_expire = TRUE + |] + +updateQueueCounts :: PostgresMsgStore -> IO () +updateQueueCounts ms = + withConnection (dbStore $ queueStore_ ms) $ \db -> do + void $ DB.execute_ + db + [sql| + CREATE TEMP TABLE queue_stats AS + SELECT recipient_id, + COUNT(*) AS size, + SUM(CASE WHEN msg_quota THEN 1 ELSE 0 END) AS quota_count + FROM messages + GROUP BY recipient_id + |] + void $ DB.execute_ + db + [sql| + UPDATE msg_queues + SET msg_queue_size = 0, msg_can_write = TRUE, msg_queue_expire = FALSE + WHERE msg_queue_size != 0 OR msg_can_write = FALSE OR msg_queue_expire = TRUE + |] + void $ DB.execute_ + db + [sql| + UPDATE msg_queues q + SET msg_queue_size = s.size, + msg_can_write = s.quota_count = 0, + msg_queue_expire = s.size > s.quota_count + FROM queue_stats s + WHERE q.recipient_id = s.recipient_id + |] + void $ DB.execute_ db "DROP TABLE queue_stats" + +batchInsertMessages :: StoreQueueClass q => Bool -> FilePath -> PostgresQueueStore q -> IO Int64 +batchInsertMessages tty f toStore = do + putStrLn "Importing messages..." + let st = dbStore toStore + (_, inserted) <- + withTransaction st $ \db -> do + DB.copy_ + db + [sql| + COPY messages (recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body) + FROM STDIN WITH (FORMAT CSV) + |] + foldLogLines tty f (putMessage db) (0 :: Int, 0) >>= (DB.putCopyEnd db $>) + pure inserted + where + putMessage db (!i, !cnt) _eof s = do + let i' = i + 1 + cnt' <- case strDecode s of + Right (MLRv3 rId msg) -> (cnt + 1) <$ DB.putCopyData db (messageRecToText rId msg) + Left e -> cnt <$ putStrLn ("Error parsing line " <> show i' <> ": " <> e) + pure (i', cnt') + +messageRecToText :: RecipientId -> Message -> B.ByteString +messageRecToText rId msg = + LB.toStrict $ BB.toLazyByteString $ mconcat tabFields <> BB.char7 '\n' + where + tabFields = BB.char7 ',' `intersperse` fields + fields = + [ renderField (toField rId), + renderField (toField $ Binary (messageId msg)), + renderField (toField $ systemSeconds (messageTs msg)), + renderField (toField msgQuota), + renderField (toField ntf), + renderField (toField $ Binary body) + ] + (msgQuota, ntf, body) = case msg of + Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body') + MessageQuota {} -> (True, False, B.empty) diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index ed24e85a4..73e1bf398 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -57,18 +57,16 @@ data STMStoreConfig = STMStoreConfig } instance StoreQueueClass STMQueue where - type MsgQueue STMQueue = STMMsgQueue recipientId = recipientId' {-# INLINE recipientId #-} queueRec = queueRec' {-# INLINE queueRec #-} - msgQueue = msgQueue' - {-# INLINE msgQueue #-} withQueueLock _ _ = id {-# INLINE withQueueLock #-} instance MsgStoreClass STMMsgStore where type StoreMonad STMMsgStore = STM + type MsgQueue STMMsgStore = STMMsgQueue type QueueStore STMMsgStore = STMQueueStore STMQueue type StoreQueue STMMsgStore = STMQueue type MsgStoreConfig STMMsgStore = STMStoreConfig @@ -82,7 +80,7 @@ instance MsgStoreClass STMMsgStore where {-# INLINE closeMsgStore #-} withActiveMsgQueues = withLoadedQueues . queueStore_ {-# INLINE withActiveMsgQueues #-} - unsafeWithAllMsgQueues _ _ = withLoadedQueues . queueStore_ + unsafeWithAllMsgQueues _ = withLoadedQueues . queueStore_ {-# INLINE unsafeWithAllMsgQueues #-} expireOldMessages :: Bool -> STMMsgStore -> Int64 -> Int64 -> IO MessageStats @@ -129,10 +127,10 @@ instance MsgStoreClass STMMsgStore where Nothing -> pure (Nothing, 0) deleteQueue :: STMMsgStore -> STMQueue -> IO (Either ErrorType QueueRec) - deleteQueue ms q = fst <$$> deleteStoreQueue (queueStore_ ms) q + deleteQueue ms q = fst <$$> deleteQueue_ ms q deleteQueueSize :: STMMsgStore -> STMQueue -> IO (Either ErrorType (QueueRec, Int)) - deleteQueueSize ms q = deleteStoreQueue (queueStore_ ms) q >>= mapM (traverse getSize) + deleteQueueSize ms q = deleteQueue_ ms q >>= mapM (traverse getSize) -- traverse operates on the second tuple element where getSize = maybe (pure 0) (\STMMsgQueue {size} -> readTVarIO size) @@ -179,10 +177,15 @@ instance MsgStoreClass STMMsgStore where Just _ -> modifyTVar' size (subtract 1) _ -> pure () - isolateQueue :: STMQueue -> Text -> STM a -> ExceptT ErrorType IO a - isolateQueue _ _ = liftIO . atomically + isolateQueue :: STMMsgStore -> STMQueue -> Text -> STM a -> ExceptT ErrorType IO a + isolateQueue _ _ _ = liftIO . atomically {-# INLINE isolateQueue #-} unsafeRunStore :: STMQueue -> Text -> STM a -> IO a unsafeRunStore _ _ = atomically {-# INLINE unsafeRunStore #-} + +deleteQueue_ :: STMMsgStore -> STMQueue -> IO (Either ErrorType (QueueRec, Maybe STMMsgQueue)) +deleteQueue_ ms q = deleteStoreQueue (queueStore_ ms) q >>= mapM remove + where + remove qr = (qr,) <$> atomically (swapTVar (msgQueue' q) Nothing) diff --git a/src/Simplex/Messaging/Server/MsgStore/Types.hs b/src/Simplex/Messaging/Server/MsgStore/Types.hs index e2d139ffb..98c12d4be 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Types.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Types.hs @@ -34,14 +34,15 @@ import Simplex.Messaging.Util ((<$$>), ($>>=)) class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => MsgStoreClass s where type StoreMonad s = (m :: Type -> Type) | m -> s type MsgStoreConfig s = c | c -> s + type MsgQueue s = q | q -> s type StoreQueue s = q | q -> s type QueueStore s = qs | qs -> s newMsgStore :: MsgStoreConfig s -> IO s closeMsgStore :: s -> IO () withActiveMsgQueues :: Monoid a => s -> (StoreQueue s -> IO a) -> IO a -- This function can only be used in server CLI commands or before server is started. - -- tty, withData, store - unsafeWithAllMsgQueues :: Monoid a => Bool -> Bool -> s -> (StoreQueue s -> IO a) -> IO a + -- tty, store + unsafeWithAllMsgQueues :: Monoid a => Bool -> s -> (StoreQueue s -> IO a) -> IO a -- tty, store, now, ttl expireOldMessages :: Bool -> s -> Int64 -> Int64 -> IO MessageStats logQueueStates :: s -> IO () @@ -51,29 +52,62 @@ class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => M -- message store methods mkQueue :: s -> Bool -> RecipientId -> QueueRec -> IO (StoreQueue s) - getMsgQueue :: s -> StoreQueue s -> Bool -> StoreMonad s (MsgQueue (StoreQueue s)) - getPeekMsgQueue :: s -> StoreQueue s -> StoreMonad s (Maybe (MsgQueue (StoreQueue s), Message)) + getMsgQueue :: s -> StoreQueue s -> Bool -> StoreMonad s (MsgQueue s) + getPeekMsgQueue :: s -> StoreQueue s -> StoreMonad s (Maybe (MsgQueue s, Message)) -- the journal queue will be closed after action if it was initially closed or idle longer than interval in config - withIdleMsgQueue :: Int64 -> s -> StoreQueue s -> (MsgQueue (StoreQueue s) -> StoreMonad s a) -> StoreMonad s (Maybe a, Int) + withIdleMsgQueue :: Int64 -> s -> StoreQueue s -> (MsgQueue s -> StoreMonad s a) -> StoreMonad s (Maybe a, Int) deleteQueue :: s -> StoreQueue s -> IO (Either ErrorType QueueRec) deleteQueueSize :: s -> StoreQueue s -> IO (Either ErrorType (QueueRec, Int)) - getQueueMessages_ :: Bool -> StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s [Message] + getQueueMessages_ :: Bool -> StoreQueue s -> MsgQueue s -> StoreMonad s [Message] writeMsg :: s -> StoreQueue s -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool)) setOverQuota_ :: StoreQueue s -> IO () -- can ONLY be used while restoring messages, not while server running - getQueueSize_ :: MsgQueue (StoreQueue s) -> StoreMonad s Int - tryPeekMsg_ :: StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s (Maybe Message) - tryDeleteMsg_ :: StoreQueue s -> MsgQueue (StoreQueue s) -> Bool -> StoreMonad s () - isolateQueue :: StoreQueue s -> Text -> StoreMonad s a -> ExceptT ErrorType IO a + getQueueSize_ :: MsgQueue s -> StoreMonad s Int + tryPeekMsg_ :: StoreQueue s -> MsgQueue s -> StoreMonad s (Maybe Message) + tryDeleteMsg_ :: StoreQueue s -> MsgQueue s -> Bool -> StoreMonad s () + isolateQueue :: s -> StoreQueue s -> Text -> StoreMonad s a -> ExceptT ErrorType IO a unsafeRunStore :: StoreQueue s -> Text -> StoreMonad s a -> IO a -data MSType = MSMemory | MSJournal + -- default implementations are overridden for PostgreSQL storage of messages + tryPeekMsg :: s -> StoreQueue s -> ExceptT ErrorType IO (Maybe Message) + tryPeekMsg st q = snd <$$> withPeekMsgQueue st q "tryPeekMsg" pure + {-# INLINE tryPeekMsg #-} + + tryDelMsg :: s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message) + tryDelMsg st q msgId' = + withPeekMsgQueue st q "tryDelMsg" $ + maybe (pure Nothing) $ \(mq, msg) -> + if + | messageId msg == msgId' -> + tryDeleteMsg_ q mq True $> Just msg + | otherwise -> pure Nothing + + -- atomic delete (== read) last and peek next message if available + tryDelPeekMsg :: s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message) + tryDelPeekMsg st q msgId' = + withPeekMsgQueue st q "tryDelPeekMsg" $ + maybe (pure (Nothing, Nothing)) $ \(mq, msg) -> + if + | messageId msg == msgId' -> (Just msg,) <$> (tryDeleteMsg_ q mq True >> tryPeekMsg_ q mq) + | otherwise -> pure (Nothing, Just msg) + + deleteExpiredMsgs :: s -> StoreQueue s -> Int64 -> ExceptT ErrorType IO Int + deleteExpiredMsgs st q old = + isolateQueue st q "deleteExpiredMsgs" $ + getMsgQueue st q False >>= deleteExpireMsgs_ old q + + getQueueSize :: s -> StoreQueue s -> ExceptT ErrorType IO Int + getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst) + {-# INLINE getQueueSize #-} + +data MSType = MSMemory | MSJournal | MSPostgres data QSType = QSMemory | QSPostgres data SMSType :: MSType -> Type where SMSMemory :: SMSType 'MSMemory SMSJournal :: SMSType 'MSJournal + SMSPostgres :: SMSType 'MSPostgres data SQSType :: QSType -> Type where SQSMemory :: SQSType 'QSMemory @@ -84,6 +118,7 @@ data MessageStats = MessageStats expiredMsgsCount :: Int, storedQueues :: Int } + deriving (Show) instance Monoid MessageStats where mempty = MessageStats 0 0 0 @@ -126,48 +161,19 @@ readQueueRec :: StoreQueueClass q => q -> IO (Either ErrorType (q, QueueRec)) readQueueRec q = maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q) {-# INLINE readQueueRec #-} -getQueueSize :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO Int -getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst) -{-# INLINE getQueueSize #-} - -tryPeekMsg :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO (Maybe Message) -tryPeekMsg st q = snd <$$> withPeekMsgQueue st q "tryPeekMsg" pure -{-# INLINE tryPeekMsg #-} - -tryDelMsg :: MsgStoreClass s => s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message) -tryDelMsg st q msgId' = - withPeekMsgQueue st q "tryDelMsg" $ - maybe (pure Nothing) $ \(mq, msg) -> - if - | messageId msg == msgId' -> - tryDeleteMsg_ q mq True $> Just msg - | otherwise -> pure Nothing - --- atomic delete (== read) last and peek next message if available -tryDelPeekMsg :: MsgStoreClass s => s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message) -tryDelPeekMsg st q msgId' = - withPeekMsgQueue st q "tryDelPeekMsg" $ - maybe (pure (Nothing, Nothing)) $ \(mq, msg) -> - if - | messageId msg == msgId' -> (Just msg,) <$> (tryDeleteMsg_ q mq True >> tryPeekMsg_ q mq) - | otherwise -> pure (Nothing, Just msg) - -- The action is called with Nothing when it is known that the queue is empty -withPeekMsgQueue :: MsgStoreClass s => s -> StoreQueue s -> Text -> (Maybe (MsgQueue (StoreQueue s), Message) -> StoreMonad s a) -> ExceptT ErrorType IO a -withPeekMsgQueue st q op a = isolateQueue q op $ getPeekMsgQueue st q >>= a +withPeekMsgQueue :: MsgStoreClass s => s -> StoreQueue s -> Text -> (Maybe (MsgQueue s, Message) -> StoreMonad s a) -> ExceptT ErrorType IO a +withPeekMsgQueue st q op a = isolateQueue st q op $ getPeekMsgQueue st q >>= a {-# INLINE withPeekMsgQueue #-} -deleteExpiredMsgs :: MsgStoreClass s => s -> StoreQueue s -> Int64 -> ExceptT ErrorType IO Int -deleteExpiredMsgs st q old = - isolateQueue q "deleteExpiredMsgs" $ - getMsgQueue st q False >>= deleteExpireMsgs_ old q - +-- not used with PostgreSQL message store expireQueueMsgs :: MsgStoreClass s => s -> Int64 -> Int64 -> StoreQueue s -> StoreMonad s MessageStats expireQueueMsgs st now old q = do (expired_, stored) <- withIdleMsgQueue now st q $ deleteExpireMsgs_ old q pure MessageStats {storedMsgsCount = stored, expiredMsgsCount = fromMaybe 0 expired_, storedQueues = 1} -deleteExpireMsgs_ :: MsgStoreClass s => Int64 -> StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s Int +-- not used with PostgreSQL message store +deleteExpireMsgs_ :: MsgStoreClass s => Int64 -> StoreQueue s -> MsgQueue s -> StoreMonad s Int deleteExpireMsgs_ old q mq = do n <- loop 0 logQueueState q diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index 5ed0754ec..4a53dcdd4 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -25,9 +25,13 @@ module Simplex.Messaging.Server.QueueStore.Postgres batchInsertQueues, foldServiceRecs, foldQueueRecs, + foldRecentQueueRecs, handleDuplicate, withLog_, + withDB, withDB', + assertUpdated, + renderField, ) where @@ -70,6 +74,7 @@ import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parseAll) @@ -83,7 +88,7 @@ import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (SMPServiceRole (..)) -import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>)) +import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, maybeFirstRow', tshow, (<$$>)) import System.Exit (exitFailure) import System.IO (IOMode (..), hFlush, stdout) import UnliftIO.STM @@ -104,15 +109,18 @@ data PostgresQueueStore q = PostgresQueueStore notifiers :: TMap NotifierId RecipientId, notifierLocks :: TMap NotifierId Lock, serviceLocks :: TMap CertFingerprint Lock, - deletedTTL :: Int64 + deletedTTL :: Int64, + useCache :: Bool } -instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where - type QueueStoreCfg (PostgresQueueStore q) = PostgresStoreCfg +type UseQueueCache = Bool - newQueueStore :: PostgresStoreCfg -> IO (PostgresQueueStore q) - newQueueStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do - dbStore <- either err pure =<< createDBStore dbOpts serverMigrations confirmMigrations +instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where + type QueueStoreCfg (PostgresQueueStore q) = (PostgresStoreCfg, UseQueueCache) + + newQueueStore :: (PostgresStoreCfg, UseQueueCache) -> IO (PostgresQueueStore q) + newQueueStore (PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL}, useCache) = do + dbStore <- either err pure =<< createDBStore dbOpts serverMigrations (MigrationConfig confirmMigrations Nothing) dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath queues <- TM.emptyIO senders <- TM.emptyIO @@ -120,7 +128,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where notifiers <- TM.emptyIO notifierLocks <- TM.emptyIO serviceLocks <- TM.emptyIO - pure PostgresQueueStore {dbStore, dbStoreLog, queues, senders, links, notifiers, notifierLocks, serviceLocks, deletedTTL} + pure PostgresQueueStore {dbStore, dbStoreLog, queues, senders, links, notifiers, notifierLocks, serviceLocks, deletedTTL, useCache} where err e = do logError $ "STORE: newQueueStore, error opening PostgreSQL database, " <> tshow e @@ -167,28 +175,35 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where void $ withDB "addQueue_" st $ \db -> E.try (DB.execute db insertQueueQuery $ queueRecToRow (rId, qr)) >>= bimapM handleDuplicate pure - atomically $ TM.insert rId sq queues - atomically $ TM.insert (senderId qr) rId senders - forM_ (notifier qr) $ \NtfCreds {notifierId = nId} -> atomically $ TM.insert nId rId notifiers - forM_ (queueData qr) $ \(lnkId, _) -> atomically $ TM.insert lnkId rId links + when useCache $ do + atomically $ TM.insert rId sq queues + atomically $ TM.insert (senderId qr) rId senders + forM_ (notifier qr) $ \NtfCreds {notifierId = nId} -> atomically $ TM.insert nId rId notifiers + forM_ (queueData qr) $ \(lnkId, _) -> atomically $ TM.insert lnkId rId links withLog "addStoreQueue" st $ \s -> logCreateQueue s rId qr pure sq where - PostgresQueueStore {queues, senders, links, notifiers} = st + PostgresQueueStore {queues, senders, links, notifiers, useCache} = st -- Not doing duplicate checks in maps as the probability of duplicates is very low. -- It needs to be reconsidered when IDs are supplied by the users. -- hasId = anyM [TM.memberIO rId queues, TM.memberIO senderId senders, hasNotifier] -- hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.memberIO notifierId notifiers) notifier getQueue_ :: QueueParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q) - getQueue_ st mkQ party qId = case party of - SRecipient -> getRcvQueue qId - SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue - SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue - -- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server - SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>)) + getQueue_ st mkQ party qId + | useCache = case party of + SRecipient -> getRcvQueue qId + SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue + SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue + -- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server + SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>)) + | otherwise = case party of + SRecipient -> loadQueueNoCache " WHERE recipient_id = ?" + SSender -> loadQueueNoCache " WHERE sender_id = ?" + SSenderLink -> loadQueueNoCache " WHERE link_id = ?" + SNotifier -> loadQueueNoCache " WHERE notifier_id = ?" where - PostgresQueueStore {queues, senders, links, notifiers} = st + PostgresQueueStore {queues, senders, links, notifiers, useCache} = st getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right) loadRcvQueue = do (rId, qRec) <- loadQueue " WHERE recipient_id = ?" @@ -205,6 +220,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where liftIO $ TM.lookupIO rId queues -- checking recipient map first >>= maybe (cacheQueue rId qRec cacheSender) (atomically (cacheSender rId) $>) + loadQueueNoCache cond = mask $ loadQueue cond >>= liftIO . uncurry (mkQ True) mask = E.uninterruptibleMask_ . runExceptT cacheSender rId = TM.insert qId rId senders loadQueue condition = @@ -227,20 +243,27 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where pure sq getQueues_ :: forall p. BatchParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q] - getQueues_ st mkQ party qIds = case party of - SRecipient -> do - qs <- readTVarIO queues - let qs' = map (\qId -> get qs qId qId) qIds - E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue - SNotifier -> do - ns <- readTVarIO notifiers - qs <- readTVarIO queues - let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds - E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) -> - forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query - (nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs) + getQueues_ st mkQ party qIds + | null qIds = pure [] + | useCache = case party of + SRecipient -> do + qs <- readTVarIO queues + let qs' = map (\qId -> get qs qId qId) qIds + E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue + SNotifier -> do + ns <- readTVarIO notifiers + qs <- readTVarIO queues + let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds + E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) -> + forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query + (nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs) + | otherwise = E.uninterruptibleMask_ $ case party of + SRecipient -> loadQueuesNoCache " WHERE recipient_id IN ?" $ \(rId, qRec) -> + Just . (rId,) <$> mkQ False rId qRec + SNotifier -> loadQueuesNoCache " WHERE notifier_id IN ?" $ \(rId, qRec) -> + forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> (nId,) <$> mkQ False rId qRec where - PostgresQueueStore {queues, notifiers} = st + PostgresQueueStore {queues, notifiers, useCache} = st get :: M.Map QueueId a -> QueueId -> QueueId -> Either QueueId a get m qId = maybe (Left qId) Right . (`M.lookup` m) loadQueues :: [Either QueueId q] -> Query -> ((RecipientId, QueueRec) -> IO (Maybe (QueueId, q))) -> IO [Either ErrorType q] @@ -249,15 +272,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where if null qIds' then pure $ map (first (const INTERNAL)) qs' else do - qs_ <- - runExceptT $ fmap M.fromList $ - withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds'))) - >>= liftIO . fmap catMaybes . mapM (mkCacheQueue . rowToQueueRec) + qs_ <- dbLoadQueues qIds' cond mkCacheQueue pure $ map (result qs_) qs' where result :: Either ErrorType (M.Map QueueId q) -> Either QueueId q -> Either ErrorType q result _ (Right q) = Right q result qs_ (Left qId) = maybe (Left AUTH) Right . M.lookup qId =<< qs_ + dbLoadQueues qIds' cond mkQueue' = + runExceptT $ fmap M.fromList $ + withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds'))) + >>= liftIO . fmap catMaybes . mapM (mkQueue' . rowToQueueRec) cacheRcvQueue (rId, qRec) = do sq <- mkQ True rId qRec sq' <- withQueueLock sq "getQueue_" $ atomically $ @@ -266,6 +290,12 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where Just sq' -> pure sq' Nothing -> sq <$ TM.insert rId sq queues pure $ Just (rId, sq') + loadQueuesNoCache cond mkQueue' = do + qs_ <- dbLoadQueues qIds cond mkQueue' + pure $ map (result qs_) qIds + where + result :: Either ErrorType (M.Map QueueId q) -> QueueId -> Either ErrorType q + result qs_ qId = maybe (Left AUTH) Right . M.lookup qId =<< qs_ getQueueLinkData :: PostgresQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData) getQueueLinkData st sq lnkId = runExceptT $ do @@ -331,19 +361,23 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where addQueueNotifier :: PostgresQueueStore q -> q -> NtfCreds -> IO (Either ErrorType (Maybe NtfCreds)) addQueueNotifier st sq ntfCreds@NtfCreds {notifierId = nId, notifierKey, rcvNtfDhSecret} = withQueueRec sq "addQueueNotifier" $ \q -> - ExceptT $ withLockMap (notifierLocks st) nId "addQueueNotifier" $ - ifM (TM.memberIO nId notifiers) (pure $ Left DUPLICATE_) $ runExceptT $ do - assertUpdated $ withDB "addQueueNotifier" st $ \db -> - E.try (update db) >>= bimapM handleDuplicate pure - nc_ <- forM (notifier q) $ \nc@NtfCreds {notifierId} -> atomically (TM.delete notifierId notifiers) $> nc - let !q' = q {notifier = Just ntfCreds} - atomically $ writeTVar (queueRec sq) $ Just q' - -- cache queue notifier ID – after notifier is added ntf server will likely subscribe + checkCachedNotifier $ do + assertUpdated $ withDB "addQueueNotifier" st $ \db -> + E.try (update db) >>= bimapM handleDuplicate pure + nc_ <- forM (notifier q) $ \nc@NtfCreds {notifierId} -> atomically (TM.delete notifierId notifiers) $> nc + let !q' = q {notifier = Just ntfCreds} + atomically $ writeTVar (queueRec sq) $ Just q' + when useCache $ do atomically $ TM.insert nId rId notifiers - withLog "addQueueNotifier" st $ \s -> logAddNotifier s rId ntfCreds - pure nc_ + withLog "addQueueNotifier" st $ \s -> logAddNotifier s rId ntfCreds + pure nc_ where - PostgresQueueStore {notifiers} = st + checkCachedNotifier add + | useCache = + ExceptT $ withLockMap (notifierLocks st) nId "addQueueNotifier" $ + ifM (TM.memberIO nId notifiers) (pure $ Left DUPLICATE_) $ runExceptT add + | otherwise = add + PostgresQueueStore {notifiers, useCache} = st rId = recipientId sq update db = DB.execute @@ -359,13 +393,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where deleteQueueNotifier st sq = withQueueRec sq "deleteQueueNotifier" $ \q -> ExceptT $ fmap sequence $ forM (notifier q) $ \nc@NtfCreds {notifierId = nId} -> - withLockMap (notifierLocks st) nId "deleteQueueNotifier" $ runExceptT $ do + withNotifierLock nId $ runExceptT $ do assertUpdated $ withDB' "deleteQueueNotifier" st update - atomically $ TM.delete nId $ notifiers st + when (useCache st) $ atomically $ TM.delete nId $ notifiers st atomically $ writeTVar (queueRec sq) $ Just q {notifier = Nothing} withLog "deleteQueueNotifier" st (`logDeleteNotifier` rId) pure nc where + withNotifierLock nId + | useCache st = withLockMap (notifierLocks st) nId "deleteQueueNotifier" + | otherwise = id rId = recipientId sq update db = DB.execute @@ -408,20 +445,20 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where rId = recipientId sq -- this method is called from JournalMsgStore deleteQueue that already locks the queue - deleteStoreQueue :: PostgresQueueStore q -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q))) + deleteStoreQueue :: PostgresQueueStore q -> q -> IO (Either ErrorType QueueRec) deleteStoreQueue st sq = E.uninterruptibleMask_ $ runExceptT $ do q <- ExceptT $ readQueueRecIO qr RoundedSystemTime ts <- liftIO getSystemDate assertUpdated $ withDB' "deleteStoreQueue" st $ \db -> DB.execute db "UPDATE msg_queues SET deleted_at = ? WHERE recipient_id = ? AND deleted_at IS NULL" (ts, rId) atomically $ writeTVar qr Nothing - atomically $ TM.delete (senderId q) $ senders st - forM_ (notifier q) $ \NtfCreds {notifierId} -> do - atomically $ TM.delete notifierId $ notifiers st - atomically $ TM.delete notifierId $ notifierLocks st - mq_ <- atomically $ swapTVar (msgQueue sq) Nothing + when (useCache st) $ do + atomically $ TM.delete (senderId q) $ senders st + forM_ (notifier q) $ \NtfCreds {notifierId} -> do + atomically $ TM.delete notifierId $ notifiers st + atomically $ TM.delete notifierId $ notifierLocks st withLog "deleteStoreQueue" st (`logDeleteQueue` rId) - pure (q, mq_) + pure q where rId = recipientId sq qr = queueRec sq @@ -487,7 +524,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where getServiceQueueCount :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64) getServiceQueueCount st party serviceId = E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCount" st $ \db -> - fmap (fromMaybe 0) $ maybeFirstRow fromOnly $ + maybeFirstRow' 0 fromOnly $ DB.query db query (Only serviceId) where query = case party of @@ -545,8 +582,28 @@ foldServiceRecs st f = DB.fold_ db "SELECT service_id, service_role, service_cert, service_cert_hash, created_at FROM services" mempty $ \ !acc -> fmap (acc <>) . f . rowToServiceRec -foldQueueRecs :: forall a q. Monoid a => Bool -> Bool -> PostgresQueueStore q -> Maybe Int64 -> ((RecipientId, QueueRec) -> IO a) -> IO a -foldQueueRecs tty withData st skipOld_ f = do +foldQueueRecs :: Monoid a => Bool -> Bool -> PostgresQueueStore q -> ((RecipientId, QueueRec) -> IO a) -> IO a +foldQueueRecs withData = foldQueueRecs_ foldRecs + where + foldRecs db acc f' + | withData = DB.fold_ db (queueRecQueryWithData <> cond) acc $ \acc' -> f' acc' . rowToQueueRecWithData + | otherwise = DB.fold_ db (queueRecQuery <> cond) acc $ \acc' -> f' acc' . rowToQueueRec + cond = " WHERE deleted_at IS NULL ORDER BY recipient_id ASC" + +foldRecentQueueRecs :: Monoid a => Int64 -> Bool -> PostgresQueueStore q -> ((RecipientId, QueueRec) -> IO a) -> IO a +foldRecentQueueRecs old = foldQueueRecs_ foldRecs + where + foldRecs db acc f' = DB.fold db (queueRecQuery <> cond) (Only old) acc $ \acc' -> f' acc' . rowToQueueRec + cond = " WHERE deleted_at IS NULL AND updated_at > ? ORDER BY recipient_id ASC" + +foldQueueRecs_ :: + Monoid a => + (DB.Connection -> (Int, a) -> ((Int, a) -> (RecipientId, QueueRec) -> IO (Int, a)) -> IO (Int, a)) -> + Bool -> + PostgresQueueStore q -> + ((RecipientId, QueueRec) -> IO a) -> + IO a +foldQueueRecs_ foldRecs tty st f = do (n, r) <- withTransaction (dbStore st) $ \db -> foldRecs db (0 :: Int, mempty) $ \(i, acc) qr -> do r <- f qr @@ -557,13 +614,6 @@ foldQueueRecs tty withData st skipOld_ f = do when tty $ putStrLn $ progress n pure r where - foldRecs db acc f' = case skipOld_ of - Nothing - | withData -> DB.fold_ db (queueRecQueryWithData <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRecWithData - | otherwise -> DB.fold_ db (queueRecQuery <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRec - Just old - | withData -> DB.fold db (queueRecQueryWithData <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRecWithData - | otherwise -> DB.fold db (queueRecQuery <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRec progress i = "Processed: " <> show i <> " records" queueRecQuery :: Query @@ -627,13 +677,14 @@ queueRecToText (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey, (linkId_, queueData_) = queueDataColumns queueData nullable :: ToField a => Maybe a -> Builder nullable = maybe mempty (renderField . toField) - renderField :: Action -> Builder - renderField = \case - Plain bld -> bld - Escape s -> BB.byteString s - EscapeByteA s -> BB.string7 "\\x" <> BB.byteStringHex s - EscapeIdentifier s -> BB.byteString s -- Not used in COPY data - Many as -> mconcat (map renderField as) + +renderField :: Action -> Builder +renderField = \case + Plain bld -> bld + Escape s -> BB.byteString s + EscapeByteA s -> BB.string7 "\\x" <> BB.byteStringHex s + EscapeIdentifier s -> BB.byteString s -- Not used in COPY data + Many as -> mconcat (map renderField as) queueDataColumns :: Maybe (LinkId, QueueLinkData) -> (Maybe LinkId, Maybe QueueLinkData) queueDataColumns = \case diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs index e8469d1cc..be14202c6 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/Migrations.hs @@ -14,7 +14,8 @@ serverSchemaMigrations = [ ("20250207_initial", m20250207_initial, Nothing), ("20250319_updated_index", m20250319_updated_index, Just down_m20250319_updated_index), ("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links), - ("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs) + ("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs), + ("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages) ] -- | The list of migrations in ascending order by date @@ -159,3 +160,299 @@ DROP INDEX idx_services_service_role; DROP TABLE services; |] + +m20250903_store_messages :: Text +m20250903_store_messages = + T.pack + [r| +CREATE TABLE messages( + message_id BIGINT NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + recipient_id BYTEA NOT NULL REFERENCES msg_queues ON DELETE CASCADE ON UPDATE RESTRICT, + msg_id BYTEA NOT NULL, + msg_ts BIGINT NOT NULL, + msg_quota BOOLEAN NOT NULL, + msg_ntf_flag BOOLEAN NOT NULL, + msg_body BYTEA NOT NULL +); + +ALTER TABLE msg_queues + ADD COLUMN msg_can_write BOOLEAN NOT NULL DEFAULT TRUE, + ADD COLUMN msg_queue_expire BOOLEAN NOT NULL DEFAULT FALSE, + ADD COLUMN msg_queue_size BIGINT NOT NULL DEFAULT 0; + +CREATE INDEX idx_messages_recipient_id_message_id ON messages (recipient_id, message_id); +CREATE INDEX idx_messages_recipient_id_msg_ts on messages(recipient_id, msg_ts); +CREATE INDEX idx_messages_recipient_id_msg_quota on messages(recipient_id, msg_quota); + +DROP INDEX idx_msg_queues_updated_at; +CREATE INDEX idx_msg_queues_updated_at_recipient_id ON msg_queues (deleted_at, updated_at, msg_queue_expire, recipient_id); + +CREATE FUNCTION write_message( + p_recipient_id BYTEA, + p_msg_id BYTEA, + p_msg_ts BIGINT, + p_msg_quota BOOLEAN, + p_msg_ntf_flag BOOLEAN, + p_msg_body BYTEA, + p_quota INT +) +RETURNS TABLE (quota_written BOOLEAN, was_empty BOOLEAN) +LANGUAGE plpgsql AS $$ +DECLARE + q_can_write BOOLEAN; + q_size BIGINT; +BEGIN + SELECT msg_can_write, msg_queue_size INTO q_can_write, q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE; + + IF q_can_write OR q_size = 0 THEN + quota_written := p_msg_quota OR q_size >= p_quota; + was_empty := q_size = 0; + + INSERT INTO messages(recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body) + VALUES (p_recipient_id, p_msg_id, p_msg_ts, quota_written, p_msg_ntf_flag AND NOT quota_written, CASE WHEN quota_written THEN '' :: BYTEA ELSE p_msg_body END); + + UPDATE msg_queues + SET msg_can_write = NOT quota_written, + msg_queue_expire = TRUE, + msg_queue_size = msg_queue_size + 1 + WHERE recipient_id = p_recipient_id; + + RETURN QUERY VALUES (quota_written, was_empty); + END IF; +END; +$$; + +CREATE FUNCTION try_del_msg(p_recipient_id BYTEA, p_msg_id BYTEA) +RETURNS TABLE (r_msg_id BYTEA, r_msg_ts BIGINT, r_msg_quota BOOLEAN, r_msg_ntf_flag BOOLEAN, r_msg_body BYTEA) +LANGUAGE plpgsql AS $$ +DECLARE + q_size BIGINT; + msg RECORD; +BEGIN + SELECT msg_queue_size INTO q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE; + + IF NOT FOUND THEN + RETURN; + END IF; + + SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + INTO msg + FROM messages + WHERE recipient_id = p_recipient_id + ORDER BY message_id ASC LIMIT 1; + + IF NOT FOUND THEN + IF q_size != 0 THEN + UPDATE msg_queues + SET msg_can_write = TRUE, + msg_queue_expire = FALSE, + msg_queue_size = 0 + WHERE recipient_id = p_recipient_id; + END IF; + RETURN; + END IF; + + IF msg.msg_id = p_msg_id THEN + DELETE FROM messages WHERE message_id = msg.message_id; + IF FOUND THEN + UPDATE msg_queues + SET msg_can_write = msg_can_write OR msg_queue_size <= 1, + msg_queue_expire = msg_queue_size > 1, + msg_queue_size = GREATEST(msg_queue_size - 1, 0) + WHERE recipient_id = p_recipient_id; + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + END IF; + END IF; +END; +$$; + +CREATE FUNCTION try_del_peek_msg(p_recipient_id BYTEA, p_msg_id BYTEA) +RETURNS TABLE (r_msg_id BYTEA, r_msg_ts BIGINT, r_msg_quota BOOLEAN, r_msg_ntf_flag BOOLEAN, r_msg_body BYTEA) +LANGUAGE plpgsql AS $$ +DECLARE + q_size BIGINT; + msg RECORD; + msg_deleted BOOLEAN; +BEGIN + SELECT msg_queue_size INTO q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE; + + IF NOT FOUND THEN + RETURN; + END IF; + + SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + INTO msg + FROM messages + WHERE recipient_id = p_recipient_id + ORDER BY message_id ASC LIMIT 1; + + IF NOT FOUND THEN + IF q_size != 0 THEN + UPDATE msg_queues + SET msg_can_write = TRUE, + msg_queue_expire = FALSE, + msg_queue_size = 0 + WHERE recipient_id = p_recipient_id; + END IF; + RETURN; + END IF; + + IF msg.msg_id = p_msg_id THEN + DELETE FROM messages WHERE message_id = msg.message_id; + + msg_deleted := FOUND; + IF msg_deleted THEN + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + END IF; + + SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + INTO msg + FROM messages + WHERE recipient_id = p_recipient_id + ORDER BY message_id ASC LIMIT 1; + + IF FOUND THEN + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + IF msg_deleted THEN + UPDATE msg_queues + SET msg_can_write = msg_can_write OR msg_queue_size <= 1, + msg_queue_expire = msg_queue_size > 1, + msg_queue_size = GREATEST(msg_queue_size - 1, 0) + WHERE recipient_id = p_recipient_id; + END IF; + ELSIF msg_deleted OR q_size != 0 THEN + UPDATE msg_queues + SET msg_can_write = TRUE, + msg_queue_expire = FALSE, + msg_queue_size = 0 + WHERE recipient_id = p_recipient_id; + END IF; + ELSE + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + END IF; +END; +$$; + +CREATE FUNCTION delete_expired_msgs(p_recipient_id BYTEA, p_old_ts BIGINT) RETURNS BIGINT +LANGUAGE plpgsql AS $$ +DECLARE + q_size BIGINT; + keep_min_id BIGINT; + del_count BIGINT; +BEGIN + SELECT msg_queue_size INTO q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE SKIP LOCKED; + + IF NOT FOUND OR q_size = 0 THEN + RETURN 0; + END IF; + + SELECT MIN(message_id) INTO keep_min_id + FROM messages WHERE recipient_id = p_recipient_id AND msg_ts >= p_old_ts AND msg_quota = FALSE; + + IF keep_min_id IS NULL THEN + DELETE FROM messages WHERE recipient_id = p_recipient_id AND msg_quota = FALSE; + ELSE + DELETE FROM messages WHERE recipient_id = p_recipient_id AND message_id < keep_min_id AND msg_quota = FALSE; + END IF; + + GET DIAGNOSTICS del_count = ROW_COUNT; + IF del_count > 0 THEN + UPDATE msg_queues + SET msg_can_write = msg_can_write OR msg_queue_size <= del_count, + msg_queue_expire = msg_queue_size > del_count AND keep_min_id IS NOT NULL, + msg_queue_size = GREATEST(msg_queue_size - del_count, 0) + WHERE recipient_id = p_recipient_id; + END IF; + RETURN del_count; +END; +$$; + +CREATE PROCEDURE expire_old_messages( + p_old_queue BIGINT, + p_old_ts BIGINT, + batch_size INT, + OUT r_expired_msgs_count BIGINT, + OUT r_stored_msgs_count BIGINT, + OUT r_stored_queues BIGINT +) +LANGUAGE plpgsql AS $$ +DECLARE + rids BYTEA[]; + rid BYTEA; + last_rid BYTEA := '\x'; + del_count BIGINT; + total_deleted BIGINT := 0; +BEGIN + LOOP + SELECT array_agg(recipient_id) + INTO rids + FROM ( + SELECT recipient_id + FROM msg_queues + WHERE deleted_at IS NULL + AND updated_at > p_old_queue + AND msg_queue_expire = TRUE + AND recipient_id > last_rid + ORDER BY recipient_id ASC + LIMIT batch_size + ) qs; + + EXIT WHEN rids IS NULL OR cardinality(rids) = 0; + + FOREACH rid IN ARRAY rids + LOOP + BEGIN + del_count := delete_expired_msgs(rid, p_old_ts); + total_deleted := total_deleted + del_count; + EXCEPTION WHEN OTHERS THEN + RAISE WARNING 'STORE, expire_old_messages, error expiring queue %: %', encode(rid, 'base64'), SQLERRM; + CONTINUE; + END; + COMMIT; + END LOOP; + last_rid := rids[cardinality(rids)]; + END LOOP; + + r_expired_msgs_count := total_deleted; + r_stored_msgs_count := (SELECT COUNT(1) FROM messages); + r_stored_queues := (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL); +END; +$$; + |] + +down_m20250903_store_messages :: Text +down_m20250903_store_messages = + T.pack + [r| +DROP FUNCTION write_message; +DROP FUNCTION try_del_msg; +DROP FUNCTION try_del_peek_msg; +DROP FUNCTION delete_expired_msgs; +DROP PROCEDURE expire_old_messages; + +DROP INDEX idx_msg_queues_updated_at_recipient_id; +CREATE INDEX idx_msg_queues_updated_at ON msg_queues (deleted_at, updated_at); + +DROP INDEX idx_messages_recipient_id_message_id; +DROP INDEX idx_messages_recipient_id_msg_ts; +DROP INDEX idx_messages_recipient_id_msg_quota; + +ALTER TABLE msg_queues + DROP COLUMN msg_can_write, + DROP COLUMN msg_queue_expire, + DROP COLUMN msg_queue_size; + +DROP TABLE messages; + |] diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql b/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql index 6c0501d8b..433d45473 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql @@ -15,9 +15,273 @@ SET row_security = off; CREATE SCHEMA smp_server; + +CREATE FUNCTION smp_server.delete_expired_msgs(p_recipient_id bytea, p_old_ts bigint) RETURNS bigint + LANGUAGE plpgsql + AS $$ +DECLARE + q_size BIGINT; + keep_min_id BIGINT; + del_count BIGINT; +BEGIN + SELECT msg_queue_size INTO q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE SKIP LOCKED; + + IF NOT FOUND OR q_size = 0 THEN + RETURN 0; + END IF; + + SELECT MIN(message_id) INTO keep_min_id + FROM messages WHERE recipient_id = p_recipient_id AND msg_ts >= p_old_ts AND msg_quota = FALSE; + + IF keep_min_id IS NULL THEN + DELETE FROM messages WHERE recipient_id = p_recipient_id AND msg_quota = FALSE; + ELSE + DELETE FROM messages WHERE recipient_id = p_recipient_id AND message_id < keep_min_id AND msg_quota = FALSE; + END IF; + + GET DIAGNOSTICS del_count = ROW_COUNT; + IF del_count > 0 THEN + UPDATE msg_queues + SET msg_can_write = msg_can_write OR msg_queue_size <= del_count, + msg_queue_expire = msg_queue_size > del_count AND keep_min_id IS NOT NULL, + msg_queue_size = GREATEST(msg_queue_size - del_count, 0) + WHERE recipient_id = p_recipient_id; + END IF; + RETURN del_count; +END; +$$; + + + +CREATE PROCEDURE smp_server.expire_old_messages(IN p_old_queue bigint, IN p_old_ts bigint, IN batch_size integer, OUT r_expired_msgs_count bigint, OUT r_stored_msgs_count bigint, OUT r_stored_queues bigint) + LANGUAGE plpgsql + AS $$ +DECLARE + rids BYTEA[]; + rid BYTEA; + last_rid BYTEA := '\x'; + del_count BIGINT; + total_deleted BIGINT := 0; +BEGIN + LOOP + SELECT array_agg(recipient_id) + INTO rids + FROM ( + SELECT recipient_id + FROM msg_queues + WHERE deleted_at IS NULL + AND updated_at > p_old_queue + AND msg_queue_expire = TRUE + AND recipient_id > last_rid + ORDER BY recipient_id ASC + LIMIT batch_size + ) qs; + + EXIT WHEN rids IS NULL OR cardinality(rids) = 0; + + FOREACH rid IN ARRAY rids + LOOP + BEGIN + del_count := delete_expired_msgs(rid, p_old_ts); + total_deleted := total_deleted + del_count; + EXCEPTION WHEN OTHERS THEN + RAISE WARNING 'STORE, expire_old_messages, error expiring queue %: %', encode(rid, 'base64'), SQLERRM; + CONTINUE; + END; + COMMIT; + END LOOP; + last_rid := rids[cardinality(rids)]; + END LOOP; + + r_expired_msgs_count := total_deleted; + r_stored_msgs_count := (SELECT COUNT(1) FROM messages); + r_stored_queues := (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL); +END; +$$; + + + +CREATE FUNCTION smp_server.try_del_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea) + LANGUAGE plpgsql + AS $$ +DECLARE + q_size BIGINT; + msg RECORD; +BEGIN + SELECT msg_queue_size INTO q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE; + + IF NOT FOUND THEN + RETURN; + END IF; + + SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + INTO msg + FROM messages + WHERE recipient_id = p_recipient_id + ORDER BY message_id ASC LIMIT 1; + + IF NOT FOUND THEN + IF q_size != 0 THEN + UPDATE msg_queues + SET msg_can_write = TRUE, + msg_queue_expire = FALSE, + msg_queue_size = 0 + WHERE recipient_id = p_recipient_id; + END IF; + RETURN; + END IF; + + IF msg.msg_id = p_msg_id THEN + DELETE FROM messages WHERE message_id = msg.message_id; + IF FOUND THEN + UPDATE msg_queues + SET msg_can_write = msg_can_write OR msg_queue_size <= 1, + msg_queue_expire = msg_queue_size > 1, + msg_queue_size = GREATEST(msg_queue_size - 1, 0) + WHERE recipient_id = p_recipient_id; + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + END IF; + END IF; +END; +$$; + + + +CREATE FUNCTION smp_server.try_del_peek_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea) + LANGUAGE plpgsql + AS $$ +DECLARE + q_size BIGINT; + msg RECORD; + msg_deleted BOOLEAN; +BEGIN + SELECT msg_queue_size INTO q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE; + + IF NOT FOUND THEN + RETURN; + END IF; + + SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + INTO msg + FROM messages + WHERE recipient_id = p_recipient_id + ORDER BY message_id ASC LIMIT 1; + + IF NOT FOUND THEN + IF q_size != 0 THEN + UPDATE msg_queues + SET msg_can_write = TRUE, + msg_queue_expire = FALSE, + msg_queue_size = 0 + WHERE recipient_id = p_recipient_id; + END IF; + RETURN; + END IF; + + IF msg.msg_id = p_msg_id THEN + DELETE FROM messages WHERE message_id = msg.message_id; + + msg_deleted := FOUND; + IF msg_deleted THEN + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + END IF; + + SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body + INTO msg + FROM messages + WHERE recipient_id = p_recipient_id + ORDER BY message_id ASC LIMIT 1; + + IF FOUND THEN + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + IF msg_deleted THEN + UPDATE msg_queues + SET msg_can_write = msg_can_write OR msg_queue_size <= 1, + msg_queue_expire = msg_queue_size > 1, + msg_queue_size = GREATEST(msg_queue_size - 1, 0) + WHERE recipient_id = p_recipient_id; + END IF; + ELSIF msg_deleted OR q_size != 0 THEN + UPDATE msg_queues + SET msg_can_write = TRUE, + msg_queue_expire = FALSE, + msg_queue_size = 0 + WHERE recipient_id = p_recipient_id; + END IF; + ELSE + RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body); + END IF; +END; +$$; + + + +CREATE FUNCTION smp_server.write_message(p_recipient_id bytea, p_msg_id bytea, p_msg_ts bigint, p_msg_quota boolean, p_msg_ntf_flag boolean, p_msg_body bytea, p_quota integer) RETURNS TABLE(quota_written boolean, was_empty boolean) + LANGUAGE plpgsql + AS $$ +DECLARE + q_can_write BOOLEAN; + q_size BIGINT; +BEGIN + SELECT msg_can_write, msg_queue_size INTO q_can_write, q_size + FROM msg_queues + WHERE recipient_id = p_recipient_id AND deleted_at IS NULL + FOR UPDATE; + + IF q_can_write OR q_size = 0 THEN + quota_written := p_msg_quota OR q_size >= p_quota; + was_empty := q_size = 0; + + INSERT INTO messages(recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body) + VALUES (p_recipient_id, p_msg_id, p_msg_ts, quota_written, p_msg_ntf_flag AND NOT quota_written, CASE WHEN quota_written THEN '' :: BYTEA ELSE p_msg_body END); + + UPDATE msg_queues + SET msg_can_write = NOT quota_written, + msg_queue_expire = TRUE, + msg_queue_size = msg_queue_size + 1 + WHERE recipient_id = p_recipient_id; + + RETURN QUERY VALUES (quota_written, was_empty); + END IF; +END; +$$; + + SET default_table_access_method = heap; +CREATE TABLE smp_server.messages ( + message_id bigint NOT NULL, + recipient_id bytea NOT NULL, + msg_id bytea NOT NULL, + msg_ts bigint NOT NULL, + msg_quota boolean NOT NULL, + msg_ntf_flag boolean NOT NULL, + msg_body bytea NOT NULL +); + + + +ALTER TABLE smp_server.messages ALTER COLUMN message_id ADD GENERATED ALWAYS AS IDENTITY ( + SEQUENCE NAME smp_server.messages_message_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + + + CREATE TABLE smp_server.migrations ( name text NOT NULL, ts timestamp without time zone NOT NULL, @@ -43,7 +307,10 @@ CREATE TABLE smp_server.msg_queues ( fixed_data bytea, user_data bytea, rcv_service_id bytea, - ntf_service_id bytea + ntf_service_id bytea, + msg_can_write boolean DEFAULT true NOT NULL, + msg_queue_expire boolean DEFAULT false NOT NULL, + msg_queue_size bigint DEFAULT 0 NOT NULL ); @@ -58,6 +325,11 @@ CREATE TABLE smp_server.services ( +ALTER TABLE ONLY smp_server.messages + ADD CONSTRAINT messages_pkey PRIMARY KEY (message_id); + + + ALTER TABLE ONLY smp_server.migrations ADD CONSTRAINT migrations_pkey PRIMARY KEY (name); @@ -78,6 +350,18 @@ ALTER TABLE ONLY smp_server.services +CREATE INDEX idx_messages_recipient_id_message_id ON smp_server.messages USING btree (recipient_id, message_id); + + + +CREATE INDEX idx_messages_recipient_id_msg_quota ON smp_server.messages USING btree (recipient_id, msg_quota); + + + +CREATE INDEX idx_messages_recipient_id_msg_ts ON smp_server.messages USING btree (recipient_id, msg_ts); + + + CREATE UNIQUE INDEX idx_msg_queues_link_id ON smp_server.msg_queues USING btree (link_id); @@ -98,7 +382,7 @@ CREATE UNIQUE INDEX idx_msg_queues_sender_id ON smp_server.msg_queues USING btre -CREATE INDEX idx_msg_queues_updated_at ON smp_server.msg_queues USING btree (deleted_at, updated_at); +CREATE INDEX idx_msg_queues_updated_at_recipient_id ON smp_server.msg_queues USING btree (deleted_at, updated_at, msg_queue_expire, recipient_id); @@ -106,6 +390,11 @@ CREATE INDEX idx_services_service_role ON smp_server.services USING btree (servi +ALTER TABLE ONLY smp_server.messages + ADD CONSTRAINT messages_recipient_id_fkey FOREIGN KEY (recipient_id) REFERENCES smp_server.msg_queues(recipient_id) ON UPDATE RESTRICT ON DELETE CASCADE; + + + ALTER TABLE ONLY smp_server.msg_queues ADD CONSTRAINT msg_queues_ntf_service_id_fkey FOREIGN KEY (ntf_service_id) REFERENCES smp_server.services(service_id) ON UPDATE RESTRICT ON DELETE SET NULL; diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 4dd6240a8..515a0ee77 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -114,7 +114,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size <$> readTVarIO (serviceSel s)) 0 addQueue_ :: STMQueueStore q -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q) - addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData} = do + addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData, rcvServiceId} = do sq <- mkQ rId qr add sq $>> withLog "addStoreQueue" st (\s -> logCreateQueue s rId qr) $> Right sq where @@ -122,8 +122,11 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where add q = atomically $ ifM hasId (pure $ Left DUPLICATE_) $ Right () <$ do TM.insert rId q queues TM.insert sId rId senders - forM_ notifier $ \NtfCreds {notifierId} -> TM.insert notifierId rId notifiers + forM_ notifier $ \NtfCreds {notifierId = nId, ntfServiceId} -> do + TM.insert nId rId notifiers + mapM_ (addServiceQueue st serviceNtfQueues nId) ntfServiceId forM_ queueData $ \(lnkId, _) -> TM.insert lnkId rId links + mapM_ (addServiceQueue st serviceRcvQueues rId) rcvServiceId hasId = anyM [TM.member rId queues, TM.member sId senders, hasNotifier, hasLink] hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.member notifierId notifiers) notifier hasLink = maybe (pure False) (\(lnkId, _) -> TM.member lnkId links) queueData @@ -225,7 +228,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where deleteQueueNotifier :: STMQueueStore q -> q -> IO (Either ErrorType (Maybe NtfCreds)) deleteQueueNotifier st sq = withQueueRec qr delete - $>>= \nc_ -> nc_ <$$ withLog "deleteQueueNotifier" st (`logDeleteNotifier` recipientId sq) + $>>= (<$$ withLog "deleteQueueNotifier" st (`logDeleteNotifier` recipientId sq)) where qr = queueRec sq delete q = forM (notifier q) $ \nc -> do @@ -261,11 +264,10 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where | changed = q <$$ withLog "updateQueueTime" st (\sl -> logUpdateQueueTime sl (recipientId sq) t) | otherwise = pure $ Right q - deleteStoreQueue :: STMQueueStore q -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q))) + deleteStoreQueue :: STMQueueStore q -> q -> IO (Either ErrorType QueueRec) deleteStoreQueue st sq = withQueueRec qr delete - $>>= \q -> withLog "deleteStoreQueue" st (`logDeleteQueue` rId) - >>= mapM (\_ -> (q,) <$> atomically (swapTVar (msgQueue sq) Nothing)) + $>>= (<$$ withLog "deleteStoreQueue" st (`logDeleteQueue` rId)) where rId = recipientId sq qr = queueRec sq diff --git a/src/Simplex/Messaging/Server/QueueStore/Types.hs b/src/Simplex/Messaging/Server/QueueStore/Types.hs index 55be4d21d..ee155cf91 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Types.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Types.hs @@ -17,10 +17,8 @@ import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.TMap (TMap) class StoreQueueClass q where - type MsgQueue q = mq | mq -> q recipientId :: q -> RecipientId queueRec :: q -> TVar (Maybe QueueRec) - msgQueue :: q -> TVar (Maybe (MsgQueue q)) withQueueLock :: q -> Text -> IO a -> IO a class StoreQueueClass q => QueueStoreClass q s where @@ -44,7 +42,7 @@ class StoreQueueClass q => QueueStoreClass q s where blockQueue :: s -> q -> BlockingInfo -> IO (Either ErrorType ()) unblockQueue :: s -> q -> IO (Either ErrorType ()) updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec) - deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q))) + deleteStoreQueue :: s -> q -> IO (Either ErrorType QueueRec) getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId) setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ()) getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)])) diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index 1dc2f56e6..ee08ebc93 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -30,12 +30,14 @@ where import Control.Applicative (optional, (<|>)) import Control.Logger.Simple (logError) +import Control.Monad import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isAsciiLower, isDigit, isHexDigit) import Data.Default (def) +import Data.Functor (($>)) import Data.IORef import Data.IP import Data.List.NonEmpty (NonEmpty (..)) @@ -58,7 +60,7 @@ import Simplex.Messaging.Parsers (parseAll, parseString) import Simplex.Messaging.Transport import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Transport.Shared -import Simplex.Messaging.Util (bshow, catchAll, tshow, (<$?>)) +import Simplex.Messaging.Util (bshow, catchAll, catchAll_, tshow, (<$?>)) import System.IO.Error import Text.Read (readMaybe) import UnliftIO.Exception (IOException) @@ -156,6 +158,11 @@ clientTransportConfig TransportClientConfig {logTLSErrors} = runTransportClient :: Transport c => TransportClientConfig -> Maybe SocksCredentials -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c 'TClient -> IO a) -> IO a runTransportClient = runTLSTransportClient defaultSupportedParams Nothing +data ConnectionHandle c + = CHSocket Socket + | CHContext T.Context + | CHTransport (c 'TClient) + runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe SocksCredentials -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c 'TClient -> IO a) -> IO a runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials, clientALPN, useSNI} socksCreds host port keyHash client = do serverCert <- newEmptyTMVarIO @@ -165,17 +172,22 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, connectTCP = case socksProxy of Just proxy -> connectSocksClient proxy socksCreds (hostAddr host) _ -> connectTCPClient hostName - c <- do - sock <- connectTCP port - mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e) + h <- newIORef Nothing + let set hc = (>>= \c -> writeIORef h (Just $ hc c) $> c) + E.bracket (set CHSocket $ connectTCP port) (\_ -> closeConn h) $ \sock -> do + mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive " <> tshow e) let tCfg = clientTransportConfig cfg -- No TLS timeout to avoid failing connections via SOCKS - tls <- connectTLS (Just hostName) tCfg clientParams sock - chain <- takePeerCertChain serverCert `E.onException` closeTLS tls + tls <- set CHContext $ connectTLS (Just hostName) tCfg clientParams sock + chain <- takePeerCertChain serverCert sent <- readIORef clientCredsSent - getTransportConnection tCfg sent chain tls - client c `E.finally` closeConnection c + client =<< set CHTransport (getTransportConnection tCfg sent chain tls) where + closeConn = readIORef >=> mapM_ (\c -> E.uninterruptibleMask_ $ closeConn_ c `catchAll_` pure ()) + closeConn_ = \case + CHSocket sock -> close sock + CHContext tls -> closeTLS tls + CHTransport c -> closeConnection c hostAddr = \case THIPv4 addr -> SocksAddrIPV4 $ tupleToHostAddress addr THIPv6 addr -> SocksAddrIPV6 addr @@ -199,10 +211,11 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err E.try (open addr) >>= either (`tryOpen` as) pure open :: AddrInfo -> IO Socket - open addr = do - sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) - connect sock $ addrAddress addr - pure sock + open addr = + E.bracketOnError + (socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)) + close + (\sock -> connect sock (addrAddress addr) $> sock) defaultSMPPort :: PortNumber defaultSMPPort = 5223 diff --git a/src/Simplex/Messaging/Transport/HTTP2/Client.hs b/src/Simplex/Messaging/Transport/HTTP2/Client.hs index bb3c2b3ac..91a8bf0e5 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Client.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Client.hs @@ -27,6 +27,7 @@ import qualified Network.TLS as T import Numeric.Natural (Natural) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Protocol (NetworkError (..), toNetworkError) import Simplex.Messaging.Transport (ALPN, STransportPeer (..), SessionId, TLS (tlsALPN, tlsPeerCert, tlsUniq), TransportPeer (..), TransportPeerI (..), getServerVerifyKey) import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), defaultTcpConnectTimeout, runTLSTransportClient) import Simplex.Messaging.Transport.HTTP2 @@ -89,7 +90,7 @@ defaultHTTP2ClientConfig = suportedTLSParams = http2TLSParams } -data HTTP2ClientError = HCResponseTimeout | HCNetworkError | HCIOError IOException +data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError IOException deriving (Show) getHTTP2Client :: HostName -> ServiceName -> Maybe XS.CertificateStore -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client) @@ -121,12 +122,15 @@ getVerifiedHTTP2ClientWith config host port disconnected setup = runClient :: HClient -> IO (Either HTTP2ClientError HTTP2Client) runClient c = do cVar <- newEmptyTMVarIO - action <- async $ setup (client c cVar) `E.finally` atomically (putTMVar cVar $ Left HCNetworkError) + action <- + async $ setup (client c cVar) `E.catch` \e -> do + atomically $ putTMVar cVar $ Left $ HCNetworkError $ toNetworkError e + E.throwIO e c_ <- connTimeout config `timeout` atomically (takeTMVar cVar) case c_ of Just (Right c') -> pure $ Right c' {action = Just action} Just (Left e) -> pure $ Left e - Nothing -> cancel action $> Left HCNetworkError + Nothing -> cancel action $> Left (HCNetworkError NETimeoutError) client :: HClient -> TMVar (Either HTTP2ClientError HTTP2Client) -> TLS p -> H.Client HTTP2Response client c cVar tls sendReq = do @@ -176,7 +180,7 @@ sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq reqTimeout `timeout` try (sendReq req process) >>= \case Just (Right r) -> pure $ Right r Just (Left e) -> disconnected $> Left (HCIOError e) - Nothing -> pure $ Left HCNetworkError + Nothing -> pure $ Left HCResponseTimeout where process r = do respBody <- getHTTP2Body r $ bodyHeadSize config diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 1a7dedef5..57fb11c21 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -171,28 +171,30 @@ catchAll_ :: IO a -> IO a -> IO a catchAll_ a = catchAll a . const {-# INLINE catchAll_ #-} -tryAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m (Either e a) -tryAllErrors err action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . err) +class Show e => AnyError e where fromSomeException :: E.SomeException -> e + +tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a) +tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException) {-# INLINE tryAllErrors #-} -tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a) -tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err) +tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a) +tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException) {-# INLINE tryAllErrors' #-} -catchAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a -catchAllErrors err action handler = tryAllErrors err action >>= either handler pure +catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a +catchAllErrors action handler = tryAllErrors action >>= either handler pure {-# INLINE catchAllErrors #-} -catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a -catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure +catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a +catchAllErrors' action handler = tryAllErrors' action >>= either handler pure {-# INLINE catchAllErrors' #-} -catchThrow :: MonadUnliftIO m => ExceptT e m a -> (E.SomeException -> e) -> ExceptT e m a -catchThrow action err = catchAllErrors err action throwE +catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a +action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err) {-# INLINE catchThrow #-} -allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a -allFinally err action final = tryAllErrors err action >>= \r -> final >> except r +allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a +allFinally action final = tryAllErrors action >>= \r -> final >> except r {-# INLINE allFinally #-} eitherToMaybe :: Either a b -> Maybe b @@ -209,17 +211,25 @@ firstRow f e a = second f . listToEither e <$> a maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b) maybeFirstRow f q = fmap f . listToMaybe <$> q +maybeFirstRow' :: Functor f => b -> (a -> b) -> f [a] -> f b +maybeFirstRow' def f q = maybe def f . listToMaybe <$> q + firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b) firstRow' f e a = (f <=< listToEither e) <$> a groupOn :: Eq k => (a -> k) -> [a] -> [[a]] groupOn = groupBy . eqOn - where - -- it is equivalent to groupBy ((==) `on` f), - -- but it redefines `on` to avoid duplicate computation for most values. - -- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn - -- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f` - eqOn f x = let fx = f x in \y -> fx == f y + +groupOn' :: Eq k => (a -> k) -> [a] -> [NonEmpty a] +groupOn' = L.groupBy . eqOn + +-- it is equivalent to groupBy ((==) `on` f), +-- but it redefines `on` to avoid duplicate computation for most values. +-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn +-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f` +eqOn :: Eq k => (a -> k) -> a -> a -> Bool +eqOn f x = let fx = f x in \y -> fx == f y +{-# INLINE eqOn #-} groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]] groupAllOn f = groupOn f . sortOn f diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index bde72fb23..a9970c273 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -306,14 +306,8 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca, atomically $ takeTMVar endSession logDebug "Session ended" -catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a -catchRCError = catchAllErrors $ \e -> case fromException e of - Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity - _ -> RCEException $ show e -{-# INLINE catchRCError #-} - putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a -a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e +a `putRCError` r = a `catchAllErrors` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e sendRCPacket :: Encoding a => TLS p -> a -> ExceptT RCErrorType IO () sendRCPacket tls pkt = do @@ -395,7 +389,7 @@ discoverRCCtrl subscribers pairings = pure r where loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a - loop action = action `catchRCError` \e -> logError (tshow e) >> loop action + loop action = action `catchAllErrors` \e -> logError (tshow e) >> loop action findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation) findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do diff --git a/src/Simplex/RemoteControl/Types.hs b/src/Simplex/RemoteControl/Types.hs index 7b8638e67..93f0c92c7 100644 --- a/src/Simplex/RemoteControl/Types.hs +++ b/src/Simplex/RemoteControl/Types.hs @@ -19,6 +19,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Word (Word16) import qualified Data.X509 as X +import qualified Network.TLS as TLS import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Encoding @@ -26,7 +27,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport (TLS, TSbChainKeys, TransportPeer (..)) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (safeDecodeUtf8) +import Simplex.Messaging.Util (AnyError (..), safeDecodeUtf8) import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange) import Simplex.Messaging.Version.Internal import UnliftIO @@ -50,6 +51,12 @@ data RCErrorType | RCESyntax {syntaxErr :: String} deriving (Eq, Show, Exception) +instance AnyError RCErrorType where + fromSomeException e = case fromException e of + Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity + _ -> RCEException $ show e + {-# INLINE fromSomeException #-} + instance StrEncoding RCErrorType where strEncode = \case RCEInternal err -> "INTERNAL" <> text err diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 5340b19a4..a6ee6d7f2 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -90,7 +90,7 @@ import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction) import Simplex.Messaging.Agent.Store.Interface import qualified Simplex.Messaging.Agent.Store.DB as DB -import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationConfirmation (..), MigrationError (..)) import Simplex.Messaging.Client (pattern NRMInteractive, NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (..), defaultClientConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -98,7 +98,7 @@ import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF) -import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), initialSMPClientVersion, srvHostnamesSMPClientVersion, supportedSMPClientVRange) +import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, NetworkError (..), ProtocolServer (..), SubscriptionMode (..), initialSMPClientVersion, srvHostnamesSMPClientVersion, supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..), ServerStoreCfg (..), StorePaths (..)) import Simplex.Messaging.Server.Expiration @@ -114,6 +114,7 @@ import Test.Hspec hiding (fit, it) import UnliftIO import Util import XFTPClient (testXFTPServer) + #if defined(dbPostgres) import Fixtures #endif @@ -122,6 +123,7 @@ import qualified Database.PostgreSQL.Simple as PSQL import Simplex.Messaging.Agent.Store (Connection (..), StoredRcvQueue (..), SomeConn (..)) import Simplex.Messaging.Agent.Store.AgentStore (getConn) import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue) +import Simplex.Messaging.Server.MsgStore.Postgres (PostgresQueue) import Simplex.Messaging.Server.MsgStore.Types (QSType (..)) import Simplex.Messaging.Server.QueueStore.Postgres import Simplex.Messaging.Server.QueueStore.Types (QueueStoreClass (..)) @@ -177,7 +179,7 @@ pGet' c skipWarn = do case cmd of CONNECT {} -> pGet c DISCONNECT {} -> pGet c - ERR (BROKER _ NETWORK) -> pGet c + ERR (BROKER _ (NETWORK _)) -> pGet c MWARN {} | skipWarn -> pGet c RFWARN {} | skipWarn -> pGet c SFWARN {} | skipWarn -> pGet c @@ -516,7 +518,7 @@ functionalAPITests ps = do it "should pass without basic auth" $ testSMPServerConnectionTest ps Nothing (noAuthSrv testSMPServer2) `shouldReturn` Nothing let srv1 = testSMPServer2 {keyHash = "1234"} it "should fail with incorrect fingerprint" $ do - testSMPServerConnectionTest ps Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) + testSMPServerConnectionTest ps Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) $ NETWORK NEUnknownCAError) describe "server with password" $ do let auth = Just "abcd" srv = ProtoServerWithAuth testSMPServer2 @@ -1105,7 +1107,7 @@ testAsyncServerOffline ps = withAgentClients2 $ \alice bob -> do (bobId, cReq) <- withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ createConnection alice 1 True SCMInvitation Nothing SMSubscribe -- connection fails - Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe + Left (BROKER _ (NETWORK _)) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe ("", "", DOWN srv conns) <- nGet alice srv `shouldBe` testSMPServer conns `shouldBe` [bobId] @@ -1172,13 +1174,13 @@ testInvitationErrors ps restart = do ("", "", DOWN _ [_]) <- nGet a aId <- runRight $ A.prepareConnectionToJoin b 1 True cReq PQSupportOn -- fails to secure the queue on testPort - BROKER srv NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe + BROKER srv (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe (testPort `isSuffixOf` srv) `shouldBe` True withServer1 ps $ do ("", "", UP _ [_]) <- nGet a let loopSecure = do -- secures the queue on testPort, but fails to create reply queue on testPort2 - BROKER srv2 NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe + BROKER srv2 (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe unless (testPort2 `isSuffixOf` srv2) $ putStrLn "retrying secure" >> threadDelay 200000 >> loopSecure loopSecure ("", "", DOWN _ [_]) <- nGet a @@ -1186,7 +1188,7 @@ testInvitationErrors ps restart = do threadDelay 200000 let loopCreate = do -- creates the reply queue on testPort2, but fails to send it to testPort - BROKER srv' NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe + BROKER srv' (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe unless (testPort `isSuffixOf` srv') $ putStrLn "retrying create" >> threadDelay 200000 >> loopCreate loopCreate restartAgentB restart b [aId] @@ -1242,12 +1244,12 @@ testContactErrors ps restart = do ("", "", DOWN _ [_]) <- nGet a aId <- runRight $ A.prepareConnectionToJoin b 1 True cReq PQSupportOn -- fails to create queue on testPort2 - BROKER srv2 NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe + BROKER srv2 (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe (testPort2 `isSuffixOf` srv2) `shouldBe` True b' <- restartAgentB restart b [aId] let loopCreate2 = do -- creates the reply queue on testPort2, but fails to send invitation to testPort - BROKER srv' NETWORK <- runLeft $ A.joinConnection b' NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe + BROKER srv' (NETWORK _) <- runLeft $ A.joinConnection b' NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe unless (testPort `isSuffixOf` srv') $ putStrLn "retrying create 2" >> threadDelay 200000 >> loopCreate2 b'' <- withServer2 ps $ do loopCreate2 @@ -1270,7 +1272,7 @@ testContactErrors ps restart = do ("", "", UP _ [_]) <- nGet b'' let loopSecure = do -- secures the queue on testPort2, but fails to create reply queue on testPort - BROKER srv NETWORK <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe + BROKER srv (NETWORK _) <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe unless (testPort `isSuffixOf` srv) $ putStrLn "retrying secure" >> threadDelay 200000 >> loopSecure loopSecure ("", "", DOWN _ [_]) <- nGet b'' @@ -1278,7 +1280,7 @@ testContactErrors ps restart = do ("", "", UP _ [_]) <- nGet a let loopCreate = do -- creates the reply queue on testPort, but fails to send confirmation to testPort2 - BROKER srv2' NETWORK <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe + BROKER srv2' (NETWORK _) <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe unless (testPort2 `isSuffixOf` srv2') $ putStrLn "retrying create" >> threadDelay 200000 >> loopCreate loopCreate restartAgentA restart a [contactId, bId] @@ -1524,20 +1526,29 @@ testOldContactQueueShortLink ps@(_, msType) = withAgentClients2 $ \a b -> do A.createConnection a NRMInteractive 1 True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate -- make it an "old" queue let updateStoreLog f = replaceSubstringInFile f " queue_mode=C" "" - () <- case testServerStoreConfig msType of - ASSCfg _ _ (SSCMemory (Just StorePaths {storeLogFile})) -> updateStoreLog storeLogFile - ASSCfg _ _ (SSCMemoryJournal {storeLogFile}) -> updateStoreLog storeLogFile - ASSCfg _ _ (SSCDatabaseJournal {storeCfg}) -> do #if defined(dbServerPostgres) - let AgentClient {agentEnv = Env {store}} = a - Right (SomeConn _ (ContactConnection _ RcvQueue {rcvId})) <- withTransaction store (`getConn` contactId) - st :: PostgresQueueStore (JournalQueue 'QSPostgres) <- newQueueStore @(JournalQueue 'QSPostgres) storeCfg - Right 1 <- runExceptT $ withDB' "test" st $ \db -> PSQL.execute db "UPDATE msg_queues SET queue_mode = ? WHERE recipient_id = ?" (Nothing :: Maybe QueueMode, rcvId) - closeQueueStore @(JournalQueue 'QSPostgres) st -#else - error "no dbServerPostgres flag" + updateDbStore :: PostgresQueueStore s -> IO () + updateDbStore st = do + let AgentClient {agentEnv = Env {store}} = a + Right (SomeConn _ (ContactConnection _ RcvQueue {rcvId})) <- withTransaction store (`getConn` contactId) + Right 1 <- runExceptT $ withDB' "test" st $ \db -> PSQL.execute db "UPDATE msg_queues SET queue_mode = ? WHERE recipient_id = ?" (Nothing :: Maybe QueueMode, rcvId) + pure () +#endif + () <- case testServerStoreConfig msType of + ASSCfg _ _ (SSCMemory sp_) -> mapM_ (\StorePaths {storeLogFile} -> updateStoreLog storeLogFile) sp_ + ASSCfg _ _ SSCMemoryJournal {storeLogFile} -> updateStoreLog storeLogFile +#if defined(dbServerPostgres) + ASSCfg _ _ SSCDatabaseJournal {storeCfg} -> do + st :: PostgresQueueStore (JournalQueue 'QSPostgres) <- newQueueStore @(JournalQueue 'QSPostgres) (storeCfg, True) + updateDbStore st + closeQueueStore @(JournalQueue 'QSPostgres) st + ASSCfg _ _ (SSCDatabase storeCfg) -> do + st :: PostgresQueueStore PostgresQueue <- newQueueStore @PostgresQueue (storeCfg, False) + updateDbStore st + closeQueueStore @PostgresQueue st +#else + ASSCfg _ _ SSCDatabaseJournal {} -> error "no dbServerPostgres flag" #endif - _ -> pure () withSmpServer ps $ do let userData = UserLinkData "some user data" @@ -1743,7 +1754,7 @@ testDuplicateMessage ps = do -- commenting two lines below and uncommenting further two lines would also runRight_, -- it is the scenario tested above, when the message was not acknowledged by the user threadDelay 200000 - Left (BROKER _ NETWORK) <- runExceptT $ ackMessage bob1 aliceId 3 Nothing + Left (BROKER _ (NETWORK _)) <- runExceptT $ ackMessage bob1 aliceId 3 Nothing disposeAgentClient alice disposeAgentClient bob1 @@ -1827,8 +1838,8 @@ testDeliveryAfterSubscriptionError ps = do pure (aId, bId) withAgentClients2 $ \a b -> do - Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection a bId - Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection b aId + Left (BROKER _ (NETWORK _)) <- runExceptT $ subscribeConnection a bId + Left (BROKER _ (NETWORK _)) <- runExceptT $ subscribeConnection b aId withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do withUP a bId $ \case ("", c, SENT 2) -> c == bId; _ -> False withUP b aId $ \case ("", c, Msg "hello") -> c == aId; _ -> False @@ -1872,7 +1883,7 @@ testExpireMessage ps = 2 <- runRight $ sendMessage a bId SMP.noMsgFlags "1" threadDelay 1500000 3 <- runRight $ sendMessage a bId SMP.noMsgFlags "2" -- this won't expire - get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False + get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && networkOrTimeoutError e; _ -> False withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do withUP a bId $ \case ("", _, SENT 3) -> True; _ -> False withUP b aId $ \case ("", _, MsgErr 2 (MsgSkipped 2 2) "2") -> True; _ -> False @@ -1891,8 +1902,8 @@ testExpireManyMessages ps = 4 <- sendMessage a bId SMP.noMsgFlags "3" liftIO $ threadDelay 2000000 5 <- sendMessage a bId SMP.noMsgFlags "4" -- this won't expire - get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False - let expected c e = bId == c && (e == TIMEOUT || e == NETWORK) + get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && networkOrTimeoutError e; _ -> False + let expected c e = bId == c && networkOrTimeoutError e get a >>= \case ("", c, MERR 3 (BROKER _ e)) -> do liftIO $ expected c e `shouldBe` True @@ -2633,7 +2644,7 @@ testDeleteConnectionAsync ps = runRight_ $ do deleteConnectionsAsync a False connIds nGet a =##> \case ("", "", DOWN {}) -> True; _ -> False - let delOk = \case (c, _, _, Just (BROKER _ e)) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False + let delOk = \case (c, _, _, Just (BROKER _ e)) -> c `elem` connIds && networkOrTimeoutError e; _ -> False get a =##> \case ("", "", DEL_RCVQS rs) -> length rs == 3 && all delOk rs; _ -> False get a =##> \case ("", "", DEL_CONNS cs) -> length cs == 3 && all (`elem` connIds) cs; _ -> False liftIO $ noMessages a "nothing else should be delivered to alice" @@ -2691,7 +2702,7 @@ testWaitDelivery ps = 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" deleteConnectionsAsync alice True [bobId] - get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" @@ -2748,7 +2759,7 @@ testWaitDeliveryAUTHErr ps = 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" deleteConnectionsAsync alice True [bobId] - get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" @@ -2788,7 +2799,7 @@ testWaitDeliveryTimeout ps = 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" deleteConnectionsAsync alice True [bobId] - get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False get alice =##> \case ("", "", DEL_CONNS [cId]) -> cId == bobId; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" @@ -2828,7 +2839,7 @@ testWaitDeliveryTimeout2 ps = 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" deleteConnectionsAsync alice True [bobId] - get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False get alice =##> \case ("", "", DEL_CONNS [cId]) -> cId == bobId; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" liftIO $ noMessages bob "nothing else should be delivered to bob" @@ -2849,6 +2860,12 @@ testWaitDeliveryTimeout2 ps = baseId = 1 msgId = subtract baseId +networkOrTimeoutError :: BrokerErrorType -> Bool +networkOrTimeoutError = \case + TIMEOUT -> True + NETWORK _ -> True + _ -> False + testJoinConnectionAsyncReplyErrorV8 :: HasCallStack => (ASrvTransport, AStoreType) -> IO () testJoinConnectionAsyncReplyErrorV8 ps@(t, ASType qsType _) = do let initAgentServersSrv2 = initAgentServers {smp = userServers [testSMPServer2]} @@ -2975,7 +2992,7 @@ testUsersNoServer ps = withAgentClientsCfg2 aCfg agentCfg $ \a b -> do nGet b =##> \case ("", "", DOWN _ cs) -> length cs == 2; _ -> False runRight_ $ do deleteUser a auId True - get a =##> \case ("", "", DEL_RCVQS [(c, _, _, Just (BROKER _ e))]) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False + get a =##> \case ("", "", DEL_RCVQS [(c, _, _, Just (BROKER _ e))]) -> c == bId' && networkOrTimeoutError e;; _ -> False get a =##> \case ("", "", DEL_CONNS [c]) -> c == bId'; _ -> False nGet a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False liftIO $ noMessages a "nothing else should be delivered to alice" @@ -3613,13 +3630,13 @@ getSMPAgentClient' clientId cfg' initServers dbPath = do #if defined(dbPostgres) createStore :: String -> IO (Either MigrationError DBStore) -createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) MCError +createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) (MigrationConfig MCError Nothing) insertUser :: DBStore -> IO () insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES") #else createStore :: String -> IO (Either MigrationError DBStore) -createStore dbPath = createAgentStore (DBOpts dbPath "" False True DB.TQOff) MCError +createStore dbPath = createAgentStore (DBOpts dbPath "" False True DB.TQOff) (MigrationConfig MCError Nothing) insertUser :: DBStore -> IO () insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users (user_id) VALUES (1)") @@ -3639,7 +3656,7 @@ testServerMultipleIdentities = exchangeGreetings alice bobId bob aliceId -- this saves queue with second server identity bob' <- liftIO $ do - Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe + Left (BROKER _ (NETWORK _)) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe disposeAgentClient bob threadDelay 250000 getSMPAgentClient' 3 agentCfg initAgentServers testDB2 diff --git a/tests/AgentTests/MigrationTests.hs b/tests/AgentTests/MigrationTests.hs index 56bf4e128..8245cfd51 100644 --- a/tests/AgentTests/MigrationTests.hs +++ b/tests/AgentTests/MigrationTests.hs @@ -212,7 +212,7 @@ createStore randSuffix migrations confirmMigrations = do poolSize = 1, createSchema = True } - createDBStore dbOpts migrations confirmMigrations + createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing) cleanup :: Word32 -> IO () cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix) @@ -235,7 +235,7 @@ createStore randSuffix migrations confirmMigrations = do vacuum = True, track = DB.TQOff } - createDBStore dbOpts migrations confirmMigrations + createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing) cleanup :: Word32 -> IO () cleanup randSuffix = removeFile (testDB randSuffix) diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index c7be1a3e2..6a1c5cef9 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -79,7 +79,7 @@ import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Store.Postgres (closeNtfDbStore, newNtfDbStore, withDB') import Simplex.Messaging.Notifications.Types (NtfTknAction (..), NtfToken (..)) import Simplex.Messaging.Parsers (parseAll) -import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NMsgMeta (..), NtfServer, ProtocolServer (..), SMPMsgMeta (..), SubscriptionMode (..)) +import Simplex.Messaging.Protocol (ErrorType (AUTH), NetworkError (..), MsgFlags (MsgFlags), NMsgMeta (..), NtfServer, ProtocolServer (..), SMPMsgMeta (..), SubscriptionMode (..)) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..)) import Simplex.Messaging.Transport (ASrvTransport) @@ -137,7 +137,7 @@ notificationTests ps@(t, _) = do it "should pass" $ testRunNTFServerTests t testNtfServer `shouldReturn` Nothing let srv1 = testNtfServer {keyHash = "1234"} it "should fail with incorrect fingerprint" $ do - testRunNTFServerTests t srv1 `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) + testRunNTFServerTests t srv1 `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) $ NETWORK NEUnknownCAError) describe "Managing notification subscriptions" $ do describe "should create notification subscription for existing connection" $ testNtfMatrix ps testNotificationSubscriptionExistingConnection @@ -321,7 +321,7 @@ testNtfTokenServerRestartReverify t apns = do runRight_ $ do verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" - Left (BROKER _ NETWORK) <- tryE $ verifyNtfToken a tkn nonce verification + Left (BROKER _ (NETWORK _)) <- tryE $ verifyNtfToken a tkn nonce verification pure () threadDelay 1500000 withAgent 2 agentCfg initAgentServers testDB $ \a' -> @@ -478,7 +478,7 @@ testNtfTokenChangeServers t apns = tkn2 <- registerTestToken a "xyzw" NMInstant apns getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort -- not yet changed deleteNtfToken a tkn2 -- force server switch - Left BROKER {brokerErr = NETWORK} <- tryError $ registerTestToken a "qwer" NMInstant apns -- ok, it's down for now + Left BROKER {brokerErr = (NETWORK _)} <- tryError $ registerTestToken a "qwer" NMInstant apns -- ok, it's down for now getTestNtfTokenPort a >>= \port2 -> liftIO $ port2 `shouldBe` ntfTestPort2 -- but the token got updated killThread ntf withNtfServerOn t ntfTestPort2 ntfTestDBCfg2 $ runRight_ $ do diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 056dc5dd3..257c3f90f 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -46,7 +46,7 @@ import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations) import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Store.SQLite.Common (DBStore (..), withTransaction') import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationConfirmation (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Crypto.Ratchet (pattern IKPQOn) @@ -83,7 +83,7 @@ createEncryptedStore key keepKey = do -- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous -- IO operations on multiple similarly named files; error seems to be environment specific r <- randomIO :: IO Word32 - Right st <- createDBStore (DBOpts (testDB <> show r) key keepKey True DB.TQOff) appMigrations MCError + Right st <- createDBStore (DBOpts (testDB <> show r) key keepKey True DB.TQOff) appMigrations (MigrationConfig MCError Nothing) withTransaction' st (`SQL.execute_` "INSERT INTO users (user_id) VALUES (1);") pure st diff --git a/tests/AgentTests/SchemaDump.hs b/tests/AgentTests/SchemaDump.hs index b64e2ec81..fdb172883 100644 --- a/tests/AgentTests/SchemaDump.hs +++ b/tests/AgentTests/SchemaDump.hs @@ -14,7 +14,7 @@ import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import Simplex.Messaging.Agent.Store.SQLite.DB (TrackQueries (..)) import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations -import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration) +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration) import Simplex.Messaging.Util (ifM) import System.Directory (doesFileExist, removeFile) import System.Process (readCreateProcess, shell) @@ -51,7 +51,7 @@ testVerifySchemaDump :: IO () testVerifySchemaDump = do savedSchema <- ifM (doesFileExist appSchema) (readFile appSchema) (pure "") savedSchema `deepseq` pure () - void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCConsole + void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCConsole Nothing) getSchema testDB appSchema `shouldReturn` savedSchema removeFile testDB @@ -59,14 +59,14 @@ testVerifyLintFKeyIndexes :: IO () testVerifyLintFKeyIndexes = do savedLint <- ifM (doesFileExist appLint) (readFile appLint) (pure "") savedLint `deepseq` pure () - void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCConsole + void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCConsole Nothing) getLintFKeyIndexes testDB "tests/tmp/agent_lint.sql" `shouldReturn` savedLint removeFile testDB testSchemaMigrations :: IO () testSchemaMigrations = do let noDownMigrations = dropWhileEnd (\Migration {down} -> isJust down) appMigrations - Right st <- createDBStore (DBOpts testDB "" False True TQOff) noDownMigrations MCError + Right st <- createDBStore (DBOpts testDB "" False True TQOff) noDownMigrations (MigrationConfig MCError Nothing) mapM_ (testDownMigration st) $ drop (length noDownMigrations) appMigrations closeDBStore st removeFile testDB @@ -89,7 +89,7 @@ testSchemaMigrations = do testUsersMigrationNew :: IO () testUsersMigrationNew = do - Right st <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCError + Right st <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCError Nothing) withTransaction' st (`SQL.query_` "SELECT user_id FROM users;") `shouldReturn` ([] :: [Only Int]) closeDBStore st @@ -97,11 +97,11 @@ testUsersMigrationNew = do testUsersMigrationOld :: IO () testUsersMigrationOld = do let beforeUsers = takeWhile (("m20230110_users" /=) . name) appMigrations - Right st <- createDBStore (DBOpts testDB "" False True TQOff) beforeUsers MCError + Right st <- createDBStore (DBOpts testDB "" False True TQOff) beforeUsers (MigrationConfig MCError Nothing) withTransaction' st (`SQL.query_` "SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users';") `shouldReturn` ([] :: [Only String]) closeDBStore st - Right st' <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCYesUp + Right st' <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCYesUp Nothing) withTransaction' st' (`SQL.query_` "SELECT user_id FROM users;") `shouldReturn` ([Only (1 :: Int)]) closeDBStore st' diff --git a/tests/CLITests.hs b/tests/CLITests.hs index c0c7c04d2..30d798ca7 100644 --- a/tests/CLITests.hs +++ b/tests/CLITests.hs @@ -10,7 +10,6 @@ import AgentTests.FunctionalAPITests (runRight_) import Control.Logger.Simple import Control.Monad import qualified Crypto.PubKey.RSA as RSA -import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as BL import qualified Data.HashMap.Strict as HM import Data.Ini (Ini (..), lookupValue, readIniFile, writeIniFile) @@ -46,6 +45,7 @@ import UnliftIO.Exception (bracket) import Util #if defined(dbServerPostgres) +import qualified Data.ByteString.Char8 as B import qualified Database.PostgreSQL.Simple as PSQL import Database.PostgreSQL.Simple.Types (Query (..)) import NtfClient (ntfTestServerDBConnectInfo, ntfTestServerDBConnstr, ntfTestStoreDBOpts) diff --git a/tests/CoreTests/MsgStoreTests.hs b/tests/CoreTests/MsgStoreTests.hs index d25b00c7c..3961a9ce0 100644 --- a/tests/CoreTests/MsgStoreTests.hs +++ b/tests/CoreTests/MsgStoreTests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} @@ -23,9 +24,9 @@ import Control.Monad import Control.Monad.IO.Class import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) -import qualified Data.ByteString.Base64.URL as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Int (Int64) import Data.List (isPrefixOf, isSuffixOf) import Data.Maybe (fromJust) import Data.Time.Clock (addUTCTime) @@ -33,9 +34,9 @@ import Data.Time.Clock.System (SystemTime (..), getSystemTime) import SMPClient (testStoreLogFile, testStoreMsgsDir, testStoreMsgsDir2, testStoreMsgsFile, testStoreMsgsFile2) import Simplex.Messaging.Crypto (pattern MaxLenBS) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Protocol (EntityId (..), LinkId, Message (..), QueueLinkData, RecipientId, SParty (..), noMsgFlags) +import Simplex.Messaging.Protocol (EntityId (..), ErrorType, LinkId, Message (..), QueueLinkData, RecipientId, SParty (..), noMsgFlags) import Simplex.Messaging.Server (exportMessages, importMessages, printMessageStats) -import Simplex.Messaging.Server.Env.STM (journalMsgStoreDepth, readWriteQueueStore) +import Simplex.Messaging.Server.Env.STM (MsgStore (..), journalMsgStoreDepth, readWriteQueueStore) import Simplex.Messaging.Server.Expiration (ExpirationConfig (..), expireBeforeEpoch) import Simplex.Messaging.Server.MsgStore.Journal import Simplex.Messaging.Server.MsgStore.STM @@ -50,28 +51,54 @@ import System.IO (IOMode (..), withFile) import Test.Hspec hiding (fit, it) import Util +#if defined(dbServerPostgres) +import Database.PostgreSQL.Simple (Only (..)) +import qualified Database.PostgreSQL.Simple as DB +import Simplex.Messaging.Agent.Store.Postgres.Common +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) +import Simplex.Messaging.Server.MsgStore.Postgres +import Simplex.Messaging.Server.QueueStore.Postgres +import SMPClient (postgressBracket, testServerDBConnectInfo, testStoreDBOpts) +#endif + msgStoreTests :: Spec msgStoreTests = do around (withMsgStore testSMTStoreConfig) $ describe "STM message store" someMsgStoreTests around (withMsgStore $ testJournalStoreCfg MQStoreCfg) $ describe "Journal message store" $ do someMsgStoreTests + journalMsgStoreTests it "should export and import journal store" testExportImportStore - describe "queue state" $ do - it "should restore queue state from the last line" testQueueState - it "should recover when message is written and state is not" testMessageState - it "should remove journal files when queue is empty" testRemoveJournals - describe "missing files" $ do - it "should create read file when missing" testReadFileMissing - it "should switch to write file when read file missing" testReadFileMissingSwitch - it "should create write file when missing" testWriteFileMissing - it "should create read file when read and write files are missing" testReadAndWriteFilesMissing +#if defined(dbServerPostgres) + around_ (postgressBracket testServerDBConnectInfo) $ do + around (withMsgStore $ testJournalStoreCfg $ PQStoreCfg testPostgresStoreCfg) $ + describe "Postgres+journal message store" $ do + someMsgStoreTests + journalMsgStoreTests + around (withMsgStore testPostgresStoreConfig) $ + describe "Postgres-only message store" $ do + someMsgStoreTests + it "should correctly update message counts and canWrite flag" testUpdateMessageCounts + it "tryDelPeekMsg (ACK not from NSE) should reset message counts when queue is empty" testResetMessageCounts +#endif describe "Journal message store: queue state backup expiration" $ do it "should remove old queue state backups" testRemoveQueueStateBackups it "should expire messages in idle queues" testExpireIdleQueues where + journalMsgStoreTests :: SpecWith (JournalMsgStore s) + journalMsgStoreTests = do + describe "queue state" $ do + it "should restore queue state from the last line" testQueueState + it "should recover when message is written and state is not" testMessageState + it "should remove journal files when queue is empty" testRemoveJournals + describe "missing files" $ do + it "should create read file when missing" testReadFileMissing + it "should switch to write file when read file missing" testReadFileMissingSwitch + it "should create write file when missing" testWriteFileMissing + it "should create read file when read and write files are missing" testReadAndWriteFilesMissing someMsgStoreTests :: MsgStoreClass s => SpecWith s someMsgStoreTests = do it "should get queue and store/read messages" testGetQueue + it "should write/ack messages" testWriteAckMessages it "should not fail on EOF when changing read journal" testChangeReadJournal -- TODO constrain to STM stores? @@ -96,6 +123,24 @@ testJournalStoreCfg queueStoreCfg = keepMinBackups = 1 } +#if defined(dbServerPostgres) +testPostgresStoreConfig :: PostgresMsgStoreCfg +testPostgresStoreConfig = + PostgresMsgStoreCfg + { queueStoreCfg = testPostgresStoreCfg, + quota = 3 + } + +testPostgresStoreCfg :: PostgresStoreCfg +testPostgresStoreCfg = + PostgresStoreCfg + { dbOpts = testStoreDBOpts, + dbStoreLogPath = Nothing, + confirmMigrations = MCYesUp, + deletedTTL = 86400 + } +#endif + mkMessage :: MonadIO m => ByteString -> m Message mkMessage body = liftIO $ do g <- C.newRandom @@ -138,7 +183,6 @@ testNewQueueRecData g qm queueData = do where rndId = atomically $ EntityId <$> C.randomBytes 24 g --- TODO constrain to STM stores testGetQueue :: MsgStoreClass s => s -> IO () testGetQueue ms = do g <- C.newRandom @@ -181,7 +225,28 @@ testGetQueue ms = do (Nothing, Nothing) <- tryDelPeekMsg ms q mId8 void $ ExceptT $ deleteQueue ms q --- TODO constrain to STM stores +-- TODO [messages] test concurrent writing and reading +testWriteAckMessages :: MsgStoreClass s => s -> IO () +testWriteAckMessages ms = do + g <- C.newRandom + (rId1, qr1) <- testNewQueueRec g QMMessaging + (rId2, qr2) <- testNewQueueRec g QMMessaging + runRight_ $ do + q1 <- ExceptT $ addQueue ms rId1 qr1 + q2 <- ExceptT $ addQueue ms rId2 qr2 + let write q s = writeMsg ms q True =<< mkMessage s + 0 <- deleteExpiredMsgs ms q1 0 -- won't expire anything, used here to mimic message sending with expiration on SEND + Just (Message {msgId = mId1}, True) <- write q1 "message 1" + (Msg "message 1", Nothing) <- tryDelPeekMsg ms q1 mId1 + 0 <- deleteExpiredMsgs ms q2 0 + Just (Message {msgId = mId2}, True) <- write q2 "message 2" + (Msg "message 2", Nothing) <- tryDelPeekMsg ms q2 mId2 + 0 <- deleteExpiredMsgs ms q2 0 + Just (Message {msgId = mId3}, True) <- write q2 "message 3" + (Msg "message 3", Nothing) <- tryDelPeekMsg ms q2 mId3 + void $ ExceptT $ deleteQueue ms q1 + void $ ExceptT $ deleteQueue ms q2 + testChangeReadJournal :: MsgStoreClass s => s -> IO () testChangeReadJournal ms = do g <- C.newRandom @@ -226,9 +291,16 @@ testExportImportStore ms = do pure () length <$> listDirectory (msgQueueDirectory ms rId1) `shouldReturn` 2 length <$> listDirectory (msgQueueDirectory ms rId2) `shouldReturn` 3 - exportMessages False ms testStoreMsgsFile False + exportMessages False (StoreJournal ms) testStoreMsgsFile False closeMsgStore ms closeStoreLog sl + -- export with closed queues and compare + ms2 <- newMsgStore $ testJournalStoreCfg MQStoreCfg + readWriteQueueStore True (mkQueue ms2 True) testStoreLogFile (stmQueueStore ms2) >>= closeStoreLog + exportMessages False (StoreJournal ms2) (testStoreMsgsFile <> ".copy") False + s <- B.readFile testStoreMsgsFile + B.readFile (testStoreMsgsFile <> ".copy") `shouldReturn` s + let cfg = (testJournalStoreCfg MQStoreCfg :: JournalStoreConfig 'QSMemory) {storePath = testStoreMsgsDir2} ms' <- newMsgStore cfg readWriteQueueStore True (mkQueue ms' True) testStoreLogFile (stmQueueStore ms') >>= closeStoreLog @@ -237,21 +309,93 @@ testExportImportStore ms = do printMessageStats "Messages" stats length <$> listDirectory (msgQueueDirectory ms rId1) `shouldReturn` 2 length <$> listDirectory (msgQueueDirectory ms rId2) `shouldReturn` 3 -- 2 message files - exportMessages False ms' testStoreMsgsFile2 False + exportMessages False (StoreJournal ms') testStoreMsgsFile2 False (B.readFile testStoreMsgsFile2 `shouldReturn`) =<< B.readFile (testStoreMsgsFile <> ".bak") stmStore <- newMsgStore testSMTStoreConfig readWriteQueueStore True (mkQueue stmStore True) testStoreLogFile (queueStore stmStore) >>= closeStoreLog MessageStats {storedMsgsCount = 5, expiredMsgsCount = 0, storedQueues = 2} <- importMessages False stmStore testStoreMsgsFile2 Nothing False - exportMessages False stmStore testStoreMsgsFile False + exportMessages False (StoreMemory stmStore) testStoreMsgsFile False (B.sort <$> B.readFile testStoreMsgsFile `shouldReturn`) =<< (B.sort <$> B.readFile (testStoreMsgsFile2 <> ".bak")) +#if defined(dbServerPostgres) +testUpdateMessageCounts :: PostgresMsgStore -> IO () +testUpdateMessageCounts ms = do + g <- C.newRandom + (rId, qr) <- testNewQueueRec g QMMessaging + runRight_ $ do + q <- ExceptT $ addQueue ms rId qr + let write s = writeMsg ms q True =<< mkMessage s + hasSize = checkQueueSize ms + q `hasSize` (0, True, False) + Just (Message {msgId = mId1}, True) <- write "message 1" + q `hasSize` (1, True, True) + Just (Message {msgId = mId2}, False) <- write "message 2" + q `hasSize` (2, True, True) + Just (Message {msgId = mId3}, False) <- write "message 3" + q `hasSize` (3, True, True) + Nothing <- write "message 4" + q `hasSize` (4, False, True) + Msg "message 1" <- tryPeekMsg ms q + q `hasSize` (4, False, True) + Msg "message 1" <- tryDelMsg ms q mId1 + q `hasSize` (3, False, True) + Msg "message 2" <- tryPeekMsg ms q + (Msg "message 2", Msg "message 3") <- tryDelPeekMsg ms q mId2 + q `hasSize` (2, False, True) + (Msg "message 3", Just MessageQuota {msgId = mId4}) <- tryDelPeekMsg ms q mId3 + q `hasSize` (1, False, True) + (Just MessageQuota {}, Nothing) <- tryDelPeekMsg ms q mId4 + q `hasSize` (0, True, False) + +checkQueueSize :: PostgresMsgStore -> PostgresQueue -> (Int64, Bool, Bool) -> ExceptT ErrorType IO () +checkQueueSize ms q (size, canWrt, expire) = liftIO $ do + [(size', canWrt', expire')] <- + withTransaction (dbStore $ queueStore ms) $ \db -> + DB.query db "SELECT msg_queue_size, msg_can_write, msg_queue_expire FROM msg_queues WHERE recipient_id = ?" (Only (recipientId q)) + size' `shouldBe` size + canWrt' `shouldBe` canWrt + expire' `shouldBe` expire + +testResetMessageCounts :: PostgresMsgStore -> IO () +testResetMessageCounts ms = do + g <- C.newRandom + (rId, qr) <- testNewQueueRec g QMMessaging + runRight_ $ do + q <- ExceptT $ addQueue ms rId qr + let write s = writeMsg ms q True =<< mkMessage s + hasSize = checkQueueSize ms + Just (Message {msgId = mId1}, True) <- write "message 1" + Just (Message {msgId = mId2}, False) <- write "message 2" + Just (Message {msgId = mId3}, False) <- write "message 3" + Nothing <- write "message 4" + q `hasSize` (4, False, True) + liftIO $ setIncorrectSize q (10, True) + Nothing <- write "message 5" + q `hasSize` (11, False, True) + (Msg "message 1", Msg "message 2") <- tryDelPeekMsg ms q mId1 + q `hasSize` (10, False, True) + (Msg "message 2", Msg "message 3") <- tryDelPeekMsg ms q mId2 + q `hasSize` (9, False, True) + (Msg "message 3", Just MessageQuota {msgId = mId4}) <- tryDelPeekMsg ms q mId3 + q `hasSize` (8, False, True) + (Just MessageQuota {}, Just MessageQuota {msgId = mId5}) <- tryDelPeekMsg ms q mId4 + q `hasSize` (7, False, True) + (Just MessageQuota {}, Nothing) <- tryDelPeekMsg ms q mId5 + q `hasSize` (0, True, False) -- reset + where + setIncorrectSize :: PostgresQueue -> (Int64, Bool) -> IO () + setIncorrectSize q (size, canWrt) = + void $ withTransaction (dbStore $ queueStore ms) $ \db -> + DB.execute db "UPDATE msg_queues SET msg_queue_size = ?, msg_can_write = ? WHERE recipient_id = ?" (size, canWrt, recipientId q) +#endif + testQueueState :: JournalMsgStore s -> IO () testQueueState ms = do g <- C.newRandom rId <- EntityId <$> atomically (C.randomBytes 24 g) let dir = msgQueueDirectory ms rId - statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId) + statePath = msgQueueStatePath dir rId createDirectoryIfMissing True dir state <- newMsgQueueState <$> newJournalId (random ms) withFile statePath WriteMode (`appendState` state) @@ -312,7 +456,7 @@ testMessageState ms = do g <- C.newRandom (rId, qr) <- testNewQueueRec g QMMessaging let dir = msgQueueDirectory ms rId - statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId) + statePath = msgQueueStatePath dir rId write q s = writeMsg ms q True =<< mkMessage s mId1 <- runRight $ do @@ -337,7 +481,7 @@ testRemoveJournals ms = do g <- C.newRandom (rId, qr) <- testNewQueueRec g QMMessaging let dir = msgQueueDirectory ms rId - statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId) + statePath = msgQueueStatePath dir rId write q s = writeMsg ms q True =<< mkMessage s runRight $ do @@ -361,7 +505,7 @@ testRemoveJournals ms = do Nothing <- tryPeekMsg ms q -- still not removed, queue is empty and not opened liftIO $ journalFilesCount dir `shouldReturn` 1 - _mq <- isolateQueue q "test" $ getMsgQueue ms q False + _mq <- isolateQueue ms q "test" $ getMsgQueue ms q False -- journal is removed liftIO $ journalFilesCount dir `shouldReturn` 0 liftIO $ stateBackupCount dir `shouldReturn` 1 @@ -442,7 +586,7 @@ testExpireIdleQueues = do ms <- newMsgStore (testJournalStoreCfg MQStoreCfg) {idleInterval = 0} let dir = msgQueueDirectory ms rId - statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId) + statePath = msgQueueStatePath dir rId write q s = writeMsg ms q True =<< mkMessage s q <- runRight $ do @@ -461,7 +605,7 @@ testExpireIdleQueues = do old <- expireBeforeEpoch ExpirationConfig {ttl = 1, checkInterval = 1} -- no old messages now <- systemSeconds <$> getSystemTime - (expired_, stored) <- runRight $ isolateQueue q "" $ withIdleMsgQueue now ms q $ deleteExpireMsgs_ old q + (expired_, stored) <- runRight $ isolateQueue ms q "" $ withIdleMsgQueue now ms q $ deleteExpireMsgs_ old q expired_ `shouldBe` Just 0 stored `shouldBe` 0 (Nothing, False) <- readQueueState ms statePath @@ -478,7 +622,7 @@ testReadFileMissing ms = do Msg "message 1" <- tryPeekMsg ms q pure q - mq <- fromJust <$> readTVarIO (msgQueue q) + mq <- fromJust <$> readTVarIO (msgQueue' q) MsgQueueState {readState = rs} <- readTVarIO $ state mq closeMsgQueue ms q let path = journalFilePath (queueDirectory $ queue mq) $ journalId rs @@ -497,7 +641,7 @@ testReadFileMissingSwitch ms = do (rId, qr) <- testNewQueueRec g QMMessaging q <- writeMessages ms rId qr - mq <- fromJust <$> readTVarIO (msgQueue q) + mq <- fromJust <$> readTVarIO (msgQueue' q) MsgQueueState {readState = rs} <- readTVarIO $ state mq closeMsgQueue ms q let path = journalFilePath (queueDirectory $ queue mq) $ journalId rs @@ -515,7 +659,7 @@ testWriteFileMissing ms = do (rId, qr) <- testNewQueueRec g QMMessaging q <- writeMessages ms rId qr - mq <- fromJust <$> readTVarIO (msgQueue q) + mq <- fromJust <$> readTVarIO (msgQueue' q) MsgQueueState {writeState = ws} <- readTVarIO $ state mq closeMsgQueue ms q let path = journalFilePath (queueDirectory $ queue mq) $ journalId ws @@ -538,7 +682,7 @@ testReadAndWriteFilesMissing ms = do (rId, qr) <- testNewQueueRec g QMMessaging q <- writeMessages ms rId qr - mq <- fromJust <$> readTVarIO (msgQueue q) + mq <- fromJust <$> readTVarIO (msgQueue' q) MsgQueueState {readState = rs, writeState = ws} <- readTVarIO $ state mq closeMsgQueue ms q removeFile $ journalFilePath (queueDirectory $ queue mq) $ journalId rs diff --git a/tests/CoreTests/UtilTests.hs b/tests/CoreTests/UtilTests.hs index 4159f25e1..946902358 100644 --- a/tests/CoreTests/UtilTests.hs +++ b/tests/CoreTests/UtilTests.hs @@ -45,56 +45,32 @@ utilTests = do runExceptT (throwTestException `catchError` handleCatch) `shouldThrow` (\(e :: IOError) -> show e == "user error (error)") describe "tryAllErrors" $ do it "should return ExceptT error as Left" $ - runExceptT (tryAllErrors testErr throwTestError) `shouldReturn` Right (Left (TestError "error")) + runExceptT (tryAllErrors throwTestError) `shouldReturn` Right (Left (TestError "error")) it "should return SomeException as Left" $ - runExceptT (tryAllErrors testErr throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) + runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) it "should return no errors as Right" $ - runExceptT (tryAllErrors testErr noErrors) `shouldReturn` Right (Right "no errors") - describe "tryAllErrors specialized as tryTestError" $ do - let tryTestError = tryAllErrors testErr - it "should return ExceptT error as Left" $ - runExceptT (tryTestError throwTestError) `shouldReturn` Right (Left (TestError "error")) - it "should return SomeException as Left" $ - runExceptT (tryTestError throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) - it "should return no errors as Right" $ - runExceptT (tryTestError noErrors) `shouldReturn` Right (Right "no errors") + runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors") describe "catchAllErrors" $ do it "should catch ExceptT error" $ - runExceptT (catchAllErrors testErr throwTestError handleCatch) `shouldReturn` Right "caught TestError \"error\"" + runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\"" it "should catch SomeException" $ - runExceptT (catchAllErrors testErr throwTestException handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" + runExceptT (throwTestException `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" it "should not throw if there are no errors" $ - runExceptT (catchAllErrors testErr noErrors throwError) `shouldReturn` Right "no errors" - describe "catchAllErrors specialized as catchTestError" $ do - let catchTestError = catchAllErrors testErr - it "should catch ExceptT error" $ - runExceptT (throwTestError `catchTestError` handleCatch) `shouldReturn` Right "caught TestError \"error\"" - it "should catch SomeException" $ - runExceptT (throwTestException `catchTestError` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\"" - it "should not throw if there are no errors" $ - runExceptT (noErrors `catchTestError` throwError) `shouldReturn` Right "no errors" + runExceptT (noErrors `catchAllErrors` throwError) `shouldReturn` Right "no errors" describe "catchThrow" $ do it "should re-throw ExceptT error" $ - runExceptT (throwTestError `catchThrow` testErr) `shouldReturn` Left (TestError "error") + runExceptT (throwTestError `catchThrow` fromSomeException) `shouldReturn` Left (TestError "error") it "should catch SomeException and throw as ExceptT error" $ - runExceptT (throwTestException `catchThrow` testErr) `shouldReturn` Left (TestException "user error (error)") + runExceptT (throwTestException `catchThrow` fromSomeException) `shouldReturn` Left (TestException "user error (error)") it "should not throw if there are no exceptions" $ - runExceptT (noErrors `catchThrow` testErr) `shouldReturn` Right "no errors" + runExceptT (noErrors `catchThrow` fromSomeException) `shouldReturn` Right "no errors" describe "allFinally should run final action" $ do it "then throw ExceptT error" $ withFinal $ \final -> - runExceptT (allFinally testErr throwTestError final) `shouldReturn` Left (TestError "error") + runExceptT (throwTestError `allFinally` final) `shouldReturn` Left (TestError "error") it "then throw SomeException as ExceptT error" $ withFinal $ \final -> - runExceptT (allFinally testErr throwTestException final) `shouldReturn` Left (TestException "user error (error)") + runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)") it "and should not throw if there are no exceptions" $ withFinal $ \final -> - runExceptT (allFinally testErr noErrors final) `shouldReturn` Right "no errors" - describe "allFinally specialized as testFinally should run final action" $ do - let testFinally = allFinally testErr - it "then throw ExceptT error" $ withFinal $ \final -> - runExceptT (throwTestError `testFinally` final) `shouldReturn` Left (TestError "error") - it "then throw SomeException as ExceptT error" $ withFinal $ \final -> - runExceptT (throwTestException `testFinally` final) `shouldReturn` Left (TestException "user error (error)") - it "and should not throw if there are no exceptions" $ withFinal $ \final -> - runExceptT (noErrors `testFinally` final) `shouldReturn` Right "no errors" + runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors" where throwTestError :: ExceptT TestError IO String throwTestError = throwError $ TestError "error" @@ -102,8 +78,6 @@ utilTests = do throwTestException = liftIO $ throwIO $ userError "error" noErrors :: ExceptT TestError IO String noErrors = pure "no errors" - testErr :: SomeException -> TestError - testErr = TestException . show handleCatch :: TestError -> ExceptT TestError IO String handleCatch e = pure $ "caught " <> show e handleException :: SomeException -> ExceptT TestError IO String @@ -119,3 +93,6 @@ data TestError = TestError String | TestException String deriving (Eq, Show) instance Exception TestError + +instance AnyError TestError where + fromSomeException = TestException . show diff --git a/tests/PostgresSchemaDump.hs b/tests/PostgresSchemaDump.hs index 77cc08fea..e9b54d540 100644 --- a/tests/PostgresSchemaDump.hs +++ b/tests/PostgresSchemaDump.hs @@ -12,7 +12,7 @@ import Data.Maybe (fromJust, isJust) import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common (DBOpts (..)) import qualified Simplex.Messaging.Agent.Store.Postgres.Migrations as Migrations -import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration) +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration) import Simplex.Messaging.Util (ifM, whenM) import System.Directory (doesFileExist, removeFile) import System.Process (readCreateProcess, shell) @@ -30,12 +30,12 @@ postgresSchemaDumpTest migrations skipComparisonForDownMigrations testDBOpts@DBO testVerifySchemaDump = do savedSchema <- ifM (doesFileExist srcSchemaPath) (readFile srcSchemaPath) (pure "") savedSchema `deepseq` pure () - void $ createDBStore testDBOpts migrations MCConsole + void $ createDBStore testDBOpts migrations (MigrationConfig MCConsole Nothing) getSchema srcSchemaPath `shouldReturn` savedSchema testSchemaMigrations = do let noDownMigrations = dropWhileEnd (\Migration {down} -> isJust down) migrations - Right st <- createDBStore testDBOpts noDownMigrations MCError + Right st <- createDBStore testDBOpts noDownMigrations (MigrationConfig MCError Nothing) mapM_ (testDownMigration st) $ drop (length noDownMigrations) migrations closeDBStore st whenM (doesFileExist testSchemaPath) $ removeFile testSchemaPath diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 2b18a2d51..3c1ac0150 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -229,6 +229,7 @@ cfgMS msType = withStoreCfg (testServerStoreConfig msType) $ \serverStoreCfg -> dailyBlockQueueQuota = 20, messageExpiration = Just defaultMessageExpiration, expireMessagesOnStart = True, + expireMessagesOnSend = False, idleQueueInterval = defaultIdleQueueInterval, notificationExpiration = defaultNtfExpiration, inactiveClientExpiration = Just defaultInactiveClientExpiration, @@ -273,9 +274,14 @@ serverStoreConfig_ useDbStoreLog = \case ASType SQSMemory SMSJournal -> ASSCfg SQSMemory SMSJournal $ SSCMemoryJournal {storeLogFile = testStoreLogFile, storeMsgsPath = testStoreMsgsDir} ASType SQSPostgres SMSJournal -> - let dbStoreLogPath = if useDbStoreLog then Just testStoreLogFile else Nothing - storeCfg = PostgresStoreCfg {dbOpts = testStoreDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = 86400} - in ASSCfg SQSPostgres SMSJournal SSCDatabaseJournal {storeCfg, storeMsgsPath' = testStoreMsgsDir} + ASSCfg SQSPostgres SMSJournal SSCDatabaseJournal {storeCfg, storeMsgsPath' = testStoreMsgsDir} +#if defined(dbServerPostgres) + ASType SQSPostgres SMSPostgres -> + ASSCfg SQSPostgres SMSPostgres $ SSCDatabase storeCfg +#endif + where + dbStoreLogPath = if useDbStoreLog then Just testStoreLogFile else Nothing + storeCfg = PostgresStoreCfg {dbOpts = testStoreDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = 86400} cfgV7 :: AServerConfig cfgV7 = updateCfg cfg $ \cfg' -> cfg' {smpServerVRange = mkVersionRange minServerSMPRelayVersion authCmdsSMPVersion} diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 560ac63d7..53269d6f6 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -14,6 +14,7 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} +{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module ServerTests where @@ -42,10 +43,10 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll, parseString) import Simplex.Messaging.Protocol import Simplex.Messaging.Server (exportMessages) -import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..), ServerStoreCfg (..), readWriteQueueStore) +import Simplex.Messaging.Server.Env.STM (AStoreType (..), MsgStore (..), ServerConfig (..), ServerStoreCfg (..), readWriteQueueStore) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Server.MsgStore.Journal (JournalStoreConfig (..), QStoreCfg (..), stmQueueStore) -import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), SMSType (..), SQSType (..), newMsgStore) +import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), QSType (..), SMSType (..), SQSType (..), newMsgStore) import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..)) import Simplex.Messaging.Server.StoreLog (StoreLogRecord (..), closeStoreLog) import Simplex.Messaging.Transport @@ -59,6 +60,11 @@ import Test.HUnit import Test.Hspec hiding (fit, it) import Util +#if defined(dbServerPostgres) +import CoreTests.MsgStoreTests (testPostgresStoreConfig) +import Simplex.Messaging.Server.MsgStore.Postgres (PostgresMsgStoreCfg (..), exportDbMessages) +#endif + serverTests :: SpecWith (ASrvTransport, AStoreType) serverTests = do describe "SMP queues" $ do @@ -86,6 +92,7 @@ serverTests = do describe "Message notifications" $ do testMessageNotifications testMessageServiceNotifications + testServiceNotificationsTwoRestarts describe "Message expiration" $ do testMsgExpireOnSend testMsgExpireOnInterval @@ -914,14 +921,27 @@ testRestoreExpireMessages = exportStoreMessages :: AStoreType -> IO () exportStoreMessages = \case ASType _ SMSJournal -> export + ASType _ SMSPostgres -> exportDB ASType _ SMSMemory -> pure () where export = do - ms <- newMsgStore (testJournalStoreCfg MQStoreCfg) {quota = 4} + ms <- readWriteQueues + exportMessages False (StoreJournal ms) testStoreMsgsFile False + closeMsgStore ms +#if defined(dbServerPostgres) + exportDB = do + readWriteQueues >>= closeMsgStore + ms' <- newMsgStore (testPostgresStoreConfig {quota = 4} :: PostgresMsgStoreCfg) + _n <- withFile testStoreMsgsFile WriteMode $ exportDbMessages False ms' + closeMsgStore ms' +#else + exportDB = error "compiled without server_postgres flag" +#endif + readWriteQueues = do + ms <- newMsgStore ((testJournalStoreCfg MQStoreCfg) {quota = 4} :: JournalStoreConfig 'QSMemory) readWriteQueueStore True (mkQueue ms True) testStoreLogFile (stmQueueStore ms) >>= closeStoreLog removeFileIfExists testStoreMsgsFile - exportMessages False ms testStoreMsgsFile False - closeMsgStore ms + pure ms runTest :: Transport c => TProxy c 'TServer -> (THandleSMP c 'TClient -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () @@ -1091,7 +1111,6 @@ testMessageServiceNotifications = (rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g Resp "1" _ (NID nId _) <- signSendRecv rh rKey ("1", rId, NKEY nPub rcvNtfPubDhKey) serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g - -- TODO [certs] we need to get certificate fingerprint and include it into signed over for NSUB commands testNtfServiceClient t serviceKeys $ \nh1 -> do -- can't subscribe without service signature in service connection Resp "2a" _ (ERR SERVICE) <- signSendRecv nh1 nKey ("2a", nId, NSUB) @@ -1155,12 +1174,57 @@ testMessageServiceNotifications = Resp "" _ (NMSG _ _) <- tGet1 nh pure () +testServiceNotificationsTwoRestarts :: SpecWith (ASrvTransport, AStoreType) +testServiceNotificationsTwoRestarts = + it "subscribe notifier as service and deliver notifications after two restarts" $ \ps@(ATransport t, _) -> do + g <- C.newRandom + (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + (nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g + serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g + (rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g + (rId, rKey, sId, dec, serviceId) <- withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> do + (sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub + let dec = decryptMsgV3 dhShared + Resp "0" _ (NID nId _) <- signSendRecv rh rKey ("0", rId, NKEY nPub rcvNtfPubDhKey) + testNtfServiceClient t serviceKeys $ \nh -> do + Resp "1" _ (SOK (Just serviceId)) <- serviceSignSendRecv nh nKey servicePK ("1", nId, NSUB) + deliverMessage rh rId rKey sh sId sKey nh "hello" dec + pure (rId, rKey, sId, dec, serviceId) + threadDelay 250000 + withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> + testNtfServiceClient t serviceKeys $ \nh -> do + Resp "2.1" serviceId' (SOKS n) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("2.1", serviceId, NSUBS) + n `shouldBe` 1 + Resp "2.2" _ (SOK Nothing) <- signSendRecv rh rKey ("2.2", rId, SUB) + serviceId' `shouldBe` serviceId + deliverMessage rh rId rKey sh sId sKey nh "hello 2" dec + threadDelay 250000 + withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> + testNtfServiceClient t serviceKeys $ \nh -> do + Resp "3.1" _ (SOKS n) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("3.1", serviceId, NSUBS) + n `shouldBe` 1 + Resp "3.2" _ (SOK Nothing) <- signSendRecv rh rKey ("3.2", rId, SUB) + deliverMessage rh rId rKey sh sId sKey nh "hello 3" dec + where + runTest2 :: Transport c => TProxy c 'TServer -> (THandleSMP c 'TClient -> THandleSMP c 'TClient -> IO a) -> ThreadId -> IO a + runTest2 _ test' server = do + a <- testSMPClient $ \h1 -> testSMPClient $ \h2 -> test' h1 h2 + killThread server + pure a + deliverMessage rh rId rKey sh sId sKey nh msgText dec = do + Resp "msg-1" _ OK <- signSendRecv sh sKey ("msg-1", sId, _SEND' msgText) + Resp "" _ (Msg mId msg) <- tGet1 rh + Resp "msg-2" _ OK <- signSendRecv rh rKey ("msg-2", rId, ACK mId) + (dec mId msg, Right msgText) #== "delivered from queue" + Resp "" _ (NMSG _ _) <- tGet1 nh + pure () + testMsgExpireOnSend :: SpecWith (ASrvTransport, AStoreType) testMsgExpireOnSend = it "should expire messages that are not received before messageTTL on SEND" $ \(ATransport (t :: TProxy c 'TServer), msType) -> do g <- C.newRandom (sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g - let cfg' = updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 10000}} + let cfg' = updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {expireMessagesOnSend = True, messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 10000}} withSmpServerConfigOn (ATransport t) cfg' testPort $ \_ -> testSMPClient @c $ \sh -> do (sId, rId, rKey, dhShared) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub diff --git a/tests/Test.hs b/tests/Test.hs index 4598bb8e4..364080e0c 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -102,9 +102,11 @@ main = do ] -- skipComparisonForDownMigrations testStoreDBOpts "src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql" - aroundAll_ (postgressBracket testServerDBConnectInfo) $ + around_ (postgressBracket testServerDBConnectInfo) $ do describe "SMP server via TLS, postgres+jornal message store" $ before (pure (transport @TLS, ASType SQSPostgres SMSJournal)) serverTests + describe "SMP server via TLS, postgres-only message store" $ + before (pure (transport @TLS, ASType SQSPostgres SMSPostgres)) serverTests #endif describe "SMP server via TLS, jornal message store" $ do describe "SMP syntax" $ serverSyntaxTests (transport @TLS) @@ -122,16 +124,21 @@ main = do [] -- skipComparisonForDownMigrations ntfTestStoreDBOpts "src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql" - aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $ do + around_ (postgressBracket ntfTestServerDBConnectInfo) $ do describe "Notifications server (SMP server: jornal store)" $ ntfServerTests (transport @TLS, ASType SQSMemory SMSJournal) - aroundAll_ (postgressBracket testServerDBConnectInfo) $ + around_ (postgressBracket testServerDBConnectInfo) $ do describe "Notifications server (SMP server: postgres+jornal store)" $ ntfServerTests (transport @TLS, ASType SQSPostgres SMSJournal) - aroundAll_ (postgressBracket testServerDBConnectInfo) $ do + describe "Notifications server (SMP server: postgres-only store)" $ + ntfServerTests (transport @TLS, ASType SQSPostgres SMSPostgres) + around_ (postgressBracket testServerDBConnectInfo) $ do describe "SMP client agent, postgres+jornal message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSJournal) + describe "SMP client agent, postgres-only message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSPostgres) describe "SMP proxy, postgres+jornal message store" $ before (pure $ ASType SQSPostgres SMSJournal) smpProxyTests + describe "SMP proxy, postgres-only message store" $ + before (pure $ ASType SQSPostgres SMSPostgres) smpProxyTests #endif describe "SMP client agent, jornal message store" $ agentTests (transport @TLS, ASType SQSMemory SMSJournal) describe "SMP proxy, jornal message store" $ diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index d19705a37..a83ec08a6 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -38,7 +38,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs) import qualified Simplex.Messaging.Crypto.File as CF import Simplex.Messaging.Encoding.String (StrEncoding (..)) -import Simplex.Messaging.Protocol (BasicAuth, ProtoServerWithAuth (..), ProtocolServer (..), XFTPServerWithAuth) +import Simplex.Messaging.Protocol (BasicAuth, NetworkError (..), ProtoServerWithAuth (..), ProtocolServer (..), XFTPServerWithAuth) import Simplex.Messaging.Server.Expiration (ExpirationConfig (..)) import Simplex.Messaging.Util (tshow) import System.Directory (doesDirectoryExist, doesFileExist, getFileSize, listDirectory, removeFile) @@ -84,7 +84,7 @@ xftpAgentTests = it "should pass without basic auth" $ testXFTPServerTest Nothing (noAuthSrv testXFTPServer2) `shouldReturn` Nothing let srv1 = testXFTPServer2 {keyHash = "1234"} it "should fail with incorrect fingerprint" $ do - testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) + testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) $ NETWORK NEUnknownCAError) describe "server with password" $ do let auth = Just "abcd" srv = ProtoServerWithAuth testXFTPServer2