mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-29 18:49:58 +00:00
agent: take invitation lock on join (#870)
This commit is contained in:
@@ -42,6 +42,7 @@ module Simplex.Messaging.Agent
|
||||
disconnectAgentClient,
|
||||
resumeAgentClient,
|
||||
withConnLock,
|
||||
withInvLock,
|
||||
createUser,
|
||||
deleteUser,
|
||||
createConnectionAsync,
|
||||
@@ -480,16 +481,17 @@ newConnNoQueues c userId connId enableNtfs cMode = do
|
||||
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
|
||||
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case crAgentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
withInvLock c (B.unpack . strEncode $ cReqUri) "joinConnAsync" $ do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case crAgentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo =
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
@@ -614,24 +616,25 @@ startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {cr
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId
|
||||
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do
|
||||
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
|
||||
g <- asks idsDrg
|
||||
connId' <- withStore c $ \db -> runExceptT $ do
|
||||
connId' <- ExceptT $ createSndConn db g cData q
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
duplexHS = connAgentVersion /= 1
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv =
|
||||
withInvLock c (B.unpack . strEncode $ inv) "joinConnSrv" $ do
|
||||
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
|
||||
g <- asks idsDrg
|
||||
connId' <- withStore c $ \db -> runExceptT $ do
|
||||
connId' <- ExceptT $ createSndConn db g cData q
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
Left e -> do
|
||||
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
|
||||
withStore' c (`deleteConn` connId')
|
||||
throwError e
|
||||
let sq = (q :: SndQueue) {connId = connId'}
|
||||
cData' = (cData :: ConnData) {connId = connId'}
|
||||
duplexHS = connAgentVersion /= 1
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
pure connId'
|
||||
Left e -> do
|
||||
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
|
||||
withStore' c (`deleteConn` connId')
|
||||
throwError e
|
||||
joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
@@ -1742,11 +1745,12 @@ getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration]
|
||||
getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn)
|
||||
|
||||
debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks
|
||||
debugAgentLocks' AgentClient {connLocks = cs, reconnectLocks = rs, deleteLock = d} = do
|
||||
debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do
|
||||
connLocks <- getLocks cs
|
||||
invLocks <- getLocks is
|
||||
srvLocks <- getLocks rs
|
||||
delLock <- atomically $ tryReadTMVar d
|
||||
pure AgentLocks {connLocks, srvLocks, delLock}
|
||||
pure AgentLocks {connLocks, invLocks, srvLocks, delLock}
|
||||
where
|
||||
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client
|
||||
ProtocolTestStep (..),
|
||||
newAgentClient,
|
||||
withConnLock,
|
||||
withInvLock,
|
||||
closeAgentClient,
|
||||
closeProtocolServerClients,
|
||||
closeXFTPServerClient,
|
||||
@@ -249,6 +250,8 @@ data AgentClient = AgentClient
|
||||
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
|
||||
-- locks to prevent concurrent operations with connection
|
||||
connLocks :: TMap ConnId Lock,
|
||||
-- locks to prevent concurrent operations with connection request invitations
|
||||
invLocks :: TMap String Lock,
|
||||
-- lock to prevent concurrency between periodic and async connection deletions
|
||||
deleteLock :: Lock,
|
||||
-- locks to prevent concurrent reconnections to SMP servers
|
||||
@@ -279,7 +282,12 @@ data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
|
||||
data AgentState = ASForeground | ASSuspending | ASSuspended
|
||||
deriving (Eq, Show)
|
||||
|
||||
data AgentLocks = AgentLocks {connLocks :: Map String String, srvLocks :: Map String String, delLock :: Maybe String}
|
||||
data AgentLocks = AgentLocks
|
||||
{ connLocks :: Map String String,
|
||||
invLocks :: Map String String,
|
||||
srvLocks :: Map String String,
|
||||
delLock :: Maybe String
|
||||
}
|
||||
deriving (Show, Generic, FromJSON)
|
||||
|
||||
instance ToJSON AgentLocks where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
@@ -325,6 +333,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
|
||||
agentState <- newTVar ASForeground
|
||||
getMsgLocks <- TM.empty
|
||||
connLocks <- TM.empty
|
||||
invLocks <- TM.empty
|
||||
deleteLock <- createLock
|
||||
reconnectLocks <- TM.empty
|
||||
reconnections <- newTAsyncs
|
||||
@@ -362,6 +371,7 @@ newAgentClient InitialAgentServers {smp, ntf, xftp, netCfg} agentEnv = do
|
||||
agentState,
|
||||
getMsgLocks,
|
||||
connLocks,
|
||||
invLocks,
|
||||
deleteLock,
|
||||
reconnectLocks,
|
||||
reconnections,
|
||||
@@ -645,6 +655,9 @@ withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
|
||||
withConnLock _ "" _ = id
|
||||
withConnLock AgentClient {connLocks} connId name = withLockMap_ connLocks connId name
|
||||
|
||||
withInvLock :: MonadUnliftIO m => AgentClient -> String -> String -> m a -> m a
|
||||
withInvLock AgentClient {invLocks} = withLockMap_ invLocks
|
||||
|
||||
withLockMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
|
||||
withLockMap_ locks key = withGetLock $ TM.lookup key locks >>= maybe newLock pure
|
||||
where
|
||||
|
||||
Reference in New Issue
Block a user