From 6ccbe5e66e7d34889ffb11daefc31daf8ac9791c Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 18 Jan 2023 14:30:25 +0000 Subject: [PATCH] retry unsuccessful subscriptions in case of temporary errors (#613) * retry unsuccessful subscriptions in case of temporary errors * do not send DOWN if connection has any active queues --- src/Simplex/Messaging/Agent/Client.hs | 89 ++++++++++++----------- src/Simplex/Messaging/Agent/TRcvQueues.hs | 8 +- 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index a94d866d0..53c9d4304 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -96,7 +96,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (isRight, partitionEithers) +import Data.Either (isRight, lefts, partitionEithers) import Data.Functor (($>)) import Data.List (partition) import Data.List.NonEmpty (NonEmpty (..)) @@ -286,7 +286,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped" atomically (getClientVar srv smpClients) >>= either - (newProtocolClient c srv smpClients connectClient reconnectClient) + (newProtocolClient c srv smpClients connectClient reconnectSMPClient) (waitForProtocolClient c srv) where connectClient :: m SMPClient @@ -303,9 +303,11 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do removeClientAndSubs :: IO ([RcvQueue], [ConnId]) removeClientAndSubs = atomically $ do TM.delete srv smpClients - (qs, conns) <- RQ.getDelSrvQueues srv $ activeSubs c + qs <- RQ.getDelSrvQueues srv $ activeSubs c mapM_ (`RQ.addQueue` pendingSubs c) qs - pure (qs, S.toList conns) + let cs = S.fromList $ map (\RcvQueue {connId} -> connId) qs + cs' <- RQ.getConns $ activeSubs c + pure (qs, S.toList $ cs `S.difference` cs') serverDown :: ([RcvQueue], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do @@ -314,40 +316,41 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do unless (null conns) $ notifySub "" $ DOWN srv conns unless (null qs) $ do atomically $ mapM_ (releaseGetLock c) qs - unliftIO u reconnectServer + unliftIO u $ reconnectServer c srv - reconnectServer :: m () - reconnectServer = do - a <- async tryReconnectClient - atomically $ modifyTVar' (reconnections c) (a :) + notifySub :: ConnId -> ACommand 'Agent -> IO () + notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) - tryReconnectClient :: m () - tryReconnectClient = do - ri <- asks $ reconnectInterval . config - withRetryInterval ri $ \loop -> - reconnectClient `catchError` const loop - - reconnectClient :: m () - reconnectClient = - withLockMap_ (reconnectLocks c) srv "reconnect" $ - atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe - where - resubscribe :: [RcvQueue] -> m () - resubscribe qs = do - connected <- maybe False isRight <$> atomically (TM.lookup srv smpClients $>>= tryReadTMVar) - cs <- atomically . RQ.getConns $ activeSubs c - (client_, rs) <- subscribeQueues c srv qs - let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs - liftIO $ do - unless connected . forM_ client_ $ \cl -> do - incClientStat c cl "CONNECT" "" - notifySub "" $ hostEvent CONNECT cl - let conns = S.toList $ S.fromList okConns `S.difference` cs - unless (null conns) $ notifySub "" $ UP srv conns - let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs - liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs - mapM_ (throwError . snd) $ listToMaybe tempErrs +reconnectServer :: AgentMonad m => AgentClient -> SMPServer -> m () +reconnectServer c srv = do + a <- async tryReconnectSMPClient + atomically $ modifyTVar' (reconnections c) (a :) + where + tryReconnectSMPClient = do + ri <- asks $ reconnectInterval . config + withRetryInterval ri $ \loop -> + reconnectSMPClient c srv `catchError` const loop +reconnectSMPClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m () +reconnectSMPClient c srv = + withLockMap_ (reconnectLocks c) srv "reconnect" $ + atomically (RQ.getSrvQueues srv $ pendingSubs c) >>= resubscribe + where + resubscribe :: [RcvQueue] -> m () + resubscribe qs = do + connected <- maybe False isRight <$> atomically (TM.lookup srv (smpClients c) $>>= tryReadTMVar) + cs <- atomically . RQ.getConns $ activeSubs c + (client_, rs) <- subscribeQueues c srv qs + let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs + liftIO $ do + unless connected . forM_ client_ $ \cl -> do + incClientStat c cl "CONNECT" "" + notifySub "" $ hostEvent CONNECT cl + let conns = S.toList $ S.fromList okConns `S.difference` cs + unless (null conns) $ notifySub "" $ UP srv conns + let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs + liftIO $ mapM_ (\(connId, e) -> notifySub connId $ ERR e) finalErrs + mapM_ (throwError . snd) $ listToMaybe tempErrs notifySub :: ConnId -> ACommand 'Agent -> IO () notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd) @@ -356,7 +359,7 @@ getNtfServerClient c@AgentClient {active, ntfClients} srv = do unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped" atomically (getClientVar srv ntfClients) >>= either - (newProtocolClient c srv ntfClients connectClient $ pure ()) + (newProtocolClient c srv ntfClients connectClient $ \_ _ -> pure ()) (waitForProtocolClient c srv) where connectClient :: m NtfClient @@ -396,7 +399,7 @@ newProtocolClient :: ProtoServer msg -> TMap (ProtoServer msg) (ClientVar msg) -> m (ProtocolClient msg) -> - m () -> + (AgentClient -> ProtoServer msg -> m ()) -> ClientVar msg -> m (ProtocolClient msg) newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync @@ -425,7 +428,7 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon connectAsync :: m () connectAsync = do ri <- asks $ reconnectInterval . config - withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop + withRetryInterval ri $ \loop -> void $ tryConnectClient (const $ reconnectClient c srv) loop hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost' client @@ -609,9 +612,11 @@ subscribeQueue c rq@RcvQueue {connId, server, rcvPrivateKey, rcvId} = do atomically $ do modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c - withLogClient c server rcvId "SUB" $ \smp -> + r <- withLogClient c server rcvId "SUB" $ \smp -> liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq) - >>= either throwError pure + case r of + Left e -> reconnectServer c server >> throwError (protocolClientError SMP (B.unpack $ strEncode server) e) + _ -> pure () processSubResult :: AgentClient -> RcvQueue -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) processSubResult c rq r = do @@ -646,7 +651,7 @@ subscribeQueues c srv qs = do forM_ qs_ $ \rq@RcvQueue {connId} -> atomically $ do modifyTVar' (subscrConns c) $ S.insert connId RQ.addQueue rq $ pendingSubs c - case L.nonEmpty qs_ of + r <- case L.nonEmpty qs_ of Just qs' -> do smp_ <- tryError (getSMPServerClient c srv) (eitherToMaybe smp_,) . (errs <>) <$> case smp_ of @@ -661,6 +666,8 @@ subscribeQueues c srv qs = do mapM_ (uncurry $ processSubResult c) rs pure $ map (second . first $ protocolClientError SMP $ clientServer smp) rs _ -> pure (Nothing, errs) + when (any temporaryOrHostError . lefts . map snd $ snd r) $ reconnectServer c srv + pure r where checkQueue rq@RcvQueue {rcvId, server} | server == srv = do diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index 9ace32db7..bed4138a2 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -40,9 +40,9 @@ getSrvQueues srv (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs where addQ qs' rq@RcvQueue {server} = if srv == server then rq : qs' else qs' -getDelSrvQueues :: SMPServer -> TRcvQueues -> STM ([RcvQueue], Set ConnId) -getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty) +getDelSrvQueues :: SMPServer -> TRcvQueues -> STM [RcvQueue] +getDelSrvQueues srv (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ ([], M.empty) where - addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId, server, rcvId} - | srv == server = ((rq : remQs, S.insert connId remConns), qs') + addQ (removed, qs') rq@RcvQueue {server, rcvId} + | srv == server = (rq : removed, qs') | otherwise = (removed, M.insert (server, rcvId) rq qs')