diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 7b94738b0..34dbcf34a 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -70,7 +70,7 @@ runXFTPServerBlocking :: TMVar Bool -> XFTPServerConfig -> IO () runXFTPServerBlocking started cfg = newXFTPServerEnv cfg >>= runReaderT (xftpServer cfg started) xftpServer :: XFTPServerConfig -> TMVar Bool -> M () -xftpServer cfg@XFTPServerConfig {xftpPort, logTLSErrors} started = do +xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig} started = do restoreServerStats raceAny_ (runServer : expireFilesThread_ cfg <> serverStatsThread_ cfg) `finally` stopServer where @@ -79,7 +79,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, logTLSErrors} started = do serverParams <- asks tlsServerParams env <- ask liftIO $ - runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams logTLSErrors $ \sessionId r sendResponse -> do + runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig $ \sessionId r sendResponse -> do reqBody <- getHTTP2Body r xftpBlockSize processRequest HTTP2Request {sessionId, request = r, reqBody, sendResponse} `runReaderT` env diff --git a/src/Simplex/FileTransfer/Server/Env.hs b/src/Simplex/FileTransfer/Server/Env.hs index 5ca6ddcc7..584594e96 100644 --- a/src/Simplex/FileTransfer/Server/Env.hs +++ b/src/Simplex/FileTransfer/Server/Env.hs @@ -26,7 +26,7 @@ import Simplex.FileTransfer.Server.StoreLog import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (BasicAuth, RcvPublicVerifyKey) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams) +import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig) import Simplex.Messaging.Util (tshow) import System.IO (IOMode (..)) import UnliftIO.STM @@ -55,7 +55,7 @@ data XFTPServerConfig = XFTPServerConfig logStatsStartTime :: Int64, serverStatsLogFile :: FilePath, serverStatsBackupFile :: Maybe FilePath, - logTLSErrors :: Bool + transportConfig :: TransportServerConfig } data XFTPEnv = XFTPEnv diff --git a/src/Simplex/FileTransfer/Server/Main.hs b/src/Simplex/FileTransfer/Server/Main.hs index 0a5d0290e..9f496e330 100644 --- a/src/Simplex/FileTransfer/Server/Main.hs +++ b/src/Simplex/FileTransfer/Server/Main.hs @@ -25,6 +25,7 @@ import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern XFTPServer) import Simplex.Messaging.Server.CLI import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Transport.Client (TransportHost (..)) +import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig) import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) @@ -151,7 +152,10 @@ xftpServerCLI cfgPath logPath = do logStatsStartTime = 0, -- seconds from 00:00 UTC serverStatsLogFile = combine logPath "file-server-stats.daily.log", serverStatsBackupFile = logStats $> combine logPath "file-server-stats.log", - logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini + transportConfig = + defaultTransportServerConfig + { logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini + } } data CliCommand diff --git a/src/Simplex/Messaging/Agent/Server.hs b/src/Simplex/Messaging/Agent/Server.hs index b51903c43..ef1bf5edc 100644 --- a/src/Simplex/Messaging/Agent/Server.hs +++ b/src/Simplex/Messaging/Agent/Server.hs @@ -23,7 +23,7 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore) import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion) -import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer) +import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer, defaultTransportServerConfig) import Simplex.Messaging.Util (bshow) import UnliftIO.Async (race_) import qualified UnliftIO.Exception as E @@ -48,7 +48,7 @@ runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile, smpAgent _ = do -- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile - runTransportServer started tcpPort tlsServerParams True $ \(h :: c) -> do + runTransportServer started tcpPort tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion c <- getAgentClient initServers logConnection c True diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 8c6c6bc03..7a1abd20f 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -72,7 +72,7 @@ runNtfServerBlocking started cfg = runReaderT (ntfServer cfg started) =<< newNtf type M a = ReaderT NtfEnv IO a ntfServer :: NtfServerConfig -> TMVar Bool -> M () -ntfServer cfg@NtfServerConfig {transports, logTLSErrors} started = do +ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do restoreServerStats s <- asks subscriber ps <- asks pushServer @@ -83,7 +83,7 @@ ntfServer cfg@NtfServerConfig {transports, logTLSErrors} started = do runServer :: (ServiceName, ATransport) -> M () runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams - runTransportServer started tcpPort serverParams logTLSErrors (runClient t) + runTransportServer started tcpPort serverParams tCfg (runClient t) runClient :: Transport c => TProxy c -> c -> M () runClient _ h = do diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index f9f3926c5..fa3f40cab 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -32,7 +32,7 @@ import Simplex.Messaging.Server.Expiration import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport) -import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams) +import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig) import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -57,7 +57,7 @@ data NtfServerConfig = NtfServerConfig logStatsStartTime :: Int64, serverStatsLogFile :: FilePath, serverStatsBackupFile :: Maybe FilePath, - logTLSErrors :: Bool + transportConfig :: TransportServerConfig } defaultInactiveClientExpiration :: ExpirationConfig diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index efe6aeb2a..cd456fc55 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -23,6 +23,7 @@ import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientCo import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern NtfServer) import Simplex.Messaging.Server.CLI import Simplex.Messaging.Transport.Client (TransportHost (..)) +import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig) import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) @@ -123,7 +124,10 @@ ntfServerCLI cfgPath logPath = logStatsStartTime = 0, -- seconds from 00:00 UTC serverStatsLogFile = combine logPath "ntf-server-stats.daily.log", serverStatsBackupFile = logStats $> combine logPath "ntf-server-stats.log", - logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini + transportConfig = + defaultTransportServerConfig + { logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini + } } data CliCommand diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index fa30520ff..dcd02ca5e 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -103,7 +103,7 @@ runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started type M a = ReaderT Env IO a smpServer :: TMVar Bool -> ServerConfig -> M () -smpServer started cfg@ServerConfig {transports, logTLSErrors} = do +smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do s <- asks server restoreServerMessages restoreServerStats @@ -117,7 +117,7 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do runServer :: (ServiceName, ATransport) -> M () runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams - runTransportServer started tcpPort serverParams logTLSErrors (runClient t) + runTransportServer started tcpPort serverParams tCfg (runClient t) serverThread :: forall s. diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 27291540e..609438848 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -30,7 +30,7 @@ import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport) -import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams) +import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig) import Simplex.Messaging.Version import System.IO (IOMode (..)) import System.Mem.Weak (Weak) @@ -69,7 +69,8 @@ data ServerConfig = ServerConfig certificateFile :: FilePath, -- | SMP client-server protocol version range smpServerVRange :: VersionRange, - logTLSErrors :: Bool + -- | TCP transport config + transportConfig :: TransportServerConfig } defMsgExpirationDays :: Int64 diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 004f72971..df15cebda 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -28,6 +28,7 @@ import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClien import Simplex.Messaging.Server.Expiration import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange) import Simplex.Messaging.Transport.Client (TransportHost (..)) +import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig) import Simplex.Messaging.Util (safeDecodeUtf8) import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath (combine) @@ -198,7 +199,10 @@ smpServerCLI cfgPath logPath = serverStatsLogFile = combine logPath "smp-server-stats.daily.log", serverStatsBackupFile = logStats $> combine logPath "smp-server-stats.log", smpServerVRange = supportedSMPServerVRange, - logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini + transportConfig = + defaultTransportServerConfig + { logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini + } } data CliCommand diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 3344a8197..3c44e0989 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -29,6 +29,7 @@ module Simplex.Messaging.Transport supportedSMPServerVRange, simplexMQVersion, smpBlockSize, + TransportConfig (..), -- * Transport connection class Transport (..), @@ -104,6 +105,11 @@ simplexMQVersion = showVersion SMQ.version -- * Transport connection class +data TransportConfig = TransportConfig + { logTLSErrors :: Bool, + transportTimeout :: Maybe Int + } + class Transport c where transport :: ATransport transport = ATransport (TProxy @c) @@ -112,11 +118,13 @@ class Transport c where transportPeer :: c -> TransportPeer + transportConfig :: c -> TransportConfig + -- | Upgrade server TLS context to connection (used in the server) - getServerConnection :: T.Context -> IO c + getServerConnection :: TransportConfig -> T.Context -> IO c -- | Upgrade client TLS context to connection (used in the client) - getClientConnection :: T.Context -> IO c + getClientConnection :: TransportConfig -> T.Context -> IO c -- | tls-unique channel binding per RFC5929 tlsUnique :: c -> SessionId @@ -150,24 +158,25 @@ data TLS = TLS { tlsContext :: T.Context, tlsPeer :: TransportPeer, tlsUniq :: ByteString, - tlsBuffer :: TBuffer + tlsBuffer :: TBuffer, + tlsTransportConfig :: TransportConfig } -connectTLS :: T.TLSParams p => Maybe HostName -> Bool -> p -> Socket -> IO T.Context -connectTLS host_ logErrors params sock = +connectTLS :: T.TLSParams p => Maybe HostName -> TransportConfig -> p -> Socket -> IO T.Context +connectTLS host_ TransportConfig {logTLSErrors} params sock = E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx -> logHandshakeErrors (T.handshake ctx) $> ctx where - logHandshakeErrors = if logErrors then (`catchAll` logThrow) else id + logHandshakeErrors = if logTLSErrors then (`catchAll` logThrow) else id logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e host = maybe "" (\h -> " (" <> h <> ")") host_ -getTLS :: TransportPeer -> T.Context -> IO TLS -getTLS tlsPeer cxt = withTlsUnique tlsPeer cxt newTLS +getTLS :: TransportPeer -> TransportConfig -> T.Context -> IO TLS +getTLS tlsPeer cfg cxt = withTlsUnique tlsPeer cxt newTLS where newTLS tlsUniq = do tlsBuffer <- atomically newTBuffer - pure TLS {tlsContext = cxt, tlsPeer, tlsUniq, tlsBuffer} + pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsPeer, tlsUniq, tlsBuffer} withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c withTlsUnique peer cxt f = @@ -199,6 +208,7 @@ supportedParameters = instance Transport TLS where transportName _ = "TLS" transportPeer = tlsPeer + transportConfig = tlsTransportConfig getServerConnection = getTLS TServer getClientConnection = getTLS TClient tlsUnique = tlsUniq @@ -207,10 +217,12 @@ instance Transport TLS where -- https://hackage.haskell.org/package/tls-1.6.0/docs/Network-TLS.html#v:recvData -- this function may return less than requested number of bytes cGet :: TLS -> Int -> IO ByteString - cGet TLS {tlsContext, tlsBuffer} n = getBuffered tlsBuffer n (T.recvData tlsContext) - + cGet TLS {tlsContext, tlsBuffer, tlsTransportConfig = TransportConfig {transportTimeout = t_}} n = + getBuffered tlsBuffer n t_ (T.recvData tlsContext) + cPut :: TLS -> ByteString -> IO () - cPut tls = T.sendData (tlsContext tls) . BL.fromStrict + cPut TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s = + withTimedErr t_ . T.sendData tlsContext $ BL.fromStrict s getLn :: TLS -> IO ByteString getLn TLS {tlsContext, tlsBuffer} = do diff --git a/src/Simplex/Messaging/Transport/Buffer.hs b/src/Simplex/Messaging/Transport/Buffer.hs index 7d09d6819..141690386 100644 --- a/src/Simplex/Messaging/Transport/Buffer.hs +++ b/src/Simplex/Messaging/Transport/Buffer.hs @@ -8,6 +8,8 @@ import Control.Concurrent.STM import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import System.Timeout (timeout) +import GHC.IO.Exception (ioException, IOException (..), IOErrorType (..)) data TBuffer = TBuffer { buffer :: TVar ByteString, @@ -26,22 +28,33 @@ withBufferLock TBuffer {getLock} = (atomically $ takeTMVar getLock) (atomically $ putTMVar getLock ()) -getBuffered :: TBuffer -> Int -> IO ByteString -> IO ByteString -getBuffered tb@TBuffer {buffer} n getChunk = withBufferLock tb $ do - b <- readChunks =<< readTVarIO buffer +getBuffered :: TBuffer -> Int -> Maybe Int -> IO ByteString -> IO ByteString +getBuffered tb@TBuffer {buffer} n t_ getChunk = withBufferLock tb $ do + b <- readChunks True =<< readTVarIO buffer let (s, b') = B.splitAt n b atomically $ writeTVar buffer $! b' -- This would prevent the need to pad auth tag in HTTP2 -- threadDelay 150 pure s where - readChunks :: ByteString -> IO ByteString - readChunks b + readChunks :: Bool -> ByteString -> IO ByteString + readChunks firstChunk b | B.length b >= n = pure b | otherwise = - getChunk >>= \case + get >>= \case "" -> pure b - s -> readChunks $ b <> s + s -> readChunks False $ b <> s + where + get + | firstChunk = getChunk + | otherwise = withTimedErr t_ getChunk + +withTimedErr :: Maybe Int -> IO a -> IO a +withTimedErr t_ a = case t_ of + Just t -> timeout t a >>= maybe err pure + Nothing -> a + where + err = ioException (IOError Nothing TimeExpired "" "get timeout" Nothing Nothing) -- This function is only used in test and needs to be improved before it can be used in production, -- it will never complete if TLS connection is closed before there is newline. diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index 464d3230e..1f7fd44a8 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -114,12 +115,16 @@ data TransportClientConfig = TransportClientConfig defaultTransportClientConfig :: TransportClientConfig defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True +clientTransportConfig :: TransportClientConfig -> TransportConfig +clientTransportConfig TransportClientConfig {logTLSErrors} = + TransportConfig {logTLSErrors, transportTimeout = Nothing} + -- | Connect to passed TCP host:port and pass handle to the client. runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a runTransportClient = runTLSTransportClient supportedParameters Nothing runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a -runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do +runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive} proxyUsername host port keyHash client = do let hostName = B.unpack $ strEncode host clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash connectTCP = case socksProxy of @@ -128,7 +133,8 @@ runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpK c <- liftIO $ do sock <- connectTCP port mapM_ (setSocketKeepAlive sock) tcpKeepAlive - connectTLS (Just hostName) logTLSErrors clientParams sock >>= getClientConnection + let tCfg = clientTransportConfig cfg + connectTLS (Just hostName) tCfg clientParams sock >>= getClientConnection tCfg client c `E.finally` liftIO (closeConnection c) where hostAddr = \case diff --git a/src/Simplex/Messaging/Transport/HTTP2.hs b/src/Simplex/Messaging/Transport/HTTP2.hs index f258f9dc9..1feccce88 100644 --- a/src/Simplex/Messaging/Transport/HTTP2.hs +++ b/src/Simplex/Messaging/Transport/HTTP2.hs @@ -73,7 +73,7 @@ instance HTTP2BodyChunk HS.Request where getHTTP2Body :: HTTP2BodyChunk a => a -> Int -> IO HTTP2Body getHTTP2Body r n = do bodyBuffer <- atomically newTBuffer - let getPart n' = getBuffered bodyBuffer n' $ getBodyChunk r + let getPart n' = getBuffered bodyBuffer n' Nothing $ getBodyChunk r bodyHead <- getPart n let bodySize = fromMaybe 0 $ getBodySize r -- TODO check bodySize once it is set diff --git a/src/Simplex/Messaging/Transport/HTTP2/Server.hs b/src/Simplex/Messaging/Transport/HTTP2/Server.hs index 9a305ed8a..650026ef4 100644 --- a/src/Simplex/Messaging/Transport/HTTP2/Server.hs +++ b/src/Simplex/Messaging/Transport/HTTP2/Server.hs @@ -14,7 +14,7 @@ import qualified Network.TLS as T import Numeric.Natural (Natural) import Simplex.Messaging.Transport (SessionId) import Simplex.Messaging.Transport.HTTP2 -import Simplex.Messaging.Transport.Server (loadSupportedTLSServerParams, runTransportServer) +import Simplex.Messaging.Transport.Server (TransportServerConfig (..), loadSupportedTLSServerParams, runTransportServer) type HTTP2ServerFunc = SessionId -> Request -> (Response -> IO ()) -> IO () @@ -27,7 +27,7 @@ data HTTP2ServerConfig = HTTP2ServerConfig caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, - logTLSErrors :: Bool + transportConfig :: TransportServerConfig } deriving (Show) @@ -45,12 +45,12 @@ data HTTP2Server = HTTP2Server -- This server is for testing only, it processes all requests in a single queue. getHTTP2Server :: HTTP2ServerConfig -> IO HTTP2Server -getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, serverSupported, caCertificateFile, certificateFile, privateKeyFile, logTLSErrors} = do +getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, serverSupported, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do tlsServerParams <- loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile started <- newEmptyTMVarIO reqQ <- newTBQueueIO qSize action <- async $ - runHTTP2Server started http2Port bufferSize tlsServerParams logTLSErrors $ \sessionId r sendResponse -> do + runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig $ \sessionId r sendResponse -> do reqBody <- getHTTP2Body r bodyHeadSize atomically $ writeTBQueue reqQ HTTP2Request {sessionId, request = r, reqBody, sendResponse} void . atomically $ takeTMVar started @@ -59,8 +59,8 @@ getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, se closeHTTP2Server :: HTTP2Server -> IO () closeHTTP2Server = uninterruptibleCancel . action -runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> Bool -> HTTP2ServerFunc -> IO () -runHTTP2Server started port bufferSize serverParams logTLSErrors http2Server = - runTransportServer started port serverParams logTLSErrors $ withHTTP2 bufferSize run +runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> HTTP2ServerFunc -> IO () +runHTTP2Server started port bufferSize serverParams transportConfig http2Server = + runTransportServer started port serverParams transportConfig $ withHTTP2 bufferSize run where run cfg sessId = H.run cfg $ \req _aux sendResp -> http2Server sessId req (`sendResp` []) diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index 536ebb0db..02d163079 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -7,6 +7,8 @@ module Simplex.Messaging.Transport.Server ( runTransportServer, runTCPServer, + TransportServerConfig (..), + defaultTransportServerConfig, loadSupportedTLSServerParams, loadTLSServerParams, loadFingerprint, @@ -38,15 +40,32 @@ import UnliftIO.Concurrent import qualified UnliftIO.Exception as E import UnliftIO.STM +data TransportServerConfig = TransportServerConfig + { logTLSErrors :: Bool, + transportTimeout :: Int + } + deriving (Eq, Show) + +defaultTransportServerConfig :: TransportServerConfig +defaultTransportServerConfig = TransportServerConfig + { logTLSErrors = True, + transportTimeout = 40000000 + } + +serverTransportConfig :: TransportServerConfig -> TransportConfig +serverTransportConfig TransportServerConfig {logTLSErrors, transportTimeout} = + TransportConfig {logTLSErrors, transportTimeout = Just transportTimeout} + -- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar. -- -- All accepted connections are passed to the passed function. -runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> Bool -> (c -> m ()) -> m () -runTransportServer started port serverParams logTLSErrors server = do +runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m () +runTransportServer started port serverParams cfg server = do u <- askUnliftIO + let tCfg = serverTransportConfig cfg liftIO . runTCPServer started port $ \conn -> E.bracket - (connectTLS Nothing logTLSErrors serverParams conn >>= getServerConnection) + (connectTLS Nothing tCfg serverParams conn >>= getServerConnection tCfg) closeConnection (unliftIO u . server) diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index c6b1a2610..a0633e09e 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -17,6 +17,7 @@ import Simplex.Messaging.Transport Transport (..), TransportError (..), TransportPeer (..), + TransportConfig (..), closeTLS, smpBlockSize, withTlsUnique, @@ -27,7 +28,8 @@ data WS = WS { wsPeer :: TransportPeer, tlsUniq :: ByteString, wsStream :: Stream, - wsConnection :: Connection + wsConnection :: Connection, + wsTransportConfig :: TransportConfig } websocketsOpts :: ConnectionOptions @@ -45,10 +47,13 @@ instance Transport WS where transportPeer :: WS -> TransportPeer transportPeer = wsPeer - getServerConnection :: T.Context -> IO WS + transportConfig :: WS -> TransportConfig + transportConfig = wsTransportConfig + + getServerConnection :: TransportConfig -> T.Context -> IO WS getServerConnection = getWS TServer - getClientConnection :: T.Context -> IO WS + getClientConnection :: TransportConfig -> T.Context -> IO WS getClientConnection = getWS TClient tlsUnique :: WS -> ByteString @@ -74,13 +79,13 @@ instance Transport WS where then E.throwIO TEBadBlock else pure $ B.init s -getWS :: TransportPeer -> T.Context -> IO WS -getWS wsPeer cxt = withTlsUnique wsPeer cxt connectWS +getWS :: TransportPeer -> TransportConfig -> T.Context -> IO WS +getWS wsPeer cfg cxt = withTlsUnique wsPeer cxt connectWS where connectWS tlsUniq = do s <- makeTLSContextStream cxt wsConnection <- connectPeer wsPeer s - pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection} + pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg} connectPeer :: TransportPeer -> Stream -> IO Connection connectPeer TServer = acceptClientRequest connectPeer TClient = sendClientRequest diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 3984ff881..184d184d6 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -44,6 +44,7 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..), http2TLSParams) import Simplex.Messaging.Transport.HTTP2.Server +import Simplex.Messaging.Transport.Server import Test.Hspec import UnliftIO.Async import UnliftIO.Concurrent @@ -99,7 +100,7 @@ ntfServerCfg = logStatsStartTime = 0, serverStatsLogFile = "tests/ntf-server-stats.daily.log", serverStatsBackupFile = Nothing, - logTLSErrors = True + transportConfig = defaultTransportServerConfig } withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a @@ -166,7 +167,7 @@ apnsMockServerConfig = caCertificateFile = "tests/fixtures/ca.crt", privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt", - logTLSErrors = True + transportConfig = defaultTransportServerConfig } withAPNSMockServer :: (APNSMockServer -> IO ()) -> IO () diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 3e60b9061..c49b65ee6 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -24,6 +24,7 @@ import Simplex.Messaging.Server (runSMPServerBlocking) import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client +import Simplex.Messaging.Transport.Server import Simplex.Messaging.Version import System.Environment (lookupEnv) import System.Info (os) @@ -99,7 +100,7 @@ cfg = privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt", smpServerVRange = supportedSMPServerVRange, - logTLSErrors = True + transportConfig = defaultTransportServerConfig } withSmpServerStoreMsgLogOnV2 :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a diff --git a/tests/XFTPClient.hs b/tests/XFTPClient.hs index 4bd91997a..658344aed 100644 --- a/tests/XFTPClient.hs +++ b/tests/XFTPClient.hs @@ -14,6 +14,7 @@ import Simplex.FileTransfer.Description import Simplex.FileTransfer.Server (runXFTPServerBlocking) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defaultFileExpiration) import Simplex.Messaging.Protocol (XFTPServer) +import Simplex.Messaging.Transport.Server import Test.Hspec xftpTest :: HasCallStack => (HasCallStack => XFTPClient -> IO ()) -> Expectation @@ -111,7 +112,7 @@ testXFTPServerConfig = logStatsStartTime = 0, serverStatsLogFile = "tests/tmp/xftp-server-stats.daily.log", serverStatsBackupFile = Nothing, - logTLSErrors = True + transportConfig = defaultTransportServerConfig } testXFTPClientConfig :: XFTPClientConfig