diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 5415ebd65..895ba28ae 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -53,7 +53,8 @@ import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) import qualified Data.IntMap.Strict as IM -import Data.List (intercalate) +import Data.List (intercalate, mapAccumR) +import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Data.Maybe (isNothing) @@ -482,16 +483,29 @@ send :: Transport c => THandleSMP c 'TServer -> Client -> IO () send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send" forever $ do - ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ - -- TODO we can authorize responses as well - void . liftIO . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts + sendTransmissions =<< atomically (readTBQueue sndQ) atomically . writeTVar sndActiveAt =<< liftIO getSystemTime where - tOrder :: Transmission BrokerMsg -> Int - tOrder (_, _, cmd) = case cmd of - MSG {} -> 0 - NMSG {} -> 0 - _ -> 1 + sendTransmissions :: NonEmpty (Transmission BrokerMsg) -> IO () + sendTransmissions ts + | L.length ts <= 2 = tSend ts + | otherwise = do + let (msgs, ts') = mapAccumR splitMessages [] ts + -- If the request had batched subscriptions (L.length ts > 2) + -- this will reply OK to all SUBs in the first batched transmission, + -- to reduce client timeouts. + tSend ts' + -- After that all messages will be sent in separate transmissions, + -- without any client response timeouts. + mapM_ tSend (L.nonEmpty msgs) + where + splitMessages :: [Transmission BrokerMsg] -> Transmission BrokerMsg -> ([Transmission BrokerMsg], Transmission BrokerMsg) + splitMessages msgs t@(corrId, entId, cmd) = case cmd of + -- replace MSG response with OK, accumulating MSG in a separate list. + MSG {} -> ((CorrId "", entId, cmd) : msgs, (corrId, entId, OK)) + _ -> (msgs, t) + tSend :: NonEmpty (Transmission BrokerMsg) -> IO () + tSend = void . tPut h . L.map (\t -> Right (Nothing, encodeTransmission params t)) disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index b890c2c00..32610b54e 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -27,9 +27,9 @@ import GHC.Stack (withFrozenCallStack) import Network.HTTP.Types (urlEncode) import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) -import Simplex.Messaging.Agent.Protocol hiding (MID, CONF, INFO, REQ) +import Simplex.Messaging.Agent.Protocol hiding (CONF, INFO, MID, REQ) import qualified Simplex.Messaging.Agent.Protocol as A -import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQSupportOn, pattern PQSupportOff) +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (ErrorType (..)) @@ -547,10 +547,10 @@ testResumeDeliveryQuotaExceeded _ alice bob = do bob <#= \case ("", "alice", Msg "message 4") -> True; _ -> False bob #: ("4", "alice", "ACK 7") #> ("4", "alice", OK) inAnyOrder - (tGetAgent alice) - [ \case ("", c, Right (SENT 8)) -> c == "bob"; _ -> False, - \case ("", c, Right QCONT) -> c == "bob"; _ -> False - ] + (tGetAgent alice) + [ \case ("", c, Right (SENT 8)) -> c == "bob"; _ -> False, + \case ("", c, Right QCONT) -> c == "bob"; _ -> False + ] bob <#= \case ("", "alice", Msg "over quota") -> True; _ -> False -- message 8 is skipped because of alice agent sending "QCONT" message bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) @@ -580,7 +580,7 @@ enableKEMStr _ = "" pqConnModeStr :: InitialKeys -> ByteString pqConnModeStr (IKNoPQ PQSupportOff) = "" -pqConnModeStr pq = " " <> strEncode pq +pqConnModeStr pq = " " <> strEncode pq sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO () sendMessage (h1, name1) (h2, name2) msg = do diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index cdcf5baed..59e433ea5 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -58,10 +58,10 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Either (isRight) import Data.Int (Int64) -import Data.List (nub) +import Data.List (find, nub) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M -import Data.Maybe (isNothing) +import Data.Maybe (isJust, isNothing) import qualified Data.Set as S import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) @@ -337,6 +337,9 @@ functionalAPITests t = do skip "faster version of the previous test (200 subscriptions gets very slow with test coverage)" $ it "should subscribe to multiple (6) subscriptions with batching" $ testBatchedSubscriptions 6 3 t + it "should subscribe to multiple connections with pending messages" $ + withSmpServer t $ + testBatchedPendingMessages 10 5 describe "Async agent commands" $ do it "should connect using async agent commands" $ withSmpServer t testAsyncCommands @@ -1534,7 +1537,7 @@ testBatchedSubscriptions :: Int -> Int -> ATransport -> IO () testBatchedSubscriptions nCreate nDel t = withAgentClientsCfgServers2 agentCfg agentCfg initAgentServers2 $ \a b -> do conns <- runServers $ do - conns <- replicateM (nCreate :: Int) $ makeConnection_ PQSupportOff a b + conns <- replicateM nCreate $ makeConnection_ PQSupportOff a b forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' @@ -1593,6 +1596,25 @@ testBatchedSubscriptions nCreate nDel t = killThread t1 pure res +testBatchedPendingMessages :: Int -> Int -> IO () +testBatchedPendingMessages nCreate nMsgs = + withA $ \a -> do + conns <- withB $ \b -> runRight $ do + replicateM nCreate $ makeConnection a b + let msgConns = take nMsgs conns + runRight_ $ forM_ msgConns $ \(_, bId) -> sendMessage a bId SMP.noMsgFlags "hello" + replicateM_ nMsgs $ get a =##> \case ("", cId, SENT _) -> isJust $ find ((cId ==) . snd) msgConns; _ -> False + withB $ \b -> runRight_ $ do + r <- subscribeConnections b $ map fst conns + liftIO $ all isRight r `shouldBe` True + replicateM_ nMsgs $ do + ("", cId, Msg' msgId _ "hello") <- get b + liftIO $ isJust (find ((cId ==) . fst) msgConns) `shouldBe` True + ackMessage b cId msgId Nothing + where + withA = withAgent 1 agentCfg initAgentServers testDB + withB = withAgent 2 agentCfg initAgentServers testDB2 + testAsyncCommands :: IO () testAsyncCommands = withAgentClients2 $ \alice bob -> runRight_ $ do