diff --git a/CHANGELOG.md b/CHANGELOG.md index ad8862b0f..b5792195d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,15 @@ +# 5.7.2 + +SMP agent: +- fix connections failing when connecting via link due to race condition on slow network. +- remove concurrency limit when waiting for connection subscription. +- remove TLS timeout. + +# 5.7.1 + +SMP agent: +- increase timeout for TLS connection via SOCKS + # 5.7.0 Version 5.7.0.4 diff --git a/package.yaml b/package.yaml index 084fe3a8f..7a536f7bf 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 5.7.0.4 +version: 5.7.2.0 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, diff --git a/simplexmq.cabal b/simplexmq.cabal index 4a8baed9b..aa008f600 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 5.7.0.4 +version: 5.7.2.0 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index e3b48ba78..3b68667e4 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -55,6 +55,7 @@ module Simplex.Messaging.Agent deleteConnectionAsync, deleteConnectionsAsync, createConnection, + prepareConnectionToJoin, joinConnection, allowConnection, acceptContact, @@ -288,9 +289,18 @@ createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs {-# INLINE createConnection #-} --- | Join SMP agent connection (JOIN command) -joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId -joinConnection c userId enableNtfs = withAgentEnv c .:: joinConn c userId "" enableNtfs +-- | Create SMP agent connection without queue (to be joined with joinConnection passing connection ID). +-- This method is required to prevent race condition when confirmation from peer is received before +-- the caller of joinConnection saves connection ID to the database. +-- Instead of it we could send confirmation asynchronously, but then it would be harder to report +-- "link deleted" (SMP AUTH) interactively, so this approach is simpler overall. +prepareConnectionToJoin :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> PQSupport -> AE ConnId +prepareConnectionToJoin c userId enableNtfs = withAgentEnv c .: newConnToJoin c userId "" enableNtfs + +-- | Join SMP agent connection (JOIN command). +joinConnection :: AgentClient -> UserId -> Maybe ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AE ConnId +joinConnection c userId Nothing enableNtfs = withAgentEnv c .:: joinConn c userId "" False enableNtfs +joinConnection c userId (Just connId) enableNtfs = withAgentEnv c .:: joinConn c userId connId True enableNtfs {-# INLINE joinConnection #-} -- | Allow connection to continue after CONF notification (LET command) @@ -575,7 +585,7 @@ processCommand :: AgentClient -> (EntityId, APartyCmd 'Client) -> AM (EntityId, processCommand c (connId, APC e cmd) = second (APC e) <$> case cmd of NEW enableNtfs (ACM cMode) pqIK subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing pqIK subMode - JOIN enableNtfs (ACR _ cReq) pqEnc subMode connInfo -> (,OK) <$> joinConn c userId connId enableNtfs cReq connInfo pqEnc subMode + JOIN enableNtfs (ACR _ cReq) pqEnc subMode connInfo -> (,OK) <$> joinConn c userId connId False enableNtfs cReq connInfo pqEnc subMode LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) ACPT invId pqEnc ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo pqEnc SMSubscribe RJCT invId -> rejectContact' c connId invId $> (connId, OK) @@ -708,11 +718,14 @@ switchConnectionAsync' c corrId connId = newConn :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> AM (ConnId, ConnectionRequestUri c) newConn c userId connId enableNtfs cMode clientData pqInitKeys subMode = - getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode + getSMPServer c userId >>= newConnSrv c userId connId False enableNtfs cMode clientData pqInitKeys subMode -newConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, ConnectionRequestUri c) -newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do - connId' <- newConnNoQueues c userId connId enableNtfs cMode (CR.connPQEncryption pqInitKeys) +newConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, ConnectionRequestUri c) +newConnSrv c userId connId hasNewConn enableNtfs cMode clientData pqInitKeys subMode srv = do + connId' <- + if hasNewConn + then pure connId + else newConnNoQueues c userId connId enableNtfs cMode (CR.connPQEncryption pqInitKeys) newRcvConnSrv c userId connId' enableNtfs cMode clientData pqInitKeys subMode srv newRcvConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> AM (ConnId, ConnectionRequestUri c) @@ -738,18 +751,36 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange) -joinConn :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId -joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do +newConnToJoin :: forall c. AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> PQSupport -> AM ConnId +newConnToJoin c userId connId enableNtfs cReq pqSup = case cReq of + CRInvitationUri {} -> + lift (compatibleInvitationUri cReq) >>= \case + Just (_, (Compatible (CR.E2ERatchetParams v _ _ _)), aVersion) -> create aVersion (Just v) + Nothing -> throwError $ AGENT A_VERSION + CRContactUri {} -> + lift (compatibleContactUri cReq) >>= \case + Just (_, aVersion) -> create aVersion Nothing + Nothing -> throwError $ AGENT A_VERSION + where + create :: Compatible VersionSMPA -> Maybe CR.VersionE2E -> AM ConnId + create (Compatible connAgentVersion) e2eV_ = do + g <- asks random + let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion e2eV_ + cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} + withStore c $ \db -> createNewConn db g cData SCMInvitation + +joinConn :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> AM ConnId +joinConn c userId connId hasNewConn enableNtfs cReq cInfo pqSupport subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> getNextServer c userId [qServer q] _ -> getSMPServer c userId - joinConnSrv c userId connId enableNtfs cReq cInfo pqSupport subMode srv + joinConnSrv c userId connId hasNewConn enableNtfs cReq cInfo pqSupport subMode srv -startJoinInvitation :: UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> AM (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation :: UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> AM (ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) startJoinInvitation userId connId enableNtfs cReqUri pqSup = lift (compatibleInvitationUri cReqUri) >>= \case - Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), aVersion@(Compatible connAgentVersion)) -> do + Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), Compatible connAgentVersion) -> do g <- asks random let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ v kem_ pqSupport) @@ -760,7 +791,7 @@ startJoinInvitation userId connId enableNtfs cReqUri pqSup = rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams q <- lift $ newSndQueue userId "" qInfo let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} - pure (aVersion, cData, q, rc, e2eSndParams) + pure (cData, q, rc, e2eSndParams) Nothing -> throwError $ AGENT A_VERSION connRequestPQSupport :: AgentClient -> PQSupport -> ConnectionRequestUri c -> IO (Maybe (VersionSMPA, PQSupport)) @@ -786,40 +817,43 @@ compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = AgentConfig {smpClientVRange, smpAgentVRange} <- asks config pure $ (,) - <$> (qUri `compatibleVersion` smpClientVRange) + <$> (qUri `compatibleVersion` smpClientVRange) <*> (crAgentVRange `compatibleVersion` smpAgentVRange) versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ {-# INLINE versionPQSupport_ #-} -joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = +joinConnSrv :: AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM ConnId +joinConnSrv c userId connId hasNewConn enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do - (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSup + (cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSup g <- asks random (connId', sq) <- withStore c $ \db -> runExceptT $ do - r@(connId', _) <- ExceptT $ createSndConn db g cData q + r@(connId', _) <- + if hasNewConn + then (connId,) <$> ExceptT (updateNewConnSnd db connId q) + else ExceptT $ createSndConn db g cData q liftIO $ createRatchet db connId' rc pure r let cData' = (cData :: ConnData) {connId = connId'} - tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case + tryError (confirmQueue c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case Right _ -> pure connId' Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e -joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = +joinConnSrv c userId connId hasNewConn enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = lift (compatibleContactUri cReqUri) >>= \case Just (qInfo, vrsn) -> do - (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv + (connId', cReq) <- newConnSrv c userId connId hasNewConn enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv void $ sendInvitation c userId qInfo vrsn cReq cInfo pure connId' Nothing -> throwError $ AGENT A_VERSION joinConnSrvAsync :: AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> AM () joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do - (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport + (cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport q' <- withStore c $ \db -> runExceptT $ do liftIO $ createRatchet db connId rc ExceptT $ updateNewConnSnd db connId q @@ -861,7 +895,7 @@ acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withCon withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConn c userId connId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do + joinConn c userId connId False enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -2565,8 +2599,8 @@ confirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode lift $ submitPendingMsg c cData sq -confirmQueue :: Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq srv connInfo e2eEncryption_ subMode = do +confirmQueue :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> AM () +confirmQueue c cData@ConnData {connId, connAgentVersion, pqSupport} sq srv connInfo e2eEncryption_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode void $ sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed @@ -2578,7 +2612,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq s void . liftIO $ updateSndIds db connId let pqEnc = CR.pqSupportToEnc pqSupport (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) currentE2EVersion - pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} + pure . smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo} mkAgentConfirmation :: AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> AM AgentMessage mkAgentConfirmation c cData sq srv connInfo subMode = do diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index a96de2e2a..2936e3841 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -124,7 +124,6 @@ import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (bshow, diffToMicroseconds, liftEitherWith, raceAny_, threadDelay', whenM) import Simplex.Messaging.Version import System.Timeout (timeout) -import UnliftIO (pooledMapConcurrentlyN) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- @@ -846,7 +845,7 @@ streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockS mapM_ (cb <=< sendBatch c) bs sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg] -sendBatch c@ProtocolClient {client_ = PClient {rcvConcurrency, sndQ}} b = do +sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of TBError e Request {entityId} -> do putStrLn "send error: large message" @@ -855,7 +854,7 @@ sendBatch c@ProtocolClient {client_ = PClient {rcvConcurrency, sndQ}} b = do | n > 0 -> do active <- newTVarIO True atomically $ writeTBQueue sndQ (active, s) - pooledMapConcurrentlyN rcvConcurrency (getResponse c active) rs + mapConcurrently (getResponse c active) rs | otherwise -> pure [] TBTransmission s r -> do active <- newTVarIO True diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index 08cff1d0d..da2c6c253 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -19,7 +19,7 @@ module Simplex.Messaging.Transport.Client TransportHost (..), TransportHosts (..), TransportHosts_ (..), - validateCertificateChain + validateCertificateChain, ) where @@ -52,9 +52,8 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll, parseString) import Simplex.Messaging.Transport import Simplex.Messaging.Transport.KeepAlive -import Simplex.Messaging.Util (bshow, (<$?>), catchAll, tshow) +import Simplex.Messaging.Util (bshow, catchAll, tshow, (<$?>)) import System.IO.Error -import System.Timeout (timeout) import Text.Read (readMaybe) import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E @@ -139,30 +138,26 @@ runTransportClient :: Transport c => TransportClientConfig -> Maybe ByteString - runTransportClient = runTLSTransportClient supportedParameters Nothing runTLSTransportClient :: Transport c => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> IO a) -> IO a -runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpConnectTimeout, tcpKeepAlive, clientCredentials, alpn} proxyUsername host port keyHash client = do +runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials, alpn} proxyUsername host port keyHash client = do serverCert <- newEmptyTMVarIO let hostName = B.unpack $ strEncode host clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials alpn serverCert connectTCP = case socksProxy of - Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host + Just proxy -> connectSocksClient proxy proxyUsername (hostAddr host) _ -> connectTCPClient hostName c <- do sock <- connectTCP port mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e) let tCfg = clientTransportConfig cfg - tcpConnectTimeout `timeout` connectTLS (Just hostName) tCfg clientParams sock >>= \case - Nothing -> do - close sock - logError "connection timed out" - fail "connection timed out" - Just tls -> do - chain <- - atomically (tryTakeTMVar serverCert) >>= \case - Nothing -> do - logError "onServerCertificate didn't fire or failed to get cert chain" - closeTLS tls >> error "onServerCertificate failed" - Just c -> pure c - getClientConnection tCfg chain tls + -- No TLS timeout to avoid failing connections via SOCKS + tls <- connectTLS (Just hostName) tCfg clientParams sock + chain <- + atomically (tryTakeTMVar serverCert) >>= \case + Nothing -> do + logError "onServerCertificate didn't fire or failed to get cert chain" + closeTLS tls >> error "onServerCertificate failed" + Just c -> pure c + getClientConnection tCfg chain tls client c `E.finally` closeConnection c where hostAddr = \case diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index ff6dfeacd..05c461f6e 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -248,7 +248,7 @@ createConnection :: AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn) joinConnection :: AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> AE ConnId -joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQSupportOn +joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId Nothing enableNtfs cReq connInfo PQSupportOn sendMessage :: AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> AE AgentMsgId sendMessage c connId msgFlags msgBody = do @@ -513,7 +513,7 @@ runAgentClientTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> AgentC runAgentClientTest pqSupport viaProxy alice@AgentClient {} bob baseId = runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqSupport) SMSubscribe - aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe + aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" pqSupport SMSubscribe ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" @@ -641,7 +641,9 @@ runAgentClientContactTest :: HasCallStack => PQSupport -> Bool -> AgentClient -> runAgentClientContactTest pqSupport viaProxy alice bob baseId = runRight_ $ do (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqSupport) SMSubscribe - aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe + aliceId <- A.prepareConnectionToJoin bob 1 True qInfo pqSupport + aliceId' <- A.joinConnection bob 1 (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe + liftIO $ aliceId' `shouldBe` aliceId ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport bobId <- acceptContact alice True invId "alice's connInfo" PQSupportOn SMSubscribe @@ -1411,7 +1413,9 @@ makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn makeConnectionForUsers_ :: PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe - aliceId <- A.joinConnection bob bobUserId True qInfo "bob's connInfo" pqSupport SMSubscribe + aliceId <- A.prepareConnectionToJoin bob bobUserId True qInfo pqSupport + aliceId' <- A.joinConnection bob bobUserId (Just aliceId) True qInfo "bob's connInfo" pqSupport SMSubscribe + liftIO $ aliceId' `shouldBe` aliceId ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 91722228b..2b0009344 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -22,10 +22,13 @@ tRcvQueuesTests = do describe "connection API" $ do it "hasConn" hasConnTest it "hasConn, batch add" hasConnTestBatch + it "hasConn, batch idempotent" batchIdempotentTest it "deleteConn" deleteConnTest describe "session API" $ do it "getSessQueues" getSessQueuesTest it "getDelSessQueues" getDelSessQueuesTest + describe "queue transfer" $ do + it "getDelSessQueues-batchAddQueues preserves total length" removeSubsTest checkDataInvariant :: RQ.TRcvQueues -> IO Bool checkDataInvariant trq = atomically $ do @@ -62,6 +65,19 @@ hasConnTestBatch = do atomically (RQ.hasConn "c3" trq) `shouldReturn` True atomically (RQ.hasConn "nope" trq) `shouldReturn` False +batchIdempotentTest :: IO () +batchIdempotentTest = do + trq <- atomically RQ.empty + let qs = [dummyRQ 0 "smp://1234-w==@alpha" "c1", dummyRQ 0 "smp://1234-w==@alpha" "c2", dummyRQ 0 "smp://1234-w==@beta" "c3"] + atomically $ RQ.batchAddQueues trq qs + checkDataInvariant trq `shouldReturn` True + qs' <- readTVarIO $ RQ.getRcvQueues trq + cs' <- readTVarIO $ RQ.getConnections trq + atomically $ RQ.batchAddQueues trq qs + checkDataInvariant trq `shouldReturn` True + readTVarIO (RQ.getRcvQueues trq) `shouldReturn` qs' + fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn`cs' -- connections get duplicated, but that doesn't appear to affect anybody + deleteConnTest :: IO () deleteConnTest = do trq <- atomically RQ.empty @@ -121,6 +137,40 @@ getDelSessQueuesTest = do atomically (RQ.hasConn "c3" trq) `shouldReturn` True atomically (RQ.hasConn "c4" trq) `shouldReturn` True +removeSubsTest :: IO () +removeSubsTest = do + aq <- atomically RQ.empty + let qs = + [ dummyRQ 0 "smp://1234-w==@alpha" "c1", + dummyRQ 0 "smp://1234-w==@alpha" "c2", + dummyRQ 0 "smp://1234-w==@beta" "c3", + dummyRQ 1 "smp://1234-w==@beta" "c4" + ] + atomically $ RQ.batchAddQueues aq qs + + pq <- atomically RQ.empty + atomically (totalSize aq pq) `shouldReturn` (4, 4) + + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@alpha", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically (totalSize aq pq) `shouldReturn` (4, 4) + + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "non-existent") aq >>= RQ.batchAddQueues pq . fst + atomically (totalSize aq pq) `shouldReturn` (4, 4) + + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@localhost", Nothing) aq >>= RQ.batchAddQueues pq . fst + atomically (totalSize aq pq) `shouldReturn` (4, 4) + + atomically $ RQ.getDelSessQueues (0, "smp://1234-w==@beta", Just "c3") aq >>= RQ.batchAddQueues pq . fst + atomically (totalSize aq pq) `shouldReturn` (4, 4) + +totalSize :: RQ.TRcvQueues -> RQ.TRcvQueues -> STM (Int, Int) +totalSize a b = do + qsizeA <- M.size <$> readTVar (RQ.getRcvQueues a) + qsizeB <- M.size <$> readTVar (RQ.getRcvQueues b) + csizeA <- M.size <$> readTVar (RQ.getConnections a) + csizeB <- M.size <$> readTVar (RQ.getConnections b) + pure (qsizeA + qsizeB, csizeA + csizeB) + dummyRQ :: UserId -> SMPServer -> ConnId -> RcvQueue dummyRQ userId server connId = RcvQueue diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 47251a2d4..b70c88883 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -133,7 +133,7 @@ agentDeliverMessageViaProxy aTestCfg@(aSrvs, _, aViaProxy) bTestCfg@(bSrvs, _, b withAgent 1 aCfg (servers aTestCfg) testDB $ \alice -> withAgent 2 aCfg (servers bTestCfg) testDB2 $ \bob -> runRight_ $ do (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (CR.IKNoPQ PQSupportOn) SMSubscribe - aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" PQSupportOn SMSubscribe + aliceId <- A.joinConnection bob 1 Nothing True qInfo "bob's connInfo" PQSupportOn SMSubscribe ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice liftIO $ pqSup' `shouldBe` PQSupportOn allowConnection alice bobId confId "alice's connInfo"