agent: batch loading of connections with PostgreSQL client (#1639)

* agent: batch loading of connections with PostgreSQL client

* batch more

* optimize getPendingCommandServers

* fix Bool conversion

* enable all tests

* cleanup
This commit is contained in:
Evgeny
2025-09-16 14:28:06 +01:00
committed by GitHub
parent b020a08ea0
commit 7b7616ce7e
5 changed files with 127 additions and 59 deletions

View File

@@ -139,7 +139,7 @@ import Control.Monad.Reader
import Control.Monad.Trans.Except
import Crypto.Random (ChaChaDRG)
import qualified Data.Aeson as J
import Data.Bifunctor (bimap, first, second)
import Data.Bifunctor (bimap, first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.), (.::), (.::.))
@@ -1269,7 +1269,7 @@ subscribeConnections' c connIds = do
errs' = M.map (Left . storeError) errs
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
resumeDelivery cs
lift $ resumeConnCmds c $ M.keys cs
resumeConnCmds c $ M.keys cs
rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs)
rcvRs' <- storeClientServiceAssocs rcvRs
ns <- asks ntfSupervisor
@@ -1473,10 +1473,10 @@ resumeSrvCmds :: AgentClient -> ConnId -> Maybe SMPServer -> AM' ()
resumeSrvCmds = void .:. getAsyncCmdWorker False
{-# INLINE resumeSrvCmds #-}
resumeConnCmds :: AgentClient -> [ConnId] -> AM' ()
resumeConnCmds :: AgentClient -> [ConnId] -> AM ()
resumeConnCmds c connIds = do
connSrvs <- rights . zipWith (second . (,)) connIds <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds)
mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs
connSrvs <- withStore' c (`getPendingCommandServers` connIds)
lift $ mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs
getAsyncCmdWorker :: Bool -> AgentClient -> ConnId -> Maybe SMPServer -> AM' Worker
getAsyncCmdWorker hasWork c connId server =
@@ -2451,23 +2451,23 @@ sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM ()
sendNtfConnCommands c cmd = do
ns <- asks ntfSupervisor
connIds <- liftIO $ S.toList <$> getSubscriptions c
rs <- lift $ withStoreBatch' c (\db -> map (getConnData db) connIds)
rs <- withStore' c (`getConnsData` connIds)
let (connIds', cErrs) = enabledNtfConns (zip connIds rs)
forM_ (L.nonEmpty connIds') $ \connIds'' ->
atomically $ writeTBQueue (ntfSubQ ns) (cmd, connIds'')
unless (null cErrs) $ atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS cErrs)
where
enabledNtfConns :: [(ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)])
enabledNtfConns :: [(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)])
enabledNtfConns = foldr addEnabledConn ([], [])
where
addEnabledConn ::
(ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode))) ->
(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode))) ->
([ConnId], [(ConnId, AgentErrorType)]) ->
([ConnId], [(ConnId, AgentErrorType)])
addEnabledConn cData_ (cIds, errs) = case cData_ of
(_, Right (Just (ConnData {connId, enableNtfs}, _))) -> if enableNtfs then (connId : cIds, errs) else (cIds, errs)
(connId, Right Nothing) -> (cIds, (connId, INTERNAL "no connection data") : errs)
(connId, Left e) -> (cIds, (connId, e) : errs)
(connId, Left e) -> (cIds, (connId, INTERNAL (show e)) : errs)
setNtfServers :: AgentClient -> [NtfServer] -> IO ()
setNtfServers c = atomically . writeTVar (ntfServers c)

View File

@@ -43,7 +43,7 @@ module Simplex.Messaging.Agent.Store.AgentStore
getDeletedConn,
getConns,
getDeletedConns,
getConnData,
getConnsData,
setConnDeleted,
setConnUserId,
setConnAgentVersion,
@@ -257,8 +257,9 @@ import Data.List (foldl', sortBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.Ord (Down (..))
import qualified Data.Set as S
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime)
import Data.Word (Word32)
@@ -287,12 +288,14 @@ import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Agent.Store.Entity
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, firstRow, firstRow', ifM, maybeFirstRow, maybeFirstRow', tshow, ($>>=), (<$$>))
import Simplex.Messaging.Util
import Simplex.Messaging.Version.Internal
import qualified UnliftIO.Exception as E
import UnliftIO.STM
#if defined(dbPostgres)
import Database.PostgreSQL.Simple (Only (..), Query, SqlError, (:.) (..))
import Data.List (sortOn)
import Data.Map.Strict (Map)
import Database.PostgreSQL.Simple (In (..), Only (..), Query, SqlError, (:.) (..))
import Database.PostgreSQL.Simple.Errors (constraintViolation)
import Database.PostgreSQL.Simple.SqlQQ (sql)
#else
@@ -427,7 +430,7 @@ deleteConnRecord db connId = DB.execute db "DELETE FROM connections WHERE conn_i
checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} =
maybeFirstRow' False fromOnly $
maybeFirstRow' False fromOnlyBI $
DB.query
db
"SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1"
@@ -1070,7 +1073,7 @@ toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity,
checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool
checkRcvMsgHashExists db connId hash =
maybeFirstRow' False fromOnly $
maybeFirstRow' False fromOnlyBI $
DB.query
db
"SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"
@@ -1298,21 +1301,26 @@ insertedRowId db = fromOnly . head <$> DB.query_ db q
q = "SELECT last_insert_rowid()"
#endif
getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer]
getPendingCommandServers db connId = do
getPendingCommandServers :: DB.Connection -> [ConnId] -> IO [(ConnId, NonEmpty (Maybe SMPServer))]
getPendingCommandServers db connIds =
-- TODO review whether this can break if, e.g., the server has another key hash.
map smpServer
<$> DB.query
mapMaybe connServers . groupOn' rowConnId
<$> DB.query_
db
[sql|
SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash)
SELECT DISTINCT c.conn_id, c.host, c.port, COALESCE(c.server_key_hash, s.key_hash)
FROM commands c
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
WHERE conn_id = ?
ORDER BY c.conn_id
|]
(Only connId)
where
rowConnId (Only connId :. _) = connId
connServers rs =
let connId = rowConnId $ L.head rs
srvs = L.map (\(_ :. r) -> smpServer r) rs
in if connId `S.member` conns then Just (connId, srvs) else Nothing
smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash
conns = S.fromList connIds
getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand))
getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed
@@ -2030,21 +2038,19 @@ getDeletedConn = getAnyConn True
{-# INLINE getDeletedConn #-}
getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getAnyConn deleted' dbConn connId =
getConnData dbConn connId >>= \case
getAnyConn deleted' db connId =
getConnData deleted' db connId >>= \case
Just (cData, cMode) -> do
rQ <- getRcvQueuesByConnId_ db connId
sQ <- getSndQueuesByConnId_ db connId
pure $ case (rQ, sQ, cMode) of
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
_ -> Left SEConnNotFound
Nothing -> pure $ Left SEConnNotFound
Just (cData@ConnData {deleted}, cMode)
| deleted /= deleted' -> pure $ Left SEConnNotFound
| otherwise -> do
rQ <- getRcvQueuesByConnId_ dbConn connId
sQ <- getSndQueuesByConnId_ dbConn connId
pure $ case (rQ, sQ, cMode) of
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
_ -> Left SEConnNotFound
getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
getConns = getAnyConns_ False
@@ -2054,28 +2060,84 @@ getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
getDeletedConns = getAnyConns_ True
{-# INLINE getDeletedConns #-}
#if defined(dbPostgres)
getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db
getAnyConns_ deleted' db connIds = do
cs <- getConnsData_ deleted' db connIds
let connIds' = M.keys cs
rQs :: Map ConnId (NonEmpty RcvQueue) <- getRcvQueuesByConnIds_ connIds'
sQs :: Map ConnId (NonEmpty SndQueue) <- getSndQueuesByConnIds_ connIds'
pure $ map (result cs rQs sQs) connIds
where
handleDBError :: E.SomeException -> IO (Either StoreError SomeConn)
handleDBError = pure . Left . SEInternal . bshow
getRcvQueuesByConnIds_ connIds' =
toQueueMap primaryFirst toRcvQueue
<$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds'))
where
primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} =
compare (Down p) (Down p') <> compare i i'
getSndQueuesByConnIds_ connIds' =
toQueueMap primaryFirst toSndQueue
<$> DB.query db (sndQueueQuery <> " WHERE q.conn_id IN ?") (Only (In connIds'))
where
primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} =
compare (Down p) (Down p') <> compare i i'
toQueueMap primaryFst toQueue =
M.fromList . map (\qs@(q :| _) -> (qConnId q, L.sortBy primaryFst qs)) . groupOn' qConnId . sortOn qConnId . map toQueue
result cs rQs sQs connId = case M.lookup connId cs of
Just (cData, cMode) -> case (M.lookup connId rQs, M.lookup connId sQs, cMode) of
(Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs)
(Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq)
(Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq)
(Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData)
_ -> Left SEConnNotFound
Nothing -> Left SEConnNotFound
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData db connId' =
maybeFirstRow cData $
getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))]
getConnsData db connIds = do
cs <- getConnsData_ False db connIds
pure $ map (Right . (`M.lookup` cs)) connIds
getConnsData_ :: Bool -> DB.Connection -> [ConnId] -> IO (Map ConnId (ConnData, ConnectionMode))
getConnsData_ deleted' db connIds =
M.fromList . map ((\c@(ConnData {connId}, _) -> (connId, c)) . rowToConnData) <$>
DB.query
db
[sql|
SELECT
user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
FROM connections
WHERE conn_id = ?
WHERE conn_id IN ? AND deleted = ?
|]
(Only connId')
where
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) =
(ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
(In connIds, BI deleted')
#else
getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]
getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db
getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))]
getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db
handleDBError :: E.SomeException -> IO (Either StoreError a)
handleDBError = pure . Left . SEInternal . bshow
#endif
getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
getConnData deleted' db connId' =
maybeFirstRow rowToConnData $
DB.query
db
[sql|
SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
FROM connections
WHERE conn_id = ? AND deleted = ?
|]
(connId', BI deleted')
rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode)
rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) =
(ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO ()
setConnDeleted db waitDelivery connId
@@ -2114,7 +2176,7 @@ addProcessedRatchetKeyHash db connId hash =
checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool
checkRatchetKeyHashExists db connId hash =
maybeFirstRow' False fromOnly $
maybeFirstRow' False fromOnlyBI $
DB.query
db
"SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1"

View File

@@ -54,7 +54,7 @@ import Network.Socket (ServiceName)
import Simplex.Messaging.Agent.Store.AgentStore ()
import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore)
import Simplex.Messaging.Agent.Store.Postgres.Common
import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_)
import Simplex.Messaging.Agent.Store.Postgres.DB (fromTextField_)
import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..))
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
@@ -64,7 +64,6 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat
import Simplex.Messaging.Notifications.Server.Store.Migrations
import Simplex.Messaging.Notifications.Server.Store.Types
import Simplex.Messaging.Notifications.Server.StoreLog
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer)
import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate)
import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_)
@@ -77,6 +76,8 @@ import System.IO (IOMode (..), hFlush, stdout, withFile)
import Text.Hex (decodeHex)
#if !defined(dbPostgres)
import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Util (eitherToMaybe)
#endif

View File

@@ -219,12 +219,17 @@ firstRow' f e a = (f <=< listToEither e) <$> a
groupOn :: Eq k => (a -> k) -> [a] -> [[a]]
groupOn = groupBy . eqOn
where
-- it is equivalent to groupBy ((==) `on` f),
-- but it redefines `on` to avoid duplicate computation for most values.
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
eqOn f x = let fx = f x in \y -> fx == f y
groupOn' :: Eq k => (a -> k) -> [a] -> [NonEmpty a]
groupOn' = L.groupBy . eqOn
-- it is equivalent to groupBy ((==) `on` f),
-- but it redefines `on` to avoid duplicate computation for most values.
-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn
-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f`
eqOn :: Eq k => (a -> k) -> a -> a -> Bool
eqOn f x = let fx = f x in \y -> fx == f y
{-# INLINE eqOn #-}
groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]]
groupAllOn f = groupOn f . sortOn f

View File

@@ -212,7 +212,7 @@ createStore randSuffix migrations confirmMigrations = do
poolSize = 1,
createSchema = True
}
createDBStore dbOpts migrations confirmMigrations
createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing)
cleanup :: Word32 -> IO ()
cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix)