mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-01 16:05:55 +00:00
* agent: do not mark subscriptions on expired sessions as active, do mark delayed subscriptions as active on the same session, SUBOK response in the next SMP protocol version
* client: prevent sub actions from zombie sessions (#1122)
* client: prevent sub actions from zombie sessions
* error handling
* add AERR to pass background errors to client
* switch to activeClientSession
* put closeClient under activeClientSession
* rename
* remove AERR, do not skip processing
* move check and state update to one transaction
* catch extra UPs
* fix
* check queue is still pending before making it active
---------
Co-authored-by: Evgeny Poberezkin <evgeny@poberezkin.com>
* do not forward agent error
* revert not expiring sending subs
* fixes
* track subscription responses better
* add pending connection
* Revert "revert not expiring sending subs"
This reverts commit 4310a69391.
* do not expire sending commands
* rename
* fix race
* function
---------
Co-authored-by: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com>
382 lines
15 KiB
Haskell
382 lines
15 KiB
Haskell
{-# LANGUAGE DuplicateRecordFields #-}
|
|
{-# LANGUAGE FlexibleInstances #-}
|
|
{-# LANGUAGE InstanceSigs #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TupleSections #-}
|
|
{-# LANGUAGE TypeApplications #-}
|
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
|
|
|
module Simplex.Messaging.Client.Agent where
|
|
|
|
import Control.Concurrent (forkIO)
|
|
import Control.Concurrent.Async (Async, uninterruptibleCancel)
|
|
import Control.Logger.Simple
|
|
import Control.Monad
|
|
import Control.Monad.Except
|
|
import Control.Monad.IO.Unlift
|
|
import Control.Monad.Trans.Except
|
|
import Control.Monad.Trans.Reader
|
|
import Crypto.Random (ChaChaDRG)
|
|
import Data.Bifunctor (bimap, first)
|
|
import Data.ByteString.Char8 (ByteString)
|
|
import qualified Data.ByteString.Char8 as B
|
|
import Data.Either (partitionEithers)
|
|
import Data.List (partition)
|
|
import Data.List.NonEmpty (NonEmpty)
|
|
import qualified Data.List.NonEmpty as L
|
|
import Data.Map.Strict (Map)
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Maybe (listToMaybe)
|
|
import Data.Set (Set)
|
|
import Data.Text.Encoding
|
|
import Data.Tuple (swap)
|
|
import Numeric.Natural
|
|
import Simplex.Messaging.Agent.RetryInterval
|
|
import Simplex.Messaging.Client
|
|
import qualified Simplex.Messaging.Crypto as C
|
|
import Simplex.Messaging.Encoding.String
|
|
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer)
|
|
import Simplex.Messaging.Session
|
|
import Simplex.Messaging.TMap (TMap)
|
|
import qualified Simplex.Messaging.TMap as TM
|
|
import Simplex.Messaging.Transport
|
|
import Simplex.Messaging.Util (catchAll_, toChunks, ($>>=))
|
|
import System.Timeout (timeout)
|
|
import UnliftIO (async)
|
|
import UnliftIO.Exception (Exception)
|
|
import qualified UnliftIO.Exception as E
|
|
import UnliftIO.STM
|
|
|
|
type SMPClientVar = SessionVar (Either SMPClientError SMPClient)
|
|
|
|
data SMPClientAgentEvent
|
|
= CAConnected SMPServer
|
|
| CADisconnected SMPServer (Set SMPSub)
|
|
| CAReconnected SMPServer
|
|
| CAResubscribed SMPServer (NonEmpty SMPSub)
|
|
| CASubError SMPServer (NonEmpty (SMPSub, SMPClientError))
|
|
|
|
data SMPSubParty = SPRecipient | SPNotifier
|
|
deriving (Eq, Ord, Show)
|
|
|
|
type SMPSub = (SMPSubParty, QueueId)
|
|
|
|
-- type SMPServerSub = (SMPServer, SMPSub)
|
|
|
|
data SMPClientAgentConfig = SMPClientAgentConfig
|
|
{ smpCfg :: ProtocolClientConfig SMPVersion,
|
|
reconnectInterval :: RetryInterval,
|
|
msgQSize :: Natural,
|
|
agentQSize :: Natural,
|
|
agentSubsBatchSize :: Int
|
|
}
|
|
|
|
defaultSMPClientAgentConfig :: SMPClientAgentConfig
|
|
defaultSMPClientAgentConfig =
|
|
SMPClientAgentConfig
|
|
{ smpCfg = defaultSMPClientConfig {defaultTransport = ("5223", transport @TLS)},
|
|
reconnectInterval =
|
|
RetryInterval
|
|
{ initialInterval = second,
|
|
increaseAfter = 10 * second,
|
|
maxInterval = 10 * second
|
|
},
|
|
msgQSize = 256,
|
|
agentQSize = 256,
|
|
agentSubsBatchSize = 900
|
|
}
|
|
where
|
|
second = 1000000
|
|
|
|
data SMPClientAgent = SMPClientAgent
|
|
{ agentCfg :: SMPClientAgentConfig,
|
|
msgQ :: TBQueue (ServerTransmission SMPVersion ErrorType BrokerMsg),
|
|
agentQ :: TBQueue SMPClientAgentEvent,
|
|
randomDrg :: TVar ChaChaDRG,
|
|
smpClients :: TMap SMPServer SMPClientVar,
|
|
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
|
|
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
|
|
reconnections :: TVar [Async ()],
|
|
asyncClients :: TVar [Async ()],
|
|
workerSeq :: TVar Int
|
|
}
|
|
|
|
newtype InternalException e = InternalException {unInternalException :: e}
|
|
deriving (Eq, Show)
|
|
|
|
instance Exception e => Exception (InternalException e)
|
|
|
|
instance Exception e => MonadUnliftIO (ExceptT e IO) where
|
|
{-# INLINE withRunInIO #-}
|
|
withRunInIO :: ((forall a. ExceptT e IO a -> IO a) -> IO b) -> ExceptT e IO b
|
|
withRunInIO inner =
|
|
ExceptT . fmap (first unInternalException) . E.try $
|
|
withRunInIO $ \run ->
|
|
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
|
|
|
|
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
|
|
-- the last two lines could be replaced with:
|
|
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
|
|
|
|
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
|
|
{-# INLINE withRunInIO #-}
|
|
withRunInIO :: ((forall a. ExceptT e (ReaderT r IO) a -> IO a) -> IO b) -> ExceptT e (ReaderT r IO) b
|
|
withRunInIO inner =
|
|
withExceptT unInternalException . ExceptT . E.try $
|
|
withRunInIO $ \run ->
|
|
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
|
|
|
|
newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent
|
|
newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do
|
|
msgQ <- newTBQueue msgQSize
|
|
agentQ <- newTBQueue agentQSize
|
|
smpClients <- TM.empty
|
|
srvSubs <- TM.empty
|
|
pendingSrvSubs <- TM.empty
|
|
reconnections <- newTVar []
|
|
asyncClients <- newTVar []
|
|
workerSeq <- newTVar 0
|
|
pure
|
|
SMPClientAgent
|
|
{ agentCfg,
|
|
msgQ,
|
|
agentQ,
|
|
randomDrg,
|
|
smpClients,
|
|
srvSubs,
|
|
pendingSrvSubs,
|
|
reconnections,
|
|
asyncClients,
|
|
workerSeq
|
|
}
|
|
|
|
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
|
|
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg, workerSeq} srv =
|
|
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
|
where
|
|
getClientVar :: STM (Either SMPClientVar SMPClientVar)
|
|
getClientVar = getSessVar workerSeq srv smpClients
|
|
|
|
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
|
waitForSMPClient v = do
|
|
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
|
|
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
|
|
liftEither $ case smpClient_ of
|
|
Just (Right smpClient) -> Right smpClient
|
|
Just (Left e) -> Left e
|
|
Nothing -> Left PCEResponseTimeout
|
|
|
|
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
|
newSMPClient v = tryConnectClient pure (liftIO tryConnectAsync)
|
|
where
|
|
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
|
|
tryConnectClient successAction retryAction =
|
|
tryE (connectClient v) >>= \r -> case r of
|
|
Right smp -> do
|
|
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
|
atomically $ putTMVar (sessionVar v) r
|
|
successAction smp
|
|
Left e -> do
|
|
if e == PCENetworkError || e == PCEResponseTimeout
|
|
then retryAction
|
|
else atomically $ do
|
|
putTMVar (sessionVar v) (Left e)
|
|
removeSessVar v srv smpClients
|
|
throwE e
|
|
tryConnectAsync :: IO ()
|
|
tryConnectAsync = do
|
|
a <- async $ void $ runExceptT connectAsync
|
|
atomically $ modifyTVar' (asyncClients ca) (a :)
|
|
connectAsync :: ExceptT SMPClientError IO ()
|
|
connectAsync =
|
|
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
|
|
void $ tryConnectClient (const reconnectClient) loop
|
|
|
|
connectClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
|
|
connectClient v = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) (clientDisconnected v)
|
|
|
|
clientDisconnected :: SMPClientVar -> SMPClient -> IO ()
|
|
clientDisconnected v _ = do
|
|
removeClientAndSubs v >>= (`forM_` serverDown)
|
|
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
|
|
|
removeClientAndSubs :: SMPClientVar -> IO (Maybe (Map SMPSub C.APrivateAuthKey))
|
|
removeClientAndSubs v = atomically $ do
|
|
removeSessVar v srv smpClients
|
|
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
|
|
where
|
|
updateSubs sVar = do
|
|
ss <- readTVar sVar
|
|
addPendingSubs sVar ss
|
|
pure ss
|
|
|
|
addPendingSubs sVar ss = do
|
|
let ps = pendingSrvSubs ca
|
|
TM.lookup srv ps >>= \case
|
|
Just ss' -> TM.union ss ss'
|
|
_ -> TM.insert srv sVar ps
|
|
|
|
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
|
|
serverDown ss = unless (M.null ss) $ do
|
|
notify . CADisconnected srv $ M.keysSet ss
|
|
reconnectServer
|
|
|
|
reconnectServer :: IO ()
|
|
reconnectServer = do
|
|
a <- async $ void $ runExceptT tryReconnectClient
|
|
atomically $ modifyTVar' (reconnections ca) (a :)
|
|
|
|
tryReconnectClient :: ExceptT SMPClientError IO ()
|
|
tryReconnectClient = do
|
|
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
|
|
reconnectClient `catchE` const loop
|
|
|
|
reconnectClient :: ExceptT SMPClientError IO ()
|
|
reconnectClient = do
|
|
withSMP ca srv $ \smp -> do
|
|
liftIO $ notify $ CAReconnected srv
|
|
cs_ <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca)
|
|
forM_ cs_ $ \cs -> do
|
|
subs' <- filterM (fmap not . atomically . hasSub (srvSubs ca) srv . fst) $ M.assocs cs
|
|
let (nSubs, rSubs) = partition (isNotifier . fst . fst) subs'
|
|
subscribe_ smp SPNotifier nSubs
|
|
subscribe_ smp SPRecipient rSubs
|
|
where
|
|
isNotifier = \case
|
|
SPNotifier -> True
|
|
SPRecipient -> False
|
|
|
|
subscribe_ :: SMPClient -> SMPSubParty -> [(SMPSub, C.APrivateAuthKey)] -> ExceptT SMPClientError IO ()
|
|
subscribe_ smp party = mapM_ subscribeBatch . toChunks (agentSubsBatchSize agentCfg)
|
|
where
|
|
subscribeBatch subs' = do
|
|
let subs'' :: (NonEmpty (QueueId, C.APrivateAuthKey)) = L.map (first snd) subs'
|
|
rs <- liftIO $ smpSubscribeQueues party ca smp srv subs''
|
|
let rs' :: (NonEmpty ((SMPSub, C.APrivateAuthKey), Either SMPClientError ())) =
|
|
L.zipWith (first . const) subs' rs
|
|
rs'' :: [Either (SMPSub, SMPClientError) (SMPSub, C.APrivateAuthKey)] =
|
|
map (\(sub, r) -> bimap (fst sub,) (const sub) r) $ L.toList rs'
|
|
(errs, oks) = partitionEithers rs''
|
|
(tempErrs, finalErrs) = partition (temporaryClientError . snd) errs
|
|
mapM_ (atomically . addSubscription ca srv) oks
|
|
mapM_ (liftIO . notify . CAResubscribed srv) $ L.nonEmpty $ map fst oks
|
|
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
|
|
mapM_ (liftIO . notify . CASubError srv) $ L.nonEmpty finalErrs
|
|
mapM_ (throwE . snd) $ listToMaybe tempErrs
|
|
|
|
notify :: SMPClientAgentEvent -> IO ()
|
|
notify evt = atomically $ writeTBQueue (agentQ ca) evt
|
|
|
|
closeSMPClientAgent :: SMPClientAgent -> IO ()
|
|
closeSMPClientAgent c = do
|
|
closeSMPServerClients c
|
|
cancelActions $ reconnections c
|
|
cancelActions $ asyncClients c
|
|
|
|
closeSMPServerClients :: SMPClientAgent -> IO ()
|
|
closeSMPServerClients c = atomically (smpClients c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient)
|
|
where
|
|
closeClient v =
|
|
atomically (readTMVar $ sessionVar v) >>= \case
|
|
Right smp -> closeProtocolClient smp `catchAll_` pure ()
|
|
_ -> pure ()
|
|
|
|
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
|
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
|
|
|
|
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO a
|
|
withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError
|
|
where
|
|
logSMPError :: SMPClientError -> ExceptT SMPClientError IO a
|
|
logSMPError e = do
|
|
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
|
|
throwE e
|
|
|
|
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO ()
|
|
subscribeQueue ca srv sub = do
|
|
atomically $ addPendingSubscription ca srv sub
|
|
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleErr
|
|
where
|
|
subscribe_ smp = do
|
|
smpSubscribe smp sub
|
|
atomically $ addSubscription ca srv sub
|
|
|
|
handleErr e = do
|
|
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
|
removePendingSubscription ca srv (fst sub)
|
|
throwE e
|
|
|
|
subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ()))
|
|
subscribeQueuesSMP = subscribeQueues_ SPRecipient
|
|
|
|
subscribeQueuesNtfs :: SMPClientAgent -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO (NonEmpty (NotifierId, Either SMPClientError ()))
|
|
subscribeQueuesNtfs = subscribeQueues_ SPNotifier
|
|
|
|
subscribeQueues_ :: SMPSubParty -> SMPClientAgent -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
|
|
subscribeQueues_ party ca srv subs = do
|
|
atomically $ forM_ subs $ addPendingSubscription ca srv . first (party,)
|
|
runExceptT (getSMPServerClient' ca srv) >>= \case
|
|
Left e -> pure $ L.map ((,Left e) . fst) subs
|
|
Right smp -> smpSubscribeQueues party ca smp srv subs
|
|
|
|
smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
|
|
smpSubscribeQueues party ca smp srv subs = do
|
|
rs <- L.zip subs <$> subscribe smp (L.map swap subs)
|
|
atomically $ forM rs $ \(sub, r) ->
|
|
(fst sub,) <$> case r of
|
|
Right () -> do
|
|
addSubscription ca srv $ first (party,) sub
|
|
pure $ Right ()
|
|
Left e -> do
|
|
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
|
removePendingSubscription ca srv (party, fst sub)
|
|
pure $ Left e
|
|
where
|
|
subscribe = case party of
|
|
SPRecipient -> subscribeSMPQueues
|
|
SPNotifier -> subscribeSMPQueuesNtfs
|
|
|
|
showServer :: SMPServer -> ByteString
|
|
showServer ProtocolServer {host, port} =
|
|
strEncode host <> B.pack (if null port then "" else ':' : port)
|
|
|
|
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO ()
|
|
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
|
|
where
|
|
subscribe_ = case party of
|
|
SPRecipient -> subscribeSMPQueue
|
|
SPNotifier -> subscribeSMPQueueNotifications
|
|
|
|
addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
|
|
addSubscription ca srv sub = do
|
|
addSub_ (srvSubs ca) srv sub
|
|
removePendingSubscription ca srv $ fst sub
|
|
|
|
addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
|
|
addPendingSubscription = addSub_ . pendingSrvSubs
|
|
|
|
addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
|
|
addSub_ subs srv (s, key) =
|
|
TM.lookup srv subs >>= \case
|
|
Just m -> TM.insert s key m
|
|
_ -> TM.singleton s key >>= \v -> TM.insert srv v subs
|
|
|
|
removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
|
removeSubscription = removeSub_ . srvSubs
|
|
|
|
removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
|
removePendingSubscription = removeSub_ . pendingSrvSubs
|
|
|
|
removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM ()
|
|
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
|
|
|
|
getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateAuthKey)
|
|
getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s
|
|
|
|
hasSub :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM Bool
|
|
hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs
|