mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-06-04 23:51:33 +00:00
extract SessionVar from AgentClient to reuse (#1099)
This commit is contained in:
committed by
GitHub
parent
5e783396e0
commit
ebb75ced12
@@ -40,6 +40,7 @@ import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer)
|
||||
import Simplex.Messaging.Session
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
@@ -50,7 +51,7 @@ import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type SMPClientVar = TMVar (Either SMPClientError SMPClient)
|
||||
type SMPClientVar = SessionVar (Either SMPClientError SMPClient)
|
||||
|
||||
data SMPClientAgentEvent
|
||||
= CAConnected SMPServer
|
||||
@@ -100,7 +101,8 @@ data SMPClientAgent = SMPClientAgent
|
||||
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
|
||||
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
|
||||
reconnections :: TVar [Async ()],
|
||||
asyncClients :: TVar [Async ()]
|
||||
asyncClients :: TVar [Async ()],
|
||||
workerSeq :: TVar Int
|
||||
}
|
||||
|
||||
newtype InternalException e = InternalException {unInternalException :: e}
|
||||
@@ -115,9 +117,10 @@ instance Exception e => MonadUnliftIO (ExceptT e IO) where
|
||||
ExceptT . fmap (first unInternalException) . E.try $
|
||||
withRunInIO $ \run ->
|
||||
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
|
||||
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
|
||||
-- the last two lines could be replaced with:
|
||||
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
|
||||
|
||||
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
|
||||
-- the last two lines could be replaced with:
|
||||
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
|
||||
|
||||
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
|
||||
{-# INLINE withRunInIO #-}
|
||||
@@ -136,46 +139,53 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg
|
||||
pendingSrvSubs <- TM.empty
|
||||
reconnections <- newTVar []
|
||||
asyncClients <- newTVar []
|
||||
pure SMPClientAgent {agentCfg, msgQ, agentQ, randomDrg, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
|
||||
workerSeq <- newTVar 0
|
||||
pure
|
||||
SMPClientAgent
|
||||
{ agentCfg,
|
||||
msgQ,
|
||||
agentQ,
|
||||
randomDrg,
|
||||
smpClients,
|
||||
srvSubs,
|
||||
pendingSrvSubs,
|
||||
reconnections,
|
||||
asyncClients,
|
||||
workerSeq
|
||||
}
|
||||
|
||||
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
|
||||
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} srv =
|
||||
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg, workerSeq} srv =
|
||||
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
||||
where
|
||||
getClientVar :: STM (Either SMPClientVar SMPClientVar)
|
||||
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
|
||||
|
||||
newClientVar :: STM SMPClientVar
|
||||
newClientVar = do
|
||||
smpVar <- newEmptyTMVar
|
||||
TM.insert srv smpVar smpClients
|
||||
pure smpVar
|
||||
getClientVar = getSessVar workerSeq srv smpClients
|
||||
|
||||
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
||||
waitForSMPClient smpVar = do
|
||||
waitForSMPClient v = do
|
||||
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
|
||||
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar)
|
||||
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
|
||||
liftEither $ case smpClient_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left PCEResponseTimeout
|
||||
|
||||
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync)
|
||||
newSMPClient v = tryConnectClient pure (liftIO tryConnectAsync)
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryE connectClient >>= \r -> case r of
|
||||
tryE (connectClient v) >>= \r -> case r of
|
||||
Right smp -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar smpVar r
|
||||
atomically $ putTMVar (sessionVar v) r
|
||||
successAction smp
|
||||
Left e -> do
|
||||
if e == PCENetworkError || e == PCEResponseTimeout
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
putTMVar (sessionVar v) (Left e)
|
||||
removeSessVar v srv smpClients
|
||||
throwE e
|
||||
tryConnectAsync :: IO ()
|
||||
tryConnectAsync = do
|
||||
@@ -186,17 +196,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
|
||||
void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: ExceptT SMPClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
connectClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
||||
connectClient v = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) (clientDisconnected v)
|
||||
|
||||
clientDisconnected :: SMPClient -> IO ()
|
||||
clientDisconnected _ = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown)
|
||||
clientDisconnected :: SMPClientVar -> SMPClient -> IO ()
|
||||
clientDisconnected v _ = do
|
||||
removeClientAndSubs v >>= (`forM_` serverDown)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateAuthKey))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
removeClientAndSubs :: SMPClientVar -> IO (Maybe (Map SMPSub C.APrivateAuthKey))
|
||||
removeClientAndSubs v = atomically $ do
|
||||
removeSessVar v srv smpClients
|
||||
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs sVar = do
|
||||
@@ -207,7 +217,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
|
||||
addPendingSubs sVar ss = do
|
||||
let ps = pendingSrvSubs ca
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union ss v
|
||||
Just ss' -> TM.union ss ss'
|
||||
_ -> TM.insert srv sVar ps
|
||||
|
||||
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
|
||||
@@ -268,10 +278,10 @@ closeSMPClientAgent c = do
|
||||
cancelActions $ asyncClients c
|
||||
|
||||
closeSMPServerClients :: SMPClientAgent -> IO ()
|
||||
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
|
||||
closeSMPServerClients c = atomically (smpClients c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient)
|
||||
where
|
||||
closeClient smpVar =
|
||||
atomically (readTMVar smpVar) >>= \case
|
||||
closeClient v =
|
||||
atomically (readTMVar $ sessionVar v) >>= \case
|
||||
Right smp -> closeProtocolClient smp `catchAll_` pure ()
|
||||
_ -> pure ()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user