agent: lock rows for concurrent queries in PostgreSQL (#1688)

* agent: lock rows for concurrent queries in PostgreSQL

* fix race conditions in workers

* refactor
This commit is contained in:
Evgeny
2026-01-08 11:09:58 +00:00
committed by GitHub
parent 07604a146f
commit 6aadcf1f3f
8 changed files with 296 additions and 136 deletions

View File

@@ -223,6 +223,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
agentXFTPDownloadChunk c userId digest replica chunkSpec
liftIO $ waitUntilForeground c
(entityId, complete, progress) <- withStore c $ \db -> runExceptT $ do
liftIO $ lockRcvFileForUpdate db rcvFileId
liftIO $ updateRcvFileChunkReceived db (rcvChunkReplicaId replica) rcvChunkId relChunkPath
RcvFile {size = FileSize currentSize, chunks, redirect} <- ExceptT $ getRcvFile db rcvFileId
let rcvd = receivedSize chunks
@@ -413,6 +414,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
(digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath
withStore c $ \db -> do
lockSndFileForUpdate db sndFileId
updateSndFileEncrypted db sndFileId digest chunkSpecsDigests
getSndFile db sndFileId
else pure sndFile
@@ -530,6 +532,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec'
liftIO $ waitUntilForeground c
sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do
lockSndFileForUpdate db sndFileId
updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded
getSndFile db sndFileId
let uploaded = uploadedSize chunks

View File

@@ -1145,7 +1145,8 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup =
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport}
case sq_ of
Just sq@SndQueue {e2ePubKey = Just _k} -> do
e2eSndParams <- withStore c $ \db ->
e2eSndParams <- withStore c $ \db -> do
lockConnForUpdate db connId
getSndRatchet db connId v >>= \case
Right r -> pure $ Right $ snd r
Left e -> do
@@ -1159,6 +1160,7 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup =
sndKey_ = snd <$> invLink_
(q, _) <- lift $ newSndQueue userId "" qInfo sndKey_
withStore c $ \db -> runExceptT $ do
liftIO $ lockConnForUpdate db connId
e2eSndParams <- createRatchet_ db g maxSupported pqSupport e2eRcvParams
sq' <- maybe (ExceptT $ updateNewConnSnd db connId q) pure sq_
pure (cData, sq', e2eSndParams, lnkId_)
@@ -1237,7 +1239,8 @@ joinConnSrv c nm userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup su
AgentConfig {smpClientVRange = vr, smpAgentVRange, e2eEncryptVRange = e2eVR} <- asks config
let qUri = SMPQueueUri vr $ (rcvSMPQueueAddress rq) {queueMode = Just QMMessaging}
crData = ConnReqUriData SSSimplex smpAgentVRange [qUri] Nothing
e2eRcvParams <- withStore' c $ \db ->
e2eRcvParams <- withStore' c $ \db -> do
lockConnForUpdate db connId
getRatchetX3dhKeys db connId >>= \case
Right keys -> pure $ CR.mkRcvE2ERatchetParams (maxVersion e2eVR) keys
Left e -> do
@@ -1957,7 +1960,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server,
withRetryLock2 ri' qLock $ \riState loop -> do
liftIO $ waitWhileSuspended c
liftIO $ waitForUserNetwork c
resp <- tryError $ case msgType of
resp <- tryAllErrors $ case msgType of
AM_CONN_INFO -> sendConfirmation c NRMBackground sq msgBody
AM_CONN_INFO_REPLY -> sendConfirmation c NRMBackground sq msgBody
_ -> case pendingMsgPrepData_ of
@@ -2097,10 +2100,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server,
notifyDelMsgs :: InternalId -> AgentErrorType -> UTCTime -> AM ()
notifyDelMsgs msgId err expireTs = do
notifyDel msgId $ MERR (unId msgId) err
msgIds_ <- withStore' c $ \db -> getExpiredSndMessages db connId sq expireTs
msgIds_ <- withStore' c $ \db -> do
msgIds_ <- getExpiredSndMessages db connId sq expireTs
forM_ msgIds_ $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure ()
pure msgIds_
forM_ (L.nonEmpty msgIds_) $ \msgIds -> do
notify $ MERRS (L.map unId msgIds) err
withStore' c $ \db -> forM_ msgIds $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure ()
atomically $ incSMPServerStat' c userId server sentExpiredErrs (length msgIds_ + 1)
delMsg :: InternalId -> AM ()
delMsg = delMsgKeep False
@@ -3025,7 +3030,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
throwE e
agentClientMsg :: TVar ChaChaDRG -> ByteString -> AM (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448))
agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do
rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY
liftIO $ lockConnForUpdate db connId
rc <- ExceptT $ getRatchetForUpdate db connId -- ratchet state pre-decryption - required for processing EREADY
(agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
@@ -3260,6 +3266,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
Just sqs' -> do
(sq_@SndQueue {sndPrivateKey}, dhPublicKey) <- lift $ newSndQueue userId connId qInfo Nothing
sq2 <- withStore c $ \db -> do
lockConnForUpdate db connId
liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs
addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
logServer "<--" c srv rId $ "MSG <QADD>:" <> logSecret' srvMsgId <> " " <> logSecret (senderId queueAddress)
@@ -3564,7 +3571,7 @@ agentRatchetEncrypt db cData msg getPaddedLen pqEnc_ currentE2EVersion = do
agentRatchetEncryptHeader :: DB.Connection -> ConnData -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> CR.VersionE2E -> ExceptT StoreError IO (CR.MsgEncryptKeyX448, Int, PQEncryption)
agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport} getPaddedLen pqEnc_ currentE2EVersion = do
rc <- ExceptT $ getRatchet db connId
rc <- ExceptT $ getRatchetForUpdate db connId
let paddedLen = getPaddedLen v pqSupport
(mek, rc') <- withExceptT (SEAgentError . cryptoError) $ CR.rcEncryptHeader rc pqEnc_ currentE2EVersion
liftIO $ updateRatchet db connId rc' CR.SMDNoChange
@@ -3573,7 +3580,7 @@ agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport}
-- encoded EncAgentMessage -> encoded AgentMessage
agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption)
agentRatchetDecrypt g db connId encAgentMsg = do
rc <- ExceptT $ getRatchet db connId
rc <- ExceptT $ getRatchetForUpdate db connId
agentRatchetDecrypt' g db connId rc encAgentMsg
agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption)

View File

@@ -2114,16 +2114,17 @@ withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (
withWork c doWork = withWork_ c doWork . withStore' c
{-# INLINE withWork #-}
-- setting doWork flag to "no work" before getWork rather than after prevents race condition when flag is set to "has work" by another thread after getWork call.
withWork_ :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m ()
withWork_ c doWork getWork action =
getWork >>= \case
Right (Just r) -> action r
Right Nothing -> noWork
-- worker is stopped here (noWork) because the next iteration is likely to produce the same result
noWork >> getWork >>= \case
Right (Just r) -> hasWork >> action r
Right Nothing -> pure ()
Left e
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
| otherwise -> notifyErr INTERNAL e
| isWorkItemError e -> notifyErr (CRITICAL False) e -- worker remains stopped here because the next iteration is likely to produce the same result
| otherwise -> hasWork >> notifyErr INTERNAL e
where
hasWork = atomically $ hasWorkToDo' doWork
noWork = liftIO $ noWorkToDo doWork
notifyErr err e = do
logError $ "withWork_ error: " <> tshow e
@@ -2131,22 +2132,24 @@ withWork_ c doWork getWork action =
withWorkItems :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' [Either e' a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m ()
withWorkItems c doWork getWork action = do
getWork >>= \case
Right [] -> noWork
noWork >> getWork >>= \case
Right [] -> pure ()
Right rs -> do
let (errs, items) = partitionEithers rs
case L.nonEmpty items of
Just items' -> action items'
Just items' -> hasWork >> action items'
Nothing -> do
let criticalErr = find isWorkItemError errs
forM_ criticalErr $ \err -> do
notifyErr (CRITICAL False) err
when (all isWorkItemError errs) noWork
case find isWorkItemError errs of
Nothing -> hasWork
Just err -> do
notifyErr (CRITICAL False) err
unless (all isWorkItemError errs) hasWork
forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (\e -> ("", INTERNAL $ show e))
Left e
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
| otherwise -> notifyErr INTERNAL e
| isWorkItemError e -> notifyErr (CRITICAL False) e
| otherwise -> hasWork >> notifyErr INTERNAL e
where
hasWork = atomically $ hasWorkToDo' doWork
noWork = liftIO $ noWorkToDo doWork
notifyErr err e = do
logError $ "withWorkItems error: " <> tshow e

View File

@@ -52,6 +52,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
getConnSubs,
getDeletedConns,
getConnsData,
lockConnForUpdate,
setConnDeleted,
setConnUserId,
setConnAgentVersion,
@@ -140,6 +141,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
createRatchet,
deleteRatchet,
getRatchet,
getRatchetForUpdate,
getSkippedMsgKeys,
updateRatchet,
-- Async commands
@@ -187,6 +189,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
-- Rcv files
createRcvFile,
createRcvFileRedirect,
lockRcvFileForUpdate,
getRcvFile,
getRcvFileByEntityId,
getRcvFileRedirects,
@@ -207,6 +210,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
getRcvFilesExpired,
-- Snd files
createSndFile,
lockSndFileForUpdate,
getSndFile,
getSndFileByEntityId,
getNextSndFileToPrepare,
@@ -405,7 +409,7 @@ createNewConn db gVar cData cMode = do
-- TODO [certs rcv] store clientServiceId from NewRcvQueue
updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue)
updateNewConnRcv db connId rq subMode =
getConn db connId $>>= \case
getConnForUpdate db connId $>>= \case
(SomeConn _ NewConnection {}) -> updateConn
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
(SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnRcv" $ connType c
@@ -415,7 +419,7 @@ updateNewConnRcv db connId rq subMode =
updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
updateNewConnSnd db connId sq =
getConn db connId $>>= \case
getConnForUpdate db connId $>>= \case
(SomeConn _ NewConnection {}) -> updateConn
(SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnSnd" $ connType c
where
@@ -449,7 +453,11 @@ checkConfirmedSndQueueExists_ db SndQueue {server, sndId} =
maybeFirstRow' False fromOnlyBI $
DB.query
db
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
( "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
#if defined(dpPostgres)
<> " FOR UPDATE"
#endif
)
(host server, port server, sndId, New)
getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn))
@@ -488,14 +496,14 @@ deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
upgradeRcvConnToDuplex db connId sq =
getConn db connId $>>= \case
getConnForUpdate db connId $>>= \case
(SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq
(SomeConn c _) -> pure . Left . SEBadConnType "upgradeRcvConnToDuplex" $ connType c
-- TODO [certs rcv] store clientServiceId from NewRcvQueue
upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue)
upgradeSndConnToDuplex db connId rq subMode =
getConn db connId >>= \case
getConnForUpdate db connId >>= \case
Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode
Right (SomeConn c _) -> pure . Left . SEBadConnType "upgradeSndConnToDuplex" $ connType c
_ -> pure $ Left SEConnNotFound
@@ -503,7 +511,7 @@ upgradeSndConnToDuplex db connId rq subMode =
-- TODO [certs rcv] store clientServiceId from NewRcvQueue
addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue)
addConnRcvQueue db connId rq subMode =
getConn db connId >>= \case
getConnForUpdate db connId >>= \case
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode
Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnRcvQueue" $ connType c
_ -> pure $ Left SEConnNotFound
@@ -515,7 +523,7 @@ addConnRcvQueue_ db connId rq@RcvQueue {server} subMode = do
addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
addConnSndQueue db connId sq =
getConn db connId >>= \case
getConnForUpdate db connId >>= \case
Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq
Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnSndQueue" $ connType c
_ -> pure $ Left SEConnNotFound
@@ -1048,7 +1056,14 @@ setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError
setMsgUserAck db connId agentMsgId = runExceptT $ do
(dbRcvId, srvMsgId) <-
ExceptT . firstRow id (SEMsgNotFound "setMsgUserAck") $
DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId)
DB.query
db
( "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?"
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(connId, agentMsgId)
rq <- ExceptT $ getRcvQueueById db connId dbRcvId
liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (BI True, connId, agentMsgId)
pure (rq, srvMsgId)
@@ -1120,6 +1135,9 @@ deleteMsgContent db connId msgId = do
deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO ()
deleteDeliveredSndMsg db connId msgId = do
#if defined(dbPostgres)
_ :: [Only Int] <- DB.query db "SELECT 1 FROM messages WHERE conn_id = ? AND internal_id = ? FOR UPDATE" (connId, msgId)
#endif
cnt <- countPendingSndDeliveries_ db connId msgId
when (cnt == 0) $ deleteMsg db connId msgId
@@ -1138,11 +1156,15 @@ deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do
maybeFirstRow id $
DB.query
db
[sql|
SELECT rcpt_status, snd_message_body_id FROM snd_messages
WHERE NOT EXISTS (SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0)
AND conn_id = ? AND internal_id = ?
|]
( [sql|
SELECT rcpt_status, snd_message_body_id FROM snd_messages
WHERE NOT EXISTS (SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0)
AND conn_id = ? AND internal_id = ?
|]
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(connId, msgId, connId, msgId)
deleteMsgAndBody :: (Maybe MsgReceiptStatus, Maybe Int64) -> IO ()
deleteMsgAndBody (rcptStatus_, sndMsgBodyId_) = do
@@ -1151,9 +1173,11 @@ deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do
Just MROk -> deleteMsg
_ -> if keepForReceipt then deleteMsgContent else deleteMsg
del db connId msgId
forM_ sndMsgBodyId_ $ \bodyId ->
-- Delete message body if it is not used by any snd message.
-- The current snd message is already deleted by deleteMsg or cleared by deleteMsgContent.
forM_ sndMsgBodyId_ $ \bodyId -> do
#if defined(dbPostgres)
-- lock for concurrent deletion of different records in snd_messages pointing to the same record in snd_message_bodies
_ :: [Only Int] <- DB.query db "SELECT 1 FROM snd_message_bodies WHERE snd_message_body_id = ? FOR UPDATE" (Only bodyId)
#endif
DB.execute
db
[sql|
@@ -1260,9 +1284,25 @@ deleteRatchet :: DB.Connection -> ConnId -> IO ()
deleteRatchet db connId =
DB.execute db "DELETE FROM ratchets WHERE conn_id = ?" (Only connId)
getRatchetForUpdate :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
getRatchetForUpdate =
#if defined(dbPostgres)
getRatchet_ (ratchetQuery <> " FOR UPDATE")
#else
getRatchet_ ratchetQuery
#endif
{-# INLINE getRatchetForUpdate #-}
getRatchet :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
getRatchet db connId =
firstRow' ratchet SERatchetNotFound $ DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId)
getRatchet = getRatchet_ ratchetQuery
{-# INLINE getRatchet #-}
ratchetQuery :: Query
ratchetQuery = "SELECT ratchet_state FROM ratchets WHERE conn_id = ?"
getRatchet_ :: Query -> DB.Connection -> ConnId -> IO (Either StoreError RatchetX448)
getRatchet_ q db connId =
firstRow' ratchet SERatchetNotFound $ DB.query db q (Only connId)
where
ratchet = maybe (Left SERatchetNotFound) Right . fromOnly
@@ -1963,13 +2003,15 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f,
-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored.
createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash)
createServer_ db newSrv@ProtocolServer {host, port, keyHash} =
getServerKeyHash_ db newSrv >>= \case
Right keyHash_ -> pure keyHash_
Left _ -> insertNewServer_ $> Nothing
createServer_ db newSrv@ProtocolServer {host, port, keyHash} = do
r <- insertNewServer_
if null r
then getServerKeyHash_ db newSrv >>= either E.throwIO pure
else pure Nothing
where
insertNewServer_ :: IO [Only Int]
insertNewServer_ =
DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash)
DB.query db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?) ON CONFLICT (host, port) DO NOTHING RETURNING 1" (host, port, keyHash)
-- | Returns the passed server key hash if it is different from the stored one, or the error if the server does not exist.
getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash))
@@ -2166,23 +2208,27 @@ getConnIds :: DB.Connection -> IO [ConnId]
getConnIds db = map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted = 0"
getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConn = getAnyConn False
getConn = getAnyConn False False
{-# INLINE getConn #-}
getConnForUpdate :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConnForUpdate = getAnyConn False True
{-# INLINE getConnForUpdate #-}
getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getDeletedConn = getAnyConn True
getDeletedConn = getAnyConn True False
{-# INLINE getDeletedConn #-}
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getAnyConn :: Bool -> Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getAnyConn = getAnyConn_ getRcvQueuesByConnId_ getSndQueuesByConnId_
{-# INLINE getAnyConn #-}
getAnyConn_ ::
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) ->
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) ->
(Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq)))
getAnyConn_ getRQs getSQs deleted' db connId =
getConnData deleted' db connId >>= \case
(Bool -> Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq)))
getAnyConn_ getRQs getSQs deleted' forUpdate db connId =
getConnData deleted' forUpdate db connId >>= \case
Just (cData, cMode) -> do
rQ <- getRQs db connId
sQ <- getSQs db connId
@@ -2281,28 +2327,39 @@ getAnyConns_ ::
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) ->
(DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) ->
(Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError (SomeConn' rq sq)])
getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' db
getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' False db
getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))]
getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db
getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False False db
handleDBError :: E.SomeException -> IO (Either StoreError a)
handleDBError = pure . Left . SEInternal . bshow
#endif
getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData deleted' db connId' =
getConnData :: Bool -> Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData deleted' forUpdate db connId' =
maybeFirstRow rowToConnData $
DB.query
db
[sql|
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
FROM connections
WHERE conn_id = ? AND deleted = ?
|]
( [sql|
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
FROM connections
WHERE conn_id = ? AND deleted = ?
|]
#if defined(dbPostgres)
<> (if forUpdate then " FOR UPDATE" else "")
#endif
)
(connId', BI deleted')
lockConnForUpdate :: DB.Connection -> ConnId -> IO ()
lockConnForUpdate db connId = do
#if defined(dbPostgres)
_ :: [Only Int] <- DB.query db "SELECT 1 FROM connections WHERE conn_id = ? FOR UPDATE" (Only connId)
#endif
pure ()
rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode)
rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) =
(ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
@@ -2347,7 +2404,11 @@ checkRatchetKeyHashExists db connId hash =
maybeFirstRow' False fromOnlyBI $
DB.query
db
"SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
( "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(connId, Binary hash)
deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO ()
@@ -2471,11 +2532,15 @@ retrieveLastIdsAndHashRcv_ dbConn connId = do
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
DB.query
dbConn
[sql|
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
FROM connections
WHERE conn_id = ?
|]
( [sql|
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
FROM connections
WHERE conn_id = ?
|]
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(Only connId)
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
@@ -2542,11 +2607,15 @@ retrieveLastIdsAndHashSnd_ dbConn connId = do
firstRow id SEConnNotFound $
DB.query
dbConn
[sql|
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
FROM connections
WHERE conn_id = ?
|]
( [sql|
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
FROM connections
WHERE conn_id = ?
|]
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(Only connId)
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
@@ -2636,19 +2705,19 @@ ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing)
ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action)
createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64
createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} =
getXFTPServerId_ db newSrv >>= \case
Right srvId -> pure srvId
Left _ -> insertNewServer_
where
insertNewServer_ = do
DB.execute db "INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) VALUES (?,?,?)" (host, port, keyHash)
insertedRowId db
getXFTPServerId_ :: DB.Connection -> XFTPServer -> IO (Either StoreError Int64)
getXFTPServerId_ db ProtocolServer {host, port, keyHash} = do
firstRow fromOnly SEXFTPServerNotFound $
DB.query db "SELECT xftp_server_id FROM xftp_servers WHERE xftp_host = ? AND xftp_port = ? AND xftp_key_hash = ?" (host, port, keyHash)
createXFTPServer_ db ProtocolServer {host, port, keyHash} = do
Only serverId : _ <-
DB.query
db
[sql|
INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash)
VALUES (?, ?, ?)
ON CONFLICT (xftp_host, xftp_port, xftp_key_hash)
DO UPDATE SET xftp_host = EXCLUDED.xftp_host
RETURNING xftp_server_id
|]
(host, port, keyHash)
pure serverId
createRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId)
createRcvFile db gVar userId fd@FileDescription {chunks} prefixPath tmpPath file approvedRelays = runExceptT $ do
@@ -2728,6 +2797,13 @@ getRcvFileRedirects db rcvFileId = do
redirects <- fromOnly <$$> DB.query db "SELECT rcv_file_id FROM rcv_files WHERE redirect_id = ?" (Only rcvFileId)
fmap catMaybes . forM redirects $ getRcvFile db >=> either (const $ pure Nothing) (pure . Just)
lockRcvFileForUpdate :: DB.Connection -> DBRcvFileId -> IO ()
lockRcvFileForUpdate db rcvFileId = do
#if defined(dbPostgres)
_ :: [Only Int] <- DB.query db "SELECT 1 FROM rcv_files WHERE rcv_file_id = ? FOR UPDATE" (Only rcvFileId)
#endif
pure ()
getRcvFile :: DB.Connection -> DBRcvFileId -> IO (Either StoreError RcvFile)
getRcvFile db rcvFileId = runExceptT $ do
f@RcvFile {rcvFileEntityId, userId, tmpPath} <- ExceptT getFile
@@ -2739,11 +2815,15 @@ getRcvFile db rcvFileId = runExceptT $ do
firstRow toFile SEFileNotFound $
DB.query
db
[sql|
SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest
FROM rcv_files
WHERE rcv_file_id = ?
|]
( [sql|
SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest
FROM rcv_files
WHERE rcv_file_id = ?
|]
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(Only rcvFileId)
where
toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, BoolInt, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile
@@ -3004,6 +3084,13 @@ getSndFileIdByEntityId_ db sndFileEntityId =
firstRow fromOnly SEFileNotFound $
DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only (Binary sndFileEntityId))
lockSndFileForUpdate :: DB.Connection -> DBSndFileId -> IO ()
lockSndFileForUpdate db sndFileId = do
#if defined(dbPostgres)
_ :: [Only Int] <- DB.query db "SELECT 1 FROM snd_files WHERE snd_file_id = ? FOR UPDATE" (Only sndFileId)
#endif
pure ()
getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile)
getSndFile db sndFileId = runExceptT $ do
f@SndFile {sndFileEntityId, userId, numRecipients, prefixPath} <- ExceptT getFile
@@ -3015,11 +3102,15 @@ getSndFile db sndFileId = runExceptT $ do
firstRow toFile SEFileNotFound $
DB.query
db
[sql|
SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest
FROM snd_files
WHERE snd_file_id = ?
|]
( [sql|
SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest
FROM snd_files
WHERE snd_file_id = ?
|]
#if defined(dbPostgres)
<> " FOR UPDATE"
#endif
)
(Only sndFileId)
where
toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, BoolInt, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Agent.Store.Postgres.DB
@@ -12,29 +13,32 @@ module Simplex.Messaging.Agent.Store.Postgres.DB
execute,
execute_,
executeMany,
PSQL.query,
PSQL.query_,
query,
query_,
blobFieldDecoder,
fromTextField_,
)
where
import qualified Control.Exception as E
import Control.Monad (void)
import qualified Data.ByteString as B
import Data.ByteString.Char8 (ByteString)
import Data.Int (Int64)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import Data.Typeable (Typeable)
import Data.Word (Word16, Word32)
import Database.PostgreSQL.Simple (ResultError (..))
import Database.PostgreSQL.Simple (Connection, ResultError (..), SqlError (..), FromRow, ToRow)
import qualified Database.PostgreSQL.Simple as PSQL
import Database.PostgreSQL.Simple.FromField (Field (..), FieldParser, FromField (..), returnError)
import Database.PostgreSQL.Simple.ToField (ToField (..))
import Database.PostgreSQL.Simple.TypeInfo.Static (textOid, varcharOid)
import Database.PostgreSQL.Simple.Types (Query (..))
newtype BoolInt = BI {unBI :: Bool}
type SQLError = PSQL.SqlError
type SQLError = SqlError
instance FromField BoolInt where
fromField field dat = BI . (/= (0 :: Int)) <$> fromField field dat
@@ -44,18 +48,30 @@ instance ToField BoolInt where
toField (BI b) = toField ((if b then 1 else 0) :: Int)
{-# INLINE toField #-}
execute :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> q -> IO ()
execute db q qs = void $ PSQL.execute db q qs
execute :: ToRow q => PSQL.Connection -> Query -> q -> IO ()
execute db q qs = void $ PSQL.execute db q qs `E.catch` addSql q
{-# INLINE execute #-}
execute_ :: PSQL.Connection -> PSQL.Query -> IO ()
execute_ db q = void $ PSQL.execute_ db q
execute_ :: PSQL.Connection -> Query -> IO ()
execute_ db q = void $ PSQL.execute_ db q `E.catch` addSql q
{-# INLINE execute_ #-}
executeMany :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> [q] -> IO ()
executeMany db q qs = void $ PSQL.executeMany db q qs
executeMany :: ToRow q => PSQL.Connection -> Query -> [q] -> IO ()
executeMany db q qs = void $ PSQL.executeMany db q qs `E.catch` addSql q
{-# INLINE executeMany #-}
query :: (ToRow q, FromRow r) => PSQL.Connection -> Query -> q -> IO [r]
query db q qs = PSQL.query db q qs `E.catch` addSql q
{-# INLINE query #-}
query_ :: FromRow r => Connection -> Query -> IO [r]
query_ db q = PSQL.query_ db q `E.catch` addSql q
{-# INLINE query_ #-}
addSql :: Query -> SqlError -> IO r
addSql q e@SqlError {sqlErrorHint = hint} =
E.throwIO e {sqlErrorHint = if B.null hint then fromQuery q else hint <> ", " <> fromQuery q}
-- orphan instances
-- used in FileSize

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
@@ -219,6 +220,9 @@ pattern SENT msgId = A.SENT msgId Nothing
pattern Rcvd :: AgentMsgId -> AEvent 'AEConn
pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}]
pattern Rcvd' :: AgentMsgId -> AgentMsgId -> AEvent 'AEConn
pattern Rcvd' aMsgId rcvdMsgId <- RCVD MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} [MsgReceipt {agentMsgId = rcvdMsgId, msgRcptStatus = MROk}]
pattern INV :: AConnectionRequestUri -> AEvent 'AEConn
pattern INV cReq = A.INV cReq Nothing
@@ -331,8 +335,8 @@ functionalAPITests ps = do
describe "Duplex connection - delivery stress test" $ do
describe "one way (50)" $ testMatrix2Stress ps $ runAgentClientStressTestOneWay 50
xdescribe "one way (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestOneWay 1000
describe "two way concurrently (50)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 25
xdescribe "two way concurrently (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 500
describe "two way concurrently (50)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 50
xdescribe "two way concurrently (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 1000
describe "Establishing duplex connection, different PQ settings" $ do
testPQMatrix2 ps $ runAgentClientTestPQ False True
describe "Establishing duplex connection v2, different Ratchet versions" $
@@ -784,36 +788,64 @@ runAgentClientStressTestOneWay n pqSupport sqSecured viaProxy alice bob baseId =
runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do
let pqEnc = PQEncryption $ supportPQ pqSupport
(aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob
let proxySrv = if viaProxy then Just testSMPServer else Nothing
message i = "message " <> bshow i
loop a bId mIdVar i = do
when (i <= n) $ do
mId <- msgId <$> A.sendMessage a bId pqEnc SMP.noMsgFlags (message i)
liftIO $ mId >= i `shouldBe` True
let getEvent = do
get a >>= \case
("", c, A.SENT _ srv) -> liftIO $ c == bId && srv == proxySrv `shouldBe` True
("", c, QCONT) -> do
liftIO $ c == bId `shouldBe` True
getEvent
("", c, Msg' mId pq msg) -> do
-- tests that mId increases
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True
ackMessage a bId mId Nothing
r -> liftIO $ expectationFailure $ "wrong message: " <> show r
getEvent
amId <- newTVarIO 0
bmId <- newTVarIO 0
concurrently_
(forM_ ([1 .. n * 2] :: [Int64]) $ loop alice bobId amId)
(forM_ ([1 .. n * 2] :: [Int64]) $ loop bob aliceId bmId)
let n2 = n `div` 2
mapConcurrently_ id
( [ send alice bobId [1 .. n2],
send alice bobId [n2 + 1 .. n],
send bob aliceId [1 .. n2],
send bob aliceId [n2 + 1 .. n],
receive alice bobId amId (n, n, n, 2 * n),
receive bob aliceId bmId (n, n, n, 2 * n)
] :: [ExceptT AgentErrorType IO ()]
)
liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice"
liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob"
where
msgId = subtract baseId . fst
pqEnc = PQEncryption $ supportPQ pqSupport
proxySrv = if viaProxy then Just testSMPServer else Nothing
message i = "message " <> bshow i
send :: AgentClient -> ConnId -> [Int64] -> ExceptT AgentErrorType IO ()
send a bId = mapM_ $ \i -> void $ A.sendMessage a bId pqEnc SMP.noMsgFlags (message i)
receive :: AgentClient -> ConnId -> TVar AgentMsgId -> (Int64, Int64, Int64, Int64) -> ExceptT AgentErrorType IO ()
receive a bId mIdVar acc' = loop acc' >> liftIO drain
where
drain =
timeout 50000 (get a)
>>= mapM_ (\case ("", _, QCONT) -> drain; r -> expectationFailure $ "unexpected: " <> show r)
loop (0, 0, 0, 0) = pure ()
loop acc@(!s, !m, !r, !o) =
timeout 3000000 (get a) >>= \case
Nothing -> error $ "timeout " <> show acc
Just evt -> case evt of
("", c, A.SENT mId srv) -> do
liftIO $ c == bId && srv == proxySrv `shouldBe` True
unless (s > 0) $ error "unexpected SENT"
loop (s - 1, m, r, o)
("", c, QCONT) -> do
liftIO $ c == bId `shouldBe` True
loop (s, m, r, o)
("", c, Msg' mId pq msg) -> do
-- tests that mId increases
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True
ackMessageAsync a "123" bId mId (Just "")
unless (m > 0) $ error "unexpected MSG"
loop (s, m - 1, r, o)
("", c, Rcvd' mId rcvdMsgId) -> do
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
liftIO $ c == bId `shouldBe` True
ackMessageAsync a "123" bId mId Nothing
unless (r > 0) $ error "unexpected RCVD"
loop (s, m, r - 1, o)
("123", c, OK) -> do
liftIO $ c == bId `shouldBe` True
unless (o > 0) $ error "unexpected OK"
loop (s, m, r, o - 1)
_ -> liftIO $ expectationFailure $ "unexpected: " <> show r
testEnablePQEncryption :: HasCallStack => IO ()
testEnablePQEncryption =
@@ -1001,10 +1033,10 @@ noMessages_ :: Bool -> HasCallStack => AgentClient -> String -> Expectation
noMessages_ ingoreQCONT c err = tryGet `shouldReturn` ()
where
tryGet =
10000 `timeout` get c >>= \case
50000 `timeout` get c >>= \case
Just (_, _, QCONT) | ingoreQCONT -> noMessages_ ingoreQCONT c err
Just msg -> error $ err <> ": " <> show msg
_ -> return ()
Nothing -> return ()
testRejectContactRequest :: HasCallStack => IO ()
testRejectContactRequest =
@@ -3713,7 +3745,15 @@ getSMPAgentClient' clientId cfg' initServers dbPath = do
#if defined(dbPostgres)
createStore :: String -> IO (Either MigrationError DBStore)
createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) (MigrationConfig MCError Nothing)
createStore schema = createAgentStore dbOpts $ MigrationConfig MCError Nothing
where
dbOpts =
DBOpts
{ connstr = testDBConnstr,
schema = B.pack schema,
poolSize = 10,
createSchema = True
}
insertUser :: DBStore -> IO ()
insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES")

View File

@@ -87,7 +87,7 @@ ntfTestStoreDBOpts =
DBOpts
{ connstr = ntfTestServerDBConnstr,
schema = "ntf_server",
poolSize = 3,
poolSize = 10,
createSchema = True
}

View File

@@ -93,7 +93,7 @@ testStoreDBOpts =
DBOpts
{ connstr = testServerDBConnstr,
schema = "smp_server",
poolSize = 3,
poolSize = 10,
createSchema = True
}