diff --git a/package.yaml b/package.yaml index 0f9d08936..85a58d74f 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: simplexmq -version: 5.6.0.0 +version: 5.6.0.2 synopsis: SimpleXMQ message broker description: | This package includes <./docs/Simplex-Messaging-Server.html server>, @@ -74,6 +74,7 @@ dependencies: - unliftio-core == 0.2.* - websockets == 0.12.* - yaml == 0.11.* + - zstd == 0.1.3.* flags: swift: diff --git a/rfcs/2023-12-29-pqdr.md b/rfcs/2023-12-29-pqdr.md new file mode 100644 index 000000000..7fd88ffbb --- /dev/null +++ b/rfcs/2023-12-29-pqdr.md @@ -0,0 +1,36 @@ +# Post-quantum double ratchet implementation + +See [the previous doc](https://github.com/simplex-chat/simplex-chat/blob/stable/docs/rfcs/2023-09-30-pq-double-ratchet.md). + +The main implementation consideration is that it should be both backwards and forwards compatible, to allow changing the connection DR to/from using PQ primitives (although client version downgrade may be impossible in this case), and also to decide whether to use PQ primitive on per-connection basis: +- use without links (in SMP confirmation or in SMP invitation via address or via member), don't use with links (as they would be too large). +- use in small groups, don't use in large groups. + +Also note that for DR to work we need to have 2 KEMs running in parallel. + +Possible combinations (assuming both clients support PQ): + +| Stage | No PQ kem | PQ key sent | PQ key + PQ ct sent | +|:------------:|:---------:|:-----------:|:-------------------:| +| inv | + | + | - | +| conf, in reply to:
no-pq inv
pq inv |  
+
+ |  
+
- |  
-
+ | +| 1st msg, in reply to:
no-pq conf
pq/pq+ct conf |  
+
+ |  
+
- |  
-
+ | +| Nth msg, in reply to:
no-pq msg
pq/pq+ct msg |  
+
+ |  
+
- |  
-
+ | + +These rules can be reduced to: +1. initial invitation optionally has PQ key, but must not have ciphertext. +2. all subsequent messages should be allowed without PQ key/ciphertext, but: + - if the previous message had PQ key or PQ key with ciphertext, they must either have no PQ key, or have PQ key with ciphertext (PQ key without ciphertext is an error). + - if the previous message had no PQ key, they must either have no PQ key, or have PQ key without ciphertext (PQ key with ciphertext is an error). + +The rules for calculating the shared secret for received/sent messages are (assuming received message is valid according to the above rules): + +| sent msg >
V received msg | no-pq | pq | pq+ct | +|:------------------------------:|:-----------:|:-------:|:---------------:| +| no-pq | DH / DH | DH / DH | err | +| pq (sent msg was NOT pq) | DH / DH | err | DH / DH+KEM | +| pq+ct (sent msg was NOT no-pq) | DH+KEM / DH | err | DH+KEM / DH+KEM | + +To summarize, the upgrade to DH+KEM secret happens in a sent message that has PQ key with ciphertext sent in reply to message with PQ key only (without ciphertext), and the downgrade to DH secret happens in the message that has no PQ key. + +The type for sending PQ key with optional ciphertext is `Maybe E2ERachetKEM` where `data E2ERachetKEM = E2ERachetKEM KEMPublicKey (Maybe KEMCiphertext)`, and for SMP invitation it will be simply `Maybe KEMPublicKey`. Possibly, there is a way to encode the rules above in the types, these types don't constrain possible transitions to valid ones. diff --git a/rfcs/2024-03-03-pqdr-version.md b/rfcs/2024-03-03-pqdr-version.md new file mode 100644 index 000000000..5db9f23a5 --- /dev/null +++ b/rfcs/2024-03-03-pqdr-version.md @@ -0,0 +1,92 @@ +# Migrating existing connections to post-quantum double ratchet algorithm + +## Problem + +Post-quantum variant of double ratchet algorithm represents an almost full-stack change affecting all parts of the protocol stack except client-server protocol (SMP): +- double-ratchet end-to-end encryption: different encoding (additional large keys require byte-strings larger than 255 bytes with 2-byte length prefixes) and larger message headers (increased by ~2200 bytes). +- agent-agent protocol: a smaller maximum message size to accomodate larger headers and to fit in 16kb blocks, reduced by ~2200 bytes for the messages and by almost ~4000 bytes for connection information. +- chat protocol: also a smaller message size compensated by zstd comression of JSON messages. + +We want the versioning that achieves these objectives: +- all changes in all protocol layers happen at the same time, when both clients support it. +- ability to downgrade the clients to the previous version without losing connection. +- ability to opt-in into this functionality via "experimental" feature toggle, that enables post-quantum encryption in connections when both contacts enable this toggle. + +To have ability to downgrade the clients we have two options: +- roll-out this functionality in two stages: 1) roll-out clients support but do not enable the new version, and then 2) upgrade client version. The problem here is that the clients won't be able to opt-in into this experiment. +- make offered range dependent on experimental feature being enabled. Currently we have an option to enable PQ encryption in agent API, and this option can be used as a proxy to maxium supported protocol version - if the option is passed, it can be seen as an indication that higher version range (or version) should offered (or accepted). + +## Solution + +Currently ratchet state stores version range. It's unclear what was the intended semantics of that version range - it simply stores the offered/supported version range at the time ratchet was initialised, but only a high bound is used to send in message headers, and it is never upgraded. In JSON this range is encoded as tuple (an array of two elements in JSON). + +We could continue using this range with the meaning of the lower bound to be "currently used ratchet version" and the meaning of higher boundary to be "maximum supported ratchet version". We could also use the version communicated in message headers to upgrade ratchet version, with the condition that upgrade should only happen if both sides want it. Currently it's defined by pqEnableKEM property in ratchet state. We could also make it more explicit by defining maximum version to which ratchet should upgrade. Given that irreversible upgrades are not very common, it is probably ok to keep it implicit. + +We can define a better type than VersionRange to reflect semantics of the range in ratchet (current/max supported range), but for backward compatibility it needs to be encoded in the same way as now. + +To summarize, the proposed solution for ratchet versioning is: +- define ratchet versions as new type to include current and maximum allowed versions, where maximum allowed will be either the same or lower than maximum supported based on PQ option (in 5.6), and in 5.7 it will be changed to maximum supported, so version starts upgrading independently from PQ being enabled. +- make encodings in ratchet depend on current version (in curent code it depends on max version). +- include max allowed in message header. +- upgrade current if in range on each new message if less than max and higher than current (same as we do for connections). +- increase max allowed once PQ is enabled (only in 5.6). Make max allowed the same as max supported (global constant). + +```haskell +data RatchetVR = RatchetVR + { currentVersion :: Version, + maxAllowedVersion :: Version + } + +instance ToJSON RatchetVR where + toEncoding (RatchetVR v1 v2) = toEncoding (v1, v2) + toJSON (RatchetVR v1 v2) = toJSON (v1, v2) + +instance FromJSON RatchetVR where + parseJSON v = do + -- this also verifies that v2 > v1 (although we could remove JSON instances for VersionRange) + VersionRange v1 v2 <- parseJSON v + pure $ RatchetVR v1 v2 +``` + +For connections, we could also make version used for the purposes of encoding dependent on the PQ being enabled, and version for decoding taken from message header, but then we'd have to not only upgrade ratchets but the connection as well every time PQ mode changes. + +Another suggestion to ensure that correct version range is used in correct contexts could be: +- using different newtypes for different version ranges. +- define generic type class for version aware encoding that would also accept only specific type class for the version to use the correct range. This may be justified as there will be several version-aware encodings, and not just the protocol as now. + +```haskell +class Ord v => EncodingV v a where + {-# MINIMAL smpEncodeV, (smpDecodeV | smpVP) #-} + smpEncodeV :: v -> a -> ByteString + -- default decode uses parser + smpDecodeV :: v -> ByteString -> Either String a + smpDecodeV = parseAll . smpVP + -- default parser decodes from length-specified bytestring + smpVP :: v -> Parser a + smpVP v = smpDecodeV v <$?> smpP +``` + +The version will be passed from currently agreed version, it may only change when message is received, not when message is sent. The version will not be extracted from the encoding itself as it happens now in ratchet encodings. + +## Various options how the problem can be simplified + +1. Do not support connection downgrade once both devices upgraded. If applied to all existing connections then it is a bad option, as it would disrupt some important conversations. + +2. Do not provide ability to opt-in into PQ encryption until v5.7 where it will be rolled out automatically. That is also suboptimal, as it won't allow announcing technology design and have testing outside of the team devices. + +3. The logic explained above where connection upgrade and downgrade is possible and applied to all existing connections if both parties consent to it. There are these important downsides: + - complexity of this logic + - regression risks when this logic is removed. + - some non-coordinated upgrades of existing, potentially important conversations, simply because two users opt-in into the experiment without any expectation that another side also opts-in. + +4. Apply upgrade/downgrade logic and enable PQ encryption as opt-in, based on the toggle in the UX, only for the new connections. This seems the least risky, and also simpler than option 3, as it would only apply to the new connections, and both users will have to enable experimental toggle prior to connecting. + +Option 4 seems the best trade-off, and has these sub-options regarding where it is controlled: +a) in chat based on connection flag. Chat will pass PQ options only to connections that were created when experimental option was enabled. +b) in agent - there will be additional logic to ignore PQ option for existing connections. +c) both in chat and in agent. + +Option 4a seems better, as it would: +- simplify agent code +- minimise required changes when releasing v5.7 (as we do want that all direct and small groups connections migrate to PQ encryption at the time, without any toggles) +- allow tests for connection upgrade in the currect code. diff --git a/simplexmq.cabal b/simplexmq.cabal index 66b671139..039c4abd0 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -5,7 +5,7 @@ cabal-version: 1.12 -- see: https://github.com/sol/hpack name: simplexmq -version: 5.6.0.0 +version: 5.6.0.2 synopsis: SimpleXMQ message broker description: This package includes <./docs/Simplex-Messaging-Server.html server>, <./docs/Simplex-Messaging-Client.html client> and @@ -103,9 +103,12 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent + Simplex.Messaging.Compression Simplex.Messaging.Crypto Simplex.Messaging.Crypto.File Simplex.Messaging.Crypto.Lazy @@ -158,6 +161,7 @@ library Simplex.Messaging.Transport.WebSockets Simplex.Messaging.Util Simplex.Messaging.Version + Simplex.Messaging.Version.Internal Simplex.RemoteControl.Client Simplex.RemoteControl.Discovery Simplex.RemoteControl.Discovery.Multicast @@ -226,6 +230,7 @@ library , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -299,6 +304,7 @@ executable ntf-server , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -372,6 +378,7 @@ executable smp-agent , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -445,6 +452,7 @@ executable smp-server , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -518,6 +526,7 @@ executable xftp , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -591,6 +600,7 @@ executable xftp-server , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON @@ -612,6 +622,7 @@ test-suite simplexmq-test AgentTests AgentTests.ConnectionRequestTests AgentTests.DoubleRatchetTests + AgentTests.EqInstances AgentTests.FunctionalAPITests AgentTests.MigrationTests AgentTests.NotificationTests @@ -634,6 +645,7 @@ test-suite simplexmq-test ServerTests SMPAgentClient SMPClient + Util XFTPAgent XFTPCLI XFTPClient @@ -702,6 +714,7 @@ test-suite simplexmq-test , unliftio-core ==0.2.* , websockets ==0.12.* , yaml ==0.11.* + , zstd ==0.1.3.* default-language: Haskell2010 if flag(swift) cpp-options: -DswiftJSON diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 9a789a104..2abf8e3dc 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -35,6 +35,7 @@ import Control.Monad.Reader import Data.Bifunctor (first) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Coerce (coerce) import Data.Composition ((.:)) import Data.Either (rights) import Data.Int (Int64) @@ -52,8 +53,8 @@ import Simplex.FileTransfer.Client.Main import Simplex.FileTransfer.Crypto import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..)) -import qualified Simplex.FileTransfer.Protocol as XFTP import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import qualified Simplex.FileTransfer.Transport as XFTP import Simplex.FileTransfer.Types import Simplex.FileTransfer.Util (removePath, uniqueCombine) import Simplex.Messaging.Agent.Client @@ -389,21 +390,25 @@ runXFTPSndPrepareWorker c Worker {doWork} = do where AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)]) - encryptFileForUpload SndFile {key, nonce, srcFile} fsEncPath = do + encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do let CryptoFile {filePath} = srcFile fileName = takeFileName filePath fileSize <- liftIO $ fromInteger <$> CF.getFileContentsSize srcFile - when (fileSize > maxFileSize) $ throwError $ INTERNAL "max file size exceeded" + when (fileSize > maxFileSizeHard) $ throwError $ INTERNAL "max file size exceeded" let fileHdr = smpEncode FileHeader {fileName, fileExtra = Nothing} fileSize' = fromIntegral (B.length fileHdr) + fileSize - chunkSizes = prepareChunkSizes $ fileSize' + fileSizeLen + authTagSize - chunkSizes' = map fromIntegral chunkSizes - encSize = sum chunkSizes' + payloadSize = fileSize' + fileSizeLen + authTagSize + chunkSizes <- case redirect of + Nothing -> pure $ prepareChunkSizes payloadSize + Just _ -> case singleChunkSize payloadSize of + Nothing -> throwError $ INTERNAL "max file size exceeded for redirect" + Just chunkSize -> pure [chunkSize] + let encSize = sum $ map fromIntegral chunkSizes void $ liftError (INTERNAL . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath digest <- liftIO $ LC.sha512Hash <$> LB.readFile fsEncPath let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes - chunkDigests <- map FileDigest <$> mapM (liftIO . getChunkDigest) chunkSpecs - pure (FileDigest digest, zip chunkSpecs chunkDigests) + chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs + pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests) chunkCreated :: SndFileChunk -> Bool chunkCreated SndFileChunk {replicas} = any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas diff --git a/src/Simplex/FileTransfer/Client.hs b/src/Simplex/FileTransfer/Client.hs index 84c99eb48..d9c4c058a 100644 --- a/src/Simplex/FileTransfer/Client.hs +++ b/src/Simplex/FileTransfer/Client.hs @@ -57,7 +57,7 @@ import UnliftIO.Directory data XFTPClient = XFTPClient { http2Client :: HTTP2Client, transportSession :: TransportSession FileResponse, - thParams :: THandleParams, + thParams :: THandleParams XFTPVersion, config :: XFTPClientConfig } diff --git a/src/Simplex/FileTransfer/Client/Main.hs b/src/Simplex/FileTransfer/Client/Main.hs index e1ef9f0d8..bca41cea8 100644 --- a/src/Simplex/FileTransfer/Client/Main.hs +++ b/src/Simplex/FileTransfer/Client/Main.hs @@ -16,9 +16,11 @@ module Simplex.FileTransfer.Client.Main xftpClientCLI, cliSendFile, cliSendFileOpts, + singleChunkSize, prepareChunkSizes, prepareChunkSpecs, maxFileSize, + maxFileSizeHard, fileSizeLen, getChunkDigest, SentRecipientReplica (..), @@ -41,7 +43,7 @@ import Data.List.NonEmpty (NonEmpty (..), nonEmpty) import qualified Data.List.NonEmpty as L import Data.Map (Map) import qualified Data.Map as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, listToMaybe) import qualified Data.Text as T import Data.Word (Word32) import GHC.Records (HasField (getField)) @@ -76,12 +78,17 @@ import UnliftIO.Directory xftpClientVersion :: String xftpClientVersion = "1.0.1" +-- | Soft limit for XFTP clients. Should be checked and reported to user. maxFileSize :: Int64 maxFileSize = gb 1 maxFileSizeStr :: String maxFileSizeStr = B.unpack . strEncode $ FileSize maxFileSize +-- | Hard internal limit for XFTP agent after which it refuses to prepare chunks. +maxFileSizeHard :: Int64 +maxFileSizeHard = gb 5 + fileSizeLen :: Int64 fileSizeLen = 8 @@ -214,13 +221,13 @@ data SentFileChunk = SentFileChunk digest :: FileDigest, replicas :: [SentFileChunkReplica] } - deriving (Eq, Show) + deriving (Show) data SentFileChunkReplica = SentFileChunkReplica { server :: XFTPServer, recipients :: [(ChunkReplicaId, C.APrivateAuthKey)] } - deriving (Eq, Show) + deriving (Show) data SentRecipientReplica = SentRecipientReplica { chunkNo :: Int, @@ -407,7 +414,8 @@ getChunkDigest :: XFTPChunkSpec -> IO ByteString getChunkDigest XFTPChunkSpec {filePath = chunkPath, chunkOffset, chunkSize} = withFile chunkPath ReadMode $ \h -> do hSeek h AbsoluteSeek $ fromIntegral chunkOffset - LC.sha256Hash <$> LB.hGet h (fromIntegral chunkSize) + chunk <- LB.hGet h (fromIntegral chunkSize) + pure $! LC.sha256Hash chunk cliReceiveFile :: ReceiveOptions -> ExceptT CLIError IO () cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, verbose, yes} = @@ -522,6 +530,12 @@ getFileDescription' path = getFileDescription path >>= \case AVFD fd -> either (throwError . CLIError) pure $ checkParty fd +singleChunkSize :: Int64 -> Maybe Word32 +singleChunkSize size' = + listToMaybe $ dropWhile (< chunkSize) serverChunkSizes + where + chunkSize = fromIntegral size' + prepareChunkSizes :: Int64 -> [Word32] prepareChunkSizes size' = prepareSizes size' where diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index 58bcb9df3..d5b5e5105 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -227,7 +227,7 @@ validateFileDescription fd@FileDescription {size, chunks} | otherwise = Right $ ValidFD fd where chunkNos = map (\FileChunk {chunkNo} -> chunkNo) chunks - chunksSize = fromIntegral . foldl' (\s FileChunk {chunkSize} -> s + unFileSize chunkSize) 0 + chunksSize = foldl' (\(s :: Int64) FileChunk {chunkSize} -> s + fromIntegral (unFileSize chunkSize)) 0 encodeFileDescription :: FileDescription p -> YAMLFileDescription encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSize, chunks, redirect} = diff --git a/src/Simplex/FileTransfer/Protocol.hs b/src/Simplex/FileTransfer/Protocol.hs index a9de56ddb..2ba75f027 100644 --- a/src/Simplex/FileTransfer/Protocol.hs +++ b/src/Simplex/FileTransfer/Protocol.hs @@ -1,9 +1,11 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} @@ -14,9 +16,7 @@ module Simplex.FileTransfer.Protocol where -import Control.Applicative ((<|>)) import qualified Data.Aeson.TH as J -import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -25,11 +25,11 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isNothing) import Data.Type.Equality import Data.Word (Word32) +import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake) import Simplex.Messaging.Client (authTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol ( BasicAuth, @@ -56,11 +56,10 @@ import Simplex.Messaging.Protocol tParse, ) import Simplex.Messaging.Transport (THandleParams (..), TransportError (..)) -import Simplex.Messaging.Util (bshow, (<$?>)) -import Simplex.Messaging.Version +import Simplex.Messaging.Util ((<$?>)) -currentXFTPVersion :: Version -currentXFTPVersion = 1 +currentXFTPVersion :: VersionXFTP +currentXFTPVersion = VersionXFTP 1 xftpBlockSize :: Int xftpBlockSize = 16384 @@ -142,10 +141,10 @@ instance ProtocolMsgTag FileCmdTag where instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where decodeTag s = decodeTag s >>= (\(FCT _ t) -> checkParty' t) -instance Protocol XFTPErrorType FileResponse where +instance Protocol XFTPVersion XFTPErrorType FileResponse where type ProtoCommand FileResponse = FileCmd type ProtoType FileResponse = 'PXFTP - protocolClientHandshake = ntfClientHandshake + protocolClientHandshake = xftpClientHandshake protocolPing = FileCmd SFRecipient PING protocolError = \case FRErr e -> Just e @@ -171,11 +170,11 @@ data FileInfo = FileInfo size :: Word32, digest :: ByteString } - deriving (Eq, Show) + deriving (Show) type XFTPFileId = ByteString -instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where +instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand p) where type Tag (FileCommand p) = FileCommandTag p encodeProtocol _v = \case FNEW file rKeys auth_ -> e (FNEW_, ' ', file, rKeys, auth_) @@ -191,7 +190,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where protocolP v tag = (\(FileCmd _ c) -> checkParty c) <$?> protocolP v (FCT (sFileParty @p) tag) - fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse + fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse {-# INLINE fromProtocolError #-} checkCredentials (auth, _, fileId, _) cmd = case cmd of @@ -208,7 +207,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where | isNothing auth || B.null fileId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding XFTPErrorType FileCmd where +instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where type Tag FileCmd = FileCmdTag encodeProtocol _v (FileCmd _ c) = encodeProtocol _v c @@ -225,7 +224,7 @@ instance ProtocolEncoding XFTPErrorType FileCmd where FACK_ -> pure FACK PING_ -> pure PING - fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse + fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse {-# INLINE fromProtocolError #-} checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c @@ -276,7 +275,7 @@ data FileResponse | FRPong deriving (Show) -instance ProtocolEncoding XFTPErrorType FileResponse where +instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where type Tag FileResponse = FileResponseTag encodeProtocol _v = \case FRSndIds fId rIds -> e (FRSndIds_, ' ', fId, rIds) @@ -319,82 +318,6 @@ instance ProtocolEncoding XFTPErrorType FileResponse where | B.null entId = Right cmd | otherwise = Left $ CMD HAS_AUTH -data XFTPErrorType - = -- | incorrect block format, encoding or signature size - BLOCK - | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) - SESSION - | -- | SMP command is unknown or has invalid syntax - CMD {cmdErr :: CommandError} - | -- | command authorization error - bad signature or non-existing SMP queue - AUTH - | -- | incorrent file size - SIZE - | -- | storage quota exceeded - QUOTA - | -- | incorrent file digest - DIGEST - | -- | file encryption/decryption failed - CRYPTO - | -- | no expected file body in request/response or no file on the server - NO_FILE - | -- | unexpected file body - HAS_FILE - | -- | file IO error - FILE_IO - | -- | bad redirect data - REDIRECT {redirectError :: String} - | -- | internal server error - INTERNAL - | -- | used internally, never returned by the server (to be removed) - DUPLICATE_ -- not part of SMP protocol, used internally - deriving (Eq, Read, Show) - -instance StrEncoding XFTPErrorType where - strEncode = \case - CMD e -> "CMD " <> bshow e - REDIRECT e -> "REDIRECT " <> bshow e - e -> bshow e - strP = - "CMD " *> (CMD <$> parseRead1) - <|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString) - <|> parseRead1 - -instance Encoding XFTPErrorType where - smpEncode = \case - BLOCK -> "BLOCK" - SESSION -> "SESSION" - CMD err -> "CMD " <> smpEncode err - AUTH -> "AUTH" - SIZE -> "SIZE" - QUOTA -> "QUOTA" - DIGEST -> "DIGEST" - CRYPTO -> "CRYPTO" - NO_FILE -> "NO_FILE" - HAS_FILE -> "HAS_FILE" - FILE_IO -> "FILE_IO" - REDIRECT err -> "REDIRECT " <> smpEncode err - INTERNAL -> "INTERNAL" - DUPLICATE_ -> "DUPLICATE_" - - smpP = - A.takeTill (== ' ') >>= \case - "BLOCK" -> pure BLOCK - "SESSION" -> pure SESSION - "CMD" -> CMD <$> _smpP - "AUTH" -> pure AUTH - "SIZE" -> pure SIZE - "QUOTA" -> pure QUOTA - "DIGEST" -> pure DIGEST - "CRYPTO" -> pure CRYPTO - "NO_FILE" -> pure NO_FILE - "HAS_FILE" -> pure HAS_FILE - "FILE_IO" -> pure FILE_IO - "REDIRECT" -> REDIRECT <$> _smpP - "INTERNAL" -> pure INTERNAL - "DUPLICATE_" -> pure DUPLICATE_ - _ -> fail "bad error type" - checkParty :: forall t p p'. (FilePartyI p, FilePartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Right c @@ -405,12 +328,12 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of Just Refl -> Just c _ -> Nothing -xftpEncodeAuthTransmission :: ProtocolEncoding e c => THandleParams -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString +xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg) xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth -xftpEncodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> Either TransportError ByteString +xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString xftpEncodeTransmission thParams (corrId, fId, msg) = do let t = encodeTransmission thParams (corrId, fId, msg) xftpEncodeBatch1 (Nothing, t) @@ -419,7 +342,7 @@ xftpEncodeTransmission thParams (corrId, fId, msg) = do xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 t) xftpBlockSize -xftpDecodeTransmission :: ProtocolEncoding e c => THandleParams -> ByteString -> Either XFTPErrorType (SignedTransmission e c) +xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> ByteString -> Either XFTPErrorType (SignedTransmission e c) xftpDecodeTransmission thParams t = do t' <- first (const BLOCK) $ C.unPad t case tParse thParams t' of @@ -427,5 +350,3 @@ xftpDecodeTransmission thParams t = do _ -> Left BLOCK $(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty) - -$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/FileTransfer/Server.hs b/src/Simplex/FileTransfer/Server.hs index 158429d79..ae202c2b0 100644 --- a/src/Simplex/FileTransfer/Server.hs +++ b/src/Simplex/FileTransfer/Server.hs @@ -69,7 +69,7 @@ type M a = ReaderT XFTPEnv IO a data XFTPTransportRequest = XFTPTransportRequest - { thParams :: THandleParams, + { thParams :: THandleParams XFTPVersion, reqBody :: HTTP2Body, request :: H.Request, sendResponse :: H.Response -> IO () diff --git a/src/Simplex/FileTransfer/Server/Store.hs b/src/Simplex/FileTransfer/Server/Store.hs index 031c46f5b..a3681944e 100644 --- a/src/Simplex/FileTransfer/Server/Store.hs +++ b/src/Simplex/FileTransfer/Server/Store.hs @@ -28,7 +28,8 @@ import Data.Int (Int64) import Data.Set (Set) import qualified Data.Set as S import Data.Time.Clock.System (SystemTime (..)) -import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPErrorType (..), XFTPFileId) +import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPFileId) +import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (RcvPublicAuthKey, RecipientId, SenderId) @@ -49,7 +50,6 @@ data FileRec = FileRec recipientIds :: TVar (Set RecipientId), createdAt :: SystemTime } - deriving (Eq) data FileRecipient = FileRecipient RecipientId RcvPublicAuthKey diff --git a/src/Simplex/FileTransfer/Transport.hs b/src/Simplex/FileTransfer/Transport.hs index 90b1a8a44..464a75ac8 100644 --- a/src/Simplex/FileTransfer/Transport.hs +++ b/src/Simplex/FileTransfer/Transport.hs @@ -1,10 +1,19 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} module Simplex.FileTransfer.Transport ( supportedFileServerVRange, + xftpClientHandshake, -- stub + XFTPVersion, + VersionXFTP, + pattern VersionXFTP, + XFTPErrorType (..), XFTPRcvChunkSpec (..), ReceiveFileError (..), receiveFile, @@ -14,22 +23,31 @@ module Simplex.FileTransfer.Transport ) where +import Control.Applicative ((<|>)) import qualified Control.Exception as E import Control.Monad import Control.Monad.Except import Control.Monad.IO.Class +import qualified Data.Aeson.TH as J +import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) import qualified Data.ByteArray as BA import Data.ByteString.Builder (Builder, byteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB -import Data.Word (Word32) -import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Data.Word (Word16, Word32) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers +import Simplex.Messaging.Protocol (CommandError) +import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..)) import Simplex.Messaging.Transport.HTTP2.File +import Simplex.Messaging.Util (bshow) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import System.IO (Handle, IOMode (..), withFile) data XFTPRcvChunkSpec = XFTPRcvChunkSpec @@ -39,8 +57,26 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec } deriving (Show) -supportedFileServerVRange :: VersionRange -supportedFileServerVRange = mkVersionRange 1 1 +data XFTPVersion + +instance VersionScope XFTPVersion + +type VersionXFTP = Version XFTPVersion + +type VersionRangeXFTP = VersionRange XFTPVersion + +pattern VersionXFTP :: Word16 -> VersionXFTP +pattern VersionXFTP v = Version v + +initialXFTPVersion :: VersionXFTP +initialXFTPVersion = VersionXFTP 1 + +supportedFileServerVRange :: VersionRangeXFTP +supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion + +-- XFTP protocol does not support handshake +xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c) +xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO () sendEncFile h send = go @@ -97,3 +133,81 @@ receiveFile_ receive XFTPRcvChunkSpec {filePath, chunkSize, chunkDigest} = do ExceptT $ withFile filePath WriteMode (`receive` chunkSize) digest' <- liftIO $ LC.sha256Hash <$> LB.readFile filePath when (digest' /= chunkDigest) $ throwError DIGEST + +data XFTPErrorType + = -- | incorrect block format, encoding or signature size + BLOCK + | -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929) + SESSION + | -- | SMP command is unknown or has invalid syntax + CMD {cmdErr :: CommandError} + | -- | command authorization error - bad signature or non-existing SMP queue + AUTH + | -- | incorrent file size + SIZE + | -- | storage quota exceeded + QUOTA + | -- | incorrent file digest + DIGEST + | -- | file encryption/decryption failed + CRYPTO + | -- | no expected file body in request/response or no file on the server + NO_FILE + | -- | unexpected file body + HAS_FILE + | -- | file IO error + FILE_IO + | -- | bad redirect data + REDIRECT {redirectError :: String} + | -- | internal server error + INTERNAL + | -- | used internally, never returned by the server (to be removed) + DUPLICATE_ -- not part of SMP protocol, used internally + deriving (Eq, Read, Show) + +instance StrEncoding XFTPErrorType where + strEncode = \case + CMD e -> "CMD " <> bshow e + REDIRECT e -> "REDIRECT " <> bshow e + e -> bshow e + strP = + "CMD " *> (CMD <$> parseRead1) + <|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString) + <|> parseRead1 + +instance Encoding XFTPErrorType where + smpEncode = \case + BLOCK -> "BLOCK" + SESSION -> "SESSION" + CMD err -> "CMD " <> smpEncode err + AUTH -> "AUTH" + SIZE -> "SIZE" + QUOTA -> "QUOTA" + DIGEST -> "DIGEST" + CRYPTO -> "CRYPTO" + NO_FILE -> "NO_FILE" + HAS_FILE -> "HAS_FILE" + FILE_IO -> "FILE_IO" + REDIRECT err -> "REDIRECT " <> smpEncode err + INTERNAL -> "INTERNAL" + DUPLICATE_ -> "DUPLICATE_" + + smpP = + A.takeTill (== ' ') >>= \case + "BLOCK" -> pure BLOCK + "SESSION" -> pure SESSION + "CMD" -> CMD <$> _smpP + "AUTH" -> pure AUTH + "SIZE" -> pure SIZE + "QUOTA" -> pure QUOTA + "DIGEST" -> pure DIGEST + "CRYPTO" -> pure CRYPTO + "NO_FILE" -> pure NO_FILE + "HAS_FILE" -> pure HAS_FILE + "FILE_IO" -> pure FILE_IO + "REDIRECT" -> REDIRECT <$> _smpP + "INTERNAL" -> pure INTERNAL + "DUPLICATE_" -> pure DUPLICATE_ + _ -> fail "bad error type" + +$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType) diff --git a/src/Simplex/FileTransfer/Types.hs b/src/Simplex/FileTransfer/Types.hs index 21967a3cd..ba306a6c6 100644 --- a/src/Simplex/FileTransfer/Types.hs +++ b/src/Simplex/FileTransfer/Types.hs @@ -55,7 +55,7 @@ data RcvFile = RcvFile status :: RcvFileStatus, deleted :: Bool } - deriving (Eq, Show) + deriving (Show) data RcvFileStatus = RFSReceiving @@ -96,7 +96,7 @@ data RcvFileChunk = RcvFileChunk fileTmpPath :: FilePath, chunkTmpPath :: Maybe FilePath } - deriving (Eq, Show) + deriving (Show) data RcvFileChunkReplica = RcvFileChunkReplica { rcvChunkReplicaId :: Int64, @@ -107,14 +107,14 @@ data RcvFileChunkReplica = RcvFileChunkReplica delay :: Maybe Int64, retries :: Int } - deriving (Eq, Show) + deriving (Show) data RcvFileRedirect = RcvFileRedirect { redirectDbId :: DBRcvFileId, redirectEntityId :: RcvFileId, redirectFileInfo :: RedirectFileInfo } - deriving (Eq, Show) + deriving (Show) -- Sending files @@ -135,7 +135,7 @@ data SndFile = SndFile deleted :: Bool, redirect :: Maybe RedirectFileInfo } - deriving (Eq, Show) + deriving (Show) sndFileEncPath :: FilePath -> FilePath sndFileEncPath prefixPath = prefixPath "xftp.encrypted" @@ -182,7 +182,7 @@ data SndFileChunk = SndFileChunk digest :: FileDigest, replicas :: [SndFileChunkReplica] } - deriving (Eq, Show) + deriving (Show) sndChunkSize :: SndFileChunk -> Word32 sndChunkSize SndFileChunk {chunkSpec = XFTPChunkSpec {chunkSize}} = chunkSize @@ -193,7 +193,7 @@ data NewSndChunkReplica = NewSndChunkReplica replicaKey :: C.APrivateAuthKey, rcvIdsKeys :: [(ChunkReplicaId, C.APrivateAuthKey)] } - deriving (Eq, Show) + deriving (Show) data SndFileChunkReplica = SndFileChunkReplica { sndChunkReplicaId :: Int64, @@ -205,7 +205,7 @@ data SndFileChunkReplica = SndFileChunkReplica delay :: Maybe Int64, retries :: Int } - deriving (Eq, Show) + deriving (Show) data SndFileReplicaStatus = SFRSCreated @@ -235,4 +235,4 @@ data DeletedSndChunkReplica = DeletedSndChunkReplica delay :: Maybe Int64, retries :: Int } - deriving (Eq, Show) + deriving (Show) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index aa9872e9c..2b2db1dd5 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -9,6 +9,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -45,6 +46,7 @@ module Simplex.Messaging.Agent withInvLock, createUser, deleteUser, + connRequestPQSupport, createConnectionAsync, joinConnectionAsync, allowConnectionAsync, @@ -140,8 +142,9 @@ import Data.Text (Text) import qualified Data.Text as T import Data.Time.Clock import Data.Time.Clock.System (systemToUTCTime) +import Data.Traversable (mapAccumL) import Data.Word (Word16) -import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFilesInternal, deleteSndFileRemote, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') +import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, deleteSndFilesInternal, deleteSndFilesRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpDeleteRcvFiles', xftpReceiveFile', xftpSendDescription', xftpSendFile') import Simplex.FileTransfer.Description (ValidFileDescription) import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Util (removePath) @@ -158,6 +161,7 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) +import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport (..), pattern PQEncOn, pattern PQEncOff, pattern PQSupportOn, pattern PQSupportOff) import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String @@ -165,11 +169,11 @@ import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfReg import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..)) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth) +import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicAuthKey, SubscriptionMode (..), UserProtocol, VersionSMPC, XFTPServerWithAuth) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (THandleParams (sessionId)) +import Simplex.Messaging.Transport (SMPVersion, THandleParams (sessionId)) import Simplex.Messaging.Util import Simplex.Messaging.Version import Simplex.RemoteControl.Client @@ -218,20 +222,20 @@ deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m () deleteUser c = withAgentEnv c .: deleteUser' c -- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id -createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId -createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .: newConnAsync c userId aCorrId enableNtfs +createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId +createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. newConnAsync c userId aCorrId enableNtfs -- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id -joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. joinConnAsync c userId aCorrId enableNtfs +joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:: joinConnAsync c userId aCorrId enableNtfs -- | Allow connection to continue after CONF notification (LET command), no synchronous response allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c -- | Accept contact after REQ notification (ACPT command) asynchronously, synchronous response is new connection id -acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:. acceptContactAsync' c aCorrId enableNtfs +acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:: acceptContactAsync' c aCorrId enableNtfs -- | Acknowledge message (ACK command) asynchronously, no synchronous response ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () @@ -242,28 +246,28 @@ switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId - switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c -- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response -deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ConnId -> m () -deleteConnectionAsync c = withAgentEnv c . deleteConnectionAsync' c +deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> Bool -> ConnId -> m () +deleteConnectionAsync c waitDelivery = withAgentEnv c . deleteConnectionAsync' c waitDelivery -- | Delete SMP agent connections using batch commands asynchronously, no synchronous response -deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> [ConnId] -> m () -deleteConnectionsAsync c = withAgentEnv c . deleteConnectionsAsync' c +deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync c waitDelivery = withAgentEnv c . deleteConnectionsAsync' c waitDelivery -- | Create SMP agent connection (NEW command) -createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) -createConnection c userId enableNtfs = withAgentEnv c .:. newConn c userId "" enableNtfs +createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +createConnection c userId enableNtfs = withAgentEnv c .:: newConn c userId "" enableNtfs -- | Join SMP agent connection (JOIN command) -joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnection c userId enableNtfs = withAgentEnv c .:. joinConn c userId "" enableNtfs +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnection c userId enableNtfs = withAgentEnv c .:: joinConn c userId "" enableNtfs -- | Allow connection to continue after CONF notification (LET command) allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () allowConnection c = withAgentEnv c .:. allowConnection' c -- | Accept contact after REQ notification (ACPT command) -acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContact c enableNtfs = withAgentEnv c .:. acceptContact' c "" enableNtfs +acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContact c enableNtfs = withAgentEnv c .:: acceptContact' c "" enableNtfs -- | Reject contact (RJCT command) rejectContact :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> m () @@ -292,16 +296,16 @@ resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map resubscribeConnections c = withAgentEnv c . resubscribeConnections' c -- | Send message to the connection (SEND command) -sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage c = withAgentEnv c .:. sendMessage' c +sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, PQEncryption) +sendMessage c = withAgentEnv c .:: sendMessage' c -type MsgReq = (ConnId, MsgFlags, MsgBody) +type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody) -- | Send multiple messages to different connections (SEND command) -sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] +sendMessages :: MonadUnliftIO m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, PQEncryption)] sendMessages c = withAgentEnv c . sendMessages' c -sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId)) +sendMessagesB :: (MonadUnliftIO m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB c = withAgentEnv c . sendMessagesB' c ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m () @@ -316,8 +320,8 @@ abortConnectionSwitch :: AgentErrorMonad m => AgentClient -> ConnId -> m Connect abortConnectionSwitch c = withAgentEnv c . abortConnectionSwitch' c -- | Re-synchronize connection ratchet keys -synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> Bool -> m ConnectionStats -synchronizeRatchet c = withAgentEnv c .: synchronizeRatchet' c +synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> PQSupport -> Bool -> m ConnectionStats +synchronizeRatchet c = withAgentEnv c .:. synchronizeRatchet' c -- | Suspend SMP agent connection (OFF command) suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () @@ -514,13 +518,13 @@ client c@AgentClient {rcvQ, subQ} = forever $ do processCommand :: forall m. AgentMonad m => AgentClient -> (EntityId, APartyCmd 'Client) -> m (EntityId, APartyCmd 'Agent) processCommand c (connId, APC e cmd) = second (APC e) <$> case cmd of - NEW enableNtfs (ACM cMode) subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing subMode - JOIN enableNtfs (ACR _ cReq) subMode connInfo -> (,OK) <$> joinConn c userId connId enableNtfs cReq connInfo subMode + NEW enableNtfs (ACM cMode) pqIK subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing pqIK subMode + JOIN enableNtfs (ACR _ cReq) pqEnc subMode connInfo -> (,OK) <$> joinConn c userId connId enableNtfs cReq connInfo pqEnc subMode LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK) - ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo SMSubscribe + ACPT invId pqEnc ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo pqEnc SMSubscribe RJCT invId -> rejectContact' c connId invId $> (connId, OK) SUB -> subscribeConnection' c connId $> (connId, OK) - SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody + SEND pqEnc msgFlags msgBody -> (connId,) . uncurry MID <$> sendMessage' c connId pqEnc msgFlags msgBody ACK msgId rcptInfo_ -> ackMessage' c connId msgId rcptInfo_ $> (connId, OK) SWCH -> switchConnection' c connId $> (connId, OK) OFF -> suspendConnection' c connId $> (connId, OK) @@ -541,7 +545,7 @@ createUser' c smp xftp = do deleteUser' :: AgentMonad m => AgentClient -> UserId -> Bool -> m () deleteUser' c userId delSMPQueues = do if delSMPQueues - then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c + then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c False else withStore c (`deleteUserRecord` userId) atomically $ TM.delete userId $ smpServers c where @@ -549,32 +553,32 @@ deleteUser' c userId delSMPQueues = do whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $ writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId) -newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId -newConnAsync c userId corrId enableNtfs cMode subMode = do - connId <- newConnNoQueues c userId "" enableNtfs cMode - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) subMode +newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> CR.InitialKeys -> SubscriptionMode -> m ConnId +newConnAsync c userId corrId enableNtfs cMode pqInitKeys subMode = do + connId <- newConnNoQueues c userId "" enableNtfs cMode (CR.connPQEncryption pqInitKeys) + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) pqInitKeys subMode pure connId -newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId -newConnNoQueues c userId connId enableNtfs cMode = do +newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> PQSupport -> m ConnId +newConnNoQueues c userId connId enableNtfs cMode pqSupport = do g <- asks random - connAgentVersion <- asks $ maxVersion . smpAgentVRange . config - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + connAgentVersion <- asks $ maxVersion . ($ pqSupport) . smpAgentVRange . config + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} withStore c $ \db -> createNewConn db g cData cMode -joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do +joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConnAsync c userId corrId enableNtfs cReqUri@CRInvitationUri {} cInfo pqSup subMode = do withInvLock c (strEncode cReqUri) "joinConnAsync" $ do - aVRange <- asks $ smpAgentVRange . config - case crAgentVRange `compatibleVersion` aVRange of - Just (Compatible connAgentVersion) -> do + compatibleInvitationUri cReqUri pqSup >>= \case + Just (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible connAgentVersion) -> do g <- asks random - let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) + cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation - enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo + enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) pqSupport subMode cInfo pure connId - _ -> throwError $ AGENT A_VERSION -joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo = + Nothing -> throwError $ AGENT A_VERSION +joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo _pqEncryption = throwError $ CMD PROHIBITED allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m () @@ -584,13 +588,13 @@ allowConnectionAsync' c corrId connId confId ownConnInfo = enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo _ -> throwError $ CMD PROHIBITED -acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContactAsync' c corrId enableNtfs invId ownConnInfo subMode = do +acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContactAsync' c corrId enableNtfs invId ownConnInfo pqSupport subMode = do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConnAsync c userId corrId enableNtfs connReq ownConnInfo subMode `catchAgentError` \err -> do + joinConnAsync c userId corrId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -608,26 +612,26 @@ ackMessageAsync' c corrId connId msgId rcptInfo_ = do enqueueAck :: m () enqueueAck = do let mId = InternalId msgId - RcvMsg {msgType} <- withStoreCtx "ackMessageAsync': getRcvMsg" c $ \db -> getRcvMsg db connId mId + RcvMsg {msgType} <- withStore c $ \db -> getRcvMsg db connId mId when (isJust rcptInfo_ && msgType /= AM_A_MSG_) $ throwError $ CMD PROHIBITED - (RcvQueue {server}, _) <- withStoreCtx "ackMessageAsync': setMsgUserAck" c $ \db -> setMsgUserAck db connId mId + (RcvQueue {server}, _) <- withStore c $ \db -> setMsgUserAck db connId mId enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId rcptInfo_ -deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () -deleteConnectionAsync' c connId = deleteConnectionsAsync' c [connId] +deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> Bool -> ConnId -> m () +deleteConnectionAsync' c waitDelivery connId = deleteConnectionsAsync' c waitDelivery [connId] -deleteConnectionsAsync' :: AgentMonad m => AgentClient -> [ConnId] -> m () +deleteConnectionsAsync' :: AgentMonad m => AgentClient -> Bool -> [ConnId] -> m () deleteConnectionsAsync' = deleteConnectionsAsync_ $ pure () -deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> [ConnId] -> m () -deleteConnectionsAsync_ onSuccess c connIds = case connIds of +deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> Bool -> [ConnId] -> m () +deleteConnectionsAsync_ onSuccess c waitDelivery connIds = case connIds of [] -> onSuccess _ -> do - (_, rqs, connIds') <- prepareDeleteConnections_ getConns c connIds - withStore' c $ forM_ connIds' . setConnDeleted + (_, rqs, connIds') <- prepareDeleteConnections_ getConns c waitDelivery connIds + withStore' c $ \db -> forM_ connIds' $ setConnDeleted db waitDelivery void . forkIO $ withLock (deleteLock c) "deleteConnectionsAsync" $ - deleteConnQueues c True rqs >> onSuccess + deleteConnQueues c waitDelivery True rqs >> onSuccess -- | Add connection to the new receive queue switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ConnectionStats @@ -644,17 +648,20 @@ switchConnectionAsync' c corrId connId = pure . connectionStats $ DuplexConnection cData rqs' sqs _ -> throwError $ CMD PROHIBITED -newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) -newConn c userId connId enableNtfs cMode clientData subMode = - getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData subMode +newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +newConn c userId connId enableNtfs cMode clientData pqInitKeys subMode = + getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode -newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) -newConnSrv c userId connId enableNtfs cMode clientData subMode srv = do - connId' <- newConnNoQueues c userId connId enableNtfs cMode - newRcvConnSrv c userId connId' enableNtfs cMode clientData subMode srv +newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do + connId' <- newConnNoQueues c userId connId enableNtfs cMode (CR.connPQEncryption pqInitKeys) + newRcvConnSrv c userId connId' enableNtfs cMode clientData pqInitKeys subMode srv -newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) -newRcvConnSrv c userId connId enableNtfs cMode clientData subMode srv = do +newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> CR.InitialKeys -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) +newRcvConnSrv c userId connId enableNtfs cMode clientData pqInitKeys subMode srv = do + case (cMode, pqInitKeys) of + (SCMContact, CR.IKUsePQ) -> throwError $ CMD PROHIBITED + _ -> pure () AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config (rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e rq' <- withStore c $ \db -> updateNewConnRcv db connId rq @@ -664,44 +671,75 @@ newRcvConnSrv c userId connId enableNtfs cMode clientData subMode srv = do when enableNtfs $ do ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) - let crData = ConnReqUriData SSSimplex smpAgentVRange [qUri] clientData + let pqEnc = CR.connPQEncryption pqInitKeys + crData = ConnReqUriData SSSimplex (smpAgentVRange pqEnc) [qUri] clientData + e2eVRange = e2eEncryptVRange pqEnc case cMode of SCMContact -> pure (connId, CRContactUri crData) SCMInvitation -> do g <- asks random - (pk1, pk2, e2eRcvParams) <- atomically . CR.generateE2EParams g $ maxVersion e2eEncryptVRange - withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 - pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange) + (pk1, pk2, pKem, e2eRcvParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion e2eVRange) (CR.initialPQEncryption pqInitKeys) + withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2 pKem + pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eVRange) -joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId -joinConn c userId connId enableNtfs cReq cInfo subMode = do +joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +joinConn c userId connId enableNtfs cReq cInfo pqSupport subMode = do srv <- case cReq of CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ -> getNextServer c userId [qServer q] _ -> getSMPServer c userId - joinConnSrv c userId connId enableNtfs cReq cInfo subMode srv + joinConnSrv c userId connId enableNtfs cReq cInfo pqSupport subMode srv -startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> m (Compatible Version, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.E2ERatchetParams 'C.X448) -startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) = do - AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - case ( qUri `compatibleVersion` smpClientVRange, - e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange, - crAgentVRange `compatibleVersion` smpAgentVRange - ) of - (Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams _ _ rcDHRr)), Just aVersion@(Compatible connAgentVersion)) -> do +startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Compatible VersionSMPA, ConnData, NewSndQueue, CR.Ratchet 'C.X448, CR.SndE2ERatchetParams 'C.X448) +startJoinInvitation userId connId enableNtfs cReqUri pqSup = + compatibleInvitationUri cReqUri pqSup >>= \case + Just (qInfo, (Compatible e2eRcvParams@(CR.E2ERatchetParams v _ rcDHRr kem_)), aVersion@(Compatible connAgentVersion)) -> do g <- asks random - (pk1, pk2, e2eSndParams) <- atomically . CR.generateE2EParams g $ version e2eRcvParams + let pqSupport = pqSup `CR.pqSupportAnd` versionPQSupport_ connAgentVersion (Just v) + (pk1, pk2, pKem, e2eSndParams) <- liftIO $ CR.generateSndE2EParams g v (CR.replyKEM_ v kem_ pqSupport) (_, rcDHRs) <- atomically $ C.generateKeyPair g - let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams + rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 pKem e2eRcvParams + maxSupported <- asks $ maxVersion . ($ pqSup) . e2eEncryptVRange . config + let rcVs = CR.RatchetVersions {current = v, maxSupported} + rc = CR.initSndRatchet rcVs rcDHRr rcDHRs rcParams q <- newSndQueue userId "" qInfo - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} pure (aVersion, cData, q, rc, e2eSndParams) - _ -> throwError $ AGENT A_VERSION + Nothing -> throwError $ AGENT A_VERSION -joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId -joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = +connRequestPQSupport :: MonadUnliftIO m => AgentClient -> PQSupport -> ConnectionRequestUri c -> m (Maybe (VersionSMPA, PQSupport)) +connRequestPQSupport c pqSup cReq = withAgentEnv c $ case cReq of + CRInvitationUri {} -> invPQSupported <$$> compatibleInvitationUri cReq pqSup + where + invPQSupported (_, Compatible (CR.E2ERatchetParams e2eV _ _ _), Compatible agentV) = (agentV, pqSup `CR.pqSupportAnd` versionPQSupport_ agentV (Just e2eV)) + CRContactUri {} -> ctPQSupported <$$> compatibleContactUri cReq pqSup + where + ctPQSupported (_, Compatible agentV) = (agentV, pqSup `CR.pqSupportAnd` versionPQSupport_ agentV Nothing) + +compatibleInvitationUri :: AgentMonad' m => ConnectionRequestUri 'CMInvitation -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible (CR.RcvE2ERatchetParams 'C.X448), Compatible VersionSMPA)) +compatibleInvitationUri (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) pqSup = do + AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config + pure $ + (,,) + <$> (qUri `compatibleVersion` smpClientVRange) + <*> (e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange pqSup) + <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) + +compatibleContactUri :: AgentMonad' m => ConnectionRequestUri 'CMContact -> PQSupport -> m (Maybe (Compatible SMPQueueInfo, Compatible VersionSMPA)) +compatibleContactUri (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) pqSup = do + AgentConfig {smpClientVRange, smpAgentVRange} <- asks config + pure $ + (,) + <$> (qUri `compatibleVersion` smpClientVRange) + <*> (crAgentVRange `compatibleVersion` smpAgentVRange pqSup) + +versionPQSupport_ :: VersionSMPA -> Maybe CR.VersionE2E -> PQSupport +versionPQSupport_ agentV e2eV_ = PQSupport $ agentV >= pqdrSMPAgentVersion && maybe True (>= CR.pqRatchetE2EEncryptVersion) e2eV_ + +joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m ConnId +joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSup subMode srv = withInvLock c (strEncode inv) "joinConnSrv" $ do - (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv + (aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSup g <- asks random (connId', sq) <- withStore c $ \db -> runExceptT $ do r@(connId', _) <- ExceptT $ createSndConn db g cData q @@ -712,28 +750,24 @@ joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv Right _ -> pure connId' Left e -> do -- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md - withStore' c (`deleteConn` connId') + void $ withStore' c $ \db -> deleteConn db Nothing connId' throwError e -joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do - aVRange <- asks $ smpAgentVRange . config - clientVRange <- asks $ smpClientVRange . config - case ( qUri `compatibleVersion` clientVRange, - crAgentVRange `compatibleVersion` aVRange - ) of - (Just qInfo, Just vrsn) -> do - (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing subMode srv +joinConnSrv c userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup subMode srv = + compatibleContactUri cReqUri pqSup >>= \case + Just (qInfo, vrsn) -> do + (connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing (CR.IKNoPQ pqSup) subMode srv sendInvitation c userId qInfo vrsn cReq cInfo pure connId' - _ -> throwError $ AGENT A_VERSION + Nothing -> throwError $ AGENT A_VERSION -joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m () -joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do - (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv +joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> PQSupport -> SubscriptionMode -> SMPServerWithAuth -> m () +joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo pqSupport subMode srv = do + (_aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv pqSupport q' <- withStore c $ \db -> runExceptT $ do liftIO $ createRatchet db connId rc ExceptT $ updateNewConnSnd db connId q confirmQueueAsync c cData q' srv cInfo (Just e2eSndParams) subMode -joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _srv = do +joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _pqSupport _srv = do throwError $ CMD PROHIBITED createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> m SMPQueueInfo @@ -764,13 +798,13 @@ allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConne _ -> throwError $ CMD PROHIBITED -- | Accept contact (ACPT command) in Reader monad -acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> SubscriptionMode -> m ConnId -acceptContact' c connId enableNtfs invId ownConnInfo subMode = withConnLock c connId "acceptContact" $ do +acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> PQSupport -> SubscriptionMode -> m ConnId +acceptContact' c connId enableNtfs invId ownConnInfo pqSupport subMode = withConnLock c connId "acceptContact" $ do Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId) withStore c (`getConn` contactConnId) >>= \case SomeConn _ (ContactConnection ConnData {userId} _) -> do withStore' c $ \db -> acceptInvitation db invId ownConnInfo - joinConn c userId connId enableNtfs connReq ownConnInfo subMode `catchAgentError` \err -> do + joinConn c userId connId enableNtfs connReq ownConnInfo pqSupport subMode `catchAgentError` \err -> do withStore' c (`unacceptInvitation` invId) throwError err _ -> throwError $ CMD PROHIBITED @@ -905,30 +939,37 @@ getNotificationMessage' c nonce encNtfInfo = do Nothing -> SMP.notification msgFlags -- | Send message to the connection (SEND command) in Reader monad -sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId -sendMessage' c connId msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c (Identity (Right (connId, msgFlags, msg))) +sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> m (AgentMsgId, PQEncryption) +sendMessage' c connId pqEnc msgFlags msg = liftEither . runIdentity =<< sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg))) -- | Send multiple messages to different connections (SEND command) in Reader monad -sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType AgentMsgId] +sendMessages' :: forall m. AgentMonad' m => AgentClient -> [MsgReq] -> m [Either AgentErrorType (AgentMsgId, PQEncryption)] sendMessages' c = sendMessagesB' c . map Right -sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType AgentMsgId)) +sendMessagesB' :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType MsgReq) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do - reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) - let reqs'' = fmap (>>= prepareConn) reqs' + reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs) + let (toEnable, reqs'') = mapAccumL prepareConn [] reqs' + void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable enqueueMessagesB c reqs'' where - prepareConn :: (MsgReq, SomeConn) -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) - prepareConn ((_, msgFlags, msg), SomeConn _ conn) = case conn of + prepareConn :: [ConnId] -> Either AgentErrorType (MsgReq, SomeConn) -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareConn acc (Left e) = (acc, Left e) + prepareConn acc (Right ((_, pqEnc, msgFlags, msg), SomeConn _ conn)) = case conn of DuplexConnection cData _ sqs -> prepareMsg cData sqs SndConnection cData sq -> prepareMsg cData [sq] - _ -> Left $ CONN SIMPLEX + _ -> (acc, Left $ CONN SIMPLEX) where - prepareMsg :: ConnData -> NonEmpty SndQueue -> Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) - prepareMsg cData sqs - | ratchetSyncSendProhibited cData = Left $ CMD PROHIBITED - | otherwise = Right (cData, sqs, msgFlags, A_MSG msg) - connIds = map (\(connId, _, _) -> connId) $ rights $ toList reqs + prepareMsg :: ConnData -> NonEmpty SndQueue -> ([ConnId], Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) + prepareMsg cData@ConnData {connId, pqSupport} sqs + | ratchetSyncSendProhibited cData = (acc, Left $ CMD PROHIBITED) + -- connection is only updated if PQ encryption was disabled, and now it has to be enabled. + -- support for PQ encryption (small message envelopes) will not be disabled when message is sent. + | pqEnc == PQEncOn && pqSupport == PQSupportOff = + let cData' = cData {pqSupport = PQSupportOn} :: ConnData + in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg)) + | otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg)) + connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs -- / async command processing v v v @@ -965,16 +1006,16 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do processCmd :: RetryInterval -> PendingCommand -> m () processCmd ri PendingCommand {cmdId, corrId, userId, connId, command} = case command of AClientCommand (APC _ cmd) -> case cmd of - NEW enableNtfs (ACM cMode) subMode -> noServer $ do + NEW enableNtfs (ACM cMode) pqEnc subMode -> noServer $ do usedSrvs <- newTVarIO ([] :: [SMPServer]) tryCommand . withNextSrv c userId usedSrvs [] $ \srv -> do - (_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing subMode srv + (_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing pqEnc subMode srv notify $ INV (ACR cMode cReq) - JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) subMode connInfo -> noServer $ do + JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) pqEnc subMode connInfo -> noServer $ do let initUsed = [qServer q] usedSrvs <- newTVarIO initUsed tryCommand . withNextSrv c userId usedSrvs initUsed $ \srv -> do - joinConnSrvAsync c userId connId enableNtfs cReq connInfo subMode srv + joinConnSrvAsync c userId connId enableNtfs cReq connInfo pqEnc subMode srv notify OK LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK @@ -1077,16 +1118,16 @@ runCommandProcessing c@AgentClient {subQ} server_ Worker {doWork} = do notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd) -- ^ ^ ^ async command processing / -enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId +enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, PQEncryption) enqueueMessages c cData sqs msgFlags aMessage = do when (ratchetSyncSendProhibited cData) $ throwError $ INTERNAL "enqueueMessages: ratchet is not synchronized" enqueueMessages' c cData sqs msgFlags aMessage -enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId +enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, CR.PQEncryption) enqueueMessages' c cData sqs msgFlags aMessage = - liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, msgFlags, aMessage))) + liftEither . runIdentity =<< enqueueMessagesB c (Identity (Right (cData, sqs, Nothing, msgFlags, aMessage))) -enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType AgentMsgId)) +enqueueMessagesB :: (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, PQEncryption))) enqueueMessagesB c reqs = do reqs' <- enqueueMessageB c reqs enqueueSavedMessageB c $ mapMaybe snd $ rights $ toList reqs' @@ -1095,35 +1136,39 @@ enqueueMessagesB c reqs = do isActiveSndQ :: SndQueue -> Bool isActiveSndQ SndQueue {status} = status == Secured || status == Active -enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId +enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m (AgentMsgId, PQEncryption) enqueueMessage c cData sq msgFlags aMessage = - liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], msgFlags, aMessage))) + liftEither . fmap fst . runIdentity =<< enqueueMessageB c (Identity (Right (cData, [sq], Nothing, msgFlags, aMessage))) -- this function is used only for sending messages in batch, it returns the list of successes to enqueue additional deliveries -enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, MsgFlags, AMessage)) -> m (t (Either AgentErrorType (AgentMsgId, Maybe (ConnData, [SndQueue], AgentMsgId)))) +enqueueMessageB :: forall m t. (AgentMonad' m, Traversable t) => AgentClient -> t (Either AgentErrorType (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage)) -> m (t (Either AgentErrorType ((AgentMsgId, PQEncryption), Maybe (ConnData, [SndQueue], AgentMsgId)))) enqueueMessageB c reqs = do - aVRange <- asks $ maxVersion . smpAgentVRange . config - reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db aVRange) reqs - forME reqMids $ \((cData, sq :| sqs, _, _), InternalId msgId) -> do + cfg <- asks config + reqMids <- withStoreBatch c $ \db -> fmap (bindRight $ storeSentMsg db cfg) reqs + forME reqMids $ \((cData, sq :| sqs, _, _, _), InternalId msgId, pqSecr) -> do submitPendingMsg c cData sq let sqs' = filter isActiveSndQ sqs - pure $ Right (msgId, if null sqs' then Nothing else Just (cData, sqs', msgId)) + pure $ Right ((msgId, pqSecr), if null sqs' then Nothing else Just (cData, sqs', msgId)) where - storeSentMsg :: DB.Connection -> Version -> (ConnData, NonEmpty SndQueue, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, MsgFlags, AMessage), InternalId)) - storeSentMsg db agentVersion req@(ConnData {connId}, sq :| _, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + storeSentMsg :: DB.Connection -> AgentConfig -> (ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage) -> IO (Either AgentErrorType ((ConnData, NonEmpty SndQueue, Maybe PQEncryption, MsgFlags, AMessage), InternalId, PQEncryption)) + storeSentMsg db cfg req@(cData@ConnData {connId, pqSupport}, sq :| _, pqEnc_, msgFlags, aMessage) = fmap (first storeError) $ runExceptT $ do + let AgentConfig {smpAgentVRange, e2eEncryptVRange} = cfg internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash agentMsg = AgentMessage privHeader aMessage agentMsgStr = smpEncode agentMsg internalHash = C.sha256Hash agentMsgStr - encAgentMessage <- agentRatchetEncrypt db connId agentMsgStr e2eEncUserMsgLength - let msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} + currentE2EVersion = maxVersion $ e2eEncryptVRange PQSupportOff + (encAgentMessage, pqEnc) <- agentRatchetEncrypt db cData agentMsgStr e2eEncUserMsgLength pqEnc_ currentE2EVersion + -- agent version range is determined by the connection suppport of PQ encryption, that is may be enabled when message is sent + let agentVersion = maxVersion $ smpAgentVRange pqSupport + msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash} + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, pqEncryption = pqEnc, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId - pure (req, internalId) + pure (req, internalId, pqEnc) enqueueSavedMessage :: AgentMonad' m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m () enqueueSavedMessage c cData msgId sq = enqueueSavedMessageB c $ Identity (cData, [sq], msgId) @@ -1166,7 +1211,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork atomically $ throwWhenNoDelivery c sq atomically $ beginAgentOperation c AOSndNetwork withWork c doWork (\db -> getPendingQueueMsg db connId sq) $ - \(rq_, PendingMsgData {msgId, msgType, msgBody, msgFlags, msgRetryState, internalTs}) -> do + \(rq_, PendingMsgData {msgId, msgType, msgBody, pqEncryption, msgFlags, msgRetryState, internalTs}) -> do atomically $ endAgentOperation c AOMsgDelivery -- this operation begins in submitPendingMsg let mId = unId msgId ri' = maybe id updateRetryInterval2 msgRetryState ri @@ -1236,7 +1281,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} ConnData {connId} sq (Worker {doWork -- it would lead to the non-deterministic internal ID of the first sent message, at to some other race conditions, -- because it can be sent before HELLO is received -- With `status == Active` condition, CON is sent here only by the accepting party, that previously received HELLO - when (status == Active) $ notify CON + when (status == Active) $ notify $ CON pqEncryption -- this branch should never be reached as receive queue is created before the confirmation, _ -> logError "HELLO sent without receive queue" AM_A_MSG_ -> notify $ SENT mId @@ -1322,13 +1367,13 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do ack :: m () ack = do -- the stored message was delivered via a specific queue, the rest failed to decrypt and were already acknowledged - (rq, srvMsgId) <- withStoreCtx "ackMessage': setMsgUserAck" c $ \db -> setMsgUserAck db connId $ InternalId msgId + (rq, srvMsgId) <- withStore c $ \db -> setMsgUserAck db connId $ InternalId msgId ackQueueMessage c rq srvMsgId del :: m () - del = withStoreCtx' "ackMessage': deleteMsg" c $ \db -> deleteMsg db connId $ InternalId msgId + del = withStore' c $ \db -> deleteMsg db connId $ InternalId msgId sendRcpt :: Connection 'CDuplex -> m () sendRcpt (DuplexConnection cData@ConnData {connAgentVersion} _ sqs) = do - msg@RcvMsg {msgType, msgReceipt} <- withStoreCtx "ackMessage': getRcvMsg" c $ \db -> getRcvMsg db connId $ InternalId msgId + msg@RcvMsg {msgType, msgReceipt} <- withStore c $ \db -> getRcvMsg db connId $ InternalId msgId case rcptInfo_ of Just rcptInfo -> do unless (msgType == AM_A_MSG_) $ throwError (CMD PROHIBITED) @@ -1339,7 +1384,7 @@ ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do Nothing -> case (msgType, msgReceipt) of -- only remove sent message if receipt hash was Ok, both to debug and for future redundancy (AM_A_RCVD_, Just MsgReceipt {agentMsgId = sndMsgId, msgRcptStatus = MROk}) -> - withStoreCtx' "ackMessage': deleteDeliveredSndMsg" c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId + withStore' c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId _ -> pure () switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats @@ -1394,21 +1439,23 @@ abortConnectionSwitch' c connId = _ -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED -synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> Bool -> m ConnectionStats -synchronizeRatchet' c connId force = withConnLock c connId "synchronizeRatchet" $ do +synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> PQSupport -> Bool -> m ConnectionStats +synchronizeRatchet' c connId pqSupport' force = withConnLock c connId "synchronizeRatchet" $ do withStore c (`getConn` connId) >>= \case - SomeConn _ (DuplexConnection cData rqs sqs) + SomeConn _ (DuplexConnection cData@ConnData {pqSupport} rqs sqs) | ratchetSyncAllowed cData || force -> do -- check queues are not switching? + when (pqSupport' /= pqSupport) $ withStore' c $ \db -> setConnPQSupport db connId pqSupport' + let cData' = cData {pqSupport = pqSupport'} :: ConnData AgentConfig {e2eEncryptVRange} <- asks config g <- asks random - (pk1, pk2, e2eParams) <- atomically . CR.generateE2EParams g $ maxVersion e2eEncryptVRange - enqueueRatchetKeyMsgs c cData sqs e2eParams + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g (maxVersion $ e2eEncryptVRange pqSupport') pqSupport' + enqueueRatchetKeyMsgs c cData' sqs e2eParams withStore' c $ \db -> do setConnRatchetSync db connId RSStarted - setRatchetX3dhKeys db connId pk1 pk2 - let cData' = cData {ratchetSyncState = RSStarted} :: ConnData - conn' = DuplexConnection cData' rqs sqs + setRatchetX3dhKeys db connId pk1 pk2 pKem + let cData'' = cData' {ratchetSyncState = RSStarted} :: ConnData + conn' = DuplexConnection cData'' rqs sqs pure $ connectionStats conn' | otherwise -> throwError $ CMD PROHIBITED _ -> throwError $ CMD PROHIBITED @@ -1452,19 +1499,23 @@ disableConn c connId = do -- Unlike deleteConnectionsAsync, this function does not mark connections as deleted in case of deletion failure. deleteConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) -deleteConnections' = deleteConnections_ getConns False +deleteConnections' = deleteConnections_ getConns False False deleteDeletedConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) -deleteDeletedConns = deleteConnections_ getDeletedConns True +deleteDeletedConns = deleteConnections_ getDeletedConns True False + +deleteDeletedWaitingDeliveryConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +deleteDeletedWaitingDeliveryConns = deleteConnections_ getConns True True prepareDeleteConnections_ :: forall m. AgentMonad m => (DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) -> AgentClient -> + Bool -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId]) -prepareDeleteConnections_ getConnections c connIds = do +prepareDeleteConnections_ getConnections c waitDelivery connIds = do conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConnections` connIds) let (errs, cs) = M.mapEither id conns errs' = M.map (Left . storeError) errs @@ -1472,19 +1523,27 @@ prepareDeleteConnections_ getConnections c connIds = do rqs = concat $ M.elems rcvQs connIds' = M.keys rcvQs forM_ connIds' $ disableConn c - withStore' c $ forM_ (M.keys delRs) . deleteConn + -- ! delRs is not used to notify about the result in any of the calling functions, + -- ! it is only used to check results count in deleteConnections_; + -- ! if it was used to notify about the result, it might be necessary to differentiate + -- ! between completed deletions of connections, and deletions delayed due to wait for delivery (see deleteConn) + deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing + rs' <- catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) (M.keys delRs)) + forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN) pure (errs' <> delRs, rqs, connIds') where rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue] rcvQueues (SomeConn _ conn) = case connRcvQueues conn of [] -> Left $ Right () rqs -> Right rqs + notify = atomically . writeTBQueue (subQ c) -deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ())) -deleteConnQueues c ntf rqs = do +deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ())) +deleteConnQueues c waitDelivery ntf rqs = do rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs) let connIds = M.keys $ M.filter isRight rs - rs' <- rights <$> withStoreBatch' c (\db -> map (\cId -> deleteConn db cId $> cId) connIds) + deliveryTimeout <- if waitDelivery then asks (Just . connDeleteDeliveryTimeout . config) else pure Nothing + rs' <- catMaybes . rights <$> withStoreBatch' c (\db -> map (deleteConn db deliveryTimeout) connIds) forM_ rs' $ \cId -> notify ("", cId, APC SAEConn DEL_CONN) pure rs where @@ -1527,13 +1586,14 @@ deleteConnections_ :: AgentMonad m => (DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) -> Bool -> + Bool -> AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) -deleteConnections_ _ _ _ [] = pure M.empty -deleteConnections_ getConnections ntf c connIds = do - (rs, rqs, _) <- prepareDeleteConnections_ getConnections c connIds - rcvRs <- deleteConnQueues c ntf rqs +deleteConnections_ _ _ _ _ [] = pure M.empty +deleteConnections_ getConnections ntf waitDelivery c connIds = do + (rs, rqs, _) <- prepareDeleteConnections_ getConnections c waitDelivery connIds + rcvRs <- deleteConnQueues c waitDelivery ntf rqs let rs' = M.union rs rcvRs notifyResultError rs' pure rs' @@ -1862,6 +1922,7 @@ cleanupManager c@AgentClient {subQ} = do deleteConns = withLock (deleteLock c) "cleanupManager" $ do void $ withStore' c getDeletedConnIds >>= deleteDeletedConns c + void $ withStore' c getDeletedWaitingDeliveryConnIds >>= deleteDeletedWaitingDeliveryConns c withStore' c deleteUsersWithoutConns >>= mapM_ (notify "" . DEL_USER) deleteRcvFilesExpired = do rcvFilesTTL <- asks $ rcvFilesTTL . config @@ -1901,9 +1962,11 @@ cleanupManager c@AgentClient {subQ} = do notify :: forall e. AEntityI e => EntityId -> ACommand 'Agent e -> ExceptT AgentErrorType m () notify entId cmd = atomically $ writeTBQueue subQ ("", entId, APC (sAEntity @e) cmd) +data ACKd = ACKd | ACKPending + -- | make sure to ACK or throw in each message processing branch -- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL -processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () +processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission SMPVersion BrokerMsg -> m () processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, sessId, rId, cmd) = do (rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId) processSMP rq conn $ toConnData conn @@ -1915,13 +1978,14 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, cData@ConnData {userId, connId, connAgentVersion, ratchetSyncState = rss} = withConnLock c connId "processSMP" $ case cmd of SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} -> - handleNotifyAck $ do + void . handleNotifyAck $ do msg' <- decryptSMPMessage rq msg - handleNotifyAck $ case msg' of + ack' <- handleNotifyAck $ case msg' of SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody SMP.ClientRcvMsgQuota {} -> queueDrained >> ack whenM (atomically $ hasGetLock c rq) $ notify (MSGNTF $ SMP.rcvMessageMeta srvMsgId msg') + pure ack' where queueDrained = case conn of DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq) @@ -1944,8 +2008,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, decryptClientMessage e2eDh clientMsg >>= \case (SMP.PHEmpty, AgentRatchetKey {agentVersion, e2eEncryption}) -> do conn' <- updateConnVersion conn cData agentVersion - qDuplex conn' "AgentRatchetKey" $ newRatchetKey e2eEncryption - ack + qDuplex conn' "AgentRatchetKey" $ \a -> newRatchetKey e2eEncryption a >> ack (SMP.PHEmpty, AgentMsgEnvelope {agentVersion, encAgentMessage}) -> do conn' <- updateConnVersion conn cData agentVersion -- primary queue is set as Active in helloMsg, below is to set additional queues Active @@ -1967,11 +2030,12 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do conn'' <- resetRatchetSync case aMessage of - HELLO -> helloMsg srvMsgId conn'' >> ackDel msgId + HELLO -> helloMsg srvMsgId msgMeta conn'' >> ackDel msgId -- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command A_MSG body -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId notify $ MSG msgMeta msgFlags body + pure ACKPending A_RCVD rcpts -> qDuplex conn'' "RCVD" $ messagesRcvd rcpts msgMeta QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending srvMsgId addr QADD qs -> qDuplexAckDel conn'' "QADD" $ qAddMsg srvMsgId qs @@ -1982,7 +2046,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, QTEST _ -> logServer "<--" c srv rId ("MSG :" <> logSecret srvMsgId) >> ackDel msgId EREADY _ -> qDuplexAckDel conn'' "EREADY" $ ereadyMsg rcPrev where - qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m () + qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m ACKd qDuplexAckDel conn'' name a = qDuplex conn'' name a >> ackDel msgId resetRatchetSync :: m (Connection c) resetRatchetSync @@ -1995,7 +2059,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, | otherwise = pure conn' Right _ -> prohibited >> ack Left e@(AGENT A_DUPLICATE) -> do - withStoreCtx' "processSMP: getLastMsg" c (\db -> getLastMsg db connId srvMsgId) >>= \case + withStore' c (\db -> getLastMsg db connId srvMsgId) >>= \case Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck} | userAck -> ackDel internalId | otherwise -> do @@ -2003,7 +2067,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, AgentMessage _ (A_MSG body) -> do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId notify $ MSG msgMeta msgFlags body - _ -> pure () + pure ACKPending + _ -> ack _ -> checkDuplicateHash e encryptedMsgHash >> ack Left (AGENT (A_CRYPTO e)) -> do exists <- withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash @@ -2027,7 +2092,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, agentClientMsg :: TVar ChaChaDRG -> ByteString -> m (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448)) agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY - agentMsgBody <- agentRatchetDecrypt' g db connId rc encAgentMessage + (agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do let msgType = agentMessageType agentMsg @@ -2037,16 +2102,16 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash recipient = (unId internalId, internalTs) broker = (srvMsgId, systemToUTCTime srvTs) - msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId} + msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId, pqEncryption} rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash, encryptedMsgHash} liftIO $ createRcvMsg db connId rq rcvMsg pure $ Just (internalId, msgMeta, aMessage, rc) _ -> pure Nothing _ -> prohibited >> ack _ -> prohibited >> ack - updateConnVersion :: Connection c -> ConnData -> Version -> m (Connection c) - updateConnVersion conn' cData' msgAgentVersion = do - aVRange <- asks $ smpAgentVRange . config + updateConnVersion :: Connection c -> ConnData -> VersionSMPA -> m (Connection c) + updateConnVersion conn' cData'@ConnData {pqSupport} msgAgentVersion = do + aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion case msgAVRange `compatibleVersion` aVRange of Just (Compatible av) @@ -2056,11 +2121,11 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, pure $ updateConnection cData'' conn' | otherwise -> pure conn' Nothing -> pure conn' - ack :: m () - ack = enqueueCmd $ ICAck rId srvMsgId - ackDel :: InternalId -> m () - ackDel = enqueueCmd . ICAckDel rId srvMsgId - handleNotifyAck :: m () -> m () + ack :: m ACKd + ack = enqueueCmd (ICAck rId srvMsgId) $> ACKd + ackDel :: InternalId -> m ACKd + ackDel aId = enqueueCmd (ICAckDel rId srvMsgId aId) $> ACKd + handleNotifyAck :: m ACKd -> m ACKd handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack SMP.END -> atomically (TM.lookup tSess smpClients $>>= (tryReadTMVar . sessionVar) >>= processEND) @@ -2107,20 +2172,26 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, parseMessage :: Encoding a => ByteString -> m a parseMessage = liftEither . parse smpP (AGENT A_MESSAGE) - smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.E2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m () + smpConfirmation :: SMP.MsgId -> Connection c -> C.APublicAuthKey -> C.PublicKeyX25519 -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> ByteString -> VersionSMPC -> VersionSMPA -> m () smpConfirmation srvMsgId conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config + let ConnData {pqSupport} = toConnData conn' + aVRange = smpAgentVRange pqSupport + e2eVRange = e2eEncryptVRange pqSupport unless - (agentVersion `isCompatible` smpAgentVRange && smpClientVersion `isCompatible` smpClientVRange) + (agentVersion `isCompatible` aVRange && smpClientVersion `isCompatible` smpClientVRange) (throwError $ AGENT A_VERSION) case status of New -> case (conn', e2eEncryption) of -- party initiating connection - (RcvConnection {}, Just e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _)) -> do - unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION) - (pk1, rcDHRs) <- withStore c (`getRatchetX3dhKeys` connId) - let rc = CR.initRcvRatchet e2eEncryptVRange rcDHRs $ CR.x3dhRcv pk1 rcDHRs e2eSndParams + (RcvConnection _ _, Just (CR.AE2ERatchetParams _ e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _ _))) -> do + unless (e2eVersion `isCompatible` e2eVRange) (throwError $ AGENT A_VERSION) + (pk1, rcDHRs, pKem) <- withStore c (`getRatchetX3dhKeys` connId) + rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 rcDHRs pKem e2eSndParams + let rcVs = CR.RatchetVersions {current = e2eVersion, maxSupported = maxVersion e2eVRange} + pqSupport' = pqSupport `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just e2eVersion) + rc = CR.initRcvRatchet rcVs rcDHRs rcParams pqSupport' g <- asks random (agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt g rc M.empty encConnInfo case (agentMsgBody_, skipped) of @@ -2133,17 +2204,18 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, processConf connInfo senderConf = do let newConfirmation = NewConfirmation {connId, senderConf, ratchetState = rc'} confId <- withStore c $ \db -> do - setConnectionVersion db connId agentVersion + setConnAgentVersion db connId agentVersion + when (pqSupport /= pqSupport') $ setConnPQSupport db connId pqSupport' createConfirmation db g newConfirmation let srvs = map qServer $ smpReplyQueues senderConf - notify $ CONF confId srvs connInfo + notify $ CONF confId pqSupport' srvs connInfo _ -> prohibited -- party accepting connection (DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do g <- asks random - withStore c (\db -> runExceptT $ agentRatchetDecrypt g db connId encConnInfo) >>= parseMessage >>= \case + withStore c (\db -> runExceptT $ agentRatchetDecrypt g db connId encConnInfo) >>= parseMessage . fst >>= \case AgentConnInfo connInfo -> do - notify $ INFO connInfo + notify $ INFO pqSupport connInfo let dhSecret = C.dh' e2ePubKey e2ePrivKey withStore' c $ \db -> setRcvQueueConfirmedE2E db rq dhSecret $ min v' smpClientVersion enqueueCmd $ ICDuplexSecure rId senderKey @@ -2151,8 +2223,8 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> prohibited _ -> prohibited - helloMsg :: SMP.MsgId -> Connection c -> m () - helloMsg srvMsgId conn' = do + helloMsg :: SMP.MsgId -> MsgMeta -> Connection c -> m () + helloMsg srvMsgId MsgMeta {pqEncryption} conn' = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId case status of Active -> prohibited @@ -2162,7 +2234,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, -- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO -- this branch is executed by the accepting party in duplexHandshake mode (v2) -- (was executed by initiating party in v1 that is no longer supported) - | sndStatus == Active -> notify CON + | sndStatus == Active -> notify $ CON pqEncryption | otherwise -> enqueueDuplexHello sq _ -> pure () where @@ -2181,18 +2253,20 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, >>= mapM_ (\(_, retryLock) -> tryPutTMVar retryLock ()) Nothing -> qError "QCONT: queue address not found" - messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> m () + messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> m ACKd messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing case L.nonEmpty . catMaybes $ L.toList rs of - Just rs' -> notify $ RCVD msgMeta rs' -- client must ACK once processed - Nothing -> enqueueCmd $ ICAck rId srvMsgId + Just rs' -> notify (RCVD msgMeta rs') $> ACKPending + Nothing -> ack where + ack :: m ACKd + ack = enqueueCmd (ICAck rId srvMsgId) $> ACKd clientReceipt :: AMessageReceipt -> m (Maybe MsgReceipt) clientReceipt AMessageReceipt {agentMsgId, msgHash} = do let sndMsgId = InternalSndId agentMsgId - SndMsg {internalId = InternalId msgId, msgType, internalHash, msgReceipt} <- withStoreCtx "messagesRcvd: getSndMsgViaRcpt" c $ \db -> getSndMsgViaRcpt db connId sndMsgId + SndMsg {internalId = InternalId msgId, msgType, internalHash, msgReceipt} <- withStore c $ \db -> getSndMsgViaRcpt db connId sndMsgId if msgType /= AM_A_MSG_ then notify (ERR $ AGENT A_PROHIBITED) $> Nothing -- unexpected message type for receipt else case msgReceipt of @@ -2279,7 +2353,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, _ -> qError "QUSE: switching SndQueue not found in connection" _ -> qError "QUSE: switched queue address not found in connection" - qError :: String -> m () + qError :: String -> m a qError = throwError . AGENT . A_QUEUE ereadyMsg :: CR.RatchetX448 -> Connection 'CDuplex -> m () @@ -2294,25 +2368,33 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, logServer "<--" c srv rId $ "MSG :" <> logSecret srvMsgId case conn' of ContactConnection {} -> do + -- show connection request even if invitaion via contact address is not compatible. + -- in case invitation not compatible, assume there is no PQ encryption support. + pqSupport <- maybe PQSupportOff pqSupported <$> compatibleInvitationUri connReq PQSupportOn g <- asks random let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo} invId <- withStore c $ \db -> createInvitation db g newInv let srvs = L.map qServer $ crSmpQueues crData - notify $ REQ invId srvs cInfo + notify $ REQ invId pqSupport srvs cInfo _ -> prohibited + where + pqSupported (_, Compatible (CR.E2ERatchetParams v _ _ _), Compatible agentVersion) = + PQSupportOn `CR.pqSupportAnd` versionPQSupport_ agentVersion (Just v) - qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m () + qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m a) -> m a qDuplex conn' name action = case conn' of DuplexConnection {} -> action conn' _ -> qError $ name <> ": message must be sent to duplex connection" - newRatchetKey :: CR.E2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () - newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = + newRatchetKey :: CR.RcvE2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m () + newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv _) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId, pqSupport} _ sqs) = unlessM ratchetExists $ do AgentConfig {e2eEncryptVRange} <- asks config - unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION) + let connE2EVRange = e2eEncryptVRange pqSupport + unless (e2eVersion `isCompatible` connE2EVRange) (throwError $ AGENT A_VERSION) keys <- getSendRatchetKeys - initRatchet e2eEncryptVRange keys + let rcVs = CR.RatchetVersions {current = e2eVersion, maxSupported = maxVersion connE2EVRange} + initRatchet rcVs keys notifyAgreed where rkHashRcv = rkHash k1Rcv k2Rcv @@ -2322,7 +2404,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, exists <- checkRatchetKeyHashExists db connId rkHashRcv unless exists $ addProcessedRatchetKeyHash db connId rkHashRcv pure exists - getSendRatchetKeys :: m (C.PrivateKeyX448, C.PrivateKeyX448) + getSendRatchetKeys :: m (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) getSendRatchetKeys = case rss of RSOk -> sendReplyKey -- receiving client RSAllowed -> sendReplyKey @@ -2338,9 +2420,9 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, where sendReplyKey = do g <- asks random - (pk1, pk2, e2eParams) <- atomically . CR.generateE2EParams g $ version e2eOtherPartyParams + (pk1, pk2, pKem, e2eParams) <- liftIO $ CR.generateRcvE2EParams g e2eVersion pqSupport enqueueRatchetKeyMsgs c cData' sqs e2eParams - pure (pk1, pk2) + pure (pk1, pk2, pKem) notifyRatchetSyncError = do let cData'' = cData' {ratchetSyncState = RSRequired} :: ConnData conn'' = updateConnection cData'' conn' @@ -2357,13 +2439,15 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), _v, createRatchet db connId rc -- compare public keys `k1` in AgentRatchetKey messages sent by self and other party -- to determine ratchet initilization ordering - initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448) -> m () - initRatchet e2eEncryptVRange (pk1, pk2) + initRatchet :: CR.RatchetVersions -> (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams) -> m () + initRatchet rcVs (pk1, pk2, pKem) | rkHash (C.publicKey pk1) (C.publicKey pk2) <= rkHashRcv = do - recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams + rcParams <- liftError cryptoError $ CR.pqX3dhRcv pk1 pk2 pKem e2eOtherPartyParams + recreateRatchet $ CR.initRcvRatchet rcVs pk2 rcParams pqSupport | otherwise = do (_, rcDHRs) <- atomically . C.generateKeyPair =<< asks random - recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams + rcParams <- liftEitherWith cryptoError $ CR.pqX3dhSnd pk1 pk2 (CR.APRKP CR.SRKSProposed <$> pKem) e2eOtherPartyParams + recreateRatchet $ CR.initSndRatchet rcVs k2Rcv rcDHRs rcParams void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity @@ -2400,59 +2484,67 @@ connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = sq' <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq enqueueConfirmation c cData sq' ownConnInfo Nothing -confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueueAsync :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> m () confirmQueueAsync c cData sq srv connInfo e2eEncryption_ subMode = do storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation c cData sq srv connInfo subMode submitPendingMsg c cData sq -confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m () -confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ subMode = do +confirmQueue :: forall m. AgentMonad m => Compatible VersionSMPA -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> SubscriptionMode -> m () +confirmQueue (Compatible agentVersion) c cData@ConnData {connId, pqSupport} sq srv connInfo e2eEncryption_ subMode = do msg <- mkConfirmation =<< mkAgentConfirmation c cData sq srv connInfo subMode sendConfirmation c sq msg withStore' c $ \db -> setSndQueueStatus db sq Confirmed where mkConfirmation :: AgentMessage -> m MsgBody - mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do - void . liftIO $ updateSndIds db connId - encConnInfo <- agentRatchetEncrypt db connId (smpEncode aMessage) e2eEncConnInfoLength - pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} + mkConfirmation aMessage = do + -- the version to be used when PQSupport is disabled + currentE2EVersion <- asks $ maxVersion . ($ PQSupportOff) . e2eEncryptVRange . config + withStore c $ \db -> runExceptT $ do + void . liftIO $ updateSndIds db connId + let pqEnc = CR.pqSupportToEnc pqSupport + (encConnInfo, _) <- agentRatchetEncrypt db cData (smpEncode aMessage) e2eEncConnInfoLength (Just pqEnc) currentE2EVersion + pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo} mkAgentConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage mkAgentConfirmation c cData sq srv connInfo subMode = do qInfo <- createReplyQueue c cData sq subMode srv pure $ AgentConnInfoReply (qInfo :| []) connInfo -enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m () +enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> m () enqueueConfirmation c cData sq connInfo e2eEncryption_ = do storeConfirmation c cData sq e2eEncryption_ $ AgentConnInfo connInfo submitPendingMsg c cData sq -storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.E2ERatchetParams 'C.X448) -> AgentMessage -> m () -storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ agentMsg = withStore c $ \db -> runExceptT $ do - internalTs <- liftIO getCurrentTime - (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId - let agentMsgStr = smpEncode agentMsg - internalHash = C.sha256Hash agentMsgStr - encConnInfo <- agentRatchetEncrypt db connId agentMsgStr e2eEncConnInfoLength - let msgBody = smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo} - msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} - liftIO $ createSndMsg db connId msgData - liftIO $ createSndMsgDelivery db connId sq internalId +storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.SndE2ERatchetParams 'C.X448) -> AgentMessage -> m () +storeConfirmation c cData@ConnData {connId, pqSupport, connAgentVersion = v} sq e2eEncryption_ agentMsg = do + -- the version to be used when PQSupport is disabled + currentE2EVersion <- asks $ maxVersion . ($ PQSupportOff) . e2eEncryptVRange . config + withStore c $ \db -> runExceptT $ do + internalTs <- liftIO getCurrentTime + (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId + let agentMsgStr = smpEncode agentMsg + internalHash = C.sha256Hash agentMsgStr + pqEnc = CR.pqSupportToEnc pqSupport + (encConnInfo, pqEncryption) <- agentRatchetEncrypt db cData agentMsgStr e2eEncConnInfoLength (Just pqEnc) currentE2EVersion + let msgBody = smpEncode $ AgentConfirmation {agentVersion = v, e2eEncryption_, encConnInfo} + msgType = agentMessageType agentMsg + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + liftIO $ createSndMsg db connId msgData + liftIO $ createSndMsgDelivery db connId sq internalId -enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.E2ERatchetParams 'C.X448 -> m () +enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m () enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do msgId <- enqueueRatchetKey c cData sq e2eEncryption mapM_ (enqueueSavedMessage c cData msgId) $ filter isActiveSndQ sqs -enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId -enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do - aVRange <- asks $ smpAgentVRange . config +enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.RcvE2ERatchetParams 'C.X448 -> m AgentMsgId +enqueueRatchetKey c cData@ConnData {connId, pqSupport} sq e2eEncryption = do + aVRange <- asks $ ($ pqSupport) . smpAgentVRange . config msgId <- storeRatchetKey $ maxVersion aVRange submitPendingMsg c cData sq pure $ unId msgId where - storeRatchetKey :: Version -> m InternalId + storeRatchetKey :: VersionSMPA -> m InternalId storeRatchetKey agentVersion = withStore c $ \db -> runExceptT $ do internalTs <- liftIO getCurrentTime (internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId @@ -2461,31 +2553,33 @@ enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do internalHash = C.sha256Hash agentMsgStr let msgBody = smpEncode $ AgentRatchetKey {agentVersion, e2eEncryption, info = agentMsgStr} msgType = agentMessageType agentMsg - msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} + -- this message is e2e encrypted with queue key, not with double ratchet + msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, pqEncryption = PQEncOff, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash} liftIO $ createSndMsg db connId msgData liftIO $ createSndMsgDelivery db connId sq internalId pure internalId -- encoded AgentMessage -> encoded EncAgentMessage -agentRatchetEncrypt :: DB.Connection -> ConnId -> ByteString -> Int -> ExceptT StoreError IO ByteString -agentRatchetEncrypt db connId msg paddedLen = do +agentRatchetEncrypt :: DB.Connection -> ConnData -> ByteString -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> CR.VersionE2E -> ExceptT StoreError IO (ByteString, PQEncryption) +agentRatchetEncrypt db ConnData {connId, connAgentVersion = v, pqSupport} msg getPaddedLen pqEnc_ currentE2EVersion = do rc <- ExceptT $ getRatchet db connId - (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg + let paddedLen = getPaddedLen v pqSupport + (encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg pqEnc_ currentE2EVersion liftIO $ updateRatchet db connId rc' CR.SMDNoChange - pure encMsg + pure (encMsg, CR.rcSndKEM rc') -- encoded EncAgentMessage -> encoded AgentMessage -agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO ByteString +agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) agentRatchetDecrypt g db connId encAgentMsg = do rc <- ExceptT $ getRatchet db connId agentRatchetDecrypt' g db connId rc encAgentMsg -agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO ByteString +agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) agentRatchetDecrypt' g db connId rc encAgentMsg = do skipped <- liftIO $ getSkippedMsgKeys db connId (agentMsgBody_, rc', skippedDiff) <- liftE (SEAgentError . cryptoError) $ CR.rcDecrypt g rc skipped encAgentMsg liftIO $ updateRatchet db connId rc' skippedDiff - liftEither $ first (SEAgentError . cryptoError) agentMsgBody_ + liftEither $ bimap (SEAgentError . cryptoError) (,CR.rcRcvKEM rc') agentMsgBody_ newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m NewSndQueue newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 23caa2254..8b3b87122 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -110,8 +110,6 @@ module Simplex.Messaging.Agent.Client whenSuspending, withStore, withStore', - withStoreCtx, - withStoreCtx', withStoreBatch, withStoreBatch', storeError, @@ -167,8 +165,8 @@ import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb) -import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse, XFTPErrorType (DIGEST)) -import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse) +import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion) import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..)) import Simplex.FileTransfer.Util (uniqueCombine) import Simplex.Messaging.Agent.Env.SQLite @@ -187,6 +185,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse) import Simplex.Messaging.Protocol @@ -215,9 +214,12 @@ import Simplex.Messaging.Protocol UserProtocol, XFTPServer, XFTPServerWithAuth, + VersionSMPC, + VersionRangeSMPC, sameSrvAddr', ) import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Transport (SMPVersion) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport.Client (TransportHost) @@ -253,7 +255,7 @@ data AgentClient = AgentClient { active :: TVar Bool, rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), - msgQ :: TBQueue (ServerTransmission BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), smpServers :: TMap UserId (NonEmpty SMPServerWithAuth), smpClients :: TMap SMPTransportSession SMPClientVar, ntfServers :: TVar [NtfServer], @@ -467,7 +469,7 @@ agentClientStore AgentClient {agentEnv = Env {store}} = store agentDRG :: AgentClient -> TVar ChaChaDRG agentDRG AgentClient {agentEnv = Env {random}} = random -class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err where +class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where type Client msg = c | c -> msg getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg) clientProtocolError :: err -> AgentErrorType @@ -476,8 +478,8 @@ class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err wher clientTransportHost :: Client msg -> TransportHost clientSessionTs :: Client msg -> UTCTime -instance ProtocolServerClient ErrorType BrokerMsg where - type Client BrokerMsg = ProtocolClient ErrorType BrokerMsg +instance ProtocolServerClient SMPVersion ErrorType BrokerMsg where + type Client BrokerMsg = ProtocolClient SMPVersion ErrorType BrokerMsg getProtocolServerClient = getSMPServerClient clientProtocolError = SMP closeProtocolServerClient = closeProtocolClient @@ -485,8 +487,8 @@ instance ProtocolServerClient ErrorType BrokerMsg where clientTransportHost = transportHost' clientSessionTs = sessionTs -instance ProtocolServerClient ErrorType NtfResponse where - type Client NtfResponse = ProtocolClient ErrorType NtfResponse +instance ProtocolServerClient NTFVersion ErrorType NtfResponse where + type Client NtfResponse = ProtocolClient NTFVersion ErrorType NtfResponse getProtocolServerClient = getNtfServerClient clientProtocolError = NTF closeProtocolServerClient = closeProtocolClient @@ -494,7 +496,7 @@ instance ProtocolServerClient ErrorType NtfResponse where clientTransportHost = transportHost' clientSessionTs = sessionTs -instance ProtocolServerClient XFTPErrorType FileResponse where +instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where type Client FileResponse = XFTPClient getProtocolServerClient = getXFTPServerClient clientProtocolError = XFTP @@ -683,8 +685,8 @@ waitForProtocolClient c (_, srv, _) v = do -- clientConnected arg is only passed for SMP server newProtocolClient :: - forall err msg m. - (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => + forall v err msg m. + (AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> TMap (TransportSession msg) (ClientVar msg) -> @@ -706,10 +708,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v = putTMVar (sessionVar v) (Left e) throwError e -- signal error to caller -hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone +hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost -getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig +getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v) getClientConfig AgentClient {useNetworkConfig} cfgSel = do cfg <- asks $ cfgSel . config networkConfig <- readTVarIO useNetworkConfig @@ -754,19 +756,19 @@ throwWhenNoDelivery c sq = unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $ throwSTM ThreadKilled -closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c) -reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () reconnectServerClients c clientsSel = readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c) -closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () +closeClient :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () closeClient c clientSel tSess = atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c) -closeClient_ :: ProtocolServerClient err msg => AgentClient -> ClientVar msg -> IO () +closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO () closeClient_ c v = do NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case @@ -798,7 +800,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure where newLock = createLock >>= \l -> TM.insert key l locks $> l -withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a +withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a withClient_ c tSess@(userId, srv, _) statCmd action = do cl <- getProtocolServerClient c tSess (action cl <* stat cl "OK") `catchAgentError` logServerError cl @@ -810,18 +812,18 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do stat cl $ strEncode e throwError e -withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a +withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do logServer "-->" c srv entId cmdStr res <- withClient_ c tSess cmdStr action logServer "<--" c srv entId "OK" return res -withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a -withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client +withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client -withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a -withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client +withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a +withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a withSMPClient c q cmdStr action = do @@ -837,7 +839,7 @@ withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityI withNtfClient c srv = withLogClient c (0, srv, Nothing) withXFTPClient :: - (AgentMonad m, ProtocolServerClient err msg) => + (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> (UserId, ProtoServer msg, EntityId) -> ByteString -> @@ -1001,7 +1003,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig -newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) +newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri) newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do C.AuthAlg a <- asks (rcvAuthAlg . config) g <- asks random @@ -1151,7 +1153,7 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" -sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () +sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do tSess <- mkTransportSession c userId smpServer senderId withLogClient_ c tSess senderId "SEND " $ \smp -> do @@ -1334,7 +1336,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody} -- add encoding as AgentInvitation'? -agentCbEncryptOnce :: AgentMonad m => Version -> C.PublicKeyX25519 -> ByteString -> m ByteString +agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncryptOnce clientVersion dhRcvPubKey msg = do g <- asks random (dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g @@ -1453,34 +1455,13 @@ waitUntilForeground :: AgentClient -> STM () waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a -withStore' = withStoreCtx_' Nothing +withStore' c action = withStore c $ fmap Right . action withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a -withStore = withStoreCtx_ Nothing - -withStoreCtx' :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO a) -> m a -withStoreCtx' = withStoreCtx_' . Just - -withStoreCtx :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a -withStoreCtx = withStoreCtx_ . Just - -withStoreCtx_' :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO a) -> m a -withStoreCtx_' ctx_ c action = withStoreCtx_ ctx_ c $ fmap Right . action - -withStoreCtx_ :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a -withStoreCtx_ ctx_ c action = do +withStore c action = do st <- asks store - liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ case ctx_ of - Nothing -> withTransaction st action `E.catch` handleInternal "" - -- uncomment to debug store performance - -- Just ctx -> do - -- t1 <- liftIO getCurrentTime - -- putStrLn $ "agent withStoreCtx start :: " <> show t1 <> " :: " <> ctx - -- r <- withTransaction st action `E.catch` handleInternal (" (" <> ctx <> ")") - -- t2 <- liftIO getCurrentTime - -- putStrLn $ "agent withStoreCtx end :: " <> show t2 <> " :: " <> ctx <> " :: duration=" <> show (diffToMilliseconds $ diffUTCTime t2 t1) - -- pure r - Just _ -> withTransaction st action `E.catch` handleInternal "" + liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ + withTransaction st action `E.catch` handleInternal "" where handleInternal :: String -> E.SomeException -> IO (Either StoreError a) handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr @@ -1518,7 +1499,7 @@ incStat AgentClient {agentStats} n k = do Just v -> modifyTVar' v (+ n) _ -> newTVar n >>= \v -> TM.insert k v agentStats -incClientStat :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () +incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO () incClientStat c userId pc = incClientStatN c userId pc 1 incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO () @@ -1528,7 +1509,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do where statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res} -incClientStatN :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO () +incClientStatN :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO () incClientStatN c userId pc n cmd res = do atomically $ incStat c n statsKey where diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 71e710473..20a378a45 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -56,16 +56,16 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (supportedE2EEncryptVRange) +import Simplex.Messaging.Crypto.Ratchet (PQSupport, VersionRangeE2E, supportedE2EEncryptVRange) import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig) +import Simplex.Messaging.Notifications.Transport (NTFVersion) import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Protocol (NtfServer, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) +import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (TLS, Transport (..)) +import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..)) import Simplex.Messaging.Transport.Client (defaultSMPPort) import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors) -import Simplex.Messaging.Version import System.Random (StdGen, newStdGen) import UnliftIO (Async, SomeException) import UnliftIO.STM @@ -87,12 +87,13 @@ data AgentConfig = AgentConfig sndAuthAlg :: C.AuthAlg, connIdBytes :: Int, tbqSize :: Natural, - smpCfg :: ProtocolClientConfig, - ntfCfg :: ProtocolClientConfig, + smpCfg :: ProtocolClientConfig SMPVersion, + ntfCfg :: ProtocolClientConfig NTFVersion, xftpCfg :: XFTPClientConfig, reconnectInterval :: RetryInterval, messageRetryInterval :: RetryInterval2, messageTimeout :: NominalDiffTime, + connDeleteDeliveryTimeout :: NominalDiffTime, helloTimeout :: NominalDiffTime, quotaExceededTimeout :: NominalDiffTime, initialCleanupDelay :: Int64, @@ -115,9 +116,9 @@ data AgentConfig = AgentConfig caCertificateFile :: FilePath, privateKeyFile :: FilePath, certificateFile :: FilePath, - e2eEncryptVRange :: VersionRange, - smpAgentVRange :: VersionRange, - smpClientVRange :: VersionRange + e2eEncryptVRange :: PQSupport -> VersionRangeE2E, + smpAgentVRange :: PQSupport -> VersionRangeSMPA, + smpClientVRange :: VersionRangeSMPC } defaultReconnectInterval :: RetryInterval @@ -161,6 +162,7 @@ defaultAgentConfig = reconnectInterval = defaultReconnectInterval, messageRetryInterval = defaultMessageRetryInterval, messageTimeout = 2 * nominalDay, + connDeleteDeliveryTimeout = 2 * nominalDay, helloTimeout = 2 * nominalDay, quotaExceededTimeout = 7 * nominalDay, initialCleanupDelay = 30 * 1000000, -- 30 seconds diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 6129b8503..2c06e0279 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -4,6 +4,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} @@ -33,8 +34,14 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md module Simplex.Messaging.Agent.Protocol ( -- * Protocol parameters + VersionSMPA, + VersionRangeSMPA, + pattern VersionSMPA, + duplexHandshakeSMPAgentVersion, ratchetSyncSMPAgentVersion, deliveryRcptsSMPAgentVersion, + pqdrSMPAgentVersion, + currentSMPAgentVersion, supportedSMPAgentVRange, e2eEncConnInfoLength, e2eEncUserMsgLength, @@ -175,14 +182,25 @@ import Data.Time.Clock.System (SystemTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () -import Data.Word (Word32) +import Data.Word (Word16, Word32) import Database.SQLite.Simple.FromField import Database.SQLite.Simple.ToField import Simplex.FileTransfer.Description -import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType) +import Simplex.FileTransfer.Protocol (FileParty (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.Messaging.Agent.QueryString import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (E2ERatchetParams, E2ERatchetParamsUri) +import Simplex.Messaging.Crypto.Ratchet + ( InitialKeys (..), + PQEncryption (..), + pattern PQEncOff, + PQSupport, + pattern PQSupportOn, + pattern PQSupportOff, + RcvE2ERatchetParams, + RcvE2ERatchetParamsUri, + SndE2ERatchetParams + ) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers @@ -200,6 +218,10 @@ import Simplex.Messaging.Protocol SMPServerWithAuth, SndPublicAuthKey, SubscriptionMode, + SMPClientVersion, + VersionSMPC, + VersionRangeSMPC, + initialSMPClientVersion, legacyEncodeServer, legacyServerP, legacyStrEncodeServer, @@ -215,6 +237,7 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..)) import Simplex.Messaging.Util import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Simplex.RemoteControl.Types import Text.Read import UnliftIO.Exception (Exception) @@ -224,30 +247,56 @@ import UnliftIO.Exception (Exception) -- 2 - "duplex" (more efficient) connection handshake (6/9/2022) -- 3 - support ratchet renegotiation (6/30/2023) -- 4 - delivery receipts (7/13/2023) +-- 5 - post-quantum double ratchet (3/14/2024) -duplexHandshakeSMPAgentVersion :: Version -duplexHandshakeSMPAgentVersion = 2 +data SMPAgentVersion -ratchetSyncSMPAgentVersion :: Version -ratchetSyncSMPAgentVersion = 3 +instance VersionScope SMPAgentVersion -deliveryRcptsSMPAgentVersion :: Version -deliveryRcptsSMPAgentVersion = 4 +type VersionSMPA = Version SMPAgentVersion -currentSMPAgentVersion :: Version -currentSMPAgentVersion = 4 +type VersionRangeSMPA = VersionRange SMPAgentVersion -supportedSMPAgentVRange :: VersionRange -supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion +pattern VersionSMPA :: Word16 -> VersionSMPA +pattern VersionSMPA v = Version v + +duplexHandshakeSMPAgentVersion :: VersionSMPA +duplexHandshakeSMPAgentVersion = VersionSMPA 2 + +ratchetSyncSMPAgentVersion :: VersionSMPA +ratchetSyncSMPAgentVersion = VersionSMPA 3 + +deliveryRcptsSMPAgentVersion :: VersionSMPA +deliveryRcptsSMPAgentVersion = VersionSMPA 4 + +pqdrSMPAgentVersion :: VersionSMPA +pqdrSMPAgentVersion = VersionSMPA 5 + +-- TODO v5.7 increase to 5 +currentSMPAgentVersion :: VersionSMPA +currentSMPAgentVersion = VersionSMPA 4 + +-- TODO v5.7 remove dependency of version range on whether PQ support is needed +supportedSMPAgentVRange :: PQSupport -> VersionRangeSMPA +supportedSMPAgentVRange pq = + mkVersionRange duplexHandshakeSMPAgentVersion $ case pq of + PQSupportOn -> pqdrSMPAgentVersion + PQSupportOff -> currentSMPAgentVersion -- it is shorter to allow all handshake headers, -- including E2E (double-ratchet) parameters and -- signing key of the sender for the server -e2eEncConnInfoLength :: Int -e2eEncConnInfoLength = 14848 +e2eEncConnInfoLength :: VersionSMPA -> PQSupport -> Int +e2eEncConnInfoLength v = \case + -- reduced by 3726 (roughly the increase of message ratchet header size + key and ciphertext in reply link) + PQSupportOn | v >= pqdrSMPAgentVersion -> 11122 + _ -> 14848 -e2eEncUserMsgLength :: Int -e2eEncUserMsgLength = 15856 +e2eEncUserMsgLength :: VersionSMPA -> PQSupport -> Int +e2eEncUserMsgLength v = \case + -- reduced by 2222 (the increase of message ratchet header size) + PQSupportOn | v >= pqdrSMPAgentVersion -> 13634 + _ -> 15856 -- | Raw (unparsed) SMP agent protocol transmission. type ARawTransmission = (ByteString, ByteString, ByteString) @@ -273,8 +322,6 @@ data SAParty :: AParty -> Type where deriving instance Show (SAParty p) -deriving instance Eq (SAParty p) - instance TestEquality SAParty where testEquality SAgent SAgent = Just Refl testEquality SClient SClient = Just Refl @@ -297,8 +344,6 @@ data SAEntity :: AEntity -> Type where deriving instance Show (SAEntity e) -deriving instance Eq (SAEntity e) - instance TestEquality SAEntity where testEquality SAEConn SAEConn = Just Refl testEquality SAERcvFile SAERcvFile = Just Refl @@ -333,16 +378,16 @@ type ConnInfo = ByteString -- | Parameterized type for SMP agent protocol commands and responses from all participants. data ACommand (p :: AParty) (e :: AEntity) where - NEW :: Bool -> AConnectionMode -> SubscriptionMode -> ACommand Client AEConn -- response INV + NEW :: Bool -> AConnectionMode -> InitialKeys -> SubscriptionMode -> ACommand Client AEConn -- response INV INV :: AConnectionRequestUri -> ACommand Agent AEConn - JOIN :: Bool -> AConnectionRequestUri -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK - CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake + JOIN :: Bool -> AConnectionRequestUri -> PQSupport -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK + CONF :: ConfirmationId -> PQSupport -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client - REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender - ACPT :: InvitationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client + REQ :: InvitationId -> PQSupport -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender + ACPT :: InvitationId -> PQSupport -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client RJCT :: InvitationId -> ACommand Client AEConn - INFO :: ConnInfo -> ACommand Agent AEConn - CON :: ACommand Agent AEConn -- notification that connection is established + INFO :: PQSupport -> ConnInfo -> ACommand Agent AEConn + CON :: PQEncryption -> ACommand Agent AEConn -- notification that connection is established SUB :: ACommand Client AEConn END :: ACommand Agent AEConn CONNECT :: AProtocolType -> TransportHost -> ACommand Agent AENone @@ -351,8 +396,8 @@ data ACommand (p :: AParty) (e :: AEntity) where UP :: SMPServer -> [ConnId] -> ACommand Agent AENone SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent AEConn RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> ACommand Agent AEConn - SEND :: MsgFlags -> MsgBody -> ACommand Client AEConn - MID :: AgentMsgId -> ACommand Agent AEConn + SEND :: PQEncryption -> MsgFlags -> MsgBody -> ACommand Client AEConn + MID :: AgentMsgId -> PQEncryption -> ACommand Agent AEConn SENT :: AgentMsgId -> ACommand Agent AEConn MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent AEConn MERRS :: NonEmpty AgentMsgId -> AgentErrorType -> ACommand Agent AEConn @@ -458,8 +503,8 @@ aCommandTag = \case REQ {} -> REQ_ ACPT {} -> ACPT_ RJCT _ -> RJCT_ - INFO _ -> INFO_ - CON -> CON_ + INFO {} -> INFO_ + CON _ -> CON_ SUB -> SUB_ END -> END_ CONNECT {} -> CONNECT_ @@ -469,7 +514,7 @@ aCommandTag = \case SWITCH {} -> SWITCH_ RSYNC {} -> RSYNC_ SEND {} -> SEND_ - MID _ -> MID_ + MID {} -> MID_ SENT _ -> SENT_ MERR {} -> MERR_ MERRS {} -> MERRS_ @@ -665,7 +710,7 @@ instance StrEncoding SndQueueInfo where pure SndQueueInfo {sndServer, sndSwitchStatus} data ConnectionStats = ConnectionStats - { connAgentVersion :: Version, + { connAgentVersion :: VersionSMPA, rcvQueuesInfo :: [RcvQueueInfo], sndQueuesInfo :: [SndQueueInfo], ratchetSyncState :: RatchetSyncState, @@ -769,17 +814,19 @@ data MsgMeta = MsgMeta { integrity :: MsgIntegrity, recipient :: (AgentMsgId, UTCTime), broker :: (MsgId, UTCTime), - sndMsgId :: AgentMsgId + sndMsgId :: AgentMsgId, + pqEncryption :: PQEncryption } deriving (Eq, Show) instance StrEncoding MsgMeta where - strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId} = + strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId, pqEncryption} = B.unwords [ strEncode integrity, "R=" <> bshow rmId <> "," <> showTs rTs, "B=" <> encode bmId <> "," <> showTs bTs, - "S=" <> bshow sndMsgId + "S=" <> bshow sndMsgId, + "PQ=" <> strEncode pqEncryption ] where showTs = B.pack . formatISO8601Millis @@ -788,7 +835,8 @@ instance StrEncoding MsgMeta where recipient <- " R=" *> partyMeta A.decimal broker <- " B=" *> partyMeta base64P sndMsgId <- " S=" *> A.decimal - pure MsgMeta {integrity, recipient, broker, sndMsgId} + pqEncryption <- " PQ=" *> strP + pure MsgMeta {integrity, recipient, broker, sndMsgId, pqEncryption} where partyMeta idParser = (,) <$> idParser <* A.char ',' <*> tsISO8601P @@ -802,28 +850,28 @@ data SMPConfirmation = SMPConfirmation -- | optional reply queues included in confirmation (added in agent protocol v2) smpReplyQueues :: [SMPQueueInfo], -- | SMP client version - smpClientVersion :: Version + smpClientVersion :: VersionSMPC } deriving (Show) data AgentMsgEnvelope = AgentConfirmation - { agentVersion :: Version, - e2eEncryption_ :: Maybe (E2ERatchetParams 'C.X448), + { agentVersion :: VersionSMPA, + e2eEncryption_ :: Maybe (SndE2ERatchetParams 'C.X448), encConnInfo :: ByteString } | AgentMsgEnvelope - { agentVersion :: Version, + { agentVersion :: VersionSMPA, encAgentMessage :: ByteString } | AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet, - { agentVersion :: Version, + { agentVersion :: VersionSMPA, connReq :: ConnectionRequestUri 'CMInvitation, connInfo :: ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet, } | AgentRatchetKey - { agentVersion :: Version, - e2eEncryption :: E2ERatchetParams 'C.X448, + { agentVersion :: VersionSMPA, + e2eEncryption :: RcvE2ERatchetParams 'C.X448, info :: ByteString } deriving (Show) @@ -1115,7 +1163,7 @@ instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) whe CRInvitationUri crData e2eParams -> crEncode "invitation" crData (Just e2eParams) CRContactUri crData -> crEncode "contact" crData Nothing where - crEncode :: ByteString -> ConnReqUriData -> Maybe (E2ERatchetParamsUri 'C.X448) -> ByteString + crEncode :: ByteString -> ConnReqUriData -> Maybe (RcvE2ERatchetParamsUri 'C.X448) -> ByteString crEncode crMode ConnReqUriData {crScheme, crAgentVRange, crSmpQueues, crClientData} e2eParams = strEncode crScheme <> "/" <> crMode <> "#/?" <> queryStr where @@ -1228,16 +1276,16 @@ sameQueue :: SMPQueue q => (SMPServer, SMP.QueueId) -> q -> Bool sameQueue addr q = sameQAddress addr (qAddress q) {-# INLINE sameQueue #-} -data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress} +data SMPQueueInfo = SMPQueueInfo {clientVersion :: VersionSMPC, queueAddress :: SMPQueueAddress} deriving (Eq, Show) instance Encoding SMPQueueInfo where smpEncode (SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey}) - | clientVersion > 1 = smpEncode (clientVersion, smpServer, senderId, dhPublicKey) + | clientVersion > initialSMPClientVersion = smpEncode (clientVersion, smpServer, senderId, dhPublicKey) | otherwise = smpEncode clientVersion <> legacyEncodeServer smpServer <> smpEncode (senderId, dhPublicKey) smpP = do clientVersion <- smpP - smpServer <- if clientVersion > 1 then smpP else updateSMPServerHosts <$> legacyServerP + smpServer <- if clientVersion > initialSMPClientVersion then smpP else updateSMPServerHosts <$> legacyServerP (senderId, dhPublicKey) <- smpP pure $ SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey} @@ -1245,20 +1293,20 @@ instance Encoding SMPQueueInfo where -- But this is created to allow backward and forward compatibility where SMPQueueUri -- could have more fields to convert to different versions of SMPQueueInfo in a different way, -- and this instance would become non-trivial. -instance VersionI SMPQueueInfo where - type VersionRangeT SMPQueueInfo = SMPQueueUri +instance VersionI SMPClientVersion SMPQueueInfo where + type VersionRangeT SMPClientVersion SMPQueueInfo = SMPQueueUri version = clientVersion toVersionRangeT (SMPQueueInfo _v addr) vr = SMPQueueUri vr addr -instance VersionRangeI SMPQueueUri where - type VersionT SMPQueueUri = SMPQueueInfo +instance VersionRangeI SMPClientVersion SMPQueueUri where + type VersionT SMPClientVersion SMPQueueUri = SMPQueueInfo versionRange = clientVRange toVersionT (SMPQueueUri _vr addr) v = SMPQueueInfo v addr -- | SMP queue information sent out-of-band. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#out-of-band-messages -data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRange, queueAddress :: SMPQueueAddress} +data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRangeSMPC, queueAddress :: SMPQueueAddress} deriving (Eq, Show) data SMPQueueAddress = SMPQueueAddress @@ -1307,7 +1355,7 @@ instance StrEncoding SMPQueueUri where smpServer = if maxVersion vr < srvHostnamesSMPClientVersion then updateSMPServerHosts srv' else srv' pure $ SMPQueueUri vr SMPQueueAddress {smpServer, senderId, dhPublicKey} where - unversioned = (versionToRange 1,[],) <$> strP <* A.endOfInput + unversioned = (versionToRange initialSMPClientVersion,[],) <$> strP <* A.endOfInput versioned = do dhKey_ <- optional strP query <- optional (A.char '/') *> A.char '?' *> strP @@ -1324,8 +1372,8 @@ instance Encoding SMPQueueUri where pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey} data ConnectionRequestUri (m :: ConnectionMode) where - CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation - -- contact connection request does NOT contain E2E encryption parameters - + CRInvitationUri :: ConnReqUriData -> RcvE2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation + -- contact connection request does NOT contain E2E encryption parameters for double ratchet - -- they are passed in AgentInvitation message CRContactUri :: ConnReqUriData -> ConnectionRequestUri CMContact @@ -1336,15 +1384,15 @@ deriving instance Show (ConnectionRequestUri m) data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m) instance Eq AConnectionRequestUri where - ACR m cr == ACR m' cr' = case testEquality m m' of - Just Refl -> cr == cr' - _ -> False + ACR m cr == ACR m' cr' = case testEquality m m' of + Just Refl -> cr == cr' + _ -> False deriving instance Show AConnectionRequestUri data ConnReqUriData = ConnReqUriData { crScheme :: ServiceScheme, - crAgentVRange :: VersionRange, + crAgentVRange :: VersionRangeSMPA, crSmpQueues :: NonEmpty SMPQueueUri, crClientData :: Maybe CRClientData } @@ -1713,13 +1761,13 @@ commandP binaryP = >>= \case ACmdTag SClient e cmd -> ACmd SClient e <$> case cmd of - NEW_ -> s (NEW <$> strP_ <*> strP_ <*> (strP <|> pure SMP.SMSubscribe)) - JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP) + NEW_ -> s (NEW <$> strP_ <*> strP_ <*> pqIKP <*> (strP <|> pure SMP.SMSubscribe)) + JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> pqSupP <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP) LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP) - ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> binaryP) + ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> binaryP) RJCT_ -> s (RJCT <$> A.takeByteString) SUB_ -> pure SUB - SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP) + SEND_ -> s (SEND <$> pqEncP <*> smpP <* A.space <*> binaryP) ACK_ -> s (ACK <$> A.decimal <*> optional (A.space *> binaryP)) SWCH_ -> pure SWCH OFF_ -> pure OFF @@ -1728,10 +1776,10 @@ commandP binaryP = ACmdTag SAgent e cmd -> ACmd SAgent e <$> case cmd of INV_ -> s (INV <$> strP) - CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP) - REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP) - INFO_ -> s (INFO <$> binaryP) - CON_ -> pure CON + CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strListP <* A.space <*> binaryP) + REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strP_ <*> binaryP) + INFO_ -> s (INFO <$> pqSupP <*> binaryP) + CON_ -> s (CON <$> strP) END_ -> pure END CONNECT_ -> s (CONNECT <$> strP_ <*> strP) DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP) @@ -1739,7 +1787,7 @@ commandP binaryP = UP_ -> s (UP <$> strP_ <*> connections) SWITCH_ -> s (SWITCH <$> strP_ <*> strP_ <*> strP) RSYNC_ -> s (RSYNC <$> strP_ <*> strP <*> strP) - MID_ -> s (MID <$> A.decimal) + MID_ -> s (MID <$> A.decimal <*> _strP) SENT_ -> s (SENT <$> A.decimal) MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP) MERRS_ -> s (MERRS <$> strP_ <*> strP) @@ -1762,6 +1810,12 @@ commandP binaryP = where s :: Parser a -> Parser a s p = A.space *> p + pqIKP :: Parser InitialKeys + pqIKP = strP_ <|> pure (IKNoPQ PQSupportOff) + pqSupP :: Parser PQSupport + pqSupP = strP_ <|> pure PQSupportOff + pqEncP :: Parser PQEncryption + pqEncP = strP_ <|> pure PQEncOff connections :: Parser [ConnId] connections = strP `A.sepBy'` A.char ',' sfDone :: Text -> Either String (ACommand 'Agent 'AESndFile) @@ -1777,15 +1831,15 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX -- | Serialize SMP agent command. serializeCommand :: ACommand p e -> ByteString serializeCommand = \case - NEW ntfs cMode subMode -> s (NEW_, ntfs, cMode, subMode) + NEW ntfs cMode pqIK subMode -> s (NEW_, ntfs, cMode, pqIK, subMode) INV cReq -> s (INV_, cReq) - JOIN ntfs cReq subMode cInfo -> s (JOIN_, ntfs, cReq, subMode, Str $ serializeBinary cInfo) - CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo] + JOIN ntfs cReq pqSup subMode cInfo -> s (JOIN_, ntfs, cReq, pqSup, subMode, Str $ serializeBinary cInfo) + CONF confId pqSup srvs cInfo -> B.unwords [s CONF_, confId, s pqSup, strEncodeList srvs, serializeBinary cInfo] LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo] - REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo] - ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo] + REQ invId pqSup srvs cInfo -> B.unwords [s REQ_, invId, s pqSup, s srvs, serializeBinary cInfo] + ACPT invId pqSup cInfo -> B.unwords [s ACPT_, invId, s pqSup, serializeBinary cInfo] RJCT invId -> B.unwords [s RJCT_, invId] - INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo] + INFO pqSup cInfo -> B.unwords [s INFO_, s pqSup, serializeBinary cInfo] SUB -> s SUB_ END -> s END_ CONNECT p h -> s (CONNECT_, p, h) @@ -1794,8 +1848,8 @@ serializeCommand = \case UP srv conns -> B.unwords [s UP_, s srv, connections conns] SWITCH dir phase srvs -> s (SWITCH_, dir, phase, srvs) RSYNC rrState cryptoErr cstats -> s (RSYNC_, rrState, cryptoErr, cstats) - SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody] - MID mId -> s (MID_, mId) + SEND pqEnc msgFlags msgBody -> B.unwords [s SEND_, s pqEnc, smpEncode msgFlags, serializeBinary msgBody] + MID mId pqEnc -> s (MID_, mId, pqEnc) SENT mId -> s (SENT_, mId) MERR mId e -> s (MERR_, mId, e) MERRS mIds e -> s (MERRS_, mIds, e) @@ -1811,7 +1865,7 @@ serializeCommand = \case DEL_USER userId -> s (DEL_USER_, userId) CHK -> s CHK_ STAT srvs -> s (STAT_, srvs) - CON -> s CON_ + CON pqEnc -> s (CON_, pqEnc) ERR e -> s (ERR_, e) OK -> s OK_ SUSPENDED -> s SUSPENDED_ @@ -1884,14 +1938,14 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p)) cmdWithMsgBody (APC e cmd) = APC e <$$> case cmd of - SEND msgFlags body -> SEND msgFlags <$$> getBody body + SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body - JOIN ntfs qUri subMode cInfo -> JOIN ntfs qUri subMode <$$> getBody cInfo - CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo + JOIN ntfs qUri pqSup subMode cInfo -> JOIN ntfs qUri pqSup subMode <$$> getBody cInfo + CONF confId pqSup srvs cInfo -> CONF confId pqSup srvs <$$> getBody cInfo LET confId cInfo -> LET confId <$$> getBody cInfo - REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo - ACPT invId cInfo -> ACPT invId <$$> getBody cInfo - INFO cInfo -> INFO <$$> getBody cInfo + REQ invId pqSup srvs cInfo -> REQ invId pqSup srvs <$$> getBody cInfo + ACPT invId pqSup cInfo -> ACPT invId pqSup <$$> getBody cInfo + INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo _ -> pure $ Right cmd getBody :: ByteString -> m (Either AgentErrorType ByteString) diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 8f67c74c2..ce76d5c89 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -30,7 +30,7 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval (RI2State) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.Ratchet (RatchetX448) +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, PQEncryption, PQSupport) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol ( MsgBody, @@ -44,10 +44,10 @@ import Simplex.Messaging.Protocol RcvPrivateAuthKey, SndPrivateAuthKey, SndPublicAuthKey, + VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util ((<$?>)) -import Simplex.Messaging.Version -- * Queue types @@ -61,8 +61,6 @@ data DBQueueId (q :: QueueStored) where DBQueueId :: Int64 -> DBQueueId 'QSStored DBNewQueue :: DBQueueId 'QSNew -deriving instance Eq (DBQueueId q) - deriving instance Show (DBQueueId q) type RcvQueue = StoredRcvQueue 'QSStored @@ -96,12 +94,12 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue dbReplaceQueueId :: Maybe Int64, rcvSwchStatus :: Maybe RcvSwitchStatus, -- | SMP client version - smpClientVersion :: Version, + smpClientVersion :: VersionSMPC, -- | credentials used in context of notifications clientNtfCreds :: Maybe ClientNtfCreds, deleteErrors :: Int } - deriving (Eq, Show) + deriving (Show) rcvQueueInfo :: RcvQueue -> RcvQueueInfo rcvQueueInfo rq@RcvQueue {server, rcvSwchStatus} = @@ -128,7 +126,7 @@ data ClientNtfCreds = ClientNtfCreds -- | shared DH secret used to encrypt/decrypt notification metadata (NMsgMeta) from server to recipient rcvNtfDhSecret :: RcvNtfDhSecret } - deriving (Eq, Show) + deriving (Show) type SndQueue = StoredSndQueue 'QSStored @@ -159,9 +157,9 @@ data StoredSndQueue (q :: QueueStored) = SndQueue dbReplaceQueueId :: Maybe Int64, sndSwchStatus :: Maybe SndSwitchStatus, -- | SMP client version - smpClientVersion :: Version + smpClientVersion :: VersionSMPC } - deriving (Eq, Show) + deriving (Show) sndQueueInfo :: SndQueue -> SndQueueInfo sndQueueInfo SndQueue {server, sndSwchStatus} = @@ -256,8 +254,6 @@ data Connection (d :: ConnType) where DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex ContactConnection :: ConnData -> RcvQueue -> Connection CContact -deriving instance Eq (Connection d) - deriving instance Show (Connection d) toConnData :: Connection d -> ConnData @@ -290,8 +286,6 @@ connType SCSnd = CSnd connType SCDuplex = CDuplex connType SCContact = CContact -deriving instance Eq (SConnType d) - deriving instance Show (SConnType d) instance TestEquality SConnType where @@ -305,21 +299,17 @@ instance TestEquality SConnType where -- Used to refer to an arbitrary connection when retrieving from store. data SomeConn = forall d. SomeConn (SConnType d) (Connection d) -instance Eq SomeConn where - SomeConn d c == SomeConn d' c' = case testEquality d d' of - Just Refl -> c == c' - _ -> False - deriving instance Show SomeConn data ConnData = ConnData { connId :: ConnId, userId :: UserId, - connAgentVersion :: Version, + connAgentVersion :: VersionSMPA, enableNtfs :: Bool, lastExternalSndId :: PrevExternalSndId, deleted :: Bool, - ratchetSyncState :: RatchetSyncState + ratchetSyncState :: RatchetSyncState, + pqSupport :: PQSupport } deriving (Eq, Show) @@ -534,6 +524,7 @@ data SndMsgData = SndMsgData msgType :: AgentMessageType, msgFlags :: MsgFlags, msgBody :: MsgBody, + pqEncryption :: PQEncryption, internalHash :: MsgHash, prevMsgHash :: MsgHash } @@ -551,6 +542,7 @@ data PendingMsgData = PendingMsgData msgType :: AgentMessageType, msgFlags :: MsgFlags, msgBody :: MsgBody, + pqEncryption :: PQEncryption, msgRetryState :: Maybe RI2State, internalTs :: InternalTs } diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index c82e91d9f..2f6707c5a 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -50,7 +50,6 @@ module Simplex.Messaging.Agent.Store.SQLite createNewConn, updateNewConnRcv, updateNewConnSnd, - createRcvConn, -- no longer used createSndConn, getConn, getDeletedConn, @@ -59,7 +58,9 @@ module Simplex.Messaging.Agent.Store.SQLite getConnData, setConnDeleted, setConnAgentVersion, + setConnPQSupport, getDeletedConnIds, + getDeletedWaitingDeliveryConnIds, setConnRatchetSync, addProcessedRatchetKeyHash, checkRatchetKeyHashExists, @@ -93,7 +94,6 @@ module Simplex.Messaging.Agent.Store.SQLite getAcceptedConfirmation, removeConfirmations, -- Invitations - sent via Contact connections - setConnectionVersion, createInvitation, getInvitation, acceptInvitation, @@ -241,7 +241,7 @@ import Data.List (foldl', intercalate, sortBy) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe, isJust, listToMaybe, catMaybes) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe) import Data.Ord (Down (..)) import Data.Text (Text) import qualified Data.Text as T @@ -268,7 +268,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations (DownMigration (..), MTRE import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) -import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) +import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..)) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) @@ -278,7 +279,7 @@ import Simplex.Messaging.Protocol import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>)) -import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) import System.FilePath (takeDirectory) @@ -542,11 +543,8 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of ConnData {connId} -> Right . (connId,) <$> create connId createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId) -createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} cMode = do - fst <$$> createConn_ gVar cData create - where - create connId = - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True) +createNewConn db gVar cData cMode = do + fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) updateNewConnRcv db connId rq = @@ -568,22 +566,25 @@ updateNewConnSnd db connId sq = updateConn :: IO (Either StoreError SndQueue) updateConn = Right <$> addConnSndQueue_ db connId sq -createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) -createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@RcvQueue {server} cMode = - createConn_ gVar cData $ \connId -> do - serverKeyHash_ <- createServer_ db server - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True) - insertRcvQueue_ db connId q serverKeyHash_ - createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue)) -createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@SndQueue {server} = +createSndConn db gVar cData q@SndQueue {server} = -- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_ ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $ createConn_ gVar cData $ \connId -> do serverKeyHash_ <- createServer_ db server - DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, True) + createConnRecord db connId cData SCMInvitation insertSndQueue_ db connId q serverKeyHash_ +createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO () +createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSupport} cMode = + DB.execute + db + [sql| + INSERT INTO connections + (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?) + |] + (userId, connId, cMode, connAgentVersion, enableNtfs, pqSupport, True) + checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do fromMaybe False @@ -602,12 +603,32 @@ getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId) (rq,) <$> ExceptT (getConn db connId) -deleteConn :: DB.Connection -> ConnId -> IO () -deleteConn db connId = - DB.executeNamed - db - "DELETE FROM connections WHERE conn_id = :conn_id;" - [":conn_id" := connId] +-- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted +deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId) +deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of + Nothing -> delete + Just timeout -> + ifM + checkNoPendingDeliveries_ + delete + ( ifM + (checkWaitDeliveryTimeout_ timeout) + delete + (pure Nothing) + ) + where + delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId + checkNoPendingDeliveries_ = do + r :: (Maybe Int64) <- + maybeFirstRow fromOnly $ + DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId) + pure $ isNothing r + checkWaitDeliveryTimeout_ timeout = do + cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime + r :: (Maybe Int64) <- + maybeFirstRow fromOnly $ + DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs) + pure $ isJust r upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) upgradeRcvConnToDuplex db connId sq = @@ -681,7 +702,7 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d |] (host, port, rcvId) -setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> Version -> IO () +setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO () setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion = DB.executeNamed db @@ -782,7 +803,7 @@ setRcvQueueNtfCreds db connId clientNtfCreds = Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) Nothing -> (Nothing, Nothing, Nothing, Nothing) -type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version) +type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC) smpConfirmation :: SMPConfirmationRow -> SMPConfirmation smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) = @@ -791,7 +812,7 @@ smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersi e2ePubKey, connInfo, smpReplyQueues = fromMaybe [] smpReplyQueues_, - smpClientVersion = fromMaybe 1 smpClientVersion_ + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ } createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId) @@ -868,10 +889,6 @@ removeConfirmations db connId = |] [":conn_id" := connId] -setConnectionVersion :: DB.Connection -> ConnId -> Version -> IO () -setConnectionVersion db connId aVersion = - DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) - createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId) createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} = createWithRandomId gVar $ \invitationId -> @@ -1008,18 +1025,18 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} = DB.query db [sql| - SELECT m.msg_type, m.msg_flags, m.msg_body, m.internal_ts, s.retry_int_slow, s.retry_int_fast + SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast FROM messages m JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id WHERE m.conn_id = ? AND m.internal_id = ? |] (connId, msgId) err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []" - pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData - pendingMsgData (msgType, msgFlags_, msgBody, internalTs, riSlow_, riFast_) = + pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData + pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) = let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_ msgRetryState = RI2State <$> riSlow_ <*> riFast_ - in PendingMsgData {msgId, msgType, msgFlags, msgBody, msgRetryState, internalTs} + in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs} markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId) getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a)) @@ -1088,7 +1105,7 @@ getRcvMsg db connId agentMsgId = [sql| SELECT r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, - m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack + m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack FROM rcv_messages r JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id @@ -1104,7 +1121,7 @@ getLastMsg db connId msgId = [sql| SELECT r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, - m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack + m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack FROM rcv_messages r JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id @@ -1113,9 +1130,9 @@ getLastMsg db connId msgId = |] (connId, msgId) -toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs, AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg -toRcvMsg (agentMsgId, internalTs, brokerId, brokerTs, sndMsgId, integrity, internalHash, msgType, msgBody, rcptInternalId_, rcptStatus_, userAck) = - let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity} +toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg +toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) = + let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption} msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck} @@ -1175,34 +1192,34 @@ deleteSndMsgsExpired db ttl = do "DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL" (Only cutoffTs) -createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO () -createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 = - DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2) +createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () +createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = + DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem) -getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448)) +getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams)) getRatchetX3dhKeys db connId = - fmap hasKeys $ - firstRow id SEX3dhKeysNotFound $ - DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2 FROM ratchets WHERE conn_id = ?" (Only connId) + firstRow' keys SEX3dhKeysNotFound $ + DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId) where - hasKeys = \case - Right (Just k1, Just k2) -> Right (k1, k2) + keys = \case + (Just k1, Just k2, pKem) -> Right (k1, k2, pKem) _ -> Left SEX3dhKeysNotFound -- used to remember new keys when starting ratchet re-synchronization -- TODO remove the columns for public keys in v5.7. -- Currently, the keys are not used but still stored to support app downgrade to the previous version. -setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO () -setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 = +setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () +setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = DB.execute db [sql| UPDATE ratchets - SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ? + SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ? WHERE conn_id = ? |] - (x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, connId) + (x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId) +-- TODO remove the columns for public keys in v5.7. createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO () createRatchet db connId rc = DB.executeNamed @@ -1213,7 +1230,10 @@ createRatchet db connId rc = ON CONFLICT (conn_id) DO UPDATE SET ratchet_state = :ratchet_state, x3dh_priv_key_1 = NULL, - x3dh_priv_key_2 = NULL + x3dh_priv_key_2 = NULL, + x3dh_pub_key_1 = NULL, + x3dh_pub_key_2 = NULL, + pq_priv_kem = NULL |] [":conn_id" := connId, ":ratchet_state" := rc] @@ -1752,6 +1772,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 +instance ToField (Version v) where toField (Version v) = toField v + +instance FromField (Version v) where fromField f = Version <$> fromField f + listToEither :: e -> [a] -> Either e a listToEither _ (x : _) = Right x listToEither e _ = Left e @@ -1903,25 +1927,38 @@ getConnData db connId' = [sql| SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, - last_external_snd_msg_id, deleted, ratchet_sync_state + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support FROM connections WHERE conn_id = ? |] (Only connId') where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState}, cMode) + cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) -setConnDeleted :: DB.Connection -> ConnId -> IO () -setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) +setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () +setConnDeleted db waitDelivery connId + | waitDelivery = do + currentTs <- getCurrentTime + DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId) + | otherwise = + DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) -setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO () +setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () setConnAgentVersion db connId aVersion = DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) +setConnPQSupport :: DB.Connection -> ConnId -> PQSupport -> IO () +setConnPQSupport db connId pqSupport = + DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId) + getDeletedConnIds :: DB.Connection -> IO [ConnId] getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) +getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId] +getDeletedWaitingDeliveryConnIds db = + map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL" + setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO () setConnRatchetSync db connId ratchetSyncState = DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId) @@ -1970,12 +2007,12 @@ rcvQueueQuery = toRcvQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe Version, Int) + :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> RcvQueue toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = let server = SMPServer host port keyHash - smpClientVersion = fromMaybe 1 smpClientVersion_ + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} _ -> Nothing @@ -2011,7 +2048,7 @@ sndQueueQuery = toSndQueue :: (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId) :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, Version) -> + :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> SndQueue toSndQueue ( (userId, keyHash, connId, host, port, sndId) @@ -2060,23 +2097,15 @@ updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do - let MsgMeta {recipient = (internalId, internalTs)} = msgMeta - DB.executeNamed + let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta + DB.execute dbConn [sql| INSERT INTO messages - ( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body) - VALUES - (:conn_id,:internal_id,:internal_ts,:internal_rcv_id, NULL,:msg_type,:msg_flags,:msg_body); + (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) + VALUES (?,?,?,?,?,?,?,?,?); |] - [ ":conn_id" := connId, - ":internal_id" := internalId, - ":internal_ts" := internalTs, - ":internal_rcv_id" := internalRcvId, - ":msg_type" := msgType, - ":msg_flags" := msgFlags, - ":msg_body" := msgBody - ] + (connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption) insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do @@ -2157,23 +2186,16 @@ updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = -- * createSndMsg helpers insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO () -insertSndMsgBase_ dbConn connId SndMsgData {..} = do - DB.executeNamed - dbConn +insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do + DB.execute + db [sql| INSERT INTO messages - ( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body) + (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) VALUES - (:conn_id,:internal_id,:internal_ts, NULL,:internal_snd_id,:msg_type,:msg_flags,:msg_body); + (?,?,?,?,?,?,?,?,?); |] - [ ":conn_id" := connId, - ":internal_id" := internalId, - ":internal_ts" := internalTs, - ":internal_snd_id" := internalSndId, - ":msg_type" := msgType, - ":msg_flags" := msgFlags, - ":msg_body" := msgBody - ] + (connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption) insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () insertSndMsgDetails_ dbConn connId SndMsgData {..} = @@ -2267,17 +2289,18 @@ createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redire forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId pure dstEntityId where - dummyDst = FileDescription - { party = SFRecipient, - size, - digest, - redirect = Nothing, - -- updated later with updateRcvFileRedirect - key = C.unsafeSbKey $ B.replicate 32 '#', - nonce = C.cbNonce "", - chunkSize = FileSize 0, - chunks = [] - } + dummyDst = + FileDescription + { party = SFRecipient, + size, + digest, + redirect = Nothing, + -- updated later with updateRcvFileRedirect + key = C.unsafeSbKey $ B.replicate 32 '#', + nonce = C.cbNonce "", + chunkSize = FileSize 0, + chunks = [] + } insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> IO (Either StoreError (RcvFileId, DBRcvFileId)) insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ = runExceptT $ do @@ -2346,10 +2369,11 @@ getRcvFile db rcvFileId = runExceptT $ do toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) = let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_ saveFile = CryptoFile savePath cfArgs - redirect = RcvFileRedirect - <$> redirectDbId - <*> redirectEntityId - <*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_) + redirect = + RcvFileRedirect + <$> redirectDbId + <*> redirectEntityId + <*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_) in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []} getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk] getChunks rcvFileEntityId userId fileTmpPath = do diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index 9ba6cd08f..18c16cc8b 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -71,7 +71,8 @@ dbBusyLoop action = loop 500 3000000 loop :: Int -> Int -> IO a loop t tLim = action `E.catch` \(e :: SQLError) -> - if tLim > t && SQL.sqlError e == SQL.ErrorBusy + let se = SQL.sqlError e in + if tLim > t && (se == SQL.ErrorBusy || se == SQL.ErrorLocked) then do threadDelay t loop (t * 9 `div` 8) (tLim - t) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 83c900f72..344a3f9ce 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -69,6 +69,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231222_command_created import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery +import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport.Client (TransportHost) @@ -106,7 +108,9 @@ schemaMigrations = ("m20231222_command_created_at", m20231222_command_created_at, Just down_m20231222_command_created_at), ("m20231225_failed_work_items", m20231225_failed_work_items, Just down_m20231225_failed_work_items), ("m20240121_message_delivery_indexes", m20240121_message_delivery_indexes, Just down_m20240121_message_delivery_indexes), - ("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect) + ("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect), + ("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery), + ("m20240225_ratchet_kem", m20240225_ratchet_kem, Just down_m20240225_ratchet_kem) ] -- | The list of migrations in ascending order by date diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs new file mode 100644 index 000000000..e61179768 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240223_connections_wait_delivery.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20240223_connections_wait_delivery :: Query +m20240223_connections_wait_delivery = + [sql| +ALTER TABLE connections ADD COLUMN deleted_at_wait_delivery TEXT; +|] + +down_m20240223_connections_wait_delivery :: Query +down_m20240223_connections_wait_delivery = + [sql| +ALTER TABLE connections DROP COLUMN deleted_at_wait_delivery; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs new file mode 100644 index 000000000..1e8a8db4d --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20240225_ratchet_kem.hs @@ -0,0 +1,22 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20240225_ratchet_kem :: Query +m20240225_ratchet_kem = + [sql| +ALTER TABLE ratchets ADD COLUMN pq_priv_kem BLOB; +ALTER TABLE connections ADD COLUMN pq_support INTEGER NOT NULL DEFAULT 0; +ALTER TABLE messages ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0; +|] + +down_m20240225_ratchet_kem :: Query +down_m20240225_ratchet_kem = + [sql| +ALTER TABLE ratchets DROP COLUMN pq_priv_kem; +ALTER TABLE connections DROP COLUMN pq_support; +ALTER TABLE messages DROP COLUMN pq_encryption; +|] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql index b9efaa05c..0818be904 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql @@ -26,7 +26,9 @@ CREATE TABLE connections( deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL), user_id INTEGER CHECK(user_id NOT NULL) REFERENCES users ON DELETE CASCADE, - ratchet_sync_state TEXT NOT NULL DEFAULT 'ok' + ratchet_sync_state TEXT NOT NULL DEFAULT 'ok', + deleted_at_wait_delivery TEXT, + pq_support INTEGER NOT NULL DEFAULT 0 ) WITHOUT ROWID; CREATE TABLE rcv_queues( host TEXT NOT NULL, @@ -89,6 +91,7 @@ CREATE TABLE messages( msg_type BLOB NOT NULL, --(H)ELLO,(R)EPLY,(D)ELETE. Should SMP confirmation be saved too? msg_body BLOB NOT NULL DEFAULT x'', msg_flags TEXT NULL, + pq_encryption INTEGER NOT NULL DEFAULT 0, PRIMARY KEY(conn_id, internal_id), FOREIGN KEY(conn_id, internal_rcv_id) REFERENCES rcv_messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, @@ -159,7 +162,8 @@ CREATE TABLE ratchets( e2e_version INTEGER NOT NULL DEFAULT 1 , x3dh_pub_key_1 BLOB, - x3dh_pub_key_2 BLOB + x3dh_pub_key_2 BLOB, + pq_priv_kem BLOB ) WITHOUT ROWID; CREATE TABLE skipped_messages( skipped_message_id INTEGER PRIMARY KEY, diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 2cbdced35..d8b202761 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -76,7 +76,7 @@ module Simplex.Messaging.Client PCTransmission, mkTransmission, authTransmission, - clientStub, + smpClientStub, ) where @@ -117,14 +117,14 @@ import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- -- Use 'getSMPClient' to connect to an SMP server and create a client handle. -data ProtocolClient err msg = ProtocolClient +data ProtocolClient v err msg = ProtocolClient { action :: Maybe (Async ()), - thParams :: THandleParams, + thParams :: THandleParams v, sessionTs :: UTCTime, - client_ :: PClient err msg + client_ :: PClient v err msg } -data PClient err msg = PClient +data PClient v err msg = PClient { connected :: TVar Bool, transportSession :: TransportSession msg, transportHost :: TransportHost, @@ -135,11 +135,11 @@ data PClient err msg = PClient sentCommands :: TMap CorrId (Request err msg), sndQ :: TBQueue ByteString, rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)), - msgQ :: Maybe (TBQueue (ServerTransmission msg)) + msgQ :: Maybe (TBQueue (ServerTransmission v msg)) } -clientStub :: TVar ChaChaDRG -> ByteString -> Version -> Maybe THandleAuth -> STM (ProtocolClient err msg) -clientStub g sessionId thVersion thAuth = do +smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg) +smpClientStub g sessionId thVersion thAuth = do connected <- newTVar False clientCorrId <- C.newRandomDRG g sentCommands <- TM.empty @@ -174,13 +174,13 @@ clientStub g sessionId thVersion thAuth = do } } -type SMPClient = ProtocolClient ErrorType BrokerMsg +type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg -- | Type for client command data type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg) -- | Type synonym for transmission from some SPM server queue. -type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg) +type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, EntityId, msg) data HostMode = -- | prefer (or require) onion hosts when connecting via SOCKS proxy @@ -241,7 +241,7 @@ transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} = TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing} -- | protocol client configuration. -data ProtocolClientConfig = ProtocolClientConfig +data ProtocolClientConfig v = ProtocolClientConfig { -- | size of TBQueue to use for server commands and responses qSize :: Natural, -- | default server port if port is not specified in ProtocolServer @@ -249,13 +249,13 @@ data ProtocolClientConfig = ProtocolClientConfig -- | network configuration networkConfig :: NetworkConfig, -- | client-server protocol version range - serverVRange :: VersionRange, + serverVRange :: VersionRange v, -- | delay between sending batches of commands (microseconds) batchDelay :: Maybe Int } -- | Default protocol client configuration. -defaultClientConfig :: VersionRange -> ProtocolClientConfig +defaultClientConfig :: VersionRange v -> ProtocolClientConfig v defaultClientConfig serverVRange = ProtocolClientConfig { qSize = 64, @@ -265,7 +265,7 @@ defaultClientConfig serverVRange = batchDelay = Nothing } -defaultSMPClientConfig :: ProtocolClientConfig +defaultSMPClientConfig :: ProtocolClientConfig SMPVersion defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange data Request err msg = Request @@ -292,15 +292,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts onionHost = find isOnionHost hosts publicHost = find (not . isOnionHost) hosts -protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String +protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err msg -> String protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_ where snd3 (_, s, _) = s -transportHost' :: ProtocolClient err msg -> TransportHost +transportHost' :: ProtocolClient v err msg -> TransportHost transportHost' = transportHost . client_ -transportSession' :: ProtocolClient err msg -> TransportSession msg +transportSession' :: ProtocolClient v err msg -> TransportSession msg transportSession' = transportSession . client_ type UserId = Int64 @@ -313,7 +313,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId) -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall err msg. Protocol err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) +getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, serverVRange, batchDelay} msgQ disconnected = do case chooseTransportHost networkConfig (host srv) of Right useHost -> @@ -322,7 +322,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig - mkProtocolClient :: TransportHost -> STM (PClient err msg) + mkProtocolClient :: TransportHost -> STM (PClient v err msg) mkProtocolClient transportHost = do connected <- newTVar False pingErrorCount <- newTVar 0 @@ -345,7 +345,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize msgQ } - runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg)) + runClient :: (ServiceName, ATransport) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg)) runClient (port', ATransport t) useHost c = do cVar <- newEmptyTMVarIO let tcConfig = transportClientConfig networkConfig @@ -366,10 +366,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize "80" -> ("80", transport @WS) p -> (p, transport @TLS) - client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO () + client :: forall c. Transport c => TProxy c -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c -> IO () client _ c cVar h = do ks <- atomically $ C.generateKeyPair g - runExceptT (protocolClientHandshake @err @msg h ks (keyHash srv) serverVRange) >>= \case + runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange) >>= \case Left e -> atomically . putTMVar cVar . Left $ PCETransportError e Right th@THandle {params} -> do sessionTs <- getCurrentTime @@ -380,16 +380,16 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0]) `finally` disconnected c' - send :: Transport c => ProtocolClient err msg -> THandle c -> IO () + send :: Transport c => ProtocolClient v err msg -> THandle v c -> IO () send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPutLog h - receive :: Transport c => ProtocolClient err msg -> THandle c -> IO () + receive :: Transport c => ProtocolClient v err msg -> THandle v c -> IO () receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ - ping :: ProtocolClient err msg -> IO () + ping :: ProtocolClient v err msg -> IO () ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do threadDelay' smpPingInterval - runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case + runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @v @err @msg) >>= \case Left PCEResponseTimeout -> do cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1) when (maxCnt == 0 || cnt < maxCnt) $ ping c @@ -397,10 +397,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize where maxCnt = smpPingCount networkConfig - process :: ProtocolClient err msg -> IO () + process :: ProtocolClient v err msg -> IO () process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c) - processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO () + processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO () processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) = if B.null $ bs corrId then sendMsg respOrErr @@ -428,7 +428,7 @@ proxyUsername :: TransportSession msg -> ByteString proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_ -- | Disconnects client from the server and terminates client threads. -closeProtocolClient :: ProtocolClient err msg -> IO () +closeProtocolClient :: ProtocolClient v err msg -> IO () closeProtocolClient = mapM_ uninterruptibleCancel . action -- | SMP client error type. @@ -517,7 +517,7 @@ processSUBResponse c (Response rId r) = case r of writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c) -serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg +serverTransmission :: ProtocolClient v err msg -> RecipientId -> msg -> ServerTransmission v msg serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} entityId message = (transportSession, thVersion, sessionId, entityId, message) @@ -635,7 +635,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg) -- | Send multiple commands with batching and collect responses -sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) +sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg)) sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs validate . concat =<< mapM (sendBatch c) bs @@ -652,12 +652,12 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz where diff = L.length cs - length rs -streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () +streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO () streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs mapM_ (cb <=< sendBatch c) bs -sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg] +sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg] sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do case b of TBError e Request {entityId} -> do @@ -673,7 +673,7 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do (: []) <$> getResponse c r -- | Send Protocol command -sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg +sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd = ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd) where @@ -690,7 +690,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand | otherwise = tEncode t -- TODO switch to timeout or TimeManager that supports Int64 -getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg) +getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg) getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do response <- timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case @@ -698,7 +698,7 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Requ Nothing -> pure $ Left PCEResponseTimeout pure Response {entityId, response} -mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg) +mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg) mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do corrId <- atomically getNextCorrId let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd) diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 068a52782..73f47648b 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -65,7 +65,7 @@ type SMPSub = (SMPSubParty, QueueId) -- type SMPServerSub = (SMPServer, SMPSub) data SMPClientAgentConfig = SMPClientAgentConfig - { smpCfg :: ProtocolClientConfig, + { smpCfg :: ProtocolClientConfig SMPVersion, reconnectInterval :: RetryInterval, msgQSize :: Natural, agentQSize :: Natural, @@ -91,7 +91,7 @@ defaultSMPClientAgentConfig = data SMPClientAgent = SMPClientAgent { agentCfg :: SMPClientAgentConfig, - msgQ :: TBQueue (ServerTransmission BrokerMsg), + msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg), agentQ :: TBQueue SMPClientAgentEvent, randomDrg :: TVar ChaChaDRG, smpClients :: TMap SMPServer SMPClientVar, diff --git a/src/Simplex/Messaging/Compression.hs b/src/Simplex/Messaging/Compression.hs new file mode 100644 index 000000000..fec9f8151 --- /dev/null +++ b/src/Simplex/Messaging/Compression.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Simplex.Messaging.Compression where + +import qualified Codec.Compression.Zstd.FFI as Z +import Control.Monad (forM) +import Control.Monad.Except +import Control.Monad.IO.Class +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import qualified Data.ByteString.Unsafe as B +import Data.Either (fromRight) +import Data.List.NonEmpty (NonEmpty) +import Foreign +import Foreign.C.Types +import GHC.IO (unsafePerformIO) +import Simplex.Messaging.Encoding +import UnliftIO.Exception (bracket) + +data Compressed + = -- | Short messages are left intact to skip copying and FFI festivities. + Passthrough ByteString + | -- | Generic compression using no extra context. + Compressed Large + +-- | Messages below this length are not encoded to avoid compression overhead. +maxLengthPassthrough :: Int +maxLengthPassthrough = 180 -- Sampled from real client data. Messages with length > 180 rapidly gain compression ratio. + +instance Encoding Compressed where + smpEncode = \case + Passthrough bytes -> "0" <> smpEncode bytes + Compressed bytes -> "1" <> smpEncode bytes + smpP = + smpP >>= \case + '0' -> Passthrough <$> smpP + '1' -> Compressed <$> smpP + x -> fail $ "unknown Compressed tag: " <> show x + +type CompressCtx = (Ptr Z.CCtx, Ptr CChar, CSize) + +withCompressCtx :: CSize -> (CompressCtx -> IO a) -> IO a +withCompressCtx scratchSize action = + bracket Z.createCCtx Z.freeCCtx $ \cctx -> + allocaBytes (fromIntegral scratchSize) $ \scratchPtr -> + action (cctx, scratchPtr, scratchSize) + +-- | Compress bytes, falling back to Passthrough in case of some internal error. +compress :: CompressCtx -> ByteString -> IO Compressed +compress ctx bs = fromRight (Passthrough bs) <$> compress_ ctx bs + +compress_ :: CompressCtx -> ByteString -> IO (Either String Compressed) +compress_ (cctx, scratchPtr, scratchSize) bs + | B.length bs <= maxLengthPassthrough = pure . Right $ Passthrough bs + | otherwise = + B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> runExceptT $ do + -- should not fail, unless input buffer is too short + dstSize <- ExceptT $ Z.checkError $ Z.compressCCtx cctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize) 3 + liftIO $ Compressed . Large <$> B.packCStringLen (scratchPtr, fromIntegral dstSize) + +type DecompressCtx = (Ptr Z.DCtx, Ptr CChar, CSize) + +withDecompressCtx :: Int -> (DecompressCtx -> IO a) -> IO a +withDecompressCtx maxUnpackedSize action = + bracket Z.createDCtx Z.freeDCtx $ \dctx -> + allocaBytes maxUnpackedSize $ \scratchPtr -> + action (dctx, scratchPtr, fromIntegral maxUnpackedSize) + +decompress :: DecompressCtx -> Compressed -> IO (Either String ByteString) +decompress (dctx, scratchPtr, scratchSize) = \case + Passthrough bs -> pure $ Right bs + Compressed (Large bs) -> + B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> do + res <- Z.checkError $ Z.decompressDCtx dctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize) + forM res $ \dstSize -> B.packCStringLen (scratchPtr, fromIntegral dstSize) + +decompressBatch :: Int -> NonEmpty Compressed -> NonEmpty (Either String ByteString) +decompressBatch maxUnpackedSize items = unsafePerformIO $ withDecompressCtx maxUnpackedSize $ forM items . decompress +{-# NOINLINE decompressBatch #-} -- prevent double-evaluation under unsafePerformIO diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 9a775faa3..28183a1fc 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -101,6 +101,7 @@ module Simplex.Messaging.Crypto verify, verify', validSignatureSize, + checkAlgorithm, -- * crypto_box authenticator, as discussed in https://groups.google.com/g/sci.crypt/c/73yb5a9pz2Y/m/LNgRO7IYXOwJ CbAuthenticator (..), @@ -243,8 +244,6 @@ data SAlgorithm :: Algorithm -> Type where SX25519 :: SAlgorithm X25519 SX448 :: SAlgorithm X448 -deriving instance Eq (SAlgorithm a) - deriving instance Show (SAlgorithm a) data Alg = forall a. AlgorithmI a => Alg (SAlgorithm a) @@ -297,11 +296,6 @@ data APublicKey AlgorithmI a => APublicKey (SAlgorithm a) (PublicKey a) -instance Eq APublicKey where - APublicKey a k == APublicKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - instance Encoding APublicKey where smpEncode = smpEncode . encodePubKey {-# INLINE smpEncode #-} @@ -342,11 +336,6 @@ data APrivateKey AlgorithmI a => APrivateKey (SAlgorithm a) (PrivateKey a) -instance Eq APrivateKey where - APrivateKey a k == APrivateKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APrivateKey type PrivateKeyEd25519 = PrivateKey Ed25519 @@ -372,11 +361,6 @@ data APrivateSignKey (AlgorithmI a, SignatureAlgorithm a) => APrivateSignKey (SAlgorithm a) (PrivateKey a) -instance Eq APrivateSignKey where - APrivateSignKey a k == APrivateSignKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APrivateSignKey instance Encoding APrivateSignKey where @@ -396,11 +380,6 @@ data APublicVerifyKey (AlgorithmI a, SignatureAlgorithm a) => APublicVerifyKey (SAlgorithm a) (PublicKey a) -instance Eq APublicVerifyKey where - APublicVerifyKey a k == APublicVerifyKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APublicVerifyKey data APrivateDhKey @@ -408,11 +387,6 @@ data APrivateDhKey (AlgorithmI a, DhAlgorithm a) => APrivateDhKey (SAlgorithm a) (PrivateKey a) -instance Eq APrivateDhKey where - APrivateDhKey a k == APrivateDhKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APrivateDhKey data APublicDhKey @@ -420,11 +394,6 @@ data APublicDhKey (AlgorithmI a, DhAlgorithm a) => APublicDhKey (SAlgorithm a) (PublicKey a) -instance Eq APublicDhKey where - APublicDhKey a k == APublicDhKey a' k' = case testEquality a a' of - Just Refl -> k == k' - Nothing -> False - deriving instance Show APublicDhKey data DhSecret (a :: Algorithm) where @@ -787,8 +756,6 @@ data Signature (a :: Algorithm) where SignatureEd25519 :: Ed25519.Signature -> Signature Ed25519 SignatureEd448 :: Ed448.Signature -> Signature Ed448 -deriving instance Eq (Signature a) - deriving instance Show (Signature a) data ASignature @@ -796,11 +763,6 @@ data ASignature (AlgorithmI a, SignatureAlgorithm a) => ASignature (SAlgorithm a) (Signature a) -instance Eq ASignature where - ASignature a s == ASignature a' s' = case testEquality a a' of - Just Refl -> s == s' - _ -> False - deriving instance Show ASignature class CryptoSignature s where @@ -885,6 +847,8 @@ data CryptoError CryptoHeaderError String | -- | no sending chain key in ratchet state CERatchetState + | -- | no decapsulation key in ratchet state + CERatchetKEMState | -- | header decryption error (could indicate that another key should be tried) CERatchetHeader | -- | too many skipped messages diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 0afa06db3..068f62776 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -3,18 +3,89 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StrictData #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -module Simplex.Messaging.Crypto.Ratchet where +module Simplex.Messaging.Crypto.Ratchet + ( Ratchet (..), + RatchetX448, + SkippedMsgDiff (..), + SkippedMsgKeys, + InitialKeys (..), + pattern IKPQOn, + pattern IKPQOff, + PQEncryption (..), + pattern PQEncOn, + pattern PQEncOff, + PQSupport (..), + pattern PQSupportOn, + pattern PQSupportOff, + AUseKEM (..), + RatchetKEMState (..), + SRatchetKEMState (..), + RcvPrivRKEMParams, + APrivRKEMParams (..), + RcvE2ERatchetParamsUri, + RcvE2ERatchetParams, + SndE2ERatchetParams, + AE2ERatchetParams (..), + E2ERatchetParamsUri (..), + E2ERatchetParams (..), + VersionE2E, + VersionRangeE2E, + pattern VersionE2E, + RatchetVersions (..), + kdfX3DHE2EEncryptVersion, + pqRatchetE2EEncryptVersion, + currentE2EEncryptVersion, + supportedE2EEncryptVRange, + generateRcvE2EParams, + generateSndE2EParams, + initialPQEncryption, + connPQEncryption, + replyKEM_, + pqSupportToEnc, + pqEncToSupport, + pqSupportAnd, + pqEnableSupport, + pqX3dhSnd, + pqX3dhRcv, + initSndRatchet, + initRcvRatchet, + rcEncrypt, + rcDecrypt, + -- used in tests + MsgHeader (..), + RatchetInitParams (..), + UseKEM (..), + RKEMParams (..), + ARKEMParams (..), + SndRatchet (..), + RcvRatchet (..), + RatchetKEM (..), + RatchetKEMAccepted (..), + RatchetKey (..), + fullHeaderLen, + applySMDiff, + encodeMsgHeader, + msgHeaderP, + ) +where +import Control.Applicative ((<|>)) import Control.Monad.Except +import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import Crypto.Hash (SHA512) @@ -23,100 +94,371 @@ import Crypto.Random (ChaChaDRG) import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson as J import qualified Data.Aeson.TH as JQ +import Data.Attoparsec.ByteString (Parser, peekWord8') +import qualified Data.Attoparsec.ByteString.Char8 as A +import qualified Data.ByteArray as BA import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as LB +import Data.Composition ((.:), (.:.)) +import Data.Functor (($>)) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, isJust) +import Data.Type.Equality import Data.Typeable (Typeable) -import Data.Word (Word32) +import Data.Word (Word16, Word32) import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.QueryString import Simplex.Messaging.Crypto +import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, defaultJSON, parseE, parseE') +import Simplex.Messaging.Util ((<$?>), ($>>=)) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import UnliftIO.STM -- e2e encryption headers version history: -- 1 - binary protocol encoding (1/1/2022) -- 2 - use KDF in x3dh (10/20/2022) -kdfX3DHE2EEncryptVersion :: Version -kdfX3DHE2EEncryptVersion = 2 +data E2EVersion -currentE2EEncryptVersion :: Version -currentE2EEncryptVersion = 2 +instance VersionScope E2EVersion -supportedE2EEncryptVRange :: VersionRange -supportedE2EEncryptVRange = mkVersionRange kdfX3DHE2EEncryptVersion currentE2EEncryptVersion +type VersionE2E = Version E2EVersion -data E2ERatchetParams (a :: Algorithm) - = E2ERatchetParams Version (PublicKey a) (PublicKey a) +type VersionRangeE2E = VersionRange E2EVersion + +pattern VersionE2E :: Word16 -> VersionE2E +pattern VersionE2E v = Version v + +kdfX3DHE2EEncryptVersion :: VersionE2E +kdfX3DHE2EEncryptVersion = VersionE2E 2 + +pqRatchetE2EEncryptVersion :: VersionE2E +pqRatchetE2EEncryptVersion = VersionE2E 3 + +-- TODO v5.7 increase to 3 +currentE2EEncryptVersion :: VersionE2E +currentE2EEncryptVersion = VersionE2E 2 + +-- TODO v5.7 remove dependency of version range on whether PQ encryption is used +supportedE2EEncryptVRange :: PQSupport -> VersionRangeE2E +supportedE2EEncryptVRange pq = + mkVersionRange kdfX3DHE2EEncryptVersion $ case pq of + PQSupportOn -> pqRatchetE2EEncryptVersion + PQSupportOff -> currentE2EEncryptVersion + +data RatchetKEMState + = RKSProposed -- only KEM encapsulation key + | RKSAccepted -- KEM ciphertext and the next encapsulation key + +data SRatchetKEMState (s :: RatchetKEMState) where + SRKSProposed :: SRatchetKEMState 'RKSProposed + SRKSAccepted :: SRatchetKEMState 'RKSAccepted + +deriving instance Show (SRatchetKEMState s) + +instance TestEquality SRatchetKEMState where + testEquality SRKSProposed SRKSProposed = Just Refl + testEquality SRKSAccepted SRKSAccepted = Just Refl + testEquality _ _ = Nothing + +class RatchetKEMStateI (s :: RatchetKEMState) where sRatchetKEMState :: SRatchetKEMState s + +instance RatchetKEMStateI RKSProposed where sRatchetKEMState = SRKSProposed + +instance RatchetKEMStateI RKSAccepted where sRatchetKEMState = SRKSAccepted + +checkRatchetKEMState :: forall t s s' a. (RatchetKEMStateI s, RatchetKEMStateI s') => t s' a -> Either String (t s a) +checkRatchetKEMState x = case testEquality (sRatchetKEMState @s) (sRatchetKEMState @s') of + Just Refl -> Right x + Nothing -> Left "bad ratchet KEM state" + +checkRatchetKEMState' :: forall t s s'. (RatchetKEMStateI s, RatchetKEMStateI s') => t s' -> Either String (t s) +checkRatchetKEMState' x = case testEquality (sRatchetKEMState @s) (sRatchetKEMState @s') of + Just Refl -> Right x + Nothing -> Left "bad ratchet KEM state" + +data RKEMParams (s :: RatchetKEMState) where + RKParamsProposed :: KEMPublicKey -> RKEMParams 'RKSProposed + RKParamsAccepted :: KEMCiphertext -> KEMPublicKey -> RKEMParams 'RKSAccepted + +deriving instance Eq (RKEMParams s) + +deriving instance Show (RKEMParams s) + +data ARKEMParams = forall s. RatchetKEMStateI s => ARKP (SRatchetKEMState s) (RKEMParams s) + +deriving instance Show ARKEMParams + +instance RatchetKEMStateI s => Encoding (RKEMParams s) where + smpEncode = \case + RKParamsProposed k -> smpEncode ('P', k) + RKParamsAccepted ct k -> smpEncode ('A', ct, k) + smpP = (\(ARKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP + +instance Encoding ARKEMParams where + smpEncode (ARKP _ ps) = smpEncode ps + smpP = + smpP >>= \case + 'P' -> ARKP SRKSProposed . RKParamsProposed <$> smpP + 'A' -> ARKP SRKSAccepted .: RKParamsAccepted <$> smpP <*> smpP + _ -> fail "bad ratchet KEM params" + +data E2ERatchetParams (s :: RatchetKEMState) (a :: Algorithm) + = E2ERatchetParams VersionE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) + deriving (Show) + +data AE2ERatchetParams (a :: Algorithm) + = forall s. + RatchetKEMStateI s => + AE2ERatchetParams (SRatchetKEMState s) (E2ERatchetParams s a) + +deriving instance Show (AE2ERatchetParams a) + +data AnyE2ERatchetParams + = forall s a. + (RatchetKEMStateI s, DhAlgorithm a, AlgorithmI a) => + AnyE2ERatchetParams (SRatchetKEMState s) (SAlgorithm a) (E2ERatchetParams s a) + +deriving instance Show AnyE2ERatchetParams + +instance (RatchetKEMStateI s, AlgorithmI a) => Encoding (E2ERatchetParams s a) where + smpEncode (E2ERatchetParams v k1 k2 kem_) + | v >= pqRatchetE2EEncryptVersion = smpEncode (v, k1, k2, kem_) + | otherwise = smpEncode (v, k1, k2) + smpP = toParams <$?> smpP + where + toParams :: AE2ERatchetParams a -> Either String (E2ERatchetParams s a) + toParams = \case + AE2ERatchetParams _ (E2ERatchetParams v k1 k2 Nothing) -> Right $ E2ERatchetParams v k1 k2 Nothing + AE2ERatchetParams _ ps -> checkRatchetKEMState ps + +instance AlgorithmI a => Encoding (AE2ERatchetParams a) where + smpEncode (AE2ERatchetParams _ ps) = smpEncode ps + smpP = (\(AnyE2ERatchetParams s _ ps) -> AE2ERatchetParams s <$> checkAlgorithm ps) <$?> smpP + +instance Encoding AnyE2ERatchetParams where + smpEncode (AnyE2ERatchetParams _ _ ps) = smpEncode ps + smpP = do + v :: VersionE2E <- smpP + APublicDhKey a k1 <- smpP + APublicDhKey a' k2 <- smpP + case testEquality a a' of + Nothing -> fail "bad e2e params: different key algorithms" + Just Refl -> + kemP v >>= \case + Just (ARKP s kem) -> pure $ AnyE2ERatchetParams s a $ E2ERatchetParams v k1 k2 (Just kem) + Nothing -> pure $ AnyE2ERatchetParams SRKSProposed a $ E2ERatchetParams v k1 k2 Nothing + where + kemP :: VersionE2E -> Parser (Maybe ARKEMParams) + kemP v + | v >= pqRatchetE2EEncryptVersion = smpP + | otherwise = pure Nothing + +instance VersionI E2EVersion (E2ERatchetParams s a) where + type VersionRangeT E2EVersion (E2ERatchetParams s a) = E2ERatchetParamsUri s a + version (E2ERatchetParams v _ _ _) = v + toVersionRangeT (E2ERatchetParams _ k1 k2 kem_) vr = E2ERatchetParamsUri vr k1 k2 kem_ + +instance VersionRangeI E2EVersion (E2ERatchetParamsUri s a) where + type VersionT E2EVersion (E2ERatchetParamsUri s a) = (E2ERatchetParams s a) + versionRange (E2ERatchetParamsUri vr _ _ _) = vr + toVersionT (E2ERatchetParamsUri _ k1 k2 kem_) v = E2ERatchetParams v k1 k2 kem_ + +type RcvE2ERatchetParamsUri a = E2ERatchetParamsUri 'RKSProposed a + +data E2ERatchetParamsUri (s :: RatchetKEMState) (a :: Algorithm) + = E2ERatchetParamsUri VersionRangeE2E (PublicKey a) (PublicKey a) (Maybe (RKEMParams s)) deriving (Eq, Show) -instance AlgorithmI a => Encoding (E2ERatchetParams a) where - smpEncode (E2ERatchetParams v k1 k2) = smpEncode (v, k1, k2) - smpP = E2ERatchetParams <$> smpP <*> smpP <*> smpP +data AE2ERatchetParamsUri (a :: Algorithm) + = forall s. + RatchetKEMStateI s => + AE2ERatchetParamsUri (SRatchetKEMState s) (E2ERatchetParamsUri s a) -instance VersionI (E2ERatchetParams a) where - type VersionRangeT (E2ERatchetParams a) = E2ERatchetParamsUri a - version (E2ERatchetParams v _ _) = v - toVersionRangeT (E2ERatchetParams _ k1 k2) vr = E2ERatchetParamsUri vr k1 k2 +deriving instance Show (AE2ERatchetParamsUri a) -instance VersionRangeI (E2ERatchetParamsUri a) where - type VersionT (E2ERatchetParamsUri a) = (E2ERatchetParams a) - versionRange (E2ERatchetParamsUri vr _ _) = vr - toVersionT (E2ERatchetParamsUri _ k1 k2) v = E2ERatchetParams v k1 k2 +data AnyE2ERatchetParamsUri + = forall s a. + (RatchetKEMStateI s, DhAlgorithm a, AlgorithmI a) => + AnyE2ERatchetParamsUri (SRatchetKEMState s) (SAlgorithm a) (E2ERatchetParamsUri s a) -data E2ERatchetParamsUri (a :: Algorithm) - = E2ERatchetParamsUri VersionRange (PublicKey a) (PublicKey a) - deriving (Eq, Show) +deriving instance Show AnyE2ERatchetParamsUri -instance AlgorithmI a => StrEncoding (E2ERatchetParamsUri a) where - strEncode (E2ERatchetParamsUri vs key1 key2) = - strEncode $ - QSP QNoEscaping [("v", strEncode vs), ("x3dh", strEncodeList [key1, key2])] +instance (RatchetKEMStateI s, AlgorithmI a) => StrEncoding (E2ERatchetParamsUri s a) where + strEncode (E2ERatchetParamsUri vs key1 key2 kem_) = + strEncode . QSP QNoEscaping $ + [("v", strEncode vs), ("x3dh", strEncodeList [key1, key2])] + <> maybe [] encodeKem kem_ + where + encodeKem kem + | maxVersion vs < pqRatchetE2EEncryptVersion = [] + | otherwise = case kem of + RKParamsProposed k -> [("kem_key", strEncode k)] + RKParamsAccepted ct k -> [("kem_ct", strEncode ct), ("kem_key", strEncode k)] + strP = toParamsURI <$?> strP + where + toParamsURI = \case + AE2ERatchetParamsUri _ (E2ERatchetParamsUri vr k1 k2 Nothing) -> Right $ E2ERatchetParamsUri vr k1 k2 Nothing + AE2ERatchetParamsUri _ ps -> checkRatchetKEMState ps + +instance AlgorithmI a => StrEncoding (AE2ERatchetParamsUri a) where + strEncode (AE2ERatchetParamsUri _ ps) = strEncode ps + strP = (\(AnyE2ERatchetParamsUri s _ ps) -> AE2ERatchetParamsUri s <$> checkAlgorithm ps) <$?> strP + +instance StrEncoding AnyE2ERatchetParamsUri where + strEncode (AnyE2ERatchetParamsUri _ _ ps) = strEncode ps strP = do query <- strP - vs <- queryParam "v" query + vr :: VersionRangeE2E <- queryParam "v" query keys <- L.toList <$> queryParam "x3dh" query case keys of - [key1, key2] -> pure $ E2ERatchetParamsUri vs key1 key2 + [APublicDhKey a k1, APublicDhKey a' k2] -> case testEquality a a' of + Nothing -> fail "bad e2e params: different key algorithms" + Just Refl -> + kemP vr query >>= \case + Just (ARKP s kem) -> pure $ AnyE2ERatchetParamsUri s a $ E2ERatchetParamsUri vr k1 k2 (Just kem) + Nothing -> pure $ AnyE2ERatchetParamsUri SRKSProposed a $ E2ERatchetParamsUri vr k1 k2 Nothing _ -> fail "bad e2e params" + where + kemP vr query + | maxVersion vr >= pqRatchetE2EEncryptVersion = + queryParam_ "kem_key" query + $>>= \k -> Just . kemParams k <$> queryParam_ "kem_ct" query + | otherwise = pure Nothing + kemParams k = \case + Nothing -> ARKP SRKSProposed $ RKParamsProposed k + Just ct -> ARKP SRKSAccepted $ RKParamsAccepted ct k -generateE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> Version -> STM (PrivateKey a, PrivateKey a, E2ERatchetParams a) -generateE2EParams g v = do - (k1, pk1) <- generateKeyPair g - (k2, pk2) <- generateKeyPair g - pure (pk1, pk2, E2ERatchetParams v k1 k2) +type RcvE2ERatchetParams a = E2ERatchetParams 'RKSProposed a + +type SndE2ERatchetParams a = AE2ERatchetParams a + +data PrivRKEMParams (s :: RatchetKEMState) where + PrivateRKParamsProposed :: KEMKeyPair -> PrivRKEMParams 'RKSProposed + PrivateRKParamsAccepted :: KEMCiphertext -> KEMSharedKey -> KEMKeyPair -> PrivRKEMParams 'RKSAccepted + +data APrivRKEMParams = forall s. RatchetKEMStateI s => APRKP (SRatchetKEMState s) (PrivRKEMParams s) + +type RcvPrivRKEMParams = PrivRKEMParams 'RKSProposed + +instance RatchetKEMStateI s => Encoding (PrivRKEMParams s) where + smpEncode = \case + PrivateRKParamsProposed k -> smpEncode ('P', k) + PrivateRKParamsAccepted ct shared k -> smpEncode ('A', ct, shared, k) + smpP = (\(APRKP _ ps) -> checkRatchetKEMState' ps) <$?> smpP + +instance Encoding APrivRKEMParams where + smpEncode (APRKP _ ps) = smpEncode ps + smpP = + smpP >>= \case + 'P' -> APRKP SRKSProposed . PrivateRKParamsProposed <$> smpP + 'A' -> APRKP SRKSAccepted .:. PrivateRKParamsAccepted <$> smpP <*> smpP <*> smpP + _ -> fail "bad APrivRKEMParams" + +instance RatchetKEMStateI s => ToField (PrivRKEMParams s) where toField = toField . smpEncode + +instance (Typeable s, RatchetKEMStateI s) => FromField (PrivRKEMParams s) where fromField = blobFieldDecoder smpDecode + +data UseKEM (s :: RatchetKEMState) where + ProposeKEM :: UseKEM 'RKSProposed + AcceptKEM :: KEMPublicKey -> UseKEM 'RKSAccepted + +data AUseKEM = forall s. RatchetKEMStateI s => AUseKEM (SRatchetKEMState s) (UseKEM s) + +generateE2EParams :: forall s a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe (UseKEM s) -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams s), E2ERatchetParams s a) +generateE2EParams g v useKEM_ = do + (k1, pk1) <- atomically $ generateKeyPair g + (k2, pk2) <- atomically $ generateKeyPair g + kems <- kemParams + pure (pk1, pk2, snd <$> kems, E2ERatchetParams v k1 k2 (fst <$> kems)) + where + kemParams :: IO (Maybe (RKEMParams s, PrivRKEMParams s)) + kemParams = case useKEM_ of + Just useKem | v >= pqRatchetE2EEncryptVersion -> Just <$> do + ks@(k, _) <- sntrup761Keypair g + case useKem of + ProposeKEM -> pure (RKParamsProposed k, PrivateRKParamsProposed ks) + AcceptKEM k' -> do + (ct, shared) <- sntrup761Enc g k' + pure (RKParamsAccepted ct k, PrivateRKParamsAccepted ct shared ks) + _ -> pure Nothing + +-- used by party initiating connection, Bob in double-ratchet spec +generateRcvE2EParams :: (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> PQSupport -> IO (PrivateKey a, PrivateKey a, Maybe (PrivRKEMParams 'RKSProposed), E2ERatchetParams 'RKSProposed a) +generateRcvE2EParams g v = generateE2EParams g v . proposeKEM_ + where + proposeKEM_ :: PQSupport -> Maybe (UseKEM 'RKSProposed) + proposeKEM_ = \case + PQSupportOn -> Just ProposeKEM + PQSupportOff -> Nothing + +-- used by party accepting connection, Alice in double-ratchet spec +generateSndE2EParams :: forall a. (AlgorithmI a, DhAlgorithm a) => TVar ChaChaDRG -> VersionE2E -> Maybe AUseKEM -> IO (PrivateKey a, PrivateKey a, Maybe APrivRKEMParams, AE2ERatchetParams a) +generateSndE2EParams g v = \case + Nothing -> do + (pk1, pk2, _, e2eParams) <- generateE2EParams g v Nothing + pure (pk1, pk2, Nothing, AE2ERatchetParams SRKSProposed e2eParams) + Just (AUseKEM s useKEM) -> do + (pk1, pk2, pKem, e2eParams) <- generateE2EParams g v (Just useKEM) + pure (pk1, pk2, APRKP s <$> pKem, AE2ERatchetParams s e2eParams) data RatchetInitParams = RatchetInitParams { assocData :: Str, ratchetKey :: RatchetKey, sndHK :: HeaderKey, - rcvNextHK :: HeaderKey + rcvNextHK :: HeaderKey, + kemAccepted :: Maybe RatchetKEMAccepted } - deriving (Eq, Show) + deriving (Show) -x3dhSnd :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> E2ERatchetParams a -> RatchetInitParams -x3dhSnd spk1 spk2 (E2ERatchetParams _ rk1 rk2) = - x3dh (publicKey spk1, rk1) (dh' rk1 spk2) (dh' rk2 spk1) (dh' rk2 spk2) +-- this is used by the peer joining the connection +pqX3dhSnd :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> Maybe APrivRKEMParams -> E2ERatchetParams 'RKSProposed a -> Either CryptoError (RatchetInitParams, Maybe KEMKeyPair) +-- 3. replied 2. received +pqX3dhSnd spk1 spk2 spKem_ (E2ERatchetParams v rk1 rk2 rKem_) = do + (ks_, kem_) <- sndPq + let initParams = pqX3dh (publicKey spk1, rk1) (dh' rk1 spk2) (dh' rk2 spk1) (dh' rk2 spk2) kem_ + pure (initParams, ks_) + where + sndPq :: Either CryptoError (Maybe KEMKeyPair, Maybe RatchetKEMAccepted) + sndPq = case spKem_ of + Just (APRKP _ ps) | v >= pqRatchetE2EEncryptVersion -> case (ps, rKem_) of + (PrivateRKParamsAccepted ct shared ks, Just (RKParamsProposed k)) -> Right (Just ks, Just $ RatchetKEMAccepted k shared ct) + (PrivateRKParamsProposed ks, _) -> Right (Just ks, Nothing) -- both parties can send "proposal" in case of ratchet renegotiation + _ -> Left CERatchetKEMState + _ -> Right (Nothing, Nothing) -x3dhRcv :: DhAlgorithm a => PrivateKey a -> PrivateKey a -> E2ERatchetParams a -> RatchetInitParams -x3dhRcv rpk1 rpk2 (E2ERatchetParams _ sk1 sk2) = - x3dh (sk1, publicKey rpk1) (dh' sk2 rpk1) (dh' sk1 rpk2) (dh' sk2 rpk2) +-- this is used by the peer that created new connection, after receiving the reply +pqX3dhRcv :: forall s a. (RatchetKEMStateI s, DhAlgorithm a) => PrivateKey a -> PrivateKey a -> Maybe (PrivRKEMParams 'RKSProposed) -> E2ERatchetParams s a -> ExceptT CryptoError IO (RatchetInitParams, Maybe KEMKeyPair) +-- 1. sent 4. received in reply +pqX3dhRcv rpk1 rpk2 rpKem_ (E2ERatchetParams v sk1 sk2 sKem_) = do + kem_ <- rcvPq + let initParams = pqX3dh (sk1, publicKey rpk1) (dh' sk2 rpk1) (dh' sk1 rpk2) (dh' sk2 rpk2) (snd <$> kem_) + pure (initParams, fst <$> kem_) + where + rcvPq :: ExceptT CryptoError IO (Maybe (KEMKeyPair, RatchetKEMAccepted)) + rcvPq = case sKem_ of + Just (RKParamsAccepted ct k') | v >= pqRatchetE2EEncryptVersion -> case rpKem_ of + Just (PrivateRKParamsProposed ks@(_, pk)) -> do + shared <- liftIO $ sntrup761Dec ct pk + pure $ Just (ks, RatchetKEMAccepted k' shared ct) + Nothing -> throwError CERatchetKEMState + _ -> pure Nothing -- both parties can send "proposal" in case of ratchet renegotiation -x3dh :: DhAlgorithm a => (PublicKey a, PublicKey a) -> DhSecret a -> DhSecret a -> DhSecret a -> RatchetInitParams -x3dh (sk1, rk1) dh1 dh2 dh3 = - RatchetInitParams {assocData, ratchetKey = RatchetKey sk, sndHK = Key hk, rcvNextHK = Key nhk} +pqX3dh :: DhAlgorithm a => (PublicKey a, PublicKey a) -> DhSecret a -> DhSecret a -> DhSecret a -> Maybe RatchetKEMAccepted -> RatchetInitParams +pqX3dh (sk1, rk1) dh1 dh2 dh3 kemAccepted = + RatchetInitParams {assocData, ratchetKey = RatchetKey sk, sndHK = Key hk, rcvNextHK = Key nhk, kemAccepted} where assocData = Str $ pubKeyBytes sk1 <> pubKeyBytes rk1 - dhs = dhBytes' dh1 <> dhBytes' dh2 <> dhBytes' dh3 + dhs = dhBytes' dh1 <> dhBytes' dh2 <> dhBytes' dh3 <> pq + pq = maybe "" (\RatchetKEMAccepted {rcPQRss = KEMSharedKey ss} -> BA.convert ss) kemAccepted (hk, nhk, sk) = let salt = B.replicate 64 '\0' in hkdf3 salt dhs "SimpleXX3DH" @@ -125,10 +467,15 @@ type RatchetX448 = Ratchet 'X448 data Ratchet a = Ratchet { -- ratchet version range sent in messages (current .. max supported ratchet version) - rcVersion :: VersionRange, + rcVersion :: RatchetVersions, -- associated data - must be the same in both parties ratchets rcAD :: Str, rcDHRs :: PrivateKey a, + rcKEM :: Maybe RatchetKEM, + rcSupportKEM :: PQSupport, -- defines header size, can only be enabled once + rcEnableKEM :: PQEncryption, -- will enable KEM on the next ratchet step + rcSndKEM :: PQEncryption, -- used KEM hybrid secret for sending ratchet + rcRcvKEM :: PQEncryption, -- used KEM hybrid secret for receiving ratchet rcRK :: RatchetKey, rcSnd :: Maybe (SndRatchet a), rcRcv :: Maybe RcvRatchet, @@ -138,20 +485,53 @@ data Ratchet a = Ratchet rcNHKs :: HeaderKey, rcNHKr :: HeaderKey } + deriving (Show) + +data RatchetVersions = RatchetVersions + { current :: VersionE2E, + maxSupported :: VersionE2E + } deriving (Eq, Show) +instance ToJSON RatchetVersions where + -- TODO v5.7 or v5.8 change to the default record encoding + toJSON RatchetVersions {current, maxSupported} = toJSON (current, maxSupported) + toEncoding RatchetVersions {current, maxSupported} = toEncoding (current, maxSupported) + +instance FromJSON RatchetVersions where + -- TODO v5.7 or v5.8 replace comment below with "tuple for backward" + -- this parser supports JSON record encoding for forward compatibility + parseJSON v = toRV <$> (tupleP <|> recordP v) + where + tupleP = parseJSON v + recordP = J.withObject "RatchetVersions" $ \o -> (,) <$> o J..: "current" <*> o J..: "maxSupported" + toRV (current, maxSupported) = RatchetVersions {current, maxSupported} + data SndRatchet a = SndRatchet { rcDHRr :: PublicKey a, rcCKs :: RatchetKey, rcHKs :: HeaderKey } - deriving (Eq, Show) + deriving (Show) data RcvRatchet = RcvRatchet { rcCKr :: RatchetKey, rcHKr :: HeaderKey } - deriving (Eq, Show) + deriving (Show) + +data RatchetKEM = RatchetKEM + { rcPQRs :: KEMKeyPair, + rcKEMs :: Maybe RatchetKEMAccepted + } + deriving (Show) + +data RatchetKEMAccepted = RatchetKEMAccepted + { rcPQRr :: KEMPublicKey, -- received key + rcPQRss :: KEMSharedKey, -- computed shared secret + rcPQRct :: KEMCiphertext -- sent encaps(rcPQRr, rcPQRss) + } + deriving (Show) type SkippedMsgKeys = Map HeaderKey SkippedHdrMsgKeys @@ -189,7 +569,7 @@ instance Encoding MessageKey where -- | Input key material for double ratchet HKDF functions newtype RatchetKey = RatchetKey ByteString - deriving (Eq, Show) + deriving (Show) instance ToJSON RatchetKey where toJSON (RatchetKey k) = strToJSON k @@ -202,19 +582,34 @@ instance ToField MessageKey where toField = toField . smpEncode instance FromField MessageKey where fromField = blobFieldDecoder smpDecode --- | Sending ratchet initialization, equivalent to RatchetInitAliceHE in double ratchet spec +-- | Sending ratchet initialization -- -- Please note that sPKey is not stored, and its public part together with random salt -- is sent to the recipient. +-- @ +-- RatchetInitAlicePQ2HE(state, SK, bob_dh_public_key, shared_hka, shared_nhkb, bob_pq_kem_encapsulation_key) +-- // below added for post-quantum KEM +-- state.PQRs = GENERATE_PQKEM() +-- state.PQRr = bob_pq_kem_encapsulation_key +-- state.PQRss = random // shared secret for KEM +-- state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) // encapsulated additional shared secret +-- // above added for KEM +-- @ initSndRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PublicKey a -> PrivateKey a -> RatchetInitParams -> Ratchet a -initSndRatchet rcVersion rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = do - -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr)) - let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs + forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PublicKey a -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> Ratchet a +initSndRatchet rcVersion rcDHRr rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) = do + -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(SK, DH(state.DHRs, state.DHRr) || state.PQRss) + let (rcRK, rcCKs, rcNHKs) = rootKdf ratchetKey rcDHRr rcDHRs (rcPQRss <$> kemAccepted) + pqOn = isJust rcPQRs_ in Ratchet { rcVersion, rcAD = assocData, rcDHRs, + rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, + rcSupportKEM = PQSupport pqOn, + rcEnableKEM = PQEncryption pqOn, + rcSndKEM = PQEncryption $ isJust kemAccepted, + rcRcvKEM = PQEncOff, rcRK, rcSnd = Just SndRatchet {rcDHRr, rcCKs, rcHKs = sndHK}, rcRcv = Nothing, @@ -225,17 +620,29 @@ initSndRatchet rcVersion rcDHRr rcDHRs RatchetInitParams {assocData, ratchetKey, rcNHKr = rcvNextHK } --- | Receiving ratchet initialization, equivalent to RatchetInitBobHE in double ratchet spec +-- | Receiving ratchet initialization, equivalent to RatchetInitBobPQ2HE in double ratchet spec +-- +-- def RatchetInitBobPQ2HE(state, SK, bob_dh_key_pair, shared_hka, shared_nhkb, bob_pq_kem_key_pair) -- -- Please note that the public part of rcDHRs was sent to the sender -- as part of the connection request and random salt was received from the sender. initRcvRatchet :: - forall a. (AlgorithmI a, DhAlgorithm a) => VersionRange -> PrivateKey a -> RatchetInitParams -> Ratchet a -initRcvRatchet rcVersion rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK} = + forall a. (AlgorithmI a, DhAlgorithm a) => RatchetVersions -> PrivateKey a -> (RatchetInitParams, Maybe KEMKeyPair) -> PQSupport -> Ratchet a +initRcvRatchet rcVersion rcDHRs (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, rcPQRs_) pqSupport = Ratchet { rcVersion, rcAD = assocData, rcDHRs, + -- rcKEM: + -- state.PQRs = bob_pq_kem_key_pair + -- state.PQRr = None + -- state.PQRss = None + -- state.PQRct = None + rcKEM = (`RatchetKEM` kemAccepted) <$> rcPQRs_, + rcSupportKEM = pqSupport, + rcEnableKEM = pqSupportToEnc pqSupport, + rcSndKEM = PQEncOff, + rcRcvKEM = PQEncOff, rcRK = ratchetKey, rcSnd = Nothing, rcRcv = Nothing, @@ -246,91 +653,251 @@ initRcvRatchet rcVersion rcDHRs RatchetInitParams {assocData, ratchetKey, sndHK, rcNHKr = sndHK } +-- encaps = state.PQRs.encaps, // added for KEM #2 +-- ct = state.PQRct // added for KEM #1 data MsgHeader a = MsgHeader { -- | max supported ratchet version - msgMaxVersion :: Version, + msgMaxVersion :: VersionE2E, msgDHRs :: PublicKey a, + msgKEM :: Maybe ARKEMParams, msgPN :: Word32, msgNs :: Word32 } - deriving (Eq, Show) - -data AMsgHeader - = forall a. - (AlgorithmI a, DhAlgorithm a) => - AMsgHeader (SAlgorithm a) (MsgHeader a) + deriving (Show) -- to allow extension without increasing the size, the actual header length is: -- 69 = 2 (original size) + 2 + 1+56 (Curve448) + 4 + 4 -paddedHeaderLen :: Int -paddedHeaderLen = 88 +-- The exact size is 2288, added reserve +paddedHeaderLen :: VersionE2E -> PQSupport -> Int +paddedHeaderLen v = \case + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 2310 + _ -> 88 -- only used in tests to validate correct padding --- (2 bytes - version size, 1 byte - header size, not to have it fixed or version-dependent) -fullHeaderLen :: Int -fullHeaderLen = 2 + 1 + paddedHeaderLen + authTagSize + ivSize @AES256 +-- (2 bytes - version size, 1 byte - header size) +fullHeaderLen :: VersionE2E -> PQSupport -> Int +fullHeaderLen v pq = 2 + 1 + paddedHeaderLen v pq + authTagSize + ivSize @AES256 -instance AlgorithmI a => Encoding (MsgHeader a) where - smpEncode MsgHeader {msgMaxVersion, msgDHRs, msgPN, msgNs} = - smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) - smpP = do - msgMaxVersion <- smpP - msgDHRs <- smpP - msgPN <- smpP - msgNs <- smpP - pure MsgHeader {msgMaxVersion, msgDHRs, msgPN, msgNs} +-- pass the current version, as MsgHeader only includes the max supported version that can be different from the current +encodeMsgHeader :: AlgorithmI a => VersionE2E -> MsgHeader a -> ByteString +encodeMsgHeader v MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} + | v >= pqRatchetE2EEncryptVersion = smpEncode (msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs) + | otherwise = smpEncode (msgMaxVersion, msgDHRs, msgPN, msgNs) + +-- pass the current version, as MsgHeader only includes the max supported version that can be different from the current +msgHeaderP :: AlgorithmI a => VersionE2E -> Parser (MsgHeader a) +msgHeaderP v = do + msgMaxVersion <- smpP + msgDHRs <- smpP + msgKEM <- if v >= pqRatchetE2EEncryptVersion then smpP else pure Nothing + msgPN <- smpP + msgNs <- smpP + pure MsgHeader {msgMaxVersion, msgDHRs, msgKEM, msgPN, msgNs} data EncMessageHeader = EncMessageHeader - { ehVersion :: Version, + { ehVersion :: VersionE2E, -- this is current ratchet version ehIV :: IV, ehAuthTag :: AuthTag, ehBody :: ByteString } +-- this encoding depends on version in EncMessageHeader because it is "current" ratchet version instance Encoding EncMessageHeader where - smpEncode EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} = - smpEncode (ehVersion, ehIV, ehAuthTag, ehBody) + smpEncode EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} + = smpEncode (ehVersion, ehIV, ehAuthTag) <> encodeLarge ehVersion ehBody smpP = do - (ehVersion, ehIV, ehAuthTag, ehBody) <- smpP + (ehVersion, ehIV, ehAuthTag) <- smpP + ehBody <- largeP pure EncMessageHeader {ehVersion, ehIV, ehAuthTag, ehBody} +-- the encoder always uses 2-byte lengths for the new version, even for short headers without PQ keys. +encodeLarge :: VersionE2E -> ByteString -> ByteString +encodeLarge v s + -- the condition for length is not necessary, it's here as a fallback. + -- | v >= pqRatchetE2EEncryptVersion || B.length s > 255 = smpEncode $ Large s + | v >= pqRatchetE2EEncryptVersion = smpEncode $ Large s + | otherwise = smpEncode s + +-- This parser relies on the fact that header cannot be shorter than 32 bytes (it is ~69 bytes without PQ KEM), +-- therefore if the first byte is less or equal to 31 (x1F), then we have 2 byte-length limited to 8191. +-- This allows upgrading the current version in one message. +largeP :: Parser ByteString +largeP = do + len1 <- peekWord8' + if len1 < 32 then unLarge <$> smpP else smpP + +-- the header is length-prefixed to parse it as string and use as part of associated data for authenticated encryption data EncRatchetMessage = EncRatchetMessage { emHeader :: ByteString, emAuthTag :: AuthTag, emBody :: ByteString } -instance Encoding EncRatchetMessage where - smpEncode EncRatchetMessage {emHeader, emBody, emAuthTag} = - smpEncode (emHeader, emAuthTag, Tail emBody) - smpP = do - (emHeader, emAuthTag, Tail emBody) <- smpP - pure EncRatchetMessage {emHeader, emBody, emAuthTag} +encodeEncRatchetMessage :: VersionE2E -> EncRatchetMessage -> ByteString +encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} + = encodeLarge v emHeader <> smpEncode (emAuthTag, Tail emBody) -rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> ExceptT CryptoError IO (ByteString, Ratchet a) -rcEncrypt Ratchet {rcSnd = Nothing} _ _ = throwE CERatchetState -rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcNs, rcPN, rcAD = Str rcAD, rcVersion} paddedMsgLen msg = do +encRatchetMessageP :: Parser EncRatchetMessage +encRatchetMessageP = do + emHeader <- largeP + (emAuthTag, Tail emBody) <- smpP + pure EncRatchetMessage {emHeader, emBody, emAuthTag} + +newtype PQEncryption = PQEncryption {enablePQ :: Bool} + deriving (Eq, Show) + +pattern PQEncOn :: PQEncryption +pattern PQEncOn = PQEncryption True + +pattern PQEncOff :: PQEncryption +pattern PQEncOff = PQEncryption False + +{-# COMPLETE PQEncOn, PQEncOff #-} + +instance ToJSON PQEncryption where + toEncoding (PQEncryption pq) = toEncoding pq + toJSON (PQEncryption pq) = toJSON pq + +instance FromJSON PQEncryption where + parseJSON v = PQEncryption <$> parseJSON v + omittedField = Just PQEncOff + +newtype PQSupport = PQSupport {supportPQ :: Bool} + deriving (Eq, Show) + +pattern PQSupportOn :: PQSupport +pattern PQSupportOn = PQSupport True + +pattern PQSupportOff :: PQSupport +pattern PQSupportOff = PQSupport False + +{-# COMPLETE PQSupportOn, PQSupportOff #-} + +instance ToJSON PQSupport where + toEncoding (PQSupport pq) = toEncoding pq + toJSON (PQSupport pq) = toJSON pq + +instance FromJSON PQSupport where + parseJSON v = PQSupport <$> parseJSON v + omittedField = Just PQSupportOff + +pqSupportToEnc :: PQSupport -> PQEncryption +pqSupportToEnc (PQSupport pq) = PQEncryption pq + +pqEncToSupport :: PQEncryption -> PQSupport +pqEncToSupport (PQEncryption pq) = PQSupport pq + +pqSupportAnd :: PQSupport -> PQSupport -> PQSupport +pqSupportAnd (PQSupport s1) (PQSupport s2) = PQSupport $ s1 && s2 + +pqEnableSupport :: VersionE2E -> PQSupport -> PQEncryption -> PQSupport +pqEnableSupport v (PQSupport sup) (PQEncryption enc) = PQSupport $ sup || (v >= pqRatchetE2EEncryptVersion && enc) + +replyKEM_ :: VersionE2E -> Maybe (RKEMParams 'RKSProposed) -> PQSupport -> Maybe AUseKEM +replyKEM_ v kem_ = \case + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> Just $ case kem_ of + Just (RKParamsProposed k) -> AUseKEM SRKSAccepted $ AcceptKEM k + Nothing -> AUseKEM SRKSProposed ProposeKEM + _ -> Nothing + +instance StrEncoding PQEncryption where + strEncode pqMode + | enablePQ pqMode = "pq=enable" + | otherwise = "pq=disable" + strP = + A.takeTill (== ' ') >>= \case + "pq=enable" -> pq True + "pq=disable" -> pq False + _ -> fail "bad PQEncryption" + where + pq = pure . PQEncryption + +instance StrEncoding PQSupport where + strEncode = strEncode . pqSupportToEnc + {-# INLINE strEncode #-} + strP = pqEncToSupport <$> strP + {-# INLINE strP #-} + +data InitialKeys = IKUsePQ | IKNoPQ PQSupport + deriving (Eq, Show) + +pattern IKPQOn :: InitialKeys +pattern IKPQOn = IKNoPQ PQSupportOn + +pattern IKPQOff :: InitialKeys +pattern IKPQOff = IKNoPQ PQSupportOff + +instance StrEncoding InitialKeys where + strEncode = \case + IKUsePQ -> "pq=invitation" + IKNoPQ pq -> strEncode pq + strP = IKNoPQ <$> strP <|> "pq=invitation" $> IKUsePQ + +-- determines whether PQ key should be included in invitation link +initialPQEncryption :: InitialKeys -> PQSupport +initialPQEncryption = \case + IKUsePQ -> PQSupportOn + IKNoPQ _ -> PQSupportOff -- default + +-- determines whether PQ encryption should be used in connection +connPQEncryption :: InitialKeys -> PQSupport +connPQEncryption = \case + IKUsePQ -> PQSupportOn + IKNoPQ pq -> pq -- default for creating connection is IKNoPQ PQEncOn + +rcEncrypt :: AlgorithmI a => Ratchet a -> Int -> ByteString -> Maybe PQEncryption -> VersionE2E -> ExceptT CryptoError IO (ByteString, Ratchet a) +rcEncrypt Ratchet {rcSnd = Nothing} _ _ _ _ = throwE CERatchetState +rcEncrypt rc@Ratchet {rcSnd = Just sr@SndRatchet {rcCKs, rcHKs}, rcDHRs, rcKEM, rcNs, rcPN, rcAD = Str rcAD, rcSupportKEM, rcEnableKEM, rcVersion} paddedMsgLen msg pqEnc_ supportedE2EVersion = do -- state.CKs, mk = KDF_CK(state.CKs) let (ck', mk, iv, ehIV) = chainKdf rcCKs + v = current rcVersion + -- PQ encryption can be enabled or disabled + rcEnableKEM' = fromMaybe rcEnableKEM pqEnc_ + -- support for PQ encryption (and therefore large headers/small envelopes) can only be enabled, it cannot be disabled + rcSupportKEM' = pqEnableSupport v rcSupportKEM rcEnableKEM' + -- This sets max version to support PQ encryption. + -- Current version upgrade happens when peer decrypts the message. + -- TODO note that maxSupported will not downgrade here below current (v). + maxSupported' = max supportedE2EVersion $ if pqEnc_ == Just PQEncOn then pqRatchetE2EEncryptVersion else v + rcVersion' = rcVersion {maxSupported = maxSupported'} -- enc_header = HENCRYPT(state.HKs, header) - (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV paddedHeaderLen rcAD msgHeader + (ehAuthTag, ehBody) <- encryptAEAD rcHKs ehIV (paddedHeaderLen v rcSupportKEM') rcAD (msgHeader v maxSupported') -- return enc_header, ENCRYPT(mk, plaintext, CONCAT(AD, enc_header)) - let emHeader = smpEncode EncMessageHeader {ehVersion = minVersion rcVersion, ehBody, ehAuthTag, ehIV} + let emHeader = smpEncode EncMessageHeader {ehVersion = v, ehBody, ehAuthTag, ehIV} (emAuthTag, emBody) <- encryptAEAD mk iv paddedMsgLen (rcAD <> emHeader) msg - let msg' = smpEncode EncRatchetMessage {emHeader, emBody, emAuthTag} + let msg' = encodeEncRatchetMessage v EncRatchetMessage {emHeader, emBody, emAuthTag} -- state.Ns += 1 - rc' = rc {rcSnd = Just sr {rcCKs = ck'}, rcNs = rcNs + 1} + rc' = + rc + { rcSnd = Just sr {rcCKs = ck'}, + rcNs = rcNs + 1, + rcSupportKEM = rcSupportKEM', + rcEnableKEM = rcEnableKEM', + rcVersion = rcVersion', + rcKEM = if pqEnc_ == Just PQEncOff then (\rck -> rck {rcKEMs = Nothing}) <$> rcKEM else rcKEM + } pure (msg', rc') where - -- header = HEADER(state.DHRs, state.PN, state.Ns) - msgHeader = - smpEncode + -- header = HEADER_PQ2( + -- dh = state.DHRs.public, + -- kem = state.PQRs.public, // added for KEM #2 + -- ct = state.PQRct, // added for KEM #1 + -- pn = state.PN, + -- n = state.Ns + -- ) + msgHeader v maxSupported' = + encodeMsgHeader + v MsgHeader - { msgMaxVersion = maxVersion rcVersion, + { msgMaxVersion = maxSupported', msgDHRs = publicKey rcDHRs, + msgKEM = msgKEMParams <$> rcKEM, msgPN = rcPN, msgNs = rcNs } + msgKEMParams RatchetKEM {rcPQRs = (k, _), rcKEMs} = case rcKEMs of + Nothing -> ARKP SRKSProposed $ RKParamsProposed k + Just RatchetKEMAccepted {rcPQRct} -> ARKP SRKSAccepted $ RKParamsAccepted rcPQRct k data SkippedMessage a = SMMessage (DecryptResult a) @@ -338,7 +905,7 @@ data SkippedMessage a | SMNone data RatchetStep = AdvanceRatchet | SameRatchet - deriving (Eq) + deriving (Eq, Show) type DecryptResult a = (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff) @@ -353,8 +920,8 @@ rcDecrypt :: SkippedMsgKeys -> ByteString -> ExceptT CryptoError IO (DecryptResult a) -rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do - encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError smpP msg' +rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD, rcVersion} rcMKSkipped msg' = do + encMsg@EncRatchetMessage {emHeader} <- parseE CryptoHeaderError encRatchetMessageP msg' encHdr <- parseE CryptoHeaderError smpP emHeader -- plaintext = TrySkippedMessageKeysHE(state, enc_header, cipher-text, AD) decryptSkipped encHdr encMsg >>= \case @@ -368,9 +935,16 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do SMMessage r -> pure r where decryptRcMessage :: RatchetStep -> MsgHeader a -> EncRatchetMessage -> ExceptT CryptoError IO (DecryptResult a) - decryptRcMessage rcStep MsgHeader {msgDHRs, msgPN, msgNs} encMsg = do + decryptRcMessage rcStep hdr@MsgHeader {msgMaxVersion, msgPN, msgNs} encMsg = do -- if dh_ratchet: - (rc', smks1) <- ratchetStep rcStep + (rc', smks1) <- case rcStep of + SameRatchet -> pure (upgradedRatchet, M.empty) + AdvanceRatchet -> do + -- SkipMessageKeysHE(state, header.pn) + (rc', hmks) <- liftEither $ skipMessageKeys msgPN upgradedRatchet + -- DHRatchetPQ2HE(state, header) + (,hmks) <$> ratchetStep rc' hdr + -- SkipMessageKeysHE(state, header.n) case skipMessageKeys msgNs rc' of Left e -> pure (Left e, rc', smkDiff smks1) Right (rc''@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr}, rcNr}, smks2) -> do @@ -380,37 +954,79 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do msg <- decryptMessage (MessageKey mk iv) encMsg -- state . Nr += 1 pure (msg, rc'' {rcRcv = Just rr {rcCKr = rcCKr'}, rcNr = rcNr + 1}, smkDiff $ smks1 <> smks2) - Right (rc'', smks2) -> do + Right (rc'', smks2) -> pure (Left CERatchetState, rc'', smkDiff $ smks1 <> smks2) where + upgradedRatchet :: Ratchet a + upgradedRatchet + | msgMaxVersion > current = rc {rcVersion = rcVersion {current = max current $ min msgMaxVersion maxSupported}} + | otherwise = rc + where + RatchetVersions {current, maxSupported} = rcVersion smkDiff :: SkippedMsgKeys -> SkippedMsgDiff smkDiff smks = if M.null smks then SMDNoChange else SMDAdd smks - ratchetStep :: RatchetStep -> ExceptT CryptoError IO (Ratchet a, SkippedMsgKeys) - ratchetStep SameRatchet = pure (rc, M.empty) - ratchetStep AdvanceRatchet = - -- SkipMessageKeysHE(state, header.pn) - case skipMessageKeys msgPN rc of - Left e -> throwE e - Right (rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr}, hmks) -> do - -- DHRatchetHE(state, header) - (_, rcDHRs') <- atomically $ generateKeyPair @a g - -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr)) - let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs - -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr)) - (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' - rc'' = - rc' - { rcDHRs = rcDHRs', - rcRK = rcRK'', - rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, - rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, - rcPN = rcNs rc, - rcNs = 0, - rcNr = 0, - rcNHKs = rcNHKs', - rcNHKr = rcNHKr' - } - pure (rc'', hmks) + ratchetStep :: Ratchet a -> MsgHeader a -> ExceptT CryptoError IO (Ratchet a) + ratchetStep rc'@Ratchet {rcDHRs, rcRK, rcNHKs, rcNHKr, rcSupportKEM, rcVersion = rv} MsgHeader {msgDHRs, msgKEM} = do + (kemSS, kemSS', rcKEM') <- pqRatchetStep rc' msgKEM + -- state.DHRs = GENERATE_DH() + (_, rcDHRs') <- atomically $ generateKeyPair @a g + -- state.RK, state.CKr, state.NHKr = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || ss) + let (rcRK', rcCKr', rcNHKr') = rootKdf rcRK msgDHRs rcDHRs kemSS + -- state.RK, state.CKs, state.NHKs = KDF_RK_HE(state.RK, DH(state.DHRs, state.DHRr) || state.PQRss) + (rcRK'', rcCKs', rcNHKs') = rootKdf rcRK' msgDHRs rcDHRs' kemSS' + sndKEM = isJust kemSS' + rcvKEM = isJust kemSS + rcEnableKEM' = PQEncryption $ sndKEM || rcvKEM || isJust rcKEM' + pure + rc' + { rcDHRs = rcDHRs', + rcKEM = rcKEM', + rcSupportKEM = pqEnableSupport (current rv) rcSupportKEM rcEnableKEM', + rcEnableKEM = rcEnableKEM', + rcSndKEM = PQEncryption sndKEM, + rcRcvKEM = PQEncryption rcvKEM, + rcRK = rcRK'', + rcSnd = Just SndRatchet {rcDHRr = msgDHRs, rcCKs = rcCKs', rcHKs = rcNHKs}, + rcRcv = Just RcvRatchet {rcCKr = rcCKr', rcHKr = rcNHKr}, + rcPN = rcNs rc, + rcNs = 0, + rcNr = 0, + rcNHKs = rcNHKs', + rcNHKr = rcNHKr' + } + pqRatchetStep :: Ratchet a -> Maybe ARKEMParams -> ExceptT CryptoError IO (Maybe KEMSharedKey, Maybe KEMSharedKey, Maybe RatchetKEM) + pqRatchetStep Ratchet {rcKEM, rcEnableKEM = PQEncryption pqEnc, rcVersion = rv} = \case + -- received message does not have KEM in header, + -- but the user enabled KEM when sending previous message + Nothing -> case rcKEM of + Nothing | pqEnc && current rv >= pqRatchetE2EEncryptVersion -> do + rcPQRs <- liftIO $ sntrup761Keypair g + pure (Nothing, Nothing, Just RatchetKEM {rcPQRs, rcKEMs = Nothing}) + _ -> pure (Nothing, Nothing, Nothing) + -- received message has KEM in header. + Just (ARKP _ ps) + | pqEnc && current rv >= pqRatchetE2EEncryptVersion -> do + -- state.PQRr = header.kem + (ss, rcPQRr) <- sharedSecret + -- state.PQRct = PQKEM-ENC(state.PQRr, state.PQRss) // encapsulated additional shared secret KEM #1 + (rcPQRct, rcPQRss) <- liftIO $ sntrup761Enc g rcPQRr + -- state.PQRs = GENERATE_PQKEM() + rcPQRs <- liftIO $ sntrup761Keypair g + let kem' = RatchetKEM {rcPQRs, rcKEMs = Just RatchetKEMAccepted {rcPQRr, rcPQRss, rcPQRct}} + pure (ss, Just rcPQRss, Just kem') + | otherwise -> do + -- state.PQRr = header.kem + (ss, _) <- sharedSecret + pure (ss, Nothing, Nothing) + where + sharedSecret = case ps of + RKParamsProposed k -> pure (Nothing, k) + RKParamsAccepted ct k -> case rcKEM of + Nothing -> throwE CERatchetKEMState + -- ss = PQKEM-DEC(state.PQRs.private, header.ct) + Just RatchetKEM {rcPQRs} -> do + ss <- liftIO $ sntrup761Dec ct (snd rcPQRs) + pure (Just ss, k) skipMessageKeys :: Word32 -> Ratchet a -> Either CryptoError (Ratchet a, SkippedMsgKeys) skipMessageKeys _ r@Ratchet {rcRcv = Nothing} = Right (r, M.empty) skipMessageKeys untilN r@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr, rcHKr}, rcNr} @@ -457,18 +1073,21 @@ rcDecrypt g rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do e -> throwE e -- header = HDECRYPT(state.NHKr, enc_header) decryptNextHeader hdr = (AdvanceRatchet,) <$> decryptHeader (rcNHKr rc) hdr - decryptHeader k EncMessageHeader {ehBody, ehAuthTag, ehIV} = do + decryptHeader k EncMessageHeader {ehVersion, ehBody, ehAuthTag, ehIV} = do header <- decryptAEAD k ehIV rcAD ehBody ehAuthTag `catchE` \_ -> throwE CERatchetHeader - parseE' CryptoHeaderError smpP header + parseE' CryptoHeaderError (msgHeaderP ehVersion) header decryptMessage :: MessageKey -> EncRatchetMessage -> ExceptT CryptoError IO (Either CryptoError ByteString) decryptMessage (MessageKey mk iv) EncRatchetMessage {emHeader, emBody, emAuthTag} = -- DECRYPT(mk, cipher-text, CONCAT(AD, enc_header)) tryE $ decryptAEAD mk iv (rcAD <> emHeader) emBody emAuthTag -rootKdf :: (AlgorithmI a, DhAlgorithm a) => RatchetKey -> PublicKey a -> PrivateKey a -> (RatchetKey, RatchetKey, Key) -rootKdf (RatchetKey rk) k pk = - let dhOut = dhBytes' $ dh' k pk - (rk', ck, nhk) = hkdf3 rk dhOut "SimpleXRootRatchet" +rootKdf :: (AlgorithmI a, DhAlgorithm a) => RatchetKey -> PublicKey a -> PrivateKey a -> Maybe KEMSharedKey -> (RatchetKey, RatchetKey, Key) +rootKdf (RatchetKey rk) k pk kemSecret_ = + let dhOut = dhBytes' (dh' k pk) + ss = case kemSecret_ of + Just (KEMSharedKey s) -> dhOut <> BA.convert s + Nothing -> dhOut + (rk', ck, nhk) = hkdf3 rk ss "SimpleXRootRatchet" in (RatchetKey rk', RatchetKey ck, Key nhk) chainKdf :: RatchetKey -> (RatchetKey, Key, IV, IV) @@ -487,6 +1106,10 @@ hkdf3 salt ikm info = (s1, s2, s3) $(JQ.deriveJSON defaultJSON ''RcvRatchet) +$(JQ.deriveJSON defaultJSON ''RatchetKEMAccepted) + +$(JQ.deriveJSON defaultJSON ''RatchetKEM) + instance AlgorithmI a => ToJSON (SndRatchet a) where toEncoding = $(JQ.mkToEncoding defaultJSON ''SndRatchet) toJSON = $(JQ.mkToJSON defaultJSON ''SndRatchet) @@ -504,3 +1127,11 @@ instance AlgorithmI a => FromJSON (Ratchet a) where instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict' + +instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField pqEnc + +instance FromField PQEncryption where fromField f = PQEncryption <$> fromField f + +instance ToField PQSupport where toField (PQSupport pqEnc) = toField pqEnc + +instance FromField PQSupport where fromField f = PQSupport <$> fromField f diff --git a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs index 0940c53ba..3b2238086 100644 --- a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs +++ b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs @@ -19,16 +19,20 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String newtype KEMPublicKey = KEMPublicKey ByteString - deriving (Show) + deriving (Eq, Show) newtype KEMSecretKey = KEMSecretKey ScrubbedBytes - deriving (Show) + deriving (Eq, Show) newtype KEMCiphertext = KEMCiphertext ByteString - deriving (Show) + deriving (Eq, Show) newtype KEMSharedKey = KEMSharedKey ScrubbedBytes - deriving (Show) + deriving (Eq, Show) + +unsafeRevealKEMSharedKey :: KEMSharedKey -> String +unsafeRevealKEMSharedKey (KEMSharedKey scrubbed) = show (BA.convert scrubbed :: ByteString) +{-# DEPRECATED unsafeRevealKEMSharedKey "unsafeRevealKEMSharedKey left in code" #-} type KEMKeyPair = (KEMPublicKey, KEMSecretKey) @@ -60,6 +64,18 @@ sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) = KEMSharedKey <$> BA.alloc c_SNTRUP761_SIZE (\kPtr -> c_sntrup761_dec kPtr cPtr skPtr) +instance Encoding KEMSecretKey where + smpEncode (KEMSecretKey c) = smpEncode . Large $ BA.convert c + smpP = KEMSecretKey . BA.convert . unLarge <$> smpP + +instance StrEncoding KEMSecretKey where + strEncode (KEMSecretKey pk) = strEncode (BA.convert pk :: ByteString) + strP = KEMSecretKey . BA.convert <$> strP @ByteString + +instance Encoding KEMPublicKey where + smpEncode (KEMPublicKey pk) = smpEncode . Large $ BA.convert pk + smpP = KEMPublicKey . BA.convert . unLarge <$> smpP + instance StrEncoding KEMPublicKey where strEncode (KEMPublicKey pk) = strEncode (BA.convert pk :: ByteString) strP = KEMPublicKey . BA.convert <$> strP @ByteString @@ -68,6 +84,25 @@ instance Encoding KEMCiphertext where smpEncode (KEMCiphertext c) = smpEncode . Large $ BA.convert c smpP = KEMCiphertext . BA.convert . unLarge <$> smpP +instance Encoding KEMSharedKey where + smpEncode (KEMSharedKey c) = smpEncode (BA.convert c :: ByteString) + smpP = KEMSharedKey . BA.convert <$> smpP @ByteString + +instance StrEncoding KEMCiphertext where + strEncode (KEMCiphertext pk) = strEncode (BA.convert pk :: ByteString) + strP = KEMCiphertext . BA.convert <$> strP @ByteString + +instance StrEncoding KEMSharedKey where + strEncode (KEMSharedKey pk) = strEncode (BA.convert pk :: ByteString) + strP = KEMSharedKey . BA.convert <$> strP @ByteString + +instance ToJSON KEMSecretKey where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance FromJSON KEMSecretKey where + parseJSON = strParseJSON "KEMSecretKey" + instance ToJSON KEMPublicKey where toJSON = strToJSON toEncoding = strToJEncoding @@ -75,8 +110,22 @@ instance ToJSON KEMPublicKey where instance FromJSON KEMPublicKey where parseJSON = strParseJSON "KEMPublicKey" +instance ToJSON KEMCiphertext where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance FromJSON KEMCiphertext where + parseJSON = strParseJSON "KEMCiphertext" + instance ToField KEMSharedKey where toField (KEMSharedKey k) = toField (BA.convert k :: ByteString) instance FromField KEMSharedKey where fromField f = KEMSharedKey . BA.convert @ByteString <$> fromField f + +instance ToJSON KEMSharedKey where + toJSON = strToJSON + toEncoding = strToJEncoding + +instance FromJSON KEMSharedKey where + parseJSON = strParseJSON "KEMSharedKey" diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index e81b0da89..fcefdc73d 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -179,6 +179,12 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin strP = (,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP {-# INLINE strP #-} +instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncoding e, StrEncoding f) => StrEncoding (a, b, c, d, e, f) where + strEncode (a, b, c, d, e, f) = B.unwords [strEncode a, strEncode b, strEncode c, strEncode d, strEncode e, strEncode f] + {-# INLINE strEncode #-} + strP = (,,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP + {-# INLINE strP #-} + strP_ :: StrEncoding a => Parser a strP_ = strP <* A.space diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs index d69114b68..72a92c278 100644 --- a/src/Simplex/Messaging/Notifications/Client.hs +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -10,15 +10,15 @@ import Data.Word (Word16) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Notifications.Transport (supportedClientNTFVRange) +import Simplex.Messaging.Notifications.Transport (NTFVersion, supportedClientNTFVRange) import Simplex.Messaging.Protocol (ErrorType) import Simplex.Messaging.Util (bshow) -type NtfClient = ProtocolClient ErrorType NtfResponse +type NtfClient = ProtocolClient NTFVersion ErrorType NtfResponse type NtfClientError = ProtocolClientError ErrorType -defaultNTFClientConfig :: ProtocolClientConfig +defaultNTFClientConfig :: ProtocolClientConfig NTFVersion defaultNTFClientConfig = defaultClientConfig supportedClientNTFVRange ntfRegisterToken :: NtfClient -> C.APrivateAuthKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519) diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 73c2dada6..943c30c5a 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -28,7 +28,7 @@ import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) +import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake) import Simplex.Messaging.Parsers (fromTextField_) import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) import Simplex.Messaging.Util (eitherToMaybe, (<$?>)) @@ -147,7 +147,7 @@ instance Encoding ANewNtfEntity where 'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP) _ -> fail "bad ANewNtfEntity" -instance Protocol ErrorType NtfResponse where +instance Protocol NTFVersion ErrorType NtfResponse where type ProtoCommand NtfResponse = NtfCmd type ProtoType NtfResponse = 'PNTF protocolClientHandshake = ntfClientHandshake @@ -184,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e) deriving instance Show NtfCmd -instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where +instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) where type Tag (NtfCommand e) = NtfCommandTag e encodeProtocol _v = \case TNEW newTkn -> e (TNEW_, ' ', newTkn) @@ -203,7 +203,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag) - fromProtocolError = fromProtocolError @ErrorType @NtfResponse + fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse {-# INLINE fromProtocolError #-} checkCredentials (auth, _, entityId, _) cmd = case cmd of @@ -223,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where | not (B.null entityId) = Left $ CMD HAS_AUTH | otherwise = Right cmd -instance ProtocolEncoding ErrorType NtfCmd where +instance ProtocolEncoding NTFVersion ErrorType NtfCmd where type Tag NtfCmd = NtfCmdTag encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c @@ -243,7 +243,7 @@ instance ProtocolEncoding ErrorType NtfCmd where SDEL_ -> pure SDEL PING_ -> pure PING - fromProtocolError = fromProtocolError @ErrorType @NtfResponse + fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse {-# INLINE fromProtocolError #-} checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c @@ -290,7 +290,7 @@ data NtfResponse | NRPong deriving (Show) -instance ProtocolEncoding ErrorType NtfResponse where +instance ProtocolEncoding NTFVersion ErrorType NtfResponse where type Tag NtfResponse = NtfResponseTag encodeProtocol _v = \case NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey) diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 754aa6d62..7ae657fd1 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -338,7 +338,7 @@ updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do old <- atomically $ stateTVar tknStatus (,status) when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status -runNtfClientTransport :: Transport c => THandle c -> M () +runNtfClientTransport :: Transport c => THandleNTF c -> M () runNtfClientTransport th@THandle {params} = do qSize <- asks $ clientQSize . config ts <- liftIO getSystemTime @@ -355,7 +355,7 @@ runNtfClientTransport th@THandle {params} = do clientDisconnected :: NtfServerClient -> IO () clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False -receive :: Transport c => THandle c -> NtfServerClient -> M () +receive :: Transport c => THandleNTF c -> NtfServerClient -> M () receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do ts <- liftIO $ tGet th forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do @@ -370,7 +370,7 @@ receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ where write q t = atomically $ writeTBQueue q t -send :: Transport c => THandle c -> NtfServerClient -> IO () +send :: Transport c => THandleNTF c -> NtfServerClient -> IO () send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do t <- atomically $ readTBQueue sndQ void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)] diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 9e3013a8d..0d722dcc3 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -24,6 +24,7 @@ import Numeric.Natural import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF) import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Stats import Simplex.Messaging.Notifications.Server.Store @@ -34,7 +35,6 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport, THandleParams) import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams) -import Simplex.Messaging.Version (VersionRange) import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -60,7 +60,7 @@ data NtfServerConfig = NtfServerConfig logStatsStartTime :: Int64, serverStatsLogFile :: FilePath, serverStatsBackupFile :: Maybe FilePath, - ntfServerVRange :: VersionRange, + ntfServerVRange :: VersionRangeNTF, transportConfig :: TransportServerConfig } @@ -161,13 +161,13 @@ data NtfRequest data NtfServerClient = NtfServerClient { rcvQ :: TBQueue NtfRequest, sndQ :: TBQueue (Transmission NtfResponse), - ntfThParams :: THandleParams, + ntfThParams :: THandleParams NTFVersion, connected :: TVar Bool, rcvActiveAt :: TVar SystemTime, sndActiveAt :: TVar SystemTime } -newNtfServerClient :: Natural -> THandleParams -> SystemTime -> STM NtfServerClient +newNtfServerClient :: Natural -> THandleParams NTFVersion -> SystemTime -> STM NtfServerClient newNtfServerClient qSize ntfThParams ts = do rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index 00fd811a2..bc68fab03 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Notifications.Transport where @@ -12,33 +13,51 @@ import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Word (Word16) import qualified Data.X509 as X import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Transport import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Simplex.Messaging.Util (liftEitherWith) ntfBlockSize :: Int ntfBlockSize = 512 -authBatchCmdsNTFVersion :: Version -authBatchCmdsNTFVersion = 2 +data NTFVersion -currentClientNTFVersion :: Version -currentClientNTFVersion = 1 +instance VersionScope NTFVersion -currentServerNTFVersion :: Version -currentServerNTFVersion = 1 +type VersionNTF = Version NTFVersion -supportedClientNTFVRange :: VersionRange -supportedClientNTFVRange = mkVersionRange 1 currentClientNTFVersion +type VersionRangeNTF = VersionRange NTFVersion -supportedServerNTFVRange :: VersionRange -supportedServerNTFVRange = mkVersionRange 1 currentServerNTFVersion +pattern VersionNTF :: Word16 -> VersionNTF +pattern VersionNTF v = Version v + +initialNTFVersion :: VersionNTF +initialNTFVersion = VersionNTF 1 + +authBatchCmdsNTFVersion :: VersionNTF +authBatchCmdsNTFVersion = VersionNTF 2 + +currentClientNTFVersion :: VersionNTF +currentClientNTFVersion = VersionNTF 1 + +currentServerNTFVersion :: VersionNTF +currentServerNTFVersion = VersionNTF 1 + +supportedClientNTFVRange :: VersionRangeNTF +supportedClientNTFVRange = mkVersionRange initialNTFVersion currentClientNTFVersion + +supportedServerNTFVRange :: VersionRangeNTF +supportedServerNTFVRange = mkVersionRange initialNTFVersion currentServerNTFVersion + +type THandleNTF c = THandle NTFVersion c data NtfServerHandshake = NtfServerHandshake - { ntfVersionRange :: VersionRange, + { ntfVersionRange :: VersionRangeNTF, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. authPubKey :: Maybe (X.SignedExact X.PubKey) @@ -46,7 +65,7 @@ data NtfServerHandshake = NtfServerHandshake data NtfClientHandshake = NtfClientHandshake { -- | agreed SMP notifications server protocol version - ntfVersion :: Version, + ntfVersion :: VersionNTF, -- | server identity - CA certificate fingerprint keyHash :: C.KeyHash, -- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. @@ -66,12 +85,12 @@ instance Encoding NtfServerHandshake where authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey} -encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds :: Encoding a => VersionNTF -> Maybe a -> ByteString encodeAuthEncryptCmds v k | v >= authBatchCmdsNTFVersion = maybe "" smpEncode k | otherwise = "" -authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP :: VersionNTF -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authBatchCmdsNTFVersion then Just <$> p else pure Nothing instance Encoding NtfClientHandshake where @@ -83,16 +102,16 @@ instance Encoding NtfClientHandshake where authPubKey <- ntfAuthPubKeyP ntfVersion pure NtfClientHandshake {ntfVersion, keyHash, authPubKey} -ntfAuthPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519) +ntfAuthPubKeyP :: VersionNTF -> Parser (Maybe C.PublicKeyX25519) ntfAuthPubKeyP v = if v >= authBatchCmdsNTFVersion then Just <$> smpP else pure Nothing -encodeNtfAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString +encodeNtfAuthPubKey :: VersionNTF -> Maybe C.PublicKeyX25519 -> ByteString encodeNtfAuthPubKey v k | v >= authBatchCmdsNTFVersion = maybe "" smpEncode k | otherwise = "" -- | Notifcations server transport handshake. -ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c) ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c let sk = C.signX509 serverSignKey $ C.publicToX509 k @@ -106,7 +125,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do | otherwise -> throwError $ TEHandshake VERSION -- | Notifcations server client transport handshake. -ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c) ntfClientHandshake c (k, pk) keyHash ntfVRange = do let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th @@ -122,15 +141,15 @@ ntfClientHandshake c (k, pk) keyHash ntfVRange = do pure $ ntfThHandle th v pk sk_ Nothing -> throwError $ TEHandshake VERSION -ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c +ntfThHandle :: forall c. THandleNTF c -> VersionNTF -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleNTF c ntfThHandle th@THandle {params} v privKey k_ = -- TODO drop SMP v6: make thAuth non-optional let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_ v3 = v >= authBatchCmdsNTFVersion params' = params {thVersion = v, thAuth, implySessId = v3, batch = v3} - in (th :: THandle c) {params = params'} + in (th :: THandleNTF c) {params = params'} -ntfTHandle :: Transport c => c -> THandle c +ntfTHandle :: Transport c => c -> THandleNTF c ntfTHandle c = THandle {connection = c, params} where - params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = False} + params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = VersionNTF 0, thAuth = Nothing, implySessId = False, batch = False} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 3a2fa241e..d583c0361 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,4 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -46,6 +47,10 @@ module Simplex.Messaging.Protocol e2eEncMessageLength, -- * SMP protocol types + SMPClientVersion, + VersionSMPC, + VersionRangeSMPC, + pattern VersionSMPC, ProtocolEncoding (..), Command (..), SubscriptionMode (..), @@ -117,6 +122,7 @@ module Simplex.Messaging.Protocol SMPMsgMeta (..), NMsgMeta (..), MsgFlags (..), + initialSMPClientVersion, userProtocol, rcvMessageMeta, noMsgFlags, @@ -152,6 +158,7 @@ module Simplex.Messaging.Protocol tEncodeBatch1, batchTransmissions, batchTransmissions', + batchTransmissions_, -- * exports for tests CommandTag (..), @@ -167,6 +174,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as J import Data.Attoparsec.ByteString.Char8 (Parser, ()) import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first) import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -180,6 +188,7 @@ import Data.Maybe (isJust, isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality +import Data.Word (Word16) import GHC.TypeLits (ErrorMessage (..), TypeError, type (+)) import Network.Socket (ServiceName) import qualified Simplex.Messaging.Crypto as C @@ -191,19 +200,34 @@ import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..)) import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>)) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal -- SMP client protocol version history: -- 1 - binary protocol encoding (1/1/2022) -- 2 - multiple server hostnames and versioned queue addresses (8/12/2022) -srvHostnamesSMPClientVersion :: Version -srvHostnamesSMPClientVersion = 2 +data SMPClientVersion -currentSMPClientVersion :: Version -currentSMPClientVersion = 2 +instance VersionScope SMPClientVersion -supportedSMPClientVRange :: VersionRange -supportedSMPClientVRange = mkVersionRange 1 currentSMPClientVersion +type VersionSMPC = Version SMPClientVersion + +type VersionRangeSMPC = VersionRange SMPClientVersion + +pattern VersionSMPC :: Word16 -> VersionSMPC +pattern VersionSMPC v = Version v + +initialSMPClientVersion :: VersionSMPC +initialSMPClientVersion = VersionSMPC 1 + +srvHostnamesSMPClientVersion :: VersionSMPC +srvHostnamesSMPClientVersion = VersionSMPC 2 + +currentSMPClientVersion :: VersionSMPC +currentSMPClientVersion = VersionSMPC 2 + +supportedSMPClientVRange :: VersionRangeSMPC +supportedSMPClientVRange = mkVersionRange initialSMPClientVersion currentSMPClientVersion maxMessageLength :: Int maxMessageLength = 16088 @@ -273,7 +297,7 @@ data RawTransmission = RawTransmission data TransmissionAuth = TASignature C.ASignature | TAAuthenticator C.CbAuthenticator - deriving (Eq, Show) + deriving (Show) -- this encoding is backwards compatible with v6 that used Maybe C.ASignature instead of TAuthorization tAuthBytes :: Maybe TransmissionAuth -> ByteString @@ -339,8 +363,6 @@ data Command (p :: Party) where deriving instance Show (Command p) -deriving instance Eq (Command p) - data SubscriptionMode = SMSubscribe | SMOnlyCreate deriving (Eq, Show) @@ -645,7 +667,7 @@ data ClientMsgEnvelope = ClientMsgEnvelope deriving (Show) data PubHeader = PubHeader - { phVersion :: Version, + { phVersion :: VersionSMPC, phE2ePubDhKey :: Maybe C.PublicKeyX25519 } deriving (Show) @@ -747,11 +769,11 @@ instance NFData (SProtocolType p) where rnf spt = spt `seq` () data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p) -deriving instance Show AProtocolType - instance Eq AProtocolType where AProtocolType p == AProtocolType p' = isJust $ testEquality p p' +deriving instance Show AProtocolType + instance TestEquality SProtocolType where testEquality SPSMP SPSMP = Just Refl testEquality SPNTF SPNTF = Just Refl @@ -1058,7 +1080,7 @@ data CommandError deriving (Eq, Read, Show) -- | SMP transmission parser. -transmissionP :: THandleParams -> Parser RawTransmission +transmissionP :: THandleParams v -> Parser RawTransmission transmissionP THandleParams {sessionId, implySessId} = do authenticator <- smpP authorized <- A.takeByteString @@ -1072,16 +1094,16 @@ transmissionP THandleParams {sessionId, implySessId} = do command <- A.takeByteString pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command} -class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where +class (ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where type ProtoCommand msg = cmd | cmd -> msg type ProtoType msg = (sch :: ProtocolType) | sch -> msg - protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) + protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> ExceptT TransportError IO (THandle v c) protocolPing :: ProtoCommand msg protocolError :: msg -> Maybe err type ProtoServer msg = ProtocolServer (ProtoType msg) -instance Protocol ErrorType BrokerMsg where +instance Protocol SMPVersion ErrorType BrokerMsg where type ProtoCommand BrokerMsg = Cmd type ProtoType BrokerMsg = 'PSMP protocolClientHandshake = smpClientHandshake @@ -1090,14 +1112,14 @@ instance Protocol ErrorType BrokerMsg where ERR e -> Just e _ -> Nothing -class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where +class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where type Tag msg - encodeProtocol :: Version -> msg -> ByteString - protocolP :: Version -> Tag msg -> Parser msg + encodeProtocol :: Version v -> msg -> ByteString + protocolP :: Version v -> Tag msg -> Parser msg fromProtocolError :: ProtocolErrorType -> err checkCredentials :: SignedRawTransmission -> msg -> Either err msg -instance PartyI p => ProtocolEncoding ErrorType (Command p) where +instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where type Tag (Command p) = CommandTag p encodeProtocol v = \case NEW rKey dhKey auth_ subMode @@ -1124,7 +1146,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag) - fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials (auth, _, queueId, _) cmd = case cmd of @@ -1146,7 +1168,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where | isNothing auth || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance ProtocolEncoding ErrorType Cmd where +instance ProtocolEncoding SMPVersion ErrorType Cmd where type Tag Cmd = CmdTag encodeProtocol v (Cmd _ c) = encodeProtocol v c @@ -1174,12 +1196,12 @@ instance ProtocolEncoding ErrorType Cmd where PING_ -> pure PING CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB - fromProtocolError = fromProtocolError @ErrorType @BrokerMsg + fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg {-# INLINE fromProtocolError #-} checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c -instance ProtocolEncoding ErrorType BrokerMsg where +instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where type Tag BrokerMsg = BrokerMsgTag encodeProtocol _v = \case IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh) @@ -1231,12 +1253,12 @@ instance ProtocolEncoding ErrorType BrokerMsg where | otherwise -> Right cmd -- | Parse SMP protocol commands and broker messages -parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg +parseProtocol :: forall v err msg. ProtocolEncoding v err msg => Version v -> ByteString -> Either err msg parseProtocol v s = let (tag, params) = B.break (== ' ') s in case decodeTag tag of - Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params - Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown + Just cmd -> parse (protocolP v cmd) (fromProtocolError @v @err @msg $ PECmdSyntax) params + Nothing -> Left $ fromProtocolError @v @err @msg $ PECmdUnknown checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p) checkParty c = case testEquality (sParty @p) (sParty @p') of @@ -1291,7 +1313,7 @@ instance Encoding CommandError where _ -> fail "bad command error type" -- | Send signed SMP transmission to TCP transport. -tPut :: Transport c => THandle c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] +tPut :: Transport c => THandle v c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()] tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params) where tPutBatch :: TransportBatch () -> IO [Either TransportError ()] @@ -1300,7 +1322,7 @@ tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (ba TBTransmissions s n _ -> replicate n <$> tPutLog th s TBTransmission s _ -> (: []) <$> tPutLog th s -tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutLog :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ()) tPutLog th s = do r <- tPutBlock th s case r of @@ -1314,11 +1336,11 @@ data TransportBatch r = TBTransmissions ByteString Int [r] | TBTransmission Byte batchTransmissions :: Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission) -> [TransportBatch ()] batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,()) --- | encodes and batches transmissions into blocks, +-- | encodes and batches transmissions into blocks batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission, r) -> [TransportBatch r] -batchTransmissions' batch bSize - | batch = addBatch . foldr addTransmission ([], 0, 0, [], []) - | otherwise = map mkBatch1 . L.toList +batchTransmissions' batch bSize ts + | batch = batchTransmissions_ bSize $ L.map (first $ fmap tEncodeForBatch) ts + | otherwise = map mkBatch1 $ L.toList ts where mkBatch1 :: (Either TransportError SentRawTransmission, r) -> TransportBatch r mkBatch1 (t_, r) = case t_ of @@ -1329,17 +1351,21 @@ batchTransmissions' batch bSize | otherwise -> TBError TELargeMsg r where s = tEncode t + +-- | Pack encoded transmissions into batches +batchTransmissions_ :: Int -> NonEmpty (Either TransportError ByteString, r) -> [TransportBatch r] +batchTransmissions_ bSize = addBatch . foldr addTransmission ([], 0, 0, [], []) + where -- 3 = 2 bytes reserved for pad size + 1 for transmission count bSize' = bSize - 3 - addTransmission :: (Either TransportError SentRawTransmission, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r]) - addTransmission (t_, r) acc@(bs, len, n, ss, rs) = case t_ of + addTransmission :: (Either TransportError ByteString, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r]) + addTransmission (t_, r) acc@(bs, !len, !n, ss, rs) = case t_ of Left e -> (TBError e r : addBatch acc, 0, 0, [], []) - Right t + Right s | len' <= bSize' && n < 255 -> (bs, len', 1 + n, s : ss, r : rs) | sLen <= bSize' -> (addBatch acc, sLen, 1, [s], [r]) | otherwise -> (TBError TELargeMsg r : addBatch acc, 0, 0, [], []) where - s = tEncodeForBatch t sLen = B.length s len' = len + sLen addBatch :: ([TransportBatch r], Int, Int, [ByteString], [r]) -> [TransportBatch r] @@ -1362,7 +1388,7 @@ tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t -- tForAuth is lazy to avoid computing it when there is no key to sign data TransmissionForAuth = TransmissionForAuth {tForAuth :: ~ByteString, tToSend :: ByteString} -encodeTransmissionForAuth :: ProtocolEncoding e c => THandleParams -> Transmission c -> TransmissionForAuth +encodeTransmissionForAuth :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> TransmissionForAuth encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t = TransmissionForAuth {tForAuth, tToSend = if implySessId then t' else tForAuth} where @@ -1370,24 +1396,24 @@ encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t' = encodeTransmission_ v t {-# INLINE encodeTransmissionForAuth #-} -encodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> ByteString +encodeTransmission :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> ByteString encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t = if implySessId then t' else smpEncode sessionId <> t' where t' = encodeTransmission_ v t {-# INLINE encodeTransmission #-} -encodeTransmission_ :: ProtocolEncoding e c => Version -> Transmission c -> ByteString +encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString encodeTransmission_ v (CorrId corrId, queueId, command) = smpEncode (corrId, queueId) <> encodeProtocol v command {-# INLINE encodeTransmission_ #-} -- | Receive and parse transmission from the TCP transport (ignoring any trailing padding). -tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission)) +tGetParse :: Transport c => THandle v c -> IO (NonEmpty (Either TransportError RawTransmission)) tGetParse th@THandle {params} = eitherList (tParse params) <$> tGetBlock th {-# INLINE tGetParse #-} -tParse :: THandleParams -> ByteString -> NonEmpty (Either TransportError RawTransmission) +tParse :: THandleParams v -> ByteString -> NonEmpty (Either TransportError RawTransmission) tParse thParams@THandleParams {batch} s | batch = eitherList (L.map (\(Large t) -> tParse1 t)) ts | otherwise = [tParse1 s] @@ -1399,24 +1425,24 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b eitherList = either (\e -> [Left e]) -- | Receive client and server transmissions (determined by `cmd` type). -tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd)) +tGet :: forall v err cmd c. (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (NonEmpty (SignedTransmission err cmd)) tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th -tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => THandleParams -> Either TransportError RawTransmission -> SignedTransmission err cmd +tDecodeParseValidate :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v -> Either TransportError RawTransmission -> SignedTransmission err cmd tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command} | implySessId || sessId == sessionId -> let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission - | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession)) + | otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PESession)) Left _ -> tError "" where tError :: ByteString -> SignedTransmission err cmd - tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock)) + tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PEBlock)) tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd tParseValidate signed t@(sig, corrId, entityId, command) = - let cmd = parseProtocol @err @cmd v command >>= checkCredentials t + let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t in (sig, signed, (CorrId corrId, entityId, cmd)) $(J.deriveJSON defaultJSON ''MsgFlags) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index aaa42d91b..0dcde0350 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -380,7 +380,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do CPQuit -> pure () CPSkip -> pure () -runClientTransport :: Transport c => THandle c -> M () +runClientTransport :: Transport c => THandleSMP c -> M () runClientTransport th@THandle {params = THandleParams {thVersion, sessionId}} = do q <- asks $ tbqSize . config ts <- liftIO getSystemTime @@ -428,7 +428,7 @@ cancelSub sub = Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread _ -> return () -receive :: Transport c => THandle c -> Client -> M () +receive :: Transport c => THandleSMP c -> Client -> M () receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive" forever $ do @@ -449,7 +449,7 @@ receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActi VRFailed -> Left (corrId, queueId, ERR AUTH) write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty -send :: Transport c => THandle c -> Client -> IO () +send :: Transport c => THandleSMP c -> Client -> IO () send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send" forever $ do @@ -464,7 +464,7 @@ send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do NMSG {} -> 0 _ -> 1 -disconnectTransport :: Transport c => THandle c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () +disconnectTransport :: Transport c => THandle v c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO () disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport" loop diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 82666a0fc..7b9fef0b3 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -33,9 +33,8 @@ import Simplex.Messaging.Server.Stats import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Transport (ATransport) +import Simplex.Messaging.Transport (ATransport, VersionSMP, VersionRangeSMP) import Simplex.Messaging.Transport.Server (SocketState, TransportServerConfig, loadFingerprint, loadTLSServerParams, newSocketState) -import Simplex.Messaging.Version import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -73,7 +72,7 @@ data ServerConfig = ServerConfig privateKeyFile :: FilePath, certificateFile :: FilePath, -- | SMP client-server protocol version range - smpServerVRange :: VersionRange, + smpServerVRange :: VersionRangeSMP, -- | TCP transport config transportConfig :: TransportServerConfig, -- | run listener on control port @@ -128,7 +127,7 @@ data Client = Client sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), endThreads :: TVar (IntMap (Weak ThreadId)), endThreadSeq :: TVar Int, - thVersion :: Version, + thVersion :: VersionSMP, sessionId :: ByteString, connected :: TVar Bool, createdAt :: SystemTime, @@ -152,7 +151,7 @@ newServer = do savingLock <- createLock return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock} -newClient :: TVar Int -> Natural -> Version -> ByteString -> SystemTime -> STM Client +newClient :: TVar Int -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client newClient nextClientId qSize thVersion sessionId createdAt = do clientId <- stateTVar nextClientId $ \next -> (next, next + 1) subscriptions <- TM.empty diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index 56ce9b679..cd1b94215 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -17,14 +17,14 @@ data QueueRec = QueueRec notifier :: !(Maybe NtfCreds), status :: !ServerQueueStatus } - deriving (Eq, Show) + deriving (Show) data NtfCreds = NtfCreds { notifierId :: !NotifierId, notifierKey :: !NtfPublicAuthKey, rcvNtfDhSecret :: !RcvNtfDhSecret } - deriving (Eq, Show) + deriving (Show) instance StrEncoding NtfCreds where strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} = strEncode (notifierId, notifierKey, rcvNtfDhSecret) diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 0d9552f9b..775400260 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -9,6 +9,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} @@ -27,10 +28,15 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a module Simplex.Messaging.Transport ( -- * SMP transport parameters + SMPVersion, + VersionSMP, + VersionRangeSMP, + THandleSMP, supportedClientSMPRelayVRange, supportedServerSMPRelayVRange, currentClientSMPRelayVersion, currentServerSMPRelayVersion, + batchCmdsSMPVersion, basicAuthSMPVersion, subModeSMPVersion, authCmdsSMPVersion, @@ -85,6 +91,7 @@ import qualified Data.ByteString.Lazy.Char8 as LB import Data.Default (def) import Data.Functor (($>)) import Data.Version (showVersion) +import Data.Word (Word16) import qualified Data.X509 as X import qualified Data.X509.Validation as XV import GHC.IO.Handle.Internals (ioe_EOF) @@ -98,6 +105,7 @@ import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON) import Simplex.Messaging.Transport.Buffer import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import UnliftIO.Exception (Exception) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -116,30 +124,41 @@ smpBlockSize = 16384 -- 6 - allow creating queues without subscribing (9/10/2023) -- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024) -batchCmdsSMPVersion :: Version -batchCmdsSMPVersion = 4 +data SMPVersion -basicAuthSMPVersion :: Version -basicAuthSMPVersion = 5 +instance VersionScope SMPVersion -subModeSMPVersion :: Version -subModeSMPVersion = 6 +type VersionSMP = Version SMPVersion -authCmdsSMPVersion :: Version -authCmdsSMPVersion = 7 +type VersionRangeSMP = VersionRange SMPVersion -currentClientSMPRelayVersion :: Version -currentClientSMPRelayVersion = 6 +pattern VersionSMP :: Word16 -> VersionSMP +pattern VersionSMP v = Version v -currentServerSMPRelayVersion :: Version -currentServerSMPRelayVersion = 6 +batchCmdsSMPVersion :: VersionSMP +batchCmdsSMPVersion = VersionSMP 4 + +basicAuthSMPVersion :: VersionSMP +basicAuthSMPVersion = VersionSMP 5 + +subModeSMPVersion :: VersionSMP +subModeSMPVersion = VersionSMP 6 + +authCmdsSMPVersion :: VersionSMP +authCmdsSMPVersion = VersionSMP 7 + +currentClientSMPRelayVersion :: VersionSMP +currentClientSMPRelayVersion = VersionSMP 6 + +currentServerSMPRelayVersion :: VersionSMP +currentServerSMPRelayVersion = VersionSMP 6 -- minimal supported protocol version is 4 -- TODO remove code that supports sending commands without batching -supportedClientSMPRelayVRange :: VersionRange +supportedClientSMPRelayVRange :: VersionRangeSMP supportedClientSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentClientSMPRelayVersion -supportedServerSMPRelayVRange :: VersionRange +supportedServerSMPRelayVRange :: VersionRangeSMP supportedServerSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentServerSMPRelayVersion simplexMQVersion :: String @@ -287,16 +306,18 @@ instance Transport TLS where -- * SMP transport -- | The handle for SMP encrypted transport connection over Transport. -data THandle c = THandle +data THandle v c = THandle { connection :: c, - params :: THandleParams + params :: THandleParams v } -data THandleParams = THandleParams +type THandleSMP c = THandle SMPVersion c + +data THandleParams v = THandleParams { sessionId :: SessionId, blockSize :: Int, -- | agreed server protocol version - thVersion :: Version, + thVersion :: Version v, -- | peer public key for command authorization and shared secrets for entity ID encryption thAuth :: Maybe THandleAuth, -- | do NOT send session ID in transmission, but include it into signed message @@ -316,7 +337,7 @@ data THandleAuth = THandleAuth type SessionId = ByteString data ServerHandshake = ServerHandshake - { smpVersionRange :: VersionRange, + { smpVersionRange :: VersionRangeSMP, sessionId :: SessionId, -- pub key to agree shared secrets for command authorization and entity ID encryption. authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey) @@ -324,7 +345,7 @@ data ServerHandshake = ServerHandshake data ClientHandshake = ClientHandshake { -- | agreed SMP server protocol version - smpVersion :: Version, + smpVersion :: VersionSMP, -- | server identity - CA certificate fingerprint keyHash :: C.KeyHash, -- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys. @@ -358,12 +379,12 @@ instance Encoding ServerHandshake where C.SignedObject key <- smpP pure (cert, key) -encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString +encodeAuthEncryptCmds :: Encoding a => VersionSMP -> Maybe a -> ByteString encodeAuthEncryptCmds v k | v >= authCmdsSMPVersion = maybe "" smpEncode k | otherwise = "" -authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a) +authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a) authEncryptCmdsP v p = if v >= authCmdsSMPVersion then Just <$> p else pure Nothing -- | Error of SMP encrypted transport over TCP. @@ -412,13 +433,13 @@ serializeTransportError = \case TEHandshake e -> "HANDSHAKE " <> bshow e -- | Pad and send block to SMP transport. -tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ()) +tPutBlock :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ()) tPutBlock THandle {connection = c, params = THandleParams {blockSize}} block = bimapM (const $ pure TELargeMsg) (cPut c) $ C.pad block blockSize -- | Receive block from SMP transport. -tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString) +tGetBlock :: Transport c => THandle v c -> IO (Either TransportError ByteString) tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do msg <- cGet c blockSize if B.length msg == blockSize @@ -428,7 +449,7 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do -- | Server SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c) smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c sk = C.signX509 serverSignKey $ C.publicToX509 k @@ -445,7 +466,7 @@ smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do -- | Client SMP transport handshake. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c) +smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c) smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do let th@THandle {params = THandleParams {sessionId}} = smpTHandle c ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th @@ -465,24 +486,24 @@ smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do pure $ smpThHandle th v pk sk_ Nothing -> throwE $ TEHandshake VERSION -smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c +smpThHandle :: forall c. THandleSMP c -> VersionSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c smpThHandle th@THandle {params} v privKey k_ = -- TODO drop SMP v6: make thAuth non-optional let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_ params' = params {thVersion = v, thAuth, implySessId = v >= authCmdsSMPVersion} - in (th :: THandle c) {params = params'} + in (th :: THandleSMP c) {params = params'} -sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO () +sendHandshake :: (Transport c, Encoding smp) => THandle v c -> smp -> ExceptT TransportError IO () sendHandshake th = ExceptT . tPutBlock th . smpEncode -- ignores tail bytes to allow future extensions -getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp +getHandshake :: (Transport c, Encoding smp) => THandle v c -> ExceptT TransportError IO smp getHandshake th = ExceptT $ (first (\_ -> TEHandshake PARSE) . A.parseOnly smpP =<<) <$> tGetBlock th -smpTHandle :: Transport c => c -> THandle c +smpTHandle :: Transport c => c -> THandleSMP c smpTHandle c = THandle {connection = c, params} where - params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = True} + params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = VersionSMP 0, thAuth = Nothing, implySessId = False, batch = True} $(J.deriveJSON (sumTypeJSON id) ''HandshakeError) diff --git a/src/Simplex/Messaging/Version.hs b/src/Simplex/Messaging/Version.hs index dc8cfff68..25f7368d1 100644 --- a/src/Simplex/Messaging/Version.hs +++ b/src/Simplex/Messaging/Version.hs @@ -1,13 +1,16 @@ {-# LANGUAGE ConstrainedClassMethods #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeSynonymInstances #-} module Simplex.Messaging.Version ( Version, VersionRange (minVersion, maxVersion), + VersionScope, pattern VersionRange, VersionI (..), VersionRangeI (..), @@ -24,47 +27,61 @@ module Simplex.Messaging.Version where import Control.Applicative (optional) -import Data.Aeson (FromJSON (..), ToJSON (..)) +import qualified Data.Aeson as J +import qualified Data.Aeson.Encoding as JE +import Data.Aeson.Types ((.:), (.=)) +import qualified Data.Aeson.Types as JT import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Word (Word16) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Util ((<$?>)) +import Simplex.Messaging.Version.Internal (Version (..)) -pattern VersionRange :: Word16 -> Word16 -> VersionRange +pattern VersionRange :: Version v -> Version v -> VersionRange v pattern VersionRange v1 v2 <- VRange v1 v2 {-# COMPLETE VersionRange #-} -type Version = Word16 - -data VersionRange = VRange - { minVersion :: Version, - maxVersion :: Version +data VersionRange v = VRange + { minVersion :: Version v, + maxVersion :: Version v } deriving (Eq, Show) +instance J.FromJSON (VersionRange v) where + parseJSON (J.Object v) = do + minVersion <- v .: "minVersion" + maxVersion <- v .: "maxVersion" + pure VRange {minVersion, maxVersion} + parseJSON invalid = + JT.prependFailure "bad VersionRange, " (JT.typeMismatch "Object" invalid) + +instance J.ToJSON (VersionRange v) where + toEncoding VRange {minVersion, maxVersion} = JE.pairs $ ("minVersion" .= minVersion) <> ("maxVersion" .= maxVersion) + toJSON VRange {minVersion, maxVersion} = J.object ["minVersion" .= minVersion, "maxVersion" .= maxVersion] + +class VersionScope v + -- | construct valid version range, to be used in constants -mkVersionRange :: Version -> Version -> VersionRange +mkVersionRange :: Version v -> Version v -> VersionRange v mkVersionRange v1 v2 | v1 <= v2 = VRange v1 v2 | otherwise = error "invalid version range" -safeVersionRange :: Version -> Version -> Maybe VersionRange +safeVersionRange :: Version v -> Version v -> Maybe (VersionRange v) safeVersionRange v1 v2 | v1 <= v2 = Just $ VRange v1 v2 | otherwise = Nothing -versionToRange :: Version -> VersionRange +versionToRange :: Version v -> VersionRange v versionToRange v = VRange v v -instance Encoding VersionRange where +instance VersionScope v => Encoding (VersionRange v) where smpEncode (VRange v1 v2) = smpEncode (v1, v2) smpP = maybe (fail "invalid version range") pure =<< safeVersionRange <$> smpP <*> smpP -instance StrEncoding VersionRange where +instance VersionScope v => StrEncoding (VersionRange v) where strEncode (VRange v1 v2) | v1 == v2 = strEncode v1 | otherwise = strEncode v1 <> "-" <> strEncode v2 @@ -73,32 +90,23 @@ instance StrEncoding VersionRange where v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-') maybe (fail "invalid version range") pure $ safeVersionRange v1 v2 -instance ToJSON VersionRange where - toJSON (VRange v1 v2) = toJSON (v1, v2) - toEncoding (VRange v1 v2) = toEncoding (v1, v2) +class VersionScope v => VersionI v a | a -> v where + type VersionRangeT v a + version :: a -> Version v + toVersionRangeT :: a -> VersionRange v -> VersionRangeT v a -instance FromJSON VersionRange where - parseJSON v = - (\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2) - <$?> parseJSON v +class VersionScope v => VersionRangeI v a | a -> v where + type VersionT v a + versionRange :: a -> VersionRange v + toVersionT :: a -> Version v -> VersionT v a -class VersionI a where - type VersionRangeT a - version :: a -> Version - toVersionRangeT :: a -> VersionRange -> VersionRangeT a - -class VersionRangeI a where - type VersionT a - versionRange :: a -> VersionRange - toVersionT :: a -> Version -> VersionT a - -instance VersionI Version where - type VersionRangeT Version = VersionRange +instance VersionScope v => VersionI v (Version v) where + type VersionRangeT v (Version v) = VersionRange v version = id toVersionRangeT _ vr = vr -instance VersionRangeI VersionRange where - type VersionT VersionRange = Version +instance VersionScope v => VersionRangeI v (VersionRange v) where + type VersionT v (VersionRange v) = Version v versionRange = id toVersionT _ v = v @@ -109,18 +117,18 @@ pattern Compatible a <- Compatible_ a {-# COMPLETE Compatible #-} -isCompatible :: VersionI a => a -> VersionRange -> Bool +isCompatible :: VersionI v a => a -> VersionRange v -> Bool isCompatible x (VRange v1 v2) = let v = version x in v1 <= v && v <= v2 -isCompatibleRange :: VersionRangeI a => a -> VersionRange -> Bool +isCompatibleRange :: VersionRangeI v a => a -> VersionRange v -> Bool isCompatibleRange x (VRange min2 max2) = min1 <= max2 && min2 <= max1 where VRange min1 max1 = versionRange x -proveCompatible :: VersionI a => a -> VersionRange -> Maybe (Compatible a) +proveCompatible :: VersionI v a => a -> VersionRange v -> Maybe (Compatible a) proveCompatible x vr = x `mkCompatibleIf` (x `isCompatible` vr) -compatibleVersion :: VersionRangeI a => a -> VersionRange -> Maybe (Compatible (VersionT a)) +compatibleVersion :: VersionRangeI v a => a -> VersionRange v -> Maybe (Compatible (VersionT v a)) compatibleVersion x vr = toVersionT x (min max1 max2) `mkCompatibleIf` isCompatibleRange x vr where diff --git a/src/Simplex/Messaging/Version/Internal.hs b/src/Simplex/Messaging/Version/Internal.hs new file mode 100644 index 000000000..23cab1d1f --- /dev/null +++ b/src/Simplex/Messaging/Version/Internal.hs @@ -0,0 +1,25 @@ +module Simplex.Messaging.Version.Internal where + +import Data.Aeson (FromJSON (..), ToJSON (..)) +import Data.Word (Word16) +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String + +-- Do not use constructor of this type directry +newtype Version v = Version Word16 + deriving (Eq, Ord, Show) + +instance Encoding (Version v) where + smpEncode (Version v) = smpEncode v + smpP = Version <$> smpP + +instance StrEncoding (Version v) where + strEncode (Version v) = strEncode v + strP = Version <$> strP + +instance ToJSON (Version v) where + toEncoding (Version v) = toEncoding v + toJSON (Version v) = toJSON v + +instance FromJSON (Version v) where + parseJSON v = Version <$> parseJSON v diff --git a/src/Simplex/RemoteControl/Client.hs b/src/Simplex/RemoteControl/Client.hs index c73679439..3cf1050fa 100644 --- a/src/Simplex/RemoteControl/Client.hs +++ b/src/Simplex/RemoteControl/Client.hs @@ -68,12 +68,6 @@ import Simplex.RemoteControl.Types import UnliftIO import UnliftIO.Concurrent -currentRCVersion :: Version -currentRCVersion = 1 - -supportedRCVRange :: VersionRange -supportedRCVRange = mkVersionRange 1 currentRCVersion - xrcpBlockSize :: Int xrcpBlockSize = 16384 @@ -181,7 +175,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct { ca = certFingerprint caCert, host, port = fromIntegral portNum, - v = supportedRCVRange, + v = supportedRCPVRange, app = ctrlAppInfo, ts, skey = fst sessKeys, @@ -220,7 +214,7 @@ prepareHostSession unless (ca == tlsHostFingerprint) $ throwError RCEIdentity (kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey - unless (isCompatible v supportedRCVRange) $ throwError RCEVersion + unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey} knownHost' <- updateKnownHost ca dhPubKey let ctrlHello = RCCtrlHello {} @@ -334,7 +328,7 @@ prepareHostHello RCInvitation {v, dh = dhPubKey} hostAppInfo = do logDebug "Preparing session" - case compatibleVersion v supportedRCVRange of + case compatibleVersion v supportedRCPVRange of Nothing -> throwError RCEVersion Just (Compatible v') -> do nonce <- liftIO . atomically $ C.randomCbNonce drg diff --git a/src/Simplex/RemoteControl/Invitation.hs b/src/Simplex/RemoteControl/Invitation.hs index f5deac9a8..712c41a9d 100644 --- a/src/Simplex/RemoteControl/Invitation.hs +++ b/src/Simplex/RemoteControl/Invitation.hs @@ -27,7 +27,7 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Version (VersionRange) +import Simplex.RemoteControl.Types (VersionRangeRCP) data RCInvitation = RCInvitation { -- | CA TLS certificate fingerprint of the controller. @@ -37,7 +37,7 @@ data RCInvitation = RCInvitation host :: TransportHost, port :: Word16, -- | Supported version range for remote control protocol - v :: VersionRange, + v :: VersionRangeRCP, -- | Application information app :: J.Value, -- | Session start time in seconds since epoch diff --git a/src/Simplex/RemoteControl/Types.hs b/src/Simplex/RemoteControl/Types.hs index e1598f25c..b8a7c1141 100644 --- a/src/Simplex/RemoteControl/Types.hs +++ b/src/Simplex/RemoteControl/Types.hs @@ -5,6 +5,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} @@ -17,6 +18,7 @@ import Data.ByteString (ByteString) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) +import Data.Word (Word16) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.SNTRUP761 import Simplex.Messaging.Crypto.SNTRUP761.Bindings @@ -26,7 +28,8 @@ import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport (TLS) import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util (safeDecodeUtf8) -import Simplex.Messaging.Version (Version, VersionRange, mkVersionRange) +import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange) +import Simplex.Messaging.Version.Internal import UnliftIO data RCErrorType @@ -92,24 +95,37 @@ instance StrEncoding RCErrorType where -- * Discovery -ipProbeVersionRange :: VersionRange -ipProbeVersionRange = mkVersionRange 1 1 +data RCPVersion + +instance VersionScope RCPVersion + +type VersionRCP = Version RCPVersion + +type VersionRangeRCP = VersionRange RCPVersion + +pattern VersionRCP :: Word16 -> VersionRCP +pattern VersionRCP v = Version v + +currentRCPVersion :: VersionRCP +currentRCPVersion = VersionRCP 1 + +supportedRCPVRange :: VersionRangeRCP +supportedRCPVRange = mkVersionRange (VersionRCP 1) currentRCPVersion data IpProbe = IpProbe - { versionRange :: VersionRange, + { versionRange :: VersionRangeRCP, randomNonce :: ByteString } deriving (Show) instance Encoding IpProbe where smpEncode IpProbe {versionRange, randomNonce} = smpEncode (versionRange, 'I', randomNonce) - smpP = IpProbe <$> (smpP <* "I") *> smpP -- * Session data RCHostHello = RCHostHello - { v :: Version, + { v :: VersionRCP, ca :: C.KeyHash, app :: J.Value, kem :: KEMPublicKey diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index f0078ae24..34719e803 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -4,6 +4,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PostfixOperators #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} @@ -12,12 +13,12 @@ module AgentTests (agentTests) where import AgentTests.ConnectionRequestTests import AgentTests.DoubleRatchetTests (doubleRatchetTests) -import AgentTests.FunctionalAPITests (functionalAPITests) +import AgentTests.FunctionalAPITests (functionalAPITests, pattern Msg, pattern Msg') import AgentTests.MigrationTests (migrationTests) import AgentTests.NotificationTests (notificationTests) import AgentTests.SQLiteTests (storeTests) import Control.Concurrent -import Control.Monad (forM_) +import Control.Monad (forM_, when) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Maybe (fromJust) @@ -26,15 +27,18 @@ import GHC.Stack (withFrozenCallStack) import Network.HTTP.Types (urlEncode) import SMPAgentClient import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) -import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.Protocol hiding (MID, CONF, INFO, REQ) import qualified Simplex.Messaging.Agent.Protocol as A +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQSupportOn, pattern PQSupportOff) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) +import Simplex.Messaging.Protocol (ErrorType (..)) import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..)) import Simplex.Messaging.Util (bshow) import System.Directory (removeFile) import System.Timeout import Test.Hspec +import Util agentTests :: ATransport -> Spec agentTests (ATransport t) = do @@ -46,24 +50,25 @@ agentTests (ATransport t) = do describe "Migration tests" migrationTests describe "SMP agent protocol syntax" $ syntaxTests t describe "Establishing duplex connection (via agent protocol)" $ do - -- These tests are disabled because the agent does not work correctly with multiple connected TCP clients - xit "should connect via one server and one agent" $ do - smpAgentTest2_1_1 $ testDuplexConnection t - xit "should connect via one server and one agent (random IDs)" $ do - smpAgentTest2_1_1 $ testDuplexConnRandomIds t + skip "These tests are disabled because the agent does not work correctly with multiple connected TCP clients" $ + describe "one agent" $ do + it "should connect via one server and one agent" $ do + smpAgentTest2_1_1 $ testDuplexConnection t + it "should connect via one server and one agent (random IDs)" $ do + smpAgentTest2_1_1 $ testDuplexConnRandomIds t it "should connect via one server and 2 agents" $ do smpAgentTest2_2_1 $ testDuplexConnection t it "should connect via one server and 2 agents (random IDs)" $ do smpAgentTest2_2_1 $ testDuplexConnRandomIds t - it "should connect via 2 servers and 2 agents" $ do - smpAgentTest2_2_2 $ testDuplexConnection t - it "should connect via 2 servers and 2 agents (random IDs)" $ do - smpAgentTest2_2_2 $ testDuplexConnRandomIds t + describe "should connect via 2 servers and 2 agents" $ do + pqMatrix2 t smpAgentTest2_2_2 testDuplexConnection' + describe "should connect via 2 servers and 2 agents (random IDs)" $ do + pqMatrix2 t smpAgentTest2_2_2 testDuplexConnRandomIds' describe "Establishing connections via `contact connection`" $ do - it "should connect via contact connection with one server and 3 agents" $ do - smpAgentTest3 $ testContactConnection t - it "should connect via contact connection with one server and 2 agents (random IDs)" $ do - smpAgentTest2_2_1 $ testContactConnRandomIds t + describe "should connect via contact connection with one server and 3 agents" $ do + pqMatrix3 t smpAgentTest3 testContactConnection + describe "should connect via contact connection with one server and 2 agents (random IDs)" $ do + pqMatrix2NoInv t smpAgentTest2_2_1 testContactConnRandomIds it "should support rejecting contact request" $ do smpAgentTest2_2_1 $ testRejectContactRequest t describe "Connection subscriptions" $ do @@ -72,8 +77,8 @@ agentTests (ATransport t) = do it "should send notifications to client when server disconnects" $ do smpAgentServerTest $ testSubscrNotification t describe "Message delivery and server reconnection" $ do - it "should deliver messages after losing server connection and re-connecting" $ do - smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t + describe "should deliver messages after losing server connection and re-connecting" $ + pqMatrix2 t smpAgentTest2_2_2_needs_server testMsgDeliveryServerRestart it "should connect to the server when server goes up if it initially was down" $ do smpAgentTestN [] $ testServerConnectionAfterError t it "should deliver pending messages after agent restarting" $ do @@ -133,6 +138,9 @@ action #> (corrId, connId, cmd) = withFrozenCallStack $ action `shouldReturn` (c (=#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation action =#> p = withFrozenCallStack $ action >>= (`shouldSatisfy` p . correctTransmission) +pattern MID :: AgentMsgId -> ACommand 'Agent 'AEConn +pattern MID msgId = A.MID msgId PQEncOn + correctTransmission :: (ACorrId, ConnId, Either AgentErrorType cmd) -> (ACorrId, ConnId, cmd) correctTransmission (corrId, connId, cmdOrErr) = case cmdOrErr of Right cmd -> (corrId, connId, cmd) @@ -161,130 +169,188 @@ h #:# err = tryGet `shouldReturn` () Just _ -> error err _ -> return () -pattern Msg :: MsgBody -> ACommand 'Agent e -pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody +type PQMatrix2 c = + HasCallStack => + TProxy c -> + (HasCallStack => (c -> c -> IO ()) -> Expectation) -> + (HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> IO ()) -> + Spec -pattern Msg' :: AgentMsgId -> MsgBody -> ACommand 'Agent e -pattern Msg' aMsgId msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} _ msgBody +pqMatrix2 :: PQMatrix2 c +pqMatrix2 = pqMatrix2_ True + +pqMatrix2NoInv :: PQMatrix2 c +pqMatrix2NoInv = pqMatrix2_ False + +pqMatrix2_ :: Bool -> PQMatrix2 c +pqMatrix2_ pqInv _ smpTest test = do + it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOff) + it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOn) + it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOff) + it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOn) + when pqInv $ do + it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOff) + it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOn) + +pqMatrix3 :: + HasCallStack => + TProxy c -> + (HasCallStack => (c -> c -> c -> IO ()) -> Expectation) -> + (HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()) -> + Spec +pqMatrix3 _ smpTest test = do + it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOff) + it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOn) + it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOff) + it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOn) + it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOff) + it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOn) + it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOff) + it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOn) testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () -testDuplexConnection _ alice bob = do - ("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV subscribe") +testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQSupportOn) + +testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO () +testDuplexConnection' (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ + pqSup = CR.pqEncToSupport pq + ("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq - bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) - ("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:) + bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) + ("", "bob", Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` pqSup alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK) - bob <# ("", "alice", INFO "alice's connInfo") - bob <# ("", "alice", CON) - alice <# ("", "bob", CON) + bob <# ("", "alice", A.INFO pqSup "alice's connInfo") + bob <# ("", "alice", CON pq) + alice <# ("", "bob", CON pq) -- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4 - alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", MID 4) + alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", A.MID 4 pq) alice <# ("", "bob", SENT 4) - bob <#= \case ("", "alice", Msg' 4 "hello") -> True; _ -> False + bob <#= \case ("", "alice", Msg' 4 pq' "hello") -> pq == pq'; _ -> False bob #: ("12", "alice", "ACK 4") #> ("12", "alice", OK) - alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", MID 5) + alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", A.MID 5 pq) alice <# ("", "bob", SENT 5) - bob <#= \case ("", "alice", Msg' 5 "how are you?") -> True; _ -> False + bob <#= \case ("", "alice", Msg' 5 pq' "how are you?") -> pq == pq'; _ -> False bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK) - bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 6) + bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", A.MID 6 pq) bob <# ("", "alice", SENT 6) - alice <#= \case ("", "bob", Msg' 6 "hello too") -> True; _ -> False + alice <#= \case ("", "bob", Msg' 6 pq' "hello too") -> pq == pq'; _ -> False alice #: ("3a", "bob", "ACK 6") #> ("3a", "bob", OK) - bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", MID 7) + bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", A.MID 7 pq) bob <# ("", "alice", SENT 7) - alice <#= \case ("", "bob", Msg' 7 "message 1") -> True; _ -> False + alice <#= \case ("", "bob", Msg' 7 pq' "message 1") -> pq == pq'; _ -> False alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK) alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) - bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", MID 8) + bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", A.MID 8 pq) bob <# ("", "alice", MERR 8 (SMP AUTH)) alice #: ("6", "bob", "DEL") #> ("6", "bob", OK) alice #:# "nothing else should be delivered to alice" -testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO () -testDuplexConnRandomIds _ alice bob = do - ("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV subscribe") +testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO () +testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQSupportOn) + +testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO () +testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ + pqSup = CR.pqEncToSupport pq + ("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq - ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") - ("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:) + ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") + ("", bobConn', Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` pqSup bobConn' `shouldBe` bobConn alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False - bob <# ("", aliceConn, INFO "alice's connInfo") - bob <# ("", aliceConn, CON) - alice <# ("", bobConn, CON) - alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, MID 4) + bob <# ("", aliceConn, A.INFO pqSup "alice's connInfo") + bob <# ("", aliceConn, CON pq) + alice <# ("", bobConn, CON pq) + alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, A.MID 4 pq) alice <# ("", bobConn, SENT 4) - bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg' 4 pq' "hello") -> c == aliceConn && pq == pq'; _ -> False bob #: ("12", aliceConn, "ACK 4") #> ("12", aliceConn, OK) - alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, MID 5) + alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, A.MID 5 pq) alice <# ("", bobConn, SENT 5) - bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg' 5 pq' "how are you?") -> c == aliceConn && pq == pq'; _ -> False bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK) - bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, MID 6) + bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, A.MID 6 pq) bob <# ("", aliceConn, SENT 6) - alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False + alice <#= \case ("", c, Msg' 6 pq' "hello too") -> c == bobConn && pq == pq'; _ -> False alice #: ("3a", bobConn, "ACK 6") #> ("3a", bobConn, OK) - bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, MID 7) + bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, A.MID 7 pq) bob <# ("", aliceConn, SENT 7) - alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False + alice <#= \case ("", c, Msg' 7 pq' "message 1") -> c == bobConn && pq == pq'; _ -> False alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK) alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK) - bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, MID 8) + bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, A.MID 8 pq) bob <# ("", aliceConn, MERR 8 (SMP AUTH)) alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK) alice #:# "nothing else should be delivered to alice" -testContactConnection :: Transport c => TProxy c -> c -> c -> c -> IO () -testContactConnection _ alice bob tom = do - ("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON subscribe") +testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO () +testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do + ("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq + abPQ = pqConnectionMode aPQ bPQ + abPQSup = CR.pqEncToSupport abPQ + aPQMode = CR.connPQEncryption aPQ - bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) - ("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:) - alice #: ("2", "bob", "ACPT " <> aInvId <> " 16\nalice's connInfo") #> ("2", "bob", OK) - ("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:) + bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK) + ("", "alice_contact", Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` bPQ + alice #: ("2", "bob", "ACPT " <> aInvId <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("2", "bob", OK) + ("", "alice", Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:) + pqSup'' `shouldBe` abPQSup bob #: ("12", "alice", "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", "alice", OK) - alice <# ("", "bob", INFO "bob's connInfo 2") - alice <# ("", "bob", CON) - bob <# ("", "alice", CON) - alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", MID 4) + alice <# ("", "bob", A.INFO abPQSup "bob's connInfo 2") + alice <# ("", "bob", CON abPQ) + bob <# ("", "alice", CON abPQ) + alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", A.MID 4 abPQ) alice <# ("", "bob", SENT 4) - bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False + bob <#= \case ("", "alice", Msg' 4 pq' "hi") -> pq' == abPQ; _ -> False bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK) - tom #: ("21", "alice", "JOIN T " <> cReq' <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK) - ("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:) - alice #: ("4", "tom", "ACPT " <> aInvId' <> " 16\nalice's connInfo") #> ("4", "tom", OK) - ("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:) + let atPQ = pqConnectionMode aPQ tPQ + atPQSup = CR.pqEncToSupport atPQ + tom #: ("21", "alice", "JOIN T " <> cReq' <> enableKEMStr tPQ <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK) + ("", "alice_contact", Right (A.REQ aInvId' pqSup3 _ "tom's connInfo")) <- (alice <#:) + pqSup3 `shouldBe` tPQ + alice #: ("4", "tom", "ACPT " <> aInvId' <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("4", "tom", OK) + ("", "alice", Right (A.CONF tConfId pqSup4 _ "alice's connInfo")) <- (tom <#:) + pqSup4 `shouldBe` atPQSup tom #: ("22", "alice", "LET " <> tConfId <> " 16\ntom's connInfo 2") #> ("22", "alice", OK) - alice <# ("", "tom", INFO "tom's connInfo 2") - alice <# ("", "tom", CON) - tom <# ("", "alice", CON) - alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", MID 4) + alice <# ("", "tom", A.INFO atPQSup "tom's connInfo 2") + alice <# ("", "tom", CON atPQ) + tom <# ("", "alice", CON atPQ) + alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", A.MID 4 atPQ) alice <# ("", "tom", SENT 4) - tom <#= \case ("", "alice", Msg "hi there") -> True; _ -> False + tom <#= \case ("", "alice", Msg' 4 pq' "hi there") -> pq' == atPQ; _ -> False tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK) -testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO () -testContactConnRandomIds _ alice bob = do - ("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON subscribe") +testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO () +testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ + pqSup = CR.pqEncToSupport pq + ("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe") let cReq' = strEncode cReq - ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") - ("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:) + ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") + ("", aliceContact', Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:) + pqSup' `shouldBe` bPQ aliceContact' `shouldBe` aliceContact - ("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> " 16\nalice's connInfo") - ("", aliceConn', Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:) + ("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> enableKEMStr (CR.connPQEncryption aPQ) <> " 16\nalice's connInfo") + ("", aliceConn', Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:) + pqSup'' `shouldBe` pqSup aliceConn' `shouldBe` aliceConn bob #: ("12", aliceConn, "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", aliceConn, OK) - alice <# ("", bobConn, INFO "bob's connInfo 2") - alice <# ("", bobConn, CON) - bob <# ("", aliceConn, CON) + alice <# ("", bobConn, A.INFO pqSup "bob's connInfo 2") + alice <# ("", bobConn, CON pq) + bob <# ("", aliceConn, CON pq) - alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, MID 4) + alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, A.MID 4 pq) alice <# ("", bobConn, SENT 4) - bob <#= \case ("", c, Msg "hi") -> c == aliceConn; _ -> False + bob <#= \case ("", c, Msg' 4 pq' "hi") -> c == aliceConn && pq == pq'; _ -> False bob #: ("13", aliceConn, "ACK 4") #> ("13", aliceConn, OK) testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO () @@ -292,7 +358,7 @@ testRejectContactRequest _ alice bob = do ("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON subscribe") let cReq' = strEncode cReq bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 10\nbob's info") #> ("11", "alice", OK) - ("", "a_contact", Right (REQ aInvId _ "bob's info")) <- (alice <#:) + ("", "a_contact", Right (A.REQ aInvId PQSupportOff _ "bob's info")) <- (alice <#:) -- RJCT must use correct contact connection alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND) alice #: ("2b", "a_contact", "RJCT " <> aInvId) #> ("2b", "a_contact", OK) @@ -327,31 +393,32 @@ testSubscrNotification t (server, _) client = do withSmpServer (ATransport t) $ client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue -testMsgDeliveryServerRestart :: Transport c => TProxy c -> c -> c -> IO () -testMsgDeliveryServerRestart t alice bob = do +testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO () +testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do + let pq = pqConnectionMode aPQ bPQ withServer $ do - connect (alice, "alice") (bob, "bob") - bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", MID 4) + connect' (alice, "alice", aPQ) (bob, "bob", bPQ) + bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", A.MID 4 pq) bob <# ("", "alice", SENT 4) - alice <#= \case ("", "bob", Msg "hi") -> True; _ -> False + alice <#= \case ("", "bob", Msg' _ pq' "hi") -> pq == pq'; _ -> False alice #: ("11", "bob", "ACK 4") #> ("11", "bob", OK) alice #:# "nothing else delivered before the server is killed" let server = SMPServer "localhost" testPort2 testKeyHash alice <#. ("", "", DOWN server ["bob"]) - bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", MID 5) + bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", A.MID 5 pq) bob #:# "nothing else delivered before the server is restarted" alice #:# "nothing else delivered before the server is restarted" withServer $ do bob <# ("", "alice", SENT 5) alice <#. ("", "", UP server ["bob"]) - alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False + alice <#= \case ("", "bob", Msg' _ pq' "hello again") -> pq == pq'; _ -> False alice #: ("12", "bob", "ACK 5") #> ("12", "bob", OK) removeFile testStoreLogFile where - withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` () + withServer test' = withSmpServerStoreLogOn (transport @c) testPort2 (const test') `shouldReturn` () testServerConnectionAfterError :: forall c. Transport c => TProxy c -> [c] -> IO () testServerConnectionAfterError t _ = do @@ -432,7 +499,7 @@ testConcurrentMsgDelivery _ alice bob = do ("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV subscribe") let cReq' = strEncode cReq bob #: ("11", "alice2", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice2", OK) - ("", "bob2", Right (CONF _confId _ "bob's connInfo")) <- (alice <#:) + ("", "bob2", Right (A.CONF _confId PQSupportOff _ "bob's connInfo")) <- (alice <#:) -- below commands would be needed to accept bob's connection, but alice does not -- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK) -- bob <# ("", "alice", INFO "alice's connInfo") @@ -492,16 +559,33 @@ testResumeDeliveryQuotaExceeded _ alice bob = do -- message 8 is skipped because of alice agent sending "QCONT" message bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK) -connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO () -connect (h1, name1) (h2, name2) = do - ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV subscribe") +connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO () +connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQSupportOn) + +connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQSupport) -> IO () +connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do + ("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe") let cReq' = strEncode cReq - h2 #: ("c2", name1, "JOIN T " <> cReq' <> " subscribe 5\ninfo2") #> ("c2", name1, OK) - ("", _, Right (CONF connId _ "info2")) <- (h1 <#:) + pq = pqConnectionMode pqMode1 pqMode2 + pqSup = CR.pqEncToSupport pq + h2 #: ("c2", name1, "JOIN T " <> cReq' <> enableKEMStr pqMode2 <> " subscribe 5\ninfo2") #> ("c2", name1, OK) + ("", _, Right (A.CONF connId pqSup' _ "info2")) <- (h1 <#:) + pqSup' `shouldBe` pqSup h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK) - h2 <# ("", name1, INFO "info1") - h2 <# ("", name1, CON) - h1 <# ("", name2, CON) + h2 <# ("", name1, A.INFO pqSup "info1") + h2 <# ("", name1, CON pq) + h1 <# ("", name2, CON pq) + +pqConnectionMode :: InitialKeys -> PQSupport -> PQEncryption +pqConnectionMode pqMode1 pqMode2 = PQEncryption $ supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2 + +enableKEMStr :: PQSupport -> ByteString +enableKEMStr PQSupportOn = " " <> strEncode PQSupportOn +enableKEMStr _ = "" + +pqConnModeStr :: InitialKeys -> ByteString +pqConnModeStr (IKNoPQ PQSupportOff) = "" +pqConnModeStr pq = " " <> strEncode pq sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO () sendMessage (h1, name1) (h2, name2) msg = do diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index 83548182a..7ab234887 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -1,7 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.ConnectionRequestTests where @@ -12,7 +15,7 @@ import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (ProtocolServer (..), supportedSMPClientVRange) +import Simplex.Messaging.Protocol (ProtocolServer (..), pattern VersionSMPC, supportedSMPClientVRange) import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Version import Test.Hspec @@ -38,7 +41,7 @@ queue :: SMPQueueUri queue = SMPQueueUri supportedSMPClientVRange queueAddr queueV1 :: SMPQueueUri -queueV1 = SMPQueueUri (mkVersionRange 1 1) queueAddr +queueV1 = SMPQueueUri (mkVersionRange (VersionSMPC 1) (VersionSMPC 1)) queueAddr testDhKey :: C.PublicKeyX25519 testDhKey = "MCowBQYDK2VuAyEAjiswwI3O/NlS8Fk3HJUW870EY2bAwmttMBsvRB9eV3o=" @@ -53,7 +56,7 @@ connReqData :: ConnReqUriData connReqData = ConnReqUriData { crScheme = SSSimplex, - crAgentVRange = mkVersionRange 2 2, + crAgentVRange = mkVersionRange (VersionSMPA 2) (VersionSMPA 2), crSmpQueues = [queueV1], crClientData = Nothing } @@ -61,11 +64,11 @@ connReqData = testDhPubKey :: C.PublicKeyX448 testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U=" -testE2ERatchetParams :: E2ERatchetParamsUri 'C.X448 -testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey +testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448 +testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey Nothing -testE2ERatchetParams12 :: E2ERatchetParamsUri 'C.X448 -testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey +testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448 +testE2ERatchetParams12 = E2ERatchetParamsUri (supportedE2EEncryptVRange PQSupportOn) testDhPubKey testDhPubKey Nothing connectionRequest :: AConnectionRequestUri connectionRequest = @@ -79,7 +82,7 @@ connectionRequestCurrentRange :: AConnectionRequestUri connectionRequestCurrentRange = ACR SCMInvitation $ CRInvitationUri - connReqData {crAgentVRange = supportedSMPAgentVRange, crSmpQueues = [queueV1, queueV1]} + connReqData {crAgentVRange = supportedSMPAgentVRange PQSupportOn, crSmpQueues = [queueV1, queueV1]} testE2ERatchetParams12 connectionRequestClientDataEmpty :: AConnectionRequestUri @@ -98,7 +101,7 @@ connectionRequestTests = it "should serialize SMP queue URIs" $ do strEncode (queue :: SMPQueueUri) {queueAddress = queueAddrNoPort} `shouldBe` "smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri - strEncode queue {clientVRange = mkVersionRange 1 2} + strEncode queue {clientVRange = mkVersionRange (VersionSMPC 1) (VersionSMPC 2)} `shouldBe` "smp://1234-w==@smp.simplex.im:5223/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri it "should parse SMP queue URIs" $ do strDecode ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStr) @@ -119,11 +122,11 @@ connectionRequestTests = <> urlEncode True testDhKeyStrUri <> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" strEncode connectionRequestCurrentRange - `shouldBe` "simplex:/invitation#/?v=2-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" + `shouldBe` "simplex:/invitation#/?v=2-5&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri - <> "&e2e=v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> "&e2e=v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" strEncode connectionRequestClientDataEmpty `shouldBe` "simplex:/invitation#/?v=2&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> urlEncode True testDhKeyStrUri @@ -167,9 +170,9 @@ connectionRequestTests = <> testDhKeyStrUri <> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D" <> testDhKeyStrUri - <> "&e2e=extra_key%3Dnew%26v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" + <> "&e2e=extra_key%3Dnew%26v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D" <> "&some_new_param=abc" - <> "&v=2-4" + <> "&v=2-5" ) `shouldBe` Right connectionRequestCurrentRange strDecode diff --git a/tests/AgentTests/DoubleRatchetTests.hs b/tests/AgentTests/DoubleRatchetTests.hs index 95e23b333..f95f07029 100644 --- a/tests/AgentTests/DoubleRatchetTests.hs +++ b/tests/AgentTests/DoubleRatchetTests.hs @@ -1,75 +1,178 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} module AgentTests.DoubleRatchetTests where import Control.Concurrent.STM +import Control.Monad (when) import Control.Monad.Except +import Control.Monad.IO.Class import Crypto.Random (ChaChaDRG) -import Data.Aeson (FromJSON, ToJSON) +import Data.Aeson (FromJSON, ToJSON, (.=)) import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.Map.Strict as M +import Data.Type.Equality import Simplex.Messaging.Crypto (Algorithm (..), AlgorithmI, CryptoError, DhAlgorithm) import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$$>)) +import Simplex.Messaging.Version import Test.Hspec doubleRatchetTests :: Spec doubleRatchetTests = do describe "double-ratchet encryption/decryption" $ do - it "should serialize and parse message header" testMessageHeader - it "should encrypt and decrypt messages" $ do - withRatchets @X25519 testEncryptDecrypt - withRatchets @X448 testEncryptDecrypt - it "should encrypt and decrypt skipped messages" $ do - withRatchets @X25519 testSkippedMessages - withRatchets @X448 testSkippedMessages - it "should encrypt and decrypt many messages" $ do - withRatchets @X25519 testManyMessages - it "should allow skipped after ratchet advance" $ do - withRatchets @X25519 testSkippedAfterRatchetAdvance + it "should serialize and parse message header" $ do + testAlgs $ testMessageHeader kdfX3DHE2EEncryptVersion + testAlgs $ testMessageHeader $ max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + describe "message tests" $ runMessageTests initRatchets False it "should encode/decode ratchet as JSON" $ do - testKeyJSON C.SX25519 - testKeyJSON C.SX448 - testRatchetJSON C.SX25519 - testRatchetJSON C.SX448 - it "should agree the same ratchet parameters" $ do - testX3dh C.SX25519 - testX3dh C.SX448 - it "should agree the same ratchet parameters with version 1" $ do - testX3dhV1 C.SX25519 - testX3dhV1 C.SX448 + testAlgs testKeyJSON + testAlgs testRatchetJSON + testVersionJSON + it "should decode v2 Ratchet with default field values" $ testDecodeV2RatchetJSON + it "should agree the same ratchet parameters" $ testAlgs testX3dh + it "should agree the same ratchet parameters with version 1" $ testAlgs testX3dhV1 + describe "post-quantum hybrid KEM double-ratchet algorithm" $ do + describe "hybrid KEM key agreement" $ do + it "should propose KEM during agreement, but no shared secret" $ testAlgs testPqX3dhProposeInReply + it "should agree shared secret using KEM" $ testAlgs testPqX3dhProposeAccept + it "should reject proposed KEM in reply" $ testAlgs testPqX3dhProposeReject + it "should allow second proposal in reply" $ testAlgs testPqX3dhProposeAgain + describe "hybrid KEM key agreement errors" $ do + it "should fail if reply contains acceptance without proposal" $ testAlgs testPqX3dhAcceptWithoutProposalError + describe "ratchet encryption/decryption" $ do + it "should serialize and parse public KEM params" testKEMParams + it "should serialize and parse message header" $ testAlgs testMessageHeaderKEM + describe "message tests, KEM proposed" $ runMessageTests initRatchetsKEMProposed True + describe "message tests, KEM accepted" $ runMessageTests initRatchetsKEMAccepted False + describe "message tests, KEM proposed again in reply" $ runMessageTests initRatchetsKEMProposedAgain True + it "should disable and re-enable KEM" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEM + it "should disable and re-enable KEM (always set PQEncryption)" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEMStrict + it "should enable KEM when it was not enabled in handshake" $ withRatchets_ @X25519 initRatchets testEnableKEM + it "should enable KEM when it was not enabled in handshake (always set PQEncryption)" $ withRatchets_ @X25519 initRatchets testEnableKEMStrict + +runMessageTests :: + (forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)) -> + Bool -> + Spec +runMessageTests initRatchets_ agreeRatchetKEMs = do + it "should encrypt and decrypt messages" $ run $ testEncryptDecrypt agreeRatchetKEMs + it "should encrypt and decrypt skipped messages" $ run $ testSkippedMessages agreeRatchetKEMs + it "should encrypt and decrypt many messages" $ run $ testManyMessages agreeRatchetKEMs + it "should allow skipped after ratchet advance" $ run $ testSkippedAfterRatchetAdvance agreeRatchetKEMs + where + run :: (forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a) -> IO () + run test = do + withRatchets_ @X25519 initRatchets_ test + withRatchets_ @X448 initRatchets_ test + + +testAlgs :: (forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()) -> IO () +testAlgs test = test C.SX25519 >> test C.SX448 paddedMsgLen :: Int paddedMsgLen = 100 -fullMsgLen :: Int -fullMsgLen = 1 + fullHeaderLen + C.authTagSize + paddedMsgLen +fullMsgLen :: Ratchet a -> Int +fullMsgLen Ratchet {rcSupportKEM, rcVersion} = headerLenLength + fullHeaderLen v rcSupportKEM + C.authTagSize + paddedMsgLen + where + v = current rcVersion + headerLenLength = case rcSupportKEM of + PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 3 -- two bytes are added because of two Large used in new encoding + _ -> 1 -testMessageHeader :: Expectation -testMessageHeader = do - (k, _) <- atomically . C.generateKeyPair @X25519 =<< C.newRandom - let hdr = MsgHeader {msgMaxVersion = currentE2EEncryptVersion, msgDHRs = k, msgPN = 0, msgNs = 0} - parseAll (smpP @(MsgHeader 'X25519)) (smpEncode hdr) `shouldBe` Right hdr +testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation +testMessageHeader v _ = do + (k, _) <- atomically . C.generateKeyPair @a =<< C.newRandom + let hdr = MsgHeader {msgMaxVersion = v, msgDHRs = k, msgKEM = Nothing, msgPN = 0, msgNs = 0} + parseAll (msgHeaderP v) (encodeMsgHeader v hdr) `shouldBe` Right hdr + +testKEMParams :: Expectation +testKEMParams = do + g <- C.newRandom + (kem, _) <- sntrup761Keypair g + let kemParams = ARKP SRKSProposed $ RKParamsProposed kem + parseAll (smpP @ARKEMParams) (smpEncode kemParams) `shouldBe` Right kemParams + (kem', _) <- sntrup761Keypair g + (ct, _) <- sntrup761Enc g kem + let kemParams' = ARKP SRKSAccepted $ RKParamsAccepted ct kem' + parseAll (smpP @ARKEMParams) (smpEncode kemParams') `shouldBe` Right kemParams' + +testMessageHeaderKEM :: forall a. AlgorithmI a => C.SAlgorithm a -> Expectation +testMessageHeaderKEM _ = do + g <- C.newRandom + (k, _) <- atomically $ C.generateKeyPair @a g + (kem, _) <- sntrup761Keypair g + let msgMaxVersion = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + msgKEM = Just . ARKP SRKSProposed $ RKParamsProposed kem + hdr = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM, msgPN = 0, msgNs = 0} + parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr) `shouldBe` Right hdr + (kem', _) <- sntrup761Keypair g + (ct, _) <- sntrup761Enc g kem + let msgKEM' = Just . ARKP SRKSAccepted $ RKParamsAccepted ct kem' + hdr' = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM = msgKEM', msgPN = 0, msgNs = 0} + parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr') `shouldBe` Right hdr' pattern Decrypted :: ByteString -> Either CryptoError (Either CryptoError ByteString) pattern Decrypted msg <- Right (Right msg) -type TestRatchets a = (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO () +type Encrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString) -testEncryptDecrypt :: TestRatchets a -testEncryptDecrypt alice bob = do +type Decrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) + +type EncryptDecryptSpec a = (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation + +type TestRatchets a = + TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> + TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> + Encrypt a -> + Decrypt a -> + EncryptDecryptSpec a -> + IO () + +deriving instance Eq (Ratchet a) + +deriving instance Eq (SndRatchet a) + +deriving instance Eq RcvRatchet + +deriving instance Eq RatchetKEM + +deriving instance Eq RatchetKEMAccepted + +deriving instance Eq RatchetInitParams + +deriving instance Eq RatchetKey + +instance Eq ARKEMParams where + (ARKP s ps) == (ARKP s' ps') = case testEquality s s' of + Just Refl -> ps == ps' + Nothing -> False + +deriving instance Eq (MsgHeader a) + +initRatchetKEM :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO () +initRatchetKEM s r = encryptDecrypt (Just $ PQEncOn) (const ()) (const ()) (s, "initialising ratchet") r + +testEncryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob (bob, "hello alice") #> alice (alice, "hello bob") #> bob Right b1 <- encrypt bob "how are you, alice?" @@ -88,8 +191,9 @@ testEncryptDecrypt alice bob = do (alice, "I'm here too, same") #> bob pure () -testSkippedMessages :: TestRatchets a -testSkippedMessages alice bob = do +testSkippedMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob Right msg1 <- encrypt bob "hello alice" Right msg2 <- encrypt bob "hello there again" Right msg3 <- encrypt bob "are you there?" @@ -99,8 +203,9 @@ testSkippedMessages alice bob = do Decrypted "hello alice" <- decrypt alice msg1 pure () -testManyMessages :: TestRatchets a -testManyMessages alice bob = do +testManyMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob (bob, "b1") #> alice (bob, "b2") #> alice (bob, "b3") #> alice @@ -117,8 +222,9 @@ testManyMessages alice bob = do (bob, "b15") #> alice (bob, "b16") #> alice -testSkippedAfterRatchetAdvance :: TestRatchets a -testSkippedAfterRatchetAdvance alice bob = do +testSkippedAfterRatchetAdvance :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a +testSkippedAfterRatchetAdvance agreeRatchetKEMs alice bob encrypt decrypt (#>) = do + when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob (bob, "b1") #> alice Right b2 <- encrypt bob "b2" Right b3 <- encrypt bob "b3" @@ -152,6 +258,84 @@ testSkippedAfterRatchetAdvance alice bob = do Decrypted "b11" <- decrypt alice b11 pure () +testDisableEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testDisableEnableKEM alice bob _ _ _ = do + (bob, "hello alice") !#> alice + (alice, "hello bob") !#> bob + (bob, "disabling KEM") !#>\ alice + (alice, "still disabling KEM") !#> bob + (bob, "now KEM is disabled") \#> alice + (alice, "KEM is disabled for both sides") \#> bob + (bob, "trying to enable KEM") \#>! alice + (alice, "but unless alice enables it too it won't enable") \#> bob + (bob, "KEM is disabled") \#> alice + (alice, "KEM is disabled for both sides") \#> bob + (bob, "enabling KEM again") \#>! alice + (alice, "and alice accepts it this time") \#>! bob + (bob, "still enabling KEM") \#>! alice + (alice, "now KEM is enabled") !#> bob + (bob, "KEM is enabled for both sides") !#> alice + +testDisableEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testDisableEnableKEMStrict alice bob _ _ _ = do + (bob, "hello alice") !#>! alice + (alice, "hello bob") !#>! bob + (bob, "disabling KEM") !#>\ alice + (alice, "still disabling KEM") !#>! bob + (bob, "now KEM is disabled") \#>\ alice + (alice, "KEM is disabled for both sides") \#>\ bob + (bob, "trying to enable KEM") \#>! alice + (alice, "but unless alice enables it too it won't enable") \#>\ bob + (bob, "KEM is disabled") \#>! alice + (alice, "KEM is disabled for both sides") \#>\ bob + (bob, "enabling KEM again") \#>! alice + (alice, "and alice accepts it this time") \#>! bob + (bob, "still enabling KEM") \#>! alice + (alice, "now KEM is enabled") !#>! bob + (bob, "KEM is enabled for both sides") !#>! alice + +testEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testEnableKEM alice bob _ _ _ = do + (bob, "hello alice") \#> alice + (alice, "hello bob") \#> bob + (bob, "enabling KEM") \#>! alice + (bob, "KEM not enabled yet") \#>! alice + (alice, "accepting KEM") \#>! bob + (alice, "KEM not enabled yet here too") \#>! bob + (bob, "KEM is still not enabled") \#>! alice + (alice, "now KEM is enabled") !#>! bob + (bob, "now KEM is enabled for both sides") !#> alice + (alice, "still enabled for both sides") !#> bob + (bob, "still enabled for both sides 2") !#> alice + (alice, "disabling KEM") !#>\ bob + (bob, "KEM not disabled yet") !#> alice + (alice, "KEM disabled") \#> bob + (bob, "KEM disabled on both sides") \#> alice + +testEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a +testEnableKEMStrict alice bob _ _ _ = do + (bob, "hello alice") \#>\ alice + (alice, "hello bob") \#>\ bob + (bob, "enabling KEM") \#>! alice + (bob, "KEM not enabled yet") \#>! alice + (alice, "accepting KEM") \#>! bob + (alice, "KEM not enabled yet here too") \#>! bob + (bob, "KEM is still not enabled") \#>! alice + (alice, "now KEM is enabled") !#>! bob + (bob, "now KEM is enabled for both sides") !#>! alice + (alice, "still enabled for both sides") !#>! bob + (bob, "still enabled for both sides 2") !#>! alice + (alice, "disabling KEM") !#>\ bob + (bob, "KEM not disabled yet") !#>! alice + (alice, "KEM disabled") \#>\ bob + (bob, "KEM disabled on both sides") \#>! alice + (alice, "KEM still disabled 1") \#>\ bob + (bob, "KEM still disabled 2") \#>! alice + (alice, "KEM still disabled 3") \#>\ bob + (bob, "KEM still disabled 4") \#>! alice + (alice, "KEM still disabled 5") \#>\ bob + (bob, "KEM still disabled 6") \#>! alice + testKeyJSON :: forall a. AlgorithmI a => C.SAlgorithm a -> IO () testKeyJSON _ = do (k, pk) <- atomically . C.generateKeyPair @a =<< C.newRandom @@ -160,10 +344,33 @@ testKeyJSON _ = do testRatchetJSON :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testRatchetJSON _ = do - (alice, bob) <- initRatchets @a + (alice, bob, _, _, _) <- initRatchets @a testEncodeDecode alice testEncodeDecode bob +testVersionJSON :: IO () +testVersionJSON = do + testEncodeDecode $ rv 1 1 + testEncodeDecode $ rv 1 2 + -- let bad = RVersions 2 1 + -- Left err <- pure $ J.eitherDecode' @RatchetVersions (J.encode bad) + -- err `shouldContain` "bad version range" + testDecodeRV $ (1 :: Int, 2 :: Int) + testDecodeRV $ J.object ["current" .= (1 :: Int), "maxSupported" .= (2 :: Int)] + where + rv v1 v2 = RatchetVersions (VersionE2E v1) (VersionE2E v2) + testDecodeRV :: ToJSON a => a -> Expectation + testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2) + +testDecodeV2RatchetJSON :: IO () +testDecodeV2RatchetJSON = do + let v2RatchetJSON = "{\"rcVersion\":[2,2],\"rcAD\":\"2GEJrq48TmQse6NR16I-hrI0tSySZQ57E_g46nDceAPRAiF6j0drq26RTE7be6X7uiB4RaGJGf4QRXzcYuVtWw==\",\"rcDHRs\":\"TUM0Q0FRQXdCUVlESzJWdUJDSUVJRkNYbUxtSHQ3SUNfeHpGTi1Qb3ZqTVQ3S2p6XzZlZlBjOG9fRFY2RWxKOQ==\",\"rcRK\":\"BOX2X7YW5qDSp2XknY_lqacSrtDqQNPvS6iJlZIs3G0=\",\"rcNs\":0,\"rcNr\":0,\"rcPN\":0,\"rcNHKs\":\"IMouSkXUvzT_mo0WM-pqEUK09-HTLk9WOTCFQglyQxU=\",\"rcNHKr\":\"g-tus1clYPV0rGlzkf5a959tUqDYQVZ1FpcPeXdKwxI=\"}" + Right (r :: Ratchet X25519) <- pure $ J.eitherDecodeStrict' v2RatchetJSON + rcSupportKEM r `shouldBe` PQSupportOff + rcEnableKEM r `shouldBe` PQEncOff + rcSndKEM r `shouldBe` PQEncOff + rcRcvKEM r `shouldBe` PQEncOff + testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation testEncodeDecode x = do let j = J.encode x @@ -173,77 +380,255 @@ testEncodeDecode x = do testX3dh :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dh _ = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion - let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice - paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff + let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () testX3dhV1 _ = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g 1 - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g 1 - let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice - paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g (VersionE2E 1) Nothing + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g (VersionE2E 1) PQSupportOff + let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob paramsAlice `shouldBe` paramsBob -(#>) :: (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation -(alice, msg) #> bob = do - Right msg' <- encrypt alice msg - Decrypted msg'' <- decrypt bob msg' +testPqX3dhProposeInReply :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeInReply _ = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (no KEM) + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff + -- propose KEM in reply + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob + paramsAlice `compatibleRatchets` paramsBob + +testPqX3dhProposeAccept :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeAccept _ = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice + -- accept KEM + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM aliceKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob + paramsAlice `compatibleRatchets` paramsBob + +testPqX3dhProposeReject :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeReject _ = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice + -- reject KEM + (pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob + paramsAlice `compatibleRatchets` paramsBob + +testPqX3dhAcceptWithoutProposalError :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhAcceptWithoutProposalError _ = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (no KEM) + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff + E2ERatchetParams _ _ _ Nothing <- pure e2eAlice + -- incorrectly accept KEM + -- we don't have key in proposal, so we just generate it + (k, _) <- sntrup761Keypair g + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM k) + pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice `shouldBe` Left C.CERatchetKEMState + runExceptT (pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob) `shouldReturn` Left C.CERatchetKEMState + +testPqX3dhProposeAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO () +testPqX3dhProposeAgain _ = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice + -- propose KEM again in reply - this is not an error + (pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob + paramsAlice `compatibleRatchets` paramsBob + +compatibleRatchets :: (RatchetInitParams, x) -> (RatchetInitParams, x) -> Expectation +compatibleRatchets + (RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, _) + (RatchetInitParams {assocData = ad, ratchetKey = rk, sndHK = shk, rcvNextHK = rnhk, kemAccepted = ka}, _) = do + assocData == ad && ratchetKey == rk && sndHK == shk && rcvNextHK == rnhk `shouldBe` True + case (kemAccepted, ka) of + (Just RatchetKEMAccepted {rcPQRr, rcPQRss, rcPQRct}, Just RatchetKEMAccepted {rcPQRr = pqk, rcPQRss = pqss, rcPQRct = pqct}) -> + pqk /= rcPQRr && pqss == rcPQRss && pqct == rcPQRct `shouldBe` True + (Nothing, Nothing) -> pure () + _ -> expectationFailure "RatchetInitParams params are not compatible" + +encryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Maybe PQEncryption -> (Ratchet a -> ()) -> (Ratchet a -> ()) -> EncryptDecryptSpec a +encryptDecrypt pqEnc validSnd validRcv (alice, msg) bob = do + Right msg' <- withTVar (encrypt_ pqEnc) validSnd alice msg + Decrypted msg'' <- decrypt' validRcv bob msg' msg'' `shouldBe` msg -withRatchets :: forall a. (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()) -> Expectation -withRatchets test = do +-- enable KEM (currently disabled) +(\#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) \#>! r = encryptDecrypt (Just PQEncOn) noSndKEM noRcvKEM (s, msg) r + +-- enable KEM (currently enabled) +(!#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) !#>! r = encryptDecrypt (Just PQEncOn) hasSndKEM hasRcvKEM (s, msg) r + +-- KEM enabled (no user preference) +(!#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) !#> r = encryptDecrypt Nothing hasSndKEM hasRcvKEM (s, msg) r + +-- disable KEM (currently enabled) +(!#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) !#>\ r = encryptDecrypt (Just PQEncOff) hasSndKEM hasRcvKEM (s, msg) r + +-- disable KEM (currently disabled) +(\#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) \#>\ r = encryptDecrypt (Just PQEncOff) noSndKEM noSndKEM (s, msg) r + +-- KEM disabled (no user preference) +(\#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a +(s, msg) \#> r = encryptDecrypt Nothing noSndKEM noSndKEM (s, msg) r + +withRatchets_ :: IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) -> TestRatchets a -> Expectation +withRatchets_ initRatchets_ test = do ga <- C.newRandom gb <- C.newRandom - (a, b) <- initRatchets @a + (a, b, encrypt, decrypt, (#>)) <- initRatchets_ alice <- newTVarIO (ga, a, M.empty) bob <- newTVarIO (gb, b, M.empty) - test alice bob `shouldReturn` () + test alice bob encrypt decrypt (#>) `shouldReturn` () -initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a) +initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) initRatchets = do g <- C.newRandom - (pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams g currentE2EEncryptVersion - (pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams g currentE2EEncryptVersion - let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice - paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + (pkBob1, pkBob2, _pKemParams@Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v Nothing + (pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob (_, pkBob3) <- atomically $ C.generateKeyPair g - let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob - alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice - pure (alice, bob) + let vs = testRatchetVersions PQSupportOff + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOff + pure (alice, bob, encrypt' noSndKEM, decrypt' noRcvKEM, (\#>)) -encrypt_ :: AlgorithmI a => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) -encrypt_ (_, rc, _) msg = - runExceptT (rcEncrypt rc paddedMsgLen msg) +initRatchetsKEMProposed :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) +initRatchetsKEMProposed = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (no KEM) + (pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff + -- propose KEM in reply + let useKem = AUseKEM SRKSProposed ProposeKEM + (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob + (_, pkBob3) <- atomically $ C.generateKeyPair g + let vs = testRatchetVersions PQSupportOn + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn + pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) + +initRatchetsKEMAccepted :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) +initRatchetsKEMAccepted = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (propose) + (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn + E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice + -- accept + let useKem = AUseKEM SRKSAccepted (AcceptKEM aliceKem) + (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob + (_, pkBob3) <- atomically $ C.generateKeyPair g + let vs = testRatchetVersions PQSupportOn + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn + pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) + +initRatchetsKEMProposedAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) +initRatchetsKEMProposedAgain = do + g <- C.newRandom + let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion + -- initiate (propose KEM) + (pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn + -- propose KEM again in reply + let useKem = AUseKEM SRKSProposed ProposeKEM + (pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem) + Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice + Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob + (_, pkBob3) <- atomically $ C.generateKeyPair g + let vs = testRatchetVersions PQSupportOn + bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob + alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn + pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>)) + +testRatchetVersions :: PQSupport -> RatchetVersions +testRatchetVersions pq = + let v = maxVersion $ supportedE2EEncryptVRange pq + in RatchetVersions v v + +encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff)) +encrypt_ pqEnc_ (_, rc, _) msg = + -- print msg >> + runExceptT (rcEncrypt rc paddedMsgLen msg pqEnc_ currentE2EEncryptVersion) >>= either (pure . Left) checkLength where checkLength (msg', rc') = do - B.length msg' `shouldBe` fullMsgLen + B.length msg' `shouldBe` fullMsgLen rc' pure $ Right (msg', rc', SMDNoChange) decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff)) decrypt_ (g, rc, smks) msg = runExceptT $ rcDecrypt g rc smks msg -encrypt :: AlgorithmI a => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString) -encrypt = withTVar encrypt_ +encrypt' :: AlgorithmI a => (Ratchet a -> ()) -> Encrypt a +encrypt' = withTVar $ encrypt_ Nothing -decrypt :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString)) -decrypt = withTVar decrypt_ +decrypt' :: (AlgorithmI a, DhAlgorithm a) => (Ratchet a -> ()) -> Decrypt a +decrypt' = withTVar decrypt_ + +noSndKEM :: Ratchet a -> () +noSndKEM Ratchet {rcSndKEM = PQEncOn} = error "snd ratchet has KEM" +noSndKEM _ = () + +noRcvKEM :: Ratchet a -> () +noRcvKEM Ratchet {rcRcvKEM = PQEncOn} = error "rcv ratchet has KEM" +noRcvKEM _ = () + +hasSndKEM :: Ratchet a -> () +hasSndKEM Ratchet {rcSndKEM = PQEncOn} = () +hasSndKEM _ = error "snd ratchet has no KEM" + +hasRcvKEM :: Ratchet a -> () +hasRcvKEM Ratchet {rcRcvKEM = PQEncOn} = () +hasRcvKEM _ = error "rcv ratchet has no KEM" withTVar :: AlgorithmI a => ((TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e (r, Ratchet a, SkippedMsgDiff))) -> + (Ratchet a -> ()) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e r) -withTVar op rcVar msg = do +withTVar op valid rcVar msg = do (g, rc, smks) <- readTVarIO rcVar applyDiff smks <$$> (testEncodeDecode rc >> op (g, rc, smks) msg) >>= \case - Right (res, rc', smks') -> atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res) + Right (res, rc', smks') -> valid rc' `seq` atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res) Left e -> pure $ Left e where applyDiff smks (res, rc', smDiff) = (res, rc', applySMDiff smks smDiff) diff --git a/tests/AgentTests/EqInstances.hs b/tests/AgentTests/EqInstances.hs new file mode 100644 index 000000000..aaaa2de51 --- /dev/null +++ b/tests/AgentTests/EqInstances.hs @@ -0,0 +1,25 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} + +module AgentTests.EqInstances where + +import Data.Type.Equality +import Simplex.Messaging.Agent.Store + +instance Eq SomeConn where + SomeConn d c == SomeConn d' c' = case testEquality d d' of + Just Refl -> c == c' + _ -> False + +deriving instance Eq (Connection d) + +deriving instance Eq (SConnType d) + +deriving instance Eq (StoredRcvQueue q) + +deriving instance Eq (StoredSndQueue q) + +deriving instance Eq (DBQueueId q) + +deriving instance Eq ClientNtfCreds diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index a5e994c68..e17f44df3 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -9,7 +10,9 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} module AgentTests.FunctionalAPITests @@ -20,6 +23,9 @@ module AgentTests.FunctionalAPITests makeConnection, exchangeGreetingsMsgId, switchComplete, + createConnection, + joinConnection, + sendMessage, runRight, runRight_, get, @@ -29,7 +35,12 @@ module AgentTests.FunctionalAPITests nGet, (##>), (=##>), + pattern CON, + pattern CONF, + pattern INFO, + pattern REQ, pattern Msg, + pattern Msg', agentCfgV7, ) where @@ -44,38 +55,49 @@ import qualified Data.ByteString.Char8 as B import Data.Either (isRight) import Data.Int (Int64) import Data.List (nub) +import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as M import Data.Maybe (isNothing) import qualified Data.Set as S import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality +import Data.Word (Word16) import qualified Database.SQLite.Simple as SQL import SMPAgentClient -import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerV7, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) -import Simplex.Messaging.Agent +import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, withSmpServerV7) +import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) +import qualified Simplex.Messaging.Agent as A import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) -import Simplex.Messaging.Agent.Protocol as Agent +import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ) +import qualified Simplex.Messaging.Agent.Protocol as A import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteStore (dbNew)) import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultSMPClientConfig) import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern PQEncOn, pattern PQEncOff, pattern PQSupportOn, pattern PQSupportOff) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Transport (authBatchCmdsNTFVersion) +import Simplex.Messaging.Notifications.Transport (NTFVersion, pattern VersionNTF, authBatchCmdsNTFVersion) import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..), SubscriptionMode (..), supportedSMPClientVRange) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport (ATransport (..), basicAuthSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion) -import Simplex.Messaging.Version +import Simplex.Messaging.Transport (ATransport (..), SMPVersion, VersionSMP, authCmdsSMPVersion, batchCmdsSMPVersion, basicAuthSMPVersion, currentServerSMPRelayVersion) +import Simplex.Messaging.Version (VersionRange (..)) +import qualified Simplex.Messaging.Version as V +import Simplex.Messaging.Version.Internal (Version (..)) import System.Directory (copyFile, renameFile) import Test.Hspec import UnliftIO +import Util import XFTPClient (testXFTPServer) type AEntityTransmission e = (ACorrId, ConnId, ACommand 'Agent e) +-- deriving instance Eq (ValidFileDescription p) + (##>) :: (HasCallStack, MonadUnliftIO m) => m (AEntityTransmission e) -> AEntityTransmission e -> m () a ##> t = withTimeout a (`shouldBe` t) @@ -118,47 +140,74 @@ pGet c = do DISCONNECT {} -> pGet c _ -> pure t -pattern Msg :: MsgBody -> ACommand 'Agent e -pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody +pattern CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand 'Agent e +pattern CONF conId srvs connInfo <- A.CONF conId PQSupportOn srvs connInfo -pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> ACommand 'Agent e +pattern INFO :: ConnInfo -> ACommand 'Agent 'AEConn +pattern INFO connInfo = A.INFO PQSupportOn connInfo + +pattern REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand 'Agent e +pattern REQ invId srvs connInfo <- A.REQ invId PQSupportOn srvs connInfo + +pattern CON :: ACommand 'Agent 'AEConn +pattern CON = A.CON PQEncOn + +pattern Msg :: MsgBody -> ACommand 'Agent e +pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk, pqEncryption = PQEncOn} _ msgBody + +pattern Msg' :: AgentMsgId -> PQEncryption -> MsgBody -> ACommand 'Agent e +pattern Msg' aMsgId pq msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _), pqEncryption = pq} _ msgBody + +pattern MsgErr :: AgentMsgId -> MsgErrorType -> MsgBody -> ACommand 'Agent 'AEConn pattern MsgErr msgId err msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err} _ msgBody -pattern Rcvd :: AgentMsgId -> ACommand 'Agent e +pattern MsgErr' :: AgentMsgId -> MsgErrorType -> PQEncryption -> MsgBody -> ACommand 'Agent 'AEConn +pattern MsgErr' msgId err pq msgBody <- MSG MsgMeta {recipient = (msgId, _), integrity = MsgError err, pqEncryption = pq} _ msgBody + +pattern Rcvd :: AgentMsgId -> ACommand 'Agent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] -smpCfgVPrev :: ProtocolClientConfig +smpCfgVPrev :: ProtocolClientConfig SMPVersion smpCfgVPrev = (smpCfg agentCfg) {serverVRange = prevRange $ serverVRange $ smpCfg agentCfg} -smpCfgV7 :: ProtocolClientConfig -smpCfgV7 = (smpCfg agentCfg) {serverVRange = mkVersionRange 4 authCmdsSMPVersion} +smpCfgV7 :: ProtocolClientConfig SMPVersion +smpCfgV7 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion} -ntfCfgV2 :: ProtocolClientConfig -ntfCfgV2 = (smpCfg agentCfg) {serverVRange = mkVersionRange 1 authBatchCmdsNTFVersion} +ntfCfgV2 :: ProtocolClientConfig NTFVersion +ntfCfgV2 = (smpCfg agentCfg) {serverVRange = V.mkVersionRange (VersionNTF 1) authBatchCmdsNTFVersion} agentCfgVPrev :: AgentConfig agentCfgVPrev = agentCfg { sndAuthAlg = C.AuthAlg C.SEd25519, - smpAgentVRange = prevRange $ smpAgentVRange agentCfg, + smpAgentVRange = \_ -> prevRange $ smpAgentVRange agentCfg PQSupportOff, smpClientVRange = prevRange $ smpClientVRange agentCfg, - e2eEncryptVRange = prevRange $ e2eEncryptVRange agentCfg, + e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQSupportOff, smpCfg = smpCfgVPrev } +-- agent config for the next client version agentCfgV7 :: AgentConfig -agentCfgV7 = +agentCfgV7 = agentCfg { sndAuthAlg = C.AuthAlg C.SX25519, + smpAgentVRange = \_ -> V.mkVersionRange duplexHandshakeSMPAgentVersion $ max pqdrSMPAgentVersion currentSMPAgentVersion, + e2eEncryptVRange = \_ -> V.mkVersionRange CR.kdfX3DHE2EEncryptVersion $ max CR.pqRatchetE2EEncryptVersion CR.currentE2EEncryptVersion, smpCfg = smpCfgV7, ntfCfg = ntfCfgV2 } agentCfgRatchetVPrev :: AgentConfig -agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = prevRange $ e2eEncryptVRange agentCfg} +agentCfgRatchetVPrev = agentCfg {e2eEncryptVRange = \_ -> prevRange $ e2eEncryptVRange agentCfg PQSupportOff} -prevRange :: VersionRange -> VersionRange -prevRange vr = vr {maxVersion = max (minVersion vr) (maxVersion vr - 1)} +prevRange :: VersionRange v -> VersionRange v +prevRange vr = vr {maxVersion = max (minVersion vr) (prevVersion $ maxVersion vr)} + +prevVersion :: Version v -> Version v +prevVersion (Version v) = Version (v - 1) + +mkVersionRange :: Word16 -> Word16 -> VersionRange v +mkVersionRange v1 v2 = V.mkVersionRange (Version v1) (Version v2) runRight_ :: (Eq e, Show e, HasCallStack) => ExceptT e IO () -> Expectation runRight_ action = runExceptT action `shouldReturn` Right () @@ -184,6 +233,18 @@ inAnyOrder g rs = do expected :: a -> (a -> Bool) -> Bool expected r rp = rp r +createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c) +createConnection c userId enableNtfs cMode clientData = A.createConnection c userId enableNtfs cMode clientData (IKNoPQ PQSupportOn) + +joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId +joinConnection c userId enableNtfs cReq connInfo = A.joinConnection c userId enableNtfs cReq connInfo PQSupportOn + +sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> SMP.MsgFlags -> MsgBody -> m AgentMsgId +sendMessage c connId msgFlags msgBody = do + (msgId, pqEnc) <- A.sendMessage c connId PQEncOn msgFlags msgBody + liftIO $ pqEnc `shouldBe` PQEncOn + pure msgId + functionalAPITests :: ATransport -> Spec functionalAPITests t = do describe "Establishing duplex connection" $ do @@ -192,6 +253,8 @@ functionalAPITests t = do withSmpServer t testServerMultipleIdentities it "should connect with two peers" $ withSmpServer t testAgentClient3 + it "should establish connection without PQ encryption and enable it" $ + withSmpServer t testEnablePQEncryption describe "Establishing duplex connection v2, different Ratchet versions" $ testRatchetMatrix2 t runAgentClientTest describe "Establish duplex connection via contact address" $ @@ -217,6 +280,7 @@ functionalAPITests t = do testIncreaseConnAgentVersionMaxCompatible t it "should increase when connection was negotiated on different versions" $ testIncreaseConnAgentVersionStartDifferentVersion t + -- TODO PQ tests for upgrading connection to PQ encryption it "should deliver message after client restart" $ testDeliverClientRestart t it "should deliver messages to the user once, even if repeat delivery is made by the server (no ACK)" $ @@ -259,9 +323,9 @@ functionalAPITests t = do describe "Batching SMP commands" $ do it "should subscribe to multiple (200) subscriptions with batching" $ testBatchedSubscriptions 200 10 t - -- 200 subscriptions gets very slow with test coverage, use below test instead - xit "should subscribe to multiple (6) subscriptions with batching" $ - testBatchedSubscriptions 6 3 t + skip "faster version of the previous test (200 subscriptions gets very slow with test coverage)" $ + it "should subscribe to multiple (6) subscriptions with batching" $ + testBatchedSubscriptions 6 3 t describe "Async agent commands" $ do it "should connect using async agent commands" $ withSmpServer t testAsyncCommands @@ -273,6 +337,17 @@ functionalAPITests t = do testDeleteConnectionAsync t it "join connection when reply queue creation fails" $ testJoinConnectionAsyncReplyError t + describe "delete connection waiting for delivery" $ do + it "should delete connection immediately if there are no pending messages" $ + testWaitDeliveryNoPending t + it "should delete connection after waiting for delivery to complete" $ + testWaitDelivery t + it "should delete connection if message can't be delivered due to AUTH error" $ + testWaitDeliveryAUTHErr t + it "should delete connection by timeout even if message wasn't delivered" $ + testWaitDeliveryTimeout t + it "should delete connection by timeout, message in progress can be delivered" $ + testWaitDeliveryTimeout2 t describe "Users" $ do it "should create and delete user with connections" $ withSmpServer t testUsers @@ -300,8 +375,8 @@ functionalAPITests t = do describe "should switch two connections simultaneously, abort one" $ testServerMatrix2 t testSwitch2ConnectionsAbort1 describe "SMP basic auth" $ do - let v4 = basicAuthSMPVersion - 1 - forM_ (nub [authCmdsSMPVersion - 1, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do + let v4 = prevVersion basicAuthSMPVersion + forM_ (nub [prevVersion authCmdsSMPVersion, authCmdsSMPVersion, currentServerSMPRelayVersion]) $ \v -> do describe ("v" <> show v <> ": with server auth") $ do -- allow NEW | server auth, v | clnt1 auth, v | clnt2 auth, v | 2 - success, 1 - JOIN fail, 0 - NEW fail it "success " $ testBasicAuth t True (Just "abcd", v) (Just "abcd", v) (Just "abcd", v) `shouldReturn` 2 @@ -345,9 +420,9 @@ functionalAPITests t = do it "should send delivery receipt only in connection v3+" $ testDeliveryReceiptsVersion t it "send delivery receipts concurrently with messages" $ testDeliveryReceiptsConcurrent t -testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int +testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> IO Int testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do - let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = mkVersionRange 4 srvVersion} + let testCfg = cfg {allowNewQueues, newQueueBasicAuth = srvAuth, smpServerVRange = V.mkVersionRange batchCmdsSMPVersion srvVersion} canCreate1 = canCreateQueue allowNewQueues srv clnt1 canCreate2 = canCreateQueue allowNewQueues srv clnt2 expected @@ -358,32 +433,31 @@ testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do created `shouldBe` expected pure created -canCreateQueue :: Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> Bool +canCreateQueue :: Bool -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> Bool canCreateQueue allowNew (srvAuth, srvVersion) (clntAuth, clntVersion) = let v = basicAuthSMPVersion in allowNew && (isNothing srvAuth || (srvVersion >= v && clntVersion >= v && srvAuth == clntAuth)) -testMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testMatrix2 :: ATransport -> (PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testMatrix2 t runTest = do - it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 runTest - it "v7 to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 runTest - it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 runTest - it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 runTest - it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 runTest - it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 runTest - it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 runTest - it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 runTest + it "v7" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn + it "v7 to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 $ runTest PQSupportOn + it "current to v7" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 $ runTest PQSupportOn + it "current with v7 server" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn + it "current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn + it "prev" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfgVPrev 3 $ runTest PQSupportOff + it "prev to current" $ withSmpServer t $ runTestCfg2 agentCfgVPrev agentCfg 3 $ runTest PQSupportOff + it "current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgVPrev 3 $ runTest PQSupportOff -testRatchetMatrix2 :: ATransport -> (AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec +testRatchetMatrix2 :: ATransport -> (PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO ()) -> Spec testRatchetMatrix2 t runTest = do - it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 runTest - pendingV "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 runTest - pendingV "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 runTest - pendingV "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 runTest - where - pendingV = - let vr = e2eEncryptVRange agentCfg - in if minVersion vr == maxVersion vr then xit else it + it "ratchet next" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfgV7 3 $ runTest PQSupportOn + it "ratchet next to current" $ withSmpServerV7 t $ runTestCfg2 agentCfgV7 agentCfg 3 $ runTest PQSupportOn + it "ratchet current to next" $ withSmpServerV7 t $ runTestCfg2 agentCfg agentCfgV7 3 $ runTest PQSupportOn + it "ratchet current" $ withSmpServer t $ runTestCfg2 agentCfg agentCfg 3 $ runTest PQSupportOn + it "ratchet prev" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfgRatchetVPrev 3 $ runTest PQSupportOff + it "ratchets prev to current" $ withSmpServer t $ runTestCfg2 agentCfgRatchetVPrev agentCfg 3 $ runTest PQSupportOff + it "ratchets current to prev" $ withSmpServer t $ runTestCfg2 agentCfg agentCfgRatchetVPrev 3 $ runTest PQSupportOff testServerMatrix2 :: ATransport -> (InitialAgentServers -> IO ()) -> Spec testServerMatrix2 t runTest = do @@ -405,40 +479,112 @@ withAgentClientsCfg2 aCfg bCfg runTest = do withAgentClients2 :: (AgentClient -> AgentClient -> IO ()) -> IO () withAgentClients2 = withAgentClientsCfg2 agentCfg agentCfg -runAgentClientTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientTest alice@AgentClient {} bob baseId = +runAgentClientTest :: HasCallStack => PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientTest pqSupport alice@AgentClient {} bob baseId = runRight_ $ do - (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe - ("", _, CONF confId _ "bob's connInfo") <- get alice + (bobId, qInfo) <- A.createConnection alice 1 True SCMInvitation Nothing (IKNoPQ pqSupport) SMSubscribe + aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe + ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" - get alice ##> ("", bobId, CON) - get bob ##> ("", aliceId, INFO "alice's connInfo") - get bob ##> ("", aliceId, CON) + let pqEnc = CR.pqSupportToEnc pqSupport + get alice ##> ("", bobId, A.CON pqEnc) + get bob ##> ("", aliceId, A.INFO pqSupport "alice's connInfo") + get bob ##> ("", aliceId, A.CON pqEnc) -- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4 - 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + 1 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" get alice ##> ("", bobId, SENT $ baseId + 1) - 2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 2 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "how are you?" get alice ##> ("", bobId, SENT $ baseId + 2) - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "hello") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 1) Nothing - get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "how are you?") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 2) Nothing - 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + 3 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" get bob ##> ("", aliceId, SENT $ baseId + 3) - 4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1" + 4 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 1" get bob ##> ("", aliceId, SENT $ baseId + 4) - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "hello too") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 3) Nothing - get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "message 1") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId - 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + 5 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 2" get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where - msgId = subtract baseId + msgId = subtract baseId . fst + +testEnablePQEncryption :: HasCallStack => IO () +testEnablePQEncryption = do + ca <- getSMPAgentClient' 1 agentCfg initAgentServers testDB + cb <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + g <- C.newRandom + runRight_ $ do + (aId, bId) <- makeConnection_ PQSupportOff ca cb + let a = (ca, aId) + b = (cb, bId) + (a, 4, "msg 1") \#>\ b + (b, 5, "msg 2") \#>\ a + -- 45 bytes is used by agent message envelope inside double ratchet message envelope + let largeMsg g' pqEnc = atomically $ C.randomBytes (e2eEncUserMsgLength pqdrSMPAgentVersion pqEnc - 45) g' + lrg <- largeMsg g PQSupportOff + (a, 6, lrg) \#>\ b + (b, 7, lrg) \#>\ a + -- enabling PQ encryption + (a, 8, lrg) \#>! b + (b, 9, lrg) \#>! a + -- switched to smaller envelopes (before reporting PQ encryption enabled) + sml <- largeMsg g PQSupportOn + -- fail because of message size + Left (A.CMD LARGE) <- tryError $ A.sendMessage ca bId PQEncOn SMP.noMsgFlags lrg + (11, PQEncOff) <- A.sendMessage ca bId PQEncOn SMP.noMsgFlags sml + get ca =##> \case ("", connId, SENT 11) -> connId == bId; _ -> False + get cb =##> \case ("", connId, MsgErr' 10 MsgSkipped {} PQEncOff msg') -> connId == aId && msg' == sml; _ -> False + ackMessage cb aId 10 Nothing + -- -- fail in reply to sync IDss + Left (A.CMD LARGE) <- tryError $ A.sendMessage cb aId PQEncOn SMP.noMsgFlags lrg + (12, PQEncOn) <- A.sendMessage cb aId PQEncOn SMP.noMsgFlags sml + get cb =##> \case ("", connId, SENT 12) -> connId == aId; _ -> False + get ca =##> \case ("", connId, MsgErr' 12 MsgSkipped {} PQEncOn msg') -> connId == bId && msg' == sml; _ -> False + ackMessage ca bId 12 Nothing + -- PQ encryption now enabled + (a, 13, sml) !#>! b + (b, 14, sml) !#>! a + -- disabling PQ encryption + (a, 15, sml) !#>\ b + (b, 16, sml) !#>\ a + (a, 17, sml) \#>\ b + (b, 18, sml) \#>\ a + -- enabling PQ encryption again + (a, 19, sml) \#>! b + (b, 20, sml) \#>! a + (a, 21, sml) \#>! b + (b, 22, sml) !#>! a + (a, 23, sml) !#>! b + -- disabling PQ encryption again + (b, 24, sml) !#>\ a + (a, 25, sml) !#>\ b + (b, 26, sml) \#>\ a + (a, 27, sml) \#>\ b + -- PQ encryption is now disabled, but support remained enabled, so we still cannot send larger messages + Left (A.CMD LARGE) <- tryError $ A.sendMessage ca bId PQEncOff SMP.noMsgFlags (sml <> "123456") + Left (A.CMD LARGE) <- tryError $ A.sendMessage cb aId PQEncOff SMP.noMsgFlags (sml <> "123456") + pure () + where + (\#>\) = PQEncOff `sndRcv` PQEncOff + (\#>!) = PQEncOff `sndRcv` PQEncOn + (!#>!) = PQEncOn `sndRcv` PQEncOn + (!#>\) = PQEncOn `sndRcv` PQEncOff + +sndRcv :: PQEncryption -> PQEncryption -> ((AgentClient, ConnId), AgentMsgId, MsgBody) -> (AgentClient, ConnId) -> ExceptT AgentErrorType IO () +sndRcv pqEnc pqEnc' ((c1, id1), mId, msg) (c2, id2) = do + r <- A.sendMessage c1 id2 pqEnc' SMP.noMsgFlags msg + liftIO $ r `shouldBe` (mId, pqEnc) + get c1 =##> \case ("", connId, SENT mId') -> connId == id2 && mId' == mId; _ -> False + get c2 =##> \case ("", connId, Msg' mId' pq msg') -> connId == id1 && mId' == mId && msg' == msg && pq == pqEnc; _ -> False + ackMessage c2 id1 mId Nothing testAgentClient3 :: HasCallStack => IO () testAgentClient3 = do @@ -466,42 +612,45 @@ testAgentClient3 = do get c =##> \case ("", connId, Msg "c5") -> connId == aIdForC; _ -> False ackMessage c aIdForC 5 Nothing -runAgentClientContactTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO () -runAgentClientContactTest alice bob baseId = +runAgentClientContactTest :: HasCallStack => PQSupport -> AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientContactTest pqSupport alice bob baseId = runRight_ $ do - (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe - ("", _, REQ invId _ "bob's connInfo") <- get alice - bobId <- acceptContact alice True invId "alice's connInfo" SMSubscribe - ("", _, CONF confId _ "alice's connInfo") <- get bob + (_, qInfo) <- A.createConnection alice 1 True SCMContact Nothing (IKNoPQ pqSupport) SMSubscribe + aliceId <- A.joinConnection bob 1 True qInfo "bob's connInfo" pqSupport SMSubscribe + ("", _, A.REQ invId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` pqSupport + bobId <- acceptContact alice True invId "alice's connInfo" PQSupportOn SMSubscribe + ("", _, A.CONF confId pqSup'' _ "alice's connInfo") <- get bob + liftIO $ pqSup'' `shouldBe` pqSupport allowConnection bob aliceId confId "bob's connInfo" - get alice ##> ("", bobId, INFO "bob's connInfo") - get alice ##> ("", bobId, CON) - get bob ##> ("", aliceId, CON) + let pqEnc = CR.pqSupportToEnc pqSupport + get alice ##> ("", bobId, A.INFO pqSupport "bob's connInfo") + get alice ##> ("", bobId, A.CON pqEnc) + get bob ##> ("", aliceId, A.CON pqEnc) -- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4 - 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + 1 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" get alice ##> ("", bobId, SENT $ baseId + 1) - 2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 2 <- msgId <$> A.sendMessage alice bobId pqEnc SMP.noMsgFlags "how are you?" get alice ##> ("", bobId, SENT $ baseId + 2) - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "hello") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 1) Nothing - get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' _ pq "how are you?") -> c == aliceId && pq == pqEnc; _ -> False ackMessage bob aliceId (baseId + 2) Nothing - 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + 3 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" get bob ##> ("", aliceId, SENT $ baseId + 3) - 4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1" + 4 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 1" get bob ##> ("", aliceId, SENT $ baseId + 4) - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "hello too") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 3) Nothing - get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' _ pq "message 1") -> c == bobId && pq == pqEnc; _ -> False ackMessage alice bobId (baseId + 4) Nothing suspendConnection alice bobId - 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + 5 <- msgId <$> A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "message 2" get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH)) deleteConnection alice bobId liftIO $ noMessages alice "nothing else should be delivered to alice" where - msgId = subtract baseId + msgId = subtract baseId . fst noMessages :: HasCallStack => AgentClient -> String -> Expectation noMessages c err = tryGet `shouldReturn` () @@ -625,12 +774,12 @@ testAllowConnectionClientRestart t = do testIncreaseConnAgentVersion :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersion t = do - alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB - bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 + alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection alice bob - exchangeGreetingsMsgId 4 alice bobId bob aliceId + (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 pure (aliceId, bobId) @@ -638,60 +787,60 @@ testIncreaseConnAgentVersion t = do -- version doesn't increase if incompatible disconnectAgentClient alice - alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB + alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB runRight_ $ do subscribeConnection alice2 bobId - exchangeGreetingsMsgId 6 alice2 bobId bob aliceId + exchangeGreetingsMsgId_ PQEncOff 6 alice2 bobId bob aliceId checkVersion alice2 bobId 2 checkVersion bob aliceId 2 -- version increases if compatible disconnectAgentClient bob - bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 + bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 runRight_ $ do subscribeConnection bob2 aliceId - exchangeGreetingsMsgId 8 alice2 bobId bob2 aliceId + exchangeGreetingsMsgId_ PQEncOff 8 alice2 bobId bob2 aliceId checkVersion alice2 bobId 3 checkVersion bob2 aliceId 3 -- version doesn't decrease, even if incompatible disconnectAgentClient alice2 - alice3 <- getSMPAgentClient' 5 agentCfg {smpAgentVRange = mkVersionRange 2 2} initAgentServers testDB + alice3 <- getSMPAgentClient' 5 agentCfg {smpAgentVRange = \_ -> mkVersionRange 2 2} initAgentServers testDB runRight_ $ do subscribeConnection alice3 bobId - exchangeGreetingsMsgId 10 alice3 bobId bob2 aliceId + exchangeGreetingsMsgId_ PQEncOff 10 alice3 bobId bob2 aliceId checkVersion alice3 bobId 3 checkVersion bob2 aliceId 3 disconnectAgentClient bob2 - bob3 <- getSMPAgentClient' 6 agentCfg {smpAgentVRange = mkVersionRange 1 1} initAgentServers testDB2 + bob3 <- getSMPAgentClient' 6 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 1} initAgentServers testDB2 runRight_ $ do subscribeConnection bob3 aliceId - exchangeGreetingsMsgId 12 alice3 bobId bob3 aliceId + exchangeGreetingsMsgId_ PQEncOff 12 alice3 bobId bob3 aliceId checkVersion alice3 bobId 3 checkVersion bob3 aliceId 3 disconnectAgentClient alice3 disconnectAgentClient bob3 -checkVersion :: AgentClient -> ConnId -> Version -> ExceptT AgentErrorType IO () +checkVersion :: AgentClient -> ConnId -> Word16 -> ExceptT AgentErrorType IO () checkVersion c connId v = do ConnectionStats {connAgentVersion} <- getConnectionServers c connId - liftIO $ connAgentVersion `shouldBe` v + liftIO $ connAgentVersion `shouldBe` VersionSMPA v testIncreaseConnAgentVersionMaxCompatible :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersionMaxCompatible t = do - alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB - bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB2 + alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection alice bob - exchangeGreetingsMsgId 4 alice bobId bob aliceId + (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 pure (aliceId, bobId) @@ -699,14 +848,14 @@ testIncreaseConnAgentVersionMaxCompatible t = do -- version increases to max compatible disconnectAgentClient alice - alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB + alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB disconnectAgentClient bob - bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = mkVersionRange 1 4} initAgentServers testDB2 + bob2 <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = supportedSMPAgentVRange} initAgentServers testDB2 runRight_ $ do subscribeConnection alice2 bobId subscribeConnection bob2 aliceId - exchangeGreetingsMsgId 6 alice2 bobId bob2 aliceId + exchangeGreetingsMsgId_ PQEncOff 6 alice2 bobId bob2 aliceId checkVersion alice2 bobId 3 checkVersion bob2 aliceId 3 disconnectAgentClient alice2 @@ -714,12 +863,12 @@ testIncreaseConnAgentVersionMaxCompatible t = do testIncreaseConnAgentVersionStartDifferentVersion :: HasCallStack => ATransport -> IO () testIncreaseConnAgentVersionStartDifferentVersion t = do - alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 2} initAgentServers testDB - bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 + alice <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 2} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId) <- runRight $ do - (aliceId, bobId) <- makeConnection alice bob - exchangeGreetingsMsgId 4 alice bobId bob aliceId + (aliceId, bobId) <- makeConnection_ PQSupportOff alice bob + exchangeGreetingsMsgId_ PQEncOff 4 alice bobId bob aliceId checkVersion alice bobId 2 checkVersion bob aliceId 2 pure (aliceId, bobId) @@ -727,11 +876,11 @@ testIncreaseConnAgentVersionStartDifferentVersion t = do -- version increases to max compatible disconnectAgentClient alice - alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB + alice2 <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB runRight_ $ do subscribeConnection alice2 bobId - exchangeGreetingsMsgId 6 alice2 bobId bob aliceId + exchangeGreetingsMsgId_ PQEncOff 6 alice2 bobId bob aliceId checkVersion alice2 bobId 3 checkVersion bob aliceId 3 disconnectAgentClient alice2 @@ -968,7 +1117,7 @@ testRatchetSync t = withAgentClients2 $ \alice bob -> withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aliceId, bobId, bob2) <- setupDesynchronizedRatchet alice bob runRight $ do - ConnectionStats {ratchetSyncState} <- synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted get alice =##> ratchetSyncP bobId RSAgreed get bob2 =##> ratchetSyncP aliceId RSAgreed @@ -1012,13 +1161,13 @@ setupDesynchronizedRatchet alice bob = do runRight_ $ do subscribeConnection bob2 aliceId - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId False + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ synchronizeRatchet bob2 aliceId PQSupportOn False 8 <- sendMessage alice bobId SMP.noMsgFlags "hello 5" get alice ##> ("", bobId, SENT 8) get bob2 =##> ratchetSyncP aliceId RSRequired - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6" + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ sendMessage bob2 aliceId SMP.noMsgFlags "hello 6" pure () pure (aliceId, bobId, bob2) @@ -1043,7 +1192,7 @@ testRatchetSyncServerOffline t = withAgentClients2 $ \alice bob -> do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted withSmpServerStoreMsgLogOn t testPort $ \_ -> do @@ -1073,7 +1222,7 @@ testRatchetSyncClientRestart t = do setupDesynchronizedRatchet alice bob ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted disconnectAgentClient bob2 bob3 <- getSMPAgentClient' 3 agentCfg initAgentServers testDB2 @@ -1100,7 +1249,7 @@ testRatchetSyncSuspendForeground t = do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ ratchetSyncState `shouldBe` RSStarted suspendAgent bob2 0 @@ -1134,10 +1283,10 @@ testRatchetSyncSimultaneous t = do ("", "", DOWN _ _) <- nGet alice ("", "", DOWN _ _) <- nGet bob2 - ConnectionStats {ratchetSyncState = bRSS} <- runRight $ synchronizeRatchet bob2 aliceId False + ConnectionStats {ratchetSyncState = bRSS} <- runRight $ synchronizeRatchet bob2 aliceId PQSupportOn False liftIO $ bRSS `shouldBe` RSStarted - ConnectionStats {ratchetSyncState = aRSS} <- runRight $ synchronizeRatchet alice bobId True + ConnectionStats {ratchetSyncState = aRSS} <- runRight $ synchronizeRatchet alice bobId PQSupportOn True liftIO $ aRSS `shouldBe` RSStarted withSmpServerStoreMsgLogOn t testPort $ \_ -> do @@ -1192,17 +1341,25 @@ testOnlyCreatePull = withAgentClients2 $ \alice bob -> runRight_ $ do pure r makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection alice bob = makeConnectionForUsers alice 1 bob 1 +makeConnection = makeConnection_ PQSupportOn + +makeConnection_ :: PQSupport -> AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnection_ pqEnc alice bob = makeConnectionForUsers_ pqEnc alice 1 bob 1 makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnectionForUsers alice aliceUserId bob bobUserId = do - (bobId, qInfo) <- createConnection alice aliceUserId True SCMInvitation Nothing SMSubscribe - aliceId <- joinConnection bob bobUserId True qInfo "bob's connInfo" SMSubscribe - ("", _, CONF confId _ "bob's connInfo") <- get alice +makeConnectionForUsers = makeConnectionForUsers_ PQSupportOn + +makeConnectionForUsers_ :: PQSupport -> AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers_ pqSupport alice aliceUserId bob bobUserId = do + (bobId, qInfo) <- A.createConnection alice aliceUserId True SCMInvitation Nothing (CR.IKNoPQ pqSupport) SMSubscribe + aliceId <- A.joinConnection bob bobUserId True qInfo "bob's connInfo" pqSupport SMSubscribe + ("", _, A.CONF confId pqSup' _ "bob's connInfo") <- get alice + liftIO $ pqSup' `shouldBe` pqSupport allowConnection alice bobId confId "alice's connInfo" - get alice ##> ("", bobId, CON) - get bob ##> ("", aliceId, INFO "alice's connInfo") - get bob ##> ("", aliceId, CON) + let pqEnc = CR.pqSupportToEnc pqSupport + get alice ##> ("", bobId, A.CON pqEnc) + get bob ##> ("", aliceId, A.INFO pqSupport "alice's connInfo") + get bob ##> ("", aliceId, A.CON pqEnc) pure (aliceId, bobId) testInactiveNoSubs :: ATransport -> IO () @@ -1325,8 +1482,8 @@ testBatchedSubscriptions nCreate nDel t = do a <- getSMPAgentClient' 1 agentCfg initAgentServers2 testDB b <- getSMPAgentClient' 2 agentCfg initAgentServers2 testDB2 conns <- runServers $ do - conns <- replicateM (nCreate :: Int) $ makeConnection a b - forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId + conns <- replicateM (nCreate :: Int) $ makeConnection_ PQSupportOff a b + forM_ conns $ \(aId, bId) -> exchangeGreetings_ PQEncOff a bId b aId let (aIds', bIds') = unzip $ take nDel conns delete a bIds' delete b aIds' @@ -1347,10 +1504,10 @@ testBatchedSubscriptions nCreate nDel t = do (aIds', bIds') = unzip conns' subscribe a bIds subscribe b aIds - forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId + forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId_ PQEncOff 6 a bId b aId void $ resubscribeConnections a bIds void $ resubscribeConnections b aIds - forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId 8 a bId b aId + forM_ conns' $ \(aId, bId) -> exchangeGreetingsMsgId_ PQEncOff 8 a bId b aId delete a bIds' delete b aIds' deleteFail a bIds' @@ -1389,10 +1546,10 @@ testBatchedSubscriptions nCreate nDel t = do testAsyncCommands :: IO () testAsyncCommands = withAgentClients2 $ \alice bob -> runRight_ $ do - bobId <- createConnectionAsync alice 1 "1" True SCMInvitation SMSubscribe + bobId <- createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe ("1", bobId', INV (ACR _ qInfo)) <- get alice liftIO $ bobId' `shouldBe` bobId - aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" SMSubscribe + aliceId <- joinConnectionAsync bob 1 "2" True qInfo "bob's connInfo" PQSupportOn SMSubscribe ("2", aliceId', OK) <- get bob liftIO $ aliceId' `shouldBe` aliceId ("", _, CONF confId _ "bob's connInfo") <- get alice @@ -1428,7 +1585,7 @@ testAsyncCommands = ] ackMessageAsync alice "7" bobId (baseId + 4) Nothing get alice =##> \case ("7", _, OK) -> True; _ -> False - deleteConnectionAsync alice bobId + deleteConnectionAsync alice False bobId get alice =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bobId; _ -> False get alice =##> \case ("", c, DEL_CONN) -> c == bobId; _ -> False liftIO $ noMessages alice "nothing else should be delivered to alice" @@ -1439,7 +1596,7 @@ testAsyncCommands = testAsyncCommandsRestore :: ATransport -> IO () testAsyncCommandsRestore t = do alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB - bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation SMSubscribe + bobId <- runRight $ createConnectionAsync alice 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe liftIO $ noMessages alice "alice doesn't receive INV because server is down" disconnectAgentClient alice alice' <- liftIO $ getSMPAgentClient' 2 agentCfg initAgentServers testDB @@ -1456,7 +1613,7 @@ testAcceptContactAsync = (_, qInfo) <- createConnection alice 1 True SCMContact Nothing SMSubscribe aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" SMSubscribe ("", _, REQ invId _ "bob's connInfo") <- get alice - bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" SMSubscribe + bobId <- acceptContactAsync alice "1" True invId "alice's connInfo" PQSupportOn SMSubscribe get alice =##> \case ("1", c, OK) -> c == bobId; _ -> False ("", _, CONF confId _ "alice's connInfo") <- get bob allowConnection bob aliceId confId "bob's connInfo" @@ -1498,7 +1655,7 @@ testDeleteConnectionAsync t = do (bId3, _inv) <- createConnection a 1 True SCMInvitation Nothing SMSubscribe pure ([bId1, bId2, bId3] :: [ConnId]) runRight_ $ do - deleteConnectionsAsync a connIds + deleteConnectionsAsync a False connIds get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c `elem` connIds && (e == TIMEOUT || e == NETWORK); _ -> False @@ -1508,16 +1665,244 @@ testDeleteConnectionAsync t = do liftIO $ noMessages a "nothing else should be delivered to alice" disconnectAgentClient a +testWaitDeliveryNoPending :: ATransport -> IO () +testWaitDeliveryNoPending t = do + alice <- getSMPAgentClient' 1 agentCfg initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ Nothing) -> cId == bobId; _ -> False + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + + 3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + get bob ##> ("", aliceId, MERR (baseId + 3) (SMP AUTH)) + + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testWaitDelivery :: ATransport -> IO () +testWaitDelivery t = do + alice <- getSMPAgentClient' 1 agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + ("", "", DOWN _ _) <- nGet bob + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do + get alice ##> ("", bobId, SENT $ baseId + 3) + get alice ##> ("", bobId, SENT $ baseId + 4) + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + + liftIO $ + getInAnyOrder + bob + [ \case ("", "", APC SAENone (UP _ [cId])) -> cId == aliceId; _ -> False, + \case ("", cId, APC SAEConn (Msg "how are you?")) -> cId == aliceId; _ -> False + ] + ackMessage bob aliceId (baseId + 3) Nothing + get bob =##> \case ("", c, Msg "message 1") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 4) Nothing + + -- queue wasn't deleted (DEL never reached server, see DEL_RCVQ with error), so bob can send message + 5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2" + get bob ##> ("", aliceId, SENT $ baseId + 5) + + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testWaitDeliveryAUTHErr :: ATransport -> IO () +testWaitDeliveryAUTHErr t = do + alice <- getSMPAgentClient' 1 agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (_aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + deleteConnectionsAsync bob False [aliceId] + get bob =##> \case ("", cId, DEL_RCVQ _ _ Nothing) -> cId == aliceId; _ -> False + get bob =##> \case ("", cId, DEL_CONN) -> cId == aliceId; _ -> False + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> do + get alice ##> ("", bobId, MERR (baseId + 3) (SMP AUTH)) + get alice ##> ("", bobId, MERR (baseId + 4) (SMP AUTH)) + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testWaitDeliveryTimeout :: ATransport -> IO () +testWaitDeliveryTimeout t = do + alice <- getSMPAgentClient' 1 agentCfg {connDeleteDeliveryTimeout = 1, initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + ("", "", DOWN _ _) <- nGet bob + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + liftIO $ threadDelay 100000 + + withSmpServerStoreLogOn t testPort $ \_ -> do + nGet bob =##> \case ("", "", UP _ [cId]) -> cId == aliceId; _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + +testWaitDeliveryTimeout2 :: ATransport -> IO () +testWaitDeliveryTimeout2 t = do + alice <- getSMPAgentClient' 1 agentCfg {connDeleteDeliveryTimeout = 2, messageRetryInterval = fastMessageRetryInterval, initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers testDB + bob <- getSMPAgentClient' 2 agentCfg initAgentServers testDB2 + (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do + (aliceId, bobId) <- makeConnection alice bob + + 1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello" + get alice ##> ("", bobId, SENT $ baseId + 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId (baseId + 1) Nothing + + 2 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too" + get bob ##> ("", aliceId, SENT $ baseId + 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId (baseId + 2) Nothing + + pure (aliceId, bobId) + + runRight_ $ do + ("", "", DOWN _ _) <- nGet alice + ("", "", DOWN _ _) <- nGet bob + 3 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?" + 4 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "message 1" + deleteConnectionsAsync alice True [bobId] + get alice =##> \case ("", cId, DEL_RCVQ _ _ (Just (BROKER _ e))) -> cId == bobId && (e == TIMEOUT || e == NETWORK); _ -> False + get alice =##> \case ("", cId, DEL_CONN) -> cId == bobId; _ -> False + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + withSmpServerStoreLogOn t testPort $ \_ -> do + get alice ##> ("", bobId, SENT $ baseId + 3) + -- "message 1" not delivered + + liftIO $ + getInAnyOrder + bob + [ \case ("", "", APC SAENone (UP _ [cId])) -> cId == aliceId; _ -> False, + \case ("", cId, APC SAEConn (Msg "how are you?")) -> cId == aliceId; _ -> False + ] + liftIO $ noMessages alice "nothing else should be delivered to alice" + liftIO $ noMessages bob "nothing else should be delivered to bob" + + disconnectAgentClient alice + disconnectAgentClient bob + where + baseId = 3 + msgId = subtract baseId + testJoinConnectionAsyncReplyError :: HasCallStack => ATransport -> IO () testJoinConnectionAsyncReplyError t = do let initAgentServersSrv2 = initAgentServers {smp = userServers [noAuthSrv testSMPServer2]} a <- getSMPAgentClient' 1 agentCfg initAgentServers testDB b <- getSMPAgentClient' 2 agentCfg initAgentServersSrv2 testDB2 (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do - bId <- createConnectionAsync a 1 "1" True SCMInvitation SMSubscribe + bId <- createConnectionAsync a 1 "1" True SCMInvitation (IKNoPQ PQSupportOn) SMSubscribe ("1", bId', INV (ACR _ qInfo)) <- get a liftIO $ bId' `shouldBe` bId - aId <- joinConnectionAsync b 1 "2" True qInfo "bob's connInfo" SMSubscribe + aId <- joinConnectionAsync b 1 "2" True qInfo "bob's connInfo" PQSupportOn SMSubscribe liftIO $ threadDelay 500000 ConnectionStats {rcvQueuesInfo = [], sndQueuesInfo = [SndQueueInfo {}]} <- getConnectionServers b aId pure (aId, bId) @@ -1714,7 +2099,7 @@ testSwitchDelete servers = do stats <- switchConnectionAsync a "" bId liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted] phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing] - deleteConnectionAsync a bId + deleteConnectionAsync a False bId get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False get a =##> \case ("", c, DEL_RCVQ _ _ Nothing) -> c == bId; _ -> False get a =##> \case ("", c, DEL_CONN) -> c == bId; _ -> False @@ -1735,7 +2120,7 @@ testAbortSwitchStarted servers = do liftIO $ rcvSwchStatuses' stats `shouldMatchList` [Just RSSwitchStarted] phaseRcv a bId SPStarted [Just RSSendingQADD, Nothing] -- repeat switch is prohibited - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ switchConnectionAsync a "" bId -- abort current switch stats' <- abortConnectionSwitch a bId liftIO $ rcvSwchStatuses' stats' `shouldMatchList` [Nothing] @@ -1857,7 +2242,7 @@ testCannotAbortSwitchSecured servers = do withA' $ \a -> do phaseRcv a bId SPConfirmed [Just RSSendingQADD, Nothing] phaseRcv a bId SPSecured [Just RSSendingQUSE, Nothing] - Left Agent.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId + Left A.CMD {cmdErr = PROHIBITED} <- runExceptT $ abortConnectionSwitch a bId pure () withA $ \a -> withB $ \b -> runRight_ $ do subscribeConnection a bId @@ -1988,7 +2373,7 @@ testSwitch2ConnectionsAbort1 servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -testCreateQueueAuth :: HasCallStack => Version -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int +testCreateQueueAuth :: HasCallStack => VersionSMP -> (Maybe BasicAuth, VersionSMP) -> (Maybe BasicAuth, VersionSMP) -> IO Int testCreateQueueAuth srvVersion clnt1 clnt2 = do a <- getClient 1 clnt1 testDB b <- getClient 2 clnt2 testDB2 @@ -2014,7 +2399,7 @@ testCreateQueueAuth srvVersion clnt1 clnt2 = do where getClient clientId (clntAuth, clntVersion) db = let servers = initAgentServers {smp = userServers [ProtoServerWithAuth testSMPServer clntAuth]} - smpCfg = (defaultSMPClientConfig :: ProtocolClientConfig) {serverVRange = mkVersionRange (basicAuthSMPVersion - 1) clntVersion} + smpCfg = (defaultSMPClientConfig :: ProtocolClientConfig SMPVersion) {serverVRange = V.mkVersionRange (prevVersion basicAuthSMPVersion) clntVersion} sndAuthAlg = if srvVersion >= authCmdsSMPVersion && clntVersion >= authCmdsSMPVersion then C.AuthAlg C.SX25519 else C.AuthAlg C.SEd25519 in getSMPAgentClient' clientId agentCfg {smpCfg, sndAuthAlg} servers db @@ -2049,53 +2434,59 @@ testDeliveryReceipts = get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False ackMessage a bId 6 $ Just "" get b =##> \case ("", c, Rcvd 6) -> c == aId; _ -> False - ackMessage b aId 7 (Just "") `catchError` \e -> liftIO $ e `shouldBe` Agent.CMD PROHIBITED + ackMessage b aId 7 (Just "") `catchError` \e -> liftIO $ e `shouldBe` A.CMD PROHIBITED ackMessage b aId 7 Nothing testDeliveryReceiptsVersion :: HasCallStack => ATransport -> IO () testDeliveryReceiptsVersion t = do - a <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB - b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = mkVersionRange 1 3} initAgentServers testDB2 + a <- getSMPAgentClient' 1 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB + b <- getSMPAgentClient' 2 agentCfg {smpAgentVRange = \_ -> mkVersionRange 1 3} initAgentServers testDB2 withSmpServerStoreMsgLogOn t testPort $ \_ -> do (aId, bId) <- runRight $ do - (aId, bId) <- makeConnection a b + (aId, bId) <- makeConnection_ PQSupportOff a b checkVersion a bId 3 checkVersion b aId 3 - 4 <- sendMessage a bId SMP.noMsgFlags "hello" + (4, _) <- A.sendMessage a bId PQEncOff SMP.noMsgFlags "hello" get a ##> ("", bId, SENT 4) - get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False + get b =##> \case ("", c, Msg' 4 PQEncOff "hello") -> c == aId; _ -> False ackMessage b aId 4 $ Just "" liftIO $ noMessages a "no delivery receipt (unsupported version)" - 5 <- sendMessage b aId SMP.noMsgFlags "hello too" + (5, _) <- A.sendMessage b aId PQEncOff SMP.noMsgFlags "hello too" get b ##> ("", aId, SENT 5) - get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False + get a =##> \case ("", c, Msg' 5 PQEncOff "hello too") -> c == bId; _ -> False ackMessage a bId 5 $ Just "" liftIO $ noMessages b "no delivery receipt (unsupported version)" pure (aId, bId) disconnectAgentClient a disconnectAgentClient b - a' <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = mkVersionRange 1 4} initAgentServers testDB - b' <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = mkVersionRange 1 4} initAgentServers testDB2 + a' <- getSMPAgentClient' 3 agentCfg {smpAgentVRange = supportedSMPAgentVRange} initAgentServers testDB + b' <- getSMPAgentClient' 4 agentCfg {smpAgentVRange = supportedSMPAgentVRange} initAgentServers testDB2 runRight_ $ do subscribeConnection a' bId subscribeConnection b' aId - exchangeGreetingsMsgId 6 a' bId b' aId + exchangeGreetingsMsgId_ PQEncOff 6 a' bId b' aId checkVersion a' bId 4 checkVersion b' aId 4 - 8 <- sendMessage a' bId SMP.noMsgFlags "hello" + (8, PQEncOff) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello" get a' ##> ("", bId, SENT 8) - get b' =##> \case ("", c, Msg "hello") -> c == aId; _ -> False + get b' =##> \case ("", c, Msg' 8 PQEncOff "hello") -> c == aId; _ -> False ackMessage b' aId 8 $ Just "" get a' =##> \case ("", c, Rcvd 8) -> c == bId; _ -> False ackMessage a' bId 9 Nothing - 10 <- sendMessage b' aId SMP.noMsgFlags "hello too" + (10, PQEncOff) <- A.sendMessage b' aId PQEncOn SMP.noMsgFlags "hello too" get b' ##> ("", aId, SENT 10) - get a' =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False + get a' =##> \case ("", c, Msg' 10 PQEncOff "hello too") -> c == bId; _ -> False ackMessage a' bId 10 $ Just "" get b' =##> \case ("", c, Rcvd 10) -> c == aId; _ -> False ackMessage b' aId 11 Nothing + (12, _) <- A.sendMessage a' bId PQEncOn SMP.noMsgFlags "hello 2" + get a' ##> ("", bId, SENT 12) + get b' =##> \case ("", c, Msg' 12 PQEncOff "hello 2") -> c == aId; _ -> False + ackMessage b' aId 12 $ Just "" + get a' =##> \case ("", c, Rcvd 12) -> c == bId; _ -> False + ackMessage a' bId 13 Nothing disconnectAgentClient a' disconnectAgentClient b' @@ -2107,7 +2498,7 @@ testDeliveryReceiptsConcurrent t = t1 <- liftIO getCurrentTime concurrently_ (runClient "a" a bId) (runClient "b" b aId) t2 <- liftIO getCurrentTime - diffUTCTime t2 t1 `shouldSatisfy` (< 15) + diffUTCTime t2 t1 `shouldSatisfy` (< 60) liftIO $ noMessages a "nothing else should be delivered to alice" liftIO $ noMessages b "nothing else should be delivered to bob" where @@ -2118,7 +2509,6 @@ testDeliveryReceiptsConcurrent t = numMsgs = 100 send = runRight_ $ replicateM_ numMsgs $ do - -- liftIO $ print $ cName <> ": sendMessage" void $ sendMessage client connId SMP.noMsgFlags "hello" receive = runRight_ $ @@ -2146,7 +2536,7 @@ testDeliveryReceiptsConcurrent t = receiveLoop (n - 1) getWithTimeout :: ExceptT AgentErrorType IO (AEntityTransmission 'AEConn) getWithTimeout = do - 1000000 `timeout` get client >>= \case + 3000000 `timeout` get client >>= \case Just r -> pure r _ -> error "timeout" @@ -2256,20 +2646,26 @@ testServerMultipleIdentities = testE2ERatchetParams12 exchangeGreetings :: HasCallStack => AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () -exchangeGreetings = exchangeGreetingsMsgId 4 +exchangeGreetings = exchangeGreetings_ PQEncOn + +exchangeGreetings_ :: HasCallStack => PQEncryption -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetings_ pqEnc = exchangeGreetingsMsgId_ pqEnc 4 exchangeGreetingsMsgId :: HasCallStack => Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () -exchangeGreetingsMsgId msgId alice bobId bob aliceId = do - msgId1 <- sendMessage alice bobId SMP.noMsgFlags "hello" - liftIO $ msgId1 `shouldBe` msgId +exchangeGreetingsMsgId = exchangeGreetingsMsgId_ PQEncOn + +exchangeGreetingsMsgId_ :: HasCallStack => PQEncryption -> Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetingsMsgId_ pqEnc msgId alice bobId bob aliceId = do + msgId1 <- A.sendMessage alice bobId pqEnc SMP.noMsgFlags "hello" + liftIO $ msgId1 `shouldBe` (msgId, pqEnc) get alice ##> ("", bobId, SENT msgId) - get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + get bob =##> \case ("", c, Msg' mId pq "hello") -> c == aliceId && mId == msgId && pq == pqEnc; _ -> False ackMessage bob aliceId msgId Nothing - msgId2 <- sendMessage bob aliceId SMP.noMsgFlags "hello too" + msgId2 <- A.sendMessage bob aliceId pqEnc SMP.noMsgFlags "hello too" let msgId' = msgId + 1 - liftIO $ msgId2 `shouldBe` msgId' + liftIO $ msgId2 `shouldBe` (msgId', pqEnc) get bob ##> ("", aliceId, SENT msgId') - get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + get alice =##> \case ("", c, Msg' mId pq "hello too") -> c == bobId && mId == msgId' && pq == pqEnc; _ -> False ackMessage alice bobId msgId' Nothing exchangeGreetingsMsgIds :: HasCallStack => AgentClient -> ConnId -> Int64 -> AgentClient -> ConnId -> Int64 -> ExceptT AgentErrorType IO () diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index bb1e687b3..d8354efed 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -12,7 +12,28 @@ module AgentTests.NotificationTests where -- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging) -import AgentTests.FunctionalAPITests (agentCfgV7, exchangeGreetingsMsgId, get, getSMPAgentClient', makeConnection, nGet, runRight, runRight_, switchComplete, testServerMatrix2, withAgentClientsCfg2, (##>), (=##>), pattern Msg) +import AgentTests.FunctionalAPITests + ( agentCfgV7, + createConnection, + exchangeGreetingsMsgId, + get, + getSMPAgentClient', + joinConnection, + makeConnection, + nGet, + runRight, + runRight_, + sendMessage, + switchComplete, + testServerMatrix2, + withAgentClientsCfg2, + (##>), + (=##>), + pattern CON, + pattern CONF, + pattern INFO, + pattern Msg, + ) import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Monad import Control.Monad.Except @@ -28,10 +49,10 @@ import Data.Text.Encoding (encodeUtf8) import NtfClient import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testDB3, testNtfServer, testNtfServer2) import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn) -import Simplex.Messaging.Agent +import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) -import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO) import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String @@ -46,6 +67,7 @@ import Simplex.Messaging.Transport (ATransport) import System.Directory (doesFileExist, removeFile) import Test.Hspec import UnliftIO +import Util removeFileIfExists :: FilePath -> IO () removeFileIfExists filePath = do @@ -125,8 +147,8 @@ testNtfMatrix t runTest = do it "next servers: SMP v7, NTF v2; next clients: v7/v2" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfgV7 agentCfgV7 runTest it "next servers: SMP v7, NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfg agentCfg runTest it "curr servers: SMP v6, NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfg agentCfg agentCfg runTest - -- this case will cannot be supported - see RFC - xit "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest + skip "this case cannot be supported - see RFC" $ + it "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest -- servers can be migrated in any order it "servers: next SMP v7, curr NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfg agentCfg agentCfg runTest it "servers: curr SMP v6, next NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfgV2 agentCfg agentCfg runTest diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 5799a0492..4bac4fb83 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -1,20 +1,27 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.SQLiteTests (storeTests) where +import AgentTests.EqInstances () import Control.Concurrent.Async (concurrently_) import Control.Concurrent.STM import Control.Exception (SomeException) import Control.Monad (replicateM_) +import Control.Monad.Trans.Except +import Crypto.Random (ChaChaDRG) import Data.ByteArray (ScrubbedBytes) import Data.ByteString.Char8 (ByteString) import Data.List (isInfixOf) @@ -38,9 +45,11 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQSupportOn) +import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Encoding.String (StrEncoding (..)) -import Simplex.Messaging.Protocol (SubscriptionMode (..)) +import Simplex.Messaging.Protocol (SubscriptionMode (..), pattern VersionSMPC) import qualified Simplex.Messaging.Protocol as SMP import System.Random import Test.Hspec @@ -91,7 +100,7 @@ storeTests = do testForeignKeysEnabled describe "db methods" $ do describe "Queue and Connection management" $ do - describe "createRcvConn" $ do + describe "create Rcv connection" $ do testCreateRcvConn testCreateRcvConnRandomId testCreateRcvConnDuplicate @@ -172,7 +181,17 @@ testForeignKeysEnabled = `shouldThrow` (\e -> SQL.sqlError e == SQL.ErrorConstraint) cData1 :: ConnData -cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk} +cData1 = + ConnData + { userId = 1, + connId = "conn1", + connAgentVersion = VersionSMPA 1, + enableNtfs = True, + lastExternalSndId = 0, + deleted = False, + ratchetSyncState = RSOk, + pqSupport = CR.PQSupportOn + } testPrivateAuthKey :: C.APrivateAuthKey testPrivateAuthKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe" @@ -203,7 +222,7 @@ rcvQueue1 = primary = True, dbReplaceQueueId = Nothing, rcvSwchStatus = Nothing, - smpClientVersion = 1, + smpClientVersion = VersionSMPC 1, clientNtfCreds = Nothing, deleteErrors = 0 } @@ -224,9 +243,15 @@ sndQueue1 = primary = True, dbReplaceQueueId = Nothing, sndSwchStatus = Nothing, - smpClientVersion = 1 + smpClientVersion = VersionSMPC 1 } +createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue)) +createRcvConn db g cData rq cMode = runExceptT $ do + connId <- ExceptT $ createNewConn db g cData cMode + rq' <- ExceptT $ updateNewConnRcv db connId rq + pure (connId, rq') + testCreateRcvConn :: SpecWith SQLiteStore testCreateRcvConn = it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do @@ -312,8 +337,8 @@ testDeleteRcvConn = Right (_, rq) <- createRcvConn db g cData1 rcvQueue1 SCMInvitation getConn db "conn1" `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq)) - deleteConn db "conn1" - `shouldReturn` () + deleteConn db Nothing "conn1" + `shouldReturn` Just "conn1" getConn db "conn1" `shouldReturn` Left SEConnNotFound @@ -324,8 +349,8 @@ testDeleteSndConn = Right (_, sq) <- createSndConn db g cData1 sndQueue1 getConn db "conn1" `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq)) - deleteConn db "conn1" - `shouldReturn` () + deleteConn db Nothing "conn1" + `shouldReturn` Just "conn1" getConn db "conn1" `shouldReturn` Left SEConnNotFound @@ -337,8 +362,8 @@ testDeleteDuplexConn = Right sq <- upgradeRcvConnToDuplex db "conn1" sndQueue1 getConn db "conn1" `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq])) - deleteConn db "conn1" - `shouldReturn` () + deleteConn db Nothing "conn1" + `shouldReturn` Just "conn1" getConn db "conn1" `shouldReturn` Left SEConnNotFound @@ -362,7 +387,7 @@ testUpgradeRcvConnToDuplex = sndSwchStatus = Nothing, primary = True, dbReplaceQueueId = Nothing, - smpClientVersion = 1 + smpClientVersion = VersionSMPC 1 } upgradeRcvConnToDuplex db "conn1" anotherSndQueue `shouldReturn` Left (SEBadConnType CSnd) @@ -391,7 +416,7 @@ testUpgradeSndConnToDuplex = rcvSwchStatus = Nothing, primary = True, dbReplaceQueueId = Nothing, - smpClientVersion = 1, + smpClientVersion = VersionSMPC 1, clientNtfCreds = Nothing, deleteErrors = 0 } @@ -459,7 +484,8 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = { integrity = MsgOk, recipient = (unId internalId, ts), sndMsgId = externalSndId, - broker = (brokerId, ts) + broker = (brokerId, ts), + pqEncryption = CR.PQEncOn }, msgType = AM_A_MSG_, msgFlags = SMP.noMsgFlags, @@ -497,6 +523,7 @@ mkSndMsgData internalId internalSndId internalHash = msgType = AM_A_MSG_, msgFlags = SMP.noMsgFlags, msgBody = hw, + pqEncryption = CR.PQEncOn, internalHash, prevMsgHash = internalHash } @@ -635,7 +662,7 @@ testGetPendingServerCommand st = do Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1) corrId' `shouldBe` "4" where - command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) SMSubscribe + command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe corruptCmd :: DB.Connection -> ByteString -> ConnId -> IO () corruptCmd db corrId connId = DB.execute db "UPDATE commands SET command = cast('bad' as blob) WHERE conn_id = ? AND corr_id = ?" (connId, corrId) diff --git a/tests/CoreTests/BatchingTests.hs b/tests/CoreTests/BatchingTests.hs index eb9b62d3d..996d2fed1 100644 --- a/tests/CoreTests/BatchingTests.hs +++ b/tests/CoreTests/BatchingTests.hs @@ -15,7 +15,6 @@ import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Version (Version) import Test.Hspec batchingTests :: Spec @@ -253,27 +252,27 @@ testClientBatchWithLargeMessageV7 = do (length rs1', length rs2') `shouldBe` (74, 136) all lenOk [s1', s2'] `shouldBe` True -testClientStub :: IO (ProtocolClient ErrorType BrokerMsg) +testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) testClientStub = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g - atomically $ clientStub g sessId (authCmdsSMPVersion - 1) Nothing + atomically $ smpClientStub g sessId subModeSMPVersion Nothing -clientStubV7 :: IO (ProtocolClient ErrorType BrokerMsg) +clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg) clientStubV7 = do g <- C.newRandom sessId <- atomically $ C.randomBytes 32 g (rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey - atomically $ clientStub g sessId authCmdsSMPVersion thAuth_ + atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_ randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) -randomSUB = randomSUB_ C.SEd25519 (authCmdsSMPVersion - 1) +randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion randomSUBv7 :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUBv7 = randomSUB_ C.SEd25519 authCmdsSMPVersion -randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSUB_ a v sessId = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -284,13 +283,13 @@ randomSUB_ a v sessId = do TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, rId, Cmd SRecipient SUB) pure $ (,tToSend) <$> authTransmission thAuth_ (Just rpKey) corrId tForAuth -randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmd = randomSUBCmd_ C.SEd25519 -randomSUBCmdV7 :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmdV7 = randomSUBCmd_ C.SEd25519 -- same as v6 -randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) +randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg) randomSUBCmd_ a c = do g <- C.newRandom rId <- atomically $ C.randomBytes 24 g @@ -298,12 +297,12 @@ randomSUBCmd_ a c = do mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB) randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) -randomSEND = randomSEND_ C.SEd25519 (authCmdsSMPVersion - 1) +randomSEND = randomSEND_ C.SEd25519 subModeSMPVersion randomSENDv7 :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSENDv7 = randomSEND_ C.SX25519 authCmdsSMPVersion -randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) +randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString)) randomSEND_ a v sessId len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g @@ -315,7 +314,7 @@ randomSEND_ a v sessId len = do TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, sId, Cmd SSender $ SEND noMsgFlags msg) pure $ (,tToSend) <$> authTransmission thAuth_ (Just spKey) corrId tForAuth -testTHandleParams :: Version -> ByteString -> THandleParams +testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion testTHandleParams v sessionId = THandleParams { sessionId, @@ -326,20 +325,20 @@ testTHandleParams v sessionId = batch = True } -testTHandleAuth :: Version -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth) +testTHandleAuth :: VersionSMP -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth) testTHandleAuth v g (C.APublicAuthKey a k) = case a of C.SX25519 | v >= authCmdsSMPVersion -> do (_, privKey) <- atomically $ C.generateKeyPair g pure $ Just THandleAuth {peerPubKey = k, privKey} _ -> pure Nothing -randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmd = randomSENDCmd_ C.SEd25519 -randomSENDCmdV7 :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmdV7 = randomSENDCmd_ C.SX25519 -randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) +randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg) randomSENDCmd_ a c len = do g <- C.newRandom sId <- atomically $ C.randomBytes 24 g diff --git a/tests/CoreTests/CryptoTests.hs b/tests/CoreTests/CryptoTests.hs index 39bc17c4b..35e82d6d2 100644 --- a/tests/CoreTests/CryptoTests.hs +++ b/tests/CoreTests/CryptoTests.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -Wno-orphans #-} module CoreTests.CryptoTests (cryptoTests) where @@ -13,6 +15,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.Lazy as LT import qualified Data.Text.Lazy.Encoding as LE +import Data.Type.Equality import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import Simplex.Messaging.Crypto.SNTRUP761.Bindings @@ -91,6 +94,16 @@ cryptoTests = do describe "sntrup761" $ it "should enc/dec key" testSNTRUP761 +instance Eq C.APublicKey where + C.APublicKey a k == C.APublicKey a' k' = case testEquality a a' of + Just Refl -> k == k' + Nothing -> False + +instance Eq C.APrivateKey where + C.APrivateKey a k == C.APrivateKey a' k' = case testEquality a a' of + Just Refl -> k == k' + Nothing -> False + testPadUnpadFile :: IO () testPadUnpadFile = do let f = "tests/tmp/testpad" diff --git a/tests/CoreTests/ProtocolErrorTests.hs b/tests/CoreTests/ProtocolErrorTests.hs index 6dc6f2c02..7b1a7b813 100644 --- a/tests/CoreTests/ProtocolErrorTests.hs +++ b/tests/CoreTests/ProtocolErrorTests.hs @@ -12,7 +12,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) -import Simplex.FileTransfer.Protocol (XFTPErrorType (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType (..)) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 70f2d93ab..91722228b 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -1,9 +1,11 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TypeApplications #-} module CoreTests.TRcvQueuesTests where +import AgentTests.EqInstances () import qualified Data.List.NonEmpty as L import qualified Data.Map as M import qualified Data.Set as S @@ -11,7 +13,7 @@ import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId) import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..)) import qualified Simplex.Messaging.Agent.TRcvQueues as RQ import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Protocol (SMPServer) +import Simplex.Messaging.Protocol (SMPServer, pattern VersionSMPC) import Test.Hspec import UnliftIO @@ -136,7 +138,7 @@ dummyRQ userId server connId = primary = True, dbReplaceQueueId = Nothing, rcvSwchStatus = Nothing, - smpClientVersion = 123, + smpClientVersion = VersionSMPC 123, clientNtfCreds = Nothing, deleteErrors = 0 } diff --git a/tests/CoreTests/VersionRangeTests.hs b/tests/CoreTests/VersionRangeTests.hs index be02e38b7..cef556376 100644 --- a/tests/CoreTests/VersionRangeTests.hs +++ b/tests/CoreTests/VersionRangeTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module CoreTests.VersionRangeTests where @@ -8,6 +9,7 @@ module CoreTests.VersionRangeTests where import GHC.Generics (Generic) import Generic.Random (genericArbitraryU) import Simplex.Messaging.Version +import Simplex.Messaging.Version.Internal import Test.Hspec import Test.Hspec.QuickCheck (modifyMaxSuccess) import Test.QuickCheck @@ -16,6 +18,10 @@ data V = V1 | V2 | V3 | V4 | V5 deriving (Eq, Enum, Ord, Generic, Show) instance Arbitrary V where arbitrary = genericArbitraryU +data T + +instance VersionScope T + versionRangeTests :: Spec versionRangeTests = modifyMaxSuccess (const 1000) $ do describe "VersionRange construction" $ do @@ -25,31 +31,31 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do (pure $! vr 2 1) `shouldThrow` anyErrorCall describe "compatible version" $ do it "should choose mutually compatible max version" $ do - (vr 1 1, vr 1 1) `compatible` Just 1 - (vr 1 1, vr 1 2) `compatible` Just 1 - (vr 1 2, vr 1 2) `compatible` Just 2 - (vr 1 2, vr 2 3) `compatible` Just 2 - (vr 1 3, vr 2 3) `compatible` Just 3 - (vr 1 3, vr 2 4) `compatible` Just 3 + (vr 1 1, vr 1 1) `compatible` Just (Version 1) + (vr 1 1, vr 1 2) `compatible` Just (Version 1) + (vr 1 2, vr 1 2) `compatible` Just (Version 2) + (vr 1 2, vr 2 3) `compatible` Just (Version 2) + (vr 1 3, vr 2 3) `compatible` Just (Version 3) + (vr 1 3, vr 2 4) `compatible` Just (Version 3) (vr 1 2, vr 3 4) `compatible` Nothing it "should check if version is compatible" $ do - isCompatible (1 :: Version) (vr 1 2) `shouldBe` True - isCompatible (2 :: Version) (vr 1 2) `shouldBe` True - isCompatible (2 :: Version) (vr 1 1) `shouldBe` False - isCompatible (1 :: Version) (vr 2 2) `shouldBe` False + isCompatible @T (Version 1) (vr 1 2) `shouldBe` True + isCompatible @T (Version 2) (vr 1 2) `shouldBe` True + isCompatible @T (Version 2) (vr 1 1) `shouldBe` False + isCompatible @T (Version 1) (vr 2 2) `shouldBe` False it "compatibleVersion should pass isCompatible check" . property $ \((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) -> min1 > max1 || min2 > max2 -- one of ranges is invalid, skip testing it - || let w = fromIntegral . fromEnum - vr1 = mkVersionRange (w min1) (w max1) :: VersionRange - vr2 = mkVersionRange (w min2) (w max2) :: VersionRange + || let w = Version . fromIntegral . fromEnum + vr1 = mkVersionRange (w min1) (w max1) :: VersionRange T + vr2 = mkVersionRange (w min2) (w max2) :: VersionRange T in case compatibleVersion vr1 vr2 of Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2 _ -> True where - vr = mkVersionRange - compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation + vr v1 v2 = mkVersionRange (Version v1) (Version v2) + compatible :: (VersionRange T, VersionRange T) -> Maybe (Version T) -> Expectation (vr1, vr2) `compatible` v = do (vr1, vr2) `checkCompatible` v (vr2, vr1) `checkCompatible` v diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 43558a86c..5a2dbb8de 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -34,6 +34,7 @@ import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Notifications.Protocol (NtfResponse) import Simplex.Messaging.Notifications.Server (runNtfServerBlocking) import Simplex.Messaging.Notifications.Server.Env import Simplex.Messaging.Notifications.Server.Push.APNS @@ -70,7 +71,7 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" ntfTestStoreLogFile :: FilePath ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log" -testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a +testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do @@ -114,8 +115,8 @@ ntfServerCfg = ntfServerCfgV2 :: NtfServerConfig ntfServerCfgV2 = ntfServerCfg - { ntfServerVRange = mkVersionRange 1 authBatchCmdsNTFVersion, - smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange 4 authCmdsSMPVersion}} + { ntfServerVRange = mkVersionRange initialNTFVersion authBatchCmdsNTFVersion, + smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}} } withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a @@ -139,7 +140,7 @@ withNtfServerOn t port' = withNtfServerThreadOn t port' . const withNtfServer :: ATransport -> IO a -> IO a withNtfServer t = withNtfServerOn t ntfTestPort -runNtfTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a +runNtfTest :: forall c a. Transport c => (THandleNTF c -> IO a) -> IO a runNtfTest test = withNtfServer (transport @c) $ testNtfClient test ntfServerTest :: @@ -147,10 +148,10 @@ ntfServerTest :: (Transport c, Encoding smp) => TProxy c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) + IO (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -159,7 +160,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation +ntfTest :: Transport c => TProxy c -> (THandleNTF c -> IO ()) -> Expectation ntfTest _ test' = runNtfTest test' `shouldReturn` () data APNSMockRequest = APNSMockRequest diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index e29a292ee..e7e2018c2 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -5,7 +5,9 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} +{-# OPTIONS_GHC -Wno-orphans #-} module NtfServerTests where @@ -37,6 +39,7 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS import qualified Simplex.Messaging.Notifications.Server.Push.APNS as APNS +import Simplex.Messaging.Notifications.Transport (THandleNTF) import Simplex.Messaging.Parsers (parse, parseAll) import Simplex.Messaging.Protocol hiding (notification) import Simplex.Messaging.Transport @@ -50,30 +53,32 @@ ntfServerTests t = do ntfSyntaxTests :: ATransport -> Spec ntfSyntaxTests (ATransport t) = do - it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) + it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", NRErr $ CMD UNKNOWN) describe "NEW" $ do - it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX) - it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX) - it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH) - it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH) + it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", NRErr $ CMD SYNTAX) + it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", NRErr $ CMD SYNTAX) + it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", NRErr $ CMD NO_AUTH) + it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", NRErr $ CMD HAS_AUTH) where (>#>) :: Encoding smp => (Maybe TransmissionAuth, ByteString, ByteString, smp) -> - (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) -> + (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) -> Expectation command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command)) -sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +deriving instance Eq NtfResponse + +sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) +signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse) signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index f1ed84d68..871652f53 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -26,7 +26,7 @@ import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.Server -import Simplex.Messaging.Version (VersionRange, mkVersionRange) +import Simplex.Messaging.Version (mkVersionRange) import System.Environment (lookupEnv) import System.Info (os) import Test.Hspec @@ -34,6 +34,7 @@ import UnliftIO.Concurrent import qualified UnliftIO.Exception as E import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar) import UnliftIO.Timeout (timeout) +import Util testHost :: NonEmpty TransportHost testHost = "localhost" @@ -60,17 +61,17 @@ testServerStatsBackupFile :: FilePath testServerStatsBackupFile = "tests/tmp/smp-server-stats.log" xit' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a) -xit' = if os == "linux" then xit else it +xit' d = if os == "linux" then skip "skipped on Linux" . it d else it d xit'' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a) xit'' d t = do ci <- runIO $ lookupEnv "CI" - (if ci == Just "true" then xit else it) d t + (if ci == Just "true" then skip "skipped on CI" . it d else it d) t -testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a +testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a testSMPClient = testSMPClientVR supportedClientSMPRelayVRange -testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRange -> (THandle c -> m a) -> m a +testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a testSMPClientVR vr client = do Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do @@ -109,7 +110,7 @@ cfg = } cfgV7 :: ServerConfig -cfgV7 = cfg {smpServerVRange = mkVersionRange 4 authCmdsSMPVersion} +cfgV7 = cfg {smpServerVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion} withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile} @@ -148,16 +149,16 @@ withSmpServer t = withSmpServerOn t testPort withSmpServerV7 :: HasCallStack => ATransport -> IO a -> IO a withSmpServerV7 t = withSmpServerConfigOn t cfgV7 testPort . const -runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a +runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandleSMP c -> IO a) -> IO a runSmpTest test = withSmpServer (transport @c) $ testSMPClient test -runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a +runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a runSmpTestN = runSmpTestNCfg cfg supportedClientSMPRelayVRange -runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> Int -> (HasCallStack => [THandle c] -> IO a) -> IO a +runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a runSmpTestNCfg srvCfg clntVR nClients test = withSmpServerConfigOn (transport @c) srvCfg testPort $ \_ -> run nClients [] where - run :: Int -> [THandle c] -> IO a + run :: Int -> [THandleSMP c] -> IO a run 0 hs = test hs run n hs = testSMPClientVR clntVR $ \h -> run (n - 1) (h : hs) @@ -169,7 +170,7 @@ smpServerTest :: IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h where - tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () + tPut' :: THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO () tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp) [Right ()] <- tPut h [Right (sig, t')] @@ -178,33 +179,33 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation +smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> IO ()) -> Expectation smpTest _ test' = runSmpTest test' `shouldReturn` () -smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation +smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO ()) -> Expectation smpTestN n test' = runSmpTestN n test' `shouldReturn` () -smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation +smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest2 = smpTest2Cfg cfg supportedClientSMPRelayVRange -smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation +smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest2Cfg srvCfg clntVR _ test' = runSmpTestNCfg srvCfg clntVR 2 _test `shouldReturn` () where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" -smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest3 _ test' = smpTestN 3 _test where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" -smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation smpTest4 _ test' = smpTestN 4 _test where - _test :: HasCallStack => [THandle c] -> IO () + _test :: HasCallStack => [THandleSMP c] -> IO () _test [h1, h2, h3, h4] = test' h1 h2 h3 h4 _test _ = error "expected 4 handles" diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index d6938fa0f..03935fed5 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -8,7 +9,9 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# OPTIONS_GHC -Wno-orphans #-} module ServerTests where @@ -23,6 +26,7 @@ import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.Set as S +import Data.Type.Equality import GHC.Stack (withFrozenCallStack) import SMPClient import qualified Simplex.Messaging.Crypto as C @@ -74,13 +78,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh) pattern Msg :: MsgId -> MsgBody -> BrokerMsg pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} -sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (sgn, tToSend) tGet1 h -signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) +signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg) signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd) Right () <- tPut1 h (authorize tForAuth, tToSend) @@ -94,12 +98,12 @@ signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do _sx448 -> undefined -- ghc8107 fails to the branch excluded by types #endif -tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) +tPut1 :: Transport c => THandle v c -> SentRawTransmission -> IO (Either TransportError ()) tPut1 h t = do [r] <- tPut h [Right t] pure r -tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd) +tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd) tGet1 h = do [r] <- liftIO $ tGet h pure r @@ -380,7 +384,7 @@ testSwitchSub (ATransport t) = Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3) (ok3, OK) #== "accepts ACK from the 2nd TCP connection" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" @@ -551,12 +555,12 @@ testWithStoreLog at@(ATransport t) = logSize testStoreLogFile `shouldReturn` 1 removeFile testStoreLogFile where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () logSize :: FilePath -> IO Int @@ -649,12 +653,12 @@ testRestoreMessages at@(ATransport t) = removeFile testStoreMsgsFile removeFile testServerStatsBackupFile where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation @@ -723,15 +727,15 @@ testRestoreExpireMessages at@(ATransport t) = Right ServerStatsData {_msgExpired} <- strDecode <$> B.readFile testServerStatsBackupFile _msgExpired `shouldBe` 2 where - runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation + runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation runTest _ test' server = do testSMPClient test' `shouldReturn` () killThread server - runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation + runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation runClient _ test' = testSMPClient test' `shouldReturn` () -createAndSecureQueue :: Transport c => THandle c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret) +createAndSecureQueue :: Transport c => THandleSMP c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret) createAndSecureQueue h sPub = do g <- C.newRandom (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g @@ -747,7 +751,7 @@ testTiming (ATransport t) = describe "should have similar time for auth error, whether queue exists or not, for all key types" $ forM_ timingTests $ \tst -> it (testName tst) $ - smpTest2Cfg cfgV7 (mkVersionRange 4 authCmdsSMPVersion) t $ \rh sh -> + smpTest2Cfg cfgV7 (mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion) t $ \rh sh -> testSameTiming rh sh tst where testName :: (C.AuthAlg, C.AuthAlg, Int) -> String @@ -766,7 +770,7 @@ testTiming (ATransport t) = ] timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const similarTime t1 t2 = abs (t2 / t1 - 1) < 0.15 -- normally the difference between "no queue" and "wrong key" is less than 5% - testSameTiming :: forall c. Transport c => THandle c -> THandle c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation + testSameTiming :: forall c. Transport c => THandleSMP c -> THandleSMP c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation testSameTiming rh sh (C.AuthAlg goodKeyAlg, C.AuthAlg badKeyAlg, n) = do g <- C.newRandom (rPub, rKey) <- atomically $ C.generateAuthKeyPair goodKeyAlg g @@ -787,7 +791,7 @@ testTiming (ATransport t) = runTimingTest sh badKey sId $ _SEND "hello" where - runTimingTest :: PartyI p => THandle c -> C.APrivateAuthKey -> ByteString -> Command p -> IO () + runTimingTest :: PartyI p => THandleSMP c -> C.APrivateAuthKey -> ByteString -> Command p -> IO () runTimingTest h badKey qId cmd = do threadDelay 100000 _ <- timeRepeat n $ do -- "warm up" the server @@ -837,14 +841,14 @@ testMessageNotifications (ATransport t) = Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2) (dec mId2 msg2, Right "hello again") #== "delivered from queue again" Resp "" _ (NMSG _ _) <- tGet1 nh2 - 1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" - 1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection" @@ -864,7 +868,7 @@ testMsgExpireOnSend t = testSMPClient @c $ \rh -> do Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" @@ -884,7 +888,7 @@ testMsgExpireOnInterval t = signSendRecv rh rKey ("2", rId, SUB) >>= \case Resp "2" _ OK -> pure () r -> unexpected r - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing should be delivered" @@ -903,7 +907,7 @@ testMsgNOTExpireOnInterval t = testSMPClient @c $ \rh -> do Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB) (dec mId msg, Right "hello (should NOT expire)") #== "delivered" - 1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case + 1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case Nothing -> return () Just _ -> error "nothing else should be delivered" @@ -919,6 +923,15 @@ sampleSig = Just $ TASignature "e8JK+8V3fq6kOLqco/SaKlpNaQ7i1gfOrXoqekEl42u4mF8B noAuth :: (Char, Maybe BasicAuth) noAuth = ('A', Nothing) +deriving instance Eq TransmissionAuth + +instance Eq C.ASignature where + C.ASignature a s == C.ASignature a' s' = case testEquality a a' of + Just Refl -> s == s' + _ -> False + +deriving instance Eq (C.Signature a) + syntaxTests :: ATransport -> Spec syntaxTests (ATransport t) = do it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) diff --git a/tests/Util.hs b/tests/Util.hs new file mode 100644 index 000000000..a52fee32c --- /dev/null +++ b/tests/Util.hs @@ -0,0 +1,6 @@ +module Util where + +import Test.Hspec + +skip :: String -> SpecWith a -> SpecWith a +skip = before_ . pendingWith diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 46d3d4dd8..999746858 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -21,7 +21,8 @@ import Data.List (find, isSuffixOf) import Data.Maybe (fromJust) import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2, testDB3) import Simplex.FileTransfer.Description (FileDescription (..), FileDescriptionURI (..), ValidFileDescription, fileDescriptionURI, mb, qrSizeLimit, pattern ValidFileDescription) -import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType (AUTH)) +import Simplex.FileTransfer.Protocol (FileParty (..)) +import Simplex.FileTransfer.Transport (XFTPErrorType (AUTH)) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..)) import Simplex.Messaging.Agent (AgentClient, disconnectAgentClient, testProtocolServer, xftpDeleteRcvFile, xftpDeleteSndFileInternal, xftpDeleteSndFileRemote, xftpReceiveFile, xftpSendDescription, xftpSendFile, xftpStartWorkers) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..)) diff --git a/tests/XFTPServerTests.hs b/tests/XFTPServerTests.hs index a11ba515a..451406275 100644 --- a/tests/XFTPServerTests.hs +++ b/tests/XFTPServerTests.hs @@ -20,9 +20,9 @@ import Data.List (isInfixOf) import ServerTests (logSize) import Simplex.FileTransfer.Client import Simplex.FileTransfer.Description (kb) -import Simplex.FileTransfer.Protocol (FileInfo (..), XFTPErrorType (..)) +import Simplex.FileTransfer.Protocol (FileInfo (..)) import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..)) -import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..)) +import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..)) import Simplex.Messaging.Client (ProtocolClientError (..)) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC