servers: add TCP timeouts to avoid memory leaks (#776)

* servers: add TCP timeouts to avoid memory leaks

* fix tests

* only use RecvTimeOut

* servers: simple timeout for TCP transport

* revert dependency change

* simplify

* simplify

* simplify 2
This commit is contained in:
Evgeny Poberezkin
2023-06-30 16:22:01 +01:00
committed by GitHub
parent 16367fcb3b
commit 94540a2c71
20 changed files with 130 additions and 59 deletions

View File

@@ -70,7 +70,7 @@ runXFTPServerBlocking :: TMVar Bool -> XFTPServerConfig -> IO ()
runXFTPServerBlocking started cfg = newXFTPServerEnv cfg >>= runReaderT (xftpServer cfg started)
xftpServer :: XFTPServerConfig -> TMVar Bool -> M ()
xftpServer cfg@XFTPServerConfig {xftpPort, logTLSErrors} started = do
xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig} started = do
restoreServerStats
raceAny_ (runServer : expireFilesThread_ cfg <> serverStatsThread_ cfg) `finally` stopServer
where
@@ -79,7 +79,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, logTLSErrors} started = do
serverParams <- asks tlsServerParams
env <- ask
liftIO $
runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams logTLSErrors $ \sessionId r sendResponse -> do
runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig $ \sessionId r sendResponse -> do
reqBody <- getHTTP2Body r xftpBlockSize
processRequest HTTP2Request {sessionId, request = r, reqBody, sendResponse} `runReaderT` env

View File

@@ -26,7 +26,7 @@ import Simplex.FileTransfer.Server.StoreLog
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (BasicAuth, RcvPublicVerifyKey)
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
import Simplex.Messaging.Util (tshow)
import System.IO (IOMode (..))
import UnliftIO.STM
@@ -55,7 +55,7 @@ data XFTPServerConfig = XFTPServerConfig
logStatsStartTime :: Int64,
serverStatsLogFile :: FilePath,
serverStatsBackupFile :: Maybe FilePath,
logTLSErrors :: Bool
transportConfig :: TransportServerConfig
}
data XFTPEnv = XFTPEnv

View File

@@ -25,6 +25,7 @@ import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern XFTPServer)
import Simplex.Messaging.Server.CLI
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport.Client (TransportHost (..))
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig)
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.FilePath (combine)
import System.IO (BufferMode (..), hSetBuffering, stderr, stdout)
@@ -151,7 +152,10 @@ xftpServerCLI cfgPath logPath = do
logStatsStartTime = 0, -- seconds from 00:00 UTC
serverStatsLogFile = combine logPath "file-server-stats.daily.log",
serverStatsBackupFile = logStats $> combine logPath "file-server-stats.log",
logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
transportConfig =
defaultTransportServerConfig
{ logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
}
}
data CliCommand

View File

@@ -23,7 +23,7 @@ import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion)
import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer)
import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer, defaultTransportServerConfig)
import Simplex.Messaging.Util (bshow)
import UnliftIO.Async (race_)
import qualified UnliftIO.Exception as E
@@ -48,7 +48,7 @@ runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile,
smpAgent _ = do
-- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
runTransportServer started tcpPort tlsServerParams True $ \(h :: c) -> do
runTransportServer started tcpPort tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
c <- getAgentClient initServers
logConnection c True

View File

@@ -72,7 +72,7 @@ runNtfServerBlocking started cfg = runReaderT (ntfServer cfg started) =<< newNtf
type M a = ReaderT NtfEnv IO a
ntfServer :: NtfServerConfig -> TMVar Bool -> M ()
ntfServer cfg@NtfServerConfig {transports, logTLSErrors} started = do
ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
restoreServerStats
s <- asks subscriber
ps <- asks pushServer
@@ -83,7 +83,7 @@ ntfServer cfg@NtfServerConfig {transports, logTLSErrors} started = do
runServer :: (ServiceName, ATransport) -> M ()
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
runTransportServer started tcpPort serverParams logTLSErrors (runClient t)
runTransportServer started tcpPort serverParams tCfg (runClient t)
runClient :: Transport c => TProxy c -> c -> M ()
runClient _ h = do

View File

@@ -32,7 +32,7 @@ import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
import UnliftIO.STM
@@ -57,7 +57,7 @@ data NtfServerConfig = NtfServerConfig
logStatsStartTime :: Int64,
serverStatsLogFile :: FilePath,
serverStatsBackupFile :: Maybe FilePath,
logTLSErrors :: Bool
transportConfig :: TransportServerConfig
}
defaultInactiveClientExpiration :: ExpirationConfig

View File

@@ -23,6 +23,7 @@ import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientCo
import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern NtfServer)
import Simplex.Messaging.Server.CLI
import Simplex.Messaging.Transport.Client (TransportHost (..))
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig)
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.FilePath (combine)
import System.IO (BufferMode (..), hSetBuffering, stderr, stdout)
@@ -123,7 +124,10 @@ ntfServerCLI cfgPath logPath =
logStatsStartTime = 0, -- seconds from 00:00 UTC
serverStatsLogFile = combine logPath "ntf-server-stats.daily.log",
serverStatsBackupFile = logStats $> combine logPath "ntf-server-stats.log",
logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
transportConfig =
defaultTransportServerConfig
{ logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
}
}
data CliCommand

View File

@@ -103,7 +103,7 @@ runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started
type M a = ReaderT Env IO a
smpServer :: TMVar Bool -> ServerConfig -> M ()
smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
s <- asks server
restoreServerMessages
restoreServerStats
@@ -117,7 +117,7 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
runServer :: (ServiceName, ATransport) -> M ()
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
runTransportServer started tcpPort serverParams logTLSErrors (runClient t)
runTransportServer started tcpPort serverParams tCfg (runClient t)
serverThread ::
forall s.

View File

@@ -30,7 +30,7 @@ import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
import Simplex.Messaging.Version
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
@@ -69,7 +69,8 @@ data ServerConfig = ServerConfig
certificateFile :: FilePath,
-- | SMP client-server protocol version range
smpServerVRange :: VersionRange,
logTLSErrors :: Bool
-- | TCP transport config
transportConfig :: TransportServerConfig
}
defMsgExpirationDays :: Int64

View File

@@ -28,6 +28,7 @@ import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClien
import Simplex.Messaging.Server.Expiration
import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange)
import Simplex.Messaging.Transport.Client (TransportHost (..))
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig)
import Simplex.Messaging.Util (safeDecodeUtf8)
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.FilePath (combine)
@@ -198,7 +199,10 @@ smpServerCLI cfgPath logPath =
serverStatsLogFile = combine logPath "smp-server-stats.daily.log",
serverStatsBackupFile = logStats $> combine logPath "smp-server-stats.log",
smpServerVRange = supportedSMPServerVRange,
logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
transportConfig =
defaultTransportServerConfig
{ logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
}
}
data CliCommand

View File

@@ -29,6 +29,7 @@ module Simplex.Messaging.Transport
supportedSMPServerVRange,
simplexMQVersion,
smpBlockSize,
TransportConfig (..),
-- * Transport connection class
Transport (..),
@@ -104,6 +105,11 @@ simplexMQVersion = showVersion SMQ.version
-- * Transport connection class
data TransportConfig = TransportConfig
{ logTLSErrors :: Bool,
transportTimeout :: Maybe Int
}
class Transport c where
transport :: ATransport
transport = ATransport (TProxy @c)
@@ -112,11 +118,13 @@ class Transport c where
transportPeer :: c -> TransportPeer
transportConfig :: c -> TransportConfig
-- | Upgrade server TLS context to connection (used in the server)
getServerConnection :: T.Context -> IO c
getServerConnection :: TransportConfig -> T.Context -> IO c
-- | Upgrade client TLS context to connection (used in the client)
getClientConnection :: T.Context -> IO c
getClientConnection :: TransportConfig -> T.Context -> IO c
-- | tls-unique channel binding per RFC5929
tlsUnique :: c -> SessionId
@@ -150,24 +158,25 @@ data TLS = TLS
{ tlsContext :: T.Context,
tlsPeer :: TransportPeer,
tlsUniq :: ByteString,
tlsBuffer :: TBuffer
tlsBuffer :: TBuffer,
tlsTransportConfig :: TransportConfig
}
connectTLS :: T.TLSParams p => Maybe HostName -> Bool -> p -> Socket -> IO T.Context
connectTLS host_ logErrors params sock =
connectTLS :: T.TLSParams p => Maybe HostName -> TransportConfig -> p -> Socket -> IO T.Context
connectTLS host_ TransportConfig {logTLSErrors} params sock =
E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx ->
logHandshakeErrors (T.handshake ctx) $> ctx
where
logHandshakeErrors = if logErrors then (`catchAll` logThrow) else id
logHandshakeErrors = if logTLSErrors then (`catchAll` logThrow) else id
logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e
host = maybe "" (\h -> " (" <> h <> ")") host_
getTLS :: TransportPeer -> T.Context -> IO TLS
getTLS tlsPeer cxt = withTlsUnique tlsPeer cxt newTLS
getTLS :: TransportPeer -> TransportConfig -> T.Context -> IO TLS
getTLS tlsPeer cfg cxt = withTlsUnique tlsPeer cxt newTLS
where
newTLS tlsUniq = do
tlsBuffer <- atomically newTBuffer
pure TLS {tlsContext = cxt, tlsPeer, tlsUniq, tlsBuffer}
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsPeer, tlsUniq, tlsBuffer}
withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c
withTlsUnique peer cxt f =
@@ -199,6 +208,7 @@ supportedParameters =
instance Transport TLS where
transportName _ = "TLS"
transportPeer = tlsPeer
transportConfig = tlsTransportConfig
getServerConnection = getTLS TServer
getClientConnection = getTLS TClient
tlsUnique = tlsUniq
@@ -207,10 +217,12 @@ instance Transport TLS where
-- https://hackage.haskell.org/package/tls-1.6.0/docs/Network-TLS.html#v:recvData
-- this function may return less than requested number of bytes
cGet :: TLS -> Int -> IO ByteString
cGet TLS {tlsContext, tlsBuffer} n = getBuffered tlsBuffer n (T.recvData tlsContext)
cGet TLS {tlsContext, tlsBuffer, tlsTransportConfig = TransportConfig {transportTimeout = t_}} n =
getBuffered tlsBuffer n t_ (T.recvData tlsContext)
cPut :: TLS -> ByteString -> IO ()
cPut tls = T.sendData (tlsContext tls) . BL.fromStrict
cPut TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s =
withTimedErr t_ . T.sendData tlsContext $ BL.fromStrict s
getLn :: TLS -> IO ByteString
getLn TLS {tlsContext, tlsBuffer} = do

View File

@@ -8,6 +8,8 @@ import Control.Concurrent.STM
import qualified Control.Exception as E
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import System.Timeout (timeout)
import GHC.IO.Exception (ioException, IOException (..), IOErrorType (..))
data TBuffer = TBuffer
{ buffer :: TVar ByteString,
@@ -26,22 +28,33 @@ withBufferLock TBuffer {getLock} =
(atomically $ takeTMVar getLock)
(atomically $ putTMVar getLock ())
getBuffered :: TBuffer -> Int -> IO ByteString -> IO ByteString
getBuffered tb@TBuffer {buffer} n getChunk = withBufferLock tb $ do
b <- readChunks =<< readTVarIO buffer
getBuffered :: TBuffer -> Int -> Maybe Int -> IO ByteString -> IO ByteString
getBuffered tb@TBuffer {buffer} n t_ getChunk = withBufferLock tb $ do
b <- readChunks True =<< readTVarIO buffer
let (s, b') = B.splitAt n b
atomically $ writeTVar buffer $! b'
-- This would prevent the need to pad auth tag in HTTP2
-- threadDelay 150
pure s
where
readChunks :: ByteString -> IO ByteString
readChunks b
readChunks :: Bool -> ByteString -> IO ByteString
readChunks firstChunk b
| B.length b >= n = pure b
| otherwise =
getChunk >>= \case
get >>= \case
"" -> pure b
s -> readChunks $ b <> s
s -> readChunks False $ b <> s
where
get
| firstChunk = getChunk
| otherwise = withTimedErr t_ getChunk
withTimedErr :: Maybe Int -> IO a -> IO a
withTimedErr t_ a = case t_ of
Just t -> timeout t a >>= maybe err pure
Nothing -> a
where
err = ioException (IOError Nothing TimeExpired "" "get timeout" Nothing Nothing)
-- This function is only used in test and needs to be improved before it can be used in production,
-- it will never complete if TLS connection is closed before there is newline.

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -114,12 +115,16 @@ data TransportClientConfig = TransportClientConfig
defaultTransportClientConfig :: TransportClientConfig
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True
clientTransportConfig :: TransportClientConfig -> TransportConfig
clientTransportConfig TransportClientConfig {logTLSErrors} =
TransportConfig {logTLSErrors, transportTimeout = Nothing}
-- | Connect to passed TCP host:port and pass handle to the client.
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTransportClient = runTLSTransportClient supportedParameters Nothing
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive} proxyUsername host port keyHash client = do
let hostName = B.unpack $ strEncode host
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash
connectTCP = case socksProxy of
@@ -128,7 +133,8 @@ runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpK
c <- liftIO $ do
sock <- connectTCP port
mapM_ (setSocketKeepAlive sock) tcpKeepAlive
connectTLS (Just hostName) logTLSErrors clientParams sock >>= getClientConnection
let tCfg = clientTransportConfig cfg
connectTLS (Just hostName) tCfg clientParams sock >>= getClientConnection tCfg
client c `E.finally` liftIO (closeConnection c)
where
hostAddr = \case

View File

@@ -73,7 +73,7 @@ instance HTTP2BodyChunk HS.Request where
getHTTP2Body :: HTTP2BodyChunk a => a -> Int -> IO HTTP2Body
getHTTP2Body r n = do
bodyBuffer <- atomically newTBuffer
let getPart n' = getBuffered bodyBuffer n' $ getBodyChunk r
let getPart n' = getBuffered bodyBuffer n' Nothing $ getBodyChunk r
bodyHead <- getPart n
let bodySize = fromMaybe 0 $ getBodySize r
-- TODO check bodySize once it is set

View File

@@ -14,7 +14,7 @@ import qualified Network.TLS as T
import Numeric.Natural (Natural)
import Simplex.Messaging.Transport (SessionId)
import Simplex.Messaging.Transport.HTTP2
import Simplex.Messaging.Transport.Server (loadSupportedTLSServerParams, runTransportServer)
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), loadSupportedTLSServerParams, runTransportServer)
type HTTP2ServerFunc = SessionId -> Request -> (Response -> IO ()) -> IO ()
@@ -27,7 +27,7 @@ data HTTP2ServerConfig = HTTP2ServerConfig
caCertificateFile :: FilePath,
privateKeyFile :: FilePath,
certificateFile :: FilePath,
logTLSErrors :: Bool
transportConfig :: TransportServerConfig
}
deriving (Show)
@@ -45,12 +45,12 @@ data HTTP2Server = HTTP2Server
-- This server is for testing only, it processes all requests in a single queue.
getHTTP2Server :: HTTP2ServerConfig -> IO HTTP2Server
getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, serverSupported, caCertificateFile, certificateFile, privateKeyFile, logTLSErrors} = do
getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, serverSupported, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do
tlsServerParams <- loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile
started <- newEmptyTMVarIO
reqQ <- newTBQueueIO qSize
action <- async $
runHTTP2Server started http2Port bufferSize tlsServerParams logTLSErrors $ \sessionId r sendResponse -> do
runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig $ \sessionId r sendResponse -> do
reqBody <- getHTTP2Body r bodyHeadSize
atomically $ writeTBQueue reqQ HTTP2Request {sessionId, request = r, reqBody, sendResponse}
void . atomically $ takeTMVar started
@@ -59,8 +59,8 @@ getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, se
closeHTTP2Server :: HTTP2Server -> IO ()
closeHTTP2Server = uninterruptibleCancel . action
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> Bool -> HTTP2ServerFunc -> IO ()
runHTTP2Server started port bufferSize serverParams logTLSErrors http2Server =
runTransportServer started port serverParams logTLSErrors $ withHTTP2 bufferSize run
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> HTTP2ServerFunc -> IO ()
runHTTP2Server started port bufferSize serverParams transportConfig http2Server =
runTransportServer started port serverParams transportConfig $ withHTTP2 bufferSize run
where
run cfg sessId = H.run cfg $ \req _aux sendResp -> http2Server sessId req (`sendResp` [])

View File

@@ -7,6 +7,8 @@
module Simplex.Messaging.Transport.Server
( runTransportServer,
runTCPServer,
TransportServerConfig (..),
defaultTransportServerConfig,
loadSupportedTLSServerParams,
loadTLSServerParams,
loadFingerprint,
@@ -38,15 +40,32 @@ import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import UnliftIO.STM
data TransportServerConfig = TransportServerConfig
{ logTLSErrors :: Bool,
transportTimeout :: Int
}
deriving (Eq, Show)
defaultTransportServerConfig :: TransportServerConfig
defaultTransportServerConfig = TransportServerConfig
{ logTLSErrors = True,
transportTimeout = 40000000
}
serverTransportConfig :: TransportServerConfig -> TransportConfig
serverTransportConfig TransportServerConfig {logTLSErrors, transportTimeout} =
TransportConfig {logTLSErrors, transportTimeout = Just transportTimeout}
-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
--
-- All accepted connections are passed to the passed function.
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> Bool -> (c -> m ()) -> m ()
runTransportServer started port serverParams logTLSErrors server = do
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
runTransportServer started port serverParams cfg server = do
u <- askUnliftIO
let tCfg = serverTransportConfig cfg
liftIO . runTCPServer started port $ \conn ->
E.bracket
(connectTLS Nothing logTLSErrors serverParams conn >>= getServerConnection)
(connectTLS Nothing tCfg serverParams conn >>= getServerConnection tCfg)
closeConnection
(unliftIO u . server)

View File

@@ -17,6 +17,7 @@ import Simplex.Messaging.Transport
Transport (..),
TransportError (..),
TransportPeer (..),
TransportConfig (..),
closeTLS,
smpBlockSize,
withTlsUnique,
@@ -27,7 +28,8 @@ data WS = WS
{ wsPeer :: TransportPeer,
tlsUniq :: ByteString,
wsStream :: Stream,
wsConnection :: Connection
wsConnection :: Connection,
wsTransportConfig :: TransportConfig
}
websocketsOpts :: ConnectionOptions
@@ -45,10 +47,13 @@ instance Transport WS where
transportPeer :: WS -> TransportPeer
transportPeer = wsPeer
getServerConnection :: T.Context -> IO WS
transportConfig :: WS -> TransportConfig
transportConfig = wsTransportConfig
getServerConnection :: TransportConfig -> T.Context -> IO WS
getServerConnection = getWS TServer
getClientConnection :: T.Context -> IO WS
getClientConnection :: TransportConfig -> T.Context -> IO WS
getClientConnection = getWS TClient
tlsUnique :: WS -> ByteString
@@ -74,13 +79,13 @@ instance Transport WS where
then E.throwIO TEBadBlock
else pure $ B.init s
getWS :: TransportPeer -> T.Context -> IO WS
getWS wsPeer cxt = withTlsUnique wsPeer cxt connectWS
getWS :: TransportPeer -> TransportConfig -> T.Context -> IO WS
getWS wsPeer cfg cxt = withTlsUnique wsPeer cxt connectWS
where
connectWS tlsUniq = do
s <- makeTLSContextStream cxt
wsConnection <- connectPeer wsPeer s
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection}
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg}
connectPeer :: TransportPeer -> Stream -> IO Connection
connectPeer TServer = acceptClientRequest
connectPeer TClient = sendClientRequest

View File

@@ -44,6 +44,7 @@ import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client
import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..), http2TLSParams)
import Simplex.Messaging.Transport.HTTP2.Server
import Simplex.Messaging.Transport.Server
import Test.Hspec
import UnliftIO.Async
import UnliftIO.Concurrent
@@ -99,7 +100,7 @@ ntfServerCfg =
logStatsStartTime = 0,
serverStatsLogFile = "tests/ntf-server-stats.daily.log",
serverStatsBackupFile = Nothing,
logTLSErrors = True
transportConfig = defaultTransportServerConfig
}
withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a
@@ -166,7 +167,7 @@ apnsMockServerConfig =
caCertificateFile = "tests/fixtures/ca.crt",
privateKeyFile = "tests/fixtures/server.key",
certificateFile = "tests/fixtures/server.crt",
logTLSErrors = True
transportConfig = defaultTransportServerConfig
}
withAPNSMockServer :: (APNSMockServer -> IO ()) -> IO ()

View File

@@ -24,6 +24,7 @@ import Simplex.Messaging.Server (runSMPServerBlocking)
import Simplex.Messaging.Server.Env.STM
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client
import Simplex.Messaging.Transport.Server
import Simplex.Messaging.Version
import System.Environment (lookupEnv)
import System.Info (os)
@@ -99,7 +100,7 @@ cfg =
privateKeyFile = "tests/fixtures/server.key",
certificateFile = "tests/fixtures/server.crt",
smpServerVRange = supportedSMPServerVRange,
logTLSErrors = True
transportConfig = defaultTransportServerConfig
}
withSmpServerStoreMsgLogOnV2 :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a

View File

@@ -14,6 +14,7 @@ import Simplex.FileTransfer.Description
import Simplex.FileTransfer.Server (runXFTPServerBlocking)
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defaultFileExpiration)
import Simplex.Messaging.Protocol (XFTPServer)
import Simplex.Messaging.Transport.Server
import Test.Hspec
xftpTest :: HasCallStack => (HasCallStack => XFTPClient -> IO ()) -> Expectation
@@ -111,7 +112,7 @@ testXFTPServerConfig =
logStatsStartTime = 0,
serverStatsLogFile = "tests/tmp/xftp-server-stats.daily.log",
serverStatsBackupFile = Nothing,
logTLSErrors = True
transportConfig = defaultTransportServerConfig
}
testXFTPClientConfig :: XFTPClientConfig