Merge branch 'master' into ep/rfc-rotation

This commit is contained in:
Evgeny Poberezkin
2022-09-09 17:06:38 +01:00
5 changed files with 191 additions and 109 deletions
+1
View File
@@ -83,6 +83,7 @@ library
Simplex.Messaging.Server.Stats
Simplex.Messaging.Server.StoreLog
Simplex.Messaging.TMap
Simplex.Messaging.TMap2
Simplex.Messaging.Transport
Simplex.Messaging.Transport.Client
Simplex.Messaging.Transport.HTTP2
+74 -53
View File
@@ -86,6 +86,7 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Char8 (ByteString)
import Data.Composition ((.:), (.:.))
import Data.Functor (($>))
import Data.List (deleteFirstsBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
@@ -366,8 +367,11 @@ ackMessageAsync' c connId msgId =
enqueueCommand c connId (Just server) $ ACK msgId
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
newConn c connId asyncMode enableNtfs cMode = do
srv <- getAnySMPServer c
newConn c connId asyncMode enableNtfs cMode =
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c)
newConnSrv c connId asyncMode enableNtfs cMode srv = do
clientVRange <- asks $ smpClientVRange . config
(rq, qUri) <- newRcvQueue c srv clientVRange True
connId' <- setUpConn asyncMode rq
@@ -394,7 +398,15 @@ newConn c connId asyncMode enableNtfs cMode = do
withStore c $ \db -> createRcvConn db g cData rq cMode
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = do
joinConn c connId asyncMode enableNtfs cReq cInfo = do
srv <- case cReq of
CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ ->
getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)]
_ -> getSMPServer c
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServer -> m ConnId
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
@@ -410,7 +422,7 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
connId' <- setUpConn asyncMode cData sq rc
let cData' = (cData :: ConnData) {connId = connId'}
tryError (confirmQueue aVersion c cData' sq cInfo $ Just e2eSndParams) >>= \case
tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
pure connId'
@@ -431,23 +443,22 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV
liftIO $ createRatchet db connId' rc
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConn c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = do
joinConnSrv c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
agentVRange `compatibleVersion` aVRange
) of
(Just qInfo, Just vrsn) -> do
(connId', cReq) <- newConn c connId False enableNtfs SCMInvitation
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation srv
sendInvitation c qInfo vrsn cReq cInfo
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConn _c _connId True _enableNtfs (CRContactUri _) _cInfo = do
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
throwError $ CMD PROHIBITED
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {server, smpClientVersion} = do
srv <- getSMPServer c server
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c srv (versionToRange smpClientVersion) True
let qInfo = toVersionT qUri smpClientVersion
addSubscription c rq connId
@@ -499,22 +510,22 @@ processConfirmation c rq@RcvQueue {e2ePrivKey, smpClientVersion = v} SMPConfirma
-- | Subscribe to receive connection messages (SUB command) in Reader monad
subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
subscribeConnection' c connId =
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData rq sq _ _) -> do
resumeMsgDelivery c cData sq
subscribe rq
void . forkIO $ doRcvQueueAction c cData rq sq
resumeConnCmds c connId
SomeConn _ (SndConnection cData sq) -> do
resumeMsgDelivery c cData sq
case status (sq :: SndQueue) of
Confirmed -> pure () -- TODO secure queue if this is a new server version
Active -> throwError $ CONN SIMPLEX
_ -> throwError $ INTERNAL "unexpected queue status"
resumeConnCmds c connId
SomeConn _ (RcvConnection _ rq) -> subscribe rq >> resumeConnCmds c connId
SomeConn _ (ContactConnection _ rq) -> subscribe rq >> resumeConnCmds c connId
SomeConn _ (NewConnection _) -> resumeConnCmds c connId
withStore c (`getConn` connId) >>= \conn -> do
resumeConnCmds c connId
case conn of
SomeConn _ (DuplexConnection cData rq sq _ _) -> do
resumeMsgDelivery c cData sq
subscribe rq
void . forkIO $ doRcvQueueAction c cData rq sq
SomeConn _ (SndConnection cData sq) -> do
resumeMsgDelivery c cData sq
case status (sq :: SndQueue) of
Confirmed -> pure () -- TODO secure queue if this is a new server version
Active -> throwError $ CONN SIMPLEX
_ -> throwError $ INTERNAL "unexpected queue status"
SomeConn _ (RcvConnection _ rq) -> subscribe rq
SomeConn _ (ContactConnection _ rq) -> subscribe rq
SomeConn _ (NewConnection _) -> pure ()
where
-- TODO sndQueueAction?
subscribe :: RcvQueue -> m ()
@@ -549,7 +560,7 @@ createNextRcvQueue c cData@ConnData {connId} rq@RcvQueue {server, sndId} sq = do
let queueAddress = SMPQueueAddress {smpServer, senderId, dhPublicKey = C.publicKey e2ePrivKey}
pure SMPQueueUri {clientVRange, queueAddress}
_ -> do
srv <- getSMPServer c server
srv <- getNextSMPServer c [server]
(rq', qUri) <- newRcvQueue c srv clientVRange False
withStore' c $ \db -> dbCreateNextRcvQueue db connId rq rq'
pure qUri
@@ -599,7 +610,7 @@ subscribeConnections' c connIds = do
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
forM_ (M.keys cs) $ resumeConnCmds c
mapM_ (resumeConnCmds c) $ M.keys cs
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
ns <- asks ntfSupervisor
tkn <- readTVarIO (ntfTkn ns)
@@ -761,26 +772,35 @@ runCommandProcessing c@AgentClient {subQ} server = do
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
Left (e :: E.SomeException) ->
notify "" $ ERR (INTERNAL $ show e)
Right (connId, ACmd _ cmd) ->
Right (connId, ACmd _ cmd) -> do
usedSrvs <- newTVarIO ([] :: [SMPServer])
withRetryInterval ri $ \loop -> do
resp <- tryError $ case cmd of
NEW enableNtfs (ACM cMode) -> do
(_, cReq) <- newConn c connId True enableNtfs cMode
notify connId $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq) connInfo -> void $ joinConn c connId True enableNtfs cReq connInfo
NEW enableNtfs (ACM cMode) ->
withNextSrv usedSrvs $ \srv -> do
(_, cReq) <- newConnSrv c connId True enableNtfs cMode srv
notify connId $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq) connInfo ->
withNextSrv usedSrvs $ \srv ->
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo
ACK msgId -> ackMessage' c connId msgId
_ -> notify "" $ ERR (INTERNAL "")
_ -> notify connId $ ERR $ INTERNAL $ "unsupported async command " <> show cmd
case resp of
Left _ ->
-- TODO retry NEW and JOIN on different server
-- TODO depending on command, some errors shouldn't be retried
retryCommand loop
Right () -> do
delCmd cmdId
Left e
| temporaryAgentError e || e == BROKER HOST -> retryCommand loop
| otherwise -> notify connId $ ERR e
Right () -> withStore' c (`deleteCommand` cmdId)
where
delCmd :: AsyncCmdId -> m ()
delCmd cmdId = withStore' c $ \db -> deleteCommand db cmdId
withNextSrv :: TVar [SMPServer] -> (SMPServer -> m ()) -> m ()
withNextSrv usedSrvs action = do
used <- readTVarIO usedSrvs
srv <- getNextSMPServer c used
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then [] else srv : used
writeTVar usedSrvs used'
action srv
notify :: ConnId -> ACommand 'Agent -> m ()
notify connId cmd = atomically $ writeTBQueue subQ ("", connId, cmd)
retryCommand loop = do
@@ -940,7 +960,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- and this branch should never be reached as receive is created before the confirmation,
-- so the condition is not necessary here, strictly speaking.
_ -> unless (duplexHandshake == Just True) $ do
qInfo <- createReplyQueue c cData sq
srv <- getSMPServer c
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
AM_QHELLO_ -> do
@@ -1270,8 +1291,8 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do
-- unsafeIOToSTM $ putStrLn $ "in timeout: suspendSendingAndDatabase"
suspendSendingAndDatabase c
getAnySMPServer :: AgentMonad m => AgentClient -> m SMPServer
getAnySMPServer c = readTVarIO (smpServers c) >>= pickServer
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
pickServer :: AgentMonad m => NonEmpty SMPServer -> m SMPServer
pickServer = \case
@@ -1280,14 +1301,14 @@ pickServer = \case
gen <- asks randomServer
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
getSMPServer :: AgentMonad m => AgentClient -> SMPServer -> m SMPServer
getSMPServer c (SMPServer host port _) = do
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer
getNextSMPServer c usedSrvs = do
srvs <- readTVarIO $ smpServers c
case L.nonEmpty $ L.filter different srvs of
case L.nonEmpty $ deleteFirstsBy different (L.toList srvs) usedSrvs of
Just srvs' -> pickServer srvs'
_ -> pure $ L.head srvs
_ -> pickServer srvs
where
different (SMPServer host' port' _) = host /= host' || port /= port'
different (SMPServer host port _) (SMPServer host' port' _) = host /= host' || port /= port'
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
@@ -1666,8 +1687,8 @@ connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq ownConnInfo Nothing
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2eEncryption = do
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do
aMessage <- mkAgentMessage agentVersion
msg <- mkConfirmation aMessage
sendConfirmation c sq msg
@@ -1681,7 +1702,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2e
mkAgentMessage :: Version -> m AgentMessage
mkAgentMessage 1 = pure $ AgentConnInfo connInfo
mkAgentMessage _ = do
qInfo <- createReplyQueue c cData sq
qInfo <- createReplyQueue c cData sq srv
pure $ AgentConnInfoReply (qInfo :| []) connInfo
enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
+29 -56
View File
@@ -92,7 +92,6 @@ import Data.Set (Set)
import qualified Data.Set as S
import Data.Text.Encoding
import Data.Time.Clock (getCurrentTime)
import Data.Tuple (swap)
import Data.Word (Word16)
import qualified Database.SQLite.Simple as DB
import Simplex.Messaging.Agent.Env.SQLite
@@ -133,6 +132,8 @@ import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.TMap2 (TMap2)
import qualified Simplex.Messaging.TMap2 as TM2
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
@@ -159,10 +160,9 @@ data AgentClient = AgentClient
ntfServers :: TVar [NtfServer],
ntfClients :: TMap NtfServer NtfClientVar,
useNetworkConfig :: TVar NetworkConfig,
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
subscrConns :: TVar (Set ConnId),
activeSubscrConns :: TMap ConnId SMPServer,
activeSubs :: TMap2 SMPServer ConnId RcvQueue,
pendingSubs :: TMap2 SMPServer ConnId RcvQueue,
connMsgsQueued :: TMap ConnId Bool,
smpQueueMsgQueues :: TMap MsgDeliveryKey (TQueue InternalId),
smpQueueMsgDeliveries :: TMap MsgDeliveryKey (Async ()),
@@ -215,10 +215,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
ntfServers <- newTVar ntf
ntfClients <- TM.empty
useNetworkConfig <- newTVar netCfg
subscrSrvrs <- TM.empty
pendingSubscrSrvrs <- TM.empty
subscrConns <- newTVar S.empty
activeSubscrConns <- TM.empty
activeSubs <- TM2.empty
pendingSubs <- TM2.empty
connMsgsQueued <- TM.empty
smpQueueMsgQueues <- TM.empty
smpQueueMsgDeliveries <- TM.empty
@@ -237,7 +236,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
asyncClients <- newTVar []
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
lock <- newTMVar ()
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
agentDbPath :: AgentClient -> FilePath
agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath
@@ -276,26 +275,18 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs
where
updateSubs cVar = do
cs <- readTVar cVar
modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
addPendingSubs cVar cs = do
let ps = pendingSubscrSrvrs c
TM.lookup srv ps >>= \case
Just v -> TM.union cs v
_ -> TM.insert srv cVar ps
TM2.insert1 srv cVar $ pendingSubs c
readTVar cVar
serverDown :: Map ConnId RcvQueue -> IO ()
serverDown cs = unless (M.null cs) $
whenM (readTVarIO active) $ do
let conns = M.keys cs
notifySub "" $ hostEvent DISCONNECT client
unless (null conns) . notifySub "" $ DOWN srv conns
serverDown cs = whenM (readTVarIO active) $ do
notifySub "" $ hostEvent DISCONNECT client
let conns = M.keys cs
unless (null conns) $ do
notifySub "" $ DOWN srv conns
atomically $ mapM_ (releaseGetLock c) cs
unliftIO u reconnectServer
@@ -313,7 +304,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
reconnectClient :: m ()
reconnectClient =
withAgentLock c $
atomically (TM.lookup srv (pendingSubscrSrvrs c) >>= mapM readTVar)
atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar)
>>= mapM_ resubscribe
where
resubscribe :: Map ConnId RcvQueue -> m ()
@@ -419,10 +410,9 @@ closeAgentClient c = liftIO $ do
cancelActions $ reconnections c
cancelActions $ asyncClients c
cancelActions $ smpQueueMsgDeliveries c
clear subscrSrvrs
clear pendingSubscrSrvrs
atomically . TM2.clear $ activeSubs c
atomically . TM2.clear $ pendingSubs c
clear subscrConns
clear activeSubscrConns
clear connMsgsQueued
clear smpQueueMsgQueues
clear getMsgLocks
@@ -534,17 +524,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
addPendingSubscription c rq connId
TM2.insert server connId rq $ pendingSubs c
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
>>= either throwError pure
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult c rq@RcvQueue {server} connId r = do
processSubResult c rq connId r = do
case r of
Left e ->
atomically . unless (temporaryClientError e) $
removePendingSubscription c server connId
TM2.delete connId (pendingSubs c)
_ -> addSubscription c rq connId
pure r
@@ -564,9 +554,9 @@ temporaryAgentError = \case
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ \q -> atomically $ do
modifyTVar (subscrConns c) . S.insert $ fst q
uncurry (addPendingSubscription c) $ swap q
forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
TM2.insert server connId rq $ pendingSubs c
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
@@ -588,35 +578,18 @@ subscribeQueues c srv qs = do
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c rq@RcvQueue {server} connId = atomically $ do
TM.insert connId server $ activeSubscrConns c
modifyTVar (subscrConns c) $ S.insert connId
addSubs_ (subscrSrvrs c) rq connId
removePendingSubscription c server connId
TM2.insert server connId rq $ activeSubs c
TM2.delete connId $ pendingSubs c
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
hasActiveSubscription c connId = TM.member connId (activeSubscrConns c)
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM ()
addSubs_ ss rq@RcvQueue {server} connId =
TM.lookup server ss >>= \case
Just m -> TM.insert connId rq m
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
hasActiveSubscription c connId = TM2.member connId $ activeSubs c
removeSubscription :: AgentClient -> ConnId -> STM ()
removeSubscription c connId = do
modifyTVar (subscrConns c) $ S.delete connId
server_ <- TM.lookupDelete connId $ activeSubscrConns c
mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_
removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM ()
removePendingSubscription = removeSubs_ . pendingSubscrSrvrs
removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM ()
removeSubs_ ss server connId =
TM.lookup server ss >>= mapM_ (TM.delete connId)
TM2.delete connId $ activeSubs c
TM2.delete connId $ pendingSubs c
getSubscriptions :: AgentClient -> STM (Set ConnId)
getSubscriptions = readTVar . subscrConns
+5
View File
@@ -2,6 +2,7 @@ module Simplex.Messaging.TMap
( TMap,
empty,
singleton,
clear,
Simplex.Messaging.TMap.null,
Simplex.Messaging.TMap.lookup,
member,
@@ -31,6 +32,10 @@ singleton :: k -> a -> STM (TMap k a)
singleton k v = newTVar $ M.singleton k v
{-# INLINE singleton #-}
clear :: TMap k a -> STM ()
clear m = writeTVar m M.empty
{-# INLINE clear #-}
null :: TMap k a -> STM Bool
null m = M.null <$> readTVar m
{-# INLINE null #-}
+82
View File
@@ -0,0 +1,82 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.TMap2
( TMap2,
empty,
clear,
Simplex.Messaging.TMap2.lookup,
lookup1,
member,
insert,
insert1,
delete,
lookupDelete1,
)
where
import Control.Concurrent.STM
import Control.Monad (forM_, (>=>))
import qualified Data.Map.Strict as M
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (whenM, ($>>=))
-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys.
-- It allows direct access via k1 to a group of k2 values and via k2 to one value
data TMap2 k1 k2 a = TMap2
{ _m1 :: TMap k1 (TMap k2 a),
_m2 :: TMap k2 k1
}
empty :: STM (TMap2 k1 k2 a)
empty = TMap2 <$> TM.empty <*> TM.empty
clear :: TMap2 k1 k2 a -> STM ()
clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2
lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a)
lookup k2 TMap2 {_m1, _m2} = do
TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2
lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1
{-# INLINE lookup1 #-}
member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool
member k2 TMap2 {_m2} = TM.member k2 _m2
{-# INLINE member #-}
insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM ()
insert k1 k2 v TMap2 {_m1, _m2} =
TM.lookup k2 _m2 >>= \case
Just k1'
| k1 == k1' -> _insert1
| otherwise -> _delete1 k1' k2 _m1 >> _insert2
_ -> _insert2
where
_insert1 =
TM.lookup k1 _m1 >>= \case
Just m -> TM.insert k2 v m
_ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1
_insert2 = TM.insert k2 k1 _m2 >> _insert1
insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM ()
insert1 k1 m' TMap2 {_m1, _m2} =
TM.lookup k1 _m1 >>= \case
Just m -> readTVar m' >>= (`TM.union` m)
_ -> TM.insert k1 m' _m1
delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM ()
delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1)
_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM ()
_delete1 k1 k2 m1 =
TM.lookup k1 m1
>>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1))
lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
lookupDelete1 k1 TMap2 {_m1, _m2} = do
m_ <- TM.lookupDelete k1 _m1
forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet
pure m_