agent: take invitation lock on join (#870)

This commit is contained in:
spaced4ndy
2023-10-24 16:55:57 +04:00
committed by GitHub
parent cf8b9c12ff
commit 73d7f84ee3
2 changed files with 47 additions and 30 deletions

View File

@@ -42,6 +42,7 @@ module Simplex.Messaging.Agent
disconnectAgentClient,
resumeAgentClient,
withConnLock,
withInvLock,
createUser,
deleteUser,
createConnectionAsync,
@@ -480,16 +481,17 @@ newConnNoQueues c userId connId enableNtfs cMode = do
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
withInvLock c (B.unpack . strEncode $ cReqUri) "joinConnAsync" $ do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo =
throwError $ CMD PROHIBITED
@@ -614,24 +616,25 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr
_ -> throwError $ AGENT A_VERSION
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
g <- asks idsDrg
connId' <- withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData q
liftIO $ createRatchet db connId' rc
pure connId'
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
duplexHS = connAgentVersion /= 1
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv =
withInvLock c (B.unpack . strEncode $ inv) "joinConnSrv" $ do
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
g <- asks idsDrg
connId' <- withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData q
liftIO $ createRatchet db connId' rc
pure connId'
Left e -> do
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
withStore' c (`deleteConn` connId')
throwError e
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
duplexHS = connAgentVersion /= 1
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
pure connId'
Left e -> do
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
withStore' c (`deleteConn` connId')
throwError e
joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
@@ -1742,11 +1745,12 @@ getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration]
getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn)
debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do
debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do
connLocks <- getLocks cs
invLocks <- getLocks is
srvLocks <- getLocks rs
delLock <- atomically $ tryReadTMVar d
pure AgentLocks {connLocks, srvLocks, delLock}
pure AgentLocks {connLocks, invLocks, srvLocks, delLock}
where
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)

View File

@@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client
ProtocolTestStep (..),
newAgentClient,
withConnLock,
withInvLock,
closeAgentClient,
closeProtocolServerClients,
closeXFTPServerClient,
@@ -249,6 +250,8 @@ data AgentClient = AgentClient
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
-- locks to prevent concurrent operations with connection
connLocks :: TMap ConnId Lock,
-- locks to prevent concurrent operations with connection request invitations
invLocks :: TMap String Lock,
-- lock to prevent concurrency between periodic and async connection deletions
deleteLock :: Lock,
-- locks to prevent concurrent reconnections to SMP servers
@@ -279,7 +282,12 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
data AgentState = ASForeground | ASSuspending | ASSuspended
deriving (Eq, Show)
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String}
data AgentLocks = AgentLocks
{ connLocks :: Map String String,
invLocks :: Map String String,
srvLocks :: Map String String,
delLock :: Maybe String
}
deriving (Show, Generic, FromJSON)
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
@@ -325,6 +333,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
agentState <- newTVar ASForeground
getMsgLocks <- TM.empty
connLocks <- TM.empty
invLocks <- TM.empty
deleteLock <- createLock
reconnectLocks <- TM.empty
reconnections <- newTAsyncs
@@ -362,6 +371,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
agentState,
getMsgLocks,
connLocks,
invLocks,
deleteLock,
reconnectLocks,
reconnections,
@@ -645,6 +655,9 @@ withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
withConnLock _ "" _ = id
withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
withInvLock :: MonadUnliftIO m => AgentClient -> String -> String -> m a -> m a
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
where