diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 6cbc4d4aa..c0ab7df8e 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -45,6 +45,7 @@ import Control.Monad.IO.Unlift import Control.Monad.Reader import Control.Monad.Trans.Except import Crypto.Random +import Control.Monad.STM (retry) import Data.Bifunctor (first) import Data.ByteString.Base64 (encode) import Data.ByteString.Char8 (ByteString) @@ -95,7 +96,6 @@ import System.Exit (exitFailure) import System.IO (hPrint, hPutStrLn, hSetNewlineMode, universalNewlineMode) import System.Mem.Weak (deRefWeak) import UnliftIO (timeout) -import UnliftIO.Async (mapConcurrently) import UnliftIO.Concurrent import UnliftIO.Directory (doesFileExist, renameFile) import UnliftIO.Exception @@ -182,12 +182,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> M (Maybe s) endPreviousSubscriptions (qId, c) = do - tId <- atomically $ stateTVar (endThreadSeq c) $ \next -> (next, next + 1) - t <- forkIO $ do - labelMyThread $ label <> ".endPreviousSubscriptions" + forkClient c (label <> ".endPreviousSubscriptions") $ atomically $ writeTBQueue (sndQ c) [(CorrId "", qId, END)] - atomically $ modifyTVar' (endThreads c) $ IM.delete tId - mkWeakThreadId t >>= atomically . modifyTVar' (endThreads c) . IM.insert tId atomically $ TM.lookupDelete qId (clientSubs c) receiveFromProxyAgent :: ProxyAgent -> M () @@ -364,10 +360,10 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do ss <- unliftIO u $ asks serverStats let putStat :: Show a => ByteString -> (ServerStats -> TVar a) -> IO () putStat label var = readTVarIO (var ss) >>= \v -> B.hPutStr h $ label <> ": " <> bshow v <> "\n" - putProxyStat :: ByteString -> (ServerStats -> ProxyStats) -> IO () + putProxyStat :: ByteString -> (ServerStats -> ProxyStats) -> IO () putProxyStat label var = do ProxyStatsData {_pRequests, _pSuccesses, _pErrorsConnect, _pErrorsCompat, _pErrorsOther} <- atomically $ getProxyStatsData $ var ss - B.hPutStr h $ label <> ": requests=" <> bshow _pRequests <> ", successes=" <> bshow _pSuccesses <> ", errorsConnect=" <> bshow _pErrorsConnect <> ", errorsCompat=" <> bshow _pErrorsCompat <> ", errorsOther=" <> bshow _pErrorsOther <> "\n" + B.hPutStr h $ label <> ": requests=" <> bshow _pRequests <> ", successes=" <> bshow _pSuccesses <> ", errorsConnect=" <> bshow _pErrorsConnect <> ", errorsCompat=" <> bshow _pErrorsCompat <> ", errorsOther=" <> bshow _pErrorsOther <> "\n" putStat "fromTime" fromTime putStat "qCreated" qCreated putStat "qSecured" qSecured @@ -650,18 +646,39 @@ dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/Xo dummyKeyX25519 :: C.PublicKey 'C.X25519 dummyKeyX25519 = "MCowBQYDK2VuAyEA4JGSMYht18H4mas/jHeBwfcM7jLwNYJNOAhi2/g4RXg=" +forkClient :: Client -> String -> M () -> M () +forkClient Client {endThreads, endThreadSeq} label action = do + tId <- atomically $ stateTVar endThreadSeq $ \next -> (next, next + 1) + t <- forkIO $ do + labelMyThread label + action `finally` atomically (modifyTVar' endThreads $ IM.delete tId) + mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId + client :: THandleParams SMPVersion 'TServer -> Client -> Server -> M () -client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId} Server {subscribedQ, ntfSubscribedQ, notifiers} = do +client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, notifiers} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands" forever $ do (proxied, rs) <- partitionEithers . L.toList <$> (mapM processCommand =<< atomically (readTBQueue rcvQ)) forM_ (L.nonEmpty rs) reply - -- TODO cancel this thread if the client gets disconnected - -- TODO limit client concurrency - forM_ (L.nonEmpty proxied) $ \cmds -> forkIO $ mapConcurrently processProxiedCmd cmds >>= reply + forM_ (L.nonEmpty proxied) $ \cmds -> mapM forkProxiedCmd cmds >>= mapM (atomically . takeTMVar) >>= reply where reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m () reply = atomically . writeTBQueue sndQ + forkProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (TMVar (Transmission BrokerMsg)) + forkProxiedCmd cmd = do + res <- newEmptyTMVarIO + bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ + -- commands MUST be processed under a reasonable timeout or the client would halt + processProxiedCmd cmd >>= atomically . putTMVar res + pure res + where + wait = do + ServerConfig {serverClientConcurrency} <- asks config + atomically $ do + used <- readTVar procThreads + when (used >= serverClientConcurrency) retry + writeTVar procThreads $! used + 1 + signal = atomically $ modifyTVar' procThreads (\t -> t - 1) processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Transmission BrokerMsg) processProxiedCmd (corrId, sessId, command) = (corrId, sessId,) <$> case command of PRXY srv auth -> ifM allowProxy getRelay (pure $ ERR $ PROXY BASIC_AUTH) diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 52a6094bc..77adb94f4 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -82,7 +82,8 @@ data ServerConfig = ServerConfig -- | run listener on control port controlPort :: Maybe ServiceName, smpAgentCfg :: SMPClientAgentConfig, - allowSMPProxy :: Bool -- auth is the same with `newQueueBasicAuth` + allowSMPProxy :: Bool, -- auth is the same with `newQueueBasicAuth` + serverClientConcurrency :: Int } defMsgExpirationDays :: Int64 @@ -102,6 +103,9 @@ defaultInactiveClientExpiration = checkInterval = 3600 -- seconds, 1 hours } +defaultProxyClientConcurrency :: Int +defaultProxyClientConcurrency = 16 + data Env = Env { config :: ServerConfig, server :: Server, @@ -139,6 +143,7 @@ data Client = Client rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)), sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), msgQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), + procThreads :: TVar Int, endThreads :: TVar (IntMap (Weak ThreadId)), endThreadSeq :: TVar Int, thVersion :: VersionSMP, @@ -173,12 +178,13 @@ newClient nextClientId qSize thVersion sessionId createdAt = do rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize msgQ <- newTBQueue qSize + procThreads <- newTVar 0 endThreads <- newTVar IM.empty endThreadSeq <- newTVar 0 connected <- newTVar True rcvActiveAt <- newTVar createdAt sndActiveAt <- newTVar createdAt - return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, msgQ, endThreads, endThreadSeq, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt} + return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, msgQ, procThreads, endThreads, endThreadSeq, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt} newSubscription :: SubscriptionThread -> STM Sub newSubscription subThread = do diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 6af1ce2a5..980a7b8d0 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -27,7 +27,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BasicAuth (..), ProtoServerWithAuth (ProtoServerWithAuth), pattern SMPServer) import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.CLI -import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defMsgExpirationDays, defaultInactiveClientExpiration, defaultMessageExpiration) +import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defMsgExpirationDays, defaultInactiveClientExpiration, defaultMessageExpiration, defaultProxyClientConcurrency) import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Transport (batchCmdsSMPVersion, sendingProxySMPVersion, simplexMQVersion, supportedSMPHandshakes, supportedServerSMPRelayVRange) import Simplex.Messaging.Transport.Client (TransportHost (..)) @@ -156,7 +156,9 @@ smpServerCLI cfgPath logPath = \# `socks_mode` can be 'onion' for SOCKS proxy to be used for .onion destination hosts only (default)\n\ \# or 'always' to be used for all destination hosts (can be used if it is an .onion server).\n\ \# socks_mode: onion\n\n\ - \[INACTIVE_CLIENTS]\n\ + \# Limit number of threads a client can spawn to process proxy commands in parrallel.\n" + <> ("# client_concurrency: " <> show defaultProxyClientConcurrency <> "\n\n") + <> "[INACTIVE_CLIENTS]\n\ \# TTL and interval to check inactive clients\n\ \disconnect: off\n" <> ("# ttl: " <> show (ttl defaultInactiveClientExpiration) <> "\n") @@ -251,7 +253,8 @@ smpServerCLI cfgPath logPath = ownServerDomains = either (const []) textToOwnServers $ lookupValue "PROXY" "own_server_domains" ini, persistErrorInterval = 30 -- seconds }, - allowSMPProxy = True + allowSMPProxy = True, + serverClientConcurrency = readIniDefault defaultProxyClientConcurrency "PROXY" "client_concurrency" ini } textToSocksMode :: Text -> SocksMode textToSocksMode = \case diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index d2e11d29b..99633ac1d 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -119,7 +119,8 @@ cfg = transportConfig = defaultTransportServerConfig {Server.alpn = Just supportedSMPHandshakes}, controlPort = Nothing, smpAgentCfg = defaultSMPClientAgentConfig, - allowSMPProxy = False + allowSMPProxy = False, + serverClientConcurrency = 2 } cfgV7 :: ServerConfig