agent: batch sending messages (#922)

* agent: batch sending messages (attempt 4)

* handle errors in batch sending

* batch attempt 5 (#923)

* attempt 5

* remove IORefs

* add liftA2 for 8.10 compat

* remove db-related zipping

* traversable

---------

Co-authored-by: IC Rainbow <aenor.realm@gmail.com>

* s/mapE/bindRight/

* name

Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com>

* comment

Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com>

* remove unused funcs

---------

Co-authored-by: IC Rainbow <aenor.realm@gmail.com>
Co-authored-by: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com>
This commit is contained in:
Evgeny Poberezkin
2023-12-19 23:01:34 +00:00
committed by GitHub
parent 7627ce6b69
commit 8c250ebe19
7 changed files with 154 additions and 57 deletions
+1 -2
View File
@@ -65,8 +65,7 @@ dependencies:
- sqlcipher-simple == 0.4.*
- stm == 2.5.*
- temporary == 1.3.*
- time == 1.9.*
- time-compat == 1.9.*
- time == 1.12.*
- time-manager == 0.0.*
- tls >= 1.7.0 && < 1.8
- transformers == 0.6.*
+7 -14
View File
@@ -212,8 +212,7 @@ library
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, tls >=1.7.0 && <1.8
, transformers ==0.6.*
@@ -285,8 +284,7 @@ executable ntf-server
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, tls >=1.7.0 && <1.8
, transformers ==0.6.*
@@ -358,8 +356,7 @@ executable smp-agent
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, tls >=1.7.0 && <1.8
, transformers ==0.6.*
@@ -431,8 +428,7 @@ executable smp-server
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, tls >=1.7.0 && <1.8
, transformers ==0.6.*
@@ -504,8 +500,7 @@ executable xftp
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, tls >=1.7.0 && <1.8
, transformers ==0.6.*
@@ -577,8 +572,7 @@ executable xftp-server
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, tls >=1.7.0 && <1.8
, transformers ==0.6.*
@@ -687,8 +681,7 @@ test-suite simplexmq-test
, sqlcipher-simple ==0.4.*
, stm ==2.5.*
, temporary ==1.3.*
, time ==1.9.*
, time-compat ==1.9.*
, time ==1.12.*
, time-manager ==0.0.*
, timeit ==2.0.*
, tls >=1.7.0 && <1.8
+81 -35
View File
@@ -64,6 +64,7 @@ module Simplex.Messaging.Agent
resubscribeConnection,
resubscribeConnections,
sendMessage,
sendMessages,
ackMessage,
switchConnection,
abortConnectionSwitch,
@@ -119,14 +120,16 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.), (.::), (.::.))
import Data.Foldable (foldl')
import Data.Either (rights)
import Data.Foldable (foldl', toList)
import Data.Functor (($>))
import Data.Functor.Identity
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time.Clock
@@ -277,6 +280,12 @@ resubscribeConnections c = withAgentEnv c . resubscribeConnections' c
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage c = withAgentEnv c .:. sendMessage' c
type MsgReq = (ConnId, MsgFlags, MsgBody)
-- | Send multiple messages to different connections (SEND command)
sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId]
sendMessages c = withAgentEnv c . sendMessages' c
ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m ()
ackMessage c = withAgentEnv c .:. ackMessage' c
@@ -867,17 +876,29 @@ getNotificationMessage' c nonce encNtfInfo = do
-- | Send message to the connection (SEND command) in Reader monad
sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData _ sqs -> enqueueMsgs cData sqs
SndConnection cData sq -> enqueueMsgs cData [sq]
_ -> throwError $ CONN SIMPLEX
sendMessage' c connId msgFlags msg = liftEither . runIdentity =<< sendMessagesB c (Identity (Right (connId, msgFlags, msg)))
-- | Send multiple messages to different connections (SEND command) in Reader monad
sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId]
sendMessages' c = sendMessagesB c . map Right
sendMessagesB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId))
sendMessagesB c reqs = withConnLocks c connIds "sendMessages" $ do
reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs)
let reqs'' = fmap (>>= prepareConn) reqs'
enqueueMessagesB c reqs''
where
enqueueMsgs :: ConnData -> NonEmpty SndQueue -> m AgentMsgId
enqueueMsgs cData sqs = do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
enqueueMessages c cData sqs msgFlags $ A_MSG msg
prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)
prepareConn ((_, msgFlags, msg), SomeConn _ conn) = case conn of
DuplexConnection cData _ sqs -> prepareMsg cData sqs
SndConnection cData sq -> prepareMsg cData [sq]
_ -> Left $ CONN SIMPLEX
where
prepareMsg :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)
prepareMsg cData sqs
| ratchetSyncSendProhibited cData = Left $ CMD PROHIBITED
| otherwise = Right (cData, sqs, msgFlags, A_MSG msg)
connIds = map (\(connId, _, _) -> connId) $ rights $ toList reqs
-- / async command processing v v v
@@ -1056,22 +1077,37 @@ enqueueMessages c cData sqs msgFlags aMessage = do
enqueueMessages' c cData sqs msgFlags aMessage
enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessages' c cData (sq :| sqs) msgFlags aMessage = do
msgId <- enqueueMessage c cData sq msgFlags aMessage
mapM_ (enqueueSavedMessage c cData msgId) $
filter (\SndQueue {status} -> status == Secured || status == Active) sqs
pure msgId
enqueueMessages' c cData sqs msgFlags aMessage =
liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, msgFlags, aMessage)))
enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType AgentMsgId))
enqueueMessagesB c reqs = do
reqs' <- enqueueMessageB c reqs
enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs'
pure $ fst <$$> reqs'
isActiveSndQ :: SndQueue -> Bool
isActiveSndQ SndQueue {status} = status == Secured || status == Active
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessage c cData@ConnData {connId} sq msgFlags aMessage = do
resumeMsgDelivery c cData sq
aVRange <- asks $ smpAgentVRange . config
msgId <- storeSentMsg $ maxVersion aVRange
queuePendingMsgs c sq [msgId]
pure $ unId msgId
enqueueMessage c cData sq msgFlags aMessage =
liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], msgFlags, aMessage)))
-- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries
enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId))))
enqueueMessageB c reqs = do
void . forME reqs $ \(cData, sq :| _, _, _) ->
runExceptT $ resumeMsgDelivery c cData sq
aVRange <- asks $ maxVersion . smpAgentVRange . config
reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db aVRange) reqs
forME reqMids $ \((cData, sq :| sqs, _, _), mId) -> do
let InternalId msgId = mId
queuePendingMsgs c sq [mId]
let sqs' = filter isActiveSndQ sqs
pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId))
where
storeSentMsg :: Version -> m InternalId
storeSentMsg agentVersion = withStore c $ \db -> runExceptT $ do
storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId))
storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash
@@ -1084,14 +1120,25 @@ enqueueMessage c cData@ConnData {connId} sq msgFlags aMessage = do
msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash}
liftIO $ createSndMsg db connId msgData
liftIO $ createSndMsgDelivery db connId sq internalId
pure internalId
pure (req, internalId)
enqueueSavedMessage :: AgentMonad m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m ()
enqueueSavedMessage c cData@ConnData {connId} msgId sq = do
resumeMsgDelivery c cData sq
let mId = InternalId msgId
queuePendingMsgs c sq [mId]
withStore' c $ \db -> createSndMsgDelivery db connId sq mId
enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m ()
enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId)
enqueueSavedMessageB :: (AgentMonad' m, Foldable t) => AgentClient -> t (ConnData, [SndQueue], AgentMsgId) -> m ()
enqueueSavedMessageB c reqs = do
-- saving to the database is in the start to avoid race conditions when delivery is read from queue before it is saved
void $ withStoreBatch' c $ \db -> concatMap (storeDeliveries db) reqs
forM_ reqs $ \(cData, sqs, msgId) ->
forM sqs $ \sq -> do
void . runExceptT $ resumeMsgDelivery c cData sq
let mId = InternalId msgId
queuePendingMsgs c sq [mId]
where
storeDeliveries :: DB.Connection -> (ConnData, [SndQueue], AgentMsgId) -> [IO ()]
storeDeliveries db (ConnData {connId}, sqs, msgId) = do
let mId = InternalId msgId
in map (\sq -> createSndMsgDelivery db connId sq mId) sqs
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
@@ -1885,7 +1932,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, s
conn
cData@ConnData {userId, connId, duplexHandshake, connAgentVersion, ratchetSyncState = rss} =
withConnLock c connId "processSMP" $ case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $ do
msg' <- decryptSMPMessage v rq msg
handleNotifyAck $ case msg' of
@@ -2434,8 +2481,7 @@ storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ agentM
enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId
enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do
msgId <- enqueueRatchetKey c cData sq e2eEncryption
mapM_ (enqueueSavedMessage c cData msgId) $
filter (\SndQueue {status} -> status == Secured || status == Active) sqs
mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs
pure msgId
enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId
+25 -1
View File
@@ -24,6 +24,7 @@ module Simplex.Messaging.Agent.Client
ProtocolTestStep (..),
newAgentClient,
withConnLock,
withConnLocks,
withInvLock,
closeAgentClient,
closeProtocolServerClients,
@@ -99,6 +100,8 @@ module Simplex.Messaging.Agent.Client
withStore',
withStoreCtx,
withStoreCtx',
withStoreBatch,
withStoreBatch',
storeError,
userServers,
pickServer,
@@ -658,8 +661,17 @@ withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId
withInvLock :: MonadUnliftIO m => AgentClient -> ByteString -> String -> m a -> m a
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
withConnLocks :: MonadUnliftIO m => AgentClient -> [ConnId] -> String -> m a -> m a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
withLockMap_ = withGetLock . getMapLock
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
withLocksMap_ = withGetLocks . getMapLock
getMapLock :: Ord k => TMap k Lock -> k -> STM Lock
getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
where
newLock = createLock >>= \l -> TM.insert key l locks $> l
@@ -1291,6 +1303,18 @@ withStoreCtx_ ctx_ c action = do
handleInternal :: String -> E.SomeException -> IO (Either StoreError a)
handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr
withStoreBatch :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> m (t (Either AgentErrorType a))
withStoreBatch c actions = do
st <- asks store
liftIO $ agentOperationBracket c AODatabase (\_ -> pure ()) $
withTransaction st $ mapM (`E.catch` handleInternal) . actions
where
handleInternal :: E.SomeException -> IO (Either AgentErrorType a)
handleInternal = pure . Left . INTERNAL . show
withStoreBatch' :: (AgentMonad' m, Traversable t) => AgentClient -> (DB.Connection -> t (IO a)) -> m (t (Either AgentErrorType a))
withStoreBatch' c actions = withStoreBatch c $ \db -> fmap Right <$> actions db
storeError :: StoreError -> AgentErrorType
storeError = \case
SEConnNotFound -> CONN NOT_FOUND
+27 -4
View File
@@ -1,8 +1,18 @@
module Simplex.Messaging.Agent.Lock where
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.Agent.Lock
( Lock,
createLock,
withLock,
withGetLock,
withGetLocks,
)
where
import Control.Monad (void)
import Control.Monad.IO.Unlift
import Data.Functor (($>))
import UnliftIO.Async (forConcurrently)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -18,9 +28,22 @@ withLock lock name =
(atomically $ putTMVar lock name)
(void . atomically $ takeTMVar lock)
withGetLock :: MonadUnliftIO m => STM Lock -> String -> m a -> m a
withGetLock getLock name a =
withGetLock :: MonadUnliftIO m => (k -> STM Lock) -> k -> String -> m a -> m a
withGetLock getLock key name a =
E.bracket
(atomically $ getLock >>= \l -> putTMVar l name $> l)
(atomically $ getPutLock getLock key name)
(atomically . takeTMVar)
(const a)
withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> [k] -> String -> m a -> m a
withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const
where
holdLocks = forConcurrently keys $ \key -> atomically $ getPutLock getLock key name
-- only this withGetLocks would be holding the locks,
-- so it's safe to combine all lock releases into one transaction
releaseLocks = atomically . mapM_ takeTMVar
-- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists,
-- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock.
getPutLock :: (k -> STM Lock) -> k -> String -> STM Lock
getPutLock getLock key name = getLock key >>= \l -> putTMVar l name $> l
+1 -1
View File
@@ -11,7 +11,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Char8 as B
import Data.Set (Set)
import qualified Data.Set as S
import Data.Time.Calendar.Month.Compat (pattern MonthDay)
import Data.Time.Calendar.Month (pattern MonthDay)
import Data.Time.Calendar.OrdinalDate (mondayStartWeek)
import Data.Time.Clock (UTCTime (..))
import Simplex.Messaging.Encoding.String
+12
View File
@@ -85,6 +85,18 @@ unlessM b = ifM b $ pure ()
($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b)
f $>>= g = f >>= fmap join . mapM g
mapME :: (Monad m, Traversable t) => (a -> m (Either e b)) -> t (Either e a) -> m (t (Either e b))
mapME f = mapM (bindRight f)
{-# INLINE mapME #-}
bindRight :: Monad m => (a -> m (Either e b)) -> Either e a -> m (Either e b)
bindRight = either (pure . Left)
{-# INLINE bindRight #-}
forME :: (Monad m, Traversable t) => t (Either e a) -> (a -> m (Either e b)) -> m (t (Either e b))
forME = flip mapME
{-# INLINE forME #-}
catchAll :: IO a -> (E.SomeException -> IO a) -> IO a
catchAll = E.catch
{-# INLINE catchAll #-}