agent: fail if non-unique connection IDs are passed to sendMessages (#1170)

This commit is contained in:
Evgeny Poberezkin
2024-05-23 22:01:57 +01:00
committed by GitHub
parent 984394d906
commit 6309f92c68
3 changed files with 29 additions and 17 deletions

View File

@@ -138,6 +138,8 @@ import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time.Clock
@@ -354,12 +356,12 @@ sendMessage c = withAgentEnv c .:: sendMessage' c
type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody)
-- | Send multiple messages to different connections (SEND command)
sendMessages :: AgentClient -> [MsgReq] -> IO [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages c = withAgentEnv' c . sendMessages' c
sendMessages :: AgentClient -> [MsgReq] -> AE [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages c = withAgentEnv c . sendMessages' c
{-# INLINE sendMessages #-}
sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> IO (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB c = withAgentEnv' c . sendMessagesB' c
sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AE (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB c = withAgentEnv c . sendMessagesB' c
{-# INLINE sendMessagesB #-}
ackMessage :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE ()
@@ -1033,16 +1035,27 @@ getNotificationMessage' c nonce encNtfInfo = do
-- | Send message to the connection (SEND command) in Reader monad
sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption)
sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg)))
sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB_ c (Identity (Right (connId, pqEnc, msgFlags, msg))) (S.singleton connId)
{-# INLINE sendMessage' #-}
-- | Send multiple messages to different connections (SEND command) in Reader monad
sendMessages' :: AgentClient -> [MsgReq] -> AM' [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages' :: AgentClient -> [MsgReq] -> AM [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages' c = sendMessagesB' c . map Right
{-# INLINE sendMessages' #-}
sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do
sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB' c reqs = do
connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs
lift $ sendMessagesB_ c reqs connIds
where
addConnId s@(Right s') (Right (connId, _, _, _))
| B.null connId = s
| connId `S.notMember` s' = Right $ S.insert connId s'
| otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID"
addConnId s _ = s
sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do
reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs)
let (toEnable, reqs'') = mapAccumL prepareConn [] reqs'
void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable
@@ -1064,7 +1077,6 @@ sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do
let cData' = cData {pqSupport = PQSupportOn} :: ConnData
in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg))
| otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg))
connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs
-- / async command processing v v v

View File

@@ -826,15 +826,15 @@ withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
withInvLock' AgentClient {invLocks} = withLockMap invLocks
{-# INLINE withInvLock' #-}
withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
withConnLocks :: AgentClient -> Set ConnId -> String -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks
{-# INLINE withConnLocks #-}
withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap = withGetLock . getMapLock
{-# INLINE withLockMap #-}
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> String -> m a -> m a
withLocksMap_ = withGetLocks . getMapLock
{-# INLINE withLocksMap_ #-}

View File

@@ -12,6 +12,8 @@ import Control.Monad (void)
import Control.Monad.Except (ExceptT (..), runExceptT)
import Control.Monad.IO.Unlift
import Data.Functor (($>))
import Data.Set (Set)
import qualified Data.Set as S
import UnliftIO.Async (forConcurrently)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -39,13 +41,11 @@ withGetLock getLock key name a =
(atomically . takeTMVar)
(const a)
withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> [k] -> String -> m a -> m a
withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> Set k -> String -> m a -> m a
withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const
where
holdLocks = forConcurrently keys $ \key -> atomically $ getPutLock getLock key name
-- only this withGetLocks would be holding the locks,
-- so it's safe to combine all lock releases into one transaction
releaseLocks = atomically . mapM_ takeTMVar
holdLocks = forConcurrently (S.toList keys) $ \key -> atomically $ getPutLock getLock key name
releaseLocks = mapM_ (atomically . takeTMVar)
-- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists,
-- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock.