mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 20:45:52 +00:00
agent: lock rows for concurrent queries in PostgreSQL (#1688)
* agent: lock rows for concurrent queries in PostgreSQL * fix race conditions in workers * refactor
This commit is contained in:
@@ -223,6 +223,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
|
||||
agentXFTPDownloadChunk c userId digest replica chunkSpec
|
||||
liftIO $ waitUntilForeground c
|
||||
(entityId, complete, progress) <- withStore c $ \db -> runExceptT $ do
|
||||
liftIO $ lockRcvFileForUpdate db rcvFileId
|
||||
liftIO $ updateRcvFileChunkReceived db (rcvChunkReplicaId replica) rcvChunkId relChunkPath
|
||||
RcvFile {size = FileSize currentSize, chunks, redirect} <- ExceptT $ getRcvFile db rcvFileId
|
||||
let rcvd = receivedSize chunks
|
||||
@@ -413,6 +414,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
|
||||
(digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath
|
||||
withStore c $ \db -> do
|
||||
lockSndFileForUpdate db sndFileId
|
||||
updateSndFileEncrypted db sndFileId digest chunkSpecsDigests
|
||||
getSndFile db sndFileId
|
||||
else pure sndFile
|
||||
@@ -530,6 +532,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
|
||||
agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec'
|
||||
liftIO $ waitUntilForeground c
|
||||
sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do
|
||||
lockSndFileForUpdate db sndFileId
|
||||
updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded
|
||||
getSndFile db sndFileId
|
||||
let uploaded = uploadedSize chunks
|
||||
|
||||
@@ -1145,7 +1145,8 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup =
|
||||
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport}
|
||||
case sq_ of
|
||||
Just sq@SndQueue {e2ePubKey = Just _k} -> do
|
||||
e2eSndParams <- withStore c $ \db ->
|
||||
e2eSndParams <- withStore c $ \db -> do
|
||||
lockConnForUpdate db connId
|
||||
getSndRatchet db connId v >>= \case
|
||||
Right r -> pure $ Right $ snd r
|
||||
Left e -> do
|
||||
@@ -1159,6 +1160,7 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup =
|
||||
sndKey_ = snd <$> invLink_
|
||||
(q, _) <- lift $ newSndQueue userId "" qInfo sndKey_
|
||||
withStore c $ \db -> runExceptT $ do
|
||||
liftIO $ lockConnForUpdate db connId
|
||||
e2eSndParams <- createRatchet_ db g maxSupported pqSupport e2eRcvParams
|
||||
sq' <- maybe (ExceptT $ updateNewConnSnd db connId q) pure sq_
|
||||
pure (cData, sq', e2eSndParams, lnkId_)
|
||||
@@ -1237,7 +1239,8 @@ joinConnSrv c nm userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup su
|
||||
AgentConfig {smpClientVRange = vr, smpAgentVRange, e2eEncryptVRange = e2eVR} <- asks config
|
||||
let qUri = SMPQueueUri vr $ (rcvSMPQueueAddress rq) {queueMode = Just QMMessaging}
|
||||
crData = ConnReqUriData SSSimplex smpAgentVRange [qUri] Nothing
|
||||
e2eRcvParams <- withStore' c $ \db ->
|
||||
e2eRcvParams <- withStore' c $ \db -> do
|
||||
lockConnForUpdate db connId
|
||||
getRatchetX3dhKeys db connId >>= \case
|
||||
Right keys -> pure $ CR.mkRcvE2ERatchetParams (maxVersion e2eVR) keys
|
||||
Left e -> do
|
||||
@@ -1957,7 +1960,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server,
|
||||
withRetryLock2 ri' qLock $ \riState loop -> do
|
||||
liftIO $ waitWhileSuspended c
|
||||
liftIO $ waitForUserNetwork c
|
||||
resp <- tryError $ case msgType of
|
||||
resp <- tryAllErrors $ case msgType of
|
||||
AM_CONN_INFO -> sendConfirmation c NRMBackground sq msgBody
|
||||
AM_CONN_INFO_REPLY -> sendConfirmation c NRMBackground sq msgBody
|
||||
_ -> case pendingMsgPrepData_ of
|
||||
@@ -2097,10 +2100,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server,
|
||||
notifyDelMsgs :: InternalId -> AgentErrorType -> UTCTime -> AM ()
|
||||
notifyDelMsgs msgId err expireTs = do
|
||||
notifyDel msgId $ MERR (unId msgId) err
|
||||
msgIds_ <- withStore' c $ \db -> getExpiredSndMessages db connId sq expireTs
|
||||
msgIds_ <- withStore' c $ \db -> do
|
||||
msgIds_ <- getExpiredSndMessages db connId sq expireTs
|
||||
forM_ msgIds_ $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure ()
|
||||
pure msgIds_
|
||||
forM_ (L.nonEmpty msgIds_) $ \msgIds -> do
|
||||
notify $ MERRS (L.map unId msgIds) err
|
||||
withStore' c $ \db -> forM_ msgIds $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure ()
|
||||
atomically $ incSMPServerStat' c userId server sentExpiredErrs (length msgIds_ + 1)
|
||||
delMsg :: InternalId -> AM ()
|
||||
delMsg = delMsgKeep False
|
||||
@@ -3025,7 +3030,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
throwE e
|
||||
agentClientMsg :: TVar ChaChaDRG -> ByteString -> AM (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448))
|
||||
agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do
|
||||
rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY
|
||||
liftIO $ lockConnForUpdate db connId
|
||||
rc <- ExceptT $ getRatchetForUpdate db connId -- ratchet state pre-decryption - required for processing EREADY
|
||||
(agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage
|
||||
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
|
||||
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
|
||||
@@ -3260,6 +3266,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
Just sqs' -> do
|
||||
(sq_@SndQueue {sndPrivateKey}, dhPublicKey) <- lift $ newSndQueue userId connId qInfo Nothing
|
||||
sq2 <- withStore c $ \db -> do
|
||||
lockConnForUpdate db connId
|
||||
liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs
|
||||
addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
|
||||
logServer "<--" c srv rId $ "MSG <QADD>:" <> logSecret' srvMsgId <> " " <> logSecret (senderId queueAddress)
|
||||
@@ -3564,7 +3571,7 @@ agentRatchetEncrypt db cData msg getPaddedLen pqEnc_ currentE2EVersion = do
|
||||
|
||||
agentRatchetEncryptHeader :: DB.Connection -> ConnData -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> CR.VersionE2E -> ExceptT StoreError IO (CR.MsgEncryptKeyX448, Int, PQEncryption)
|
||||
agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport} getPaddedLen pqEnc_ currentE2EVersion = do
|
||||
rc <- ExceptT $ getRatchet db connId
|
||||
rc <- ExceptT $ getRatchetForUpdate db connId
|
||||
let paddedLen = getPaddedLen v pqSupport
|
||||
(mek, rc') <- withExceptT (SEAgentError . cryptoError) $ CR.rcEncryptHeader rc pqEnc_ currentE2EVersion
|
||||
liftIO $ updateRatchet db connId rc' CR.SMDNoChange
|
||||
@@ -3573,7 +3580,7 @@ agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport}
|
||||
-- encoded EncAgentMessage -> encoded AgentMessage
|
||||
agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption)
|
||||
agentRatchetDecrypt g db connId encAgentMsg = do
|
||||
rc <- ExceptT $ getRatchet db connId
|
||||
rc <- ExceptT $ getRatchetForUpdate db connId
|
||||
agentRatchetDecrypt' g db connId rc encAgentMsg
|
||||
|
||||
agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption)
|
||||
|
||||
@@ -2114,16 +2114,17 @@ withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (
|
||||
withWork c doWork = withWork_ c doWork . withStore' c
|
||||
{-# INLINE withWork #-}
|
||||
|
||||
-- setting doWork flag to "no work" before getWork rather than after prevents race condition when flag is set to "has work" by another thread after getWork call.
|
||||
withWork_ :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m ()
|
||||
withWork_ c doWork getWork action =
|
||||
getWork >>= \case
|
||||
Right (Just r) -> action r
|
||||
Right Nothing -> noWork
|
||||
-- worker is stopped here (noWork) because the next iteration is likely to produce the same result
|
||||
noWork >> getWork >>= \case
|
||||
Right (Just r) -> hasWork >> action r
|
||||
Right Nothing -> pure ()
|
||||
Left e
|
||||
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
|
||||
| otherwise -> notifyErr INTERNAL e
|
||||
| isWorkItemError e -> notifyErr (CRITICAL False) e -- worker remains stopped here because the next iteration is likely to produce the same result
|
||||
| otherwise -> hasWork >> notifyErr INTERNAL e
|
||||
where
|
||||
hasWork = atomically $ hasWorkToDo' doWork
|
||||
noWork = liftIO $ noWorkToDo doWork
|
||||
notifyErr err e = do
|
||||
logError $ "withWork_ error: " <> tshow e
|
||||
@@ -2131,22 +2132,24 @@ withWork_ c doWork getWork action =
|
||||
|
||||
withWorkItems :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' [Either e' a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m ()
|
||||
withWorkItems c doWork getWork action = do
|
||||
getWork >>= \case
|
||||
Right [] -> noWork
|
||||
noWork >> getWork >>= \case
|
||||
Right [] -> pure ()
|
||||
Right rs -> do
|
||||
let (errs, items) = partitionEithers rs
|
||||
case L.nonEmpty items of
|
||||
Just items' -> action items'
|
||||
Just items' -> hasWork >> action items'
|
||||
Nothing -> do
|
||||
let criticalErr = find isWorkItemError errs
|
||||
forM_ criticalErr $ \err -> do
|
||||
notifyErr (CRITICAL False) err
|
||||
when (all isWorkItemError errs) noWork
|
||||
case find isWorkItemError errs of
|
||||
Nothing -> hasWork
|
||||
Just err -> do
|
||||
notifyErr (CRITICAL False) err
|
||||
unless (all isWorkItemError errs) hasWork
|
||||
forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (\e -> ("", INTERNAL $ show e))
|
||||
Left e
|
||||
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
|
||||
| otherwise -> notifyErr INTERNAL e
|
||||
| isWorkItemError e -> notifyErr (CRITICAL False) e
|
||||
| otherwise -> hasWork >> notifyErr INTERNAL e
|
||||
where
|
||||
hasWork = atomically $ hasWorkToDo' doWork
|
||||
noWork = liftIO $ noWorkToDo doWork
|
||||
notifyErr err e = do
|
||||
logError $ "withWorkItems error: " <> tshow e
|
||||
|
||||
@@ -52,6 +52,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
getConnSubs,
|
||||
getDeletedConns,
|
||||
getConnsData,
|
||||
lockConnForUpdate,
|
||||
setConnDeleted,
|
||||
setConnUserId,
|
||||
setConnAgentVersion,
|
||||
@@ -140,6 +141,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
createRatchet,
|
||||
deleteRatchet,
|
||||
getRatchet,
|
||||
getRatchetForUpdate,
|
||||
getSkippedMsgKeys,
|
||||
updateRatchet,
|
||||
-- Async commands
|
||||
@@ -187,6 +189,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
-- Rcv files
|
||||
createRcvFile,
|
||||
createRcvFileRedirect,
|
||||
lockRcvFileForUpdate,
|
||||
getRcvFile,
|
||||
getRcvFileByEntityId,
|
||||
getRcvFileRedirects,
|
||||
@@ -207,6 +210,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
|
||||
getRcvFilesExpired,
|
||||
-- Snd files
|
||||
createSndFile,
|
||||
lockSndFileForUpdate,
|
||||
getSndFile,
|
||||
getSndFileByEntityId,
|
||||
getNextSndFileToPrepare,
|
||||
@@ -405,7 +409,7 @@ createNewConn db gVar cData cMode = do
|
||||
-- TODO [certs rcv] store clientServiceId from NewRcvQueue
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue)
|
||||
updateNewConnRcv db connId rq subMode =
|
||||
getConn db connId $>>= \case
|
||||
getConnForUpdate db connId $>>= \case
|
||||
(SomeConn _ NewConnection {}) -> updateConn
|
||||
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnRcv" $ connType c
|
||||
@@ -415,7 +419,7 @@ updateNewConnRcv db connId rq subMode =
|
||||
|
||||
updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
||||
updateNewConnSnd db connId sq =
|
||||
getConn db connId $>>= \case
|
||||
getConnForUpdate db connId $>>= \case
|
||||
(SomeConn _ NewConnection {}) -> updateConn
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnSnd" $ connType c
|
||||
where
|
||||
@@ -449,7 +453,11 @@ checkConfirmedSndQueueExists_ db SndQueue {server, sndId} =
|
||||
maybeFirstRow' False fromOnlyBI $
|
||||
DB.query
|
||||
db
|
||||
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
|
||||
( "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
|
||||
#if defined(dpPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(host server, port server, sndId, New)
|
||||
|
||||
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
|
||||
@@ -488,14 +496,14 @@ deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of
|
||||
|
||||
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
||||
upgradeRcvConnToDuplex db connId sq =
|
||||
getConn db connId $>>= \case
|
||||
getConnForUpdate db connId $>>= \case
|
||||
(SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType "upgradeRcvConnToDuplex" $ connType c
|
||||
|
||||
-- TODO [certs rcv] store clientServiceId from NewRcvQueue
|
||||
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue)
|
||||
upgradeSndConnToDuplex db connId rq subMode =
|
||||
getConn db connId >>= \case
|
||||
getConnForUpdate db connId >>= \case
|
||||
Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType "upgradeSndConnToDuplex" $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
@@ -503,7 +511,7 @@ upgradeSndConnToDuplex db connId rq subMode =
|
||||
-- TODO [certs rcv] store clientServiceId from NewRcvQueue
|
||||
addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue)
|
||||
addConnRcvQueue db connId rq subMode =
|
||||
getConn db connId >>= \case
|
||||
getConnForUpdate db connId >>= \case
|
||||
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnRcvQueue" $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
@@ -515,7 +523,7 @@ addConnRcvQueue_ db connId rq@RcvQueue {server} subMode = do
|
||||
|
||||
addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
||||
addConnSndQueue db connId sq =
|
||||
getConn db connId >>= \case
|
||||
getConnForUpdate db connId >>= \case
|
||||
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnSndQueue" $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
@@ -1048,7 +1056,14 @@ setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError
|
||||
setMsgUserAck db connId agentMsgId = runExceptT $ do
|
||||
(dbRcvId, srvMsgId) <-
|
||||
ExceptT . firstRow id (SEMsgNotFound "setMsgUserAck") $
|
||||
DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
|
||||
DB.query
|
||||
db
|
||||
( "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?"
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(connId, agentMsgId)
|
||||
rq <- ExceptT $ getRcvQueueById db connId dbRcvId
|
||||
liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (BI True, connId, agentMsgId)
|
||||
pure (rq, srvMsgId)
|
||||
@@ -1120,6 +1135,9 @@ deleteMsgContent db connId msgId = do
|
||||
|
||||
deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO ()
|
||||
deleteDeliveredSndMsg db connId msgId = do
|
||||
#if defined(dbPostgres)
|
||||
_ :: [Only Int] <- DB.query db "SELECT 1 FROM messages WHERE conn_id = ? AND internal_id = ? FOR UPDATE" (connId, msgId)
|
||||
#endif
|
||||
cnt <- countPendingSndDeliveries_ db connId msgId
|
||||
when (cnt == 0) $ deleteMsg db connId msgId
|
||||
|
||||
@@ -1138,11 +1156,15 @@ deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do
|
||||
maybeFirstRow id $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT rcpt_status, snd_message_body_id FROM snd_messages
|
||||
WHERE NOT EXISTS (SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0)
|
||||
AND conn_id = ? AND internal_id = ?
|
||||
|]
|
||||
( [sql|
|
||||
SELECT rcpt_status, snd_message_body_id FROM snd_messages
|
||||
WHERE NOT EXISTS (SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0)
|
||||
AND conn_id = ? AND internal_id = ?
|
||||
|]
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(connId, msgId, connId, msgId)
|
||||
deleteMsgAndBody :: (Maybe MsgReceiptStatus, Maybe Int64) -> IO ()
|
||||
deleteMsgAndBody (rcptStatus_, sndMsgBodyId_) = do
|
||||
@@ -1151,9 +1173,11 @@ deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do
|
||||
Just MROk -> deleteMsg
|
||||
_ -> if keepForReceipt then deleteMsgContent else deleteMsg
|
||||
del db connId msgId
|
||||
forM_ sndMsgBodyId_ $ \bodyId ->
|
||||
-- Delete message body if it is not used by any snd message.
|
||||
-- The current snd message is already deleted by deleteMsg or cleared by deleteMsgContent.
|
||||
forM_ sndMsgBodyId_ $ \bodyId -> do
|
||||
#if defined(dbPostgres)
|
||||
-- lock for concurrent deletion of different records in snd_messages pointing to the same record in snd_message_bodies
|
||||
_ :: [Only Int] <- DB.query db "SELECT 1 FROM snd_message_bodies WHERE snd_message_body_id = ? FOR UPDATE" (Only bodyId)
|
||||
#endif
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
@@ -1260,9 +1284,25 @@ deleteRatchet :: DB.Connection -> ConnId -> IO ()
|
||||
deleteRatchet db connId =
|
||||
DB.execute db "DELETE FROM ratchets WHERE conn_id = ?" (Only connId)
|
||||
|
||||
getRatchetForUpdate :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
|
||||
getRatchetForUpdate =
|
||||
#if defined(dbPostgres)
|
||||
getRatchet_ (ratchetQuery <> " FOR UPDATE")
|
||||
#else
|
||||
getRatchet_ ratchetQuery
|
||||
#endif
|
||||
{-# INLINE getRatchetForUpdate #-}
|
||||
|
||||
getRatchet :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
|
||||
getRatchet db connId =
|
||||
firstRow' ratchet SERatchetNotFound $ DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId)
|
||||
getRatchet = getRatchet_ ratchetQuery
|
||||
{-# INLINE getRatchet #-}
|
||||
|
||||
ratchetQuery :: Query
|
||||
ratchetQuery = "SELECT ratchet_state FROM ratchets WHERE conn_id = ?"
|
||||
|
||||
getRatchet_ :: Query -> DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
|
||||
getRatchet_ q db connId =
|
||||
firstRow' ratchet SERatchetNotFound $ DB.query db q (Only connId)
|
||||
where
|
||||
ratchet = maybe (Left SERatchetNotFound) Right . fromOnly
|
||||
|
||||
@@ -1963,13 +2003,15 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
|
||||
|
||||
-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored.
|
||||
createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash)
|
||||
createServer_ db newSrv@ProtocolServer {host, port, keyHash} =
|
||||
getServerKeyHash_ db newSrv >>= \case
|
||||
Right keyHash_ -> pure keyHash_
|
||||
Left _ -> insertNewServer_ $> Nothing
|
||||
createServer_ db newSrv@ProtocolServer {host, port, keyHash} = do
|
||||
r <- insertNewServer_
|
||||
if null r
|
||||
then getServerKeyHash_ db newSrv >>= either E.throwIO pure
|
||||
else pure Nothing
|
||||
where
|
||||
insertNewServer_ :: IO [Only Int]
|
||||
insertNewServer_ =
|
||||
DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash)
|
||||
DB.query db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?) ON CONFLICT (host, port) DO NOTHING RETURNING 1" (host, port, keyHash)
|
||||
|
||||
-- | Returns the passed server key hash if it is different from the stored one, or the error if the server does not exist.
|
||||
getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash))
|
||||
@@ -2166,23 +2208,27 @@ getConnIds :: DB.Connection -> IO [ConnId]
|
||||
getConnIds db = map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted = 0"
|
||||
|
||||
getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getConn = getAnyConn False
|
||||
getConn = getAnyConn False False
|
||||
{-# INLINE getConn #-}
|
||||
|
||||
getConnForUpdate :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getConnForUpdate = getAnyConn False True
|
||||
{-# INLINE getConnForUpdate #-}
|
||||
|
||||
getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getDeletedConn = getAnyConn True
|
||||
getDeletedConn = getAnyConn True False
|
||||
{-# INLINE getDeletedConn #-}
|
||||
|
||||
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getAnyConn :: Bool -> Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getAnyConn = getAnyConn_ getRcvQueuesByConnId_ getSndQueuesByConnId_
|
||||
{-# INLINE getAnyConn #-}
|
||||
|
||||
getAnyConn_ ::
|
||||
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) ->
|
||||
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) ->
|
||||
(Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq)))
|
||||
getAnyConn_ getRQs getSQs deleted' db connId =
|
||||
getConnData deleted' db connId >>= \case
|
||||
(Bool -> Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq)))
|
||||
getAnyConn_ getRQs getSQs deleted' forUpdate db connId =
|
||||
getConnData deleted' forUpdate db connId >>= \case
|
||||
Just (cData, cMode) -> do
|
||||
rQ <- getRQs db connId
|
||||
sQ <- getSQs db connId
|
||||
@@ -2281,28 +2327,39 @@ getAnyConns_ ::
|
||||
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) ->
|
||||
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) ->
|
||||
(Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError (SomeConn' rq sq)])
|
||||
getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' db
|
||||
getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' False db
|
||||
|
||||
getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))]
|
||||
getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db
|
||||
getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False False db
|
||||
|
||||
handleDBError :: E.SomeException -> IO (Either StoreError a)
|
||||
handleDBError = pure . Left . SEInternal . bshow
|
||||
#endif
|
||||
|
||||
getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData deleted' db connId' =
|
||||
getConnData :: Bool -> Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
getConnData deleted' forUpdate db connId' =
|
||||
maybeFirstRow rowToConnData $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
||||
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
|
||||
FROM connections
|
||||
WHERE conn_id = ? AND deleted = ?
|
||||
|]
|
||||
( [sql|
|
||||
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
||||
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
|
||||
FROM connections
|
||||
WHERE conn_id = ? AND deleted = ?
|
||||
|]
|
||||
#if defined(dbPostgres)
|
||||
<> (if forUpdate then " FOR UPDATE" else "")
|
||||
#endif
|
||||
)
|
||||
(connId', BI deleted')
|
||||
|
||||
lockConnForUpdate :: DB.Connection -> ConnId -> IO ()
|
||||
lockConnForUpdate db connId = do
|
||||
#if defined(dbPostgres)
|
||||
_ :: [Only Int] <- DB.query db "SELECT 1 FROM connections WHERE conn_id = ? FOR UPDATE" (Only connId)
|
||||
#endif
|
||||
pure ()
|
||||
|
||||
rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode)
|
||||
rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) =
|
||||
(ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
|
||||
@@ -2347,7 +2404,11 @@ checkRatchetKeyHashExists db connId hash =
|
||||
maybeFirstRow' False fromOnlyBI $
|
||||
DB.query
|
||||
db
|
||||
"SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
||||
( "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(connId, Binary hash)
|
||||
|
||||
deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
|
||||
@@ -2471,11 +2532,15 @@ retrieveLastIdsAndHashRcv_ dbConn connId = do
|
||||
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
|
||||
DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
( [sql|
|
||||
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(Only connId)
|
||||
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
@@ -2542,11 +2607,15 @@ retrieveLastIdsAndHashSnd_ dbConn connId = do
|
||||
firstRow id SEConnNotFound $
|
||||
DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
( [sql|
|
||||
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(Only connId)
|
||||
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
|
||||
@@ -2636,19 +2705,19 @@ ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing)
|
||||
ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action)
|
||||
|
||||
createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64
|
||||
createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} =
|
||||
getXFTPServerId_ db newSrv >>= \case
|
||||
Right srvId -> pure srvId
|
||||
Left _ -> insertNewServer_
|
||||
where
|
||||
insertNewServer_ = do
|
||||
DB.execute db "INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) VALUES (?,?,?)" (host, port, keyHash)
|
||||
insertedRowId db
|
||||
|
||||
getXFTPServerId_ :: DB.Connection -> XFTPServer -> IO (Either StoreError Int64)
|
||||
getXFTPServerId_ db ProtocolServer {host, port, keyHash} = do
|
||||
firstRow fromOnly SEXFTPServerNotFound $
|
||||
DB.query db "SELECT xftp_server_id FROM xftp_servers WHERE xftp_host = ? AND xftp_port = ? AND xftp_key_hash = ?" (host, port, keyHash)
|
||||
createXFTPServer_ db ProtocolServer {host, port, keyHash} = do
|
||||
Only serverId : _ <-
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (xftp_host, xftp_port, xftp_key_hash)
|
||||
DO UPDATE SET xftp_host = EXCLUDED.xftp_host
|
||||
RETURNING xftp_server_id
|
||||
|]
|
||||
(host, port, keyHash)
|
||||
pure serverId
|
||||
|
||||
createRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId)
|
||||
createRcvFile db gVar userId fd@FileDescription {chunks} prefixPath tmpPath file approvedRelays = runExceptT $ do
|
||||
@@ -2728,6 +2797,13 @@ getRcvFileRedirects db rcvFileId = do
|
||||
redirects <- fromOnly <$$> DB.query db "SELECT rcv_file_id FROM rcv_files WHERE redirect_id = ?" (Only rcvFileId)
|
||||
fmap catMaybes . forM redirects $ getRcvFile db >=> either (const $ pure Nothing) (pure . Just)
|
||||
|
||||
lockRcvFileForUpdate :: DB.Connection -> DBRcvFileId -> IO ()
|
||||
lockRcvFileForUpdate db rcvFileId = do
|
||||
#if defined(dbPostgres)
|
||||
_ :: [Only Int] <- DB.query db "SELECT 1 FROM rcv_files WHERE rcv_file_id = ? FOR UPDATE" (Only rcvFileId)
|
||||
#endif
|
||||
pure ()
|
||||
|
||||
getRcvFile :: DB.Connection -> DBRcvFileId -> IO (Either StoreError RcvFile)
|
||||
getRcvFile db rcvFileId = runExceptT $ do
|
||||
f@RcvFile {rcvFileEntityId, userId, tmpPath} <- ExceptT getFile
|
||||
@@ -2739,11 +2815,15 @@ getRcvFile db rcvFileId = runExceptT $ do
|
||||
firstRow toFile SEFileNotFound $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest
|
||||
FROM rcv_files
|
||||
WHERE rcv_file_id = ?
|
||||
|]
|
||||
( [sql|
|
||||
SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest
|
||||
FROM rcv_files
|
||||
WHERE rcv_file_id = ?
|
||||
|]
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(Only rcvFileId)
|
||||
where
|
||||
toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, BoolInt, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile
|
||||
@@ -3004,6 +3084,13 @@ getSndFileIdByEntityId_ db sndFileEntityId =
|
||||
firstRow fromOnly SEFileNotFound $
|
||||
DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only (Binary sndFileEntityId))
|
||||
|
||||
lockSndFileForUpdate :: DB.Connection -> DBSndFileId -> IO ()
|
||||
lockSndFileForUpdate db sndFileId = do
|
||||
#if defined(dbPostgres)
|
||||
_ :: [Only Int] <- DB.query db "SELECT 1 FROM snd_files WHERE snd_file_id = ? FOR UPDATE" (Only sndFileId)
|
||||
#endif
|
||||
pure ()
|
||||
|
||||
getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile)
|
||||
getSndFile db sndFileId = runExceptT $ do
|
||||
f@SndFile {sndFileEntityId, userId, numRecipients, prefixPath} <- ExceptT getFile
|
||||
@@ -3015,11 +3102,15 @@ getSndFile db sndFileId = runExceptT $ do
|
||||
firstRow toFile SEFileNotFound $
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest
|
||||
FROM snd_files
|
||||
WHERE snd_file_id = ?
|
||||
|]
|
||||
( [sql|
|
||||
SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest
|
||||
FROM snd_files
|
||||
WHERE snd_file_id = ?
|
||||
|]
|
||||
#if defined(dbPostgres)
|
||||
<> " FOR UPDATE"
|
||||
#endif
|
||||
)
|
||||
(Only sndFileId)
|
||||
where
|
||||
toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, BoolInt, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.Postgres.DB
|
||||
@@ -12,29 +13,32 @@ module Simplex.Messaging.Agent.Store.Postgres.DB
|
||||
execute,
|
||||
execute_,
|
||||
executeMany,
|
||||
PSQL.query,
|
||||
PSQL.query_,
|
||||
query,
|
||||
query_,
|
||||
blobFieldDecoder,
|
||||
fromTextField_,
|
||||
)
|
||||
where
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad (void)
|
||||
import qualified Data.ByteString as B
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Text (Text)
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Word (Word16, Word32)
|
||||
import Database.PostgreSQL.Simple (ResultError (..))
|
||||
import Database.PostgreSQL.Simple (Connection, ResultError (..), SqlError (..), FromRow, ToRow)
|
||||
import qualified Database.PostgreSQL.Simple as PSQL
|
||||
import Database.PostgreSQL.Simple.FromField (Field (..), FieldParser, FromField (..), returnError)
|
||||
import Database.PostgreSQL.Simple.ToField (ToField (..))
|
||||
import Database.PostgreSQL.Simple.TypeInfo.Static (textOid, varcharOid)
|
||||
import Database.PostgreSQL.Simple.Types (Query (..))
|
||||
|
||||
newtype BoolInt = BI {unBI :: Bool}
|
||||
|
||||
type SQLError = PSQL.SqlError
|
||||
type SQLError = SqlError
|
||||
|
||||
instance FromField BoolInt where
|
||||
fromField field dat = BI . (/= (0 :: Int)) <$> fromField field dat
|
||||
@@ -44,18 +48,30 @@ instance ToField BoolInt where
|
||||
toField (BI b) = toField ((if b then 1 else 0) :: Int)
|
||||
{-# INLINE toField #-}
|
||||
|
||||
execute :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> q -> IO ()
|
||||
execute db q qs = void $ PSQL.execute db q qs
|
||||
execute :: ToRow q => PSQL.Connection -> Query -> q -> IO ()
|
||||
execute db q qs = void $ PSQL.execute db q qs `E.catch` addSql q
|
||||
{-# INLINE execute #-}
|
||||
|
||||
execute_ :: PSQL.Connection -> PSQL.Query -> IO ()
|
||||
execute_ db q = void $ PSQL.execute_ db q
|
||||
execute_ :: PSQL.Connection -> Query -> IO ()
|
||||
execute_ db q = void $ PSQL.execute_ db q `E.catch` addSql q
|
||||
{-# INLINE execute_ #-}
|
||||
|
||||
executeMany :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> [q] -> IO ()
|
||||
executeMany db q qs = void $ PSQL.executeMany db q qs
|
||||
executeMany :: ToRow q => PSQL.Connection -> Query -> [q] -> IO ()
|
||||
executeMany db q qs = void $ PSQL.executeMany db q qs `E.catch` addSql q
|
||||
{-# INLINE executeMany #-}
|
||||
|
||||
query :: (ToRow q, FromRow r) => PSQL.Connection -> Query -> q -> IO [r]
|
||||
query db q qs = PSQL.query db q qs `E.catch` addSql q
|
||||
{-# INLINE query #-}
|
||||
|
||||
query_ :: FromRow r => Connection -> Query -> IO [r]
|
||||
query_ db q = PSQL.query_ db q `E.catch` addSql q
|
||||
{-# INLINE query_ #-}
|
||||
|
||||
addSql :: Query -> SqlError -> IO r
|
||||
addSql q e@SqlError {sqlErrorHint = hint} =
|
||||
E.throwIO e {sqlErrorHint = if B.null hint then fromQuery q else hint <> ", " <> fromQuery q}
|
||||
|
||||
-- orphan instances
|
||||
|
||||
-- used in FileSize
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
@@ -219,6 +220,9 @@ pattern SENT msgId = A.SENT msgId Nothing
|
||||
pattern Rcvd :: AgentMsgId -> AEvent 'AEConn
|
||||
pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}]
|
||||
|
||||
pattern Rcvd' :: AgentMsgId -> AgentMsgId -> AEvent 'AEConn
|
||||
pattern Rcvd' aMsgId rcvdMsgId <- RCVD MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} [MsgReceipt {agentMsgId = rcvdMsgId, msgRcptStatus = MROk}]
|
||||
|
||||
pattern INV :: AConnectionRequestUri -> AEvent 'AEConn
|
||||
pattern INV cReq = A.INV cReq Nothing
|
||||
|
||||
@@ -331,8 +335,8 @@ functionalAPITests ps = do
|
||||
describe "Duplex connection - delivery stress test" $ do
|
||||
describe "one way (50)" $ testMatrix2Stress ps $ runAgentClientStressTestOneWay 50
|
||||
xdescribe "one way (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestOneWay 1000
|
||||
describe "two way concurrently (50)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 25
|
||||
xdescribe "two way concurrently (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 500
|
||||
describe "two way concurrently (50)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 50
|
||||
xdescribe "two way concurrently (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 1000
|
||||
describe "Establishing duplex connection, different PQ settings" $ do
|
||||
testPQMatrix2 ps $ runAgentClientTestPQ False True
|
||||
describe "Establishing duplex connection v2, different Ratchet versions" $
|
||||
@@ -784,36 +788,64 @@ runAgentClientStressTestOneWay n pqSupport sqSecured viaProxy alice bob baseId =
|
||||
|
||||
runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do
|
||||
let pqEnc = PQEncryption $ supportPQ pqSupport
|
||||
(aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob
|
||||
let proxySrv = if viaProxy then Just testSMPServer else Nothing
|
||||
message i = "message " <> bshow i
|
||||
loop a bId mIdVar i = do
|
||||
when (i <= n) $ do
|
||||
mId <- msgId <$> A.sendMessage a bId pqEnc SMP.noMsgFlags (message i)
|
||||
liftIO $ mId >= i `shouldBe` True
|
||||
let getEvent = do
|
||||
get a >>= \case
|
||||
("", c, A.SENT _ srv) -> liftIO $ c == bId && srv == proxySrv `shouldBe` True
|
||||
("", c, QCONT) -> do
|
||||
liftIO $ c == bId `shouldBe` True
|
||||
getEvent
|
||||
("", c, Msg' mId pq msg) -> do
|
||||
-- tests that mId increases
|
||||
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
|
||||
liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True
|
||||
ackMessage a bId mId Nothing
|
||||
r -> liftIO $ expectationFailure $ "wrong message: " <> show r
|
||||
getEvent
|
||||
amId <- newTVarIO 0
|
||||
bmId <- newTVarIO 0
|
||||
concurrently_
|
||||
(forM_ ([1 .. n * 2] :: [Int64]) $ loop alice bobId amId)
|
||||
(forM_ ([1 .. n * 2] :: [Int64]) $ loop bob aliceId bmId)
|
||||
let n2 = n `div` 2
|
||||
mapConcurrently_ id
|
||||
( [ send alice bobId [1 .. n2],
|
||||
send alice bobId [n2 + 1 .. n],
|
||||
send bob aliceId [1 .. n2],
|
||||
send bob aliceId [n2 + 1 .. n],
|
||||
receive alice bobId amId (n, n, n, 2 * n),
|
||||
receive bob aliceId bmId (n, n, n, 2 * n)
|
||||
] :: [ExceptT AgentErrorType IO ()]
|
||||
)
|
||||
liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob"
|
||||
where
|
||||
msgId = subtract baseId . fst
|
||||
pqEnc = PQEncryption $ supportPQ pqSupport
|
||||
proxySrv = if viaProxy then Just testSMPServer else Nothing
|
||||
message i = "message " <> bshow i
|
||||
send :: AgentClient -> ConnId -> [Int64] -> ExceptT AgentErrorType IO ()
|
||||
send a bId = mapM_ $ \i -> void $ A.sendMessage a bId pqEnc SMP.noMsgFlags (message i)
|
||||
receive :: AgentClient -> ConnId -> TVar AgentMsgId -> (Int64, Int64, Int64, Int64) -> ExceptT AgentErrorType IO ()
|
||||
receive a bId mIdVar acc' = loop acc' >> liftIO drain
|
||||
where
|
||||
drain =
|
||||
timeout 50000 (get a)
|
||||
>>= mapM_ (\case ("", _, QCONT) -> drain; r -> expectationFailure $ "unexpected: " <> show r)
|
||||
loop (0, 0, 0, 0) = pure ()
|
||||
loop acc@(!s, !m, !r, !o) =
|
||||
timeout 3000000 (get a) >>= \case
|
||||
Nothing -> error $ "timeout " <> show acc
|
||||
Just evt -> case evt of
|
||||
("", c, A.SENT mId srv) -> do
|
||||
liftIO $ c == bId && srv == proxySrv `shouldBe` True
|
||||
unless (s > 0) $ error "unexpected SENT"
|
||||
loop (s - 1, m, r, o)
|
||||
("", c, QCONT) -> do
|
||||
liftIO $ c == bId `shouldBe` True
|
||||
loop (s, m, r, o)
|
||||
("", c, Msg' mId pq msg) -> do
|
||||
-- tests that mId increases
|
||||
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
|
||||
liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True
|
||||
ackMessageAsync a "123" bId mId (Just "")
|
||||
unless (m > 0) $ error "unexpected MSG"
|
||||
loop (s, m - 1, r, o)
|
||||
("", c, Rcvd' mId rcvdMsgId) -> do
|
||||
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
|
||||
liftIO $ c == bId `shouldBe` True
|
||||
ackMessageAsync a "123" bId mId Nothing
|
||||
unless (r > 0) $ error "unexpected RCVD"
|
||||
loop (s, m, r - 1, o)
|
||||
("123", c, OK) -> do
|
||||
liftIO $ c == bId `shouldBe` True
|
||||
unless (o > 0) $ error "unexpected OK"
|
||||
loop (s, m, r, o - 1)
|
||||
_ -> liftIO $ expectationFailure $ "unexpected: " <> show r
|
||||
|
||||
testEnablePQEncryption :: HasCallStack => IO ()
|
||||
testEnablePQEncryption =
|
||||
@@ -1001,10 +1033,10 @@ noMessages_ :: Bool -> HasCallStack => AgentClient -> String -> Expectation
|
||||
noMessages_ ingoreQCONT c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` get c >>= \case
|
||||
50000 `timeout` get c >>= \case
|
||||
Just (_, _, QCONT) | ingoreQCONT -> noMessages_ ingoreQCONT c err
|
||||
Just msg -> error $ err <> ": " <> show msg
|
||||
_ -> return ()
|
||||
Nothing -> return ()
|
||||
|
||||
testRejectContactRequest :: HasCallStack => IO ()
|
||||
testRejectContactRequest =
|
||||
@@ -3713,7 +3745,15 @@ getSMPAgentClient' clientId cfg' initServers dbPath = do
|
||||
|
||||
#if defined(dbPostgres)
|
||||
createStore :: String -> IO (Either MigrationError DBStore)
|
||||
createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) (MigrationConfig MCError Nothing)
|
||||
createStore schema = createAgentStore dbOpts $ MigrationConfig MCError Nothing
|
||||
where
|
||||
dbOpts =
|
||||
DBOpts
|
||||
{ connstr = testDBConnstr,
|
||||
schema = B.pack schema,
|
||||
poolSize = 10,
|
||||
createSchema = True
|
||||
}
|
||||
|
||||
insertUser :: DBStore -> IO ()
|
||||
insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES")
|
||||
|
||||
@@ -87,7 +87,7 @@ ntfTestStoreDBOpts =
|
||||
DBOpts
|
||||
{ connstr = ntfTestServerDBConnstr,
|
||||
schema = "ntf_server",
|
||||
poolSize = 3,
|
||||
poolSize = 10,
|
||||
createSchema = True
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ testStoreDBOpts =
|
||||
DBOpts
|
||||
{ connstr = testServerDBConnstr,
|
||||
schema = "smp_server",
|
||||
poolSize = 3,
|
||||
poolSize = 10,
|
||||
createSchema = True
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user