From 9f1a9a5f15bf55e6b92293153315871cabdced7a Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:39:51 +0400 Subject: [PATCH] notify up, fix test --- src/Simplex/Messaging/Agent/Client.hs | 15 ++++++++++++++- tests/AgentTests/FunctionalAPITests.hs | 18 ++++++++---------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 3a74f5def..f61c1b483 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -932,7 +932,20 @@ reconnectSMPServerClients c = do pure (clients, qs <> qs') atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone DOWN_ALL) mapM_ (liftIO . forkIO . closeClient_ c) clients - void $ subscribeQueues c qs + (qSubRs, _) <- subscribeQueues c qs + let upConns = subscribedConnsByServer qSubRs + forM_ (M.toList upConns) $ \(server, connIds) -> + liftIO $ notifyUP server (S.toList . S.fromList $ connIds) + where + subscribedConnsByServer :: [(RcvQueue, Either AgentErrorType ())] -> Map SMPServer [ConnId] + subscribedConnsByServer = foldl' insertConnId M.empty + where + insertConnId :: Map SMPServer [ConnId] -> (RcvQueue, Either AgentErrorType ()) -> Map SMPServer [ConnId] + insertConnId acc (RcvQueue {server, connId}, qSubResult) = case qSubResult of + Right _ -> M.insertWith (<>) server [connId] acc + Left _ -> acc + notifyUP :: SMPServer -> [ConnId] -> IO () + notifyUP server connIds = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone (UP server connIds)) reconnectSMPServer :: AgentClient -> UserId -> SMPServer -> IO () reconnectSMPServer c userId srv = do diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 7d630630f..131cda3c5 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -14,6 +14,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} +{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.FunctionalAPITests ( functionalAPITests, @@ -2904,7 +2905,7 @@ testDeliveryReceiptsConcurrent t = _ -> error "timeout" testTwoUsers :: HasCallStack => IO () -testTwoUsers = withAgentClients2 $ \a b -> do +testTwoUsers = withAgentClientsCfg2 aCfg aCfg $ \a b -> do let nc = netCfg initAgentServers sessionMode nc `shouldBe` TSMUser runRight_ $ do @@ -2916,8 +2917,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do b `hasClients` 1 liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a - ("", "", UP _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 2 @@ -2926,7 +2926,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do liftIO $ threadDelay 250000 liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 1 @@ -2939,10 +2939,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do b `hasClients` 1 liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a - ("", "", UP _ _) <- nGet a - ("", "", UP _ _) <- nGet a - ("", "", UP _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 4 exchangeGreetingsMsgId 6 a bId1 b aId1 @@ -2952,8 +2949,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do liftIO $ threadDelay 250000 liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a - ("", "", UP _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 2 exchangeGreetingsMsgId 8 a bId1 b aId1 @@ -2961,6 +2957,8 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetingsMsgId 6 a bId2 b aId2 exchangeGreetingsMsgId 6 a bId2' b aId2' where + aCfg :: AgentConfig + aCfg = agentCfg {tbqSize = 16} hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO () hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n