From 7249cb0f0e63ae4e38f16fea55e937e84a5971d4 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 29 Apr 2022 13:12:30 +0100 Subject: [PATCH] close socket on connection exception (#365) --- src/Simplex/Messaging/Agent/Client.hs | 4 ++-- src/Simplex/Messaging/Client/Agent.hs | 4 ++-- src/Simplex/Messaging/Transport.hs | 9 +++++---- src/Simplex/Messaging/Transport/Server.hs | 5 +++-- src/Simplex/Messaging/Util.hs | 10 ++++++++++ 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 6e5112cef..21128b2ce 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -72,7 +72,7 @@ import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, ProtocolServer (..), Qu import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, ifM, liftEitherError, liftError, tryError, unlessM, whenM) +import Simplex.Messaging.Util (bshow, catchAll_, ifM, liftEitherError, liftError, tryError, unlessM, whenM) import Simplex.Messaging.Version import System.Timeout (timeout) import UnliftIO (async, forConcurrently) @@ -312,7 +312,7 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli where closeClient smpVar = atomically (readTMVar smpVar) >>= \case - Right smp -> closeProtocolClient smp `E.catch` \(_ :: E.SomeException) -> pure () + Right smp -> closeProtocolClient smp `catchAll_` pure () _ -> pure () cancelActions :: Foldable f => TVar (f (Async ())) -> IO () diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 1f9d215cd..dd3b0e90b 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -28,7 +28,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (tryE, whenM, ($>>=)) +import Simplex.Messaging.Util (catchAll_, tryE, whenM, ($>>=)) import System.Timeout (timeout) import UnliftIO (async, forConcurrently_) import UnliftIO.Exception (Exception) @@ -232,7 +232,7 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli where closeClient smpVar = atomically (readTMVar smpVar) >>= \case - Right smp -> closeProtocolClient smp `E.catch` \(_ :: E.SomeException) -> pure () + Right smp -> closeProtocolClient smp `catchAll_` pure () _ -> pure () cancelActions :: Foldable f => TVar (f (Async ())) -> IO () diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 3d0071733..f3f553b79 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -80,7 +80,7 @@ import qualified Network.TLS.Extra as TE import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON) -import Simplex.Messaging.Util (bshow) +import Simplex.Messaging.Util (bshow, catchAll, catchAll_) import Simplex.Messaging.Version import Test.QuickCheck (Arbitrary (..)) import UnliftIO.Exception (Exception) @@ -154,7 +154,7 @@ connectTLS :: T.TLSParams p => p -> Socket -> IO T.Context connectTLS params sock = E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx -> do T.handshake ctx - `E.catch` \(e :: E.SomeException) -> putStrLn ("exception: " <> show e) >> E.throwIO e + `catchAll` \e -> putStrLn ("exception: " <> show e) >> E.throwIO e pure ctx getTLS :: TransportPeer -> T.Context -> IO TLS @@ -175,8 +175,9 @@ withTlsUnique peer cxt f = closeTLS :: T.Context -> IO () closeTLS ctx = - (T.bye ctx >> T.contextClose ctx) -- sometimes socket was closed before 'TLS.bye' - `E.catch` (\(_ :: E.SomeException) -> pure ()) -- so we catch the 'Broken pipe' error here + T.bye ctx -- sometimes socket was closed before 'TLS.bye' so we catch the 'Broken pipe' error here + `catchAll_` T.contextClose ctx + `catchAll_` pure () supportedParameters :: T.Supported supportedParameters = diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index d4009b4f6..55393d278 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -24,6 +24,7 @@ import qualified Data.X509.Validation as XV import Network.Socket import qualified Network.TLS as T import Simplex.Messaging.Transport +import Simplex.Messaging.Util (catchAll_) import System.Exit (exitFailure) import UnliftIO.Concurrent import qualified UnliftIO.Exception as E @@ -42,7 +43,7 @@ runTransportServer started port serverParams server = do (closeServer started clients) $ \sock -> forever $ do (connSock, _) <- accept sock - tid <- forkIO $ connectClient u connSock `E.catch` \(_ :: E.SomeException) -> pure () + tid <- forkIO $ connectClient u connSock `catchAll_` close connSock `catchAll_` pure () atomically . modifyTVar' clients $ S.insert tid where connectClient :: UnliftIO m -> Socket -> IO () @@ -60,7 +61,7 @@ runTCPServer started port server = do (closeServer started clients) $ \sock -> forever $ do (connSock, _) <- accept sock - tid <- forkIO $ server connSock `E.catch` \(_ :: E.SomeException) -> pure () + tid <- forkIO $ server connSock `catchAll_` pure () atomically . modifyTVar' clients $ S.insert tid closeServer :: TMVar Bool -> TVar (Set ThreadId) -> Socket -> IO () diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 0d741bb43..24291773b 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -1,7 +1,9 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Util where +import qualified Control.Exception as E import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except @@ -68,3 +70,11 @@ unlessM b = ifM b $ pure () ($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b) f $>>= g = f >>= fmap join . mapM g + +catchAll :: IO a -> (E.SomeException -> IO a) -> IO a +catchAll = E.catch +{-# INLINE catchAll #-} + +catchAll_ :: IO a -> IO a -> IO a +catchAll_ a = catchAll a . const +{-# INLINE catchAll_ #-}