diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 5d021d502..9a775faa3 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -8,6 +8,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} @@ -73,9 +74,12 @@ module Simplex.Messaging.Crypto generateAuthKeyPair, generateDhKeyPair, privateToX509, + x509ToPublic, + x509ToPrivate, publicKey, signatureKeyPair, publicToX509, + encodeASNObj, -- * key encoding/decoding encodePubKey, @@ -162,10 +166,13 @@ module Simplex.Messaging.Crypto Certificate, signCertificate, signX509, + verifyX509, certificateFingerprint, signedFingerprint, SignatureAlgorithmX509 (..), SignedObject (..), + encodeCertChain, + certChainP, -- * Cryptography error type CryptoError (..), @@ -210,6 +217,7 @@ import qualified Data.ByteString.Char8 as B import Data.ByteString.Lazy (fromStrict, toStrict) import Data.Constraint (Dict (..)) import Data.Kind (Constraint, Type) +import qualified Data.List.NonEmpty as L import Data.String import Data.Type.Equality import Data.Typeable (Proxy (Proxy), Typeable) @@ -1137,6 +1145,18 @@ signX509 key = fst . objectToSignedExact f signatureAlgorithmX509 key, () ) +{-# INLINE signX509 #-} + +verifyX509 :: (ASN1Object o, Eq o, Show o) => APublicVerifyKey -> SignedExact o -> Either String o +verifyX509 key exact = do + signature <- case signedAlg of + SignatureALG_IntrinsicHash PubKeyALG_Ed25519 -> ASignature SEd25519 <$> decodeSignature signedSignature + SignatureALG_IntrinsicHash PubKeyALG_Ed448 -> ASignature SEd448 <$> decodeSignature signedSignature + _ -> Left "unknown x509 signature algorithm" + if verify key signature $ getSignedData exact then Right signedObject else Left "bad signature" + where + Signed {signedObject, signedAlg, signedSignature} = getSigned exact +{-# INLINE verifyX509 #-} certificateFingerprint :: SignedCertificate -> KeyHash certificateFingerprint = signedFingerprint @@ -1165,7 +1185,7 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where signatureAlgorithmX509 = signatureAlgorithmX509 . snd -- | A wrapper to marshall signed ASN1 objects, like certificates. -newtype SignedObject a = SignedObject (SignedExact a) +newtype SignedObject a = SignedObject {getSignedExact :: SignedExact a} instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where fromField = fmap SignedObject . blobFieldDecoder decodeSignedObject @@ -1173,6 +1193,20 @@ instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where toField (SignedObject s) = toField $ encodeSignedObject s +instance (Eq a, Show a, ASN1Object a) => Encoding (SignedObject a) where + smpEncode (SignedObject exact) = smpEncode . Large $ encodeSignedObject exact + smpP = fmap SignedObject . decodeSignedObject . unLarge <$?> smpP + +encodeCertChain :: CertificateChain -> L.NonEmpty Large +encodeCertChain cc = L.fromList $ map Large blobs + where + CertificateChainRaw blobs = encodeCertificateChain cc + +certChainP :: A.Parser CertificateChain +certChainP = do + rawChain <- CertificateChainRaw . map unLarge . L.toList <$> smpP + either (fail . show) pure $ decodeCertificateChain rawChain + -- | Signature verification. -- -- Used by SMP servers to authorize SMP commands and by SMP agents to verify messages. diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 0822f8d2d..3d61feea9 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -49,7 +49,7 @@ import Simplex.Messaging.Server import Simplex.Messaging.Server.Stats import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport (..), THandle (..), THandleAuth (..), THandleParams (..), TProxy, Transport (..)) -import Simplex.Messaging.Transport.Server (runTransportServer) +import Simplex.Messaging.Transport.Server (runTransportServer, tlsServerCredentials) import Simplex.Messaging.Util import System.Exit (exitFailure) import System.IO (BufferMode (..), hPutStrLn, hSetBuffering) @@ -82,14 +82,16 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do runServer :: (ServiceName, ATransport) -> M () runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams - runTransportServer started tcpPort serverParams tCfg (runClient t) + serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams + runTransportServer started tcpPort serverParams tCfg (runClient serverSignKey t) + fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey - runClient :: Transport c => TProxy c -> c -> M () - runClient _ h = do + runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M () + runClient signKey _ h = do kh <- asks serverIdentity ks <- atomically . C.generateKeyPair =<< asks random NtfServerConfig {ntfServerVRange} <- asks config - liftIO (runExceptT $ ntfServerHandshake h ks kh ntfServerVRange) >>= \case + liftIO (runExceptT $ ntfServerHandshake signKey h ks kh ntfServerVRange) >>= \case Right th -> runNtfClientTransport th Left _ -> pure () diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index c3289d177..a0564a079 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -7,13 +7,17 @@ module Simplex.Messaging.Notifications.Transport where +import Control.Monad (forM) import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import qualified Data.X509 as X import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Transport import Simplex.Messaging.Version +import Simplex.Messaging.Util (liftEitherWith) ntfBlockSize :: Int ntfBlockSize = 512 @@ -40,7 +44,7 @@ data NtfServerHandshake = NtfServerHandshake { ntfVersionRange :: VersionRange, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. - authPubKey :: Maybe C.PublicKeyX25519 + authPubKey :: Maybe (X.SignedExact X.PubKey) } data NtfClientHandshake = NtfClientHandshake @@ -54,13 +58,25 @@ data NtfClientHandshake = NtfClientHandshake instance Encoding NtfServerHandshake where smpEncode NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} = - smpEncode (ntfVersionRange, sessionId) <> encodeNtfAuthPubKey (maxVersion ntfVersionRange) authPubKey + B.concat + [ smpEncode (ntfVersionRange, sessionId), + encodeAuthEncryptCmds (maxVersion ntfVersionRange) $ C.SignedObject <$> authPubKey + ] + smpP = do (ntfVersionRange, sessionId) <- smpP -- TODO drop SMP v6: remove special parser and make key non-optional - authPubKey <- ntfAuthPubKeyP $ maxVersion ntfVersionRange + authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} +encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds v k + | v >= authEncryptCmdsNTFVersion = maybe "" smpEncode k + | otherwise = "" + +authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP v p = if v >= authEncryptCmdsNTFVersion then Just <$> p else pure Nothing + instance Encoding NtfClientHandshake where smpEncode NtfClientHandshake {ntfVersion, keyHash, authPubKey} = smpEncode (ntfVersion, keyHash) <> encodeNtfAuthPubKey ntfVersion authPubKey @@ -79,10 +95,11 @@ encodeNtfAuthPubKey v k | otherwise = "" -- | Notifcations server transport handshake. -ntfServerHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) -ntfServerHandshake c (k, pk) kh ntfVRange = do +ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c - sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange, authPubKey = Just k} + let sk = C.signX509 serverSignKey $ C.publicToX509 k + sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange, authPubKey = Just sk} getHandshake th >>= \case NtfClientHandshake {ntfVersion = v, keyHash, authPubKey = k'} | keyHash /= kh -> @@ -95,13 +112,17 @@ ntfServerHandshake c (k, pk) kh ntfVRange = do ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) ntfClientHandshake c (k, pk) keyHash ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c - NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = k'} <- getHandshake th + NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th if sessionId /= sessId then throwError TEBadSession else case ntfVersionRange `compatibleVersion` ntfVRange of Just (Compatible v) -> do + sk_ <- forM sk' $ \exact -> liftEitherWith (const $ TEHandshake BAD_AUTH) $ do + serverKey <- getServerVerifyKey c + pubKey <- C.verifyX509 serverKey exact + C.x509ToPublic (pubKey, []) >>= C.pubKey sendHandshake th $ NtfClientHandshake {ntfVersion = v, keyHash, authPubKey = Just k} - pure $ ntfThHandle th v pk k' + pure $ ntfThHandle th v pk sk_ Nothing -> throwError $ TEHandshake VERSION ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 22717521f..a653a62a5 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -133,7 +133,9 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do runServer (tcpPort, ATransport t) = do serverParams <- asks tlsServerParams ss <- asks sockets - runTransportServerState ss started tcpPort serverParams tCfg (runClient t) + serverSignKey <- either fail pure . fromTLSCredentials $ tlsServerCredentials serverParams + runTransportServerState ss started tcpPort serverParams tCfg (runClient serverSignKey t) + fromTLSCredentials (_, pk) = C.x509ToPrivate (pk, []) >>= C.privKey saveServer :: Bool -> M () saveServer keepMsgs = withLog closeStoreLog >> saveServerMessages keepMsgs >> saveServerStats @@ -244,13 +246,13 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do ] liftIO $ threadDelay' interval - runClient :: Transport c => TProxy c -> c -> M () - runClient tp h = do + runClient :: Transport c => C.APrivateSignKey -> TProxy c -> c -> M () + runClient signKey tp h = do kh <- asks serverIdentity ks <- atomically . C.generateKeyPair =<< asks random ServerConfig {smpServerVRange, smpHandshakeTimeout} <- asks config labelMyThread $ "smp handshake for " <> transportName tp - liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake h ks kh smpServerVRange) >>= \case + liftIO (timeout smpHandshakeTimeout . runExceptT $ smpServerHandshake signKey h ks kh smpServerVRange) >>= \case Just (Right th) -> runClientTransport th _ -> pure () diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 1fe00a253..f51e14fb7 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -44,6 +44,7 @@ module Simplex.Messaging.Transport TProxy (..), ATransport (..), TransportPeer (..), + getServerVerifyKey, -- * TLS Transport TLS (..), @@ -71,12 +72,13 @@ module Simplex.Messaging.Transport where import Control.Applicative ((<|>)) +import Control.Monad (forM) import Control.Monad.Except import Control.Monad.Trans.Except (throwE) import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Bifunctor (first) +import Data.Bifunctor (bimap, first) import Data.Bitraversable (bimapM) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -84,6 +86,8 @@ import qualified Data.ByteString.Lazy.Char8 as LB import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) +import qualified Data.X509 as X +import qualified Data.X509.Validation as XV import GHC.IO.Handle.Internals (ioe_EOF) import Network.Socket import qualified Network.TLS as T @@ -93,7 +97,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON) import Simplex.Messaging.Transport.Buffer -import Simplex.Messaging.Util (bshow, catchAll, catchAll_) +import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith) import Simplex.Messaging.Version import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E @@ -163,10 +167,12 @@ class Transport c where transportConfig :: c -> TransportConfig -- | Upgrade server TLS context to connection (used in the server) - getServerConnection :: TransportConfig -> T.Context -> IO c + getServerConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO c -- | Upgrade client TLS context to connection (used in the client) - getClientConnection :: TransportConfig -> T.Context -> IO c + getClientConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO c + + getServerCerts :: c -> X.CertificateChain -- | tls-unique channel binding per RFC5929 tlsUnique :: c -> SessionId @@ -194,6 +200,12 @@ data TProxy c = TProxy data ATransport = forall c. Transport c => ATransport (TProxy c) +getServerVerifyKey :: Transport c => c -> Either String C.APublicVerifyKey +getServerVerifyKey c = + case getServerCerts c of + X.CertificateChain (server : _ca) -> C.x509ToPublic (X.certPubKey . X.signedObject $ X.getSigned server, []) >>= C.pubKey + _ -> Left "no certificate chain" + -- * TLS Transport data TLS = TLS @@ -201,6 +213,7 @@ data TLS = TLS tlsPeer :: TransportPeer, tlsUniq :: ByteString, tlsBuffer :: TBuffer, + tlsServerCerts :: X.CertificateChain, tlsTransportConfig :: TransportConfig } @@ -213,12 +226,12 @@ connectTLS host_ TransportConfig {logTLSErrors} params sock = logThrow e = putStrLn ("TLS error" <> host <> ": " <> show e) >> E.throwIO e host = maybe "" (\h -> " (" <> h <> ")") host_ -getTLS :: TransportPeer -> TransportConfig -> T.Context -> IO TLS -getTLS tlsPeer cfg cxt = withTlsUnique tlsPeer cxt newTLS +getTLS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> IO TLS +getTLS tlsPeer cfg tlsServerCerts cxt = withTlsUnique tlsPeer cxt newTLS where newTLS tlsUniq = do tlsBuffer <- atomically newTBuffer - pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsPeer, tlsUniq, tlsBuffer} + pure TLS {tlsContext = cxt, tlsTransportConfig = cfg, tlsServerCerts, tlsPeer, tlsUniq, tlsBuffer} withTlsUnique :: TransportPeer -> T.Context -> (ByteString -> IO c) -> IO c withTlsUnique peer cxt f = @@ -253,6 +266,7 @@ instance Transport TLS where transportConfig = tlsTransportConfig getServerConnection = getTLS TServer getClientConnection = getTLS TClient + getServerCerts = tlsServerCerts tlsUnique = tlsUniq closeConnection tls = closeTLS $ tlsContext tls @@ -280,7 +294,7 @@ instance Transport TLS where data THandle c = THandle { connection :: c, params :: THandleParams - } + } data THandleParams = THandleParams { sessionId :: SessionId, @@ -301,7 +315,7 @@ data THandleParams = THandleParams data THandleAuth = THandleAuth { peerPubKey :: C.PublicKeyX25519, -- used only in the client to combine with per-queue key privKey :: C.PrivateKeyX25519, -- used to combine with peer's per-queue key (currently only in the server) - dhSecret :: C.DhSecretX25519 -- used by both parties to encrypt entity IDs in for version >= 7 + dhSecret :: C.DhSecretX25519 -- used by both parties to encrypt entity IDs in for version >= 7 } -- | TLS-unique channel binding @@ -311,7 +325,7 @@ data ServerHandshake = ServerHandshake { smpVersionRange :: VersionRange, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. - authPubKey :: Maybe C.PublicKeyX25519 + authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey) } data ClientHandshake = ClientHandshake @@ -325,30 +339,39 @@ data ClientHandshake = ClientHandshake instance Encoding ClientHandshake where smpEncode ClientHandshake {smpVersion, keyHash, authPubKey} = - smpEncode (smpVersion, keyHash) <> encodeAuthPubKey smpVersion authPubKey + smpEncode (smpVersion, keyHash) <> encodeAuthEncryptCmds smpVersion authPubKey smpP = do (smpVersion, keyHash) <- smpP -- TODO drop SMP v6: remove special parser and make key non-optional - authPubKey <- authPubKeyP smpVersion + authPubKey <- authEncryptCmdsP smpVersion smpP pure ClientHandshake {smpVersion, keyHash, authPubKey} instance Encoding ServerHandshake where smpEncode ServerHandshake {smpVersionRange, sessionId, authPubKey} = - smpEncode (smpVersionRange, sessionId) <> encodeAuthPubKey (maxVersion smpVersionRange) authPubKey + smpEncode (smpVersionRange, sessionId) <> auth + where + auth = + encodeAuthEncryptCmds (maxVersion smpVersionRange) $ + bimap C.encodeCertChain C.SignedObject <$> authPubKey smpP = do (smpVersionRange, sessionId) <- smpP -- TODO drop SMP v6: remove special parser and make key non-optional - authPubKey <- authPubKeyP $ maxVersion smpVersionRange + authPubKey <- authEncryptCmdsP (maxVersion smpVersionRange) authP pure ServerHandshake {smpVersionRange, sessionId, authPubKey} - -authPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519) -authPubKeyP v = if v >= authEncryptCmdsSMPVersion then Just <$> smpP else pure Nothing + where + authP = do + cert <- C.certChainP + C.SignedObject key <- smpP + pure (cert, key) -encodeAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString -encodeAuthPubKey v k +encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds v k | v >= authEncryptCmdsSMPVersion = maybe "" smpEncode k | otherwise = "" +authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP v p = if v >= authEncryptCmdsSMPVersion then Just <$> p else pure Nothing + -- | Error of SMP encrypted transport over TCP. data TransportError = -- | error parsing transport block @@ -372,6 +395,8 @@ data HandshakeError VERSION | -- | incorrect server identity IDENTITY + | -- | v7 authentication failed + BAD_AUTH deriving (Eq, Read, Show, Exception) -- | SMP encrypted transport error parser. @@ -409,10 +434,12 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do -- | Server SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpServerHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) -smpServerHandshake c (k, pk) kh smpVRange = do +smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c - sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange, authPubKey = Just k} + sk = C.signX509 serverSignKey $ C.publicToX509 k + certChain = getServerCerts c + sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange, authPubKey = Just (certChain, sk)} getHandshake th >>= \case ClientHandshake {smpVersion = v, keyHash, authPubKey = k'} | keyHash /= kh -> @@ -425,15 +452,23 @@ smpServerHandshake c (k, pk) kh smpVRange = do -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) -smpClientHandshake c (k, pk) keyHash smpVRange = do +smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c - ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey = k'} <- getHandshake th + ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th if sessionId /= sessId then throwE TEBadSession else case smpVersionRange `compatibleVersion` smpVRange of Just (Compatible v) -> do + sk_ <- forM authPubKey $ \(X.CertificateChain cert, exact) -> + liftEitherWith (const $ TEHandshake BAD_AUTH) $ do + case cert of + [_leaf, ca] | XV.Fingerprint kh == XV.getFingerprint ca X.HashSHA256 -> pure () + _ -> throwError "bad certificate" + serverKey <- getServerVerifyKey c + pubKey <- C.verifyX509 serverKey exact + C.x509ToPublic (pubKey, []) >>= C.pubKey sendHandshake th $ ClientHandshake {smpVersion = v, keyHash, authPubKey = Just k} - pure $ smpThHandle th v pk k' + pure $ smpThHandle th v pk sk_ Nothing -> throwE $ TEHandshake VERSION smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index c36b33719..ddc08ae98 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -22,6 +22,7 @@ where import Control.Applicative (optional) import Control.Logger.Simple (logError) +import Control.Monad (when) import Control.Monad.IO.Unlift import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A @@ -54,6 +55,7 @@ import System.IO.Error import Text.Read (readMaybe) import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E +import UnliftIO.STM data TransportHost = THIPv4 (Word8, Word8, Word8, Word8) @@ -129,8 +131,9 @@ runTransportClient = runTLSTransportClient supportedParameters Nothing runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> TransportClientConfig -> Maybe ByteString -> TransportHost -> ServiceName -> Maybe C.KeyHash -> (c -> m a) -> m a runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, tcpKeepAlive, clientCredentials} proxyUsername host port keyHash client = do + serverCert <- newEmptyTMVarIO let hostName = B.unpack $ strEncode host - clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials + clientParams = mkTLSClientParams tlsParams caStore_ hostName port keyHash clientCredentials serverCert connectTCP = case socksProxy of Just proxy -> connectSocksClient proxy proxyUsername $ hostAddr host _ -> connectTCPClient hostName @@ -138,7 +141,13 @@ runTLSTransportClient tlsParams caStore_ cfg@TransportClientConfig {socksProxy, sock <- connectTCP port mapM_ (setSocketKeepAlive sock) tcpKeepAlive `catchAll` \e -> logError ("Error setting TCP keep-alive" <> tshow e) let tCfg = clientTransportConfig cfg - connectTLS (Just hostName) tCfg clientParams sock >>= getClientConnection tCfg + connectTLS (Just hostName) tCfg clientParams sock >>= \tls -> do + chain <- atomically (tryTakeTMVar serverCert) >>= \case + Nothing -> do + logError "onServerCertificate didn't fire or failed to get cert chain" + closeTLS tls >> error "onServerCertificate failed" + Just c -> pure c + getClientConnection tCfg chain tls client c `E.finally` liftIO (closeConnection c) where hostAddr = \case @@ -207,19 +216,24 @@ instance ToJSON SocksProxy where instance FromJSON SocksProxy where parseJSON = strParseJSON "SocksProxy" -mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> T.ClientParams -mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ = +mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe (X.CertificateChain, T.PrivKey) -> TMVar X.CertificateChain -> T.ClientParams +mkTLSClientParams supported caStore_ host port cafp_ clientCreds_ serverCerts = (T.defaultParamsClient host p) { T.clientShared = def {T.sharedCAStore = fromMaybe (T.sharedCAStore def) caStore_}, T.clientHooks = def - { T.onServerCertificate = maybe def (\cafp _ _ _ -> validateCertificateChain cafp host p) cafp_, + { T.onServerCertificate = onServerCert, T.onCertificateRequest = maybe def (const . pure . Just) clientCreds_ }, T.clientSupported = supported } where p = B.pack port + onServerCert _ _ _ c = do + errs <- maybe def (\ca -> validateCertificateChain ca host p c) cafp_ + when (null errs) $ + atomically (putTMVar serverCerts c) + pure errs validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason] validateCertificateChain _ _ _ (X.CertificateChain []) = pure [XV.EmptyChain] diff --git a/src/Simplex/Messaging/Transport/Server.hs b/src/Simplex/Messaging/Transport/Server.hs index 06f97a353..983068434 100644 --- a/src/Simplex/Messaging/Transport/Server.hs +++ b/src/Simplex/Messaging/Transport/Server.hs @@ -19,6 +19,7 @@ module Simplex.Messaging.Transport.Server loadTLSServerParams, loadFingerprint, smpServerHandshake, + tlsServerCredentials ) where @@ -78,13 +79,13 @@ runTransportServerState :: forall c m. (Transport c, MonadUnliftIO m) => SocketS runTransportServerState ss started port = runTransportServerSocketState ss started (startTCPServer started port) (transportName (TProxy :: TProxy c)) -- | Run a transport server with provided connection setup and handler. -runTransportServerSocket :: (MonadUnliftIO m, T.TLSParams p, Transport a) => TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m () +runTransportServerSocket :: (MonadUnliftIO m, Transport a) => TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m () runTransportServerSocket started getSocket threadLabel serverParams cfg server = do ss <- atomically newSocketState runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server -- | Run a transport server with provided connection setup and handler. -runTransportServerSocketState :: (MonadUnliftIO m, T.TLSParams p, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> p -> TransportServerConfig -> (a -> m ()) -> m () +runTransportServerSocketState :: (MonadUnliftIO m, Transport a) => SocketState -> TMVar Bool -> IO Socket -> String -> T.ServerParams -> TransportServerConfig -> (a -> m ()) -> m () runTransportServerSocketState ss started getSocket threadLabel serverParams cfg server = do u <- askUnliftIO labelMyThread $ "transport server for " <> threadLabel @@ -95,7 +96,12 @@ runTransportServerSocketState ss started getSocket threadLabel serverParams cfg setup conn = timeout (tlsSetupTimeout cfg) $ do labelMyThread $ threadLabel <> "/setup" tls <- connectTLS Nothing tCfg serverParams conn - getServerConnection tCfg tls + getServerConnection tCfg (fst $ tlsServerCredentials serverParams) tls + +tlsServerCredentials :: T.ServerParams -> (X.CertificateChain, X.PrivKey) +tlsServerCredentials serverParams = case T.sharedCredentials $ T.serverShared serverParams of + T.Credentials [creds] -> creds + _ -> error "server has more than one key" -- | Run TCP server without TLS runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO () diff --git a/src/Simplex/Messaging/Transport/WebSockets.hs b/src/Simplex/Messaging/Transport/WebSockets.hs index 4a39234b5..062f4f0f0 100644 --- a/src/Simplex/Messaging/Transport/WebSockets.hs +++ b/src/Simplex/Messaging/Transport/WebSockets.hs @@ -8,6 +8,7 @@ import qualified Control.Exception as E import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as LB +import qualified Data.X509 as X import qualified Network.TLS as T import Network.WebSockets import Network.WebSockets.Stream (Stream) @@ -29,7 +30,8 @@ data WS = WS tlsUniq :: ByteString, wsStream :: Stream, wsConnection :: Connection, - wsTransportConfig :: TransportConfig + wsTransportConfig :: TransportConfig, + wsServerCerts :: X.CertificateChain } websocketsOpts :: ConnectionOptions @@ -50,12 +52,15 @@ instance Transport WS where transportConfig :: WS -> TransportConfig transportConfig = wsTransportConfig - getServerConnection :: TransportConfig -> T.Context -> IO WS + getServerConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO WS getServerConnection = getWS TServer - getClientConnection :: TransportConfig -> T.Context -> IO WS + getClientConnection :: TransportConfig -> X.CertificateChain -> T.Context -> IO WS getClientConnection = getWS TClient + getServerCerts :: WS -> X.CertificateChain + getServerCerts = wsServerCerts + tlsUnique :: WS -> ByteString tlsUnique = tlsUniq @@ -79,13 +84,13 @@ instance Transport WS where then E.throwIO TEBadBlock else pure $ B.init s -getWS :: TransportPeer -> TransportConfig -> T.Context -> IO WS -getWS wsPeer cfg cxt = withTlsUnique wsPeer cxt connectWS +getWS :: TransportPeer -> TransportConfig -> X.CertificateChain -> T.Context -> IO WS +getWS wsPeer cfg wsServerCerts cxt = withTlsUnique wsPeer cxt connectWS where connectWS tlsUniq = do s <- makeTLSContextStream cxt wsConnection <- connectPeer wsPeer s - pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg} + pure $ WS {wsPeer, tlsUniq, wsStream = s, wsConnection, wsTransportConfig = cfg, wsServerCerts} connectPeer :: TransportPeer -> Stream -> IO Connection connectPeer TServer = acceptClientRequest connectPeer TClient = sendClientRequest diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 1d93e50b5..75de107d6 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -73,9 +73,9 @@ notificationTests t = do withAPNSMockServer $ \apns -> testNtfTokenChangeServers t apns describe "Managing notification subscriptions" $ do - describe "should create notification subscription for existing connection" $ + fdescribe "should create notification subscription for existing connection" $ testNtfMatrix t testNotificationSubscriptionExistingConnection - describe "should create notification subscription for new connection" $ + fdescribe "should create notification subscription for new connection" $ testNtfMatrix t testNotificationSubscriptionNewConnection it "should change notifications mode" $ withSmpServer t $ @@ -348,7 +348,7 @@ testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} alice@Agen verifyNtfToken alice tkn vNonce verification NTActive <- checkNtfToken alice tkn -- send message - liftIO $ threadDelay 50000 + liftIO $ threadDelay 250000 1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello" get bob ##> ("", aliceId, SENT $ baseId + 1) -- notification diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index 7f64c28ad..81f416643 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -8,8 +8,8 @@ module CoreTests.BatchingTests (batchingTests) where import Control.Concurrent.STM import Control.Monad import Crypto.Random (ChaChaDRG) -import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString as B +import Data.ByteString.Char8 (ByteString) import qualified Data.List.NonEmpty as L import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 361e5f14e..3cb2ea57a 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} @@ -89,6 +90,9 @@ signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do C.SEd25519 -> Just . TASignature . C.ASignature C.SEd25519 $ C.sign' pk t C.SEd448 -> Just . TASignature . C.ASignature C.SEd448 $ C.sign' pk t C.SX25519 -> (\THandleAuth {peerPubKey} -> TAAuthenticator $ C.cbAuthenticate peerPubKey pk (C.cbNonce corrId) t) <$> thAuth params +#if !MIN_VERSION_base(4,18,0) + _sx448 -> undefined -- ghc8107 fails to the branch excluded by types +#endif tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) tPut1 h t = do