mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-24 08:05:25 +00:00
Merge branch 'master' into unified-push
This commit is contained in:
+35
-16
@@ -58,11 +58,11 @@ jobs:
|
||||
# =============================
|
||||
|
||||
build:
|
||||
name: "ubuntu-${{ matrix.os }}, GHC: ${{ matrix.ghc }}"
|
||||
name: "ubuntu-${{ matrix.os }}-${{ matrix.arch }}, GHC: ${{ matrix.ghc }}"
|
||||
needs: maybe-release
|
||||
env:
|
||||
apps: "smp-server xftp-server ntf-server xftp"
|
||||
runs-on: ubuntu-${{ matrix.os }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
@@ -81,16 +81,34 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- os: 22.04
|
||||
os_underscore: 22_04
|
||||
arch: x86-64
|
||||
runner: "ubuntu-22.04"
|
||||
ghc: "8.10.7"
|
||||
platform_name: 22_04-8.10.7
|
||||
should_run: ${{ !(github.ref == 'refs/heads/stable' || startsWith(github.ref, 'refs/tags/v')) }}
|
||||
- os: 22.04
|
||||
os_underscore: 22_04
|
||||
arch: x86-64
|
||||
runner: "ubuntu-22.04"
|
||||
ghc: "9.6.3"
|
||||
platform_name: 22_04-x86-64
|
||||
should_run: true
|
||||
- os: 24.04
|
||||
os_underscore: 24_04
|
||||
arch: x86-64
|
||||
runner: "ubuntu-24.04"
|
||||
ghc: "9.6.3"
|
||||
should_run: true
|
||||
- os: 22.04
|
||||
os_underscore: 22_04
|
||||
arch: aarch64
|
||||
runner: "ubuntu-22.04-arm"
|
||||
ghc: "9.6.3"
|
||||
should_run: true
|
||||
- os: 24.04
|
||||
os_underscore: 24_04
|
||||
arch: aarch64
|
||||
runner: "ubuntu-24.04-arm"
|
||||
ghc: "9.6.3"
|
||||
platform_name: 24_04-x86-64
|
||||
should_run: true
|
||||
steps:
|
||||
- name: Clone project
|
||||
@@ -127,11 +145,7 @@ jobs:
|
||||
context: .
|
||||
load: true
|
||||
file: Dockerfile.build
|
||||
tags: build/${{ matrix.platform_name }}:latest
|
||||
cache-from: |
|
||||
type=gha
|
||||
type=gha,scope=master
|
||||
cache-to: type=gha,mode=max
|
||||
tags: build/${{ matrix.os }}:latest
|
||||
build-args: |
|
||||
TAG=${{ matrix.os }}
|
||||
GHC=${{ matrix.ghc }}
|
||||
@@ -143,23 +157,28 @@ jobs:
|
||||
path: |
|
||||
~/.cabal/store
|
||||
dist-newstyle
|
||||
key: ${{ matrix.os }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }}
|
||||
key: ubuntu-${{ matrix.os }}-${{ matrix.arch }}-ghc${{ matrix.ghc }}-${{ hashFiles('cabal.project', 'simplexmq.cabal') }}
|
||||
|
||||
- name: Start container
|
||||
if: matrix.should_run == true
|
||||
shell: bash
|
||||
run: |
|
||||
docker run -t -d \
|
||||
--device /dev/fuse \
|
||||
--cap-add SYS_ADMIN \
|
||||
--security-opt apparmor:unconfined \
|
||||
--name builder \
|
||||
-v ~/.cabal:/root/.cabal \
|
||||
-v /home/runner/work/_temp:/home/runner/work/_temp \
|
||||
-v ${{ github.workspace }}:/project \
|
||||
build/${{ matrix.platform_name }}:latest
|
||||
build/${{ matrix.os }}:latest
|
||||
|
||||
- name: Build smp-server (postgresql) and tests
|
||||
if: matrix.should_run == true
|
||||
shell: docker exec -t builder sh -eu {0}
|
||||
run: |
|
||||
chmod -R 777 dist-newstyle ~/.cabal && git config --global --add safe.directory '*'
|
||||
cabal clean
|
||||
cabal update
|
||||
cabal build --jobs=$(nproc) --enable-tests -fserver_postgres
|
||||
mkdir -p /out
|
||||
@@ -181,7 +200,7 @@ jobs:
|
||||
id: prepare-postgres
|
||||
shell: bash
|
||||
run: |
|
||||
name="smp-server-postgres-ubuntu-${{ matrix.platform_name }}"
|
||||
name="smp-server-postgres-ubuntu-${{ matrix.os_underscore }}-${{ matrix.arch }}"
|
||||
docker cp builder:/out/smp-server $name
|
||||
|
||||
path="${{ github.workspace }}/$name"
|
||||
@@ -213,9 +232,9 @@ jobs:
|
||||
printf 'bins<<EOF\n' > bins.output
|
||||
printf 'hashes<<EOF\n' > hashes.output
|
||||
for i in ${{ env.apps }}; do
|
||||
mv ./out/$i ./$i-ubuntu-${{ matrix.platform_name }}
|
||||
name="$i-ubuntu-${{ matrix.os_underscore }}-${{ matrix.arch }}"
|
||||
|
||||
name="$i-ubuntu-${{ matrix.platform_name }}"
|
||||
mv ./out/$i ./$name
|
||||
|
||||
path="${{ github.workspace }}/$name"
|
||||
hash="SHA2-256($name)= $(openssl sha256 $path | cut -d' ' -f 2)"
|
||||
@@ -246,7 +265,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Test
|
||||
if: matrix.should_run == true
|
||||
if: matrix.should_run == true && matrix.arch == 'x86-64'
|
||||
timeout-minutes: 120
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
+2
-1
@@ -1,7 +1,7 @@
|
||||
cabal-version: 1.12
|
||||
|
||||
name: simplexmq
|
||||
version: 6.4.4.1
|
||||
version: 6.5.0.0.1
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
<./docs/Simplex-Messaging-Client.html client> and
|
||||
@@ -267,6 +267,7 @@ library
|
||||
Simplex.Messaging.Notifications.Server.Store.Postgres
|
||||
Simplex.Messaging.Notifications.Server.Store.Types
|
||||
Simplex.Messaging.Notifications.Server.StoreLog
|
||||
Simplex.Messaging.Server.MsgStore.Postgres
|
||||
Simplex.Messaging.Server.QueueStore.Postgres
|
||||
Simplex.Messaging.Server.QueueStore.Postgres.Migrations
|
||||
other-modules:
|
||||
|
||||
@@ -75,7 +75,7 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String (strDecode, strEncode)
|
||||
import Simplex.Messaging.Protocol (ProtocolServer, ProtocolType (..), XFTPServer)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (catchAll_, liftError, tshow, unlessM, whenM)
|
||||
import Simplex.Messaging.Util (allFinally, catchAll_, catchAllErrors, liftError, tshow, unlessM, whenM)
|
||||
import System.FilePath (takeFileName, (</>))
|
||||
import UnliftIO
|
||||
import UnliftIO.Directory
|
||||
@@ -198,10 +198,10 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
atomically $ incXFTPServerStat c userId srv downloadAttempts
|
||||
downloadFileChunk fc replica approvedRelays
|
||||
`catchAgentError` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP rcv worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (serverHostError e) $ notify c (fromMaybe rcvFileEntityId redirectEntityId_) (RFWARN e)
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateRcvChunkReplicaDelay db rcvChunkReplicaId replicaDelay
|
||||
@@ -280,7 +280,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
runXFTPOperation AgentConfig {rcvFilesTTL} =
|
||||
withWork c doWork (`getNextRcvFileToDecrypt` rcvFilesTTL) $
|
||||
\f@RcvFile {rcvFileId, rcvFileEntityId, tmpPath, redirect} ->
|
||||
decryptFile f `catchAgentError` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath
|
||||
decryptFile f `catchAllErrors` rcvWorkerInternalError c rcvFileId rcvFileEntityId (redirectEntityId <$> redirect) tmpPath
|
||||
decryptFile :: RcvFile -> AM ()
|
||||
decryptFile RcvFile {rcvFileId, rcvFileEntityId, size, digest, key, nonce, tmpPath, saveFile, status, chunks, redirect} = do
|
||||
let CryptoFile savePath cfArgs = saveFile
|
||||
@@ -307,7 +307,7 @@ runXFTPRcvLocalWorker c Worker {doWork} = do
|
||||
liftIO $ waitUntilForeground c
|
||||
withStore' c (`updateRcvFileComplete` rcvFileId)
|
||||
-- proceed with redirect
|
||||
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `agentFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
yaml <- liftError (FILE . FILE_IO . show) (CF.readFile $ CryptoFile fsSavePath cfArgs) `allFinally` (lift $ toFSFilePath fsSavePath >>= removePath)
|
||||
next@FileDescription {chunks = nextChunks} <- case strDecode (LB.toStrict yaml) of
|
||||
-- TODO switch to another error constructor
|
||||
Left _ -> throwE . FILE $ REDIRECT "decode error"
|
||||
@@ -399,7 +399,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
runXFTPOperation cfg@AgentConfig {sndFilesTTL} =
|
||||
withWork c doWork (`getNextSndFileToPrepare` sndFilesTTL) $
|
||||
\f@SndFile {sndFileId, sndFileEntityId, prefixPath} ->
|
||||
prepareFile cfg f `catchAgentError` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
|
||||
prepareFile cfg f `catchAllErrors` sndWorkerInternalError c sndFileId sndFileEntityId prefixPath
|
||||
prepareFile :: AgentConfig -> SndFile -> AM ()
|
||||
prepareFile _ SndFile {prefixPath = Nothing} =
|
||||
throwE $ INTERNAL "no prefix path"
|
||||
@@ -468,11 +468,11 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
let triedAllSrvs = n > userSrvCount
|
||||
createWithNextSrv triedHosts
|
||||
`catchAgentError` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP prepare worker" (retryLoop loop triedAllSrvs e) (throwE e) e
|
||||
where
|
||||
-- we don't do closeXFTPServerClient here to not risk closing connection for concurrent chunk upload
|
||||
retryLoop loop triedAllSrvs e = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (triedAllSrvs && serverHostError e) $ notify c sndFileEntityId $ SFWARN e
|
||||
liftIO $ assertAgentForeground c
|
||||
loop
|
||||
@@ -508,10 +508,10 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
atomically $ incXFTPServerStat c userId srv uploadAttempts
|
||||
uploadFileChunk cfg fc replica
|
||||
`catchAgentError` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP snd worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (serverHostError e) $ notify c sndFileEntityId $ SFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server digest
|
||||
withStore' c $ \db -> updateSndChunkReplicaDelay db sndChunkReplicaId replicaDelay
|
||||
@@ -681,10 +681,10 @@ runXFTPDelWorker c srv Worker {doWork} = do
|
||||
liftIO $ waitForUserNetwork c
|
||||
atomically $ incXFTPServerStat c userId srv deleteAttempts
|
||||
deleteChunkReplica
|
||||
`catchAgentError` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
|
||||
`catchAllErrors` \e -> retryOnError "XFTP del worker" (retryLoop loop e delay') (retryDone e) e
|
||||
where
|
||||
retryLoop loop e replicaDelay = do
|
||||
flip catchAgentError (\_ -> pure ()) $ do
|
||||
flip catchAllErrors (\_ -> pure ()) $ do
|
||||
when (serverHostError e) $ notify c "" $ SFWARN e
|
||||
liftIO $ closeXFTPServerClient c userId server chunkDigest
|
||||
withStore' c $ \db -> updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId replicaDelay
|
||||
|
||||
@@ -59,6 +59,8 @@ import Simplex.Messaging.Protocol
|
||||
RecipientId,
|
||||
SenderId,
|
||||
pattern NoEntity,
|
||||
NetworkError (..),
|
||||
toNetworkError,
|
||||
)
|
||||
import Simplex.Messaging.Transport (ALPN, CertChainPubKey (..), HandshakeError (..), THandleAuth (..), THandleParams (..), TransportError (..), TransportPeer (..), defaultSupportedParams)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost)
|
||||
@@ -191,7 +193,7 @@ xftpHTTP2Config transportConfig XFTPClientConfig {xftpNetworkConfig = NetworkCon
|
||||
xftpClientError :: HTTP2ClientError -> XFTPClientError
|
||||
xftpClientError = \case
|
||||
HCResponseTimeout -> PCEResponseTimeout
|
||||
HCNetworkError -> PCENetworkError
|
||||
HCNetworkError e -> PCENetworkError e
|
||||
HCIOError e -> PCEIOError e
|
||||
|
||||
sendXFTPCommand :: forall p. FilePartyI p => XFTPClient -> C.APrivateAuthKey -> XFTPFileId -> FileCommand p -> Maybe XFTPChunkSpec -> ExceptT XFTPClientError IO (FileResponse, HTTP2Body)
|
||||
@@ -261,9 +263,9 @@ downloadXFTPChunk g c@XFTPClient {config} rpKey fId chunkSpec@XFTPRcvChunkSpec {
|
||||
ExceptT (sequence <$> (t `timeout` (download cbState `catches` errors))) >>= maybe (throwE PCEResponseTimeout) pure
|
||||
where
|
||||
errors =
|
||||
[ Handler $ \(_e :: H.HTTP2Error) -> pure $ Left PCENetworkError,
|
||||
Handler $ \(e :: IOException) -> pure $ Left (PCEIOError e),
|
||||
Handler $ \(_e :: SomeException) -> pure $ Left PCENetworkError
|
||||
[ Handler $ \(e :: H.HTTP2Error) -> pure $ Left $ PCENetworkError $ NEConnectError $ displayException e,
|
||||
Handler $ \(e :: IOException) -> pure $ Left $ PCEIOError e,
|
||||
Handler $ \(e :: SomeException) -> pure $ Left $ PCENetworkError $ toNetworkError e
|
||||
]
|
||||
download cbState =
|
||||
runExceptT . withExceptT PCEResponseError $
|
||||
|
||||
@@ -139,7 +139,7 @@ import Control.Monad.Reader
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.Aeson as J
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:), (.:.), (.::), (.::.))
|
||||
@@ -284,13 +284,13 @@ saveServersStats c@AgentClient {subQ, smpServersStats, xftpServersStats, ntfServ
|
||||
xss <- mapM (liftIO . getAgentXFTPServerStats) =<< readTVarIO xftpServersStats
|
||||
nss <- mapM (liftIO . getAgentNtfServerStats) =<< readTVarIO ntfServersStats
|
||||
let stats = AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}
|
||||
tryAgentError' (withStore' c (`updateServersStats` stats)) >>= \case
|
||||
tryAllErrors' (withStore' c (`updateServersStats` stats)) >>= \case
|
||||
Left e -> atomically $ writeTBQueue subQ ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e)
|
||||
Right () -> pure ()
|
||||
|
||||
restoreServersStats :: AgentClient -> AM' ()
|
||||
restoreServersStats c@AgentClient {smpServersStats, xftpServersStats, ntfServersStats, srvStatsStartedAt} = do
|
||||
tryAgentError' (withStore c getServersStats) >>= \case
|
||||
tryAllErrors' (withStore c getServersStats) >>= \case
|
||||
Left e -> atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ INTERNAL $ show e)
|
||||
Right (startedAt, Nothing) -> atomically $ writeTVar srvStatsStartedAt startedAt
|
||||
Right (startedAt, Just AgentPersistedServerStats {smpServersStats = sss, xftpServersStats = xss, ntfServersStats = OptionalMap nss}) -> do
|
||||
@@ -774,7 +774,7 @@ acceptContactAsync' :: AgentClient -> UserId -> ACorrId -> Bool -> InvitationId
|
||||
acceptContactAsync' c userId corrId enableNtfs invId ownConnInfo pqSupport subMode = do
|
||||
Invitation {connReq} <- withStore c $ \db -> getInvitation db "acceptContactAsync'" invId
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do
|
||||
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAllErrors` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwE err
|
||||
|
||||
@@ -961,7 +961,7 @@ newRcvConnSrv c nm userId connId enableNtfs cMode userData_ clientData pqInitKey
|
||||
createRcvQueue nonce_ qd e2eKeys = do
|
||||
AgentConfig {smpClientVRange = vr} <- asks config
|
||||
ntfServer_ <- if enableNtfs then newQueueNtfServer else pure Nothing
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAgentError` \e -> liftIO (print e) >> throwE e
|
||||
(rq, qUri, tSess, sessId) <- newRcvQueue_ c nm userId connId srvWithAuth vr qd (isJust ntfServer_) subMode nonce_ e2eKeys `catchAllErrors` \e -> liftIO (print e) >> throwE e
|
||||
atomically $ incSMPServerStat c userId srv connCreated
|
||||
rq' <- withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
lift . when (subMode == SMSubscribe) $ addNewQueueSubscription c rq' tSess sessId
|
||||
@@ -1269,7 +1269,7 @@ subscribeConnections' c connIds = do
|
||||
errs' = M.map (Left . storeError) errs
|
||||
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
|
||||
resumeDelivery cs
|
||||
lift $ resumeConnCmds c $ M.keys cs
|
||||
resumeConnCmds c $ M.keys cs
|
||||
rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs)
|
||||
rcvRs' <- storeClientServiceAssocs rcvRs
|
||||
ns <- asks ntfSupervisor
|
||||
@@ -1351,7 +1351,7 @@ subscribeClientService' = undefined
|
||||
|
||||
-- requesting messages sequentially, to reduce memory usage
|
||||
getConnectionMessages' :: AgentClient -> NonEmpty ConnMsgReq -> AM' (NonEmpty (Either AgentErrorType (Maybe SMPMsgMeta)))
|
||||
getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
|
||||
getConnectionMessages' c = mapM $ tryAllErrors' . getConnectionMessage
|
||||
where
|
||||
getConnectionMessage :: ConnMsgReq -> AM (Maybe SMPMsgMeta)
|
||||
getConnectionMessage (ConnMsgReq connId dbQueueId msgTs_) = do
|
||||
@@ -1363,7 +1363,7 @@ getConnectionMessages' c = mapM $ tryAgentError' . getConnectionMessage
|
||||
ContactConnection _ rq -> pure rq
|
||||
SndConnection _ _ -> throwE $ CONN SIMPLEX "getConnectionMessage"
|
||||
NewConnection _ -> throwE $ CMD PROHIBITED "getConnectionMessage: NewConnection"
|
||||
msg_ <- getQueueMessage c rq `catchAgentError` \e -> atomically (releaseGetLock c rq) >> throwError e
|
||||
msg_ <- getQueueMessage c rq `catchAllErrors` \e -> atomically (releaseGetLock c rq) >> throwError e
|
||||
when (isNothing msg_) $ do
|
||||
atomically $ releaseGetLock c rq
|
||||
forM_ msgTs_ $ \msgTs -> withStore' c $ \db -> setLastBrokerTs db connId (DBEntityId dbQueueId) msgTs
|
||||
@@ -1473,10 +1473,10 @@ resumeSrvCmds :: AgentClient -> ConnId -> Maybe SMPServer -> AM' ()
|
||||
resumeSrvCmds = void .:. getAsyncCmdWorker False
|
||||
{-# INLINE resumeSrvCmds #-}
|
||||
|
||||
resumeConnCmds :: AgentClient -> [ConnId] -> AM' ()
|
||||
resumeConnCmds :: AgentClient -> [ConnId] -> AM ()
|
||||
resumeConnCmds c connIds = do
|
||||
connSrvs <- rights . zipWith (second . (,)) connIds <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds)
|
||||
mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs
|
||||
connSrvs <- withStore' c (`getPendingCommandServers` connIds)
|
||||
lift $ mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs
|
||||
|
||||
getAsyncCmdWorker :: Bool -> AgentClient -> ConnId -> Maybe SMPServer -> AM' Worker
|
||||
getAsyncCmdWorker hasWork c connId server =
|
||||
@@ -1534,7 +1534,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do
|
||||
pure CCCompleted
|
||||
-- duplex connection is matched to handle SKEY retries
|
||||
DuplexConnection cData _ (sq :| _) -> do
|
||||
tryAgentError (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case
|
||||
tryAllErrors (mapM_ (connectReplyQueues c cData ownConnInfo (Just sq)) (L.nonEmpty $ smpReplyQueues senderConf)) >>= \case
|
||||
Right () -> pure CCCompleted
|
||||
Left e
|
||||
| temporaryOrHostError e && Just server /= server_ -> do
|
||||
@@ -1621,7 +1621,7 @@ runCommandProcessing c@AgentClient {subQ} connId server_ Worker {doWork} = do
|
||||
tryMoveableCommand action = withRetryInterval ri $ \_ loop -> do
|
||||
liftIO $ waitWhileSuspended c
|
||||
liftIO $ waitForUserNetwork c
|
||||
tryAgentError action >>= \case
|
||||
tryAllErrors action >>= \case
|
||||
Left e
|
||||
| temporaryOrHostError e -> retrySndOp c loop
|
||||
| otherwise -> cmdError e
|
||||
@@ -2065,7 +2065,7 @@ synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchroni
|
||||
ackQueueMessage :: AgentClient -> RcvQueue -> SMP.MsgId -> AM ()
|
||||
ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do
|
||||
atomically $ incSMPServerStat c userId server ackAttempts
|
||||
tryAgentError (sendAck c rq srvMsgId) >>= \case
|
||||
tryAllErrors (sendAck c rq srvMsgId) >>= \case
|
||||
Right _ -> sendMsgNtf ackMsgs
|
||||
Left (SMP _ SMP.NO_MSG) -> sendMsgNtf ackNoMsgErrs
|
||||
Left e -> do
|
||||
@@ -2076,7 +2076,7 @@ ackQueueMessage c rq@RcvQueue {userId, connId, server} srvMsgId = do
|
||||
atomically $ incSMPServerStat c userId server stat
|
||||
whenM (liftIO $ hasGetLock c rq) $ do
|
||||
atomically $ releaseGetLock c rq
|
||||
brokerTs_ <- eitherToMaybe <$> tryAgentError (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId)
|
||||
brokerTs_ <- eitherToMaybe <$> tryAllErrors (withStore c $ \db -> getRcvMsgBrokerTs db connId srvMsgId)
|
||||
atomically $ writeTBQueue (subQ c) ("", connId, AEvt SAEConn $ MSGNTF srvMsgId brokerTs_)
|
||||
|
||||
-- | Suspend SMP agent connection (OFF command) in Reader monad
|
||||
@@ -2307,7 +2307,7 @@ registerNtfToken' c nm suppliedDeviceToken suppliedNtfMode =
|
||||
replaceToken :: NtfTokenId -> AM NtfTknStatus
|
||||
replaceToken tknId = do
|
||||
ns <- asks ntfSupervisor
|
||||
tryReplace ns `catchAgentError` \e ->
|
||||
tryReplace ns `catchAllErrors` \e ->
|
||||
if temporaryOrHostError e
|
||||
then throwE e
|
||||
else do
|
||||
@@ -2451,23 +2451,23 @@ sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM ()
|
||||
sendNtfConnCommands c cmd = do
|
||||
ns <- asks ntfSupervisor
|
||||
connIds <- liftIO $ S.toList <$> getSubscriptions c
|
||||
rs <- lift $ withStoreBatch' c (\db -> map (getConnData db) connIds)
|
||||
rs <- withStore' c (`getConnsData` connIds)
|
||||
let (connIds', cErrs) = enabledNtfConns (zip connIds rs)
|
||||
forM_ (L.nonEmpty connIds') $ \connIds'' ->
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (cmd, connIds'')
|
||||
unless (null cErrs) $ atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS cErrs)
|
||||
where
|
||||
enabledNtfConns :: [(ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)])
|
||||
enabledNtfConns :: [(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)])
|
||||
enabledNtfConns = foldr addEnabledConn ([], [])
|
||||
where
|
||||
addEnabledConn ::
|
||||
(ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode))) ->
|
||||
(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode))) ->
|
||||
([ConnId], [(ConnId, AgentErrorType)]) ->
|
||||
([ConnId], [(ConnId, AgentErrorType)])
|
||||
addEnabledConn cData_ (cIds, errs) = case cData_ of
|
||||
(_, Right (Just (ConnData {connId, enableNtfs}, _))) -> if enableNtfs then (connId : cIds, errs) else (cIds, errs)
|
||||
(connId, Right Nothing) -> (cIds, (connId, INTERNAL "no connection data") : errs)
|
||||
(connId, Left e) -> (cIds, (connId, e) : errs)
|
||||
(connId, Left e) -> (cIds, (connId, INTERNAL (show e)) : errs)
|
||||
|
||||
setNtfServers :: AgentClient -> [NtfServer] -> IO ()
|
||||
setNtfServers c = atomically . writeTVar (ntfServers c)
|
||||
@@ -2564,7 +2564,7 @@ cleanupManager c@AgentClient {subQ} = do
|
||||
where
|
||||
run :: forall e. AEntityI e => (AgentErrorType -> AEvent e) -> AM () -> AM' ()
|
||||
run err a = do
|
||||
waitActive . runExceptT $ a `catchAgentError` (notify "" . err)
|
||||
waitActive . runExceptT $ a `catchAllErrors` (notify "" . err)
|
||||
step <- asks $ cleanupStepInterval . config
|
||||
liftIO $ threadDelay step
|
||||
-- we are catching it to avoid CRITICAL errors in tests when this is the only remaining handle to active
|
||||
@@ -2578,33 +2578,33 @@ cleanupManager c@AgentClient {subQ} = do
|
||||
deleteRcvFilesExpired = do
|
||||
rcvFilesTTL <- asks $ rcvFilesTTL . config
|
||||
rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL)
|
||||
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
|
||||
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`deleteRcvFile'` dbId)
|
||||
deleteRcvFilesDeleted = do
|
||||
rcvDeleted <- withStore' c getCleanupRcvFilesDeleted
|
||||
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
|
||||
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`deleteRcvFile'` dbId)
|
||||
deleteRcvFilesTmpPaths = do
|
||||
rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths
|
||||
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
|
||||
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . RFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`updateRcvFileNoTmpPath` dbId)
|
||||
deleteSndFilesExpired = do
|
||||
sndFilesTTL <- asks $ sndFilesTTL . config
|
||||
sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL)
|
||||
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
|
||||
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
|
||||
lift . forM_ p $ removePath <=< toFSFilePath
|
||||
withStore' c (`deleteSndFile'` dbId)
|
||||
deleteSndFilesDeleted = do
|
||||
sndDeleted <- withStore' c getCleanupSndFilesDeleted
|
||||
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
|
||||
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
|
||||
lift . forM_ p $ removePath <=< toFSFilePath
|
||||
withStore' c (`deleteSndFile'` dbId)
|
||||
deleteSndFilesPrefixPaths = do
|
||||
sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths
|
||||
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
|
||||
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAllErrors (notify entId . SFERR) $ do
|
||||
lift $ removePath =<< toFSFilePath p
|
||||
withStore' c (`updateSndFileNoPrefixPath` dbId)
|
||||
deleteExpiredReplicasForDeletion = do
|
||||
@@ -2652,10 +2652,10 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
where
|
||||
withRcvConn :: SMP.RecipientId -> (forall c. RcvQueue -> Connection c -> AM ()) -> AM' ()
|
||||
withRcvConn rId a = do
|
||||
tryAgentError' (withStore c $ \db -> getRcvConn db srv rId) >>= \case
|
||||
tryAllErrors' (withStore c $ \db -> getRcvConn db srv rId) >>= \case
|
||||
Left e -> notify' "" (ERR e)
|
||||
Right (rq@RcvQueue {connId}, SomeConn _ conn) ->
|
||||
tryAgentError' (a rq conn) >>= \case
|
||||
tryAllErrors' (a rq conn) >>= \case
|
||||
Left e -> notify' connId (ERR e)
|
||||
Right () -> pure ()
|
||||
processSubOk :: RcvQueue -> TVar [ConnId] -> AM ()
|
||||
@@ -2739,7 +2739,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
_ -> pure ()
|
||||
let encryptedMsgHash = C.sha256Hash encAgentMessage
|
||||
g <- asks random
|
||||
tryAgentError (agentClientMsg g encryptedMsgHash) >>= \case
|
||||
tryAllErrors (agentClientMsg g encryptedMsgHash) >>= \case
|
||||
Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do
|
||||
conn'' <- resetRatchetSync
|
||||
case aMessage of
|
||||
@@ -2848,7 +2848,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
ackDel :: InternalId -> AM ACKd
|
||||
ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd
|
||||
handleNotifyAck :: AM ACKd -> AM ACKd
|
||||
handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack
|
||||
handleNotifyAck m = m `catchAllErrors` \e -> notify (ERR e) >> ack
|
||||
SMP.END ->
|
||||
atomically (ifM (activeClientSession c tSess sessId) (removeSubscription c connId $> True) (pure False))
|
||||
>>= notifyEnd
|
||||
@@ -3006,7 +3006,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> AM ACKd
|
||||
messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do
|
||||
logServer "<--" c srv rId $ "MSG <RCPT>:" <> logSecret' srvMsgId
|
||||
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing
|
||||
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAllErrors` \e -> notify (ERR e) $> Nothing
|
||||
case L.nonEmpty . catMaybes $ L.toList rs of
|
||||
Just rs' -> notify (RCVD msgMeta rs') $> ACKPending
|
||||
Nothing -> ack
|
||||
|
||||
@@ -130,6 +130,7 @@ module Simplex.Messaging.Agent.Client
|
||||
hasWorkToDo,
|
||||
hasWorkToDo',
|
||||
withWork,
|
||||
withWork_,
|
||||
withWorkItems,
|
||||
agentOperations,
|
||||
agentOperationBracket,
|
||||
@@ -249,6 +250,7 @@ import Simplex.Messaging.Protocol
|
||||
EntityId (..),
|
||||
ServiceId,
|
||||
ErrorType,
|
||||
NetworkError (..),
|
||||
MsgFlags (..),
|
||||
MsgId,
|
||||
NtfServer,
|
||||
@@ -371,12 +373,12 @@ data SMPConnectedClient = SMPConnectedClient
|
||||
|
||||
type ProxiedRelayVar = SessionVar (Either AgentErrorType ProxiedRelay)
|
||||
|
||||
getAgentWorker :: (Ord k, Show k) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> AM ()) -> AM' Worker
|
||||
getAgentWorker :: (Ord k, Show k, AnyError e, MonadUnliftIO m) => String -> Bool -> AgentClient -> k -> TMap k Worker -> (Worker -> ExceptT e m ()) -> m Worker
|
||||
getAgentWorker = getAgentWorker' id pure
|
||||
{-# INLINE getAgentWorker #-}
|
||||
|
||||
getAgentWorker' :: forall a k. (Ord k, Show k) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> AM ()) -> AM' a
|
||||
getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
getAgentWorker' :: forall a k e m. (Ord k, Show k, AnyError e, MonadUnliftIO m) => (a -> Worker) -> (Worker -> STM a) -> String -> Bool -> AgentClient -> k -> TMap k a -> (a -> ExceptT e m ()) -> m a
|
||||
getAgentWorker' toW fromW name hasWork c@AgentClient {agentEnv} key ws work = do
|
||||
atomically (getWorker >>= maybe createWorker whenExists) >>= \w -> runWorker w $> w
|
||||
where
|
||||
getWorker = TM.lookup key ws
|
||||
@@ -389,12 +391,12 @@ getAgentWorker' toW fromW name hasWork c key ws work = do
|
||||
| otherwise = pure w
|
||||
runWorker w = runWorkerAsync (toW w) runWork
|
||||
where
|
||||
runWork :: AM' ()
|
||||
runWork = tryAgentError' (work w) >>= restartOrDelete
|
||||
restartOrDelete :: Either AgentErrorType () -> AM' ()
|
||||
runWork :: m ()
|
||||
runWork = tryAllErrors' (work w) >>= restartOrDelete
|
||||
restartOrDelete :: Either e () -> m ()
|
||||
restartOrDelete e_ = do
|
||||
t <- liftIO getSystemTime
|
||||
maxRestarts <- asks $ maxWorkerRestartsPerMin . config
|
||||
let maxRestarts = maxWorkerRestartsPerMin $ config agentEnv
|
||||
-- worker may terminate because it was deleted from the map (getWorker returns Nothing), then it won't restart
|
||||
restart <- atomically $ getWorker >>= maybe (pure False) (shouldRestart e_ (toW w) t maxRestarts)
|
||||
when restart runWork
|
||||
@@ -431,7 +433,7 @@ newWorker c = do
|
||||
restarts <- newTVar $ RestartCount 0 0
|
||||
pure Worker {workerId, doWork, action, restarts}
|
||||
|
||||
runWorkerAsync :: Worker -> AM' () -> AM' ()
|
||||
runWorkerAsync :: MonadUnliftIO m => Worker -> m () -> m ()
|
||||
runWorkerAsync Worker {action} work =
|
||||
E.bracket
|
||||
(atomically $ takeTMVar action) -- get current action, locking to avoid race conditions
|
||||
@@ -665,7 +667,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
|
||||
pure (clnt, sess)
|
||||
newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay)
|
||||
newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv =
|
||||
tryAgentError (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case
|
||||
tryAllErrors (liftClient SMP (clientServer smp) $ connectSMPProxiedRelay smp nm destSrv proxyAuth) >>= \case
|
||||
Right sess -> do
|
||||
atomically $ putTMVar (sessionVar rv) (Right sess)
|
||||
pure $ Right sess
|
||||
@@ -688,7 +690,7 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq
|
||||
smpConnectClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession -> TMap SMPServer ProxiedRelayVar -> SMPClientVar -> AM SMPConnectedClient
|
||||
smpConnectClient c@AgentClient {smpClients, msgQ, proxySessTs, presetDomains} nm tSess@(_, srv, _) prs v =
|
||||
newProtocolClient c tSess smpClients connectClient v
|
||||
`catchAgentError` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
|
||||
`catchAllErrors` \e -> lift (resubscribeSMPSession c tSess) >> throwE e
|
||||
where
|
||||
connectClient :: SMPClientVar -> AM SMPConnectedClient
|
||||
connectClient v' = do
|
||||
@@ -866,7 +868,7 @@ newProtocolClient ::
|
||||
ClientVar msg ->
|
||||
AM (Client msg)
|
||||
newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
tryAgentError (connectClient v) >>= \case
|
||||
tryAllErrors (connectClient v) >>= \case
|
||||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv <> " (user " <> bshow userId <> maybe "" (" for entity " <>) entityId_ <> ")"
|
||||
atomically $ putTMVar (sessionVar v) (Right client)
|
||||
@@ -1027,7 +1029,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
|
||||
withClient_ :: forall a v err msg. ProtocolServerClient v err msg => AgentClient -> NetworkRequestMode -> TransportSession msg -> (Client msg -> AM a) -> AM a
|
||||
withClient_ c nm tSess@(_, srv, _) action = do
|
||||
cl <- getProtocolServerClient c nm tSess
|
||||
action cl `catchAgentError` logServerError
|
||||
action cl `catchAllErrors` logServerError
|
||||
where
|
||||
logServerError :: AgentErrorType -> AM a
|
||||
logServerError e = do
|
||||
@@ -1040,7 +1042,7 @@ withProxySession c nm proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = d
|
||||
logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr
|
||||
case sess_ of
|
||||
Right sess -> do
|
||||
r <- action (cl, sess) `catchAgentError` logServerError cl
|
||||
r <- action (cl, sess) `catchAllErrors` logServerError cl
|
||||
logServer ("<-- " <> proxySrv cl <> " <") c destSrv entId "OK"
|
||||
pure r
|
||||
Left e -> logServerError cl e
|
||||
@@ -1117,7 +1119,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn
|
||||
unknownServer = liftIO $ maybe True (\srvs -> all (`S.notMember` knownHosts srvs) destHosts) <$> TM.lookupIO userId (smpServers c)
|
||||
sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer, a)
|
||||
sendViaProxy proxySrv_ destSess@(_, _, connId_) = do
|
||||
r <- tryAgentError . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do
|
||||
r <- tryAllErrors . withProxySession c nm proxySrv_ destSess entId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do
|
||||
r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess
|
||||
let proxySrv = protocolClientServer' smp
|
||||
case r' of
|
||||
@@ -1164,7 +1166,7 @@ sendOrProxySMPCommand c nm userId destSrv@ProtocolServer {host = destHosts} conn
|
||||
| otherwise -> throwE e
|
||||
sendDirectly tSess =
|
||||
withLogClient_ c nm tSess (unEntityId entId) ("SEND " <> cmdStr) $ \(SMPConnectedClient smp _) -> do
|
||||
tryAgentError (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case
|
||||
tryAllErrors (liftClient SMP (clientServer smp) $ sendCmdDirectly smp) >>= \case
|
||||
Right r -> r <$ atomically (incSMPServerStat c userId destSrv sentDirect)
|
||||
Left e -> throwE e
|
||||
|
||||
@@ -1198,12 +1200,12 @@ protocolClientError protocolError_ host = \case
|
||||
PCEResponseError e -> BROKER host $ RESPONSE $ B.unpack $ smpEncode e
|
||||
PCEUnexpectedResponse e -> BROKER host $ UNEXPECTED $ B.unpack e
|
||||
PCEResponseTimeout -> BROKER host TIMEOUT
|
||||
PCENetworkError -> BROKER host NETWORK
|
||||
PCENetworkError e -> BROKER host $ NETWORK e
|
||||
PCEIncompatibleHost -> BROKER host HOST
|
||||
PCETransportError e -> BROKER host $ TRANSPORT e
|
||||
e@PCECryptoError {} -> INTERNAL $ show e
|
||||
PCEServiceUnavailable {} -> BROKER host NO_SERVICE
|
||||
PCEIOError {} -> BROKER host NETWORK
|
||||
PCEIOError e -> BROKER host $ NETWORK $ NEConnectError $ E.displayException e
|
||||
|
||||
data ProtocolTestStep
|
||||
= TSConnect
|
||||
@@ -1477,7 +1479,7 @@ temporaryAgentError = \case
|
||||
_ -> False
|
||||
where
|
||||
tempBrokerError = \case
|
||||
NETWORK -> True
|
||||
NETWORK _ -> True
|
||||
TIMEOUT -> True
|
||||
_ -> False
|
||||
|
||||
@@ -1517,7 +1519,7 @@ subscribeQueues c qs = do
|
||||
subscribeQueues_ env session smp qs' = do
|
||||
let (userId, srv, _) = transportSession' smp
|
||||
atomically $ incSMPServerStat' c userId srv connSubAttempts $ length qs'
|
||||
rs <- sendBatch (\smp' _ -> subscribeSMPQueues smp') smp NRMBackground qs'
|
||||
rs <- sendBatch (\smp' _ -> subscribeSMPQueues smp') smp NRMBackground qs'
|
||||
active <-
|
||||
atomically $
|
||||
ifM
|
||||
@@ -1528,7 +1530,8 @@ subscribeQueues c qs = do
|
||||
then when (hasTempErrors rs) resubscribe $> rs
|
||||
else do
|
||||
logWarn "subcription batch result for replaced SMP client, resubscribing"
|
||||
resubscribe $> L.map (second $ \_ -> Left PCENetworkError) rs
|
||||
-- TODO we probably use PCENetworkError here instead of the original error, so it becomes temporary.
|
||||
resubscribe $> L.map (second $ Left . PCENetworkError . NESubscribeError . show) rs
|
||||
where
|
||||
tSess = transportSession' smp
|
||||
sessId = sessionId $ thParams smp
|
||||
@@ -1562,7 +1565,7 @@ sendTSessionBatches statCmd toRQ action c nm qs =
|
||||
in M.alter (Just . maybe [q] (q <|)) tSess m
|
||||
sendClientBatch :: (SMPTransportSession, NonEmpty q) -> AM' (BatchResponses q AgentErrorType r)
|
||||
sendClientBatch (tSess@(_, srv, _), qs') =
|
||||
tryAgentError' (getSMPServerClient c nm tSess) >>= \case
|
||||
tryAllErrors' (getSMPServerClient c nm tSess) >>= \case
|
||||
Left e -> pure $ L.map (,Left e) qs'
|
||||
Right (SMPConnectedClient smp _) -> liftIO $ do
|
||||
logServer' "-->" c srv (bshow (length qs') <> " queues") statCmd
|
||||
@@ -1867,7 +1870,7 @@ withNtfBatch ::
|
||||
AM' (NonEmpty (Either AgentErrorType r))
|
||||
withNtfBatch cmdStr action c NtfToken {ntfServer, ntfPrivKey} subs = do
|
||||
let tSess = (0, ntfServer, Nothing)
|
||||
tryAgentError' (getNtfServerClient c NRMBackground tSess) >>= \case
|
||||
tryAllErrors' (getNtfServerClient c NRMBackground tSess) >>= \case
|
||||
Left e -> pure $ L.map (\_ -> Left e) subs
|
||||
Right ntf -> liftIO $ do
|
||||
logServer' "-->" c ntfServer (bshow (length subs) <> " subscriptions") cmdStr
|
||||
@@ -1968,40 +1971,42 @@ waitForWork = void . atomically . readTMVar
|
||||
{-# INLINE waitForWork #-}
|
||||
|
||||
withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (Maybe a))) -> (a -> AM ()) -> AM ()
|
||||
withWork c doWork getWork action =
|
||||
withStore' c getWork >>= \case
|
||||
withWork c doWork = withWork_ c doWork . withStore' c
|
||||
{-# INLINE withWork #-}
|
||||
|
||||
withWork_ :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m ()
|
||||
withWork_ c doWork getWork action =
|
||||
getWork >>= \case
|
||||
Right (Just r) -> action r
|
||||
Right Nothing -> noWork
|
||||
-- worker is stopped here (noWork) because the next iteration is likely to produce the same result
|
||||
Left e@SEWorkItemError {} -> noWork >> notifyErr (CRITICAL False) e
|
||||
Left e -> notifyErr INTERNAL e
|
||||
Left e
|
||||
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
|
||||
| otherwise -> notifyErr INTERNAL e
|
||||
where
|
||||
noWork = liftIO $ noWorkToDo doWork
|
||||
notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e)
|
||||
|
||||
withWorkItems :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError [Either StoreError a])) -> (NonEmpty a -> AM ()) -> AM ()
|
||||
withWorkItems :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' [Either e' a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m ()
|
||||
withWorkItems c doWork getWork action = do
|
||||
withStore' c getWork >>= \case
|
||||
getWork >>= \case
|
||||
Right [] -> noWork
|
||||
Right rs -> do
|
||||
let (errs, items) = partitionEithers rs
|
||||
case L.nonEmpty items of
|
||||
Just items' -> action items'
|
||||
Nothing -> do
|
||||
let criticalErr = find workItemError errs
|
||||
let criticalErr = find isWorkItemError errs
|
||||
forM_ criticalErr $ \err -> do
|
||||
notifyErr (CRITICAL False) err
|
||||
when (all workItemError errs) noWork
|
||||
when (all isWorkItemError errs) noWork
|
||||
unless (null errs) $
|
||||
atomically $
|
||||
writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS $ map (\e -> ("", INTERNAL $ show e)) errs)
|
||||
Left e
|
||||
| workItemError e -> noWork >> notifyErr (CRITICAL False) e
|
||||
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
|
||||
| otherwise -> notifyErr INTERNAL e
|
||||
where
|
||||
workItemError = \case
|
||||
SEWorkItemError {} -> True
|
||||
_ -> False
|
||||
noWork = liftIO $ noWorkToDo doWork
|
||||
notifyErr err e = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e)
|
||||
|
||||
|
||||
@@ -27,11 +27,6 @@ module Simplex.Messaging.Agent.Env.SQLite
|
||||
serverHosts,
|
||||
defaultAgentConfig,
|
||||
defaultReconnectInterval,
|
||||
tryAgentError,
|
||||
tryAgentError',
|
||||
catchAgentError,
|
||||
catchAgentError',
|
||||
agentFinally,
|
||||
Env (..),
|
||||
newSMPAgentEnv,
|
||||
createAgentStore,
|
||||
@@ -45,7 +40,6 @@ module Simplex.Messaging.Agent.Env.SQLite
|
||||
where
|
||||
|
||||
import Control.Concurrent (ThreadId)
|
||||
import Control.Exception (BlockedIndefinitelyOnSTM (..), SomeException, fromException)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
@@ -70,7 +64,7 @@ import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store (createStore)
|
||||
import Simplex.Messaging.Agent.Store.Common (DBStore)
|
||||
import Simplex.Messaging.Agent.Store.Interface (DBOpts)
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationError (..))
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange)
|
||||
@@ -83,7 +77,6 @@ import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPVersion)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryAllErrors, tryAllErrors')
|
||||
import System.Mem.Weak (Weak)
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO.STM
|
||||
@@ -273,7 +266,7 @@ newSMPAgentEnv config store = do
|
||||
multicastSubscribers <- newTMVarIO 0
|
||||
pure Env {config, store, random, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers}
|
||||
|
||||
createAgentStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore)
|
||||
createAgentStore :: DBOpts -> MigrationConfig -> IO (Either MigrationError DBStore)
|
||||
createAgentStore = createStore
|
||||
|
||||
data NtfSupervisor = NtfSupervisor
|
||||
@@ -312,33 +305,6 @@ newXFTPAgent = do
|
||||
xftpDelWorkers <- TM.emptyIO
|
||||
pure XFTPAgent {xftpWorkDir, xftpRcvWorkers, xftpSndWorkers, xftpDelWorkers}
|
||||
|
||||
tryAgentError :: AM a -> AM (Either AgentErrorType a)
|
||||
tryAgentError = tryAllErrors mkInternal
|
||||
{-# INLINE tryAgentError #-}
|
||||
|
||||
-- unlike runExceptT, this ensures we catch IO exceptions as well
|
||||
tryAgentError' :: AM a -> AM' (Either AgentErrorType a)
|
||||
tryAgentError' = tryAllErrors' mkInternal
|
||||
{-# INLINE tryAgentError' #-}
|
||||
|
||||
catchAgentError :: AM a -> (AgentErrorType -> AM a) -> AM a
|
||||
catchAgentError = catchAllErrors mkInternal
|
||||
{-# INLINE catchAgentError #-}
|
||||
|
||||
catchAgentError' :: AM a -> (AgentErrorType -> AM' a) -> AM' a
|
||||
catchAgentError' = catchAllErrors' mkInternal
|
||||
{-# INLINE catchAgentError' #-}
|
||||
|
||||
agentFinally :: AM a -> AM b -> AM a
|
||||
agentFinally = allFinally mkInternal
|
||||
{-# INLINE agentFinally #-}
|
||||
|
||||
mkInternal :: SomeException -> AgentErrorType
|
||||
mkInternal e = case fromException e of
|
||||
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
|
||||
_ -> INTERNAL $ show e
|
||||
{-# INLINE mkInternal #-}
|
||||
|
||||
data Worker = Worker
|
||||
{ workerId :: Int,
|
||||
doWork :: TMVar (),
|
||||
|
||||
@@ -52,7 +52,7 @@ import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Protocol (NtfServer, sameSrvAddr)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (diffToMicroseconds, threadDelay', tshow, whenM)
|
||||
import Simplex.Messaging.Util (catchAllErrors, diffToMicroseconds, threadDelay', tryAllErrors, tshow, whenM)
|
||||
import System.Random (randomR)
|
||||
import UnliftIO
|
||||
import UnliftIO.Concurrent (forkIO)
|
||||
@@ -217,7 +217,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
runNtfOperation :: AM ()
|
||||
runNtfOperation = do
|
||||
ntfBatchSize <- asks $ ntfBatchSize . config
|
||||
withWorkItems c doWork (\db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubNTFActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
logInfo $ "runNtfWorker - length nextSubs = " <> tshow (length nextSubs)
|
||||
currTs <- liftIO getCurrentTime
|
||||
let (creates, checks, deletes, rotates) = splitActions currTs nextSubs
|
||||
@@ -357,7 +357,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
runCatching :: (NtfSubscription -> AM (Maybe NtfSubscription)) -> NtfSubscription -> AM' (Maybe NtfSubscription)
|
||||
runCatching action sub@NtfSubscription {connId} =
|
||||
fromRight Nothing
|
||||
<$> runExceptT (action sub `catchAgentError` \e -> workerInternalError c connId (show e) $> Nothing)
|
||||
<$> runExceptT (action sub `catchAllErrors` \e -> workerInternalError c connId (show e) $> Nothing)
|
||||
-- deleteNtfSub is only used in NSADelete and NSARotate, so also deprecated
|
||||
deleteNtfSub :: NtfSubscription -> AM () -> AM (Maybe NtfSubscription)
|
||||
deleteNtfSub sub@NtfSubscription {userId, ntfSubId} continue = case ntfSubId of
|
||||
@@ -365,7 +365,7 @@ runNtfWorker c srv Worker {doWork} =
|
||||
lift getNtfToken >>= \case
|
||||
Just tkn@NtfToken {ntfServer} -> do
|
||||
atomically $ incNtfServerStat c userId ntfServer ntfDelAttempts
|
||||
tryAgentError (agentNtfDeleteSubscription c nSubId tkn) >>= \case
|
||||
tryAllErrors (agentNtfDeleteSubscription c nSubId tkn) >>= \case
|
||||
Right _ -> do
|
||||
atomically $ incNtfServerStat c userId ntfServer ntfDeleted
|
||||
continue'
|
||||
@@ -385,7 +385,7 @@ runNtfSMPWorker c srv Worker {doWork} = forever $ do
|
||||
runNtfSMPOperation :: AM ()
|
||||
runNtfSMPOperation = do
|
||||
ntfBatchSize <- asks $ ntfBatchSize . config
|
||||
withWorkItems c doWork (\db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
withWorkItems c doWork (withStore' c $ \db -> getNextNtfSubSMPActions db srv ntfBatchSize) $ \nextSubs -> do
|
||||
logInfo $ "runNtfSMPWorker - length nextSubs = " <> tshow (length nextSubs)
|
||||
let (creates, deletes) = splitActions nextSubs
|
||||
retrySubActions c creates createNotifierKeys
|
||||
@@ -567,7 +567,7 @@ runNtfTknDelWorker c srv Worker {doWork} =
|
||||
withRetryInterval ri $ \_ loop -> do
|
||||
liftIO $ waitWhileSuspended c
|
||||
liftIO $ waitForUserNetwork c
|
||||
processTknToDelete nextTknToDelete `catchAgentError` retryTmpError loop nextTknToDelete
|
||||
processTknToDelete nextTknToDelete `catchAllErrors` retryTmpError loop nextTknToDelete
|
||||
retryTmpError :: AM () -> NtfTokenToDelete -> AgentErrorType -> AM ()
|
||||
retryTmpError loop (tknDbId, _, _) e = do
|
||||
logError $ "ntf tkn del error: " <> tshow e
|
||||
|
||||
@@ -173,6 +173,7 @@ module Simplex.Messaging.Agent.Protocol
|
||||
where
|
||||
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Exception (BlockedIndefinitelyOnSTM (..), fromException)
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..), Value (..), (.:), (.:?))
|
||||
import qualified Data.Aeson as J'
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
@@ -1866,6 +1867,12 @@ data AgentErrorType
|
||||
INACTIVE
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
instance AnyError AgentErrorType where
|
||||
fromSomeException e = case fromException e of
|
||||
Just BlockedIndefinitelyOnSTM -> CRITICAL True "Thread blocked indefinitely in STM transaction"
|
||||
_ -> INTERNAL $ show e
|
||||
{-# INLINE fromSomeException #-}
|
||||
|
||||
-- | SMP agent protocol command or response error.
|
||||
data CommandErrorType
|
||||
= -- | command is prohibited in this context
|
||||
|
||||
@@ -29,10 +29,11 @@ import Data.Time (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval (RI2State)
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Agent.Store.Common
|
||||
import Simplex.Messaging.Agent.Store.Interface (createDBStore)
|
||||
import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations)
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationError (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (MsgEncryptKeyX448, PQEncryption, PQSupport, RatchetX448)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
@@ -52,9 +53,9 @@ import Simplex.Messaging.Protocol
|
||||
VersionSMPC,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Util (AnyError (..), bshow)
|
||||
|
||||
createStore :: DBOpts -> MigrationConfirmation -> IO (Either MigrationError DBStore)
|
||||
createStore :: DBOpts -> MigrationConfig -> IO (Either MigrationError DBStore)
|
||||
createStore dbOpts = createDBStore dbOpts appMigrations
|
||||
|
||||
-- * Queue types
|
||||
@@ -692,7 +693,20 @@ data StoreError
|
||||
| -- | XFTP Deleted snd chunk replica not found.
|
||||
SEDeletedSndChunkReplicaNotFound
|
||||
| -- | Error when reading work item that suspends worker - do not use!
|
||||
SEWorkItemError ByteString
|
||||
SEWorkItemError {errContext :: String}
|
||||
| -- | Servers stats not found.
|
||||
SEServersStatsNotFound
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
instance AnyError StoreError where
|
||||
fromSomeException = SEInternal . bshow
|
||||
|
||||
class (Show e, AnyError e) => AnyStoreError e where
|
||||
isWorkItemError :: e -> Bool
|
||||
mkWorkItemError :: String -> e
|
||||
|
||||
instance AnyStoreError StoreError where
|
||||
isWorkItemError = \case
|
||||
SEWorkItemError {} -> True
|
||||
_ -> False
|
||||
mkWorkItemError errContext = SEWorkItemError {errContext}
|
||||
|
||||
@@ -43,7 +43,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
getDeletedConn,
|
||||
getConns,
|
||||
getDeletedConns,
|
||||
getConnData,
|
||||
getConnsData,
|
||||
setConnDeleted,
|
||||
setConnUserId,
|
||||
setConnAgentVersion,
|
||||
@@ -237,6 +237,8 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
firstRow',
|
||||
maybeFirstRow,
|
||||
fromOnlyBI,
|
||||
getWorkItem,
|
||||
getWorkItems,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -255,8 +257,9 @@ import Data.List (foldl', sortBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
|
||||
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
|
||||
import Data.Ord (Down (..))
|
||||
import qualified Data.Set as S
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime)
|
||||
import Data.Word (Word32)
|
||||
@@ -285,12 +288,14 @@ import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Agent.Store.Entity
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, firstRow, firstRow', ifM, maybeFirstRow, tshow, ($>>=), (<$$>))
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
#if defined(dbPostgres)
|
||||
import Database.PostgreSQL.Simple (Only (..), Query, SqlError, (:.) (..))
|
||||
import Data.List (sortOn)
|
||||
import Data.Map.Strict (Map)
|
||||
import Database.PostgreSQL.Simple (In (..), Only (..), Query, SqlError, (:.) (..))
|
||||
import Database.PostgreSQL.Simple.Errors (constraintViolation)
|
||||
import Database.PostgreSQL.Simple.SqlQQ (sql)
|
||||
#else
|
||||
@@ -424,15 +429,12 @@ deleteConnRecord :: DB.Connection -> ConnId -> IO ()
|
||||
deleteConnRecord db connId = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId)
|
||||
|
||||
checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool
|
||||
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
|
||||
fromMaybe False
|
||||
<$> maybeFirstRow
|
||||
fromOnly
|
||||
( DB.query
|
||||
db
|
||||
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
|
||||
(host server, port server, sndId, New)
|
||||
)
|
||||
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} =
|
||||
maybeFirstRow' False fromOnlyBI $
|
||||
DB.query
|
||||
db
|
||||
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
|
||||
(host server, port server, sndId, New)
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
||||
@@ -966,28 +968,25 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} =
|
||||
_ -> Left $ SEInternal "unexpected snd msg data"
|
||||
markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId)
|
||||
|
||||
getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a))
|
||||
getWorkItem :: (Show i, AnyStoreError e) => String -> IO (Maybe i) -> (i -> IO (Either e a)) -> (i -> IO ()) -> IO (Either e (Maybe a))
|
||||
getWorkItem itemName getId getItem markFailed =
|
||||
runExceptT $ handleWrkErr itemName "getId" getId >>= mapM (tryGetItem itemName getItem markFailed)
|
||||
|
||||
getWorkItems :: Show i => ByteString -> IO [i] -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError [Either StoreError a])
|
||||
getWorkItems :: (Show i, AnyStoreError e) => String -> IO [i] -> (i -> IO (Either e a)) -> (i -> IO ()) -> IO (Either e [Either e a])
|
||||
getWorkItems itemName getIds getItem markFailed =
|
||||
runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed)
|
||||
|
||||
tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a
|
||||
tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e
|
||||
tryGetItem :: (Show i, AnyStoreError e) => String -> (i -> IO (Either e a)) -> (i -> IO ()) -> i -> ExceptT e IO a
|
||||
tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchAllErrors` \e -> mark >> throwE e
|
||||
where
|
||||
mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId
|
||||
|
||||
catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a
|
||||
catchStoreError = catchAllErrors (SEInternal . bshow)
|
||||
mark = handleWrkErr itemName ("markFailed ID " <> show itemId) $ markFailed itemId
|
||||
|
||||
-- Errors caught by this function will suspend worker as if there is no more work,
|
||||
handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a
|
||||
handleWrkErr :: forall e a. AnyStoreError e => String -> String -> IO a -> ExceptT e IO a
|
||||
handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action
|
||||
where
|
||||
mkError :: E.SomeException -> StoreError
|
||||
mkError e = SEWorkItemError $ itemName <> " " <> opName <> " error: " <> bshow e
|
||||
mkError :: E.SomeException -> e
|
||||
mkError e = mkWorkItemError $ itemName <> " " <> opName <> " error: " <> show e
|
||||
|
||||
updatePendingMsgRIState :: DB.Connection -> ConnId -> InternalId -> RI2State -> IO ()
|
||||
updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} =
|
||||
@@ -1073,15 +1072,12 @@ toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity,
|
||||
in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck}
|
||||
|
||||
checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool
|
||||
checkRcvMsgHashExists db connId hash = do
|
||||
fromMaybe False
|
||||
<$> maybeFirstRow
|
||||
fromOnly
|
||||
( DB.query
|
||||
db
|
||||
"SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
||||
(connId, Binary hash)
|
||||
)
|
||||
checkRcvMsgHashExists db connId hash =
|
||||
maybeFirstRow' False fromOnlyBI $
|
||||
DB.query
|
||||
db
|
||||
"SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
||||
(connId, Binary hash)
|
||||
|
||||
getRcvMsgBrokerTs :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Either StoreError BrokerTs)
|
||||
getRcvMsgBrokerTs db connId msgId =
|
||||
@@ -1305,21 +1301,26 @@ insertedRowId db = fromOnly . head <$> DB.query_ db q
|
||||
q = "SELECT last_insert_rowid()"
|
||||
#endif
|
||||
|
||||
getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer]
|
||||
getPendingCommandServers db connId = do
|
||||
getPendingCommandServers :: DB.Connection -> [ConnId] -> IO [(ConnId, NonEmpty (Maybe SMPServer))]
|
||||
getPendingCommandServers db connIds =
|
||||
-- TODO review whether this can break if, e.g., the server has another key hash.
|
||||
map smpServer
|
||||
<$> DB.query
|
||||
mapMaybe connServers . groupOn' rowConnId
|
||||
<$> DB.query_
|
||||
db
|
||||
[sql|
|
||||
SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash)
|
||||
SELECT DISTINCT c.conn_id, c.host, c.port, COALESCE(c.server_key_hash, s.key_hash)
|
||||
FROM commands c
|
||||
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
|
||||
WHERE conn_id = ?
|
||||
ORDER BY c.conn_id
|
||||
|]
|
||||
(Only connId)
|
||||
where
|
||||
rowConnId (Only connId :. _) = connId
|
||||
connServers rs =
|
||||
let connId = rowConnId $ L.head rs
|
||||
srvs = L.map (\(_ :. r) -> smpServer r) rs
|
||||
in if connId `S.member` conns then Just (connId, srvs) else Nothing
|
||||
smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash
|
||||
conns = S.fromList connIds
|
||||
|
||||
getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand))
|
||||
getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed
|
||||
@@ -2037,21 +2038,19 @@ getDeletedConn = getAnyConn True
|
||||
{-# INLINE getDeletedConn #-}
|
||||
|
||||
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getAnyConn deleted' dbConn connId =
|
||||
getConnData dbConn connId >>= \case
|
||||
getAnyConn deleted' db connId =
|
||||
getConnData deleted' db connId >>= \case
|
||||
Just (cData, cMode) -> do
|
||||
rQ <- getRcvQueuesByConnId_ db connId
|
||||
sQ <- getSndQueuesByConnId_ db connId
|
||||
pure $ case (rQ, sQ, cMode) of
|
||||
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
||||
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
||||
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
||||
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
||||
_ -> Left SEConnNotFound
|
||||
Nothing -> pure $ Left SEConnNotFound
|
||||
Just (cData@ConnData {deleted}, cMode)
|
||||
| deleted /= deleted' -> pure $ Left SEConnNotFound
|
||||
| otherwise -> do
|
||||
rQ <- getRcvQueuesByConnId_ dbConn connId
|
||||
sQ <- getSndQueuesByConnId_ dbConn connId
|
||||
pure $ case (rQ, sQ, cMode) of
|
||||
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
||||
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
||||
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
||||
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
||||
getConns = getAnyConns_ False
|
||||
@@ -2061,28 +2060,84 @@ getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
||||
getDeletedConns = getAnyConns_ True
|
||||
{-# INLINE getDeletedConns #-}
|
||||
|
||||
#if defined(dbPostgres)
|
||||
getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
||||
getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db
|
||||
getAnyConns_ deleted' db connIds = do
|
||||
cs <- getConnsData_ deleted' db connIds
|
||||
let connIds' = M.keys cs
|
||||
rQs :: Map ConnId (NonEmpty RcvQueue) <- getRcvQueuesByConnIds_ connIds'
|
||||
sQs :: Map ConnId (NonEmpty SndQueue) <- getSndQueuesByConnIds_ connIds'
|
||||
pure $ map (result cs rQs sQs) connIds
|
||||
where
|
||||
handleDBError :: E.SomeException -> IO (Either StoreError SomeConn)
|
||||
handleDBError = pure . Left . SEInternal . bshow
|
||||
getRcvQueuesByConnIds_ connIds' =
|
||||
toQueueMap primaryFirst toRcvQueue
|
||||
<$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds'))
|
||||
where
|
||||
primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} =
|
||||
compare (Down p) (Down p') <> compare i i'
|
||||
getSndQueuesByConnIds_ connIds' =
|
||||
toQueueMap primaryFirst toSndQueue
|
||||
<$> DB.query db (sndQueueQuery <> " WHERE q.conn_id IN ?") (Only (In connIds'))
|
||||
where
|
||||
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
|
||||
compare (Down p) (Down p') <> compare i i'
|
||||
toQueueMap primaryFst toQueue =
|
||||
M.fromList . map (\qs@(q :| _) -> (qConnId q, L.sortBy primaryFst qs)) . groupOn' qConnId . sortOn qConnId . map toQueue
|
||||
result cs rQs sQs connId = case M.lookup connId cs of
|
||||
Just (cData, cMode) -> case (M.lookup connId rQs, M.lookup connId sQs, cMode) of
|
||||
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
|
||||
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
|
||||
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
|
||||
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
|
||||
_ -> Left SEConnNotFound
|
||||
Nothing -> Left SEConnNotFound
|
||||
|
||||
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData db connId' =
|
||||
maybeFirstRow cData $
|
||||
getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))]
|
||||
getConnsData db connIds = do
|
||||
cs <- getConnsData_ False db connIds
|
||||
pure $ map (Right . (`M.lookup` cs)) connIds
|
||||
|
||||
getConnsData_ :: Bool -> DB.Connection -> [ConnId] -> IO (Map ConnId (ConnData, ConnectionMode))
|
||||
getConnsData_ deleted' db connIds =
|
||||
M.fromList . map ((\c@(ConnData {connId}, _) -> (connId, c)) . rowToConnData) <$>
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT
|
||||
user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
||||
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
||||
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
|
||||
FROM connections
|
||||
WHERE conn_id = ?
|
||||
WHERE conn_id IN ? AND deleted = ?
|
||||
|]
|
||||
(Only connId')
|
||||
where
|
||||
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) =
|
||||
(ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
|
||||
(In connIds, BI deleted')
|
||||
|
||||
#else
|
||||
getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
|
||||
getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db
|
||||
|
||||
getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))]
|
||||
getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db
|
||||
|
||||
handleDBError :: E.SomeException -> IO (Either StoreError a)
|
||||
handleDBError = pure . Left . SEInternal . bshow
|
||||
#endif
|
||||
|
||||
getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData deleted' db connId' =
|
||||
maybeFirstRow rowToConnData $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
||||
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
|
||||
FROM connections
|
||||
WHERE conn_id = ? AND deleted = ?
|
||||
|]
|
||||
(connId', BI deleted')
|
||||
|
||||
rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode)
|
||||
rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) =
|
||||
(ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
|
||||
|
||||
setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO ()
|
||||
setConnDeleted db waitDelivery connId
|
||||
@@ -2120,15 +2175,12 @@ addProcessedRatchetKeyHash db connId hash =
|
||||
DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, Binary hash)
|
||||
|
||||
checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool
|
||||
checkRatchetKeyHashExists db connId hash = do
|
||||
fromMaybe False
|
||||
<$> maybeFirstRow
|
||||
fromOnly
|
||||
( DB.query
|
||||
db
|
||||
"SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
||||
(connId, Binary hash)
|
||||
)
|
||||
checkRatchetKeyHashExists db connId hash =
|
||||
maybeFirstRow' False fromOnlyBI $
|
||||
DB.query
|
||||
db
|
||||
"SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
||||
(connId, Binary hash)
|
||||
|
||||
deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
||||
deleteRatchetKeyHashesExpired db ttl = do
|
||||
@@ -2906,8 +2958,8 @@ deleteSndFile' db sndFileId =
|
||||
|
||||
getSndFileDeleted :: DB.Connection -> DBSndFileId -> IO Bool
|
||||
getSndFileDeleted db sndFileId =
|
||||
fromMaybe True
|
||||
<$> maybeFirstRow fromOnlyBI (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId))
|
||||
maybeFirstRow' True fromOnlyBI $
|
||||
DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)
|
||||
|
||||
createSndFileReplica :: DB.Connection -> SndFileChunk -> NewSndChunkReplica -> IO ()
|
||||
createSndFileReplica db SndFileChunk {sndChunkId} = createSndFileReplica_ db sndChunkId
|
||||
|
||||
@@ -15,7 +15,7 @@ where
|
||||
import Control.Monad
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
import Data.Maybe (isNothing, mapMaybe)
|
||||
import Data.Maybe (isJust, isNothing, mapMaybe)
|
||||
import Simplex.Messaging.Agent.Store.Shared
|
||||
import System.Exit (exitFailure)
|
||||
import System.IO (hFlush, stdout)
|
||||
@@ -37,7 +37,7 @@ data DBMigrate = DBMigrate
|
||||
{ initialize :: IO (),
|
||||
getCurrent :: IO [Migration],
|
||||
run :: MigrationsToRun -> IO (),
|
||||
backup :: IO ()
|
||||
backup :: Maybe (IO ())
|
||||
}
|
||||
|
||||
sharedMigrateSchema :: DBMigrate -> Bool -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError ())
|
||||
@@ -54,20 +54,20 @@ sharedMigrateSchema dbm dbNew' migrations confirmMigrations = do
|
||||
| otherwise -> case confirmMigrations of
|
||||
MCYesUp -> runWithBackup ms
|
||||
MCYesUpDown -> runWithBackup ms
|
||||
MCConsole -> confirm err >> runWithBackup ms
|
||||
MCConsole -> confirm' err >> runWithBackup ms
|
||||
MCError -> pure $ Left err
|
||||
where
|
||||
err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums)
|
||||
Right ms@(MTRDown dms) -> case confirmMigrations of
|
||||
MCYesUpDown -> runWithBackup ms
|
||||
MCConsole -> confirm err >> runWithBackup ms
|
||||
MCConsole -> confirm' err >> runWithBackup ms
|
||||
MCYesUp -> pure $ Left err
|
||||
MCError -> pure $ Left err
|
||||
where
|
||||
err = MEDowngrade $ map downName dms
|
||||
where
|
||||
runWithBackup ms = backup dbm >> run dbm ms $> Right ()
|
||||
confirm err = confirmOrExit $ migrationErrorDescription err
|
||||
runWithBackup ms = sequence (backup dbm) >> run dbm ms $> Right ()
|
||||
confirm' err = confirmOrExit $ migrationErrorDescription (isJust $ backup dbm) err
|
||||
|
||||
confirmOrExit :: String -> IO ()
|
||||
confirmOrExit s = do
|
||||
|
||||
@@ -30,15 +30,15 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc
|
||||
import qualified Simplex.Messaging.Agent.Store.Postgres.Migrations as Migrations
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Common
|
||||
import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..))
|
||||
import Simplex.Messaging.Util (ifM, safeDecodeUtf8)
|
||||
import System.Exit (exitFailure)
|
||||
|
||||
-- | Create a new Postgres DBStore with the given connection string, schema name and migrations.
|
||||
-- If passed schema does not exist in connectInfo database, it will be created.
|
||||
-- Applies necessary migrations to schema.
|
||||
createDBStore :: DBOpts -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore)
|
||||
createDBStore opts migrations confirmMigrations = do
|
||||
createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore)
|
||||
createDBStore opts migrations MigrationConfig {confirm} = do
|
||||
st <- connectPostgresStore opts
|
||||
r <- migrateSchema st `onException` closeDBStore st
|
||||
case r of
|
||||
@@ -48,15 +48,16 @@ createDBStore opts migrations confirmMigrations = do
|
||||
migrateSchema st =
|
||||
let initialize = Migrations.initialize st
|
||||
getCurrent = withTransaction st Migrations.getCurrentMigrations
|
||||
dbm = DBMigrate {initialize, getCurrent, run = Migrations.run st, backup = pure ()}
|
||||
in sharedMigrateSchema dbm (dbNew st) migrations confirmMigrations
|
||||
dbm = DBMigrate {initialize, getCurrent, run = Migrations.run st, backup = Nothing}
|
||||
in sharedMigrateSchema dbm (dbNew st) migrations confirm
|
||||
|
||||
connectPostgresStore :: DBOpts -> IO DBStore
|
||||
connectPostgresStore DBOpts {connstr, schema, poolSize, createSchema} = do
|
||||
dbPriorityPool <- newDBStorePool poolSize
|
||||
dbPool <- newDBStorePool poolSize
|
||||
dbClosed <- newTVarIO True
|
||||
let st = DBStore {dbConnstr = connstr, dbSchema = schema, dbPoolSize = fromIntegral poolSize, dbPriorityPool, dbPool, dbNew = False, dbClosed}
|
||||
let dbConnect = fst <$> connectDB connstr schema False
|
||||
st = DBStore {dbConnstr = connstr, dbSchema = schema, dbPoolSize = fromIntegral poolSize, dbPriorityPool, dbPool, dbConnect, dbNew = False, dbClosed}
|
||||
dbNew <- connectStore st createSchema
|
||||
pure st {dbNew}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.Postgres.Common
|
||||
@@ -19,7 +20,7 @@ where
|
||||
|
||||
import Control.Concurrent.MVar
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (bracket)
|
||||
import qualified Control.Exception as E
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Database.PostgreSQL.Simple as PSQL
|
||||
import Numeric.Natural (Natural)
|
||||
@@ -32,11 +33,7 @@ data DBStore = DBStore
|
||||
dbPoolSize :: Int,
|
||||
dbPriorityPool :: DBStorePool,
|
||||
dbPool :: DBStorePool,
|
||||
-- dbPoolSize :: Int,
|
||||
-- dbPool :: TBQueue PSQL.Connection,
|
||||
-- -- MVar is needed for fair pool distribution, without STM retry contention.
|
||||
-- -- Only one thread can be blocked on STM read.
|
||||
-- dbSem :: MVar (),
|
||||
dbConnect :: IO PSQL.Connection,
|
||||
dbClosed :: TVar Bool,
|
||||
dbNew :: Bool
|
||||
}
|
||||
@@ -55,15 +52,23 @@ data DBStorePool = DBStorePool
|
||||
}
|
||||
|
||||
withConnectionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a
|
||||
withConnectionPriority DBStore {dbPriorityPool, dbPool} priority =
|
||||
withConnectionPool $ if priority then dbPriorityPool else dbPool
|
||||
withConnectionPriority DBStore {dbPriorityPool, dbPool, dbConnect} priority =
|
||||
withConnectionPool (if priority then dbPriorityPool else dbPool) dbConnect
|
||||
{-# INLINE withConnectionPriority #-}
|
||||
|
||||
withConnectionPool :: DBStorePool -> (PSQL.Connection -> IO a) -> IO a
|
||||
withConnectionPool DBStorePool {dbPoolConns, dbSem} =
|
||||
bracket
|
||||
(withMVar dbSem $ \_ -> atomically $ readTBQueue dbPoolConns)
|
||||
(atomically . writeTBQueue dbPoolConns)
|
||||
withConnectionPool :: DBStorePool -> IO PSQL.Connection -> (PSQL.Connection -> IO a) -> IO a
|
||||
withConnectionPool DBStorePool {dbPoolConns, dbSem} dbConnect action =
|
||||
E.mask $ \restore -> do
|
||||
conn <- withMVar dbSem $ \_ -> atomically $ readTBQueue dbPoolConns
|
||||
r <- restore (action conn) `E.onException` reset conn
|
||||
atomically $ writeTBQueue dbPoolConns conn
|
||||
pure r
|
||||
where
|
||||
reset conn = do
|
||||
conn' <- E.try dbConnect >>= \case
|
||||
Right conn' -> PSQL.close conn >> pure conn'
|
||||
Left (_ :: E.SomeException) -> pure conn
|
||||
atomically $ writeTBQueue dbPoolConns conn'
|
||||
|
||||
withConnection :: DBStore -> (PSQL.Connection -> IO a) -> IO a
|
||||
withConnection st = withConnectionPriority st False
|
||||
|
||||
@@ -57,18 +57,18 @@ import Simplex.Messaging.Agent.Store.Migrations (DBMigrate (..), sharedMigrateSc
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationError (..))
|
||||
import Simplex.Messaging.Util (ifM, safeDecodeUtf8)
|
||||
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
||||
import System.FilePath (takeDirectory)
|
||||
import System.FilePath (takeDirectory, takeFileName, (</>))
|
||||
import UnliftIO.Exception (bracketOnError, onException)
|
||||
import UnliftIO.MVar
|
||||
import UnliftIO.STM
|
||||
|
||||
-- * SQLite Store implementation
|
||||
|
||||
createDBStore :: DBOpts -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore)
|
||||
createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations confirmMigrations = do
|
||||
createDBStore :: DBOpts -> [Migration] -> MigrationConfig -> IO (Either MigrationError DBStore)
|
||||
createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations MigrationConfig {confirm, backupPath} = do
|
||||
let dbDir = takeDirectory dbFilePath
|
||||
createDirectoryIfMissing True dbDir
|
||||
st <- connectSQLiteStore dbFilePath dbKey keepKey track
|
||||
@@ -81,9 +81,12 @@ createDBStore DBOpts {dbFilePath, dbKey, keepKey, track, vacuum} migrations conf
|
||||
let initialize = Migrations.initialize st
|
||||
getCurrent = withTransaction st Migrations.getCurrentMigrations
|
||||
run = Migrations.run st vacuum
|
||||
backup = copyFile dbFilePath (dbFilePath <> ".bak")
|
||||
backup = mkBackup <$> backupPath
|
||||
mkBackup bp =
|
||||
let f = if null bp then dbFilePath else bp </> takeFileName dbFilePath
|
||||
in copyFile dbFilePath $ f <> ".bak"
|
||||
dbm = DBMigrate {initialize, getCurrent, run, backup}
|
||||
in sharedMigrateSchema dbm (dbNew st) migrations confirmMigrations
|
||||
in sharedMigrateSchema dbm (dbNew st) migrations confirm
|
||||
|
||||
connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> DB.TrackQueries -> IO DBStore
|
||||
connectSQLiteStore dbFilePath key keepKey track = do
|
||||
|
||||
@@ -52,7 +52,7 @@ import Simplex.Messaging.Util (diffToMicroseconds, tshow)
|
||||
newtype BoolInt = BI {unBI :: Bool}
|
||||
deriving newtype (FromField, ToField)
|
||||
|
||||
newtype Binary = Binary {fromBinary :: ByteString}
|
||||
newtype Binary a = Binary {fromBinary :: a}
|
||||
deriving newtype (FromField, ToField)
|
||||
|
||||
data Connection = Connection
|
||||
|
||||
@@ -9,6 +9,7 @@ module Simplex.Messaging.Agent.Store.Shared
|
||||
DownMigration (..),
|
||||
MTRError (..),
|
||||
mtrErrorDescription,
|
||||
MigrationConfig (..),
|
||||
MigrationConfirmation (..),
|
||||
MigrationError (..),
|
||||
UpMigration (..),
|
||||
@@ -55,13 +56,15 @@ data MigrationError
|
||||
| MigrationError {mtrError :: MTRError}
|
||||
deriving (Eq, Show)
|
||||
|
||||
migrationErrorDescription :: MigrationError -> String
|
||||
migrationErrorDescription = \case
|
||||
migrationErrorDescription :: Bool -> MigrationError -> String
|
||||
migrationErrorDescription withBackup = \case
|
||||
MEUpgrade ums ->
|
||||
"The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map upName ums)
|
||||
"The app has a newer version than the database.\nConfirm to " <> backupStr <> "upgrade using these migrations: " <> intercalate ", " (map upName ums)
|
||||
MEDowngrade dms ->
|
||||
"Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms
|
||||
"Database version is newer than the app.\nConfirm to " <> backupStr <> "downgrade using these migrations: " <> intercalate ", " dms
|
||||
MigrationError err -> mtrErrorDescription err
|
||||
where
|
||||
backupStr = if withBackup then "back up and " else ""
|
||||
|
||||
data UpMigration = UpMigration {upName :: String, withDown :: Bool}
|
||||
deriving (Eq, Show)
|
||||
@@ -69,6 +72,11 @@ data UpMigration = UpMigration {upName :: String, withDown :: Bool}
|
||||
upMigration :: Migration -> UpMigration
|
||||
upMigration Migration {name, down} = UpMigration name $ isJust down
|
||||
|
||||
data MigrationConfig = MigrationConfig
|
||||
{ confirm :: MigrationConfirmation,
|
||||
backupPath :: Maybe FilePath -- Nothing - no backup, empty string - the same folder
|
||||
}
|
||||
|
||||
data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError
|
||||
deriving (Eq, Show)
|
||||
|
||||
|
||||
@@ -597,12 +597,14 @@ getProtocolClient g nm transportSession@(_, srv, _) cfg@ProtocolClientConfig {qS
|
||||
socksCreds = clientSocksCredentials networkConfig proxySessTs transportSession
|
||||
tId <-
|
||||
runTransportClient tcConfig socksCreds useHost port' (Just $ keyHash srv) (client t c cVar)
|
||||
`forkFinally` \_ -> void (atomically . tryPutTMVar cVar $ Left PCENetworkError)
|
||||
`forkFinally` \r ->
|
||||
let err = either toNetworkError (const NEFailedError) r
|
||||
in void $ atomically $ tryPutTMVar cVar $ Left $ PCENetworkError err
|
||||
c_ <- netTimeoutInt tcpConnectTimeout nm `timeout` atomically (takeTMVar cVar)
|
||||
case c_ of
|
||||
Just (Right c') -> mkWeakThreadId tId >>= \tId' -> pure $ Right c' {action = Just tId'}
|
||||
Just (Left e) -> pure $ Left e
|
||||
Nothing -> killThread tId $> Left PCENetworkError
|
||||
Nothing -> killThread tId $> Left (PCENetworkError NETimeoutError)
|
||||
|
||||
useTransport :: (ServiceName, ATransport 'TClient)
|
||||
useTransport = case port srv of
|
||||
@@ -743,7 +745,7 @@ data ProtocolClientError err
|
||||
PCEResponseTimeout
|
||||
| -- | Failure to establish TCP connection.
|
||||
-- Forwarded to the agent client as `ERR BROKER NETWORK`.
|
||||
PCENetworkError
|
||||
PCENetworkError NetworkError
|
||||
| -- | No host compatible with network configuration
|
||||
PCEIncompatibleHost
|
||||
| -- | Service is unavailable for command that requires service connection
|
||||
@@ -761,7 +763,7 @@ type SMPClientError = ProtocolClientError ErrorType
|
||||
|
||||
temporaryClientError :: ProtocolClientError err -> Bool
|
||||
temporaryClientError = \case
|
||||
PCENetworkError -> True
|
||||
PCENetworkError _ -> True
|
||||
PCEResponseTimeout -> True
|
||||
PCEIOError _ -> True
|
||||
_ -> False
|
||||
@@ -782,7 +784,7 @@ smpProxyError = \case
|
||||
PCEResponseError e -> PROXY $ BROKER $ RESPONSE $ B.unpack $ strEncode e
|
||||
PCEUnexpectedResponse e -> PROXY $ BROKER $ UNEXPECTED $ B.unpack e
|
||||
PCEResponseTimeout -> PROXY $ BROKER TIMEOUT
|
||||
PCENetworkError -> PROXY $ BROKER NETWORK
|
||||
PCENetworkError e -> PROXY $ BROKER $ NETWORK e
|
||||
PCEIncompatibleHost -> PROXY $ BROKER HOST
|
||||
PCEServiceUnavailable -> PROXY $ BROKER $ NO_SERVICE -- for completeness, it cannot happen.
|
||||
PCETransportError t -> PROXY $ BROKER $ TRANSPORT t
|
||||
|
||||
@@ -391,7 +391,7 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE
|
||||
where
|
||||
logSMPError :: SMPClientError -> ExceptT SMPClientError IO a
|
||||
logSMPError e = do
|
||||
logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode $ host srv) <> "): " <> tshow e
|
||||
logInfo $ "SMP error (" <> safeDecodeUtf8 (strEncode srv) <> "): " <> tshow e
|
||||
throwE e
|
||||
|
||||
subscribeQueuesNtfs :: SMPClientAgent 'NotifierService -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO ()
|
||||
|
||||
@@ -613,7 +613,7 @@ ntfSubscriber NtfSubscriber {smpAgent = ca@SMPClientAgent {msgQ, agentQ}} =
|
||||
PCEIncompatibleHost -> Just $ NSErr "IncompatibleHost"
|
||||
PCEServiceUnavailable -> Just NSService -- this error should not happen on individual subscriptions
|
||||
PCEResponseTimeout -> Nothing
|
||||
PCENetworkError -> Nothing
|
||||
PCENetworkError _ -> Nothing
|
||||
PCEIOError _ -> Nothing
|
||||
where
|
||||
-- Note on moving to PostgreSQL: the idea of logging errors without e is removed here
|
||||
|
||||
@@ -54,7 +54,8 @@ import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Agent.Store.AgentStore ()
|
||||
import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Common
|
||||
import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.DB (fromTextField_)
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..))
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -63,7 +64,6 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat
|
||||
import Simplex.Messaging.Notifications.Server.Store.Migrations
|
||||
import Simplex.Messaging.Notifications.Server.Store.Types
|
||||
import Simplex.Messaging.Notifications.Server.StoreLog
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer)
|
||||
import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate)
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_)
|
||||
@@ -76,6 +76,8 @@ import System.IO (IOMode (..), hFlush, stdout, withFile)
|
||||
import Text.Hex (decodeHex)
|
||||
|
||||
#if !defined(dbPostgres)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util (eitherToMaybe)
|
||||
#endif
|
||||
|
||||
@@ -98,7 +100,7 @@ data NtfEntityRec (e :: NtfEntity) where
|
||||
|
||||
newNtfDbStore :: PostgresStoreCfg -> IO NtfPostgresStore
|
||||
newNtfDbStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do
|
||||
dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations confirmMigrations
|
||||
dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations (MigrationConfig confirmMigrations Nothing)
|
||||
dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath
|
||||
pure NtfPostgresStore {dbStore, dbStoreLog, deletedTTL}
|
||||
where
|
||||
|
||||
@@ -81,6 +81,7 @@ module Simplex.Messaging.Protocol
|
||||
CommandError (..),
|
||||
ProxyError (..),
|
||||
BrokerErrorType (..),
|
||||
NetworkError (..),
|
||||
BlockingInfo (..),
|
||||
BlockingReason (..),
|
||||
RawTransmission,
|
||||
@@ -168,6 +169,7 @@ module Simplex.Messaging.Protocol
|
||||
noMsgFlags,
|
||||
messageId,
|
||||
messageTs,
|
||||
toNetworkError,
|
||||
|
||||
-- * Parse and serialize
|
||||
ProtocolMsgTag (..),
|
||||
@@ -212,7 +214,7 @@ module Simplex.Messaging.Protocol
|
||||
where
|
||||
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Exception (Exception)
|
||||
import Control.Exception (Exception, SomeException, displayException, fromException)
|
||||
import Control.Monad.Except
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson.TH as J
|
||||
@@ -241,6 +243,7 @@ import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
|
||||
import qualified GHC.TypeLits as TE
|
||||
import qualified GHC.TypeLits as Type
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Network.TLS as TLS
|
||||
import Simplex.Messaging.Agent.Store.DB (Binary (..), FromField (..), ToField (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
@@ -1555,7 +1558,7 @@ data BrokerErrorType
|
||||
| -- | unexpected response
|
||||
UNEXPECTED {respErr :: String}
|
||||
| -- | network error
|
||||
NETWORK
|
||||
NETWORK {networkError :: NetworkError}
|
||||
| -- | no compatible server host (e.g. onion when public is required, or vice versa)
|
||||
HOST
|
||||
| -- | service unavailable client-side - used in agent errors
|
||||
@@ -1566,6 +1569,24 @@ data BrokerErrorType
|
||||
TIMEOUT
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
data NetworkError
|
||||
= NEConnectError {connectError :: String}
|
||||
| NETLSError {tlsError :: String}
|
||||
| NEUnknownCAError
|
||||
| NEFailedError
|
||||
| NETimeoutError
|
||||
| NESubscribeError {subscribeError :: String}
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
toNetworkError :: SomeException -> NetworkError
|
||||
toNetworkError e = maybe (NEConnectError err) fromTLSError (fromException e)
|
||||
where
|
||||
err = displayException e
|
||||
fromTLSError :: TLS.TLSException -> NetworkError
|
||||
fromTLSError = \case
|
||||
TLS.HandshakeFailed (TLS.Error_Protocol _ TLS.UnknownCa) -> NEUnknownCAError
|
||||
_ -> NETLSError err
|
||||
|
||||
data BlockingInfo = BlockingInfo
|
||||
{ reason :: BlockingReason
|
||||
}
|
||||
@@ -2001,7 +2022,7 @@ instance Encoding BrokerErrorType where
|
||||
RESPONSE e -> "RESPONSE " <> smpEncode e
|
||||
UNEXPECTED e -> "UNEXPECTED " <> smpEncode e
|
||||
TRANSPORT e -> "TRANSPORT " <> smpEncode e
|
||||
NETWORK -> "NETWORK"
|
||||
NETWORK _e -> "NETWORK" -- TODO once all upgrade: "NETWORK " <> smpEncode e
|
||||
TIMEOUT -> "TIMEOUT"
|
||||
HOST -> "HOST"
|
||||
NO_SERVICE -> "NO_SERVICE"
|
||||
@@ -2010,7 +2031,7 @@ instance Encoding BrokerErrorType where
|
||||
"RESPONSE" -> RESPONSE <$> _smpP
|
||||
"UNEXPECTED" -> UNEXPECTED <$> _smpP
|
||||
"TRANSPORT" -> TRANSPORT <$> _smpP
|
||||
"NETWORK" -> pure NETWORK
|
||||
"NETWORK" -> NETWORK <$> (_smpP <|> pure NEFailedError)
|
||||
"TIMEOUT" -> pure TIMEOUT
|
||||
"HOST" -> pure HOST
|
||||
"NO_SERVICE" -> pure NO_SERVICE
|
||||
@@ -2021,7 +2042,7 @@ instance StrEncoding BrokerErrorType where
|
||||
RESPONSE e -> "RESPONSE " <> encodeUtf8 (T.pack e)
|
||||
UNEXPECTED e -> "UNEXPECTED " <> encodeUtf8 (T.pack e)
|
||||
TRANSPORT e -> "TRANSPORT " <> smpEncode e
|
||||
NETWORK -> "NETWORK"
|
||||
NETWORK _e -> "NETWORK" -- TODO once all upgrade: "NETWORK " <> strEncode e
|
||||
TIMEOUT -> "TIMEOUT"
|
||||
HOST -> "HOST"
|
||||
NO_SERVICE -> "NO_SERVICE"
|
||||
@@ -2030,13 +2051,50 @@ instance StrEncoding BrokerErrorType where
|
||||
"RESPONSE" -> RESPONSE <$> _textP
|
||||
"UNEXPECTED" -> UNEXPECTED <$> _textP
|
||||
"TRANSPORT" -> TRANSPORT <$> _smpP
|
||||
"NETWORK" -> pure NETWORK
|
||||
"NETWORK" -> NETWORK <$> (_strP <|> pure NEFailedError)
|
||||
"TIMEOUT" -> pure TIMEOUT
|
||||
"HOST" -> pure HOST
|
||||
"NO_SERVICE" -> pure NO_SERVICE
|
||||
_ -> fail "bad BrokerErrorType"
|
||||
where
|
||||
_textP = A.space *> (T.unpack . safeDecodeUtf8 <$> A.takeByteString)
|
||||
|
||||
instance Encoding NetworkError where
|
||||
smpEncode = \case
|
||||
NEConnectError e -> "CONNECT " <> smpEncode e
|
||||
NETLSError e -> "TLS " <> smpEncode e
|
||||
NEUnknownCAError -> "UNKNOWNCA"
|
||||
NEFailedError -> "FAILED"
|
||||
NETimeoutError -> "TIMEOUT"
|
||||
NESubscribeError e -> "SUBSCRIBE " <> smpEncode e
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"CONNECT" -> NEConnectError <$> _smpP
|
||||
"TLS" -> NETLSError <$> _smpP
|
||||
"UNKNOWNCA" -> pure NEUnknownCAError
|
||||
"FAILED" -> pure NEFailedError
|
||||
"TIMEOUT" -> pure NETimeoutError
|
||||
"SUBSCRIBE" -> NESubscribeError <$> _smpP
|
||||
_ -> fail "bad NetworkError"
|
||||
|
||||
instance StrEncoding NetworkError where
|
||||
strEncode = \case
|
||||
NEConnectError e -> "CONNECT " <> encodeUtf8 (T.pack e)
|
||||
NETLSError e -> "TLS " <> encodeUtf8 (T.pack e)
|
||||
NEUnknownCAError -> "UNKNOWNCA"
|
||||
NEFailedError -> "FAILED"
|
||||
NETimeoutError -> "TIMEOUT"
|
||||
NESubscribeError e -> "SUBSCRIBE " <> encodeUtf8 (T.pack e)
|
||||
strP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"CONNECT" -> NEConnectError <$> _textP
|
||||
"TLS" -> NETLSError <$> _textP
|
||||
"UNKNOWNCA" -> pure NEUnknownCAError
|
||||
"FAILED" -> pure NEFailedError
|
||||
"TIMEOUT" -> pure NETimeoutError
|
||||
"SUBSCRIBE" -> NESubscribeError <$> _textP
|
||||
_ -> fail "bad NetworkError"
|
||||
|
||||
_textP :: Parser String
|
||||
_textP = A.space *> (T.unpack . safeDecodeUtf8 <$> A.takeByteString)
|
||||
|
||||
-- | Send signed SMP transmission to TCP transport.
|
||||
tPut :: Transport c => THandle v c p -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
|
||||
@@ -2200,6 +2258,12 @@ $(J.deriveJSON defaultJSON ''MsgFlags)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''CommandError)
|
||||
|
||||
$(J.deriveToJSON (sumTypeJSON $ dropPrefix "NE") ''NetworkError)
|
||||
|
||||
instance FromJSON NetworkError where
|
||||
parseJSON = $(J.mkParseJSON (sumTypeJSON $ dropPrefix "NE") ''NetworkError)
|
||||
omittedField = Just NEFailedError
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''BrokerErrorType)
|
||||
|
||||
$(J.deriveJSON defaultJSON ''BlockingInfo)
|
||||
|
||||
@@ -105,7 +105,7 @@ import Simplex.Messaging.Server.Control
|
||||
import Simplex.Messaging.Server.Env.STM as Env
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.MsgStore.Journal (JournalMsgStore, JournalQueue)
|
||||
import Simplex.Messaging.Server.MsgStore.Journal (JournalMsgStore, JournalQueue (..), getJournalQueueMessages)
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
import Simplex.Messaging.Server.MsgStore.Types
|
||||
import Simplex.Messaging.Server.NtfStore
|
||||
@@ -132,12 +132,17 @@ import UnliftIO.Directory (doesFileExist, renameFile)
|
||||
import UnliftIO.Exception
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM
|
||||
|
||||
#if MIN_VERSION_base(4,18,0)
|
||||
import Data.List (sort)
|
||||
import GHC.Conc (listThreads, threadStatus)
|
||||
import GHC.Conc.Sync (threadLabel)
|
||||
#endif
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
import Simplex.Messaging.Server.MsgStore.Postgres (exportDbMessages, getDbMessageStats)
|
||||
#endif
|
||||
|
||||
-- | Runs an SMP server using passed configuration.
|
||||
--
|
||||
-- See a full server here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-server/Main.hs
|
||||
@@ -477,7 +482,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt
|
||||
atomicWriteIORef (msgCount stats) stored
|
||||
atomicModifyIORef'_ (msgExpired stats) (+ expired)
|
||||
printMessageStats "STORE: messages" msgStats
|
||||
Left e -> logError $ "STORE: withAllMsgQueues, error expiring messages, " <> tshow e
|
||||
Left e -> logError $ "STORE: expireOldMessages, error expiring messages, " <> tshow e
|
||||
|
||||
expireNtfsThread :: ServerConfig s -> M s ()
|
||||
expireNtfsThread ServerConfig {notificationExpiration = expCfg} = do
|
||||
@@ -1848,10 +1853,10 @@ client
|
||||
Right body -> do
|
||||
when (isJust (queueData qr) && isSecuredMsgQueue qr) $ void $ liftIO $
|
||||
deleteQueueLinkData (queueStore ms) q
|
||||
ServerConfig {messageExpiration, msgIdBytes} <- asks config
|
||||
ServerConfig {messageExpiration, expireMessagesOnSend, msgIdBytes} <- asks config
|
||||
msgId <- randomId' msgIdBytes
|
||||
msg_ <- liftIO $ runExceptT $ do
|
||||
expireMessages messageExpiration stats
|
||||
when expireMessagesOnSend $ mapM_ (expireMessages stats) messageExpiration
|
||||
msg <- liftIO $ mkMessage msgId body
|
||||
writeMsg ms q True msg
|
||||
case msg_ of
|
||||
@@ -1875,9 +1880,9 @@ client
|
||||
msgTs <- getSystemTime
|
||||
pure $! Message msgId msgTs msgFlags body
|
||||
|
||||
expireMessages :: Maybe ExpirationConfig -> ServerStats -> ExceptT ErrorType IO ()
|
||||
expireMessages msgExp stats = do
|
||||
deleted <- maybe (pure 0) (deleteExpiredMsgs ms q <=< liftIO . expireBeforeEpoch) msgExp
|
||||
expireMessages :: ServerStats -> ExpirationConfig -> ExceptT ErrorType IO ()
|
||||
expireMessages stats msgExp = do
|
||||
deleted <- deleteExpiredMsgs ms q =<< liftIO (expireBeforeEpoch msgExp)
|
||||
liftIO $ when (deleted > 0) $ atomicModifyIORef'_ (msgExpired stats) (+ deleted)
|
||||
|
||||
-- The condition for delivery of the message is:
|
||||
@@ -2104,27 +2109,42 @@ randomId = fmap EntityId . randomId'
|
||||
{-# INLINE randomId #-}
|
||||
|
||||
saveServerMessages :: Bool -> MsgStore s -> IO ()
|
||||
saveServerMessages drainMsgs = \case
|
||||
StoreMemory ms@STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of
|
||||
saveServerMessages drainMsgs ms = case ms of
|
||||
StoreMemory STMMsgStore {storeConfig = STMStoreConfig {storePath}} -> case storePath of
|
||||
Just f -> exportMessages False ms f drainMsgs
|
||||
Nothing -> logNote "undelivered messages are not saved"
|
||||
StoreJournal _ -> logNote "closed journal message storage"
|
||||
#if defined(dbServerPostgres)
|
||||
StoreDatabase _ -> logNote "closed postgres message storage"
|
||||
#endif
|
||||
|
||||
exportMessages :: MsgStoreClass s => Bool -> s -> FilePath -> Bool -> IO ()
|
||||
exportMessages tty ms f drainMsgs = do
|
||||
exportMessages :: forall s. MsgStoreClass s => Bool -> MsgStore s -> FilePath -> Bool -> IO ()
|
||||
exportMessages tty st f drainMsgs = do
|
||||
logNote $ "saving messages to file " <> T.pack f
|
||||
liftIO $ withFile f WriteMode $ \h ->
|
||||
tryAny (unsafeWithAllMsgQueues tty True ms $ saveQueueMsgs h) >>= \case
|
||||
Right (Sum total) -> logNote $ "messages saved: " <> tshow total
|
||||
run $ case st of
|
||||
StoreMemory ms -> exportMessages_ ms $ getMsgs ms
|
||||
StoreJournal ms -> exportMessages_ ms $ getJournalMsgs ms
|
||||
#if defined(dbServerPostgres)
|
||||
StoreDatabase ms -> exportDbMessages tty ms
|
||||
#endif
|
||||
where
|
||||
exportMessages_ ms get = fmap (\(Sum n) -> n) . unsafeWithAllMsgQueues tty ms . saveQueueMsgs get
|
||||
run :: (Handle -> IO Int) -> IO ()
|
||||
run a = liftIO $ withFile f WriteMode $ tryAny . a >=> \case
|
||||
Right n -> logNote $ "messages saved: " <> tshow n
|
||||
Left e -> do
|
||||
logError $ "error exporting messages: " <> tshow e
|
||||
exitFailure
|
||||
where
|
||||
saveQueueMsgs h q = do
|
||||
msgs <-
|
||||
unsafeRunStore q "saveQueueMsgs" $
|
||||
getQueueMessages_ drainMsgs q =<< getMsgQueue ms q False
|
||||
BLD.hPutBuilder h $ encodeMessages (recipientId q) msgs
|
||||
getJournalMsgs ms q =
|
||||
readTVarIO (msgQueue' q) >>= \case
|
||||
Just _ -> getMsgs ms q
|
||||
Nothing -> getJournalQueueMessages ms q
|
||||
getMsgs :: MsgStoreClass s' => s' -> StoreQueue s' -> IO [Message]
|
||||
getMsgs ms q = unsafeRunStore q "saveQueueMsgs" $ getQueueMessages_ drainMsgs q =<< getMsgQueue ms q False
|
||||
saveQueueMsgs :: (StoreQueue s -> IO [Message]) -> Handle -> StoreQueue s -> IO (Sum Int)
|
||||
saveQueueMsgs get h q = do
|
||||
msgs <- get q
|
||||
unless (null msgs) $ BLD.hPutBuilder h $ encodeMessages (recipientId q) msgs
|
||||
pure $ Sum $ length msgs
|
||||
encodeMessages rId = mconcat . map (\msg -> BLD.byteString (strEncode $ MLRv3 rId msg) <> BLD.char8 '\n')
|
||||
|
||||
@@ -2140,6 +2160,9 @@ processServerMessages StartOptions {skipWarnings} = do
|
||||
Just f -> ifM (doesFileExist f) (Just <$> importMessages False ms f old_ skipWarnings) (pure Nothing)
|
||||
Nothing -> pure Nothing
|
||||
StoreJournal ms -> processJournalMessages old_ expire ms
|
||||
#if defined(dbServerPostgres)
|
||||
StoreDatabase ms -> processDbMessages old_ expire ms
|
||||
#endif
|
||||
processJournalMessages :: forall s. Maybe Int64 -> Bool -> JournalMsgStore s -> IO (Maybe MessageStats)
|
||||
processJournalMessages old_ expire ms
|
||||
| expire = Just <$> case old_ of
|
||||
@@ -2151,7 +2174,7 @@ processServerMessages StartOptions {skipWarnings} = do
|
||||
run processValidateQueue
|
||||
| otherwise = logWarn "skipping message expiration" $> Nothing
|
||||
where
|
||||
run a = unsafeWithAllMsgQueues False False ms a `catchAny` \_ -> exitFailure
|
||||
run a = unsafeWithAllMsgQueues False ms a `catchAny` \_ -> exitFailure
|
||||
processExpireQueue :: Int64 -> JournalQueue s -> IO MessageStats
|
||||
processExpireQueue old q = unsafeRunStore q "processExpireQueue" $ do
|
||||
mq <- getMsgQueue ms q False
|
||||
@@ -2162,6 +2185,17 @@ processServerMessages StartOptions {skipWarnings} = do
|
||||
processValidateQueue q = unsafeRunStore q "processValidateQueue" $ do
|
||||
storedMsgsCount <- getQueueSize_ =<< getMsgQueue ms q False
|
||||
pure newMessageStats {storedMsgsCount, storedQueues = 1}
|
||||
#if defined(dbServerPostgres)
|
||||
processDbMessages old_ expire ms
|
||||
| expire = Just <$> case old_ of
|
||||
Just old -> do
|
||||
-- TODO [messages] expire messages from all queues, not only recent
|
||||
logNote "expiring database store messages..."
|
||||
now <- systemSeconds <$> getSystemTime
|
||||
expireOldMessages False ms now (now - old)
|
||||
Nothing -> getDbMessageStats ms
|
||||
| otherwise = logWarn "skipping message expiration" $> Nothing
|
||||
#endif
|
||||
|
||||
importMessages :: forall s. MsgStoreClass s => Bool -> s -> FilePath -> Maybe Int64 -> Bool -> IO MessageStats
|
||||
importMessages tty ms f old_ skipWarnings = do
|
||||
|
||||
@@ -33,7 +33,7 @@ import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..))
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), ProtocolServer (..), ProtocolTypeI)
|
||||
import Simplex.Messaging.Server.Env.STM (ServerStoreCfg (..), StartOptions (..), StorePaths (..))
|
||||
import Simplex.Messaging.Server.Env.STM (ServerStoreCfg (..), StartOptions (..), dbStoreCfg, storeLogFile')
|
||||
import Simplex.Messaging.Server.Main.GitCommit
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..))
|
||||
import Simplex.Messaging.Transport (ASrvTransport, ATransport (..), TLS, Transport (..), simplexMQVersion)
|
||||
@@ -414,12 +414,13 @@ printServerTransports protocol ts = do
|
||||
\Set `port` in smp-server.ini section [TRANSPORT] to `5223,443`\n"
|
||||
|
||||
printSMPServerConfig :: [(ServiceName, ASrvTransport, AddHTTP)] -> ServerStoreCfg s -> IO ()
|
||||
printSMPServerConfig transports = \case
|
||||
SSCMemory sp_ -> printServerConfig "SMP" transports $ (\StorePaths {storeLogFile} -> storeLogFile) <$> sp_
|
||||
SSCMemoryJournal {storeLogFile} -> printServerConfig "SMP" transports $ Just storeLogFile
|
||||
SSCDatabaseJournal {storeCfg = PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}}} -> do
|
||||
B.putStrLn $ "PostgreSQL database: " <> connstr <> ", schema: " <> schema
|
||||
printServerTransports "SMP" transports
|
||||
printSMPServerConfig transports st = case dbStoreCfg st of
|
||||
Just cfg -> printDBConfig cfg
|
||||
Nothing -> printServerConfig "SMP" transports $ storeLogFile' st
|
||||
where
|
||||
printDBConfig PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}} = do
|
||||
B.putStrLn $ "PostgreSQL database: " <> connstr <> ", schema: " <> schema
|
||||
printServerTransports "SMP" transports
|
||||
|
||||
deleteDirIfExists :: FilePath -> IO ()
|
||||
deleteDirIfExists path = whenM (doesDirectoryExist path) $ removeDirectoryRecursive path
|
||||
|
||||
@@ -72,7 +72,10 @@ module Simplex.Messaging.Server.Env.STM
|
||||
defaultIdleQueueInterval,
|
||||
journalMsgStoreDepth,
|
||||
readWriteQueueStore,
|
||||
noPostgresExitStr,
|
||||
noPostgresExit,
|
||||
dbStoreCfg,
|
||||
storeLogFile',
|
||||
)
|
||||
where
|
||||
|
||||
@@ -131,6 +134,10 @@ import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
import Simplex.Messaging.Server.MsgStore.Postgres
|
||||
#endif
|
||||
|
||||
data ServerConfig s = ServerConfig
|
||||
{ transports :: [(ServiceName, ASrvTransport, AddHTTP)],
|
||||
smpHandshakeTimeout :: Int,
|
||||
@@ -153,6 +160,7 @@ data ServerConfig s = ServerConfig
|
||||
-- | time after which the messages can be removed from the queues and check interval, seconds
|
||||
messageExpiration :: Maybe ExpirationConfig,
|
||||
expireMessagesOnStart :: Bool,
|
||||
expireMessagesOnSend :: Bool,
|
||||
-- | interval of inactivity after which journal queue is closed
|
||||
idleQueueInterval :: Int64,
|
||||
-- | notification expiration interval (seconds)
|
||||
@@ -274,14 +282,25 @@ fromMsgStore :: MsgStore s -> s
|
||||
fromMsgStore = \case
|
||||
StoreMemory s -> s
|
||||
StoreJournal s -> s
|
||||
#if defined(dbServerPostgres)
|
||||
StoreDatabase s -> s
|
||||
#endif
|
||||
{-# INLINE fromMsgStore #-}
|
||||
|
||||
type family SupportedStore (qs :: QSType) (ms :: MSType) :: Constraint where
|
||||
SupportedStore 'QSMemory 'MSMemory = ()
|
||||
SupportedStore 'QSMemory 'MSJournal = ()
|
||||
SupportedStore 'QSPostgres 'MSJournal = ()
|
||||
SupportedStore 'QSMemory 'MSPostgres =
|
||||
(Int ~ Bool, TypeError ('TE.Text "Storing messages in Postgres DB with queues in memory is not supported"))
|
||||
SupportedStore 'QSPostgres 'MSMemory =
|
||||
(Int ~ Bool, TypeError ('TE.Text "Storing messages in memory with Postgres DB is not supported"))
|
||||
(Int ~ Bool, TypeError ('TE.Text "Storing messages in memory with queues in Postgres DB is not supported"))
|
||||
SupportedStore 'QSPostgres 'MSJournal = ()
|
||||
#if defined(dbServerPostgres)
|
||||
SupportedStore 'QSPostgres 'MSPostgres = ()
|
||||
#else
|
||||
SupportedStore 'QSPostgres 'MSPostgres =
|
||||
(Int ~ Bool, TypeError ('TE.Text "Server compiled without server_postgres flag"))
|
||||
#endif
|
||||
|
||||
data AStoreType =
|
||||
forall qs ms. (SupportedStore qs ms, MsgStoreClass (MsgStoreType qs ms)) =>
|
||||
@@ -291,16 +310,43 @@ data ServerStoreCfg s where
|
||||
SSCMemory :: Maybe StorePaths -> ServerStoreCfg STMMsgStore
|
||||
SSCMemoryJournal :: {storeLogFile :: FilePath, storeMsgsPath :: FilePath} -> ServerStoreCfg (JournalMsgStore 'QSMemory)
|
||||
SSCDatabaseJournal :: {storeCfg :: PostgresStoreCfg, storeMsgsPath' :: FilePath} -> ServerStoreCfg (JournalMsgStore 'QSPostgres)
|
||||
#if defined(dbServerPostgres)
|
||||
SSCDatabase :: PostgresStoreCfg -> ServerStoreCfg PostgresMsgStore
|
||||
#endif
|
||||
|
||||
dbStoreCfg :: ServerStoreCfg s -> Maybe PostgresStoreCfg
|
||||
dbStoreCfg = \case
|
||||
SSCMemory _ -> Nothing
|
||||
SSCMemoryJournal {} -> Nothing
|
||||
SSCDatabaseJournal {storeCfg} -> Just storeCfg
|
||||
#if defined(dbServerPostgres)
|
||||
SSCDatabase cfg -> Just cfg
|
||||
#endif
|
||||
|
||||
storeLogFile' :: ServerStoreCfg s -> Maybe FilePath
|
||||
storeLogFile' = \case
|
||||
SSCMemory sp_ -> (\StorePaths {storeLogFile} -> storeLogFile) <$> sp_
|
||||
SSCMemoryJournal {storeLogFile} -> Just storeLogFile
|
||||
SSCDatabaseJournal {storeCfg = PostgresStoreCfg {dbStoreLogPath}} -> dbStoreLogPath
|
||||
#if defined(dbServerPostgres)
|
||||
SSCDatabase (PostgresStoreCfg {dbStoreLogPath}) -> dbStoreLogPath
|
||||
#endif
|
||||
|
||||
data StorePaths = StorePaths {storeLogFile :: FilePath, storeMsgsFile :: Maybe FilePath}
|
||||
|
||||
type family MsgStoreType (qs :: QSType) (ms :: MSType) where
|
||||
MsgStoreType 'QSMemory 'MSMemory = STMMsgStore
|
||||
MsgStoreType qs 'MSJournal = JournalMsgStore qs
|
||||
#if defined(dbServerPostgres)
|
||||
MsgStoreType 'QSPostgres 'MSPostgres = PostgresMsgStore
|
||||
#endif
|
||||
|
||||
data MsgStore s where
|
||||
StoreMemory :: STMMsgStore -> MsgStore STMMsgStore
|
||||
StoreJournal :: JournalMsgStore qs -> MsgStore (JournalMsgStore qs)
|
||||
#if defined(dbServerPostgres)
|
||||
StoreDatabase :: PostgresMsgStore -> MsgStore PostgresMsgStore
|
||||
#endif
|
||||
|
||||
data Server s = Server
|
||||
{ clients :: ServerClients s,
|
||||
@@ -532,8 +578,12 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp
|
||||
qsCfg = PQStoreCfg (storeCfg {confirmMigrations} :: PostgresStoreCfg)
|
||||
cfg = mkJournalStoreConfig qsCfg storeMsgsPath' msgQueueQuota maxJournalMsgCount maxJournalStateLines idleQueueInterval
|
||||
when compactLog $ compactDbStoreLog $ dbStoreLogPath storeCfg
|
||||
ms <- newMsgStore cfg
|
||||
pure $ StoreJournal ms
|
||||
StoreJournal <$> newMsgStore cfg
|
||||
SSCDatabase storeCfg -> do
|
||||
let StartOptions {compactLog, confirmMigrations} = startOptions config
|
||||
cfg = PostgresMsgStoreCfg storeCfg {confirmMigrations} msgQueueQuota
|
||||
when compactLog $ compactDbStoreLog $ dbStoreLogPath storeCfg
|
||||
StoreDatabase <$> newMsgStore cfg
|
||||
#else
|
||||
SSCDatabaseJournal {} -> noPostgresExit
|
||||
#endif
|
||||
@@ -627,10 +677,12 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp
|
||||
_ -> SPMMessages
|
||||
|
||||
noPostgresExit :: IO a
|
||||
noPostgresExit = do
|
||||
putStrLn "Error: server binary is compiled without support for PostgreSQL database."
|
||||
putStrLn "Please download `smp-server-postgres` or re-compile with `cabal build -fserver_postgres`."
|
||||
exitFailure
|
||||
noPostgresExit = putStrLn noPostgresExitStr >> exitFailure
|
||||
|
||||
noPostgresExitStr :: String
|
||||
noPostgresExitStr =
|
||||
"Error: server binary is compiled without support for PostgreSQL database.\n"
|
||||
<> "Please download `smp-server-postgres` or re-compile with `cabal build -fserver_postgres`."
|
||||
|
||||
mkJournalStoreConfig :: QStoreCfg s -> FilePath -> Int -> Int -> Int -> Int64 -> JournalStoreConfig s
|
||||
mkJournalStoreConfig queueStoreCfg storePath msgQueueQuota maxJournalMsgCount maxJournalStateLines idleQueueInterval =
|
||||
|
||||
@@ -14,7 +14,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (isJust)
|
||||
import Data.Text (Text)
|
||||
import Simplex.Messaging.Agent.Protocol (ConnectionLink, ConnectionMode (..), ConnectionRequestUri)
|
||||
import Simplex.Messaging.Agent.Protocol (ConnectionLink, ConnectionMode (..))
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON)
|
||||
|
||||
|
||||
@@ -18,9 +18,10 @@
|
||||
module Simplex.Messaging.Server.Main where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (finally)
|
||||
import Control.Exception (SomeException, finally, try)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlpha, isAscii, toUpper)
|
||||
@@ -60,11 +61,11 @@ import Simplex.Messaging.Transport (supportedProxyClientSMPRelayVRange, alpnSupp
|
||||
import Simplex.Messaging.Transport.Client (TransportHost (..), defaultSocksProxy)
|
||||
import Simplex.Messaging.Transport.HTTP2 (httpALPN)
|
||||
import Simplex.Messaging.Transport.Server (ServerCredentials (..), mkTransportServerConfig)
|
||||
import Simplex.Messaging.Util (eitherToMaybe, ifM)
|
||||
import Simplex.Messaging.Util (eitherToMaybe, ifM, unlessM)
|
||||
import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath (combine)
|
||||
import System.IO (BufferMode (..), hSetBuffering, stderr, stdout)
|
||||
import System.IO (BufferMode (..), IOMode (..), hSetBuffering, stderr, stdout, withFile)
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
@@ -73,6 +74,7 @@ import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists)
|
||||
import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue)
|
||||
import Simplex.Messaging.Server.MsgStore.Types (QSType (..))
|
||||
import Simplex.Messaging.Server.MsgStore.Journal (postgresQueueStore)
|
||||
import Simplex.Messaging.Server.MsgStore.Postgres
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres (batchInsertQueues, batchInsertServices, foldQueueRecs, foldServiceRecs)
|
||||
import Simplex.Messaging.Server.QueueStore.STM (STMQueueStore (..))
|
||||
import Simplex.Messaging.Server.QueueStore.Types
|
||||
@@ -129,6 +131,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
printMessageStats "Messages" msgStats
|
||||
putStrLn $ case readStoreType ini of
|
||||
Right (ASType SQSMemory SMSMemory) -> "store_messages set to `memory`, update it to `journal` in INI file"
|
||||
Right (ASType SQSPostgres SMSPostgres) -> "store_messages set to `database`, update it to `journal` in INI file"
|
||||
Right (ASType _ SMSJournal) -> "store_messages set to `journal`"
|
||||
Left e -> e <> ", configure storage correctly"
|
||||
SCExport
|
||||
@@ -140,19 +143,31 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
confirmOrExit
|
||||
("WARNING: journal directory " <> storeMsgsJournalDir <> " will be exported to message log file " <> storeMsgsFilePath)
|
||||
"Journal not exported"
|
||||
ms <- newJournalMsgStore logPath MQStoreCfg
|
||||
-- TODO [postgres] in case postgres configured, queues must be read from database
|
||||
readQueueStore True (mkQueue ms False) storeLogFile $ stmQueueStore ms
|
||||
exportMessages True ms storeMsgsFilePath False
|
||||
putStrLn "Export completed"
|
||||
case readStoreType ini of
|
||||
Right (ASType SQSMemory SMSMemory) -> putStrLn "store_messages set to `memory`, start the server."
|
||||
Right (ASType SQSMemory SMSJournal) -> putStrLn "store_messages set to `journal`, update it to `memory` in INI file"
|
||||
Right (ASType SQSPostgres SMSJournal) ->
|
||||
Right (ASType SQSMemory msType) -> do
|
||||
ms <- newJournalMsgStore logPath MQStoreCfg
|
||||
readQueueStore True (mkQueue ms False) storeLogFile $ stmQueueStore ms
|
||||
exportMessages True (StoreJournal ms) storeMsgsFilePath False
|
||||
putStrLn "Export completed"
|
||||
putStrLn $ case msType of
|
||||
SMSMemory -> "store_messages set to `memory`, start the server."
|
||||
SMSJournal -> "store_messages set to `journal`, update it to `memory` in INI file"
|
||||
#if defined(dbServerPostgres)
|
||||
Right (ASType SQSPostgres SMSJournal) -> do
|
||||
let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath
|
||||
dbOpts@DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts
|
||||
unlessM (checkSchemaExists connstr schema) $ do
|
||||
putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr
|
||||
exitFailure
|
||||
ms <- newJournalMsgStore logPath $ PQStoreCfg PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini}
|
||||
exportMessages True (StoreJournal ms) storeMsgsFilePath False
|
||||
putStrLn "Export completed"
|
||||
putStrLn "store_messages set to `journal`, store_queues is set to `database`.\nExport queues to store log to use memory storage for messages (`smp-server database export`)."
|
||||
Right (ASType SQSPostgres SMSPostgres) -> do
|
||||
putStrLn $ "Messages can be exported with `dabatase export --table messages`."
|
||||
exitFailure
|
||||
#else
|
||||
noPostgresExit
|
||||
Right (ASType SQSPostgres SMSJournal) -> noPostgresExit
|
||||
#endif
|
||||
Left e -> putStrLn $ e <> ", configure storage correctly"
|
||||
SCDelete
|
||||
@@ -166,11 +181,12 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
deleteDirIfExists storeMsgsJournalDir
|
||||
putStrLn $ "Deleted all messages in journal " <> storeMsgsJournalDir
|
||||
#if defined(dbServerPostgres)
|
||||
Database cmd dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do
|
||||
Database cmd tables dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do
|
||||
schemaExists <- checkSchemaExists connstr schema
|
||||
storeLogExists <- doesFileExist storeLogFilePath
|
||||
case cmd of
|
||||
SCImport
|
||||
msgsFileExists <- doesFileExist storeMsgsFilePath
|
||||
case (cmd, tables) of
|
||||
(SCImport, DTQueues)
|
||||
| schemaExists && storeLogExists -> exitConfigureQueueStore connstr schema
|
||||
| schemaExists -> do
|
||||
putStrLn $ "Schema " <> B.unpack schema <> " already exists in PostrgreSQL database: " <> B.unpack connstr
|
||||
@@ -188,12 +204,29 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
putStrLn $ case readStoreType ini of
|
||||
Right (ASType SQSMemory SMSMemory) -> setToDbStr <> "\nstore_messages set to `memory`, import messages to journal to use PostgreSQL database for queues (`smp-server journal import`)"
|
||||
Right (ASType SQSMemory SMSJournal) -> setToDbStr
|
||||
Right (ASType SQSPostgres SMSJournal) -> "store_queues set to `database`, start the server."
|
||||
Right (ASType SQSPostgres _) -> "store_queues set to `database`, start the server."
|
||||
Left e -> e <> ", configure storage correctly"
|
||||
where
|
||||
setToDbStr :: String
|
||||
setToDbStr = "store_queues set to `memory`, update it to `database` in INI file"
|
||||
SCExport
|
||||
(SCImport, DTMessages)
|
||||
| not schemaExists -> do
|
||||
putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr
|
||||
exitFailure
|
||||
| not msgsFileExists -> do
|
||||
putStrLn $ storeMsgsFilePath <> " file does not exist."
|
||||
exitFailure
|
||||
| otherwise -> do
|
||||
confirmOrExit
|
||||
("WARNING: message log file " <> storeMsgsFilePath <> " will be imported to PostrgreSQL database " <> B.unpack connstr <> ", schema: " <> B.unpack schema)
|
||||
"Message records not imported"
|
||||
mCnt <- importMessagesToDatabase storeMsgsFilePath dbOpts
|
||||
putStrLn $ "Import completed: " <> show mCnt <> " messages"
|
||||
putStrLn $ case readStoreType ini of
|
||||
Right (ASType SQSPostgres SMSPostgres) -> "store_queues and store_messages set to `database`, start the server."
|
||||
Right _ -> "set store_queues and store_messages set to `database` in INI file"
|
||||
Left e -> e <> ", configure storage correctly"
|
||||
(SCExport, DTQueues)
|
||||
| schemaExists && storeLogExists -> exitConfigureQueueStore connstr schema
|
||||
| not schemaExists -> do
|
||||
putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr
|
||||
@@ -203,15 +236,33 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
exitFailure
|
||||
| otherwise -> do
|
||||
confirmOrExit
|
||||
("WARNING: PostrgreSQL database schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath)
|
||||
("WARNING: PostrgreSQL schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath)
|
||||
"Queue records not exported"
|
||||
(sCnt, qCnt) <- exportDatabaseToStoreLog logPath dbOpts storeLogFilePath
|
||||
putStrLn $ "Export completed: " <> show sCnt <> " services, " <> show qCnt <> " queues"
|
||||
putStrLn $ case readStoreType ini of
|
||||
Right (ASType SQSPostgres SMSJournal) -> "store_queues set to `database`, update it to `memory` in INI file."
|
||||
Right (ASType SQSPostgres _) -> "store_queues or store_messages set to `database`, update it to `memory` in INI file."
|
||||
Right (ASType SQSMemory _) -> "store_queues set to `memory`, start the server"
|
||||
Left e -> e <> ", configure storage correctly"
|
||||
SCDelete
|
||||
(SCExport, DTMessages)
|
||||
| not schemaExists -> do
|
||||
putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr
|
||||
exitFailure
|
||||
| msgsFileExists -> do
|
||||
putStrLn $ storeMsgsFilePath <> " file already exists."
|
||||
exitFailure
|
||||
| otherwise -> do
|
||||
confirmOrExit
|
||||
("WARNING: Messages from PostrgreSQL schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to message log file " <> storeMsgsFilePath)
|
||||
"Message records not exported"
|
||||
let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL}
|
||||
ms <- newMsgStore $ PostgresMsgStoreCfg storeCfg defaultMsgQueueQuota
|
||||
withFile storeMsgsFilePath WriteMode (try . exportDbMessages True ms) >>= \case
|
||||
Right mCnt -> do
|
||||
putStrLn $ "Export completed: " <> show mCnt <> " messages"
|
||||
putStrLn "Export queues with `smp-server database export queues`"
|
||||
Left (e :: SomeException) -> putStrLn $ "Error exporting messages: " <> show e
|
||||
(SCDelete, _)
|
||||
| not schemaExists -> do
|
||||
putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr
|
||||
exitFailure
|
||||
@@ -245,8 +296,14 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
readStoreType ini = case (iniStoreQueues, iniStoreMessage) of
|
||||
("memory", "memory") -> Right $ ASType SQSMemory SMSMemory
|
||||
("memory", "journal") -> Right $ ASType SQSMemory SMSJournal
|
||||
("memory", "database") -> Left "Database and memory storage are not compatible."
|
||||
("database", "memory") -> Left "Database and memory storage are not compatible."
|
||||
("database", "journal") -> Right $ ASType SQSPostgres SMSJournal
|
||||
("database", "memory") -> Left "Using PostgreSQL database requires journal memory storage."
|
||||
#if defined(dbServerPostgres)
|
||||
("database", "database") -> Right $ ASType SQSPostgres SMSPostgres
|
||||
#else
|
||||
("database", "database") -> Left noPostgresExitStr
|
||||
#endif
|
||||
(q, m) -> Left $ T.unpack $ "Invalid storage settings: store_queues: " <> q <> ", store_messages: " <> m
|
||||
where
|
||||
iniStoreQueues = fromRight "memory" $ lookupValue "STORE_LOG" "store_queues" ini
|
||||
@@ -396,6 +453,12 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath
|
||||
storeCfg = PostgresStoreCfg {dbOpts = iniDBOptions ini defaultDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini}
|
||||
in SSCDatabaseJournal {storeCfg, storeMsgsPath' = storeMsgsJournalDir}
|
||||
#if defined(dbServerPostgres)
|
||||
iniStoreCfg SQSPostgres SMSPostgres =
|
||||
let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath
|
||||
storeCfg = PostgresStoreCfg {dbOpts = iniDBOptions ini defaultDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini}
|
||||
in SSCDatabase storeCfg
|
||||
#endif
|
||||
serverConfig :: ServerStoreCfg s -> ServerConfig s
|
||||
serverConfig serverStoreCfg =
|
||||
ServerConfig
|
||||
@@ -428,6 +491,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
{ ttl = 86400 * readIniDefault defMsgExpirationDays "STORE_LOG" "expire_messages_days" ini
|
||||
},
|
||||
expireMessagesOnStart = fromMaybe True $ iniOnOff "STORE_LOG" "expire_messages_on_start" ini,
|
||||
expireMessagesOnSend = fromMaybe True $ iniOnOff "STORE_LOG" "expire_messages_on_send" ini,
|
||||
idleQueueInterval = defaultIdleQueueInterval,
|
||||
notificationExpiration =
|
||||
defaultNtfExpiration
|
||||
@@ -504,6 +568,14 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
msgsFileExists <- doesFileExist storeMsgsFilePath
|
||||
storeLogExists <- doesFileExist storeLogFilePath
|
||||
case mode of
|
||||
#if defined(dbServerPostgres)
|
||||
ASType SQSPostgres SMSPostgres
|
||||
| msgsFileExists || msgsDirExists -> do
|
||||
putStrLn $ "Error: " <> storeMsgsFilePath <> " file or " <> storeMsgsJournalDir <> " directory are present."
|
||||
putStrLn "Configure memory storage."
|
||||
exitFailure
|
||||
| otherwise -> checkDbStorage ini storeLogExists
|
||||
#endif
|
||||
ASType qs SMSJournal
|
||||
| msgsFileExists && msgsDirExists -> exitConfigureMsgStorage
|
||||
| msgsFileExists -> do
|
||||
@@ -516,28 +588,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
SQSMemory ->
|
||||
unless (storeLogExists) $ putStrLn $ "store_queues is `memory`, " <> storeLogFilePath <> " file will be created."
|
||||
#if defined(dbServerPostgres)
|
||||
SQSPostgres -> do
|
||||
let DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts
|
||||
schemaExists <- checkSchemaExists connstr schema
|
||||
case enableDbStoreLog' ini of
|
||||
Just ()
|
||||
| not schemaExists -> noDatabaseSchema connstr schema
|
||||
| not storeLogExists -> do
|
||||
putStrLn $ "Error: db_store_log is `on`, " <> storeLogFilePath <> " does not exist"
|
||||
exitFailure
|
||||
| otherwise -> pure ()
|
||||
Nothing
|
||||
| storeLogExists && schemaExists -> exitConfigureQueueStore connstr schema
|
||||
| storeLogExists -> do
|
||||
putStrLn $ "Error: store_queues is `database` with " <> storeLogFilePath <> " file present."
|
||||
putStrLn "Set store_queues to `memory` or use `smp-server database import` to migrate."
|
||||
exitFailure
|
||||
| not schemaExists -> noDatabaseSchema connstr schema
|
||||
| otherwise -> pure ()
|
||||
where
|
||||
noDatabaseSchema connstr schema = do
|
||||
putStrLn $ "Error: store_queues is `database`, create schema " <> B.unpack schema <> " in PostgreSQL database " <> B.unpack connstr
|
||||
exitFailure
|
||||
SQSPostgres -> checkDbStorage ini storeLogExists
|
||||
#else
|
||||
SQSPostgres -> noPostgresExit
|
||||
#endif
|
||||
@@ -555,6 +606,29 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
|
||||
exitFailure
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
checkDbStorage ini storeLogExists = do
|
||||
let DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts
|
||||
schemaExists <- checkSchemaExists connstr schema
|
||||
case enableDbStoreLog' ini of
|
||||
Just ()
|
||||
| not schemaExists -> noDatabaseSchema connstr schema
|
||||
| not storeLogExists -> do
|
||||
putStrLn $ "Error: db_store_log is `on`, " <> storeLogFilePath <> " does not exist"
|
||||
exitFailure
|
||||
| otherwise -> pure ()
|
||||
Nothing
|
||||
| storeLogExists && schemaExists -> exitConfigureQueueStore connstr schema
|
||||
| storeLogExists -> do
|
||||
putStrLn $ "Error: store_queues is `database` with " <> storeLogFilePath <> " file present."
|
||||
putStrLn "Set store_queues to `memory` or use `smp-server database import` to migrate."
|
||||
exitFailure
|
||||
| not schemaExists -> noDatabaseSchema connstr schema
|
||||
| otherwise -> pure ()
|
||||
where
|
||||
noDatabaseSchema connstr schema = do
|
||||
putStrLn $ "Error: store_queues is `database`, create schema " <> B.unpack schema <> " in PostgreSQL database " <> B.unpack connstr
|
||||
exitFailure
|
||||
|
||||
exitConfigureQueueStore connstr schema = do
|
||||
putStrLn $ "Error: both " <> storeLogFilePath <> " file and " <> B.unpack schema <> " schema are present (database: " <> B.unpack connstr <> ")."
|
||||
putStrLn "Configure queue storage."
|
||||
@@ -575,13 +649,28 @@ importStoreLogToDatabase logPath storeLogFile dbOpts = do
|
||||
renameFile storeLogFile $ storeLogFile <> ".bak"
|
||||
pure (sCnt, qCnt)
|
||||
|
||||
importMessagesToDatabase :: FilePath -> DBOpts -> IO Int64
|
||||
importMessagesToDatabase msgsLogFile dbOpts = do
|
||||
let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL}
|
||||
ms <- newMsgStore $ PostgresMsgStoreCfg storeCfg defaultMsgQueueQuota
|
||||
mCnt <- getDbMessageCount ms
|
||||
when (mCnt > 0) $ do
|
||||
confirmOrExit ("WARNING: the database contains messages, they will be deleted.") "Message records not imported"
|
||||
deleteAllMessages ms
|
||||
inserted <- batchInsertMessages True msgsLogFile $ queueStore ms
|
||||
mCnt' <- getDbMessageCount ms
|
||||
unless (inserted == mCnt') $ putStrLn $ "WARNING: inserted " <> show inserted <> " rows, table has " <> show mCnt' <> " messages."
|
||||
updateQueueCounts ms
|
||||
renameFile msgsLogFile $ msgsLogFile <> ".bak"
|
||||
pure mCnt'
|
||||
|
||||
exportDatabaseToStoreLog :: FilePath -> DBOpts -> FilePath -> IO (Int, Int)
|
||||
exportDatabaseToStoreLog logPath dbOpts storeLogFilePath = do
|
||||
let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = 86400 * defaultDeletedTTL}
|
||||
ps <- newJournalMsgStore logPath $ PQStoreCfg storeCfg
|
||||
sl <- openWriteStoreLog False storeLogFilePath
|
||||
Sum sCnt <- foldServiceRecs (postgresQueueStore ps) $ \sr -> logNewService sl sr $> Sum (1 :: Int)
|
||||
Sum qCnt <- foldQueueRecs True True (postgresQueueStore ps) Nothing $ \(rId, qr) -> logCreateQueue sl rId qr $> Sum (1 :: Int)
|
||||
Sum qCnt <- foldQueueRecs True True (postgresQueueStore ps) $ \(rId, qr) -> logCreateQueue sl rId qr $> Sum (1 :: Int)
|
||||
closeStoreLog sl
|
||||
pure (sCnt, qCnt)
|
||||
#endif
|
||||
@@ -667,10 +756,22 @@ data CliCommand
|
||||
| Start StartOptions
|
||||
| Delete
|
||||
| Journal StoreCmd
|
||||
| Database StoreCmd DBOpts
|
||||
| Database StoreCmd DatabaseTable DBOpts
|
||||
|
||||
data StoreCmd = SCImport | SCExport | SCDelete
|
||||
|
||||
data DatabaseTable = DTQueues | DTMessages
|
||||
|
||||
instance StrEncoding DatabaseTable where
|
||||
strEncode = \case
|
||||
DTQueues -> "queues"
|
||||
DTMessages -> "messages"
|
||||
strP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"queues" -> pure DTQueues
|
||||
"messages" -> pure DTMessages
|
||||
_ -> fail "DatabaseTable"
|
||||
|
||||
cliCommandP :: FilePath -> FilePath -> FilePath -> Parser CliCommand
|
||||
cliCommandP cfgPath logPath iniFile =
|
||||
hsubparser
|
||||
@@ -679,7 +780,7 @@ cliCommandP cfgPath logPath iniFile =
|
||||
<> command "start" (info (Start <$> startOptionsP) (progDesc $ "Start server (configuration: " <> iniFile <> ")"))
|
||||
<> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files"))
|
||||
<> command "journal" (info (Journal <$> journalCmdP) (progDesc "Import/export messages to/from journal storage"))
|
||||
<> command "database" (info (Database <$> databaseCmdP <*> dbOptsP defaultDBOpts) (progDesc "Import/export queues to/from PostgreSQL database storage"))
|
||||
<> command "database" (info (Database <$> databaseCmdP <*> dbTableP <*> dbOptsP defaultDBOpts) (progDesc "Import/export queues to/from PostgreSQL database storage"))
|
||||
)
|
||||
where
|
||||
initP :: Parser InitOptions
|
||||
@@ -833,6 +934,13 @@ cliCommandP cfgPath logPath iniFile =
|
||||
<> command "export" (info (pure SCExport) (progDesc $ "Export " <> dest <> " to " <> src))
|
||||
<> command "delete" (info (pure SCDelete) (progDesc $ "Delete " <> dest))
|
||||
)
|
||||
dbTableP =
|
||||
option
|
||||
strParse
|
||||
( long "table"
|
||||
<> help "Database tables: queues/messages"
|
||||
<> metavar "TABLE"
|
||||
)
|
||||
parseBasicAuth :: ReadM ServerPassword
|
||||
parseBasicAuth = eitherReader $ fmap ServerPassword . strDecode . B.pack
|
||||
entityP :: String -> String -> String -> Parser (Maybe Entity, Maybe Text)
|
||||
|
||||
@@ -87,7 +87,8 @@ iniFileContent cfgPath logPath opts host basicAuth controlPortPwds =
|
||||
<> ("restore_messages: " <> onOff enableStoreLog <> "\n\n")
|
||||
<> "# Messages and notifications expiration periods.\n"
|
||||
<> ("expire_messages_days: " <> tshow defMsgExpirationDays <> "\n")
|
||||
<> "expire_messages_on_start: on\n"
|
||||
<> "expire_messages_on_start: on\n\
|
||||
\expire_messages_on_send: off\n"
|
||||
<> ("expire_ntfs_hours: " <> tshow defNtfExpirationHours <> "\n\n")
|
||||
<> "# Log daily server statistics to CSV file\n"
|
||||
<> ("log_stats: " <> onOff logStats <> "\n\n")
|
||||
|
||||
@@ -24,7 +24,7 @@ module Simplex.Messaging.Server.MsgStore.Journal
|
||||
( JournalMsgStore (random, expireBackupsBefore),
|
||||
QStore (..),
|
||||
QStoreCfg (..),
|
||||
JournalQueue,
|
||||
JournalQueue (msgQueue'), -- msgQueue' is used in tests
|
||||
JournalMsgQueue (queue, state),
|
||||
JMQueue (queueDirectory, statePath),
|
||||
JournalStoreConfig (..),
|
||||
@@ -38,6 +38,7 @@ module Simplex.Messaging.Server.MsgStore.Journal
|
||||
msgQueueStatePath,
|
||||
readQueueState,
|
||||
newMsgQueueState,
|
||||
getJournalQueueMessages,
|
||||
newJournalId,
|
||||
appendState,
|
||||
queueLogFileName,
|
||||
@@ -58,7 +59,7 @@ import Control.Monad.Trans.Except
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.Either (fromRight, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.List (sort)
|
||||
@@ -290,13 +291,10 @@ newtype StoreIO (s :: QSType) a = StoreIO {unStoreIO :: IO a}
|
||||
deriving newtype (Functor, Applicative, Monad)
|
||||
|
||||
instance StoreQueueClass (JournalQueue s) where
|
||||
type MsgQueue (JournalQueue s) = JournalMsgQueue s
|
||||
recipientId = recipientId'
|
||||
{-# INLINE recipientId #-}
|
||||
queueRec = queueRec'
|
||||
{-# INLINE queueRec #-}
|
||||
msgQueue = msgQueue'
|
||||
{-# INLINE msgQueue #-}
|
||||
withQueueLock :: JournalQueue s -> Text -> IO a -> IO a
|
||||
withQueueLock JournalQueue {recipientId', queueLock, sharedLock} =
|
||||
withLockWaitShared recipientId' queueLock sharedLock
|
||||
@@ -309,7 +307,7 @@ instance QueueStoreClass (JournalQueue s) (QStore s) where
|
||||
newQueueStore = \case
|
||||
MQStoreCfg -> MQStore <$> newQueueStore @(JournalQueue s) ()
|
||||
#if defined(dbServerPostgres)
|
||||
PQStoreCfg cfg -> PQStore <$> newQueueStore @(JournalQueue s) cfg
|
||||
PQStoreCfg cfg -> PQStore <$> newQueueStore @(JournalQueue s) (cfg, True)
|
||||
#endif
|
||||
|
||||
closeQueueStore = withQS (closeQueueStore @(JournalQueue s))
|
||||
@@ -378,6 +376,7 @@ makeQueue_ JournalMsgStore {sharedLock} rId qr queueLock = do
|
||||
|
||||
instance MsgStoreClass (JournalMsgStore s) where
|
||||
type StoreMonad (JournalMsgStore s) = StoreIO s
|
||||
type MsgQueue (JournalMsgStore s) = JournalMsgQueue s
|
||||
type QueueStore (JournalMsgStore s) = QStore s
|
||||
type StoreQueue (JournalMsgStore s) = JournalQueue s
|
||||
type MsgStoreConfig (JournalMsgStore s) = JournalStoreConfig s
|
||||
@@ -405,11 +404,11 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
|
||||
-- This function can only be used in server CLI commands or before server is started.
|
||||
-- It does not cache queues and is NOT concurrency safe.
|
||||
unsafeWithAllMsgQueues :: Monoid a => Bool -> Bool -> JournalMsgStore s -> (JournalQueue s -> IO a) -> IO a
|
||||
unsafeWithAllMsgQueues tty withData ms action = case queueStore_ ms of
|
||||
unsafeWithAllMsgQueues :: Monoid a => Bool -> JournalMsgStore s -> (JournalQueue s -> IO a) -> IO a
|
||||
unsafeWithAllMsgQueues tty ms action = case queueStore_ ms of
|
||||
MQStore st -> withLoadedQueues st run
|
||||
#if defined(dbServerPostgres)
|
||||
PQStore st -> foldQueueRecs tty withData st Nothing $ uncurry (mkQueue ms False) >=> run
|
||||
PQStore st -> foldQueueRecs False tty st $ uncurry (mkQueue ms False) >=> run
|
||||
#endif
|
||||
where
|
||||
run q = do
|
||||
@@ -421,7 +420,7 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
expireOldMessages :: Bool -> JournalMsgStore s -> Int64 -> Int64 -> IO MessageStats
|
||||
expireOldMessages tty ms now ttl = case queueStore_ ms of
|
||||
MQStore st ->
|
||||
withLoadedQueues st $ \q -> run $ isolateQueue q "deleteExpiredMsgs" $ do
|
||||
withLoadedQueues st $ \q -> run $ isolateQueue ms q "deleteExpiredMsgs" $ do
|
||||
StoreIO (readTVarIO $ queueRec q) >>= \case
|
||||
Just QueueRec {updatedAt = Just (RoundedSystemTime t)} | t > veryOld ->
|
||||
expireQueueMsgs ms now old q
|
||||
@@ -429,7 +428,7 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
#if defined(dbServerPostgres)
|
||||
PQStore st -> do
|
||||
let JournalMsgStore {queueLocks, sharedLock} = ms
|
||||
foldQueueRecs tty False st (Just veryOld) $ \(rId, qr) -> do
|
||||
foldRecentQueueRecs veryOld tty st $ \(rId, qr) -> do
|
||||
q <- mkQueue ms False rId qr
|
||||
withSharedWaitLock rId queueLocks sharedLock $ run $ tryStore' "deleteExpiredMsgs" rId $
|
||||
getLoadedQueue q >>= unStoreIO . expireQueueMsgs ms now old
|
||||
@@ -485,7 +484,7 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
where
|
||||
newQ = do
|
||||
let dir = msgQueueDirectory ms rId
|
||||
statePath = msgQueueStatePath dir $ B.unpack (strEncode rId)
|
||||
statePath = msgQueueStatePath dir rId
|
||||
queue = JMQueue {queueDirectory = dir, statePath}
|
||||
q <- ifM (doesDirectoryExist dir) (openMsgQueue ms queue forWrite) (createQ queue)
|
||||
atomically $ writeTVar msgQueue' $ Just q
|
||||
@@ -563,8 +562,9 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
where
|
||||
getSize = maybe (pure (-1)) (fmap size . readTVarIO . state)
|
||||
|
||||
-- drainMsgs is never True with Journal storage
|
||||
getQueueMessages_ :: Bool -> JournalQueue s -> JournalMsgQueue s -> StoreIO s [Message]
|
||||
getQueueMessages_ drainMsgs q' q = StoreIO (run [])
|
||||
getQueueMessages_ drainMsgs q' q = StoreIO $ if drainMsgs then run [] else readTVarIO (state q) >>= runFast
|
||||
where
|
||||
run msgs = readTVarIO (handles q) >>= maybe (pure []) (getMsg msgs)
|
||||
getMsg msgs hs = chooseReadJournal q' q drainMsgs hs >>= maybe (pure msgs) readMsg
|
||||
@@ -573,9 +573,19 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
(msg, len) <- hGetMsgAt h $ bytePos rs
|
||||
updateReadPos q' q drainMsgs len hs
|
||||
(msg :) <$> run msgs
|
||||
runFast MsgQueueState {writeState = ws, readState = rs, size}
|
||||
| size > 0 =
|
||||
readTVarIO (handles q) >>= \case
|
||||
Just (MsgQueueHandles _ rh wh_) -> do
|
||||
msgs <- getJournalRange rh (bytePos rs) (byteCount rs)
|
||||
case wh_ of
|
||||
Just wh -> (msgs ++) <$> getJournalRange wh 0 (bytePos ws)
|
||||
Nothing -> pure msgs
|
||||
Nothing -> pure []
|
||||
| otherwise = pure []
|
||||
|
||||
writeMsg :: JournalMsgStore s -> JournalQueue s -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool))
|
||||
writeMsg ms q' logState msg = isolateQueue q' "writeMsg" $ do
|
||||
writeMsg ms q' logState msg = isolateQueue ms q' "writeMsg" $ do
|
||||
q <- getMsgQueue ms q' True
|
||||
StoreIO $ (`E.finally` updateActiveAt q') $ do
|
||||
st@MsgQueueState {canWrite, size} <- readTVarIO (state q)
|
||||
@@ -649,8 +659,8 @@ instance MsgStoreClass (JournalMsgStore s) where
|
||||
$>>= \len -> readTVarIO handles
|
||||
$>>= \hs -> updateReadPos q mq logState len hs $> Just ()
|
||||
|
||||
isolateQueue :: JournalQueue s -> Text -> StoreIO s a -> ExceptT ErrorType IO a
|
||||
isolateQueue sq op = tryStore' op (recipientId' sq) . withQueueLock sq op . unStoreIO
|
||||
isolateQueue :: JournalMsgStore s -> JournalQueue s -> Text -> StoreIO s a -> ExceptT ErrorType IO a
|
||||
isolateQueue _ sq op = tryStore' op (recipientId' sq) . withQueueLock sq op . unStoreIO
|
||||
|
||||
unsafeRunStore :: JournalQueue s -> Text -> StoreIO s a -> IO a
|
||||
unsafeRunStore sq op a =
|
||||
@@ -795,8 +805,8 @@ msgQueueDirectory JournalMsgStore {config = JournalStoreConfig {storePath, pathP
|
||||
let (seg, s') = B.splitAt 2 s
|
||||
in seg : splitSegments (n - 1) s'
|
||||
|
||||
msgQueueStatePath :: FilePath -> String -> FilePath
|
||||
msgQueueStatePath dir queueId = dir </> (queueLogFileName <> "." <> queueId <> logFileExt)
|
||||
msgQueueStatePath :: FilePath -> RecipientId -> FilePath
|
||||
msgQueueStatePath dir rId = dir </> (queueLogFileName <> "." <> B.unpack (strEncode rId) <> logFileExt)
|
||||
|
||||
createNewJournal :: FilePath -> ByteString -> IO Handle
|
||||
createNewJournal dir journalId = do
|
||||
@@ -965,10 +975,11 @@ deleteQueue_ ms q =
|
||||
pure r
|
||||
where
|
||||
rId = recipientId q
|
||||
remove r@(_, mq_) = do
|
||||
remove qr = do
|
||||
mq_ <- atomically $ swapTVar (msgQueue' q) Nothing
|
||||
mapM_ (closeMsgQueueHandles ms) mq_
|
||||
removeQueueDirectory ms rId
|
||||
pure r
|
||||
pure (qr, mq_)
|
||||
|
||||
closeMsgQueue :: JournalMsgStore s -> JournalQueue s -> IO ()
|
||||
closeMsgQueue ms JournalQueue {msgQueue'} = atomically (swapTVar msgQueue' Nothing) >>= mapM_ (closeMsgQueueHandles ms)
|
||||
@@ -1019,3 +1030,33 @@ hClose h =
|
||||
|
||||
closeOnException :: Handle -> IO a -> IO a
|
||||
closeOnException h a = a `E.onException` hClose h
|
||||
|
||||
getJournalQueueMessages :: JournalMsgStore s -> JournalQueue s -> IO [Message]
|
||||
getJournalQueueMessages ms q =
|
||||
readQueueState ms (msgQueueStatePath dir rId) >>= \case
|
||||
(Just MsgQueueState {readState = rs, writeState = ws, size}, _) | size > 0 -> do
|
||||
msgs <- getMsgs (journalId rs) (bytePos rs) (byteCount rs)
|
||||
if journalId rs == journalId ws
|
||||
then pure msgs
|
||||
else (msgs ++) <$> getMsgs (journalId ws) 0 (bytePos ws)
|
||||
_ -> pure []
|
||||
where
|
||||
rId = recipientId' q
|
||||
dir = msgQueueDirectory ms rId
|
||||
getMsgs jId from to =
|
||||
IO.withFile (journalFilePath dir jId) ReadWriteMode $ \h' ->
|
||||
getJournalRange h' from to
|
||||
|
||||
getJournalRange :: Handle -> Int64 -> Int64 -> IO [Message]
|
||||
getJournalRange h from to
|
||||
| to > from = do
|
||||
IO.hSeek h AbsoluteSeek $ fromIntegral from
|
||||
parseMsgs =<< B.hGet h (fromIntegral $ to - from)
|
||||
| otherwise = pure []
|
||||
where
|
||||
parseMsgs s = do
|
||||
let (errs, msgs) = partitionEithers $ map strDecode $ B.lines s
|
||||
unless (null errs) $ do
|
||||
f <- IO.hShow h
|
||||
putStrLn $ "Error reading " <> show (length errs) <> " messages from " <> f
|
||||
pure msgs
|
||||
|
||||
@@ -0,0 +1,386 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DerivingStrategies #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
module Simplex.Messaging.Server.MsgStore.Postgres
|
||||
( PostgresMsgStore,
|
||||
PostgresMsgStoreCfg (..),
|
||||
PostgresQueue,
|
||||
exportDbMessages,
|
||||
getDbMessageStats,
|
||||
getDbMessageCount,
|
||||
deleteAllMessages,
|
||||
batchInsertMessages,
|
||||
updateQueueCounts,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.Trans.Except
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Builder as BB
|
||||
import qualified Data.ByteString.Lazy as LB
|
||||
import Data.Functor (($>))
|
||||
import Data.IORef
|
||||
import Data.Int (Int64)
|
||||
import Data.List (intersperse)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Text (Text)
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Database.PostgreSQL.Simple (Binary (..), Only (..), (:.) (..))
|
||||
import qualified Database.PostgreSQL.Simple as DB
|
||||
import qualified Database.PostgreSQL.Simple.Copy as DB
|
||||
import Database.PostgreSQL.Simple.SqlQQ (sql)
|
||||
import Database.PostgreSQL.Simple.ToField (ToField (..))
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Common
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.MsgStore.Types
|
||||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres
|
||||
import Simplex.Messaging.Server.QueueStore.Types
|
||||
import Simplex.Messaging.Server.StoreLog (foldLogLines)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Util (maybeFirstRow, maybeFirstRow', (<$$>))
|
||||
import System.IO (Handle, hFlush, stdout)
|
||||
|
||||
data PostgresMsgStore = PostgresMsgStore
|
||||
{ config :: PostgresMsgStoreCfg,
|
||||
queueStore_ :: PostgresQueueStore'
|
||||
}
|
||||
|
||||
data PostgresMsgStoreCfg = PostgresMsgStoreCfg
|
||||
{ queueStoreCfg :: PostgresStoreCfg,
|
||||
quota :: Int
|
||||
}
|
||||
|
||||
type PostgresQueueStore' = PostgresQueueStore PostgresQueue
|
||||
|
||||
data PostgresQueue = PostgresQueue
|
||||
{ recipientId' :: RecipientId,
|
||||
queueRec' :: TVar (Maybe QueueRec)
|
||||
}
|
||||
|
||||
instance StoreQueueClass PostgresQueue where
|
||||
recipientId = recipientId'
|
||||
{-# INLINE recipientId #-}
|
||||
queueRec = queueRec'
|
||||
{-# INLINE queueRec #-}
|
||||
withQueueLock PostgresQueue {} _ = id -- TODO [messages] maybe it's just transaction?
|
||||
{-# INLINE withQueueLock #-}
|
||||
|
||||
newtype DBTransaction = DBTransaction {dbConn :: DB.Connection}
|
||||
|
||||
type DBStoreIO a = ReaderT DBTransaction IO a
|
||||
|
||||
instance MsgStoreClass PostgresMsgStore where
|
||||
type StoreMonad PostgresMsgStore = ReaderT DBTransaction IO
|
||||
type MsgQueue PostgresMsgStore = ()
|
||||
type QueueStore PostgresMsgStore = PostgresQueueStore'
|
||||
type StoreQueue PostgresMsgStore = PostgresQueue
|
||||
type MsgStoreConfig PostgresMsgStore = PostgresMsgStoreCfg
|
||||
|
||||
newMsgStore :: PostgresMsgStoreCfg -> IO PostgresMsgStore
|
||||
newMsgStore config = do
|
||||
queueStore_ <- newQueueStore @PostgresQueue (queueStoreCfg config, False)
|
||||
pure PostgresMsgStore {config, queueStore_}
|
||||
|
||||
closeMsgStore :: PostgresMsgStore -> IO ()
|
||||
closeMsgStore = closeQueueStore @PostgresQueue . queueStore_
|
||||
|
||||
withActiveMsgQueues _ _ = error "withActiveMsgQueues not used"
|
||||
|
||||
unsafeWithAllMsgQueues _ _ _ = error "unsafeWithAllMsgQueues not used"
|
||||
|
||||
expireOldMessages :: Bool -> PostgresMsgStore -> Int64 -> Int64 -> IO MessageStats
|
||||
expireOldMessages _tty ms now ttl =
|
||||
maybeFirstRow' newMessageStats toMessageStats $ withConnection st $ \db ->
|
||||
DB.query db "CALL expire_old_messages(?,?,?,0,0,0)" (oldQueue, oldMsg, batchSize)
|
||||
where
|
||||
st = dbStore $ queueStore_ ms
|
||||
oldQueue = 0 :: Int64 -- expire all queues
|
||||
oldMsg = now - ttl
|
||||
batchSize = 10000 :: Int
|
||||
toMessageStats (expiredMsgsCount, storedMsgsCount, storedQueues) =
|
||||
MessageStats {expiredMsgsCount, storedMsgsCount, storedQueues}
|
||||
|
||||
logQueueStates _ = error "logQueueStates not used"
|
||||
|
||||
logQueueState _ = error "logQueueState not used"
|
||||
|
||||
queueStore = queueStore_
|
||||
{-# INLINE queueStore #-}
|
||||
|
||||
loadedQueueCounts :: PostgresMsgStore -> IO LoadedQueueCounts
|
||||
loadedQueueCounts ms = do
|
||||
loadedQueueCount <- M.size <$> readTVarIO queues
|
||||
loadedNotifierCount <- M.size <$> readTVarIO notifiers
|
||||
notifierLockCount <- M.size <$> readTVarIO notifierLocks
|
||||
pure LoadedQueueCounts {loadedQueueCount, loadedNotifierCount, openJournalCount = 0, queueLockCount = 0, notifierLockCount}
|
||||
where
|
||||
PostgresQueueStore {queues, notifiers, notifierLocks} = queueStore_ ms
|
||||
|
||||
mkQueue :: PostgresMsgStore -> Bool -> RecipientId -> QueueRec -> IO PostgresQueue
|
||||
mkQueue _ _keepLock rId qr = PostgresQueue rId <$> newTVarIO (Just qr)
|
||||
{-# INLINE mkQueue #-}
|
||||
|
||||
getMsgQueue _ _ _ = pure ()
|
||||
{-# INLINE getMsgQueue #-}
|
||||
|
||||
getPeekMsgQueue :: PostgresMsgStore -> PostgresQueue -> DBStoreIO (Maybe ((), Message))
|
||||
getPeekMsgQueue _ q = ((),) <$$> tryPeekMsg_ q ()
|
||||
|
||||
withIdleMsgQueue :: Int64 -> PostgresMsgStore -> PostgresQueue -> (() -> DBStoreIO a) -> DBStoreIO (Maybe a, Int)
|
||||
withIdleMsgQueue _ _ _ _ = error "withIdleMsgQueue not used"
|
||||
|
||||
deleteQueue :: PostgresMsgStore -> PostgresQueue -> IO (Either ErrorType QueueRec)
|
||||
deleteQueue ms q = deleteStoreQueue (queueStore_ ms) q
|
||||
{-# INLINE deleteQueue #-}
|
||||
|
||||
deleteQueueSize :: PostgresMsgStore -> PostgresQueue -> IO (Either ErrorType (QueueRec, Int))
|
||||
deleteQueueSize ms q = runExceptT $ do
|
||||
size <- getQueueSize ms q
|
||||
qr <- ExceptT $ deleteStoreQueue (queueStore_ ms) q
|
||||
pure (qr, size)
|
||||
|
||||
getQueueMessages_ _ _ _ = error "getQueueMessages_ not used"
|
||||
|
||||
writeMsg :: PostgresMsgStore -> PostgresQueue -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool))
|
||||
writeMsg ms q _ msg =
|
||||
uninterruptibleMask_ $
|
||||
withDB' "writeMsg" (queueStore_ ms) $ \db -> do
|
||||
let (msgQuota, ntf, body) = case msg of
|
||||
Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body')
|
||||
MessageQuota {} -> (True, False, B.empty)
|
||||
toResult <$>
|
||||
DB.query
|
||||
db
|
||||
"SELECT quota_written, was_empty FROM write_message(?,?,?,?,?,?,?)"
|
||||
(recipientId' q, Binary (messageId msg), systemSeconds (messageTs msg), msgQuota, ntf, Binary body, quota)
|
||||
where
|
||||
toResult = \case
|
||||
((msgQuota, wasEmpty) : _) -> if msgQuota then Nothing else Just (msg, wasEmpty)
|
||||
[] -> Nothing
|
||||
PostgresMsgStore {config = PostgresMsgStoreCfg {quota}} = ms
|
||||
|
||||
setOverQuota_ :: PostgresQueue -> IO () -- can ONLY be used while restoring messages, not while server running
|
||||
setOverQuota_ _ = error "TODO setOverQuota_" -- TODO [messages]
|
||||
|
||||
getQueueSize_ :: () -> DBStoreIO Int
|
||||
getQueueSize_ _ = error "getQueueSize_ not used"
|
||||
|
||||
getQueueSize :: PostgresMsgStore -> PostgresQueue -> ExceptT ErrorType IO Int
|
||||
getQueueSize ms q =
|
||||
withDB' "getQueueSize" (queueStore_ ms) $ \db ->
|
||||
maybeFirstRow' 0 fromOnly $
|
||||
DB.query db "SELECT msg_queue_size FROM msg_queues WHERE recipient_id = ? AND deleted_at IS NULL" (Only (recipientId' q))
|
||||
|
||||
tryPeekMsg_ :: PostgresQueue -> () -> DBStoreIO (Maybe Message)
|
||||
tryPeekMsg_ q _ = do
|
||||
db <- asks dbConn
|
||||
liftIO $ maybeFirstRow toMessage $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
FROM messages
|
||||
WHERE recipient_id = ?
|
||||
ORDER BY message_id ASC LIMIT 1
|
||||
|]
|
||||
(Only (recipientId' q))
|
||||
|
||||
tryDeleteMsg_ :: PostgresQueue -> () -> Bool -> DBStoreIO ()
|
||||
tryDeleteMsg_ _q _ _ = error "tryDeleteMsg_ not used" -- do
|
||||
|
||||
isolateQueue :: PostgresMsgStore -> PostgresQueue -> Text -> DBStoreIO a -> ExceptT ErrorType IO a
|
||||
isolateQueue ms _q op a = uninterruptibleMask_ $ withDB' op (queueStore_ ms) $ runReaderT a . DBTransaction
|
||||
|
||||
unsafeRunStore _ _ _ = error "unsafeRunStore not used"
|
||||
|
||||
tryPeekMsg :: PostgresMsgStore -> PostgresQueue -> ExceptT ErrorType IO (Maybe Message)
|
||||
tryPeekMsg ms q = isolateQueue ms q "tryPeekMsg" $ tryPeekMsg_ q ()
|
||||
{-# INLINE tryPeekMsg #-}
|
||||
|
||||
tryDelMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message)
|
||||
tryDelMsg ms q msgId =
|
||||
uninterruptibleMask_ $
|
||||
withDB' "tryDelMsg" (queueStore_ ms) $ \db ->
|
||||
maybeFirstRow toMessage $
|
||||
DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_msg(?, ?)" (recipientId' q, Binary msgId)
|
||||
|
||||
tryDelPeekMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message)
|
||||
tryDelPeekMsg ms q msgId =
|
||||
uninterruptibleMask_ $
|
||||
withDB' "tryDelPeekMsg" (queueStore_ ms) $ \db ->
|
||||
toResult . map toMessage
|
||||
<$> DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_peek_msg(?, ?)" (recipientId' q, Binary msgId)
|
||||
where
|
||||
toResult = \case
|
||||
[] -> (Nothing, Nothing)
|
||||
[msg]
|
||||
| messageId msg == msgId -> (Just msg, Nothing)
|
||||
| otherwise -> (Nothing, Just msg)
|
||||
deleted : next : _ -> (Just deleted, Just next)
|
||||
|
||||
deleteExpiredMsgs :: PostgresMsgStore -> PostgresQueue -> Int64 -> ExceptT ErrorType IO Int
|
||||
deleteExpiredMsgs ms q old =
|
||||
uninterruptibleMask_ $
|
||||
maybeFirstRow' 0 (fromIntegral @Int64 . fromOnly) $ withDB' "deleteExpiredMsgs" (queueStore_ ms) $ \db ->
|
||||
DB.query db "SELECT delete_expired_msgs(?, ?)" (recipientId' q, old)
|
||||
|
||||
uninterruptibleMask_ :: ExceptT ErrorType IO a -> ExceptT ErrorType IO a
|
||||
uninterruptibleMask_ = ExceptT . E.uninterruptibleMask_ . runExceptT
|
||||
{-# INLINE uninterruptibleMask_ #-}
|
||||
|
||||
toMessage :: (Binary MsgId, Int64, Bool, Bool, Binary MsgBody) -> Message
|
||||
toMessage (Binary msgId, ts, msgQuota, ntf, Binary body)
|
||||
| msgQuota = MessageQuota {msgId, msgTs}
|
||||
| otherwise = Message {msgId, msgTs, msgFlags = MsgFlags ntf, msgBody = C.unsafeMaxLenBS body} -- TODO [messages] unsafeMaxLenBS?
|
||||
where
|
||||
msgTs = MkSystemTime ts 0
|
||||
|
||||
exportDbMessages :: Bool -> PostgresMsgStore -> Handle -> IO Int
|
||||
exportDbMessages tty ms h = do
|
||||
rows <- newIORef []
|
||||
n <- withConnection st $ \db -> DB.foldWithOptions_ opts db query 0 $ \i r -> do
|
||||
let i' = i + 1
|
||||
if i' `mod` 1000 > 0
|
||||
then modifyIORef rows (r :)
|
||||
else do
|
||||
readIORef rows >>= writeMessages . (r :)
|
||||
writeIORef rows []
|
||||
when tty $ putStr (progress i' <> "\r") >> hFlush stdout
|
||||
pure i'
|
||||
readIORef rows >>= \rs -> unless (null rs) $ writeMessages rs
|
||||
when tty $ putStrLn $ progress n
|
||||
pure n
|
||||
where
|
||||
st = dbStore $ queueStore_ ms
|
||||
opts = DB.defaultFoldOptions {DB.fetchQuantity = DB.Fixed 1000}
|
||||
query =
|
||||
[sql|
|
||||
SELECT recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
FROM messages
|
||||
ORDER BY recipient_id, message_id ASC
|
||||
|]
|
||||
writeMessages = BB.hPutBuilder h . encodeMessages . reverse
|
||||
encodeMessages = mconcat . map (\(Only rId :. msg) -> BB.byteString (strEncode $ MLRv3 rId $ toMessage msg) <> BB.char8 '\n')
|
||||
progress i = "Processed: " <> show i <> " records"
|
||||
|
||||
getDbMessageStats :: PostgresMsgStore -> IO MessageStats
|
||||
getDbMessageStats ms =
|
||||
maybeFirstRow' newMessageStats toMessageStats $ withConnection st $ \db ->
|
||||
DB.query_
|
||||
db
|
||||
[sql|
|
||||
SELECT
|
||||
(SELECT COUNT (1) FROM msg_queues WHERE deleted_at IS NULL),
|
||||
(SELECT COUNT (1) FROM messages m JOIN msg_queues q USING recipient_id WHERE deleted_at IS NULL)
|
||||
|]
|
||||
where
|
||||
st = dbStore $ queueStore_ ms
|
||||
toMessageStats (storedQueues, storedMsgsCount) =
|
||||
MessageStats {storedQueues, storedMsgsCount, expiredMsgsCount = 0}
|
||||
|
||||
getDbMessageCount :: PostgresMsgStore -> IO Int64
|
||||
getDbMessageCount ms =
|
||||
maybeFirstRow' 0 fromOnly $
|
||||
withConnection (dbStore $ queueStore_ ms) (`DB.query_` "SELECT COUNT(*) FROM messages")
|
||||
|
||||
deleteAllMessages :: PostgresMsgStore -> IO ()
|
||||
deleteAllMessages ms =
|
||||
withConnection (dbStore $ queueStore_ ms) $ \db -> do
|
||||
void $ DB.execute_ db "TRUNCATE messages"
|
||||
void $ DB.execute_
|
||||
db
|
||||
[sql|
|
||||
UPDATE msg_queues
|
||||
SET msg_queue_size = 0, msg_can_write = TRUE, msg_queue_expire = FALSE
|
||||
WHERE msg_queue_size != 0 OR msg_can_write = FALSE OR msg_queue_expire = TRUE
|
||||
|]
|
||||
|
||||
updateQueueCounts :: PostgresMsgStore -> IO ()
|
||||
updateQueueCounts ms =
|
||||
withConnection (dbStore $ queueStore_ ms) $ \db -> do
|
||||
void $ DB.execute_
|
||||
db
|
||||
[sql|
|
||||
CREATE TEMP TABLE queue_stats AS
|
||||
SELECT recipient_id,
|
||||
COUNT(*) AS size,
|
||||
SUM(CASE WHEN msg_quota THEN 1 ELSE 0 END) AS quota_count
|
||||
FROM messages
|
||||
GROUP BY recipient_id
|
||||
|]
|
||||
void $ DB.execute_
|
||||
db
|
||||
[sql|
|
||||
UPDATE msg_queues
|
||||
SET msg_queue_size = 0, msg_can_write = TRUE, msg_queue_expire = FALSE
|
||||
WHERE msg_queue_size != 0 OR msg_can_write = FALSE OR msg_queue_expire = TRUE
|
||||
|]
|
||||
void $ DB.execute_
|
||||
db
|
||||
[sql|
|
||||
UPDATE msg_queues q
|
||||
SET msg_queue_size = s.size,
|
||||
msg_can_write = s.quota_count = 0,
|
||||
msg_queue_expire = s.size > s.quota_count
|
||||
FROM queue_stats s
|
||||
WHERE q.recipient_id = s.recipient_id
|
||||
|]
|
||||
void $ DB.execute_ db "DROP TABLE queue_stats"
|
||||
|
||||
batchInsertMessages :: StoreQueueClass q => Bool -> FilePath -> PostgresQueueStore q -> IO Int64
|
||||
batchInsertMessages tty f toStore = do
|
||||
putStrLn "Importing messages..."
|
||||
let st = dbStore toStore
|
||||
(_, inserted) <-
|
||||
withTransaction st $ \db -> do
|
||||
DB.copy_
|
||||
db
|
||||
[sql|
|
||||
COPY messages (recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body)
|
||||
FROM STDIN WITH (FORMAT CSV)
|
||||
|]
|
||||
foldLogLines tty f (putMessage db) (0 :: Int, 0) >>= (DB.putCopyEnd db $>)
|
||||
pure inserted
|
||||
where
|
||||
putMessage db (!i, !cnt) _eof s = do
|
||||
let i' = i + 1
|
||||
cnt' <- case strDecode s of
|
||||
Right (MLRv3 rId msg) -> (cnt + 1) <$ DB.putCopyData db (messageRecToText rId msg)
|
||||
Left e -> cnt <$ putStrLn ("Error parsing line " <> show i' <> ": " <> e)
|
||||
pure (i', cnt')
|
||||
|
||||
messageRecToText :: RecipientId -> Message -> B.ByteString
|
||||
messageRecToText rId msg =
|
||||
LB.toStrict $ BB.toLazyByteString $ mconcat tabFields <> BB.char7 '\n'
|
||||
where
|
||||
tabFields = BB.char7 ',' `intersperse` fields
|
||||
fields =
|
||||
[ renderField (toField rId),
|
||||
renderField (toField $ Binary (messageId msg)),
|
||||
renderField (toField $ systemSeconds (messageTs msg)),
|
||||
renderField (toField msgQuota),
|
||||
renderField (toField ntf),
|
||||
renderField (toField $ Binary body)
|
||||
]
|
||||
(msgQuota, ntf, body) = case msg of
|
||||
Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body')
|
||||
MessageQuota {} -> (True, False, B.empty)
|
||||
@@ -57,18 +57,16 @@ data STMStoreConfig = STMStoreConfig
|
||||
}
|
||||
|
||||
instance StoreQueueClass STMQueue where
|
||||
type MsgQueue STMQueue = STMMsgQueue
|
||||
recipientId = recipientId'
|
||||
{-# INLINE recipientId #-}
|
||||
queueRec = queueRec'
|
||||
{-# INLINE queueRec #-}
|
||||
msgQueue = msgQueue'
|
||||
{-# INLINE msgQueue #-}
|
||||
withQueueLock _ _ = id
|
||||
{-# INLINE withQueueLock #-}
|
||||
|
||||
instance MsgStoreClass STMMsgStore where
|
||||
type StoreMonad STMMsgStore = STM
|
||||
type MsgQueue STMMsgStore = STMMsgQueue
|
||||
type QueueStore STMMsgStore = STMQueueStore STMQueue
|
||||
type StoreQueue STMMsgStore = STMQueue
|
||||
type MsgStoreConfig STMMsgStore = STMStoreConfig
|
||||
@@ -82,7 +80,7 @@ instance MsgStoreClass STMMsgStore where
|
||||
{-# INLINE closeMsgStore #-}
|
||||
withActiveMsgQueues = withLoadedQueues . queueStore_
|
||||
{-# INLINE withActiveMsgQueues #-}
|
||||
unsafeWithAllMsgQueues _ _ = withLoadedQueues . queueStore_
|
||||
unsafeWithAllMsgQueues _ = withLoadedQueues . queueStore_
|
||||
{-# INLINE unsafeWithAllMsgQueues #-}
|
||||
|
||||
expireOldMessages :: Bool -> STMMsgStore -> Int64 -> Int64 -> IO MessageStats
|
||||
@@ -129,10 +127,10 @@ instance MsgStoreClass STMMsgStore where
|
||||
Nothing -> pure (Nothing, 0)
|
||||
|
||||
deleteQueue :: STMMsgStore -> STMQueue -> IO (Either ErrorType QueueRec)
|
||||
deleteQueue ms q = fst <$$> deleteStoreQueue (queueStore_ ms) q
|
||||
deleteQueue ms q = fst <$$> deleteQueue_ ms q
|
||||
|
||||
deleteQueueSize :: STMMsgStore -> STMQueue -> IO (Either ErrorType (QueueRec, Int))
|
||||
deleteQueueSize ms q = deleteStoreQueue (queueStore_ ms) q >>= mapM (traverse getSize)
|
||||
deleteQueueSize ms q = deleteQueue_ ms q >>= mapM (traverse getSize)
|
||||
-- traverse operates on the second tuple element
|
||||
where
|
||||
getSize = maybe (pure 0) (\STMMsgQueue {size} -> readTVarIO size)
|
||||
@@ -179,10 +177,15 @@ instance MsgStoreClass STMMsgStore where
|
||||
Just _ -> modifyTVar' size (subtract 1)
|
||||
_ -> pure ()
|
||||
|
||||
isolateQueue :: STMQueue -> Text -> STM a -> ExceptT ErrorType IO a
|
||||
isolateQueue _ _ = liftIO . atomically
|
||||
isolateQueue :: STMMsgStore -> STMQueue -> Text -> STM a -> ExceptT ErrorType IO a
|
||||
isolateQueue _ _ _ = liftIO . atomically
|
||||
{-# INLINE isolateQueue #-}
|
||||
|
||||
unsafeRunStore :: STMQueue -> Text -> STM a -> IO a
|
||||
unsafeRunStore _ _ = atomically
|
||||
{-# INLINE unsafeRunStore #-}
|
||||
|
||||
deleteQueue_ :: STMMsgStore -> STMQueue -> IO (Either ErrorType (QueueRec, Maybe STMMsgQueue))
|
||||
deleteQueue_ ms q = deleteStoreQueue (queueStore_ ms) q >>= mapM remove
|
||||
where
|
||||
remove qr = (qr,) <$> atomically (swapTVar (msgQueue' q) Nothing)
|
||||
|
||||
@@ -34,14 +34,15 @@ import Simplex.Messaging.Util ((<$$>), ($>>=))
|
||||
class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => MsgStoreClass s where
|
||||
type StoreMonad s = (m :: Type -> Type) | m -> s
|
||||
type MsgStoreConfig s = c | c -> s
|
||||
type MsgQueue s = q | q -> s
|
||||
type StoreQueue s = q | q -> s
|
||||
type QueueStore s = qs | qs -> s
|
||||
newMsgStore :: MsgStoreConfig s -> IO s
|
||||
closeMsgStore :: s -> IO ()
|
||||
withActiveMsgQueues :: Monoid a => s -> (StoreQueue s -> IO a) -> IO a
|
||||
-- This function can only be used in server CLI commands or before server is started.
|
||||
-- tty, withData, store
|
||||
unsafeWithAllMsgQueues :: Monoid a => Bool -> Bool -> s -> (StoreQueue s -> IO a) -> IO a
|
||||
-- tty, store
|
||||
unsafeWithAllMsgQueues :: Monoid a => Bool -> s -> (StoreQueue s -> IO a) -> IO a
|
||||
-- tty, store, now, ttl
|
||||
expireOldMessages :: Bool -> s -> Int64 -> Int64 -> IO MessageStats
|
||||
logQueueStates :: s -> IO ()
|
||||
@@ -51,29 +52,62 @@ class (Monad (StoreMonad s), QueueStoreClass (StoreQueue s) (QueueStore s)) => M
|
||||
|
||||
-- message store methods
|
||||
mkQueue :: s -> Bool -> RecipientId -> QueueRec -> IO (StoreQueue s)
|
||||
getMsgQueue :: s -> StoreQueue s -> Bool -> StoreMonad s (MsgQueue (StoreQueue s))
|
||||
getPeekMsgQueue :: s -> StoreQueue s -> StoreMonad s (Maybe (MsgQueue (StoreQueue s), Message))
|
||||
getMsgQueue :: s -> StoreQueue s -> Bool -> StoreMonad s (MsgQueue s)
|
||||
getPeekMsgQueue :: s -> StoreQueue s -> StoreMonad s (Maybe (MsgQueue s, Message))
|
||||
|
||||
-- the journal queue will be closed after action if it was initially closed or idle longer than interval in config
|
||||
withIdleMsgQueue :: Int64 -> s -> StoreQueue s -> (MsgQueue (StoreQueue s) -> StoreMonad s a) -> StoreMonad s (Maybe a, Int)
|
||||
withIdleMsgQueue :: Int64 -> s -> StoreQueue s -> (MsgQueue s -> StoreMonad s a) -> StoreMonad s (Maybe a, Int)
|
||||
deleteQueue :: s -> StoreQueue s -> IO (Either ErrorType QueueRec)
|
||||
deleteQueueSize :: s -> StoreQueue s -> IO (Either ErrorType (QueueRec, Int))
|
||||
getQueueMessages_ :: Bool -> StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s [Message]
|
||||
getQueueMessages_ :: Bool -> StoreQueue s -> MsgQueue s -> StoreMonad s [Message]
|
||||
writeMsg :: s -> StoreQueue s -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool))
|
||||
setOverQuota_ :: StoreQueue s -> IO () -- can ONLY be used while restoring messages, not while server running
|
||||
getQueueSize_ :: MsgQueue (StoreQueue s) -> StoreMonad s Int
|
||||
tryPeekMsg_ :: StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s (Maybe Message)
|
||||
tryDeleteMsg_ :: StoreQueue s -> MsgQueue (StoreQueue s) -> Bool -> StoreMonad s ()
|
||||
isolateQueue :: StoreQueue s -> Text -> StoreMonad s a -> ExceptT ErrorType IO a
|
||||
getQueueSize_ :: MsgQueue s -> StoreMonad s Int
|
||||
tryPeekMsg_ :: StoreQueue s -> MsgQueue s -> StoreMonad s (Maybe Message)
|
||||
tryDeleteMsg_ :: StoreQueue s -> MsgQueue s -> Bool -> StoreMonad s ()
|
||||
isolateQueue :: s -> StoreQueue s -> Text -> StoreMonad s a -> ExceptT ErrorType IO a
|
||||
unsafeRunStore :: StoreQueue s -> Text -> StoreMonad s a -> IO a
|
||||
|
||||
data MSType = MSMemory | MSJournal
|
||||
-- default implementations are overridden for PostgreSQL storage of messages
|
||||
tryPeekMsg :: s -> StoreQueue s -> ExceptT ErrorType IO (Maybe Message)
|
||||
tryPeekMsg st q = snd <$$> withPeekMsgQueue st q "tryPeekMsg" pure
|
||||
{-# INLINE tryPeekMsg #-}
|
||||
|
||||
tryDelMsg :: s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message)
|
||||
tryDelMsg st q msgId' =
|
||||
withPeekMsgQueue st q "tryDelMsg" $
|
||||
maybe (pure Nothing) $ \(mq, msg) ->
|
||||
if
|
||||
| messageId msg == msgId' ->
|
||||
tryDeleteMsg_ q mq True $> Just msg
|
||||
| otherwise -> pure Nothing
|
||||
|
||||
-- atomic delete (== read) last and peek next message if available
|
||||
tryDelPeekMsg :: s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message)
|
||||
tryDelPeekMsg st q msgId' =
|
||||
withPeekMsgQueue st q "tryDelPeekMsg" $
|
||||
maybe (pure (Nothing, Nothing)) $ \(mq, msg) ->
|
||||
if
|
||||
| messageId msg == msgId' -> (Just msg,) <$> (tryDeleteMsg_ q mq True >> tryPeekMsg_ q mq)
|
||||
| otherwise -> pure (Nothing, Just msg)
|
||||
|
||||
deleteExpiredMsgs :: s -> StoreQueue s -> Int64 -> ExceptT ErrorType IO Int
|
||||
deleteExpiredMsgs st q old =
|
||||
isolateQueue st q "deleteExpiredMsgs" $
|
||||
getMsgQueue st q False >>= deleteExpireMsgs_ old q
|
||||
|
||||
getQueueSize :: s -> StoreQueue s -> ExceptT ErrorType IO Int
|
||||
getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst)
|
||||
{-# INLINE getQueueSize #-}
|
||||
|
||||
data MSType = MSMemory | MSJournal | MSPostgres
|
||||
|
||||
data QSType = QSMemory | QSPostgres
|
||||
|
||||
data SMSType :: MSType -> Type where
|
||||
SMSMemory :: SMSType 'MSMemory
|
||||
SMSJournal :: SMSType 'MSJournal
|
||||
SMSPostgres :: SMSType 'MSPostgres
|
||||
|
||||
data SQSType :: QSType -> Type where
|
||||
SQSMemory :: SQSType 'QSMemory
|
||||
@@ -84,6 +118,7 @@ data MessageStats = MessageStats
|
||||
expiredMsgsCount :: Int,
|
||||
storedQueues :: Int
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
instance Monoid MessageStats where
|
||||
mempty = MessageStats 0 0 0
|
||||
@@ -126,48 +161,19 @@ readQueueRec :: StoreQueueClass q => q -> IO (Either ErrorType (q, QueueRec))
|
||||
readQueueRec q = maybe (Left AUTH) (Right . (q,)) <$> readTVarIO (queueRec q)
|
||||
{-# INLINE readQueueRec #-}
|
||||
|
||||
getQueueSize :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO Int
|
||||
getQueueSize st q = withPeekMsgQueue st q "getQueueSize" $ maybe (pure 0) (getQueueSize_ . fst)
|
||||
{-# INLINE getQueueSize #-}
|
||||
|
||||
tryPeekMsg :: MsgStoreClass s => s -> StoreQueue s -> ExceptT ErrorType IO (Maybe Message)
|
||||
tryPeekMsg st q = snd <$$> withPeekMsgQueue st q "tryPeekMsg" pure
|
||||
{-# INLINE tryPeekMsg #-}
|
||||
|
||||
tryDelMsg :: MsgStoreClass s => s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message)
|
||||
tryDelMsg st q msgId' =
|
||||
withPeekMsgQueue st q "tryDelMsg" $
|
||||
maybe (pure Nothing) $ \(mq, msg) ->
|
||||
if
|
||||
| messageId msg == msgId' ->
|
||||
tryDeleteMsg_ q mq True $> Just msg
|
||||
| otherwise -> pure Nothing
|
||||
|
||||
-- atomic delete (== read) last and peek next message if available
|
||||
tryDelPeekMsg :: MsgStoreClass s => s -> StoreQueue s -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message)
|
||||
tryDelPeekMsg st q msgId' =
|
||||
withPeekMsgQueue st q "tryDelPeekMsg" $
|
||||
maybe (pure (Nothing, Nothing)) $ \(mq, msg) ->
|
||||
if
|
||||
| messageId msg == msgId' -> (Just msg,) <$> (tryDeleteMsg_ q mq True >> tryPeekMsg_ q mq)
|
||||
| otherwise -> pure (Nothing, Just msg)
|
||||
|
||||
-- The action is called with Nothing when it is known that the queue is empty
|
||||
withPeekMsgQueue :: MsgStoreClass s => s -> StoreQueue s -> Text -> (Maybe (MsgQueue (StoreQueue s), Message) -> StoreMonad s a) -> ExceptT ErrorType IO a
|
||||
withPeekMsgQueue st q op a = isolateQueue q op $ getPeekMsgQueue st q >>= a
|
||||
withPeekMsgQueue :: MsgStoreClass s => s -> StoreQueue s -> Text -> (Maybe (MsgQueue s, Message) -> StoreMonad s a) -> ExceptT ErrorType IO a
|
||||
withPeekMsgQueue st q op a = isolateQueue st q op $ getPeekMsgQueue st q >>= a
|
||||
{-# INLINE withPeekMsgQueue #-}
|
||||
|
||||
deleteExpiredMsgs :: MsgStoreClass s => s -> StoreQueue s -> Int64 -> ExceptT ErrorType IO Int
|
||||
deleteExpiredMsgs st q old =
|
||||
isolateQueue q "deleteExpiredMsgs" $
|
||||
getMsgQueue st q False >>= deleteExpireMsgs_ old q
|
||||
|
||||
-- not used with PostgreSQL message store
|
||||
expireQueueMsgs :: MsgStoreClass s => s -> Int64 -> Int64 -> StoreQueue s -> StoreMonad s MessageStats
|
||||
expireQueueMsgs st now old q = do
|
||||
(expired_, stored) <- withIdleMsgQueue now st q $ deleteExpireMsgs_ old q
|
||||
pure MessageStats {storedMsgsCount = stored, expiredMsgsCount = fromMaybe 0 expired_, storedQueues = 1}
|
||||
|
||||
deleteExpireMsgs_ :: MsgStoreClass s => Int64 -> StoreQueue s -> MsgQueue (StoreQueue s) -> StoreMonad s Int
|
||||
-- not used with PostgreSQL message store
|
||||
deleteExpireMsgs_ :: MsgStoreClass s => Int64 -> StoreQueue s -> MsgQueue s -> StoreMonad s Int
|
||||
deleteExpireMsgs_ old q mq = do
|
||||
n <- loop 0
|
||||
logQueueState q
|
||||
|
||||
@@ -25,9 +25,13 @@ module Simplex.Messaging.Server.QueueStore.Postgres
|
||||
batchInsertQueues,
|
||||
foldServiceRecs,
|
||||
foldQueueRecs,
|
||||
foldRecentQueueRecs,
|
||||
handleDuplicate,
|
||||
withLog_,
|
||||
withDB,
|
||||
withDB',
|
||||
assertUpdated,
|
||||
renderField,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -70,6 +74,7 @@ import Simplex.Messaging.Agent.Store.AgentStore ()
|
||||
import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Common
|
||||
import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_)
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
@@ -83,7 +88,7 @@ import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (SMPServiceRole (..))
|
||||
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, tshow, (<$$>))
|
||||
import Simplex.Messaging.Util (eitherToMaybe, firstRow, ifM, maybeFirstRow, maybeFirstRow', tshow, (<$$>))
|
||||
import System.Exit (exitFailure)
|
||||
import System.IO (IOMode (..), hFlush, stdout)
|
||||
import UnliftIO.STM
|
||||
@@ -104,15 +109,18 @@ data PostgresQueueStore q = PostgresQueueStore
|
||||
notifiers :: TMap NotifierId RecipientId,
|
||||
notifierLocks :: TMap NotifierId Lock,
|
||||
serviceLocks :: TMap CertFingerprint Lock,
|
||||
deletedTTL :: Int64
|
||||
deletedTTL :: Int64,
|
||||
useCache :: Bool
|
||||
}
|
||||
|
||||
instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
type QueueStoreCfg (PostgresQueueStore q) = PostgresStoreCfg
|
||||
type UseQueueCache = Bool
|
||||
|
||||
newQueueStore :: PostgresStoreCfg -> IO (PostgresQueueStore q)
|
||||
newQueueStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do
|
||||
dbStore <- either err pure =<< createDBStore dbOpts serverMigrations confirmMigrations
|
||||
instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
type QueueStoreCfg (PostgresQueueStore q) = (PostgresStoreCfg, UseQueueCache)
|
||||
|
||||
newQueueStore :: (PostgresStoreCfg, UseQueueCache) -> IO (PostgresQueueStore q)
|
||||
newQueueStore (PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL}, useCache) = do
|
||||
dbStore <- either err pure =<< createDBStore dbOpts serverMigrations (MigrationConfig confirmMigrations Nothing)
|
||||
dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath
|
||||
queues <- TM.emptyIO
|
||||
senders <- TM.emptyIO
|
||||
@@ -120,7 +128,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
notifiers <- TM.emptyIO
|
||||
notifierLocks <- TM.emptyIO
|
||||
serviceLocks <- TM.emptyIO
|
||||
pure PostgresQueueStore {dbStore, dbStoreLog, queues, senders, links, notifiers, notifierLocks, serviceLocks, deletedTTL}
|
||||
pure PostgresQueueStore {dbStore, dbStoreLog, queues, senders, links, notifiers, notifierLocks, serviceLocks, deletedTTL, useCache}
|
||||
where
|
||||
err e = do
|
||||
logError $ "STORE: newQueueStore, error opening PostgreSQL database, " <> tshow e
|
||||
@@ -167,28 +175,35 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
void $ withDB "addQueue_" st $ \db ->
|
||||
E.try (DB.execute db insertQueueQuery $ queueRecToRow (rId, qr))
|
||||
>>= bimapM handleDuplicate pure
|
||||
atomically $ TM.insert rId sq queues
|
||||
atomically $ TM.insert (senderId qr) rId senders
|
||||
forM_ (notifier qr) $ \NtfCreds {notifierId = nId} -> atomically $ TM.insert nId rId notifiers
|
||||
forM_ (queueData qr) $ \(lnkId, _) -> atomically $ TM.insert lnkId rId links
|
||||
when useCache $ do
|
||||
atomically $ TM.insert rId sq queues
|
||||
atomically $ TM.insert (senderId qr) rId senders
|
||||
forM_ (notifier qr) $ \NtfCreds {notifierId = nId} -> atomically $ TM.insert nId rId notifiers
|
||||
forM_ (queueData qr) $ \(lnkId, _) -> atomically $ TM.insert lnkId rId links
|
||||
withLog "addStoreQueue" st $ \s -> logCreateQueue s rId qr
|
||||
pure sq
|
||||
where
|
||||
PostgresQueueStore {queues, senders, links, notifiers} = st
|
||||
PostgresQueueStore {queues, senders, links, notifiers, useCache} = st
|
||||
-- Not doing duplicate checks in maps as the probability of duplicates is very low.
|
||||
-- It needs to be reconsidered when IDs are supplied by the users.
|
||||
-- hasId = anyM [TM.memberIO rId queues, TM.memberIO senderId senders, hasNotifier]
|
||||
-- hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.memberIO notifierId notifiers) notifier
|
||||
|
||||
getQueue_ :: QueueParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> QueueId -> IO (Either ErrorType q)
|
||||
getQueue_ st mkQ party qId = case party of
|
||||
SRecipient -> getRcvQueue qId
|
||||
SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
|
||||
SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue
|
||||
-- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server
|
||||
SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>))
|
||||
getQueue_ st mkQ party qId
|
||||
| useCache = case party of
|
||||
SRecipient -> getRcvQueue qId
|
||||
SSender -> TM.lookupIO qId senders >>= maybe (mask loadSndQueue) getRcvQueue
|
||||
SSenderLink -> TM.lookupIO qId links >>= maybe (mask loadLinkQueue) getRcvQueue
|
||||
-- loaded queue is deleted from notifiers map to reduce cache size after queue was subscribed to by ntf server
|
||||
SNotifier -> TM.lookupIO qId notifiers >>= maybe (mask loadNtfQueue) (getRcvQueue >=> (atomically (TM.delete qId notifiers) $>))
|
||||
| otherwise = case party of
|
||||
SRecipient -> loadQueueNoCache " WHERE recipient_id = ?"
|
||||
SSender -> loadQueueNoCache " WHERE sender_id = ?"
|
||||
SSenderLink -> loadQueueNoCache " WHERE link_id = ?"
|
||||
SNotifier -> loadQueueNoCache " WHERE notifier_id = ?"
|
||||
where
|
||||
PostgresQueueStore {queues, senders, links, notifiers} = st
|
||||
PostgresQueueStore {queues, senders, links, notifiers, useCache} = st
|
||||
getRcvQueue rId = TM.lookupIO rId queues >>= maybe (mask loadRcvQueue) (pure . Right)
|
||||
loadRcvQueue = do
|
||||
(rId, qRec) <- loadQueue " WHERE recipient_id = ?"
|
||||
@@ -205,6 +220,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
liftIO $
|
||||
TM.lookupIO rId queues -- checking recipient map first
|
||||
>>= maybe (cacheQueue rId qRec cacheSender) (atomically (cacheSender rId) $>)
|
||||
loadQueueNoCache cond = mask $ loadQueue cond >>= liftIO . uncurry (mkQ True)
|
||||
mask = E.uninterruptibleMask_ . runExceptT
|
||||
cacheSender rId = TM.insert qId rId senders
|
||||
loadQueue condition =
|
||||
@@ -227,20 +243,27 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
pure sq
|
||||
|
||||
getQueues_ :: forall p. BatchParty p => PostgresQueueStore q -> (Bool -> RecipientId -> QueueRec -> IO q) -> SParty p -> [QueueId] -> IO [Either ErrorType q]
|
||||
getQueues_ st mkQ party qIds = case party of
|
||||
SRecipient -> do
|
||||
qs <- readTVarIO queues
|
||||
let qs' = map (\qId -> get qs qId qId) qIds
|
||||
E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue
|
||||
SNotifier -> do
|
||||
ns <- readTVarIO notifiers
|
||||
qs <- readTVarIO queues
|
||||
let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds
|
||||
E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) ->
|
||||
forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query
|
||||
(nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs)
|
||||
getQueues_ st mkQ party qIds
|
||||
| null qIds = pure []
|
||||
| useCache = case party of
|
||||
SRecipient -> do
|
||||
qs <- readTVarIO queues
|
||||
let qs' = map (\qId -> get qs qId qId) qIds
|
||||
E.uninterruptibleMask_ $ loadQueues qs' " WHERE recipient_id IN ?" cacheRcvQueue
|
||||
SNotifier -> do
|
||||
ns <- readTVarIO notifiers
|
||||
qs <- readTVarIO queues
|
||||
let qs' = map (\qId -> get ns qId qId >>= get qs qId) qIds
|
||||
E.uninterruptibleMask_ $ loadQueues qs' " WHERE notifier_id IN ?" $ \(rId, qRec) ->
|
||||
forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> -- it is always Just with this query
|
||||
(nId,) <$> maybe (mkQ False rId qRec) pure (M.lookup rId qs)
|
||||
| otherwise = E.uninterruptibleMask_ $ case party of
|
||||
SRecipient -> loadQueuesNoCache " WHERE recipient_id IN ?" $ \(rId, qRec) ->
|
||||
Just . (rId,) <$> mkQ False rId qRec
|
||||
SNotifier -> loadQueuesNoCache " WHERE notifier_id IN ?" $ \(rId, qRec) ->
|
||||
forM (notifier qRec) $ \NtfCreds {notifierId = nId} -> (nId,) <$> mkQ False rId qRec
|
||||
where
|
||||
PostgresQueueStore {queues, notifiers} = st
|
||||
PostgresQueueStore {queues, notifiers, useCache} = st
|
||||
get :: M.Map QueueId a -> QueueId -> QueueId -> Either QueueId a
|
||||
get m qId = maybe (Left qId) Right . (`M.lookup` m)
|
||||
loadQueues :: [Either QueueId q] -> Query -> ((RecipientId, QueueRec) -> IO (Maybe (QueueId, q))) -> IO [Either ErrorType q]
|
||||
@@ -249,15 +272,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
if null qIds'
|
||||
then pure $ map (first (const INTERNAL)) qs'
|
||||
else do
|
||||
qs_ <-
|
||||
runExceptT $ fmap M.fromList $
|
||||
withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds')))
|
||||
>>= liftIO . fmap catMaybes . mapM (mkCacheQueue . rowToQueueRec)
|
||||
qs_ <- dbLoadQueues qIds' cond mkCacheQueue
|
||||
pure $ map (result qs_) qs'
|
||||
where
|
||||
result :: Either ErrorType (M.Map QueueId q) -> Either QueueId q -> Either ErrorType q
|
||||
result _ (Right q) = Right q
|
||||
result qs_ (Left qId) = maybe (Left AUTH) Right . M.lookup qId =<< qs_
|
||||
dbLoadQueues qIds' cond mkQueue' =
|
||||
runExceptT $ fmap M.fromList $
|
||||
withDB' "getQueues_" st (\db -> DB.query db (queueRecQuery <> cond <> " AND deleted_at IS NULL") (Only (In qIds')))
|
||||
>>= liftIO . fmap catMaybes . mapM (mkQueue' . rowToQueueRec)
|
||||
cacheRcvQueue (rId, qRec) = do
|
||||
sq <- mkQ True rId qRec
|
||||
sq' <- withQueueLock sq "getQueue_" $ atomically $
|
||||
@@ -266,6 +290,12 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
Just sq' -> pure sq'
|
||||
Nothing -> sq <$ TM.insert rId sq queues
|
||||
pure $ Just (rId, sq')
|
||||
loadQueuesNoCache cond mkQueue' = do
|
||||
qs_ <- dbLoadQueues qIds cond mkQueue'
|
||||
pure $ map (result qs_) qIds
|
||||
where
|
||||
result :: Either ErrorType (M.Map QueueId q) -> QueueId -> Either ErrorType q
|
||||
result qs_ qId = maybe (Left AUTH) Right . M.lookup qId =<< qs_
|
||||
|
||||
getQueueLinkData :: PostgresQueueStore q -> q -> LinkId -> IO (Either ErrorType QueueLinkData)
|
||||
getQueueLinkData st sq lnkId = runExceptT $ do
|
||||
@@ -331,19 +361,23 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
addQueueNotifier :: PostgresQueueStore q -> q -> NtfCreds -> IO (Either ErrorType (Maybe NtfCreds))
|
||||
addQueueNotifier st sq ntfCreds@NtfCreds {notifierId = nId, notifierKey, rcvNtfDhSecret} =
|
||||
withQueueRec sq "addQueueNotifier" $ \q ->
|
||||
ExceptT $ withLockMap (notifierLocks st) nId "addQueueNotifier" $
|
||||
ifM (TM.memberIO nId notifiers) (pure $ Left DUPLICATE_) $ runExceptT $ do
|
||||
assertUpdated $ withDB "addQueueNotifier" st $ \db ->
|
||||
E.try (update db) >>= bimapM handleDuplicate pure
|
||||
nc_ <- forM (notifier q) $ \nc@NtfCreds {notifierId} -> atomically (TM.delete notifierId notifiers) $> nc
|
||||
let !q' = q {notifier = Just ntfCreds}
|
||||
atomically $ writeTVar (queueRec sq) $ Just q'
|
||||
-- cache queue notifier ID – after notifier is added ntf server will likely subscribe
|
||||
checkCachedNotifier $ do
|
||||
assertUpdated $ withDB "addQueueNotifier" st $ \db ->
|
||||
E.try (update db) >>= bimapM handleDuplicate pure
|
||||
nc_ <- forM (notifier q) $ \nc@NtfCreds {notifierId} -> atomically (TM.delete notifierId notifiers) $> nc
|
||||
let !q' = q {notifier = Just ntfCreds}
|
||||
atomically $ writeTVar (queueRec sq) $ Just q'
|
||||
when useCache $ do
|
||||
atomically $ TM.insert nId rId notifiers
|
||||
withLog "addQueueNotifier" st $ \s -> logAddNotifier s rId ntfCreds
|
||||
pure nc_
|
||||
withLog "addQueueNotifier" st $ \s -> logAddNotifier s rId ntfCreds
|
||||
pure nc_
|
||||
where
|
||||
PostgresQueueStore {notifiers} = st
|
||||
checkCachedNotifier add
|
||||
| useCache =
|
||||
ExceptT $ withLockMap (notifierLocks st) nId "addQueueNotifier" $
|
||||
ifM (TM.memberIO nId notifiers) (pure $ Left DUPLICATE_) $ runExceptT add
|
||||
| otherwise = add
|
||||
PostgresQueueStore {notifiers, useCache} = st
|
||||
rId = recipientId sq
|
||||
update db =
|
||||
DB.execute
|
||||
@@ -359,13 +393,16 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
deleteQueueNotifier st sq =
|
||||
withQueueRec sq "deleteQueueNotifier" $ \q ->
|
||||
ExceptT $ fmap sequence $ forM (notifier q) $ \nc@NtfCreds {notifierId = nId} ->
|
||||
withLockMap (notifierLocks st) nId "deleteQueueNotifier" $ runExceptT $ do
|
||||
withNotifierLock nId $ runExceptT $ do
|
||||
assertUpdated $ withDB' "deleteQueueNotifier" st update
|
||||
atomically $ TM.delete nId $ notifiers st
|
||||
when (useCache st) $ atomically $ TM.delete nId $ notifiers st
|
||||
atomically $ writeTVar (queueRec sq) $ Just q {notifier = Nothing}
|
||||
withLog "deleteQueueNotifier" st (`logDeleteNotifier` rId)
|
||||
pure nc
|
||||
where
|
||||
withNotifierLock nId
|
||||
| useCache st = withLockMap (notifierLocks st) nId "deleteQueueNotifier"
|
||||
| otherwise = id
|
||||
rId = recipientId sq
|
||||
update db =
|
||||
DB.execute
|
||||
@@ -408,20 +445,20 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
rId = recipientId sq
|
||||
|
||||
-- this method is called from JournalMsgStore deleteQueue that already locks the queue
|
||||
deleteStoreQueue :: PostgresQueueStore q -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q)))
|
||||
deleteStoreQueue :: PostgresQueueStore q -> q -> IO (Either ErrorType QueueRec)
|
||||
deleteStoreQueue st sq = E.uninterruptibleMask_ $ runExceptT $ do
|
||||
q <- ExceptT $ readQueueRecIO qr
|
||||
RoundedSystemTime ts <- liftIO getSystemDate
|
||||
assertUpdated $ withDB' "deleteStoreQueue" st $ \db ->
|
||||
DB.execute db "UPDATE msg_queues SET deleted_at = ? WHERE recipient_id = ? AND deleted_at IS NULL" (ts, rId)
|
||||
atomically $ writeTVar qr Nothing
|
||||
atomically $ TM.delete (senderId q) $ senders st
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> do
|
||||
atomically $ TM.delete notifierId $ notifiers st
|
||||
atomically $ TM.delete notifierId $ notifierLocks st
|
||||
mq_ <- atomically $ swapTVar (msgQueue sq) Nothing
|
||||
when (useCache st) $ do
|
||||
atomically $ TM.delete (senderId q) $ senders st
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> do
|
||||
atomically $ TM.delete notifierId $ notifiers st
|
||||
atomically $ TM.delete notifierId $ notifierLocks st
|
||||
withLog "deleteStoreQueue" st (`logDeleteQueue` rId)
|
||||
pure (q, mq_)
|
||||
pure q
|
||||
where
|
||||
rId = recipientId sq
|
||||
qr = queueRec sq
|
||||
@@ -487,7 +524,7 @@ instance StoreQueueClass q => QueueStoreClass q (PostgresQueueStore q) where
|
||||
getServiceQueueCount :: (PartyI p, ServiceParty p) => PostgresQueueStore q -> SParty p -> ServiceId -> IO (Either ErrorType Int64)
|
||||
getServiceQueueCount st party serviceId =
|
||||
E.uninterruptibleMask_ $ runExceptT $ withDB' "getServiceQueueCount" st $ \db ->
|
||||
fmap (fromMaybe 0) $ maybeFirstRow fromOnly $
|
||||
maybeFirstRow' 0 fromOnly $
|
||||
DB.query db query (Only serviceId)
|
||||
where
|
||||
query = case party of
|
||||
@@ -545,8 +582,28 @@ foldServiceRecs st f =
|
||||
DB.fold_ db "SELECT service_id, service_role, service_cert, service_cert_hash, created_at FROM services" mempty $
|
||||
\ !acc -> fmap (acc <>) . f . rowToServiceRec
|
||||
|
||||
foldQueueRecs :: forall a q. Monoid a => Bool -> Bool -> PostgresQueueStore q -> Maybe Int64 -> ((RecipientId, QueueRec) -> IO a) -> IO a
|
||||
foldQueueRecs tty withData st skipOld_ f = do
|
||||
foldQueueRecs :: Monoid a => Bool -> Bool -> PostgresQueueStore q -> ((RecipientId, QueueRec) -> IO a) -> IO a
|
||||
foldQueueRecs withData = foldQueueRecs_ foldRecs
|
||||
where
|
||||
foldRecs db acc f'
|
||||
| withData = DB.fold_ db (queueRecQueryWithData <> cond) acc $ \acc' -> f' acc' . rowToQueueRecWithData
|
||||
| otherwise = DB.fold_ db (queueRecQuery <> cond) acc $ \acc' -> f' acc' . rowToQueueRec
|
||||
cond = " WHERE deleted_at IS NULL ORDER BY recipient_id ASC"
|
||||
|
||||
foldRecentQueueRecs :: Monoid a => Int64 -> Bool -> PostgresQueueStore q -> ((RecipientId, QueueRec) -> IO a) -> IO a
|
||||
foldRecentQueueRecs old = foldQueueRecs_ foldRecs
|
||||
where
|
||||
foldRecs db acc f' = DB.fold db (queueRecQuery <> cond) (Only old) acc $ \acc' -> f' acc' . rowToQueueRec
|
||||
cond = " WHERE deleted_at IS NULL AND updated_at > ? ORDER BY recipient_id ASC"
|
||||
|
||||
foldQueueRecs_ ::
|
||||
Monoid a =>
|
||||
(DB.Connection -> (Int, a) -> ((Int, a) -> (RecipientId, QueueRec) -> IO (Int, a)) -> IO (Int, a)) ->
|
||||
Bool ->
|
||||
PostgresQueueStore q ->
|
||||
((RecipientId, QueueRec) -> IO a) ->
|
||||
IO a
|
||||
foldQueueRecs_ foldRecs tty st f = do
|
||||
(n, r) <- withTransaction (dbStore st) $ \db ->
|
||||
foldRecs db (0 :: Int, mempty) $ \(i, acc) qr -> do
|
||||
r <- f qr
|
||||
@@ -557,13 +614,6 @@ foldQueueRecs tty withData st skipOld_ f = do
|
||||
when tty $ putStrLn $ progress n
|
||||
pure r
|
||||
where
|
||||
foldRecs db acc f' = case skipOld_ of
|
||||
Nothing
|
||||
| withData -> DB.fold_ db (queueRecQueryWithData <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRecWithData
|
||||
| otherwise -> DB.fold_ db (queueRecQuery <> " WHERE deleted_at IS NULL") acc $ \acc' -> f' acc' . rowToQueueRec
|
||||
Just old
|
||||
| withData -> DB.fold db (queueRecQueryWithData <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRecWithData
|
||||
| otherwise -> DB.fold db (queueRecQuery <> " WHERE deleted_at IS NULL AND updated_at > ?") (Only old) acc $ \acc' -> f' acc' . rowToQueueRec
|
||||
progress i = "Processed: " <> show i <> " records"
|
||||
|
||||
queueRecQuery :: Query
|
||||
@@ -627,13 +677,14 @@ queueRecToText (rId, QueueRec {recipientKeys, rcvDhSecret, senderId, senderKey,
|
||||
(linkId_, queueData_) = queueDataColumns queueData
|
||||
nullable :: ToField a => Maybe a -> Builder
|
||||
nullable = maybe mempty (renderField . toField)
|
||||
renderField :: Action -> Builder
|
||||
renderField = \case
|
||||
Plain bld -> bld
|
||||
Escape s -> BB.byteString s
|
||||
EscapeByteA s -> BB.string7 "\\x" <> BB.byteStringHex s
|
||||
EscapeIdentifier s -> BB.byteString s -- Not used in COPY data
|
||||
Many as -> mconcat (map renderField as)
|
||||
|
||||
renderField :: Action -> Builder
|
||||
renderField = \case
|
||||
Plain bld -> bld
|
||||
Escape s -> BB.byteString s
|
||||
EscapeByteA s -> BB.string7 "\\x" <> BB.byteStringHex s
|
||||
EscapeIdentifier s -> BB.byteString s -- Not used in COPY data
|
||||
Many as -> mconcat (map renderField as)
|
||||
|
||||
queueDataColumns :: Maybe (LinkId, QueueLinkData) -> (Maybe LinkId, Maybe QueueLinkData)
|
||||
queueDataColumns = \case
|
||||
|
||||
@@ -14,7 +14,8 @@ serverSchemaMigrations =
|
||||
[ ("20250207_initial", m20250207_initial, Nothing),
|
||||
("20250319_updated_index", m20250319_updated_index, Just down_m20250319_updated_index),
|
||||
("20250320_short_links", m20250320_short_links, Just down_m20250320_short_links),
|
||||
("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs)
|
||||
("20250514_service_certs", m20250514_service_certs, Just down_m20250514_service_certs),
|
||||
("20250903_store_messages", m20250903_store_messages, Just down_m20250903_store_messages)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
@@ -159,3 +160,299 @@ DROP INDEX idx_services_service_role;
|
||||
|
||||
DROP TABLE services;
|
||||
|]
|
||||
|
||||
m20250903_store_messages :: Text
|
||||
m20250903_store_messages =
|
||||
T.pack
|
||||
[r|
|
||||
CREATE TABLE messages(
|
||||
message_id BIGINT NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
|
||||
recipient_id BYTEA NOT NULL REFERENCES msg_queues ON DELETE CASCADE ON UPDATE RESTRICT,
|
||||
msg_id BYTEA NOT NULL,
|
||||
msg_ts BIGINT NOT NULL,
|
||||
msg_quota BOOLEAN NOT NULL,
|
||||
msg_ntf_flag BOOLEAN NOT NULL,
|
||||
msg_body BYTEA NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE msg_queues
|
||||
ADD COLUMN msg_can_write BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
ADD COLUMN msg_queue_expire BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
ADD COLUMN msg_queue_size BIGINT NOT NULL DEFAULT 0;
|
||||
|
||||
CREATE INDEX idx_messages_recipient_id_message_id ON messages (recipient_id, message_id);
|
||||
CREATE INDEX idx_messages_recipient_id_msg_ts on messages(recipient_id, msg_ts);
|
||||
CREATE INDEX idx_messages_recipient_id_msg_quota on messages(recipient_id, msg_quota);
|
||||
|
||||
DROP INDEX idx_msg_queues_updated_at;
|
||||
CREATE INDEX idx_msg_queues_updated_at_recipient_id ON msg_queues (deleted_at, updated_at, msg_queue_expire, recipient_id);
|
||||
|
||||
CREATE FUNCTION write_message(
|
||||
p_recipient_id BYTEA,
|
||||
p_msg_id BYTEA,
|
||||
p_msg_ts BIGINT,
|
||||
p_msg_quota BOOLEAN,
|
||||
p_msg_ntf_flag BOOLEAN,
|
||||
p_msg_body BYTEA,
|
||||
p_quota INT
|
||||
)
|
||||
RETURNS TABLE (quota_written BOOLEAN, was_empty BOOLEAN)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
q_can_write BOOLEAN;
|
||||
q_size BIGINT;
|
||||
BEGIN
|
||||
SELECT msg_can_write, msg_queue_size INTO q_can_write, q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE;
|
||||
|
||||
IF q_can_write OR q_size = 0 THEN
|
||||
quota_written := p_msg_quota OR q_size >= p_quota;
|
||||
was_empty := q_size = 0;
|
||||
|
||||
INSERT INTO messages(recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body)
|
||||
VALUES (p_recipient_id, p_msg_id, p_msg_ts, quota_written, p_msg_ntf_flag AND NOT quota_written, CASE WHEN quota_written THEN '' :: BYTEA ELSE p_msg_body END);
|
||||
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = NOT quota_written,
|
||||
msg_queue_expire = TRUE,
|
||||
msg_queue_size = msg_queue_size + 1
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
|
||||
RETURN QUERY VALUES (quota_written, was_empty);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION try_del_msg(p_recipient_id BYTEA, p_msg_id BYTEA)
|
||||
RETURNS TABLE (r_msg_id BYTEA, r_msg_ts BIGINT, r_msg_quota BOOLEAN, r_msg_ntf_flag BOOLEAN, r_msg_body BYTEA)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
q_size BIGINT;
|
||||
msg RECORD;
|
||||
BEGIN
|
||||
SELECT msg_queue_size INTO q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
INTO msg
|
||||
FROM messages
|
||||
WHERE recipient_id = p_recipient_id
|
||||
ORDER BY message_id ASC LIMIT 1;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
IF q_size != 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = TRUE,
|
||||
msg_queue_expire = FALSE,
|
||||
msg_queue_size = 0
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
IF msg.msg_id = p_msg_id THEN
|
||||
DELETE FROM messages WHERE message_id = msg.message_id;
|
||||
IF FOUND THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = msg_can_write OR msg_queue_size <= 1,
|
||||
msg_queue_expire = msg_queue_size > 1,
|
||||
msg_queue_size = GREATEST(msg_queue_size - 1, 0)
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
END IF;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION try_del_peek_msg(p_recipient_id BYTEA, p_msg_id BYTEA)
|
||||
RETURNS TABLE (r_msg_id BYTEA, r_msg_ts BIGINT, r_msg_quota BOOLEAN, r_msg_ntf_flag BOOLEAN, r_msg_body BYTEA)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
q_size BIGINT;
|
||||
msg RECORD;
|
||||
msg_deleted BOOLEAN;
|
||||
BEGIN
|
||||
SELECT msg_queue_size INTO q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
INTO msg
|
||||
FROM messages
|
||||
WHERE recipient_id = p_recipient_id
|
||||
ORDER BY message_id ASC LIMIT 1;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
IF q_size != 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = TRUE,
|
||||
msg_queue_expire = FALSE,
|
||||
msg_queue_size = 0
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
IF msg.msg_id = p_msg_id THEN
|
||||
DELETE FROM messages WHERE message_id = msg.message_id;
|
||||
|
||||
msg_deleted := FOUND;
|
||||
IF msg_deleted THEN
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
END IF;
|
||||
|
||||
SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
INTO msg
|
||||
FROM messages
|
||||
WHERE recipient_id = p_recipient_id
|
||||
ORDER BY message_id ASC LIMIT 1;
|
||||
|
||||
IF FOUND THEN
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
IF msg_deleted THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = msg_can_write OR msg_queue_size <= 1,
|
||||
msg_queue_expire = msg_queue_size > 1,
|
||||
msg_queue_size = GREATEST(msg_queue_size - 1, 0)
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
ELSIF msg_deleted OR q_size != 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = TRUE,
|
||||
msg_queue_expire = FALSE,
|
||||
msg_queue_size = 0
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
ELSE
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION delete_expired_msgs(p_recipient_id BYTEA, p_old_ts BIGINT) RETURNS BIGINT
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
q_size BIGINT;
|
||||
keep_min_id BIGINT;
|
||||
del_count BIGINT;
|
||||
BEGIN
|
||||
SELECT msg_queue_size INTO q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE SKIP LOCKED;
|
||||
|
||||
IF NOT FOUND OR q_size = 0 THEN
|
||||
RETURN 0;
|
||||
END IF;
|
||||
|
||||
SELECT MIN(message_id) INTO keep_min_id
|
||||
FROM messages WHERE recipient_id = p_recipient_id AND msg_ts >= p_old_ts AND msg_quota = FALSE;
|
||||
|
||||
IF keep_min_id IS NULL THEN
|
||||
DELETE FROM messages WHERE recipient_id = p_recipient_id AND msg_quota = FALSE;
|
||||
ELSE
|
||||
DELETE FROM messages WHERE recipient_id = p_recipient_id AND message_id < keep_min_id AND msg_quota = FALSE;
|
||||
END IF;
|
||||
|
||||
GET DIAGNOSTICS del_count = ROW_COUNT;
|
||||
IF del_count > 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = msg_can_write OR msg_queue_size <= del_count,
|
||||
msg_queue_expire = msg_queue_size > del_count AND keep_min_id IS NOT NULL,
|
||||
msg_queue_size = GREATEST(msg_queue_size - del_count, 0)
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
RETURN del_count;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE PROCEDURE expire_old_messages(
|
||||
p_old_queue BIGINT,
|
||||
p_old_ts BIGINT,
|
||||
batch_size INT,
|
||||
OUT r_expired_msgs_count BIGINT,
|
||||
OUT r_stored_msgs_count BIGINT,
|
||||
OUT r_stored_queues BIGINT
|
||||
)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
rids BYTEA[];
|
||||
rid BYTEA;
|
||||
last_rid BYTEA := '\x';
|
||||
del_count BIGINT;
|
||||
total_deleted BIGINT := 0;
|
||||
BEGIN
|
||||
LOOP
|
||||
SELECT array_agg(recipient_id)
|
||||
INTO rids
|
||||
FROM (
|
||||
SELECT recipient_id
|
||||
FROM msg_queues
|
||||
WHERE deleted_at IS NULL
|
||||
AND updated_at > p_old_queue
|
||||
AND msg_queue_expire = TRUE
|
||||
AND recipient_id > last_rid
|
||||
ORDER BY recipient_id ASC
|
||||
LIMIT batch_size
|
||||
) qs;
|
||||
|
||||
EXIT WHEN rids IS NULL OR cardinality(rids) = 0;
|
||||
|
||||
FOREACH rid IN ARRAY rids
|
||||
LOOP
|
||||
BEGIN
|
||||
del_count := delete_expired_msgs(rid, p_old_ts);
|
||||
total_deleted := total_deleted + del_count;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'STORE, expire_old_messages, error expiring queue %: %', encode(rid, 'base64'), SQLERRM;
|
||||
CONTINUE;
|
||||
END;
|
||||
COMMIT;
|
||||
END LOOP;
|
||||
last_rid := rids[cardinality(rids)];
|
||||
END LOOP;
|
||||
|
||||
r_expired_msgs_count := total_deleted;
|
||||
r_stored_msgs_count := (SELECT COUNT(1) FROM messages);
|
||||
r_stored_queues := (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL);
|
||||
END;
|
||||
$$;
|
||||
|]
|
||||
|
||||
down_m20250903_store_messages :: Text
|
||||
down_m20250903_store_messages =
|
||||
T.pack
|
||||
[r|
|
||||
DROP FUNCTION write_message;
|
||||
DROP FUNCTION try_del_msg;
|
||||
DROP FUNCTION try_del_peek_msg;
|
||||
DROP FUNCTION delete_expired_msgs;
|
||||
DROP PROCEDURE expire_old_messages;
|
||||
|
||||
DROP INDEX idx_msg_queues_updated_at_recipient_id;
|
||||
CREATE INDEX idx_msg_queues_updated_at ON msg_queues (deleted_at, updated_at);
|
||||
|
||||
DROP INDEX idx_messages_recipient_id_message_id;
|
||||
DROP INDEX idx_messages_recipient_id_msg_ts;
|
||||
DROP INDEX idx_messages_recipient_id_msg_quota;
|
||||
|
||||
ALTER TABLE msg_queues
|
||||
DROP COLUMN msg_can_write,
|
||||
DROP COLUMN msg_queue_expire,
|
||||
DROP COLUMN msg_queue_size;
|
||||
|
||||
DROP TABLE messages;
|
||||
|]
|
||||
|
||||
@@ -15,9 +15,273 @@ SET row_security = off;
|
||||
CREATE SCHEMA smp_server;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.delete_expired_msgs(p_recipient_id bytea, p_old_ts bigint) RETURNS bigint
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
q_size BIGINT;
|
||||
keep_min_id BIGINT;
|
||||
del_count BIGINT;
|
||||
BEGIN
|
||||
SELECT msg_queue_size INTO q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE SKIP LOCKED;
|
||||
|
||||
IF NOT FOUND OR q_size = 0 THEN
|
||||
RETURN 0;
|
||||
END IF;
|
||||
|
||||
SELECT MIN(message_id) INTO keep_min_id
|
||||
FROM messages WHERE recipient_id = p_recipient_id AND msg_ts >= p_old_ts AND msg_quota = FALSE;
|
||||
|
||||
IF keep_min_id IS NULL THEN
|
||||
DELETE FROM messages WHERE recipient_id = p_recipient_id AND msg_quota = FALSE;
|
||||
ELSE
|
||||
DELETE FROM messages WHERE recipient_id = p_recipient_id AND message_id < keep_min_id AND msg_quota = FALSE;
|
||||
END IF;
|
||||
|
||||
GET DIAGNOSTICS del_count = ROW_COUNT;
|
||||
IF del_count > 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = msg_can_write OR msg_queue_size <= del_count,
|
||||
msg_queue_expire = msg_queue_size > del_count AND keep_min_id IS NOT NULL,
|
||||
msg_queue_size = GREATEST(msg_queue_size - del_count, 0)
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
RETURN del_count;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE PROCEDURE smp_server.expire_old_messages(IN p_old_queue bigint, IN p_old_ts bigint, IN batch_size integer, OUT r_expired_msgs_count bigint, OUT r_stored_msgs_count bigint, OUT r_stored_queues bigint)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
rids BYTEA[];
|
||||
rid BYTEA;
|
||||
last_rid BYTEA := '\x';
|
||||
del_count BIGINT;
|
||||
total_deleted BIGINT := 0;
|
||||
BEGIN
|
||||
LOOP
|
||||
SELECT array_agg(recipient_id)
|
||||
INTO rids
|
||||
FROM (
|
||||
SELECT recipient_id
|
||||
FROM msg_queues
|
||||
WHERE deleted_at IS NULL
|
||||
AND updated_at > p_old_queue
|
||||
AND msg_queue_expire = TRUE
|
||||
AND recipient_id > last_rid
|
||||
ORDER BY recipient_id ASC
|
||||
LIMIT batch_size
|
||||
) qs;
|
||||
|
||||
EXIT WHEN rids IS NULL OR cardinality(rids) = 0;
|
||||
|
||||
FOREACH rid IN ARRAY rids
|
||||
LOOP
|
||||
BEGIN
|
||||
del_count := delete_expired_msgs(rid, p_old_ts);
|
||||
total_deleted := total_deleted + del_count;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE WARNING 'STORE, expire_old_messages, error expiring queue %: %', encode(rid, 'base64'), SQLERRM;
|
||||
CONTINUE;
|
||||
END;
|
||||
COMMIT;
|
||||
END LOOP;
|
||||
last_rid := rids[cardinality(rids)];
|
||||
END LOOP;
|
||||
|
||||
r_expired_msgs_count := total_deleted;
|
||||
r_stored_msgs_count := (SELECT COUNT(1) FROM messages);
|
||||
r_stored_queues := (SELECT COUNT(1) FROM msg_queues WHERE deleted_at IS NULL);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.try_del_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
q_size BIGINT;
|
||||
msg RECORD;
|
||||
BEGIN
|
||||
SELECT msg_queue_size INTO q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
INTO msg
|
||||
FROM messages
|
||||
WHERE recipient_id = p_recipient_id
|
||||
ORDER BY message_id ASC LIMIT 1;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
IF q_size != 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = TRUE,
|
||||
msg_queue_expire = FALSE,
|
||||
msg_queue_size = 0
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
IF msg.msg_id = p_msg_id THEN
|
||||
DELETE FROM messages WHERE message_id = msg.message_id;
|
||||
IF FOUND THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = msg_can_write OR msg_queue_size <= 1,
|
||||
msg_queue_expire = msg_queue_size > 1,
|
||||
msg_queue_size = GREATEST(msg_queue_size - 1, 0)
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
END IF;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.try_del_peek_msg(p_recipient_id bytea, p_msg_id bytea) RETURNS TABLE(r_msg_id bytea, r_msg_ts bigint, r_msg_quota boolean, r_msg_ntf_flag boolean, r_msg_body bytea)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
q_size BIGINT;
|
||||
msg RECORD;
|
||||
msg_deleted BOOLEAN;
|
||||
BEGIN
|
||||
SELECT msg_queue_size INTO q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
SELECT message_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
INTO msg
|
||||
FROM messages
|
||||
WHERE recipient_id = p_recipient_id
|
||||
ORDER BY message_id ASC LIMIT 1;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
IF q_size != 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = TRUE,
|
||||
msg_queue_expire = FALSE,
|
||||
msg_queue_size = 0
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
IF msg.msg_id = p_msg_id THEN
|
||||
DELETE FROM messages WHERE message_id = msg.message_id;
|
||||
|
||||
msg_deleted := FOUND;
|
||||
IF msg_deleted THEN
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
END IF;
|
||||
|
||||
SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
||||
INTO msg
|
||||
FROM messages
|
||||
WHERE recipient_id = p_recipient_id
|
||||
ORDER BY message_id ASC LIMIT 1;
|
||||
|
||||
IF FOUND THEN
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
IF msg_deleted THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = msg_can_write OR msg_queue_size <= 1,
|
||||
msg_queue_expire = msg_queue_size > 1,
|
||||
msg_queue_size = GREATEST(msg_queue_size - 1, 0)
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
ELSIF msg_deleted OR q_size != 0 THEN
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = TRUE,
|
||||
msg_queue_expire = FALSE,
|
||||
msg_queue_size = 0
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
END IF;
|
||||
ELSE
|
||||
RETURN QUERY VALUES (msg.msg_id, msg.msg_ts, msg.msg_quota, msg.msg_ntf_flag, msg.msg_body);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE FUNCTION smp_server.write_message(p_recipient_id bytea, p_msg_id bytea, p_msg_ts bigint, p_msg_quota boolean, p_msg_ntf_flag boolean, p_msg_body bytea, p_quota integer) RETURNS TABLE(quota_written boolean, was_empty boolean)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
q_can_write BOOLEAN;
|
||||
q_size BIGINT;
|
||||
BEGIN
|
||||
SELECT msg_can_write, msg_queue_size INTO q_can_write, q_size
|
||||
FROM msg_queues
|
||||
WHERE recipient_id = p_recipient_id AND deleted_at IS NULL
|
||||
FOR UPDATE;
|
||||
|
||||
IF q_can_write OR q_size = 0 THEN
|
||||
quota_written := p_msg_quota OR q_size >= p_quota;
|
||||
was_empty := q_size = 0;
|
||||
|
||||
INSERT INTO messages(recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body)
|
||||
VALUES (p_recipient_id, p_msg_id, p_msg_ts, quota_written, p_msg_ntf_flag AND NOT quota_written, CASE WHEN quota_written THEN '' :: BYTEA ELSE p_msg_body END);
|
||||
|
||||
UPDATE msg_queues
|
||||
SET msg_can_write = NOT quota_written,
|
||||
msg_queue_expire = TRUE,
|
||||
msg_queue_size = msg_queue_size + 1
|
||||
WHERE recipient_id = p_recipient_id;
|
||||
|
||||
RETURN QUERY VALUES (quota_written, was_empty);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
SET default_table_access_method = heap;
|
||||
|
||||
|
||||
CREATE TABLE smp_server.messages (
|
||||
message_id bigint NOT NULL,
|
||||
recipient_id bytea NOT NULL,
|
||||
msg_id bytea NOT NULL,
|
||||
msg_ts bigint NOT NULL,
|
||||
msg_quota boolean NOT NULL,
|
||||
msg_ntf_flag boolean NOT NULL,
|
||||
msg_body bytea NOT NULL
|
||||
);
|
||||
|
||||
|
||||
|
||||
ALTER TABLE smp_server.messages ALTER COLUMN message_id ADD GENERATED ALWAYS AS IDENTITY (
|
||||
SEQUENCE NAME smp_server.messages_message_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1
|
||||
);
|
||||
|
||||
|
||||
|
||||
CREATE TABLE smp_server.migrations (
|
||||
name text NOT NULL,
|
||||
ts timestamp without time zone NOT NULL,
|
||||
@@ -43,7 +307,10 @@ CREATE TABLE smp_server.msg_queues (
|
||||
fixed_data bytea,
|
||||
user_data bytea,
|
||||
rcv_service_id bytea,
|
||||
ntf_service_id bytea
|
||||
ntf_service_id bytea,
|
||||
msg_can_write boolean DEFAULT true NOT NULL,
|
||||
msg_queue_expire boolean DEFAULT false NOT NULL,
|
||||
msg_queue_size bigint DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
|
||||
@@ -58,6 +325,11 @@ CREATE TABLE smp_server.services (
|
||||
|
||||
|
||||
|
||||
ALTER TABLE ONLY smp_server.messages
|
||||
ADD CONSTRAINT messages_pkey PRIMARY KEY (message_id);
|
||||
|
||||
|
||||
|
||||
ALTER TABLE ONLY smp_server.migrations
|
||||
ADD CONSTRAINT migrations_pkey PRIMARY KEY (name);
|
||||
|
||||
@@ -78,6 +350,18 @@ ALTER TABLE ONLY smp_server.services
|
||||
|
||||
|
||||
|
||||
CREATE INDEX idx_messages_recipient_id_message_id ON smp_server.messages USING btree (recipient_id, message_id);
|
||||
|
||||
|
||||
|
||||
CREATE INDEX idx_messages_recipient_id_msg_quota ON smp_server.messages USING btree (recipient_id, msg_quota);
|
||||
|
||||
|
||||
|
||||
CREATE INDEX idx_messages_recipient_id_msg_ts ON smp_server.messages USING btree (recipient_id, msg_ts);
|
||||
|
||||
|
||||
|
||||
CREATE UNIQUE INDEX idx_msg_queues_link_id ON smp_server.msg_queues USING btree (link_id);
|
||||
|
||||
|
||||
@@ -98,7 +382,7 @@ CREATE UNIQUE INDEX idx_msg_queues_sender_id ON smp_server.msg_queues USING btre
|
||||
|
||||
|
||||
|
||||
CREATE INDEX idx_msg_queues_updated_at ON smp_server.msg_queues USING btree (deleted_at, updated_at);
|
||||
CREATE INDEX idx_msg_queues_updated_at_recipient_id ON smp_server.msg_queues USING btree (deleted_at, updated_at, msg_queue_expire, recipient_id);
|
||||
|
||||
|
||||
|
||||
@@ -106,6 +390,11 @@ CREATE INDEX idx_services_service_role ON smp_server.services USING btree (servi
|
||||
|
||||
|
||||
|
||||
ALTER TABLE ONLY smp_server.messages
|
||||
ADD CONSTRAINT messages_recipient_id_fkey FOREIGN KEY (recipient_id) REFERENCES smp_server.msg_queues(recipient_id) ON UPDATE RESTRICT ON DELETE CASCADE;
|
||||
|
||||
|
||||
|
||||
ALTER TABLE ONLY smp_server.msg_queues
|
||||
ADD CONSTRAINT msg_queues_ntf_service_id_fkey FOREIGN KEY (ntf_service_id) REFERENCES smp_server.services(service_id) ON UPDATE RESTRICT ON DELETE SET NULL;
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
serviceQueuesCount serviceSel = foldM (\n s -> (n +) . S.size <$> readTVarIO (serviceSel s)) 0
|
||||
|
||||
addQueue_ :: STMQueueStore q -> (RecipientId -> QueueRec -> IO q) -> RecipientId -> QueueRec -> IO (Either ErrorType q)
|
||||
addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData} = do
|
||||
addQueue_ st mkQ rId qr@QueueRec {senderId = sId, notifier, queueData, rcvServiceId} = do
|
||||
sq <- mkQ rId qr
|
||||
add sq $>> withLog "addStoreQueue" st (\s -> logCreateQueue s rId qr) $> Right sq
|
||||
where
|
||||
@@ -122,8 +122,11 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
add q = atomically $ ifM hasId (pure $ Left DUPLICATE_) $ Right () <$ do
|
||||
TM.insert rId q queues
|
||||
TM.insert sId rId senders
|
||||
forM_ notifier $ \NtfCreds {notifierId} -> TM.insert notifierId rId notifiers
|
||||
forM_ notifier $ \NtfCreds {notifierId = nId, ntfServiceId} -> do
|
||||
TM.insert nId rId notifiers
|
||||
mapM_ (addServiceQueue st serviceNtfQueues nId) ntfServiceId
|
||||
forM_ queueData $ \(lnkId, _) -> TM.insert lnkId rId links
|
||||
mapM_ (addServiceQueue st serviceRcvQueues rId) rcvServiceId
|
||||
hasId = anyM [TM.member rId queues, TM.member sId senders, hasNotifier, hasLink]
|
||||
hasNotifier = maybe (pure False) (\NtfCreds {notifierId} -> TM.member notifierId notifiers) notifier
|
||||
hasLink = maybe (pure False) (\(lnkId, _) -> TM.member lnkId links) queueData
|
||||
@@ -225,7 +228,7 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
deleteQueueNotifier :: STMQueueStore q -> q -> IO (Either ErrorType (Maybe NtfCreds))
|
||||
deleteQueueNotifier st sq =
|
||||
withQueueRec qr delete
|
||||
$>>= \nc_ -> nc_ <$$ withLog "deleteQueueNotifier" st (`logDeleteNotifier` recipientId sq)
|
||||
$>>= (<$$ withLog "deleteQueueNotifier" st (`logDeleteNotifier` recipientId sq))
|
||||
where
|
||||
qr = queueRec sq
|
||||
delete q = forM (notifier q) $ \nc -> do
|
||||
@@ -261,11 +264,10 @@ instance StoreQueueClass q => QueueStoreClass q (STMQueueStore q) where
|
||||
| changed = q <$$ withLog "updateQueueTime" st (\sl -> logUpdateQueueTime sl (recipientId sq) t)
|
||||
| otherwise = pure $ Right q
|
||||
|
||||
deleteStoreQueue :: STMQueueStore q -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q)))
|
||||
deleteStoreQueue :: STMQueueStore q -> q -> IO (Either ErrorType QueueRec)
|
||||
deleteStoreQueue st sq =
|
||||
withQueueRec qr delete
|
||||
$>>= \q -> withLog "deleteStoreQueue" st (`logDeleteQueue` rId)
|
||||
>>= mapM (\_ -> (q,) <$> atomically (swapTVar (msgQueue sq) Nothing))
|
||||
$>>= (<$$ withLog "deleteStoreQueue" st (`logDeleteQueue` rId))
|
||||
where
|
||||
rId = recipientId sq
|
||||
qr = queueRec sq
|
||||
|
||||
@@ -17,10 +17,8 @@ import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
|
||||
class StoreQueueClass q where
|
||||
type MsgQueue q = mq | mq -> q
|
||||
recipientId :: q -> RecipientId
|
||||
queueRec :: q -> TVar (Maybe QueueRec)
|
||||
msgQueue :: q -> TVar (Maybe (MsgQueue q))
|
||||
withQueueLock :: q -> Text -> IO a -> IO a
|
||||
|
||||
class StoreQueueClass q => QueueStoreClass q s where
|
||||
@@ -44,7 +42,7 @@ class StoreQueueClass q => QueueStoreClass q s where
|
||||
blockQueue :: s -> q -> BlockingInfo -> IO (Either ErrorType ())
|
||||
unblockQueue :: s -> q -> IO (Either ErrorType ())
|
||||
updateQueueTime :: s -> q -> RoundedSystemTime -> IO (Either ErrorType QueueRec)
|
||||
deleteStoreQueue :: s -> q -> IO (Either ErrorType (QueueRec, Maybe (MsgQueue q)))
|
||||
deleteStoreQueue :: s -> q -> IO (Either ErrorType QueueRec)
|
||||
getCreateService :: s -> ServiceRec -> IO (Either ErrorType ServiceId)
|
||||
setQueueService :: (PartyI p, ServiceParty p) => s -> q -> SParty p -> Maybe ServiceId -> IO (Either ErrorType ())
|
||||
getQueueNtfServices :: s -> [(NotifierId, a)] -> IO (Either ErrorType ([(Maybe ServiceId, [(NotifierId, a)])], [(NotifierId, a)]))
|
||||
|
||||
@@ -30,12 +30,14 @@ where
|
||||
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Logger.Simple (logError)
|
||||
import Control.Monad
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAsciiLower, isDigit, isHexDigit)
|
||||
import Data.Default (def)
|
||||
import Data.Functor (($>))
|
||||
import Data.IORef
|
||||
import Data.IP
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
@@ -58,7 +60,7 @@ import Simplex.Messaging.Parsers (parseAll, parseString)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import Simplex.Messaging.Transport.Shared
|
||||
import Simplex.Messaging.Util (bshow, catchAll, tshow, (<$?>))
|
||||
import Simplex.Messaging.Util (bshow, catchAll, catchAll_, tshow, (<$?>))
|
||||
import System.IO.Error
|
||||
import Text.Read (readMaybe)
|
||||
import UnliftIO.Exception (IOException)
|
||||
@@ -156,6 +158,11 @@ clientTransportConfig TransportClientConfig {logTLSErrors} =
|
||||
runTransportClient :: Transport c => TransportClientConfig -> Maybe SocksCredentials -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c 'TClient -> IO a) -> IO a
|
||||
runTransportClient = runTLSTransportClient defaultSupportedParams Nothing
|
||||
|
||||
data ConnectionHandle c
|
||||
= CHSocket Socket
|
||||
| CHContext T.Context
|
||||
| CHTransport (c 'TClient)
|
||||
|
||||
runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe SocksCredentials -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c 'TClient -> IO a) -> IO a
|
||||
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials, clientALPN, useSNI} socksCreds host port keyHash client = do
|
||||
serverCert <- newEmptyTMVarIO
|
||||
@@ -165,17 +172,22 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
|
||||
connectTCP = case socksProxy of
|
||||
Just proxy -> connectSocksClient proxy socksCreds (hostAddr host)
|
||||
_ -> connectTCPClient hostName
|
||||
c <- do
|
||||
sock <- connectTCP port
|
||||
mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e)
|
||||
h <- newIORef Nothing
|
||||
let set hc = (>>= \c -> writeIORef h (Just $ hc c) $> c)
|
||||
E.bracket (set CHSocket $ connectTCP port) (\_ -> closeConn h) $ \sock -> do
|
||||
mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive " <> tshow e)
|
||||
let tCfg = clientTransportConfig cfg
|
||||
-- No TLS timeout to avoid failing connections via SOCKS
|
||||
tls <- connectTLS (Just hostName) tCfg clientParams sock
|
||||
chain <- takePeerCertChain serverCert `E.onException` closeTLS tls
|
||||
tls <- set CHContext $ connectTLS (Just hostName) tCfg clientParams sock
|
||||
chain <- takePeerCertChain serverCert
|
||||
sent <- readIORef clientCredsSent
|
||||
getTransportConnection tCfg sent chain tls
|
||||
client c `E.finally` closeConnection c
|
||||
client =<< set CHTransport (getTransportConnection tCfg sent chain tls)
|
||||
where
|
||||
closeConn = readIORef >=> mapM_ (\c -> E.uninterruptibleMask_ $ closeConn_ c `catchAll_` pure ())
|
||||
closeConn_ = \case
|
||||
CHSocket sock -> close sock
|
||||
CHContext tls -> closeTLS tls
|
||||
CHTransport c -> closeConnection c
|
||||
hostAddr = \case
|
||||
THIPv4 addr -> SocksAddrIPV4 $ tupleToHostAddress addr
|
||||
THIPv6 addr -> SocksAddrIPV6 addr
|
||||
@@ -199,10 +211,11 @@ connectTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
|
||||
E.try (open addr) >>= either (`tryOpen` as) pure
|
||||
|
||||
open :: AddrInfo -> IO Socket
|
||||
open addr = do
|
||||
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
|
||||
connect sock $ addrAddress addr
|
||||
pure sock
|
||||
open addr =
|
||||
E.bracketOnError
|
||||
(socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
|
||||
close
|
||||
(\sock -> connect sock (addrAddress addr) $> sock)
|
||||
|
||||
defaultSMPPort :: PortNumber
|
||||
defaultSMPPort = 5223
|
||||
|
||||
@@ -27,6 +27,7 @@ import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (NetworkError (..), toNetworkError)
|
||||
import Simplex.Messaging.Transport (ALPN, STransportPeer (..), SessionId, TLS (tlsALPN, tlsPeerCert, tlsUniq), TransportPeer (..), TransportPeerI (..), getServerVerifyKey)
|
||||
import Simplex.Messaging.Transport.Client (TransportClientConfig (..), TransportHost (..), defaultTcpConnectTimeout, runTLSTransportClient)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
@@ -89,7 +90,7 @@ defaultHTTP2ClientConfig =
|
||||
suportedTLSParams = http2TLSParams
|
||||
}
|
||||
|
||||
data HTTP2ClientError = HCResponseTimeout | HCNetworkError | HCIOError IOException
|
||||
data HTTP2ClientError = HCResponseTimeout | HCNetworkError NetworkError | HCIOError IOException
|
||||
deriving (Show)
|
||||
|
||||
getHTTP2Client :: HostName -> ServiceName -> Maybe XS.CertificateStore -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
@@ -121,12 +122,15 @@ getVerifiedHTTP2ClientWith config host port disconnected setup =
|
||||
runClient :: HClient -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
runClient c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
action <- async $ setup (client c cVar) `E.finally` atomically (putTMVar cVar $ Left HCNetworkError)
|
||||
action <-
|
||||
async $ setup (client c cVar) `E.catch` \e -> do
|
||||
atomically $ putTMVar cVar $ Left $ HCNetworkError $ toNetworkError e
|
||||
E.throwIO e
|
||||
c_ <- connTimeout config `timeout` atomically (takeTMVar cVar)
|
||||
case c_ of
|
||||
Just (Right c') -> pure $ Right c' {action = Just action}
|
||||
Just (Left e) -> pure $ Left e
|
||||
Nothing -> cancel action $> Left HCNetworkError
|
||||
Nothing -> cancel action $> Left (HCNetworkError NETimeoutError)
|
||||
|
||||
client :: HClient -> TMVar (Either HTTP2ClientError HTTP2Client) -> TLS p -> H.Client HTTP2Response
|
||||
client c cVar tls sendReq = do
|
||||
@@ -176,7 +180,7 @@ sendRequestDirect HTTP2Client {client_ = HClient {config, disconnected}, sendReq
|
||||
reqTimeout `timeout` try (sendReq req process) >>= \case
|
||||
Just (Right r) -> pure $ Right r
|
||||
Just (Left e) -> disconnected $> Left (HCIOError e)
|
||||
Nothing -> pure $ Left HCNetworkError
|
||||
Nothing -> pure $ Left HCResponseTimeout
|
||||
where
|
||||
process r = do
|
||||
respBody <- getHTTP2Body r $ bodyHeadSize config
|
||||
|
||||
@@ -171,28 +171,30 @@ catchAll_ :: IO a -> IO a -> IO a
|
||||
catchAll_ a = catchAll a . const
|
||||
{-# INLINE catchAll_ #-}
|
||||
|
||||
tryAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m (Either e a)
|
||||
tryAllErrors err action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . err)
|
||||
class Show e => AnyError e where fromSomeException :: E.SomeException -> e
|
||||
|
||||
tryAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m (Either e a)
|
||||
tryAllErrors action = ExceptT $ Right <$> runExceptT action `UE.catch` (pure . Left . fromSomeException)
|
||||
{-# INLINE tryAllErrors #-}
|
||||
|
||||
tryAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> m (Either e a)
|
||||
tryAllErrors' err action = runExceptT action `UE.catch` (pure . Left . err)
|
||||
tryAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> m (Either e a)
|
||||
tryAllErrors' action = runExceptT action `UE.catch` (pure . Left . fromSomeException)
|
||||
{-# INLINE tryAllErrors' #-}
|
||||
|
||||
catchAllErrors :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
|
||||
catchAllErrors err action handler = tryAllErrors err action >>= either handler pure
|
||||
catchAllErrors :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
|
||||
catchAllErrors action handler = tryAllErrors action >>= either handler pure
|
||||
{-# INLINE catchAllErrors #-}
|
||||
|
||||
catchAllErrors' :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> (e -> m a) -> m a
|
||||
catchAllErrors' err action handler = tryAllErrors' err action >>= either handler pure
|
||||
catchAllErrors' :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> (e -> m a) -> m a
|
||||
catchAllErrors' action handler = tryAllErrors' action >>= either handler pure
|
||||
{-# INLINE catchAllErrors' #-}
|
||||
|
||||
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (E.SomeException -> e) -> ExceptT e m a
|
||||
catchThrow action err = catchAllErrors err action throwE
|
||||
catchThrow :: MonadUnliftIO m => ExceptT e m a -> (SomeException -> e) -> ExceptT e m a
|
||||
action `catchThrow` err = ExceptT $ runExceptT action `UE.catch` (pure . Left . err)
|
||||
{-# INLINE catchThrow #-}
|
||||
|
||||
allFinally :: MonadUnliftIO m => (E.SomeException -> e) -> ExceptT e m a -> ExceptT e m b -> ExceptT e m a
|
||||
allFinally err action final = tryAllErrors err action >>= \r -> final >> except r
|
||||
allFinally :: (AnyError e, MonadUnliftIO m) => ExceptT e m a -> ExceptT e m b -> ExceptT e m a
|
||||
allFinally action final = tryAllErrors action >>= \r -> final >> except r
|
||||
{-# INLINE allFinally #-}
|
||||
|
||||
eitherToMaybe :: Either a b -> Maybe b
|
||||
@@ -209,17 +211,25 @@ firstRow f e a = second f . listToEither e <$> a
|
||||
maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b)
|
||||
maybeFirstRow f q = fmap f . listToMaybe <$> q
|
||||
|
||||
maybeFirstRow' :: Functor f => b -> (a -> b) -> f [a] -> f b
|
||||
maybeFirstRow' def f q = maybe def f . listToMaybe <$> q
|
||||
|
||||
firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b)
|
||||
firstRow' f e a = (f <=< listToEither e) <$> a
|
||||
|
||||
groupOn :: Eq k => (a -> k) -> [a] -> [[a]]
|
||||
groupOn = groupBy . eqOn
|
||||
where
|
||||
-- it is equivalent to groupBy ((==) `on` f),
|
||||
-- but it redefines `on` to avoid duplicate computation for most values.
|
||||
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
|
||||
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
|
||||
eqOn f x = let fx = f x in \y -> fx == f y
|
||||
|
||||
groupOn' :: Eq k => (a -> k) -> [a] -> [NonEmpty a]
|
||||
groupOn' = L.groupBy . eqOn
|
||||
|
||||
-- it is equivalent to groupBy ((==) `on` f),
|
||||
-- but it redefines `on` to avoid duplicate computation for most values.
|
||||
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
|
||||
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
|
||||
eqOn :: Eq k => (a -> k) -> a -> a -> Bool
|
||||
eqOn f x = let fx = f x in \y -> fx == f y
|
||||
{-# INLINE eqOn #-}
|
||||
|
||||
groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]]
|
||||
groupAllOn f = groupOn f . sortOn f
|
||||
|
||||
@@ -306,14 +306,8 @@ connectRCCtrl_ drg pairing'@RCCtrlPairing {caKey, caCert} inv@RCInvitation {ca,
|
||||
atomically $ takeTMVar endSession
|
||||
logDebug "Session ended"
|
||||
|
||||
catchRCError :: ExceptT RCErrorType IO a -> (RCErrorType -> ExceptT RCErrorType IO a) -> ExceptT RCErrorType IO a
|
||||
catchRCError = catchAllErrors $ \e -> case fromException e of
|
||||
Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity
|
||||
_ -> RCEException $ show e
|
||||
{-# INLINE catchRCError #-}
|
||||
|
||||
putRCError :: ExceptT RCErrorType IO a -> TMVar (Either RCErrorType b) -> ExceptT RCErrorType IO a
|
||||
a `putRCError` r = a `catchRCError` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
|
||||
a `putRCError` r = a `catchAllErrors` \e -> atomically (tryPutTMVar r $ Left e) >> throwE e
|
||||
|
||||
sendRCPacket :: Encoding a => TLS p -> a -> ExceptT RCErrorType IO ()
|
||||
sendRCPacket tls pkt = do
|
||||
@@ -395,7 +389,7 @@ discoverRCCtrl subscribers pairings =
|
||||
pure r
|
||||
where
|
||||
loop :: ExceptT RCErrorType IO a -> ExceptT RCErrorType IO a
|
||||
loop action = action `catchRCError` \e -> logError (tshow e) >> loop action
|
||||
loop action = action `catchAllErrors` \e -> logError (tshow e) >> loop action
|
||||
|
||||
findRCCtrlPairing :: NonEmpty RCCtrlPairing -> RCEncInvitation -> ExceptT RCErrorType IO (RCCtrlPairing, RCVerifiedInvitation)
|
||||
findRCCtrlPairing pairings RCEncInvitation {dhPubKey, nonce, encInvitation} = do
|
||||
|
||||
@@ -19,6 +19,7 @@ import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Word (Word16)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Network.TLS as TLS
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
|
||||
import Simplex.Messaging.Encoding
|
||||
@@ -26,7 +27,7 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport (TLS, TSbChainKeys, TransportPeer (..))
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8)
|
||||
import Simplex.Messaging.Util (AnyError (..), safeDecodeUtf8)
|
||||
import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange)
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import UnliftIO
|
||||
@@ -50,6 +51,12 @@ data RCErrorType
|
||||
| RCESyntax {syntaxErr :: String}
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
instance AnyError RCErrorType where
|
||||
fromSomeException e = case fromException e of
|
||||
Just (TLS.Terminated _ _ (TLS.Error_Protocol _ TLS.UnknownCa)) -> RCEIdentity
|
||||
_ -> RCEException $ show e
|
||||
{-# INLINE fromSomeException #-}
|
||||
|
||||
instance StrEncoding RCErrorType where
|
||||
strEncode = \case
|
||||
RCEInternal err -> "INTERNAL" <> text err
|
||||
|
||||
@@ -90,7 +90,7 @@ import qualified Simplex.Messaging.Agent.Protocol as A
|
||||
import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction)
|
||||
import Simplex.Messaging.Agent.Store.Interface
|
||||
import qualified Simplex.Messaging.Agent.Store.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationConfirmation (..), MigrationError (..))
|
||||
import Simplex.Messaging.Client (pattern NRMInteractive, NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (..), defaultClientConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn)
|
||||
@@ -98,7 +98,7 @@ import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF)
|
||||
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), initialSMPClientVersion, srvHostnamesSMPClientVersion, supportedSMPClientVRange)
|
||||
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, NetworkError (..), ProtocolServer (..), SubscriptionMode (..), initialSMPClientVersion, srvHostnamesSMPClientVersion, supportedSMPClientVRange)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..), ServerStoreCfg (..), StorePaths (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
@@ -114,6 +114,7 @@ import Test.Hspec hiding (fit, it)
|
||||
import UnliftIO
|
||||
import Util
|
||||
import XFTPClient (testXFTPServer)
|
||||
|
||||
#if defined(dbPostgres)
|
||||
import Fixtures
|
||||
#endif
|
||||
@@ -122,6 +123,7 @@ import qualified Database.PostgreSQL.Simple as PSQL
|
||||
import Simplex.Messaging.Agent.Store (Connection (..), StoredRcvQueue (..), SomeConn (..))
|
||||
import Simplex.Messaging.Agent.Store.AgentStore (getConn)
|
||||
import Simplex.Messaging.Server.MsgStore.Journal (JournalQueue)
|
||||
import Simplex.Messaging.Server.MsgStore.Postgres (PostgresQueue)
|
||||
import Simplex.Messaging.Server.MsgStore.Types (QSType (..))
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres
|
||||
import Simplex.Messaging.Server.QueueStore.Types (QueueStoreClass (..))
|
||||
@@ -177,7 +179,7 @@ pGet' c skipWarn = do
|
||||
case cmd of
|
||||
CONNECT {} -> pGet c
|
||||
DISCONNECT {} -> pGet c
|
||||
ERR (BROKER _ NETWORK) -> pGet c
|
||||
ERR (BROKER _ (NETWORK _)) -> pGet c
|
||||
MWARN {} | skipWarn -> pGet c
|
||||
RFWARN {} | skipWarn -> pGet c
|
||||
SFWARN {} | skipWarn -> pGet c
|
||||
@@ -516,7 +518,7 @@ functionalAPITests ps = do
|
||||
it "should pass without basic auth" $ testSMPServerConnectionTest ps Nothing (noAuthSrv testSMPServer2) `shouldReturn` Nothing
|
||||
let srv1 = testSMPServer2 {keyHash = "1234"}
|
||||
it "should fail with incorrect fingerprint" $ do
|
||||
testSMPServerConnectionTest ps Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK)
|
||||
testSMPServerConnectionTest ps Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) $ NETWORK NEUnknownCAError)
|
||||
describe "server with password" $ do
|
||||
let auth = Just "abcd"
|
||||
srv = ProtoServerWithAuth testSMPServer2
|
||||
@@ -1105,7 +1107,7 @@ testAsyncServerOffline ps = withAgentClients2 $ \alice bob -> do
|
||||
(bobId, cReq) <- withSmpServerStoreLogOn ps testPort $ \_ ->
|
||||
runRight $ createConnection alice 1 True SCMInvitation Nothing SMSubscribe
|
||||
-- connection fails
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe
|
||||
Left (BROKER _ (NETWORK _)) <- runExceptT $ joinConnection bob 1 True cReq "bob's connInfo" SMSubscribe
|
||||
("", "", DOWN srv conns) <- nGet alice
|
||||
srv `shouldBe` testSMPServer
|
||||
conns `shouldBe` [bobId]
|
||||
@@ -1172,13 +1174,13 @@ testInvitationErrors ps restart = do
|
||||
("", "", DOWN _ [_]) <- nGet a
|
||||
aId <- runRight $ A.prepareConnectionToJoin b 1 True cReq PQSupportOn
|
||||
-- fails to secure the queue on testPort
|
||||
BROKER srv NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
(testPort `isSuffixOf` srv) `shouldBe` True
|
||||
withServer1 ps $ do
|
||||
("", "", UP _ [_]) <- nGet a
|
||||
let loopSecure = do
|
||||
-- secures the queue on testPort, but fails to create reply queue on testPort2
|
||||
BROKER srv2 NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv2 (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
unless (testPort2 `isSuffixOf` srv2) $ putStrLn "retrying secure" >> threadDelay 200000 >> loopSecure
|
||||
loopSecure
|
||||
("", "", DOWN _ [_]) <- nGet a
|
||||
@@ -1186,7 +1188,7 @@ testInvitationErrors ps restart = do
|
||||
threadDelay 200000
|
||||
let loopCreate = do
|
||||
-- creates the reply queue on testPort2, but fails to send it to testPort
|
||||
BROKER srv' NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv' (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
unless (testPort `isSuffixOf` srv') $ putStrLn "retrying create" >> threadDelay 200000 >> loopCreate
|
||||
loopCreate
|
||||
restartAgentB restart b [aId]
|
||||
@@ -1242,12 +1244,12 @@ testContactErrors ps restart = do
|
||||
("", "", DOWN _ [_]) <- nGet a
|
||||
aId <- runRight $ A.prepareConnectionToJoin b 1 True cReq PQSupportOn
|
||||
-- fails to create queue on testPort2
|
||||
BROKER srv2 NETWORK <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv2 (NETWORK _) <- runLeft $ A.joinConnection b NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
(testPort2 `isSuffixOf` srv2) `shouldBe` True
|
||||
b' <- restartAgentB restart b [aId]
|
||||
let loopCreate2 = do
|
||||
-- creates the reply queue on testPort2, but fails to send invitation to testPort
|
||||
BROKER srv' NETWORK <- runLeft $ A.joinConnection b' NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv' (NETWORK _) <- runLeft $ A.joinConnection b' NRMInteractive 1 aId True cReq "bob's connInfo" PQSupportOn SMSubscribe
|
||||
unless (testPort `isSuffixOf` srv') $ putStrLn "retrying create 2" >> threadDelay 200000 >> loopCreate2
|
||||
b'' <- withServer2 ps $ do
|
||||
loopCreate2
|
||||
@@ -1270,7 +1272,7 @@ testContactErrors ps restart = do
|
||||
("", "", UP _ [_]) <- nGet b''
|
||||
let loopSecure = do
|
||||
-- secures the queue on testPort2, but fails to create reply queue on testPort
|
||||
BROKER srv NETWORK <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv (NETWORK _) <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe
|
||||
unless (testPort `isSuffixOf` srv) $ putStrLn "retrying secure" >> threadDelay 200000 >> loopSecure
|
||||
loopSecure
|
||||
("", "", DOWN _ [_]) <- nGet b''
|
||||
@@ -1278,7 +1280,7 @@ testContactErrors ps restart = do
|
||||
("", "", UP _ [_]) <- nGet a
|
||||
let loopCreate = do
|
||||
-- creates the reply queue on testPort, but fails to send confirmation to testPort2
|
||||
BROKER srv2' NETWORK <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe
|
||||
BROKER srv2' (NETWORK _) <- runLeft $ acceptContact a 1 bId True invId "alice's connInfo" PQSupportOn SMSubscribe
|
||||
unless (testPort2 `isSuffixOf` srv2') $ putStrLn "retrying create" >> threadDelay 200000 >> loopCreate
|
||||
loopCreate
|
||||
restartAgentA restart a [contactId, bId]
|
||||
@@ -1524,20 +1526,29 @@ testOldContactQueueShortLink ps@(_, msType) = withAgentClients2 $ \a b -> do
|
||||
A.createConnection a NRMInteractive 1 True SCMContact Nothing Nothing CR.IKPQOn SMOnlyCreate
|
||||
-- make it an "old" queue
|
||||
let updateStoreLog f = replaceSubstringInFile f " queue_mode=C" ""
|
||||
() <- case testServerStoreConfig msType of
|
||||
ASSCfg _ _ (SSCMemory (Just StorePaths {storeLogFile})) -> updateStoreLog storeLogFile
|
||||
ASSCfg _ _ (SSCMemoryJournal {storeLogFile}) -> updateStoreLog storeLogFile
|
||||
ASSCfg _ _ (SSCDatabaseJournal {storeCfg}) -> do
|
||||
#if defined(dbServerPostgres)
|
||||
let AgentClient {agentEnv = Env {store}} = a
|
||||
Right (SomeConn _ (ContactConnection _ RcvQueue {rcvId})) <- withTransaction store (`getConn` contactId)
|
||||
st :: PostgresQueueStore (JournalQueue 'QSPostgres) <- newQueueStore @(JournalQueue 'QSPostgres) storeCfg
|
||||
Right 1 <- runExceptT $ withDB' "test" st $ \db -> PSQL.execute db "UPDATE msg_queues SET queue_mode = ? WHERE recipient_id = ?" (Nothing :: Maybe QueueMode, rcvId)
|
||||
closeQueueStore @(JournalQueue 'QSPostgres) st
|
||||
#else
|
||||
error "no dbServerPostgres flag"
|
||||
updateDbStore :: PostgresQueueStore s -> IO ()
|
||||
updateDbStore st = do
|
||||
let AgentClient {agentEnv = Env {store}} = a
|
||||
Right (SomeConn _ (ContactConnection _ RcvQueue {rcvId})) <- withTransaction store (`getConn` contactId)
|
||||
Right 1 <- runExceptT $ withDB' "test" st $ \db -> PSQL.execute db "UPDATE msg_queues SET queue_mode = ? WHERE recipient_id = ?" (Nothing :: Maybe QueueMode, rcvId)
|
||||
pure ()
|
||||
#endif
|
||||
() <- case testServerStoreConfig msType of
|
||||
ASSCfg _ _ (SSCMemory sp_) -> mapM_ (\StorePaths {storeLogFile} -> updateStoreLog storeLogFile) sp_
|
||||
ASSCfg _ _ SSCMemoryJournal {storeLogFile} -> updateStoreLog storeLogFile
|
||||
#if defined(dbServerPostgres)
|
||||
ASSCfg _ _ SSCDatabaseJournal {storeCfg} -> do
|
||||
st :: PostgresQueueStore (JournalQueue 'QSPostgres) <- newQueueStore @(JournalQueue 'QSPostgres) (storeCfg, True)
|
||||
updateDbStore st
|
||||
closeQueueStore @(JournalQueue 'QSPostgres) st
|
||||
ASSCfg _ _ (SSCDatabase storeCfg) -> do
|
||||
st :: PostgresQueueStore PostgresQueue <- newQueueStore @PostgresQueue (storeCfg, False)
|
||||
updateDbStore st
|
||||
closeQueueStore @PostgresQueue st
|
||||
#else
|
||||
ASSCfg _ _ SSCDatabaseJournal {} -> error "no dbServerPostgres flag"
|
||||
#endif
|
||||
_ -> pure ()
|
||||
|
||||
withSmpServer ps $ do
|
||||
let userData = UserLinkData "some user data"
|
||||
@@ -1743,7 +1754,7 @@ testDuplicateMessage ps = do
|
||||
-- commenting two lines below and uncommenting further two lines would also runRight_,
|
||||
-- it is the scenario tested above, when the message was not acknowledged by the user
|
||||
threadDelay 200000
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ ackMessage bob1 aliceId 3 Nothing
|
||||
Left (BROKER _ (NETWORK _)) <- runExceptT $ ackMessage bob1 aliceId 3 Nothing
|
||||
|
||||
disposeAgentClient alice
|
||||
disposeAgentClient bob1
|
||||
@@ -1827,8 +1838,8 @@ testDeliveryAfterSubscriptionError ps = do
|
||||
pure (aId, bId)
|
||||
|
||||
withAgentClients2 $ \a b -> do
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection a bId
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection b aId
|
||||
Left (BROKER _ (NETWORK _)) <- runExceptT $ subscribeConnection a bId
|
||||
Left (BROKER _ (NETWORK _)) <- runExceptT $ subscribeConnection b aId
|
||||
withSmpServerStoreLogOn ps testPort $ \_ -> runRight $ do
|
||||
withUP a bId $ \case ("", c, SENT 2) -> c == bId; _ -> False
|
||||
withUP b aId $ \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
@@ -1872,7 +1883,7 @@ testExpireMessage ps =
|
||||
2 <- runRight $ sendMessage a bId SMP.noMsgFlags "1"
|
||||
threadDelay 1500000
|
||||
3 <- runRight $ sendMessage a bId SMP.noMsgFlags "2" -- this won't expire
|
||||
get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && networkOrTimeoutError e; _ -> False
|
||||
withSmpServerStoreLogOn ps testPort $ \_ -> runRight_ $ do
|
||||
withUP a bId $ \case ("", _, SENT 3) -> True; _ -> False
|
||||
withUP b aId $ \case ("", _, MsgErr 2 (MsgSkipped 2 2) "2") -> True; _ -> False
|
||||
@@ -1891,8 +1902,8 @@ testExpireManyMessages ps =
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "3"
|
||||
liftIO $ threadDelay 2000000
|
||||
5 <- sendMessage a bId SMP.noMsgFlags "4" -- this won't expire
|
||||
get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
let expected c e = bId == c && (e == TIMEOUT || e == NETWORK)
|
||||
get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && networkOrTimeoutError e; _ -> False
|
||||
let expected c e = bId == c && networkOrTimeoutError e
|
||||
get a >>= \case
|
||||
("", c, MERR 3 (BROKER _ e)) -> do
|
||||
liftIO $ expected c e `shouldBe` True
|
||||
@@ -2633,7 +2644,7 @@ testDeleteConnectionAsync ps =
|
||||
runRight_ $ do
|
||||
deleteConnectionsAsync a False connIds
|
||||
nGet a =##> \case ("", "", DOWN {}) -> True; _ -> False
|
||||
let delOk = \case (c, _, _, Just (BROKER _ e)) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
let delOk = \case (c, _, _, Just (BROKER _ e)) -> c `elem` connIds && networkOrTimeoutError e; _ -> False
|
||||
get a =##> \case ("", "", DEL_RCVQS rs) -> length rs == 3 && all delOk rs; _ -> False
|
||||
get a =##> \case ("", "", DEL_CONNS cs) -> length cs == 3 && all (`elem` connIds) cs; _ -> False
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
@@ -2691,7 +2702,7 @@ testWaitDelivery ps =
|
||||
3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1"
|
||||
deleteConnectionsAsync alice True [bobId]
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessages bob "nothing else should be delivered to bob"
|
||||
|
||||
@@ -2748,7 +2759,7 @@ testWaitDeliveryAUTHErr ps =
|
||||
3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1"
|
||||
deleteConnectionsAsync alice True [bobId]
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessages bob "nothing else should be delivered to bob"
|
||||
|
||||
@@ -2788,7 +2799,7 @@ testWaitDeliveryTimeout ps =
|
||||
3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1"
|
||||
deleteConnectionsAsync alice True [bobId]
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False
|
||||
get alice =##> \case ("", "", DEL_CONNS [cId]) -> cId == bobId; _ -> False
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessages bob "nothing else should be delivered to bob"
|
||||
@@ -2828,7 +2839,7 @@ testWaitDeliveryTimeout2 ps =
|
||||
3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1"
|
||||
deleteConnectionsAsync alice True [bobId]
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get alice =##> \case ("", "", DEL_RCVQS [(cId, _, _, Just (BROKER _ e))]) -> cId == bobId && networkOrTimeoutError e; _ -> False
|
||||
get alice =##> \case ("", "", DEL_CONNS [cId]) -> cId == bobId; _ -> False
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessages bob "nothing else should be delivered to bob"
|
||||
@@ -2849,6 +2860,12 @@ testWaitDeliveryTimeout2 ps =
|
||||
baseId = 1
|
||||
msgId = subtract baseId
|
||||
|
||||
networkOrTimeoutError :: BrokerErrorType -> Bool
|
||||
networkOrTimeoutError = \case
|
||||
TIMEOUT -> True
|
||||
NETWORK _ -> True
|
||||
_ -> False
|
||||
|
||||
testJoinConnectionAsyncReplyErrorV8 :: HasCallStack => (ASrvTransport, AStoreType) -> IO ()
|
||||
testJoinConnectionAsyncReplyErrorV8 ps@(t, ASType qsType _) = do
|
||||
let initAgentServersSrv2 = initAgentServers {smp = userServers [testSMPServer2]}
|
||||
@@ -2975,7 +2992,7 @@ testUsersNoServer ps = withAgentClientsCfg2 aCfg agentCfg $ \a b -> do
|
||||
nGet b =##> \case ("", "", DOWN _ cs) -> length cs == 2; _ -> False
|
||||
runRight_ $ do
|
||||
deleteUser a auId True
|
||||
get a =##> \case ("", "", DEL_RCVQS [(c, _, _, Just (BROKER _ e))]) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
get a =##> \case ("", "", DEL_RCVQS [(c, _, _, Just (BROKER _ e))]) -> c == bId' && networkOrTimeoutError e;; _ -> False
|
||||
get a =##> \case ("", "", DEL_CONNS [c]) -> c == bId'; _ -> False
|
||||
nGet a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False
|
||||
liftIO $ noMessages a "nothing else should be delivered to alice"
|
||||
@@ -3613,13 +3630,13 @@ getSMPAgentClient' clientId cfg' initServers dbPath = do
|
||||
|
||||
#if defined(dbPostgres)
|
||||
createStore :: String -> IO (Either MigrationError DBStore)
|
||||
createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) MCError
|
||||
createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) (MigrationConfig MCError Nothing)
|
||||
|
||||
insertUser :: DBStore -> IO ()
|
||||
insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES")
|
||||
#else
|
||||
createStore :: String -> IO (Either MigrationError DBStore)
|
||||
createStore dbPath = createAgentStore (DBOpts dbPath "" False True DB.TQOff) MCError
|
||||
createStore dbPath = createAgentStore (DBOpts dbPath "" False True DB.TQOff) (MigrationConfig MCError Nothing)
|
||||
|
||||
insertUser :: DBStore -> IO ()
|
||||
insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users (user_id) VALUES (1)")
|
||||
@@ -3639,7 +3656,7 @@ testServerMultipleIdentities =
|
||||
exchangeGreetings alice bobId bob aliceId
|
||||
-- this saves queue with second server identity
|
||||
bob' <- liftIO $ do
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
|
||||
Left (BROKER _ (NETWORK _)) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
|
||||
disposeAgentClient bob
|
||||
threadDelay 250000
|
||||
getSMPAgentClient' 3 agentCfg initAgentServers testDB2
|
||||
|
||||
@@ -212,7 +212,7 @@ createStore randSuffix migrations confirmMigrations = do
|
||||
poolSize = 1,
|
||||
createSchema = True
|
||||
}
|
||||
createDBStore dbOpts migrations confirmMigrations
|
||||
createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing)
|
||||
|
||||
cleanup :: Word32 -> IO ()
|
||||
cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix)
|
||||
@@ -235,7 +235,7 @@ createStore randSuffix migrations confirmMigrations = do
|
||||
vacuum = True,
|
||||
track = DB.TQOff
|
||||
}
|
||||
createDBStore dbOpts migrations confirmMigrations
|
||||
createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing)
|
||||
|
||||
cleanup :: Word32 -> IO ()
|
||||
cleanup randSuffix = removeFile (testDB randSuffix)
|
||||
|
||||
@@ -79,7 +79,7 @@ import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Server.Store.Postgres (closeNtfDbStore, newNtfDbStore, withDB')
|
||||
import Simplex.Messaging.Notifications.Types (NtfTknAction (..), NtfToken (..))
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NMsgMeta (..), NtfServer, ProtocolServer (..), SMPMsgMeta (..), SubscriptionMode (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (AUTH), NetworkError (..), MsgFlags (MsgFlags), NMsgMeta (..), NtfServer, ProtocolServer (..), SMPMsgMeta (..), SubscriptionMode (..))
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..))
|
||||
import Simplex.Messaging.Transport (ASrvTransport)
|
||||
@@ -137,7 +137,7 @@ notificationTests ps@(t, _) = do
|
||||
it "should pass" $ testRunNTFServerTests t testNtfServer `shouldReturn` Nothing
|
||||
let srv1 = testNtfServer {keyHash = "1234"}
|
||||
it "should fail with incorrect fingerprint" $ do
|
||||
testRunNTFServerTests t srv1 `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK)
|
||||
testRunNTFServerTests t srv1 `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) $ NETWORK NEUnknownCAError)
|
||||
describe "Managing notification subscriptions" $ do
|
||||
describe "should create notification subscription for existing connection" $
|
||||
testNtfMatrix ps testNotificationSubscriptionExistingConnection
|
||||
@@ -321,7 +321,7 @@ testNtfTokenServerRestartReverify t apns = do
|
||||
runRight_ $ do
|
||||
verification <- ntfData .-> "verification"
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
Left (BROKER _ NETWORK) <- tryE $ verifyNtfToken a tkn nonce verification
|
||||
Left (BROKER _ (NETWORK _)) <- tryE $ verifyNtfToken a tkn nonce verification
|
||||
pure ()
|
||||
threadDelay 1500000
|
||||
withAgent 2 agentCfg initAgentServers testDB $ \a' ->
|
||||
@@ -478,7 +478,7 @@ testNtfTokenChangeServers t apns =
|
||||
tkn2 <- registerTestToken a "xyzw" NMInstant apns
|
||||
getTestNtfTokenPort a >>= \port -> liftIO $ port `shouldBe` ntfTestPort -- not yet changed
|
||||
deleteNtfToken a tkn2 -- force server switch
|
||||
Left BROKER {brokerErr = NETWORK} <- tryError $ registerTestToken a "qwer" NMInstant apns -- ok, it's down for now
|
||||
Left BROKER {brokerErr = (NETWORK _)} <- tryError $ registerTestToken a "qwer" NMInstant apns -- ok, it's down for now
|
||||
getTestNtfTokenPort a >>= \port2 -> liftIO $ port2 `shouldBe` ntfTestPort2 -- but the token got updated
|
||||
killThread ntf
|
||||
withNtfServerOn t ntfTestPort2 ntfTestDBCfg2 $ runRight_ $ do
|
||||
|
||||
@@ -46,7 +46,7 @@ import Simplex.Messaging.Agent.Store.Migrations.App (appMigrations)
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common (DBStore (..), withTransaction')
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..))
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..), MigrationConfirmation (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.File (CryptoFile (..))
|
||||
import Simplex.Messaging.Crypto.Ratchet (pattern IKPQOn)
|
||||
@@ -83,7 +83,7 @@ createEncryptedStore key keepKey = do
|
||||
-- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous
|
||||
-- IO operations on multiple similarly named files; error seems to be environment specific
|
||||
r <- randomIO :: IO Word32
|
||||
Right st <- createDBStore (DBOpts (testDB <> show r) key keepKey True DB.TQOff) appMigrations MCError
|
||||
Right st <- createDBStore (DBOpts (testDB <> show r) key keepKey True DB.TQOff) appMigrations (MigrationConfig MCError Nothing)
|
||||
withTransaction' st (`SQL.execute_` "INSERT INTO users (user_id) VALUES (1);")
|
||||
pure st
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction')
|
||||
import Simplex.Messaging.Agent.Store.SQLite.DB (TrackQueries (..))
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration)
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration)
|
||||
import Simplex.Messaging.Util (ifM)
|
||||
import System.Directory (doesFileExist, removeFile)
|
||||
import System.Process (readCreateProcess, shell)
|
||||
@@ -51,7 +51,7 @@ testVerifySchemaDump :: IO ()
|
||||
testVerifySchemaDump = do
|
||||
savedSchema <- ifM (doesFileExist appSchema) (readFile appSchema) (pure "")
|
||||
savedSchema `deepseq` pure ()
|
||||
void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCConsole
|
||||
void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCConsole Nothing)
|
||||
getSchema testDB appSchema `shouldReturn` savedSchema
|
||||
removeFile testDB
|
||||
|
||||
@@ -59,14 +59,14 @@ testVerifyLintFKeyIndexes :: IO ()
|
||||
testVerifyLintFKeyIndexes = do
|
||||
savedLint <- ifM (doesFileExist appLint) (readFile appLint) (pure "")
|
||||
savedLint `deepseq` pure ()
|
||||
void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCConsole
|
||||
void $ createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCConsole Nothing)
|
||||
getLintFKeyIndexes testDB "tests/tmp/agent_lint.sql" `shouldReturn` savedLint
|
||||
removeFile testDB
|
||||
|
||||
testSchemaMigrations :: IO ()
|
||||
testSchemaMigrations = do
|
||||
let noDownMigrations = dropWhileEnd (\Migration {down} -> isJust down) appMigrations
|
||||
Right st <- createDBStore (DBOpts testDB "" False True TQOff) noDownMigrations MCError
|
||||
Right st <- createDBStore (DBOpts testDB "" False True TQOff) noDownMigrations (MigrationConfig MCError Nothing)
|
||||
mapM_ (testDownMigration st) $ drop (length noDownMigrations) appMigrations
|
||||
closeDBStore st
|
||||
removeFile testDB
|
||||
@@ -89,7 +89,7 @@ testSchemaMigrations = do
|
||||
|
||||
testUsersMigrationNew :: IO ()
|
||||
testUsersMigrationNew = do
|
||||
Right st <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCError
|
||||
Right st <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCError Nothing)
|
||||
withTransaction' st (`SQL.query_` "SELECT user_id FROM users;")
|
||||
`shouldReturn` ([] :: [Only Int])
|
||||
closeDBStore st
|
||||
@@ -97,11 +97,11 @@ testUsersMigrationNew = do
|
||||
testUsersMigrationOld :: IO ()
|
||||
testUsersMigrationOld = do
|
||||
let beforeUsers = takeWhile (("m20230110_users" /=) . name) appMigrations
|
||||
Right st <- createDBStore (DBOpts testDB "" False True TQOff) beforeUsers MCError
|
||||
Right st <- createDBStore (DBOpts testDB "" False True TQOff) beforeUsers (MigrationConfig MCError Nothing)
|
||||
withTransaction' st (`SQL.query_` "SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users';")
|
||||
`shouldReturn` ([] :: [Only String])
|
||||
closeDBStore st
|
||||
Right st' <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations MCYesUp
|
||||
Right st' <- createDBStore (DBOpts testDB "" False True TQOff) appMigrations (MigrationConfig MCYesUp Nothing)
|
||||
withTransaction' st' (`SQL.query_` "SELECT user_id FROM users;")
|
||||
`shouldReturn` ([Only (1 :: Int)])
|
||||
closeDBStore st'
|
||||
|
||||
+1
-1
@@ -10,7 +10,6 @@ import AgentTests.FunctionalAPITests (runRight_)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import qualified Crypto.PubKey.RSA as RSA
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.HashMap.Strict as HM
|
||||
import Data.Ini (Ini (..), lookupValue, readIniFile, writeIniFile)
|
||||
@@ -46,6 +45,7 @@ import UnliftIO.Exception (bracket)
|
||||
import Util
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Database.PostgreSQL.Simple as PSQL
|
||||
import Database.PostgreSQL.Simple.Types (Query (..))
|
||||
import NtfClient (ntfTestServerDBConnectInfo, ntfTestServerDBConnstr, ntfTestStoreDBOpts)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
@@ -23,9 +24,9 @@ import Control.Monad
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import qualified Data.ByteString.Base64.URL as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Int (Int64)
|
||||
import Data.List (isPrefixOf, isSuffixOf)
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Time.Clock (addUTCTime)
|
||||
@@ -33,9 +34,9 @@ import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import SMPClient (testStoreLogFile, testStoreMsgsDir, testStoreMsgsDir2, testStoreMsgsFile, testStoreMsgsFile2)
|
||||
import Simplex.Messaging.Crypto (pattern MaxLenBS)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (EntityId (..), LinkId, Message (..), QueueLinkData, RecipientId, SParty (..), noMsgFlags)
|
||||
import Simplex.Messaging.Protocol (EntityId (..), ErrorType, LinkId, Message (..), QueueLinkData, RecipientId, SParty (..), noMsgFlags)
|
||||
import Simplex.Messaging.Server (exportMessages, importMessages, printMessageStats)
|
||||
import Simplex.Messaging.Server.Env.STM (journalMsgStoreDepth, readWriteQueueStore)
|
||||
import Simplex.Messaging.Server.Env.STM (MsgStore (..), journalMsgStoreDepth, readWriteQueueStore)
|
||||
import Simplex.Messaging.Server.Expiration (ExpirationConfig (..), expireBeforeEpoch)
|
||||
import Simplex.Messaging.Server.MsgStore.Journal
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
@@ -50,28 +51,54 @@ import System.IO (IOMode (..), withFile)
|
||||
import Test.Hspec hiding (fit, it)
|
||||
import Util
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
import Database.PostgreSQL.Simple (Only (..))
|
||||
import qualified Database.PostgreSQL.Simple as DB
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Common
|
||||
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..))
|
||||
import Simplex.Messaging.Server.MsgStore.Postgres
|
||||
import Simplex.Messaging.Server.QueueStore.Postgres
|
||||
import SMPClient (postgressBracket, testServerDBConnectInfo, testStoreDBOpts)
|
||||
#endif
|
||||
|
||||
msgStoreTests :: Spec
|
||||
msgStoreTests = do
|
||||
around (withMsgStore testSMTStoreConfig) $ describe "STM message store" someMsgStoreTests
|
||||
around (withMsgStore $ testJournalStoreCfg MQStoreCfg) $ describe "Journal message store" $ do
|
||||
someMsgStoreTests
|
||||
journalMsgStoreTests
|
||||
it "should export and import journal store" testExportImportStore
|
||||
describe "queue state" $ do
|
||||
it "should restore queue state from the last line" testQueueState
|
||||
it "should recover when message is written and state is not" testMessageState
|
||||
it "should remove journal files when queue is empty" testRemoveJournals
|
||||
describe "missing files" $ do
|
||||
it "should create read file when missing" testReadFileMissing
|
||||
it "should switch to write file when read file missing" testReadFileMissingSwitch
|
||||
it "should create write file when missing" testWriteFileMissing
|
||||
it "should create read file when read and write files are missing" testReadAndWriteFilesMissing
|
||||
#if defined(dbServerPostgres)
|
||||
around_ (postgressBracket testServerDBConnectInfo) $ do
|
||||
around (withMsgStore $ testJournalStoreCfg $ PQStoreCfg testPostgresStoreCfg) $
|
||||
describe "Postgres+journal message store" $ do
|
||||
someMsgStoreTests
|
||||
journalMsgStoreTests
|
||||
around (withMsgStore testPostgresStoreConfig) $
|
||||
describe "Postgres-only message store" $ do
|
||||
someMsgStoreTests
|
||||
it "should correctly update message counts and canWrite flag" testUpdateMessageCounts
|
||||
it "tryDelPeekMsg (ACK not from NSE) should reset message counts when queue is empty" testResetMessageCounts
|
||||
#endif
|
||||
describe "Journal message store: queue state backup expiration" $ do
|
||||
it "should remove old queue state backups" testRemoveQueueStateBackups
|
||||
it "should expire messages in idle queues" testExpireIdleQueues
|
||||
where
|
||||
journalMsgStoreTests :: SpecWith (JournalMsgStore s)
|
||||
journalMsgStoreTests = do
|
||||
describe "queue state" $ do
|
||||
it "should restore queue state from the last line" testQueueState
|
||||
it "should recover when message is written and state is not" testMessageState
|
||||
it "should remove journal files when queue is empty" testRemoveJournals
|
||||
describe "missing files" $ do
|
||||
it "should create read file when missing" testReadFileMissing
|
||||
it "should switch to write file when read file missing" testReadFileMissingSwitch
|
||||
it "should create write file when missing" testWriteFileMissing
|
||||
it "should create read file when read and write files are missing" testReadAndWriteFilesMissing
|
||||
someMsgStoreTests :: MsgStoreClass s => SpecWith s
|
||||
someMsgStoreTests = do
|
||||
it "should get queue and store/read messages" testGetQueue
|
||||
it "should write/ack messages" testWriteAckMessages
|
||||
it "should not fail on EOF when changing read journal" testChangeReadJournal
|
||||
|
||||
-- TODO constrain to STM stores?
|
||||
@@ -96,6 +123,24 @@ testJournalStoreCfg queueStoreCfg =
|
||||
keepMinBackups = 1
|
||||
}
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
testPostgresStoreConfig :: PostgresMsgStoreCfg
|
||||
testPostgresStoreConfig =
|
||||
PostgresMsgStoreCfg
|
||||
{ queueStoreCfg = testPostgresStoreCfg,
|
||||
quota = 3
|
||||
}
|
||||
|
||||
testPostgresStoreCfg :: PostgresStoreCfg
|
||||
testPostgresStoreCfg =
|
||||
PostgresStoreCfg
|
||||
{ dbOpts = testStoreDBOpts,
|
||||
dbStoreLogPath = Nothing,
|
||||
confirmMigrations = MCYesUp,
|
||||
deletedTTL = 86400
|
||||
}
|
||||
#endif
|
||||
|
||||
mkMessage :: MonadIO m => ByteString -> m Message
|
||||
mkMessage body = liftIO $ do
|
||||
g <- C.newRandom
|
||||
@@ -138,7 +183,6 @@ testNewQueueRecData g qm queueData = do
|
||||
where
|
||||
rndId = atomically $ EntityId <$> C.randomBytes 24 g
|
||||
|
||||
-- TODO constrain to STM stores
|
||||
testGetQueue :: MsgStoreClass s => s -> IO ()
|
||||
testGetQueue ms = do
|
||||
g <- C.newRandom
|
||||
@@ -181,7 +225,28 @@ testGetQueue ms = do
|
||||
(Nothing, Nothing) <- tryDelPeekMsg ms q mId8
|
||||
void $ ExceptT $ deleteQueue ms q
|
||||
|
||||
-- TODO constrain to STM stores
|
||||
-- TODO [messages] test concurrent writing and reading
|
||||
testWriteAckMessages :: MsgStoreClass s => s -> IO ()
|
||||
testWriteAckMessages ms = do
|
||||
g <- C.newRandom
|
||||
(rId1, qr1) <- testNewQueueRec g QMMessaging
|
||||
(rId2, qr2) <- testNewQueueRec g QMMessaging
|
||||
runRight_ $ do
|
||||
q1 <- ExceptT $ addQueue ms rId1 qr1
|
||||
q2 <- ExceptT $ addQueue ms rId2 qr2
|
||||
let write q s = writeMsg ms q True =<< mkMessage s
|
||||
0 <- deleteExpiredMsgs ms q1 0 -- won't expire anything, used here to mimic message sending with expiration on SEND
|
||||
Just (Message {msgId = mId1}, True) <- write q1 "message 1"
|
||||
(Msg "message 1", Nothing) <- tryDelPeekMsg ms q1 mId1
|
||||
0 <- deleteExpiredMsgs ms q2 0
|
||||
Just (Message {msgId = mId2}, True) <- write q2 "message 2"
|
||||
(Msg "message 2", Nothing) <- tryDelPeekMsg ms q2 mId2
|
||||
0 <- deleteExpiredMsgs ms q2 0
|
||||
Just (Message {msgId = mId3}, True) <- write q2 "message 3"
|
||||
(Msg "message 3", Nothing) <- tryDelPeekMsg ms q2 mId3
|
||||
void $ ExceptT $ deleteQueue ms q1
|
||||
void $ ExceptT $ deleteQueue ms q2
|
||||
|
||||
testChangeReadJournal :: MsgStoreClass s => s -> IO ()
|
||||
testChangeReadJournal ms = do
|
||||
g <- C.newRandom
|
||||
@@ -226,9 +291,16 @@ testExportImportStore ms = do
|
||||
pure ()
|
||||
length <$> listDirectory (msgQueueDirectory ms rId1) `shouldReturn` 2
|
||||
length <$> listDirectory (msgQueueDirectory ms rId2) `shouldReturn` 3
|
||||
exportMessages False ms testStoreMsgsFile False
|
||||
exportMessages False (StoreJournal ms) testStoreMsgsFile False
|
||||
closeMsgStore ms
|
||||
closeStoreLog sl
|
||||
-- export with closed queues and compare
|
||||
ms2 <- newMsgStore $ testJournalStoreCfg MQStoreCfg
|
||||
readWriteQueueStore True (mkQueue ms2 True) testStoreLogFile (stmQueueStore ms2) >>= closeStoreLog
|
||||
exportMessages False (StoreJournal ms2) (testStoreMsgsFile <> ".copy") False
|
||||
s <- B.readFile testStoreMsgsFile
|
||||
B.readFile (testStoreMsgsFile <> ".copy") `shouldReturn` s
|
||||
|
||||
let cfg = (testJournalStoreCfg MQStoreCfg :: JournalStoreConfig 'QSMemory) {storePath = testStoreMsgsDir2}
|
||||
ms' <- newMsgStore cfg
|
||||
readWriteQueueStore True (mkQueue ms' True) testStoreLogFile (stmQueueStore ms') >>= closeStoreLog
|
||||
@@ -237,21 +309,93 @@ testExportImportStore ms = do
|
||||
printMessageStats "Messages" stats
|
||||
length <$> listDirectory (msgQueueDirectory ms rId1) `shouldReturn` 2
|
||||
length <$> listDirectory (msgQueueDirectory ms rId2) `shouldReturn` 3 -- 2 message files
|
||||
exportMessages False ms' testStoreMsgsFile2 False
|
||||
exportMessages False (StoreJournal ms') testStoreMsgsFile2 False
|
||||
(B.readFile testStoreMsgsFile2 `shouldReturn`) =<< B.readFile (testStoreMsgsFile <> ".bak")
|
||||
stmStore <- newMsgStore testSMTStoreConfig
|
||||
readWriteQueueStore True (mkQueue stmStore True) testStoreLogFile (queueStore stmStore) >>= closeStoreLog
|
||||
MessageStats {storedMsgsCount = 5, expiredMsgsCount = 0, storedQueues = 2} <-
|
||||
importMessages False stmStore testStoreMsgsFile2 Nothing False
|
||||
exportMessages False stmStore testStoreMsgsFile False
|
||||
exportMessages False (StoreMemory stmStore) testStoreMsgsFile False
|
||||
(B.sort <$> B.readFile testStoreMsgsFile `shouldReturn`) =<< (B.sort <$> B.readFile (testStoreMsgsFile2 <> ".bak"))
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
testUpdateMessageCounts :: PostgresMsgStore -> IO ()
|
||||
testUpdateMessageCounts ms = do
|
||||
g <- C.newRandom
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
runRight_ $ do
|
||||
q <- ExceptT $ addQueue ms rId qr
|
||||
let write s = writeMsg ms q True =<< mkMessage s
|
||||
hasSize = checkQueueSize ms
|
||||
q `hasSize` (0, True, False)
|
||||
Just (Message {msgId = mId1}, True) <- write "message 1"
|
||||
q `hasSize` (1, True, True)
|
||||
Just (Message {msgId = mId2}, False) <- write "message 2"
|
||||
q `hasSize` (2, True, True)
|
||||
Just (Message {msgId = mId3}, False) <- write "message 3"
|
||||
q `hasSize` (3, True, True)
|
||||
Nothing <- write "message 4"
|
||||
q `hasSize` (4, False, True)
|
||||
Msg "message 1" <- tryPeekMsg ms q
|
||||
q `hasSize` (4, False, True)
|
||||
Msg "message 1" <- tryDelMsg ms q mId1
|
||||
q `hasSize` (3, False, True)
|
||||
Msg "message 2" <- tryPeekMsg ms q
|
||||
(Msg "message 2", Msg "message 3") <- tryDelPeekMsg ms q mId2
|
||||
q `hasSize` (2, False, True)
|
||||
(Msg "message 3", Just MessageQuota {msgId = mId4}) <- tryDelPeekMsg ms q mId3
|
||||
q `hasSize` (1, False, True)
|
||||
(Just MessageQuota {}, Nothing) <- tryDelPeekMsg ms q mId4
|
||||
q `hasSize` (0, True, False)
|
||||
|
||||
checkQueueSize :: PostgresMsgStore -> PostgresQueue -> (Int64, Bool, Bool) -> ExceptT ErrorType IO ()
|
||||
checkQueueSize ms q (size, canWrt, expire) = liftIO $ do
|
||||
[(size', canWrt', expire')] <-
|
||||
withTransaction (dbStore $ queueStore ms) $ \db ->
|
||||
DB.query db "SELECT msg_queue_size, msg_can_write, msg_queue_expire FROM msg_queues WHERE recipient_id = ?" (Only (recipientId q))
|
||||
size' `shouldBe` size
|
||||
canWrt' `shouldBe` canWrt
|
||||
expire' `shouldBe` expire
|
||||
|
||||
testResetMessageCounts :: PostgresMsgStore -> IO ()
|
||||
testResetMessageCounts ms = do
|
||||
g <- C.newRandom
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
runRight_ $ do
|
||||
q <- ExceptT $ addQueue ms rId qr
|
||||
let write s = writeMsg ms q True =<< mkMessage s
|
||||
hasSize = checkQueueSize ms
|
||||
Just (Message {msgId = mId1}, True) <- write "message 1"
|
||||
Just (Message {msgId = mId2}, False) <- write "message 2"
|
||||
Just (Message {msgId = mId3}, False) <- write "message 3"
|
||||
Nothing <- write "message 4"
|
||||
q `hasSize` (4, False, True)
|
||||
liftIO $ setIncorrectSize q (10, True)
|
||||
Nothing <- write "message 5"
|
||||
q `hasSize` (11, False, True)
|
||||
(Msg "message 1", Msg "message 2") <- tryDelPeekMsg ms q mId1
|
||||
q `hasSize` (10, False, True)
|
||||
(Msg "message 2", Msg "message 3") <- tryDelPeekMsg ms q mId2
|
||||
q `hasSize` (9, False, True)
|
||||
(Msg "message 3", Just MessageQuota {msgId = mId4}) <- tryDelPeekMsg ms q mId3
|
||||
q `hasSize` (8, False, True)
|
||||
(Just MessageQuota {}, Just MessageQuota {msgId = mId5}) <- tryDelPeekMsg ms q mId4
|
||||
q `hasSize` (7, False, True)
|
||||
(Just MessageQuota {}, Nothing) <- tryDelPeekMsg ms q mId5
|
||||
q `hasSize` (0, True, False) -- reset
|
||||
where
|
||||
setIncorrectSize :: PostgresQueue -> (Int64, Bool) -> IO ()
|
||||
setIncorrectSize q (size, canWrt) =
|
||||
void $ withTransaction (dbStore $ queueStore ms) $ \db ->
|
||||
DB.execute db "UPDATE msg_queues SET msg_queue_size = ?, msg_can_write = ? WHERE recipient_id = ?" (size, canWrt, recipientId q)
|
||||
#endif
|
||||
|
||||
testQueueState :: JournalMsgStore s -> IO ()
|
||||
testQueueState ms = do
|
||||
g <- C.newRandom
|
||||
rId <- EntityId <$> atomically (C.randomBytes 24 g)
|
||||
let dir = msgQueueDirectory ms rId
|
||||
statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId)
|
||||
statePath = msgQueueStatePath dir rId
|
||||
createDirectoryIfMissing True dir
|
||||
state <- newMsgQueueState <$> newJournalId (random ms)
|
||||
withFile statePath WriteMode (`appendState` state)
|
||||
@@ -312,7 +456,7 @@ testMessageState ms = do
|
||||
g <- C.newRandom
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
let dir = msgQueueDirectory ms rId
|
||||
statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId)
|
||||
statePath = msgQueueStatePath dir rId
|
||||
write q s = writeMsg ms q True =<< mkMessage s
|
||||
|
||||
mId1 <- runRight $ do
|
||||
@@ -337,7 +481,7 @@ testRemoveJournals ms = do
|
||||
g <- C.newRandom
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
let dir = msgQueueDirectory ms rId
|
||||
statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId)
|
||||
statePath = msgQueueStatePath dir rId
|
||||
write q s = writeMsg ms q True =<< mkMessage s
|
||||
|
||||
runRight $ do
|
||||
@@ -361,7 +505,7 @@ testRemoveJournals ms = do
|
||||
Nothing <- tryPeekMsg ms q
|
||||
-- still not removed, queue is empty and not opened
|
||||
liftIO $ journalFilesCount dir `shouldReturn` 1
|
||||
_mq <- isolateQueue q "test" $ getMsgQueue ms q False
|
||||
_mq <- isolateQueue ms q "test" $ getMsgQueue ms q False
|
||||
-- journal is removed
|
||||
liftIO $ journalFilesCount dir `shouldReturn` 0
|
||||
liftIO $ stateBackupCount dir `shouldReturn` 1
|
||||
@@ -442,7 +586,7 @@ testExpireIdleQueues = do
|
||||
ms <- newMsgStore (testJournalStoreCfg MQStoreCfg) {idleInterval = 0}
|
||||
|
||||
let dir = msgQueueDirectory ms rId
|
||||
statePath = msgQueueStatePath dir $ B.unpack (B64.encode $ unEntityId rId)
|
||||
statePath = msgQueueStatePath dir rId
|
||||
write q s = writeMsg ms q True =<< mkMessage s
|
||||
|
||||
q <- runRight $ do
|
||||
@@ -461,7 +605,7 @@ testExpireIdleQueues = do
|
||||
old <- expireBeforeEpoch ExpirationConfig {ttl = 1, checkInterval = 1} -- no old messages
|
||||
now <- systemSeconds <$> getSystemTime
|
||||
|
||||
(expired_, stored) <- runRight $ isolateQueue q "" $ withIdleMsgQueue now ms q $ deleteExpireMsgs_ old q
|
||||
(expired_, stored) <- runRight $ isolateQueue ms q "" $ withIdleMsgQueue now ms q $ deleteExpireMsgs_ old q
|
||||
expired_ `shouldBe` Just 0
|
||||
stored `shouldBe` 0
|
||||
(Nothing, False) <- readQueueState ms statePath
|
||||
@@ -478,7 +622,7 @@ testReadFileMissing ms = do
|
||||
Msg "message 1" <- tryPeekMsg ms q
|
||||
pure q
|
||||
|
||||
mq <- fromJust <$> readTVarIO (msgQueue q)
|
||||
mq <- fromJust <$> readTVarIO (msgQueue' q)
|
||||
MsgQueueState {readState = rs} <- readTVarIO $ state mq
|
||||
closeMsgQueue ms q
|
||||
let path = journalFilePath (queueDirectory $ queue mq) $ journalId rs
|
||||
@@ -497,7 +641,7 @@ testReadFileMissingSwitch ms = do
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
q <- writeMessages ms rId qr
|
||||
|
||||
mq <- fromJust <$> readTVarIO (msgQueue q)
|
||||
mq <- fromJust <$> readTVarIO (msgQueue' q)
|
||||
MsgQueueState {readState = rs} <- readTVarIO $ state mq
|
||||
closeMsgQueue ms q
|
||||
let path = journalFilePath (queueDirectory $ queue mq) $ journalId rs
|
||||
@@ -515,7 +659,7 @@ testWriteFileMissing ms = do
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
q <- writeMessages ms rId qr
|
||||
|
||||
mq <- fromJust <$> readTVarIO (msgQueue q)
|
||||
mq <- fromJust <$> readTVarIO (msgQueue' q)
|
||||
MsgQueueState {writeState = ws} <- readTVarIO $ state mq
|
||||
closeMsgQueue ms q
|
||||
let path = journalFilePath (queueDirectory $ queue mq) $ journalId ws
|
||||
@@ -538,7 +682,7 @@ testReadAndWriteFilesMissing ms = do
|
||||
(rId, qr) <- testNewQueueRec g QMMessaging
|
||||
q <- writeMessages ms rId qr
|
||||
|
||||
mq <- fromJust <$> readTVarIO (msgQueue q)
|
||||
mq <- fromJust <$> readTVarIO (msgQueue' q)
|
||||
MsgQueueState {readState = rs, writeState = ws} <- readTVarIO $ state mq
|
||||
closeMsgQueue ms q
|
||||
removeFile $ journalFilePath (queueDirectory $ queue mq) $ journalId rs
|
||||
|
||||
@@ -45,56 +45,32 @@ utilTests = do
|
||||
runExceptT (throwTestException `catchError` handleCatch) `shouldThrow` (\(e :: IOError) -> show e == "user error (error)")
|
||||
describe "tryAllErrors" $ do
|
||||
it "should return ExceptT error as Left" $
|
||||
runExceptT (tryAllErrors testErr throwTestError) `shouldReturn` Right (Left (TestError "error"))
|
||||
runExceptT (tryAllErrors throwTestError) `shouldReturn` Right (Left (TestError "error"))
|
||||
it "should return SomeException as Left" $
|
||||
runExceptT (tryAllErrors testErr throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
|
||||
runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
|
||||
it "should return no errors as Right" $
|
||||
runExceptT (tryAllErrors testErr noErrors) `shouldReturn` Right (Right "no errors")
|
||||
describe "tryAllErrors specialized as tryTestError" $ do
|
||||
let tryTestError = tryAllErrors testErr
|
||||
it "should return ExceptT error as Left" $
|
||||
runExceptT (tryTestError throwTestError) `shouldReturn` Right (Left (TestError "error"))
|
||||
it "should return SomeException as Left" $
|
||||
runExceptT (tryTestError throwTestException) `shouldReturn` Right (Left (TestException "user error (error)"))
|
||||
it "should return no errors as Right" $
|
||||
runExceptT (tryTestError noErrors) `shouldReturn` Right (Right "no errors")
|
||||
runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors")
|
||||
describe "catchAllErrors" $ do
|
||||
it "should catch ExceptT error" $
|
||||
runExceptT (catchAllErrors testErr throwTestError handleCatch) `shouldReturn` Right "caught TestError \"error\""
|
||||
runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\""
|
||||
it "should catch SomeException" $
|
||||
runExceptT (catchAllErrors testErr throwTestException handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
|
||||
runExceptT (throwTestException `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
|
||||
it "should not throw if there are no errors" $
|
||||
runExceptT (catchAllErrors testErr noErrors throwError) `shouldReturn` Right "no errors"
|
||||
describe "catchAllErrors specialized as catchTestError" $ do
|
||||
let catchTestError = catchAllErrors testErr
|
||||
it "should catch ExceptT error" $
|
||||
runExceptT (throwTestError `catchTestError` handleCatch) `shouldReturn` Right "caught TestError \"error\""
|
||||
it "should catch SomeException" $
|
||||
runExceptT (throwTestException `catchTestError` handleCatch) `shouldReturn` Right "caught TestException \"user error (error)\""
|
||||
it "should not throw if there are no errors" $
|
||||
runExceptT (noErrors `catchTestError` throwError) `shouldReturn` Right "no errors"
|
||||
runExceptT (noErrors `catchAllErrors` throwError) `shouldReturn` Right "no errors"
|
||||
describe "catchThrow" $ do
|
||||
it "should re-throw ExceptT error" $
|
||||
runExceptT (throwTestError `catchThrow` testErr) `shouldReturn` Left (TestError "error")
|
||||
runExceptT (throwTestError `catchThrow` fromSomeException) `shouldReturn` Left (TestError "error")
|
||||
it "should catch SomeException and throw as ExceptT error" $
|
||||
runExceptT (throwTestException `catchThrow` testErr) `shouldReturn` Left (TestException "user error (error)")
|
||||
runExceptT (throwTestException `catchThrow` fromSomeException) `shouldReturn` Left (TestException "user error (error)")
|
||||
it "should not throw if there are no exceptions" $
|
||||
runExceptT (noErrors `catchThrow` testErr) `shouldReturn` Right "no errors"
|
||||
runExceptT (noErrors `catchThrow` fromSomeException) `shouldReturn` Right "no errors"
|
||||
describe "allFinally should run final action" $ do
|
||||
it "then throw ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (allFinally testErr throwTestError final) `shouldReturn` Left (TestError "error")
|
||||
runExceptT (throwTestError `allFinally` final) `shouldReturn` Left (TestError "error")
|
||||
it "then throw SomeException as ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (allFinally testErr throwTestException final) `shouldReturn` Left (TestException "user error (error)")
|
||||
runExceptT (throwTestException `allFinally` final) `shouldReturn` Left (TestException "user error (error)")
|
||||
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
|
||||
runExceptT (allFinally testErr noErrors final) `shouldReturn` Right "no errors"
|
||||
describe "allFinally specialized as testFinally should run final action" $ do
|
||||
let testFinally = allFinally testErr
|
||||
it "then throw ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (throwTestError `testFinally` final) `shouldReturn` Left (TestError "error")
|
||||
it "then throw SomeException as ExceptT error" $ withFinal $ \final ->
|
||||
runExceptT (throwTestException `testFinally` final) `shouldReturn` Left (TestException "user error (error)")
|
||||
it "and should not throw if there are no exceptions" $ withFinal $ \final ->
|
||||
runExceptT (noErrors `testFinally` final) `shouldReturn` Right "no errors"
|
||||
runExceptT (noErrors `allFinally` final) `shouldReturn` Right "no errors"
|
||||
where
|
||||
throwTestError :: ExceptT TestError IO String
|
||||
throwTestError = throwError $ TestError "error"
|
||||
@@ -102,8 +78,6 @@ utilTests = do
|
||||
throwTestException = liftIO $ throwIO $ userError "error"
|
||||
noErrors :: ExceptT TestError IO String
|
||||
noErrors = pure "no errors"
|
||||
testErr :: SomeException -> TestError
|
||||
testErr = TestException . show
|
||||
handleCatch :: TestError -> ExceptT TestError IO String
|
||||
handleCatch e = pure $ "caught " <> show e
|
||||
handleException :: SomeException -> ExceptT TestError IO String
|
||||
@@ -119,3 +93,6 @@ data TestError = TestError String | TestException String
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Exception TestError
|
||||
|
||||
instance AnyError TestError where
|
||||
fromSomeException = TestException . show
|
||||
|
||||
@@ -12,7 +12,7 @@ import Data.Maybe (fromJust, isJust)
|
||||
import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore)
|
||||
import Simplex.Messaging.Agent.Store.Postgres.Common (DBOpts (..))
|
||||
import qualified Simplex.Messaging.Agent.Store.Postgres.Migrations as Migrations
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration)
|
||||
import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfig (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration)
|
||||
import Simplex.Messaging.Util (ifM, whenM)
|
||||
import System.Directory (doesFileExist, removeFile)
|
||||
import System.Process (readCreateProcess, shell)
|
||||
@@ -30,12 +30,12 @@ postgresSchemaDumpTest migrations skipComparisonForDownMigrations testDBOpts@DBO
|
||||
testVerifySchemaDump = do
|
||||
savedSchema <- ifM (doesFileExist srcSchemaPath) (readFile srcSchemaPath) (pure "")
|
||||
savedSchema `deepseq` pure ()
|
||||
void $ createDBStore testDBOpts migrations MCConsole
|
||||
void $ createDBStore testDBOpts migrations (MigrationConfig MCConsole Nothing)
|
||||
getSchema srcSchemaPath `shouldReturn` savedSchema
|
||||
|
||||
testSchemaMigrations = do
|
||||
let noDownMigrations = dropWhileEnd (\Migration {down} -> isJust down) migrations
|
||||
Right st <- createDBStore testDBOpts noDownMigrations MCError
|
||||
Right st <- createDBStore testDBOpts noDownMigrations (MigrationConfig MCError Nothing)
|
||||
mapM_ (testDownMigration st) $ drop (length noDownMigrations) migrations
|
||||
closeDBStore st
|
||||
whenM (doesFileExist testSchemaPath) $ removeFile testSchemaPath
|
||||
|
||||
+9
-3
@@ -229,6 +229,7 @@ cfgMS msType = withStoreCfg (testServerStoreConfig msType) $ \serverStoreCfg ->
|
||||
dailyBlockQueueQuota = 20,
|
||||
messageExpiration = Just defaultMessageExpiration,
|
||||
expireMessagesOnStart = True,
|
||||
expireMessagesOnSend = False,
|
||||
idleQueueInterval = defaultIdleQueueInterval,
|
||||
notificationExpiration = defaultNtfExpiration,
|
||||
inactiveClientExpiration = Just defaultInactiveClientExpiration,
|
||||
@@ -273,9 +274,14 @@ serverStoreConfig_ useDbStoreLog = \case
|
||||
ASType SQSMemory SMSJournal ->
|
||||
ASSCfg SQSMemory SMSJournal $ SSCMemoryJournal {storeLogFile = testStoreLogFile, storeMsgsPath = testStoreMsgsDir}
|
||||
ASType SQSPostgres SMSJournal ->
|
||||
let dbStoreLogPath = if useDbStoreLog then Just testStoreLogFile else Nothing
|
||||
storeCfg = PostgresStoreCfg {dbOpts = testStoreDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = 86400}
|
||||
in ASSCfg SQSPostgres SMSJournal SSCDatabaseJournal {storeCfg, storeMsgsPath' = testStoreMsgsDir}
|
||||
ASSCfg SQSPostgres SMSJournal SSCDatabaseJournal {storeCfg, storeMsgsPath' = testStoreMsgsDir}
|
||||
#if defined(dbServerPostgres)
|
||||
ASType SQSPostgres SMSPostgres ->
|
||||
ASSCfg SQSPostgres SMSPostgres $ SSCDatabase storeCfg
|
||||
#endif
|
||||
where
|
||||
dbStoreLogPath = if useDbStoreLog then Just testStoreLogFile else Nothing
|
||||
storeCfg = PostgresStoreCfg {dbOpts = testStoreDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = 86400}
|
||||
|
||||
cfgV7 :: AServerConfig
|
||||
cfgV7 = updateCfg cfg $ \cfg' -> cfg' {smpServerVRange = mkVersionRange minServerSMPRelayVersion authCmdsSMPVersion}
|
||||
|
||||
+71
-7
@@ -14,6 +14,7 @@
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module ServerTests where
|
||||
|
||||
@@ -42,10 +43,10 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll, parseString)
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server (exportMessages)
|
||||
import Simplex.Messaging.Server.Env.STM (AStoreType (..), ServerConfig (..), ServerStoreCfg (..), readWriteQueueStore)
|
||||
import Simplex.Messaging.Server.Env.STM (AStoreType (..), MsgStore (..), ServerConfig (..), ServerStoreCfg (..), readWriteQueueStore)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.MsgStore.Journal (JournalStoreConfig (..), QStoreCfg (..), stmQueueStore)
|
||||
import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), SMSType (..), SQSType (..), newMsgStore)
|
||||
import Simplex.Messaging.Server.MsgStore.Types (MsgStoreClass (..), QSType (..), SMSType (..), SQSType (..), newMsgStore)
|
||||
import Simplex.Messaging.Server.Stats (PeriodStatsData (..), ServerStatsData (..))
|
||||
import Simplex.Messaging.Server.StoreLog (StoreLogRecord (..), closeStoreLog)
|
||||
import Simplex.Messaging.Transport
|
||||
@@ -59,6 +60,11 @@ import Test.HUnit
|
||||
import Test.Hspec hiding (fit, it)
|
||||
import Util
|
||||
|
||||
#if defined(dbServerPostgres)
|
||||
import CoreTests.MsgStoreTests (testPostgresStoreConfig)
|
||||
import Simplex.Messaging.Server.MsgStore.Postgres (PostgresMsgStoreCfg (..), exportDbMessages)
|
||||
#endif
|
||||
|
||||
serverTests :: SpecWith (ASrvTransport, AStoreType)
|
||||
serverTests = do
|
||||
describe "SMP queues" $ do
|
||||
@@ -86,6 +92,7 @@ serverTests = do
|
||||
describe "Message notifications" $ do
|
||||
testMessageNotifications
|
||||
testMessageServiceNotifications
|
||||
testServiceNotificationsTwoRestarts
|
||||
describe "Message expiration" $ do
|
||||
testMsgExpireOnSend
|
||||
testMsgExpireOnInterval
|
||||
@@ -914,14 +921,27 @@ testRestoreExpireMessages =
|
||||
exportStoreMessages :: AStoreType -> IO ()
|
||||
exportStoreMessages = \case
|
||||
ASType _ SMSJournal -> export
|
||||
ASType _ SMSPostgres -> exportDB
|
||||
ASType _ SMSMemory -> pure ()
|
||||
where
|
||||
export = do
|
||||
ms <- newMsgStore (testJournalStoreCfg MQStoreCfg) {quota = 4}
|
||||
ms <- readWriteQueues
|
||||
exportMessages False (StoreJournal ms) testStoreMsgsFile False
|
||||
closeMsgStore ms
|
||||
#if defined(dbServerPostgres)
|
||||
exportDB = do
|
||||
readWriteQueues >>= closeMsgStore
|
||||
ms' <- newMsgStore (testPostgresStoreConfig {quota = 4} :: PostgresMsgStoreCfg)
|
||||
_n <- withFile testStoreMsgsFile WriteMode $ exportDbMessages False ms'
|
||||
closeMsgStore ms'
|
||||
#else
|
||||
exportDB = error "compiled without server_postgres flag"
|
||||
#endif
|
||||
readWriteQueues = do
|
||||
ms <- newMsgStore ((testJournalStoreCfg MQStoreCfg) {quota = 4} :: JournalStoreConfig 'QSMemory)
|
||||
readWriteQueueStore True (mkQueue ms True) testStoreLogFile (stmQueueStore ms) >>= closeStoreLog
|
||||
removeFileIfExists testStoreMsgsFile
|
||||
exportMessages False ms testStoreMsgsFile False
|
||||
closeMsgStore ms
|
||||
pure ms
|
||||
runTest :: Transport c => TProxy c 'TServer -> (THandleSMP c 'TClient -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
@@ -1091,7 +1111,6 @@ testMessageServiceNotifications =
|
||||
(rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g
|
||||
Resp "1" _ (NID nId _) <- signSendRecv rh rKey ("1", rId, NKEY nPub rcvNtfPubDhKey)
|
||||
serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g
|
||||
-- TODO [certs] we need to get certificate fingerprint and include it into signed over for NSUB commands
|
||||
testNtfServiceClient t serviceKeys $ \nh1 -> do
|
||||
-- can't subscribe without service signature in service connection
|
||||
Resp "2a" _ (ERR SERVICE) <- signSendRecv nh1 nKey ("2a", nId, NSUB)
|
||||
@@ -1155,12 +1174,57 @@ testMessageServiceNotifications =
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh
|
||||
pure ()
|
||||
|
||||
testServiceNotificationsTwoRestarts :: SpecWith (ASrvTransport, AStoreType)
|
||||
testServiceNotificationsTwoRestarts =
|
||||
it "subscribe notifier as service and deliver notifications after two restarts" $ \ps@(ATransport t, _) -> do
|
||||
g <- C.newRandom
|
||||
(sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
(nPub, nKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
serviceKeys@(_, servicePK) <- atomically $ C.generateKeyPair g
|
||||
(rcvNtfPubDhKey, _) <- atomically $ C.generateKeyPair g
|
||||
(rId, rKey, sId, dec, serviceId) <- withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh -> do
|
||||
(sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub
|
||||
let dec = decryptMsgV3 dhShared
|
||||
Resp "0" _ (NID nId _) <- signSendRecv rh rKey ("0", rId, NKEY nPub rcvNtfPubDhKey)
|
||||
testNtfServiceClient t serviceKeys $ \nh -> do
|
||||
Resp "1" _ (SOK (Just serviceId)) <- serviceSignSendRecv nh nKey servicePK ("1", nId, NSUB)
|
||||
deliverMessage rh rId rKey sh sId sKey nh "hello" dec
|
||||
pure (rId, rKey, sId, dec, serviceId)
|
||||
threadDelay 250000
|
||||
withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh ->
|
||||
testNtfServiceClient t serviceKeys $ \nh -> do
|
||||
Resp "2.1" serviceId' (SOKS n) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("2.1", serviceId, NSUBS)
|
||||
n `shouldBe` 1
|
||||
Resp "2.2" _ (SOK Nothing) <- signSendRecv rh rKey ("2.2", rId, SUB)
|
||||
serviceId' `shouldBe` serviceId
|
||||
deliverMessage rh rId rKey sh sId sKey nh "hello 2" dec
|
||||
threadDelay 250000
|
||||
withSmpServerStoreLogOn ps testPort $ runTest2 t $ \sh rh ->
|
||||
testNtfServiceClient t serviceKeys $ \nh -> do
|
||||
Resp "3.1" _ (SOKS n) <- signSendRecv nh (C.APrivateAuthKey C.SEd25519 servicePK) ("3.1", serviceId, NSUBS)
|
||||
n `shouldBe` 1
|
||||
Resp "3.2" _ (SOK Nothing) <- signSendRecv rh rKey ("3.2", rId, SUB)
|
||||
deliverMessage rh rId rKey sh sId sKey nh "hello 3" dec
|
||||
where
|
||||
runTest2 :: Transport c => TProxy c 'TServer -> (THandleSMP c 'TClient -> THandleSMP c 'TClient -> IO a) -> ThreadId -> IO a
|
||||
runTest2 _ test' server = do
|
||||
a <- testSMPClient $ \h1 -> testSMPClient $ \h2 -> test' h1 h2
|
||||
killThread server
|
||||
pure a
|
||||
deliverMessage rh rId rKey sh sId sKey nh msgText dec = do
|
||||
Resp "msg-1" _ OK <- signSendRecv sh sKey ("msg-1", sId, _SEND' msgText)
|
||||
Resp "" _ (Msg mId msg) <- tGet1 rh
|
||||
Resp "msg-2" _ OK <- signSendRecv rh rKey ("msg-2", rId, ACK mId)
|
||||
(dec mId msg, Right msgText) #== "delivered from queue"
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh
|
||||
pure ()
|
||||
|
||||
testMsgExpireOnSend :: SpecWith (ASrvTransport, AStoreType)
|
||||
testMsgExpireOnSend =
|
||||
it "should expire messages that are not received before messageTTL on SEND" $ \(ATransport (t :: TProxy c 'TServer), msType) -> do
|
||||
g <- C.newRandom
|
||||
(sPub, sKey) <- atomically $ C.generateAuthKeyPair C.SEd25519 g
|
||||
let cfg' = updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 10000}}
|
||||
let cfg' = updateCfg (cfgMS msType) $ \cfg_ -> cfg_ {expireMessagesOnSend = True, messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 10000}}
|
||||
withSmpServerConfigOn (ATransport t) cfg' testPort $ \_ ->
|
||||
testSMPClient @c $ \sh -> do
|
||||
(sId, rId, rKey, dhShared) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub
|
||||
|
||||
+11
-4
@@ -102,9 +102,11 @@ main = do
|
||||
] -- skipComparisonForDownMigrations
|
||||
testStoreDBOpts
|
||||
"src/Simplex/Messaging/Server/QueueStore/Postgres/server_schema.sql"
|
||||
aroundAll_ (postgressBracket testServerDBConnectInfo) $
|
||||
around_ (postgressBracket testServerDBConnectInfo) $ do
|
||||
describe "SMP server via TLS, postgres+jornal message store" $
|
||||
before (pure (transport @TLS, ASType SQSPostgres SMSJournal)) serverTests
|
||||
describe "SMP server via TLS, postgres-only message store" $
|
||||
before (pure (transport @TLS, ASType SQSPostgres SMSPostgres)) serverTests
|
||||
#endif
|
||||
describe "SMP server via TLS, jornal message store" $ do
|
||||
describe "SMP syntax" $ serverSyntaxTests (transport @TLS)
|
||||
@@ -122,16 +124,21 @@ main = do
|
||||
[] -- skipComparisonForDownMigrations
|
||||
ntfTestStoreDBOpts
|
||||
"src/Simplex/Messaging/Notifications/Server/Store/ntf_server_schema.sql"
|
||||
aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $ do
|
||||
around_ (postgressBracket ntfTestServerDBConnectInfo) $ do
|
||||
describe "Notifications server (SMP server: jornal store)" $
|
||||
ntfServerTests (transport @TLS, ASType SQSMemory SMSJournal)
|
||||
aroundAll_ (postgressBracket testServerDBConnectInfo) $
|
||||
around_ (postgressBracket testServerDBConnectInfo) $ do
|
||||
describe "Notifications server (SMP server: postgres+jornal store)" $
|
||||
ntfServerTests (transport @TLS, ASType SQSPostgres SMSJournal)
|
||||
aroundAll_ (postgressBracket testServerDBConnectInfo) $ do
|
||||
describe "Notifications server (SMP server: postgres-only store)" $
|
||||
ntfServerTests (transport @TLS, ASType SQSPostgres SMSPostgres)
|
||||
around_ (postgressBracket testServerDBConnectInfo) $ do
|
||||
describe "SMP client agent, postgres+jornal message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSJournal)
|
||||
describe "SMP client agent, postgres-only message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSPostgres)
|
||||
describe "SMP proxy, postgres+jornal message store" $
|
||||
before (pure $ ASType SQSPostgres SMSJournal) smpProxyTests
|
||||
describe "SMP proxy, postgres-only message store" $
|
||||
before (pure $ ASType SQSPostgres SMSPostgres) smpProxyTests
|
||||
#endif
|
||||
describe "SMP client agent, jornal message store" $ agentTests (transport @TLS, ASType SQSMemory SMSJournal)
|
||||
describe "SMP proxy, jornal message store" $
|
||||
|
||||
+2
-2
@@ -38,7 +38,7 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs)
|
||||
import qualified Simplex.Messaging.Crypto.File as CF
|
||||
import Simplex.Messaging.Encoding.String (StrEncoding (..))
|
||||
import Simplex.Messaging.Protocol (BasicAuth, ProtoServerWithAuth (..), ProtocolServer (..), XFTPServerWithAuth)
|
||||
import Simplex.Messaging.Protocol (BasicAuth, NetworkError (..), ProtoServerWithAuth (..), ProtocolServer (..), XFTPServerWithAuth)
|
||||
import Simplex.Messaging.Server.Expiration (ExpirationConfig (..))
|
||||
import Simplex.Messaging.Util (tshow)
|
||||
import System.Directory (doesDirectoryExist, doesFileExist, getFileSize, listDirectory, removeFile)
|
||||
@@ -84,7 +84,7 @@ xftpAgentTests =
|
||||
it "should pass without basic auth" $ testXFTPServerTest Nothing (noAuthSrv testXFTPServer2) `shouldReturn` Nothing
|
||||
let srv1 = testXFTPServer2 {keyHash = "1234"}
|
||||
it "should fail with incorrect fingerprint" $ do
|
||||
testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK)
|
||||
testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) $ NETWORK NEUnknownCAError)
|
||||
describe "server with password" $ do
|
||||
let auth = Just "abcd"
|
||||
srv = ProtoServerWithAuth testXFTPServer2
|
||||
|
||||
Reference in New Issue
Block a user