mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-26 01:04:44 +00:00
Merge branch 'master' into ep/rfc-rotation
This commit is contained in:
@@ -83,6 +83,7 @@ library
|
||||
Simplex.Messaging.Server.Stats
|
||||
Simplex.Messaging.Server.StoreLog
|
||||
Simplex.Messaging.TMap
|
||||
Simplex.Messaging.TMap2
|
||||
Simplex.Messaging.Transport
|
||||
Simplex.Messaging.Transport.Client
|
||||
Simplex.Messaging.Transport.HTTP2
|
||||
|
||||
@@ -86,6 +86,7 @@ import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Composition ((.:), (.:.))
|
||||
import Data.Functor (($>))
|
||||
import Data.List (deleteFirstsBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
@@ -366,8 +367,11 @@ ackMessageAsync' c connId msgId =
|
||||
enqueueCommand c connId (Just server) $ ACK msgId
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c connId asyncMode enableNtfs cMode = do
|
||||
srv <- getAnySMPServer c
|
||||
newConn c connId asyncMode enableNtfs cMode =
|
||||
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode
|
||||
|
||||
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c connId asyncMode enableNtfs cMode srv = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
(rq, qUri) <- newRcvQueue c srv clientVRange True
|
||||
connId' <- setUpConn asyncMode rq
|
||||
@@ -394,7 +398,15 @@ newConn c connId asyncMode enableNtfs cMode = do
|
||||
withStore c $ \db -> createRcvConn db g cData rq cMode
|
||||
|
||||
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = do
|
||||
joinConn c connId asyncMode enableNtfs cReq cInfo = do
|
||||
srv <- case cReq of
|
||||
CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ ->
|
||||
getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)]
|
||||
_ -> getSMPServer c
|
||||
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
|
||||
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServer -> m ConnId
|
||||
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case ( qUri `compatibleVersion` clientVRange,
|
||||
@@ -410,7 +422,7 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV
|
||||
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
|
||||
connId' <- setUpConn asyncMode cData sq rc
|
||||
let cData' = (cData :: ConnData) {connId = connId'}
|
||||
tryError (confirmQueue aVersion c cData' sq cInfo $ Just e2eSndParams) >>= \case
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
pure connId'
|
||||
@@ -431,23 +443,22 @@ joinConn c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentV
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConn c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = do
|
||||
joinConnSrv c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case ( qUri `compatibleVersion` clientVRange,
|
||||
agentVRange `compatibleVersion` aVRange
|
||||
) of
|
||||
(Just qInfo, Just vrsn) -> do
|
||||
(connId', cReq) <- newConn c connId False enableNtfs SCMInvitation
|
||||
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation srv
|
||||
sendInvitation c qInfo vrsn cReq cInfo
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConn _c _connId True _enableNtfs (CRContactUri _) _cInfo = do
|
||||
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {server, smpClientVersion} = do
|
||||
srv <- getSMPServer c server
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c srv (versionToRange smpClientVersion) True
|
||||
let qInfo = toVersionT qUri smpClientVersion
|
||||
addSubscription c rq connId
|
||||
@@ -499,22 +510,22 @@ processConfirmation c rq@RcvQueue {e2ePrivKey, smpClientVersion = v} SMPConfirma
|
||||
-- | Subscribe to receive connection messages (SUB command) in Reader monad
|
||||
subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection' c connId =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection cData rq sq _ _) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
subscribe rq
|
||||
void . forkIO $ doRcvQueueAction c cData rq sq
|
||||
resumeConnCmds c connId
|
||||
SomeConn _ (SndConnection cData sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
case status (sq :: SndQueue) of
|
||||
Confirmed -> pure () -- TODO secure queue if this is a new server version
|
||||
Active -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ INTERNAL "unexpected queue status"
|
||||
resumeConnCmds c connId
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq >> resumeConnCmds c connId
|
||||
SomeConn _ (ContactConnection _ rq) -> subscribe rq >> resumeConnCmds c connId
|
||||
SomeConn _ (NewConnection _) -> resumeConnCmds c connId
|
||||
withStore c (`getConn` connId) >>= \conn -> do
|
||||
resumeConnCmds c connId
|
||||
case conn of
|
||||
SomeConn _ (DuplexConnection cData rq sq _ _) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
subscribe rq
|
||||
void . forkIO $ doRcvQueueAction c cData rq sq
|
||||
SomeConn _ (SndConnection cData sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
case status (sq :: SndQueue) of
|
||||
Confirmed -> pure () -- TODO secure queue if this is a new server version
|
||||
Active -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ INTERNAL "unexpected queue status"
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (ContactConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (NewConnection _) -> pure ()
|
||||
where
|
||||
-- TODO sndQueueAction?
|
||||
subscribe :: RcvQueue -> m ()
|
||||
@@ -549,7 +560,7 @@ createNextRcvQueue c cData@ConnData {connId} rq@RcvQueue {server, sndId} sq = do
|
||||
let queueAddress = SMPQueueAddress {smpServer, senderId, dhPublicKey = C.publicKey e2ePrivKey}
|
||||
pure SMPQueueUri {clientVRange, queueAddress}
|
||||
_ -> do
|
||||
srv <- getSMPServer c server
|
||||
srv <- getNextSMPServer c [server]
|
||||
(rq', qUri) <- newRcvQueue c srv clientVRange False
|
||||
withStore' c $ \db -> dbCreateNextRcvQueue db connId rq rq'
|
||||
pure qUri
|
||||
@@ -599,7 +610,7 @@ subscribeConnections' c connIds = do
|
||||
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
|
||||
srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs
|
||||
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
|
||||
forM_ (M.keys cs) $ resumeConnCmds c
|
||||
mapM_ (resumeConnCmds c) $ M.keys cs
|
||||
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
|
||||
ns <- asks ntfSupervisor
|
||||
tkn <- readTVarIO (ntfTkn ns)
|
||||
@@ -761,26 +772,35 @@ runCommandProcessing c@AgentClient {subQ} server = do
|
||||
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
|
||||
Left (e :: E.SomeException) ->
|
||||
notify "" $ ERR (INTERNAL $ show e)
|
||||
Right (connId, ACmd _ cmd) ->
|
||||
Right (connId, ACmd _ cmd) -> do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
withRetryInterval ri $ \loop -> do
|
||||
resp <- tryError $ case cmd of
|
||||
NEW enableNtfs (ACM cMode) -> do
|
||||
(_, cReq) <- newConn c connId True enableNtfs cMode
|
||||
notify connId $ INV (ACR cMode cReq)
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo -> void $ joinConn c connId True enableNtfs cReq connInfo
|
||||
NEW enableNtfs (ACM cMode) ->
|
||||
withNextSrv usedSrvs $ \srv -> do
|
||||
(_, cReq) <- newConnSrv c connId True enableNtfs cMode srv
|
||||
notify connId $ INV (ACR cMode cReq)
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo ->
|
||||
withNextSrv usedSrvs $ \srv ->
|
||||
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
|
||||
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo
|
||||
ACK msgId -> ackMessage' c connId msgId
|
||||
_ -> notify "" $ ERR (INTERNAL "")
|
||||
_ -> notify connId $ ERR $ INTERNAL $ "unsupported async command " <> show cmd
|
||||
case resp of
|
||||
Left _ ->
|
||||
-- TODO retry NEW and JOIN on different server
|
||||
-- TODO depending on command, some errors shouldn't be retried
|
||||
retryCommand loop
|
||||
Right () -> do
|
||||
delCmd cmdId
|
||||
Left e
|
||||
| temporaryAgentError e || e == BROKER HOST -> retryCommand loop
|
||||
| otherwise -> notify connId $ ERR e
|
||||
Right () -> withStore' c (`deleteCommand` cmdId)
|
||||
where
|
||||
delCmd :: AsyncCmdId -> m ()
|
||||
delCmd cmdId = withStore' c $ \db -> deleteCommand db cmdId
|
||||
withNextSrv :: TVar [SMPServer] -> (SMPServer -> m ()) -> m ()
|
||||
withNextSrv usedSrvs action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srv <- getNextSMPServer c used
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then [] else srv : used
|
||||
writeTVar usedSrvs used'
|
||||
action srv
|
||||
notify :: ConnId -> ACommand 'Agent -> m ()
|
||||
notify connId cmd = atomically $ writeTBQueue subQ ("", connId, cmd)
|
||||
retryCommand loop = do
|
||||
@@ -940,7 +960,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
-- and this branch should never be reached as receive is created before the confirmation,
|
||||
-- so the condition is not necessary here, strictly speaking.
|
||||
_ -> unless (duplexHandshake == Just True) $ do
|
||||
qInfo <- createReplyQueue c cData sq
|
||||
srv <- getSMPServer c
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
|
||||
AM_A_MSG_ -> notify $ SENT mId
|
||||
AM_QHELLO_ -> do
|
||||
@@ -1270,8 +1291,8 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do
|
||||
-- unsafeIOToSTM $ putStrLn $ "in timeout: suspendSendingAndDatabase"
|
||||
suspendSendingAndDatabase c
|
||||
|
||||
getAnySMPServer :: AgentMonad m => AgentClient -> m SMPServer
|
||||
getAnySMPServer c = readTVarIO (smpServers c) >>= pickServer
|
||||
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
|
||||
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
|
||||
|
||||
pickServer :: AgentMonad m => NonEmpty SMPServer -> m SMPServer
|
||||
pickServer = \case
|
||||
@@ -1280,14 +1301,14 @@ pickServer = \case
|
||||
gen <- asks randomServer
|
||||
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
|
||||
|
||||
getSMPServer :: AgentMonad m => AgentClient -> SMPServer -> m SMPServer
|
||||
getSMPServer c (SMPServer host port _) = do
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer
|
||||
getNextSMPServer c usedSrvs = do
|
||||
srvs <- readTVarIO $ smpServers c
|
||||
case L.nonEmpty $ L.filter different srvs of
|
||||
case L.nonEmpty $ deleteFirstsBy different (L.toList srvs) usedSrvs of
|
||||
Just srvs' -> pickServer srvs'
|
||||
_ -> pure $ L.head srvs
|
||||
_ -> pickServer srvs
|
||||
where
|
||||
different (SMPServer host' port' _) = host /= host' || port /= port'
|
||||
different (SMPServer host port _) (SMPServer host' port' _) = host /= host' || port /= port'
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
subscriber c@AgentClient {msgQ} = forever $ do
|
||||
@@ -1666,8 +1687,8 @@ connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
|
||||
withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
enqueueConfirmation c cData sq ownConnInfo Nothing
|
||||
|
||||
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2eEncryption = do
|
||||
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do
|
||||
aMessage <- mkAgentMessage agentVersion
|
||||
msg <- mkConfirmation aMessage
|
||||
sendConfirmation c sq msg
|
||||
@@ -1681,7 +1702,7 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2e
|
||||
mkAgentMessage :: Version -> m AgentMessage
|
||||
mkAgentMessage 1 = pure $ AgentConnInfo connInfo
|
||||
mkAgentMessage _ = do
|
||||
qInfo <- createReplyQueue c cData sq
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
pure $ AgentConnInfoReply (qInfo :| []) connInfo
|
||||
|
||||
enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
|
||||
@@ -92,7 +92,6 @@ import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text.Encoding
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
import Data.Tuple (swap)
|
||||
import Data.Word (Word16)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -133,6 +132,8 @@ import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.TMap2 (TMap2)
|
||||
import qualified Simplex.Messaging.TMap2 as TM2
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
@@ -159,10 +160,9 @@ data AgentClient = AgentClient
|
||||
ntfServers :: TVar [NtfServer],
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
useNetworkConfig :: TVar NetworkConfig,
|
||||
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
subscrConns :: TVar (Set ConnId),
|
||||
activeSubscrConns :: TMap ConnId SMPServer,
|
||||
activeSubs :: TMap2 SMPServer ConnId RcvQueue,
|
||||
pendingSubs :: TMap2 SMPServer ConnId RcvQueue,
|
||||
connMsgsQueued :: TMap ConnId Bool,
|
||||
smpQueueMsgQueues :: TMap MsgDeliveryKey (TQueue InternalId),
|
||||
smpQueueMsgDeliveries :: TMap MsgDeliveryKey (Async ()),
|
||||
@@ -215,10 +215,9 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
ntfServers <- newTVar ntf
|
||||
ntfClients <- TM.empty
|
||||
useNetworkConfig <- newTVar netCfg
|
||||
subscrSrvrs <- TM.empty
|
||||
pendingSubscrSrvrs <- TM.empty
|
||||
subscrConns <- newTVar S.empty
|
||||
activeSubscrConns <- TM.empty
|
||||
activeSubs <- TM2.empty
|
||||
pendingSubs <- TM2.empty
|
||||
connMsgsQueued <- TM.empty
|
||||
smpQueueMsgQueues <- TM.empty
|
||||
smpQueueMsgDeliveries <- TM.empty
|
||||
@@ -237,7 +236,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
asyncClients <- newTVar []
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
|
||||
lock <- newTMVar ()
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, nextRcvQueueMsgs, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
|
||||
|
||||
agentDbPath :: AgentClient -> FilePath
|
||||
agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath
|
||||
@@ -276,26 +275,18 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
|
||||
TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
|
||||
addPendingSubs cVar cs = do
|
||||
let ps = pendingSubscrSrvrs c
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union cs v
|
||||
_ -> TM.insert srv cVar ps
|
||||
TM2.insert1 srv cVar $ pendingSubs c
|
||||
readTVar cVar
|
||||
|
||||
serverDown :: Map ConnId RcvQueue -> IO ()
|
||||
serverDown cs = unless (M.null cs) $
|
||||
whenM (readTVarIO active) $ do
|
||||
let conns = M.keys cs
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
unless (null conns) . notifySub "" $ DOWN srv conns
|
||||
serverDown cs = whenM (readTVarIO active) $ do
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
let conns = M.keys cs
|
||||
unless (null conns) $ do
|
||||
notifySub "" $ DOWN srv conns
|
||||
atomically $ mapM_ (releaseGetLock c) cs
|
||||
unliftIO u reconnectServer
|
||||
|
||||
@@ -313,7 +304,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withAgentLock c $
|
||||
atomically (TM.lookup srv (pendingSubscrSrvrs c) >>= mapM readTVar)
|
||||
atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar)
|
||||
>>= mapM_ resubscribe
|
||||
where
|
||||
resubscribe :: Map ConnId RcvQueue -> m ()
|
||||
@@ -419,10 +410,9 @@ closeAgentClient c = liftIO $ do
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
cancelActions $ smpQueueMsgDeliveries c
|
||||
clear subscrSrvrs
|
||||
clear pendingSubscrSrvrs
|
||||
atomically . TM2.clear $ activeSubs c
|
||||
atomically . TM2.clear $ pendingSubs c
|
||||
clear subscrConns
|
||||
clear activeSubscrConns
|
||||
clear connMsgsQueued
|
||||
clear smpQueueMsgQueues
|
||||
clear getMsgLocks
|
||||
@@ -534,17 +524,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
addPendingSubscription c rq connId
|
||||
TM2.insert server connId rq $ pendingSubs c
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
|
||||
>>= either throwError pure
|
||||
|
||||
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult c rq@RcvQueue {server} connId r = do
|
||||
processSubResult c rq connId r = do
|
||||
case r of
|
||||
Left e ->
|
||||
atomically . unless (temporaryClientError e) $
|
||||
removePendingSubscription c server connId
|
||||
TM2.delete connId (pendingSubs c)
|
||||
_ -> addSubscription c rq connId
|
||||
pure r
|
||||
|
||||
@@ -564,9 +554,9 @@ temporaryAgentError = \case
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
|
||||
forM_ qs_ $ \q -> atomically $ do
|
||||
modifyTVar (subscrConns c) . S.insert $ fst q
|
||||
uncurry (addPendingSubscription c) $ swap q
|
||||
forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
TM2.insert server connId rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
@@ -588,35 +578,18 @@ subscribeQueues c srv qs = do
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
||||
TM.insert connId server $ activeSubscrConns c
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
addSubs_ (subscrSrvrs c) rq connId
|
||||
removePendingSubscription c server connId
|
||||
TM2.insert server connId rq $ activeSubs c
|
||||
TM2.delete connId $ pendingSubs c
|
||||
|
||||
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
|
||||
hasActiveSubscription c connId = TM.member connId (activeSubscrConns c)
|
||||
|
||||
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
|
||||
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
|
||||
|
||||
addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM ()
|
||||
addSubs_ ss rq@RcvQueue {server} connId =
|
||||
TM.lookup server ss >>= \case
|
||||
Just m -> TM.insert connId rq m
|
||||
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
|
||||
hasActiveSubscription c connId = TM2.member connId $ activeSubs c
|
||||
|
||||
removeSubscription :: AgentClient -> ConnId -> STM ()
|
||||
removeSubscription c connId = do
|
||||
modifyTVar (subscrConns c) $ S.delete connId
|
||||
server_ <- TM.lookupDelete connId $ activeSubscrConns c
|
||||
mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_
|
||||
|
||||
removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM ()
|
||||
removePendingSubscription = removeSubs_ . pendingSubscrSrvrs
|
||||
|
||||
removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM ()
|
||||
removeSubs_ ss server connId =
|
||||
TM.lookup server ss >>= mapM_ (TM.delete connId)
|
||||
TM2.delete connId $ activeSubs c
|
||||
TM2.delete connId $ pendingSubs c
|
||||
|
||||
getSubscriptions :: AgentClient -> STM (Set ConnId)
|
||||
getSubscriptions = readTVar . subscrConns
|
||||
|
||||
@@ -2,6 +2,7 @@ module Simplex.Messaging.TMap
|
||||
( TMap,
|
||||
empty,
|
||||
singleton,
|
||||
clear,
|
||||
Simplex.Messaging.TMap.null,
|
||||
Simplex.Messaging.TMap.lookup,
|
||||
member,
|
||||
@@ -31,6 +32,10 @@ singleton :: k -> a -> STM (TMap k a)
|
||||
singleton k v = newTVar $ M.singleton k v
|
||||
{-# INLINE singleton #-}
|
||||
|
||||
clear :: TMap k a -> STM ()
|
||||
clear m = writeTVar m M.empty
|
||||
{-# INLINE clear #-}
|
||||
|
||||
null :: TMap k a -> STM Bool
|
||||
null m = M.null <$> readTVar m
|
||||
{-# INLINE null #-}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.TMap2
|
||||
( TMap2,
|
||||
empty,
|
||||
clear,
|
||||
Simplex.Messaging.TMap2.lookup,
|
||||
lookup1,
|
||||
member,
|
||||
insert,
|
||||
insert1,
|
||||
delete,
|
||||
lookupDelete1,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (forM_, (>=>))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (whenM, ($>>=))
|
||||
|
||||
-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys.
|
||||
-- It allows direct access via k1 to a group of k2 values and via k2 to one value
|
||||
data TMap2 k1 k2 a = TMap2
|
||||
{ _m1 :: TMap k1 (TMap k2 a),
|
||||
_m2 :: TMap k2 k1
|
||||
}
|
||||
|
||||
empty :: STM (TMap2 k1 k2 a)
|
||||
empty = TMap2 <$> TM.empty <*> TM.empty
|
||||
|
||||
clear :: TMap2 k1 k2 a -> STM ()
|
||||
clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2
|
||||
|
||||
lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a)
|
||||
lookup k2 TMap2 {_m1, _m2} = do
|
||||
TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2
|
||||
|
||||
lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
|
||||
lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1
|
||||
{-# INLINE lookup1 #-}
|
||||
|
||||
member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool
|
||||
member k2 TMap2 {_m2} = TM.member k2 _m2
|
||||
{-# INLINE member #-}
|
||||
|
||||
insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM ()
|
||||
insert k1 k2 v TMap2 {_m1, _m2} =
|
||||
TM.lookup k2 _m2 >>= \case
|
||||
Just k1'
|
||||
| k1 == k1' -> _insert1
|
||||
| otherwise -> _delete1 k1' k2 _m1 >> _insert2
|
||||
_ -> _insert2
|
||||
where
|
||||
_insert1 =
|
||||
TM.lookup k1 _m1 >>= \case
|
||||
Just m -> TM.insert k2 v m
|
||||
_ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1
|
||||
_insert2 = TM.insert k2 k1 _m2 >> _insert1
|
||||
|
||||
insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM ()
|
||||
insert1 k1 m' TMap2 {_m1, _m2} =
|
||||
TM.lookup k1 _m1 >>= \case
|
||||
Just m -> readTVar m' >>= (`TM.union` m)
|
||||
_ -> TM.insert k1 m' _m1
|
||||
|
||||
delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM ()
|
||||
delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1)
|
||||
|
||||
_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM ()
|
||||
_delete1 k1 k2 m1 =
|
||||
TM.lookup k1 m1
|
||||
>>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1))
|
||||
|
||||
lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
|
||||
lookupDelete1 k1 TMap2 {_m1, _m2} = do
|
||||
m_ <- TM.lookupDelete k1 _m1
|
||||
forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet
|
||||
pure m_
|
||||
Reference in New Issue
Block a user