mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 18:35:59 +00:00
servers: add TCP timeouts to avoid memory leaks (#776)
* servers: add TCP timeouts to avoid memory leaks * fix tests * only use RecvTimeOut * servers: simple timeout for TCP transport * revert dependency change * simplify * simplify * simplify 2
This commit is contained in:
committed by
GitHub
parent
16367fcb3b
commit
94540a2c71
@@ -70,7 +70,7 @@ runXFTPServerBlocking :: TMVar Bool -> XFTPServerConfig -> IO ()
|
||||
runXFTPServerBlocking started cfg = newXFTPServerEnv cfg >>= runReaderT (xftpServer cfg started)
|
||||
|
||||
xftpServer :: XFTPServerConfig -> TMVar Bool -> M ()
|
||||
xftpServer cfg@XFTPServerConfig {xftpPort, logTLSErrors} started = do
|
||||
xftpServer cfg@XFTPServerConfig {xftpPort, transportConfig} started = do
|
||||
restoreServerStats
|
||||
raceAny_ (runServer : expireFilesThread_ cfg <> serverStatsThread_ cfg) `finally` stopServer
|
||||
where
|
||||
@@ -79,7 +79,7 @@ xftpServer cfg@XFTPServerConfig {xftpPort, logTLSErrors} started = do
|
||||
serverParams <- asks tlsServerParams
|
||||
env <- ask
|
||||
liftIO $
|
||||
runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams logTLSErrors $ \sessionId r sendResponse -> do
|
||||
runHTTP2Server started xftpPort defaultHTTP2BufferSize serverParams transportConfig $ \sessionId r sendResponse -> do
|
||||
reqBody <- getHTTP2Body r xftpBlockSize
|
||||
processRequest HTTP2Request {sessionId, request = r, reqBody, sendResponse} `runReaderT` env
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ import Simplex.FileTransfer.Server.StoreLog
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (BasicAuth, RcvPublicVerifyKey)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
|
||||
import Simplex.Messaging.Util (tshow)
|
||||
import System.IO (IOMode (..))
|
||||
import UnliftIO.STM
|
||||
@@ -55,7 +55,7 @@ data XFTPServerConfig = XFTPServerConfig
|
||||
logStatsStartTime :: Int64,
|
||||
serverStatsLogFile :: FilePath,
|
||||
serverStatsBackupFile :: Maybe FilePath,
|
||||
logTLSErrors :: Bool
|
||||
transportConfig :: TransportServerConfig
|
||||
}
|
||||
|
||||
data XFTPEnv = XFTPEnv
|
||||
|
||||
@@ -25,6 +25,7 @@ import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern XFTPServer)
|
||||
import Simplex.Messaging.Server.CLI
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport.Client (TransportHost (..))
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig)
|
||||
import System.Directory (createDirectoryIfMissing, doesFileExist)
|
||||
import System.FilePath (combine)
|
||||
import System.IO (BufferMode (..), hSetBuffering, stderr, stdout)
|
||||
@@ -151,7 +152,10 @@ xftpServerCLI cfgPath logPath = do
|
||||
logStatsStartTime = 0, -- seconds from 00:00 UTC
|
||||
serverStatsLogFile = combine logPath "file-server-stats.daily.log",
|
||||
serverStatsBackupFile = logStats $> combine logPath "file-server-stats.log",
|
||||
logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
|
||||
transportConfig =
|
||||
defaultTransportServerConfig
|
||||
{ logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
|
||||
}
|
||||
}
|
||||
|
||||
data CliCommand
|
||||
|
||||
@@ -23,7 +23,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), simplexMQVersion)
|
||||
import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer)
|
||||
import Simplex.Messaging.Transport.Server (loadTLSServerParams, runTransportServer, defaultTransportServerConfig)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import UnliftIO.Async (race_)
|
||||
import qualified UnliftIO.Exception as E
|
||||
@@ -48,7 +48,7 @@ runSMPAgentBlocking (ATransport t) cfg@AgentConfig {tcpPort, caCertificateFile,
|
||||
smpAgent _ = do
|
||||
-- tlsServerParams is not in Env to avoid breaking functional API w/t key and certificate generation
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
runTransportServer started tcpPort tlsServerParams True $ \(h :: c) -> do
|
||||
runTransportServer started tcpPort tlsServerParams defaultTransportServerConfig $ \(h :: c) -> do
|
||||
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
c <- getAgentClient initServers
|
||||
logConnection c True
|
||||
|
||||
@@ -72,7 +72,7 @@ runNtfServerBlocking started cfg = runReaderT (ntfServer cfg started) =<< newNtf
|
||||
type M a = ReaderT NtfEnv IO a
|
||||
|
||||
ntfServer :: NtfServerConfig -> TMVar Bool -> M ()
|
||||
ntfServer cfg@NtfServerConfig {transports, logTLSErrors} started = do
|
||||
ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
|
||||
restoreServerStats
|
||||
s <- asks subscriber
|
||||
ps <- asks pushServer
|
||||
@@ -83,7 +83,7 @@ ntfServer cfg@NtfServerConfig {transports, logTLSErrors} started = do
|
||||
runServer :: (ServiceName, ATransport) -> M ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
runTransportServer started tcpPort serverParams logTLSErrors (runClient t)
|
||||
runTransportServer started tcpPort serverParams tCfg (runClient t)
|
||||
|
||||
runClient :: Transport c => TProxy c -> c -> M ()
|
||||
runClient _ h = do
|
||||
|
||||
@@ -32,7 +32,7 @@ import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
@@ -57,7 +57,7 @@ data NtfServerConfig = NtfServerConfig
|
||||
logStatsStartTime :: Int64,
|
||||
serverStatsLogFile :: FilePath,
|
||||
serverStatsBackupFile :: Maybe FilePath,
|
||||
logTLSErrors :: Bool
|
||||
transportConfig :: TransportServerConfig
|
||||
}
|
||||
|
||||
defaultInactiveClientExpiration :: ExpirationConfig
|
||||
|
||||
@@ -23,6 +23,7 @@ import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientCo
|
||||
import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern NtfServer)
|
||||
import Simplex.Messaging.Server.CLI
|
||||
import Simplex.Messaging.Transport.Client (TransportHost (..))
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig)
|
||||
import System.Directory (createDirectoryIfMissing, doesFileExist)
|
||||
import System.FilePath (combine)
|
||||
import System.IO (BufferMode (..), hSetBuffering, stderr, stdout)
|
||||
@@ -123,7 +124,10 @@ ntfServerCLI cfgPath logPath =
|
||||
logStatsStartTime = 0, -- seconds from 00:00 UTC
|
||||
serverStatsLogFile = combine logPath "ntf-server-stats.daily.log",
|
||||
serverStatsBackupFile = logStats $> combine logPath "ntf-server-stats.log",
|
||||
logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
|
||||
transportConfig =
|
||||
defaultTransportServerConfig
|
||||
{ logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
|
||||
}
|
||||
}
|
||||
|
||||
data CliCommand
|
||||
|
||||
@@ -103,7 +103,7 @@ runSMPServerBlocking started cfg = newEnv cfg >>= runReaderT (smpServer started
|
||||
type M a = ReaderT Env IO a
|
||||
|
||||
smpServer :: TMVar Bool -> ServerConfig -> M ()
|
||||
smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
|
||||
smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
s <- asks server
|
||||
restoreServerMessages
|
||||
restoreServerStats
|
||||
@@ -117,7 +117,7 @@ smpServer started cfg@ServerConfig {transports, logTLSErrors} = do
|
||||
runServer :: (ServiceName, ATransport) -> M ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
runTransportServer started tcpPort serverParams logTLSErrors (runClient t)
|
||||
runTransportServer started tcpPort serverParams tCfg (runClient t)
|
||||
|
||||
serverThread ::
|
||||
forall s.
|
||||
|
||||
@@ -30,7 +30,7 @@ import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams, TransportServerConfig)
|
||||
import Simplex.Messaging.Version
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
@@ -69,7 +69,8 @@ data ServerConfig = ServerConfig
|
||||
certificateFile :: FilePath,
|
||||
-- | SMP client-server protocol version range
|
||||
smpServerVRange :: VersionRange,
|
||||
logTLSErrors :: Bool
|
||||
-- | TCP transport config
|
||||
transportConfig :: TransportServerConfig
|
||||
}
|
||||
|
||||
defMsgExpirationDays :: Int64
|
||||
|
||||
@@ -28,6 +28,7 @@ import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClien
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost (..))
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), defaultTransportServerConfig)
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8)
|
||||
import System.Directory (createDirectoryIfMissing, doesFileExist)
|
||||
import System.FilePath (combine)
|
||||
@@ -198,7 +199,10 @@ smpServerCLI cfgPath logPath =
|
||||
serverStatsLogFile = combine logPath "smp-server-stats.daily.log",
|
||||
serverStatsBackupFile = logStats $> combine logPath "smp-server-stats.log",
|
||||
smpServerVRange = supportedSMPServerVRange,
|
||||
logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
|
||||
transportConfig =
|
||||
defaultTransportServerConfig
|
||||
{ logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini
|
||||
}
|
||||
}
|
||||
|
||||
data CliCommand
|
||||
|
||||
@@ -29,6 +29,7 @@ module Simplex.Messaging.Transport
|
||||
supportedSMPServerVRange,
|
||||
simplexMQVersion,
|
||||
smpBlockSize,
|
||||
TransportConfig (..),
|
||||
|
||||
-- * Transport connection class
|
||||
Transport (..),
|
||||
@@ -104,6 +105,11 @@ simplexMQVersion = showVersion SMQ.version
|
||||
|
||||
-- * Transport connection class
|
||||
|
||||
data TransportConfig = TransportConfig
|
||||
{ logTLSErrors :: Bool,
|
||||
transportTimeout :: Maybe Int
|
||||
}
|
||||
|
||||
class Transport c where
|
||||
transport :: ATransport
|
||||
transport = ATransport (TProxy @c)
|
||||
@@ -112,11 +118,13 @@ class Transport c where
|
||||
|
||||
transportPeer :: c -> TransportPeer
|
||||
|
||||
transportConfig :: c -> TransportConfig
|
||||
|
||||
-- | Upgrade server TLS context to connection (used in the server)
|
||||
getServerConnection :: T.Context -> IO c
|
||||
getServerConnection :: TransportConfig -> T.Context -> IO c
|
||||
|
||||
-- | Upgrade client TLS context to connection (used in the client)
|
||||
getClientConnection :: T.Context -> IO c
|
||||
getClientConnection :: TransportConfig -> T.Context -> IO c
|
||||
|
||||
-- | tls-unique channel binding per RFC5929
|
||||
tlsUnique :: c -> SessionId
|
||||
@@ -150,24 +158,25 @@ data TLS = TLS
|
||||
{ tlsContext :: T.Context,
|
||||
tlsPeer :: TransportPeer,
|
||||
tlsUniq :: ByteString,
|
||||
tlsBuffer :: TBuffer
|
||||
tlsBuffer :: TBuffer,
|
||||
tlsTransportConfig :: TransportConfig
|
||||
}
|
||||
|
||||
connectTLS :: T.TLSParams p => Maybe HostName -> Bool -> p -> Socket -> IO T.Context
|
||||
connectTLS host_ logErrors params sock =
|
||||
connectTLS :: T.TLSParams p => Maybe HostName -> TransportConfig -> p -> Socket -> IO T.Context
|
||||
connectTLS host_ TransportConfig {logTLSErrors} params sock =
|
||||
E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx ->
|
||||
logHandshakeErrors (T.handshake ctx) $> ctx
|
||||
where
|
||||
logHandshakeErrors = if logErrors then (`catchAll` logThrow) else id
|
||||
logHandshakeErrors = if logTLSErrors then (`catchAll` logThrow) else id
|
||||
logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e
|
||||
host = maybe "" (\h -> " (" <> h <> ")") host_
|
||||
|
||||
getTLS :: TransportPeer -> T.Context -> IO TLS
|
||||
getTLS tlsPeer cxt = withTlsUnique tlsPeer cxt newTLS
|
||||
getTLS :: TransportPeer -> TransportConfig -> T.Context -> IO TLS
|
||||
getTLS tlsPeer cfg cxt = withTlsUnique tlsPeer cxt newTLS
|
||||
where
|
||||
newTLS tlsUniq = do
|
||||
tlsBuffer <- atomically newTBuffer
|
||||
pure TLS {tlsContext = cxt, tlsPeer, tlsUniq, tlsBuffer}
|
||||
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsPeer, tlsUniq, tlsBuffer}
|
||||
|
||||
withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c
|
||||
withTlsUnique peer cxt f =
|
||||
@@ -199,6 +208,7 @@ supportedParameters =
|
||||
instance Transport TLS where
|
||||
transportName _ = "TLS"
|
||||
transportPeer = tlsPeer
|
||||
transportConfig = tlsTransportConfig
|
||||
getServerConnection = getTLS TServer
|
||||
getClientConnection = getTLS TClient
|
||||
tlsUnique = tlsUniq
|
||||
@@ -207,10 +217,12 @@ instance Transport TLS where
|
||||
-- https://hackage.haskell.org/package/tls-1.6.0/docs/Network-TLS.html#v:recvData
|
||||
-- this function may return less than requested number of bytes
|
||||
cGet :: TLS -> Int -> IO ByteString
|
||||
cGet TLS {tlsContext, tlsBuffer} n = getBuffered tlsBuffer n (T.recvData tlsContext)
|
||||
|
||||
cGet TLS {tlsContext, tlsBuffer, tlsTransportConfig = TransportConfig {transportTimeout = t_}} n =
|
||||
getBuffered tlsBuffer n t_ (T.recvData tlsContext)
|
||||
|
||||
cPut :: TLS -> ByteString -> IO ()
|
||||
cPut tls = T.sendData (tlsContext tls) . BL.fromStrict
|
||||
cPut TLS {tlsContext, tlsTransportConfig = TransportConfig {transportTimeout = t_}} s =
|
||||
withTimedErr t_ . T.sendData tlsContext $ BL.fromStrict s
|
||||
|
||||
getLn :: TLS -> IO ByteString
|
||||
getLn TLS {tlsContext, tlsBuffer} = do
|
||||
|
||||
@@ -8,6 +8,8 @@ import Control.Concurrent.STM
|
||||
import qualified Control.Exception as E
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import System.Timeout (timeout)
|
||||
import GHC.IO.Exception (ioException, IOException (..), IOErrorType (..))
|
||||
|
||||
data TBuffer = TBuffer
|
||||
{ buffer :: TVar ByteString,
|
||||
@@ -26,22 +28,33 @@ withBufferLock TBuffer {getLock} =
|
||||
(atomically $ takeTMVar getLock)
|
||||
(atomically $ putTMVar getLock ())
|
||||
|
||||
getBuffered :: TBuffer -> Int -> IO ByteString -> IO ByteString
|
||||
getBuffered tb@TBuffer {buffer} n getChunk = withBufferLock tb $ do
|
||||
b <- readChunks =<< readTVarIO buffer
|
||||
getBuffered :: TBuffer -> Int -> Maybe Int -> IO ByteString -> IO ByteString
|
||||
getBuffered tb@TBuffer {buffer} n t_ getChunk = withBufferLock tb $ do
|
||||
b <- readChunks True =<< readTVarIO buffer
|
||||
let (s, b') = B.splitAt n b
|
||||
atomically $ writeTVar buffer $! b'
|
||||
-- This would prevent the need to pad auth tag in HTTP2
|
||||
-- threadDelay 150
|
||||
pure s
|
||||
where
|
||||
readChunks :: ByteString -> IO ByteString
|
||||
readChunks b
|
||||
readChunks :: Bool -> ByteString -> IO ByteString
|
||||
readChunks firstChunk b
|
||||
| B.length b >= n = pure b
|
||||
| otherwise =
|
||||
getChunk >>= \case
|
||||
get >>= \case
|
||||
"" -> pure b
|
||||
s -> readChunks $ b <> s
|
||||
s -> readChunks False $ b <> s
|
||||
where
|
||||
get
|
||||
| firstChunk = getChunk
|
||||
| otherwise = withTimedErr t_ getChunk
|
||||
|
||||
withTimedErr :: Maybe Int -> IO a -> IO a
|
||||
withTimedErr t_ a = case t_ of
|
||||
Just t -> timeout t a >>= maybe err pure
|
||||
Nothing -> a
|
||||
where
|
||||
err = ioException (IOError Nothing TimeExpired "" "get timeout" Nothing Nothing)
|
||||
|
||||
-- This function is only used in test and needs to be improved before it can be used in production,
|
||||
-- it will never complete if TLS connection is closed before there is newline.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -114,12 +115,16 @@ data TransportClientConfig = TransportClientConfig
|
||||
defaultTransportClientConfig :: TransportClientConfig
|
||||
defaultTransportClientConfig = TransportClientConfig Nothing (Just defaultKeepAliveOpts) True
|
||||
|
||||
clientTransportConfig :: TransportClientConfig -> TransportConfig
|
||||
clientTransportConfig TransportClientConfig {logTLSErrors} =
|
||||
TransportConfig {logTLSErrors, transportTimeout = Nothing}
|
||||
|
||||
-- | Connect to passed TCP host:port and pass handle to the client.
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors} proxyUsername host port keyHash client = do
|
||||
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive} proxyUsername host port keyHash client = do
|
||||
let hostName = B.unpack $ strEncode host
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash
|
||||
connectTCP = case socksProxy of
|
||||
@@ -128,7 +133,8 @@ runTLSTransportClient tlsParams caStore_ TransportClientConfig {socksProxy, tcpK
|
||||
c <- liftIO $ do
|
||||
sock <- connectTCP port
|
||||
mapM_ (setSocketKeepAlive sock) tcpKeepAlive
|
||||
connectTLS (Just hostName) logTLSErrors clientParams sock >>= getClientConnection
|
||||
let tCfg = clientTransportConfig cfg
|
||||
connectTLS (Just hostName) tCfg clientParams sock >>= getClientConnection tCfg
|
||||
client c `E.finally` liftIO (closeConnection c)
|
||||
where
|
||||
hostAddr = \case
|
||||
|
||||
@@ -73,7 +73,7 @@ instance HTTP2BodyChunk HS.Request where
|
||||
getHTTP2Body :: HTTP2BodyChunk a => a -> Int -> IO HTTP2Body
|
||||
getHTTP2Body r n = do
|
||||
bodyBuffer <- atomically newTBuffer
|
||||
let getPart n' = getBuffered bodyBuffer n' $ getBodyChunk r
|
||||
let getPart n' = getBuffered bodyBuffer n' Nothing $ getBodyChunk r
|
||||
bodyHead <- getPart n
|
||||
let bodySize = fromMaybe 0 $ getBodySize r
|
||||
-- TODO check bodySize once it is set
|
||||
|
||||
@@ -14,7 +14,7 @@ import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Transport (SessionId)
|
||||
import Simplex.Messaging.Transport.HTTP2
|
||||
import Simplex.Messaging.Transport.Server (loadSupportedTLSServerParams, runTransportServer)
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig (..), loadSupportedTLSServerParams, runTransportServer)
|
||||
|
||||
type HTTP2ServerFunc = SessionId -> Request -> (Response -> IO ()) -> IO ()
|
||||
|
||||
@@ -27,7 +27,7 @@ data HTTP2ServerConfig = HTTP2ServerConfig
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath,
|
||||
logTLSErrors :: Bool
|
||||
transportConfig :: TransportServerConfig
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
@@ -45,12 +45,12 @@ data HTTP2Server = HTTP2Server
|
||||
|
||||
-- This server is for testing only, it processes all requests in a single queue.
|
||||
getHTTP2Server :: HTTP2ServerConfig -> IO HTTP2Server
|
||||
getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, serverSupported, caCertificateFile, certificateFile, privateKeyFile, logTLSErrors} = do
|
||||
getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, serverSupported, caCertificateFile, certificateFile, privateKeyFile, transportConfig} = do
|
||||
tlsServerParams <- loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile
|
||||
started <- newEmptyTMVarIO
|
||||
reqQ <- newTBQueueIO qSize
|
||||
action <- async $
|
||||
runHTTP2Server started http2Port bufferSize tlsServerParams logTLSErrors $ \sessionId r sendResponse -> do
|
||||
runHTTP2Server started http2Port bufferSize tlsServerParams transportConfig $ \sessionId r sendResponse -> do
|
||||
reqBody <- getHTTP2Body r bodyHeadSize
|
||||
atomically $ writeTBQueue reqQ HTTP2Request {sessionId, request = r, reqBody, sendResponse}
|
||||
void . atomically $ takeTMVar started
|
||||
@@ -59,8 +59,8 @@ getHTTP2Server HTTP2ServerConfig {qSize, http2Port, bufferSize, bodyHeadSize, se
|
||||
closeHTTP2Server :: HTTP2Server -> IO ()
|
||||
closeHTTP2Server = uninterruptibleCancel . action
|
||||
|
||||
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> Bool -> HTTP2ServerFunc -> IO ()
|
||||
runHTTP2Server started port bufferSize serverParams logTLSErrors http2Server =
|
||||
runTransportServer started port serverParams logTLSErrors $ withHTTP2 bufferSize run
|
||||
runHTTP2Server :: TMVar Bool -> ServiceName -> BufferSize -> T.ServerParams -> TransportServerConfig -> HTTP2ServerFunc -> IO ()
|
||||
runHTTP2Server started port bufferSize serverParams transportConfig http2Server =
|
||||
runTransportServer started port serverParams transportConfig $ withHTTP2 bufferSize run
|
||||
where
|
||||
run cfg sessId = H.run cfg $ \req _aux sendResp -> http2Server sessId req (`sendResp` [])
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
module Simplex.Messaging.Transport.Server
|
||||
( runTransportServer,
|
||||
runTCPServer,
|
||||
TransportServerConfig (..),
|
||||
defaultTransportServerConfig,
|
||||
loadSupportedTLSServerParams,
|
||||
loadTLSServerParams,
|
||||
loadFingerprint,
|
||||
@@ -38,15 +40,32 @@ import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
data TransportServerConfig = TransportServerConfig
|
||||
{ logTLSErrors :: Bool,
|
||||
transportTimeout :: Int
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
defaultTransportServerConfig :: TransportServerConfig
|
||||
defaultTransportServerConfig = TransportServerConfig
|
||||
{ logTLSErrors = True,
|
||||
transportTimeout = 40000000
|
||||
}
|
||||
|
||||
serverTransportConfig :: TransportServerConfig -> TransportConfig
|
||||
serverTransportConfig TransportServerConfig {logTLSErrors, transportTimeout} =
|
||||
TransportConfig {logTLSErrors, transportTimeout = Just transportTimeout}
|
||||
|
||||
-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
|
||||
--
|
||||
-- All accepted connections are passed to the passed function.
|
||||
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> Bool -> (c -> m ()) -> m ()
|
||||
runTransportServer started port serverParams logTLSErrors server = do
|
||||
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> TransportServerConfig -> (c -> m ()) -> m ()
|
||||
runTransportServer started port serverParams cfg server = do
|
||||
u <- askUnliftIO
|
||||
let tCfg = serverTransportConfig cfg
|
||||
liftIO . runTCPServer started port $ \conn ->
|
||||
E.bracket
|
||||
(connectTLS Nothing logTLSErrors serverParams conn >>= getServerConnection)
|
||||
(connectTLS Nothing tCfg serverParams conn >>= getServerConnection tCfg)
|
||||
closeConnection
|
||||
(unliftIO u . server)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import Simplex.Messaging.Transport
|
||||
Transport (..),
|
||||
TransportError (..),
|
||||
TransportPeer (..),
|
||||
TransportConfig (..),
|
||||
closeTLS,
|
||||
smpBlockSize,
|
||||
withTlsUnique,
|
||||
@@ -27,7 +28,8 @@ data WS = WS
|
||||
{ wsPeer :: TransportPeer,
|
||||
tlsUniq :: ByteString,
|
||||
wsStream :: Stream,
|
||||
wsConnection :: Connection
|
||||
wsConnection :: Connection,
|
||||
wsTransportConfig :: TransportConfig
|
||||
}
|
||||
|
||||
websocketsOpts :: ConnectionOptions
|
||||
@@ -45,10 +47,13 @@ instance Transport WS where
|
||||
transportPeer :: WS -> TransportPeer
|
||||
transportPeer = wsPeer
|
||||
|
||||
getServerConnection :: T.Context -> IO WS
|
||||
transportConfig :: WS -> TransportConfig
|
||||
transportConfig = wsTransportConfig
|
||||
|
||||
getServerConnection :: TransportConfig -> T.Context -> IO WS
|
||||
getServerConnection = getWS TServer
|
||||
|
||||
getClientConnection :: T.Context -> IO WS
|
||||
getClientConnection :: TransportConfig -> T.Context -> IO WS
|
||||
getClientConnection = getWS TClient
|
||||
|
||||
tlsUnique :: WS -> ByteString
|
||||
@@ -74,13 +79,13 @@ instance Transport WS where
|
||||
then E.throwIO TEBadBlock
|
||||
else pure $ B.init s
|
||||
|
||||
getWS :: TransportPeer -> T.Context -> IO WS
|
||||
getWS wsPeer cxt = withTlsUnique wsPeer cxt connectWS
|
||||
getWS :: TransportPeer -> TransportConfig -> T.Context -> IO WS
|
||||
getWS wsPeer cfg cxt = withTlsUnique wsPeer cxt connectWS
|
||||
where
|
||||
connectWS tlsUniq = do
|
||||
s <- makeTLSContextStream cxt
|
||||
wsConnection <- connectPeer wsPeer s
|
||||
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection}
|
||||
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg}
|
||||
connectPeer :: TransportPeer -> Stream -> IO Connection
|
||||
connectPeer TServer = acceptClientRequest
|
||||
connectPeer TClient = sendClientRequest
|
||||
|
||||
@@ -44,6 +44,7 @@ import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..), http2TLSParams)
|
||||
import Simplex.Messaging.Transport.HTTP2.Server
|
||||
import Simplex.Messaging.Transport.Server
|
||||
import Test.Hspec
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Concurrent
|
||||
@@ -99,7 +100,7 @@ ntfServerCfg =
|
||||
logStatsStartTime = 0,
|
||||
serverStatsLogFile = "tests/ntf-server-stats.daily.log",
|
||||
serverStatsBackupFile = Nothing,
|
||||
logTLSErrors = True
|
||||
transportConfig = defaultTransportServerConfig
|
||||
}
|
||||
|
||||
withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a
|
||||
@@ -166,7 +167,7 @@ apnsMockServerConfig =
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt",
|
||||
logTLSErrors = True
|
||||
transportConfig = defaultTransportServerConfig
|
||||
}
|
||||
|
||||
withAPNSMockServer :: (APNSMockServer -> IO ()) -> IO ()
|
||||
|
||||
@@ -24,6 +24,7 @@ import Simplex.Messaging.Server (runSMPServerBlocking)
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.Server
|
||||
import Simplex.Messaging.Version
|
||||
import System.Environment (lookupEnv)
|
||||
import System.Info (os)
|
||||
@@ -99,7 +100,7 @@ cfg =
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt",
|
||||
smpServerVRange = supportedSMPServerVRange,
|
||||
logTLSErrors = True
|
||||
transportConfig = defaultTransportServerConfig
|
||||
}
|
||||
|
||||
withSmpServerStoreMsgLogOnV2 :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
|
||||
|
||||
@@ -14,6 +14,7 @@ import Simplex.FileTransfer.Description
|
||||
import Simplex.FileTransfer.Server (runXFTPServerBlocking)
|
||||
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..), defaultFileExpiration)
|
||||
import Simplex.Messaging.Protocol (XFTPServer)
|
||||
import Simplex.Messaging.Transport.Server
|
||||
import Test.Hspec
|
||||
|
||||
xftpTest :: HasCallStack => (HasCallStack => XFTPClient -> IO ()) -> Expectation
|
||||
@@ -111,7 +112,7 @@ testXFTPServerConfig =
|
||||
logStatsStartTime = 0,
|
||||
serverStatsLogFile = "tests/tmp/xftp-server-stats.daily.log",
|
||||
serverStatsBackupFile = Nothing,
|
||||
logTLSErrors = True
|
||||
transportConfig = defaultTransportServerConfig
|
||||
}
|
||||
|
||||
testXFTPClientConfig :: XFTPClientConfig
|
||||
|
||||
Reference in New Issue
Block a user