mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-14 22:36:29 +00:00
transport: fetch and store server certificate (#985)
* THandleParams (WIP, does not compile) * transport: fetch and store server certificate * smp: add getOnlinePubKey example to smpClientHandshake * add server certs and sign authPub * cleanup * update * style * load server certs from test fixtures * sign ntf authPubKey * fix onServerCertificate * increase delay before sending messages * require certificate with key in SMP server handshake --------- Co-authored-by: Evgeny Poberezkin <evgeny@poberezkin.com>
This commit is contained in:
committed by
GitHub
parent
6aec0b13fd
commit
76eddfbc9d
@@ -8,6 +8,7 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
@@ -73,9 +74,12 @@ module Simplex.Messaging.Crypto
|
||||
generateAuthKeyPair,
|
||||
generateDhKeyPair,
|
||||
privateToX509,
|
||||
x509ToPublic,
|
||||
x509ToPrivate,
|
||||
publicKey,
|
||||
signatureKeyPair,
|
||||
publicToX509,
|
||||
encodeASNObj,
|
||||
|
||||
-- * key encoding/decoding
|
||||
encodePubKey,
|
||||
@@ -162,10 +166,13 @@ module Simplex.Messaging.Crypto
|
||||
Certificate,
|
||||
signCertificate,
|
||||
signX509,
|
||||
verifyX509,
|
||||
certificateFingerprint,
|
||||
signedFingerprint,
|
||||
SignatureAlgorithmX509 (..),
|
||||
SignedObject (..),
|
||||
encodeCertChain,
|
||||
certChainP,
|
||||
|
||||
-- * Cryptography error type
|
||||
CryptoError (..),
|
||||
@@ -210,6 +217,7 @@ import qualified Data.ByteString.Char8 as B
|
||||
import Data.ByteString.Lazy (fromStrict, toStrict)
|
||||
import Data.Constraint (Dict (..))
|
||||
import Data.Kind (Constraint, Type)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.String
|
||||
import Data.Type.Equality
|
||||
import Data.Typeable (Proxy (Proxy), Typeable)
|
||||
@@ -1137,6 +1145,18 @@ signX509 key = fst . objectToSignedExact f
|
||||
signatureAlgorithmX509 key,
|
||||
()
|
||||
)
|
||||
{-# INLINE signX509 #-}
|
||||
|
||||
verifyX509 :: (ASN1Object o, Eq o, Show o) => APublicVerifyKey -> SignedExact o -> Either String o
|
||||
verifyX509 key exact = do
|
||||
signature <- case signedAlg of
|
||||
SignatureALG_IntrinsicHash PubKeyALG_Ed25519 -> ASignature SEd25519 <$> decodeSignature signedSignature
|
||||
SignatureALG_IntrinsicHash PubKeyALG_Ed448 -> ASignature SEd448 <$> decodeSignature signedSignature
|
||||
_ -> Left "unknown x509 signature algorithm"
|
||||
if verify key signature $ getSignedData exact then Right signedObject else Left "bad signature"
|
||||
where
|
||||
Signed {signedObject, signedAlg, signedSignature} = getSigned exact
|
||||
{-# INLINE verifyX509 #-}
|
||||
|
||||
certificateFingerprint :: SignedCertificate -> KeyHash
|
||||
certificateFingerprint = signedFingerprint
|
||||
@@ -1165,7 +1185,7 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where
|
||||
signatureAlgorithmX509 = signatureAlgorithmX509 . snd
|
||||
|
||||
-- | A wrapper to marshall signed ASN1 objects, like certificates.
|
||||
newtype SignedObject a = SignedObject (SignedExact a)
|
||||
newtype SignedObject a = SignedObject {getSignedExact :: SignedExact a}
|
||||
|
||||
instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where
|
||||
fromField = fmap SignedObject . blobFieldDecoder decodeSignedObject
|
||||
@@ -1173,6 +1193,20 @@ instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a)
|
||||
instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where
|
||||
toField (SignedObject s) = toField $ encodeSignedObject s
|
||||
|
||||
instance (Eq a, Show a, ASN1Object a) => Encoding (SignedObject a) where
|
||||
smpEncode (SignedObject exact) = smpEncode . Large $ encodeSignedObject exact
|
||||
smpP = fmap SignedObject . decodeSignedObject . unLarge <$?> smpP
|
||||
|
||||
encodeCertChain :: CertificateChain -> L.NonEmpty Large
|
||||
encodeCertChain cc = L.fromList $ map Large blobs
|
||||
where
|
||||
CertificateChainRaw blobs = encodeCertificateChain cc
|
||||
|
||||
certChainP :: A.Parser CertificateChain
|
||||
certChainP = do
|
||||
rawChain <- CertificateChainRaw . map unLarge . L.toList <$> smpP
|
||||
either (fail . show) pure $ decodeCertificateChain rawChain
|
||||
|
||||
-- | Signature verification.
|
||||
--
|
||||
-- Used by SMP servers to authorize SMP commands and by SMP agents to verify messages.
|
||||
|
||||
@@ -49,7 +49,7 @@ import Simplex.Messaging.Server
|
||||
import Simplex.Messaging.Server.Stats
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport (..), THandle (..), THandleAuth (..), THandleParams (..), TProxy, Transport (..))
|
||||
import Simplex.Messaging.Transport.Server (runTransportServer)
|
||||
import Simplex.Messaging.Transport.Server (runTransportServer, tlsServerCredentials)
|
||||
import Simplex.Messaging.Util
|
||||
import System.Exit (exitFailure)
|
||||
import System.IO (BufferMode (..), hPutStrLn, hSetBuffering)
|
||||
@@ -82,14 +82,16 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do
|
||||
runServer :: (ServiceName, ATransport) -> M ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
runTransportServer started tcpPort serverParams tCfg (runClient t)
|
||||
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
|
||||
runTransportServer started tcpPort serverParams tCfg (runClient serverSignKey t)
|
||||
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
|
||||
|
||||
runClient :: Transport c => TProxy c -> c -> M ()
|
||||
runClient _ h = do
|
||||
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
|
||||
runClient signKey _ h = do
|
||||
kh <- asks serverIdentity
|
||||
ks <- atomically . C.generateKeyPair =<< asks random
|
||||
NtfServerConfig {ntfServerVRange} <- asks config
|
||||
liftIO (runExceptT $ ntfServerHandshake h ks kh ntfServerVRange) >>= \case
|
||||
liftIO (runExceptT $ ntfServerHandshake signKey h ks kh ntfServerVRange) >>= \case
|
||||
Right th -> runNtfClientTransport th
|
||||
Left _ -> pure ()
|
||||
|
||||
|
||||
@@ -7,13 +7,17 @@
|
||||
|
||||
module Simplex.Messaging.Notifications.Transport where
|
||||
|
||||
import Control.Monad (forM)
|
||||
import Control.Monad.Except
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.X509 as X
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Util (liftEitherWith)
|
||||
|
||||
ntfBlockSize :: Int
|
||||
ntfBlockSize = 512
|
||||
@@ -40,7 +44,7 @@ data NtfServerHandshake = NtfServerHandshake
|
||||
{ ntfVersionRange :: VersionRange,
|
||||
sessionId :: SessionId,
|
||||
-- pub key to agree shared secrets for command authorization and entity ID encryption.
|
||||
authPubKey :: Maybe C.PublicKeyX25519
|
||||
authPubKey :: Maybe (X.SignedExact X.PubKey)
|
||||
}
|
||||
|
||||
data NtfClientHandshake = NtfClientHandshake
|
||||
@@ -54,13 +58,25 @@ data NtfClientHandshake = NtfClientHandshake
|
||||
|
||||
instance Encoding NtfServerHandshake where
|
||||
smpEncode NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} =
|
||||
smpEncode (ntfVersionRange, sessionId) <> encodeNtfAuthPubKey (maxVersion ntfVersionRange) authPubKey
|
||||
B.concat
|
||||
[ smpEncode (ntfVersionRange, sessionId),
|
||||
encodeAuthEncryptCmds (maxVersion ntfVersionRange) $ C.SignedObject <$> authPubKey
|
||||
]
|
||||
|
||||
smpP = do
|
||||
(ntfVersionRange, sessionId) <- smpP
|
||||
-- TODO drop SMP v6: remove special parser and make key non-optional
|
||||
authPubKey <- ntfAuthPubKeyP $ maxVersion ntfVersionRange
|
||||
authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP
|
||||
pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey}
|
||||
|
||||
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
|
||||
encodeAuthEncryptCmds v k
|
||||
| v >= authEncryptCmdsNTFVersion = maybe "" smpEncode k
|
||||
| otherwise = ""
|
||||
|
||||
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
|
||||
authEncryptCmdsP v p = if v >= authEncryptCmdsNTFVersion then Just <$> p else pure Nothing
|
||||
|
||||
instance Encoding NtfClientHandshake where
|
||||
smpEncode NtfClientHandshake {ntfVersion, keyHash, authPubKey} =
|
||||
smpEncode (ntfVersion, keyHash) <> encodeNtfAuthPubKey ntfVersion authPubKey
|
||||
@@ -79,10 +95,11 @@ encodeNtfAuthPubKey v k
|
||||
| otherwise = ""
|
||||
|
||||
-- | Notifcations server transport handshake.
|
||||
ntfServerHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfServerHandshake c (k, pk) kh ntfVRange = do
|
||||
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
|
||||
sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange, authPubKey = Just k}
|
||||
let sk = C.signX509 serverSignKey $ C.publicToX509 k
|
||||
sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange, authPubKey = Just sk}
|
||||
getHandshake th >>= \case
|
||||
NtfClientHandshake {ntfVersion = v, keyHash, authPubKey = k'}
|
||||
| keyHash /= kh ->
|
||||
@@ -95,13 +112,17 @@ ntfServerHandshake c (k, pk) kh ntfVRange = do
|
||||
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfClientHandshake c (k, pk) keyHash ntfVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
|
||||
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = k'} <- getHandshake th
|
||||
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th
|
||||
if sessionId /= sessId
|
||||
then throwError TEBadSession
|
||||
else case ntfVersionRange `compatibleVersion` ntfVRange of
|
||||
Just (Compatible v) -> do
|
||||
sk_ <- forM sk' $ \exact -> liftEitherWith (const $ TEHandshake BAD_AUTH) $ do
|
||||
serverKey <- getServerVerifyKey c
|
||||
pubKey <- C.verifyX509 serverKey exact
|
||||
C.x509ToPublic (pubKey, []) >>= C.pubKey
|
||||
sendHandshake th $ NtfClientHandshake {ntfVersion = v, keyHash, authPubKey = Just k}
|
||||
pure $ ntfThHandle th v pk k'
|
||||
pure $ ntfThHandle th v pk sk_
|
||||
Nothing -> throwError $ TEHandshake VERSION
|
||||
|
||||
ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
|
||||
|
||||
@@ -133,7 +133,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
ss <- asks sockets
|
||||
runTransportServerState ss started tcpPort serverParams tCfg (runClient t)
|
||||
serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams
|
||||
runTransportServerState ss started tcpPort serverParams tCfg (runClient serverSignKey t)
|
||||
fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey
|
||||
|
||||
saveServer :: Bool -> M ()
|
||||
saveServer keepMsgs = withLog closeStoreLog >> saveServerMessages keepMsgs >> saveServerStats
|
||||
@@ -244,13 +246,13 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
]
|
||||
liftIO $ threadDelay' interval
|
||||
|
||||
runClient :: Transport c => TProxy c -> c -> M ()
|
||||
runClient tp h = do
|
||||
runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M ()
|
||||
runClient signKey tp h = do
|
||||
kh <- asks serverIdentity
|
||||
ks <- atomically . C.generateKeyPair =<< asks random
|
||||
ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config
|
||||
labelMyThread $ "smp handshake for " <> transportName tp
|
||||
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake h ks kh smpServerVRange) >>= \case
|
||||
liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake signKey h ks kh smpServerVRange) >>= \case
|
||||
Just (Right th) -> runClientTransport th
|
||||
_ -> pure ()
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ module Simplex.Messaging.Transport
|
||||
TProxy (..),
|
||||
ATransport (..),
|
||||
TransportPeer (..),
|
||||
getServerVerifyKey,
|
||||
|
||||
-- * TLS Transport
|
||||
TLS (..),
|
||||
@@ -71,12 +72,13 @@ module Simplex.Messaging.Transport
|
||||
where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Monad (forM)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except (throwE)
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import Data.Bitraversable (bimapM)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -84,6 +86,8 @@ import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Default (def)
|
||||
import Data.Functor (($>))
|
||||
import Data.Version (showVersion)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Data.X509.Validation as XV
|
||||
import GHC.IO.Handle.Internals (ioe_EOF)
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
@@ -93,7 +97,7 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport.Buffer
|
||||
import Simplex.Messaging.Util (bshow, catchAll, catchAll_)
|
||||
import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith)
|
||||
import Simplex.Messaging.Version
|
||||
import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
@@ -163,10 +167,12 @@ class Transport c where
|
||||
transportConfig :: c -> TransportConfig
|
||||
|
||||
-- | Upgrade server TLS context to connection (used in the server)
|
||||
getServerConnection :: TransportConfig -> T.Context -> IO c
|
||||
getServerConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO c
|
||||
|
||||
-- | Upgrade client TLS context to connection (used in the client)
|
||||
getClientConnection :: TransportConfig -> T.Context -> IO c
|
||||
getClientConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO c
|
||||
|
||||
getServerCerts :: c -> X.CertificateChain
|
||||
|
||||
-- | tls-unique channel binding per RFC5929
|
||||
tlsUnique :: c -> SessionId
|
||||
@@ -194,6 +200,12 @@ data TProxy c = TProxy
|
||||
|
||||
data ATransport = forall c. Transport c => ATransport (TProxy c)
|
||||
|
||||
getServerVerifyKey :: Transport c => c -> Either String C.APublicVerifyKey
|
||||
getServerVerifyKey c =
|
||||
case getServerCerts c of
|
||||
X.CertificateChain (server : _ca) -> C.x509ToPublic (X.certPubKey . X.signedObject $ X.getSigned server, []) >>= C.pubKey
|
||||
_ -> Left "no certificate chain"
|
||||
|
||||
-- * TLS Transport
|
||||
|
||||
data TLS = TLS
|
||||
@@ -201,6 +213,7 @@ data TLS = TLS
|
||||
tlsPeer :: TransportPeer,
|
||||
tlsUniq :: ByteString,
|
||||
tlsBuffer :: TBuffer,
|
||||
tlsServerCerts :: X.CertificateChain,
|
||||
tlsTransportConfig :: TransportConfig
|
||||
}
|
||||
|
||||
@@ -213,12 +226,12 @@ connectTLS host_ TransportConfig {logTLSErrors} params sock =
|
||||
logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e
|
||||
host = maybe "" (\h -> " (" <> h <> ")") host_
|
||||
|
||||
getTLS :: TransportPeer -> TransportConfig -> T.Context -> IO TLS
|
||||
getTLS tlsPeer cfg cxt = withTlsUnique tlsPeer cxt newTLS
|
||||
getTLS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> IO TLS
|
||||
getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS
|
||||
where
|
||||
newTLS tlsUniq = do
|
||||
tlsBuffer <- atomically newTBuffer
|
||||
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsPeer, tlsUniq, tlsBuffer}
|
||||
pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer}
|
||||
|
||||
withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c
|
||||
withTlsUnique peer cxt f =
|
||||
@@ -253,6 +266,7 @@ instance Transport TLS where
|
||||
transportConfig = tlsTransportConfig
|
||||
getServerConnection = getTLS TServer
|
||||
getClientConnection = getTLS TClient
|
||||
getServerCerts = tlsServerCerts
|
||||
tlsUnique = tlsUniq
|
||||
closeConnection tls = closeTLS $ tlsContext tls
|
||||
|
||||
@@ -280,7 +294,7 @@ instance Transport TLS where
|
||||
data THandle c = THandle
|
||||
{ connection :: c,
|
||||
params :: THandleParams
|
||||
}
|
||||
}
|
||||
|
||||
data THandleParams = THandleParams
|
||||
{ sessionId :: SessionId,
|
||||
@@ -301,7 +315,7 @@ data THandleParams = THandleParams
|
||||
data THandleAuth = THandleAuth
|
||||
{ peerPubKey :: C.PublicKeyX25519, -- used only in the client to combine with per-queue key
|
||||
privKey :: C.PrivateKeyX25519, -- used to combine with peer's per-queue key (currently only in the server)
|
||||
dhSecret :: C.DhSecretX25519 -- used by both parties to encrypt entity IDs in for version >= 7
|
||||
dhSecret :: C.DhSecretX25519 -- used by both parties to encrypt entity IDs in for version >= 7
|
||||
}
|
||||
|
||||
-- | TLS-unique channel binding
|
||||
@@ -311,7 +325,7 @@ data ServerHandshake = ServerHandshake
|
||||
{ smpVersionRange :: VersionRange,
|
||||
sessionId :: SessionId,
|
||||
-- pub key to agree shared secrets for command authorization and entity ID encryption.
|
||||
authPubKey :: Maybe C.PublicKeyX25519
|
||||
authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey)
|
||||
}
|
||||
|
||||
data ClientHandshake = ClientHandshake
|
||||
@@ -325,30 +339,39 @@ data ClientHandshake = ClientHandshake
|
||||
|
||||
instance Encoding ClientHandshake where
|
||||
smpEncode ClientHandshake {smpVersion, keyHash, authPubKey} =
|
||||
smpEncode (smpVersion, keyHash) <> encodeAuthPubKey smpVersion authPubKey
|
||||
smpEncode (smpVersion, keyHash) <> encodeAuthEncryptCmds smpVersion authPubKey
|
||||
smpP = do
|
||||
(smpVersion, keyHash) <- smpP
|
||||
-- TODO drop SMP v6: remove special parser and make key non-optional
|
||||
authPubKey <- authPubKeyP smpVersion
|
||||
authPubKey <- authEncryptCmdsP smpVersion smpP
|
||||
pure ClientHandshake {smpVersion, keyHash, authPubKey}
|
||||
|
||||
instance Encoding ServerHandshake where
|
||||
smpEncode ServerHandshake {smpVersionRange, sessionId, authPubKey} =
|
||||
smpEncode (smpVersionRange, sessionId) <> encodeAuthPubKey (maxVersion smpVersionRange) authPubKey
|
||||
smpEncode (smpVersionRange, sessionId) <> auth
|
||||
where
|
||||
auth =
|
||||
encodeAuthEncryptCmds (maxVersion smpVersionRange) $
|
||||
bimap C.encodeCertChain C.SignedObject <$> authPubKey
|
||||
smpP = do
|
||||
(smpVersionRange, sessionId) <- smpP
|
||||
-- TODO drop SMP v6: remove special parser and make key non-optional
|
||||
authPubKey <- authPubKeyP $ maxVersion smpVersionRange
|
||||
authPubKey <- authEncryptCmdsP (maxVersion smpVersionRange) authP
|
||||
pure ServerHandshake {smpVersionRange, sessionId, authPubKey}
|
||||
|
||||
authPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519)
|
||||
authPubKeyP v = if v >= authEncryptCmdsSMPVersion then Just <$> smpP else pure Nothing
|
||||
where
|
||||
authP = do
|
||||
cert <- C.certChainP
|
||||
C.SignedObject key <- smpP
|
||||
pure (cert, key)
|
||||
|
||||
encodeAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString
|
||||
encodeAuthPubKey v k
|
||||
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
|
||||
encodeAuthEncryptCmds v k
|
||||
| v >= authEncryptCmdsSMPVersion = maybe "" smpEncode k
|
||||
| otherwise = ""
|
||||
|
||||
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
|
||||
authEncryptCmdsP v p = if v >= authEncryptCmdsSMPVersion then Just <$> p else pure Nothing
|
||||
|
||||
-- | Error of SMP encrypted transport over TCP.
|
||||
data TransportError
|
||||
= -- | error parsing transport block
|
||||
@@ -372,6 +395,8 @@ data HandshakeError
|
||||
VERSION
|
||||
| -- | incorrect server identity
|
||||
IDENTITY
|
||||
| -- | v7 authentication failed
|
||||
BAD_AUTH
|
||||
deriving (Eq, Read, Show, Exception)
|
||||
|
||||
-- | SMP encrypted transport error parser.
|
||||
@@ -409,10 +434,12 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
|
||||
-- | Server SMP transport handshake.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
smpServerHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpServerHandshake c (k, pk) kh smpVRange = do
|
||||
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
|
||||
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange, authPubKey = Just k}
|
||||
sk = C.signX509 serverSignKey $ C.publicToX509 k
|
||||
certChain = getServerCerts c
|
||||
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange, authPubKey = Just (certChain, sk)}
|
||||
getHandshake th >>= \case
|
||||
ClientHandshake {smpVersion = v, keyHash, authPubKey = k'}
|
||||
| keyHash /= kh ->
|
||||
@@ -425,15 +452,23 @@ smpServerHandshake c (k, pk) kh smpVRange = do
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpClientHandshake c (k, pk) keyHash smpVRange = do
|
||||
smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
|
||||
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey = k'} <- getHandshake th
|
||||
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th
|
||||
if sessionId /= sessId
|
||||
then throwE TEBadSession
|
||||
else case smpVersionRange `compatibleVersion` smpVRange of
|
||||
Just (Compatible v) -> do
|
||||
sk_ <- forM authPubKey $ \(X.CertificateChain cert, exact) ->
|
||||
liftEitherWith (const $ TEHandshake BAD_AUTH) $ do
|
||||
case cert of
|
||||
[_leaf, ca] | XV.Fingerprint kh == XV.getFingerprint ca X.HashSHA256 -> pure ()
|
||||
_ -> throwError "bad certificate"
|
||||
serverKey <- getServerVerifyKey c
|
||||
pubKey <- C.verifyX509 serverKey exact
|
||||
C.x509ToPublic (pubKey, []) >>= C.pubKey
|
||||
sendHandshake th $ ClientHandshake {smpVersion = v, keyHash, authPubKey = Just k}
|
||||
pure $ smpThHandle th v pk k'
|
||||
pure $ smpThHandle th v pk sk_
|
||||
Nothing -> throwE $ TEHandshake VERSION
|
||||
|
||||
smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
|
||||
|
||||
@@ -22,6 +22,7 @@ where
|
||||
|
||||
import Control.Applicative (optional)
|
||||
import Control.Logger.Simple (logError)
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
@@ -54,6 +55,7 @@ import System.IO.Error
|
||||
import Text.Read (readMaybe)
|
||||
import UnliftIO.Exception (IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
data TransportHost
|
||||
= THIPv4 (Word8, Word8, Word8, Word8)
|
||||
@@ -129,8 +131,9 @@ runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do
|
||||
serverCert <- newEmptyTMVarIO
|
||||
let hostName = B.unpack $ strEncode host
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials
|
||||
clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials serverCert
|
||||
connectTCP = case socksProxy of
|
||||
Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host
|
||||
_ -> connectTCPClient hostName
|
||||
@@ -138,7 +141,13 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy,
|
||||
sock <- connectTCP port
|
||||
mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e)
|
||||
let tCfg = clientTransportConfig cfg
|
||||
connectTLS (Just hostName) tCfg clientParams sock >>= getClientConnection tCfg
|
||||
connectTLS (Just hostName) tCfg clientParams sock >>= \tls -> do
|
||||
chain <- atomically (tryTakeTMVar serverCert) >>= \case
|
||||
Nothing -> do
|
||||
logError "onServerCertificate didn't fire or failed to get cert chain"
|
||||
closeTLS tls >> error "onServerCertificate failed"
|
||||
Just c -> pure c
|
||||
getClientConnection tCfg chain tls
|
||||
client c `E.finally` liftIO (closeConnection c)
|
||||
where
|
||||
hostAddr = \case
|
||||
@@ -207,19 +216,24 @@ instance ToJSON SocksProxy where
|
||||
instance FromJSON SocksProxy where
|
||||
parseJSON = strParseJSON "SocksProxy"
|
||||
|
||||
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> T.ClientParams
|
||||
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ =
|
||||
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> TMVar X.CertificateChain -> T.ClientParams
|
||||
mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts =
|
||||
(T.defaultParamsClient host p)
|
||||
{ T.clientShared = def {T.sharedCAStore = fromMaybe (T.sharedCAStore def) caStore_},
|
||||
T.clientHooks =
|
||||
def
|
||||
{ T.onServerCertificate = maybe def (\cafp _ _ _ -> validateCertificateChain cafp host p) cafp_,
|
||||
{ T.onServerCertificate = onServerCert,
|
||||
T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_
|
||||
},
|
||||
T.clientSupported = supported
|
||||
}
|
||||
where
|
||||
p = B.pack port
|
||||
onServerCert _ _ _ c = do
|
||||
errs <- maybe def (\ca -> validateCertificateChain ca host p c) cafp_
|
||||
when (null errs) $
|
||||
atomically (putTMVar serverCerts c)
|
||||
pure errs
|
||||
|
||||
validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason]
|
||||
validateCertificateChain _ _ _ (X.CertificateChain []) = pure [XV.EmptyChain]
|
||||
|
||||
@@ -19,6 +19,7 @@ module Simplex.Messaging.Transport.Server
|
||||
loadTLSServerParams,
|
||||
loadFingerprint,
|
||||
smpServerHandshake,
|
||||
tlsServerCredentials
|
||||
)
|
||||
where
|
||||
|
||||
@@ -78,13 +79,13 @@ runTransportServerState :: forall c m. (Transport c, MonadUnliftIO m) => SocketS
|
||||
runTransportServerState ss started port = runTransportServerSocketState ss started (startTCPServer started port) (transportName (TProxy :: TProxy c))
|
||||
|
||||
-- | Run a transport server with provided connection setup and handler.
|
||||
runTransportServerSocket :: (MonadUnliftIO m, T.TLSParams p, Transport a) => TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocket :: (MonadUnliftIO m, Transport a) => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocket started getSocket threadLabel serverParams cfg server = do
|
||||
ss <- atomically newSocketState
|
||||
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server
|
||||
|
||||
-- | Run a transport server with provided connection setup and handler.
|
||||
runTransportServerSocketState :: (MonadUnliftIO m, T.TLSParams p, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocketState :: (MonadUnliftIO m, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m ()
|
||||
runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server = do
|
||||
u <- askUnliftIO
|
||||
labelMyThread $ "transport server for " <> threadLabel
|
||||
@@ -95,7 +96,12 @@ runTransportServerSocketState ss started getSocket threadLabel serverParams cfg
|
||||
setup conn = timeout (tlsSetupTimeout cfg) $ do
|
||||
labelMyThread $ threadLabel <> "/setup"
|
||||
tls <- connectTLS Nothing tCfg serverParams conn
|
||||
getServerConnection tCfg tls
|
||||
getServerConnection tCfg (fst $ tlsServerCredentials serverParams) tls
|
||||
|
||||
tlsServerCredentials :: T.ServerParams -> (X.CertificateChain, X.PrivKey)
|
||||
tlsServerCredentials serverParams = case T.sharedCredentials $ T.serverShared serverParams of
|
||||
T.Credentials [creds] -> creds
|
||||
_ -> error "server has more than one key"
|
||||
|
||||
-- | Run TCP server without TLS
|
||||
runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()
|
||||
|
||||
@@ -8,6 +8,7 @@ import qualified Control.Exception as E
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy as LB
|
||||
import qualified Data.X509 as X
|
||||
import qualified Network.TLS as T
|
||||
import Network.WebSockets
|
||||
import Network.WebSockets.Stream (Stream)
|
||||
@@ -29,7 +30,8 @@ data WS = WS
|
||||
tlsUniq :: ByteString,
|
||||
wsStream :: Stream,
|
||||
wsConnection :: Connection,
|
||||
wsTransportConfig :: TransportConfig
|
||||
wsTransportConfig :: TransportConfig,
|
||||
wsServerCerts :: X.CertificateChain
|
||||
}
|
||||
|
||||
websocketsOpts :: ConnectionOptions
|
||||
@@ -50,12 +52,15 @@ instance Transport WS where
|
||||
transportConfig :: WS -> TransportConfig
|
||||
transportConfig = wsTransportConfig
|
||||
|
||||
getServerConnection :: TransportConfig -> T.Context -> IO WS
|
||||
getServerConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO WS
|
||||
getServerConnection = getWS TServer
|
||||
|
||||
getClientConnection :: TransportConfig -> T.Context -> IO WS
|
||||
getClientConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO WS
|
||||
getClientConnection = getWS TClient
|
||||
|
||||
getServerCerts :: WS -> X.CertificateChain
|
||||
getServerCerts = wsServerCerts
|
||||
|
||||
tlsUnique :: WS -> ByteString
|
||||
tlsUnique = tlsUniq
|
||||
|
||||
@@ -79,13 +84,13 @@ instance Transport WS where
|
||||
then E.throwIO TEBadBlock
|
||||
else pure $ B.init s
|
||||
|
||||
getWS :: TransportPeer -> TransportConfig -> T.Context -> IO WS
|
||||
getWS wsPeer cfg cxt = withTlsUnique wsPeer cxt connectWS
|
||||
getWS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> IO WS
|
||||
getWS wsPeer cfg wsServerCerts cxt = withTlsUnique wsPeer cxt connectWS
|
||||
where
|
||||
connectWS tlsUniq = do
|
||||
s <- makeTLSContextStream cxt
|
||||
wsConnection <- connectPeer wsPeer s
|
||||
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg}
|
||||
pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg, wsServerCerts}
|
||||
connectPeer :: TransportPeer -> Stream -> IO Connection
|
||||
connectPeer TServer = acceptClientRequest
|
||||
connectPeer TClient = sendClientRequest
|
||||
|
||||
Reference in New Issue
Block a user