From 8c250ebe19f56dd7d53572d984e8016cb0e4d658 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Tue, 19 Dec 2023 23:01:34 +0000 Subject: [PATCH] agent: batch sending messages (#922) * agent: batch sending messages (attempt 4) * handle errors in batch sending * batch attempt 5 (#923) * attempt 5 * remove IORefs * add liftA2 for 8.10 compat * remove db-related zipping * traversable --------- Co-authored-by: IC Rainbow * s/mapE/bindRight/ * name Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> * comment Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> * remove unused funcs --------- Co-authored-by: IC Rainbow Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> --- package.yaml | 3 +- simplexmq.cabal | 21 ++--- src/Simplex/Messaging/Agent.hs | 116 ++++++++++++++++++-------- src/Simplex/Messaging/Agent/Client.hs | 26 +++++- src/Simplex/Messaging/Agent/Lock.hs | 31 ++++++- src/Simplex/Messaging/Server/Stats.hs | 2 +- src/Simplex/Messaging/Util.hs | 12 +++ 7 files changed, 154 insertions(+), 57 deletions(-) diff --git a/package.yaml b/package.yaml index 7e53371d4..7134f6909 100644 --- a/package.yaml +++ b/package.yaml @@ -65,8 +65,7 @@ dependencies: - sqlcipher-simple == 0.4.* - stm == 2.5.* - temporary == 1.3.* - - time == 1.9.* - - time-compat == 1.9.* + - time == 1.12.* - time-manager == 0.0.* - tls >= 1.7.0 && < 1.8 - transformers == 0.6.* diff --git a/simplexmq.cabal b/simplexmq.cabal index 118021ce5..f240b8a59 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -212,8 +212,7 @@ library , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , tls >=1.7.0 && <1.8 , transformers ==0.6.* @@ -285,8 +284,7 @@ executable ntf-server , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , tls >=1.7.0 && <1.8 , transformers ==0.6.* @@ -358,8 +356,7 @@ executable smp-agent , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , tls >=1.7.0 && <1.8 , transformers ==0.6.* @@ -431,8 +428,7 @@ executable smp-server , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , tls >=1.7.0 && <1.8 , transformers ==0.6.* @@ -504,8 +500,7 @@ executable xftp , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , tls >=1.7.0 && <1.8 , transformers ==0.6.* @@ -577,8 +572,7 @@ executable xftp-server , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , tls >=1.7.0 && <1.8 , transformers ==0.6.* @@ -687,8 +681,7 @@ test-suite simplexmq-test , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* - , time ==1.9.* - , time-compat ==1.9.* + , time ==1.12.* , time-manager ==0.0.* , timeit ==2.0.* , tls >=1.7.0 && <1.8 diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 336d82fbf..675509fb4 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -64,6 +64,7 @@ module Simplex.Messaging.Agent resubscribeConnection, resubscribeConnections, sendMessage, + sendMessages, ackMessage, switchConnection, abortConnectionSwitch, @@ -119,14 +120,16 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Composition ((.:), (.:.), (.::), (.::.)) -import Data.Foldable (foldl') +import Data.Either (rights) +import Data.Foldable (foldl', toList) import Data.Functor (($>)) +import Data.Functor.Identity import Data.List (find) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.Text (Text) import qualified Data.Text as T import Data.Time.Clock @@ -277,6 +280,12 @@ resubscribeConnections c = withAgentEnv c . resubscribeConnections' c sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId sendMessage c = withAgentEnv c .:. sendMessage' c +type MsgReq = (ConnId, MsgFlags, MsgBody) + +-- | Send multiple messages to different connections (SEND command) +sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] +sendMessages c = withAgentEnv c . sendMessages' c + ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () ackMessage c = withAgentEnv c .:. ackMessage' c @@ -867,17 +876,29 @@ getNotificationMessage' c nonce encNtfInfo = do -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do - SomeConn _ conn <- withStore c (`getConn` connId) - case conn of - DuplexConnection cData _ sqs -> enqueueMsgs cData sqs - SndConnection cData sq -> enqueueMsgs cData [sq] - _ -> throwError $ CONN SIMPLEX +sendMessage' c connId msgFlags msg = liftEither . runIdentity =<< sendMessagesB c (Identity (Right (connId, msgFlags, msg))) + +-- | Send multiple messages to different connections (SEND command) in Reader monad +sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] +sendMessages' c = sendMessagesB c . map Right + +sendMessagesB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId)) +sendMessagesB c reqs = withConnLocks c connIds "sendMessages" $ do + reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) + let reqs'' = fmap (>>= prepareConn) reqs' + enqueueMessagesB c reqs'' where - enqueueMsgs :: ConnData -> NonEmpty SndQueue -> m AgentMsgId - enqueueMsgs cData sqs = do - when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED - enqueueMessages c cData sqs msgFlags $ A_MSG msg + prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) + prepareConn ((_, msgFlags, msg), SomeConn _ conn) = case conn of + DuplexConnection cData _ sqs -> prepareMsg cData sqs + SndConnection cData sq -> prepareMsg cData [sq] + _ -> Left $ CONN SIMPLEX + where + prepareMsg :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) + prepareMsg cData sqs + | ratchetSyncSendProhibited cData = Left $ CMD PROHIBITED + | otherwise = Right (cData, sqs, msgFlags, A_MSG msg) + connIds = map (\(connId, _, _) -> connId) $ rights $ toList reqs -- / async command processing v v v @@ -1056,22 +1077,37 @@ enqueueMessages c cData sqs msgFlags aMessage = do enqueueMessages' c cData sqs msgFlags aMessage enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessages' c cData (sq :| sqs) msgFlags aMessage = do - msgId <- enqueueMessage c cData sq msgFlags aMessage - mapM_ (enqueueSavedMessage c cData msgId) $ - filter (\SndQueue {status} -> status == Secured || status == Active) sqs - pure msgId +enqueueMessages' c cData sqs msgFlags aMessage = + liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, msgFlags, aMessage))) + +enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType AgentMsgId)) +enqueueMessagesB c reqs = do + reqs' <- enqueueMessageB c reqs + enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' + pure $ fst <$$> reqs' + +isActiveSndQ :: SndQueue -> Bool +isActiveSndQ SndQueue {status} = status == Secured || status == Active enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId -enqueueMessage c cData@ConnData {connId} sq msgFlags aMessage = do - resumeMsgDelivery c cData sq - aVRange <- asks $ smpAgentVRange . config - msgId <- storeSentMsg $ maxVersion aVRange - queuePendingMsgs c sq [msgId] - pure $ unId msgId +enqueueMessage c cData sq msgFlags aMessage = + liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], msgFlags, aMessage))) + +-- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries +enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB c reqs = do + void . forME reqs $ \(cData, sq :| _, _, _) -> + runExceptT $ resumeMsgDelivery c cData sq + aVRange <- asks $ maxVersion . smpAgentVRange . config + reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db aVRange) reqs + forME reqMids $ \((cData, sq :| sqs, _, _), mId) -> do + let InternalId msgId = mId + queuePendingMsgs c sq [mId] + let sqs' = filter isActiveSndQ sqs + pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: Version -> m InternalId - storeSentMsg agentVersion = withStore c $ \db -> runExceptT $ do + storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId)) + storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash @@ -1084,14 +1120,25 @@ enqueueMessage c cData@ConnData {connId} sq msgFlags aMessage = do msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId - pure internalId + pure (req, internalId) -enqueueSavedMessage :: AgentMonad m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () -enqueueSavedMessage c cData@ConnData {connId} msgId sq = do - resumeMsgDelivery c cData sq - let mId = InternalId msgId - queuePendingMsgs c sq [mId] - withStore' c $ \db -> createSndMsgDelivery db connId sq mId +enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () +enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) + +enqueueSavedMessageB :: (AgentMonad' m, Foldable t) => AgentClient -> t (ConnData, [SndQueue], AgentMsgId) -> m () +enqueueSavedMessageB c reqs = do + -- saving to the database is in the start to avoid race conditions when delivery is read from queue before it is saved + void $ withStoreBatch' c $ \db -> concatMap (storeDeliveries db) reqs + forM_ reqs $ \(cData, sqs, msgId) -> + forM sqs $ \sq -> do + void . runExceptT $ resumeMsgDelivery c cData sq + let mId = InternalId msgId + queuePendingMsgs c sq [mId] + where + storeDeliveries :: DB.Connection -> (ConnData, [SndQueue], AgentMsgId) -> [IO ()] + storeDeliveries db (ConnData {connId}, sqs, msgId) = do + let mId = InternalId msgId + in map (\sq -> createSndMsgDelivery db connId sq mId) sqs resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m () resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do @@ -1885,7 +1932,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s conn cData@ConnData {userId, connId, duplexHandshake, connAgentVersion, ratchetSyncState = rss} = withConnLock c connId "processSMP" $ case cmd of - SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> + SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> handleNotifyAck $ do msg' <- decryptSMPMessage v rq msg handleNotifyAck $ case msg' of @@ -2434,8 +2481,7 @@ storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ agentM enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do msgId <- enqueueRatchetKey c cData sq e2eEncryption - mapM_ (enqueueSavedMessage c cData msgId) $ - filter (\SndQueue {status} -> status == Secured || status == Active) sqs + mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs pure msgId enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 18eb3d642..0ac65d466 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -24,6 +24,7 @@ module Simplex.Messaging.Agent.Client ProtocolTestStep (..), newAgentClient, withConnLock, + withConnLocks, withInvLock, closeAgentClient, closeProtocolServerClients, @@ -99,6 +100,8 @@ module Simplex.Messaging.Agent.Client withStore', withStoreCtx, withStoreCtx', + withStoreBatch, + withStoreBatch', storeError, userServers, pickServer, @@ -658,8 +661,17 @@ withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a withInvLock AgentClient {invLocks} = withLockMap_ invLocks +withConnLocks :: MonadUnliftIO m => AgentClient -> [ConnId] -> String -> m a -> m a +withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null) + withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a -withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure +withLockMap_ = withGetLock . getMapLock + +withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a +withLocksMap_ = withGetLocks . getMapLock + +getMapLock :: Ord k => TMap k Lock -> k -> STM Lock +getMapLock locks key = TM.lookup key locks >>= maybe newLock pure where newLock = createLock >>= \l -> TM.insert key l locks $> l @@ -1291,6 +1303,18 @@ withStoreCtx_ ctx_ c action = do handleInternal :: String -> E.SomeException -> IO (Either StoreError a) handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr +withStoreBatch :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> m (t (Either AgentErrorType a)) +withStoreBatch c actions = do + st <- asks store + liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $ + withTransaction st $ mapM (`E.catch` handleInternal) . actions + where + handleInternal :: E.SomeException -> IO (Either AgentErrorType a) + handleInternal = pure . Left . INTERNAL . show + +withStoreBatch' :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO a)) -> m (t (Either AgentErrorType a)) +withStoreBatch' c actions = withStoreBatch c $ \db -> fmap Right <$> actions db + storeError :: StoreError -> AgentErrorType storeError = \case SEConnNotFound -> CONN NOT_FOUND diff --git a/src/Simplex/Messaging/Agent/Lock.hs b/src/Simplex/Messaging/Agent/Lock.hs index 10062495d..37b63eb0e 100644 --- a/src/Simplex/Messaging/Agent/Lock.hs +++ b/src/Simplex/Messaging/Agent/Lock.hs @@ -1,8 +1,18 @@ -module Simplex.Messaging.Agent.Lock where +{-# LANGUAGE NamedFieldPuns #-} + +module Simplex.Messaging.Agent.Lock + ( Lock, + createLock, + withLock, + withGetLock, + withGetLocks, + ) +where import Control.Monad (void) import Control.Monad.IO.Unlift import Data.Functor (($>)) +import UnliftIO.Async (forConcurrently) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -18,9 +28,22 @@ withLock lock name = (atomically $ putTMVar lock name) (void . atomically $ takeTMVar lock) -withGetLock :: MonadUnliftIO m => STM Lock -> String -> m a -> m a -withGetLock getLock name a = +withGetLock :: MonadUnliftIO m => (k -> STM Lock) -> k -> String -> m a -> m a +withGetLock getLock key name a = E.bracket - (atomically $ getLock >>= \l -> putTMVar l name $> l) + (atomically $ getPutLock getLock key name) (atomically . takeTMVar) (const a) + +withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> [k] -> String -> m a -> m a +withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const + where + holdLocks = forConcurrently keys $ \key -> atomically $ getPutLock getLock key name + -- only this withGetLocks would be holding the locks, + -- so it's safe to combine all lock releases into one transaction + releaseLocks = atomically . mapM_ takeTMVar + +-- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists, +-- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock. +getPutLock :: (k -> STM Lock) -> k -> String -> STM Lock +getPutLock getLock key name = getLock key >>= \l -> putTMVar l name $> l diff --git a/src/Simplex/Messaging/Server/Stats.hs b/src/Simplex/Messaging/Server/Stats.hs index 493bd5ac1..38e1d13db 100644 --- a/src/Simplex/Messaging/Server/Stats.hs +++ b/src/Simplex/Messaging/Server/Stats.hs @@ -11,7 +11,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteString.Char8 as B import Data.Set (Set) import qualified Data.Set as S -import Data.Time.Calendar.Month.Compat (pattern MonthDay) +import Data.Time.Calendar.Month (pattern MonthDay) import Data.Time.Calendar.OrdinalDate (mondayStartWeek) import Data.Time.Clock (UTCTime (..)) import Simplex.Messaging.Encoding.String diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 3143427ca..e9d94f0c2 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -85,6 +85,18 @@ unlessM b = ifM b $ pure () ($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b) f $>>= g = f >>= fmap join . mapM g +mapME :: (Monad m, Traversable t) => (a -> m (Either e b)) -> t (Either e a) -> m (t (Either e b)) +mapME f = mapM (bindRight f) +{-# INLINE mapME #-} + +bindRight :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b) +bindRight = either (pure . Left) +{-# INLINE bindRight #-} + +forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> m (t (Either e b)) +forME = flip mapME +{-# INLINE forME #-} + catchAll :: IO a -> (E.SomeException -> IO a) -> IO a catchAll = E.catch {-# INLINE catchAll #-}