From c6dde772b459a2d8f392ad5455d6f8fd6f16e867 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 25 Apr 2022 08:26:52 +0100 Subject: [PATCH] batch all connections in DOWN/UP agent messages (#363) --- src/Simplex/Messaging/Agent/Client.hs | 31 ++++++++++++++----------- src/Simplex/Messaging/Agent/Protocol.hs | 19 ++++++++++----- src/Simplex/Messaging/Protocol.hs | 2 ++ tests/AgentTests.hs | 25 +++++++++++--------- tests/AgentTests/FunctionalAPITests.hs | 11 +++++---- tests/AgentTests/NotificationTests.hs | 12 +++++----- tests/SMPAgentClient.hs | 5 +++- 7 files changed, 64 insertions(+), 41 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index e063ae9bc..44c843074 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -55,7 +55,7 @@ import qualified Data.ByteString.Char8 as B import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (isNothing) +import Data.Maybe (catMaybes, isNothing) import Data.Text.Encoding import Data.Word (Word16) import Simplex.Messaging.Agent.Env.SQLite @@ -72,10 +72,10 @@ import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, ProtocolServer (..), Qu import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, whenM) +import Simplex.Messaging.Util (bshow, ifM, liftEitherError, liftError, tryError) import Simplex.Messaging.Version import System.Timeout (timeout) -import UnliftIO (async, forConcurrently_) +import UnliftIO (async, forConcurrently) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -181,7 +181,8 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () serverDown u cs = unless (M.null cs) $ do - mapM_ (notifySub DOWN) $ M.keysSet cs + let conns = M.keys cs + unless (null conns) . notifySub "" $ DOWN srv conns unliftIO u reconnectServer reconnectServer :: m () @@ -199,26 +200,30 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = reconnectClient = withAgentLock c . withClient c srv $ \smp -> do cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c) - forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) -> - whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $ - subscribe_ smp sub `catchError` handleError connId + conns <- forConcurrently (maybe [] M.toList cs) $ \sub@(connId, _) -> + ifM + (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) + (subscribe_ smp sub `catchError` handleError connId) + (pure $ Just connId) + liftIO . unless (null conns) . notifySub "" . UP srv $ catMaybes conns where - subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT ProtocolClientError IO () + subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT ProtocolClientError IO (Maybe ConnId) subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do subscribeSMPQueue smp rcvPrivateKey rcvId addSubscription c rq connId - liftIO $ notifySub UP connId + pure $ Just connId - handleError :: ConnId -> ProtocolClientError -> ExceptT ProtocolClientError IO () + handleError :: ConnId -> ProtocolClientError -> ExceptT ProtocolClientError IO (Maybe ConnId) handleError connId = \case e@PCEResponseTimeout -> throwError e e@PCENetworkError -> throwError e e -> do - liftIO $ notifySub (ERR $ protocolClientError SMP e) connId + liftIO . notifySub connId . ERR $ protocolClientError SMP e atomically $ removePendingSubscription c srv connId + pure Nothing - notifySub :: ACommand 'Agent -> ConnId -> IO () - notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd) + notifySub :: ConnId -> ACommand 'Agent -> IO () + notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient getNtfServerClient c@AgentClient {ntfClients} srv = diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 8be63a060..31ce96c3b 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -210,8 +210,8 @@ data ACommand (p :: AParty) where CON :: ACommand Agent -- notification that connection is established SUB :: ACommand Client END :: ACommand Agent - DOWN :: ACommand Agent - UP :: ACommand Agent + DOWN :: SMPServer -> [ConnId] -> ACommand Agent + UP :: SMPServer -> [ConnId] -> ACommand Agent SEND :: MsgBody -> ACommand Client MID :: AgentMsgId -> ACommand Agent SENT :: AgentMsgId -> ACommand Agent @@ -817,8 +817,8 @@ commandP = <|> "INFO " *> infoCmd <|> "SUB" $> ACmd SClient SUB <|> "END" $> ACmd SAgent END - <|> "DOWN" $> ACmd SAgent DOWN - <|> "UP" $> ACmd SAgent UP + <|> "DOWN " *> downsResp + <|> "UP " *> upsResp <|> "SEND " *> sendCmd <|> "MID " *> msgIdResp <|> "SENT " *> sentResp @@ -840,12 +840,15 @@ commandP = acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString rjctCmd = ACmd SClient . RJCT <$> A.takeByteString infoCmd = ACmd SAgent . INFO <$> A.takeByteString + downsResp = ACmd SAgent .: DOWN <$> strP <* A.space <*> connections + upsResp = ACmd SAgent .: UP <$> strP <* A.space <*> connections sendCmd = ACmd SClient . SEND <$> A.takeByteString msgIdResp = ACmd SAgent . MID <$> A.decimal sentResp = ACmd SAgent . SENT <$> A.decimal msgErrResp = ACmd SAgent .: MERR <$> A.decimal <* A.space <*> strP message = ACmd SAgent .: MSG <$> msgMetaP <* A.space <*> A.takeByteString ackCmd = ACmd SClient . ACK <$> A.decimal + connections = strP `A.sepBy'` (A.char ',') msgMetaP = do integrity <- strP recipient <- " R=" *> partyMeta A.decimal @@ -872,8 +875,8 @@ serializeCommand = \case INFO cInfo -> "INFO " <> serializeBinary cInfo SUB -> "SUB" END -> "END" - DOWN -> "DOWN" - UP -> "UP" + DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns] + UP srv conns -> B.unwords ["UP", strEncode srv, connections conns] SEND msgBody -> "SEND " <> serializeBinary msgBody MID mId -> "MID " <> bshow mId SENT mId -> "SENT " <> bshow mId @@ -888,6 +891,8 @@ serializeCommand = \case where showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis + connections :: [ConnId] -> ByteString + connections = B.intercalate "," . map strEncode serializeMsgMeta :: MsgMeta -> ByteString serializeMsgMeta MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId} = B.unwords @@ -939,6 +944,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody ACPT {} -> Right cmd -- ERROR response does not always have connId ERR _ -> Right cmd + DOWN {} -> Right cmd + UP {} -> Right cmd -- other responses must have connId _ | B.null connId -> Left $ CMD NO_CONN diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 4c58c8adc..dde5992b3 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -388,6 +388,8 @@ type SMPServer = ProtocolServer pattern SMPServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer pattern SMPServer host port keyHash = ProtocolServer host port keyHash +{-# COMPLETE SMPServer #-} + -- | SMP server location and transport key digest (hash). data ProtocolServer = ProtocolServer { host :: HostName, diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 6885fa8e6..f71531fa2 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -19,7 +19,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Network.HTTP.Types (urlEncode) import SMPAgentClient -import SMPClient (testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) +import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Encoding.String @@ -273,7 +273,7 @@ testSubscrNotification t (server, _) client = do client #: ("1", "conn1", "NEW INV") =#> \case ("1", "conn1", INV {}) -> True; _ -> False client #:# "nothing should be delivered to client before the server is killed" killThread server - client <# ("", "conn1", DOWN) + client <# ("", "", DOWN testSMPServer ["conn1"]) withSmpServer (ATransport t) $ client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue @@ -287,14 +287,15 @@ testMsgDeliveryServerRestart t alice bob = do alice #: ("11", "bob", "ACK 5") #> ("11", "bob", OK) alice #:# "nothing else delivered before the server is killed" - alice <# ("", "bob", DOWN) + let server = (SMPServer "localhost" testPort2 testKeyHash) + alice <# ("", "", DOWN server ["bob"]) bob #: ("2", "alice", "SEND 11\nhello again") #> ("2", "alice", MID 6) bob #:# "nothing else delivered before the server is restarted" alice #:# "nothing else delivered before the server is restarted" withServer $ do bob <# ("", "alice", SENT 6) - alice <# ("", "bob", UP) + alice <# ("", "", UP server ["bob"]) alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False alice #: ("12", "bob", "ACK 6") #> ("12", "bob", OK) @@ -309,8 +310,8 @@ testServerConnectionAfterError t _ = do withServer $ do connect (bob, "bob") (alice, "alice") - bob <# ("", "alice", DOWN) - alice <# ("", "bob", DOWN) + bob <# ("", "", DOWN server ["alice"]) + alice <# ("", "", DOWN server ["bob"]) alice #: ("1", "bob", "SEND 5\nhello") #> ("1", "bob", MID 5) alice #:# "nothing else delivered before the server is restarted" bob #:# "nothing else delivered before the server is restarted" @@ -320,11 +321,11 @@ testServerConnectionAfterError t _ = do bob #: ("1", "alice", "SUB") #> ("1", "alice", ERR (BROKER NETWORK)) alice #: ("1", "bob", "SUB") #> ("1", "bob", ERR (BROKER NETWORK)) withServer $ do - alice <#= \case ("", "bob", cmd) -> cmd == UP || cmd == SENT 5; _ -> False - alice <#= \case ("", "bob", cmd) -> cmd == UP || cmd == SENT 5; _ -> False - bob <# ("", "alice", UP) + alice <# ("", "bob", SENT 5) + bob <# ("", "", UP server ["alice"]) bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK) + alice <# ("", "", UP server ["bob"]) alice #: ("1", "bob", "SEND 11\nhello again") #> ("1", "bob", MID 6) alice <# ("", "bob", SENT 6) bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False @@ -333,6 +334,7 @@ testServerConnectionAfterError t _ = do removeFile testDB removeFile testDB2 where + server = SMPServer "localhost" testPort2 testKeyHash withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` () withAgent1 = withAgent agentTestPort testDB withAgent2 = withAgent agentTestPort2 testDB2 @@ -341,6 +343,7 @@ testServerConnectionAfterError t _ = do testMsgDeliveryAgentRestart :: Transport c => TProxy c -> c -> IO () testMsgDeliveryAgentRestart t bob = do + let server = SMPServer "localhost" testPort2 testKeyHash withAgent $ \alice -> do withServer $ do connect (bob, "bob") (alice, "alice") @@ -350,7 +353,7 @@ testMsgDeliveryAgentRestart t bob = do bob #: ("11", "alice", "ACK 5") #> ("11", "alice", OK) bob #:# "nothing else delivered before the server is down" - bob <# ("", "alice", DOWN) + bob <# ("", "", DOWN server ["alice"]) alice #: ("2", "bob", "SEND 11\nhello again") #> ("2", "bob", MID 6) alice #:# "nothing else delivered before the server is restarted" bob #:# "nothing else delivered before the server is restarted" @@ -363,7 +366,7 @@ testMsgDeliveryAgentRestart t bob = do (corrId == "3" && cmd == OK) || (corrId == "" && cmd == SENT 6) _ -> False - bob <# ("", "alice", UP) + bob <# ("", "", UP server ["alice"]) bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False bob #: ("12", "alice", "ACK 6") #> ("12", "alice", OK) diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index b20154a3f..030fad7b3 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -159,12 +159,15 @@ testAsyncServerOffline t = do runExceptT $ createConnection alice SCMInvitation -- connection fails Left (BROKER NETWORK) <- runExceptT $ joinConnection bob cReq "bob's connInfo" - ("", bobId1, DOWN) <- get alice - bobId1 `shouldBe` bobId + ("", "", DOWN srv conns) <- get alice + srv `shouldBe` testSMPServer + conns `shouldBe` [bobId] -- connection succeeds after server start Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do - ("", bobId2, UP) <- get alice - liftIO $ bobId2 `shouldBe` bobId + ("", "", UP srv1 conns1) <- get alice + liftIO $ do + srv1 `shouldBe` testSMPServer + conns1 `shouldBe` [bobId] aliceId <- joinConnection bob cReq "bob's connInfo" ("", _, CONF confId "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 752c97760..599e421f6 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -50,7 +50,7 @@ testNotificationToken APNSMockServer {apnsQ} = do a <- getSMPAgentClient agentCfg initAgentServers Right () <- runExceptT $ do let tkn = DeviceToken PPApns "abcd" - registerNtfToken a tkn + NTRegistered <- registerNtfToken a tkn APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" @@ -80,13 +80,13 @@ testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do a <- getSMPAgentClient agentCfg initAgentServers Right () <- runExceptT $ do let tkn = DeviceToken PPApns "abcd" - registerNtfToken a tkn + NTRegistered <- registerNtfToken a tkn APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" liftIO $ sendApnsResponse APNSRespOk - registerNtfToken a tkn + NTRegistered <- registerNtfToken a tkn APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- atomically $ readTBQueue apnsQ _ <- ntfData' .-> "verification" @@ -107,7 +107,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do a' <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers Right () <- runExceptT $ do let tkn = DeviceToken PPApns "abcd" - registerNtfToken a tkn + NTRegistered <- registerNtfToken a tkn APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ verification <- ntfData .-> "verification" @@ -115,7 +115,7 @@ testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do liftIO $ sendApnsResponse APNSRespOk verifyNtfToken a tkn verification nonce - registerNtfToken a' tkn + NTRegistered <- registerNtfToken a' tkn APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- atomically $ readTBQueue apnsQ verification' <- ntfData' .-> "verification" @@ -141,7 +141,7 @@ testNtfTokenServerRestart t APNSMockServer {apnsQ} = do a <- getSMPAgentClient agentCfg initAgentServers let tkn = DeviceToken PPApns "abcd" Right ntfData <- withNtfServer t . runExceptT $ do - registerNtfToken a tkn + NTRegistered <- registerNtfToken a tkn APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- atomically $ readTBQueue apnsQ liftIO $ sendApnsResponse APNSRespOk diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index d23ab522e..0865e7208 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -155,10 +155,13 @@ smpAgentTest1_1_1 test' = _test [h] = test' h _test _ = error "expected 1 handle" +testSMPServer :: SMPServer +testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001" + initAgentServers :: InitialAgentServers initAgentServers = InitialAgentServers - { smp = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"], + { smp = L.fromList [testSMPServer], ntf = ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"] }