diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 7159a1324..60c98b4ea 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -33,6 +33,7 @@ module Simplex.Messaging.Agent AgentClient (..), AE, SubscriptionsInfo (..), + MsgReq, getSMPAgentClient, getSMPAgentClient_, disconnectAgentClient, @@ -393,6 +394,10 @@ sendMessage :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> A sendMessage c = withAgentEnv c .:: sendMessage' c {-# INLINE sendMessage #-} +-- When sending multiple messages to the same connection, +-- only the first MsgReq for this connection should have non-empty ConnId. +-- All subsequent MsgReq in traversable for this connection must be empty. +-- This is done to optimize processing by grouping all messages to one connection together. type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) @@ -1057,38 +1062,49 @@ sendMessages' c = sendMessagesB' c . map Right sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB' c reqs = do - connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs + (_, connIds) <- liftEither $ foldl' addConnId (Right ("", S.empty)) reqs lift $ sendMessagesB_ c reqs connIds where - addConnId s@(Right s') (Right (connId, _, _, _)) - | B.null connId = s - | connId `S.notMember` s' = Right $ S.insert connId s' - | otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID" - addConnId s _ = s + addConnId acc@(Right (prevId, s)) (Right (connId, _, _, _)) + | B.null connId = if B.null prevId then Left $ INTERNAL "sendMessages: empty first connId" else acc + | connId `S.member` s = Left $ INTERNAL "sendMessages: duplicate connId" + | otherwise = Right (connId, S.insert connId s) + addConnId acc _ = acc sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do - reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) + prev <- newTVarIO Nothing + reqs' <- withStoreBatch c $ \db -> fmap (bindRight $ getConn_ db prev) reqs let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' - void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable + void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) $ S.toList toEnable enqueueMessagesB c reqs'' where - prepareConn :: [ConnId] -> Either AgentErrorType (MsgReq, SomeConn) -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) - prepareConn acc (Left e) = (acc, Left e) - prepareConn acc (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of + getConn_ :: DB.Connection -> TVar (Maybe (Either AgentErrorType SomeConn)) -> MsgReq -> IO (Either AgentErrorType (MsgReq, SomeConn)) + getConn_ db prev req@(connId, _, _, _) = + (req,) <$$> + if B.null connId + then fromMaybe (Left $ INTERNAL "sendMessagesB_: empty prev connId") <$> atomically (readTVar prev) + else do + conn <- first storeError <$> getConn db connId + conn <$ atomically (writeTVar prev $ Just conn) + prepareConn :: Set ConnId -> Either AgentErrorType (MsgReq, SomeConn) -> (Set ConnId, Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareConn s (Left e) = (s, Left e) + prepareConn s (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of DuplexConnection cData _ sqs -> prepareMsg cData sqs SndConnection cData sq -> prepareMsg cData [sq] - _ -> (acc, Left $ CONN SIMPLEX) + _ -> (s, Left $ CONN SIMPLEX) where - prepareMsg :: ConnData -> NonEmpty SndQueue -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareMsg :: ConnData -> NonEmpty SndQueue -> (Set ConnId, Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) prepareMsg cData@ConnData {connId, pqSupport} sqs - | ratchetSyncSendProhibited cData = (acc, Left $ CMD PROHIBITED "sendMessagesB: send prohibited") + | ratchetSyncSendProhibited cData = (s, Left $ CMD PROHIBITED "sendMessagesB: send prohibited") -- connection is only updated if PQ encryption was disabled, and now it has to be enabled. -- support for PQ encryption (small message envelopes) will not be disabled when message is sent. | pqEnc == PQEncOn && pqSupport == PQSupportOff = let cData' = cData {pqSupport = PQSupportOn} :: ConnData - in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) - | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) + in (S.insert connId s, mkReq cData') + | otherwise = (s, mkReq cData) + where + mkReq cData' = Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg) -- / async command processing v v v diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 53991ed12..0ab162d08 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -357,6 +357,9 @@ functionalAPITests t = do it "should subscribe to multiple connections with pending messages" $ withSmpServer t $ testBatchedPendingMessages 10 5 + describe "Batch send messages" $ do + it "should send multiple messages to the same connection" $ withSmpServer t testSendMessagesB + it "should send messages to the 2 connections" $ withSmpServer t testSendMessagesB2 describe "Async agent commands" $ do describe "connect using async agent commands" $ testBasicMatrix2 t testAsyncCommands @@ -1932,6 +1935,48 @@ testBatchedPendingMessages nCreate nMsgs = withA = withAgent 1 agentCfg initAgentServers testDB withB = withAgent 2 agentCfg initAgentServers testDB2 +testSendMessagesB :: IO () +testSendMessagesB = withAgentClients2 $ \a b -> runRight_ $ do + (aId, bId) <- makeConnection a b + let msg cId body = Right (cId, PQEncOn, SMP.noMsgFlags, body) + [SentB 2, SentB 3, SentB 4] <- sendMessagesB a ([msg bId "msg 1", msg "" "msg 2", msg "" "msg 3"] :: [Either AgentErrorType MsgReq]) + get a ##> ("", bId, SENT 2) + get a ##> ("", bId, SENT 3) + get a ##> ("", bId, SENT 4) + receiveMsg b aId 2 "msg 1" + receiveMsg b aId 3 "msg 2" + receiveMsg b aId 4 "msg 3" + +testSendMessagesB2 :: IO () +testSendMessagesB2 = withAgentClients3 $ \a b c -> runRight_ $ do + (abId, bId) <- makeConnection a b + (acId, cId) <- makeConnection a c + let msg connId body = Right (connId, PQEncOn, SMP.noMsgFlags, body) + [SentB 2, SentB 3, SentB 4, SentB 2, SentB 3] <- + sendMessagesB a ([msg bId "msg 1", msg "" "msg 2", msg "" "msg 3", msg cId "msg 4", msg "" "msg 5"] :: [Either AgentErrorType MsgReq]) + liftIO $ + getInAnyOrder + a + [ \case ("", cId', AEvt SAEConn (SENT 2)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 3)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 4)) -> cId' == bId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 2)) -> cId' == cId; _ -> False, + \case ("", cId', AEvt SAEConn (SENT 3)) -> cId' == cId; _ -> False + ] + receiveMsg b abId 2 "msg 1" + receiveMsg b abId 3 "msg 2" + receiveMsg b abId 4 "msg 3" + receiveMsg c acId 2 "msg 4" + receiveMsg c acId 3 "msg 5" + +pattern SentB :: AgentMsgId -> Either AgentErrorType (AgentMsgId, PQEncryption) +pattern SentB msgId <- Right (msgId, PQEncOn) + +receiveMsg :: AgentClient -> ConnId -> AgentMsgId -> MsgBody -> ExceptT AgentErrorType IO () +receiveMsg c cId msgId msg = do + get c =##> \case ("", cId', Msg' mId' PQEncOn msg') -> cId' == cId && mId' == msgId && msg' == msg; _ -> False + ackMessage c cId msgId Nothing + testAsyncCommands :: SndQueueSecured -> AgentClient -> AgentClient -> AgentMsgId -> IO () testAsyncCommands sqSecured alice bob baseId = runRight_ $ do