use TMap for subscription maps (#341)

* use TMap for subscription maps

* refactor

* correction
This commit is contained in:
Evgeny Poberezkin
2022-03-28 18:49:17 +01:00
committed by GitHub
parent 6ef6bedc03
commit cd22e06b3a
4 changed files with 76 additions and 60 deletions

View File

@@ -63,7 +63,7 @@ import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKe
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError)
import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, whenM)
import Simplex.Messaging.Version
import System.Timeout (timeout)
import UnliftIO (async, forConcurrently_)
@@ -79,8 +79,8 @@ data AgentClient = AgentClient
msgQ :: TBQueue SMPServerTransmission,
smpServers :: TVar (NonEmpty SMPServer),
smpClients :: TMap SMPServer SMPClientVar,
subscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue),
pendingSubscrSrvrs :: TMap SMPServer (Map ConnId RcvQueue),
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
subscrConns :: TMap ConnId SMPServer,
connMsgsQueued :: TMap ConnId Bool,
smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId),
@@ -188,21 +188,24 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
clientDisconnected :: UnliftIO m -> IO ()
clientDisconnected u = do
removeClientSubs >>= (`forM_` serverDown u)
removeClientAndSubs >>= (`forM_` serverDown u)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientSubs = atomically $ do
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
cs_ <- TM.lookupDelete srv $ subscrSrvrs c
forM_ cs_ $ \cs -> do
modifyTVar' (TM.tVar $ subscrConns c) (`M.withoutKeys` M.keysSet cs)
modifyTVar' (TM.tVar $ pendingSubscrSrvrs c) $ addPendingSubs cs
return cs_
cVar_ <- TM.lookupDelete srv $ subscrSrvrs c
forM cVar_ $ \cVar -> do
cs <- readTVar cVar
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
where
addPendingSubs :: Map ConnId RcvQueue -> Map SMPServer (Map ConnId RcvQueue) -> Map SMPServer (Map ConnId RcvQueue)
addPendingSubs cs = M.alter (Just . addSubs cs) srv
addSubs cs = maybe cs (M.union cs)
addPendingSubs cVar cs = do
let ps = pendingSubscrSrvrs c
TM.lookup srv ps >>= \case
Just v -> TM.union cs v
_ -> TM.insert srv cVar ps
serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO ()
serverDown u cs = unless (M.null cs) $ do
@@ -221,19 +224,26 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
reconnectClient `catchError` const loop
reconnectClient :: m ()
reconnectClient = do
reconnectClient =
withAgentLock c . withSMP c srv $ \smp -> do
subs <- readTVarIO . TM.tVar $ subscrConns c
cs <- atomically . TM.lookup srv $ pendingSubscrSrvrs c
forConcurrently_ (maybe [] M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) ->
when (isNothing $ M.lookup connId subs) $ do
subscribeSMPQueue smp rcvPrivateKey rcvId
`catchError` \case
e@SMPResponseTimeout -> throwError e
e@SMPNetworkError -> throwError e
e -> liftIO $ notifySub (ERR $ smpClientError e) connId
addSubscription c rq connId
liftIO $ notifySub UP connId
cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c)
forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) ->
whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $
subscribe_ smp sub `catchError` handleError connId
where
subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT SMPClientError IO ()
subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do
subscribeSMPQueue smp rcvPrivateKey rcvId
addSubscription c rq connId
liftIO $ notifySub UP connId
handleError :: ConnId -> SMPClientError -> ExceptT SMPClientError IO ()
handleError connId = \case
e@SMPResponseTimeout -> throwError e
e@SMPNetworkError -> throwError e
e -> do
liftIO $ notifySub (ERR $ smpClientError e) connId
atomically $ removePendingSubscription c srv connId
notifySub :: ACommand 'Agent -> ConnId -> IO ()
notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
@@ -243,10 +253,10 @@ closeAgentClient c = liftIO $ do
closeSMPServerClients c
cancelActions $ reconnections c
cancelActions $ asyncClients c
cancelActions . TM.tVar $ smpQueueMsgDeliveries c
cancelActions $ smpQueueMsgDeliveries c
closeSMPServerClients :: AgentClient -> IO ()
closeSMPServerClients c = readTVarIO (TM.tVar $ smpClients c) >>= mapM_ (forkIO . closeClient)
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
where
closeClient smpVar =
atomically (readTMVar smpVar) >>= \case
@@ -331,7 +341,7 @@ newRcvQueue_ a c srv = do
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
addPendingSubscription c rq connId
atomically $ addPendingSubscription c rq connId
withLogSMP c server rcvId "SUB" $ \smp -> do
liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case
Left e -> do
@@ -343,17 +353,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c rq@RcvQueue {server} connId = atomically $ do
TM.insert connId server $ subscrConns c
addSubs_ rq connId $ subscrSrvrs c
addSubs_ (subscrSrvrs c) rq connId
removePendingSubscription c server connId
addPendingSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addPendingSubscription c rq connId =
atomically . addSubs_ rq connId $ pendingSubscrSrvrs c
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
addSubs_ :: RcvQueue -> ConnId -> TMap SMPServer (Map ConnId RcvQueue) -> STM ()
addSubs_ rq@RcvQueue {server} connId = TM.alter (Just . addSub) server
where
addSub = maybe (M.singleton connId rq) (M.insert connId rq)
addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM ()
addSubs_ ss rq@RcvQueue {server} connId =
TM.lookup server ss >>= \case
Just m -> TM.insert connId rq m
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
removeSubscription :: MonadUnliftIO m => AgentClient -> ConnId -> m ()
removeSubscription c@AgentClient {subscrConns} connId = atomically $ do
@@ -363,13 +373,9 @@ removeSubscription c@AgentClient {subscrConns} connId = atomically $ do
removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM ()
removePendingSubscription = removeSubs_ . pendingSubscrSrvrs
removeSubs_ :: TMap SMPServer (Map ConnId RcvQueue) -> SMPServer -> ConnId -> STM ()
removeSubs_ ss server connId = TM.update delSub server ss
where
delSub :: Map ConnId RcvQueue -> Maybe (Map ConnId RcvQueue)
delSub cs =
let cs' = M.delete connId cs
in if M.null cs' then Nothing else Just cs'
removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM ()
removeSubs_ ss server connId =
TM.lookup server ss >>= mapM_ (TM.delete connId)
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
logServer dir AgentClient {clientId} srv qId cmdStr =

View File

@@ -137,7 +137,7 @@ runClientTransport th@THandle {sessionId} = do
clientDisconnected :: (MonadUnliftIO m, MonadReader Env m) => Client -> m ()
clientDisconnected c@Client {subscriptions, connected} = do
atomically $ writeTVar connected False
subs <- readTVarIO $ TM.tVar subscriptions
subs <- readTVarIO subscriptions
mapM_ cancelSub subs
cs <- asks $ subscribers . server
atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs

View File

@@ -114,9 +114,9 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile}
restoreQueues QueueStore {queues, senders, notifiers} s = do
(qs, s') <- liftIO $ readWriteStoreLog s
atomically $ do
writeTVar (TM.tVar queues) =<< mapM newTVar qs
writeTVar (TM.tVar senders) $ M.foldr' addSender M.empty qs
writeTVar (TM.tVar notifiers) $ M.foldr' addNotifier M.empty qs
writeTVar queues =<< mapM newTVar qs
writeTVar senders $ M.foldr' addSender M.empty qs
writeTVar notifiers $ M.foldr' addNotifier M.empty qs
pure s'
addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId
addSender q = M.insert (senderId q) (recipientId q)

View File

@@ -1,6 +1,7 @@
module Simplex.Messaging.TMap
( TMap (..),
( TMap,
empty,
singleton,
Simplex.Messaging.TMap.lookup,
member,
insert,
@@ -10,6 +11,7 @@ module Simplex.Messaging.TMap
adjust,
update,
alter,
union,
)
where
@@ -17,44 +19,52 @@ import Control.Concurrent.STM
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
newtype TMap k a = TMap {tVar :: TVar (Map k a)}
type TMap k a = TVar (Map k a)
empty :: STM (TMap k a)
empty = TMap <$> newTVar M.empty
empty = newTVar M.empty
{-# INLINE empty #-}
singleton :: k -> a -> STM (TMap k a)
singleton k v = newTVar $ M.singleton k v
{-# INLINE singleton #-}
lookup :: Ord k => k -> TMap k a -> STM (Maybe a)
lookup k (TMap m) = M.lookup k <$> readTVar m
lookup k m = M.lookup k <$> readTVar m
{-# INLINE lookup #-}
member :: Ord k => k -> TMap k a -> STM Bool
member k (TMap m) = M.member k <$> readTVar m
member k m = M.member k <$> readTVar m
{-# INLINE member #-}
insert :: Ord k => k -> a -> TMap k a -> STM ()
insert k v (TMap m) = modifyTVar' m $ M.insert k v
insert k v m = modifyTVar' m $ M.insert k v
{-# INLINE insert #-}
delete :: Ord k => k -> TMap k a -> STM ()
delete k (TMap m) = modifyTVar' m $ M.delete k
delete k m = modifyTVar' m $ M.delete k
{-# INLINE delete #-}
lookupInsert :: Ord k => k -> a -> TMap k a -> STM (Maybe a)
lookupInsert k v (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv)
lookupInsert k v m = stateTVar m $ \mv -> (M.lookup k mv, M.insert k v mv)
{-# INLINE lookupInsert #-}
lookupDelete :: Ord k => k -> TMap k a -> STM (Maybe a)
lookupDelete k (TMap m) = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv)
lookupDelete k m = stateTVar m $ \mv -> (M.lookup k mv, M.delete k mv)
{-# INLINE lookupDelete #-}
adjust :: Ord k => (a -> a) -> k -> TMap k a -> STM ()
adjust f k (TMap m) = modifyTVar' m $ M.adjust f k
adjust f k m = modifyTVar' m $ M.adjust f k
{-# INLINE adjust #-}
update :: Ord k => (a -> Maybe a) -> k -> TMap k a -> STM ()
update f k (TMap m) = modifyTVar' m $ M.update f k
update f k m = modifyTVar' m $ M.update f k
{-# INLINE update #-}
alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM ()
alter f k (TMap m) = modifyTVar' m $ M.alter f k
alter f k m = modifyTVar' m $ M.alter f k
{-# INLINE alter #-}
union :: Ord k => Map k a -> TMap k a -> STM ()
union m' m = modifyTVar' m $ M.union m'
{-# INLINE union #-}