From 7cb289e88aacb7fac96d7a413adba63488eb24b3 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Thu, 15 Oct 2020 07:08:21 +0100 Subject: [PATCH] refactor: TCP transport --- src/Server.hs | 12 ++--------- src/Transport.hs | 52 +++++++++++++++++++++++++++++++++------------- tests/SMPClient.hs | 24 +++------------------ 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/src/Server.hs b/src/Server.hs index 3e9a4ebb0..ad65b6945 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -21,25 +21,17 @@ import Numeric.Natural import Transmission import Transport import UnliftIO.Async -import UnliftIO.Concurrent -import qualified UnliftIO.Exception as E import UnliftIO.IO import UnliftIO.STM runSMPServer :: MonadUnliftIO m => ServiceName -> Natural -> m () runSMPServer port queueSize = do env <- atomically $ newEnv port queueSize - runReaderT (runTCPServer runClient) env - -runTCPServer :: (MonadUnliftIO m, MonadReader Env m) => (Handle -> m ()) -> m () -runTCPServer server = - E.bracket startTCPServer (liftIO . close) $ \sock -> forever $ do - h <- acceptTCPConn sock - putLn h "Welcome to SMP" - forkFinally (server h) (const $ hClose h) + runReaderT (runTCPServer port runClient) env runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () runClient h = do + putLn h "Welcome to SMP" q <- asks queueSize c <- atomically $ newClient q void $ race (client h c) (receive h c) diff --git a/src/Transport.hs b/src/Transport.hs index 6d6ab090c..33931a72f 100644 --- a/src/Transport.hs +++ b/src/Transport.hs @@ -8,26 +8,36 @@ module Transport where import Control.Monad.IO.Class +import Control.Monad.IO.Unlift import Control.Monad.Reader import qualified Data.ByteString.Char8 as B -import Env.STM import Network.Socket import System.IO import Text.Read import Transmission +import UnliftIO.Concurrent +import qualified UnliftIO.Exception as E +import qualified UnliftIO.IO as IO -startTCPServer :: (MonadReader Env m, MonadIO m) => m Socket -startTCPServer = do - port <- asks tcpPort - liftIO . withSocketsDo $ do - let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream} - addr <- head <$> getAddrInfo (Just hints) Nothing (Just port) - sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) - setSocketOption sock ReuseAddr 1 - withFdSocket sock setCloseOnExecIfNeeded - bind sock $ addrAddress addr - listen sock 1024 - return sock +startTCPServer :: MonadIO m => ServiceName -> m Socket +startTCPServer port = liftIO . withSocketsDo $ resolve >>= open + where + resolve = do + let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream} + head <$> getAddrInfo (Just hints) Nothing (Just port) + open addr = do + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + setSocketOption sock ReuseAddr 1 + withFdSocket sock setCloseOnExecIfNeeded + bind sock $ addrAddress addr + listen sock 1024 + return sock + +runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m () +runTCPServer port server = + E.bracket (startTCPServer port) (liftIO . close) $ \sock -> forever $ do + h <- acceptTCPConn sock + forkFinally (server h) (const $ IO.hClose h) acceptTCPConn :: MonadIO m => Socket -> m Handle acceptTCPConn sock = liftIO $ do @@ -35,6 +45,20 @@ acceptTCPConn sock = liftIO $ do -- putStrLn $ "Accepted connection from " ++ show peer getSocketHandle conn +startTCPClient :: MonadIO m => HostName -> ServiceName -> m Handle +startTCPClient host port = liftIO . withSocketsDo $ resolve >>= open + where + resolve = do + let hints = defaultHints {addrSocketType = Stream} + head <$> getAddrInfo (Just hints) (Just host) (Just port) + open addr = do + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + connect sock $ addrAddress addr + getSocketHandle sock + +runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a +runTCPClient host port = E.bracket (startTCPClient host port) IO.hClose + getSocketHandle :: MonadIO m => Socket -> m Handle getSocketHandle conn = liftIO $ do h <- socketToHandle conn ReadWriteMode @@ -104,7 +128,7 @@ tGet fromParty h = do | null connId -> Left $ SYNTAX errNoConnectionId | otherwise -> Right cmd -- other client commands must have both signature and connection ID - _ + Cmd SRecipient _ | null signature || null connId -> Left $ SYNTAX errNoCredentials | otherwise -> Right cmd diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index f41349eb1..726be399f 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -2,7 +2,6 @@ module SMPClient where -import Control.Monad import Control.Monad.IO.Unlift import Network.Socket import Numeric.Natural @@ -13,26 +12,9 @@ import UnliftIO.Concurrent import qualified UnliftIO.Exception as E import UnliftIO.IO -startTCPClient :: MonadIO m => HostName -> ServiceName -> m Handle -startTCPClient host port = liftIO . withSocketsDo $ do +testSMPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a +testSMPClient host port client = do threadDelay 1 -- TODO hack: thread delay for SMP server to start - addr <- resolve - open addr - where - resolve = do - let hints = defaultHints {addrSocketType = Stream} - head <$> getAddrInfo (Just hints) (Just host) (Just port) - open addr = do - sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) - connect sock $ addrAddress addr - getSocketHandle sock - -runTCPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a -runTCPClient host port client = do - E.bracket (startTCPClient host port) (liftIO . hClose) client - -runSMPClient :: MonadUnliftIO m => HostName -> ServiceName -> (Handle -> m a) -> m a -runSMPClient host port client = runTCPClient host port $ \h -> do line <- getLn h if line == "Welcome to SMP" @@ -55,7 +37,7 @@ runSmpTest test = E.bracket (forkIO $ runSMPServer testPort queueSize) (liftIO . killThread) - \_ -> runSMPClient "localhost" testPort test + \_ -> testSMPClient "localhost" testPort test smpServerTest :: [TestTransmission] -> IO [TestTransmission] smpServerTest commands = runSmpTest \h -> mapM (sendReceive h) commands