client: prevent sub actions from zombie sessions

This commit is contained in:
Alexander Bondarenko
2024-05-01 21:51:34 +03:00
parent 9bc9d88971
commit 785ceb78e9
5 changed files with 35 additions and 9 deletions

View File

@@ -121,6 +121,7 @@ import Control.Logger.Simple (logError, logInfo, showText)
import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.Trans.Except (throwE)
import Crypto.Random (ChaChaDRG)
import qualified Data.Aeson as J
import Data.Bifunctor (bimap, first, second)
@@ -171,6 +172,7 @@ import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, XFTPServerWithAuth)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.ServiceScheme (ServiceScheme (..))
import Simplex.Messaging.Session (checkSessVar)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (SMPVersion, THandleParams (sessionId))
import Simplex.Messaging.Util
@@ -1957,11 +1959,11 @@ getSMPServer c userId = withUserServers c userId pickServer
{-# INLINE getSMPServer #-}
subscriber :: AgentClient -> AM' ()
subscriber c@AgentClient {msgQ} = forever $ do
subscriber c@AgentClient {subQ, msgQ} = forever $ do
t <- atomically $ readTBQueue msgQ
agentOperationBracket c AORcvNetwork waitUntilActive $
runExceptT (processSMPTransmission c t) >>= \case
Left e -> liftIO $ print e
Left e -> atomically $ writeTBQueue subQ ("", "", APC SAEConn $ ERR e)
Right _ -> return ()
cleanupManager :: AgentClient -> AM' ()
@@ -2040,6 +2042,8 @@ data ACKd = ACKd | ACKPending
-- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL
processSMPTransmission :: AgentClient -> ServerTransmission SMPVersion BrokerMsg -> AM ()
processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, isResponse, rId, cmd) = do
unlessM (atomically $ checkSessVar smpClients tSess $ either (const False) ((== sessId) . sessionId . thParams)) . throwE $
CRITICAL False "Congratulations, you've caught a rare beast: zomblie delivery process!"
(rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId)
processSMP rq conn $ toConnData conn
where

View File

@@ -156,6 +156,7 @@ import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (lefts, partitionEithers)
import Data.Function (on)
import Data.Functor (($>))
import Data.Int (Int64)
import Data.List (deleteFirstsBy, foldl', partition, (\\))
@@ -645,6 +646,7 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
cs <- readTVarIO $ RQ.getConnections $ activeSubs c
rs <- lift . subscribeQueues c $ L.toList qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
mapM_ throwError $ listToMaybe [e | (_conn, e@CRITICAL {}) <- errs]
liftIO $ do
let conns = filter (`M.notMember` cs) okConns
unless (null conns) $ notifySub "" $ UP srv conns
@@ -652,7 +654,7 @@ reconnectSMPClient tc c tSess@(_, srv, _) qs = do
liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs
forM_ (listToMaybe tempErrs) $ \(_, err) -> do
when (null okConns && M.null cs && null finalErrs) . liftIO $
closeClient c smpClients tSess
closeClient c smpClients tSess -- XXX: closing client without checking session
throwError err
notifySub :: forall e. AEntityI e => ConnId -> ACommand 'Agent e -> IO ()
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, APC (sAEntity @e) cmd)
@@ -951,6 +953,7 @@ protocolClientError protocolError_ host = \case
PCEIncompatibleHost -> BROKER host HOST
PCETransportError e -> BROKER host $ TRANSPORT e
e@PCECryptoError {} -> INTERNAL $ show e
PCEZombieSession -> CRITICAL False "A message came from a disconnected session"
PCEIOError {} -> BROKER host NETWORK
data ProtocolTestStep
@@ -1159,7 +1162,7 @@ temporaryOrHostError = \case
-- | Subscribe to queues. The list of results can have a different order.
subscribeQueues :: AgentClient -> [RcvQueue] -> AM' [(RcvQueue, Either AgentErrorType ())]
subscribeQueues c qs = do
subscribeQueues c@AgentClient {smpClients} qs = do
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
atomically $ do
modifyTVar' (subscrConns c) (`S.union` S.fromList (map qConnId qs'))
@@ -1174,10 +1177,14 @@ subscribeQueues c qs = do
subscribeQueues_ :: Env -> SMPClient -> NonEmpty RcvQueue -> IO (BatchResponses SMPClientError ())
subscribeQueues_ env smp qs' = do
rs <- sendBatch subscribeSMPQueues smp qs'
mapM_ (uncurry $ processSubResult c) rs
when (any temporaryClientError . lefts . map snd $ L.toList rs) $
runReaderT (resubscribeSMPSession c $ transportSession' smp) env
pure rs
ok <- atomically $ checkSessVar smpClients (transportSession' smp) (either (const False) $ sameClient smp)
if not ok
then pure $ L.map (\(q, _ignore) -> (q, Left PCEZombieSession)) rs -- the session is gone, don't touch agent state
else do
mapM_ (uncurry $ processSubResult c) rs
when (any temporaryClientError . lefts . map snd $ L.toList rs) $
runReaderT (resubscribeSMPSession c $ transportSession' smp) env
pure rs
activeClientSession :: AgentClient -> SMPTransportSession -> SessionId -> STM Bool
activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClients c)
@@ -1186,7 +1193,10 @@ activeClientSession c tSess sessId = sameSess <$> tryReadSessVar tSess (smpClien
Just (Right smp) -> sessId == sessionId (thParams smp)
_ -> False
type BatchResponses e r = (NonEmpty (RcvQueue, Either e r))
sameClient :: ProtocolClient v err msg -> ProtocolClient v err msg -> Bool
sameClient = (==) `on` (sessionId . thParams)
type BatchResponses e r = NonEmpty (RcvQueue, Either e r)
-- statBatchSize is not used to batch the commands, only for traffic statistics
sendTSessionBatches :: forall q r. ByteString -> Int -> (q -> RcvQueue) -> (SMPClient -> NonEmpty q -> IO (BatchResponses SMPClientError r)) -> AgentClient -> [q] -> AM' [(RcvQueue, Either AgentErrorType r)]

View File

@@ -499,6 +499,8 @@ data ProtocolClientError err
PCETransportError TransportError
| -- | Error when cryptographically "signing" the command or when initializing crypto_box.
PCECryptoError C.CryptoError
| -- | Message came from a killed session.
PCEZombieSession
| -- | IO Error
PCEIOError IOException
deriving (Eq, Show, Exception)

View File

@@ -274,6 +274,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
PCEResponseTimeout -> pure Nothing
PCENetworkError -> pure Nothing
PCEIOError _ -> pure Nothing
PCEZombieSession -> pure Nothing
where
updateErr :: Show e => ByteString -> e -> M (Maybe NtfSubStatus)
updateErr errType e = updateSubStatus smpQueue (NSErr $ errType <> bshow e) $> Just (NSErr errType)

View File

@@ -40,3 +40,12 @@ removeSessVar' v sessKey vs =
tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a)
tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar)
checkSessVar :: Ord k => TMap k (SessionVar a) -> k -> (a -> Bool) -> STM Bool
checkSessVar vs sessKey p =
TM.lookup sessKey vs >>= \case
Nothing -> pure False
Just SessionVar {sessionVar} ->
tryReadTMVar sessionVar >>= \case
Nothing -> pure False
Just x -> pure $ p x