transport: fetch and store server certificate (#985)

* THandleParams (WIP, does not compile)

* transport: fetch and store server certificate

* smp: add getOnlinePubKey example to smpClientHandshake

* add server certs and sign authPub

* cleanup

* update

* style

* load server certs from test fixtures

* sign ntf authPubKey

* fix onServerCertificate

* increase delay before sending messages

* require certificate with key in SMP server handshake

---------

Co-authored-by: Evgeny Poberezkin <evgeny@poberezkin.com>
This commit is contained in:
Alexander Bondarenko
2024-02-13 07:02:03 -08:00
committed by GitHub
parent 6aec0b13fd
commit 76eddfbc9d
11 changed files with 184 additions and 61 deletions

View File

@@ -8,6 +8,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
@@ -73,9 +74,12 @@ module Simplex.Messaging.Crypto
generateAuthKeyPair,
generateDhKeyPair,
privateToX509,
x509ToPublic,
x509ToPrivate,
publicKey,
signatureKeyPair,
publicToX509,
encodeASNObj,
-- * key encoding/decoding
encodePubKey,
@@ -162,10 +166,13 @@ module Simplex.Messaging.Crypto
Certificate,
signCertificate,
signX509,
verifyX509,
certificateFingerprint,
signedFingerprint,
SignatureAlgorithmX509 (..),
SignedObject (..),
encodeCertChain,
certChainP,
-- * Cryptography error type
CryptoError (..),
@@ -210,6 +217,7 @@ import qualified Data.ByteString.Char8 as B
import Data.ByteString.Lazy (fromStrict, toStrict)
import Data.Constraint (Dict (..))
import Data.Kind (Constraint, Type)
import qualified Data.List.NonEmpty as L
import Data.String
import Data.Type.Equality
import Data.Typeable (Proxy (Proxy), Typeable)
@@ -1137,6 +1145,18 @@ signX509 key = fst . objectToSignedExact f
signatureAlgorithmX509 key,
()
)
{-# INLINE signX509 #-}
verifyX509 :: (ASN1Object o, Eq o, Show o) => APublicVerifyKey -> SignedExact o -> Either String o
verifyX509 key exact = do
signature <- case signedAlg of
SignatureALG_IntrinsicHash PubKeyALG_Ed25519 -> ASignature SEd25519 <$> decodeSignature signedSignature
SignatureALG_IntrinsicHash PubKeyALG_Ed448 -> ASignature SEd448 <$> decodeSignature signedSignature
_ -> Left "unknown x509 signature algorithm"
if verify key signature $ getSignedData exact then Right signedObject else Left "bad signature"
where
Signed {signedObject, signedAlg, signedSignature} = getSigned exact
{-# INLINE verifyX509 #-}
certificateFingerprint :: SignedCertificate -> KeyHash
certificateFingerprint = signedFingerprint
@@ -1165,7 +1185,7 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where
signatureAlgorithmX509 = signatureAlgorithmX509 . snd
-- | A wrapper to marshall signed ASN1 objects, like certificates.
newtype SignedObject a = SignedObject (SignedExact a)
newtype SignedObject a = SignedObject {getSignedExact :: SignedExact a}
instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where
fromField = fmap SignedObject . blobFieldDecoder decodeSignedObject
@@ -1173,6 +1193,20 @@ instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a)
instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where
toField (SignedObject s) = toField $ encodeSignedObject s
instance (Eq a, Show a, ASN1Object a) => Encoding (SignedObject a) where
smpEncode (SignedObject exact) = smpEncode . Large $ encodeSignedObject exact
smpP = fmap SignedObject . decodeSignedObject . unLarge <$?> smpP
encodeCertChain :: CertificateChain -> L.NonEmpty Large
encodeCertChain cc = L.fromList $ map Large blobs
where
CertificateChainRaw blobs = encodeCertificateChain cc
certChainP :: A.Parser CertificateChain
certChainP = do
rawChain <- CertificateChainRaw . map unLarge . L.toList <$> smpP
either (fail . show) pure $ decodeCertificateChain rawChain
-- | Signature verification.
--
-- Used by SMP servers to authorize SMP commands and by SMP agents to verify messages.

View File

@@ -49,7 +49,7 @@ import Simplex.Messaging.Server
import Simplex.Messaging.Server.Stats
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport (..), THandle (..), THandleAuth (..), THandleParams (..), TProxy, Transport (..))
import Simplex.Messaging.Transport.Server (runTransportServer)
import Simplex.Messaging.Transport.Server (runTransportServer, tlsServerCredentials)
import Simplex.Messaging.Util
import System.Exit (exitFailure)
import System.IO (BufferMode (..), hPutStrLn, hSetBuffering)
@@ -82,14 +82,16 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
runServer :: (ServiceName, ATransport) -> M ()
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
runTransportServer started tcpPort serverParams tCfg (runClient t)
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
runTransportServer started tcpPort serverParams tCfg (runClient serverSignKey t)
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
runClient :: Transport c => TProxy c -> c -> M ()
runClient _ h = do
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
runClient signKey _ h = do
kh <- asks serverIdentity
ks <- atomically . C.generateKeyPair =<< asks random
NtfServerConfig {ntfServerVRange} <- asks config
liftIO (runExceptT $ ntfServerHandshake h ks kh ntfServerVRange) >>= \case
liftIO (runExceptT $ ntfServerHandshake signKey h ks kh ntfServerVRange) >>= \case
Right th -> runNtfClientTransport th
Left _ -> pure ()

View File

@@ -7,13 +7,17 @@
module Simplex.Messaging.Notifications.Transport where
import Control.Monad (forM)
import Control.Monad.Except
import Data.Attoparsec.ByteString.Char8 (Parser)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.X509 as X
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Transport
import Simplex.Messaging.Version
import Simplex.Messaging.Util (liftEitherWith)
ntfBlockSize :: Int
ntfBlockSize = 512
@@ -40,7 +44,7 @@ data NtfServerHandshake = NtfServerHandshake
{ ntfVersionRange :: VersionRange,
sessionId :: SessionId,
-- pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: Maybe C.PublicKeyX25519
authPubKey :: Maybe (X.SignedExact X.PubKey)
}
data NtfClientHandshake = NtfClientHandshake
@@ -54,13 +58,25 @@ data NtfClientHandshake = NtfClientHandshake
instance Encoding NtfServerHandshake where
smpEncode NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} =
smpEncode (ntfVersionRange, sessionId) <> encodeNtfAuthPubKey (maxVersion ntfVersionRange) authPubKey
B.concat
[ smpEncode (ntfVersionRange, sessionId),
encodeAuthEncryptCmds (maxVersion ntfVersionRange) $ C.SignedObject <$> authPubKey
]
smpP = do
(ntfVersionRange, sessionId) <- smpP
-- TODO drop SMP v6: remove special parser and make key non-optional
authPubKey <- ntfAuthPubKeyP $ maxVersion ntfVersionRange
authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP
pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey}
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
encodeAuthEncryptCmds v k
| v >= authEncryptCmdsNTFVersion = maybe "" smpEncode k
| otherwise = ""
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
authEncryptCmdsP v p = if v >= authEncryptCmdsNTFVersion then Just <$> p else pure Nothing
instance Encoding NtfClientHandshake where
smpEncode NtfClientHandshake {ntfVersion, keyHash, authPubKey} =
smpEncode (ntfVersion, keyHash) <> encodeNtfAuthPubKey ntfVersion authPubKey
@@ -79,10 +95,11 @@ encodeNtfAuthPubKey v k
| otherwise = ""
-- | Notifcations server transport handshake.
ntfServerHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfServerHandshake c (k, pk) kh ntfVRange = do
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange, authPubKey = Just k}
let sk = C.signX509 serverSignKey $ C.publicToX509 k
sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange, authPubKey = Just sk}
getHandshake th >>= \case
NtfClientHandshake {ntfVersion = v, keyHash, authPubKey = k'}
| keyHash /= kh ->
@@ -95,13 +112,17 @@ ntfServerHandshake c (k, pk) kh ntfVRange = do
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfClientHandshake c (k, pk) keyHash ntfVRange = do
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = k'} <- getHandshake th
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th
if sessionId /= sessId
then throwError TEBadSession
else case ntfVersionRange `compatibleVersion` ntfVRange of
Just (Compatible v) -> do
sk_ <- forM sk' $ \exact -> liftEitherWith (const $ TEHandshake BAD_AUTH) $ do
serverKey <- getServerVerifyKey c
pubKey <- C.verifyX509 serverKey exact
C.x509ToPublic (pubKey, []) >>= C.pubKey
sendHandshake th $ NtfClientHandshake {ntfVersion = v, keyHash, authPubKey = Just k}
pure $ ntfThHandle th v pk k'
pure $ ntfThHandle th v pk sk_
Nothing -> throwError $ TEHandshake VERSION
ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c

View File

@@ -133,7 +133,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
ss <- asks sockets
runTransportServerState ss started tcpPort serverParams tCfg (runClient t)
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
runTransportServerState ss started tcpPort serverParams tCfg (runClient serverSignKey t)
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
saveServer :: Bool -> M ()
saveServer keepMsgs = withLog closeStoreLog >> saveServerMessages keepMsgs >> saveServerStats
@@ -244,13 +246,13 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
]
liftIO $ threadDelay' interval
runClient :: Transport c => TProxy c -> c -> M ()
runClient tp h = do
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
runClient signKey tp h = do
kh <- asks serverIdentity
ks <- atomically . C.generateKeyPair =<< asks random
ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config
labelMyThread $ "smp handshake for " <> transportName tp
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake h ks kh smpServerVRange) >>= \case
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake signKey h ks kh smpServerVRange) >>= \case
Just (Right th) -> runClientTransport th
_ -> pure ()

View File

@@ -44,6 +44,7 @@ module Simplex.Messaging.Transport
TProxy (..),
ATransport (..),
TransportPeer (..),
getServerVerifyKey,
-- * TLS Transport
TLS (..),
@@ -71,12 +72,13 @@ module Simplex.Messaging.Transport
where
import Control.Applicative ((<|>))
import Control.Monad (forM)
import Control.Monad.Except
import Control.Monad.Trans.Except (throwE)
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.Bifunctor (bimap, first)
import Data.Bitraversable (bimapM)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -84,6 +86,8 @@ import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Default (def)
import Data.Functor (($>))
import Data.Version (showVersion)
import qualified Data.X509 as X
import qualified Data.X509.Validation as XV
import GHC.IO.Handle.Internals (ioe_EOF)
import Network.Socket
import qualified Network.TLS as T
@@ -93,7 +97,7 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON)
import Simplex.Messaging.Transport.Buffer
import Simplex.Messaging.Util (bshow, catchAll, catchAll_)
import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith)
import Simplex.Messaging.Version
import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
@@ -163,10 +167,12 @@ class Transport c where
transportConfig :: c -> TransportConfig
-- | Upgrade server TLS context to connection (used in the server)
getServerConnection :: TransportConfig -> T.Context -> IO c
getServerConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO c
-- | Upgrade client TLS context to connection (used in the client)
getClientConnection :: TransportConfig -> T.Context -> IO c
getClientConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO c
getServerCerts :: c -> X.CertificateChain
-- | tls-unique channel binding per RFC5929
tlsUnique :: c -> SessionId
@@ -194,6 +200,12 @@ data TProxy c = TProxy
data ATransport = forall c. Transport c => ATransport (TProxy c)
getServerVerifyKey :: Transport c => c -> Either String C.APublicVerifyKey
getServerVerifyKey c =
case getServerCerts c of
X.CertificateChain (server : _ca) -> C.x509ToPublic (X.certPubKey . X.signedObject $ X.getSigned server, []) >>= C.pubKey
_ -> Left "no certificate chain"
-- * TLS Transport
data TLS = TLS
@@ -201,6 +213,7 @@ data TLS = TLS
tlsPeer :: TransportPeer,
tlsUniq :: ByteString,
tlsBuffer :: TBuffer,
tlsServerCerts :: X.CertificateChain,
tlsTransportConfig :: TransportConfig
}
@@ -213,12 +226,12 @@ connectTLS host_ TransportConfig {logTLSErrors} params sock =
logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e
host = maybe "" (\h -> " (" <> h <> ")") host_
getTLS :: TransportPeer -> TransportConfig -> T.Context -> IO TLS
getTLS tlsPeer cfg cxt = withTlsUnique tlsPeer cxt newTLS
getTLS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> IO TLS
getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS
where
newTLS tlsUniq = do
tlsBuffer <- atomically newTBuffer
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsPeer, tlsUniq, tlsBuffer}
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer}
withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c
withTlsUnique peer cxt f =
@@ -253,6 +266,7 @@ instance Transport TLS where
transportConfig = tlsTransportConfig
getServerConnection = getTLS TServer
getClientConnection = getTLS TClient
getServerCerts = tlsServerCerts
tlsUnique = tlsUniq
closeConnection tls = closeTLS $ tlsContext tls
@@ -280,7 +294,7 @@ instance Transport TLS where
data THandle c = THandle
{ connection :: c,
params :: THandleParams
}
}
data THandleParams = THandleParams
{ sessionId :: SessionId,
@@ -301,7 +315,7 @@ data THandleParams = THandleParams
data THandleAuth = THandleAuth
{ peerPubKey :: C.PublicKeyX25519, -- used only in the client to combine with per-queue key
privKey :: C.PrivateKeyX25519, -- used to combine with peer's per-queue key (currently only in the server)
dhSecret :: C.DhSecretX25519 -- used by both parties to encrypt entity IDs in for version >= 7
dhSecret :: C.DhSecretX25519 -- used by both parties to encrypt entity IDs in for version >= 7
}
-- | TLS-unique channel binding
@@ -311,7 +325,7 @@ data ServerHandshake = ServerHandshake
{ smpVersionRange :: VersionRange,
sessionId :: SessionId,
-- pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: Maybe C.PublicKeyX25519
authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey)
}
data ClientHandshake = ClientHandshake
@@ -325,30 +339,39 @@ data ClientHandshake = ClientHandshake
instance Encoding ClientHandshake where
smpEncode ClientHandshake {smpVersion, keyHash, authPubKey} =
smpEncode (smpVersion, keyHash) <> encodeAuthPubKey smpVersion authPubKey
smpEncode (smpVersion, keyHash) <> encodeAuthEncryptCmds smpVersion authPubKey
smpP = do
(smpVersion, keyHash) <- smpP
-- TODO drop SMP v6: remove special parser and make key non-optional
authPubKey <- authPubKeyP smpVersion
authPubKey <- authEncryptCmdsP smpVersion smpP
pure ClientHandshake {smpVersion, keyHash, authPubKey}
instance Encoding ServerHandshake where
smpEncode ServerHandshake {smpVersionRange, sessionId, authPubKey} =
smpEncode (smpVersionRange, sessionId) <> encodeAuthPubKey (maxVersion smpVersionRange) authPubKey
smpEncode (smpVersionRange, sessionId) <> auth
where
auth =
encodeAuthEncryptCmds (maxVersion smpVersionRange) $
bimap C.encodeCertChain C.SignedObject <$> authPubKey
smpP = do
(smpVersionRange, sessionId) <- smpP
-- TODO drop SMP v6: remove special parser and make key non-optional
authPubKey <- authPubKeyP $ maxVersion smpVersionRange
authPubKey <- authEncryptCmdsP (maxVersion smpVersionRange) authP
pure ServerHandshake {smpVersionRange, sessionId, authPubKey}
authPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519)
authPubKeyP v = if v >= authEncryptCmdsSMPVersion then Just <$> smpP else pure Nothing
where
authP = do
cert <- C.certChainP
C.SignedObject key <- smpP
pure (cert, key)
encodeAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString
encodeAuthPubKey v k
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
encodeAuthEncryptCmds v k
| v >= authEncryptCmdsSMPVersion = maybe "" smpEncode k
| otherwise = ""
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
authEncryptCmdsP v p = if v >= authEncryptCmdsSMPVersion then Just <$> p else pure Nothing
-- | Error of SMP encrypted transport over TCP.
data TransportError
= -- | error parsing transport block
@@ -372,6 +395,8 @@ data HandshakeError
VERSION
| -- | incorrect server identity
IDENTITY
| -- | v7 authentication failed
BAD_AUTH
deriving (Eq, Read, Show, Exception)
-- | SMP encrypted transport error parser.
@@ -409,10 +434,12 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
-- | Server SMP transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
smpServerHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpServerHandshake c (k, pk) kh smpVRange = do
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange, authPubKey = Just k}
sk = C.signX509 serverSignKey $ C.publicToX509 k
certChain = getServerCerts c
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange, authPubKey = Just (certChain, sk)}
getHandshake th >>= \case
ClientHandshake {smpVersion = v, keyHash, authPubKey = k'}
| keyHash /= kh ->
@@ -425,15 +452,23 @@ smpServerHandshake c (k, pk) kh smpVRange = do
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpClientHandshake c (k, pk) keyHash smpVRange = do
smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey = k'} <- getHandshake th
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th
if sessionId /= sessId
then throwE TEBadSession
else case smpVersionRange `compatibleVersion` smpVRange of
Just (Compatible v) -> do
sk_ <- forM authPubKey $ \(X.CertificateChain cert, exact) ->
liftEitherWith (const $ TEHandshake BAD_AUTH) $ do
case cert of
[_leaf, ca] | XV.Fingerprint kh == XV.getFingerprint ca X.HashSHA256 -> pure ()
_ -> throwError "bad certificate"
serverKey <- getServerVerifyKey c
pubKey <- C.verifyX509 serverKey exact
C.x509ToPublic (pubKey, []) >>= C.pubKey
sendHandshake th $ ClientHandshake {smpVersion = v, keyHash, authPubKey = Just k}
pure $ smpThHandle th v pk k'
pure $ smpThHandle th v pk sk_
Nothing -> throwE $ TEHandshake VERSION
smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c

View File

@@ -22,6 +22,7 @@ where
import Control.Applicative (optional)
import Control.Logger.Simple (logError)
import Control.Monad (when)
import Control.Monad.IO.Unlift
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Attoparsec.ByteString.Char8 as A
@@ -54,6 +55,7 @@ import System.IO.Error
import Text.Read (readMaybe)
import UnliftIO.Exception (IOException)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
data TransportHost
= THIPv4 (Word8, Word8, Word8, Word8)
@@ -129,8 +131,9 @@ runTransportClient = runTLSTransportClient supportedParameters Nothing
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do
serverCert <- newEmptyTMVarIO
let hostName = B.unpack $ strEncode host
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials serverCert
connectTCP = case socksProxy of
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
_ -> connectTCPClient hostName
@@ -138,7 +141,13 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
sock <- connectTCP port
mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e)
let tCfg = clientTransportConfig cfg
connectTLS (Just hostName) tCfg clientParams sock >>= getClientConnection tCfg
connectTLS (Just hostName) tCfg clientParams sock >>= \tls -> do
chain <- atomically (tryTakeTMVar serverCert) >>= \case
Nothing -> do
logError "onServerCertificate didn't fire or failed to get cert chain"
closeTLS tls >> error "onServerCertificate failed"
Just c -> pure c
getClientConnection tCfg chain tls
client c `E.finally` liftIO (closeConnection c)
where
hostAddr = \case
@@ -207,19 +216,24 @@ instance ToJSON SocksProxy where
instance FromJSON SocksProxy where
parseJSON = strParseJSON "SocksProxy"
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> T.ClientParams
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ =
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> TMVar X.CertificateChain -> T.ClientParams
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts =
(T.defaultParamsClient host p)
{ T.clientShared = def {T.sharedCAStore = fromMaybe (T.sharedCAStore def) caStore_},
T.clientHooks =
def
{ T.onServerCertificate = maybe def (\cafp _ _ _ -> validateCertificateChain cafp host p) cafp_,
{ T.onServerCertificate = onServerCert,
T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_
},
T.clientSupported = supported
}
where
p = B.pack port
onServerCert _ _ _ c = do
errs <- maybe def (\ca -> validateCertificateChain ca host p c) cafp_
when (null errs) $
atomically (putTMVar serverCerts c)
pure errs
validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason]
validateCertificateChain _ _ _ (X.CertificateChain []) = pure [XV.EmptyChain]

View File

@@ -19,6 +19,7 @@ module Simplex.Messaging.Transport.Server
loadTLSServerParams,
loadFingerprint,
smpServerHandshake,
tlsServerCredentials
)
where
@@ -78,13 +79,13 @@ runTransportServerState :: forall c m. (Transport c, MonadUnliftIO m) => SocketS
runTransportServerState ss started port = runTransportServerSocketState ss started (startTCPServer started port) (transportName (TProxy :: TProxy c))
-- | Run a transport server with provided connection setup and handler.
runTransportServerSocket :: (MonadUnliftIO m, T.TLSParams p, Transport a) => TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocket :: (MonadUnliftIO m, Transport a) => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocket started getSocket threadLabel serverParams cfg server = do
ss <- atomically newSocketState
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server
-- | Run a transport server with provided connection setup and handler.
runTransportServerSocketState :: (MonadUnliftIO m, T.TLSParams p, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocketState :: (MonadUnliftIO m, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server = do
u <- askUnliftIO
labelMyThread $ "transport server for " <> threadLabel
@@ -95,7 +96,12 @@ runTransportServerSocketState ss started getSocket threadLabel serverParams cfg
setup conn = timeout (tlsSetupTimeout cfg) $ do
labelMyThread $ threadLabel <> "/setup"
tls <- connectTLS Nothing tCfg serverParams conn
getServerConnection tCfg tls
getServerConnection tCfg (fst $ tlsServerCredentials serverParams) tls
tlsServerCredentials :: T.ServerParams -> (X.CertificateChain, X.PrivKey)
tlsServerCredentials serverParams = case T.sharedCredentials $ T.serverShared serverParams of
T.Credentials [creds] -> creds
_ -> error "server has more than one key"
-- | Run TCP server without TLS
runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()

View File

@@ -8,6 +8,7 @@ import qualified Control.Exception as E
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.X509 as X
import qualified Network.TLS as T
import Network.WebSockets
import Network.WebSockets.Stream (Stream)
@@ -29,7 +30,8 @@ data WS = WS
tlsUniq :: ByteString,
wsStream :: Stream,
wsConnection :: Connection,
wsTransportConfig :: TransportConfig
wsTransportConfig :: TransportConfig,
wsServerCerts :: X.CertificateChain
}
websocketsOpts :: ConnectionOptions
@@ -50,12 +52,15 @@ instance Transport WS where
transportConfig :: WS -> TransportConfig
transportConfig = wsTransportConfig
getServerConnection :: TransportConfig -> T.Context -> IO WS
getServerConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO WS
getServerConnection = getWS TServer
getClientConnection :: TransportConfig -> T.Context -> IO WS
getClientConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO WS
getClientConnection = getWS TClient
getServerCerts :: WS -> X.CertificateChain
getServerCerts = wsServerCerts
tlsUnique :: WS -> ByteString
tlsUnique = tlsUniq
@@ -79,13 +84,13 @@ instance Transport WS where
then E.throwIO TEBadBlock
else pure $ B.init s
getWS :: TransportPeer -> TransportConfig -> T.Context -> IO WS
getWS wsPeer cfg cxt = withTlsUnique wsPeer cxt connectWS
getWS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> IO WS
getWS wsPeer cfg wsServerCerts cxt = withTlsUnique wsPeer cxt connectWS
where
connectWS tlsUniq = do
s <- makeTLSContextStream cxt
wsConnection <- connectPeer wsPeer s
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg}
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg, wsServerCerts}
connectPeer :: TransportPeer -> Stream -> IO Connection
connectPeer TServer = acceptClientRequest
connectPeer TClient = sendClientRequest