mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 18:35:59 +00:00
agent: fail if non-unique connection IDs are passed to sendMessages (#1170)
This commit is contained in:
committed by
GitHub
parent
984394d906
commit
6309f92c68
@@ -138,6 +138,8 @@ import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time.Clock
|
||||
@@ -354,12 +356,12 @@ sendMessage c = withAgentEnv c .:: sendMessage' c
|
||||
type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody)
|
||||
|
||||
-- | Send multiple messages to different connections (SEND command)
|
||||
sendMessages :: AgentClient -> [MsgReq] -> IO [Either AgentErrorType (AgentMsgId, PQEncryption)]
|
||||
sendMessages c = withAgentEnv' c . sendMessages' c
|
||||
sendMessages :: AgentClient -> [MsgReq] -> AE [Either AgentErrorType (AgentMsgId, PQEncryption)]
|
||||
sendMessages c = withAgentEnv c . sendMessages' c
|
||||
{-# INLINE sendMessages #-}
|
||||
|
||||
sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> IO (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
|
||||
sendMessagesB c = withAgentEnv' c . sendMessagesB' c
|
||||
sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AE (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
|
||||
sendMessagesB c = withAgentEnv c . sendMessagesB' c
|
||||
{-# INLINE sendMessagesB #-}
|
||||
|
||||
ackMessage :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE ()
|
||||
@@ -1033,16 +1035,27 @@ getNotificationMessage' c nonce encNtfInfo = do
|
||||
|
||||
-- | Send message to the connection (SEND command) in Reader monad
|
||||
sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption)
|
||||
sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg)))
|
||||
sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB_ c (Identity (Right (connId, pqEnc, msgFlags, msg))) (S.singleton connId)
|
||||
{-# INLINE sendMessage' #-}
|
||||
|
||||
-- | Send multiple messages to different connections (SEND command) in Reader monad
|
||||
sendMessages' :: AgentClient -> [MsgReq] -> AM' [Either AgentErrorType (AgentMsgId, PQEncryption)]
|
||||
sendMessages' :: AgentClient -> [MsgReq] -> AM [Either AgentErrorType (AgentMsgId, PQEncryption)]
|
||||
sendMessages' c = sendMessagesB' c . map Right
|
||||
{-# INLINE sendMessages' #-}
|
||||
|
||||
sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
|
||||
sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do
|
||||
sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
|
||||
sendMessagesB' c reqs = do
|
||||
connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs
|
||||
lift $ sendMessagesB_ c reqs connIds
|
||||
where
|
||||
addConnId s@(Right s') (Right (connId, _, _, _))
|
||||
| B.null connId = s
|
||||
| connId `S.notMember` s' = Right $ S.insert connId s'
|
||||
| otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID"
|
||||
addConnId s _ = s
|
||||
|
||||
sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
|
||||
sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do
|
||||
reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs)
|
||||
let (toEnable, reqs'') = mapAccumL prepareConn [] reqs'
|
||||
void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable
|
||||
@@ -1064,7 +1077,6 @@ sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do
|
||||
let cData' = cData {pqSupport = PQSupportOn} :: ConnData
|
||||
in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg))
|
||||
| otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg))
|
||||
connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs
|
||||
|
||||
-- / async command processing v v v
|
||||
|
||||
|
||||
@@ -826,15 +826,15 @@ withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
|
||||
withInvLock' AgentClient {invLocks} = withLockMap invLocks
|
||||
{-# INLINE withInvLock' #-}
|
||||
|
||||
withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
|
||||
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
|
||||
withConnLocks :: AgentClient -> Set ConnId -> String -> AM' a -> AM' a
|
||||
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks
|
||||
{-# INLINE withConnLocks #-}
|
||||
|
||||
withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap = withGetLock . getMapLock
|
||||
{-# INLINE withLockMap #-}
|
||||
|
||||
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
|
||||
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> String -> m a -> m a
|
||||
withLocksMap_ = withGetLocks . getMapLock
|
||||
{-# INLINE withLocksMap_ #-}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import Control.Monad (void)
|
||||
import Control.Monad.Except (ExceptT (..), runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Functor (($>))
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import UnliftIO.Async (forConcurrently)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
@@ -39,13 +41,11 @@ withGetLock getLock key name a =
|
||||
(atomically . takeTMVar)
|
||||
(const a)
|
||||
|
||||
withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> [k] -> String -> m a -> m a
|
||||
withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> Set k -> String -> m a -> m a
|
||||
withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const
|
||||
where
|
||||
holdLocks = forConcurrently keys $ \key -> atomically $ getPutLock getLock key name
|
||||
-- only this withGetLocks would be holding the locks,
|
||||
-- so it's safe to combine all lock releases into one transaction
|
||||
releaseLocks = atomically . mapM_ takeTMVar
|
||||
holdLocks = forConcurrently (S.toList keys) $ \key -> atomically $ getPutLock getLock key name
|
||||
releaseLocks = mapM_ (atomically . takeTMVar)
|
||||
|
||||
-- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists,
|
||||
-- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock.
|
||||
|
||||
Reference in New Issue
Block a user