Merge remote-tracking branch 'origin/master' into ab/bench-target

This commit is contained in:
Alexander Bondarenko
2024-03-15 18:18:13 +02:00
65 changed files with 3762 additions and 1543 deletions
+2 -1
View File
@@ -1,5 +1,5 @@
name: simplexmq
version: 5.6.0.0
version: 5.6.0.2
synopsis: SimpleXMQ message broker
description: |
This package includes <./docs/Simplex-Messaging-Server.html server>,
@@ -74,6 +74,7 @@ dependencies:
- unliftio-core == 0.2.*
- websockets == 0.12.*
- yaml == 0.11.*
- zstd == 0.1.3.*
flags:
swift:
+36
View File
@@ -0,0 +1,36 @@
# Post-quantum double ratchet implementation
See [the previous doc](https://github.com/simplex-chat/simplex-chat/blob/stable/docs/rfcs/2023-09-30-pq-double-ratchet.md).
The main implementation consideration is that it should be both backwards and forwards compatible, to allow changing the connection DR to/from using PQ primitives (although client version downgrade may be impossible in this case), and also to decide whether to use PQ primitive on per-connection basis:
- use without links (in SMP confirmation or in SMP invitation via address or via member), don't use with links (as they would be too large).
- use in small groups, don't use in large groups.
Also note that for DR to work we need to have 2 KEMs running in parallel.
Possible combinations (assuming both clients support PQ):
| Stage | No PQ kem | PQ key sent | PQ key + PQ ct sent |
|:------------:|:---------:|:-----------:|:-------------------:|
| inv | + | + | - |
| conf, in reply to: <br>no-pq inv <br>pq inv | &nbsp;<br>+<br>+ | &nbsp;<br>+<br>- | &nbsp;<br>-<br>+ |
| 1st msg, in reply to:<br>no-pq conf<br>pq/pq+ct conf | &nbsp;<br>+<br>+ | &nbsp;<br>+<br>- | &nbsp;<br>-<br>+ |
| Nth msg, in reply to:<br>no-pq msg <br>pq/pq+ct msg | &nbsp;<br>+<br>+ | &nbsp;<br>+<br>- | &nbsp;<br>-<br>+ |
These rules can be reduced to:
1. initial invitation optionally has PQ key, but must not have ciphertext.
2. all subsequent messages should be allowed without PQ key/ciphertext, but:
- if the previous message had PQ key or PQ key with ciphertext, they must either have no PQ key, or have PQ key with ciphertext (PQ key without ciphertext is an error).
- if the previous message had no PQ key, they must either have no PQ key, or have PQ key without ciphertext (PQ key with ciphertext is an error).
The rules for calculating the shared secret for received/sent messages are (assuming received message is valid according to the above rules):
| sent msg ><br>V received msg | no-pq | pq | pq+ct |
|:------------------------------:|:-----------:|:-------:|:---------------:|
| no-pq | DH / DH | DH / DH | err |
| pq (sent msg was NOT pq) | DH / DH | err | DH / DH+KEM |
| pq+ct (sent msg was NOT no-pq) | DH+KEM / DH | err | DH+KEM / DH+KEM |
To summarize, the upgrade to DH+KEM secret happens in a sent message that has PQ key with ciphertext sent in reply to message with PQ key only (without ciphertext), and the downgrade to DH secret happens in the message that has no PQ key.
The type for sending PQ key with optional ciphertext is `Maybe E2ERachetKEM` where `data E2ERachetKEM = E2ERachetKEM KEMPublicKey (Maybe KEMCiphertext)`, and for SMP invitation it will be simply `Maybe KEMPublicKey`. Possibly, there is a way to encode the rules above in the types, these types don't constrain possible transitions to valid ones.
+92
View File
@@ -0,0 +1,92 @@
# Migrating existing connections to post-quantum double ratchet algorithm
## Problem
Post-quantum variant of double ratchet algorithm represents an almost full-stack change affecting all parts of the protocol stack except client-server protocol (SMP):
- double-ratchet end-to-end encryption: different encoding (additional large keys require byte-strings larger than 255 bytes with 2-byte length prefixes) and larger message headers (increased by ~2200 bytes).
- agent-agent protocol: a smaller maximum message size to accomodate larger headers and to fit in 16kb blocks, reduced by ~2200 bytes for the messages and by almost ~4000 bytes for connection information.
- chat protocol: also a smaller message size compensated by zstd comression of JSON messages.
We want the versioning that achieves these objectives:
- all changes in all protocol layers happen at the same time, when both clients support it.
- ability to downgrade the clients to the previous version without losing connection.
- ability to opt-in into this functionality via "experimental" feature toggle, that enables post-quantum encryption in connections when both contacts enable this toggle.
To have ability to downgrade the clients we have two options:
- roll-out this functionality in two stages: 1) roll-out clients support but do not enable the new version, and then 2) upgrade client version. The problem here is that the clients won't be able to opt-in into this experiment.
- make offered range dependent on experimental feature being enabled. Currently we have an option to enable PQ encryption in agent API, and this option can be used as a proxy to maxium supported protocol version - if the option is passed, it can be seen as an indication that higher version range (or version) should offered (or accepted).
## Solution
Currently ratchet state stores version range. It's unclear what was the intended semantics of that version range - it simply stores the offered/supported version range at the time ratchet was initialised, but only a high bound is used to send in message headers, and it is never upgraded. In JSON this range is encoded as tuple (an array of two elements in JSON).
We could continue using this range with the meaning of the lower bound to be "currently used ratchet version" and the meaning of higher boundary to be "maximum supported ratchet version". We could also use the version communicated in message headers to upgrade ratchet version, with the condition that upgrade should only happen if both sides want it. Currently it's defined by pqEnableKEM property in ratchet state. We could also make it more explicit by defining maximum version to which ratchet should upgrade. Given that irreversible upgrades are not very common, it is probably ok to keep it implicit.
We can define a better type than VersionRange to reflect semantics of the range in ratchet (current/max supported range), but for backward compatibility it needs to be encoded in the same way as now.
To summarize, the proposed solution for ratchet versioning is:
- define ratchet versions as new type to include current and maximum allowed versions, where maximum allowed will be either the same or lower than maximum supported based on PQ option (in 5.6), and in 5.7 it will be changed to maximum supported, so version starts upgrading independently from PQ being enabled.
- make encodings in ratchet depend on current version (in curent code it depends on max version).
- include max allowed in message header.
- upgrade current if in range on each new message if less than max and higher than current (same as we do for connections).
- increase max allowed once PQ is enabled (only in 5.6). Make max allowed the same as max supported (global constant).
```haskell
data RatchetVR = RatchetVR
{ currentVersion :: Version,
maxAllowedVersion :: Version
}
instance ToJSON RatchetVR where
toEncoding (RatchetVR v1 v2) = toEncoding (v1, v2)
toJSON (RatchetVR v1 v2) = toJSON (v1, v2)
instance FromJSON RatchetVR where
parseJSON v = do
-- this also verifies that v2 > v1 (although we could remove JSON instances for VersionRange)
VersionRange v1 v2 <- parseJSON v
pure $ RatchetVR v1 v2
```
For connections, we could also make version used for the purposes of encoding dependent on the PQ being enabled, and version for decoding taken from message header, but then we'd have to not only upgrade ratchets but the connection as well every time PQ mode changes.
Another suggestion to ensure that correct version range is used in correct contexts could be:
- using different newtypes for different version ranges.
- define generic type class for version aware encoding that would also accept only specific type class for the version to use the correct range. This may be justified as there will be several version-aware encodings, and not just the protocol as now.
```haskell
class Ord v => EncodingV v a where
{-# MINIMAL smpEncodeV, (smpDecodeV | smpVP) #-}
smpEncodeV :: v -> a -> ByteString
-- default decode uses parser
smpDecodeV :: v -> ByteString -> Either String a
smpDecodeV = parseAll . smpVP
-- default parser decodes from length-specified bytestring
smpVP :: v -> Parser a
smpVP v = smpDecodeV v <$?> smpP
```
The version will be passed from currently agreed version, it may only change when message is received, not when message is sent. The version will not be extracted from the encoding itself as it happens now in ratchet encodings.
## Various options how the problem can be simplified
1. Do not support connection downgrade once both devices upgraded. If applied to all existing connections then it is a bad option, as it would disrupt some important conversations.
2. Do not provide ability to opt-in into PQ encryption until v5.7 where it will be rolled out automatically. That is also suboptimal, as it won't allow announcing technology design and have testing outside of the team devices.
3. The logic explained above where connection upgrade and downgrade is possible and applied to all existing connections if both parties consent to it. There are these important downsides:
- complexity of this logic
- regression risks when this logic is removed.
- some non-coordinated upgrades of existing, potentially important conversations, simply because two users opt-in into the experiment without any expectation that another side also opts-in.
4. Apply upgrade/downgrade logic and enable PQ encryption as opt-in, based on the toggle in the UX, only for the new connections. This seems the least risky, and also simpler than option 3, as it would only apply to the new connections, and both users will have to enable experimental toggle prior to connecting.
Option 4 seems the best trade-off, and has these sub-options regarding where it is controlled:
a) in chat based on connection flag. Chat will pass PQ options only to connections that were created when experimental option was enabled.
b) in agent - there will be additional logic to ignore PQ option for existing connections.
c) both in chat and in agent.
Option 4a seems better, as it would:
- simplify agent code
- minimise required changes when releasing v5.7 (as we do want that all direct and small groups connections migrate to PQ encryption at the time, without any toggles)
- allow tests for connection upgrade in the currect code.
+14 -1
View File
@@ -5,7 +5,7 @@ cabal-version: 1.12
-- see: https://github.com/sol/hpack
name: simplexmq
version: 5.6.0.0
version: 5.6.0.2
synopsis: SimpleXMQ message broker
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
<./docs/Simplex-Messaging-Client.html client> and
@@ -103,9 +103,12 @@ library
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem
Simplex.Messaging.Agent.TRcvQueues
Simplex.Messaging.Client
Simplex.Messaging.Client.Agent
Simplex.Messaging.Compression
Simplex.Messaging.Crypto
Simplex.Messaging.Crypto.File
Simplex.Messaging.Crypto.Lazy
@@ -158,6 +161,7 @@ library
Simplex.Messaging.Transport.WebSockets
Simplex.Messaging.Util
Simplex.Messaging.Version
Simplex.Messaging.Version.Internal
Simplex.RemoteControl.Client
Simplex.RemoteControl.Discovery
Simplex.RemoteControl.Discovery.Multicast
@@ -226,6 +230,7 @@ library
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
@@ -299,6 +304,7 @@ executable ntf-server
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
@@ -372,6 +378,7 @@ executable smp-agent
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
@@ -445,6 +452,7 @@ executable smp-server
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
@@ -518,6 +526,7 @@ executable xftp
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
@@ -591,6 +600,7 @@ executable xftp-server
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
@@ -612,6 +622,7 @@ test-suite simplexmq-test
AgentTests
AgentTests.ConnectionRequestTests
AgentTests.DoubleRatchetTests
AgentTests.EqInstances
AgentTests.FunctionalAPITests
AgentTests.MigrationTests
AgentTests.NotificationTests
@@ -634,6 +645,7 @@ test-suite simplexmq-test
ServerTests
SMPAgentClient
SMPClient
Util
XFTPAgent
XFTPCLI
XFTPClient
@@ -702,6 +714,7 @@ test-suite simplexmq-test
, unliftio-core ==0.2.*
, websockets ==0.12.*
, yaml ==0.11.*
, zstd ==0.1.3.*
default-language: Haskell2010
if flag(swift)
cpp-options: -DswiftJSON
+13 -8
View File
@@ -35,6 +35,7 @@ import Control.Monad.Reader
import Data.Bifunctor (first)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Coerce (coerce)
import Data.Composition ((.:))
import Data.Either (rights)
import Data.Int (Int64)
@@ -52,8 +53,8 @@ import Simplex.FileTransfer.Client.Main
import Simplex.FileTransfer.Crypto
import Simplex.FileTransfer.Description
import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..))
import qualified Simplex.FileTransfer.Protocol as XFTP
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
import qualified Simplex.FileTransfer.Transport as XFTP
import Simplex.FileTransfer.Types
import Simplex.FileTransfer.Util (removePath, uniqueCombine)
import Simplex.Messaging.Agent.Client
@@ -389,21 +390,25 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
where
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)])
encryptFileForUpload SndFile {key, nonce, srcFile} fsEncPath = do
encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do
let CryptoFile {filePath} = srcFile
fileName = takeFileName filePath
fileSize <- liftIO $ fromInteger <$> CF.getFileContentsSize srcFile
when (fileSize > maxFileSize) $ throwError $ INTERNAL "max file size exceeded"
when (fileSize > maxFileSizeHard) $ throwError $ INTERNAL "max file size exceeded"
let fileHdr = smpEncode FileHeader {fileName, fileExtra = Nothing}
fileSize' = fromIntegral (B.length fileHdr) + fileSize
chunkSizes = prepareChunkSizes $ fileSize' + fileSizeLen + authTagSize
chunkSizes' = map fromIntegral chunkSizes
encSize = sum chunkSizes'
payloadSize = fileSize' + fileSizeLen + authTagSize
chunkSizes <- case redirect of
Nothing -> pure $ prepareChunkSizes payloadSize
Just _ -> case singleChunkSize payloadSize of
Nothing -> throwError $ INTERNAL "max file size exceeded for redirect"
Just chunkSize -> pure [chunkSize]
let encSize = sum $ map fromIntegral chunkSizes
void $ liftError (INTERNAL . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath
digest <- liftIO $ LC.sha512Hash <$> LB.readFile fsEncPath
let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes
chunkDigests <- map FileDigest <$> mapM (liftIO . getChunkDigest) chunkSpecs
pure (FileDigest digest, zip chunkSpecs chunkDigests)
chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs
pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests)
chunkCreated :: SndFileChunk -> Bool
chunkCreated SndFileChunk {replicas} =
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas
+1 -1
View File
@@ -57,7 +57,7 @@ import UnliftIO.Directory
data XFTPClient = XFTPClient
{ http2Client :: HTTP2Client,
transportSession :: TransportSession FileResponse,
thParams :: THandleParams,
thParams :: THandleParams XFTPVersion,
config :: XFTPClientConfig
}
+18 -4
View File
@@ -16,9 +16,11 @@ module Simplex.FileTransfer.Client.Main
xftpClientCLI,
cliSendFile,
cliSendFileOpts,
singleChunkSize,
prepareChunkSizes,
prepareChunkSpecs,
maxFileSize,
maxFileSizeHard,
fileSizeLen,
getChunkDigest,
SentRecipientReplica (..),
@@ -41,7 +43,7 @@ import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Data.Maybe (fromMaybe, listToMaybe)
import qualified Data.Text as T
import Data.Word (Word32)
import GHC.Records (HasField (getField))
@@ -76,12 +78,17 @@ import UnliftIO.Directory
xftpClientVersion :: String
xftpClientVersion = "1.0.1"
-- | Soft limit for XFTP clients. Should be checked and reported to user.
maxFileSize :: Int64
maxFileSize = gb 1
maxFileSizeStr :: String
maxFileSizeStr = B.unpack . strEncode $ FileSize maxFileSize
-- | Hard internal limit for XFTP agent after which it refuses to prepare chunks.
maxFileSizeHard :: Int64
maxFileSizeHard = gb 5
fileSizeLen :: Int64
fileSizeLen = 8
@@ -214,13 +221,13 @@ data SentFileChunk = SentFileChunk
digest :: FileDigest,
replicas :: [SentFileChunkReplica]
}
deriving (Eq, Show)
deriving (Show)
data SentFileChunkReplica = SentFileChunkReplica
{ server :: XFTPServer,
recipients :: [(ChunkReplicaId, C.APrivateAuthKey)]
}
deriving (Eq, Show)
deriving (Show)
data SentRecipientReplica = SentRecipientReplica
{ chunkNo :: Int,
@@ -407,7 +414,8 @@ getChunkDigest :: XFTPChunkSpec -> IO ByteString
getChunkDigest XFTPChunkSpec {filePath = chunkPath, chunkOffset, chunkSize} =
withFile chunkPath ReadMode $ \h -> do
hSeek h AbsoluteSeek $ fromIntegral chunkOffset
LC.sha256Hash <$> LB.hGet h (fromIntegral chunkSize)
chunk <- LB.hGet h (fromIntegral chunkSize)
pure $! LC.sha256Hash chunk
cliReceiveFile :: ReceiveOptions -> ExceptT CLIError IO ()
cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, verbose, yes} =
@@ -522,6 +530,12 @@ getFileDescription' path =
getFileDescription path >>= \case
AVFD fd -> either (throwError . CLIError) pure $ checkParty fd
singleChunkSize :: Int64 -> Maybe Word32
singleChunkSize size' =
listToMaybe $ dropWhile (< chunkSize) serverChunkSizes
where
chunkSize = fromIntegral size'
prepareChunkSizes :: Int64 -> [Word32]
prepareChunkSizes size' = prepareSizes size'
where
+1 -1
View File
@@ -227,7 +227,7 @@ validateFileDescription fd@FileDescription {size, chunks}
| otherwise = Right $ ValidFD fd
where
chunkNos = map (\FileChunk {chunkNo} -> chunkNo) chunks
chunksSize = fromIntegral . foldl' (\s FileChunk {chunkSize} -> s + unFileSize chunkSize) 0
chunksSize = foldl' (\(s :: Int64) FileChunk {chunkSize} -> s + fromIntegral (unFileSize chunkSize)) 0
encodeFileDescription :: FileDescription p -> YAMLFileDescription
encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSize, chunks, redirect} =
+17 -96
View File
@@ -1,9 +1,11 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
@@ -14,9 +16,7 @@
module Simplex.FileTransfer.Protocol where
import Control.Applicative ((<|>))
import qualified Data.Aeson.TH as J
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -25,11 +25,11 @@ import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (isNothing)
import Data.Type.Equality
import Data.Word (Word32)
import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake)
import Simplex.Messaging.Client (authTransmission)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol
( BasicAuth,
@@ -56,11 +56,10 @@ import Simplex.Messaging.Protocol
tParse,
)
import Simplex.Messaging.Transport (THandleParams (..), TransportError (..))
import Simplex.Messaging.Util (bshow, (<$?>))
import Simplex.Messaging.Version
import Simplex.Messaging.Util ((<$?>))
currentXFTPVersion :: Version
currentXFTPVersion = 1
currentXFTPVersion :: VersionXFTP
currentXFTPVersion = VersionXFTP 1
xftpBlockSize :: Int
xftpBlockSize = 16384
@@ -142,10 +141,10 @@ instance ProtocolMsgTag FileCmdTag where
instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where
decodeTag s = decodeTag s >>= (\(FCT _ t) -> checkParty' t)
instance Protocol XFTPErrorType FileResponse where
instance Protocol XFTPVersion XFTPErrorType FileResponse where
type ProtoCommand FileResponse = FileCmd
type ProtoType FileResponse = 'PXFTP
protocolClientHandshake = ntfClientHandshake
protocolClientHandshake = xftpClientHandshake
protocolPing = FileCmd SFRecipient PING
protocolError = \case
FRErr e -> Just e
@@ -171,11 +170,11 @@ data FileInfo = FileInfo
size :: Word32,
digest :: ByteString
}
deriving (Eq, Show)
deriving (Show)
type XFTPFileId = ByteString
instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where
instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand p) where
type Tag (FileCommand p) = FileCommandTag p
encodeProtocol _v = \case
FNEW file rKeys auth_ -> e (FNEW_, ' ', file, rKeys, auth_)
@@ -191,7 +190,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where
protocolP v tag = (\(FileCmd _ c) -> checkParty c) <$?> protocolP v (FCT (sFileParty @p) tag)
fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, fileId, _) cmd = case cmd of
@@ -208,7 +207,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where
| isNothing auth || B.null fileId -> Left $ CMD NO_AUTH
| otherwise -> Right cmd
instance ProtocolEncoding XFTPErrorType FileCmd where
instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where
type Tag FileCmd = FileCmdTag
encodeProtocol _v (FileCmd _ c) = encodeProtocol _v c
@@ -225,7 +224,7 @@ instance ProtocolEncoding XFTPErrorType FileCmd where
FACK_ -> pure FACK
PING_ -> pure PING
fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
{-# INLINE fromProtocolError #-}
checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c
@@ -276,7 +275,7 @@ data FileResponse
| FRPong
deriving (Show)
instance ProtocolEncoding XFTPErrorType FileResponse where
instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where
type Tag FileResponse = FileResponseTag
encodeProtocol _v = \case
FRSndIds fId rIds -> e (FRSndIds_, ' ', fId, rIds)
@@ -319,82 +318,6 @@ instance ProtocolEncoding XFTPErrorType FileResponse where
| B.null entId = Right cmd
| otherwise = Left $ CMD HAS_AUTH
data XFTPErrorType
= -- | incorrect block format, encoding or signature size
BLOCK
| -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929)
SESSION
| -- | SMP command is unknown or has invalid syntax
CMD {cmdErr :: CommandError}
| -- | command authorization error - bad signature or non-existing SMP queue
AUTH
| -- | incorrent file size
SIZE
| -- | storage quota exceeded
QUOTA
| -- | incorrent file digest
DIGEST
| -- | file encryption/decryption failed
CRYPTO
| -- | no expected file body in request/response or no file on the server
NO_FILE
| -- | unexpected file body
HAS_FILE
| -- | file IO error
FILE_IO
| -- | bad redirect data
REDIRECT {redirectError :: String}
| -- | internal server error
INTERNAL
| -- | used internally, never returned by the server (to be removed)
DUPLICATE_ -- not part of SMP protocol, used internally
deriving (Eq, Read, Show)
instance StrEncoding XFTPErrorType where
strEncode = \case
CMD e -> "CMD " <> bshow e
REDIRECT e -> "REDIRECT " <> bshow e
e -> bshow e
strP =
"CMD " *> (CMD <$> parseRead1)
<|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString)
<|> parseRead1
instance Encoding XFTPErrorType where
smpEncode = \case
BLOCK -> "BLOCK"
SESSION -> "SESSION"
CMD err -> "CMD " <> smpEncode err
AUTH -> "AUTH"
SIZE -> "SIZE"
QUOTA -> "QUOTA"
DIGEST -> "DIGEST"
CRYPTO -> "CRYPTO"
NO_FILE -> "NO_FILE"
HAS_FILE -> "HAS_FILE"
FILE_IO -> "FILE_IO"
REDIRECT err -> "REDIRECT " <> smpEncode err
INTERNAL -> "INTERNAL"
DUPLICATE_ -> "DUPLICATE_"
smpP =
A.takeTill (== ' ') >>= \case
"BLOCK" -> pure BLOCK
"SESSION" -> pure SESSION
"CMD" -> CMD <$> _smpP
"AUTH" -> pure AUTH
"SIZE" -> pure SIZE
"QUOTA" -> pure QUOTA
"DIGEST" -> pure DIGEST
"CRYPTO" -> pure CRYPTO
"NO_FILE" -> pure NO_FILE
"HAS_FILE" -> pure HAS_FILE
"FILE_IO" -> pure FILE_IO
"REDIRECT" -> REDIRECT <$> _smpP
"INTERNAL" -> pure INTERNAL
"DUPLICATE_" -> pure DUPLICATE_
_ -> fail "bad error type"
checkParty :: forall t p p'. (FilePartyI p, FilePartyI p') => t p' -> Either String (t p)
checkParty c = case testEquality (sFileParty @p) (sFileParty @p') of
Just Refl -> Right c
@@ -405,12 +328,12 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of
Just Refl -> Just c
_ -> Nothing
xftpEncodeAuthTransmission :: ProtocolEncoding e c => THandleParams -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg)
xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth
xftpEncodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> Either TransportError ByteString
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString
xftpEncodeTransmission thParams (corrId, fId, msg) = do
let t = encodeTransmission thParams (corrId, fId, msg)
xftpEncodeBatch1 (Nothing, t)
@@ -419,7 +342,7 @@ xftpEncodeTransmission thParams (corrId, fId, msg) = do
xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString
xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 t) xftpBlockSize
xftpDecodeTransmission :: ProtocolEncoding e c => THandleParams -> ByteString -> Either XFTPErrorType (SignedTransmission e c)
xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> ByteString -> Either XFTPErrorType (SignedTransmission e c)
xftpDecodeTransmission thParams t = do
t' <- first (const BLOCK) $ C.unPad t
case tParse thParams t' of
@@ -427,5 +350,3 @@ xftpDecodeTransmission thParams t = do
_ -> Left BLOCK
$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty)
$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType)
+1 -1
View File
@@ -69,7 +69,7 @@ type M a = ReaderT XFTPEnv IO a
data XFTPTransportRequest =
XFTPTransportRequest
{ thParams :: THandleParams,
{ thParams :: THandleParams XFTPVersion,
reqBody :: HTTP2Body,
request :: H.Request,
sendResponse :: H.Response -> IO ()
+2 -2
View File
@@ -28,7 +28,8 @@ import Data.Int (Int64)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Time.Clock.System (SystemTime (..))
import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPErrorType (..), XFTPFileId)
import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPFileId)
import Simplex.FileTransfer.Transport (XFTPErrorType (..))
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (RcvPublicAuthKey, RecipientId, SenderId)
@@ -49,7 +50,6 @@ data FileRec = FileRec
recipientIds :: TVar (Set RecipientId),
createdAt :: SystemTime
}
deriving (Eq)
data FileRecipient = FileRecipient RecipientId RcvPublicAuthKey
+118 -4
View File
@@ -1,10 +1,19 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Simplex.FileTransfer.Transport
( supportedFileServerVRange,
xftpClientHandshake, -- stub
XFTPVersion,
VersionXFTP,
pattern VersionXFTP,
XFTPErrorType (..),
XFTPRcvChunkSpec (..),
ReceiveFileError (..),
receiveFile,
@@ -14,22 +23,31 @@ module Simplex.FileTransfer.Transport
)
where
import Control.Applicative ((<|>))
import qualified Control.Exception as E
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class
import qualified Data.Aeson.TH as J
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import qualified Data.ByteArray as BA
import Data.ByteString.Builder (Builder, byteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Word (Word32)
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
import Data.Word (Word16, Word32)
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol (CommandError)
import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..))
import Simplex.Messaging.Transport.HTTP2.File
import Simplex.Messaging.Util (bshow)
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import System.IO (Handle, IOMode (..), withFile)
data XFTPRcvChunkSpec = XFTPRcvChunkSpec
@@ -39,8 +57,26 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec
}
deriving (Show)
supportedFileServerVRange :: VersionRange
supportedFileServerVRange = mkVersionRange 1 1
data XFTPVersion
instance VersionScope XFTPVersion
type VersionXFTP = Version XFTPVersion
type VersionRangeXFTP = VersionRange XFTPVersion
pattern VersionXFTP :: Word16 -> VersionXFTP
pattern VersionXFTP v = Version v
initialXFTPVersion :: VersionXFTP
initialXFTPVersion = VersionXFTP 1
supportedFileServerVRange :: VersionRangeXFTP
supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion
-- XFTP protocol does not support handshake
xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c)
xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION
sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO ()
sendEncFile h send = go
@@ -97,3 +133,81 @@ receiveFile_ receive XFTPRcvChunkSpec {filePath, chunkSize, chunkDigest} = do
ExceptT $ withFile filePath WriteMode (`receive` chunkSize)
digest' <- liftIO $ LC.sha256Hash <$> LB.readFile filePath
when (digest' /= chunkDigest) $ throwError DIGEST
data XFTPErrorType
= -- | incorrect block format, encoding or signature size
BLOCK
| -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929)
SESSION
| -- | SMP command is unknown or has invalid syntax
CMD {cmdErr :: CommandError}
| -- | command authorization error - bad signature or non-existing SMP queue
AUTH
| -- | incorrent file size
SIZE
| -- | storage quota exceeded
QUOTA
| -- | incorrent file digest
DIGEST
| -- | file encryption/decryption failed
CRYPTO
| -- | no expected file body in request/response or no file on the server
NO_FILE
| -- | unexpected file body
HAS_FILE
| -- | file IO error
FILE_IO
| -- | bad redirect data
REDIRECT {redirectError :: String}
| -- | internal server error
INTERNAL
| -- | used internally, never returned by the server (to be removed)
DUPLICATE_ -- not part of SMP protocol, used internally
deriving (Eq, Read, Show)
instance StrEncoding XFTPErrorType where
strEncode = \case
CMD e -> "CMD " <> bshow e
REDIRECT e -> "REDIRECT " <> bshow e
e -> bshow e
strP =
"CMD " *> (CMD <$> parseRead1)
<|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString)
<|> parseRead1
instance Encoding XFTPErrorType where
smpEncode = \case
BLOCK -> "BLOCK"
SESSION -> "SESSION"
CMD err -> "CMD " <> smpEncode err
AUTH -> "AUTH"
SIZE -> "SIZE"
QUOTA -> "QUOTA"
DIGEST -> "DIGEST"
CRYPTO -> "CRYPTO"
NO_FILE -> "NO_FILE"
HAS_FILE -> "HAS_FILE"
FILE_IO -> "FILE_IO"
REDIRECT err -> "REDIRECT " <> smpEncode err
INTERNAL -> "INTERNAL"
DUPLICATE_ -> "DUPLICATE_"
smpP =
A.takeTill (== ' ') >>= \case
"BLOCK" -> pure BLOCK
"SESSION" -> pure SESSION
"CMD" -> CMD <$> _smpP
"AUTH" -> pure AUTH
"SIZE" -> pure SIZE
"QUOTA" -> pure QUOTA
"DIGEST" -> pure DIGEST
"CRYPTO" -> pure CRYPTO
"NO_FILE" -> pure NO_FILE
"HAS_FILE" -> pure HAS_FILE
"FILE_IO" -> pure FILE_IO
"REDIRECT" -> REDIRECT <$> _smpP
"INTERNAL" -> pure INTERNAL
"DUPLICATE_" -> pure DUPLICATE_
_ -> fail "bad error type"
$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType)
+9 -9
View File
@@ -55,7 +55,7 @@ data RcvFile = RcvFile
status :: RcvFileStatus,
deleted :: Bool
}
deriving (Eq, Show)
deriving (Show)
data RcvFileStatus
= RFSReceiving
@@ -96,7 +96,7 @@ data RcvFileChunk = RcvFileChunk
fileTmpPath :: FilePath,
chunkTmpPath :: Maybe FilePath
}
deriving (Eq, Show)
deriving (Show)
data RcvFileChunkReplica = RcvFileChunkReplica
{ rcvChunkReplicaId :: Int64,
@@ -107,14 +107,14 @@ data RcvFileChunkReplica = RcvFileChunkReplica
delay :: Maybe Int64,
retries :: Int
}
deriving (Eq, Show)
deriving (Show)
data RcvFileRedirect = RcvFileRedirect
{ redirectDbId :: DBRcvFileId,
redirectEntityId :: RcvFileId,
redirectFileInfo :: RedirectFileInfo
}
deriving (Eq, Show)
deriving (Show)
-- Sending files
@@ -135,7 +135,7 @@ data SndFile = SndFile
deleted :: Bool,
redirect :: Maybe RedirectFileInfo
}
deriving (Eq, Show)
deriving (Show)
sndFileEncPath :: FilePath -> FilePath
sndFileEncPath prefixPath = prefixPath </> "xftp.encrypted"
@@ -182,7 +182,7 @@ data SndFileChunk = SndFileChunk
digest :: FileDigest,
replicas :: [SndFileChunkReplica]
}
deriving (Eq, Show)
deriving (Show)
sndChunkSize :: SndFileChunk -> Word32
sndChunkSize SndFileChunk {chunkSpec = XFTPChunkSpec {chunkSize}} = chunkSize
@@ -193,7 +193,7 @@ data NewSndChunkReplica = NewSndChunkReplica
replicaKey :: C.APrivateAuthKey,
rcvIdsKeys :: [(ChunkReplicaId, C.APrivateAuthKey)]
}
deriving (Eq, Show)
deriving (Show)
data SndFileChunkReplica = SndFileChunkReplica
{ sndChunkReplicaId :: Int64,
@@ -205,7 +205,7 @@ data SndFileChunkReplica = SndFileChunkReplica
delay :: Maybe Int64,
retries :: Int
}
deriving (Eq, Show)
deriving (Show)
data SndFileReplicaStatus
= SFRSCreated
@@ -235,4 +235,4 @@ data DeletedSndChunkReplica = DeletedSndChunkReplica
delay :: Maybe Int64,
retries :: Int
}
deriving (Eq, Show)
deriving (Show)
File diff suppressed because it is too large Load Diff
+37 -56
View File
@@ -110,8 +110,6 @@ module Simplex.Messaging.Agent.Client
whenSuspending,
withStore,
withStore',
withStoreCtx,
withStoreCtx',
withStoreBatch,
withStoreBatch',
storeError,
@@ -167,8 +165,8 @@ import Network.Socket (HostName)
import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError)
import qualified Simplex.FileTransfer.Client as X
import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb)
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse, XFTPErrorType (DIGEST))
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse)
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion)
import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..))
import Simplex.FileTransfer.Util (uniqueCombine)
import Simplex.Messaging.Agent.Env.SQLite
@@ -187,6 +185,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Transport (NTFVersion)
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse)
import Simplex.Messaging.Protocol
@@ -215,9 +214,12 @@ import Simplex.Messaging.Protocol
UserProtocol,
XFTPServer,
XFTPServerWithAuth,
VersionSMPC,
VersionRangeSMPC,
sameSrvAddr',
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport (SMPVersion)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -253,7 +255,7 @@ data AgentClient = AgentClient
{ active :: TVar Bool,
rcvQ :: TBQueue (ATransmission 'Client),
subQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue (ServerTransmission BrokerMsg),
msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg),
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
smpClients :: TMap SMPTransportSession SMPClientVar,
ntfServers :: TVar [NtfServer],
@@ -467,7 +469,7 @@ agentClientStore AgentClient {agentEnv = Env {store}} = store
agentDRG :: AgentClient -> TVar ChaChaDRG
agentDRG AgentClient {agentEnv = Env {random}} = random
class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err where
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
type Client msg = c | c -> msg
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg)
clientProtocolError :: err -> AgentErrorType
@@ -476,8 +478,8 @@ class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err wher
clientTransportHost :: Client msg -> TransportHost
clientSessionTs :: Client msg -> UTCTime
instance ProtocolServerClient ErrorType BrokerMsg where
type Client BrokerMsg = ProtocolClient ErrorType BrokerMsg
instance ProtocolServerClient SMPVersion ErrorType BrokerMsg where
type Client BrokerMsg = ProtocolClient SMPVersion ErrorType BrokerMsg
getProtocolServerClient = getSMPServerClient
clientProtocolError = SMP
closeProtocolServerClient = closeProtocolClient
@@ -485,8 +487,8 @@ instance ProtocolServerClient ErrorType BrokerMsg where
clientTransportHost = transportHost'
clientSessionTs = sessionTs
instance ProtocolServerClient ErrorType NtfResponse where
type Client NtfResponse = ProtocolClient ErrorType NtfResponse
instance ProtocolServerClient NTFVersion ErrorType NtfResponse where
type Client NtfResponse = ProtocolClient NTFVersion ErrorType NtfResponse
getProtocolServerClient = getNtfServerClient
clientProtocolError = NTF
closeProtocolServerClient = closeProtocolClient
@@ -494,7 +496,7 @@ instance ProtocolServerClient ErrorType NtfResponse where
clientTransportHost = transportHost'
clientSessionTs = sessionTs
instance ProtocolServerClient XFTPErrorType FileResponse where
instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
type Client FileResponse = XFTPClient
getProtocolServerClient = getXFTPServerClient
clientProtocolError = XFTP
@@ -683,8 +685,8 @@ waitForProtocolClient c (_, srv, _) v = do
-- clientConnected arg is only passed for SMP server
newProtocolClient ::
forall err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) =>
forall v err msg m.
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
AgentClient ->
TransportSession msg ->
TMap (TransportSession msg) (ClientVar msg) ->
@@ -706,10 +708,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
putTMVar (sessionVar v) (Left e)
throwError e -- signal error to caller
hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v)
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
cfg <- asks $ cfgSel . config
networkConfig <- readTVarIO useNetworkConfig
@@ -754,19 +756,19 @@ throwWhenNoDelivery c sq =
unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $
throwSTM ThreadKilled
closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c)
reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
reconnectServerClients c clientsSel =
readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c)
closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO ()
closeClient :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO ()
closeClient c clientSel tSess =
atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c)
closeClient_ :: ProtocolServerClient err msg => AgentClient -> ClientVar msg -> IO ()
closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO ()
closeClient_ c v = do
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
@@ -798,7 +800,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
where
newLock = createLock >>= \l -> TM.insert key l locks $> l
withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
withClient_ c tSess@(userId, srv, _) statCmd action = do
cl <- getProtocolServerClient c tSess
(action cl <* stat cl "OK") `catchAgentError` logServerError cl
@@ -810,18 +812,18 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do
stat cl $ strEncode e
throwError e
withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
logServer "-->" c srv entId cmdStr
res <- withClient_ c tSess cmdStr action
logServer "<--" c srv entId "OK"
return res
withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMPClient c q cmdStr action = do
@@ -837,7 +839,7 @@ withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityI
withNtfClient c srv = withLogClient c (0, srv, Nothing)
withXFTPClient ::
(AgentMonad m, ProtocolServerClient err msg) =>
(AgentMonad m, ProtocolServerClient v err msg) =>
AgentClient ->
(UserId, ProtoServer msg, EntityId) ->
ByteString ->
@@ -1001,7 +1003,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
C.AuthAlg a <- asks (rcvAuthAlg . config)
g <- asks random
@@ -1151,7 +1153,7 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
tSess <- mkTransportSession c userId smpServer senderId
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
@@ -1334,7 +1336,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
-- add encoding as AgentInvitation'?
agentCbEncryptOnce :: AgentMonad m => Version -> C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
g <- asks random
(dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g
@@ -1453,34 +1455,13 @@ waitUntilForeground :: AgentClient -> STM ()
waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry
withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a
withStore' = withStoreCtx_' Nothing
withStore' c action = withStore c $ fmap Right . action
withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
withStore = withStoreCtx_ Nothing
withStoreCtx' :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO a) -> m a
withStoreCtx' = withStoreCtx_' . Just
withStoreCtx :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
withStoreCtx = withStoreCtx_ . Just
withStoreCtx_' :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO a) -> m a
withStoreCtx_' ctx_ c action = withStoreCtx_ ctx_ c $ fmap Right . action
withStoreCtx_ :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
withStoreCtx_ ctx_ c action = do
withStore c action = do
st <- asks store
liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ case ctx_ of
Nothing -> withTransaction st action `E.catch` handleInternal ""
-- uncomment to debug store performance
-- Just ctx -> do
-- t1 <- liftIO getCurrentTime
-- putStrLn $ "agent withStoreCtx start :: " <> show t1 <> " :: " <> ctx
-- r <- withTransaction st action `E.catch` handleInternal (" (" <> ctx <> ")")
-- t2 <- liftIO getCurrentTime
-- putStrLn $ "agent withStoreCtx end :: " <> show t2 <> " :: " <> ctx <> " :: duration=" <> show (diffToMilliseconds $ diffUTCTime t2 t1)
-- pure r
Just _ -> withTransaction st action `E.catch` handleInternal ""
liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $
withTransaction st action `E.catch` handleInternal ""
where
handleInternal :: String -> E.SomeException -> IO (Either StoreError a)
handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr
@@ -1518,7 +1499,7 @@ incStat AgentClient {agentStats} n k = do
Just v -> modifyTVar' v (+ n)
_ -> newTVar n >>= \v -> TM.insert k v agentStats
incClientStat :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
incClientStat c userId pc = incClientStatN c userId pc 1
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
@@ -1528,7 +1509,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do
where
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
incClientStatN :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO ()
incClientStatN c userId pc n cmd res = do
atomically $ incStat c n statsKey
where
+11 -9
View File
@@ -56,16 +56,16 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (supportedE2EEncryptVRange)
import Simplex.Messaging.Crypto.Ratchet (PQSupport, VersionRangeE2E, supportedE2EEncryptVRange)
import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig)
import Simplex.Messaging.Notifications.Transport (NTFVersion)
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Protocol (NtfServer, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange)
import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (TLS, Transport (..))
import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..))
import Simplex.Messaging.Transport.Client (defaultSMPPort)
import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors)
import Simplex.Messaging.Version
import System.Random (StdGen, newStdGen)
import UnliftIO (Async, SomeException)
import UnliftIO.STM
@@ -87,12 +87,13 @@ data AgentConfig = AgentConfig
sndAuthAlg :: C.AuthAlg,
connIdBytes :: Int,
tbqSize :: Natural,
smpCfg :: ProtocolClientConfig,
ntfCfg :: ProtocolClientConfig,
smpCfg :: ProtocolClientConfig SMPVersion,
ntfCfg :: ProtocolClientConfig NTFVersion,
xftpCfg :: XFTPClientConfig,
reconnectInterval :: RetryInterval,
messageRetryInterval :: RetryInterval2,
messageTimeout :: NominalDiffTime,
connDeleteDeliveryTimeout :: NominalDiffTime,
helloTimeout :: NominalDiffTime,
quotaExceededTimeout :: NominalDiffTime,
initialCleanupDelay :: Int64,
@@ -115,9 +116,9 @@ data AgentConfig = AgentConfig
caCertificateFile :: FilePath,
privateKeyFile :: FilePath,
certificateFile :: FilePath,
e2eEncryptVRange :: VersionRange,
smpAgentVRange :: VersionRange,
smpClientVRange :: VersionRange
e2eEncryptVRange :: PQSupport -> VersionRangeE2E,
smpAgentVRange :: PQSupport -> VersionRangeSMPA,
smpClientVRange :: VersionRangeSMPC
}
defaultReconnectInterval :: RetryInterval
@@ -161,6 +162,7 @@ defaultAgentConfig =
reconnectInterval = defaultReconnectInterval,
messageRetryInterval = defaultMessageRetryInterval,
messageTimeout = 2 * nominalDay,
connDeleteDeliveryTimeout = 2 * nominalDay,
helloTimeout = 2 * nominalDay,
quotaExceededTimeout = 7 * nominalDay,
initialCleanupDelay = 30 * 1000000, -- 30 seconds
+139 -85
View File
@@ -4,6 +4,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
@@ -33,8 +34,14 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
module Simplex.Messaging.Agent.Protocol
( -- * Protocol parameters
VersionSMPA,
VersionRangeSMPA,
pattern VersionSMPA,
duplexHandshakeSMPAgentVersion,
ratchetSyncSMPAgentVersion,
deliveryRcptsSMPAgentVersion,
pqdrSMPAgentVersion,
currentSMPAgentVersion,
supportedSMPAgentVRange,
e2eEncConnInfoLength,
e2eEncUserMsgLength,
@@ -175,14 +182,25 @@ import Data.Time.Clock.System (SystemTime)
import Data.Time.ISO8601
import Data.Type.Equality
import Data.Typeable ()
import Data.Word (Word32)
import Data.Word (Word16, Word32)
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.ToField
import Simplex.FileTransfer.Description
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType)
import Simplex.FileTransfer.Protocol (FileParty (..))
import Simplex.FileTransfer.Transport (XFTPErrorType)
import Simplex.Messaging.Agent.QueryString
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (E2ERatchetParams, E2ERatchetParamsUri)
import Simplex.Messaging.Crypto.Ratchet
( InitialKeys (..),
PQEncryption (..),
pattern PQEncOff,
PQSupport,
pattern PQSupportOn,
pattern PQSupportOff,
RcvE2ERatchetParams,
RcvE2ERatchetParamsUri,
SndE2ERatchetParams
)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
@@ -200,6 +218,10 @@ import Simplex.Messaging.Protocol
SMPServerWithAuth,
SndPublicAuthKey,
SubscriptionMode,
SMPClientVersion,
VersionSMPC,
VersionRangeSMPC,
initialSMPClientVersion,
legacyEncodeServer,
legacyServerP,
legacyStrEncodeServer,
@@ -215,6 +237,7 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..))
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import Simplex.RemoteControl.Types
import Text.Read
import UnliftIO.Exception (Exception)
@@ -224,30 +247,56 @@ import UnliftIO.Exception (Exception)
-- 2 - "duplex" (more efficient) connection handshake (6/9/2022)
-- 3 - support ratchet renegotiation (6/30/2023)
-- 4 - delivery receipts (7/13/2023)
-- 5 - post-quantum double ratchet (3/14/2024)
duplexHandshakeSMPAgentVersion :: Version
duplexHandshakeSMPAgentVersion = 2
data SMPAgentVersion
ratchetSyncSMPAgentVersion :: Version
ratchetSyncSMPAgentVersion = 3
instance VersionScope SMPAgentVersion
deliveryRcptsSMPAgentVersion :: Version
deliveryRcptsSMPAgentVersion = 4
type VersionSMPA = Version SMPAgentVersion
currentSMPAgentVersion :: Version
currentSMPAgentVersion = 4
type VersionRangeSMPA = VersionRange SMPAgentVersion
supportedSMPAgentVRange :: VersionRange
supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion
pattern VersionSMPA :: Word16 -> VersionSMPA
pattern VersionSMPA v = Version v
duplexHandshakeSMPAgentVersion :: VersionSMPA
duplexHandshakeSMPAgentVersion = VersionSMPA 2
ratchetSyncSMPAgentVersion :: VersionSMPA
ratchetSyncSMPAgentVersion = VersionSMPA 3
deliveryRcptsSMPAgentVersion :: VersionSMPA
deliveryRcptsSMPAgentVersion = VersionSMPA 4
pqdrSMPAgentVersion :: VersionSMPA
pqdrSMPAgentVersion = VersionSMPA 5
-- TODO v5.7 increase to 5
currentSMPAgentVersion :: VersionSMPA
currentSMPAgentVersion = VersionSMPA 4
-- TODO v5.7 remove dependency of version range on whether PQ support is needed
supportedSMPAgentVRange :: PQSupport -> VersionRangeSMPA
supportedSMPAgentVRange pq =
mkVersionRange duplexHandshakeSMPAgentVersion $ case pq of
PQSupportOn -> pqdrSMPAgentVersion
PQSupportOff -> currentSMPAgentVersion
-- it is shorter to allow all handshake headers,
-- including E2E (double-ratchet) parameters and
-- signing key of the sender for the server
e2eEncConnInfoLength :: Int
e2eEncConnInfoLength = 14848
e2eEncConnInfoLength :: VersionSMPA -> PQSupport -> Int
e2eEncConnInfoLength v = \case
-- reduced by 3726 (roughly the increase of message ratchet header size + key and ciphertext in reply link)
PQSupportOn | v >= pqdrSMPAgentVersion -> 11122
_ -> 14848
e2eEncUserMsgLength :: Int
e2eEncUserMsgLength = 15856
e2eEncUserMsgLength :: VersionSMPA -> PQSupport -> Int
e2eEncUserMsgLength v = \case
-- reduced by 2222 (the increase of message ratchet header size)
PQSupportOn | v >= pqdrSMPAgentVersion -> 13634
_ -> 15856
-- | Raw (unparsed) SMP agent protocol transmission.
type ARawTransmission = (ByteString, ByteString, ByteString)
@@ -273,8 +322,6 @@ data SAParty :: AParty -> Type where
deriving instance Show (SAParty p)
deriving instance Eq (SAParty p)
instance TestEquality SAParty where
testEquality SAgent SAgent = Just Refl
testEquality SClient SClient = Just Refl
@@ -297,8 +344,6 @@ data SAEntity :: AEntity -> Type where
deriving instance Show (SAEntity e)
deriving instance Eq (SAEntity e)
instance TestEquality SAEntity where
testEquality SAEConn SAEConn = Just Refl
testEquality SAERcvFile SAERcvFile = Just Refl
@@ -333,16 +378,16 @@ type ConnInfo = ByteString
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
data ACommand (p :: AParty) (e :: AEntity) where
NEW :: Bool -> AConnectionMode -> SubscriptionMode -> ACommand Client AEConn -- response INV
NEW :: Bool -> AConnectionMode -> InitialKeys -> SubscriptionMode -> ACommand Client AEConn -- response INV
INV :: AConnectionRequestUri -> ACommand Agent AEConn
JOIN :: Bool -> AConnectionRequestUri -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK
CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake
JOIN :: Bool -> AConnectionRequestUri -> PQSupport -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK
CONF :: ConfirmationId -> PQSupport -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake
LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client
REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender
ACPT :: InvitationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client
REQ :: InvitationId -> PQSupport -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender
ACPT :: InvitationId -> PQSupport -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client
RJCT :: InvitationId -> ACommand Client AEConn
INFO :: ConnInfo -> ACommand Agent AEConn
CON :: ACommand Agent AEConn -- notification that connection is established
INFO :: PQSupport -> ConnInfo -> ACommand Agent AEConn
CON :: PQEncryption -> ACommand Agent AEConn -- notification that connection is established
SUB :: ACommand Client AEConn
END :: ACommand Agent AEConn
CONNECT :: AProtocolType -> TransportHost -> ACommand Agent AENone
@@ -351,8 +396,8 @@ data ACommand (p :: AParty) (e :: AEntity) where
UP :: SMPServer -> [ConnId] -> ACommand Agent AENone
SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent AEConn
RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> ACommand Agent AEConn
SEND :: MsgFlags -> MsgBody -> ACommand Client AEConn
MID :: AgentMsgId -> ACommand Agent AEConn
SEND :: PQEncryption -> MsgFlags -> MsgBody -> ACommand Client AEConn
MID :: AgentMsgId -> PQEncryption -> ACommand Agent AEConn
SENT :: AgentMsgId -> ACommand Agent AEConn
MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent AEConn
MERRS :: NonEmpty AgentMsgId -> AgentErrorType -> ACommand Agent AEConn
@@ -458,8 +503,8 @@ aCommandTag = \case
REQ {} -> REQ_
ACPT {} -> ACPT_
RJCT _ -> RJCT_
INFO _ -> INFO_
CON -> CON_
INFO {} -> INFO_
CON _ -> CON_
SUB -> SUB_
END -> END_
CONNECT {} -> CONNECT_
@@ -469,7 +514,7 @@ aCommandTag = \case
SWITCH {} -> SWITCH_
RSYNC {} -> RSYNC_
SEND {} -> SEND_
MID _ -> MID_
MID {} -> MID_
SENT _ -> SENT_
MERR {} -> MERR_
MERRS {} -> MERRS_
@@ -665,7 +710,7 @@ instance StrEncoding SndQueueInfo where
pure SndQueueInfo {sndServer, sndSwitchStatus}
data ConnectionStats = ConnectionStats
{ connAgentVersion :: Version,
{ connAgentVersion :: VersionSMPA,
rcvQueuesInfo :: [RcvQueueInfo],
sndQueuesInfo :: [SndQueueInfo],
ratchetSyncState :: RatchetSyncState,
@@ -769,17 +814,19 @@ data MsgMeta = MsgMeta
{ integrity :: MsgIntegrity,
recipient :: (AgentMsgId, UTCTime),
broker :: (MsgId, UTCTime),
sndMsgId :: AgentMsgId
sndMsgId :: AgentMsgId,
pqEncryption :: PQEncryption
}
deriving (Eq, Show)
instance StrEncoding MsgMeta where
strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId} =
strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId, pqEncryption} =
B.unwords
[ strEncode integrity,
"R=" <> bshow rmId <> "," <> showTs rTs,
"B=" <> encode bmId <> "," <> showTs bTs,
"S=" <> bshow sndMsgId
"S=" <> bshow sndMsgId,
"PQ=" <> strEncode pqEncryption
]
where
showTs = B.pack . formatISO8601Millis
@@ -788,7 +835,8 @@ instance StrEncoding MsgMeta where
recipient <- " R=" *> partyMeta A.decimal
broker <- " B=" *> partyMeta base64P
sndMsgId <- " S=" *> A.decimal
pure MsgMeta {integrity, recipient, broker, sndMsgId}
pqEncryption <- " PQ=" *> strP
pure MsgMeta {integrity, recipient, broker, sndMsgId, pqEncryption}
where
partyMeta idParser = (,) <$> idParser <* A.char ',' <*> tsISO8601P
@@ -802,28 +850,28 @@ data SMPConfirmation = SMPConfirmation
-- | optional reply queues included in confirmation (added in agent protocol v2)
smpReplyQueues :: [SMPQueueInfo],
-- | SMP client version
smpClientVersion :: Version
smpClientVersion :: VersionSMPC
}
deriving (Show)
data AgentMsgEnvelope
= AgentConfirmation
{ agentVersion :: Version,
e2eEncryption_ :: Maybe (E2ERatchetParams 'C.X448),
{ agentVersion :: VersionSMPA,
e2eEncryption_ :: Maybe (SndE2ERatchetParams 'C.X448),
encConnInfo :: ByteString
}
| AgentMsgEnvelope
{ agentVersion :: Version,
{ agentVersion :: VersionSMPA,
encAgentMessage :: ByteString
}
| AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet,
{ agentVersion :: Version,
{ agentVersion :: VersionSMPA,
connReq :: ConnectionRequestUri 'CMInvitation,
connInfo :: ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet,
}
| AgentRatchetKey
{ agentVersion :: Version,
e2eEncryption :: E2ERatchetParams 'C.X448,
{ agentVersion :: VersionSMPA,
e2eEncryption :: RcvE2ERatchetParams 'C.X448,
info :: ByteString
}
deriving (Show)
@@ -1115,7 +1163,7 @@ instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) whe
CRInvitationUri crData e2eParams -> crEncode "invitation" crData (Just e2eParams)
CRContactUri crData -> crEncode "contact" crData Nothing
where
crEncode :: ByteString -> ConnReqUriData -> Maybe (E2ERatchetParamsUri 'C.X448) -> ByteString
crEncode :: ByteString -> ConnReqUriData -> Maybe (RcvE2ERatchetParamsUri 'C.X448) -> ByteString
crEncode crMode ConnReqUriData {crScheme, crAgentVRange, crSmpQueues, crClientData} e2eParams =
strEncode crScheme <> "/" <> crMode <> "#/?" <> queryStr
where
@@ -1228,16 +1276,16 @@ sameQueue :: SMPQueue q => (SMPServer, SMP.QueueId) -> q -> Bool
sameQueue addr q = sameQAddress addr (qAddress q)
{-# INLINE sameQueue #-}
data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress}
data SMPQueueInfo = SMPQueueInfo {clientVersion :: VersionSMPC, queueAddress :: SMPQueueAddress}
deriving (Eq, Show)
instance Encoding SMPQueueInfo where
smpEncode (SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey})
| clientVersion > 1 = smpEncode (clientVersion, smpServer, senderId, dhPublicKey)
| clientVersion > initialSMPClientVersion = smpEncode (clientVersion, smpServer, senderId, dhPublicKey)
| otherwise = smpEncode clientVersion <> legacyEncodeServer smpServer <> smpEncode (senderId, dhPublicKey)
smpP = do
clientVersion <- smpP
smpServer <- if clientVersion > 1 then smpP else updateSMPServerHosts <$> legacyServerP
smpServer <- if clientVersion > initialSMPClientVersion then smpP else updateSMPServerHosts <$> legacyServerP
(senderId, dhPublicKey) <- smpP
pure $ SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey}
@@ -1245,20 +1293,20 @@ instance Encoding SMPQueueInfo where
-- But this is created to allow backward and forward compatibility where SMPQueueUri
-- could have more fields to convert to different versions of SMPQueueInfo in a different way,
-- and this instance would become non-trivial.
instance VersionI SMPQueueInfo where
type VersionRangeT SMPQueueInfo = SMPQueueUri
instance VersionI SMPClientVersion SMPQueueInfo where
type VersionRangeT SMPClientVersion SMPQueueInfo = SMPQueueUri
version = clientVersion
toVersionRangeT (SMPQueueInfo _v addr) vr = SMPQueueUri vr addr
instance VersionRangeI SMPQueueUri where
type VersionT SMPQueueUri = SMPQueueInfo
instance VersionRangeI SMPClientVersion SMPQueueUri where
type VersionT SMPClientVersion SMPQueueUri = SMPQueueInfo
versionRange = clientVRange
toVersionT (SMPQueueUri _vr addr) v = SMPQueueInfo v addr
-- | SMP queue information sent out-of-band.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#out-of-band-messages
data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRange, queueAddress :: SMPQueueAddress}
data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRangeSMPC, queueAddress :: SMPQueueAddress}
deriving (Eq, Show)
data SMPQueueAddress = SMPQueueAddress
@@ -1307,7 +1355,7 @@ instance StrEncoding SMPQueueUri where
smpServer = if maxVersion vr < srvHostnamesSMPClientVersion then updateSMPServerHosts srv' else srv'
pure $ SMPQueueUri vr SMPQueueAddress {smpServer, senderId, dhPublicKey}
where
unversioned = (versionToRange 1,[],) <$> strP <* A.endOfInput
unversioned = (versionToRange initialSMPClientVersion,[],) <$> strP <* A.endOfInput
versioned = do
dhKey_ <- optional strP
query <- optional (A.char '/') *> A.char '?' *> strP
@@ -1324,8 +1372,8 @@ instance Encoding SMPQueueUri where
pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}
data ConnectionRequestUri (m :: ConnectionMode) where
CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation
-- contact connection request does NOT contain E2E encryption parameters -
CRInvitationUri :: ConnReqUriData -> RcvE2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation
-- contact connection request does NOT contain E2E encryption parameters for double ratchet -
-- they are passed in AgentInvitation message
CRContactUri :: ConnReqUriData -> ConnectionRequestUri CMContact
@@ -1336,15 +1384,15 @@ deriving instance Show (ConnectionRequestUri m)
data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m)
instance Eq AConnectionRequestUri where
ACR m cr == ACR m' cr' = case testEquality m m' of
Just Refl -> cr == cr'
_ -> False
ACR m cr == ACR m' cr' = case testEquality m m' of
Just Refl -> cr == cr'
_ -> False
deriving instance Show AConnectionRequestUri
data ConnReqUriData = ConnReqUriData
{ crScheme :: ServiceScheme,
crAgentVRange :: VersionRange,
crAgentVRange :: VersionRangeSMPA,
crSmpQueues :: NonEmpty SMPQueueUri,
crClientData :: Maybe CRClientData
}
@@ -1713,13 +1761,13 @@ commandP binaryP =
>>= \case
ACmdTag SClient e cmd ->
ACmd SClient e <$> case cmd of
NEW_ -> s (NEW <$> strP_ <*> strP_ <*> (strP <|> pure SMP.SMSubscribe))
JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP)
NEW_ -> s (NEW <$> strP_ <*> strP_ <*> pqIKP <*> (strP <|> pure SMP.SMSubscribe))
JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> pqSupP <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP)
LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP)
ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> binaryP)
ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> binaryP)
RJCT_ -> s (RJCT <$> A.takeByteString)
SUB_ -> pure SUB
SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP)
SEND_ -> s (SEND <$> pqEncP <*> smpP <* A.space <*> binaryP)
ACK_ -> s (ACK <$> A.decimal <*> optional (A.space *> binaryP))
SWCH_ -> pure SWCH
OFF_ -> pure OFF
@@ -1728,10 +1776,10 @@ commandP binaryP =
ACmdTag SAgent e cmd ->
ACmd SAgent e <$> case cmd of
INV_ -> s (INV <$> strP)
CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP)
REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP)
INFO_ -> s (INFO <$> binaryP)
CON_ -> pure CON
CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strListP <* A.space <*> binaryP)
REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strP_ <*> binaryP)
INFO_ -> s (INFO <$> pqSupP <*> binaryP)
CON_ -> s (CON <$> strP)
END_ -> pure END
CONNECT_ -> s (CONNECT <$> strP_ <*> strP)
DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP)
@@ -1739,7 +1787,7 @@ commandP binaryP =
UP_ -> s (UP <$> strP_ <*> connections)
SWITCH_ -> s (SWITCH <$> strP_ <*> strP_ <*> strP)
RSYNC_ -> s (RSYNC <$> strP_ <*> strP <*> strP)
MID_ -> s (MID <$> A.decimal)
MID_ -> s (MID <$> A.decimal <*> _strP)
SENT_ -> s (SENT <$> A.decimal)
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
MERRS_ -> s (MERRS <$> strP_ <*> strP)
@@ -1762,6 +1810,12 @@ commandP binaryP =
where
s :: Parser a -> Parser a
s p = A.space *> p
pqIKP :: Parser InitialKeys
pqIKP = strP_ <|> pure (IKNoPQ PQSupportOff)
pqSupP :: Parser PQSupport
pqSupP = strP_ <|> pure PQSupportOff
pqEncP :: Parser PQEncryption
pqEncP = strP_ <|> pure PQEncOff
connections :: Parser [ConnId]
connections = strP `A.sepBy'` A.char ','
sfDone :: Text -> Either String (ACommand 'Agent 'AESndFile)
@@ -1777,15 +1831,15 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX
-- | Serialize SMP agent command.
serializeCommand :: ACommand p e -> ByteString
serializeCommand = \case
NEW ntfs cMode subMode -> s (NEW_, ntfs, cMode, subMode)
NEW ntfs cMode pqIK subMode -> s (NEW_, ntfs, cMode, pqIK, subMode)
INV cReq -> s (INV_, cReq)
JOIN ntfs cReq subMode cInfo -> s (JOIN_, ntfs, cReq, subMode, Str $ serializeBinary cInfo)
CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo]
JOIN ntfs cReq pqSup subMode cInfo -> s (JOIN_, ntfs, cReq, pqSup, subMode, Str $ serializeBinary cInfo)
CONF confId pqSup srvs cInfo -> B.unwords [s CONF_, confId, s pqSup, strEncodeList srvs, serializeBinary cInfo]
LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo]
REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo]
ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo]
REQ invId pqSup srvs cInfo -> B.unwords [s REQ_, invId, s pqSup, s srvs, serializeBinary cInfo]
ACPT invId pqSup cInfo -> B.unwords [s ACPT_, invId, s pqSup, serializeBinary cInfo]
RJCT invId -> B.unwords [s RJCT_, invId]
INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo]
INFO pqSup cInfo -> B.unwords [s INFO_, s pqSup, serializeBinary cInfo]
SUB -> s SUB_
END -> s END_
CONNECT p h -> s (CONNECT_, p, h)
@@ -1794,8 +1848,8 @@ serializeCommand = \case
UP srv conns -> B.unwords [s UP_, s srv, connections conns]
SWITCH dir phase srvs -> s (SWITCH_, dir, phase, srvs)
RSYNC rrState cryptoErr cstats -> s (RSYNC_, rrState, cryptoErr, cstats)
SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody]
MID mId -> s (MID_, mId)
SEND pqEnc msgFlags msgBody -> B.unwords [s SEND_, s pqEnc, smpEncode msgFlags, serializeBinary msgBody]
MID mId pqEnc -> s (MID_, mId, pqEnc)
SENT mId -> s (SENT_, mId)
MERR mId e -> s (MERR_, mId, e)
MERRS mIds e -> s (MERRS_, mIds, e)
@@ -1811,7 +1865,7 @@ serializeCommand = \case
DEL_USER userId -> s (DEL_USER_, userId)
CHK -> s CHK_
STAT srvs -> s (STAT_, srvs)
CON -> s CON_
CON pqEnc -> s (CON_, pqEnc)
ERR e -> s (ERR_, e)
OK -> s OK_
SUSPENDED -> s SUSPENDED_
@@ -1884,14 +1938,14 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p))
cmdWithMsgBody (APC e cmd) =
APC e <$$> case cmd of
SEND msgFlags body -> SEND msgFlags <$$> getBody body
SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body
MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body
JOIN ntfs qUri subMode cInfo -> JOIN ntfs qUri subMode <$$> getBody cInfo
CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo
JOIN ntfs qUri pqSup subMode cInfo -> JOIN ntfs qUri pqSup subMode <$$> getBody cInfo
CONF confId pqSup srvs cInfo -> CONF confId pqSup srvs <$$> getBody cInfo
LET confId cInfo -> LET confId <$$> getBody cInfo
REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo
ACPT invId cInfo -> ACPT invId <$$> getBody cInfo
INFO cInfo -> INFO <$$> getBody cInfo
REQ invId pqSup srvs cInfo -> REQ invId pqSup srvs <$$> getBody cInfo
ACPT invId pqSup cInfo -> ACPT invId pqSup <$$> getBody cInfo
INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo
_ -> pure $ Right cmd
getBody :: ByteString -> m (Either AgentErrorType ByteString)
+12 -20
View File
@@ -30,7 +30,7 @@ import Data.Type.Equality
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval (RI2State)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (RatchetX448)
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, PQEncryption, PQSupport)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
( MsgBody,
@@ -44,10 +44,10 @@ import Simplex.Messaging.Protocol
RcvPrivateAuthKey,
SndPrivateAuthKey,
SndPublicAuthKey,
VersionSMPC,
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Util ((<$?>))
import Simplex.Messaging.Version
-- * Queue types
@@ -61,8 +61,6 @@ data DBQueueId (q :: QueueStored) where
DBQueueId :: Int64 -> DBQueueId 'QSStored
DBNewQueue :: DBQueueId 'QSNew
deriving instance Eq (DBQueueId q)
deriving instance Show (DBQueueId q)
type RcvQueue = StoredRcvQueue 'QSStored
@@ -96,12 +94,12 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue
dbReplaceQueueId :: Maybe Int64,
rcvSwchStatus :: Maybe RcvSwitchStatus,
-- | SMP client version
smpClientVersion :: Version,
smpClientVersion :: VersionSMPC,
-- | credentials used in context of notifications
clientNtfCreds :: Maybe ClientNtfCreds,
deleteErrors :: Int
}
deriving (Eq, Show)
deriving (Show)
rcvQueueInfo :: RcvQueue -> RcvQueueInfo
rcvQueueInfo rq@RcvQueue {server, rcvSwchStatus} =
@@ -128,7 +126,7 @@ data ClientNtfCreds = ClientNtfCreds
-- | shared DH secret used to encrypt/decrypt notification metadata (NMsgMeta) from server to recipient
rcvNtfDhSecret :: RcvNtfDhSecret
}
deriving (Eq, Show)
deriving (Show)
type SndQueue = StoredSndQueue 'QSStored
@@ -159,9 +157,9 @@ data StoredSndQueue (q :: QueueStored) = SndQueue
dbReplaceQueueId :: Maybe Int64,
sndSwchStatus :: Maybe SndSwitchStatus,
-- | SMP client version
smpClientVersion :: Version
smpClientVersion :: VersionSMPC
}
deriving (Eq, Show)
deriving (Show)
sndQueueInfo :: SndQueue -> SndQueueInfo
sndQueueInfo SndQueue {server, sndSwchStatus} =
@@ -256,8 +254,6 @@ data Connection (d :: ConnType) where
DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex
ContactConnection :: ConnData -> RcvQueue -> Connection CContact
deriving instance Eq (Connection d)
deriving instance Show (Connection d)
toConnData :: Connection d -> ConnData
@@ -290,8 +286,6 @@ connType SCSnd = CSnd
connType SCDuplex = CDuplex
connType SCContact = CContact
deriving instance Eq (SConnType d)
deriving instance Show (SConnType d)
instance TestEquality SConnType where
@@ -305,21 +299,17 @@ instance TestEquality SConnType where
-- Used to refer to an arbitrary connection when retrieving from store.
data SomeConn = forall d. SomeConn (SConnType d) (Connection d)
instance Eq SomeConn where
SomeConn d c == SomeConn d' c' = case testEquality d d' of
Just Refl -> c == c'
_ -> False
deriving instance Show SomeConn
data ConnData = ConnData
{ connId :: ConnId,
userId :: UserId,
connAgentVersion :: Version,
connAgentVersion :: VersionSMPA,
enableNtfs :: Bool,
lastExternalSndId :: PrevExternalSndId,
deleted :: Bool,
ratchetSyncState :: RatchetSyncState
ratchetSyncState :: RatchetSyncState,
pqSupport :: PQSupport
}
deriving (Eq, Show)
@@ -534,6 +524,7 @@ data SndMsgData = SndMsgData
msgType :: AgentMessageType,
msgFlags :: MsgFlags,
msgBody :: MsgBody,
pqEncryption :: PQEncryption,
internalHash :: MsgHash,
prevMsgHash :: MsgHash
}
@@ -551,6 +542,7 @@ data PendingMsgData = PendingMsgData
msgType :: AgentMessageType,
msgFlags :: MsgFlags,
msgBody :: MsgBody,
pqEncryption :: PQEncryption,
msgRetryState :: Maybe RI2State,
internalTs :: InternalTs
}
+129 -105
View File
@@ -50,7 +50,6 @@ module Simplex.Messaging.Agent.Store.SQLite
createNewConn,
updateNewConnRcv,
updateNewConnSnd,
createRcvConn, -- no longer used
createSndConn,
getConn,
getDeletedConn,
@@ -59,7 +58,9 @@ module Simplex.Messaging.Agent.Store.SQLite
getConnData,
setConnDeleted,
setConnAgentVersion,
setConnPQSupport,
getDeletedConnIds,
getDeletedWaitingDeliveryConnIds,
setConnRatchetSync,
addProcessedRatchetKeyHash,
checkRatchetKeyHashExists,
@@ -93,7 +94,6 @@ module Simplex.Messaging.Agent.Store.SQLite
getAcceptedConfirmation,
removeConfirmations,
-- Invitations - sent via Contact connections
setConnectionVersion,
createInvitation,
getInvitation,
acceptInvitation,
@@ -241,7 +241,7 @@ import Data.List (foldl', intercalate, sortBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, isJust, listToMaybe, catMaybes)
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe)
import Data.Ord (Down (..))
import Data.Text (Text)
import qualified Data.Text as T
@@ -268,7 +268,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations (DownMigration (..), MTRE
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys)
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..))
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
@@ -278,7 +279,7 @@ import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>))
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
import System.Exit (exitFailure)
import System.FilePath (takeDirectory)
@@ -542,11 +543,8 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
ConnData {connId} -> Right . (connId,) <$> create connId
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} cMode = do
fst <$$> createConn_ gVar cData create
where
create connId =
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True)
createNewConn db gVar cData cMode = do
fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode)
updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue)
updateNewConnRcv db connId rq =
@@ -568,22 +566,25 @@ updateNewConnSnd db connId sq =
updateConn :: IO (Either StoreError SndQueue)
updateConn = Right <$> addConnSndQueue_ db connId sq
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue))
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True)
insertRcvQueue_ db connId q serverKeyHash_
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue))
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@SndQueue {server} =
createSndConn db gVar cData q@SndQueue {server} =
-- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_
ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $
createConn_ gVar cData $ \connId -> do
serverKeyHash_ <- createServer_ db server
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, True)
createConnRecord db connId cData SCMInvitation
insertSndQueue_ db connId q serverKeyHash_
createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO ()
createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSupport} cMode =
DB.execute
db
[sql|
INSERT INTO connections
(user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?)
|]
(userId, connId, cMode, connAgentVersion, enableNtfs, pqSupport, True)
checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
fromMaybe False
@@ -602,12 +603,32 @@ getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId)
(rq,) <$> ExceptT (getConn db connId)
deleteConn :: DB.Connection -> ConnId -> IO ()
deleteConn db connId =
DB.executeNamed
db
"DELETE FROM connections WHERE conn_id = :conn_id;"
[":conn_id" := connId]
-- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted
deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId)
deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of
Nothing -> delete
Just timeout ->
ifM
checkNoPendingDeliveries_
delete
( ifM
(checkWaitDeliveryTimeout_ timeout)
delete
(pure Nothing)
)
where
delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId
checkNoPendingDeliveries_ = do
r :: (Maybe Int64) <-
maybeFirstRow fromOnly $
DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId)
pure $ isNothing r
checkWaitDeliveryTimeout_ timeout = do
cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime
r :: (Maybe Int64) <-
maybeFirstRow fromOnly $
DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs)
pure $ isJust r
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
upgradeRcvConnToDuplex db connId sq =
@@ -681,7 +702,7 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d
|]
(host, port, rcvId)
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> Version -> IO ()
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO ()
setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion =
DB.executeNamed
db
@@ -782,7 +803,7 @@ setRcvQueueNtfCreds db connId clientNtfCreds =
Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret)
Nothing -> (Nothing, Nothing, Nothing, Nothing)
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version)
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC)
smpConfirmation :: SMPConfirmationRow -> SMPConfirmation
smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) =
@@ -791,7 +812,7 @@ smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersi
e2ePubKey,
connInfo,
smpReplyQueues = fromMaybe [] smpReplyQueues_,
smpClientVersion = fromMaybe 1 smpClientVersion_
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
}
createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId)
@@ -868,10 +889,6 @@ removeConfirmations db connId =
|]
[":conn_id" := connId]
setConnectionVersion :: DB.Connection -> ConnId -> Version -> IO ()
setConnectionVersion db connId aVersion =
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId)
createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} =
createWithRandomId gVar $ \invitationId ->
@@ -1008,18 +1025,18 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} =
DB.query
db
[sql|
SELECT m.msg_type, m.msg_flags, m.msg_body, m.internal_ts, s.retry_int_slow, s.retry_int_fast
SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast
FROM messages m
JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id
WHERE m.conn_id = ? AND m.internal_id = ?
|]
(connId, msgId)
err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []"
pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData
pendingMsgData (msgType, msgFlags_, msgBody, internalTs, riSlow_, riFast_) =
pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData
pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) =
let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_
msgRetryState = RI2State <$> riSlow_ <*> riFast_
in PendingMsgData {msgId, msgType, msgFlags, msgBody, msgRetryState, internalTs}
in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs}
markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId)
getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a))
@@ -1088,7 +1105,7 @@ getRcvMsg db connId agentMsgId =
[sql|
SELECT
r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash,
m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack
m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack
FROM rcv_messages r
JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id
LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id
@@ -1104,7 +1121,7 @@ getLastMsg db connId msgId =
[sql|
SELECT
r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash,
m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack
m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack
FROM rcv_messages r
JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id
JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id
@@ -1113,9 +1130,9 @@ getLastMsg db connId msgId =
|]
(connId, msgId)
toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs, AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg
toRcvMsg (agentMsgId, internalTs, brokerId, brokerTs, sndMsgId, integrity, internalHash, msgType, msgBody, rcptInternalId_, rcptStatus_, userAck) =
let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity}
toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg
toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) =
let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption}
msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_
in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck}
@@ -1175,34 +1192,34 @@ deleteSndMsgsExpired db ttl = do
"DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL"
(Only cutoffTs)
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO ()
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 =
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2)
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO ()
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem)
getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448))
getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams))
getRatchetX3dhKeys db connId =
fmap hasKeys $
firstRow id SEX3dhKeysNotFound $
DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2 FROM ratchets WHERE conn_id = ?" (Only connId)
firstRow' keys SEX3dhKeysNotFound $
DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId)
where
hasKeys = \case
Right (Just k1, Just k2) -> Right (k1, k2)
keys = \case
(Just k1, Just k2, pKem) -> Right (k1, k2, pKem)
_ -> Left SEX3dhKeysNotFound
-- used to remember new keys when starting ratchet re-synchronization
-- TODO remove the columns for public keys in v5.7.
-- Currently, the keys are not used but still stored to support app downgrade to the previous version.
setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO ()
setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 =
setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO ()
setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
DB.execute
db
[sql|
UPDATE ratchets
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
WHERE conn_id = ?
|]
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, connId)
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId)
-- TODO remove the columns for public keys in v5.7.
createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO ()
createRatchet db connId rc =
DB.executeNamed
@@ -1213,7 +1230,10 @@ createRatchet db connId rc =
ON CONFLICT (conn_id) DO UPDATE SET
ratchet_state = :ratchet_state,
x3dh_priv_key_1 = NULL,
x3dh_priv_key_2 = NULL
x3dh_priv_key_2 = NULL,
x3dh_pub_key_1 = NULL,
x3dh_pub_key_2 = NULL,
pq_priv_kem = NULL
|]
[":conn_id" := connId, ":ratchet_state" := rc]
@@ -1752,6 +1772,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn
instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
instance ToField (Version v) where toField (Version v) = toField v
instance FromField (Version v) where fromField f = Version <$> fromField f
listToEither :: e -> [a] -> Either e a
listToEither _ (x : _) = Right x
listToEither e _ = Left e
@@ -1903,25 +1927,38 @@ getConnData db connId' =
[sql|
SELECT
user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
last_external_snd_msg_id, deleted, ratchet_sync_state
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
FROM connections
WHERE conn_id = ?
|]
(Only connId')
where
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState) =
(ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState}, cMode)
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport) =
(ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
setConnDeleted :: DB.Connection -> ConnId -> IO ()
setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO ()
setConnDeleted db waitDelivery connId
| waitDelivery = do
currentTs <- getCurrentTime
DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId)
| otherwise =
DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO ()
setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO ()
setConnAgentVersion db connId aVersion =
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
setConnPQSupport :: DB.Connection -> ConnId -> PQSupport -> IO ()
setConnPQSupport db connId pqSupport =
DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId)
getDeletedConnIds :: DB.Connection -> IO [ConnId]
getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True)
getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId]
getDeletedWaitingDeliveryConnIds db =
map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL"
setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO ()
setConnRatchetSync db connId ratchetSyncState =
DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId)
@@ -1970,12 +2007,12 @@ rcvQueueQuery =
toRcvQueue ::
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe Version, Int)
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int)
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
RcvQueue
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
let server = SMPServer host port keyHash
smpClientVersion = fromMaybe 1 smpClientVersion_
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
_ -> Nothing
@@ -2011,7 +2048,7 @@ sndQueueQuery =
toSndQueue ::
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId)
:. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus)
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, Version) ->
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) ->
SndQueue
toSndQueue
( (userId, keyHash, connId, host, port, sndId)
@@ -2060,23 +2097,15 @@ updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do
let MsgMeta {recipient = (internalId, internalTs)} = msgMeta
DB.executeNamed
let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta
DB.execute
dbConn
[sql|
INSERT INTO messages
( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body)
VALUES
(:conn_id,:internal_id,:internal_ts,:internal_rcv_id, NULL,:msg_type,:msg_flags,:msg_body);
(conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption)
VALUES (?,?,?,?,?,?,?,?,?);
|]
[ ":conn_id" := connId,
":internal_id" := internalId,
":internal_ts" := internalTs,
":internal_rcv_id" := internalRcvId,
":msg_type" := msgType,
":msg_flags" := msgFlags,
":msg_body" := msgBody
]
(connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption)
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do
@@ -2157,23 +2186,16 @@ updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
-- * createSndMsg helpers
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
insertSndMsgBase_ dbConn connId SndMsgData {..} = do
DB.executeNamed
dbConn
insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do
DB.execute
db
[sql|
INSERT INTO messages
( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body)
(conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption)
VALUES
(:conn_id,:internal_id,:internal_ts, NULL,:internal_snd_id,:msg_type,:msg_flags,:msg_body);
(?,?,?,?,?,?,?,?,?);
|]
[ ":conn_id" := connId,
":internal_id" := internalId,
":internal_ts" := internalTs,
":internal_snd_id" := internalSndId,
":msg_type" := msgType,
":msg_flags" := msgFlags,
":msg_body" := msgBody
]
(connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption)
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
@@ -2267,17 +2289,18 @@ createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redire
forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId
pure dstEntityId
where
dummyDst = FileDescription
{ party = SFRecipient,
size,
digest,
redirect = Nothing,
-- updated later with updateRcvFileRedirect
key = C.unsafeSbKey $ B.replicate 32 '#',
nonce = C.cbNonce "",
chunkSize = FileSize 0,
chunks = []
}
dummyDst =
FileDescription
{ party = SFRecipient,
size,
digest,
redirect = Nothing,
-- updated later with updateRcvFileRedirect
key = C.unsafeSbKey $ B.replicate 32 '#',
nonce = C.cbNonce "",
chunkSize = FileSize 0,
chunks = []
}
insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> IO (Either StoreError (RcvFileId, DBRcvFileId))
insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ = runExceptT $ do
@@ -2346,10 +2369,11 @@ getRcvFile db rcvFileId = runExceptT $ do
toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) =
let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_
saveFile = CryptoFile savePath cfArgs
redirect = RcvFileRedirect
<$> redirectDbId
<*> redirectEntityId
<*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_)
redirect =
RcvFileRedirect
<$> redirectDbId
<*> redirectEntityId
<*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_)
in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []}
getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk]
getChunks rcvFileEntityId userId fileTmpPath = do
@@ -71,7 +71,8 @@ dbBusyLoop action = loop 500 3000000
loop :: Int -> Int -> IO a
loop t tLim =
action `E.catch` \(e :: SQLError) ->
if tLim > t && SQL.sqlError e == SQL.ErrorBusy
let se = SQL.sqlError e in
if tLim > t && (se == SQL.ErrorBusy || se == SQL.ErrorLocked)
then do
threadDelay t
loop (t * 9 `div` 8) (tLim - t)
@@ -69,6 +69,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231222_command_created
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON)
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -106,7 +108,9 @@ schemaMigrations =
("m20231222_command_created_at", m20231222_command_created_at, Just down_m20231222_command_created_at),
("m20231225_failed_work_items", m20231225_failed_work_items, Just down_m20231225_failed_work_items),
("m20240121_message_delivery_indexes", m20240121_message_delivery_indexes, Just down_m20240121_message_delivery_indexes),
("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect)
("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect),
("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery),
("m20240225_ratchet_kem", m20240225_ratchet_kem, Just down_m20240225_ratchet_kem)
]
-- | The list of migrations in ascending order by date
@@ -0,0 +1,18 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20240223_connections_wait_delivery :: Query
m20240223_connections_wait_delivery =
[sql|
ALTER TABLE connections ADD COLUMN deleted_at_wait_delivery TEXT;
|]
down_m20240223_connections_wait_delivery :: Query
down_m20240223_connections_wait_delivery =
[sql|
ALTER TABLE connections DROP COLUMN deleted_at_wait_delivery;
|]
@@ -0,0 +1,22 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20240225_ratchet_kem :: Query
m20240225_ratchet_kem =
[sql|
ALTER TABLE ratchets ADD COLUMN pq_priv_kem BLOB;
ALTER TABLE connections ADD COLUMN pq_support INTEGER NOT NULL DEFAULT 0;
ALTER TABLE messages ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0;
|]
down_m20240225_ratchet_kem :: Query
down_m20240225_ratchet_kem =
[sql|
ALTER TABLE ratchets DROP COLUMN pq_priv_kem;
ALTER TABLE connections DROP COLUMN pq_support;
ALTER TABLE messages DROP COLUMN pq_encryption;
|]
@@ -26,7 +26,9 @@ CREATE TABLE connections(
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL),
user_id INTEGER CHECK(user_id NOT NULL)
REFERENCES users ON DELETE CASCADE,
ratchet_sync_state TEXT NOT NULL DEFAULT 'ok'
ratchet_sync_state TEXT NOT NULL DEFAULT 'ok',
deleted_at_wait_delivery TEXT,
pq_support INTEGER NOT NULL DEFAULT 0
) WITHOUT ROWID;
CREATE TABLE rcv_queues(
host TEXT NOT NULL,
@@ -89,6 +91,7 @@ CREATE TABLE messages(
msg_type BLOB NOT NULL, --(H)ELLO,(R)EPLY,(D)ELETE. Should SMP confirmation be saved too?
msg_body BLOB NOT NULL DEFAULT x'',
msg_flags TEXT NULL,
pq_encryption INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY(conn_id, internal_id),
FOREIGN KEY(conn_id, internal_rcv_id) REFERENCES rcv_messages
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
@@ -159,7 +162,8 @@ CREATE TABLE ratchets(
e2e_version INTEGER NOT NULL DEFAULT 1
,
x3dh_pub_key_1 BLOB,
x3dh_pub_key_2 BLOB
x3dh_pub_key_2 BLOB,
pq_priv_kem BLOB
) WITHOUT ROWID;
CREATE TABLE skipped_messages(
skipped_message_id INTEGER PRIMARY KEY,
+36 -36
View File
@@ -76,7 +76,7 @@ module Simplex.Messaging.Client
PCTransmission,
mkTransmission,
authTransmission,
clientStub,
smpClientStub,
)
where
@@ -117,14 +117,14 @@ import System.Timeout (timeout)
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
--
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
data ProtocolClient err msg = ProtocolClient
data ProtocolClient v err msg = ProtocolClient
{ action :: Maybe (Async ()),
thParams :: THandleParams,
thParams :: THandleParams v,
sessionTs :: UTCTime,
client_ :: PClient err msg
client_ :: PClient v err msg
}
data PClient err msg = PClient
data PClient v err msg = PClient
{ connected :: TVar Bool,
transportSession :: TransportSession msg,
transportHost :: TransportHost,
@@ -135,11 +135,11 @@ data PClient err msg = PClient
sentCommands :: TMap CorrId (Request err msg),
sndQ :: TBQueue ByteString,
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
msgQ :: Maybe (TBQueue (ServerTransmission msg))
msgQ :: Maybe (TBQueue (ServerTransmission v msg))
}
clientStub :: TVar ChaChaDRG -> ByteString -> Version -> Maybe THandleAuth -> STM (ProtocolClient err msg)
clientStub g sessionId thVersion thAuth = do
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg)
smpClientStub g sessionId thVersion thAuth = do
connected <- newTVar False
clientCorrId <- C.newRandomDRG g
sentCommands <- TM.empty
@@ -174,13 +174,13 @@ clientStub g sessionId thVersion thAuth = do
}
}
type SMPClient = ProtocolClient ErrorType BrokerMsg
type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg
-- | Type for client command data
type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg)
-- | Type synonym for transmission from some SPM server queue.
type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg)
type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, EntityId, msg)
data HostMode
= -- | prefer (or require) onion hosts when connecting via SOCKS proxy
@@ -241,7 +241,7 @@ transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
-- | protocol client configuration.
data ProtocolClientConfig = ProtocolClientConfig
data ProtocolClientConfig v = ProtocolClientConfig
{ -- | size of TBQueue to use for server commands and responses
qSize :: Natural,
-- | default server port if port is not specified in ProtocolServer
@@ -249,13 +249,13 @@ data ProtocolClientConfig = ProtocolClientConfig
-- | network configuration
networkConfig :: NetworkConfig,
-- | client-server protocol version range
serverVRange :: VersionRange,
serverVRange :: VersionRange v,
-- | delay between sending batches of commands (microseconds)
batchDelay :: Maybe Int
}
-- | Default protocol client configuration.
defaultClientConfig :: VersionRange -> ProtocolClientConfig
defaultClientConfig :: VersionRange v -> ProtocolClientConfig v
defaultClientConfig serverVRange =
ProtocolClientConfig
{ qSize = 64,
@@ -265,7 +265,7 @@ defaultClientConfig serverVRange =
batchDelay = Nothing
}
defaultSMPClientConfig :: ProtocolClientConfig
defaultSMPClientConfig :: ProtocolClientConfig SMPVersion
defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange
data Request err msg = Request
@@ -292,15 +292,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
onionHost = find isOnionHost hosts
publicHost = find (not . isOnionHost) hosts
protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String
protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err msg -> String
protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_
where
snd3 (_, s, _) = s
transportHost' :: ProtocolClient err msg -> TransportHost
transportHost' :: ProtocolClient v err msg -> TransportHost
transportHost' = transportHost . client_
transportSession' :: ProtocolClient err msg -> TransportSession msg
transportSession' :: ProtocolClient v err msg -> TransportSession msg
transportSession' = transportSession . client_
type UserId = Int64
@@ -313,7 +313,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
--
-- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall err msg. Protocol err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg))
getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, serverVRange, batchDelay} msgQ disconnected = do
case chooseTransportHost networkConfig (host srv) of
Right useHost ->
@@ -322,7 +322,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
Left e -> pure $ Left e
where
NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig
mkProtocolClient :: TransportHost -> STM (PClient err msg)
mkProtocolClient :: TransportHost -> STM (PClient v err msg)
mkProtocolClient transportHost = do
connected <- newTVar False
pingErrorCount <- newTVar 0
@@ -345,7 +345,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
msgQ
}
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg))
runClient (port', ATransport t) useHost c = do
cVar <- newEmptyTMVarIO
let tcConfig = transportClientConfig networkConfig
@@ -366,10 +366,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
"80" -> ("80", transport @WS)
p -> (p, transport @TLS)
client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO ()
client :: forall c. Transport c => TProxy c -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c -> IO ()
client _ c cVar h = do
ks <- atomically $ C.generateKeyPair g
runExceptT (protocolClientHandshake @err @msg h ks (keyHash srv) serverVRange) >>= \case
runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange) >>= \case
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
Right th@THandle {params} -> do
sessionTs <- getCurrentTime
@@ -380,16 +380,16 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0])
`finally` disconnected c'
send :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
send :: Transport c => ProtocolClient v err msg -> THandle v c -> IO ()
send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPutLog h
receive :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
receive :: Transport c => ProtocolClient v err msg -> THandle v c -> IO ()
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
ping :: ProtocolClient err msg -> IO ()
ping :: ProtocolClient v err msg -> IO ()
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
threadDelay' smpPingInterval
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @v @err @msg) >>= \case
Left PCEResponseTimeout -> do
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
when (maxCnt == 0 || cnt < maxCnt) $ ping c
@@ -397,10 +397,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
where
maxCnt = smpPingCount networkConfig
process :: ProtocolClient err msg -> IO ()
process :: ProtocolClient v err msg -> IO ()
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO ()
processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO ()
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) =
if B.null $ bs corrId
then sendMsg respOrErr
@@ -428,7 +428,7 @@ proxyUsername :: TransportSession msg -> ByteString
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
-- | Disconnects client from the server and terminates client threads.
closeProtocolClient :: ProtocolClient err msg -> IO ()
closeProtocolClient :: ProtocolClient v err msg -> IO ()
closeProtocolClient = mapM_ uninterruptibleCancel . action
-- | SMP client error type.
@@ -517,7 +517,7 @@ processSUBResponse c (Response rId r) = case r of
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg
serverTransmission :: ProtocolClient v err msg -> RecipientId -> msg -> ServerTransmission v msg
serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} entityId message =
(transportSession, thVersion, sessionId, entityId, message)
@@ -635,7 +635,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg)
-- | Send multiple commands with batching and collect responses
sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do
bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs
validate . concat =<< mapM (sendBatch c) bs
@@ -652,12 +652,12 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz
where
diff = L.length cs - length rs
streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do
bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs
mapM_ (cb <=< sendBatch c) bs
sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg]
sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg]
sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
case b of
TBError e Request {entityId} -> do
@@ -673,7 +673,7 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
(: []) <$> getResponse c r
-- | Send Protocol command
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd =
ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd)
where
@@ -690,7 +690,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand
| otherwise = tEncode t
-- TODO switch to timeout or TimeManager that supports Int64
getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg)
getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg)
getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do
response <-
timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case
@@ -698,7 +698,7 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Requ
Nothing -> pure $ Left PCEResponseTimeout
pure Response {entityId, response}
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg)
mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg)
mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do
corrId <- atomically getNextCorrId
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd)
+2 -2
View File
@@ -65,7 +65,7 @@ type SMPSub = (SMPSubParty, QueueId)
-- type SMPServerSub = (SMPServer, SMPSub)
data SMPClientAgentConfig = SMPClientAgentConfig
{ smpCfg :: ProtocolClientConfig,
{ smpCfg :: ProtocolClientConfig SMPVersion,
reconnectInterval :: RetryInterval,
msgQSize :: Natural,
agentQSize :: Natural,
@@ -91,7 +91,7 @@ defaultSMPClientAgentConfig =
data SMPClientAgent = SMPClientAgent
{ agentCfg :: SMPClientAgentConfig,
msgQ :: TBQueue (ServerTransmission BrokerMsg),
msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg),
agentQ :: TBQueue SMPClientAgentEvent,
randomDrg :: TVar ChaChaDRG,
smpClients :: TMap SMPServer SMPClientVar,
+80
View File
@@ -0,0 +1,80 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Compression where
import qualified Codec.Compression.Zstd.FFI as Z
import Control.Monad (forM)
import Control.Monad.Except
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Data.Either (fromRight)
import Data.List.NonEmpty (NonEmpty)
import Foreign
import Foreign.C.Types
import GHC.IO (unsafePerformIO)
import Simplex.Messaging.Encoding
import UnliftIO.Exception (bracket)
data Compressed
= -- | Short messages are left intact to skip copying and FFI festivities.
Passthrough ByteString
| -- | Generic compression using no extra context.
Compressed Large
-- | Messages below this length are not encoded to avoid compression overhead.
maxLengthPassthrough :: Int
maxLengthPassthrough = 180 -- Sampled from real client data. Messages with length > 180 rapidly gain compression ratio.
instance Encoding Compressed where
smpEncode = \case
Passthrough bytes -> "0" <> smpEncode bytes
Compressed bytes -> "1" <> smpEncode bytes
smpP =
smpP >>= \case
'0' -> Passthrough <$> smpP
'1' -> Compressed <$> smpP
x -> fail $ "unknown Compressed tag: " <> show x
type CompressCtx = (Ptr Z.CCtx, Ptr CChar, CSize)
withCompressCtx :: CSize -> (CompressCtx -> IO a) -> IO a
withCompressCtx scratchSize action =
bracket Z.createCCtx Z.freeCCtx $ \cctx ->
allocaBytes (fromIntegral scratchSize) $ \scratchPtr ->
action (cctx, scratchPtr, scratchSize)
-- | Compress bytes, falling back to Passthrough in case of some internal error.
compress :: CompressCtx -> ByteString -> IO Compressed
compress ctx bs = fromRight (Passthrough bs) <$> compress_ ctx bs
compress_ :: CompressCtx -> ByteString -> IO (Either String Compressed)
compress_ (cctx, scratchPtr, scratchSize) bs
| B.length bs <= maxLengthPassthrough = pure . Right $ Passthrough bs
| otherwise =
B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> runExceptT $ do
-- should not fail, unless input buffer is too short
dstSize <- ExceptT $ Z.checkError $ Z.compressCCtx cctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize) 3
liftIO $ Compressed . Large <$> B.packCStringLen (scratchPtr, fromIntegral dstSize)
type DecompressCtx = (Ptr Z.DCtx, Ptr CChar, CSize)
withDecompressCtx :: Int -> (DecompressCtx -> IO a) -> IO a
withDecompressCtx maxUnpackedSize action =
bracket Z.createDCtx Z.freeDCtx $ \dctx ->
allocaBytes maxUnpackedSize $ \scratchPtr ->
action (dctx, scratchPtr, fromIntegral maxUnpackedSize)
decompress :: DecompressCtx -> Compressed -> IO (Either String ByteString)
decompress (dctx, scratchPtr, scratchSize) = \case
Passthrough bs -> pure $ Right bs
Compressed (Large bs) ->
B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> do
res <- Z.checkError $ Z.decompressDCtx dctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize)
forM res $ \dstSize -> B.packCStringLen (scratchPtr, fromIntegral dstSize)
decompressBatch :: Int -> NonEmpty Compressed -> NonEmpty (Either String ByteString)
decompressBatch maxUnpackedSize items = unsafePerformIO $ withDecompressCtx maxUnpackedSize $ forM items . decompress
{-# NOINLINE decompressBatch #-} -- prevent double-evaluation under unsafePerformIO
+3 -39
View File
@@ -101,6 +101,7 @@ module Simplex.Messaging.Crypto
verify,
verify',
validSignatureSize,
checkAlgorithm,
-- * crypto_box authenticator, as discussed in https://groups.google.com/g/sci.crypt/c/73yb5a9pz2Y/m/LNgRO7IYXOwJ
CbAuthenticator (..),
@@ -243,8 +244,6 @@ data SAlgorithm :: Algorithm -> Type where
SX25519 :: SAlgorithm X25519
SX448 :: SAlgorithm X448
deriving instance Eq (SAlgorithm a)
deriving instance Show (SAlgorithm a)
data Alg = forall a. AlgorithmI a => Alg (SAlgorithm a)
@@ -297,11 +296,6 @@ data APublicKey
AlgorithmI a =>
APublicKey (SAlgorithm a) (PublicKey a)
instance Eq APublicKey where
APublicKey a k == APublicKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
instance Encoding APublicKey where
smpEncode = smpEncode . encodePubKey
{-# INLINE smpEncode #-}
@@ -342,11 +336,6 @@ data APrivateKey
AlgorithmI a =>
APrivateKey (SAlgorithm a) (PrivateKey a)
instance Eq APrivateKey where
APrivateKey a k == APrivateKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
deriving instance Show APrivateKey
type PrivateKeyEd25519 = PrivateKey Ed25519
@@ -372,11 +361,6 @@ data APrivateSignKey
(AlgorithmI a, SignatureAlgorithm a) =>
APrivateSignKey (SAlgorithm a) (PrivateKey a)
instance Eq APrivateSignKey where
APrivateSignKey a k == APrivateSignKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
deriving instance Show APrivateSignKey
instance Encoding APrivateSignKey where
@@ -396,11 +380,6 @@ data APublicVerifyKey
(AlgorithmI a, SignatureAlgorithm a) =>
APublicVerifyKey (SAlgorithm a) (PublicKey a)
instance Eq APublicVerifyKey where
APublicVerifyKey a k == APublicVerifyKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
deriving instance Show APublicVerifyKey
data APrivateDhKey
@@ -408,11 +387,6 @@ data APrivateDhKey
(AlgorithmI a, DhAlgorithm a) =>
APrivateDhKey (SAlgorithm a) (PrivateKey a)
instance Eq APrivateDhKey where
APrivateDhKey a k == APrivateDhKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
deriving instance Show APrivateDhKey
data APublicDhKey
@@ -420,11 +394,6 @@ data APublicDhKey
(AlgorithmI a, DhAlgorithm a) =>
APublicDhKey (SAlgorithm a) (PublicKey a)
instance Eq APublicDhKey where
APublicDhKey a k == APublicDhKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
deriving instance Show APublicDhKey
data DhSecret (a :: Algorithm) where
@@ -787,8 +756,6 @@ data Signature (a :: Algorithm) where
SignatureEd25519 :: Ed25519.Signature -> Signature Ed25519
SignatureEd448 :: Ed448.Signature -> Signature Ed448
deriving instance Eq (Signature a)
deriving instance Show (Signature a)
data ASignature
@@ -796,11 +763,6 @@ data ASignature
(AlgorithmI a, SignatureAlgorithm a) =>
ASignature (SAlgorithm a) (Signature a)
instance Eq ASignature where
ASignature a s == ASignature a' s' = case testEquality a a' of
Just Refl -> s == s'
_ -> False
deriving instance Show ASignature
class CryptoSignature s where
@@ -885,6 +847,8 @@ data CryptoError
CryptoHeaderError String
| -- | no sending chain key in ratchet state
CERatchetState
| -- | no decapsulation key in ratchet state
CERatchetKEMState
| -- | header decryption error (could indicate that another key should be tried)
CERatchetHeader
| -- | too many skipped messages
File diff suppressed because it is too large Load Diff
@@ -19,16 +19,20 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
newtype KEMPublicKey = KEMPublicKey ByteString
deriving (Show)
deriving (Eq, Show)
newtype KEMSecretKey = KEMSecretKey ScrubbedBytes
deriving (Show)
deriving (Eq, Show)
newtype KEMCiphertext = KEMCiphertext ByteString
deriving (Show)
deriving (Eq, Show)
newtype KEMSharedKey = KEMSharedKey ScrubbedBytes
deriving (Show)
deriving (Eq, Show)
unsafeRevealKEMSharedKey :: KEMSharedKey -> String
unsafeRevealKEMSharedKey (KEMSharedKey scrubbed) = show (BA.convert scrubbed :: ByteString)
{-# DEPRECATED unsafeRevealKEMSharedKey "unsafeRevealKEMSharedKey left in code" #-}
type KEMKeyPair = (KEMPublicKey, KEMSecretKey)
@@ -60,6 +64,18 @@ sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) =
KEMSharedKey
<$> BA.alloc c_SNTRUP761_SIZE (\kPtr -> c_sntrup761_dec kPtr cPtr skPtr)
instance Encoding KEMSecretKey where
smpEncode (KEMSecretKey c) = smpEncode . Large $ BA.convert c
smpP = KEMSecretKey . BA.convert . unLarge <$> smpP
instance StrEncoding KEMSecretKey where
strEncode (KEMSecretKey pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMSecretKey . BA.convert <$> strP @ByteString
instance Encoding KEMPublicKey where
smpEncode (KEMPublicKey pk) = smpEncode . Large $ BA.convert pk
smpP = KEMPublicKey . BA.convert . unLarge <$> smpP
instance StrEncoding KEMPublicKey where
strEncode (KEMPublicKey pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMPublicKey . BA.convert <$> strP @ByteString
@@ -68,6 +84,25 @@ instance Encoding KEMCiphertext where
smpEncode (KEMCiphertext c) = smpEncode . Large $ BA.convert c
smpP = KEMCiphertext . BA.convert . unLarge <$> smpP
instance Encoding KEMSharedKey where
smpEncode (KEMSharedKey c) = smpEncode (BA.convert c :: ByteString)
smpP = KEMSharedKey . BA.convert <$> smpP @ByteString
instance StrEncoding KEMCiphertext where
strEncode (KEMCiphertext pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMCiphertext . BA.convert <$> strP @ByteString
instance StrEncoding KEMSharedKey where
strEncode (KEMSharedKey pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMSharedKey . BA.convert <$> strP @ByteString
instance ToJSON KEMSecretKey where
toJSON = strToJSON
toEncoding = strToJEncoding
instance FromJSON KEMSecretKey where
parseJSON = strParseJSON "KEMSecretKey"
instance ToJSON KEMPublicKey where
toJSON = strToJSON
toEncoding = strToJEncoding
@@ -75,8 +110,22 @@ instance ToJSON KEMPublicKey where
instance FromJSON KEMPublicKey where
parseJSON = strParseJSON "KEMPublicKey"
instance ToJSON KEMCiphertext where
toJSON = strToJSON
toEncoding = strToJEncoding
instance FromJSON KEMCiphertext where
parseJSON = strParseJSON "KEMCiphertext"
instance ToField KEMSharedKey where
toField (KEMSharedKey k) = toField (BA.convert k :: ByteString)
instance FromField KEMSharedKey where
fromField f = KEMSharedKey . BA.convert @ByteString <$> fromField f
instance ToJSON KEMSharedKey where
toJSON = strToJSON
toEncoding = strToJEncoding
instance FromJSON KEMSharedKey where
parseJSON = strParseJSON "KEMSharedKey"
+6
View File
@@ -179,6 +179,12 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin
strP = (,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP
{-# INLINE strP #-}
instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncoding e, StrEncoding f) => StrEncoding (a, b, c, d, e, f) where
strEncode (a, b, c, d, e, f) = B.unwords [strEncode a, strEncode b, strEncode c, strEncode d, strEncode e, strEncode f]
{-# INLINE strEncode #-}
strP = (,,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP
{-# INLINE strP #-}
strP_ :: StrEncoding a => Parser a
strP_ = strP <* A.space
@@ -10,15 +10,15 @@ import Data.Word (Word16)
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Transport (supportedClientNTFVRange)
import Simplex.Messaging.Notifications.Transport (NTFVersion, supportedClientNTFVRange)
import Simplex.Messaging.Protocol (ErrorType)
import Simplex.Messaging.Util (bshow)
type NtfClient = ProtocolClient ErrorType NtfResponse
type NtfClient = ProtocolClient NTFVersion ErrorType NtfResponse
type NtfClientError = ProtocolClientError ErrorType
defaultNTFClientConfig :: ProtocolClientConfig
defaultNTFClientConfig :: ProtocolClientConfig NTFVersion
defaultNTFClientConfig = defaultClientConfig supportedClientNTFVRange
ntfRegisterToken :: NtfClient -> C.APrivateAuthKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519)
@@ -28,7 +28,7 @@ import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake)
import Simplex.Messaging.Parsers (fromTextField_)
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
import Simplex.Messaging.Util (eitherToMaybe, (<$?>))
@@ -147,7 +147,7 @@ instance Encoding ANewNtfEntity where
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP)
_ -> fail "bad ANewNtfEntity"
instance Protocol ErrorType NtfResponse where
instance Protocol NTFVersion ErrorType NtfResponse where
type ProtoCommand NtfResponse = NtfCmd
type ProtoType NtfResponse = 'PNTF
protocolClientHandshake = ntfClientHandshake
@@ -184,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
deriving instance Show NtfCmd
instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) where
type Tag (NtfCommand e) = NtfCommandTag e
encodeProtocol _v = \case
TNEW newTkn -> e (TNEW_, ' ', newTkn)
@@ -203,7 +203,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag)
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, entityId, _) cmd = case cmd of
@@ -223,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
| not (B.null entityId) = Left $ CMD HAS_AUTH
| otherwise = Right cmd
instance ProtocolEncoding ErrorType NtfCmd where
instance ProtocolEncoding NTFVersion ErrorType NtfCmd where
type Tag NtfCmd = NtfCmdTag
encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c
@@ -243,7 +243,7 @@ instance ProtocolEncoding ErrorType NtfCmd where
SDEL_ -> pure SDEL
PING_ -> pure PING
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
{-# INLINE fromProtocolError #-}
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
@@ -290,7 +290,7 @@ data NtfResponse
| NRPong
deriving (Show)
instance ProtocolEncoding ErrorType NtfResponse where
instance ProtocolEncoding NTFVersion ErrorType NtfResponse where
type Tag NtfResponse = NtfResponseTag
encodeProtocol _v = \case
NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey)
@@ -338,7 +338,7 @@ updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do
old <- atomically $ stateTVar tknStatus (,status)
when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status
runNtfClientTransport :: Transport c => THandle c -> M ()
runNtfClientTransport :: Transport c => THandleNTF c -> M ()
runNtfClientTransport th@THandle {params} = do
qSize <- asks $ clientQSize . config
ts <- liftIO getSystemTime
@@ -355,7 +355,7 @@ runNtfClientTransport th@THandle {params} = do
clientDisconnected :: NtfServerClient -> IO ()
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False
receive :: Transport c => THandle c -> NtfServerClient -> M ()
receive :: Transport c => THandleNTF c -> NtfServerClient -> M ()
receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do
ts <- liftIO $ tGet th
forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do
@@ -370,7 +370,7 @@ receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ
where
write q t = atomically $ writeTBQueue q t
send :: Transport c => THandle c -> NtfServerClient -> IO ()
send :: Transport c => THandleNTF c -> NtfServerClient -> IO ()
send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do
t <- atomically $ readTBQueue sndQ
void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)]
@@ -24,6 +24,7 @@ import Numeric.Natural
import Simplex.Messaging.Client.Agent
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF)
import Simplex.Messaging.Notifications.Server.Push.APNS
import Simplex.Messaging.Notifications.Server.Stats
import Simplex.Messaging.Notifications.Server.Store
@@ -34,7 +35,6 @@ import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport, THandleParams)
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
import Simplex.Messaging.Version (VersionRange)
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
import UnliftIO.STM
@@ -60,7 +60,7 @@ data NtfServerConfig = NtfServerConfig
logStatsStartTime :: Int64,
serverStatsLogFile :: FilePath,
serverStatsBackupFile :: Maybe FilePath,
ntfServerVRange :: VersionRange,
ntfServerVRange :: VersionRangeNTF,
transportConfig :: TransportServerConfig
}
@@ -161,13 +161,13 @@ data NtfRequest
data NtfServerClient = NtfServerClient
{ rcvQ :: TBQueue NtfRequest,
sndQ :: TBQueue (Transmission NtfResponse),
ntfThParams :: THandleParams,
ntfThParams :: THandleParams NTFVersion,
connected :: TVar Bool,
rcvActiveAt :: TVar SystemTime,
sndActiveAt :: TVar SystemTime
}
newNtfServerClient :: Natural -> THandleParams -> SystemTime -> STM NtfServerClient
newNtfServerClient :: Natural -> THandleParams NTFVersion -> SystemTime -> STM NtfServerClient
newNtfServerClient qSize ntfThParams ts = do
rcvQ <- newTBQueue qSize
sndQ <- newTBQueue qSize
@@ -3,6 +3,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Notifications.Transport where
@@ -12,33 +13,51 @@ import Control.Monad.Except
import Data.Attoparsec.ByteString.Char8 (Parser)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Word (Word16)
import qualified Data.X509 as X
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Transport
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import Simplex.Messaging.Util (liftEitherWith)
ntfBlockSize :: Int
ntfBlockSize = 512
authBatchCmdsNTFVersion :: Version
authBatchCmdsNTFVersion = 2
data NTFVersion
currentClientNTFVersion :: Version
currentClientNTFVersion = 1
instance VersionScope NTFVersion
currentServerNTFVersion :: Version
currentServerNTFVersion = 1
type VersionNTF = Version NTFVersion
supportedClientNTFVRange :: VersionRange
supportedClientNTFVRange = mkVersionRange 1 currentClientNTFVersion
type VersionRangeNTF = VersionRange NTFVersion
supportedServerNTFVRange :: VersionRange
supportedServerNTFVRange = mkVersionRange 1 currentServerNTFVersion
pattern VersionNTF :: Word16 -> VersionNTF
pattern VersionNTF v = Version v
initialNTFVersion :: VersionNTF
initialNTFVersion = VersionNTF 1
authBatchCmdsNTFVersion :: VersionNTF
authBatchCmdsNTFVersion = VersionNTF 2
currentClientNTFVersion :: VersionNTF
currentClientNTFVersion = VersionNTF 1
currentServerNTFVersion :: VersionNTF
currentServerNTFVersion = VersionNTF 1
supportedClientNTFVRange :: VersionRangeNTF
supportedClientNTFVRange = mkVersionRange initialNTFVersion currentClientNTFVersion
supportedServerNTFVRange :: VersionRangeNTF
supportedServerNTFVRange = mkVersionRange initialNTFVersion currentServerNTFVersion
type THandleNTF c = THandle NTFVersion c
data NtfServerHandshake = NtfServerHandshake
{ ntfVersionRange :: VersionRange,
{ ntfVersionRange :: VersionRangeNTF,
sessionId :: SessionId,
-- pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: Maybe (X.SignedExact X.PubKey)
@@ -46,7 +65,7 @@ data NtfServerHandshake = NtfServerHandshake
data NtfClientHandshake = NtfClientHandshake
{ -- | agreed SMP notifications server protocol version
ntfVersion :: Version,
ntfVersion :: VersionNTF,
-- | server identity - CA certificate fingerprint
keyHash :: C.KeyHash,
-- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
@@ -66,12 +85,12 @@ instance Encoding NtfServerHandshake where
authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP
pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey}
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
encodeAuthEncryptCmds :: Encoding a => VersionNTF -> Maybe a -> ByteString
encodeAuthEncryptCmds v k
| v >= authBatchCmdsNTFVersion = maybe "" smpEncode k
| otherwise = ""
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
authEncryptCmdsP :: VersionNTF -> Parser a -> Parser (Maybe a)
authEncryptCmdsP v p = if v >= authBatchCmdsNTFVersion then Just <$> p else pure Nothing
instance Encoding NtfClientHandshake where
@@ -83,16 +102,16 @@ instance Encoding NtfClientHandshake where
authPubKey <- ntfAuthPubKeyP ntfVersion
pure NtfClientHandshake {ntfVersion, keyHash, authPubKey}
ntfAuthPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519)
ntfAuthPubKeyP :: VersionNTF -> Parser (Maybe C.PublicKeyX25519)
ntfAuthPubKeyP v = if v >= authBatchCmdsNTFVersion then Just <$> smpP else pure Nothing
encodeNtfAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString
encodeNtfAuthPubKey :: VersionNTF -> Maybe C.PublicKeyX25519 -> ByteString
encodeNtfAuthPubKey v k
| v >= authBatchCmdsNTFVersion = maybe "" smpEncode k
| otherwise = ""
-- | Notifcations server transport handshake.
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c)
ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
let sk = C.signX509 serverSignKey $ C.publicToX509 k
@@ -106,7 +125,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
| otherwise -> throwError $ TEHandshake VERSION
-- | Notifcations server client transport handshake.
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c)
ntfClientHandshake c (k, pk) keyHash ntfVRange = do
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th
@@ -122,15 +141,15 @@ ntfClientHandshake c (k, pk) keyHash ntfVRange = do
pure $ ntfThHandle th v pk sk_
Nothing -> throwError $ TEHandshake VERSION
ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
ntfThHandle :: forall c. THandleNTF c -> VersionNTF -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleNTF c
ntfThHandle th@THandle {params} v privKey k_ =
-- TODO drop SMP v6: make thAuth non-optional
let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_
v3 = v >= authBatchCmdsNTFVersion
params' = params {thVersion = v, thAuth, implySessId = v3, batch = v3}
in (th :: THandle c) {params = params'}
in (th :: THandleNTF c) {params = params'}
ntfTHandle :: Transport c => c -> THandle c
ntfTHandle :: Transport c => c -> THandleNTF c
ntfTHandle c = THandle {connection = c, params}
where
params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = False}
params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = VersionNTF 0, thAuth = Nothing, implySessId = False, batch = False}
+73 -47
View File
@@ -1,4 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
@@ -46,6 +47,10 @@ module Simplex.Messaging.Protocol
e2eEncMessageLength,
-- * SMP protocol types
SMPClientVersion,
VersionSMPC,
VersionRangeSMPC,
pattern VersionSMPC,
ProtocolEncoding (..),
Command (..),
SubscriptionMode (..),
@@ -117,6 +122,7 @@ module Simplex.Messaging.Protocol
SMPMsgMeta (..),
NMsgMeta (..),
MsgFlags (..),
initialSMPClientVersion,
userProtocol,
rcvMessageMeta,
noMsgFlags,
@@ -152,6 +158,7 @@ module Simplex.Messaging.Protocol
tEncodeBatch1,
batchTransmissions,
batchTransmissions',
batchTransmissions_,
-- * exports for tests
CommandTag (..),
@@ -167,6 +174,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson.TH as J
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -180,6 +188,7 @@ import Data.Maybe (isJust, isNothing)
import Data.String
import Data.Time.Clock.System (SystemTime (..))
import Data.Type.Equality
import Data.Word (Word16)
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
import Network.Socket (ServiceName)
import qualified Simplex.Messaging.Crypto as C
@@ -191,19 +200,34 @@ import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..))
import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>))
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
-- SMP client protocol version history:
-- 1 - binary protocol encoding (1/1/2022)
-- 2 - multiple server hostnames and versioned queue addresses (8/12/2022)
srvHostnamesSMPClientVersion :: Version
srvHostnamesSMPClientVersion = 2
data SMPClientVersion
currentSMPClientVersion :: Version
currentSMPClientVersion = 2
instance VersionScope SMPClientVersion
supportedSMPClientVRange :: VersionRange
supportedSMPClientVRange = mkVersionRange 1 currentSMPClientVersion
type VersionSMPC = Version SMPClientVersion
type VersionRangeSMPC = VersionRange SMPClientVersion
pattern VersionSMPC :: Word16 -> VersionSMPC
pattern VersionSMPC v = Version v
initialSMPClientVersion :: VersionSMPC
initialSMPClientVersion = VersionSMPC 1
srvHostnamesSMPClientVersion :: VersionSMPC
srvHostnamesSMPClientVersion = VersionSMPC 2
currentSMPClientVersion :: VersionSMPC
currentSMPClientVersion = VersionSMPC 2
supportedSMPClientVRange :: VersionRangeSMPC
supportedSMPClientVRange = mkVersionRange initialSMPClientVersion currentSMPClientVersion
maxMessageLength :: Int
maxMessageLength = 16088
@@ -273,7 +297,7 @@ data RawTransmission = RawTransmission
data TransmissionAuth
= TASignature C.ASignature
| TAAuthenticator C.CbAuthenticator
deriving (Eq, Show)
deriving (Show)
-- this encoding is backwards compatible with v6 that used Maybe C.ASignature instead of TAuthorization
tAuthBytes :: Maybe TransmissionAuth -> ByteString
@@ -339,8 +363,6 @@ data Command (p :: Party) where
deriving instance Show (Command p)
deriving instance Eq (Command p)
data SubscriptionMode = SMSubscribe | SMOnlyCreate
deriving (Eq, Show)
@@ -645,7 +667,7 @@ data ClientMsgEnvelope = ClientMsgEnvelope
deriving (Show)
data PubHeader = PubHeader
{ phVersion :: Version,
{ phVersion :: VersionSMPC,
phE2ePubDhKey :: Maybe C.PublicKeyX25519
}
deriving (Show)
@@ -747,11 +769,11 @@ instance NFData (SProtocolType p) where rnf spt = spt `seq` ()
data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p)
deriving instance Show AProtocolType
instance Eq AProtocolType where
AProtocolType p == AProtocolType p' = isJust $ testEquality p p'
deriving instance Show AProtocolType
instance TestEquality SProtocolType where
testEquality SPSMP SPSMP = Just Refl
testEquality SPNTF SPNTF = Just Refl
@@ -1058,7 +1080,7 @@ data CommandError
deriving (Eq, Read, Show)
-- | SMP transmission parser.
transmissionP :: THandleParams -> Parser RawTransmission
transmissionP :: THandleParams v -> Parser RawTransmission
transmissionP THandleParams {sessionId, implySessId} = do
authenticator <- smpP
authorized <- A.takeByteString
@@ -1072,16 +1094,16 @@ transmissionP THandleParams {sessionId, implySessId} = do
command <- A.takeByteString
pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command}
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
class (ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where
type ProtoCommand msg = cmd | cmd -> msg
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> ExceptT TransportError IO (THandle v c)
protocolPing :: ProtoCommand msg
protocolError :: msg -> Maybe err
type ProtoServer msg = ProtocolServer (ProtoType msg)
instance Protocol ErrorType BrokerMsg where
instance Protocol SMPVersion ErrorType BrokerMsg where
type ProtoCommand BrokerMsg = Cmd
type ProtoType BrokerMsg = 'PSMP
protocolClientHandshake = smpClientHandshake
@@ -1090,14 +1112,14 @@ instance Protocol ErrorType BrokerMsg where
ERR e -> Just e
_ -> Nothing
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where
type Tag msg
encodeProtocol :: Version -> msg -> ByteString
protocolP :: Version -> Tag msg -> Parser msg
encodeProtocol :: Version v -> msg -> ByteString
protocolP :: Version v -> Tag msg -> Parser msg
fromProtocolError :: ProtocolErrorType -> err
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
type Tag (Command p) = CommandTag p
encodeProtocol v = \case
NEW rKey dhKey auth_ subMode
@@ -1124,7 +1146,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials (auth, _, queueId, _) cmd = case cmd of
@@ -1146,7 +1168,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where
| isNothing auth || B.null queueId -> Left $ CMD NO_AUTH
| otherwise -> Right cmd
instance ProtocolEncoding ErrorType Cmd where
instance ProtocolEncoding SMPVersion ErrorType Cmd where
type Tag Cmd = CmdTag
encodeProtocol v (Cmd _ c) = encodeProtocol v c
@@ -1174,12 +1196,12 @@ instance ProtocolEncoding ErrorType Cmd where
PING_ -> pure PING
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
{-# INLINE fromProtocolError #-}
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
instance ProtocolEncoding ErrorType BrokerMsg where
instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
type Tag BrokerMsg = BrokerMsgTag
encodeProtocol _v = \case
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
@@ -1231,12 +1253,12 @@ instance ProtocolEncoding ErrorType BrokerMsg where
| otherwise -> Right cmd
-- | Parse SMP protocol commands and broker messages
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
parseProtocol :: forall v err msg. ProtocolEncoding v err msg => Version v -> ByteString -> Either err msg
parseProtocol v s =
let (tag, params) = B.break (== ' ') s
in case decodeTag tag of
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
Just cmd -> parse (protocolP v cmd) (fromProtocolError @v @err @msg $ PECmdSyntax) params
Nothing -> Left $ fromProtocolError @v @err @msg $ PECmdUnknown
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
checkParty c = case testEquality (sParty @p) (sParty @p') of
@@ -1291,7 +1313,7 @@ instance Encoding CommandError where
_ -> fail "bad command error type"
-- | Send signed SMP transmission to TCP transport.
tPut :: Transport c => THandle c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
tPut :: Transport c => THandle v c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params)
where
tPutBatch :: TransportBatch () -> IO [Either TransportError ()]
@@ -1300,7 +1322,7 @@ tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (ba
TBTransmissions s n _ -> replicate n <$> tPutLog th s
TBTransmission s _ -> (: []) <$> tPutLog th s
tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
tPutLog :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ())
tPutLog th s = do
r <- tPutBlock th s
case r of
@@ -1314,11 +1336,11 @@ data TransportBatch r = TBTransmissions ByteString Int [r] | TBTransmission Byte
batchTransmissions :: Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission) -> [TransportBatch ()]
batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,())
-- | encodes and batches transmissions into blocks,
-- | encodes and batches transmissions into blocks
batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission, r) -> [TransportBatch r]
batchTransmissions' batch bSize
| batch = addBatch . foldr addTransmission ([], 0, 0, [], [])
| otherwise = map mkBatch1 . L.toList
batchTransmissions' batch bSize ts
| batch = batchTransmissions_ bSize $ L.map (first $ fmap tEncodeForBatch) ts
| otherwise = map mkBatch1 $ L.toList ts
where
mkBatch1 :: (Either TransportError SentRawTransmission, r) -> TransportBatch r
mkBatch1 (t_, r) = case t_ of
@@ -1329,17 +1351,21 @@ batchTransmissions' batch bSize
| otherwise -> TBError TELargeMsg r
where
s = tEncode t
-- | Pack encoded transmissions into batches
batchTransmissions_ :: Int -> NonEmpty (Either TransportError ByteString, r) -> [TransportBatch r]
batchTransmissions_ bSize = addBatch . foldr addTransmission ([], 0, 0, [], [])
where
-- 3 = 2 bytes reserved for pad size + 1 for transmission count
bSize' = bSize - 3
addTransmission :: (Either TransportError SentRawTransmission, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r])
addTransmission (t_, r) acc@(bs, len, n, ss, rs) = case t_ of
addTransmission :: (Either TransportError ByteString, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r])
addTransmission (t_, r) acc@(bs, !len, !n, ss, rs) = case t_ of
Left e -> (TBError e r : addBatch acc, 0, 0, [], [])
Right t
Right s
| len' <= bSize' && n < 255 -> (bs, len', 1 + n, s : ss, r : rs)
| sLen <= bSize' -> (addBatch acc, sLen, 1, [s], [r])
| otherwise -> (TBError TELargeMsg r : addBatch acc, 0, 0, [], [])
where
s = tEncodeForBatch t
sLen = B.length s
len' = len + sLen
addBatch :: ([TransportBatch r], Int, Int, [ByteString], [r]) -> [TransportBatch r]
@@ -1362,7 +1388,7 @@ tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t
-- tForAuth is lazy to avoid computing it when there is no key to sign
data TransmissionForAuth = TransmissionForAuth {tForAuth :: ~ByteString, tToSend :: ByteString}
encodeTransmissionForAuth :: ProtocolEncoding e c => THandleParams -> Transmission c -> TransmissionForAuth
encodeTransmissionForAuth :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> TransmissionForAuth
encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t =
TransmissionForAuth {tForAuth, tToSend = if implySessId then t' else tForAuth}
where
@@ -1370,24 +1396,24 @@ encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId}
t' = encodeTransmission_ v t
{-# INLINE encodeTransmissionForAuth #-}
encodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> ByteString
encodeTransmission :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> ByteString
encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t =
if implySessId then t' else smpEncode sessionId <> t'
where
t' = encodeTransmission_ v t
{-# INLINE encodeTransmission #-}
encodeTransmission_ :: ProtocolEncoding e c => Version -> Transmission c -> ByteString
encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString
encodeTransmission_ v (CorrId corrId, queueId, command) =
smpEncode (corrId, queueId) <> encodeProtocol v command
{-# INLINE encodeTransmission_ #-}
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission))
tGetParse :: Transport c => THandle v c -> IO (NonEmpty (Either TransportError RawTransmission))
tGetParse th@THandle {params} = eitherList (tParse params) <$> tGetBlock th
{-# INLINE tGetParse #-}
tParse :: THandleParams -> ByteString -> NonEmpty (Either TransportError RawTransmission)
tParse :: THandleParams v -> ByteString -> NonEmpty (Either TransportError RawTransmission)
tParse thParams@THandleParams {batch} s
| batch = eitherList (L.map (\(Large t) -> tParse1 t)) ts
| otherwise = [tParse1 s]
@@ -1399,24 +1425,24 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
eitherList = either (\e -> [Left e])
-- | Receive client and server transmissions (determined by `cmd` type).
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
tGet :: forall v err cmd c. (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (NonEmpty (SignedTransmission err cmd))
tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => THandleParams -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v -> Either TransportError RawTransmission -> SignedTransmission err cmd
tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case
Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command}
| implySessId || sessId == sessionId ->
let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator
in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PESession))
Left _ -> tError ""
where
tError :: ByteString -> SignedTransmission err cmd
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PEBlock))
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
tParseValidate signed t@(sig, corrId, entityId, command) =
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t
in (sig, signed, (CorrId corrId, entityId, cmd))
$(J.deriveJSON defaultJSON ''MsgFlags)
+4 -4
View File
@@ -380,7 +380,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
CPQuit -> pure ()
CPSkip -> pure ()
runClientTransport :: Transport c => THandle c -> M ()
runClientTransport :: Transport c => THandleSMP c -> M ()
runClientTransport th@THandle {params = THandleParams {thVersion, sessionId}} = do
q <- asks $ tbqSize . config
ts <- liftIO getSystemTime
@@ -428,7 +428,7 @@ cancelSub sub =
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
_ -> return ()
receive :: Transport c => THandle c -> Client -> M ()
receive :: Transport c => THandleSMP c -> Client -> M ()
receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
forever $ do
@@ -449,7 +449,7 @@ receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActi
VRFailed -> Left (corrId, queueId, ERR AUTH)
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: Transport c => THandle c -> Client -> IO ()
send :: Transport c => THandleSMP c -> Client -> IO ()
send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send"
forever $ do
@@ -464,7 +464,7 @@ send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do
NMSG {} -> 0
_ -> 1
disconnectTransport :: Transport c => THandle c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
disconnectTransport :: Transport c => THandle v c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport"
loop
+4 -5
View File
@@ -33,9 +33,8 @@ import Simplex.Messaging.Server.Stats
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport (ATransport, VersionSMP, VersionRangeSMP)
import Simplex.Messaging.Transport.Server (SocketState, TransportServerConfig, loadFingerprint, loadTLSServerParams, newSocketState)
import Simplex.Messaging.Version
import System.IO (IOMode (..))
import System.Mem.Weak (Weak)
import UnliftIO.STM
@@ -73,7 +72,7 @@ data ServerConfig = ServerConfig
privateKeyFile :: FilePath,
certificateFile :: FilePath,
-- | SMP client-server protocol version range
smpServerVRange :: VersionRange,
smpServerVRange :: VersionRangeSMP,
-- | TCP transport config
transportConfig :: TransportServerConfig,
-- | run listener on control port
@@ -128,7 +127,7 @@ data Client = Client
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
endThreads :: TVar (IntMap (Weak ThreadId)),
endThreadSeq :: TVar Int,
thVersion :: Version,
thVersion :: VersionSMP,
sessionId :: ByteString,
connected :: TVar Bool,
createdAt :: SystemTime,
@@ -152,7 +151,7 @@ newServer = do
savingLock <- createLock
return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock}
newClient :: TVar Int -> Natural -> Version -> ByteString -> SystemTime -> STM Client
newClient :: TVar Int -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client
newClient nextClientId qSize thVersion sessionId createdAt = do
clientId <- stateTVar nextClientId $ \next -> (next, next + 1)
subscriptions <- TM.empty
+2 -2
View File
@@ -17,14 +17,14 @@ data QueueRec = QueueRec
notifier :: !(Maybe NtfCreds),
status :: !ServerQueueStatus
}
deriving (Eq, Show)
deriving (Show)
data NtfCreds = NtfCreds
{ notifierId :: !NotifierId,
notifierKey :: !NtfPublicAuthKey,
rcvNtfDhSecret :: !RcvNtfDhSecret
}
deriving (Eq, Show)
deriving (Show)
instance StrEncoding NtfCreds where
strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} = strEncode (notifierId, notifierKey, rcvNtfDhSecret)
+53 -32
View File
@@ -9,6 +9,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
@@ -27,10 +28,15 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
module Simplex.Messaging.Transport
( -- * SMP transport parameters
SMPVersion,
VersionSMP,
VersionRangeSMP,
THandleSMP,
supportedClientSMPRelayVRange,
supportedServerSMPRelayVRange,
currentClientSMPRelayVersion,
currentServerSMPRelayVersion,
batchCmdsSMPVersion,
basicAuthSMPVersion,
subModeSMPVersion,
authCmdsSMPVersion,
@@ -85,6 +91,7 @@ import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Default (def)
import Data.Functor (($>))
import Data.Version (showVersion)
import Data.Word (Word16)
import qualified Data.X509 as X
import qualified Data.X509.Validation as XV
import GHC.IO.Handle.Internals (ioe_EOF)
@@ -98,6 +105,7 @@ import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON)
import Simplex.Messaging.Transport.Buffer
import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith)
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import UnliftIO.Exception (Exception)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -116,30 +124,41 @@ smpBlockSize = 16384
-- 6 - allow creating queues without subscribing (9/10/2023)
-- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024)
batchCmdsSMPVersion :: Version
batchCmdsSMPVersion = 4
data SMPVersion
basicAuthSMPVersion :: Version
basicAuthSMPVersion = 5
instance VersionScope SMPVersion
subModeSMPVersion :: Version
subModeSMPVersion = 6
type VersionSMP = Version SMPVersion
authCmdsSMPVersion :: Version
authCmdsSMPVersion = 7
type VersionRangeSMP = VersionRange SMPVersion
currentClientSMPRelayVersion :: Version
currentClientSMPRelayVersion = 6
pattern VersionSMP :: Word16 -> VersionSMP
pattern VersionSMP v = Version v
currentServerSMPRelayVersion :: Version
currentServerSMPRelayVersion = 6
batchCmdsSMPVersion :: VersionSMP
batchCmdsSMPVersion = VersionSMP 4
basicAuthSMPVersion :: VersionSMP
basicAuthSMPVersion = VersionSMP 5
subModeSMPVersion :: VersionSMP
subModeSMPVersion = VersionSMP 6
authCmdsSMPVersion :: VersionSMP
authCmdsSMPVersion = VersionSMP 7
currentClientSMPRelayVersion :: VersionSMP
currentClientSMPRelayVersion = VersionSMP 6
currentServerSMPRelayVersion :: VersionSMP
currentServerSMPRelayVersion = VersionSMP 6
-- minimal supported protocol version is 4
-- TODO remove code that supports sending commands without batching
supportedClientSMPRelayVRange :: VersionRange
supportedClientSMPRelayVRange :: VersionRangeSMP
supportedClientSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentClientSMPRelayVersion
supportedServerSMPRelayVRange :: VersionRange
supportedServerSMPRelayVRange :: VersionRangeSMP
supportedServerSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentServerSMPRelayVersion
simplexMQVersion :: String
@@ -287,16 +306,18 @@ instance Transport TLS where
-- * SMP transport
-- | The handle for SMP encrypted transport connection over Transport.
data THandle c = THandle
data THandle v c = THandle
{ connection :: c,
params :: THandleParams
params :: THandleParams v
}
data THandleParams = THandleParams
type THandleSMP c = THandle SMPVersion c
data THandleParams v = THandleParams
{ sessionId :: SessionId,
blockSize :: Int,
-- | agreed server protocol version
thVersion :: Version,
thVersion :: Version v,
-- | peer public key for command authorization and shared secrets for entity ID encryption
thAuth :: Maybe THandleAuth,
-- | do NOT send session ID in transmission, but include it into signed message
@@ -316,7 +337,7 @@ data THandleAuth = THandleAuth
type SessionId = ByteString
data ServerHandshake = ServerHandshake
{ smpVersionRange :: VersionRange,
{ smpVersionRange :: VersionRangeSMP,
sessionId :: SessionId,
-- pub key to agree shared secrets for command authorization and entity ID encryption.
authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey)
@@ -324,7 +345,7 @@ data ServerHandshake = ServerHandshake
data ClientHandshake = ClientHandshake
{ -- | agreed SMP server protocol version
smpVersion :: Version,
smpVersion :: VersionSMP,
-- | server identity - CA certificate fingerprint
keyHash :: C.KeyHash,
-- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
@@ -358,12 +379,12 @@ instance Encoding ServerHandshake where
C.SignedObject key <- smpP
pure (cert, key)
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
encodeAuthEncryptCmds :: Encoding a => VersionSMP -> Maybe a -> ByteString
encodeAuthEncryptCmds v k
| v >= authCmdsSMPVersion = maybe "" smpEncode k
| otherwise = ""
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a)
authEncryptCmdsP v p = if v >= authCmdsSMPVersion then Just <$> p else pure Nothing
-- | Error of SMP encrypted transport over TCP.
@@ -412,13 +433,13 @@ serializeTransportError = \case
TEHandshake e -> "HANDSHAKE " <> bshow e
-- | Pad and send block to SMP transport.
tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
tPutBlock :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ())
tPutBlock THandle {connection = c, params = THandleParams {blockSize}} block =
bimapM (const $ pure TELargeMsg) (cPut c) $
C.pad block blockSize
-- | Receive block from SMP transport.
tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString)
tGetBlock :: Transport c => THandle v c -> IO (Either TransportError ByteString)
tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
msg <- cGet c blockSize
if B.length msg == blockSize
@@ -428,7 +449,7 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
-- | Server SMP transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c)
smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
sk = C.signX509 serverSignKey $ C.publicToX509 k
@@ -445,7 +466,7 @@ smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
-- | Client SMP transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c)
smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th
@@ -465,24 +486,24 @@ smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
pure $ smpThHandle th v pk sk_
Nothing -> throwE $ TEHandshake VERSION
smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
smpThHandle :: forall c. THandleSMP c -> VersionSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c
smpThHandle th@THandle {params} v privKey k_ =
-- TODO drop SMP v6: make thAuth non-optional
let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_
params' = params {thVersion = v, thAuth, implySessId = v >= authCmdsSMPVersion}
in (th :: THandle c) {params = params'}
in (th :: THandleSMP c) {params = params'}
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
sendHandshake :: (Transport c, Encoding smp) => THandle v c -> smp -> ExceptT TransportError IO ()
sendHandshake th = ExceptT . tPutBlock th . smpEncode
-- ignores tail bytes to allow future extensions
getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp
getHandshake :: (Transport c, Encoding smp) => THandle v c -> ExceptT TransportError IO smp
getHandshake th = ExceptT $ (first (\_ -> TEHandshake PARSE) . A.parseOnly smpP =<<) <$> tGetBlock th
smpTHandle :: Transport c => c -> THandle c
smpTHandle :: Transport c => c -> THandleSMP c
smpTHandle c = THandle {connection = c, params}
where
params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = True}
params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = VersionSMP 0, thAuth = Nothing, implySessId = False, batch = True}
$(J.deriveJSON (sumTypeJSON id) ''HandshakeError)
+48 -40
View File
@@ -1,13 +1,16 @@
{-# LANGUAGE ConstrainedClassMethods #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Simplex.Messaging.Version
( Version,
VersionRange (minVersion, maxVersion),
VersionScope,
pattern VersionRange,
VersionI (..),
VersionRangeI (..),
@@ -24,47 +27,61 @@ module Simplex.Messaging.Version
where
import Control.Applicative (optional)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Aeson as J
import qualified Data.Aeson.Encoding as JE
import Data.Aeson.Types ((.:), (.=))
import qualified Data.Aeson.Types as JT
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Word (Word16)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Util ((<$?>))
import Simplex.Messaging.Version.Internal (Version (..))
pattern VersionRange :: Word16 -> Word16 -> VersionRange
pattern VersionRange :: Version v -> Version v -> VersionRange v
pattern VersionRange v1 v2 <- VRange v1 v2
{-# COMPLETE VersionRange #-}
type Version = Word16
data VersionRange = VRange
{ minVersion :: Version,
maxVersion :: Version
data VersionRange v = VRange
{ minVersion :: Version v,
maxVersion :: Version v
}
deriving (Eq, Show)
instance J.FromJSON (VersionRange v) where
parseJSON (J.Object v) = do
minVersion <- v .: "minVersion"
maxVersion <- v .: "maxVersion"
pure VRange {minVersion, maxVersion}
parseJSON invalid =
JT.prependFailure "bad VersionRange, " (JT.typeMismatch "Object" invalid)
instance J.ToJSON (VersionRange v) where
toEncoding VRange {minVersion, maxVersion} = JE.pairs $ ("minVersion" .= minVersion) <> ("maxVersion" .= maxVersion)
toJSON VRange {minVersion, maxVersion} = J.object ["minVersion" .= minVersion, "maxVersion" .= maxVersion]
class VersionScope v
-- | construct valid version range, to be used in constants
mkVersionRange :: Version -> Version -> VersionRange
mkVersionRange :: Version v -> Version v -> VersionRange v
mkVersionRange v1 v2
| v1 <= v2 = VRange v1 v2
| otherwise = error "invalid version range"
safeVersionRange :: Version -> Version -> Maybe VersionRange
safeVersionRange :: Version v -> Version v -> Maybe (VersionRange v)
safeVersionRange v1 v2
| v1 <= v2 = Just $ VRange v1 v2
| otherwise = Nothing
versionToRange :: Version -> VersionRange
versionToRange :: Version v -> VersionRange v
versionToRange v = VRange v v
instance Encoding VersionRange where
instance VersionScope v => Encoding (VersionRange v) where
smpEncode (VRange v1 v2) = smpEncode (v1, v2)
smpP =
maybe (fail "invalid version range") pure
=<< safeVersionRange <$> smpP <*> smpP
instance StrEncoding VersionRange where
instance VersionScope v => StrEncoding (VersionRange v) where
strEncode (VRange v1 v2)
| v1 == v2 = strEncode v1
| otherwise = strEncode v1 <> "-" <> strEncode v2
@@ -73,32 +90,23 @@ instance StrEncoding VersionRange where
v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-')
maybe (fail "invalid version range") pure $ safeVersionRange v1 v2
instance ToJSON VersionRange where
toJSON (VRange v1 v2) = toJSON (v1, v2)
toEncoding (VRange v1 v2) = toEncoding (v1, v2)
class VersionScope v => VersionI v a | a -> v where
type VersionRangeT v a
version :: a -> Version v
toVersionRangeT :: a -> VersionRange v -> VersionRangeT v a
instance FromJSON VersionRange where
parseJSON v =
(\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2)
<$?> parseJSON v
class VersionScope v => VersionRangeI v a | a -> v where
type VersionT v a
versionRange :: a -> VersionRange v
toVersionT :: a -> Version v -> VersionT v a
class VersionI a where
type VersionRangeT a
version :: a -> Version
toVersionRangeT :: a -> VersionRange -> VersionRangeT a
class VersionRangeI a where
type VersionT a
versionRange :: a -> VersionRange
toVersionT :: a -> Version -> VersionT a
instance VersionI Version where
type VersionRangeT Version = VersionRange
instance VersionScope v => VersionI v (Version v) where
type VersionRangeT v (Version v) = VersionRange v
version = id
toVersionRangeT _ vr = vr
instance VersionRangeI VersionRange where
type VersionT VersionRange = Version
instance VersionScope v => VersionRangeI v (VersionRange v) where
type VersionT v (VersionRange v) = Version v
versionRange = id
toVersionT _ v = v
@@ -109,18 +117,18 @@ pattern Compatible a <- Compatible_ a
{-# COMPLETE Compatible #-}
isCompatible :: VersionI a => a -> VersionRange -> Bool
isCompatible :: VersionI v a => a -> VersionRange v -> Bool
isCompatible x (VRange v1 v2) = let v = version x in v1 <= v && v <= v2
isCompatibleRange :: VersionRangeI a => a -> VersionRange -> Bool
isCompatibleRange :: VersionRangeI v a => a -> VersionRange v -> Bool
isCompatibleRange x (VRange min2 max2) = min1 <= max2 && min2 <= max1
where
VRange min1 max1 = versionRange x
proveCompatible :: VersionI a => a -> VersionRange -> Maybe (Compatible a)
proveCompatible :: VersionI v a => a -> VersionRange v -> Maybe (Compatible a)
proveCompatible x vr = x `mkCompatibleIf` (x `isCompatible` vr)
compatibleVersion :: VersionRangeI a => a -> VersionRange -> Maybe (Compatible (VersionT a))
compatibleVersion :: VersionRangeI v a => a -> VersionRange v -> Maybe (Compatible (VersionT v a))
compatibleVersion x vr =
toVersionT x (min max1 max2) `mkCompatibleIf` isCompatibleRange x vr
where
+25
View File
@@ -0,0 +1,25 @@
module Simplex.Messaging.Version.Internal where
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Word (Word16)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
-- Do not use constructor of this type directry
newtype Version v = Version Word16
deriving (Eq, Ord, Show)
instance Encoding (Version v) where
smpEncode (Version v) = smpEncode v
smpP = Version <$> smpP
instance StrEncoding (Version v) where
strEncode (Version v) = strEncode v
strP = Version <$> strP
instance ToJSON (Version v) where
toEncoding (Version v) = toEncoding v
toJSON (Version v) = toJSON v
instance FromJSON (Version v) where
parseJSON v = Version <$> parseJSON v
+3 -9
View File
@@ -68,12 +68,6 @@ import Simplex.RemoteControl.Types
import UnliftIO
import UnliftIO.Concurrent
currentRCVersion :: Version
currentRCVersion = 1
supportedRCVRange :: VersionRange
supportedRCVRange = mkVersionRange 1 currentRCVersion
xrcpBlockSize :: Int
xrcpBlockSize = 16384
@@ -181,7 +175,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
{ ca = certFingerprint caCert,
host,
port = fromIntegral portNum,
v = supportedRCVRange,
v = supportedRCPVRange,
app = ctrlAppInfo,
ts,
skey = fst sessKeys,
@@ -220,7 +214,7 @@ prepareHostSession
unless (ca == tlsHostFingerprint) $ throwError RCEIdentity
(kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey
let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey
unless (isCompatible v supportedRCVRange) $ throwError RCEVersion
unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion
let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey}
knownHost' <- updateKnownHost ca dhPubKey
let ctrlHello = RCCtrlHello {}
@@ -334,7 +328,7 @@ prepareHostHello
RCInvitation {v, dh = dhPubKey}
hostAppInfo = do
logDebug "Preparing session"
case compatibleVersion v supportedRCVRange of
case compatibleVersion v supportedRCPVRange of
Nothing -> throwError RCEVersion
Just (Compatible v') -> do
nonce <- liftIO . atomically $ C.randomCbNonce drg
+2 -2
View File
@@ -27,7 +27,7 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Version (VersionRange)
import Simplex.RemoteControl.Types (VersionRangeRCP)
data RCInvitation = RCInvitation
{ -- | CA TLS certificate fingerprint of the controller.
@@ -37,7 +37,7 @@ data RCInvitation = RCInvitation
host :: TransportHost,
port :: Word16,
-- | Supported version range for remote control protocol
v :: VersionRange,
v :: VersionRangeRCP,
-- | Application information
app :: J.Value,
-- | Session start time in seconds since epoch
+22 -6
View File
@@ -5,6 +5,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
@@ -17,6 +18,7 @@ import Data.ByteString (ByteString)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.Word (Word16)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.SNTRUP761
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
@@ -26,7 +28,8 @@ import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON)
import Simplex.Messaging.Transport (TLS)
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util (safeDecodeUtf8)
import Simplex.Messaging.Version (Version, VersionRange, mkVersionRange)
import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange)
import Simplex.Messaging.Version.Internal
import UnliftIO
data RCErrorType
@@ -92,24 +95,37 @@ instance StrEncoding RCErrorType where
-- * Discovery
ipProbeVersionRange :: VersionRange
ipProbeVersionRange = mkVersionRange 1 1
data RCPVersion
instance VersionScope RCPVersion
type VersionRCP = Version RCPVersion
type VersionRangeRCP = VersionRange RCPVersion
pattern VersionRCP :: Word16 -> VersionRCP
pattern VersionRCP v = Version v
currentRCPVersion :: VersionRCP
currentRCPVersion = VersionRCP 1
supportedRCPVRange :: VersionRangeRCP
supportedRCPVRange = mkVersionRange (VersionRCP 1) currentRCPVersion
data IpProbe = IpProbe
{ versionRange :: VersionRange,
{ versionRange :: VersionRangeRCP,
randomNonce :: ByteString
}
deriving (Show)
instance Encoding IpProbe where
smpEncode IpProbe {versionRange, randomNonce} = smpEncode (versionRange, 'I', randomNonce)
smpP = IpProbe <$> (smpP <* "I") *> smpP
-- * Session
data RCHostHello = RCHostHello
{ v :: Version,
{ v :: VersionRCP,
ca :: C.KeyHash,
app :: J.Value,
kem :: KEMPublicKey
+191 -107
View File
@@ -4,6 +4,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PostfixOperators #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
@@ -12,12 +13,12 @@ module AgentTests (agentTests) where
import AgentTests.ConnectionRequestTests
import AgentTests.DoubleRatchetTests (doubleRatchetTests)
import AgentTests.FunctionalAPITests (functionalAPITests)
import AgentTests.FunctionalAPITests (functionalAPITests, pattern Msg, pattern Msg')
import AgentTests.MigrationTests (migrationTests)
import AgentTests.NotificationTests (notificationTests)
import AgentTests.SQLiteTests (storeTests)
import Control.Concurrent
import Control.Monad (forM_)
import Control.Monad (forM_, when)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Maybe (fromJust)
@@ -26,15 +27,18 @@ import GHC.Stack (withFrozenCallStack)
import Network.HTTP.Types (urlEncode)
import SMPAgentClient
import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn)
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Protocol hiding (MID, CONF, INFO, REQ)
import qualified Simplex.Messaging.Agent.Protocol as A
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQSupportOn, pattern PQSupportOff)
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
import Simplex.Messaging.Protocol (ErrorType (..))
import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..))
import Simplex.Messaging.Util (bshow)
import System.Directory (removeFile)
import System.Timeout
import Test.Hspec
import Util
agentTests :: ATransport -> Spec
agentTests (ATransport t) = do
@@ -46,24 +50,25 @@ agentTests (ATransport t) = do
describe "Migration tests" migrationTests
describe "SMP agent protocol syntax" $ syntaxTests t
describe "Establishing duplex connection (via agent protocol)" $ do
-- These tests are disabled because the agent does not work correctly with multiple connected TCP clients
xit "should connect via one server and one agent" $ do
smpAgentTest2_1_1 $ testDuplexConnection t
xit "should connect via one server and one agent (random IDs)" $ do
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
skip "These tests are disabled because the agent does not work correctly with multiple connected TCP clients" $
describe "one agent" $ do
it "should connect via one server and one agent" $ do
smpAgentTest2_1_1 $ testDuplexConnection t
it "should connect via one server and one agent (random IDs)" $ do
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
it "should connect via one server and 2 agents" $ do
smpAgentTest2_2_1 $ testDuplexConnection t
it "should connect via one server and 2 agents (random IDs)" $ do
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
it "should connect via 2 servers and 2 agents" $ do
smpAgentTest2_2_2 $ testDuplexConnection t
it "should connect via 2 servers and 2 agents (random IDs)" $ do
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
describe "should connect via 2 servers and 2 agents" $ do
pqMatrix2 t smpAgentTest2_2_2 testDuplexConnection'
describe "should connect via 2 servers and 2 agents (random IDs)" $ do
pqMatrix2 t smpAgentTest2_2_2 testDuplexConnRandomIds'
describe "Establishing connections via `contact connection`" $ do
it "should connect via contact connection with one server and 3 agents" $ do
smpAgentTest3 $ testContactConnection t
it "should connect via contact connection with one server and 2 agents (random IDs)" $ do
smpAgentTest2_2_1 $ testContactConnRandomIds t
describe "should connect via contact connection with one server and 3 agents" $ do
pqMatrix3 t smpAgentTest3 testContactConnection
describe "should connect via contact connection with one server and 2 agents (random IDs)" $ do
pqMatrix2NoInv t smpAgentTest2_2_1 testContactConnRandomIds
it "should support rejecting contact request" $ do
smpAgentTest2_2_1 $ testRejectContactRequest t
describe "Connection subscriptions" $ do
@@ -72,8 +77,8 @@ agentTests (ATransport t) = do
it "should send notifications to client when server disconnects" $ do
smpAgentServerTest $ testSubscrNotification t
describe "Message delivery and server reconnection" $ do
it "should deliver messages after losing server connection and re-connecting" $ do
smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t
describe "should deliver messages after losing server connection and re-connecting" $
pqMatrix2 t smpAgentTest2_2_2_needs_server testMsgDeliveryServerRestart
it "should connect to the server when server goes up if it initially was down" $ do
smpAgentTestN [] $ testServerConnectionAfterError t
it "should deliver pending messages after agent restarting" $ do
@@ -133,6 +138,9 @@ action #> (corrId, connId, cmd) = withFrozenCallStack $ action `shouldReturn` (c
(=#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation
action =#> p = withFrozenCallStack $ action >>= (`shouldSatisfy` p . correctTransmission)
pattern MID :: AgentMsgId -> ACommand 'Agent 'AEConn
pattern MID msgId = A.MID msgId PQEncOn
correctTransmission :: (ACorrId, ConnId, Either AgentErrorType cmd) -> (ACorrId, ConnId, cmd)
correctTransmission (corrId, connId, cmdOrErr) = case cmdOrErr of
Right cmd -> (corrId, connId, cmd)
@@ -161,130 +169,188 @@ h #:# err = tryGet `shouldReturn` ()
Just _ -> error err
_ -> return ()
pattern Msg :: MsgBody -> ACommand 'Agent e
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
type PQMatrix2 c =
HasCallStack =>
TProxy c ->
(HasCallStack => (c -> c -> IO ()) -> Expectation) ->
(HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> IO ()) ->
Spec
pattern Msg' :: AgentMsgId -> MsgBody -> ACommand 'Agent e
pattern Msg' aMsgId msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} _ msgBody
pqMatrix2 :: PQMatrix2 c
pqMatrix2 = pqMatrix2_ True
pqMatrix2NoInv :: PQMatrix2 c
pqMatrix2NoInv = pqMatrix2_ False
pqMatrix2_ :: Bool -> PQMatrix2 c
pqMatrix2_ pqInv _ smpTest test = do
it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOff)
it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOn)
it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOff)
it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOn)
when pqInv $ do
it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOff)
it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOn)
pqMatrix3 ::
HasCallStack =>
TProxy c ->
(HasCallStack => (c -> c -> c -> IO ()) -> Expectation) ->
(HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()) ->
Spec
pqMatrix3 _ smpTest test = do
it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOff)
it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOn)
it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOff)
it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOn)
it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOff)
it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOn)
it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOff)
it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOn)
testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO ()
testDuplexConnection _ alice bob = do
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV subscribe")
testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQSupportOn)
testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO ()
testDuplexConnection' (alice, aPQ) (bob, bPQ) = do
let pq = pqConnectionMode aPQ bPQ
pqSup = CR.pqEncToSupport pq
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe")
let cReq' = strEncode cReq
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
("", "bob", Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:)
pqSup' `shouldBe` pqSup
alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
bob <# ("", "alice", INFO "alice's connInfo")
bob <# ("", "alice", CON)
alice <# ("", "bob", CON)
bob <# ("", "alice", A.INFO pqSup "alice's connInfo")
bob <# ("", "alice", CON pq)
alice <# ("", "bob", CON pq)
-- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4
alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", MID 4)
alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", A.MID 4 pq)
alice <# ("", "bob", SENT 4)
bob <#= \case ("", "alice", Msg' 4 "hello") -> True; _ -> False
bob <#= \case ("", "alice", Msg' 4 pq' "hello") -> pq == pq'; _ -> False
bob #: ("12", "alice", "ACK 4") #> ("12", "alice", OK)
alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", MID 5)
alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", A.MID 5 pq)
alice <# ("", "bob", SENT 5)
bob <#= \case ("", "alice", Msg' 5 "how are you?") -> True; _ -> False
bob <#= \case ("", "alice", Msg' 5 pq' "how are you?") -> pq == pq'; _ -> False
bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK)
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 6)
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", A.MID 6 pq)
bob <# ("", "alice", SENT 6)
alice <#= \case ("", "bob", Msg' 6 "hello too") -> True; _ -> False
alice <#= \case ("", "bob", Msg' 6 pq' "hello too") -> pq == pq'; _ -> False
alice #: ("3a", "bob", "ACK 6") #> ("3a", "bob", OK)
bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", MID 7)
bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", A.MID 7 pq)
bob <# ("", "alice", SENT 7)
alice <#= \case ("", "bob", Msg' 7 "message 1") -> True; _ -> False
alice <#= \case ("", "bob", Msg' 7 pq' "message 1") -> pq == pq'; _ -> False
alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK)
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", MID 8)
bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", A.MID 8 pq)
bob <# ("", "alice", MERR 8 (SMP AUTH))
alice #: ("6", "bob", "DEL") #> ("6", "bob", OK)
alice #:# "nothing else should be delivered to alice"
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
testDuplexConnRandomIds _ alice bob = do
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV subscribe")
testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO ()
testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQSupportOn)
testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO ()
testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do
let pq = pqConnectionMode aPQ bPQ
pqSup = CR.pqEncToSupport pq
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe")
let cReq' = strEncode cReq
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo")
("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo")
("", bobConn', Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:)
pqSup' `shouldBe` pqSup
bobConn' `shouldBe` bobConn
alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False
bob <# ("", aliceConn, INFO "alice's connInfo")
bob <# ("", aliceConn, CON)
alice <# ("", bobConn, CON)
alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, MID 4)
bob <# ("", aliceConn, A.INFO pqSup "alice's connInfo")
bob <# ("", aliceConn, CON pq)
alice <# ("", bobConn, CON pq)
alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, A.MID 4 pq)
alice <# ("", bobConn, SENT 4)
bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False
bob <#= \case ("", c, Msg' 4 pq' "hello") -> c == aliceConn && pq == pq'; _ -> False
bob #: ("12", aliceConn, "ACK 4") #> ("12", aliceConn, OK)
alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, MID 5)
alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, A.MID 5 pq)
alice <# ("", bobConn, SENT 5)
bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False
bob <#= \case ("", c, Msg' 5 pq' "how are you?") -> c == aliceConn && pq == pq'; _ -> False
bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK)
bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, MID 6)
bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, A.MID 6 pq)
bob <# ("", aliceConn, SENT 6)
alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False
alice <#= \case ("", c, Msg' 6 pq' "hello too") -> c == bobConn && pq == pq'; _ -> False
alice #: ("3a", bobConn, "ACK 6") #> ("3a", bobConn, OK)
bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, MID 7)
bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, A.MID 7 pq)
bob <# ("", aliceConn, SENT 7)
alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False
alice <#= \case ("", c, Msg' 7 pq' "message 1") -> c == bobConn && pq == pq'; _ -> False
alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK)
alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK)
bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, MID 8)
bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, A.MID 8 pq)
bob <# ("", aliceConn, MERR 8 (SMP AUTH))
alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK)
alice #:# "nothing else should be delivered to alice"
testContactConnection :: Transport c => TProxy c -> c -> c -> c -> IO ()
testContactConnection _ alice bob tom = do
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON subscribe")
testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()
testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe")
let cReq' = strEncode cReq
abPQ = pqConnectionMode aPQ bPQ
abPQSup = CR.pqEncToSupport abPQ
aPQMode = CR.connPQEncryption aPQ
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
alice #: ("2", "bob", "ACPT " <> aInvId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:)
bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
("", "alice_contact", Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:)
pqSup' `shouldBe` bPQ
alice #: ("2", "bob", "ACPT " <> aInvId <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("2", "bob", OK)
("", "alice", Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:)
pqSup'' `shouldBe` abPQSup
bob #: ("12", "alice", "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", "alice", OK)
alice <# ("", "bob", INFO "bob's connInfo 2")
alice <# ("", "bob", CON)
bob <# ("", "alice", CON)
alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", MID 4)
alice <# ("", "bob", A.INFO abPQSup "bob's connInfo 2")
alice <# ("", "bob", CON abPQ)
bob <# ("", "alice", CON abPQ)
alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", A.MID 4 abPQ)
alice <# ("", "bob", SENT 4)
bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False
bob <#= \case ("", "alice", Msg' 4 pq' "hi") -> pq' == abPQ; _ -> False
bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK)
tom #: ("21", "alice", "JOIN T " <> cReq' <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK)
("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:)
alice #: ("4", "tom", "ACPT " <> aInvId' <> " 16\nalice's connInfo") #> ("4", "tom", OK)
("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:)
let atPQ = pqConnectionMode aPQ tPQ
atPQSup = CR.pqEncToSupport atPQ
tom #: ("21", "alice", "JOIN T " <> cReq' <> enableKEMStr tPQ <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK)
("", "alice_contact", Right (A.REQ aInvId' pqSup3 _ "tom's connInfo")) <- (alice <#:)
pqSup3 `shouldBe` tPQ
alice #: ("4", "tom", "ACPT " <> aInvId' <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("4", "tom", OK)
("", "alice", Right (A.CONF tConfId pqSup4 _ "alice's connInfo")) <- (tom <#:)
pqSup4 `shouldBe` atPQSup
tom #: ("22", "alice", "LET " <> tConfId <> " 16\ntom's connInfo 2") #> ("22", "alice", OK)
alice <# ("", "tom", INFO "tom's connInfo 2")
alice <# ("", "tom", CON)
tom <# ("", "alice", CON)
alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", MID 4)
alice <# ("", "tom", A.INFO atPQSup "tom's connInfo 2")
alice <# ("", "tom", CON atPQ)
tom <# ("", "alice", CON atPQ)
alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", A.MID 4 atPQ)
alice <# ("", "tom", SENT 4)
tom <#= \case ("", "alice", Msg "hi there") -> True; _ -> False
tom <#= \case ("", "alice", Msg' 4 pq' "hi there") -> pq' == atPQ; _ -> False
tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK)
testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
testContactConnRandomIds _ alice bob = do
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON subscribe")
testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO ()
testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do
let pq = pqConnectionMode aPQ bPQ
pqSup = CR.pqEncToSupport pq
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe")
let cReq' = strEncode cReq
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo")
("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo")
("", aliceContact', Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:)
pqSup' `shouldBe` bPQ
aliceContact' `shouldBe` aliceContact
("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> " 16\nalice's connInfo")
("", aliceConn', Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:)
("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> enableKEMStr (CR.connPQEncryption aPQ) <> " 16\nalice's connInfo")
("", aliceConn', Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:)
pqSup'' `shouldBe` pqSup
aliceConn' `shouldBe` aliceConn
bob #: ("12", aliceConn, "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", aliceConn, OK)
alice <# ("", bobConn, INFO "bob's connInfo 2")
alice <# ("", bobConn, CON)
bob <# ("", aliceConn, CON)
alice <# ("", bobConn, A.INFO pqSup "bob's connInfo 2")
alice <# ("", bobConn, CON pq)
bob <# ("", aliceConn, CON pq)
alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, MID 4)
alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, A.MID 4 pq)
alice <# ("", bobConn, SENT 4)
bob <#= \case ("", c, Msg "hi") -> c == aliceConn; _ -> False
bob <#= \case ("", c, Msg' 4 pq' "hi") -> c == aliceConn && pq == pq'; _ -> False
bob #: ("13", aliceConn, "ACK 4") #> ("13", aliceConn, OK)
testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO ()
@@ -292,7 +358,7 @@ testRejectContactRequest _ alice bob = do
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON subscribe")
let cReq' = strEncode cReq
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 10\nbob's info") #> ("11", "alice", OK)
("", "a_contact", Right (REQ aInvId _ "bob's info")) <- (alice <#:)
("", "a_contact", Right (A.REQ aInvId PQSupportOff _ "bob's info")) <- (alice <#:)
-- RJCT must use correct contact connection
alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND)
alice #: ("2b", "a_contact", "RJCT " <> aInvId) #> ("2b", "a_contact", OK)
@@ -327,31 +393,32 @@ testSubscrNotification t (server, _) client = do
withSmpServer (ATransport t) $
client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue
testMsgDeliveryServerRestart :: Transport c => TProxy c -> c -> c -> IO ()
testMsgDeliveryServerRestart t alice bob = do
testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO ()
testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do
let pq = pqConnectionMode aPQ bPQ
withServer $ do
connect (alice, "alice") (bob, "bob")
bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", MID 4)
connect' (alice, "alice", aPQ) (bob, "bob", bPQ)
bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", A.MID 4 pq)
bob <# ("", "alice", SENT 4)
alice <#= \case ("", "bob", Msg "hi") -> True; _ -> False
alice <#= \case ("", "bob", Msg' _ pq' "hi") -> pq == pq'; _ -> False
alice #: ("11", "bob", "ACK 4") #> ("11", "bob", OK)
alice #:# "nothing else delivered before the server is killed"
let server = SMPServer "localhost" testPort2 testKeyHash
alice <#. ("", "", DOWN server ["bob"])
bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", MID 5)
bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", A.MID 5 pq)
bob #:# "nothing else delivered before the server is restarted"
alice #:# "nothing else delivered before the server is restarted"
withServer $ do
bob <# ("", "alice", SENT 5)
alice <#. ("", "", UP server ["bob"])
alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False
alice <#= \case ("", "bob", Msg' _ pq' "hello again") -> pq == pq'; _ -> False
alice #: ("12", "bob", "ACK 5") #> ("12", "bob", OK)
removeFile testStoreLogFile
where
withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` ()
withServer test' = withSmpServerStoreLogOn (transport @c) testPort2 (const test') `shouldReturn` ()
testServerConnectionAfterError :: forall c. Transport c => TProxy c -> [c] -> IO ()
testServerConnectionAfterError t _ = do
@@ -432,7 +499,7 @@ testConcurrentMsgDelivery _ alice bob = do
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV subscribe")
let cReq' = strEncode cReq
bob #: ("11", "alice2", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice2", OK)
("", "bob2", Right (CONF _confId _ "bob's connInfo")) <- (alice <#:)
("", "bob2", Right (A.CONF _confId PQSupportOff _ "bob's connInfo")) <- (alice <#:)
-- below commands would be needed to accept bob's connection, but alice does not
-- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
-- bob <# ("", "alice", INFO "alice's connInfo")
@@ -492,16 +559,33 @@ testResumeDeliveryQuotaExceeded _ alice bob = do
-- message 8 is skipped because of alice agent sending "QCONT" message
bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK)
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV subscribe")
connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQSupportOn)
connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQSupport) -> IO ()
connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe")
let cReq' = strEncode cReq
h2 #: ("c2", name1, "JOIN T " <> cReq' <> " subscribe 5\ninfo2") #> ("c2", name1, OK)
("", _, Right (CONF connId _ "info2")) <- (h1 <#:)
pq = pqConnectionMode pqMode1 pqMode2
pqSup = CR.pqEncToSupport pq
h2 #: ("c2", name1, "JOIN T " <> cReq' <> enableKEMStr pqMode2 <> " subscribe 5\ninfo2") #> ("c2", name1, OK)
("", _, Right (A.CONF connId pqSup' _ "info2")) <- (h1 <#:)
pqSup' `shouldBe` pqSup
h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK)
h2 <# ("", name1, INFO "info1")
h2 <# ("", name1, CON)
h1 <# ("", name2, CON)
h2 <# ("", name1, A.INFO pqSup "info1")
h2 <# ("", name1, CON pq)
h1 <# ("", name2, CON pq)
pqConnectionMode :: InitialKeys -> PQSupport -> PQEncryption
pqConnectionMode pqMode1 pqMode2 = PQEncryption $ supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2
enableKEMStr :: PQSupport -> ByteString
enableKEMStr PQSupportOn = " " <> strEncode PQSupportOn
enableKEMStr _ = ""
pqConnModeStr :: InitialKeys -> ByteString
pqConnModeStr (IKNoPQ PQSupportOff) = ""
pqConnModeStr pq = " " <> strEncode pq
sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO ()
sendMessage (h1, name1) (h2, name2) msg = do
+16 -13
View File
@@ -1,7 +1,10 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module AgentTests.ConnectionRequestTests where
@@ -12,7 +15,7 @@ import Simplex.Messaging.Agent.Protocol
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (ProtocolServer (..), supportedSMPClientVRange)
import Simplex.Messaging.Protocol (ProtocolServer (..), pattern VersionSMPC, supportedSMPClientVRange)
import Simplex.Messaging.ServiceScheme (ServiceScheme (..))
import Simplex.Messaging.Version
import Test.Hspec
@@ -38,7 +41,7 @@ queue :: SMPQueueUri
queue = SMPQueueUri supportedSMPClientVRange queueAddr
queueV1 :: SMPQueueUri
queueV1 = SMPQueueUri (mkVersionRange 1 1) queueAddr
queueV1 = SMPQueueUri (mkVersionRange (VersionSMPC 1) (VersionSMPC 1)) queueAddr
testDhKey :: C.PublicKeyX25519
testDhKey = "MCowBQYDK2VuAyEAjiswwI3O/NlS8Fk3HJUW870EY2bAwmttMBsvRB9eV3o="
@@ -53,7 +56,7 @@ connReqData :: ConnReqUriData
connReqData =
ConnReqUriData
{ crScheme = SSSimplex,
crAgentVRange = mkVersionRange 2 2,
crAgentVRange = mkVersionRange (VersionSMPA 2) (VersionSMPA 2),
crSmpQueues = [queueV1],
crClientData = Nothing
}
@@ -61,11 +64,11 @@ connReqData =
testDhPubKey :: C.PublicKeyX448
testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U="
testE2ERatchetParams :: E2ERatchetParamsUri 'C.X448
testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey
testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448
testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey Nothing
testE2ERatchetParams12 :: E2ERatchetParamsUri 'C.X448
testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey
testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448
testE2ERatchetParams12 = E2ERatchetParamsUri (supportedE2EEncryptVRange PQSupportOn) testDhPubKey testDhPubKey Nothing
connectionRequest :: AConnectionRequestUri
connectionRequest =
@@ -79,7 +82,7 @@ connectionRequestCurrentRange :: AConnectionRequestUri
connectionRequestCurrentRange =
ACR SCMInvitation $
CRInvitationUri
connReqData {crAgentVRange = supportedSMPAgentVRange, crSmpQueues = [queueV1, queueV1]}
connReqData {crAgentVRange = supportedSMPAgentVRange PQSupportOn, crSmpQueues = [queueV1, queueV1]}
testE2ERatchetParams12
connectionRequestClientDataEmpty :: AConnectionRequestUri
@@ -98,7 +101,7 @@ connectionRequestTests =
it "should serialize SMP queue URIs" $ do
strEncode (queue :: SMPQueueUri) {queueAddress = queueAddrNoPort}
`shouldBe` "smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri
strEncode queue {clientVRange = mkVersionRange 1 2}
strEncode queue {clientVRange = mkVersionRange (VersionSMPC 1) (VersionSMPC 2)}
`shouldBe` "smp://1234-w==@smp.simplex.im:5223/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri
it "should parse SMP queue URIs" $ do
strDecode ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStr)
@@ -119,11 +122,11 @@ connectionRequestTests =
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
strEncode connectionRequestCurrentRange
`shouldBe` "simplex:/invitation#/?v=2-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
`shouldBe` "simplex:/invitation#/?v=2-5&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
<> "&e2e=v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&e2e=v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
strEncode connectionRequestClientDataEmpty
`shouldBe` "simplex:/invitation#/?v=2&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> urlEncode True testDhKeyStrUri
@@ -167,9 +170,9 @@ connectionRequestTests =
<> testDhKeyStrUri
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
<> testDhKeyStrUri
<> "&e2e=extra_key%3Dnew%26v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&e2e=extra_key%3Dnew%26v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
<> "&some_new_param=abc"
<> "&v=2-4"
<> "&v=2-5"
)
`shouldBe` Right connectionRequestCurrentRange
strDecode
+458 -73
View File
@@ -1,75 +1,178 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
module AgentTests.DoubleRatchetTests where
import Control.Concurrent.STM
import Control.Monad (when)
import Control.Monad.Except
import Control.Monad.IO.Class
import Crypto.Random (ChaChaDRG)
import Data.Aeson (FromJSON, ToJSON)
import Data.Aeson (FromJSON, ToJSON, (.=))
import qualified Data.Aeson as J
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Map.Strict as M
import Data.Type.Equality
import Simplex.Messaging.Crypto (Algorithm (..), AlgorithmI, CryptoError, DhAlgorithm)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
import Simplex.Messaging.Crypto.Ratchet
import Simplex.Messaging.Encoding
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Util ((<$$>))
import Simplex.Messaging.Version
import Test.Hspec
doubleRatchetTests :: Spec
doubleRatchetTests = do
describe "double-ratchet encryption/decryption" $ do
it "should serialize and parse message header" testMessageHeader
it "should encrypt and decrypt messages" $ do
withRatchets @X25519 testEncryptDecrypt
withRatchets @X448 testEncryptDecrypt
it "should encrypt and decrypt skipped messages" $ do
withRatchets @X25519 testSkippedMessages
withRatchets @X448 testSkippedMessages
it "should encrypt and decrypt many messages" $ do
withRatchets @X25519 testManyMessages
it "should allow skipped after ratchet advance" $ do
withRatchets @X25519 testSkippedAfterRatchetAdvance
it "should serialize and parse message header" $ do
testAlgs $ testMessageHeader kdfX3DHE2EEncryptVersion
testAlgs $ testMessageHeader $ max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
describe "message tests" $ runMessageTests initRatchets False
it "should encode/decode ratchet as JSON" $ do
testKeyJSON C.SX25519
testKeyJSON C.SX448
testRatchetJSON C.SX25519
testRatchetJSON C.SX448
it "should agree the same ratchet parameters" $ do
testX3dh C.SX25519
testX3dh C.SX448
it "should agree the same ratchet parameters with version 1" $ do
testX3dhV1 C.SX25519
testX3dhV1 C.SX448
testAlgs testKeyJSON
testAlgs testRatchetJSON
testVersionJSON
it "should decode v2 Ratchet with default field values" $ testDecodeV2RatchetJSON
it "should agree the same ratchet parameters" $ testAlgs testX3dh
it "should agree the same ratchet parameters with version 1" $ testAlgs testX3dhV1
describe "post-quantum hybrid KEM double-ratchet algorithm" $ do
describe "hybrid KEM key agreement" $ do
it "should propose KEM during agreement, but no shared secret" $ testAlgs testPqX3dhProposeInReply
it "should agree shared secret using KEM" $ testAlgs testPqX3dhProposeAccept
it "should reject proposed KEM in reply" $ testAlgs testPqX3dhProposeReject
it "should allow second proposal in reply" $ testAlgs testPqX3dhProposeAgain
describe "hybrid KEM key agreement errors" $ do
it "should fail if reply contains acceptance without proposal" $ testAlgs testPqX3dhAcceptWithoutProposalError
describe "ratchet encryption/decryption" $ do
it "should serialize and parse public KEM params" testKEMParams
it "should serialize and parse message header" $ testAlgs testMessageHeaderKEM
describe "message tests, KEM proposed" $ runMessageTests initRatchetsKEMProposed True
describe "message tests, KEM accepted" $ runMessageTests initRatchetsKEMAccepted False
describe "message tests, KEM proposed again in reply" $ runMessageTests initRatchetsKEMProposedAgain True
it "should disable and re-enable KEM" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEM
it "should disable and re-enable KEM (always set PQEncryption)" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEMStrict
it "should enable KEM when it was not enabled in handshake" $ withRatchets_ @X25519 initRatchets testEnableKEM
it "should enable KEM when it was not enabled in handshake (always set PQEncryption)" $ withRatchets_ @X25519 initRatchets testEnableKEMStrict
runMessageTests ::
(forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)) ->
Bool ->
Spec
runMessageTests initRatchets_ agreeRatchetKEMs = do
it "should encrypt and decrypt messages" $ run $ testEncryptDecrypt agreeRatchetKEMs
it "should encrypt and decrypt skipped messages" $ run $ testSkippedMessages agreeRatchetKEMs
it "should encrypt and decrypt many messages" $ run $ testManyMessages agreeRatchetKEMs
it "should allow skipped after ratchet advance" $ run $ testSkippedAfterRatchetAdvance agreeRatchetKEMs
where
run :: (forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a) -> IO ()
run test = do
withRatchets_ @X25519 initRatchets_ test
withRatchets_ @X448 initRatchets_ test
testAlgs :: (forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()) -> IO ()
testAlgs test = test C.SX25519 >> test C.SX448
paddedMsgLen :: Int
paddedMsgLen = 100
fullMsgLen :: Int
fullMsgLen = 1 + fullHeaderLen + C.authTagSize + paddedMsgLen
fullMsgLen :: Ratchet a -> Int
fullMsgLen Ratchet {rcSupportKEM, rcVersion} = headerLenLength + fullHeaderLen v rcSupportKEM + C.authTagSize + paddedMsgLen
where
v = current rcVersion
headerLenLength = case rcSupportKEM of
PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 3 -- two bytes are added because of two Large used in new encoding
_ -> 1
testMessageHeader :: Expectation
testMessageHeader = do
(k, _) <- atomically . C.generateKeyPair @X25519 =<< C.newRandom
let hdr = MsgHeader {msgMaxVersion = currentE2EEncryptVersion, msgDHRs = k, msgPN = 0, msgNs = 0}
parseAll (smpP @(MsgHeader 'X25519)) (smpEncode hdr) `shouldBe` Right hdr
testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation
testMessageHeader v _ = do
(k, _) <- atomically . C.generateKeyPair @a =<< C.newRandom
let hdr = MsgHeader {msgMaxVersion = v, msgDHRs = k, msgKEM = Nothing, msgPN = 0, msgNs = 0}
parseAll (msgHeaderP v) (encodeMsgHeader v hdr) `shouldBe` Right hdr
testKEMParams :: Expectation
testKEMParams = do
g <- C.newRandom
(kem, _) <- sntrup761Keypair g
let kemParams = ARKP SRKSProposed $ RKParamsProposed kem
parseAll (smpP @ARKEMParams) (smpEncode kemParams) `shouldBe` Right kemParams
(kem', _) <- sntrup761Keypair g
(ct, _) <- sntrup761Enc g kem
let kemParams' = ARKP SRKSAccepted $ RKParamsAccepted ct kem'
parseAll (smpP @ARKEMParams) (smpEncode kemParams') `shouldBe` Right kemParams'
testMessageHeaderKEM :: forall a. AlgorithmI a => C.SAlgorithm a -> Expectation
testMessageHeaderKEM _ = do
g <- C.newRandom
(k, _) <- atomically $ C.generateKeyPair @a g
(kem, _) <- sntrup761Keypair g
let msgMaxVersion = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
msgKEM = Just . ARKP SRKSProposed $ RKParamsProposed kem
hdr = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM, msgPN = 0, msgNs = 0}
parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr) `shouldBe` Right hdr
(kem', _) <- sntrup761Keypair g
(ct, _) <- sntrup761Enc g kem
let msgKEM' = Just . ARKP SRKSAccepted $ RKParamsAccepted ct kem'
hdr' = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM = msgKEM', msgPN = 0, msgNs = 0}
parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr') `shouldBe` Right hdr'
pattern Decrypted :: ByteString -> Either CryptoError (Either CryptoError ByteString)
pattern Decrypted msg <- Right (Right msg)
type TestRatchets a = (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()
type Encrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString)
testEncryptDecrypt :: TestRatchets a
testEncryptDecrypt alice bob = do
type Decrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString))
type EncryptDecryptSpec a = (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation
type TestRatchets a =
TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) ->
TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) ->
Encrypt a ->
Decrypt a ->
EncryptDecryptSpec a ->
IO ()
deriving instance Eq (Ratchet a)
deriving instance Eq (SndRatchet a)
deriving instance Eq RcvRatchet
deriving instance Eq RatchetKEM
deriving instance Eq RatchetKEMAccepted
deriving instance Eq RatchetInitParams
deriving instance Eq RatchetKey
instance Eq ARKEMParams where
(ARKP s ps) == (ARKP s' ps') = case testEquality s s' of
Just Refl -> ps == ps'
Nothing -> False
deriving instance Eq (MsgHeader a)
initRatchetKEM :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()
initRatchetKEM s r = encryptDecrypt (Just $ PQEncOn) (const ()) (const ()) (s, "initialising ratchet") r
testEncryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
(bob, "hello alice") #> alice
(alice, "hello bob") #> bob
Right b1 <- encrypt bob "how are you, alice?"
@@ -88,8 +191,9 @@ testEncryptDecrypt alice bob = do
(alice, "I'm here too, same") #> bob
pure ()
testSkippedMessages :: TestRatchets a
testSkippedMessages alice bob = do
testSkippedMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
Right msg1 <- encrypt bob "hello alice"
Right msg2 <- encrypt bob "hello there again"
Right msg3 <- encrypt bob "are you there?"
@@ -99,8 +203,9 @@ testSkippedMessages alice bob = do
Decrypted "hello alice" <- decrypt alice msg1
pure ()
testManyMessages :: TestRatchets a
testManyMessages alice bob = do
testManyMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
(bob, "b1") #> alice
(bob, "b2") #> alice
(bob, "b3") #> alice
@@ -117,8 +222,9 @@ testManyMessages alice bob = do
(bob, "b15") #> alice
(bob, "b16") #> alice
testSkippedAfterRatchetAdvance :: TestRatchets a
testSkippedAfterRatchetAdvance alice bob = do
testSkippedAfterRatchetAdvance :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
testSkippedAfterRatchetAdvance agreeRatchetKEMs alice bob encrypt decrypt (#>) = do
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
(bob, "b1") #> alice
Right b2 <- encrypt bob "b2"
Right b3 <- encrypt bob "b3"
@@ -152,6 +258,84 @@ testSkippedAfterRatchetAdvance alice bob = do
Decrypted "b11" <- decrypt alice b11
pure ()
testDisableEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
testDisableEnableKEM alice bob _ _ _ = do
(bob, "hello alice") !#> alice
(alice, "hello bob") !#> bob
(bob, "disabling KEM") !#>\ alice
(alice, "still disabling KEM") !#> bob
(bob, "now KEM is disabled") \#> alice
(alice, "KEM is disabled for both sides") \#> bob
(bob, "trying to enable KEM") \#>! alice
(alice, "but unless alice enables it too it won't enable") \#> bob
(bob, "KEM is disabled") \#> alice
(alice, "KEM is disabled for both sides") \#> bob
(bob, "enabling KEM again") \#>! alice
(alice, "and alice accepts it this time") \#>! bob
(bob, "still enabling KEM") \#>! alice
(alice, "now KEM is enabled") !#> bob
(bob, "KEM is enabled for both sides") !#> alice
testDisableEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
testDisableEnableKEMStrict alice bob _ _ _ = do
(bob, "hello alice") !#>! alice
(alice, "hello bob") !#>! bob
(bob, "disabling KEM") !#>\ alice
(alice, "still disabling KEM") !#>! bob
(bob, "now KEM is disabled") \#>\ alice
(alice, "KEM is disabled for both sides") \#>\ bob
(bob, "trying to enable KEM") \#>! alice
(alice, "but unless alice enables it too it won't enable") \#>\ bob
(bob, "KEM is disabled") \#>! alice
(alice, "KEM is disabled for both sides") \#>\ bob
(bob, "enabling KEM again") \#>! alice
(alice, "and alice accepts it this time") \#>! bob
(bob, "still enabling KEM") \#>! alice
(alice, "now KEM is enabled") !#>! bob
(bob, "KEM is enabled for both sides") !#>! alice
testEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
testEnableKEM alice bob _ _ _ = do
(bob, "hello alice") \#> alice
(alice, "hello bob") \#> bob
(bob, "enabling KEM") \#>! alice
(bob, "KEM not enabled yet") \#>! alice
(alice, "accepting KEM") \#>! bob
(alice, "KEM not enabled yet here too") \#>! bob
(bob, "KEM is still not enabled") \#>! alice
(alice, "now KEM is enabled") !#>! bob
(bob, "now KEM is enabled for both sides") !#> alice
(alice, "still enabled for both sides") !#> bob
(bob, "still enabled for both sides 2") !#> alice
(alice, "disabling KEM") !#>\ bob
(bob, "KEM not disabled yet") !#> alice
(alice, "KEM disabled") \#> bob
(bob, "KEM disabled on both sides") \#> alice
testEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
testEnableKEMStrict alice bob _ _ _ = do
(bob, "hello alice") \#>\ alice
(alice, "hello bob") \#>\ bob
(bob, "enabling KEM") \#>! alice
(bob, "KEM not enabled yet") \#>! alice
(alice, "accepting KEM") \#>! bob
(alice, "KEM not enabled yet here too") \#>! bob
(bob, "KEM is still not enabled") \#>! alice
(alice, "now KEM is enabled") !#>! bob
(bob, "now KEM is enabled for both sides") !#>! alice
(alice, "still enabled for both sides") !#>! bob
(bob, "still enabled for both sides 2") !#>! alice
(alice, "disabling KEM") !#>\ bob
(bob, "KEM not disabled yet") !#>! alice
(alice, "KEM disabled") \#>\ bob
(bob, "KEM disabled on both sides") \#>! alice
(alice, "KEM still disabled 1") \#>\ bob
(bob, "KEM still disabled 2") \#>! alice
(alice, "KEM still disabled 3") \#>\ bob
(bob, "KEM still disabled 4") \#>! alice
(alice, "KEM still disabled 5") \#>\ bob
(bob, "KEM still disabled 6") \#>! alice
testKeyJSON :: forall a. AlgorithmI a => C.SAlgorithm a -> IO ()
testKeyJSON _ = do
(k, pk) <- atomically . C.generateKeyPair @a =<< C.newRandom
@@ -160,10 +344,33 @@ testKeyJSON _ = do
testRatchetJSON :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testRatchetJSON _ = do
(alice, bob) <- initRatchets @a
(alice, bob, _, _, _) <- initRatchets @a
testEncodeDecode alice
testEncodeDecode bob
testVersionJSON :: IO ()
testVersionJSON = do
testEncodeDecode $ rv 1 1
testEncodeDecode $ rv 1 2
-- let bad = RVersions 2 1
-- Left err <- pure $ J.eitherDecode' @RatchetVersions (J.encode bad)
-- err `shouldContain` "bad version range"
testDecodeRV $ (1 :: Int, 2 :: Int)
testDecodeRV $ J.object ["current" .= (1 :: Int), "maxSupported" .= (2 :: Int)]
where
rv v1 v2 = RatchetVersions (VersionE2E v1) (VersionE2E v2)
testDecodeRV :: ToJSON a => a -> Expectation
testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2)
testDecodeV2RatchetJSON :: IO ()
testDecodeV2RatchetJSON = do
let v2RatchetJSON = "{\"rcVersion\":[2,2],\"rcAD\":\"2GEJrq48TmQse6NR16I-hrI0tSySZQ57E_g46nDceAPRAiF6j0drq26RTE7be6X7uiB4RaGJGf4QRXzcYuVtWw==\",\"rcDHRs\":\"TUM0Q0FRQXdCUVlESzJWdUJDSUVJRkNYbUxtSHQ3SUNfeHpGTi1Qb3ZqTVQ3S2p6XzZlZlBjOG9fRFY2RWxKOQ==\",\"rcRK\":\"BOX2X7YW5qDSp2XknY_lqacSrtDqQNPvS6iJlZIs3G0=\",\"rcNs\":0,\"rcNr\":0,\"rcPN\":0,\"rcNHKs\":\"IMouSkXUvzT_mo0WM-pqEUK09-HTLk9WOTCFQglyQxU=\",\"rcNHKr\":\"g-tus1clYPV0rGlzkf5a959tUqDYQVZ1FpcPeXdKwxI=\"}"
Right (r :: Ratchet X25519) <- pure $ J.eitherDecodeStrict' v2RatchetJSON
rcSupportKEM r `shouldBe` PQSupportOff
rcEnableKEM r `shouldBe` PQEncOff
rcSndKEM r `shouldBe` PQEncOff
rcRcvKEM r `shouldBe` PQEncOff
testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation
testEncodeDecode x = do
let j = J.encode x
@@ -173,77 +380,255 @@ testEncodeDecode x = do
testX3dh :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testX3dh _ = do
g <- C.newRandom
(pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion
(pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion
let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice
paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
(pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff
let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
paramsAlice `shouldBe` paramsBob
testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testX3dhV1 _ = do
g <- C.newRandom
(pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g 1
(pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g 1
let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice
paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob
(pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g (VersionE2E 1) Nothing
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g (VersionE2E 1) PQSupportOff
let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
paramsAlice `shouldBe` paramsBob
(#>) :: (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation
(alice, msg) #> bob = do
Right msg' <- encrypt alice msg
Decrypted msg'' <- decrypt bob msg'
testPqX3dhProposeInReply :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testPqX3dhProposeInReply _ = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (no KEM)
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff
-- propose KEM in reply
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM)
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
paramsAlice `compatibleRatchets` paramsBob
testPqX3dhProposeAccept :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testPqX3dhProposeAccept _ = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (propose KEM)
(pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn
E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice
-- accept KEM
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM aliceKem)
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob
paramsAlice `compatibleRatchets` paramsBob
testPqX3dhProposeReject :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testPqX3dhProposeReject _ = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (propose KEM)
(pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn
E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice
-- reject KEM
(pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob
paramsAlice `compatibleRatchets` paramsBob
testPqX3dhAcceptWithoutProposalError :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testPqX3dhAcceptWithoutProposalError _ = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (no KEM)
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff
E2ERatchetParams _ _ _ Nothing <- pure e2eAlice
-- incorrectly accept KEM
-- we don't have key in proposal, so we just generate it
(k, _) <- sntrup761Keypair g
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM k)
pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice `shouldBe` Left C.CERatchetKEMState
runExceptT (pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob) `shouldReturn` Left C.CERatchetKEMState
testPqX3dhProposeAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
testPqX3dhProposeAgain _ = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (propose KEM)
(pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn
E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice
-- propose KEM again in reply - this is not an error
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM)
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob
paramsAlice `compatibleRatchets` paramsBob
compatibleRatchets :: (RatchetInitParams, x) -> (RatchetInitParams, x) -> Expectation
compatibleRatchets
(RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, _)
(RatchetInitParams {assocData = ad, ratchetKey = rk, sndHK = shk, rcvNextHK = rnhk, kemAccepted = ka}, _) = do
assocData == ad && ratchetKey == rk && sndHK == shk && rcvNextHK == rnhk `shouldBe` True
case (kemAccepted, ka) of
(Just RatchetKEMAccepted {rcPQRr, rcPQRss, rcPQRct}, Just RatchetKEMAccepted {rcPQRr = pqk, rcPQRss = pqss, rcPQRct = pqct}) ->
pqk /= rcPQRr && pqss == rcPQRss && pqct == rcPQRct `shouldBe` True
(Nothing, Nothing) -> pure ()
_ -> expectationFailure "RatchetInitParams params are not compatible"
encryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Maybe PQEncryption -> (Ratchet a -> ()) -> (Ratchet a -> ()) -> EncryptDecryptSpec a
encryptDecrypt pqEnc validSnd validRcv (alice, msg) bob = do
Right msg' <- withTVar (encrypt_ pqEnc) validSnd alice msg
Decrypted msg'' <- decrypt' validRcv bob msg'
msg'' `shouldBe` msg
withRatchets :: forall a. (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()) -> Expectation
withRatchets test = do
-- enable KEM (currently disabled)
(\#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
(s, msg) \#>! r = encryptDecrypt (Just PQEncOn) noSndKEM noRcvKEM (s, msg) r
-- enable KEM (currently enabled)
(!#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
(s, msg) !#>! r = encryptDecrypt (Just PQEncOn) hasSndKEM hasRcvKEM (s, msg) r
-- KEM enabled (no user preference)
(!#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
(s, msg) !#> r = encryptDecrypt Nothing hasSndKEM hasRcvKEM (s, msg) r
-- disable KEM (currently enabled)
(!#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
(s, msg) !#>\ r = encryptDecrypt (Just PQEncOff) hasSndKEM hasRcvKEM (s, msg) r
-- disable KEM (currently disabled)
(\#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
(s, msg) \#>\ r = encryptDecrypt (Just PQEncOff) noSndKEM noSndKEM (s, msg) r
-- KEM disabled (no user preference)
(\#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
(s, msg) \#> r = encryptDecrypt Nothing noSndKEM noSndKEM (s, msg) r
withRatchets_ :: IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) -> TestRatchets a -> Expectation
withRatchets_ initRatchets_ test = do
ga <- C.newRandom
gb <- C.newRandom
(a, b) <- initRatchets @a
(a, b, encrypt, decrypt, (#>)) <- initRatchets_
alice <- newTVarIO (ga, a, M.empty)
bob <- newTVarIO (gb, b, M.empty)
test alice bob `shouldReturn` ()
test alice bob encrypt decrypt (#>) `shouldReturn` ()
initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a)
initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
initRatchets = do
g <- C.newRandom
(pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams g currentE2EEncryptVersion
(pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams g currentE2EEncryptVersion
let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice
paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
(pkBob1, pkBob2, _pKemParams@Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v Nothing
(pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
(_, pkBob3) <- atomically $ C.generateKeyPair g
let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob
alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice
pure (alice, bob)
let vs = testRatchetVersions PQSupportOff
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOff
pure (alice, bob, encrypt' noSndKEM, decrypt' noRcvKEM, (\#>))
encrypt_ :: AlgorithmI a => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff))
encrypt_ (_, rc, _) msg =
runExceptT (rcEncrypt rc paddedMsgLen msg)
initRatchetsKEMProposed :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
initRatchetsKEMProposed = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (no KEM)
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff
-- propose KEM in reply
let useKem = AUseKEM SRKSProposed ProposeKEM
(pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem)
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
(_, pkBob3) <- atomically $ C.generateKeyPair g
let vs = testRatchetVersions PQSupportOn
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn
pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>))
initRatchetsKEMAccepted :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
initRatchetsKEMAccepted = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (propose)
(pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn
E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice
-- accept
let useKem = AUseKEM SRKSAccepted (AcceptKEM aliceKem)
(pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem)
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob
(_, pkBob3) <- atomically $ C.generateKeyPair g
let vs = testRatchetVersions PQSupportOn
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn
pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>))
initRatchetsKEMProposedAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
initRatchetsKEMProposedAgain = do
g <- C.newRandom
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
-- initiate (propose KEM)
(pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn
-- propose KEM again in reply
let useKem = AUseKEM SRKSProposed ProposeKEM
(pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem)
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob
(_, pkBob3) <- atomically $ C.generateKeyPair g
let vs = testRatchetVersions PQSupportOn
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn
pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>))
testRatchetVersions :: PQSupport -> RatchetVersions
testRatchetVersions pq =
let v = maxVersion $ supportedE2EEncryptVRange pq
in RatchetVersions v v
encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff))
encrypt_ pqEnc_ (_, rc, _) msg =
-- print msg >>
runExceptT (rcEncrypt rc paddedMsgLen msg pqEnc_ currentE2EEncryptVersion)
>>= either (pure . Left) checkLength
where
checkLength (msg', rc') = do
B.length msg' `shouldBe` fullMsgLen
B.length msg' `shouldBe` fullMsgLen rc'
pure $ Right (msg', rc', SMDNoChange)
decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff))
decrypt_ (g, rc, smks) msg = runExceptT $ rcDecrypt g rc smks msg
encrypt :: AlgorithmI a => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString)
encrypt = withTVar encrypt_
encrypt' :: AlgorithmI a => (Ratchet a -> ()) -> Encrypt a
encrypt' = withTVar $ encrypt_ Nothing
decrypt :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString))
decrypt = withTVar decrypt_
decrypt' :: (AlgorithmI a, DhAlgorithm a) => (Ratchet a -> ()) -> Decrypt a
decrypt' = withTVar decrypt_
noSndKEM :: Ratchet a -> ()
noSndKEM Ratchet {rcSndKEM = PQEncOn} = error "snd ratchet has KEM"
noSndKEM _ = ()
noRcvKEM :: Ratchet a -> ()
noRcvKEM Ratchet {rcRcvKEM = PQEncOn} = error "rcv ratchet has KEM"
noRcvKEM _ = ()
hasSndKEM :: Ratchet a -> ()
hasSndKEM Ratchet {rcSndKEM = PQEncOn} = ()
hasSndKEM _ = error "snd ratchet has no KEM"
hasRcvKEM :: Ratchet a -> ()
hasRcvKEM Ratchet {rcRcvKEM = PQEncOn} = ()
hasRcvKEM _ = error "rcv ratchet has no KEM"
withTVar ::
AlgorithmI a =>
((TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e (r, Ratchet a, SkippedMsgDiff))) ->
(Ratchet a -> ()) ->
TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) ->
ByteString ->
IO (Either e r)
withTVar op rcVar msg = do
withTVar op valid rcVar msg = do
(g, rc, smks) <- readTVarIO rcVar
applyDiff smks <$$> (testEncodeDecode rc >> op (g, rc, smks) msg)
>>= \case
Right (res, rc', smks') -> atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res)
Right (res, rc', smks') -> valid rc' `seq` atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res)
Left e -> pure $ Left e
where
applyDiff smks (res, rc', smDiff) = (res, rc', applySMDiff smks smDiff)
+25
View File
@@ -0,0 +1,25 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module AgentTests.EqInstances where
import Data.Type.Equality
import Simplex.Messaging.Agent.Store
instance Eq SomeConn where
SomeConn d c == SomeConn d' c' = case testEquality d d' of
Just Refl -> c == c'
_ -> False
deriving instance Eq (Connection d)
deriving instance Eq (SConnType d)
deriving instance Eq (StoredRcvQueue q)
deriving instance Eq (StoredSndQueue q)
deriving instance Eq (DBQueueId q)
deriving instance Eq ClientNtfCreds
File diff suppressed because it is too large Load Diff
+27 -5
View File
@@ -12,7 +12,28 @@
module AgentTests.NotificationTests where
-- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging)
import AgentTests.FunctionalAPITests (agentCfgV7, exchangeGreetingsMsgId, get, getSMPAgentClient', makeConnection, nGet, runRight, runRight_, switchComplete, testServerMatrix2, withAgentClientsCfg2, (##>), (=##>), pattern Msg)
import AgentTests.FunctionalAPITests
( agentCfgV7,
createConnection,
exchangeGreetingsMsgId,
get,
getSMPAgentClient',
joinConnection,
makeConnection,
nGet,
runRight,
runRight_,
sendMessage,
switchComplete,
testServerMatrix2,
withAgentClientsCfg2,
(##>),
(=##>),
pattern CON,
pattern CONF,
pattern INFO,
pattern Msg,
)
import Control.Concurrent (ThreadId, killThread, threadDelay)
import Control.Monad
import Control.Monad.Except
@@ -28,10 +49,10 @@ import Data.Text.Encoding (encodeUtf8)
import NtfClient
import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testDB3, testNtfServer, testNtfServer2)
import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage)
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore')
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers)
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO)
import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
@@ -46,6 +67,7 @@ import Simplex.Messaging.Transport (ATransport)
import System.Directory (doesFileExist, removeFile)
import Test.Hspec
import UnliftIO
import Util
removeFileIfExists :: FilePath -> IO ()
removeFileIfExists filePath = do
@@ -125,8 +147,8 @@ testNtfMatrix t runTest = do
it "next servers: SMP v7, NTF v2; next clients: v7/v2" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfgV7 agentCfgV7 runTest
it "next servers: SMP v7, NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfg agentCfg runTest
it "curr servers: SMP v6, NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfg agentCfg agentCfg runTest
-- this case will cannot be supported - see RFC
xit "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest
skip "this case cannot be supported - see RFC" $
it "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest
-- servers can be migrated in any order
it "servers: next SMP v7, curr NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfg agentCfg agentCfg runTest
it "servers: curr SMP v6, next NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfgV2 agentCfg agentCfg runTest
+42 -15
View File
@@ -1,20 +1,27 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
module AgentTests.SQLiteTests (storeTests) where
import AgentTests.EqInstances ()
import Control.Concurrent.Async (concurrently_)
import Control.Concurrent.STM
import Control.Exception (SomeException)
import Control.Monad (replicateM_)
import Control.Monad.Trans.Except
import Crypto.Random (ChaChaDRG)
import Data.ByteArray (ScrubbedBytes)
import Data.ByteString.Char8 (ByteString)
import Data.List (isInfixOf)
@@ -38,9 +45,11 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction')
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQSupportOn)
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Crypto.File (CryptoFile (..))
import Simplex.Messaging.Encoding.String (StrEncoding (..))
import Simplex.Messaging.Protocol (SubscriptionMode (..))
import Simplex.Messaging.Protocol (SubscriptionMode (..), pattern VersionSMPC)
import qualified Simplex.Messaging.Protocol as SMP
import System.Random
import Test.Hspec
@@ -91,7 +100,7 @@ storeTests = do
testForeignKeysEnabled
describe "db methods" $ do
describe "Queue and Connection management" $ do
describe "createRcvConn" $ do
describe "create Rcv connection" $ do
testCreateRcvConn
testCreateRcvConnRandomId
testCreateRcvConnDuplicate
@@ -172,7 +181,17 @@ testForeignKeysEnabled =
`shouldThrow` (\e -> SQL.sqlError e == SQL.ErrorConstraint)
cData1 :: ConnData
cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
cData1 =
ConnData
{ userId = 1,
connId = "conn1",
connAgentVersion = VersionSMPA 1,
enableNtfs = True,
lastExternalSndId = 0,
deleted = False,
ratchetSyncState = RSOk,
pqSupport = CR.PQSupportOn
}
testPrivateAuthKey :: C.APrivateAuthKey
testPrivateAuthKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
@@ -203,7 +222,7 @@ rcvQueue1 =
primary = True,
dbReplaceQueueId = Nothing,
rcvSwchStatus = Nothing,
smpClientVersion = 1,
smpClientVersion = VersionSMPC 1,
clientNtfCreds = Nothing,
deleteErrors = 0
}
@@ -224,9 +243,15 @@ sndQueue1 =
primary = True,
dbReplaceQueueId = Nothing,
sndSwchStatus = Nothing,
smpClientVersion = 1
smpClientVersion = VersionSMPC 1
}
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue))
createRcvConn db g cData rq cMode = runExceptT $ do
connId <- ExceptT $ createNewConn db g cData cMode
rq' <- ExceptT $ updateNewConnRcv db connId rq
pure (connId, rq')
testCreateRcvConn :: SpecWith SQLiteStore
testCreateRcvConn =
it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do
@@ -312,8 +337,8 @@ testDeleteRcvConn =
Right (_, rq) <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
getConn db "conn1"
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq))
deleteConn db "conn1"
`shouldReturn` ()
deleteConn db Nothing "conn1"
`shouldReturn` Just "conn1"
getConn db "conn1"
`shouldReturn` Left SEConnNotFound
@@ -324,8 +349,8 @@ testDeleteSndConn =
Right (_, sq) <- createSndConn db g cData1 sndQueue1
getConn db "conn1"
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq))
deleteConn db "conn1"
`shouldReturn` ()
deleteConn db Nothing "conn1"
`shouldReturn` Just "conn1"
getConn db "conn1"
`shouldReturn` Left SEConnNotFound
@@ -337,8 +362,8 @@ testDeleteDuplexConn =
Right sq <- upgradeRcvConnToDuplex db "conn1" sndQueue1
getConn db "conn1"
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq]))
deleteConn db "conn1"
`shouldReturn` ()
deleteConn db Nothing "conn1"
`shouldReturn` Just "conn1"
getConn db "conn1"
`shouldReturn` Left SEConnNotFound
@@ -362,7 +387,7 @@ testUpgradeRcvConnToDuplex =
sndSwchStatus = Nothing,
primary = True,
dbReplaceQueueId = Nothing,
smpClientVersion = 1
smpClientVersion = VersionSMPC 1
}
upgradeRcvConnToDuplex db "conn1" anotherSndQueue
`shouldReturn` Left (SEBadConnType CSnd)
@@ -391,7 +416,7 @@ testUpgradeSndConnToDuplex =
rcvSwchStatus = Nothing,
primary = True,
dbReplaceQueueId = Nothing,
smpClientVersion = 1,
smpClientVersion = VersionSMPC 1,
clientNtfCreds = Nothing,
deleteErrors = 0
}
@@ -459,7 +484,8 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
{ integrity = MsgOk,
recipient = (unId internalId, ts),
sndMsgId = externalSndId,
broker = (brokerId, ts)
broker = (brokerId, ts),
pqEncryption = CR.PQEncOn
},
msgType = AM_A_MSG_,
msgFlags = SMP.noMsgFlags,
@@ -497,6 +523,7 @@ mkSndMsgData internalId internalSndId internalHash =
msgType = AM_A_MSG_,
msgFlags = SMP.noMsgFlags,
msgBody = hw,
pqEncryption = CR.PQEncOn,
internalHash,
prevMsgHash = internalHash
}
@@ -635,7 +662,7 @@ testGetPendingServerCommand st = do
Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1)
corrId' `shouldBe` "4"
where
command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) SMSubscribe
command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe
corruptCmd :: DB.Connection -> ByteString -> ConnId -> IO ()
corruptCmd db corrId connId = DB.execute db "UPDATE commands SET command = cast('bad' as blob) WHERE conn_id = ? AND corr_id = ?" (connId, corrId)
+16 -17
View File
@@ -15,7 +15,6 @@ import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport
import Simplex.Messaging.Version (Version)
import Test.Hspec
batchingTests :: Spec
@@ -253,27 +252,27 @@ testClientBatchWithLargeMessageV7 = do
(length rs1', length rs2') `shouldBe` (74, 136)
all lenOk [s1', s2'] `shouldBe` True
testClientStub :: IO (ProtocolClient ErrorType BrokerMsg)
testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg)
testClientStub = do
g <- C.newRandom
sessId <- atomically $ C.randomBytes 32 g
atomically $ clientStub g sessId (authCmdsSMPVersion - 1) Nothing
atomically $ smpClientStub g sessId subModeSMPVersion Nothing
clientStubV7 :: IO (ProtocolClient ErrorType BrokerMsg)
clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg)
clientStubV7 = do
g <- C.newRandom
sessId <- atomically $ C.randomBytes 32 g
(rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g
thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey
atomically $ clientStub g sessId authCmdsSMPVersion thAuth_
atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_
randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSUB = randomSUB_ C.SEd25519 (authCmdsSMPVersion - 1)
randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion
randomSUBv7 :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSUBv7 = randomSUB_ C.SEd25519 authCmdsSMPVersion
randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSUB_ a v sessId = do
g <- C.newRandom
rId <- atomically $ C.randomBytes 24 g
@@ -284,13 +283,13 @@ randomSUB_ a v sessId = do
TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, rId, Cmd SRecipient SUB)
pure $ (,tToSend) <$> authTransmission thAuth_ (Just rpKey) corrId tForAuth
randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmd = randomSUBCmd_ C.SEd25519
randomSUBCmdV7 :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmdV7 = randomSUBCmd_ C.SEd25519 -- same as v6
randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
randomSUBCmd_ a c = do
g <- C.newRandom
rId <- atomically $ C.randomBytes 24 g
@@ -298,12 +297,12 @@ randomSUBCmd_ a c = do
mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB)
randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSEND = randomSEND_ C.SEd25519 (authCmdsSMPVersion - 1)
randomSEND = randomSEND_ C.SEd25519 subModeSMPVersion
randomSENDv7 :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSENDv7 = randomSEND_ C.SX25519 authCmdsSMPVersion
randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
randomSEND_ a v sessId len = do
g <- C.newRandom
sId <- atomically $ C.randomBytes 24 g
@@ -315,7 +314,7 @@ randomSEND_ a v sessId len = do
TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, sId, Cmd SSender $ SEND noMsgFlags msg)
pure $ (,tToSend) <$> authTransmission thAuth_ (Just spKey) corrId tForAuth
testTHandleParams :: Version -> ByteString -> THandleParams
testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion
testTHandleParams v sessionId =
THandleParams
{ sessionId,
@@ -326,20 +325,20 @@ testTHandleParams v sessionId =
batch = True
}
testTHandleAuth :: Version -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth)
testTHandleAuth :: VersionSMP -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth)
testTHandleAuth v g (C.APublicAuthKey a k) = case a of
C.SX25519 | v >= authCmdsSMPVersion -> do
(_, privKey) <- atomically $ C.generateKeyPair g
pure $ Just THandleAuth {peerPubKey = k, privKey}
_ -> pure Nothing
randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmd = randomSENDCmd_ C.SEd25519
randomSENDCmdV7 :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmdV7 = randomSENDCmd_ C.SX25519
randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
randomSENDCmd_ a c len = do
g <- C.newRandom
sId <- atomically $ C.randomBytes 24 g
+13
View File
@@ -1,5 +1,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module CoreTests.CryptoTests (cryptoTests) where
@@ -13,6 +15,7 @@ import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import qualified Data.Text.Lazy as LT
import qualified Data.Text.Lazy.Encoding as LE
import Data.Type.Equality
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
@@ -91,6 +94,16 @@ cryptoTests = do
describe "sntrup761" $
it "should enc/dec key" testSNTRUP761
instance Eq C.APublicKey where
C.APublicKey a k == C.APublicKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
instance Eq C.APrivateKey where
C.APrivateKey a k == C.APrivateKey a' k' = case testEquality a a' of
Just Refl -> k == k'
Nothing -> False
testPadUnpadFile :: IO ()
testPadUnpadFile = do
let f = "tests/tmp/testpad"
+1 -1
View File
@@ -12,7 +12,7 @@ import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
import Simplex.FileTransfer.Transport (XFTPErrorType (..))
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
+4 -2
View File
@@ -1,9 +1,11 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeApplications #-}
module CoreTests.TRcvQueuesTests where
import AgentTests.EqInstances ()
import qualified Data.List.NonEmpty as L
import qualified Data.Map as M
import qualified Data.Set as S
@@ -11,7 +13,7 @@ import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId)
import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..))
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (SMPServer)
import Simplex.Messaging.Protocol (SMPServer, pattern VersionSMPC)
import Test.Hspec
import UnliftIO
@@ -136,7 +138,7 @@ dummyRQ userId server connId =
primary = True,
dbReplaceQueueId = Nothing,
rcvSwchStatus = Nothing,
smpClientVersion = 123,
smpClientVersion = VersionSMPC 123,
clientNtfCreds = Nothing,
deleteErrors = 0
}
+21 -15
View File
@@ -1,6 +1,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module CoreTests.VersionRangeTests where
@@ -8,6 +9,7 @@ module CoreTests.VersionRangeTests where
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import Simplex.Messaging.Version
import Simplex.Messaging.Version.Internal
import Test.Hspec
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck
@@ -16,6 +18,10 @@ data V = V1 | V2 | V3 | V4 | V5 deriving (Eq, Enum, Ord, Generic, Show)
instance Arbitrary V where arbitrary = genericArbitraryU
data T
instance VersionScope T
versionRangeTests :: Spec
versionRangeTests = modifyMaxSuccess (const 1000) $ do
describe "VersionRange construction" $ do
@@ -25,31 +31,31 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do
(pure $! vr 2 1) `shouldThrow` anyErrorCall
describe "compatible version" $ do
it "should choose mutually compatible max version" $ do
(vr 1 1, vr 1 1) `compatible` Just 1
(vr 1 1, vr 1 2) `compatible` Just 1
(vr 1 2, vr 1 2) `compatible` Just 2
(vr 1 2, vr 2 3) `compatible` Just 2
(vr 1 3, vr 2 3) `compatible` Just 3
(vr 1 3, vr 2 4) `compatible` Just 3
(vr 1 1, vr 1 1) `compatible` Just (Version 1)
(vr 1 1, vr 1 2) `compatible` Just (Version 1)
(vr 1 2, vr 1 2) `compatible` Just (Version 2)
(vr 1 2, vr 2 3) `compatible` Just (Version 2)
(vr 1 3, vr 2 3) `compatible` Just (Version 3)
(vr 1 3, vr 2 4) `compatible` Just (Version 3)
(vr 1 2, vr 3 4) `compatible` Nothing
it "should check if version is compatible" $ do
isCompatible (1 :: Version) (vr 1 2) `shouldBe` True
isCompatible (2 :: Version) (vr 1 2) `shouldBe` True
isCompatible (2 :: Version) (vr 1 1) `shouldBe` False
isCompatible (1 :: Version) (vr 2 2) `shouldBe` False
isCompatible @T (Version 1) (vr 1 2) `shouldBe` True
isCompatible @T (Version 2) (vr 1 2) `shouldBe` True
isCompatible @T (Version 2) (vr 1 1) `shouldBe` False
isCompatible @T (Version 1) (vr 2 2) `shouldBe` False
it "compatibleVersion should pass isCompatible check" . property $
\((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) ->
min1 > max1
|| min2 > max2 -- one of ranges is invalid, skip testing it
|| let w = fromIntegral . fromEnum
vr1 = mkVersionRange (w min1) (w max1) :: VersionRange
vr2 = mkVersionRange (w min2) (w max2) :: VersionRange
|| let w = Version . fromIntegral . fromEnum
vr1 = mkVersionRange (w min1) (w max1) :: VersionRange T
vr2 = mkVersionRange (w min2) (w max2) :: VersionRange T
in case compatibleVersion vr1 vr2 of
Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2
_ -> True
where
vr = mkVersionRange
compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation
vr v1 v2 = mkVersionRange (Version v1) (Version v2)
compatible :: (VersionRange T, VersionRange T) -> Maybe (Version T) -> Expectation
(vr1, vr2) `compatible` v = do
(vr1, vr2) `checkCompatible` v
(vr2, vr1) `checkCompatible` v
+8 -7
View File
@@ -34,6 +34,7 @@ import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost,
import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Protocol (NtfResponse)
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
import Simplex.Messaging.Notifications.Server.Env
import Simplex.Messaging.Notifications.Server.Push.APNS
@@ -70,7 +71,7 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
ntfTestStoreLogFile :: FilePath
ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a
testNtfClient client = do
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do
@@ -114,8 +115,8 @@ ntfServerCfg =
ntfServerCfgV2 :: NtfServerConfig
ntfServerCfgV2 =
ntfServerCfg
{ ntfServerVRange = mkVersionRange 1 authBatchCmdsNTFVersion,
smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange 4 authCmdsSMPVersion}}
{ ntfServerVRange = mkVersionRange initialNTFVersion authBatchCmdsNTFVersion,
smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}}
}
withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a
@@ -139,7 +140,7 @@ withNtfServerOn t port' = withNtfServerThreadOn t port' . const
withNtfServer :: ATransport -> IO a -> IO a
withNtfServer t = withNtfServerOn t ntfTestPort
runNtfTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a
runNtfTest :: forall c a. Transport c => (THandleNTF c -> IO a) -> IO a
runNtfTest test = withNtfServer (transport @c) $ testNtfClient test
ntfServerTest ::
@@ -147,10 +148,10 @@ ntfServerTest ::
(Transport c, Encoding smp) =>
TProxy c ->
(Maybe TransmissionAuth, ByteString, ByteString, smp) ->
IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg)
IO (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse)
ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
where
tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
tPut' :: THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do
let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp)
[Right ()] <- tPut h [Right (sig, t')]
@@ -159,7 +160,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
pure (Nothing, corrId, qId, cmd)
ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
ntfTest :: Transport c => TProxy c -> (THandleNTF c -> IO ()) -> Expectation
ntfTest _ test' = runNtfTest test' `shouldReturn` ()
data APNSMockRequest = APNSMockRequest
+13 -8
View File
@@ -5,7 +5,9 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module NtfServerTests where
@@ -37,6 +39,7 @@ import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push.APNS
import qualified Simplex.Messaging.Notifications.Server.Push.APNS as APNS
import Simplex.Messaging.Notifications.Transport (THandleNTF)
import Simplex.Messaging.Parsers (parse, parseAll)
import Simplex.Messaging.Protocol hiding (notification)
import Simplex.Messaging.Transport
@@ -50,30 +53,32 @@ ntfServerTests t = do
ntfSyntaxTests :: ATransport -> Spec
ntfSyntaxTests (ATransport t) = do
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", NRErr $ CMD UNKNOWN)
describe "NEW" $ do
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX)
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX)
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH)
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH)
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", NRErr $ CMD SYNTAX)
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", NRErr $ CMD SYNTAX)
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", NRErr $ CMD NO_AUTH)
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", NRErr $ CMD HAS_AUTH)
where
(>#>) ::
Encoding smp =>
(Maybe TransmissionAuth, ByteString, ByteString, smp) ->
(Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) ->
(Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) ->
Expectation
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
deriving instance Eq NtfResponse
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (sgn, tToSend)
tGet1 h
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (authorize tForAuth, tToSend)
+21 -20
View File
@@ -26,7 +26,7 @@ import Simplex.Messaging.Server.Env.STM
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client
import Simplex.Messaging.Transport.Server
import Simplex.Messaging.Version (VersionRange, mkVersionRange)
import Simplex.Messaging.Version (mkVersionRange)
import System.Environment (lookupEnv)
import System.Info (os)
import Test.Hspec
@@ -34,6 +34,7 @@ import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
import UnliftIO.Timeout (timeout)
import Util
testHost :: NonEmpty TransportHost
testHost = "localhost"
@@ -60,17 +61,17 @@ testServerStatsBackupFile :: FilePath
testServerStatsBackupFile = "tests/tmp/smp-server-stats.log"
xit' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a)
xit' = if os == "linux" then xit else it
xit' d = if os == "linux" then skip "skipped on Linux" . it d else it d
xit'' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a)
xit'' d t = do
ci <- runIO $ lookupEnv "CI"
(if ci == Just "true" then xit else it) d t
(if ci == Just "true" then skip "skipped on CI" . it d else it d) t
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a
testSMPClient = testSMPClientVR supportedClientSMPRelayVRange
testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRange -> (THandle c -> m a) -> m a
testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a
testSMPClientVR vr client = do
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do
@@ -109,7 +110,7 @@ cfg =
}
cfgV7 :: ServerConfig
cfgV7 = cfg {smpServerVRange = mkVersionRange 4 authCmdsSMPVersion}
cfgV7 = cfg {smpServerVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}
withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile}
@@ -148,16 +149,16 @@ withSmpServer t = withSmpServerOn t testPort
withSmpServerV7 :: HasCallStack => ATransport -> IO a -> IO a
withSmpServerV7 t = withSmpServerConfigOn t cfgV7 testPort . const
runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a
runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandleSMP c -> IO a) -> IO a
runSmpTest test = withSmpServer (transport @c) $ testSMPClient test
runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a
runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a
runSmpTestN = runSmpTestNCfg cfg supportedClientSMPRelayVRange
runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> Int -> (HasCallStack => [THandle c] -> IO a) -> IO a
runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a
runSmpTestNCfg srvCfg clntVR nClients test = withSmpServerConfigOn (transport @c) srvCfg testPort $ \_ -> run nClients []
where
run :: Int -> [THandle c] -> IO a
run :: Int -> [THandleSMP c] -> IO a
run 0 hs = test hs
run n hs = testSMPClientVR clntVR $ \h -> run (n - 1) (h : hs)
@@ -169,7 +170,7 @@ smpServerTest ::
IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg)
smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h
where
tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
tPut' :: THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do
let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp)
[Right ()] <- tPut h [Right (sig, t')]
@@ -178,33 +179,33 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
pure (Nothing, corrId, qId, cmd)
smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation
smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> IO ()) -> Expectation
smpTest _ test' = runSmpTest test' `shouldReturn` ()
smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation
smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO ()) -> Expectation
smpTestN n test' = runSmpTestN n test' `shouldReturn` ()
smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation
smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
smpTest2 = smpTest2Cfg cfg supportedClientSMPRelayVRange
smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation
smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
smpTest2Cfg srvCfg clntVR _ test' = runSmpTestNCfg srvCfg clntVR 2 _test `shouldReturn` ()
where
_test :: HasCallStack => [THandle c] -> IO ()
_test :: HasCallStack => [THandleSMP c] -> IO ()
_test [h1, h2] = test' h1 h2
_test _ = error "expected 2 handles"
smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
smpTest3 _ test' = smpTestN 3 _test
where
_test :: HasCallStack => [THandle c] -> IO ()
_test :: HasCallStack => [THandleSMP c] -> IO ()
_test [h1, h2, h3] = test' h1 h2 h3
_test _ = error "expected 3 handles"
smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
smpTest4 _ test' = smpTestN 4 _test
where
_test :: HasCallStack => [THandle c] -> IO ()
_test :: HasCallStack => [THandleSMP c] -> IO ()
_test [h1, h2, h3, h4] = test' h1 h2 h3 h4
_test _ = error "expected 4 handles"
+33 -20
View File
@@ -1,6 +1,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -8,7 +9,9 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module ServerTests where
@@ -23,6 +26,7 @@ import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Set as S
import Data.Type.Equality
import GHC.Stack (withFrozenCallStack)
import SMPClient
import qualified Simplex.Messaging.Crypto as C
@@ -74,13 +78,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh)
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (sgn, tToSend)
tGet1 h
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
Right () <- tPut1 h (authorize tForAuth, tToSend)
@@ -94,12 +98,12 @@ signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
_sx448 -> undefined -- ghc8107 fails to the branch excluded by types
#endif
tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
tPut1 :: Transport c => THandle v c -> SentRawTransmission -> IO (Either TransportError ())
tPut1 h t = do
[r] <- tPut h [Right t]
pure r
tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd)
tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd)
tGet1 h = do
[r] <- liftIO $ tGet h
pure r
@@ -380,7 +384,7 @@ testSwitchSub (ATransport t) =
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3)
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case
Nothing -> return ()
Just _ -> error "nothing else is delivered to the 1st TCP connection"
@@ -551,12 +555,12 @@ testWithStoreLog at@(ATransport t) =
logSize testStoreLogFile `shouldReturn` 1
removeFile testStoreLogFile
where
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation
runTest _ test' server = do
testSMPClient test' `shouldReturn` ()
killThread server
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation
runClient _ test' = testSMPClient test' `shouldReturn` ()
logSize :: FilePath -> IO Int
@@ -649,12 +653,12 @@ testRestoreMessages at@(ATransport t) =
removeFile testStoreMsgsFile
removeFile testServerStatsBackupFile
where
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation
runTest _ test' server = do
testSMPClient test' `shouldReturn` ()
killThread server
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation
runClient _ test' = testSMPClient test' `shouldReturn` ()
checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation
@@ -723,15 +727,15 @@ testRestoreExpireMessages at@(ATransport t) =
Right ServerStatsData {_msgExpired} <- strDecode <$> B.readFile testServerStatsBackupFile
_msgExpired `shouldBe` 2
where
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation
runTest _ test' server = do
testSMPClient test' `shouldReturn` ()
killThread server
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation
runClient _ test' = testSMPClient test' `shouldReturn` ()
createAndSecureQueue :: Transport c => THandle c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret)
createAndSecureQueue :: Transport c => THandleSMP c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret)
createAndSecureQueue h sPub = do
g <- C.newRandom
(rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g
@@ -747,7 +751,7 @@ testTiming (ATransport t) =
describe "should have similar time for auth error, whether queue exists or not, for all key types" $
forM_ timingTests $ \tst ->
it (testName tst) $
smpTest2Cfg cfgV7 (mkVersionRange 4 authCmdsSMPVersion) t $ \rh sh ->
smpTest2Cfg cfgV7 (mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion) t $ \rh sh ->
testSameTiming rh sh tst
where
testName :: (C.AuthAlg, C.AuthAlg, Int) -> String
@@ -766,7 +770,7 @@ testTiming (ATransport t) =
]
timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const
similarTime t1 t2 = abs (t2 / t1 - 1) < 0.15 -- normally the difference between "no queue" and "wrong key" is less than 5%
testSameTiming :: forall c. Transport c => THandle c -> THandle c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation
testSameTiming :: forall c. Transport c => THandleSMP c -> THandleSMP c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation
testSameTiming rh sh (C.AuthAlg goodKeyAlg, C.AuthAlg badKeyAlg, n) = do
g <- C.newRandom
(rPub, rKey) <- atomically $ C.generateAuthKeyPair goodKeyAlg g
@@ -787,7 +791,7 @@ testTiming (ATransport t) =
runTimingTest sh badKey sId $ _SEND "hello"
where
runTimingTest :: PartyI p => THandle c -> C.APrivateAuthKey -> ByteString -> Command p -> IO ()
runTimingTest :: PartyI p => THandleSMP c -> C.APrivateAuthKey -> ByteString -> Command p -> IO ()
runTimingTest h badKey qId cmd = do
threadDelay 100000
_ <- timeRepeat n $ do -- "warm up" the server
@@ -837,14 +841,14 @@ testMessageNotifications (ATransport t) =
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
Resp "" _ (NMSG _ _) <- tGet1 nh2
1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there")
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
Nothing -> pure ()
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
@@ -864,7 +868,7 @@ testMsgExpireOnSend t =
testSMPClient @c $ \rh -> do
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing else should be delivered"
@@ -884,7 +888,7 @@ testMsgExpireOnInterval t =
signSendRecv rh rKey ("2", rId, SUB) >>= \case
Resp "2" _ OK -> pure ()
r -> unexpected r
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing should be delivered"
@@ -903,7 +907,7 @@ testMsgNOTExpireOnInterval t =
testSMPClient @c $ \rh -> do
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
Nothing -> return ()
Just _ -> error "nothing else should be delivered"
@@ -919,6 +923,15 @@ sampleSig = Just $ TASignature "e8JK+8V3fq6kOLqco/SaKlpNaQ7i1gfOrXoqekEl42u4mF8B
noAuth :: (Char, Maybe BasicAuth)
noAuth = ('A', Nothing)
deriving instance Eq TransmissionAuth
instance Eq C.ASignature where
C.ASignature a s == C.ASignature a' s' = case testEquality a a' of
Just Refl -> s == s'
_ -> False
deriving instance Eq (C.Signature a)
syntaxTests :: ATransport -> Spec
syntaxTests (ATransport t) = do
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
+6
View File
@@ -0,0 +1,6 @@
module Util where
import Test.Hspec
skip :: String -> SpecWith a -> SpecWith a
skip = before_ . pendingWith
+2 -1
View File
@@ -21,7 +21,8 @@ import Data.List (find, isSuffixOf)
import Data.Maybe (fromJust)
import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2, testDB3)
import Simplex.FileTransfer.Description (FileDescription (..), FileDescriptionURI (..), ValidFileDescription, fileDescriptionURI, mb, qrSizeLimit, pattern ValidFileDescription)
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType (AUTH))
import Simplex.FileTransfer.Protocol (FileParty (..))
import Simplex.FileTransfer.Transport (XFTPErrorType (AUTH))
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..))
import Simplex.Messaging.Agent (AgentClient, disconnectAgentClient, testProtocolServer, xftpDeleteRcvFile, xftpDeleteSndFileInternal, xftpDeleteSndFileRemote, xftpReceiveFile, xftpSendDescription, xftpSendFile, xftpStartWorkers)
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..))
+2 -2
View File
@@ -20,9 +20,9 @@ import Data.List (isInfixOf)
import ServerTests (logSize)
import Simplex.FileTransfer.Client
import Simplex.FileTransfer.Description (kb)
import Simplex.FileTransfer.Protocol (FileInfo (..), XFTPErrorType (..))
import Simplex.FileTransfer.Protocol (FileInfo (..))
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..))
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..))
import Simplex.Messaging.Client (ProtocolClientError (..))
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC