Files
simplexmq/src/Simplex/Messaging/Client/Agent.hs
Evgeny Poberezkin 91cc48aabe agent: do not mark subscriptions on expired sessions as active, mark delayed subscriptions as active on the same session, do not cancel sending expired commands (#1127)
* agent: do not mark subscriptions on expired sessions as active, do mark delayed subscriptions as active on the same session, SUBOK response in the next SMP protocol version

* client: prevent sub actions from zombie sessions (#1122)

* client: prevent sub actions from zombie sessions

* error handling

* add AERR to pass background errors to client

* switch to activeClientSession

* put closeClient under activeClientSession

* rename

* remove AERR, do not skip processing

* move check and state update to one transaction

* catch extra UPs

* fix

* check queue is still pending before making it active

---------

Co-authored-by: Evgeny Poberezkin <evgeny@poberezkin.com>

* do not forward agent error

* revert not expiring sending subs

* fixes

* track subscription responses better

* add pending connection

* Revert "revert not expiring sending subs"

This reverts commit 4310a69391.

* do not expire sending commands

* rename

* fix race

* function

---------

Co-authored-by: Alexander Bondarenko <486682+dpwiz@users.noreply.github.com>
2024-05-12 17:47:08 +01:00

382 lines
15 KiB
Haskell

{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Simplex.Messaging.Client.Agent where
import Control.Concurrent (forkIO)
import Control.Concurrent.Async (Async, uninterruptibleCancel)
import Control.Logger.Simple
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Trans.Except
import Control.Monad.Trans.Reader
import Crypto.Random (ChaChaDRG)
import Data.Bifunctor (bimap, first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (partitionEithers)
import Data.List (partition)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (listToMaybe)
import Data.Set (Set)
import Data.Text.Encoding
import Data.Tuple (swap)
import Numeric.Natural
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, NotifierId, NtfPrivateAuthKey, ProtocolServer (..), QueueId, RcvPrivateAuthKey, RecipientId, SMPServer)
import Simplex.Messaging.Session
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Util (catchAll_, toChunks, ($>>=))
import System.Timeout (timeout)
import UnliftIO (async)
import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
type SMPClientVar = SessionVar (Either SMPClientError SMPClient)
data SMPClientAgentEvent
= CAConnected SMPServer
| CADisconnected SMPServer (Set SMPSub)
| CAReconnected SMPServer
| CAResubscribed SMPServer (NonEmpty SMPSub)
| CASubError SMPServer (NonEmpty (SMPSub, SMPClientError))
data SMPSubParty = SPRecipient | SPNotifier
deriving (Eq, Ord, Show)
type SMPSub = (SMPSubParty, QueueId)
-- type SMPServerSub = (SMPServer, SMPSub)
data SMPClientAgentConfig = SMPClientAgentConfig
{ smpCfg :: ProtocolClientConfig SMPVersion,
reconnectInterval :: RetryInterval,
msgQSize :: Natural,
agentQSize :: Natural,
agentSubsBatchSize :: Int
}
defaultSMPClientAgentConfig :: SMPClientAgentConfig
defaultSMPClientAgentConfig =
SMPClientAgentConfig
{ smpCfg = defaultSMPClientConfig {defaultTransport = ("5223", transport @TLS)},
reconnectInterval =
RetryInterval
{ initialInterval = second,
increaseAfter = 10 * second,
maxInterval = 10 * second
},
msgQSize = 256,
agentQSize = 256,
agentSubsBatchSize = 900
}
where
second = 1000000
data SMPClientAgent = SMPClientAgent
{ agentCfg :: SMPClientAgentConfig,
msgQ :: TBQueue (ServerTransmission SMPVersion ErrorType BrokerMsg),
agentQ :: TBQueue SMPClientAgentEvent,
randomDrg :: TVar ChaChaDRG,
smpClients :: TMap SMPServer SMPClientVar,
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey),
reconnections :: TVar [Async ()],
asyncClients :: TVar [Async ()],
workerSeq :: TVar Int
}
newtype InternalException e = InternalException {unInternalException :: e}
deriving (Eq, Show)
instance Exception e => Exception (InternalException e)
instance Exception e => MonadUnliftIO (ExceptT e IO) where
{-# INLINE withRunInIO #-}
withRunInIO :: ((forall a. ExceptT e IO a -> IO a) -> IO b) -> ExceptT e IO b
withRunInIO inner =
ExceptT . fmap (first unInternalException) . E.try $
withRunInIO $ \run ->
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
-- as MonadUnliftIO instance for IO is `withRunInIO inner = inner id`,
-- the last two lines could be replaced with:
-- inner $ either (E.throwIO . InternalException) pure <=< runExceptT
instance Exception e => MonadUnliftIO (ExceptT e (ReaderT r IO)) where
{-# INLINE withRunInIO #-}
withRunInIO :: ((forall a. ExceptT e (ReaderT r IO) a -> IO a) -> IO b) -> ExceptT e (ReaderT r IO) b
withRunInIO inner =
withExceptT unInternalException . ExceptT . E.try $
withRunInIO $ \run ->
inner $ run . (either (E.throwIO . InternalException) pure <=< runExceptT)
newSMPClientAgent :: SMPClientAgentConfig -> TVar ChaChaDRG -> STM SMPClientAgent
newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} randomDrg = do
msgQ <- newTBQueue msgQSize
agentQ <- newTBQueue agentQSize
smpClients <- TM.empty
srvSubs <- TM.empty
pendingSrvSubs <- TM.empty
reconnections <- newTVar []
asyncClients <- newTVar []
workerSeq <- newTVar 0
pure
SMPClientAgent
{ agentCfg,
msgQ,
agentQ,
randomDrg,
smpClients,
srvSubs,
pendingSrvSubs,
reconnections,
asyncClients,
workerSeq
}
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT SMPClientError IO SMPClient
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ, randomDrg, workerSeq} srv =
atomically getClientVar >>= either newSMPClient waitForSMPClient
where
getClientVar :: STM (Either SMPClientVar SMPClientVar)
getClientVar = getSessVar workerSeq srv smpClients
waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
waitForSMPClient v = do
let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg
smpClient_ <- liftIO $ tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v)
liftEither $ case smpClient_ of
Just (Right smpClient) -> Right smpClient
Just (Left e) -> Left e
Nothing -> Left PCEResponseTimeout
newSMPClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
newSMPClient v = tryConnectClient pure (liftIO tryConnectAsync)
where
tryConnectClient :: (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO () -> ExceptT SMPClientError IO a
tryConnectClient successAction retryAction =
tryE (connectClient v) >>= \r -> case r of
Right smp -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar (sessionVar v) r
successAction smp
Left e -> do
if e == PCENetworkError || e == PCEResponseTimeout
then retryAction
else atomically $ do
putTMVar (sessionVar v) (Left e)
removeSessVar v srv smpClients
throwE e
tryConnectAsync :: IO ()
tryConnectAsync = do
a <- async $ void $ runExceptT connectAsync
atomically $ modifyTVar' (asyncClients ca) (a :)
connectAsync :: ExceptT SMPClientError IO ()
connectAsync =
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
void $ tryConnectClient (const reconnectClient) loop
connectClient :: SMPClientVar -> ExceptT SMPClientError IO SMPClient
connectClient v = ExceptT $ getProtocolClient randomDrg (1, srv, Nothing) (smpCfg agentCfg) (Just msgQ) (clientDisconnected v)
clientDisconnected :: SMPClientVar -> SMPClient -> IO ()
clientDisconnected v _ = do
removeClientAndSubs v >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientAndSubs :: SMPClientVar -> IO (Maybe (Map SMPSub C.APrivateAuthKey))
removeClientAndSubs v = atomically $ do
removeSessVar v srv smpClients
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
where
updateSubs sVar = do
ss <- readTVar sVar
addPendingSubs sVar ss
pure ss
addPendingSubs sVar ss = do
let ps = pendingSrvSubs ca
TM.lookup srv ps >>= \case
Just ss' -> TM.union ss ss'
_ -> TM.insert srv sVar ps
serverDown :: Map SMPSub C.APrivateAuthKey -> IO ()
serverDown ss = unless (M.null ss) $ do
notify . CADisconnected srv $ M.keysSet ss
reconnectServer
reconnectServer :: IO ()
reconnectServer = do
a <- async $ void $ runExceptT tryReconnectClient
atomically $ modifyTVar' (reconnections ca) (a :)
tryReconnectClient :: ExceptT SMPClientError IO ()
tryReconnectClient = do
withRetryInterval (reconnectInterval agentCfg) $ \_ loop ->
reconnectClient `catchE` const loop
reconnectClient :: ExceptT SMPClientError IO ()
reconnectClient = do
withSMP ca srv $ \smp -> do
liftIO $ notify $ CAReconnected srv
cs_ <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca)
forM_ cs_ $ \cs -> do
subs' <- filterM (fmap not . atomically . hasSub (srvSubs ca) srv . fst) $ M.assocs cs
let (nSubs, rSubs) = partition (isNotifier . fst . fst) subs'
subscribe_ smp SPNotifier nSubs
subscribe_ smp SPRecipient rSubs
where
isNotifier = \case
SPNotifier -> True
SPRecipient -> False
subscribe_ :: SMPClient -> SMPSubParty -> [(SMPSub, C.APrivateAuthKey)] -> ExceptT SMPClientError IO ()
subscribe_ smp party = mapM_ subscribeBatch . toChunks (agentSubsBatchSize agentCfg)
where
subscribeBatch subs' = do
let subs'' :: (NonEmpty (QueueId, C.APrivateAuthKey)) = L.map (first snd) subs'
rs <- liftIO $ smpSubscribeQueues party ca smp srv subs''
let rs' :: (NonEmpty ((SMPSub, C.APrivateAuthKey), Either SMPClientError ())) =
L.zipWith (first . const) subs' rs
rs'' :: [Either (SMPSub, SMPClientError) (SMPSub, C.APrivateAuthKey)] =
map (\(sub, r) -> bimap (fst sub,) (const sub) r) $ L.toList rs'
(errs, oks) = partitionEithers rs''
(tempErrs, finalErrs) = partition (temporaryClientError . snd) errs
mapM_ (atomically . addSubscription ca srv) oks
mapM_ (liftIO . notify . CAResubscribed srv) $ L.nonEmpty $ map fst oks
mapM_ (atomically . removePendingSubscription ca srv . fst) finalErrs
mapM_ (liftIO . notify . CASubError srv) $ L.nonEmpty finalErrs
mapM_ (throwE . snd) $ listToMaybe tempErrs
notify :: SMPClientAgentEvent -> IO ()
notify evt = atomically $ writeTBQueue (agentQ ca) evt
closeSMPClientAgent :: SMPClientAgent -> IO ()
closeSMPClientAgent c = do
closeSMPServerClients c
cancelActions $ reconnections c
cancelActions $ asyncClients c
closeSMPServerClients :: SMPClientAgent -> IO ()
closeSMPServerClients c = atomically (smpClients c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient)
where
closeClient v =
atomically (readTMVar $ sessionVar v) >>= \case
Right smp -> closeProtocolClient smp `catchAll_` pure ()
_ -> pure ()
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> ExceptT SMPClientError IO a
withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError
where
logSMPError :: SMPClientError -> ExceptT SMPClientError IO a
logSMPError e = do
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
throwE e
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO ()
subscribeQueue ca srv sub = do
atomically $ addPendingSubscription ca srv sub
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleErr
where
subscribe_ smp = do
smpSubscribe smp sub
atomically $ addSubscription ca srv sub
handleErr e = do
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription ca srv (fst sub)
throwE e
subscribeQueuesSMP :: SMPClientAgent -> SMPServer -> NonEmpty (RecipientId, RcvPrivateAuthKey) -> IO (NonEmpty (RecipientId, Either SMPClientError ()))
subscribeQueuesSMP = subscribeQueues_ SPRecipient
subscribeQueuesNtfs :: SMPClientAgent -> SMPServer -> NonEmpty (NotifierId, NtfPrivateAuthKey) -> IO (NonEmpty (NotifierId, Either SMPClientError ()))
subscribeQueuesNtfs = subscribeQueues_ SPNotifier
subscribeQueues_ :: SMPSubParty -> SMPClientAgent -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
subscribeQueues_ party ca srv subs = do
atomically $ forM_ subs $ addPendingSubscription ca srv . first (party,)
runExceptT (getSMPServerClient' ca srv) >>= \case
Left e -> pure $ L.map ((,Left e) . fst) subs
Right smp -> smpSubscribeQueues party ca smp srv subs
smpSubscribeQueues :: SMPSubParty -> SMPClientAgent -> SMPClient -> SMPServer -> NonEmpty (QueueId, C.APrivateAuthKey) -> IO (NonEmpty (QueueId, Either SMPClientError ()))
smpSubscribeQueues party ca smp srv subs = do
rs <- L.zip subs <$> subscribe smp (L.map swap subs)
atomically $ forM rs $ \(sub, r) ->
(fst sub,) <$> case r of
Right () -> do
addSubscription ca srv $ first (party,) sub
pure $ Right ()
Left e -> do
when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription ca srv (party, fst sub)
pure $ Left e
where
subscribe = case party of
SPRecipient -> subscribeSMPQueues
SPNotifier -> subscribeSMPQueuesNtfs
showServer :: SMPServer -> ByteString
showServer ProtocolServer {host, port} =
strEncode host <> B.pack (if null port then "" else ':' : port)
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateAuthKey) -> ExceptT SMPClientError IO ()
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
where
subscribe_ = case party of
SPRecipient -> subscribeSMPQueue
SPNotifier -> subscribeSMPQueueNotifications
addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
addSubscription ca srv sub = do
addSub_ (srvSubs ca) srv sub
removePendingSubscription ca srv $ fst sub
addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
addPendingSubscription = addSub_ . pendingSrvSubs
addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> (SMPSub, C.APrivateAuthKey) -> STM ()
addSub_ subs srv (s, key) =
TM.lookup srv subs >>= \case
Just m -> TM.insert s key m
_ -> TM.singleton s key >>= \v -> TM.insert srv v subs
removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
removeSubscription = removeSub_ . srvSubs
removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
removePendingSubscription = removeSub_ . pendingSrvSubs
removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM ()
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateAuthKey)
getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s
hasSub :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSub -> STM Bool
hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs