diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 982c3099a..27967bfd6 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -139,7 +139,7 @@ import Control.Monad.Reader import Control.Monad.Trans.Except import Crypto.Random (ChaChaDRG) import qualified Data.Aeson as J -import Data.Bifunctor (bimap, first, second) +import Data.Bifunctor (bimap, first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) @@ -1269,7 +1269,7 @@ subscribeConnections' c connIds = do errs' = M.map (Left . storeError) errs (subRs, rcvQs) = M.mapEither rcvQueueOrResult cs resumeDelivery cs - lift $ resumeConnCmds c $ M.keys cs + resumeConnCmds c $ M.keys cs rcvRs <- lift $ connResults . fst <$> subscribeQueues c (concat $ M.elems rcvQs) rcvRs' <- storeClientServiceAssocs rcvRs ns <- asks ntfSupervisor @@ -1473,10 +1473,10 @@ resumeSrvCmds :: AgentClient -> ConnId -> Maybe SMPServer -> AM' () resumeSrvCmds = void .:. getAsyncCmdWorker False {-# INLINE resumeSrvCmds #-} -resumeConnCmds :: AgentClient -> [ConnId] -> AM' () +resumeConnCmds :: AgentClient -> [ConnId] -> AM () resumeConnCmds c connIds = do - connSrvs <- rights . zipWith (second . (,)) connIds <$> withStoreBatch' c (\db -> fmap (getPendingCommandServers db) connIds) - mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs + connSrvs <- withStore' c (`getPendingCommandServers` connIds) + lift $ mapM_ (\(connId, srvs) -> mapM_ (resumeSrvCmds c connId) srvs) connSrvs getAsyncCmdWorker :: Bool -> AgentClient -> ConnId -> Maybe SMPServer -> AM' Worker getAsyncCmdWorker hasWork c connId server = @@ -2451,23 +2451,23 @@ sendNtfConnCommands :: AgentClient -> NtfSupervisorCommand -> AM () sendNtfConnCommands c cmd = do ns <- asks ntfSupervisor connIds <- liftIO $ S.toList <$> getSubscriptions c - rs <- lift $ withStoreBatch' c (\db -> map (getConnData db) connIds) + rs <- withStore' c (`getConnsData` connIds) let (connIds', cErrs) = enabledNtfConns (zip connIds rs) forM_ (L.nonEmpty connIds') $ \connIds'' -> atomically $ writeTBQueue (ntfSubQ ns) (cmd, connIds'') unless (null cErrs) $ atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone $ ERRS cErrs) where - enabledNtfConns :: [(ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)]) + enabledNtfConns :: [(ConnId, Either StoreError (Maybe (ConnData, ConnectionMode)))] -> ([ConnId], [(ConnId, AgentErrorType)]) enabledNtfConns = foldr addEnabledConn ([], []) where addEnabledConn :: - (ConnId, Either AgentErrorType (Maybe (ConnData, ConnectionMode))) -> + (ConnId, Either StoreError (Maybe (ConnData, ConnectionMode))) -> ([ConnId], [(ConnId, AgentErrorType)]) -> ([ConnId], [(ConnId, AgentErrorType)]) addEnabledConn cData_ (cIds, errs) = case cData_ of (_, Right (Just (ConnData {connId, enableNtfs}, _))) -> if enableNtfs then (connId : cIds, errs) else (cIds, errs) (connId, Right Nothing) -> (cIds, (connId, INTERNAL "no connection data") : errs) - (connId, Left e) -> (cIds, (connId, e) : errs) + (connId, Left e) -> (cIds, (connId, INTERNAL (show e)) : errs) setNtfServers :: AgentClient -> [NtfServer] -> IO () setNtfServers c = atomically . writeTVar (ntfServers c) diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index c958ae710..350d3bfe7 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -43,7 +43,7 @@ module Simplex.Messaging.Agent.Store.AgentStore getDeletedConn, getConns, getDeletedConns, - getConnData, + getConnsData, setConnDeleted, setConnUserId, setConnAgentVersion, @@ -257,8 +257,9 @@ import Data.List (foldl', sortBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.Ord (Down (..)) +import qualified Data.Set as S import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Word (Word32) @@ -287,12 +288,14 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Agent.Store.Entity import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, firstRow, firstRow', ifM, maybeFirstRow, maybeFirstRow', tshow, ($>>=), (<$$>)) +import Simplex.Messaging.Util import Simplex.Messaging.Version.Internal import qualified UnliftIO.Exception as E import UnliftIO.STM #if defined(dbPostgres) -import Database.PostgreSQL.Simple (Only (..), Query, SqlError, (:.) (..)) +import Data.List (sortOn) +import Data.Map.Strict (Map) +import Database.PostgreSQL.Simple (In (..), Only (..), Query, SqlError, (:.) (..)) import Database.PostgreSQL.Simple.Errors (constraintViolation) import Database.PostgreSQL.Simple.SqlQQ (sql) #else @@ -427,7 +430,7 @@ deleteConnRecord db connId = DB.execute db "DELETE FROM connections WHERE conn_i checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = - maybeFirstRow' False fromOnly $ + maybeFirstRow' False fromOnlyBI $ DB.query db "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" @@ -1070,7 +1073,7 @@ toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool checkRcvMsgHashExists db connId hash = - maybeFirstRow' False fromOnly $ + maybeFirstRow' False fromOnlyBI $ DB.query db "SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" @@ -1298,21 +1301,26 @@ insertedRowId db = fromOnly . head <$> DB.query_ db q q = "SELECT last_insert_rowid()" #endif -getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer] -getPendingCommandServers db connId = do +getPendingCommandServers :: DB.Connection -> [ConnId] -> IO [(ConnId, NonEmpty (Maybe SMPServer))] +getPendingCommandServers db connIds = -- TODO review whether this can break if, e.g., the server has another key hash. - map smpServer - <$> DB.query + mapMaybe connServers . groupOn' rowConnId + <$> DB.query_ db [sql| - SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) + SELECT DISTINCT c.conn_id, c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) FROM commands c LEFT JOIN servers s ON s.host = c.host AND s.port = c.port - WHERE conn_id = ? + ORDER BY c.conn_id |] - (Only connId) where + rowConnId (Only connId :. _) = connId + connServers rs = + let connId = rowConnId $ L.head rs + srvs = L.map (\(_ :. r) -> smpServer r) rs + in if connId `S.member` conns then Just (connId, srvs) else Nothing smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash + conns = S.fromList connIds getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed @@ -2030,21 +2038,19 @@ getDeletedConn = getAnyConn True {-# INLINE getDeletedConn #-} getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getAnyConn deleted' dbConn connId = - getConnData dbConn connId >>= \case +getAnyConn deleted' db connId = + getConnData deleted' db connId >>= \case + Just (cData, cMode) -> do + rQ <- getRcvQueuesByConnId_ db connId + sQ <- getSndQueuesByConnId_ db connId + pure $ case (rQ, sQ, cMode) of + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) + _ -> Left SEConnNotFound Nothing -> pure $ Left SEConnNotFound - Just (cData@ConnData {deleted}, cMode) - | deleted /= deleted' -> pure $ Left SEConnNotFound - | otherwise -> do - rQ <- getRcvQueuesByConnId_ dbConn connId - sQ <- getSndQueuesByConnId_ dbConn connId - pure $ case (rQ, sQ, cMode) of - (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) - (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) - (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) - (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) - (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) - _ -> Left SEConnNotFound getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getConns = getAnyConns_ False @@ -2054,28 +2060,84 @@ getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] getDeletedConns = getAnyConns_ True {-# INLINE getDeletedConns #-} +#if defined(dbPostgres) getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db +getAnyConns_ deleted' db connIds = do + cs <- getConnsData_ deleted' db connIds + let connIds' = M.keys cs + rQs :: Map ConnId (NonEmpty RcvQueue) <- getRcvQueuesByConnIds_ connIds' + sQs :: Map ConnId (NonEmpty SndQueue) <- getSndQueuesByConnIds_ connIds' + pure $ map (result cs rQs sQs) connIds where - handleDBError :: E.SomeException -> IO (Either StoreError SomeConn) - handleDBError = pure . Left . SEInternal . bshow + getRcvQueuesByConnIds_ connIds' = + toQueueMap primaryFirst toRcvQueue + <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id IN ? AND q.deleted = 0") (Only (In connIds')) + where + primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = + compare (Down p) (Down p') <> compare i i' + getSndQueuesByConnIds_ connIds' = + toQueueMap primaryFirst toSndQueue + <$> DB.query db (sndQueueQuery <> " WHERE q.conn_id IN ?") (Only (In connIds')) + where + primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = + compare (Down p) (Down p') <> compare i i' + toQueueMap primaryFst toQueue = + M.fromList . map (\qs@(q :| _) -> (qConnId q, L.sortBy primaryFst qs)) . groupOn' qConnId . sortOn qConnId . map toQueue + result cs rQs sQs connId = case M.lookup connId cs of + Just (cData, cMode) -> case (M.lookup connId rQs, M.lookup connId sQs, cMode) of + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) + _ -> Left SEConnNotFound + Nothing -> Left SEConnNotFound -getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) -getConnData db connId' = - maybeFirstRow cData $ +getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] +getConnsData db connIds = do + cs <- getConnsData_ False db connIds + pure $ map (Right . (`M.lookup` cs)) connIds + +getConnsData_ :: Bool -> DB.Connection -> [ConnId] -> IO (Map ConnId (ConnData, ConnectionMode)) +getConnsData_ deleted' db connIds = + M.fromList . map ((\c@(ConnData {connId}, _) -> (connId, c)) . rowToConnData) <$> DB.query db [sql| - SELECT - user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, + SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support FROM connections - WHERE conn_id = ? + WHERE conn_id IN ? AND deleted = ? |] - (Only connId') - where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) + (In connIds, BI deleted') + +#else +getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] +getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db + +getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] +getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db + +handleDBError :: E.SomeException -> IO (Either StoreError a) +handleDBError = pure . Left . SEInternal . bshow +#endif + +getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) +getConnData deleted' db connId' = + maybeFirstRow rowToConnData $ + DB.query + db + [sql| + SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support + FROM connections + WHERE conn_id = ? AND deleted = ? + |] + (connId', BI deleted') + +rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode) +rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () setConnDeleted db waitDelivery connId @@ -2114,7 +2176,7 @@ addProcessedRatchetKeyHash db connId hash = checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool checkRatchetKeyHashExists db connId hash = - maybeFirstRow' False fromOnly $ + maybeFirstRow' False fromOnlyBI $ DB.query db "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index e1ae09cad..78891796f 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -54,7 +54,7 @@ import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common -import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_) +import Simplex.Messaging.Agent.Store.Postgres.DB (fromTextField_) import Simplex.Messaging.Agent.Store.Shared (MigrationConfig (..)) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -64,7 +64,6 @@ import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubDat import Simplex.Messaging.Notifications.Server.Store.Migrations import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Server.StoreLog -import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, ServiceId, pattern SMPServer) import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate) import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) @@ -77,6 +76,8 @@ import System.IO (IOMode (..), hFlush, stdout, withFile) import Text.Hex (decodeHex) #if !defined(dbPostgres) +import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (eitherToMaybe) #endif diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 1fcee0783..57fb11c21 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -219,12 +219,17 @@ firstRow' f e a = (f <=< listToEither e) <$> a groupOn :: Eq k => (a -> k) -> [a] -> [[a]] groupOn = groupBy . eqOn - where - -- it is equivalent to groupBy ((==) `on` f), - -- but it redefines `on` to avoid duplicate computation for most values. - -- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn - -- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f` - eqOn f x = let fx = f x in \y -> fx == f y + +groupOn' :: Eq k => (a -> k) -> [a] -> [NonEmpty a] +groupOn' = L.groupBy . eqOn + +-- it is equivalent to groupBy ((==) `on` f), +-- but it redefines `on` to avoid duplicate computation for most values. +-- source: https://hackage.haskell.org/package/extra-1.7.13/docs/src/Data.List.Extra.html#groupOn +-- the on2 in this package is specialized to only use `==` as the function, `eqOn f` is equivalent to `(==) `on` f` +eqOn :: Eq k => (a -> k) -> a -> a -> Bool +eqOn f x = let fx = f x in \y -> fx == f y +{-# INLINE eqOn #-} groupAllOn :: Ord k => (a -> k) -> [a] -> [[a]] groupAllOn f = groupOn f . sortOn f diff --git a/tests/AgentTests/MigrationTests.hs b/tests/AgentTests/MigrationTests.hs index e4de45c7a..8245cfd51 100644 --- a/tests/AgentTests/MigrationTests.hs +++ b/tests/AgentTests/MigrationTests.hs @@ -212,7 +212,7 @@ createStore randSuffix migrations confirmMigrations = do poolSize = 1, createSchema = True } - createDBStore dbOpts migrations confirmMigrations + createDBStore dbOpts migrations (MigrationConfig confirmMigrations Nothing) cleanup :: Word32 -> IO () cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix)