create TLS ServerParams once per server run; remove tlsServerParams from agent env (fixes functional agent client for chat) (#223)

This commit is contained in:
Efim Poberezkin
2021-12-15 19:03:34 +04:00
committed by GitHub
parent 5aa0e97cd9
commit bcf5e25cab
5 changed files with 33 additions and 35 deletions
+5 -4
View File
@@ -85,7 +85,7 @@ import Simplex.Messaging.Client (SMPServerTransmission)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (MsgBody, SndPublicVerifyKey)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), currentSMPVersionStr, runTransportServer)
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), currentSMPVersionStr, loadTLSServerParams, runTransportServer)
import Simplex.Messaging.Util (bshow, tryError, unlessM)
import System.Random (randomR)
import UnliftIO.Async (async, race_)
@@ -105,13 +105,14 @@ runSMPAgent t cfg = do
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = do
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, agentCertificateFile, agentPrivateKeyFile} = do
runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
where
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
smpAgent _ = do
credential <- asks agentCredential
runTransportServer started tcpPort credential $ \(h :: c) -> do
-- tlsServerParams not in env to avoid breaking functional api w/t key and certificate generation
tlsServerParams <- liftIO $ loadTLSServerParams agentCertificateFile agentPrivateKeyFile
runTransportServer started tcpPort tlsServerParams $ \(h :: c) -> do
liftIO . putLn h $ "Welcome to SMP agent v" <> currentSMPVersionStr
c <- getAgentClient
logConnection c True
+2 -6
View File
@@ -10,7 +10,6 @@ import Control.Monad.IO.Unlift
import Crypto.Random
import Data.List.NonEmpty (NonEmpty)
import Network.Socket
import qualified Network.TLS as T
import Numeric.Natural
import Simplex.Messaging.Agent.Protocol (SMPServer)
import Simplex.Messaging.Agent.RetryInterval
@@ -18,7 +17,6 @@ import Simplex.Messaging.Agent.Store.SQLite
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Transport (loadServerCredential)
import System.Random (StdGen, newStdGen)
import UnliftIO.STM
@@ -76,8 +74,7 @@ data Env = Env
idsDrg :: TVar ChaChaDRG,
clientCounter :: TVar Int,
reservedMsgSize :: Int,
randomServer :: TVar StdGen,
agentCredential :: T.Credential
randomServer :: TVar StdGen
}
newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
@@ -86,8 +83,7 @@ newSMPAgentEnv cfg = do
store <- liftIO $ createSQLiteStore (dbFile cfg) (dbPoolSize cfg) Migrations.app
clientCounter <- newTVarIO 0
randomServer <- newTVarIO =<< liftIO newStdGen
agentCredential <- liftIO $ loadServerCredential (agentPrivateKeyFile cfg) (agentCertificateFile cfg)
return Env {config = cfg, store, idsDrg, clientCounter, reservedMsgSize, randomServer, agentCredential}
return Env {config = cfg, store, idsDrg, clientCounter, reservedMsgSize, randomServer}
where
-- 1st rsaKeySize is used by the RSA signature in each command,
-- 2nd - by encrypted message body header
+2 -2
View File
@@ -83,8 +83,8 @@ runSMPServerBlocking started cfg@ServerConfig {transports} = do
runServer :: (MonadUnliftIO m', MonadReader Env m') => (ServiceName, ATransport) -> m' ()
runServer (tcpPort, ATransport t) = do
credential <- asks serverCredential
runTransportServer started tcpPort credential (runClient t)
serverParams <- asks tlsServerParams
runTransportServer started tcpPort serverParams (runClient t)
serverThread ::
forall m' s.
+4 -4
View File
@@ -19,7 +19,7 @@ import Simplex.Messaging.Server.MsgStore.STM
import Simplex.Messaging.Server.QueueStore (QueueRec (..))
import Simplex.Messaging.Server.QueueStore.STM
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.Transport (ATransport, SessionId, loadServerCredential)
import Simplex.Messaging.Transport (ATransport, SessionId, loadTLSServerParams)
import System.IO (IOMode (..))
import UnliftIO.STM
@@ -45,7 +45,7 @@ data Env = Env
idsDrg :: TVar ChaChaDRG,
serverKeyPair :: C.KeyPair 'C.RSA, -- TODO delete
storeLog :: Maybe (StoreLog 'WriteMode),
serverCredential :: T.Credential
tlsServerParams :: T.ServerParams
}
data Server = Server
@@ -100,8 +100,8 @@ newEnv config = do
s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig)
let pk = serverPrivateKey config -- TODO remove
serverKeyPair = (C.publicKey pk, pk)
serverCredential <- liftIO $ loadServerCredential (serverPrivateKeyFile config) (serverCertificateFile config)
return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s', serverCredential}
tlsServerParams <- liftIO $ loadTLSServerParams (serverCertificateFile config) (serverPrivateKeyFile config)
return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s', tlsServerParams}
where
restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode)
restoreQueues queueStore s = do
+20 -19
View File
@@ -33,7 +33,7 @@ module Simplex.Messaging.Transport
-- * Transport over TLS 1.3
runTransportServer,
runTransportClient,
loadServerCredential,
loadTLSServerParams,
-- * TLS 1.3 Transport
TLS (..),
@@ -132,8 +132,8 @@ data ATransport = forall c. Transport c => ATransport (TProxy c)
-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
--
-- All accepted connections are passed to the passed function.
runTransportServer :: (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.Credential -> (c -> m ()) -> m ()
runTransportServer started port credential server = do
runTransportServer :: (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> (c -> m ()) -> m ()
runTransportServer started port serverParams server = do
clients <- newTVarIO S.empty
E.bracket
(liftIO $ startTCPServer started port)
@@ -151,7 +151,6 @@ runTransportServer started port credential server = do
acceptConnection :: Transport c => Socket -> IO c
acceptConnection sock = do
(newSock, _) <- accept sock
let serverParams = mkServerParams credential
connectTLS "server" getServerConnection serverParams newSock
startTCPServer :: TMVar Bool -> ServiceName -> IO Socket
@@ -197,12 +196,23 @@ startTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
connect sock $ addrAddress addr
connectTLS "client" getClientConnection clientParams sock
-- TODO non lazy
loadServerCredential :: FilePath -> FilePath -> IO T.Credential
loadServerCredential privateKeyFile certificateFile =
T.credentialLoadX509 certificateFile privateKeyFile >>= \case
Right cert -> pure cert
Left _ -> putStrLn "invalid credential" >> exitFailure
loadTLSServerParams :: FilePath -> FilePath -> IO T.ServerParams
loadTLSServerParams certificateFile privateKeyFile =
fromCredential <$> loadServerCredential
where
loadServerCredential :: IO T.Credential
loadServerCredential =
T.credentialLoadX509 certificateFile privateKeyFile >>= \case
Right credential -> pure credential
Left _ -> putStrLn "invalid credential" >> exitFailure
fromCredential :: T.Credential -> T.ServerParams
fromCredential credential =
def
{ T.serverWantClientCert = False,
T.serverShared = def {T.sharedCredentials = T.Credentials [credential]},
T.serverHooks = def,
T.serverSupported = supportedParameters
}
-- * TLS 1.3 Transport
@@ -222,15 +232,6 @@ closeTLS ctx =
(T.bye ctx >> T.contextClose ctx) -- sometimes socket was closed before 'TLS.bye'
`E.catch` (\(_ :: E.SomeException) -> pure ()) -- so we catch the 'Broken pipe' error here
mkServerParams :: T.Credential -> T.ServerParams
mkServerParams credential =
def
{ T.serverWantClientCert = False,
T.serverShared = def {T.sharedCredentials = T.Credentials [credential]},
T.serverHooks = def,
T.serverSupported = supportedParameters
}
clientParams :: T.ClientParams
clientParams =
(T.defaultParamsClient "localhost" "5223")