extract SessionVar from AgentClient to reuse (#1099)

This commit is contained in:
Evgeny Poberezkin
2024-04-13 18:33:12 +01:00
committed by GitHub
parent 5e783396e0
commit ebb75ced12
4 changed files with 104 additions and 79 deletions
+43 -33
View File
@@ -40,6 +40,7 @@ import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer)
import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
@@ -50,7 +51,7 @@ import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type SMPClientVar = TMVar (Either SMPClientError SMPClient)
type SMPClientVar = SessionVar (Either SMPClientError SMPClient)
data SMPClientAgentEvent
= CAConnected SMPServer
@@ -100,7 +101,8 @@ data SMPClientAgent = SMPClientAgent
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
reconnections :: TVar [Async ()],
asyncClients :: TVar [Async ()]
asyncClients :: TVar [Async ()],
workerSeq :: TVar Int
}
newtype InternalException e = InternalException {unInternalException :: e}
@@ -115,9 +117,10 @@ instance Exception e => MonadUnliftIO (ExceptT e IO) where
ExceptT . fmap (first unInternalException) . E.try $
withRunInIO $ \run ->
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
-- the last two lines could be replaced with:
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
-- the last two lines could be replaced with:
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
{-# INLINE withRunInIO #-}
@@ -136,46 +139,53 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg
pendingSrvSubs <- TM.empty
reconnections <- newTVar []
asyncClients <- newTVar []
pure SMPClientAgent {agentCfg, msgQ, agentQ, randomDrg, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
workerSeq <- newTVar 0
pure
SMPClientAgent
{ agentCfg,
msgQ,
agentQ,
randomDrg,
smpClients,
srvSubs,
pendingSrvSubs,
reconnections,
asyncClients,
workerSeq
}
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} srv =
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg, workerSeq} srv =
atomically getClientVar >>= either newSMPClient waitForSMPClient
where
getClientVar :: STM (Either SMPClientVar SMPClientVar)
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
newClientVar :: STM SMPClientVar
newClientVar = do
smpVar <- newEmptyTMVar
TM.insert srv smpVar smpClients
pure smpVar
getClientVar = getSessVar workerSeq srv smpClients
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
waitForSMPClient smpVar = do
waitForSMPClient v = do
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar)
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
liftEither $ case smpClient_ of
Just (Right smpClient) -> Right smpClient
Just (Left e) -> Left e
Nothing -> Left PCEResponseTimeout
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync)
newSMPClient v = tryConnectClient pure (liftIO tryConnectAsync)
where
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
tryConnectClient successAction retryAction =
tryE connectClient >>= \r -> case r of
tryE (connectClient v) >>= \r -> case r of
Right smp -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar smpVar r
atomically $ putTMVar (sessionVar v) r
successAction smp
Left e -> do
if e == PCENetworkError || e == PCEResponseTimeout
then retryAction
else atomically $ do
putTMVar smpVar (Left e)
TM.delete srv smpClients
putTMVar (sessionVar v) (Left e)
removeSessVar v srv smpClients
throwE e
tryConnectAsync :: IO ()
tryConnectAsync = do
@@ -186,17 +196,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
void $ tryConnectClient (const reconnectClient) loop
connectClient :: ExceptT SMPClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
connectClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
connectClient v = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) (clientDisconnected v)
clientDisconnected :: SMPClient -> IO ()
clientDisconnected _ = do
removeClientAndSubs >>= (`forM_` serverDown)
clientDisconnected :: SMPClientVar -> SMPClient -> IO ()
clientDisconnected v _ = do
removeClientAndSubs v >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateAuthKey))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
removeClientAndSubs :: SMPClientVar -> IO (Maybe (Map SMPSub C.APrivateAuthKey))
removeClientAndSubs v = atomically $ do
removeSessVar v srv smpClients
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
where
updateSubs sVar = do
@@ -207,7 +217,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
addPendingSubs sVar ss = do
let ps = pendingSrvSubs ca
TM.lookup srv ps >>= \case
Just v -> TM.union ss v
Just ss' -> TM.union ss ss'
_ -> TM.insert srv sVar ps
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
@@ -268,10 +278,10 @@ closeSMPClientAgent c = do
cancelActions $ asyncClients c
closeSMPServerClients :: SMPClientAgent -> IO ()
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
closeSMPServerClients c = atomically (smpClients c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient)
where
closeClient smpVar =
atomically (readTMVar smpVar) >>= \case
closeClient v =
atomically (readTMVar $ sessionVar v) >>= \case
Right smp -> closeProtocolClient smp `catchAll_` pure ()
_ -> pure ()