smp-server: conserve resources (#1194)

* transport: force auth params, remove async wrapper

* stricter new messages

* bang more thunks

* style

* don't produce msgQuota unless requested

* strict

* refactor

* remove bangs

---------

Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com>
This commit is contained in:
Alexander Bondarenko
2024-06-24 15:15:08 +03:00
committed by GitHub
parent d47c099ac9
commit 9e7e0d102d
8 changed files with 35 additions and 25 deletions
+3 -1
View File
@@ -427,7 +427,9 @@ setNetworkConfig c@AgentClient {useNetworkConfig} cfg' = do
(_, cfg) <- readTVar useNetworkConfig (_, cfg) <- readTVar useNetworkConfig
if cfg == cfg' if cfg == cfg'
then pure False then pure False
else True <$ (writeTVar useNetworkConfig $! (slowNetworkConfig cfg', cfg')) else
let cfgSlow = slowNetworkConfig cfg'
in True <$ (cfgSlow `seq` writeTVar useNetworkConfig (cfgSlow, cfg'))
when changed $ reconnectAllServers c when changed $ reconnectAllServers c
setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO () setUserNetworkInfo :: AgentClient -> UserNetworkInfo -> IO ()
+9 -8
View File
@@ -100,6 +100,7 @@ module Simplex.Messaging.Client
where where
import Control.Applicative ((<|>)) import Control.Applicative ((<|>))
import Control.Concurrent (ThreadId, forkFinally, killThread, mkWeakThreadId)
import Control.Concurrent.Async import Control.Concurrent.Async
import Control.Concurrent.STM import Control.Concurrent.STM
import Control.Exception import Control.Exception
@@ -138,13 +139,14 @@ import Simplex.Messaging.Transport.KeepAlive
import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Transport.WebSockets (WS)
import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tshow, whenM) import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tshow, whenM)
import Simplex.Messaging.Version import Simplex.Messaging.Version
import System.Mem.Weak (Weak, deRefWeak)
import System.Timeout (timeout) import System.Timeout (timeout)
-- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- | 'SMPClient' is a handle used to send commands to a specific SMP server.
-- --
-- Use 'getSMPClient' to connect to an SMP server and create a client handle. -- Use 'getSMPClient' to connect to an SMP server and create a client handle.
data ProtocolClient v err msg = ProtocolClient data ProtocolClient v err msg = ProtocolClient
{ action :: Maybe (Async ()), { action :: Maybe (Weak ThreadId),
thParams :: THandleParams v 'TClient, thParams :: THandleParams v 'TClient,
sessionTs :: UTCTime, sessionTs :: UTCTime,
client_ :: PClient v err msg client_ :: PClient v err msg
@@ -475,15 +477,14 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
cVar <- newEmptyTMVarIO cVar <- newEmptyTMVarIO
let tcConfig = (transportClientConfig networkConfig useHost) {alpn = clientALPN} let tcConfig = (transportClientConfig networkConfig useHost) {alpn = clientALPN}
username = proxyUsername transportSession username = proxyUsername transportSession
action <- tId <-
async $ runTransportClient tcConfig (Just username) useHost port' (Just $ keyHash srv) (client t c cVar)
runTransportClient tcConfig (Just username) useHost port' (Just $ keyHash srv) (client t c cVar) `forkFinally` \_ -> void (atomically . tryPutTMVar cVar $ Left PCENetworkError)
`finally` atomically (tryPutTMVar cVar $ Left PCENetworkError)
c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar) c_ <- tcpConnectTimeout `timeout` atomically (takeTMVar cVar)
case c_ of case c_ of
Just (Right c') -> pure $ Right c' {action = Just action} Just (Right c') -> mkWeakThreadId tId >>= \tId' -> pure $ Right c' {action = Just tId'}
Just (Left e) -> pure $ Left e Just (Left e) -> pure $ Left e
Nothing -> cancel action $> Left PCENetworkError Nothing -> killThread tId $> Left PCENetworkError
useTransport :: (ServiceName, ATransport) useTransport :: (ServiceName, ATransport)
useTransport = case port srv of useTransport = case port srv of
@@ -589,7 +590,7 @@ proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (
-- | Disconnects client from the server and terminates client threads. -- | Disconnects client from the server and terminates client threads.
closeProtocolClient :: ProtocolClient v err msg -> IO () closeProtocolClient :: ProtocolClient v err msg -> IO ()
closeProtocolClient = mapM_ uninterruptibleCancel . action closeProtocolClient = mapM_ (deRefWeak >=> mapM_ killThread) . action
{-# INLINE closeProtocolClient #-} {-# INLINE closeProtocolClient #-}
-- | SMP client error type. -- | SMP client error type.
+3 -3
View File
@@ -1,8 +1,7 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
@@ -171,7 +170,8 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke
case r of case r of
Right smp -> do Right smp -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
let c = (isOwnServer ca srv, smp) let !owned = isOwnServer ca srv
!c = (owned, smp)
atomically $ do atomically $ do
putTMVar (sessionVar v) (Right c) putTMVar (sessionVar v) (Right c)
TM.insert (sessionId $ thParams smp) c smpSessions TM.insert (sessionId $ thParams smp) c smpSessions
@@ -382,7 +382,7 @@ send :: Transport c => THandleNTF c 'TServer -> NtfServerClient -> IO ()
send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do
t <- atomically $ readTBQueue sndQ t <- atomically $ readTBQueue sndQ
void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)] void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)]
atomically . writeTVar sndActiveAt =<< liftIO getSystemTime atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime
-- instance Show a => Show (TVar a) where -- instance Show a => Show (TVar a) where
-- show x = unsafePerformIO $ show <$> readTVarIO x -- show x = unsafePerformIO $ show <$> readTVarIO x
+7 -6
View File
@@ -525,7 +525,7 @@ receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiv
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
forever $ do forever $ do
ts <- L.toList <$> liftIO (tGet h) ts <- L.toList <$> liftIO (tGet h)
atomically . writeTVar rcvActiveAt =<< liftIO getSystemTime atomically . (writeTVar rcvActiveAt $!) =<< liftIO getSystemTime
stats <- asks serverStats stats <- asks serverStats
(errs, cmds) <- partitionEithers <$> mapM (cmdAction stats) ts (errs, cmds) <- partitionEithers <$> mapM (cmdAction stats) ts
write sndQ errs write sndQ errs
@@ -581,7 +581,7 @@ tSend :: Transport c => MVar (THandleSMP c 'TServer) -> Client -> NonEmpty (Tran
tSend th Client {sndActiveAt} ts = do tSend th Client {sndActiveAt} ts = do
withMVar th $ \h@THandle {params} -> withMVar th $ \h@THandle {params} ->
void . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts void . tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts
atomically . writeTVar sndActiveAt =<< liftIO getSystemTime atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime
disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () disconnectTransport :: Transport c => THandle v c 'TServer -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
@@ -1037,15 +1037,16 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
mkMessage body = do mkMessage body = do
msgId <- randomId =<< asks (msgIdBytes . config) msgId <- randomId =<< asks (msgIdBytes . config)
msgTs <- liftIO getSystemTime msgTs <- liftIO getSystemTime
pure $ Message msgId msgTs msgFlags body pure $! Message msgId msgTs msgFlags body
expireMessages :: MsgQueue -> M () expireMessages :: MsgQueue -> M ()
expireMessages q = do expireMessages q = do
msgExp <- asks $ messageExpiration . config msgExp <- asks $ messageExpiration . config
old <- liftIO $ mapM expireBeforeEpoch msgExp old <- liftIO $ mapM expireBeforeEpoch msgExp
stats <- asks serverStats
deleted <- atomically $ sum <$> mapM (deleteExpiredMsgs q) old deleted <- atomically $ sum <$> mapM (deleteExpiredMsgs q) old
atomically $ modifyTVar' (msgExpired stats) (+ deleted) when (deleted > 0) $ do
stats <- asks serverStats
atomically $ modifyTVar' (msgExpired stats) (+ deleted)
trySendNotification :: NtfCreds -> Message -> TVar ChaChaDRG -> STM (Maybe Bool) trySendNotification :: NtfCreds -> Message -> TVar ChaChaDRG -> STM (Maybe Bool)
trySendNotification NtfCreds {notifierId, rcvNtfDhSecret} msg ntfNonceDrg = trySendNotification NtfCreds {notifierId, rcvNtfDhSecret} msg ntfNonceDrg =
@@ -1164,7 +1165,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
msgTs' = messageTs msg msgTs' = messageTs msg
setDelivered :: Sub -> Message -> STM Bool setDelivered :: Sub -> Message -> STM Bool
setDelivered s msg = tryPutTMVar (delivered s) (messageId msg) setDelivered s msg = tryPutTMVar (delivered s) $! messageId msg
getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue getStoreMsgQueue :: T.Text -> RecipientId -> M MsgQueue
getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do getStoreMsgQueue name rId = time (name <> " getMsgQueue") $ do
+3 -2
View File
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
@@ -75,7 +76,7 @@ snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue .
pure msgs pure msgs
writeMsg :: MsgQueue -> Message -> STM (Maybe Message) writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do
canWrt <- readTVar canWrite canWrt <- readTVar canWrite
empty <- isEmptyTQueue q empty <- isEmptyTQueue q
if canWrt || empty if canWrt || empty
@@ -85,7 +86,7 @@ writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} msg = do
modifyTVar' size (+ 1) modifyTVar' size (+ 1)
if canWrt' if canWrt'
then writeTQueue q msg $> Just msg then writeTQueue q msg $> Just msg
else writeTQueue q msgQuota $> Nothing else (writeTQueue q $! msgQuota) $> Nothing
else pure Nothing else pure Nothing
where where
msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg} msgQuota = MessageQuota {msgId = msgId msg, msgTs = msgTs msg}
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
@@ -69,7 +70,7 @@ secureQueue QueueStore {queues} rId sKey =
readTVar qVar >>= \q -> case senderKey q of readTVar qVar >>= \q -> case senderKey q of
Just k -> pure $ if sKey == k then Just q else Nothing Just k -> pure $ if sKey == k then Just q else Nothing
_ -> _ ->
let q' = q {senderKey = Just sKey} let !q' = q {senderKey = Just sKey}
in writeTVar qVar q' $> Just q' in writeTVar qVar q' $> Just q'
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec) addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
+7 -3
View File
@@ -83,7 +83,7 @@ module Simplex.Messaging.Transport
where where
import Control.Applicative (optional) import Control.Applicative (optional)
import Control.Monad (forM) import Control.Monad (forM, (<$!>))
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Trans.Except (throwE) import Control.Monad.Trans.Except (throwE)
import qualified Data.Aeson.TH as J import qualified Data.Aeson.TH as J
@@ -540,12 +540,12 @@ smpClientHandshake c ks_ keyHash@(C.KeyHash kh) smpVRange = do
smpTHandleServer :: forall c. THandleSMP c 'TServer -> VersionSMP -> VersionRangeSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c 'TServer smpTHandleServer :: forall c. THandleSMP c 'TServer -> VersionSMP -> VersionRangeSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c 'TServer
smpTHandleServer th v vr pk k_ = smpTHandleServer th v vr pk k_ =
let thAuth = THAuthServer {serverPrivKey = pk, sessSecret' = (`C.dh'` pk) <$> k_} let thAuth = THAuthServer {serverPrivKey = pk, sessSecret' = (`C.dh'` pk) <$!> k_}
in smpTHandle_ th v vr (Just thAuth) in smpTHandle_ th v vr (Just thAuth)
smpTHandleClient :: forall c. THandleSMP c 'TClient -> VersionSMP -> VersionRangeSMP -> Maybe C.PrivateKeyX25519 -> Maybe (C.PublicKeyX25519, (X.CertificateChain, X.SignedExact X.PubKey)) -> THandleSMP c 'TClient smpTHandleClient :: forall c. THandleSMP c 'TClient -> VersionSMP -> VersionRangeSMP -> Maybe C.PrivateKeyX25519 -> Maybe (C.PublicKeyX25519, (X.CertificateChain, X.SignedExact X.PubKey)) -> THandleSMP c 'TClient
smpTHandleClient th v vr pk_ ck_ = smpTHandleClient th v vr pk_ ck_ =
let thAuth = (\(k, ck) -> THAuthClient {serverPeerPubKey = k, serverCertKey = ck, sessSecret = C.dh' k <$> pk_}) <$> ck_ let thAuth = (\(k, ck) -> THAuthClient {serverPeerPubKey = k, serverCertKey = forceCertChain ck, sessSecret = C.dh' k <$!> pk_}) <$!> ck_
in smpTHandle_ th v vr thAuth in smpTHandle_ th v vr thAuth
smpTHandle_ :: forall c p. THandleSMP c p -> VersionSMP -> VersionRangeSMP -> Maybe (THandleAuth p) -> THandleSMP c p smpTHandle_ :: forall c p. THandleSMP c p -> VersionSMP -> VersionRangeSMP -> Maybe (THandleAuth p) -> THandleSMP c p
@@ -554,6 +554,10 @@ smpTHandle_ th@THandle {params} v vr thAuth =
let params' = params {thVersion = v, thServerVRange = vr, thAuth, implySessId = v >= authCmdsSMPVersion} let params' = params {thVersion = v, thServerVRange = vr, thAuth, implySessId = v >= authCmdsSMPVersion}
in (th :: THandleSMP c p) {params = params'} in (th :: THandleSMP c p) {params = params'}
{-# INLINE forceCertChain #-}
forceCertChain :: (X.CertificateChain, X.SignedExact T.PubKey) -> (X.CertificateChain, X.SignedExact T.PubKey)
forceCertChain cert@(X.CertificateChain cc, signedKey) = length (show cc) `seq` show signedKey `seq` cert
-- This function is only used with v >= 8, so currently it's a simple record update. -- This function is only used with v >= 8, so currently it's a simple record update.
-- It may require some parameters update in the future, to be consistent with smpTHandle_. -- It may require some parameters update in the future, to be consistent with smpTHandle_.
smpTHParamsSetVersion :: VersionSMP -> THandleParams SMPVersion p -> THandleParams SMPVersion p smpTHParamsSetVersion :: VersionSMP -> THandleParams SMPVersion p -> THandleParams SMPVersion p