From 9e7e0d102dc39846103c71101b84d793f16ab8ab Mon Sep 17 00:00:00 2001 From: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:15:08 +0300 Subject: [PATCH] smp-server: conserve resources (#1194) * transport: force auth params, remove async wrapper * stricter new messages * bang more thunks * style * don't produce msgQuota unless requested * strict * refactor * remove bangs --------- Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> --- src/Simplex/Messaging/Agent.hs | 4 +++- src/Simplex/Messaging/Client.hs | 17 +++++++++-------- src/Simplex/Messaging/Client/Agent.hs | 6 +++--- src/Simplex/Messaging/Notifications/Server.hs | 2 +- src/Simplex/Messaging/Server.hs | 13 +++++++------ src/Simplex/Messaging/Server/MsgStore/STM.hs | 5 +++-- src/Simplex/Messaging/Server/QueueStore/STM.hs | 3 ++- src/Simplex/Messaging/Transport.hs | 10 +++++++--- 8 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 63f74b5b2..f1ab78200 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -427,7 +427,9 @@ setNetworkConfig c@AgentClient {useNetworkConfig} cfg' = do (_, cfg) <- readTVar useNetworkConfig if cfg == cfg' then pure False - else True <$ (writeTVar useNetworkConfig $! (slowNetworkConfig cfg', cfg')) + else + let cfgSlow = slowNetworkConfig cfg' + in True <$ (cfgSlow `seq` writeTVar useNetworkConfig (cfgSlow, cfg')) when changed $ reconnectAllServers c setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO () diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index e4413d595..de178e368 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -100,6 +100,7 @@ module Simplex.Messaging.Client where import Control.Applicative ((<|>)) +import Control.Concurrent (ThreadId, forkFinally, killThread, mkWeakThreadId) import Control.Concurrent.Async import Control.Concurrent.STM import Control.Exception @@ -138,13 +139,14 @@ import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tshow, whenM) import Simplex.Messaging.Version +import System.Mem.Weak (Weak, deRefWeak) import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- -- Use 'getSMPClient' to connect to an SMP server and create a client handle. data ProtocolClient v err msg = ProtocolClient - { action :: Maybe (Async ()), + { action :: Maybe (Weak ThreadId), thParams :: THandleParams v 'TClient, sessionTs :: UTCTime, client_ :: PClient v err msg @@ -475,15 +477,14 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize cVar <- newEmptyTMVarIO let tcConfig = (transportClientConfig networkConfig useHost) {alpn = clientALPN} username = proxyUsername transportSession - action <- - async $ - runTransportClient tcConfig (Just username) useHost port' (Just $ keyHash srv) (client t c cVar) - `finally` atomically (tryPutTMVar cVar $ Left PCENetworkError) + tId <- + runTransportClient tcConfig (Just username) useHost port' (Just $ keyHash srv) (client t c cVar) + `forkFinally` \_ -> void (atomically . tryPutTMVar cVar $ Left PCENetworkError) c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar) case c_ of - Just (Right c') -> pure $ Right c' {action = Just action} + Just (Right c') -> mkWeakThreadId tId >>= \tId' -> pure $ Right c' {action = Just tId'} Just (Left e) -> pure $ Left e - Nothing -> cancel action $> Left PCENetworkError + Nothing -> killThread tId $> Left PCENetworkError useTransport :: (ServiceName, ATransport) useTransport = case port srv of @@ -589,7 +590,7 @@ proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" ( -- | Disconnects client from the server and terminates client threads. closeProtocolClient :: ProtocolClient v err msg -> IO () -closeProtocolClient = mapM_ uninterruptibleCancel . action +closeProtocolClient = mapM_ (deRefWeak >=> mapM_ killThread) . action {-# INLINE closeProtocolClient #-} -- | SMP client error type. diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 8781d87a6..e7c22eec2 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -1,8 +1,7 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} @@ -171,7 +170,8 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke case r of Right smp -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - let c = (isOwnServer ca srv, smp) + let !owned = isOwnServer ca srv + !c = (owned, smp) atomically $ do putTMVar (sessionVar v) (Right c) TM.insert (sessionId $ thParams smp) c smpSessions diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 892560660..5d3b4d806 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -382,7 +382,7 @@ send :: Transport c => THandleNTF c 'TServer -> NtfServerClient -> IO () send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do t <- atomically $ readTBQueue sndQ void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)] - atomically . writeTVar sndActiveAt =<< liftIO getSystemTime + atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime -- instance Show a => Show (TVar a) where -- show x = unsafePerformIO $ show <$> readTVarIO x diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index dfb4973ea..5c8d13a5e 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -525,7 +525,7 @@ receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiv labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" forever $ do ts <- L.toList <$> liftIO (tGet h) - atomically . writeTVar rcvActiveAt =<< liftIO getSystemTime + atomically . (writeTVar rcvActiveAt $!) =<< liftIO getSystemTime stats <- asks serverStats (errs, cmds) <- partitionEithers <$> mapM (cmdAction stats) ts write sndQ errs @@ -581,7 +581,7 @@ tSend :: Transport c => MVar (THandleSMP c 'TServer) -> Client -> NonEmpty (Tran tSend th Client {sndActiveAt} ts = do withMVar th $ \h@THandle {params} -> void . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts - atomically . writeTVar sndActiveAt =<< liftIO getSystemTime + atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do @@ -1037,15 +1037,16 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi mkMessage body = do msgId <- randomId =<< asks (msgIdBytes . config) msgTs <- liftIO getSystemTime - pure $ Message msgId msgTs msgFlags body + pure $! Message msgId msgTs msgFlags body expireMessages :: MsgQueue -> M () expireMessages q = do msgExp <- asks $ messageExpiration . config old <- liftIO $ mapM expireBeforeEpoch msgExp - stats <- asks serverStats deleted <- atomically $ sum <$> mapM (deleteExpiredMsgs q) old - atomically $ modifyTVar' (msgExpired stats) (+ deleted) + when (deleted > 0) $ do + stats <- asks serverStats + atomically $ modifyTVar' (msgExpired stats) (+ deleted) trySendNotification :: NtfCreds -> Message -> TVar ChaChaDRG -> STM (Maybe Bool) trySendNotification NtfCreds {notifierId, rcvNtfDhSecret} msg ntfNonceDrg = @@ -1164,7 +1165,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi msgTs' = messageTs msg setDelivered :: Sub -> Message -> STM Bool - setDelivered s msg = tryPutTMVar (delivered s) (messageId msg) + setDelivered s msg = tryPutTMVar (delivered s) $! messageId msg getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index 2d735d1d4..e315c4fe5 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -75,7 +76,7 @@ snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue . pure msgs writeMsg :: MsgQueue -> Message -> STM (Maybe Message) -writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do +writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do canWrt <- readTVar canWrite empty <- isEmptyTQueue q if canWrt || empty @@ -85,7 +86,7 @@ writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do modifyTVar' size (+ 1) if canWrt' then writeTQueue q msg $> Just msg - else writeTQueue q msgQuota $> Nothing + else (writeTQueue q $! msgQuota) $> Nothing else pure Nothing where msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg} diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 8de7a38c6..d6cdaf10a 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -69,7 +70,7 @@ secureQueue QueueStore {queues} rId sKey = readTVar qVar >>= \q -> case senderKey q of Just k -> pure $ if sKey == k then Just q else Nothing _ -> - let q' = q {senderKey = Just sKey} + let !q' = q {senderKey = Just sKey} in writeTVar qVar q' $> Just q' addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 6eddcabf8..7088480f5 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -83,7 +83,7 @@ module Simplex.Messaging.Transport where import Control.Applicative (optional) -import Control.Monad (forM) +import Control.Monad (forM, (<$!>)) import Control.Monad.Except import Control.Monad.Trans.Except (throwE) import qualified Data.Aeson.TH as J @@ -540,12 +540,12 @@ smpClientHandshake c ks_ keyHash@(C.KeyHash kh) smpVRange = do smpTHandleServer :: forall c. THandleSMP c 'TServer -> VersionSMP -> VersionRangeSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c 'TServer smpTHandleServer th v vr pk k_ = - let thAuth = THAuthServer {serverPrivKey = pk, sessSecret' = (`C.dh'` pk) <$> k_} + let thAuth = THAuthServer {serverPrivKey = pk, sessSecret' = (`C.dh'` pk) <$!> k_} in smpTHandle_ th v vr (Just thAuth) smpTHandleClient :: forall c. THandleSMP c 'TClient -> VersionSMP -> VersionRangeSMP -> Maybe C.PrivateKeyX25519 -> Maybe (C.PublicKeyX25519, (X.CertificateChain, X.SignedExact X.PubKey)) -> THandleSMP c 'TClient smpTHandleClient th v vr pk_ ck_ = - let thAuth = (\(k, ck) -> THAuthClient {serverPeerPubKey = k, serverCertKey = ck, sessSecret = C.dh' k <$> pk_}) <$> ck_ + let thAuth = (\(k, ck) -> THAuthClient {serverPeerPubKey = k, serverCertKey = forceCertChain ck, sessSecret = C.dh' k <$!> pk_}) <$!> ck_ in smpTHandle_ th v vr thAuth smpTHandle_ :: forall c p. THandleSMP c p -> VersionSMP -> VersionRangeSMP -> Maybe (THandleAuth p) -> THandleSMP c p @@ -554,6 +554,10 @@ smpTHandle_ th@THandle {params} v vr thAuth = let params' = params {thVersion = v, thServerVRange = vr, thAuth, implySessId = v >= authCmdsSMPVersion} in (th :: THandleSMP c p) {params = params'} +{-# INLINE forceCertChain #-} +forceCertChain :: (X.CertificateChain, X.SignedExact T.PubKey) -> (X.CertificateChain, X.SignedExact T.PubKey) +forceCertChain cert@(X.CertificateChain cc, signedKey) = length (show cc) `seq` show signedKey `seq` cert + -- This function is only used with v >= 8, so currently it's a simple record update. -- It may require some parameters update in the future, to be consistent with smpTHandle_. smpTHParamsSetVersion :: VersionSMP -> THandleParams SMPVersion p -> THandleParams SMPVersion p