diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 652e2d58d..dddaa7231 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -138,6 +138,8 @@ import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) +import Data.Set (Set) +import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T import Data.Time.Clock @@ -354,12 +356,12 @@ sendMessage c = withAgentEnv c .:: sendMessage' c type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) -sendMessages :: AgentClient -> [MsgReq] -> IO [Either AgentErrorType (AgentMsgId, PQEncryption)] -sendMessages c = withAgentEnv' c . sendMessages' c +sendMessages :: AgentClient -> [MsgReq] -> AE [Either AgentErrorType (AgentMsgId, PQEncryption)] +sendMessages c = withAgentEnv c . sendMessages' c {-# INLINE sendMessages #-} -sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> IO (t (Either AgentErrorType (AgentMsgId, PQEncryption))) -sendMessagesB c = withAgentEnv' c . sendMessagesB' c +sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AE (t (Either AgentErrorType (AgentMsgId, PQEncryption))) +sendMessagesB c = withAgentEnv c . sendMessagesB' c {-# INLINE sendMessagesB #-} ackMessage :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE () @@ -1033,16 +1035,27 @@ getNotificationMessage' c nonce encNtfInfo = do -- | Send message to the connection (SEND command) in Reader monad sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption) -sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg))) +sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB_ c (Identity (Right (connId, pqEnc, msgFlags, msg))) (S.singleton connId) {-# INLINE sendMessage' #-} -- | Send multiple messages to different connections (SEND command) in Reader monad -sendMessages' :: AgentClient -> [MsgReq] -> AM' [Either AgentErrorType (AgentMsgId, PQEncryption)] +sendMessages' :: AgentClient -> [MsgReq] -> AM [Either AgentErrorType (AgentMsgId, PQEncryption)] sendMessages' c = sendMessagesB' c . map Right {-# INLINE sendMessages' #-} -sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) -sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do +sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption))) +sendMessagesB' c reqs = do + connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs + lift $ sendMessagesB_ c reqs connIds + where + addConnId s@(Right s') (Right (connId, _, _, _)) + | B.null connId = s + | connId `S.notMember` s' = Right $ S.insert connId s' + | otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID" + addConnId s _ = s + +sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) +sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable @@ -1064,7 +1077,6 @@ sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do let cData' = cData {pqSupport = PQSupportOn} :: ConnData in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) - connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs -- / async command processing v v v diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index e59dee7a1..2dd4ec3c9 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -826,15 +826,15 @@ withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a withInvLock' AgentClient {invLocks} = withLockMap invLocks {-# INLINE withInvLock' #-} -withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a -withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null) +withConnLocks :: AgentClient -> Set ConnId -> String -> AM' a -> AM' a +withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks {-# INLINE withConnLocks #-} withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a withLockMap = withGetLock . getMapLock {-# INLINE withLockMap #-} -withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a +withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> String -> m a -> m a withLocksMap_ = withGetLocks . getMapLock {-# INLINE withLocksMap_ #-} diff --git a/src/Simplex/Messaging/Agent/Lock.hs b/src/Simplex/Messaging/Agent/Lock.hs index c0647b844..69b8169e2 100644 --- a/src/Simplex/Messaging/Agent/Lock.hs +++ b/src/Simplex/Messaging/Agent/Lock.hs @@ -12,6 +12,8 @@ import Control.Monad (void) import Control.Monad.Except (ExceptT (..), runExceptT) import Control.Monad.IO.Unlift import Data.Functor (($>)) +import Data.Set (Set) +import qualified Data.Set as S import UnliftIO.Async (forConcurrently) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -39,13 +41,11 @@ withGetLock getLock key name a = (atomically . takeTMVar) (const a) -withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> [k] -> String -> m a -> m a +withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> Set k -> String -> m a -> m a withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const where - holdLocks = forConcurrently keys $ \key -> atomically $ getPutLock getLock key name - -- only this withGetLocks would be holding the locks, - -- so it's safe to combine all lock releases into one transaction - releaseLocks = atomically . mapM_ takeTMVar + holdLocks = forConcurrently (S.toList keys) $ \key -> atomically $ getPutLock getLock key name + releaseLocks = mapM_ (atomically . takeTMVar) -- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists, -- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock.