server: preprocess proxy commands that will not be connecting to network to reduce concurrency, do not wait for destination relay responses before processing the next command (#1174)

* server: preprocess proxy commands that will not be connecting to network to reduce concurrency

* implementation

* tests

* increase proxy client concurrency

* simplify

* refactor

* refactor2

* rename

* refactor3

* fix 8.10.7
This commit is contained in:
Evgeny Poberezkin
2024-05-28 09:38:47 +01:00
committed by GitHub
parent c8b2bb2ae1
commit 4a96dbf871
5 changed files with 99 additions and 72 deletions
+15 -1
View File
@@ -45,7 +45,7 @@ import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Util (catchAll_, ifM, toChunks, whenM, ($>>=))
import Simplex.Messaging.Util (catchAll_, ifM, toChunks, whenM, ($>>=), (<$$>))
import System.Timeout (timeout)
import UnliftIO (async)
import qualified UnliftIO.Exception as E
@@ -287,6 +287,20 @@ notify :: MonadIO m => SMPClientAgent -> SMPClientAgentEvent -> m ()
notify ca evt = atomically $ writeTBQueue (agentQ ca) evt
{-# INLINE notify #-}
-- Returns already connected client for proxying messages or Nothing if client is absent, not connected yet or stores expired error.
-- If Nothing is return proxy will spawn a new thread to wait or to create another client connection to destination relay.
getConnectedSMPServerClient :: SMPClientAgent -> SMPServer -> IO (Maybe (Either SMPClientError (OwnServer, SMPClient)))
getConnectedSMPServerClient SMPClientAgent {smpClients} srv =
atomically (TM.lookup srv smpClients $>>= \v -> (v,) <$$> tryReadTMVar (sessionVar v)) -- Nothing: client is absent or not connected yet
$>>= \case
(_, Right r) -> pure $ Just $ Right r
(v, Left (e, ts_)) ->
pure ts_ $>>= \ts -> -- proxy will create a new connection if ts_ is Nothing
ifM
((ts <) <$> liftIO getCurrentTime) -- error persistence interval period expired?
(Nothing <$ atomically (removeSessVar v srv smpClients)) -- proxy will create a new connection
(pure $ Just $ Left e) -- not expired, returning error
lookupSMPServerClient :: SMPClientAgent -> SessionId -> STM (Maybe (OwnServer, SMPClient))
lookupSMPServerClient SMPClientAgent {smpSessions} sessId = TM.lookup sessId smpSessions
+74 -70
View File
@@ -59,7 +59,7 @@ import Data.List (intercalate, mapAccumR)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (isNothing)
import Data.Maybe (catMaybes, fromMaybe, isNothing)
import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1)
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
@@ -70,8 +70,8 @@ import GHC.Stats (getRTSStats)
import GHC.TypeLits (KnownNat)
import Network.Socket (ServiceName, Socket, socketToHandle)
import Simplex.Messaging.Agent.Lock
import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), forwardSMPMessage, smpProxyError, temporaryClientError)
import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient)
import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPMessage, smpProxyError, temporaryClientError)
import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
@@ -103,7 +103,6 @@ import UnliftIO.IO
import UnliftIO.STM
#if MIN_VERSION_base(4,18,0)
import Data.List (sort)
import Data.Maybe (fromMaybe)
import GHC.Conc (listThreads, threadStatus)
import GHC.Conc.Sync (threadLabel)
#endif
@@ -657,40 +656,33 @@ forkClient Client {endThreads, endThreadSeq} label action = do
client :: THandleParams SMPVersion 'TServer -> Client -> Server -> M ()
client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, notifiers} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
forever $ do
(proxied, rs) <- partitionEithers . L.toList <$> (mapM processCommand =<< atomically (readTBQueue rcvQ))
forM_ (L.nonEmpty rs) reply
forM_ (L.nonEmpty proxied) $ \cmds -> mapM forkProxiedCmd cmds >>= mapM (atomically . takeTMVar) >>= reply
forever $
atomically (readTBQueue rcvQ)
>>= mapM processCommand
>>= mapM_ reply . L.nonEmpty . catMaybes . L.toList
where
reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m ()
reply = atomically . writeTBQueue sndQ
forkProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (TMVar (Transmission BrokerMsg))
forkProxiedCmd cmd = do
res <- newEmptyTMVarIO
bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $
-- commands MUST be processed under a reasonable timeout or the client would halt
processProxiedCmd cmd >>= atomically . putTMVar res
pure res
where
wait = do
ServerConfig {serverClientConcurrency} <- asks config
atomically $ do
used <- readTVar procThreads
when (used >= serverClientConcurrency) retry
writeTVar procThreads $! used + 1
signal = atomically $ modifyTVar' procThreads (\t -> t - 1)
processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Transmission BrokerMsg)
processProxiedCmd (corrId, sessId, command) = (corrId, sessId,) <$> case command of
PRXY srv auth -> ifM allowProxy getRelay (pure $ ERR $ PROXY BASIC_AUTH)
processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Maybe (Transmission BrokerMsg))
processProxiedCmd (corrId, sessId, command) = (corrId,sessId,) <$$> case command of
PRXY srv auth -> ifM allowProxy getRelay (pure $ Just $ ERR $ PROXY BASIC_AUTH)
where
allowProxy = do
ServerConfig {allowSMPProxy, newQueueBasicAuth} <- asks config
pure $ allowSMPProxy && maybe True ((== auth) . Just) newQueueBasicAuth
getRelay = do
ProxyAgent {smpAgent = a} <- asks proxyAgent
liftIO (getConnectedSMPServerClient a srv) >>= \case
Just r -> Just <$> proxyServerResponse a r
Nothing ->
forkProxiedCmd $
liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError))
>>= proxyServerResponse a
proxyServerResponse :: SMPClientAgent -> Either SMPClientError (OwnServer, SMPClient) -> M BrokerMsg
proxyServerResponse a smp_ = do
ServerStats {pRelays, pRelaysOwn} <- asks serverStats
let inc = mkIncProxyStats pRelays pRelaysOwn
ProxyAgent {smpAgent = a} <- asks proxyAgent
liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError)) >>= \case
case smp_ of
Right (own, smp) -> do
inc own pRequests
case proxyResp smp of
@@ -704,7 +696,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
where
proxyResp smp =
let THandleParams {sessionId = srvSessId, thVersion, thServerVRange, thAuth} = thParams smp
in case compatibleVRange thServerVRange proxiedSMPRelayVRange of
in case compatibleVRange thServerVRange proxiedSMPRelayVRange of
-- Cap the destination relay version range to prevent client version fingerprinting.
-- See comment for proxiedSMPRelayVersion.
Just (Compatible vr) | thVersion >= sendingProxySMPVersion -> case thAuth of
@@ -718,54 +710,66 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
atomically (lookupSMPServerClient a sessId) >>= \case
Just (own, smp) -> do
inc own pRequests
if
| v >= sendingProxySMPVersion ->
liftIO (runExceptT (forwardSMPMessage smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case
Right r -> PRES r <$ inc own pSuccesses
Left e -> case e of
PCEProtocolError {} -> ERR err <$ inc own pSuccesses
_ -> ERR err <$ inc own pErrorsOther
where
err = smpProxyError e
| otherwise -> ERR (transportErr TEVersion) <$ inc own pErrorsCompat
if v >= sendingProxySMPVersion
then forkProxiedCmd $ do
liftIO (runExceptT (forwardSMPMessage smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case
Right r -> PRES r <$ inc own pSuccesses
Left e -> ERR (smpProxyError e) <$ case e of
PCEProtocolError {} -> inc own pSuccesses
_ -> inc own pErrorsOther
else Just (ERR $ transportErr TEVersion) <$ inc own pErrorsCompat
where
THandleParams {thVersion = v} = thParams smp
Nothing -> inc False pRequests >> inc False pErrorsConnect $> ERR (PROXY NO_SESSION)
Nothing -> inc False pRequests >> inc False pErrorsConnect $> Just (ERR $ PROXY NO_SESSION)
where
forkProxiedCmd :: M BrokerMsg -> M (Maybe BrokerMsg)
forkProxiedCmd cmdAction = do
bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ do
-- commands MUST be processed under a reasonable timeout or the client would halt
cmdAction >>= \t -> reply [(corrId, sessId, t)]
pure Nothing
where
wait = do
ServerConfig {serverClientConcurrency} <- asks config
atomically $ do
used <- readTVar procThreads
when (used >= serverClientConcurrency) retry
writeTVar procThreads $! used + 1
signal = atomically $ modifyTVar' procThreads (\t -> t - 1)
transportErr :: TransportError -> ErrorType
transportErr = PROXY . BROKER . TRANSPORT
mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> TVar Int) -> m ()
mkIncProxyStats ps psOwn = \own sel -> do
atomically $ modifyTVar' (sel ps) (+ 1)
when own $ atomically $ modifyTVar' (sel psOwn) (+ 1)
processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Either (Transmission (Command 'ProxiedClient)) (Transmission BrokerMsg))
processCommand (qr_, (corrId, queueId, cmd)) = do
st <- asks queueStore
case cmd of
Cmd SProxiedClient command -> pure $ Left (corrId, queueId, command)
Cmd SSender command -> Right <$> case command of
SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody
PING -> pure (corrId, "", PONG)
RFWD encBlock -> (corrId, "",) <$> processForwardedCommand encBlock
Cmd SNotifier NSUB -> Right <$> subscribeNotifications
Cmd SRecipient command ->
Right <$> case command of
NEW rKey dhKey auth subMode ->
ifM
allowNew
(createQueue st rKey dhKey subMode)
(pure (corrId, queueId, ERR AUTH))
where
allowNew = do
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth
SUB -> withQueue (`subscribeQueue` queueId)
GET -> withQueue getMessage
ACK msgId -> withQueue (`acknowledgeMsg` msgId)
KEY sKey -> secureQueue_ st sKey
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
NDEL -> deleteQueueNotifier_ st
OFF -> suspendQueue_ st
DEL -> delQueueAndMsgs st
processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Maybe (Transmission BrokerMsg))
processCommand (qr_, (corrId, queueId, cmd)) = case cmd of
Cmd SProxiedClient command -> processProxiedCmd (corrId, queueId, command)
Cmd SSender command -> Just <$> case command of
SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody
PING -> pure (corrId, "", PONG)
RFWD encBlock -> (corrId, "",) <$> processForwardedCommand encBlock
Cmd SNotifier NSUB -> Just <$> subscribeNotifications
Cmd SRecipient command -> do
st <- asks queueStore
Just <$> case command of
NEW rKey dhKey auth subMode ->
ifM
allowNew
(createQueue st rKey dhKey subMode)
(pure (corrId, queueId, ERR AUTH))
where
allowNew = do
ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config
pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth
SUB -> withQueue (`subscribeQueue` queueId)
GET -> withQueue getMessage
ACK msgId -> withQueue (`acknowledgeMsg` msgId)
KEY sKey -> secureQueue_ st sKey
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
NDEL -> deleteQueueNotifier_ st
OFF -> suspendQueue_ st
DEL -> delQueueAndMsgs st
where
createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> M (Transmission BrokerMsg)
createQueue st recipientKey dhKey subMode = time "NEW" $ do
@@ -1036,7 +1040,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
Right t''@(_, (corrId', entId', cmd')) -> case cmd' of
Cmd SSender SEND {} ->
-- Left will not be returned by processCommand, as only SEND command is allowed
fromRight (corrId', entId', ERR INTERNAL) <$> lift (processCommand t'')
fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand t'')
_ ->
pure (corrId', entId', ERR $ CMD PROHIBITED)
-- encode response
+1 -1
View File
@@ -104,7 +104,7 @@ defaultInactiveClientExpiration =
}
defaultProxyClientConcurrency :: Int
defaultProxyClientConcurrency = 16
defaultProxyClientConcurrency = 32
data Env = Env
{ config :: ServerConfig,
+5
View File
@@ -6,6 +6,7 @@ module AgentTests.EqInstances where
import Data.Type.Equality
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Client (ProxiedRelay (..))
instance Eq SomeConn where
SomeConn d c == SomeConn d' c' = case testEquality d d' of
@@ -23,3 +24,7 @@ deriving instance Eq (StoredSndQueue q)
deriving instance Eq (DBQueueId q)
deriving instance Eq ClientNtfCreds
deriving instance Show ProxiedRelay
deriving instance Eq ProxiedRelay
+4
View File
@@ -12,6 +12,7 @@
module SMPProxyTests where
import AgentTests.EqInstances ()
import AgentTests.FunctionalAPITests
import Control.Logger.Simple
import Control.Monad (forM, forM_, forever)
@@ -150,10 +151,13 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do
QIK {rcvId, sndId, rcvPublicDhKey = srvDh} <- runExceptT' $ createSMPQueue rc (rPub, rPriv) rdhPub (Just "correct") SMSubscribe
let dec = decryptMsgV3 $ C.dh' srvDh rdhPriv
-- get proxy session
sess0 <- runExceptT' $ connectSMPProxiedRelay pc relayServ (Just "correct")
sess <- runExceptT' $ connectSMPProxiedRelay pc relayServ (Just "correct")
sess0 `shouldBe` sess
-- send via proxy to unsecured queue
forM_ unsecuredMsgs $ \msg -> do
runExceptT' (proxySMPMessage pc sess Nothing sndId noMsgFlags msg) `shouldReturn` Right ()
runExceptT' (proxySMPMessage pc sess {prSessionId = "bad session"} Nothing sndId noMsgFlags msg) `shouldReturn` Left (ProxyProtocolError $ SMP.PROXY SMP.NO_SESSION)
-- receive 1
(_tSess, _v, _sid, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId, msgBody = EncRcvMsgBody encBody})))]) <- atomically $ readTBQueue msgQ
dec msgId encBody `shouldBe` Right msg