extract SessionVar from AgentClient to reuse (#1099)

This commit is contained in:
Evgeny Poberezkin
2024-04-13 18:33:12 +01:00
committed by GitHub
parent 5e783396e0
commit ebb75ced12
4 changed files with 104 additions and 79 deletions
+1
View File
@@ -147,6 +147,7 @@ library
Simplex.Messaging.Server.Stats
Simplex.Messaging.Server.StoreLog
Simplex.Messaging.ServiceScheme
Simplex.Messaging.Session
Simplex.Messaging.TMap
Simplex.Messaging.Transport
Simplex.Messaging.Transport.Buffer
+22 -46
View File
@@ -152,7 +152,6 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:.))
import Data.Either (lefts, partitionEithers)
import Data.Functor (($>))
import Data.Int (Int64)
@@ -227,6 +226,7 @@ import Simplex.Messaging.Protocol
sameSrvAddr',
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPVersion)
@@ -241,11 +241,6 @@ import UnliftIO.Directory (doesFileExist, getTemporaryDirectory, removeFile)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
data SessionVar a = SessionVar
{ sessionVar :: TMVar a,
sessionVarId :: Int
}
type ClientVar msg = SessionVar (Either AgentErrorType (Client msg))
type SMPClientVar = ClientVar SMP.BrokerMsg
@@ -550,9 +545,9 @@ instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
clientSessionTs = X.xftpSessionTs
getSMPServerClient :: AgentClient -> SMPTransportSession -> AM SMPClient
getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do
getSMPServerClient c@AgentClient {active, smpClients, msgQ, workerSeq} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess smpClients)
atomically (getSessVar workerSeq tSess smpClients)
>>= either newClient (waitForProtocolClient c tSess)
where
-- we resubscribe only on newClient error, but not on waitForProtocolClient error,
@@ -581,7 +576,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
removeClientAndSubs :: IO ([RcvQueue], [ConnId])
removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], [])
where
currentActiveClient = (&&) <$> removeTSessVar' v tSess smpClients <*> readTVar active
currentActiveClient = (&&) <$> removeSessVar' v tSess smpClients <*> readTVar active
removeSubs = do
(qs, cs) <- RQ.getDelSessQueues tSess $ activeSubs c
RQ.batchAddQueues (pendingSubs c) qs
@@ -600,14 +595,14 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
resubscribeSMPSession :: AgentClient -> SMPTransportSession -> AM' ()
resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
resubscribeSMPSession c@AgentClient {smpSubWorkers, workerSeq} tSess =
atomically getWorkerVar >>= mapM_ (either newSubWorker (\_ -> pure ()))
where
getWorkerVar =
ifM
(null <$> getPending)
(pure Nothing) -- prevent race with cleanup and adding pending queues in another call
(Just <$> getTSessVar c tSess smpSubWorkers)
(Just <$> getSessVar workerSeq tSess smpSubWorkers)
newSubWorker v = do
a <- async $ void (E.tryAny runSubWorker) >> atomically (cleanup v)
atomically $ putTMVar (sessionVar v) a
@@ -626,7 +621,7 @@ resubscribeSMPSession c@AgentClient {smpSubWorkers} tSess =
-- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar.
-- Not waiting may result in terminated worker remaining in the map.
whenM (isEmptyTMVar $ sessionVar v) retry
removeTSessVar v tSess smpSubWorkers
removeSessVar v tSess smpSubWorkers
reconnectSMPClient :: TVar Int -> AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> AM ()
reconnectSMPClient tc c tSess@(_, srv, _) qs = do
@@ -660,9 +655,9 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
getNtfServerClient :: AgentClient -> NtfTransportSession -> AM NtfClient
getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = do
getNtfServerClient c@AgentClient {active, ntfClients, workerSeq} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess ntfClients)
atomically (getSessVar workerSeq tSess ntfClients)
>>= either
(newProtocolClient c tSess ntfClients connectClient)
(waitForProtocolClient c tSess)
@@ -677,15 +672,15 @@ getNtfServerClient c@AgentClient {active, ntfClients} tSess@(userId, srv, _) = d
clientDisconnected :: NtfClientVar -> NtfClient -> IO ()
clientDisconnected v client = do
atomically $ removeTSessVar v tSess ntfClients
atomically $ removeSessVar v tSess ntfClients
incClientStat c userId client "DISCONNECT" ""
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient
getXFTPServerClient c@AgentClient {active, xftpClients} tSess@(userId, srv, _) = do
getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq} tSess@(userId, srv, _) = do
unlessM (readTVarIO active) . throwError $ INACTIVE
atomically (getTSessVar c tSess xftpClients)
atomically (getSessVar workerSeq tSess xftpClients)
>>= either
(newProtocolClient c tSess xftpClients connectClient)
(waitForProtocolClient c tSess)
@@ -701,32 +696,11 @@ getXFTPServerClient c@AgentClient {active, xftpClients} tSess@(userId, srv, _) =
clientDisconnected :: XFTPClientVar -> XFTPClient -> IO ()
clientDisconnected v client = do
atomically $ removeTSessVar v tSess xftpClients
atomically $ removeSessVar v tSess xftpClients
incClientStat c userId client "DISCONNECT" ""
atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getTSessVar :: forall a s. AgentClient -> TransportSession s -> TMap (TransportSession s) (SessionVar a) -> STM (Either (SessionVar a) (SessionVar a))
getTSessVar c tSess vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lookup tSess vs
where
newSessionVar :: STM (SessionVar a)
newSessionVar = do
sessionVar <- newEmptyTMVar
sessionVarId <- stateTVar (workerSeq c) $ \next -> (next, next + 1)
let v = SessionVar {sessionVar, sessionVarId}
TM.insert tSess v vs
pure v
removeTSessVar :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM ()
removeTSessVar = void .:. removeTSessVar'
{-# INLINE removeTSessVar #-}
removeTSessVar' :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM Bool
removeTSessVar' v tSess vs =
TM.lookup tSess vs >>= \case
Just v' | sessionVarId v == sessionVarId v' -> TM.delete tSess vs $> True
_ -> pure False
waitForProtocolClient :: ProtocolTypeI (ProtoType msg) => AgentClient -> TransportSession msg -> ClientVar msg -> AM (Client msg)
waitForProtocolClient c (_, srv, _) v = do
NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c
@@ -757,7 +731,7 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
Left e -> do
liftIO $ incServerStat c userId srv "CLIENT" $ strEncode e
atomically $ do
removeTSessVar v tSess clients
removeSessVar v tSess clients
putTMVar (sessionVar v) (Left e)
throwError e -- signal error to caller
@@ -781,10 +755,11 @@ getNetworkConfig c = do
waitForUserNetwork :: AgentClient -> AM' ()
waitForUserNetwork AgentClient {userNetworkState} =
(offline <$> readTVarIO userNetworkState) >>= mapM_ waitWhileOffline
readTVarIO userNetworkState >>= mapM_ waitWhileOffline . offline
where
waitWhileOffline UNSOffline {offlineDelay = d} =
unlessM (liftIO $ waitOnline d False) $ do -- network delay reached, increase delay
unlessM (liftIO $ waitOnline d False) $ do
-- network delay reached, increase delay
ts' <- liftIO getCurrentTime
ni <- asks $ userNetworkInterval . config
atomically $ do
@@ -794,7 +769,7 @@ waitForUserNetwork AgentClient {userNetworkState} =
-- and to reset `offlineDelay` if network went `on` and `off` again.
writeTVar userNetworkState $!
let d'' = nextRetryDelay (diffToMicroseconds $ diffUTCTime ts' ts) (min d d') ni
in ns {offline = Just UNSOffline {offlineDelay = d'', offlineFrom = ts}}
in ns {offline = Just UNSOffline {offlineDelay = d'', offlineFrom = ts}}
waitOnline :: Int64 -> Bool -> IO Bool
waitOnline t online'
| t <= 0 = pure online'
@@ -866,9 +841,10 @@ closeClient c clientSel tSess =
closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO ()
closeClient_ c v = do
NetworkConfig {tcpConnectTimeout} <- atomically $ getNetworkConfig c
E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
Just (Right client) -> closeProtocolServerClient client `catchAll_` pure ()
_ -> pure ()
E.handle (\BlockedIndefinitelyOnSTM -> pure ()) $
tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
Just (Right client) -> closeProtocolServerClient client `catchAll_` pure ()
_ -> pure ()
closeXFTPServerClient :: AgentClient -> UserId -> XFTPServer -> FileDigest -> IO ()
closeXFTPServerClient c userId server (FileDigest chunkDigest) =
+43 -33
View File
@@ -40,6 +40,7 @@ import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BrokerMsg, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer)
import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
@@ -50,7 +51,7 @@ import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type SMPClientVar = TMVar (Either SMPClientError SMPClient)
type SMPClientVar = SessionVar (Either SMPClientError SMPClient)
data SMPClientAgentEvent
= CAConnected SMPServer
@@ -100,7 +101,8 @@ data SMPClientAgent = SMPClientAgent
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
reconnections :: TVar [Async ()],
asyncClients :: TVar [Async ()]
asyncClients :: TVar [Async ()],
workerSeq :: TVar Int
}
newtype InternalException e = InternalException {unInternalException :: e}
@@ -115,9 +117,10 @@ instance Exception e => MonadUnliftIO (ExceptT e IO) where
ExceptT . fmap (first unInternalException) . E.try $
withRunInIO $ \run ->
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
-- the last two lines could be replaced with:
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
-- the last two lines could be replaced with:
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
{-# INLINE withRunInIO #-}
@@ -136,46 +139,53 @@ newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg
pendingSrvSubs <- TM.empty
reconnections <- newTVar []
asyncClients <- newTVar []
pure SMPClientAgent {agentCfg, msgQ, agentQ, randomDrg, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
workerSeq <- newTVar 0
pure
SMPClientAgent
{ agentCfg,
msgQ,
agentQ,
randomDrg,
smpClients,
srvSubs,
pendingSrvSubs,
reconnections,
asyncClients,
workerSeq
}
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} srv =
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg, workerSeq} srv =
atomically getClientVar >>= either newSMPClient waitForSMPClient
where
getClientVar :: STM (Either SMPClientVar SMPClientVar)
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
newClientVar :: STM SMPClientVar
newClientVar = do
smpVar <- newEmptyTMVar
TM.insert srv smpVar smpClients
pure smpVar
getClientVar = getSessVar workerSeq srv smpClients
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
waitForSMPClient smpVar = do
waitForSMPClient v = do
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar smpVar)
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
liftEither $ case smpClient_ of
Just (Right smpClient) -> Right smpClient
Just (Left e) -> Left e
Nothing -> Left PCEResponseTimeout
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
newSMPClient smpVar = tryConnectClient pure (liftIO tryConnectAsync)
newSMPClient v = tryConnectClient pure (liftIO tryConnectAsync)
where
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
tryConnectClient successAction retryAction =
tryE connectClient >>= \r -> case r of
tryE (connectClient v) >>= \r -> case r of
Right smp -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar smpVar r
atomically $ putTMVar (sessionVar v) r
successAction smp
Left e -> do
if e == PCENetworkError || e == PCEResponseTimeout
then retryAction
else atomically $ do
putTMVar smpVar (Left e)
TM.delete srv smpClients
putTMVar (sessionVar v) (Left e)
removeSessVar v srv smpClients
throwE e
tryConnectAsync :: IO ()
tryConnectAsync = do
@@ -186,17 +196,17 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
void $ tryConnectClient (const reconnectClient) loop
connectClient :: ExceptT SMPClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) clientDisconnected
connectClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
connectClient v = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) (clientDisconnected v)
clientDisconnected :: SMPClient -> IO ()
clientDisconnected _ = do
removeClientAndSubs >>= (`forM_` serverDown)
clientDisconnected :: SMPClientVar -> SMPClient -> IO ()
clientDisconnected v _ = do
removeClientAndSubs v >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateAuthKey))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
removeClientAndSubs :: SMPClientVar -> IO (Maybe (Map SMPSub C.APrivateAuthKey))
removeClientAndSubs v = atomically $ do
removeSessVar v srv smpClients
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
where
updateSubs sVar = do
@@ -207,7 +217,7 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg} sr
addPendingSubs sVar ss = do
let ps = pendingSrvSubs ca
TM.lookup srv ps >>= \case
Just v -> TM.union ss v
Just ss' -> TM.union ss ss'
_ -> TM.insert srv sVar ps
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
@@ -268,10 +278,10 @@ closeSMPClientAgent c = do
cancelActions $ asyncClients c
closeSMPServerClients :: SMPClientAgent -> IO ()
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
closeSMPServerClients c = atomically (smpClients c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient)
where
closeClient smpVar =
atomically (readTMVar smpVar) >>= \case
closeClient v =
atomically (readTMVar $ sessionVar v) >>= \case
Right smp -> closeProtocolClient smp `catchAll_` pure ()
_ -> pure ()
+38
View File
@@ -0,0 +1,38 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Session where
import Control.Concurrent.STM
import Control.Monad
import Data.Composition ((.:.))
import Data.Functor (($>))
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
data SessionVar a = SessionVar
{ sessionVar :: TMVar a,
sessionVarId :: Int
}
getSessVar :: forall k a. Ord k => TVar Int -> k -> TMap k (SessionVar a) -> STM (Either (SessionVar a) (SessionVar a))
getSessVar sessSeq sessKey vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lookup sessKey vs
where
newSessionVar :: STM (SessionVar a)
newSessionVar = do
sessionVar <- newEmptyTMVar
sessionVarId <- stateTVar sessSeq $ \next -> (next, next + 1)
let v = SessionVar {sessionVar, sessionVarId}
TM.insert sessKey v vs
pure v
removeSessVar :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM ()
removeSessVar = void .:. removeSessVar'
{-# INLINE removeSessVar #-}
removeSessVar' :: Ord k => SessionVar a -> k -> TMap k (SessionVar a) -> STM Bool
removeSessVar' v sessKey vs =
TM.lookup sessKey vs >>= \case
Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs $> True
_ -> pure False