Merge branch 'master' into pq

This commit is contained in:
Evgeny Poberezkin
2024-03-04 20:13:01 +00:00
44 changed files with 851 additions and 574 deletions
+133 -48
View File
@@ -3,6 +3,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
@@ -14,10 +15,64 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
-- {-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Simplex.Messaging.Crypto.Ratchet where
module Simplex.Messaging.Crypto.Ratchet
( Ratchet (..),
RatchetX448,
SkippedMsgDiff (..),
SkippedMsgKeys,
InitialKeys (..),
PQEncryption (..),
pattern PQEncOn,
pattern PQEncOff,
AUseKEM (..),
RatchetKEMState (..),
SRatchetKEMState (..),
RcvPrivRKEMParams,
APrivRKEMParams (..),
RcvE2ERatchetParamsUri,
RcvE2ERatchetParams,
SndE2ERatchetParams,
AE2ERatchetParams (..),
E2ERatchetParamsUri (..),
E2ERatchetParams (..),
VersionE2E,
VersionRangeE2E,
pattern VersionE2E,
kdfX3DHE2EEncryptVersion,
pqRatchetVersion,
currentE2EEncryptVersion,
supportedE2EEncryptVRange,
generateRcvE2EParams,
generateSndE2EParams,
initialPQEncryption,
connPQEncryption,
joinContactInitialKeys,
replyKEM_,
pqX3dhSnd,
pqX3dhRcv,
initSndRatchet,
initRcvRatchet,
rcEncrypt,
rcDecrypt,
-- used in tests
MsgHeader (..),
RatchetVersions (..),
RatchetInitParams (..),
UseKEM (..),
RKEMParams (..),
ARKEMParams (..),
SndRatchet (..),
RcvRatchet (..),
RatchetKEM (..),
RatchetKEMAccepted (..),
RatchetKey (..),
ratchetVersions,
fullHeaderLen,
applySMDiff,
)
where
import Control.Applicative ((<|>))
import Control.Monad.Except
@@ -44,7 +99,7 @@ import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, isJust)
import Data.Type.Equality
import Data.Typeable (Typeable)
import Data.Word (Word32)
import Data.Word (Word16, Word32)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Messaging.Agent.QueryString
@@ -55,22 +110,34 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE')
import Simplex.Messaging.Util ((<$?>), ($>>=))
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import UnliftIO.STM
-- e2e encryption headers version history:
-- 1 - binary protocol encoding (1/1/2022)
-- 2 - use KDF in x3dh (10/20/2022)
kdfX3DHE2EEncryptVersion :: Version
kdfX3DHE2EEncryptVersion = 2
data E2EVersion
pqRatchetVersion :: Version
pqRatchetVersion = 3
instance VersionScope E2EVersion
currentE2EEncryptVersion :: Version
currentE2EEncryptVersion = 3
type VersionE2E = Version E2EVersion
supportedE2EEncryptVRange :: VersionRange
type VersionRangeE2E = VersionRange E2EVersion
pattern VersionE2E :: Word16 -> VersionE2E
pattern VersionE2E v = Version v
kdfX3DHE2EEncryptVersion :: VersionE2E
kdfX3DHE2EEncryptVersion = VersionE2E 2
pqRatchetVersion :: VersionE2E
pqRatchetVersion = VersionE2E 3
currentE2EEncryptVersion :: VersionE2E
currentE2EEncryptVersion = VersionE2E 3
supportedE2EEncryptVRange :: VersionRangeE2E
supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion
data RatchetKEMState
@@ -120,7 +187,7 @@ instance RatchetKEMStateI s => Encoding (RKEMParams s) where
RKParamsAccepted ct k -> smpEncode ('A', ct, k)
smpP = (\(ARKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP
instance Encoding (ARKEMParams) where
instance Encoding ARKEMParams where
smpEncode (ARKP _ ps) = smpEncode ps
smpP =
smpP >>= \case
@@ -129,7 +196,7 @@ instance Encoding (ARKEMParams) where
_ -> fail "bad ratchet KEM params"
data E2ERatchetParams (s :: RatchetKEMState) (a :: Algorithm)
= E2ERatchetParams Version (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
= E2ERatchetParams VersionE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
deriving (Show)
data AE2ERatchetParams (a :: Algorithm)
@@ -159,12 +226,12 @@ instance (RatchetKEMStateI s, AlgorithmI a) => Encoding (E2ERatchetParams s a) w
instance AlgorithmI a => Encoding (AE2ERatchetParams a) where
smpEncode (AE2ERatchetParams _ ps) = smpEncode ps
smpP = (\(AnyE2ERatchetParams s _ ps) -> (AE2ERatchetParams s) <$> checkAlgorithm ps) <$?> smpP
smpP = (\(AnyE2ERatchetParams s _ ps) -> AE2ERatchetParams s <$> checkAlgorithm ps) <$?> smpP
instance Encoding AnyE2ERatchetParams where
smpEncode (AnyE2ERatchetParams _ _ ps) = smpEncode ps
smpP = do
v :: Version <- smpP
v :: VersionE2E <- smpP
APublicDhKey a k1 <- smpP
APublicDhKey a' k2 <- smpP
case testEquality a a' of
@@ -174,25 +241,25 @@ instance Encoding AnyE2ERatchetParams where
Just (ARKP s kem) -> pure $ AnyE2ERatchetParams s a $ E2ERatchetParams v k1 k2 (Just kem)
Nothing -> pure $ AnyE2ERatchetParams SRKSProposed a $ E2ERatchetParams v k1 k2 Nothing
where
kemP :: Version -> Parser (Maybe (ARKEMParams))
kemP :: VersionE2E -> Parser (Maybe ARKEMParams)
kemP v
| v >= pqRatchetVersion = smpP
| otherwise = pure Nothing
instance VersionI (E2ERatchetParams s a) where
type VersionRangeT (E2ERatchetParams s a) = E2ERatchetParamsUri s a
instance VersionI E2EVersion (E2ERatchetParams s a) where
type VersionRangeT E2EVersion (E2ERatchetParams s a) = E2ERatchetParamsUri s a
version (E2ERatchetParams v _ _ _) = v
toVersionRangeT (E2ERatchetParams _ k1 k2 kem_) vr = E2ERatchetParamsUri vr k1 k2 kem_
instance VersionRangeI (E2ERatchetParamsUri s a) where
type VersionT (E2ERatchetParamsUri s a) = (E2ERatchetParams s a)
instance VersionRangeI E2EVersion (E2ERatchetParamsUri s a) where
type VersionT E2EVersion (E2ERatchetParamsUri s a) = (E2ERatchetParams s a)
versionRange (E2ERatchetParamsUri vr _ _ _) = vr
toVersionT (E2ERatchetParamsUri _ k1 k2 kem_) v = E2ERatchetParams v k1 k2 kem_
type RcvE2ERatchetParamsUri a = E2ERatchetParamsUri 'RKSProposed a
data E2ERatchetParamsUri (s :: RatchetKEMState) (a :: Algorithm)
= E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
= E2ERatchetParamsUri VersionRangeE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s))
deriving (Show)
data AE2ERatchetParamsUri (a :: Algorithm)
@@ -228,13 +295,13 @@ instance (RatchetKEMStateI s, AlgorithmI a) => StrEncoding (E2ERatchetParamsUri
instance AlgorithmI a => StrEncoding (AE2ERatchetParamsUri a) where
strEncode (AE2ERatchetParamsUri _ ps) = strEncode ps
strP = (\(AnyE2ERatchetParamsUri s _ ps) -> (AE2ERatchetParamsUri s) <$> checkAlgorithm ps) <$?> strP
strP = (\(AnyE2ERatchetParamsUri s _ ps) -> AE2ERatchetParamsUri s <$> checkAlgorithm ps) <$?> strP
instance StrEncoding AnyE2ERatchetParamsUri where
strEncode (AnyE2ERatchetParamsUri _ _ ps) = strEncode ps
strP = do
query <- strP
vr :: VersionRange <- queryParam "v" query
vr :: VersionRangeE2E <- queryParam "v" query
keys <- L.toList <$> queryParam "x3dh" query
case keys of
[APublicDhKey a k1, APublicDhKey a' k2] -> case testEquality a a' of
@@ -248,7 +315,7 @@ instance StrEncoding AnyE2ERatchetParamsUri where
kemP vr query
| maxVersion vr >= pqRatchetVersion =
queryParam_ "kem_key" query
$>>= \k -> (Just . kemParams k <$> queryParam_ "kem_ct" query)
$>>= \k -> Just . kemParams k <$> queryParam_ "kem_ct" query
| otherwise = pure Nothing
kemParams k = \case
Nothing -> ARKP SRKSProposed $ RKParamsProposed k
@@ -272,7 +339,7 @@ instance RatchetKEMStateI s => Encoding (PrivRKEMParams s) where
PrivateRKParamsAccepted ct shared k -> smpEncode ('A', ct, shared, k)
smpP = (\(APRKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP
instance Encoding (APrivRKEMParams) where
instance Encoding APrivRKEMParams where
smpEncode (APRKP _ ps) = smpEncode ps
smpP =
smpP >>= \case
@@ -290,7 +357,7 @@ data UseKEM (s :: RatchetKEMState) where
data AUseKEM = forall s. RatchetKEMStateI s => AUseKEM (SRatchetKEMState s) (UseKEM s)
generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a)
generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a)
generateE2EParams g v useKEM_ = do
(k1, pk1) <- atomically $ generateKeyPair g
(k2, pk2) <- atomically $ generateKeyPair g
@@ -309,7 +376,7 @@ generateE2EParams g v useKEM_ = do
_ -> pure Nothing
-- used by party initiating connection, Bob in double-ratchet spec
generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a)
generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> PQEncryption -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a)
generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_
where
proposeKEM_ :: PQEncryption -> Maybe (UseKEM 'RKSProposed)
@@ -319,7 +386,7 @@ generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_
-- used by party accepting connection, Alice in double-ratchet spec
generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a)
generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a)
generateSndE2EParams g v = \case
Nothing -> do
(pk1, pk2, _, e2eParams) <- generateE2EParams g v Nothing
@@ -385,7 +452,7 @@ type RatchetX448 = Ratchet 'X448
data Ratchet a = Ratchet
{ -- ratchet version range sent in messages (current .. max supported ratchet version)
rcVersion :: VersionRange,
rcVersion :: RatchetVersions,
-- associated data - must be the same in both parties ratchets
rcAD :: Str,
rcDHRs :: PrivateKey a,
@@ -405,6 +472,29 @@ data Ratchet a = Ratchet
}
deriving (Show)
data RatchetVersions = RVersions
{ current :: VersionE2E,
maxSupported :: VersionE2E
}
deriving (Eq, Show)
instance ToJSON RatchetVersions where
-- TODO v5.7 or v5.8 change to the default record encoding
toJSON (RVersions v1 v2) = toJSON (v1, v2)
toEncoding (RVersions v1 v2) = toEncoding (v1, v2)
instance FromJSON RatchetVersions where
-- TODO v6.0 replace with the default record parser
-- this parser supports JSON record encoding for forward compatibility
parseJSON v = (tupleP <|> recordP v) >>= toRV
where
tupleP = parseJSON v
recordP = J.withObject "RatchetVersions" $ \o -> (,) <$> o J..: "current" <*> o J..: "maxSupported"
toRV (v1, v2) = maybe (fail "bad version range") (pure . ratchetVersions) $ safeVersionRange v1 v2
ratchetVersions :: VersionRangeE2E -> RatchetVersions
ratchetVersions (VersionRange v1 v2) = RVersions {current = v1, maxSupported = v2}
data SndRatchet a = SndRatchet
{ rcDHRr :: PublicKey a,
rcCKs :: RatchetKey,
@@ -494,12 +584,12 @@ instance FromField MessageKey where fromField = blobFieldDecoder smpDecode
-- // above added for KEM
-- @
initSndRatchet ::
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a
initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a
initSndRatchet v rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do
-- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss)
let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted)
in Ratchet
{ rcVersion,
{ rcVersion = ratchetVersions v,
rcAD = assocData,
rcDHRs,
rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_,
@@ -523,10 +613,10 @@ initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey
-- Please note that the public part of rcDHRs was sent to the sender
-- as part of the connection request and random salt was received from the sender.
initRcvRatchet ::
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a
initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM =
forall a. (AlgorithmI a, DhAlgorithm a) => VersionRangeE2E -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQEncryption -> Ratchet a
initRcvRatchet v rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) rcEnableKEM =
Ratchet
{ rcVersion,
{ rcVersion = ratchetVersions v,
rcAD = assocData,
rcDHRs,
-- rcKEM:
@@ -552,7 +642,7 @@ initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK
-- ct = state.PQRct // added for KEM #1
data MsgHeader a = MsgHeader
{ -- | max supported ratchet version
msgMaxVersion :: Version,
msgMaxVersion :: VersionE2E,
msgDHRs :: PublicKey a,
msgKEM :: Maybe ARKEMParams,
msgPN :: Word32,
@@ -560,11 +650,6 @@ data MsgHeader a = MsgHeader
}
deriving (Show)
data AMsgHeader
= forall a.
(AlgorithmI a, DhAlgorithm a) =>
AMsgHeader (SAlgorithm a) (MsgHeader a)
-- to allow extension without increasing the size, the actual header length is:
-- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4
-- TODO PQ this must be version-dependent
@@ -590,7 +675,7 @@ instance AlgorithmI a => Encoding (MsgHeader a) where
pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs}
data EncMessageHeader = EncMessageHeader
{ ehVersion :: Version,
{ ehVersion :: VersionE2E,
ehIV :: IV,
ehAuthTag :: AuthTag,
ehBody :: ByteString
@@ -611,12 +696,12 @@ data EncRatchetMessage = EncRatchetMessage
emBody :: ByteString
}
encodeEncRatchetMessage :: Version -> EncRatchetMessage -> ByteString
encodeEncRatchetMessage :: VersionE2E -> EncRatchetMessage -> ByteString
encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag}
| v >= pqRatchetVersion = smpEncode (Large emHeader, emAuthTag, Tail emBody)
| otherwise = smpEncode (emHeader, emAuthTag, Tail emBody)
encRatchetMessageP :: Version -> Parser EncRatchetMessage
encRatchetMessageP :: VersionE2E -> Parser EncRatchetMessage
encRatchetMessageP v = do
emHeader <- if v >= pqRatchetVersion then unLarge <$> smpP else smpP
(emAuthTag, Tail emBody) <- smpP
@@ -694,10 +779,10 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM,
-- enc_header = HENCRYPT(state.HKs, header)
(ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader
-- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header))
-- TODO PQ versioning in Ratchet should change somehow
let emHeader = smpEncode EncMessageHeader {ehVersion = maxVersion rcVersion, ehBody, ehAuthTag, ehIV}
-- TODO PQ versioning in Ratchet should change: we should use "current" version here
let emHeader = smpEncode EncMessageHeader {ehVersion = maxSupported rcVersion, ehBody, ehAuthTag, ehIV}
(emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg
let msg' = encodeEncRatchetMessage (maxVersion rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag}
let msg' = encodeEncRatchetMessage (maxSupported rcVersion) EncRatchetMessage {emHeader, emBody, emAuthTag}
-- state.Ns += 1
rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1}
rc'' = case pqMode_ of
@@ -719,7 +804,7 @@ rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM,
msgHeader =
smpEncode
MsgHeader
{ msgMaxVersion = maxVersion rcVersion,
{ msgMaxVersion = maxSupported rcVersion,
msgDHRs = publicKey rcDHRs,
msgKEM = msgKEMParams <$> rcKEM,
msgPN = rcPN,
@@ -752,7 +837,7 @@ rcDecrypt ::
ExceptT CryptoError IO (DecryptResult a)
rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do
-- TODO PQ versioning should change
encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxVersion rcVersion) msg'
encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError (encRatchetMessageP $ maxSupported rcVersion) msg'
encHdr <- parseE CryptoHeaderError smpP emHeader
-- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD)
decryptSkipped encHdr encMsg >>= \case