mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-31 20:36:22 +00:00
Merge branch 'master' into f/reconnect-servers
This commit is contained in:
@@ -1126,10 +1126,14 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
lift $ waitForWork doWork
|
||||
atomically $ throwWhenInactive c
|
||||
atomically $ beginAgentOperation c AOSndNetwork
|
||||
withWork c doWork (`getPendingServerCommand` server_) $ processCmd (riFast ri)
|
||||
withWork c doWork (`getPendingServerCommand` server_) $ runProcessCmd (riFast ri)
|
||||
where
|
||||
processCmd :: RetryInterval -> PendingCommand -> AM ()
|
||||
processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} = case command of
|
||||
runProcessCmd ri cmd = do
|
||||
pending <- newTVarIO []
|
||||
processCmd ri cmd pending
|
||||
mapM_ (atomically . writeTBQueue subQ) . reverse =<< readTVarIO pending
|
||||
processCmd :: RetryInterval -> PendingCommand -> TVar [ATransmission] -> AM ()
|
||||
processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} pendingCmds = case command of
|
||||
AClientCommand cmd -> case cmd of
|
||||
NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
@@ -1145,7 +1149,7 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
|
||||
ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK
|
||||
SWCH ->
|
||||
noServer . tryCommand . withConnLock c connId "switchConnection" $
|
||||
noServer . tryWithLock "switchConnection" $
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ conn@(DuplexConnection _ (replaced :| _rqs) _) ->
|
||||
switchDuplexConnection c conn replaced >>= notify . SWITCH QDRcv SPStarted
|
||||
@@ -1247,7 +1251,9 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do
|
||||
internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command)
|
||||
cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
|
||||
notify :: forall e. AEntityI e => AEvent e -> AM ()
|
||||
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, AEvt (sAEntity @e) cmd)
|
||||
notify cmd =
|
||||
let t = (corrId, connId, AEvt (sAEntity @e) cmd)
|
||||
in atomically $ ifM (isFullTBQueue subQ) (modifyTVar' pendingCmds (t :)) (writeTBQueue subQ t)
|
||||
-- ^ ^ ^ async command processing /
|
||||
|
||||
enqueueMessages :: AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> AM (AgentMsgId, PQEncryption)
|
||||
@@ -2159,7 +2165,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
forM_ ts $ \(entId, t) -> case t of
|
||||
STEvent msgOrErr ->
|
||||
withRcvConn entId $ \rq@RcvQueue {connId} conn -> case msgOrErr of
|
||||
Right msg -> processSMP rq conn (toConnData conn) msg
|
||||
Right msg -> runProcessSMP rq conn (toConnData conn) msg
|
||||
Left e -> lift $ notifyErr connId e
|
||||
STResponse (Cmd SRecipient cmd) respOrErr ->
|
||||
withRcvConn entId $ \rq conn -> case cmd of
|
||||
@@ -2167,11 +2173,11 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
Right SMP.OK -> processSubOk rq upConnIds
|
||||
Right msg@SMP.MSG {} -> do
|
||||
processSubOk rq upConnIds -- the connection is UP even when processing this particular message fails
|
||||
processSMP rq conn (toConnData conn) msg
|
||||
runProcessSMP rq conn (toConnData conn) msg
|
||||
Right r -> processSubErr rq $ unexpectedResponse r
|
||||
Left e -> unless (temporaryClientError e) $ processSubErr rq e -- timeout/network was already reported
|
||||
SMP.ACK _ -> case respOrErr of
|
||||
Right msg@SMP.MSG {} -> processSMP rq conn (toConnData conn) msg
|
||||
Right msg@SMP.MSG {} -> runProcessSMP rq conn (toConnData conn) msg
|
||||
_ -> pure () -- TODO process OK response to ACK
|
||||
_ -> pure () -- TODO process expired response to DEL
|
||||
STResponse {} -> pure () -- TODO process expired responses to sent messages
|
||||
@@ -2209,12 +2215,18 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
notify' connId msg = atomically $ writeTBQueue subQ ("", connId, AEvt (sAEntity @e) msg)
|
||||
notifyErr :: ConnId -> SMPClientError -> AM' ()
|
||||
notifyErr connId = notify' connId . ERR . protocolClientError SMP (B.unpack $ strEncode srv)
|
||||
processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> BrokerMsg -> AM ()
|
||||
runProcessSMP :: RcvQueue -> Connection c -> ConnData -> BrokerMsg -> AM ()
|
||||
runProcessSMP rq conn cData msg = do
|
||||
pending <- newTVarIO []
|
||||
processSMP rq conn cData msg pending
|
||||
mapM_ (atomically . writeTBQueue subQ) . reverse =<< readTVarIO pending
|
||||
processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> BrokerMsg -> TVar [ATransmission] -> AM ()
|
||||
processSMP
|
||||
rq@RcvQueue {rcvId = rId, sndSecure, e2ePrivKey, e2eDhSecret, status}
|
||||
conn
|
||||
cData@ConnData {connId, connAgentVersion, ratchetSyncState = rss}
|
||||
smpMsg =
|
||||
smpMsg
|
||||
pendingMsgs =
|
||||
withConnLock c connId "processSMP" $ case smpMsg of
|
||||
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> do
|
||||
atomically $ incSMPServerStat c userId srv recvMsgs
|
||||
@@ -2395,7 +2407,9 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
|
||||
r -> unexpected r
|
||||
where
|
||||
notify :: forall e m. (AEntityI e, MonadIO m) => AEvent e -> m ()
|
||||
notify = notify' connId
|
||||
notify msg =
|
||||
let t = ("", connId, AEvt (sAEntity @e) msg)
|
||||
in atomically $ ifM (isFullTBQueue subQ) (modifyTVar' pendingMsgs (t :)) (writeTBQueue subQ t)
|
||||
|
||||
prohibited :: Text -> AM ()
|
||||
prohibited s = do
|
||||
|
||||
@@ -20,17 +20,15 @@ import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (partitionEithers)
|
||||
import Data.List (partition)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (listToMaybe)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text.Encoding
|
||||
import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime)
|
||||
import Data.Tuple (swap)
|
||||
@@ -55,8 +53,8 @@ type SMPClientVar = SessionVar (Either (SMPClientError, Maybe UTCTime) (OwnServe
|
||||
data SMPClientAgentEvent
|
||||
= CAConnected SMPServer
|
||||
| CADisconnected SMPServer (Set SMPSub)
|
||||
| CAResubscribed SMPServer (NonEmpty SMPSub)
|
||||
| CASubError SMPServer (NonEmpty (SMPSub, SMPClientError))
|
||||
| CASubscribed SMPServer SMPSubParty (NonEmpty QueueId)
|
||||
| CASubError SMPServer SMPSubParty (NonEmpty (QueueId, SMPClientError))
|
||||
|
||||
data SMPSubParty = SPRecipient | SPNotifier
|
||||
deriving (Eq, Ord, Show)
|
||||
@@ -86,9 +84,9 @@ defaultSMPClientAgentConfig =
|
||||
maxInterval = 10 * second
|
||||
},
|
||||
persistErrorInterval = 30, -- seconds
|
||||
msgQSize = 256,
|
||||
agentQSize = 256,
|
||||
agentSubsBatchSize = 900,
|
||||
msgQSize = 1024,
|
||||
agentQSize = 1024,
|
||||
agentSubsBatchSize = 1360,
|
||||
ownServerDomains = []
|
||||
}
|
||||
where
|
||||
@@ -192,7 +190,7 @@ getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, worke
|
||||
isOwnServer :: SMPClientAgent -> SMPServer -> OwnServer
|
||||
isOwnServer SMPClientAgent {agentCfg} ProtocolServer {host} =
|
||||
let srv = strEncode $ L.head host
|
||||
in any (\s -> s == srv || (B.cons '.' s) `B.isSuffixOf` srv) (ownServerDomains agentCfg)
|
||||
in any (\s -> s == srv || B.cons '.' s `B.isSuffixOf` srv) (ownServerDomains agentCfg)
|
||||
|
||||
-- | Run an SMP client for SMPClientVar
|
||||
connectClient :: SMPClientAgent -> SMPServer -> SMPClientVar -> IO (Either SMPClientError SMPClient)
|
||||
@@ -212,15 +210,9 @@ connectClient ca@SMPClientAgent {agentCfg, smpClients, smpSessions, msgQ, random
|
||||
where
|
||||
updateSubs sVar = do
|
||||
ss <- readTVar sVar
|
||||
addPendingSubs sVar ss
|
||||
addSubs_ (pendingSrvSubs ca) srv ss
|
||||
pure ss
|
||||
|
||||
addPendingSubs sVar ss = do
|
||||
let ps = pendingSrvSubs ca
|
||||
TM.lookup srv ps >>= \case
|
||||
Just ss' -> TM.union ss ss'
|
||||
_ -> TM.insert srv sVar ps
|
||||
|
||||
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
|
||||
serverDown ss = unless (M.null ss) $ do
|
||||
notify ca . CADisconnected srv $ M.keysSet ss
|
||||
@@ -244,11 +236,11 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
|
||||
runSubWorker =
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \_ loop -> do
|
||||
pending <- atomically getPending
|
||||
forM_ pending $ \cs -> whenM (readTVarIO active) $ do
|
||||
void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv cs)
|
||||
unless (null pending) $ whenM (readTVarIO active) $ do
|
||||
void $ tcpConnectTimeout `timeout` runExceptT (reconnectSMPClient ca srv pending)
|
||||
loop
|
||||
ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
|
||||
getPending = mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca)
|
||||
getPending = maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca)
|
||||
cleanup :: SessionVar (Async ()) -> STM ()
|
||||
cleanup v = do
|
||||
-- Here we wait until TMVar is not empty to prevent worker cleanup happening before worker is added to TMVar.
|
||||
@@ -258,32 +250,22 @@ reconnectClient ca@SMPClientAgent {active, agentCfg, smpSubWorkers, workerSeq} s
|
||||
|
||||
reconnectSMPClient :: SMPClientAgent -> SMPServer -> Map SMPSub C.APrivateAuthKey -> ExceptT SMPClientError IO ()
|
||||
reconnectSMPClient ca@SMPClientAgent {agentCfg} srv cs =
|
||||
withSMP ca srv $ \smp -> do
|
||||
subs' <- filterM (fmap not . atomically . hasSub (srvSubs ca) srv . fst) $ M.assocs cs
|
||||
let (nSubs, rSubs) = partition (isNotifier . fst . fst) subs'
|
||||
withSMP ca srv $ \smp -> liftIO $ do
|
||||
currSubs <- atomically $ maybe (pure M.empty) readTVar =<< TM.lookup srv (srvSubs ca)
|
||||
let (nSubs, rSubs) = foldr (groupSub currSubs) ([], []) $ M.assocs cs
|
||||
subscribe_ smp SPNotifier nSubs
|
||||
subscribe_ smp SPRecipient rSubs
|
||||
where
|
||||
isNotifier = \case
|
||||
SPNotifier -> True
|
||||
SPRecipient -> False
|
||||
subscribe_ :: SMPClient -> SMPSubParty -> [(SMPSub, C.APrivateAuthKey)] -> ExceptT SMPClientError IO ()
|
||||
subscribe_ smp party = mapM_ subscribeBatch . toChunks (agentSubsBatchSize agentCfg)
|
||||
groupSub :: Map SMPSub C.APrivateAuthKey -> (SMPSub, C.APrivateAuthKey) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)]) -> ([(QueueId, C.APrivateAuthKey)], [(QueueId, C.APrivateAuthKey)])
|
||||
groupSub currSubs (s@(party, qId), k) (nSubs, rSubs)
|
||||
| M.member s currSubs = (nSubs, rSubs)
|
||||
| otherwise = case party of
|
||||
SPNotifier -> (s' : nSubs, rSubs)
|
||||
SPRecipient -> (nSubs, s' : rSubs)
|
||||
where
|
||||
subscribeBatch subs' = do
|
||||
let subs'' :: (NonEmpty (QueueId, C.APrivateAuthKey)) = L.map (first snd) subs'
|
||||
rs <- liftIO $ smpSubscribeQueues party ca smp srv subs''
|
||||
let rs' :: (NonEmpty ((SMPSub, C.APrivateAuthKey), Either SMPClientError ())) =
|
||||
L.zipWith (first . const) subs' rs
|
||||
rs'' :: [Either (SMPSub, SMPClientError) (SMPSub, C.APrivateAuthKey)] =
|
||||
map (\(sub, r) -> bimap (fst sub,) (const sub) r) $ L.toList rs'
|
||||
(errs, oks) = partitionEithers rs''
|
||||
(tempErrs, finalErrs) = partition (temporaryClientError . snd) errs
|
||||
mapM_ (atomically . addSubscription ca srv) oks
|
||||
mapM_ (notify ca . CAResubscribed srv) $ L.nonEmpty $ map fst oks
|
||||
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
|
||||
mapM_ (notify ca . CASubError srv) $ L.nonEmpty finalErrs
|
||||
mapM_ (throwE . snd) $ listToMaybe tempErrs
|
||||
s' = (qId, k)
|
||||
subscribe_ :: SMPClient -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> IO ()
|
||||
subscribe_ smp party = mapM_ (smpSubscribeQueues party ca smp srv) . toChunks (agentSubsBatchSize agentCfg)
|
||||
|
||||
notify :: MonadIO m => SMPClientAgent -> SMPClientAgentEvent -> m ()
|
||||
notify ca evt = atomically $ writeTBQueue (agentQ ca) evt
|
||||
@@ -297,7 +279,8 @@ getConnectedSMPServerClient SMPClientAgent {smpClients} srv =
|
||||
$>>= \case
|
||||
(_, Right r) -> pure $ Just $ Right r
|
||||
(v, Left (e, ts_)) ->
|
||||
pure ts_ $>>= \ts -> -- proxy will create a new connection if ts_ is Nothing
|
||||
pure ts_ $>>= \ts ->
|
||||
-- proxy will create a new connection if ts_ is Nothing
|
||||
ifM
|
||||
((ts <) <$> liftIO getCurrentTime) -- error persistence interval period expired?
|
||||
(Nothing <$ atomically (removeSessVar v srv smpClients)) -- proxy will create a new connection
|
||||
@@ -334,86 +317,99 @@ withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPE
|
||||
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
|
||||
throwE e
|
||||
|
||||
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO ()
|
||||
subscribeQueue ca srv sub = do
|
||||
atomically $ addPendingSubscription ca srv sub
|
||||
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleErr
|
||||
where
|
||||
subscribe_ smp = do
|
||||
smpSubscribe smp sub
|
||||
atomically $ addSubscription ca srv sub
|
||||
|
||||
handleErr e = do
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv (fst sub)
|
||||
throwE e
|
||||
|
||||
subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ()))
|
||||
subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO ()
|
||||
subscribeQueuesSMP = subscribeQueues_ SPRecipient
|
||||
|
||||
subscribeQueuesNtfs :: SMPClientAgent -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO (NonEmpty (NotifierId, Either SMPClientError ()))
|
||||
subscribeQueuesNtfs :: SMPClientAgent -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO ()
|
||||
subscribeQueuesNtfs = subscribeQueues_ SPNotifier
|
||||
|
||||
subscribeQueues_ :: SMPSubParty -> SMPClientAgent -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
|
||||
subscribeQueues_ :: SMPSubParty -> SMPClientAgent -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
|
||||
subscribeQueues_ party ca srv subs = do
|
||||
atomically $ forM_ subs $ addPendingSubscription ca srv . first (party,)
|
||||
atomically $ addPendingSubs ca srv party $ L.toList subs
|
||||
runExceptT (getSMPServerClient' ca srv) >>= \case
|
||||
Left e -> pure $ L.map ((,Left e) . fst) subs
|
||||
Right smp -> smpSubscribeQueues party ca smp srv subs
|
||||
Left _ -> pure () -- no call to reconnectClient - failing getSMPServerClient' does that
|
||||
|
||||
smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
|
||||
smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO ()
|
||||
smpSubscribeQueues party ca smp srv subs = do
|
||||
rs <- L.zip subs <$> subscribe smp (L.map swap subs)
|
||||
atomically $ forM rs $ \(sub, r) ->
|
||||
(fst sub,) <$> case r of
|
||||
Right () -> do
|
||||
addSubscription ca srv $ first (party,) sub
|
||||
pure $ Right ()
|
||||
Left e -> do
|
||||
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv (party, fst sub)
|
||||
pure $ Left e
|
||||
rs <- subscribe smp $ L.map swap subs
|
||||
rs' <-
|
||||
atomically $
|
||||
ifM
|
||||
(activeClientSession ca smp srv)
|
||||
(Just <$> processSubscriptions rs)
|
||||
(pure Nothing)
|
||||
case rs' of
|
||||
Just (tempErrs, finalErrs, oks, _) -> do
|
||||
notify_ CASubscribed $ map fst oks
|
||||
notify_ CASubError finalErrs
|
||||
when tempErrs $ reconnectClient ca srv
|
||||
Nothing -> reconnectClient ca srv
|
||||
where
|
||||
processSubscriptions :: NonEmpty (Either SMPClientError ()) -> STM (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId])
|
||||
processSubscriptions rs = do
|
||||
pending <- maybe (pure M.empty) readTVar =<< TM.lookup srv (pendingSrvSubs ca)
|
||||
let acc@(_, _, oks, notPending) = foldr (groupSub pending) (False, [], [], []) (L.zip subs rs)
|
||||
unless (null oks) $ addSubscriptions ca srv party oks
|
||||
unless (null notPending) $ removePendingSubs ca srv party notPending
|
||||
pure acc
|
||||
groupSub :: Map SMPSub C.APrivateAuthKey -> ((QueueId, C.APrivateAuthKey), Either SMPClientError ()) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId]) -> (Bool, [(QueueId, SMPClientError)], [(QueueId, C.APrivateAuthKey)], [QueueId])
|
||||
groupSub pending (s@(qId, _), r) acc@(!tempErrs, finalErrs, oks, notPending) = case r of
|
||||
Right ()
|
||||
| M.member (party, qId) pending -> (tempErrs, finalErrs, s : oks, qId : notPending)
|
||||
| otherwise -> acc
|
||||
Left e
|
||||
| temporaryClientError e -> (True, finalErrs, oks, notPending)
|
||||
| otherwise -> (tempErrs, (qId, e) : finalErrs, oks, qId : notPending)
|
||||
subscribe = case party of
|
||||
SPRecipient -> subscribeSMPQueues
|
||||
SPNotifier -> subscribeSMPQueuesNtfs
|
||||
notify_ :: (SMPServer -> SMPSubParty -> NonEmpty a -> SMPClientAgentEvent) -> [a] -> IO ()
|
||||
notify_ evt qs = mapM_ (notify ca . evt srv party) $ L.nonEmpty qs
|
||||
|
||||
activeClientSession :: SMPClientAgent -> SMPClient -> SMPServer -> STM Bool
|
||||
activeClientSession ca smp srv = sameSess <$> tryReadSessVar srv (smpClients ca)
|
||||
where
|
||||
sessId = sessionId . thParams
|
||||
sameSess = \case
|
||||
Just (Right (_, smp')) -> sessId smp == sessId smp'
|
||||
_ -> False
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer ProtocolServer {host, port} =
|
||||
strEncode host <> B.pack (if null port then "" else ':' : port)
|
||||
|
||||
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO ()
|
||||
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
|
||||
addSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM ()
|
||||
addSubscriptions = addSubsList_ . srvSubs
|
||||
{-# INLINE addSubscriptions #-}
|
||||
|
||||
addPendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM ()
|
||||
addPendingSubs = addSubsList_ . pendingSrvSubs
|
||||
{-# INLINE addPendingSubs #-}
|
||||
|
||||
addSubsList_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [(QueueId, C.APrivateAuthKey)] -> STM ()
|
||||
addSubsList_ subs srv party ss = addSubs_ subs srv ss'
|
||||
where
|
||||
subscribe_ = case party of
|
||||
SPRecipient -> subscribeSMPQueue
|
||||
SPNotifier -> subscribeSMPQueueNotifications
|
||||
ss' = M.fromList $ map (first (party,)) ss
|
||||
|
||||
addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
|
||||
addSubscription ca srv sub = do
|
||||
addSub_ (srvSubs ca) srv sub
|
||||
removePendingSubscription ca srv $ fst sub
|
||||
|
||||
addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
|
||||
addPendingSubscription = addSub_ . pendingSrvSubs
|
||||
|
||||
addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
|
||||
addSub_ subs srv (s, key) =
|
||||
addSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> Map SMPSub C.APrivateAuthKey -> STM ()
|
||||
addSubs_ subs srv ss =
|
||||
TM.lookup srv subs >>= \case
|
||||
Just m -> TM.insert s key m
|
||||
_ -> TM.singleton s key >>= \v -> TM.insert srv v subs
|
||||
Just m -> TM.union ss m
|
||||
_ -> newTVar ss >>= \v -> TM.insert srv v subs
|
||||
|
||||
removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
||||
removeSubscription = removeSub_ . srvSubs
|
||||
|
||||
removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
||||
removePendingSubscription = removeSub_ . pendingSrvSubs
|
||||
{-# INLINE removeSubscription #-}
|
||||
|
||||
removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM ()
|
||||
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
|
||||
|
||||
getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateAuthKey)
|
||||
getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s
|
||||
removePendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM ()
|
||||
removePendingSubs = removeSubs_ . pendingSrvSubs
|
||||
{-# INLINE removePendingSubs #-}
|
||||
|
||||
hasSub :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM Bool
|
||||
hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs
|
||||
removeSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [QueueId] -> STM ()
|
||||
removeSubs_ subs srv party qs = TM.lookup srv subs >>= mapM_ (`modifyTVar'` (`M.withoutKeys` ss))
|
||||
where
|
||||
ss = S.fromList $ map (party,) qs
|
||||
|
||||
@@ -188,33 +188,16 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
runSMPSubscriber :: SMPSubscriber -> M ()
|
||||
runSMPSubscriber SMPSubscriber {newSubQ = subscriberSubQ} =
|
||||
forever $ do
|
||||
subs <- atomically (peekTQueue subscriberSubQ)
|
||||
subs <- atomically $ readTQueue subscriberSubQ
|
||||
let subs' = L.map (\(NtfSub sub) -> sub) subs
|
||||
srv = server $ L.head subs
|
||||
logSubStatus srv "subscribing" $ length subs
|
||||
mapM_ (\NtfSubData {smpQueue} -> updateSubStatus smpQueue NSPending) subs'
|
||||
rs <- liftIO $ subscribeQueues srv subs'
|
||||
(subs'', oks, errs) <- foldM process ([], 0, []) rs
|
||||
atomically $ do
|
||||
void $ readTQueue subscriberSubQ
|
||||
mapM_ (writeTQueue subscriberSubQ . L.map NtfSub) $ L.nonEmpty subs''
|
||||
logSubStatus srv "retrying" $ length subs''
|
||||
logSubStatus srv "subscribed" oks
|
||||
logSubErrors srv errs
|
||||
where
|
||||
process :: ([NtfSubData], Int, [NtfSubStatus]) -> (NtfSubData, Either SMPClientError ()) -> M ([NtfSubData], Int, [NtfSubStatus])
|
||||
process (subs, oks, errs) (sub@NtfSubData {smpQueue}, r) = case r of
|
||||
Right _ -> updateSubStatus smpQueue NSActive $> (subs, oks + 1, errs)
|
||||
Left e -> update <$> handleSubError smpQueue e
|
||||
where
|
||||
update = \case
|
||||
Just err -> (subs, oks, err : errs) -- permanent error, log and don't retry subscription
|
||||
Nothing -> (sub : subs, oks, errs) -- temporary error, retry subscription
|
||||
liftIO $ subscribeQueues srv subs'
|
||||
|
||||
-- \| Subscribe to queues. The list of results can have a different order.
|
||||
subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO (NonEmpty (NtfSubData, Either SMPClientError ()))
|
||||
subscribeQueues srv subs =
|
||||
L.zipWith (\s r -> (s, snd r)) subs <$> subscribeQueuesNtfs ca srv (L.map sub subs)
|
||||
subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO ()
|
||||
subscribeQueues srv subs = subscribeQueuesNtfs ca srv (L.map sub subs)
|
||||
where
|
||||
sub NtfSubData {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey)
|
||||
|
||||
@@ -239,7 +222,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
incNtfStat ntfReceived
|
||||
Right SMP.END -> updateSubStatus smpQueue NSEnd
|
||||
Right (SMP.ERR e) -> logError $ "SMP server error: " <> tshow e
|
||||
Right _ -> logError $ "SMP server unexpected response"
|
||||
Right _ -> logError "SMP server unexpected response"
|
||||
Left e -> logError $ "SMP client error: " <> tshow e
|
||||
|
||||
receiveAgent =
|
||||
@@ -252,11 +235,11 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
|
||||
forM_ subs $ \(_, ntfId) -> do
|
||||
let smpQueue = SMPQueueNtf srv ntfId
|
||||
updateSubStatus smpQueue NSInactive
|
||||
CAResubscribed srv subs -> do
|
||||
forM_ subs $ \(_, ntfId) -> updateSubStatus (SMPQueueNtf srv ntfId) NSActive
|
||||
logSubStatus srv "resubscribed" $ length subs
|
||||
CASubError srv errs ->
|
||||
forM errs (\((_, ntfId), err) -> handleSubError (SMPQueueNtf srv ntfId) err)
|
||||
CASubscribed srv _ subs -> do
|
||||
forM_ subs $ \ntfId -> updateSubStatus (SMPQueueNtf srv ntfId) NSActive
|
||||
logSubStatus srv "subscribed" $ length subs
|
||||
CASubError srv _ errs ->
|
||||
forM errs (\(ntfId, err) -> handleSubError (SMPQueueNtf srv ntfId) err)
|
||||
>>= logSubErrors srv . catMaybes . L.toList
|
||||
|
||||
logSubStatus srv event n =
|
||||
|
||||
@@ -193,8 +193,8 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
CAConnected srv -> logInfo $ "SMP server connected " <> showServer' srv
|
||||
CADisconnected srv [] -> logInfo $ "SMP server disconnected " <> showServer' srv
|
||||
CADisconnected srv subs -> logError $ "SMP server disconnected " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
|
||||
CAResubscribed srv subs -> logError $ "SMP server resubscribed " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
|
||||
CASubError srv errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs)
|
||||
CASubscribed srv _ subs -> logError $ "SMP server subscribed " <> showServer' srv <> " / subscriptions: " <> tshow (length subs)
|
||||
CASubError srv _ errs -> logError $ "SMP server subscription errors " <> showServer' srv <> " / errors: " <> tshow (length errs)
|
||||
where
|
||||
showServer' = decodeLatin1 . strEncode . host
|
||||
|
||||
@@ -514,11 +514,11 @@ clientDisconnected c@Client {clientId, subscriptions, connected, sessionId, endT
|
||||
sameClientId :: Client -> Client -> Bool
|
||||
sameClientId Client {clientId} Client {clientId = cId'} = clientId == cId'
|
||||
|
||||
cancelSub :: TVar Sub -> IO ()
|
||||
cancelSub sub =
|
||||
readTVarIO sub >>= \case
|
||||
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
|
||||
_ -> return ()
|
||||
cancelSub :: Sub -> IO ()
|
||||
cancelSub s =
|
||||
readTVarIO (subThread s) >>= \case
|
||||
SubThread t -> liftIO $ deRefWeak t >>= mapM_ killThread
|
||||
_ -> pure ()
|
||||
|
||||
receive :: Transport c => THandleSMP c 'TServer -> Client -> M ()
|
||||
receive h@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do
|
||||
@@ -683,7 +683,7 @@ forkClient Client {endThreads, endThreadSeq} label action = do
|
||||
mkWeakThreadId t >>= atomically . modifyTVar' endThreads . IM.insert tId
|
||||
|
||||
client :: THandleParams SMPVersion 'TServer -> Client -> Server -> M ()
|
||||
client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, notifiers} = do
|
||||
client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, procThreads} Server {subscribedQ, ntfSubscribedQ, subscribers, notifiers} = do
|
||||
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " commands"
|
||||
forever $
|
||||
atomically (readTBQueue rcvQ)
|
||||
@@ -901,36 +901,36 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
Nothing -> do
|
||||
atomically $ modifyTVar' (qSub stats) (+ 1)
|
||||
newSub >>= deliver
|
||||
Just sub ->
|
||||
readTVarIO sub >>= \case
|
||||
Sub {subThread = ProhibitSub} -> do
|
||||
Just s@Sub {subThread} ->
|
||||
readTVarIO subThread >>= \case
|
||||
ProhibitSub -> do
|
||||
-- cannot use SUB in the same connection where GET was used
|
||||
atomically $ modifyTVar' (qSubProhibited stats) (+ 1)
|
||||
pure (corrId, rId, ERR $ CMD PROHIBITED)
|
||||
s -> do
|
||||
_ -> do
|
||||
atomically $ modifyTVar' (qSubDuplicate stats) (+ 1)
|
||||
atomically (tryTakeTMVar $ delivered s) >> deliver sub
|
||||
atomically (tryTakeTMVar $ delivered s) >> deliver s
|
||||
where
|
||||
newSub :: M (TVar Sub)
|
||||
newSub :: M Sub
|
||||
newSub = time "SUB newSub" . atomically $ do
|
||||
writeTQueue subscribedQ (rId, clnt)
|
||||
sub <- newTVar =<< newSubscription NoSub
|
||||
sub <- newSubscription NoSub
|
||||
TM.insert rId sub subscriptions
|
||||
pure sub
|
||||
deliver :: TVar Sub -> M (Transmission BrokerMsg)
|
||||
deliver :: Sub -> M (Transmission BrokerMsg)
|
||||
deliver sub = do
|
||||
q <- getStoreMsgQueue "SUB" rId
|
||||
msg_ <- atomically $ tryPeekMsg q
|
||||
deliverMessage "SUB" qr rId sub q msg_
|
||||
deliverMessage "SUB" qr rId sub msg_
|
||||
|
||||
getMessage :: QueueRec -> M (Transmission BrokerMsg)
|
||||
getMessage qr = time "GET" $ do
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
atomically newSub >>= getMessage_
|
||||
Just sub ->
|
||||
readTVarIO sub >>= \case
|
||||
s@Sub {subThread = ProhibitSub} ->
|
||||
Just s@Sub {subThread} ->
|
||||
readTVarIO subThread >>= \case
|
||||
ProhibitSub ->
|
||||
atomically (tryTakeTMVar $ delivered s)
|
||||
>> getMessage_ s
|
||||
-- cannot use GET in the same connection where there is an active subscription
|
||||
@@ -939,8 +939,7 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
newSub :: STM Sub
|
||||
newSub = do
|
||||
s <- newSubscription ProhibitSub
|
||||
sub <- newTVar s
|
||||
TM.insert queueId sub subscriptions
|
||||
TM.insert queueId s subscriptions
|
||||
pure s
|
||||
getMessage_ :: Sub -> M (Transmission BrokerMsg)
|
||||
getMessage_ s = do
|
||||
@@ -968,25 +967,24 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
Nothing -> pure $ err NO_MSG
|
||||
Just sub ->
|
||||
atomically (getDelivered sub) >>= \case
|
||||
Just s -> do
|
||||
Just st -> do
|
||||
q <- getStoreMsgQueue "ACK" queueId
|
||||
case s of
|
||||
Sub {subThread = ProhibitSub} -> do
|
||||
case st of
|
||||
ProhibitSub -> do
|
||||
deletedMsg_ <- atomically $ tryDelMsg q msgId
|
||||
mapM_ updateStats deletedMsg_
|
||||
pure ok
|
||||
_ -> do
|
||||
(deletedMsg_, msg_) <- atomically $ tryDelPeekMsg q msgId
|
||||
mapM_ updateStats deletedMsg_
|
||||
deliverMessage "ACK" qr queueId sub q msg_
|
||||
deliverMessage "ACK" qr queueId sub msg_
|
||||
_ -> pure $ err NO_MSG
|
||||
where
|
||||
getDelivered :: TVar Sub -> STM (Maybe Sub)
|
||||
getDelivered sub = do
|
||||
s@Sub {delivered} <- readTVar sub
|
||||
getDelivered :: Sub -> STM (Maybe SubscriptionThread)
|
||||
getDelivered Sub {delivered, subThread} = do
|
||||
tryTakeTMVar delivered $>>= \msgId' ->
|
||||
if msgId == msgId' || B.null msgId
|
||||
then pure $ Just s
|
||||
then Just <$> readTVar subThread
|
||||
else putTMVar delivered msgId' $> Nothing
|
||||
updateStats :: Message -> M ()
|
||||
updateStats = \case
|
||||
@@ -1024,7 +1022,8 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
Nothing -> do
|
||||
atomically $ modifyTVar' (msgSentQuota stats) (+ 1)
|
||||
pure $ err QUOTA
|
||||
Just msg -> time "SEND ok" $ do
|
||||
Just (msg, wasEmpty) -> time "SEND ok" $ do
|
||||
when wasEmpty $ tryDeliverMessage msg
|
||||
when (notification msgFlags) $ do
|
||||
forM_ (notifier qr) $ \ntf -> do
|
||||
asks random >>= atomically . trySendNotification ntf msg >>= \case
|
||||
@@ -1058,6 +1057,52 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar' (msgExpired stats) (+ deleted)
|
||||
|
||||
-- The condition for delivery of the message is:
|
||||
-- - the queue was empty when the message was sent,
|
||||
-- - there is subscribed recipient,
|
||||
-- - no message was "delivered" that was not acknowledged.
|
||||
-- If the send queue of the subscribed client is not full the message is put there in the same transaction.
|
||||
-- If the queue is not full, then the thread is created where these checks are made:
|
||||
-- - it is the same subscribed client (in case it was reconnected it would receive message via SUB command)
|
||||
-- - nothing was delivered to this subscription (to avoid race conditions with the recipient).
|
||||
tryDeliverMessage :: Message -> M ()
|
||||
tryDeliverMessage msg = atomically deliverToSub >>= mapM_ forkDeliver
|
||||
where
|
||||
rId = recipientId qr
|
||||
deliverToSub =
|
||||
TM.lookup rId subscribers
|
||||
$>>= \rc@Client {subscriptions = subs, sndQ = q} -> TM.lookup rId subs
|
||||
$>>= \s@Sub {subThread, delivered} -> readTVar subThread >>= \case
|
||||
NoSub ->
|
||||
tryTakeTMVar delivered >>= \case
|
||||
Just _ -> pure Nothing -- if a message was already delivered, should not deliver more
|
||||
Nothing ->
|
||||
ifM
|
||||
(isFullTBQueue q)
|
||||
(writeTVar subThread SubPending $> Just (rc, s))
|
||||
(deliver q s $> Nothing)
|
||||
_ -> pure Nothing
|
||||
deliver q s = do
|
||||
let encMsg = encryptMsg qr msg
|
||||
writeTBQueue q [(CorrId "", rId, MSG encMsg)]
|
||||
void $ setDelivered s msg
|
||||
forkDeliver (rc@Client {sndQ = q}, s@Sub {subThread, delivered}) = do
|
||||
t <- mkWeakThreadId =<< forkIO deliverThread
|
||||
atomically . modifyTVar' subThread $ \case
|
||||
-- this case is needed because deliverThread can exit before it
|
||||
SubPending -> SubThread t
|
||||
st -> st
|
||||
where
|
||||
deliverThread = do
|
||||
labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " deliver/SEND"
|
||||
time "deliver" . atomically $
|
||||
whenM (maybe False (sameClientId rc) <$> TM.lookup rId subscribers) $ do
|
||||
tryTakeTMVar delivered >>= \case
|
||||
Just _ -> pure () -- if a message was already delivered, should not deliver more
|
||||
Nothing -> do
|
||||
deliver q s
|
||||
writeTVar subThread NoSub
|
||||
|
||||
trySendNotification :: NtfCreds -> Message -> TVar ChaChaDRG -> STM (Maybe Bool)
|
||||
trySendNotification NtfCreds {notifierId, rcvNtfDhSecret} msg ntfNonceDrg =
|
||||
mapM (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
|
||||
@@ -1132,35 +1177,17 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
verified = \case
|
||||
VRVerified qr -> Right (qr, (corrId', entId', cmd'))
|
||||
VRFailed -> Left (corrId', entId', ERR AUTH)
|
||||
deliverMessage :: T.Text -> QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> M (Transmission BrokerMsg)
|
||||
deliverMessage name qr rId sub q msg_ = time (name <> " deliver") $ do
|
||||
readTVarIO sub >>= \case
|
||||
s@Sub {subThread = NoSub} ->
|
||||
case msg_ of
|
||||
Just msg ->
|
||||
let encMsg = encryptMsg qr msg
|
||||
in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg)
|
||||
_ -> forkSub $> resp
|
||||
_ -> pure resp
|
||||
deliverMessage :: T.Text -> QueueRec -> RecipientId -> Sub -> Maybe Message -> M (Transmission BrokerMsg)
|
||||
deliverMessage name qr rId s@Sub {subThread} msg_ = time (name <> " deliver") . atomically $
|
||||
readTVar subThread >>= \case
|
||||
ProhibitSub -> pure resp
|
||||
_ -> case msg_ of
|
||||
Just msg ->
|
||||
let encMsg = encryptMsg qr msg
|
||||
in setDelivered s msg $> (corrId, rId, MSG encMsg)
|
||||
_ -> pure resp
|
||||
where
|
||||
resp = (corrId, rId, OK)
|
||||
forkSub :: M ()
|
||||
forkSub = do
|
||||
atomically . modifyTVar' sub $ \s -> s {subThread = SubPending}
|
||||
t <- mkWeakThreadId =<< forkIO subscriber
|
||||
atomically . modifyTVar' sub $ \case
|
||||
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
|
||||
s -> s
|
||||
where
|
||||
subscriber = do
|
||||
labelMyThread $ B.unpack ("client $" <> encode sessionId) <> " subscriber/" <> T.unpack name
|
||||
msg <- atomically $ peekMsg q
|
||||
time "subscriber" . atomically $ do
|
||||
let encMsg = encryptMsg qr msg
|
||||
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
|
||||
s <- readTVar sub
|
||||
void $ setDelivered s msg
|
||||
writeTVar sub $! s {subThread = NoSub}
|
||||
|
||||
time :: T.Text -> M a -> M a
|
||||
time name = timed name queueId
|
||||
@@ -1202,9 +1229,9 @@ client thParams' clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessi
|
||||
pure QueueInfo {qiSnd = isJust senderKey, qiNtf = isJust notifier, qiSub, qiSize, qiMsg}
|
||||
pure (corrId, queueId, INFO info)
|
||||
where
|
||||
mkQSub sub = do
|
||||
Sub {subThread, delivered} <- readTVar sub
|
||||
let qSubThread = case subThread of
|
||||
mkQSub Sub {subThread, delivered} = do
|
||||
st <- readTVar subThread
|
||||
let qSubThread = case st of
|
||||
NoSub -> QNoSub
|
||||
SubPending -> QSubPending
|
||||
SubThread _ -> QSubThread
|
||||
|
||||
@@ -47,7 +47,6 @@ data ServerConfig = ServerConfig
|
||||
{ transports :: [(ServiceName, ATransport)],
|
||||
smpHandshakeTimeout :: Int,
|
||||
tbqSize :: Natural,
|
||||
-- serverTbqSize :: Natural,
|
||||
msgQueueQuota :: Int,
|
||||
queueIdBytes :: Int,
|
||||
msgIdBytes :: Int,
|
||||
@@ -145,7 +144,7 @@ type ClientId = Int
|
||||
|
||||
data Client = Client
|
||||
{ clientId :: ClientId,
|
||||
subscriptions :: TMap RecipientId (TVar Sub),
|
||||
subscriptions :: TMap RecipientId Sub,
|
||||
ntfSubscriptions :: TMap NotifierId (),
|
||||
rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)),
|
||||
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
|
||||
@@ -164,7 +163,7 @@ data Client = Client
|
||||
data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) | ProhibitSub
|
||||
|
||||
data Sub = Sub
|
||||
{ subThread :: SubscriptionThread,
|
||||
{ subThread :: TVar SubscriptionThread,
|
||||
delivered :: TMVar MsgId
|
||||
}
|
||||
|
||||
@@ -194,8 +193,9 @@ newClient nextClientId qSize thVersion sessionId createdAt = do
|
||||
return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, msgQ, procThreads, endThreads, endThreadSeq, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt}
|
||||
|
||||
newSubscription :: SubscriptionThread -> STM Sub
|
||||
newSubscription subThread = do
|
||||
newSubscription st = do
|
||||
delivered <- newEmptyTMVar
|
||||
subThread <- newTVar st
|
||||
return Sub {subThread, delivered}
|
||||
|
||||
newEnv :: ServerConfig -> IO Env
|
||||
|
||||
@@ -255,7 +255,6 @@ smpServerCLI_ generateSite serveStaticFiles cfgPath logPath =
|
||||
{ transports = iniTransports ini,
|
||||
smpHandshakeTimeout = 120000000,
|
||||
tbqSize = 64,
|
||||
-- serverTbqSize = 1024,
|
||||
msgQueueQuota = 128,
|
||||
queueIdBytes = 24,
|
||||
msgIdBytes = 24, -- must be at least 24 bytes, it is used as 192-bit nonce for XSalsa20
|
||||
|
||||
@@ -75,7 +75,7 @@ snapshotMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (snapshotTQueue .
|
||||
mapM_ (writeTQueue q) msgs
|
||||
pure msgs
|
||||
|
||||
writeMsg :: MsgQueue -> Message -> STM (Maybe Message)
|
||||
writeMsg :: MsgQueue -> Message -> STM (Maybe (Message, Bool))
|
||||
writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do
|
||||
canWrt <- readTVar canWrite
|
||||
empty <- isEmptyTQueue q
|
||||
@@ -85,7 +85,7 @@ writeMsg MsgQueue {msgQueue = q, quota, canWrite, size} !msg = do
|
||||
writeTVar canWrite $! canWrt'
|
||||
modifyTVar' size (+ 1)
|
||||
if canWrt'
|
||||
then writeTQueue q msg $> Just msg
|
||||
then writeTQueue q msg $> Just (msg, empty)
|
||||
else (writeTQueue q $! msgQuota) $> Nothing
|
||||
else pure Nothing
|
||||
where
|
||||
|
||||
@@ -263,6 +263,11 @@ functionalAPITests t = do
|
||||
withSmpServer t testAgentClient3
|
||||
it "should establish connection without PQ encryption and enable it" $
|
||||
withSmpServer t testEnablePQEncryption
|
||||
describe "Duplex connection - delivery stress test" $ do
|
||||
describe "one way (50)" $ testMatrix2Stress t $ runAgentClientStressTestOneWay 50
|
||||
xdescribe "one way (1000)" $ testMatrix2Stress t $ runAgentClientStressTestOneWay 1000
|
||||
describe "two way concurrently (50)" $ testMatrix2Stress t $ runAgentClientStressTestConc 25
|
||||
xdescribe "two way concurrently (1000)" $ testMatrix2Stress t $ runAgentClientStressTestConc 500
|
||||
describe "Establishing duplex connection, different PQ settings" $ do
|
||||
testPQMatrix2 t $ runAgentClientTestPQ True
|
||||
describe "Establishing duplex connection v2, different Ratchet versions" $
|
||||
@@ -482,6 +487,19 @@ testMatrix2 t runTest = do
|
||||
it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff False
|
||||
it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff False
|
||||
|
||||
testMatrix2Stress :: HasCallStack => ATransport -> (PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testMatrix2Stress t runTest = do
|
||||
it "current, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aCfg aCfg (initAgentServersProxy SPMAlways SPFProhibit) 1 $ runTest PQSupportOn True
|
||||
it "v8, via proxy" $ withSmpServerProxy t $ runTestCfgServers2 aProxyCfgV8 aProxyCfgV8 (initAgentServersProxy SPMAlways SPFProhibit) 3 $ runTest PQSupportOn True
|
||||
it "current" $ withSmpServer t $ runTestCfg2 aCfg aCfg 1 $ runTest PQSupportOn False
|
||||
it "prev" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfgVPrev 3 $ runTest PQSupportOff False
|
||||
it "prev to current" $ withSmpServer t $ runTestCfg2 aCfgVPrev aCfg 3 $ runTest PQSupportOff False
|
||||
it "current to prev" $ withSmpServer t $ runTestCfg2 aCfg aCfgVPrev 3 $ runTest PQSupportOff False
|
||||
where
|
||||
aCfg = agentCfg {messageRetryInterval = fastMessageRetryInterval}
|
||||
aProxyCfgV8 = agentProxyCfgV8 {messageRetryInterval = fastMessageRetryInterval}
|
||||
aCfgVPrev = agentCfgVPrev {messageRetryInterval = fastMessageRetryInterval}
|
||||
|
||||
testBasicMatrix2 :: HasCallStack => ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec
|
||||
testBasicMatrix2 t runTest = do
|
||||
it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 1 $ runTest
|
||||
@@ -616,6 +634,71 @@ runAgentClientTestPQ viaProxy (alice, aPQ) (bob, bPQ) baseId =
|
||||
pqConnectionMode :: InitialKeys -> PQSupport -> Bool
|
||||
pqConnectionMode pqMode1 pqMode2 = supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2
|
||||
|
||||
runAgentClientStressTestOneWay :: HasCallStack => Int64 -> PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientStressTestOneWay n pqSupport viaProxy alice bob baseId = runRight_ $ do
|
||||
let pqEnc = PQEncryption $ supportPQ pqSupport
|
||||
(aliceId, bobId) <- makeConnection_ pqSupport alice bob
|
||||
let proxySrv = if viaProxy then Just testSMPServer else Nothing
|
||||
message i = "message " <> bshow i
|
||||
concurrently_
|
||||
( forM_ ([1 .. n] :: [Int64]) $ \i -> do
|
||||
mId <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags (message i)
|
||||
liftIO $ do
|
||||
mId >= i `shouldBe` True
|
||||
let getEvent =
|
||||
get alice >>= \case
|
||||
("", c, A.SENT mId' srv) -> c == bobId && mId' >= baseId + i && srv == proxySrv `shouldBe` True
|
||||
("", c, QCONT) -> do
|
||||
c == bobId `shouldBe` True
|
||||
getEvent
|
||||
r -> expectationFailure $ "wrong message: " <> show r
|
||||
getEvent
|
||||
)
|
||||
( forM_ ([1 .. n] :: [Int64]) $ \i -> do
|
||||
get bob >>= \case
|
||||
("", c, Msg' mId pq msg) -> do
|
||||
liftIO $ c == aliceId && mId >= baseId + i && pq == pqEnc && msg == message i `shouldBe` True
|
||||
ackMessage bob aliceId mId Nothing
|
||||
r -> liftIO $ expectationFailure $ "wrong message: " <> show r
|
||||
)
|
||||
liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob"
|
||||
where
|
||||
msgId = subtract baseId . fst
|
||||
|
||||
runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientStressTestConc n pqSupport viaProxy alice bob baseId = runRight_ $ do
|
||||
let pqEnc = PQEncryption $ supportPQ pqSupport
|
||||
(aliceId, bobId) <- makeConnection_ pqSupport alice bob
|
||||
let proxySrv = if viaProxy then Just testSMPServer else Nothing
|
||||
message i = "message " <> bshow i
|
||||
loop a bId mIdVar i = do
|
||||
when (i <= n) $ do
|
||||
mId <- msgId <$> A.sendMessage a bId pqEnc SMP.noMsgFlags (message i)
|
||||
liftIO $ mId >= i `shouldBe` True
|
||||
let getEvent = do
|
||||
get a >>= \case
|
||||
("", c, A.SENT _ srv) -> liftIO $ c == bId && srv == proxySrv `shouldBe` True
|
||||
("", c, QCONT) -> do
|
||||
liftIO $ c == bId `shouldBe` True
|
||||
getEvent
|
||||
("", c, Msg' mId pq msg) -> do
|
||||
-- tests that mId increases
|
||||
liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True
|
||||
liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True
|
||||
ackMessage a bId mId Nothing
|
||||
r -> liftIO $ expectationFailure $ "wrong message: " <> show r
|
||||
getEvent
|
||||
amId <- newTVarIO 0
|
||||
bmId <- newTVarIO 0
|
||||
concurrently_
|
||||
(forM_ ([1 .. n * 2] :: [Int64]) $ loop alice bobId amId)
|
||||
(forM_ ([1 .. n * 2] :: [Int64]) $ loop bob aliceId bmId)
|
||||
liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice"
|
||||
liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob"
|
||||
where
|
||||
msgId = subtract baseId . fst
|
||||
|
||||
testEnablePQEncryption :: HasCallStack => IO ()
|
||||
testEnablePQEncryption =
|
||||
withAgentClients2 $ \ca cb -> runRight_ $ do
|
||||
@@ -789,10 +872,17 @@ runAgentClientContactTestPQ3 viaProxy (alice, aPQ) (bob, bPQ) (tom, tPQ) baseId
|
||||
ackMessage a bId (baseId + 2) Nothing
|
||||
|
||||
noMessages :: HasCallStack => AgentClient -> String -> Expectation
|
||||
noMessages c err = tryGet `shouldReturn` ()
|
||||
noMessages = noMessages_ False
|
||||
|
||||
noMessagesIngoreQCONT :: AgentClient -> String -> Expectation
|
||||
noMessagesIngoreQCONT = noMessages_ True
|
||||
|
||||
noMessages_ :: Bool -> HasCallStack => AgentClient -> String -> Expectation
|
||||
noMessages_ ingoreQCONT c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` get c >>= \case
|
||||
Just (_, _, QCONT) | ingoreQCONT -> noMessages_ ingoreQCONT c err
|
||||
Just msg -> error $ err <> ": " <> show msg
|
||||
_ -> return ()
|
||||
|
||||
@@ -1038,6 +1128,7 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do
|
||||
-- version increases to max compatible
|
||||
|
||||
disposeAgentClient alice
|
||||
threadDelay 250000
|
||||
alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB
|
||||
|
||||
runRight_ $ do
|
||||
@@ -1193,7 +1284,6 @@ testDeliveryAfterSubscriptionError t = do
|
||||
withAgentClients2 $ \a b -> do
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection a bId
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ subscribeConnection b aId
|
||||
pure ()
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
|
||||
withUP a bId $ \case ("", c, SENT 2) -> c == bId; _ -> False
|
||||
withUP b aId $ \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
@@ -1229,13 +1319,13 @@ testMsgDeliveryQuotaExceeded t =
|
||||
|
||||
testExpireMessage :: HasCallStack => ATransport -> IO ()
|
||||
testExpireMessage t =
|
||||
withAgent 1 agentCfg {messageTimeout = 1, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a ->
|
||||
withAgent 1 agentCfg {messageTimeout = 1.5, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a ->
|
||||
withAgent 2 agentCfg initAgentServers testDB2 $ \b -> do
|
||||
(aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ makeConnection a b
|
||||
nGet a =##> \case ("", "", DOWN _ [c]) -> c == bId; _ -> False
|
||||
nGet b =##> \case ("", "", DOWN _ [c]) -> c == aId; _ -> False
|
||||
2 <- runRight $ sendMessage a bId SMP.noMsgFlags "1"
|
||||
threadDelay 1000000
|
||||
threadDelay 1500000
|
||||
3 <- runRight $ sendMessage a bId SMP.noMsgFlags "2" -- this won't expire
|
||||
get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
|
||||
@@ -1245,7 +1335,7 @@ testExpireMessage t =
|
||||
|
||||
testExpireManyMessages :: HasCallStack => ATransport -> IO ()
|
||||
testExpireManyMessages t =
|
||||
withAgent 1 agentCfg {messageTimeout = 1, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a ->
|
||||
withAgent 1 agentCfg {messageTimeout = 1.5, messageRetryInterval = fastMessageRetryInterval} initAgentServers testDB $ \a ->
|
||||
withAgent 2 agentCfg initAgentServers testDB2 $ \b -> do
|
||||
(aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ makeConnection a b
|
||||
runRight_ $ do
|
||||
@@ -1254,7 +1344,7 @@ testExpireManyMessages t =
|
||||
2 <- sendMessage a bId SMP.noMsgFlags "1"
|
||||
3 <- sendMessage a bId SMP.noMsgFlags "2"
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "3"
|
||||
liftIO $ threadDelay 1000000
|
||||
liftIO $ threadDelay 1500000
|
||||
5 <- sendMessage a bId SMP.noMsgFlags "4" -- this won't expire
|
||||
get a =##> \case ("", c, MERR 2 (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
-- get a =##> \case ("", c, MERRS [5, 6] (BROKER _ e)) -> bId == c && (e == TIMEOUT || e == NETWORK); _ -> False
|
||||
@@ -1274,7 +1364,7 @@ testExpireManyMessages t =
|
||||
withUP b aId $ \case ("", _, MsgErr 2 (MsgSkipped 2 4) "4") -> True; _ -> False
|
||||
ackMessage b aId 2 Nothing
|
||||
|
||||
withUP :: AgentClient -> ConnId -> (AEntityTransmission 'AEConn -> Bool) -> ExceptT AgentErrorType IO ()
|
||||
withUP :: HasCallStack => AgentClient -> ConnId -> (AEntityTransmission 'AEConn -> Bool) -> ExceptT AgentErrorType IO ()
|
||||
withUP a bId p =
|
||||
liftIO $
|
||||
getInAnyOrder
|
||||
@@ -2896,6 +2986,7 @@ testServerMultipleIdentities =
|
||||
bob' <- liftIO $ do
|
||||
Left (BROKER _ NETWORK) <- runExceptT $ joinConnection bob 1 True secondIdentityCReq "bob's connInfo" SMSubscribe
|
||||
disposeAgentClient bob
|
||||
threadDelay 250000
|
||||
getSMPAgentClient' 3 agentCfg initAgentServers testDB2
|
||||
subscribeConnection bob' aliceId
|
||||
exchangeGreetingsMsgId 4 alice bobId bob' aliceId
|
||||
@@ -3045,7 +3136,7 @@ testServerQueueInfo = do
|
||||
pure ()
|
||||
where
|
||||
checkEmptyQ c cId qiSnd' = do
|
||||
r <- checkQ c cId qiSnd' (Just QSubThread) 0 Nothing
|
||||
r <- checkQ c cId qiSnd' (Just QNoSub) 0 Nothing
|
||||
liftIO $ r `shouldBe` Nothing
|
||||
checkMsgQ c cId qiSize' = do
|
||||
r <- checkQ c cId True (Just QNoSub) qiSize' (Just MTMessage)
|
||||
|
||||
@@ -374,20 +374,20 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} baseId ali
|
||||
-- alice client already has subscription for the connection
|
||||
Left (CMD PROHIBITED _) <- runExceptT $ getNotificationMessage alice nonce message
|
||||
|
||||
threadDelay 300000
|
||||
threadDelay 500000
|
||||
suspendAgent alice 0
|
||||
closeSQLiteStore store
|
||||
threadDelay 300000
|
||||
threadDelay 500000
|
||||
|
||||
-- aliceNtf client doesn't have subscription and is allowed to get notification message
|
||||
withAgent 3 aliceCfg initAgentServers testDB $ \aliceNtf -> runRight_ $ do
|
||||
(_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message
|
||||
pure ()
|
||||
|
||||
threadDelay 300000
|
||||
threadDelay 500000
|
||||
reopenSQLiteStore store
|
||||
foregroundAgent alice
|
||||
threadDelay 300000
|
||||
threadDelay 500000
|
||||
|
||||
runRight_ $ do
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
|
||||
@@ -100,7 +100,6 @@ cfg =
|
||||
{ transports = [],
|
||||
smpHandshakeTimeout = 60000000,
|
||||
tbqSize = 1,
|
||||
-- serverTbqSize = 1,
|
||||
msgQueueQuota = 4,
|
||||
queueIdBytes = 24,
|
||||
msgIdBytes = 24,
|
||||
|
||||
@@ -509,19 +509,21 @@ testWithStoreLog at@(ATransport t) =
|
||||
writeTVar senderId1 sId1
|
||||
writeTVar notifierId nId
|
||||
Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB)
|
||||
signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") >>= \case
|
||||
Resp "bcda" _ OK -> pure ()
|
||||
r -> unexpected r
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 h
|
||||
(mId1, msg1) <-
|
||||
signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") >>= \case
|
||||
Resp "" _ (Msg mId1 msg1) -> pure (mId1, msg1)
|
||||
r -> error $ "unexpected response " <> take 100 (show r)
|
||||
Resp "bcda" _ OK <- tGet1 h
|
||||
(decryptMsgV3 dhShared mId1 msg1, Right "hello") #== "delivered from queue 1"
|
||||
Resp "" _ (NMSG _ _) <- tGet1 h1
|
||||
|
||||
(sId2, rId2, rKey2, dhShared2) <- createAndSecureQueue h sPub2
|
||||
atomically $ writeTVar senderId2 sId2
|
||||
signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too") >>= \case
|
||||
Resp "cdab" _ OK -> pure ()
|
||||
r -> unexpected r
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet1 h
|
||||
(mId2, msg2) <-
|
||||
signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too") >>= \case
|
||||
Resp "" _ (Msg mId2 msg2) -> pure (mId2, msg2)
|
||||
r -> error $ "unexpected response " <> take 100 (show r)
|
||||
Resp "cdab" _ OK <- tGet1 h
|
||||
(decryptMsgV3 dhShared2 mId2 msg2, Right "hello too") #== "delivered from queue 2"
|
||||
|
||||
Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, DEL)
|
||||
@@ -884,7 +886,7 @@ testMsgExpireOnInterval t =
|
||||
testSMPClient @c $ \sh -> do
|
||||
(sId, rId, rKey, _) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello (should expire)")
|
||||
threadDelay 2500000
|
||||
threadDelay 3000000
|
||||
testSMPClient @c $ \rh -> do
|
||||
signSendRecv rh rKey ("2", rId, SUB) >>= \case
|
||||
Resp "2" _ OK -> pure ()
|
||||
|
||||
Reference in New Issue
Block a user