From 3f69636f1a5f48292b6155195eb6840789364f86 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Tue, 7 Jun 2022 11:52:32 +0100 Subject: [PATCH] fix sockets/threads/memory leak (#388) * fix sockets/threads/memory leak * refactor --- src/Simplex/Messaging/Transport/Server.hs | 40 ++++++++++------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index c2e12aff0..3c1604a2f 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -12,20 +12,22 @@ module Simplex.Messaging.Transport.Server ) where +import Control.Concurrent.STM (stateTVar) import Control.Monad.Except import Control.Monad.IO.Unlift import qualified Crypto.Store.X509 as SX import Data.Default (def) -import Data.Set (Set) -import qualified Data.Set as S import qualified Data.X509 as X import Data.X509.Validation (Fingerprint (..)) import qualified Data.X509.Validation as XV import Network.Socket import qualified Network.TLS as T +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Util (catchAll_) import System.Exit (exitFailure) +import System.Mem.Weak (Weak, deRefWeak) import UnliftIO.Concurrent import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -36,37 +38,29 @@ import UnliftIO.STM runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> (c -> m ()) -> m () runTransportServer started port serverParams server = do u <- askUnliftIO - liftIO $ do - clients <- newTVarIO S.empty + liftIO . runTCPServer started port $ \conn -> E.bracket - (startTCPServer started port) - (closeServer started clients) - $ \sock -> forever . E.bracketOnError (accept sock) (close . fst) $ \(conn, _peer) -> do - -- catchAll_ is needed here in case the connection was closed earlier - tid <- forkFinally (connectClient u conn) (const . liftIO $ gracefulClose conn 5000 `catchAll_` pure ()) - atomically . modifyTVar' clients $ S.insert tid - where - connectClient :: UnliftIO m -> Socket -> IO () - connectClient u conn = - E.bracket - (connectTLS serverParams conn >>= getServerConnection) - closeConnection - (unliftIO u . server) + (connectTLS serverParams conn >>= getServerConnection) + closeConnection + (unliftIO u . server) --- | Run TCP server without TLS - only used in SimpleX Chat +-- | Run TCP server without TLS runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO () runTCPServer started port server = do - clients <- newTVarIO S.empty + clients <- atomically TM.empty + clientId <- newTVarIO 0 E.bracket (startTCPServer started port) (closeServer started clients) $ \sock -> forever . E.bracketOnError (accept sock) (close . fst) $ \(conn, _peer) -> do - tid <- forkFinally (server conn) (const $ gracefulClose conn 5000) - atomically . modifyTVar' clients $ S.insert tid + -- catchAll_ is needed here in case the connection was closed earlier + cId <- atomically $ stateTVar clientId $ \cId -> (cId + 1, cId + 1) + tId <- mkWeakThreadId =<< forkFinally (server conn) (const $ gracefulClose conn 5000 `catchAll_` atomically (TM.delete cId clients)) + atomically $ TM.insert cId tId clients -closeServer :: TMVar Bool -> TVar (Set ThreadId) -> Socket -> IO () +closeServer :: TMVar Bool -> TMap Int (Weak ThreadId) -> Socket -> IO () closeServer started clients sock = do - readTVarIO clients >>= mapM_ killThread + readTVarIO clients >>= mapM_ (deRefWeak >=> mapM_ killThread) close sock void . atomically $ tryPutTMVar started False