diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index b8614c460..43b3b8064 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -597,10 +597,10 @@ getSMPServerClient c@AgentClient {active, smpClients, workerSeq} tSess = do prs <- atomically TM.empty smpConnectClient c tSess prs v -getSMPProxyClient :: AgentClient -> SMPTransportSession -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) -getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq} destSess@(userId, destSrv, qId) = do +getSMPProxyClient :: AgentClient -> Maybe SMPServerWithAuth -> SMPTransportSession -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) +getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq} proxySrv_ destSess@(userId, destSrv, qId) = do unlessM (readTVarIO active) $ throwE INACTIVE - proxySrv <- getNextServer c userId [destSrv] + proxySrv <- maybe (getNextServer c userId [destSrv]) pure proxySrv_ ts <- liftIO getCurrentTime atomically (getClientVar proxySrv ts) >>= \(tSess, auth, v) -> either (newProxyClient tSess auth ts) (waitForProxyClient tSess auth) v @@ -993,9 +993,9 @@ withClient_ c tSess@(_, srv, _) action = do logServer "<--" c srv "" $ bshow e throwE e -withProxySession :: AgentClient -> SMPTransportSession -> SMP.SenderId -> ByteString -> ((SMPConnectedClient, ProxiedRelay) -> AM a) -> AM a -withProxySession c destSess@(_, destSrv, _) entId cmdStr action = do - (cl, sess_) <- getSMPProxyClient c destSess +withProxySession :: AgentClient -> Maybe SMPServerWithAuth -> SMPTransportSession -> SMP.SenderId -> ByteString -> ((SMPConnectedClient, ProxiedRelay) -> AM a) -> AM a +withProxySession c proxySrv_ destSess@(_, destSrv, _) entId cmdStr action = do + (cl, sess_) <- getSMPProxyClient c proxySrv_ destSess logServer ("--> " <> proxySrv cl <> " >") c destSrv entId cmdStr case sess_ of Right sess -> do @@ -1053,7 +1053,7 @@ sendOrProxySMPCommand :: AM (Maybe SMPServer) sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDirectly = do sess <- liftIO $ mkTransportSession c userId destSrv senderId - ifM (atomically shouldUseProxy) (sendViaProxy sess) (sendDirectly sess $> Nothing) + ifM (atomically shouldUseProxy) (sendViaProxy Nothing sess) (sendDirectly sess $> Nothing) where shouldUseProxy = do cfg <- getNetworkConfig c @@ -1071,22 +1071,31 @@ sendOrProxySMPCommand c userId destSrv cmdStr senderId sendCmdViaProxy sendCmdDi SPFAllowProtected -> ipAddressProtected cfg destSrv SPFProhibit -> False unknownServer = maybe True (notElem destSrv . knownSrvs) <$> TM.lookup userId (smpServers c) - sendViaProxy destSess@(_, _, qId) = do - r <- tryAgentError . withProxySession c destSess senderId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess) -> do + sendViaProxy :: Maybe SMPServerWithAuth -> SMPTransportSession -> AM (Maybe SMPServer) + sendViaProxy proxySrv_ destSess@(_, _, qId) = do + r <- tryAgentError . withProxySession c proxySrv_ destSess senderId ("PFWD " <> cmdStr) $ \(SMPConnectedClient smp _, proxySess@ProxiedRelay {prBasicAuth}) -> do r' <- liftClient SMP (clientServer smp) $ sendCmdViaProxy smp proxySess + let proxySrv = protocolClientServer' smp case r' of - Right () -> pure . Just $ protocolClientServer' smp + Right () -> pure $ Just proxySrv Left proxyErr -> do case proxyErr of - (ProxyProtocolError (SMP.PROXY SMP.NO_SESSION)) -> atomically deleteRelaySession - _ -> pure () - throwE - PROXY - { proxyServer = protocolClientServer smp, - relayServer = B.unpack $ strEncode destSrv, - proxyErr - } + ProxyProtocolError (SMP.PROXY SMP.NO_SESSION) -> do + atomically deleteRelaySession + case proxySrv_ of + Just _ -> proxyError + -- sendViaProxy is called recursively here to re-create the session via the same server + -- to avoid failure in interactive calls that don't retry after the session disconnection. + Nothing -> sendViaProxy (Just $ ProtoServerWithAuth proxySrv prBasicAuth) destSess + _ -> proxyError where + proxyError = + throwE + PROXY + { proxyServer = protocolClientServer smp, + relayServer = B.unpack $ strEncode destSrv, + proxyErr + } -- checks that the current proxied relay session is the same one that was used to send the message and removes it deleteRelaySession = ( TM.lookup destSess (smpProxiedRelays c) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 1837a256b..4b1e673b0 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -823,7 +823,7 @@ connectSMPProxiedRelay c@ProtocolClient {client_ = PClient {tcpConnectTimeout, t PKEY sId vr (chain, key) -> case supportedClientSMPRelayVRange `compatibleVersion` vr of Nothing -> throwE $ transportErr TEVersion - Just (Compatible v) -> liftEitherWith (const $ transportErr $ TEHandshake IDENTITY) $ ProxiedRelay sId v <$> validateRelay chain key + Just (Compatible v) -> liftEitherWith (const $ transportErr $ TEHandshake IDENTITY) $ ProxiedRelay sId v proxyAuth <$> validateRelay chain key r -> throwE $ unexpectedResponse r | otherwise = throwE $ PCETransportError TEVersion where @@ -842,6 +842,7 @@ connectSMPProxiedRelay c@ProtocolClient {client_ = PClient {tcpConnectTimeout, t data ProxiedRelay = ProxiedRelay { prSessionId :: SessionId, prVersion :: VersionSMP, + prBasicAuth :: Maybe BasicAuth, -- auth is included here to allow reconnecting via the same proxy after NO_SESSION error prServerKey :: C.PublicKeyX25519 } @@ -902,7 +903,7 @@ proxySMPCommand :: SenderId -> Command 'Sender -> ExceptT SMPClientError IO (Either ProxyClientError ()) -proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {clientCorrId = g, tcpTimeout}} (ProxiedRelay sessionId v serverKey) spKey sId command = do +proxySMPCommand c@ProtocolClient {thParams = proxyThParams, client_ = PClient {clientCorrId = g, tcpTimeout}} (ProxiedRelay sessionId v _ serverKey) spKey sId command = do -- prepare params let serverThAuth = (\ta -> ta {serverPeerPubKey = serverKey}) <$> thAuth proxyThParams serverThParams = smpTHParamsSetVersion v proxyThParams {sessionId, thAuth = serverThAuth} diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 625dfbda7..7505ef977 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -34,7 +34,8 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (pattern PQSupportOn) import qualified Simplex.Messaging.Crypto.Ratchet as CR -import Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Protocol (EncRcvMsgBody (..), MsgBody, RcvMessage (..), SubscriptionMode (..), maxMessageLength, noMsgFlags) +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Transport import Simplex.Messaging.Util (bshow, tshow) @@ -122,6 +123,8 @@ smpProxyTests = do agentViaProxyVersionError it "retries sending when destination or proxy relay is offline" $ agentViaProxyRetryOffline + it "retries sending when destination relay session disconnects in proxy" $ + agentViaProxyRetryNoSession describe "stress test 1k" $ do let deliver nAgents nMsgs = agentDeliverMessagesViaProxyConc (replicate nAgents [srv1]) (map bshow [1 :: Int .. nMsgs]) it "2 agents, 250 messages" . oneServer $ deliver 2 250 @@ -157,7 +160,7 @@ deliverMessagesViaProxy proxyServ relayServ alg unsecuredMsgs securedMsgs = do -- prepare receiving queue (rPub, rPriv) <- atomically $ C.generateAuthKeyPair alg g (rdhPub, rdhPriv :: C.PrivateKeyX25519) <- atomically $ C.generateKeyPair g - QIK {rcvId, sndId, rcvPublicDhKey = srvDh} <- runExceptT' $ createSMPQueue rc (rPub, rPriv) rdhPub (Just "correct") SMSubscribe False + SMP.QIK {rcvId, sndId, rcvPublicDhKey = srvDh} <- runExceptT' $ createSMPQueue rc (rPub, rPriv) rdhPub (Just "correct") SMSubscribe False let dec = decryptMsgV3 $ C.dh' srvDh rdhPriv -- get proxy session sess0 <- runExceptT' $ connectSMPProxiedRelay pc relayServ (Just "correct") @@ -374,18 +377,38 @@ agentViaProxyRetryOffline = do msgId = subtract baseId . fst servers srv = (initAgentServersProxy SPMAlways SPFProhibit) {smp = userServers [srv]} +agentViaProxyRetryNoSession :: IO () +agentViaProxyRetryNoSession = do + let srv1 = SMPServer testHost testPort testKeyHash + srv2 = SMPServer testHost testPort2 testKeyHash + withAgent 1 agentCfg (servers srv1) testDB $ \a -> + withAgent 2 agentCfg (servers srv2) testDB2 $ \b -> do + withSmpServerConfigOn (transport @TLS) proxyCfg testPort $ \_ -> do + (aId, _) <- withServer2 $ \_ -> runRight $ makeConnection a b + nGet b =##> \case ("", "", DOWN _ [c]) -> c == aId; _ -> False + withServer2 $ \_ -> do + nGet b =##> \case ("", "", UP _ [c]) -> c == aId; _ -> False + -- to test retry in case of NO_SESSION error, + -- the client using server 1 as proxy and server 2 as destination + -- should be joining the connection, so the order is swapped here. + _ <- runRight $ makeConnection b a + pure () + where + withServer2 = withSmpServerConfigOn (transport @TLS) proxyCfg {storeLogFile = Just testStoreLogFile2, storeMsgsFile = Just testStoreMsgsFile2} testPort2 + servers srv = (initAgentServersProxy SPMAlways SPFProhibit) {smp = userServers [srv]} + testNoProxy :: IO () testNoProxy = do withSmpServerConfigOn (transport @TLS) cfg testPort2 $ \_ -> do testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do - (_, _, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", PRXY testSMPServer Nothing) + (_, _, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", SMP.PRXY testSMPServer Nothing) reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) testProxyAuth :: IO () testProxyAuth = do withSmpServerConfigOn (transport @TLS) proxyCfgAuth testPort $ \_ -> do testSMPClient_ "127.0.0.1" testPort proxyVRangeV8 $ \(th :: THandleSMP TLS 'TClient) -> do - (_, _s, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", PRXY testSMPServer2 $ Just "wrong") + (_, _s, (_corrId, _entityId, reply)) <- sendRecv th (Nothing, "0", "", SMP.PRXY testSMPServer2 $ Just "wrong") reply `shouldBe` Right (SMP.ERR $ SMP.PROXY SMP.BASIC_AUTH) where proxyCfgAuth = proxyCfg {newQueueBasicAuth = Just "correct"}