From 73d7f84ee3b68cf34ed85e047f9345214af402f5 Mon Sep 17 00:00:00 2001 From: spaced4ndy <8711996+spaced4ndy@users.noreply.github.com> Date: Tue, 24 Oct 2023 16:55:57 +0400 Subject: [PATCH] agent: take invitation lock on join (#870) --- src/Simplex/Messaging/Agent.hs | 62 ++++++++++++++------------- src/Simplex/Messaging/Agent/Client.hs | 15 ++++++- 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index bd8b3f418..3ad9ef475 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -42,6 +42,7 @@ module Simplex.Messaging.Agent disconnectAgentClient, resumeAgentClient, withConnLock, + withInvLock, createUser, deleteUser, createConnectionAsync, @@ -480,16 +481,17 @@ newConnNoQueues c userId connId enableNtfs cMode = do joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do - aVRange <- asks $ smpAgentVRange . config - case crAgentVRange `compatibleVersion` aVRange of - Just (Compatible connAgentVersion) -> do - g <- asks idsDrg - let duplexHS = connAgentVersion /= 1 - cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} - connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo - pure connId - _ -> throwError $ AGENT A_VERSION + withInvLock c (B.unpack . strEncode $ cReqUri) "joinConnAsync" $ do + aVRange <- asks $ smpAgentVRange . config + case crAgentVRange `compatibleVersion` aVRange of + Just (Compatible connAgentVersion) -> do + g <- asks idsDrg + let duplexHS = connAgentVersion /= 1 + cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo + pure connId + _ -> throwError $ AGENT A_VERSION joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo = throwError $ CMD PROHIBITED @@ -614,24 +616,25 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr _ -> throwError $ AGENT A_VERSION joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do - (aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv - g <- asks idsDrg - connId' <- withStore c $ \db -> runExceptT $ do - connId' <- ExceptT $ createSndConn db g cData q - liftIO $ createRatchet db connId' rc - pure connId' - let sq = (q :: SndQueue) {connId = connId'} - cData' = (cData :: ConnData) {connId = connId'} - duplexHS = connAgentVersion /= 1 - tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case - Right _ -> do - unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO +joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = + withInvLock c (B.unpack . strEncode $ inv) "joinConnSrv" $ do + (aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv + g <- asks idsDrg + connId' <- withStore c $ \db -> runExceptT $ do + connId' <- ExceptT $ createSndConn db g cData q + liftIO $ createRatchet db connId' rc pure connId' - Left e -> do - -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md - withStore' c (`deleteConn` connId') - throwError e + let sq = (q :: SndQueue) {connId = connId'} + cData' = (cData :: ConnData) {connId = connId'} + duplexHS = connAgentVersion /= 1 + tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case + Right _ -> do + unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO + pure connId' + Left e -> do + -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md + withStore' c (`deleteConn` connId') + throwError e joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do aVRange <- asks $ smpAgentVRange . config clientVRange <- asks $ smpClientVRange . config @@ -1742,11 +1745,12 @@ getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration] getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn) debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks -debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do +debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do connLocks <- getLocks cs + invLocks <- getLocks is srvLocks <- getLocks rs delLock <- atomically $ tryReadTMVar d - pure AgentLocks {connLocks, srvLocks, delLock} + pure AgentLocks {connLocks, invLocks, srvLocks, delLock} where getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 16811fb86..8c7fb089b 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client ProtocolTestStep (..), newAgentClient, withConnLock, + withInvLock, closeAgentClient, closeProtocolServerClients, closeXFTPServerClient, @@ -249,6 +250,8 @@ data AgentClient = AgentClient getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()), -- locks to prevent concurrent operations with connection connLocks :: TMap ConnId Lock, + -- locks to prevent concurrent operations with connection request invitations + invLocks :: TMap String Lock, -- lock to prevent concurrency between periodic and async connection deletions deleteLock :: Lock, -- locks to prevent concurrent reconnections to SMP servers @@ -279,7 +282,12 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int} data AgentState = ASForeground | ASSuspending | ASSuspended deriving (Eq, Show) -data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String} +data AgentLocks = AgentLocks + { connLocks :: Map String String, + invLocks :: Map String String, + srvLocks :: Map String String, + delLock :: Maybe String + } deriving (Show, Generic, FromJSON) instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions @@ -325,6 +333,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do agentState <- newTVar ASForeground getMsgLocks <- TM.empty connLocks <- TM.empty + invLocks <- TM.empty deleteLock <- createLock reconnectLocks <- TM.empty reconnections <- newTAsyncs @@ -362,6 +371,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do agentState, getMsgLocks, connLocks, + invLocks, deleteLock, reconnectLocks, reconnections, @@ -645,6 +655,9 @@ withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a withConnLock _ "" _ = id withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name +withInvLock :: MonadUnliftIO m => AgentClient -> String -> String -> m a -> m a +withInvLock AgentClient {invLocks} = withLockMap_ invLocks + withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure where