diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index bd878ada7..3b4a78cf5 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -45,7 +45,7 @@ import Simplex.Messaging.Session import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport -import Simplex.Messaging.Util (catchAll_, ifM, toChunks, whenM, ($>>=)) +import Simplex.Messaging.Util (catchAll_, ifM, toChunks, whenM, ($>>=), (<$$>)) import System.Timeout (timeout) import UnliftIO (async) import qualified UnliftIO.Exception as E @@ -287,6 +287,20 @@ notify :: MonadIO m => SMPClientAgent -> SMPClientAgentEvent -> m () notify ca evt = atomically $ writeTBQueue (agentQ ca) evt {-# INLINE notify #-} +-- Returns already connected client for proxying messages or Nothing if client is absent, not connected yet or stores expired error. +-- If Nothing is return proxy will spawn a new thread to wait or to create another client connection to destination relay. +getConnectedSMPServerClient :: SMPClientAgent -> SMPServer -> IO (Maybe (Either SMPClientError (OwnServer, SMPClient))) +getConnectedSMPServerClient SMPClientAgent {smpClients} srv = + atomically (TM.lookup srv smpClients $>>= \v -> (v,) <$$> tryReadTMVar (sessionVar v)) -- Nothing: client is absent or not connected yet + $>>= \case + (_, Right r) -> pure $ Just $ Right r + (v, Left (e, ts_)) -> + pure ts_ $>>= \ts -> -- proxy will create a new connection if ts_ is Nothing + ifM + ((ts <) <$> liftIO getCurrentTime) -- error persistence interval period expired? + (Nothing <$ atomically (removeSessVar v srv smpClients)) -- proxy will create a new connection + (pure $ Just $ Left e) -- not expired, returning error + lookupSMPServerClient :: SMPClientAgent -> SessionId -> STM (Maybe (OwnServer, SMPClient)) lookupSMPServerClient SMPClientAgent {smpSessions} sessId = TM.lookup sessId smpSessions diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index c0ab7df8e..adf0b5df7 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -59,7 +59,7 @@ import Data.List (intercalate, mapAccumR) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (isNothing) +import Data.Maybe (catMaybes, fromMaybe, isNothing) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) @@ -70,8 +70,8 @@ import GHC.Stats (getRTSStats) import GHC.TypeLits (KnownNat) import Network.Socket (ServiceName, Socket, socketToHandle) import Simplex.Messaging.Agent.Lock -import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), forwardSMPMessage, smpProxyError, temporaryClientError) -import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient) +import Simplex.Messaging.Client (ProtocolClient (thParams), ProtocolClientError (..), SMPClient, SMPClientError, forwardSMPMessage, smpProxyError, temporaryClientError) +import Simplex.Messaging.Client.Agent (OwnServer, SMPClientAgent (..), SMPClientAgentEvent (..), closeSMPClientAgent, getSMPServerClient'', isOwnServer, lookupSMPServerClient, getConnectedSMPServerClient) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -103,7 +103,6 @@ import UnliftIO.IO import UnliftIO.STM #if MIN_VERSION_base(4,18,0) import Data.List (sort) -import Data.Maybe (fromMaybe) import GHC.Conc (listThreads, threadStatus) import GHC.Conc.Sync (threadLabel) #endif @@ -657,40 +656,33 @@ forkClient Client {endThreads, endThreadSeq} label action = do client :: THandleParams SMPVersion 'TServer -> Client -> Server -> M () client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, notifiers} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands" - forever $ do - (proxied, rs) <- partitionEithers . L.toList <$> (mapM processCommand =<< atomically (readTBQueue rcvQ)) - forM_ (L.nonEmpty rs) reply - forM_ (L.nonEmpty proxied) $ \cmds -> mapM forkProxiedCmd cmds >>= mapM (atomically . takeTMVar) >>= reply + forever $ + atomically (readTBQueue rcvQ) + >>= mapM processCommand + >>= mapM_ reply . L.nonEmpty . catMaybes . L.toList where reply :: MonadIO m => NonEmpty (Transmission BrokerMsg) -> m () reply = atomically . writeTBQueue sndQ - forkProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (TMVar (Transmission BrokerMsg)) - forkProxiedCmd cmd = do - res <- newEmptyTMVarIO - bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ - -- commands MUST be processed under a reasonable timeout or the client would halt - processProxiedCmd cmd >>= atomically . putTMVar res - pure res - where - wait = do - ServerConfig {serverClientConcurrency} <- asks config - atomically $ do - used <- readTVar procThreads - when (used >= serverClientConcurrency) retry - writeTVar procThreads $! used + 1 - signal = atomically $ modifyTVar' procThreads (\t -> t - 1) - processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Transmission BrokerMsg) - processProxiedCmd (corrId, sessId, command) = (corrId, sessId,) <$> case command of - PRXY srv auth -> ifM allowProxy getRelay (pure $ ERR $ PROXY BASIC_AUTH) + processProxiedCmd :: Transmission (Command 'ProxiedClient) -> M (Maybe (Transmission BrokerMsg)) + processProxiedCmd (corrId, sessId, command) = (corrId,sessId,) <$$> case command of + PRXY srv auth -> ifM allowProxy getRelay (pure $ Just $ ERR $ PROXY BASIC_AUTH) where allowProxy = do ServerConfig {allowSMPProxy, newQueueBasicAuth} <- asks config pure $ allowSMPProxy && maybe True ((== auth) . Just) newQueueBasicAuth getRelay = do + ProxyAgent {smpAgent = a} <- asks proxyAgent + liftIO (getConnectedSMPServerClient a srv) >>= \case + Just r -> Just <$> proxyServerResponse a r + Nothing -> + forkProxiedCmd $ + liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError)) + >>= proxyServerResponse a + proxyServerResponse :: SMPClientAgent -> Either SMPClientError (OwnServer, SMPClient) -> M BrokerMsg + proxyServerResponse a smp_ = do ServerStats {pRelays, pRelaysOwn} <- asks serverStats let inc = mkIncProxyStats pRelays pRelaysOwn - ProxyAgent {smpAgent = a} <- asks proxyAgent - liftIO (runExceptT (getSMPServerClient'' a srv) `catch` (pure . Left . PCEIOError)) >>= \case + case smp_ of Right (own, smp) -> do inc own pRequests case proxyResp smp of @@ -704,7 +696,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi where proxyResp smp = let THandleParams {sessionId = srvSessId, thVersion, thServerVRange, thAuth} = thParams smp - in case compatibleVRange thServerVRange proxiedSMPRelayVRange of + in case compatibleVRange thServerVRange proxiedSMPRelayVRange of -- Cap the destination relay version range to prevent client version fingerprinting. -- See comment for proxiedSMPRelayVersion. Just (Compatible vr) | thVersion >= sendingProxySMPVersion -> case thAuth of @@ -718,54 +710,66 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi atomically (lookupSMPServerClient a sessId) >>= \case Just (own, smp) -> do inc own pRequests - if - | v >= sendingProxySMPVersion -> - liftIO (runExceptT (forwardSMPMessage smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case - Right r -> PRES r <$ inc own pSuccesses - Left e -> case e of - PCEProtocolError {} -> ERR err <$ inc own pSuccesses - _ -> ERR err <$ inc own pErrorsOther - where - err = smpProxyError e - | otherwise -> ERR (transportErr TEVersion) <$ inc own pErrorsCompat + if v >= sendingProxySMPVersion + then forkProxiedCmd $ do + liftIO (runExceptT (forwardSMPMessage smp corrId fwdV pubKey encBlock) `catch` (pure . Left . PCEIOError)) >>= \case + Right r -> PRES r <$ inc own pSuccesses + Left e -> ERR (smpProxyError e) <$ case e of + PCEProtocolError {} -> inc own pSuccesses + _ -> inc own pErrorsOther + else Just (ERR $ transportErr TEVersion) <$ inc own pErrorsCompat where THandleParams {thVersion = v} = thParams smp - Nothing -> inc False pRequests >> inc False pErrorsConnect $> ERR (PROXY NO_SESSION) + Nothing -> inc False pRequests >> inc False pErrorsConnect $> Just (ERR $ PROXY NO_SESSION) + where + forkProxiedCmd :: M BrokerMsg -> M (Maybe BrokerMsg) + forkProxiedCmd cmdAction = do + bracket_ wait signal . forkClient clnt (B.unpack $ "client $" <> encode sessionId <> " proxy") $ do + -- commands MUST be processed under a reasonable timeout or the client would halt + cmdAction >>= \t -> reply [(corrId, sessId, t)] + pure Nothing + where + wait = do + ServerConfig {serverClientConcurrency} <- asks config + atomically $ do + used <- readTVar procThreads + when (used >= serverClientConcurrency) retry + writeTVar procThreads $! used + 1 + signal = atomically $ modifyTVar' procThreads (\t -> t - 1) transportErr :: TransportError -> ErrorType transportErr = PROXY . BROKER . TRANSPORT mkIncProxyStats :: MonadIO m => ProxyStats -> ProxyStats -> OwnServer -> (ProxyStats -> TVar Int) -> m () mkIncProxyStats ps psOwn = \own sel -> do atomically $ modifyTVar' (sel ps) (+ 1) when own $ atomically $ modifyTVar' (sel psOwn) (+ 1) - processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Either (Transmission (Command 'ProxiedClient)) (Transmission BrokerMsg)) - processCommand (qr_, (corrId, queueId, cmd)) = do - st <- asks queueStore - case cmd of - Cmd SProxiedClient command -> pure $ Left (corrId, queueId, command) - Cmd SSender command -> Right <$> case command of - SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody - PING -> pure (corrId, "", PONG) - RFWD encBlock -> (corrId, "",) <$> processForwardedCommand encBlock - Cmd SNotifier NSUB -> Right <$> subscribeNotifications - Cmd SRecipient command -> - Right <$> case command of - NEW rKey dhKey auth subMode -> - ifM - allowNew - (createQueue st rKey dhKey subMode) - (pure (corrId, queueId, ERR AUTH)) - where - allowNew = do - ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config - pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth - SUB -> withQueue (`subscribeQueue` queueId) - GET -> withQueue getMessage - ACK msgId -> withQueue (`acknowledgeMsg` msgId) - KEY sKey -> secureQueue_ st sKey - NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey - NDEL -> deleteQueueNotifier_ st - OFF -> suspendQueue_ st - DEL -> delQueueAndMsgs st + processCommand :: (Maybe QueueRec, Transmission Cmd) -> M (Maybe (Transmission BrokerMsg)) + processCommand (qr_, (corrId, queueId, cmd)) = case cmd of + Cmd SProxiedClient command -> processProxiedCmd (corrId, queueId, command) + Cmd SSender command -> Just <$> case command of + SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody + PING -> pure (corrId, "", PONG) + RFWD encBlock -> (corrId, "",) <$> processForwardedCommand encBlock + Cmd SNotifier NSUB -> Just <$> subscribeNotifications + Cmd SRecipient command -> do + st <- asks queueStore + Just <$> case command of + NEW rKey dhKey auth subMode -> + ifM + allowNew + (createQueue st rKey dhKey subMode) + (pure (corrId, queueId, ERR AUTH)) + where + allowNew = do + ServerConfig {allowNewQueues, newQueueBasicAuth} <- asks config + pure $ allowNewQueues && maybe True ((== auth) . Just) newQueueBasicAuth + SUB -> withQueue (`subscribeQueue` queueId) + GET -> withQueue getMessage + ACK msgId -> withQueue (`acknowledgeMsg` msgId) + KEY sKey -> secureQueue_ st sKey + NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey + NDEL -> deleteQueueNotifier_ st + OFF -> suspendQueue_ st + DEL -> delQueueAndMsgs st where createQueue :: QueueStore -> RcvPublicAuthKey -> RcvPublicDhKey -> SubscriptionMode -> M (Transmission BrokerMsg) createQueue st recipientKey dhKey subMode = time "NEW" $ do @@ -1036,7 +1040,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi Right t''@(_, (corrId', entId', cmd')) -> case cmd' of Cmd SSender SEND {} -> -- Left will not be returned by processCommand, as only SEND command is allowed - fromRight (corrId', entId', ERR INTERNAL) <$> lift (processCommand t'') + fromMaybe (corrId', entId', ERR INTERNAL) <$> lift (processCommand t'') _ -> pure (corrId', entId', ERR $ CMD PROHIBITED) -- encode response diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 77adb94f4..f602c890b 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -104,7 +104,7 @@ defaultInactiveClientExpiration = } defaultProxyClientConcurrency :: Int -defaultProxyClientConcurrency = 16 +defaultProxyClientConcurrency = 32 data Env = Env { config :: ServerConfig, diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs index aaaa2de51..a810247fe 100644 --- a/tests/AgentTests/EqInstances.hs +++ b/tests/AgentTests/EqInstances.hs @@ -6,6 +6,7 @@ module AgentTests.EqInstances where import Data.Type.Equality import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Client (ProxiedRelay (..)) instance Eq SomeConn where SomeConn d c == SomeConn d' c' = case testEquality d d' of @@ -23,3 +24,7 @@ deriving instance Eq (StoredSndQueue q) deriving instance Eq (DBQueueId q) deriving instance Eq ClientNtfCreds + +deriving instance Show ProxiedRelay + +deriving instance Eq ProxiedRelay diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 748eb34e7..0c5792d7a 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -12,6 +12,7 @@ module SMPProxyTests where +import AgentTests.EqInstances () import AgentTests.FunctionalAPITests import Control.Logger.Simple import Control.Monad (forM, forM_, forever) @@ -150,10 +151,13 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do QIK {rcvId, sndId, rcvPublicDhKey = srvDh} <- runExceptT' $ createSMPQueue rc (rPub, rPriv) rdhPub (Just "correct") SMSubscribe let dec = decryptMsgV3 $ C.dh' srvDh rdhPriv -- get proxy session + sess0 <- runExceptT' $ connectSMPProxiedRelay pc relayServ (Just "correct") sess <- runExceptT' $ connectSMPProxiedRelay pc relayServ (Just "correct") + sess0 `shouldBe` sess -- send via proxy to unsecured queue forM_ unsecuredMsgs $ \msg -> do runExceptT' (proxySMPMessage pc sess Nothing sndId noMsgFlags msg) `shouldReturn` Right () + runExceptT' (proxySMPMessage pc sess {prSessionId = "bad session"} Nothing sndId noMsgFlags msg) `shouldReturn` Left (ProxyProtocolError $ SMP.PROXY SMP.NO_SESSION) -- receive 1 (_tSess, _v, _sid, [(_entId, STEvent (Right (SMP.MSG RcvMessage {msgId, msgBody = EncRcvMsgBody encBody})))]) <- atomically $ readTBQueue msgQ dec msgId encBody `shouldBe` Right msg