mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-06-04 08:41:25 +00:00
Merge remote-tracking branch 'origin/master' into ab/bench-target
This commit is contained in:
+2
-1
@@ -1,5 +1,5 @@
|
||||
name: simplexmq
|
||||
version: 5.6.0.0
|
||||
version: 5.6.0.2
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: |
|
||||
This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
@@ -74,6 +74,7 @@ dependencies:
|
||||
- unliftio-core == 0.2.*
|
||||
- websockets == 0.12.*
|
||||
- yaml == 0.11.*
|
||||
- zstd == 0.1.3.*
|
||||
|
||||
flags:
|
||||
swift:
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# Post-quantum double ratchet implementation
|
||||
|
||||
See [the previous doc](https://github.com/simplex-chat/simplex-chat/blob/stable/docs/rfcs/2023-09-30-pq-double-ratchet.md).
|
||||
|
||||
The main implementation consideration is that it should be both backwards and forwards compatible, to allow changing the connection DR to/from using PQ primitives (although client version downgrade may be impossible in this case), and also to decide whether to use PQ primitive on per-connection basis:
|
||||
- use without links (in SMP confirmation or in SMP invitation via address or via member), don't use with links (as they would be too large).
|
||||
- use in small groups, don't use in large groups.
|
||||
|
||||
Also note that for DR to work we need to have 2 KEMs running in parallel.
|
||||
|
||||
Possible combinations (assuming both clients support PQ):
|
||||
|
||||
| Stage | No PQ kem | PQ key sent | PQ key + PQ ct sent |
|
||||
|:------------:|:---------:|:-----------:|:-------------------:|
|
||||
| inv | + | + | - |
|
||||
| conf, in reply to: <br>no-pq inv <br>pq inv | <br>+<br>+ | <br>+<br>- | <br>-<br>+ |
|
||||
| 1st msg, in reply to:<br>no-pq conf<br>pq/pq+ct conf | <br>+<br>+ | <br>+<br>- | <br>-<br>+ |
|
||||
| Nth msg, in reply to:<br>no-pq msg <br>pq/pq+ct msg | <br>+<br>+ | <br>+<br>- | <br>-<br>+ |
|
||||
|
||||
These rules can be reduced to:
|
||||
1. initial invitation optionally has PQ key, but must not have ciphertext.
|
||||
2. all subsequent messages should be allowed without PQ key/ciphertext, but:
|
||||
- if the previous message had PQ key or PQ key with ciphertext, they must either have no PQ key, or have PQ key with ciphertext (PQ key without ciphertext is an error).
|
||||
- if the previous message had no PQ key, they must either have no PQ key, or have PQ key without ciphertext (PQ key with ciphertext is an error).
|
||||
|
||||
The rules for calculating the shared secret for received/sent messages are (assuming received message is valid according to the above rules):
|
||||
|
||||
| sent msg ><br>V received msg | no-pq | pq | pq+ct |
|
||||
|:------------------------------:|:-----------:|:-------:|:---------------:|
|
||||
| no-pq | DH / DH | DH / DH | err |
|
||||
| pq (sent msg was NOT pq) | DH / DH | err | DH / DH+KEM |
|
||||
| pq+ct (sent msg was NOT no-pq) | DH+KEM / DH | err | DH+KEM / DH+KEM |
|
||||
|
||||
To summarize, the upgrade to DH+KEM secret happens in a sent message that has PQ key with ciphertext sent in reply to message with PQ key only (without ciphertext), and the downgrade to DH secret happens in the message that has no PQ key.
|
||||
|
||||
The type for sending PQ key with optional ciphertext is `Maybe E2ERachetKEM` where `data E2ERachetKEM = E2ERachetKEM KEMPublicKey (Maybe KEMCiphertext)`, and for SMP invitation it will be simply `Maybe KEMPublicKey`. Possibly, there is a way to encode the rules above in the types, these types don't constrain possible transitions to valid ones.
|
||||
@@ -0,0 +1,92 @@
|
||||
# Migrating existing connections to post-quantum double ratchet algorithm
|
||||
|
||||
## Problem
|
||||
|
||||
Post-quantum variant of double ratchet algorithm represents an almost full-stack change affecting all parts of the protocol stack except client-server protocol (SMP):
|
||||
- double-ratchet end-to-end encryption: different encoding (additional large keys require byte-strings larger than 255 bytes with 2-byte length prefixes) and larger message headers (increased by ~2200 bytes).
|
||||
- agent-agent protocol: a smaller maximum message size to accomodate larger headers and to fit in 16kb blocks, reduced by ~2200 bytes for the messages and by almost ~4000 bytes for connection information.
|
||||
- chat protocol: also a smaller message size compensated by zstd comression of JSON messages.
|
||||
|
||||
We want the versioning that achieves these objectives:
|
||||
- all changes in all protocol layers happen at the same time, when both clients support it.
|
||||
- ability to downgrade the clients to the previous version without losing connection.
|
||||
- ability to opt-in into this functionality via "experimental" feature toggle, that enables post-quantum encryption in connections when both contacts enable this toggle.
|
||||
|
||||
To have ability to downgrade the clients we have two options:
|
||||
- roll-out this functionality in two stages: 1) roll-out clients support but do not enable the new version, and then 2) upgrade client version. The problem here is that the clients won't be able to opt-in into this experiment.
|
||||
- make offered range dependent on experimental feature being enabled. Currently we have an option to enable PQ encryption in agent API, and this option can be used as a proxy to maxium supported protocol version - if the option is passed, it can be seen as an indication that higher version range (or version) should offered (or accepted).
|
||||
|
||||
## Solution
|
||||
|
||||
Currently ratchet state stores version range. It's unclear what was the intended semantics of that version range - it simply stores the offered/supported version range at the time ratchet was initialised, but only a high bound is used to send in message headers, and it is never upgraded. In JSON this range is encoded as tuple (an array of two elements in JSON).
|
||||
|
||||
We could continue using this range with the meaning of the lower bound to be "currently used ratchet version" and the meaning of higher boundary to be "maximum supported ratchet version". We could also use the version communicated in message headers to upgrade ratchet version, with the condition that upgrade should only happen if both sides want it. Currently it's defined by pqEnableKEM property in ratchet state. We could also make it more explicit by defining maximum version to which ratchet should upgrade. Given that irreversible upgrades are not very common, it is probably ok to keep it implicit.
|
||||
|
||||
We can define a better type than VersionRange to reflect semantics of the range in ratchet (current/max supported range), but for backward compatibility it needs to be encoded in the same way as now.
|
||||
|
||||
To summarize, the proposed solution for ratchet versioning is:
|
||||
- define ratchet versions as new type to include current and maximum allowed versions, where maximum allowed will be either the same or lower than maximum supported based on PQ option (in 5.6), and in 5.7 it will be changed to maximum supported, so version starts upgrading independently from PQ being enabled.
|
||||
- make encodings in ratchet depend on current version (in curent code it depends on max version).
|
||||
- include max allowed in message header.
|
||||
- upgrade current if in range on each new message if less than max and higher than current (same as we do for connections).
|
||||
- increase max allowed once PQ is enabled (only in 5.6). Make max allowed the same as max supported (global constant).
|
||||
|
||||
```haskell
|
||||
data RatchetVR = RatchetVR
|
||||
{ currentVersion :: Version,
|
||||
maxAllowedVersion :: Version
|
||||
}
|
||||
|
||||
instance ToJSON RatchetVR where
|
||||
toEncoding (RatchetVR v1 v2) = toEncoding (v1, v2)
|
||||
toJSON (RatchetVR v1 v2) = toJSON (v1, v2)
|
||||
|
||||
instance FromJSON RatchetVR where
|
||||
parseJSON v = do
|
||||
-- this also verifies that v2 > v1 (although we could remove JSON instances for VersionRange)
|
||||
VersionRange v1 v2 <- parseJSON v
|
||||
pure $ RatchetVR v1 v2
|
||||
```
|
||||
|
||||
For connections, we could also make version used for the purposes of encoding dependent on the PQ being enabled, and version for decoding taken from message header, but then we'd have to not only upgrade ratchets but the connection as well every time PQ mode changes.
|
||||
|
||||
Another suggestion to ensure that correct version range is used in correct contexts could be:
|
||||
- using different newtypes for different version ranges.
|
||||
- define generic type class for version aware encoding that would also accept only specific type class for the version to use the correct range. This may be justified as there will be several version-aware encodings, and not just the protocol as now.
|
||||
|
||||
```haskell
|
||||
class Ord v => EncodingV v a where
|
||||
{-# MINIMAL smpEncodeV, (smpDecodeV | smpVP) #-}
|
||||
smpEncodeV :: v -> a -> ByteString
|
||||
-- default decode uses parser
|
||||
smpDecodeV :: v -> ByteString -> Either String a
|
||||
smpDecodeV = parseAll . smpVP
|
||||
-- default parser decodes from length-specified bytestring
|
||||
smpVP :: v -> Parser a
|
||||
smpVP v = smpDecodeV v <$?> smpP
|
||||
```
|
||||
|
||||
The version will be passed from currently agreed version, it may only change when message is received, not when message is sent. The version will not be extracted from the encoding itself as it happens now in ratchet encodings.
|
||||
|
||||
## Various options how the problem can be simplified
|
||||
|
||||
1. Do not support connection downgrade once both devices upgraded. If applied to all existing connections then it is a bad option, as it would disrupt some important conversations.
|
||||
|
||||
2. Do not provide ability to opt-in into PQ encryption until v5.7 where it will be rolled out automatically. That is also suboptimal, as it won't allow announcing technology design and have testing outside of the team devices.
|
||||
|
||||
3. The logic explained above where connection upgrade and downgrade is possible and applied to all existing connections if both parties consent to it. There are these important downsides:
|
||||
- complexity of this logic
|
||||
- regression risks when this logic is removed.
|
||||
- some non-coordinated upgrades of existing, potentially important conversations, simply because two users opt-in into the experiment without any expectation that another side also opts-in.
|
||||
|
||||
4. Apply upgrade/downgrade logic and enable PQ encryption as opt-in, based on the toggle in the UX, only for the new connections. This seems the least risky, and also simpler than option 3, as it would only apply to the new connections, and both users will have to enable experimental toggle prior to connecting.
|
||||
|
||||
Option 4 seems the best trade-off, and has these sub-options regarding where it is controlled:
|
||||
a) in chat based on connection flag. Chat will pass PQ options only to connections that were created when experimental option was enabled.
|
||||
b) in agent - there will be additional logic to ignore PQ option for existing connections.
|
||||
c) both in chat and in agent.
|
||||
|
||||
Option 4a seems better, as it would:
|
||||
- simplify agent code
|
||||
- minimise required changes when releasing v5.7 (as we do want that all direct and small groups connections migrate to PQ encryption at the time, without any toggles)
|
||||
- allow tests for connection upgrade in the currect code.
|
||||
+14
-1
@@ -5,7 +5,7 @@ cabal-version: 1.12
|
||||
-- see: https://github.com/sol/hpack
|
||||
|
||||
name: simplexmq
|
||||
version: 5.6.0.0
|
||||
version: 5.6.0.2
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
<./docs/Simplex-Messaging-Client.html client> and
|
||||
@@ -103,9 +103,12 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem
|
||||
Simplex.Messaging.Agent.TRcvQueues
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
Simplex.Messaging.Compression
|
||||
Simplex.Messaging.Crypto
|
||||
Simplex.Messaging.Crypto.File
|
||||
Simplex.Messaging.Crypto.Lazy
|
||||
@@ -158,6 +161,7 @@ library
|
||||
Simplex.Messaging.Transport.WebSockets
|
||||
Simplex.Messaging.Util
|
||||
Simplex.Messaging.Version
|
||||
Simplex.Messaging.Version.Internal
|
||||
Simplex.RemoteControl.Client
|
||||
Simplex.RemoteControl.Discovery
|
||||
Simplex.RemoteControl.Discovery.Multicast
|
||||
@@ -226,6 +230,7 @@ library
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
@@ -299,6 +304,7 @@ executable ntf-server
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
@@ -372,6 +378,7 @@ executable smp-agent
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
@@ -445,6 +452,7 @@ executable smp-server
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
@@ -518,6 +526,7 @@ executable xftp
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
@@ -591,6 +600,7 @@ executable xftp-server
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
@@ -612,6 +622,7 @@ test-suite simplexmq-test
|
||||
AgentTests
|
||||
AgentTests.ConnectionRequestTests
|
||||
AgentTests.DoubleRatchetTests
|
||||
AgentTests.EqInstances
|
||||
AgentTests.FunctionalAPITests
|
||||
AgentTests.MigrationTests
|
||||
AgentTests.NotificationTests
|
||||
@@ -634,6 +645,7 @@ test-suite simplexmq-test
|
||||
ServerTests
|
||||
SMPAgentClient
|
||||
SMPClient
|
||||
Util
|
||||
XFTPAgent
|
||||
XFTPCLI
|
||||
XFTPClient
|
||||
@@ -702,6 +714,7 @@ test-suite simplexmq-test
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, yaml ==0.11.*
|
||||
, zstd ==0.1.3.*
|
||||
default-language: Haskell2010
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
|
||||
@@ -35,6 +35,7 @@ import Control.Monad.Reader
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Composition ((.:))
|
||||
import Data.Either (rights)
|
||||
import Data.Int (Int64)
|
||||
@@ -52,8 +53,8 @@ import Simplex.FileTransfer.Client.Main
|
||||
import Simplex.FileTransfer.Crypto
|
||||
import Simplex.FileTransfer.Description
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..))
|
||||
import qualified Simplex.FileTransfer.Protocol as XFTP
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
|
||||
import qualified Simplex.FileTransfer.Transport as XFTP
|
||||
import Simplex.FileTransfer.Types
|
||||
import Simplex.FileTransfer.Util (removePath, uniqueCombine)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
@@ -389,21 +390,25 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
|
||||
where
|
||||
AgentConfig {xftpMaxRecipientsPerRequest = maxRecipients, messageRetryInterval = ri} = cfg
|
||||
encryptFileForUpload :: SndFile -> FilePath -> m (FileDigest, [(XFTPChunkSpec, FileDigest)])
|
||||
encryptFileForUpload SndFile {key, nonce, srcFile} fsEncPath = do
|
||||
encryptFileForUpload SndFile {key, nonce, srcFile, redirect} fsEncPath = do
|
||||
let CryptoFile {filePath} = srcFile
|
||||
fileName = takeFileName filePath
|
||||
fileSize <- liftIO $ fromInteger <$> CF.getFileContentsSize srcFile
|
||||
when (fileSize > maxFileSize) $ throwError $ INTERNAL "max file size exceeded"
|
||||
when (fileSize > maxFileSizeHard) $ throwError $ INTERNAL "max file size exceeded"
|
||||
let fileHdr = smpEncode FileHeader {fileName, fileExtra = Nothing}
|
||||
fileSize' = fromIntegral (B.length fileHdr) + fileSize
|
||||
chunkSizes = prepareChunkSizes $ fileSize' + fileSizeLen + authTagSize
|
||||
chunkSizes' = map fromIntegral chunkSizes
|
||||
encSize = sum chunkSizes'
|
||||
payloadSize = fileSize' + fileSizeLen + authTagSize
|
||||
chunkSizes <- case redirect of
|
||||
Nothing -> pure $ prepareChunkSizes payloadSize
|
||||
Just _ -> case singleChunkSize payloadSize of
|
||||
Nothing -> throwError $ INTERNAL "max file size exceeded for redirect"
|
||||
Just chunkSize -> pure [chunkSize]
|
||||
let encSize = sum $ map fromIntegral chunkSizes
|
||||
void $ liftError (INTERNAL . show) $ encryptFile srcFile fileHdr key nonce fileSize' encSize fsEncPath
|
||||
digest <- liftIO $ LC.sha512Hash <$> LB.readFile fsEncPath
|
||||
let chunkSpecs = prepareChunkSpecs fsEncPath chunkSizes
|
||||
chunkDigests <- map FileDigest <$> mapM (liftIO . getChunkDigest) chunkSpecs
|
||||
pure (FileDigest digest, zip chunkSpecs chunkDigests)
|
||||
chunkDigests <- liftIO $ mapM getChunkDigest chunkSpecs
|
||||
pure (FileDigest digest, zip chunkSpecs $ coerce chunkDigests)
|
||||
chunkCreated :: SndFileChunk -> Bool
|
||||
chunkCreated SndFileChunk {replicas} =
|
||||
any (\SndFileChunkReplica {replicaStatus} -> replicaStatus == SFRSCreated) replicas
|
||||
|
||||
@@ -57,7 +57,7 @@ import UnliftIO.Directory
|
||||
data XFTPClient = XFTPClient
|
||||
{ http2Client :: HTTP2Client,
|
||||
transportSession :: TransportSession FileResponse,
|
||||
thParams :: THandleParams,
|
||||
thParams :: THandleParams XFTPVersion,
|
||||
config :: XFTPClientConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -16,9 +16,11 @@ module Simplex.FileTransfer.Client.Main
|
||||
xftpClientCLI,
|
||||
cliSendFile,
|
||||
cliSendFileOpts,
|
||||
singleChunkSize,
|
||||
prepareChunkSizes,
|
||||
prepareChunkSpecs,
|
||||
maxFileSize,
|
||||
maxFileSizeHard,
|
||||
fileSizeLen,
|
||||
getChunkDigest,
|
||||
SentRecipientReplica (..),
|
||||
@@ -41,7 +43,7 @@ import Data.List.NonEmpty (NonEmpty (..), nonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Maybe (fromMaybe, listToMaybe)
|
||||
import qualified Data.Text as T
|
||||
import Data.Word (Word32)
|
||||
import GHC.Records (HasField (getField))
|
||||
@@ -76,12 +78,17 @@ import UnliftIO.Directory
|
||||
xftpClientVersion :: String
|
||||
xftpClientVersion = "1.0.1"
|
||||
|
||||
-- | Soft limit for XFTP clients. Should be checked and reported to user.
|
||||
maxFileSize :: Int64
|
||||
maxFileSize = gb 1
|
||||
|
||||
maxFileSizeStr :: String
|
||||
maxFileSizeStr = B.unpack . strEncode $ FileSize maxFileSize
|
||||
|
||||
-- | Hard internal limit for XFTP agent after which it refuses to prepare chunks.
|
||||
maxFileSizeHard :: Int64
|
||||
maxFileSizeHard = gb 5
|
||||
|
||||
fileSizeLen :: Int64
|
||||
fileSizeLen = 8
|
||||
|
||||
@@ -214,13 +221,13 @@ data SentFileChunk = SentFileChunk
|
||||
digest :: FileDigest,
|
||||
replicas :: [SentFileChunkReplica]
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data SentFileChunkReplica = SentFileChunkReplica
|
||||
{ server :: XFTPServer,
|
||||
recipients :: [(ChunkReplicaId, C.APrivateAuthKey)]
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data SentRecipientReplica = SentRecipientReplica
|
||||
{ chunkNo :: Int,
|
||||
@@ -407,7 +414,8 @@ getChunkDigest :: XFTPChunkSpec -> IO ByteString
|
||||
getChunkDigest XFTPChunkSpec {filePath = chunkPath, chunkOffset, chunkSize} =
|
||||
withFile chunkPath ReadMode $ \h -> do
|
||||
hSeek h AbsoluteSeek $ fromIntegral chunkOffset
|
||||
LC.sha256Hash <$> LB.hGet h (fromIntegral chunkSize)
|
||||
chunk <- LB.hGet h (fromIntegral chunkSize)
|
||||
pure $! LC.sha256Hash chunk
|
||||
|
||||
cliReceiveFile :: ReceiveOptions -> ExceptT CLIError IO ()
|
||||
cliReceiveFile ReceiveOptions {fileDescription, filePath, retryCount, tempPath, verbose, yes} =
|
||||
@@ -522,6 +530,12 @@ getFileDescription' path =
|
||||
getFileDescription path >>= \case
|
||||
AVFD fd -> either (throwError . CLIError) pure $ checkParty fd
|
||||
|
||||
singleChunkSize :: Int64 -> Maybe Word32
|
||||
singleChunkSize size' =
|
||||
listToMaybe $ dropWhile (< chunkSize) serverChunkSizes
|
||||
where
|
||||
chunkSize = fromIntegral size'
|
||||
|
||||
prepareChunkSizes :: Int64 -> [Word32]
|
||||
prepareChunkSizes size' = prepareSizes size'
|
||||
where
|
||||
|
||||
@@ -227,7 +227,7 @@ validateFileDescription fd@FileDescription {size, chunks}
|
||||
| otherwise = Right $ ValidFD fd
|
||||
where
|
||||
chunkNos = map (\FileChunk {chunkNo} -> chunkNo) chunks
|
||||
chunksSize = fromIntegral . foldl' (\s FileChunk {chunkSize} -> s + unFileSize chunkSize) 0
|
||||
chunksSize = foldl' (\(s :: Int64) FileChunk {chunkSize} -> s + fromIntegral (unFileSize chunkSize)) 0
|
||||
|
||||
encodeFileDescription :: FileDescription p -> YAMLFileDescription
|
||||
encodeFileDescription FileDescription {party, size, digest, key, nonce, chunkSize, chunks, redirect} =
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
@@ -14,9 +16,7 @@
|
||||
|
||||
module Simplex.FileTransfer.Protocol where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -25,11 +25,11 @@ import Data.List.NonEmpty (NonEmpty (..))
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word32)
|
||||
import Simplex.FileTransfer.Transport (VersionXFTP, XFTPErrorType (..), XFTPVersion, pattern VersionXFTP, xftpClientHandshake)
|
||||
import Simplex.Messaging.Client (authTransmission)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol
|
||||
( BasicAuth,
|
||||
@@ -56,11 +56,10 @@ import Simplex.Messaging.Protocol
|
||||
tParse,
|
||||
)
|
||||
import Simplex.Messaging.Transport (THandleParams (..), TransportError (..))
|
||||
import Simplex.Messaging.Util (bshow, (<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
currentXFTPVersion :: Version
|
||||
currentXFTPVersion = 1
|
||||
currentXFTPVersion :: VersionXFTP
|
||||
currentXFTPVersion = VersionXFTP 1
|
||||
|
||||
xftpBlockSize :: Int
|
||||
xftpBlockSize = 16384
|
||||
@@ -142,10 +141,10 @@ instance ProtocolMsgTag FileCmdTag where
|
||||
instance FilePartyI p => ProtocolMsgTag (FileCommandTag p) where
|
||||
decodeTag s = decodeTag s >>= (\(FCT _ t) -> checkParty' t)
|
||||
|
||||
instance Protocol XFTPErrorType FileResponse where
|
||||
instance Protocol XFTPVersion XFTPErrorType FileResponse where
|
||||
type ProtoCommand FileResponse = FileCmd
|
||||
type ProtoType FileResponse = 'PXFTP
|
||||
protocolClientHandshake = ntfClientHandshake
|
||||
protocolClientHandshake = xftpClientHandshake
|
||||
protocolPing = FileCmd SFRecipient PING
|
||||
protocolError = \case
|
||||
FRErr e -> Just e
|
||||
@@ -171,11 +170,11 @@ data FileInfo = FileInfo
|
||||
size :: Word32,
|
||||
digest :: ByteString
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
type XFTPFileId = ByteString
|
||||
|
||||
instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where
|
||||
instance FilePartyI p => ProtocolEncoding XFTPVersion XFTPErrorType (FileCommand p) where
|
||||
type Tag (FileCommand p) = FileCommandTag p
|
||||
encodeProtocol _v = \case
|
||||
FNEW file rKeys auth_ -> e (FNEW_, ' ', file, rKeys, auth_)
|
||||
@@ -191,7 +190,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where
|
||||
|
||||
protocolP v tag = (\(FileCmd _ c) -> checkParty c) <$?> protocolP v (FCT (sFileParty @p) tag)
|
||||
|
||||
fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse
|
||||
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (auth, _, fileId, _) cmd = case cmd of
|
||||
@@ -208,7 +207,7 @@ instance FilePartyI p => ProtocolEncoding XFTPErrorType (FileCommand p) where
|
||||
| isNothing auth || B.null fileId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance ProtocolEncoding XFTPErrorType FileCmd where
|
||||
instance ProtocolEncoding XFTPVersion XFTPErrorType FileCmd where
|
||||
type Tag FileCmd = FileCmdTag
|
||||
encodeProtocol _v (FileCmd _ c) = encodeProtocol _v c
|
||||
|
||||
@@ -225,7 +224,7 @@ instance ProtocolEncoding XFTPErrorType FileCmd where
|
||||
FACK_ -> pure FACK
|
||||
PING_ -> pure PING
|
||||
|
||||
fromProtocolError = fromProtocolError @XFTPErrorType @FileResponse
|
||||
fromProtocolError = fromProtocolError @XFTPVersion @XFTPErrorType @FileResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (FileCmd p c) = FileCmd p <$> checkCredentials t c
|
||||
@@ -276,7 +275,7 @@ data FileResponse
|
||||
| FRPong
|
||||
deriving (Show)
|
||||
|
||||
instance ProtocolEncoding XFTPErrorType FileResponse where
|
||||
instance ProtocolEncoding XFTPVersion XFTPErrorType FileResponse where
|
||||
type Tag FileResponse = FileResponseTag
|
||||
encodeProtocol _v = \case
|
||||
FRSndIds fId rIds -> e (FRSndIds_, ' ', fId, rIds)
|
||||
@@ -319,82 +318,6 @@ instance ProtocolEncoding XFTPErrorType FileResponse where
|
||||
| B.null entId = Right cmd
|
||||
| otherwise = Left $ CMD HAS_AUTH
|
||||
|
||||
data XFTPErrorType
|
||||
= -- | incorrect block format, encoding or signature size
|
||||
BLOCK
|
||||
| -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929)
|
||||
SESSION
|
||||
| -- | SMP command is unknown or has invalid syntax
|
||||
CMD {cmdErr :: CommandError}
|
||||
| -- | command authorization error - bad signature or non-existing SMP queue
|
||||
AUTH
|
||||
| -- | incorrent file size
|
||||
SIZE
|
||||
| -- | storage quota exceeded
|
||||
QUOTA
|
||||
| -- | incorrent file digest
|
||||
DIGEST
|
||||
| -- | file encryption/decryption failed
|
||||
CRYPTO
|
||||
| -- | no expected file body in request/response or no file on the server
|
||||
NO_FILE
|
||||
| -- | unexpected file body
|
||||
HAS_FILE
|
||||
| -- | file IO error
|
||||
FILE_IO
|
||||
| -- | bad redirect data
|
||||
REDIRECT {redirectError :: String}
|
||||
| -- | internal server error
|
||||
INTERNAL
|
||||
| -- | used internally, never returned by the server (to be removed)
|
||||
DUPLICATE_ -- not part of SMP protocol, used internally
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
instance StrEncoding XFTPErrorType where
|
||||
strEncode = \case
|
||||
CMD e -> "CMD " <> bshow e
|
||||
REDIRECT e -> "REDIRECT " <> bshow e
|
||||
e -> bshow e
|
||||
strP =
|
||||
"CMD " *> (CMD <$> parseRead1)
|
||||
<|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString)
|
||||
<|> parseRead1
|
||||
|
||||
instance Encoding XFTPErrorType where
|
||||
smpEncode = \case
|
||||
BLOCK -> "BLOCK"
|
||||
SESSION -> "SESSION"
|
||||
CMD err -> "CMD " <> smpEncode err
|
||||
AUTH -> "AUTH"
|
||||
SIZE -> "SIZE"
|
||||
QUOTA -> "QUOTA"
|
||||
DIGEST -> "DIGEST"
|
||||
CRYPTO -> "CRYPTO"
|
||||
NO_FILE -> "NO_FILE"
|
||||
HAS_FILE -> "HAS_FILE"
|
||||
FILE_IO -> "FILE_IO"
|
||||
REDIRECT err -> "REDIRECT " <> smpEncode err
|
||||
INTERNAL -> "INTERNAL"
|
||||
DUPLICATE_ -> "DUPLICATE_"
|
||||
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"BLOCK" -> pure BLOCK
|
||||
"SESSION" -> pure SESSION
|
||||
"CMD" -> CMD <$> _smpP
|
||||
"AUTH" -> pure AUTH
|
||||
"SIZE" -> pure SIZE
|
||||
"QUOTA" -> pure QUOTA
|
||||
"DIGEST" -> pure DIGEST
|
||||
"CRYPTO" -> pure CRYPTO
|
||||
"NO_FILE" -> pure NO_FILE
|
||||
"HAS_FILE" -> pure HAS_FILE
|
||||
"FILE_IO" -> pure FILE_IO
|
||||
"REDIRECT" -> REDIRECT <$> _smpP
|
||||
"INTERNAL" -> pure INTERNAL
|
||||
"DUPLICATE_" -> pure DUPLICATE_
|
||||
_ -> fail "bad error type"
|
||||
|
||||
checkParty :: forall t p p'. (FilePartyI p, FilePartyI p') => t p' -> Either String (t p)
|
||||
checkParty c = case testEquality (sFileParty @p) (sFileParty @p') of
|
||||
Just Refl -> Right c
|
||||
@@ -405,12 +328,12 @@ checkParty' c = case testEquality (sFileParty @p) (sFileParty @p') of
|
||||
Just Refl -> Just c
|
||||
_ -> Nothing
|
||||
|
||||
xftpEncodeAuthTransmission :: ProtocolEncoding e c => THandleParams -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeAuthTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> C.APrivateAuthKey -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeAuthTransmission thParams pKey (corrId, fId, msg) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, fId, msg)
|
||||
xftpEncodeBatch1 . (,tToSend) =<< authTransmission Nothing (Just pKey) corrId tForAuth
|
||||
|
||||
xftpEncodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> Transmission c -> Either TransportError ByteString
|
||||
xftpEncodeTransmission thParams (corrId, fId, msg) = do
|
||||
let t = encodeTransmission thParams (corrId, fId, msg)
|
||||
xftpEncodeBatch1 (Nothing, t)
|
||||
@@ -419,7 +342,7 @@ xftpEncodeTransmission thParams (corrId, fId, msg) = do
|
||||
xftpEncodeBatch1 :: SentRawTransmission -> Either TransportError ByteString
|
||||
xftpEncodeBatch1 t = first (const TELargeMsg) $ C.pad (tEncodeBatch1 t) xftpBlockSize
|
||||
|
||||
xftpDecodeTransmission :: ProtocolEncoding e c => THandleParams -> ByteString -> Either XFTPErrorType (SignedTransmission e c)
|
||||
xftpDecodeTransmission :: ProtocolEncoding XFTPVersion e c => THandleParams XFTPVersion -> ByteString -> Either XFTPErrorType (SignedTransmission e c)
|
||||
xftpDecodeTransmission thParams t = do
|
||||
t' <- first (const BLOCK) $ C.unPad t
|
||||
case tParse thParams t' of
|
||||
@@ -427,5 +350,3 @@ xftpDecodeTransmission thParams t = do
|
||||
_ -> Left BLOCK
|
||||
|
||||
$(J.deriveJSON (enumJSON $ dropPrefix "F") ''FileParty)
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType)
|
||||
|
||||
@@ -69,7 +69,7 @@ type M a = ReaderT XFTPEnv IO a
|
||||
|
||||
data XFTPTransportRequest =
|
||||
XFTPTransportRequest
|
||||
{ thParams :: THandleParams,
|
||||
{ thParams :: THandleParams XFTPVersion,
|
||||
reqBody :: HTTP2Body,
|
||||
request :: H.Request,
|
||||
sendResponse :: H.Response -> IO ()
|
||||
|
||||
@@ -28,7 +28,8 @@ import Data.Int (Int64)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPErrorType (..), XFTPFileId)
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..), SFileParty (..), XFTPFileId)
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (RcvPublicAuthKey, RecipientId, SenderId)
|
||||
@@ -49,7 +50,6 @@ data FileRec = FileRec
|
||||
recipientIds :: TVar (Set RecipientId),
|
||||
createdAt :: SystemTime
|
||||
}
|
||||
deriving (Eq)
|
||||
|
||||
data FileRecipient = FileRecipient RecipientId RcvPublicAuthKey
|
||||
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
module Simplex.FileTransfer.Transport
|
||||
( supportedFileServerVRange,
|
||||
xftpClientHandshake, -- stub
|
||||
XFTPVersion,
|
||||
VersionXFTP,
|
||||
pattern VersionXFTP,
|
||||
XFTPErrorType (..),
|
||||
XFTPRcvChunkSpec (..),
|
||||
ReceiveFileError (..),
|
||||
receiveFile,
|
||||
@@ -14,22 +23,31 @@ module Simplex.FileTransfer.Transport
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import qualified Data.Aeson.TH as J
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Builder (Builder, byteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Word (Word32)
|
||||
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
|
||||
import Data.Word (Word16, Word32)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol (CommandError)
|
||||
import Simplex.Messaging.Transport (HandshakeError (..), THandle, TransportError (..))
|
||||
import Simplex.Messaging.Transport.HTTP2.File
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import System.IO (Handle, IOMode (..), withFile)
|
||||
|
||||
data XFTPRcvChunkSpec = XFTPRcvChunkSpec
|
||||
@@ -39,8 +57,26 @@ data XFTPRcvChunkSpec = XFTPRcvChunkSpec
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
supportedFileServerVRange :: VersionRange
|
||||
supportedFileServerVRange = mkVersionRange 1 1
|
||||
data XFTPVersion
|
||||
|
||||
instance VersionScope XFTPVersion
|
||||
|
||||
type VersionXFTP = Version XFTPVersion
|
||||
|
||||
type VersionRangeXFTP = VersionRange XFTPVersion
|
||||
|
||||
pattern VersionXFTP :: Word16 -> VersionXFTP
|
||||
pattern VersionXFTP v = Version v
|
||||
|
||||
initialXFTPVersion :: VersionXFTP
|
||||
initialXFTPVersion = VersionXFTP 1
|
||||
|
||||
supportedFileServerVRange :: VersionRangeXFTP
|
||||
supportedFileServerVRange = mkVersionRange initialXFTPVersion initialXFTPVersion
|
||||
|
||||
-- XFTP protocol does not support handshake
|
||||
xftpClientHandshake :: c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeXFTP -> ExceptT TransportError IO (THandle XFTPVersion c)
|
||||
xftpClientHandshake _c _ks _keyHash _xftpVRange = throwError $ TEHandshake VERSION
|
||||
|
||||
sendEncFile :: Handle -> (Builder -> IO ()) -> LC.SbState -> Word32 -> IO ()
|
||||
sendEncFile h send = go
|
||||
@@ -97,3 +133,81 @@ receiveFile_ receive XFTPRcvChunkSpec {filePath, chunkSize, chunkDigest} = do
|
||||
ExceptT $ withFile filePath WriteMode (`receive` chunkSize)
|
||||
digest' <- liftIO $ LC.sha256Hash <$> LB.readFile filePath
|
||||
when (digest' /= chunkDigest) $ throwError DIGEST
|
||||
|
||||
data XFTPErrorType
|
||||
= -- | incorrect block format, encoding or signature size
|
||||
BLOCK
|
||||
| -- | incorrect SMP session ID (TLS Finished message / tls-unique binding RFC5929)
|
||||
SESSION
|
||||
| -- | SMP command is unknown or has invalid syntax
|
||||
CMD {cmdErr :: CommandError}
|
||||
| -- | command authorization error - bad signature or non-existing SMP queue
|
||||
AUTH
|
||||
| -- | incorrent file size
|
||||
SIZE
|
||||
| -- | storage quota exceeded
|
||||
QUOTA
|
||||
| -- | incorrent file digest
|
||||
DIGEST
|
||||
| -- | file encryption/decryption failed
|
||||
CRYPTO
|
||||
| -- | no expected file body in request/response or no file on the server
|
||||
NO_FILE
|
||||
| -- | unexpected file body
|
||||
HAS_FILE
|
||||
| -- | file IO error
|
||||
FILE_IO
|
||||
| -- | bad redirect data
|
||||
REDIRECT {redirectError :: String}
|
||||
| -- | internal server error
|
||||
INTERNAL
|
||||
| -- | used internally, never returned by the server (to be removed)
|
||||
DUPLICATE_ -- not part of SMP protocol, used internally
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
instance StrEncoding XFTPErrorType where
|
||||
strEncode = \case
|
||||
CMD e -> "CMD " <> bshow e
|
||||
REDIRECT e -> "REDIRECT " <> bshow e
|
||||
e -> bshow e
|
||||
strP =
|
||||
"CMD " *> (CMD <$> parseRead1)
|
||||
<|> "REDIRECT " *> (REDIRECT <$> parseRead A.takeByteString)
|
||||
<|> parseRead1
|
||||
|
||||
instance Encoding XFTPErrorType where
|
||||
smpEncode = \case
|
||||
BLOCK -> "BLOCK"
|
||||
SESSION -> "SESSION"
|
||||
CMD err -> "CMD " <> smpEncode err
|
||||
AUTH -> "AUTH"
|
||||
SIZE -> "SIZE"
|
||||
QUOTA -> "QUOTA"
|
||||
DIGEST -> "DIGEST"
|
||||
CRYPTO -> "CRYPTO"
|
||||
NO_FILE -> "NO_FILE"
|
||||
HAS_FILE -> "HAS_FILE"
|
||||
FILE_IO -> "FILE_IO"
|
||||
REDIRECT err -> "REDIRECT " <> smpEncode err
|
||||
INTERNAL -> "INTERNAL"
|
||||
DUPLICATE_ -> "DUPLICATE_"
|
||||
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"BLOCK" -> pure BLOCK
|
||||
"SESSION" -> pure SESSION
|
||||
"CMD" -> CMD <$> _smpP
|
||||
"AUTH" -> pure AUTH
|
||||
"SIZE" -> pure SIZE
|
||||
"QUOTA" -> pure QUOTA
|
||||
"DIGEST" -> pure DIGEST
|
||||
"CRYPTO" -> pure CRYPTO
|
||||
"NO_FILE" -> pure NO_FILE
|
||||
"HAS_FILE" -> pure HAS_FILE
|
||||
"FILE_IO" -> pure FILE_IO
|
||||
"REDIRECT" -> REDIRECT <$> _smpP
|
||||
"INTERNAL" -> pure INTERNAL
|
||||
"DUPLICATE_" -> pure DUPLICATE_
|
||||
_ -> fail "bad error type"
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''XFTPErrorType)
|
||||
|
||||
@@ -55,7 +55,7 @@ data RcvFile = RcvFile
|
||||
status :: RcvFileStatus,
|
||||
deleted :: Bool
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data RcvFileStatus
|
||||
= RFSReceiving
|
||||
@@ -96,7 +96,7 @@ data RcvFileChunk = RcvFileChunk
|
||||
fileTmpPath :: FilePath,
|
||||
chunkTmpPath :: Maybe FilePath
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data RcvFileChunkReplica = RcvFileChunkReplica
|
||||
{ rcvChunkReplicaId :: Int64,
|
||||
@@ -107,14 +107,14 @@ data RcvFileChunkReplica = RcvFileChunkReplica
|
||||
delay :: Maybe Int64,
|
||||
retries :: Int
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data RcvFileRedirect = RcvFileRedirect
|
||||
{ redirectDbId :: DBRcvFileId,
|
||||
redirectEntityId :: RcvFileId,
|
||||
redirectFileInfo :: RedirectFileInfo
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
-- Sending files
|
||||
|
||||
@@ -135,7 +135,7 @@ data SndFile = SndFile
|
||||
deleted :: Bool,
|
||||
redirect :: Maybe RedirectFileInfo
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
sndFileEncPath :: FilePath -> FilePath
|
||||
sndFileEncPath prefixPath = prefixPath </> "xftp.encrypted"
|
||||
@@ -182,7 +182,7 @@ data SndFileChunk = SndFileChunk
|
||||
digest :: FileDigest,
|
||||
replicas :: [SndFileChunkReplica]
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
sndChunkSize :: SndFileChunk -> Word32
|
||||
sndChunkSize SndFileChunk {chunkSpec = XFTPChunkSpec {chunkSize}} = chunkSize
|
||||
@@ -193,7 +193,7 @@ data NewSndChunkReplica = NewSndChunkReplica
|
||||
replicaKey :: C.APrivateAuthKey,
|
||||
rcvIdsKeys :: [(ChunkReplicaId, C.APrivateAuthKey)]
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data SndFileChunkReplica = SndFileChunkReplica
|
||||
{ sndChunkReplicaId :: Int64,
|
||||
@@ -205,7 +205,7 @@ data SndFileChunkReplica = SndFileChunkReplica
|
||||
delay :: Maybe Int64,
|
||||
retries :: Int
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data SndFileReplicaStatus
|
||||
= SFRSCreated
|
||||
@@ -235,4 +235,4 @@ data DeletedSndChunkReplica = DeletedSndChunkReplica
|
||||
delay :: Maybe Int64,
|
||||
retries :: Int
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
+346
-252
File diff suppressed because it is too large
Load Diff
@@ -110,8 +110,6 @@ module Simplex.Messaging.Agent.Client
|
||||
whenSuspending,
|
||||
withStore,
|
||||
withStore',
|
||||
withStoreCtx,
|
||||
withStoreCtx',
|
||||
withStoreBatch,
|
||||
withStoreBatch',
|
||||
storeError,
|
||||
@@ -167,8 +165,8 @@ import Network.Socket (HostName)
|
||||
import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError)
|
||||
import qualified Simplex.FileTransfer.Client as X
|
||||
import Simplex.FileTransfer.Description (ChunkReplicaId (..), FileDigest (..), kb)
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse, XFTPErrorType (DIGEST))
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..), FileResponse)
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (DIGEST), XFTPVersion)
|
||||
import Simplex.FileTransfer.Types (DeletedSndChunkReplica (..), NewSndChunkReplica (..), RcvFileChunkReplica (..), SndFileChunk (..), SndFileChunkReplica (..))
|
||||
import Simplex.FileTransfer.Util (uniqueCombine)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -187,6 +185,7 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion)
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, enumJSON, parse)
|
||||
import Simplex.Messaging.Protocol
|
||||
@@ -215,9 +214,12 @@ import Simplex.Messaging.Protocol
|
||||
UserProtocol,
|
||||
XFTPServer,
|
||||
XFTPServerWithAuth,
|
||||
VersionSMPC,
|
||||
VersionRangeSMPC,
|
||||
sameSrvAddr',
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (SMPVersion)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
@@ -253,7 +255,7 @@ data AgentClient = AgentClient
|
||||
{ active :: TVar Bool,
|
||||
rcvQ :: TBQueue (ATransmission 'Client),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg),
|
||||
smpServers :: TMap UserId (NonEmpty SMPServerWithAuth),
|
||||
smpClients :: TMap SMPTransportSession SMPClientVar,
|
||||
ntfServers :: TVar [NtfServer],
|
||||
@@ -467,7 +469,7 @@ agentClientStore AgentClient {agentEnv = Env {store}} = store
|
||||
agentDRG :: AgentClient -> TVar ChaChaDRG
|
||||
agentDRG AgentClient {agentEnv = Env {random}} = random
|
||||
|
||||
class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err where
|
||||
class (Encoding err, Show err) => ProtocolServerClient v err msg | msg -> v, msg -> err where
|
||||
type Client msg = c | c -> msg
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> TransportSession msg -> m (Client msg)
|
||||
clientProtocolError :: err -> AgentErrorType
|
||||
@@ -476,8 +478,8 @@ class (Encoding err, Show err) => ProtocolServerClient err msg | msg -> err wher
|
||||
clientTransportHost :: Client msg -> TransportHost
|
||||
clientSessionTs :: Client msg -> UTCTime
|
||||
|
||||
instance ProtocolServerClient ErrorType BrokerMsg where
|
||||
type Client BrokerMsg = ProtocolClient ErrorType BrokerMsg
|
||||
instance ProtocolServerClient SMPVersion ErrorType BrokerMsg where
|
||||
type Client BrokerMsg = ProtocolClient SMPVersion ErrorType BrokerMsg
|
||||
getProtocolServerClient = getSMPServerClient
|
||||
clientProtocolError = SMP
|
||||
closeProtocolServerClient = closeProtocolClient
|
||||
@@ -485,8 +487,8 @@ instance ProtocolServerClient ErrorType BrokerMsg where
|
||||
clientTransportHost = transportHost'
|
||||
clientSessionTs = sessionTs
|
||||
|
||||
instance ProtocolServerClient ErrorType NtfResponse where
|
||||
type Client NtfResponse = ProtocolClient ErrorType NtfResponse
|
||||
instance ProtocolServerClient NTFVersion ErrorType NtfResponse where
|
||||
type Client NtfResponse = ProtocolClient NTFVersion ErrorType NtfResponse
|
||||
getProtocolServerClient = getNtfServerClient
|
||||
clientProtocolError = NTF
|
||||
closeProtocolServerClient = closeProtocolClient
|
||||
@@ -494,7 +496,7 @@ instance ProtocolServerClient ErrorType NtfResponse where
|
||||
clientTransportHost = transportHost'
|
||||
clientSessionTs = sessionTs
|
||||
|
||||
instance ProtocolServerClient XFTPErrorType FileResponse where
|
||||
instance ProtocolServerClient XFTPVersion XFTPErrorType FileResponse where
|
||||
type Client FileResponse = XFTPClient
|
||||
getProtocolServerClient = getXFTPServerClient
|
||||
clientProtocolError = XFTP
|
||||
@@ -683,8 +685,8 @@ waitForProtocolClient c (_, srv, _) v = do
|
||||
|
||||
-- clientConnected arg is only passed for SMP server
|
||||
newProtocolClient ::
|
||||
forall err msg m.
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) =>
|
||||
forall v err msg m.
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) =>
|
||||
AgentClient ->
|
||||
TransportSession msg ->
|
||||
TMap (TransportSession msg) (ClientVar msg) ->
|
||||
@@ -706,10 +708,10 @@ newProtocolClient c tSess@(userId, srv, entityId_) clients connectClient v =
|
||||
putTMVar (sessionVar v) (Left e)
|
||||
throwError e -- signal error to caller
|
||||
|
||||
hostEvent :: forall err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
|
||||
hostEvent :: forall v err msg. (ProtocolTypeI (ProtoType msg), ProtocolServerClient v err msg) => (AProtocolType -> TransportHost -> ACommand 'Agent 'AENone) -> Client msg -> ACommand 'Agent 'AENone
|
||||
hostEvent event = event (AProtocolType $ protocolTypeI @(ProtoType msg)) . clientTransportHost
|
||||
|
||||
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig) -> m ProtocolClientConfig
|
||||
getClientConfig :: AgentMonad' m => AgentClient -> (AgentConfig -> ProtocolClientConfig v) -> m (ProtocolClientConfig v)
|
||||
getClientConfig AgentClient {useNetworkConfig} cfgSel = do
|
||||
cfg <- asks $ cfgSel . config
|
||||
networkConfig <- readTVarIO useNetworkConfig
|
||||
@@ -754,19 +756,19 @@ throwWhenNoDelivery c sq =
|
||||
unlessM (TM.member (qAddress sq) $ smpDeliveryWorkers c) $
|
||||
throwSTM ThreadKilled
|
||||
|
||||
closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
closeProtocolServerClients c clientsSel =
|
||||
atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c)
|
||||
|
||||
reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
|
||||
reconnectServerClients c clientsSel =
|
||||
readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c)
|
||||
|
||||
closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO ()
|
||||
closeClient :: ProtocolServerClient v err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO ()
|
||||
closeClient c clientSel tSess =
|
||||
atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c)
|
||||
|
||||
closeClient_ :: ProtocolServerClient err msg => AgentClient -> ClientVar msg -> IO ()
|
||||
closeClient_ :: ProtocolServerClient v err msg => AgentClient -> ClientVar msg -> IO ()
|
||||
closeClient_ c v = do
|
||||
NetworkConfig {tcpConnectTimeout} <- readTVarIO $ useNetworkConfig c
|
||||
tcpConnectTimeout `timeout` atomically (readTMVar $ sessionVar v) >>= \case
|
||||
@@ -798,7 +800,7 @@ getMapLock locks key = TM.lookup key locks >>= maybe newLock pure
|
||||
where
|
||||
newLock = createLock >>= \l -> TM.insert key l locks $> l
|
||||
|
||||
withClient_ :: forall a m err msg. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
|
||||
withClient_ :: forall a m v err msg. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> m a) -> m a
|
||||
withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
cl <- getProtocolServerClient c tSess
|
||||
(action cl <* stat cl "OK") `catchAgentError` logServerError cl
|
||||
@@ -810,18 +812,18 @@ withClient_ c tSess@(userId, srv, _) statCmd action = do
|
||||
stat cl $ strEncode e
|
||||
throwError e
|
||||
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> m a) -> m a
|
||||
withLogClient_ c tSess@(_, srv, _) entId cmdStr action = do
|
||||
logServer "-->" c srv entId cmdStr
|
||||
res <- withClient_ c tSess cmdStr action
|
||||
logServer "<--" c srv entId "OK"
|
||||
return res
|
||||
|
||||
withClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
|
||||
withClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withClient c tSess statKey action = withClient_ c tSess statKey $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
|
||||
|
||||
withLogClient :: forall m err msg a. (AgentMonad m, ProtocolServerClient err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @err @msg) (clientServer client) $ action client
|
||||
withLogClient :: forall m v err msg a. (AgentMonad m, ProtocolServerClient v err msg) => AgentClient -> TransportSession msg -> EntityId -> ByteString -> (Client msg -> ExceptT (ProtocolClientError err) IO a) -> m a
|
||||
withLogClient c tSess entId cmdStr action = withLogClient_ c tSess entId cmdStr $ \client -> liftClient (clientProtocolError @v @err @msg) (clientServer client) $ action client
|
||||
|
||||
withSMPClient :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMPClient c q cmdStr action = do
|
||||
@@ -837,7 +839,7 @@ withNtfClient :: forall m a. AgentMonad m => AgentClient -> NtfServer -> EntityI
|
||||
withNtfClient c srv = withLogClient c (0, srv, Nothing)
|
||||
|
||||
withXFTPClient ::
|
||||
(AgentMonad m, ProtocolServerClient err msg) =>
|
||||
(AgentMonad m, ProtocolServerClient v err msg) =>
|
||||
AgentClient ->
|
||||
(UserId, ProtoServer msg, EntityId) ->
|
||||
ByteString ->
|
||||
@@ -1001,7 +1003,7 @@ mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)
|
||||
getSessionMode :: AgentMonad' m => AgentClient -> m TransportSessionMode
|
||||
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRangeSMPC -> SubscriptionMode -> m (NewRcvQueue, SMPQueueUri)
|
||||
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange subMode = do
|
||||
C.AuthAlg a <- asks (rcvAuthAlg . config)
|
||||
g <- asks random
|
||||
@@ -1151,7 +1153,7 @@ sendConfirmation c sq@SndQueue {sndId, sndPublicKey = Just sndPublicKey, e2ePubK
|
||||
liftClient SMP (clientServer smp) $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> UserId -> Compatible SMPQueueInfo -> Compatible VersionSMPA -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c userId (Compatible (SMPQueueInfo v SMPQueueAddress {smpServer, senderId, dhPublicKey})) (Compatible agentVersion) connReq connInfo = do
|
||||
tSess <- mkTransportSession c userId smpServer senderId
|
||||
withLogClient_ c tSess senderId "SEND <INV>" $ \smp -> do
|
||||
@@ -1334,7 +1336,7 @@ agentCbEncrypt SndQueue {e2eDhSecret, smpClientVersion} e2ePubKey msg = do
|
||||
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
|
||||
|
||||
-- add encoding as AgentInvitation'?
|
||||
agentCbEncryptOnce :: AgentMonad m => Version -> C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncryptOnce :: AgentMonad m => VersionSMPC -> C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncryptOnce clientVersion dhRcvPubKey msg = do
|
||||
g <- asks random
|
||||
(dhSndPubKey, dhSndPrivKey) <- atomically $ C.generateKeyPair g
|
||||
@@ -1453,34 +1455,13 @@ waitUntilForeground :: AgentClient -> STM ()
|
||||
waitUntilForeground c = unlessM ((ASForeground ==) <$> readTVar (agentState c)) retry
|
||||
|
||||
withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a
|
||||
withStore' = withStoreCtx_' Nothing
|
||||
withStore' c action = withStore c $ fmap Right . action
|
||||
|
||||
withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
|
||||
withStore = withStoreCtx_ Nothing
|
||||
|
||||
withStoreCtx' :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO a) -> m a
|
||||
withStoreCtx' = withStoreCtx_' . Just
|
||||
|
||||
withStoreCtx :: AgentMonad m => String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
|
||||
withStoreCtx = withStoreCtx_ . Just
|
||||
|
||||
withStoreCtx_' :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO a) -> m a
|
||||
withStoreCtx_' ctx_ c action = withStoreCtx_ ctx_ c $ fmap Right . action
|
||||
|
||||
withStoreCtx_ :: AgentMonad m => Maybe String -> AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
|
||||
withStoreCtx_ ctx_ c action = do
|
||||
withStore c action = do
|
||||
st <- asks store
|
||||
liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $ case ctx_ of
|
||||
Nothing -> withTransaction st action `E.catch` handleInternal ""
|
||||
-- uncomment to debug store performance
|
||||
-- Just ctx -> do
|
||||
-- t1 <- liftIO getCurrentTime
|
||||
-- putStrLn $ "agent withStoreCtx start :: " <> show t1 <> " :: " <> ctx
|
||||
-- r <- withTransaction st action `E.catch` handleInternal (" (" <> ctx <> ")")
|
||||
-- t2 <- liftIO getCurrentTime
|
||||
-- putStrLn $ "agent withStoreCtx end :: " <> show t2 <> " :: " <> ctx <> " :: duration=" <> show (diffToMilliseconds $ diffUTCTime t2 t1)
|
||||
-- pure r
|
||||
Just _ -> withTransaction st action `E.catch` handleInternal ""
|
||||
liftEitherError storeError . agentOperationBracket c AODatabase (\_ -> pure ()) $
|
||||
withTransaction st action `E.catch` handleInternal ""
|
||||
where
|
||||
handleInternal :: String -> E.SomeException -> IO (Either StoreError a)
|
||||
handleInternal ctxStr e = pure . Left . SEInternal . B.pack $ show e <> ctxStr
|
||||
@@ -1518,7 +1499,7 @@ incStat AgentClient {agentStats} n k = do
|
||||
Just v -> modifyTVar' v (+ n)
|
||||
_ -> newTVar n >>= \v -> TM.insert k v agentStats
|
||||
|
||||
incClientStat :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> ByteString -> ByteString -> IO ()
|
||||
incClientStat c userId pc = incClientStatN c userId pc 1
|
||||
|
||||
incServerStat :: AgentClient -> UserId -> ProtocolServer p -> ByteString -> ByteString -> IO ()
|
||||
@@ -1528,7 +1509,7 @@ incServerStat c userId ProtocolServer {host} cmd res = do
|
||||
where
|
||||
statsKey = AgentStatsKey {userId, host = strEncode $ L.head host, clientTs = "", cmd, res}
|
||||
|
||||
incClientStatN :: ProtocolServerClient err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN :: ProtocolServerClient v err msg => AgentClient -> UserId -> Client msg -> Int -> ByteString -> ByteString -> IO ()
|
||||
incClientStatN c userId pc n cmd res = do
|
||||
atomically $ incStat c n statsKey
|
||||
where
|
||||
|
||||
@@ -56,16 +56,16 @@ import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (supportedE2EEncryptVRange)
|
||||
import Simplex.Messaging.Crypto.Ratchet (PQSupport, VersionRangeE2E, supportedE2EEncryptVRange)
|
||||
import Simplex.Messaging.Notifications.Client (defaultNTFClientConfig)
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion)
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Protocol (NtfServer, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange)
|
||||
import Simplex.Messaging.Protocol (NtfServer, VersionRangeSMPC, XFTPServer, XFTPServerWithAuth, supportedSMPClientVRange)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport (SMPVersion, TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport.Client (defaultSMPPort)
|
||||
import Simplex.Messaging.Util (allFinally, catchAllErrors, tryAllErrors)
|
||||
import Simplex.Messaging.Version
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO (Async, SomeException)
|
||||
import UnliftIO.STM
|
||||
@@ -87,12 +87,13 @@ data AgentConfig = AgentConfig
|
||||
sndAuthAlg :: C.AuthAlg,
|
||||
connIdBytes :: Int,
|
||||
tbqSize :: Natural,
|
||||
smpCfg :: ProtocolClientConfig,
|
||||
ntfCfg :: ProtocolClientConfig,
|
||||
smpCfg :: ProtocolClientConfig SMPVersion,
|
||||
ntfCfg :: ProtocolClientConfig NTFVersion,
|
||||
xftpCfg :: XFTPClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
messageRetryInterval :: RetryInterval2,
|
||||
messageTimeout :: NominalDiffTime,
|
||||
connDeleteDeliveryTimeout :: NominalDiffTime,
|
||||
helloTimeout :: NominalDiffTime,
|
||||
quotaExceededTimeout :: NominalDiffTime,
|
||||
initialCleanupDelay :: Int64,
|
||||
@@ -115,9 +116,9 @@ data AgentConfig = AgentConfig
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath,
|
||||
e2eEncryptVRange :: VersionRange,
|
||||
smpAgentVRange :: VersionRange,
|
||||
smpClientVRange :: VersionRange
|
||||
e2eEncryptVRange :: PQSupport -> VersionRangeE2E,
|
||||
smpAgentVRange :: PQSupport -> VersionRangeSMPA,
|
||||
smpClientVRange :: VersionRangeSMPC
|
||||
}
|
||||
|
||||
defaultReconnectInterval :: RetryInterval
|
||||
@@ -161,6 +162,7 @@ defaultAgentConfig =
|
||||
reconnectInterval = defaultReconnectInterval,
|
||||
messageRetryInterval = defaultMessageRetryInterval,
|
||||
messageTimeout = 2 * nominalDay,
|
||||
connDeleteDeliveryTimeout = 2 * nominalDay,
|
||||
helloTimeout = 2 * nominalDay,
|
||||
quotaExceededTimeout = 7 * nominalDay,
|
||||
initialCleanupDelay = 30 * 1000000, -- 30 seconds
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
@@ -33,8 +34,14 @@
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
|
||||
module Simplex.Messaging.Agent.Protocol
|
||||
( -- * Protocol parameters
|
||||
VersionSMPA,
|
||||
VersionRangeSMPA,
|
||||
pattern VersionSMPA,
|
||||
duplexHandshakeSMPAgentVersion,
|
||||
ratchetSyncSMPAgentVersion,
|
||||
deliveryRcptsSMPAgentVersion,
|
||||
pqdrSMPAgentVersion,
|
||||
currentSMPAgentVersion,
|
||||
supportedSMPAgentVRange,
|
||||
e2eEncConnInfoLength,
|
||||
e2eEncUserMsgLength,
|
||||
@@ -175,14 +182,25 @@ import Data.Time.Clock.System (SystemTime)
|
||||
import Data.Time.ISO8601
|
||||
import Data.Type.Equality
|
||||
import Data.Typeable ()
|
||||
import Data.Word (Word32)
|
||||
import Data.Word (Word16, Word32)
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.ToField
|
||||
import Simplex.FileTransfer.Description
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType)
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType)
|
||||
import Simplex.Messaging.Agent.QueryString
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (E2ERatchetParams, E2ERatchetParamsUri)
|
||||
import Simplex.Messaging.Crypto.Ratchet
|
||||
( InitialKeys (..),
|
||||
PQEncryption (..),
|
||||
pattern PQEncOff,
|
||||
PQSupport,
|
||||
pattern PQSupportOn,
|
||||
pattern PQSupportOff,
|
||||
RcvE2ERatchetParams,
|
||||
RcvE2ERatchetParamsUri,
|
||||
SndE2ERatchetParams
|
||||
)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
@@ -200,6 +218,10 @@ import Simplex.Messaging.Protocol
|
||||
SMPServerWithAuth,
|
||||
SndPublicAuthKey,
|
||||
SubscriptionMode,
|
||||
SMPClientVersion,
|
||||
VersionSMPC,
|
||||
VersionRangeSMPC,
|
||||
initialSMPClientVersion,
|
||||
legacyEncodeServer,
|
||||
legacyServerP,
|
||||
legacyStrEncodeServer,
|
||||
@@ -215,6 +237,7 @@ import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTra
|
||||
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts_ (..))
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import Simplex.RemoteControl.Types
|
||||
import Text.Read
|
||||
import UnliftIO.Exception (Exception)
|
||||
@@ -224,30 +247,56 @@ import UnliftIO.Exception (Exception)
|
||||
-- 2 - "duplex" (more efficient) connection handshake (6/9/2022)
|
||||
-- 3 - support ratchet renegotiation (6/30/2023)
|
||||
-- 4 - delivery receipts (7/13/2023)
|
||||
-- 5 - post-quantum double ratchet (3/14/2024)
|
||||
|
||||
duplexHandshakeSMPAgentVersion :: Version
|
||||
duplexHandshakeSMPAgentVersion = 2
|
||||
data SMPAgentVersion
|
||||
|
||||
ratchetSyncSMPAgentVersion :: Version
|
||||
ratchetSyncSMPAgentVersion = 3
|
||||
instance VersionScope SMPAgentVersion
|
||||
|
||||
deliveryRcptsSMPAgentVersion :: Version
|
||||
deliveryRcptsSMPAgentVersion = 4
|
||||
type VersionSMPA = Version SMPAgentVersion
|
||||
|
||||
currentSMPAgentVersion :: Version
|
||||
currentSMPAgentVersion = 4
|
||||
type VersionRangeSMPA = VersionRange SMPAgentVersion
|
||||
|
||||
supportedSMPAgentVRange :: VersionRange
|
||||
supportedSMPAgentVRange = mkVersionRange duplexHandshakeSMPAgentVersion currentSMPAgentVersion
|
||||
pattern VersionSMPA :: Word16 -> VersionSMPA
|
||||
pattern VersionSMPA v = Version v
|
||||
|
||||
duplexHandshakeSMPAgentVersion :: VersionSMPA
|
||||
duplexHandshakeSMPAgentVersion = VersionSMPA 2
|
||||
|
||||
ratchetSyncSMPAgentVersion :: VersionSMPA
|
||||
ratchetSyncSMPAgentVersion = VersionSMPA 3
|
||||
|
||||
deliveryRcptsSMPAgentVersion :: VersionSMPA
|
||||
deliveryRcptsSMPAgentVersion = VersionSMPA 4
|
||||
|
||||
pqdrSMPAgentVersion :: VersionSMPA
|
||||
pqdrSMPAgentVersion = VersionSMPA 5
|
||||
|
||||
-- TODO v5.7 increase to 5
|
||||
currentSMPAgentVersion :: VersionSMPA
|
||||
currentSMPAgentVersion = VersionSMPA 4
|
||||
|
||||
-- TODO v5.7 remove dependency of version range on whether PQ support is needed
|
||||
supportedSMPAgentVRange :: PQSupport -> VersionRangeSMPA
|
||||
supportedSMPAgentVRange pq =
|
||||
mkVersionRange duplexHandshakeSMPAgentVersion $ case pq of
|
||||
PQSupportOn -> pqdrSMPAgentVersion
|
||||
PQSupportOff -> currentSMPAgentVersion
|
||||
|
||||
-- it is shorter to allow all handshake headers,
|
||||
-- including E2E (double-ratchet) parameters and
|
||||
-- signing key of the sender for the server
|
||||
e2eEncConnInfoLength :: Int
|
||||
e2eEncConnInfoLength = 14848
|
||||
e2eEncConnInfoLength :: VersionSMPA -> PQSupport -> Int
|
||||
e2eEncConnInfoLength v = \case
|
||||
-- reduced by 3726 (roughly the increase of message ratchet header size + key and ciphertext in reply link)
|
||||
PQSupportOn | v >= pqdrSMPAgentVersion -> 11122
|
||||
_ -> 14848
|
||||
|
||||
e2eEncUserMsgLength :: Int
|
||||
e2eEncUserMsgLength = 15856
|
||||
e2eEncUserMsgLength :: VersionSMPA -> PQSupport -> Int
|
||||
e2eEncUserMsgLength v = \case
|
||||
-- reduced by 2222 (the increase of message ratchet header size)
|
||||
PQSupportOn | v >= pqdrSMPAgentVersion -> 13634
|
||||
_ -> 15856
|
||||
|
||||
-- | Raw (unparsed) SMP agent protocol transmission.
|
||||
type ARawTransmission = (ByteString, ByteString, ByteString)
|
||||
@@ -273,8 +322,6 @@ data SAParty :: AParty -> Type where
|
||||
|
||||
deriving instance Show (SAParty p)
|
||||
|
||||
deriving instance Eq (SAParty p)
|
||||
|
||||
instance TestEquality SAParty where
|
||||
testEquality SAgent SAgent = Just Refl
|
||||
testEquality SClient SClient = Just Refl
|
||||
@@ -297,8 +344,6 @@ data SAEntity :: AEntity -> Type where
|
||||
|
||||
deriving instance Show (SAEntity e)
|
||||
|
||||
deriving instance Eq (SAEntity e)
|
||||
|
||||
instance TestEquality SAEntity where
|
||||
testEquality SAEConn SAEConn = Just Refl
|
||||
testEquality SAERcvFile SAERcvFile = Just Refl
|
||||
@@ -333,16 +378,16 @@ type ConnInfo = ByteString
|
||||
|
||||
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
|
||||
data ACommand (p :: AParty) (e :: AEntity) where
|
||||
NEW :: Bool -> AConnectionMode -> SubscriptionMode -> ACommand Client AEConn -- response INV
|
||||
NEW :: Bool -> AConnectionMode -> InitialKeys -> SubscriptionMode -> ACommand Client AEConn -- response INV
|
||||
INV :: AConnectionRequestUri -> ACommand Agent AEConn
|
||||
JOIN :: Bool -> AConnectionRequestUri -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK
|
||||
CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake
|
||||
JOIN :: Bool -> AConnectionRequestUri -> PQSupport -> SubscriptionMode -> ConnInfo -> ACommand Client AEConn -- response OK
|
||||
CONF :: ConfirmationId -> PQSupport -> [SMPServer] -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake
|
||||
LET :: ConfirmationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client
|
||||
REQ :: InvitationId -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender
|
||||
ACPT :: InvitationId -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client
|
||||
REQ :: InvitationId -> PQSupport -> NonEmpty SMPServer -> ConnInfo -> ACommand Agent AEConn -- ConnInfo is from sender
|
||||
ACPT :: InvitationId -> PQSupport -> ConnInfo -> ACommand Client AEConn -- ConnInfo is from client
|
||||
RJCT :: InvitationId -> ACommand Client AEConn
|
||||
INFO :: ConnInfo -> ACommand Agent AEConn
|
||||
CON :: ACommand Agent AEConn -- notification that connection is established
|
||||
INFO :: PQSupport -> ConnInfo -> ACommand Agent AEConn
|
||||
CON :: PQEncryption -> ACommand Agent AEConn -- notification that connection is established
|
||||
SUB :: ACommand Client AEConn
|
||||
END :: ACommand Agent AEConn
|
||||
CONNECT :: AProtocolType -> TransportHost -> ACommand Agent AENone
|
||||
@@ -351,8 +396,8 @@ data ACommand (p :: AParty) (e :: AEntity) where
|
||||
UP :: SMPServer -> [ConnId] -> ACommand Agent AENone
|
||||
SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> ACommand Agent AEConn
|
||||
RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> ACommand Agent AEConn
|
||||
SEND :: MsgFlags -> MsgBody -> ACommand Client AEConn
|
||||
MID :: AgentMsgId -> ACommand Agent AEConn
|
||||
SEND :: PQEncryption -> MsgFlags -> MsgBody -> ACommand Client AEConn
|
||||
MID :: AgentMsgId -> PQEncryption -> ACommand Agent AEConn
|
||||
SENT :: AgentMsgId -> ACommand Agent AEConn
|
||||
MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent AEConn
|
||||
MERRS :: NonEmpty AgentMsgId -> AgentErrorType -> ACommand Agent AEConn
|
||||
@@ -458,8 +503,8 @@ aCommandTag = \case
|
||||
REQ {} -> REQ_
|
||||
ACPT {} -> ACPT_
|
||||
RJCT _ -> RJCT_
|
||||
INFO _ -> INFO_
|
||||
CON -> CON_
|
||||
INFO {} -> INFO_
|
||||
CON _ -> CON_
|
||||
SUB -> SUB_
|
||||
END -> END_
|
||||
CONNECT {} -> CONNECT_
|
||||
@@ -469,7 +514,7 @@ aCommandTag = \case
|
||||
SWITCH {} -> SWITCH_
|
||||
RSYNC {} -> RSYNC_
|
||||
SEND {} -> SEND_
|
||||
MID _ -> MID_
|
||||
MID {} -> MID_
|
||||
SENT _ -> SENT_
|
||||
MERR {} -> MERR_
|
||||
MERRS {} -> MERRS_
|
||||
@@ -665,7 +710,7 @@ instance StrEncoding SndQueueInfo where
|
||||
pure SndQueueInfo {sndServer, sndSwitchStatus}
|
||||
|
||||
data ConnectionStats = ConnectionStats
|
||||
{ connAgentVersion :: Version,
|
||||
{ connAgentVersion :: VersionSMPA,
|
||||
rcvQueuesInfo :: [RcvQueueInfo],
|
||||
sndQueuesInfo :: [SndQueueInfo],
|
||||
ratchetSyncState :: RatchetSyncState,
|
||||
@@ -769,17 +814,19 @@ data MsgMeta = MsgMeta
|
||||
{ integrity :: MsgIntegrity,
|
||||
recipient :: (AgentMsgId, UTCTime),
|
||||
broker :: (MsgId, UTCTime),
|
||||
sndMsgId :: AgentMsgId
|
||||
sndMsgId :: AgentMsgId,
|
||||
pqEncryption :: PQEncryption
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding MsgMeta where
|
||||
strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId} =
|
||||
strEncode MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId, pqEncryption} =
|
||||
B.unwords
|
||||
[ strEncode integrity,
|
||||
"R=" <> bshow rmId <> "," <> showTs rTs,
|
||||
"B=" <> encode bmId <> "," <> showTs bTs,
|
||||
"S=" <> bshow sndMsgId
|
||||
"S=" <> bshow sndMsgId,
|
||||
"PQ=" <> strEncode pqEncryption
|
||||
]
|
||||
where
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
@@ -788,7 +835,8 @@ instance StrEncoding MsgMeta where
|
||||
recipient <- " R=" *> partyMeta A.decimal
|
||||
broker <- " B=" *> partyMeta base64P
|
||||
sndMsgId <- " S=" *> A.decimal
|
||||
pure MsgMeta {integrity, recipient, broker, sndMsgId}
|
||||
pqEncryption <- " PQ=" *> strP
|
||||
pure MsgMeta {integrity, recipient, broker, sndMsgId, pqEncryption}
|
||||
where
|
||||
partyMeta idParser = (,) <$> idParser <* A.char ',' <*> tsISO8601P
|
||||
|
||||
@@ -802,28 +850,28 @@ data SMPConfirmation = SMPConfirmation
|
||||
-- | optional reply queues included in confirmation (added in agent protocol v2)
|
||||
smpReplyQueues :: [SMPQueueInfo],
|
||||
-- | SMP client version
|
||||
smpClientVersion :: Version
|
||||
smpClientVersion :: VersionSMPC
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data AgentMsgEnvelope
|
||||
= AgentConfirmation
|
||||
{ agentVersion :: Version,
|
||||
e2eEncryption_ :: Maybe (E2ERatchetParams 'C.X448),
|
||||
{ agentVersion :: VersionSMPA,
|
||||
e2eEncryption_ :: Maybe (SndE2ERatchetParams 'C.X448),
|
||||
encConnInfo :: ByteString
|
||||
}
|
||||
| AgentMsgEnvelope
|
||||
{ agentVersion :: Version,
|
||||
{ agentVersion :: VersionSMPA,
|
||||
encAgentMessage :: ByteString
|
||||
}
|
||||
| AgentInvitation -- the connInfo in contactInvite is only encrypted with per-queue E2E, not with double ratchet,
|
||||
{ agentVersion :: Version,
|
||||
{ agentVersion :: VersionSMPA,
|
||||
connReq :: ConnectionRequestUri 'CMInvitation,
|
||||
connInfo :: ByteString -- this message is only encrypted with per-queue E2E, not with double ratchet,
|
||||
}
|
||||
| AgentRatchetKey
|
||||
{ agentVersion :: Version,
|
||||
e2eEncryption :: E2ERatchetParams 'C.X448,
|
||||
{ agentVersion :: VersionSMPA,
|
||||
e2eEncryption :: RcvE2ERatchetParams 'C.X448,
|
||||
info :: ByteString
|
||||
}
|
||||
deriving (Show)
|
||||
@@ -1115,7 +1163,7 @@ instance forall m. ConnectionModeI m => StrEncoding (ConnectionRequestUri m) whe
|
||||
CRInvitationUri crData e2eParams -> crEncode "invitation" crData (Just e2eParams)
|
||||
CRContactUri crData -> crEncode "contact" crData Nothing
|
||||
where
|
||||
crEncode :: ByteString -> ConnReqUriData -> Maybe (E2ERatchetParamsUri 'C.X448) -> ByteString
|
||||
crEncode :: ByteString -> ConnReqUriData -> Maybe (RcvE2ERatchetParamsUri 'C.X448) -> ByteString
|
||||
crEncode crMode ConnReqUriData {crScheme, crAgentVRange, crSmpQueues, crClientData} e2eParams =
|
||||
strEncode crScheme <> "/" <> crMode <> "#/?" <> queryStr
|
||||
where
|
||||
@@ -1228,16 +1276,16 @@ sameQueue :: SMPQueue q => (SMPServer, SMP.QueueId) -> q -> Bool
|
||||
sameQueue addr q = sameQAddress addr (qAddress q)
|
||||
{-# INLINE sameQueue #-}
|
||||
|
||||
data SMPQueueInfo = SMPQueueInfo {clientVersion :: Version, queueAddress :: SMPQueueAddress}
|
||||
data SMPQueueInfo = SMPQueueInfo {clientVersion :: VersionSMPC, queueAddress :: SMPQueueAddress}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding SMPQueueInfo where
|
||||
smpEncode (SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey})
|
||||
| clientVersion > 1 = smpEncode (clientVersion, smpServer, senderId, dhPublicKey)
|
||||
| clientVersion > initialSMPClientVersion = smpEncode (clientVersion, smpServer, senderId, dhPublicKey)
|
||||
| otherwise = smpEncode clientVersion <> legacyEncodeServer smpServer <> smpEncode (senderId, dhPublicKey)
|
||||
smpP = do
|
||||
clientVersion <- smpP
|
||||
smpServer <- if clientVersion > 1 then smpP else updateSMPServerHosts <$> legacyServerP
|
||||
smpServer <- if clientVersion > initialSMPClientVersion then smpP else updateSMPServerHosts <$> legacyServerP
|
||||
(senderId, dhPublicKey) <- smpP
|
||||
pure $ SMPQueueInfo clientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey}
|
||||
|
||||
@@ -1245,20 +1293,20 @@ instance Encoding SMPQueueInfo where
|
||||
-- But this is created to allow backward and forward compatibility where SMPQueueUri
|
||||
-- could have more fields to convert to different versions of SMPQueueInfo in a different way,
|
||||
-- and this instance would become non-trivial.
|
||||
instance VersionI SMPQueueInfo where
|
||||
type VersionRangeT SMPQueueInfo = SMPQueueUri
|
||||
instance VersionI SMPClientVersion SMPQueueInfo where
|
||||
type VersionRangeT SMPClientVersion SMPQueueInfo = SMPQueueUri
|
||||
version = clientVersion
|
||||
toVersionRangeT (SMPQueueInfo _v addr) vr = SMPQueueUri vr addr
|
||||
|
||||
instance VersionRangeI SMPQueueUri where
|
||||
type VersionT SMPQueueUri = SMPQueueInfo
|
||||
instance VersionRangeI SMPClientVersion SMPQueueUri where
|
||||
type VersionT SMPClientVersion SMPQueueUri = SMPQueueInfo
|
||||
versionRange = clientVRange
|
||||
toVersionT (SMPQueueUri _vr addr) v = SMPQueueInfo v addr
|
||||
|
||||
-- | SMP queue information sent out-of-band.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#out-of-band-messages
|
||||
data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRange, queueAddress :: SMPQueueAddress}
|
||||
data SMPQueueUri = SMPQueueUri {clientVRange :: VersionRangeSMPC, queueAddress :: SMPQueueAddress}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data SMPQueueAddress = SMPQueueAddress
|
||||
@@ -1307,7 +1355,7 @@ instance StrEncoding SMPQueueUri where
|
||||
smpServer = if maxVersion vr < srvHostnamesSMPClientVersion then updateSMPServerHosts srv' else srv'
|
||||
pure $ SMPQueueUri vr SMPQueueAddress {smpServer, senderId, dhPublicKey}
|
||||
where
|
||||
unversioned = (versionToRange 1,[],) <$> strP <* A.endOfInput
|
||||
unversioned = (versionToRange initialSMPClientVersion,[],) <$> strP <* A.endOfInput
|
||||
versioned = do
|
||||
dhKey_ <- optional strP
|
||||
query <- optional (A.char '/') *> A.char '?' *> strP
|
||||
@@ -1324,8 +1372,8 @@ instance Encoding SMPQueueUri where
|
||||
pure $ SMPQueueUri clientVRange SMPQueueAddress {smpServer, senderId, dhPublicKey}
|
||||
|
||||
data ConnectionRequestUri (m :: ConnectionMode) where
|
||||
CRInvitationUri :: ConnReqUriData -> E2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation
|
||||
-- contact connection request does NOT contain E2E encryption parameters -
|
||||
CRInvitationUri :: ConnReqUriData -> RcvE2ERatchetParamsUri 'C.X448 -> ConnectionRequestUri CMInvitation
|
||||
-- contact connection request does NOT contain E2E encryption parameters for double ratchet -
|
||||
-- they are passed in AgentInvitation message
|
||||
CRContactUri :: ConnReqUriData -> ConnectionRequestUri CMContact
|
||||
|
||||
@@ -1336,15 +1384,15 @@ deriving instance Show (ConnectionRequestUri m)
|
||||
data AConnectionRequestUri = forall m. ConnectionModeI m => ACR (SConnectionMode m) (ConnectionRequestUri m)
|
||||
|
||||
instance Eq AConnectionRequestUri where
|
||||
ACR m cr == ACR m' cr' = case testEquality m m' of
|
||||
Just Refl -> cr == cr'
|
||||
_ -> False
|
||||
ACR m cr == ACR m' cr' = case testEquality m m' of
|
||||
Just Refl -> cr == cr'
|
||||
_ -> False
|
||||
|
||||
deriving instance Show AConnectionRequestUri
|
||||
|
||||
data ConnReqUriData = ConnReqUriData
|
||||
{ crScheme :: ServiceScheme,
|
||||
crAgentVRange :: VersionRange,
|
||||
crAgentVRange :: VersionRangeSMPA,
|
||||
crSmpQueues :: NonEmpty SMPQueueUri,
|
||||
crClientData :: Maybe CRClientData
|
||||
}
|
||||
@@ -1713,13 +1761,13 @@ commandP binaryP =
|
||||
>>= \case
|
||||
ACmdTag SClient e cmd ->
|
||||
ACmd SClient e <$> case cmd of
|
||||
NEW_ -> s (NEW <$> strP_ <*> strP_ <*> (strP <|> pure SMP.SMSubscribe))
|
||||
JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP)
|
||||
NEW_ -> s (NEW <$> strP_ <*> strP_ <*> pqIKP <*> (strP <|> pure SMP.SMSubscribe))
|
||||
JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> pqSupP <*> (strP_ <|> pure SMP.SMSubscribe) <*> binaryP)
|
||||
LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP)
|
||||
ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> binaryP)
|
||||
ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> binaryP)
|
||||
RJCT_ -> s (RJCT <$> A.takeByteString)
|
||||
SUB_ -> pure SUB
|
||||
SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP)
|
||||
SEND_ -> s (SEND <$> pqEncP <*> smpP <* A.space <*> binaryP)
|
||||
ACK_ -> s (ACK <$> A.decimal <*> optional (A.space *> binaryP))
|
||||
SWCH_ -> pure SWCH
|
||||
OFF_ -> pure OFF
|
||||
@@ -1728,10 +1776,10 @@ commandP binaryP =
|
||||
ACmdTag SAgent e cmd ->
|
||||
ACmd SAgent e <$> case cmd of
|
||||
INV_ -> s (INV <$> strP)
|
||||
CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP)
|
||||
REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP)
|
||||
INFO_ -> s (INFO <$> binaryP)
|
||||
CON_ -> pure CON
|
||||
CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strListP <* A.space <*> binaryP)
|
||||
REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> pqSupP <*> strP_ <*> binaryP)
|
||||
INFO_ -> s (INFO <$> pqSupP <*> binaryP)
|
||||
CON_ -> s (CON <$> strP)
|
||||
END_ -> pure END
|
||||
CONNECT_ -> s (CONNECT <$> strP_ <*> strP)
|
||||
DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP)
|
||||
@@ -1739,7 +1787,7 @@ commandP binaryP =
|
||||
UP_ -> s (UP <$> strP_ <*> connections)
|
||||
SWITCH_ -> s (SWITCH <$> strP_ <*> strP_ <*> strP)
|
||||
RSYNC_ -> s (RSYNC <$> strP_ <*> strP <*> strP)
|
||||
MID_ -> s (MID <$> A.decimal)
|
||||
MID_ -> s (MID <$> A.decimal <*> _strP)
|
||||
SENT_ -> s (SENT <$> A.decimal)
|
||||
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
|
||||
MERRS_ -> s (MERRS <$> strP_ <*> strP)
|
||||
@@ -1762,6 +1810,12 @@ commandP binaryP =
|
||||
where
|
||||
s :: Parser a -> Parser a
|
||||
s p = A.space *> p
|
||||
pqIKP :: Parser InitialKeys
|
||||
pqIKP = strP_ <|> pure (IKNoPQ PQSupportOff)
|
||||
pqSupP :: Parser PQSupport
|
||||
pqSupP = strP_ <|> pure PQSupportOff
|
||||
pqEncP :: Parser PQEncryption
|
||||
pqEncP = strP_ <|> pure PQEncOff
|
||||
connections :: Parser [ConnId]
|
||||
connections = strP `A.sepBy'` A.char ','
|
||||
sfDone :: Text -> Either String (ACommand 'Agent 'AESndFile)
|
||||
@@ -1777,15 +1831,15 @@ parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX
|
||||
-- | Serialize SMP agent command.
|
||||
serializeCommand :: ACommand p e -> ByteString
|
||||
serializeCommand = \case
|
||||
NEW ntfs cMode subMode -> s (NEW_, ntfs, cMode, subMode)
|
||||
NEW ntfs cMode pqIK subMode -> s (NEW_, ntfs, cMode, pqIK, subMode)
|
||||
INV cReq -> s (INV_, cReq)
|
||||
JOIN ntfs cReq subMode cInfo -> s (JOIN_, ntfs, cReq, subMode, Str $ serializeBinary cInfo)
|
||||
CONF confId srvs cInfo -> B.unwords [s CONF_, confId, strEncodeList srvs, serializeBinary cInfo]
|
||||
JOIN ntfs cReq pqSup subMode cInfo -> s (JOIN_, ntfs, cReq, pqSup, subMode, Str $ serializeBinary cInfo)
|
||||
CONF confId pqSup srvs cInfo -> B.unwords [s CONF_, confId, s pqSup, strEncodeList srvs, serializeBinary cInfo]
|
||||
LET confId cInfo -> B.unwords [s LET_, confId, serializeBinary cInfo]
|
||||
REQ invId srvs cInfo -> B.unwords [s REQ_, invId, s srvs, serializeBinary cInfo]
|
||||
ACPT invId cInfo -> B.unwords [s ACPT_, invId, serializeBinary cInfo]
|
||||
REQ invId pqSup srvs cInfo -> B.unwords [s REQ_, invId, s pqSup, s srvs, serializeBinary cInfo]
|
||||
ACPT invId pqSup cInfo -> B.unwords [s ACPT_, invId, s pqSup, serializeBinary cInfo]
|
||||
RJCT invId -> B.unwords [s RJCT_, invId]
|
||||
INFO cInfo -> B.unwords [s INFO_, serializeBinary cInfo]
|
||||
INFO pqSup cInfo -> B.unwords [s INFO_, s pqSup, serializeBinary cInfo]
|
||||
SUB -> s SUB_
|
||||
END -> s END_
|
||||
CONNECT p h -> s (CONNECT_, p, h)
|
||||
@@ -1794,8 +1848,8 @@ serializeCommand = \case
|
||||
UP srv conns -> B.unwords [s UP_, s srv, connections conns]
|
||||
SWITCH dir phase srvs -> s (SWITCH_, dir, phase, srvs)
|
||||
RSYNC rrState cryptoErr cstats -> s (RSYNC_, rrState, cryptoErr, cstats)
|
||||
SEND msgFlags msgBody -> B.unwords [s SEND_, smpEncode msgFlags, serializeBinary msgBody]
|
||||
MID mId -> s (MID_, mId)
|
||||
SEND pqEnc msgFlags msgBody -> B.unwords [s SEND_, s pqEnc, smpEncode msgFlags, serializeBinary msgBody]
|
||||
MID mId pqEnc -> s (MID_, mId, pqEnc)
|
||||
SENT mId -> s (SENT_, mId)
|
||||
MERR mId e -> s (MERR_, mId, e)
|
||||
MERRS mIds e -> s (MERRS_, mIds, e)
|
||||
@@ -1811,7 +1865,7 @@ serializeCommand = \case
|
||||
DEL_USER userId -> s (DEL_USER_, userId)
|
||||
CHK -> s CHK_
|
||||
STAT srvs -> s (STAT_, srvs)
|
||||
CON -> s CON_
|
||||
CON pqEnc -> s (CON_, pqEnc)
|
||||
ERR e -> s (ERR_, e)
|
||||
OK -> s OK_
|
||||
SUSPENDED -> s SUSPENDED_
|
||||
@@ -1884,14 +1938,14 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p))
|
||||
cmdWithMsgBody (APC e cmd) =
|
||||
APC e <$$> case cmd of
|
||||
SEND msgFlags body -> SEND msgFlags <$$> getBody body
|
||||
SEND pqEnc msgFlags body -> SEND pqEnc msgFlags <$$> getBody body
|
||||
MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body
|
||||
JOIN ntfs qUri subMode cInfo -> JOIN ntfs qUri subMode <$$> getBody cInfo
|
||||
CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo
|
||||
JOIN ntfs qUri pqSup subMode cInfo -> JOIN ntfs qUri pqSup subMode <$$> getBody cInfo
|
||||
CONF confId pqSup srvs cInfo -> CONF confId pqSup srvs <$$> getBody cInfo
|
||||
LET confId cInfo -> LET confId <$$> getBody cInfo
|
||||
REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo
|
||||
ACPT invId cInfo -> ACPT invId <$$> getBody cInfo
|
||||
INFO cInfo -> INFO <$$> getBody cInfo
|
||||
REQ invId pqSup srvs cInfo -> REQ invId pqSup srvs <$$> getBody cInfo
|
||||
ACPT invId pqSup cInfo -> ACPT invId pqSup <$$> getBody cInfo
|
||||
INFO pqSup cInfo -> INFO pqSup <$$> getBody cInfo
|
||||
_ -> pure $ Right cmd
|
||||
|
||||
getBody :: ByteString -> m (Either AgentErrorType ByteString)
|
||||
|
||||
@@ -30,7 +30,7 @@ import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval (RI2State)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448)
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, PQEncryption, PQSupport)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
( MsgBody,
|
||||
@@ -44,10 +44,10 @@ import Simplex.Messaging.Protocol
|
||||
RcvPrivateAuthKey,
|
||||
SndPrivateAuthKey,
|
||||
SndPublicAuthKey,
|
||||
VersionSMPC,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
|
||||
-- * Queue types
|
||||
|
||||
@@ -61,8 +61,6 @@ data DBQueueId (q :: QueueStored) where
|
||||
DBQueueId :: Int64 -> DBQueueId 'QSStored
|
||||
DBNewQueue :: DBQueueId 'QSNew
|
||||
|
||||
deriving instance Eq (DBQueueId q)
|
||||
|
||||
deriving instance Show (DBQueueId q)
|
||||
|
||||
type RcvQueue = StoredRcvQueue 'QSStored
|
||||
@@ -96,12 +94,12 @@ data StoredRcvQueue (q :: QueueStored) = RcvQueue
|
||||
dbReplaceQueueId :: Maybe Int64,
|
||||
rcvSwchStatus :: Maybe RcvSwitchStatus,
|
||||
-- | SMP client version
|
||||
smpClientVersion :: Version,
|
||||
smpClientVersion :: VersionSMPC,
|
||||
-- | credentials used in context of notifications
|
||||
clientNtfCreds :: Maybe ClientNtfCreds,
|
||||
deleteErrors :: Int
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
rcvQueueInfo :: RcvQueue -> RcvQueueInfo
|
||||
rcvQueueInfo rq@RcvQueue {server, rcvSwchStatus} =
|
||||
@@ -128,7 +126,7 @@ data ClientNtfCreds = ClientNtfCreds
|
||||
-- | shared DH secret used to encrypt/decrypt notification metadata (NMsgMeta) from server to recipient
|
||||
rcvNtfDhSecret :: RcvNtfDhSecret
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
type SndQueue = StoredSndQueue 'QSStored
|
||||
|
||||
@@ -159,9 +157,9 @@ data StoredSndQueue (q :: QueueStored) = SndQueue
|
||||
dbReplaceQueueId :: Maybe Int64,
|
||||
sndSwchStatus :: Maybe SndSwitchStatus,
|
||||
-- | SMP client version
|
||||
smpClientVersion :: Version
|
||||
smpClientVersion :: VersionSMPC
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
sndQueueInfo :: SndQueue -> SndQueueInfo
|
||||
sndQueueInfo SndQueue {server, sndSwchStatus} =
|
||||
@@ -256,8 +254,6 @@ data Connection (d :: ConnType) where
|
||||
DuplexConnection :: ConnData -> NonEmpty RcvQueue -> NonEmpty SndQueue -> Connection CDuplex
|
||||
ContactConnection :: ConnData -> RcvQueue -> Connection CContact
|
||||
|
||||
deriving instance Eq (Connection d)
|
||||
|
||||
deriving instance Show (Connection d)
|
||||
|
||||
toConnData :: Connection d -> ConnData
|
||||
@@ -290,8 +286,6 @@ connType SCSnd = CSnd
|
||||
connType SCDuplex = CDuplex
|
||||
connType SCContact = CContact
|
||||
|
||||
deriving instance Eq (SConnType d)
|
||||
|
||||
deriving instance Show (SConnType d)
|
||||
|
||||
instance TestEquality SConnType where
|
||||
@@ -305,21 +299,17 @@ instance TestEquality SConnType where
|
||||
-- Used to refer to an arbitrary connection when retrieving from store.
|
||||
data SomeConn = forall d. SomeConn (SConnType d) (Connection d)
|
||||
|
||||
instance Eq SomeConn where
|
||||
SomeConn d c == SomeConn d' c' = case testEquality d d' of
|
||||
Just Refl -> c == c'
|
||||
_ -> False
|
||||
|
||||
deriving instance Show SomeConn
|
||||
|
||||
data ConnData = ConnData
|
||||
{ connId :: ConnId,
|
||||
userId :: UserId,
|
||||
connAgentVersion :: Version,
|
||||
connAgentVersion :: VersionSMPA,
|
||||
enableNtfs :: Bool,
|
||||
lastExternalSndId :: PrevExternalSndId,
|
||||
deleted :: Bool,
|
||||
ratchetSyncState :: RatchetSyncState
|
||||
ratchetSyncState :: RatchetSyncState,
|
||||
pqSupport :: PQSupport
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -534,6 +524,7 @@ data SndMsgData = SndMsgData
|
||||
msgType :: AgentMessageType,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: MsgBody,
|
||||
pqEncryption :: PQEncryption,
|
||||
internalHash :: MsgHash,
|
||||
prevMsgHash :: MsgHash
|
||||
}
|
||||
@@ -551,6 +542,7 @@ data PendingMsgData = PendingMsgData
|
||||
msgType :: AgentMessageType,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: MsgBody,
|
||||
pqEncryption :: PQEncryption,
|
||||
msgRetryState :: Maybe RI2State,
|
||||
internalTs :: InternalTs
|
||||
}
|
||||
|
||||
@@ -50,7 +50,6 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
createNewConn,
|
||||
updateNewConnRcv,
|
||||
updateNewConnSnd,
|
||||
createRcvConn, -- no longer used
|
||||
createSndConn,
|
||||
getConn,
|
||||
getDeletedConn,
|
||||
@@ -59,7 +58,9 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
getConnData,
|
||||
setConnDeleted,
|
||||
setConnAgentVersion,
|
||||
setConnPQSupport,
|
||||
getDeletedConnIds,
|
||||
getDeletedWaitingDeliveryConnIds,
|
||||
setConnRatchetSync,
|
||||
addProcessedRatchetKeyHash,
|
||||
checkRatchetKeyHashExists,
|
||||
@@ -93,7 +94,6 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
getAcceptedConfirmation,
|
||||
removeConfirmations,
|
||||
-- Invitations - sent via Contact connections
|
||||
setConnectionVersion,
|
||||
createInvitation,
|
||||
getInvitation,
|
||||
acceptInvitation,
|
||||
@@ -241,7 +241,7 @@ import Data.List (foldl', intercalate, sortBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromMaybe, isJust, listToMaybe, catMaybes)
|
||||
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe)
|
||||
import Data.Ord (Down (..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
@@ -268,7 +268,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations (DownMigration (..), MTRE
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..))
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys)
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys, PQEncryption (..), PQSupport (..))
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..))
|
||||
@@ -278,7 +279,7 @@ import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, ($>>=), (<$$>))
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath (takeDirectory)
|
||||
@@ -542,11 +543,8 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
|
||||
ConnData {connId} -> Right . (connId,) <$> create connId
|
||||
|
||||
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createNewConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} cMode = do
|
||||
fst <$$> createConn_ gVar cData create
|
||||
where
|
||||
create connId =
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True)
|
||||
createNewConn db gVar cData cMode = do
|
||||
fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode)
|
||||
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue)
|
||||
updateNewConnRcv db connId rq =
|
||||
@@ -568,22 +566,25 @@ updateNewConnSnd db connId sq =
|
||||
updateConn :: IO (Either StoreError SndQueue)
|
||||
updateConn = Right <$> addConnSndQueue_ db connId sq
|
||||
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue))
|
||||
createRcvConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@RcvQueue {server} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, cMode, connAgentVersion, enableNtfs, True)
|
||||
insertRcvQueue_ db connId q serverKeyHash_
|
||||
|
||||
createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue))
|
||||
createSndConn db gVar cData@ConnData {userId, connAgentVersion, enableNtfs} q@SndQueue {server} =
|
||||
createSndConn db gVar cData q@SndQueue {server} =
|
||||
-- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_
|
||||
ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
serverKeyHash_ <- createServer_ db server
|
||||
DB.execute db "INSERT INTO connections (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?,?,?,?,?,?)" (userId, connId, SCMInvitation, connAgentVersion, enableNtfs, True)
|
||||
createConnRecord db connId cData SCMInvitation
|
||||
insertSndQueue_ db connId q serverKeyHash_
|
||||
|
||||
createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO ()
|
||||
createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSupport} cMode =
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
(user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?)
|
||||
|]
|
||||
(userId, connId, cMode, connAgentVersion, enableNtfs, pqSupport, True)
|
||||
|
||||
checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool
|
||||
checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do
|
||||
fromMaybe False
|
||||
@@ -602,12 +603,32 @@ getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do
|
||||
DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId)
|
||||
(rq,) <$> ExceptT (getConn db connId)
|
||||
|
||||
deleteConn :: DB.Connection -> ConnId -> IO ()
|
||||
deleteConn db connId =
|
||||
DB.executeNamed
|
||||
db
|
||||
"DELETE FROM connections WHERE conn_id = :conn_id;"
|
||||
[":conn_id" := connId]
|
||||
-- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted
|
||||
deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId)
|
||||
deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of
|
||||
Nothing -> delete
|
||||
Just timeout ->
|
||||
ifM
|
||||
checkNoPendingDeliveries_
|
||||
delete
|
||||
( ifM
|
||||
(checkWaitDeliveryTimeout_ timeout)
|
||||
delete
|
||||
(pure Nothing)
|
||||
)
|
||||
where
|
||||
delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId
|
||||
checkNoPendingDeliveries_ = do
|
||||
r :: (Maybe Int64) <-
|
||||
maybeFirstRow fromOnly $
|
||||
DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId)
|
||||
pure $ isNothing r
|
||||
checkWaitDeliveryTimeout_ timeout = do
|
||||
cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime
|
||||
r :: (Maybe Int64) <-
|
||||
maybeFirstRow fromOnly $
|
||||
DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs)
|
||||
pure $ isJust r
|
||||
|
||||
upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue)
|
||||
upgradeRcvConnToDuplex db connId sq =
|
||||
@@ -681,7 +702,7 @@ setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = d
|
||||
|]
|
||||
(host, port, rcvId)
|
||||
|
||||
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> Version -> IO ()
|
||||
setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO ()
|
||||
setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion =
|
||||
DB.executeNamed
|
||||
db
|
||||
@@ -782,7 +803,7 @@ setRcvQueueNtfCreds db connId clientNtfCreds =
|
||||
Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret)
|
||||
Nothing -> (Nothing, Nothing, Nothing, Nothing)
|
||||
|
||||
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe Version)
|
||||
type SMPConfirmationRow = (SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC)
|
||||
|
||||
smpConfirmation :: SMPConfirmationRow -> SMPConfirmation
|
||||
smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) =
|
||||
@@ -791,7 +812,7 @@ smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersi
|
||||
e2ePubKey,
|
||||
connInfo,
|
||||
smpReplyQueues = fromMaybe [] smpReplyQueues_,
|
||||
smpClientVersion = fromMaybe 1 smpClientVersion_
|
||||
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
|
||||
}
|
||||
|
||||
createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId)
|
||||
@@ -868,10 +889,6 @@ removeConfirmations db connId =
|
||||
|]
|
||||
[":conn_id" := connId]
|
||||
|
||||
setConnectionVersion :: DB.Connection -> ConnId -> Version -> IO ()
|
||||
setConnectionVersion db connId aVersion =
|
||||
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
|
||||
|
||||
createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId)
|
||||
createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} =
|
||||
createWithRandomId gVar $ \invitationId ->
|
||||
@@ -1008,18 +1025,18 @@ getPendingQueueMsg db connId SndQueue {dbQueueId} =
|
||||
DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT m.msg_type, m.msg_flags, m.msg_body, m.internal_ts, s.retry_int_slow, s.retry_int_fast
|
||||
SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast
|
||||
FROM messages m
|
||||
JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id
|
||||
WHERE m.conn_id = ? AND m.internal_id = ?
|
||||
|]
|
||||
(connId, msgId)
|
||||
err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []"
|
||||
pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData
|
||||
pendingMsgData (msgType, msgFlags_, msgBody, internalTs, riSlow_, riFast_) =
|
||||
pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData
|
||||
pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) =
|
||||
let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_
|
||||
msgRetryState = RI2State <$> riSlow_ <*> riFast_
|
||||
in PendingMsgData {msgId, msgType, msgFlags, msgBody, msgRetryState, internalTs}
|
||||
in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs}
|
||||
markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId)
|
||||
|
||||
getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a))
|
||||
@@ -1088,7 +1105,7 @@ getRcvMsg db connId agentMsgId =
|
||||
[sql|
|
||||
SELECT
|
||||
r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash,
|
||||
m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack
|
||||
m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack
|
||||
FROM rcv_messages r
|
||||
JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id
|
||||
LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id
|
||||
@@ -1104,7 +1121,7 @@ getLastMsg db connId msgId =
|
||||
[sql|
|
||||
SELECT
|
||||
r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash,
|
||||
m.msg_type, m.msg_body, s.internal_id, s.rcpt_status, r.user_ack
|
||||
m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack
|
||||
FROM rcv_messages r
|
||||
JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id
|
||||
JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id
|
||||
@@ -1113,9 +1130,9 @@ getLastMsg db connId msgId =
|
||||
|]
|
||||
(connId, msgId)
|
||||
|
||||
toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs, AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg
|
||||
toRcvMsg (agentMsgId, internalTs, brokerId, brokerTs, sndMsgId, integrity, internalHash, msgType, msgBody, rcptInternalId_, rcptStatus_, userAck) =
|
||||
let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity}
|
||||
toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg
|
||||
toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) =
|
||||
let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption}
|
||||
msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_
|
||||
in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck}
|
||||
|
||||
@@ -1175,34 +1192,34 @@ deleteSndMsgsExpired db ttl = do
|
||||
"DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL"
|
||||
(Only cutoffTs)
|
||||
|
||||
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO ()
|
||||
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 =
|
||||
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2) VALUES (?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2)
|
||||
createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO ()
|
||||
createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
|
||||
DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem)
|
||||
|
||||
getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448))
|
||||
getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams))
|
||||
getRatchetX3dhKeys db connId =
|
||||
fmap hasKeys $
|
||||
firstRow id SEX3dhKeysNotFound $
|
||||
DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2 FROM ratchets WHERE conn_id = ?" (Only connId)
|
||||
firstRow' keys SEX3dhKeysNotFound $
|
||||
DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId)
|
||||
where
|
||||
hasKeys = \case
|
||||
Right (Just k1, Just k2) -> Right (k1, k2)
|
||||
keys = \case
|
||||
(Just k1, Just k2, pKem) -> Right (k1, k2, pKem)
|
||||
_ -> Left SEX3dhKeysNotFound
|
||||
|
||||
-- used to remember new keys when starting ratchet re-synchronization
|
||||
-- TODO remove the columns for public keys in v5.7.
|
||||
-- Currently, the keys are not used but still stored to support app downgrade to the previous version.
|
||||
setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> IO ()
|
||||
setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 =
|
||||
setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO ()
|
||||
setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem =
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
UPDATE ratchets
|
||||
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?
|
||||
SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, x3dh_pub_key_1 = ?, x3dh_pub_key_2 = ?, pq_priv_kem = ?
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, connId)
|
||||
(x3dhPrivKey1, x3dhPrivKey2, C.publicKey x3dhPrivKey1, C.publicKey x3dhPrivKey2, pqPrivKem, connId)
|
||||
|
||||
-- TODO remove the columns for public keys in v5.7.
|
||||
createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO ()
|
||||
createRatchet db connId rc =
|
||||
DB.executeNamed
|
||||
@@ -1213,7 +1230,10 @@ createRatchet db connId rc =
|
||||
ON CONFLICT (conn_id) DO UPDATE SET
|
||||
ratchet_state = :ratchet_state,
|
||||
x3dh_priv_key_1 = NULL,
|
||||
x3dh_priv_key_2 = NULL
|
||||
x3dh_priv_key_2 = NULL,
|
||||
x3dh_pub_key_1 = NULL,
|
||||
x3dh_pub_key_2 = NULL,
|
||||
pq_priv_kem = NULL
|
||||
|]
|
||||
[":conn_id" := connId, ":ratchet_state" := rc]
|
||||
|
||||
@@ -1752,6 +1772,10 @@ instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEn
|
||||
|
||||
instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
|
||||
|
||||
instance ToField (Version v) where toField (Version v) = toField v
|
||||
|
||||
instance FromField (Version v) where fromField f = Version <$> fromField f
|
||||
|
||||
listToEither :: e -> [a] -> Either e a
|
||||
listToEither _ (x : _) = Right x
|
||||
listToEither e _ = Left e
|
||||
@@ -1903,25 +1927,38 @@ getConnData db connId' =
|
||||
[sql|
|
||||
SELECT
|
||||
user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs,
|
||||
last_external_snd_msg_id, deleted, ratchet_sync_state
|
||||
last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support
|
||||
FROM connections
|
||||
WHERE conn_id = ?
|
||||
|]
|
||||
(Only connId')
|
||||
where
|
||||
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState) =
|
||||
(ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState}, cMode)
|
||||
cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport) =
|
||||
(ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode)
|
||||
|
||||
setConnDeleted :: DB.Connection -> ConnId -> IO ()
|
||||
setConnDeleted db connId = DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
|
||||
setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO ()
|
||||
setConnDeleted db waitDelivery connId
|
||||
| waitDelivery = do
|
||||
currentTs <- getCurrentTime
|
||||
DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId)
|
||||
| otherwise =
|
||||
DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId)
|
||||
|
||||
setConnAgentVersion :: DB.Connection -> ConnId -> Version -> IO ()
|
||||
setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO ()
|
||||
setConnAgentVersion db connId aVersion =
|
||||
DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId)
|
||||
|
||||
setConnPQSupport :: DB.Connection -> ConnId -> PQSupport -> IO ()
|
||||
setConnPQSupport db connId pqSupport =
|
||||
DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId)
|
||||
|
||||
getDeletedConnIds :: DB.Connection -> IO [ConnId]
|
||||
getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True)
|
||||
|
||||
getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId]
|
||||
getDeletedWaitingDeliveryConnIds db =
|
||||
map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL"
|
||||
|
||||
setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO ()
|
||||
setConnRatchetSync db connId ratchetSyncState =
|
||||
DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId)
|
||||
@@ -1970,12 +2007,12 @@ rcvQueueQuery =
|
||||
|
||||
toRcvQueue ::
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, QueueStatus)
|
||||
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe Version, Int)
|
||||
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int)
|
||||
:. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) ->
|
||||
RcvQueue
|
||||
toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, status) :. (dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) =
|
||||
let server = SMPServer host port keyHash
|
||||
smpClientVersion = fromMaybe 1 smpClientVersion_
|
||||
smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_
|
||||
clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of
|
||||
(Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
||||
_ -> Nothing
|
||||
@@ -2011,7 +2048,7 @@ sndQueueQuery =
|
||||
toSndQueue ::
|
||||
(UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId)
|
||||
:. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus)
|
||||
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, Version) ->
|
||||
:. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) ->
|
||||
SndQueue
|
||||
toSndQueue
|
||||
( (userId, keyHash, connId, host, port, sndId)
|
||||
@@ -2060,23 +2097,15 @@ updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do
|
||||
let MsgMeta {recipient = (internalId, internalTs)} = msgMeta
|
||||
DB.executeNamed
|
||||
let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO messages
|
||||
( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body)
|
||||
VALUES
|
||||
(:conn_id,:internal_id,:internal_ts,:internal_rcv_id, NULL,:msg_type,:msg_flags,:msg_body);
|
||||
(conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption)
|
||||
VALUES (?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
[ ":conn_id" := connId,
|
||||
":internal_id" := internalId,
|
||||
":internal_ts" := internalTs,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":msg_type" := msgType,
|
||||
":msg_flags" := msgFlags,
|
||||
":msg_body" := msgBody
|
||||
]
|
||||
(connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption)
|
||||
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do
|
||||
@@ -2157,23 +2186,16 @@ updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
|
||||
-- * createSndMsg helpers
|
||||
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connId SndMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO messages
|
||||
( conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body)
|
||||
(conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption)
|
||||
VALUES
|
||||
(:conn_id,:internal_id,:internal_ts, NULL,:internal_snd_id,:msg_type,:msg_flags,:msg_body);
|
||||
(?,?,?,?,?,?,?,?,?);
|
||||
|]
|
||||
[ ":conn_id" := connId,
|
||||
":internal_id" := internalId,
|
||||
":internal_ts" := internalTs,
|
||||
":internal_snd_id" := internalSndId,
|
||||
":msg_type" := msgType,
|
||||
":msg_flags" := msgFlags,
|
||||
":msg_body" := msgBody
|
||||
]
|
||||
(connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption)
|
||||
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
|
||||
@@ -2267,17 +2289,18 @@ createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redire
|
||||
forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId
|
||||
pure dstEntityId
|
||||
where
|
||||
dummyDst = FileDescription
|
||||
{ party = SFRecipient,
|
||||
size,
|
||||
digest,
|
||||
redirect = Nothing,
|
||||
-- updated later with updateRcvFileRedirect
|
||||
key = C.unsafeSbKey $ B.replicate 32 '#',
|
||||
nonce = C.cbNonce "",
|
||||
chunkSize = FileSize 0,
|
||||
chunks = []
|
||||
}
|
||||
dummyDst =
|
||||
FileDescription
|
||||
{ party = SFRecipient,
|
||||
size,
|
||||
digest,
|
||||
redirect = Nothing,
|
||||
-- updated later with updateRcvFileRedirect
|
||||
key = C.unsafeSbKey $ B.replicate 32 '#',
|
||||
nonce = C.cbNonce "",
|
||||
chunkSize = FileSize 0,
|
||||
chunks = []
|
||||
}
|
||||
|
||||
insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> IO (Either StoreError (RcvFileId, DBRcvFileId))
|
||||
insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ = runExceptT $ do
|
||||
@@ -2346,10 +2369,11 @@ getRcvFile db rcvFileId = runExceptT $ do
|
||||
toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) =
|
||||
let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_
|
||||
saveFile = CryptoFile savePath cfArgs
|
||||
redirect = RcvFileRedirect
|
||||
<$> redirectDbId
|
||||
<*> redirectEntityId
|
||||
<*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_)
|
||||
redirect =
|
||||
RcvFileRedirect
|
||||
<$> redirectDbId
|
||||
<*> redirectEntityId
|
||||
<*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_)
|
||||
in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []}
|
||||
getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk]
|
||||
getChunks rcvFileEntityId userId fileTmpPath = do
|
||||
|
||||
@@ -71,7 +71,8 @@ dbBusyLoop action = loop 500 3000000
|
||||
loop :: Int -> Int -> IO a
|
||||
loop t tLim =
|
||||
action `E.catch` \(e :: SQLError) ->
|
||||
if tLim > t && SQL.sqlError e == SQL.ErrorBusy
|
||||
let se = SQL.sqlError e in
|
||||
if tLim > t && (se == SQL.ErrorBusy || se == SQL.ErrorLocked)
|
||||
then do
|
||||
threadDelay t
|
||||
loop (t * 9 `div` 8) (tLim - t)
|
||||
|
||||
@@ -69,6 +69,8 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231222_command_created
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
@@ -106,7 +108,9 @@ schemaMigrations =
|
||||
("m20231222_command_created_at", m20231222_command_created_at, Just down_m20231222_command_created_at),
|
||||
("m20231225_failed_work_items", m20231225_failed_work_items, Just down_m20231225_failed_work_items),
|
||||
("m20240121_message_delivery_indexes", m20240121_message_delivery_indexes, Just down_m20240121_message_delivery_indexes),
|
||||
("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect)
|
||||
("m20240124_file_redirect", m20240124_file_redirect, Just down_m20240124_file_redirect),
|
||||
("m20240223_connections_wait_delivery", m20240223_connections_wait_delivery, Just down_m20240223_connections_wait_delivery),
|
||||
("m20240225_ratchet_kem", m20240225_ratchet_kem, Just down_m20240225_ratchet_kem)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
+18
@@ -0,0 +1,18 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20240223_connections_wait_delivery :: Query
|
||||
m20240223_connections_wait_delivery =
|
||||
[sql|
|
||||
ALTER TABLE connections ADD COLUMN deleted_at_wait_delivery TEXT;
|
||||
|]
|
||||
|
||||
down_m20240223_connections_wait_delivery :: Query
|
||||
down_m20240223_connections_wait_delivery =
|
||||
[sql|
|
||||
ALTER TABLE connections DROP COLUMN deleted_at_wait_delivery;
|
||||
|]
|
||||
@@ -0,0 +1,22 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20240225_ratchet_kem :: Query
|
||||
m20240225_ratchet_kem =
|
||||
[sql|
|
||||
ALTER TABLE ratchets ADD COLUMN pq_priv_kem BLOB;
|
||||
ALTER TABLE connections ADD COLUMN pq_support INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE messages ADD COLUMN pq_encryption INTEGER NOT NULL DEFAULT 0;
|
||||
|]
|
||||
|
||||
down_m20240225_ratchet_kem :: Query
|
||||
down_m20240225_ratchet_kem =
|
||||
[sql|
|
||||
ALTER TABLE ratchets DROP COLUMN pq_priv_kem;
|
||||
ALTER TABLE connections DROP COLUMN pq_support;
|
||||
ALTER TABLE messages DROP COLUMN pq_encryption;
|
||||
|]
|
||||
@@ -26,7 +26,9 @@ CREATE TABLE connections(
|
||||
deleted INTEGER DEFAULT 0 CHECK(deleted NOT NULL),
|
||||
user_id INTEGER CHECK(user_id NOT NULL)
|
||||
REFERENCES users ON DELETE CASCADE,
|
||||
ratchet_sync_state TEXT NOT NULL DEFAULT 'ok'
|
||||
ratchet_sync_state TEXT NOT NULL DEFAULT 'ok',
|
||||
deleted_at_wait_delivery TEXT,
|
||||
pq_support INTEGER NOT NULL DEFAULT 0
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE rcv_queues(
|
||||
host TEXT NOT NULL,
|
||||
@@ -89,6 +91,7 @@ CREATE TABLE messages(
|
||||
msg_type BLOB NOT NULL, --(H)ELLO,(R)EPLY,(D)ELETE. Should SMP confirmation be saved too?
|
||||
msg_body BLOB NOT NULL DEFAULT x'',
|
||||
msg_flags TEXT NULL,
|
||||
pq_encryption INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY(conn_id, internal_id),
|
||||
FOREIGN KEY(conn_id, internal_rcv_id) REFERENCES rcv_messages
|
||||
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
|
||||
@@ -159,7 +162,8 @@ CREATE TABLE ratchets(
|
||||
e2e_version INTEGER NOT NULL DEFAULT 1
|
||||
,
|
||||
x3dh_pub_key_1 BLOB,
|
||||
x3dh_pub_key_2 BLOB
|
||||
x3dh_pub_key_2 BLOB,
|
||||
pq_priv_kem BLOB
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE skipped_messages(
|
||||
skipped_message_id INTEGER PRIMARY KEY,
|
||||
|
||||
@@ -76,7 +76,7 @@ module Simplex.Messaging.Client
|
||||
PCTransmission,
|
||||
mkTransmission,
|
||||
authTransmission,
|
||||
clientStub,
|
||||
smpClientStub,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -117,14 +117,14 @@ import System.Timeout (timeout)
|
||||
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
|
||||
--
|
||||
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
|
||||
data ProtocolClient err msg = ProtocolClient
|
||||
data ProtocolClient v err msg = ProtocolClient
|
||||
{ action :: Maybe (Async ()),
|
||||
thParams :: THandleParams,
|
||||
thParams :: THandleParams v,
|
||||
sessionTs :: UTCTime,
|
||||
client_ :: PClient err msg
|
||||
client_ :: PClient v err msg
|
||||
}
|
||||
|
||||
data PClient err msg = PClient
|
||||
data PClient v err msg = PClient
|
||||
{ connected :: TVar Bool,
|
||||
transportSession :: TransportSession msg,
|
||||
transportHost :: TransportHost,
|
||||
@@ -135,11 +135,11 @@ data PClient err msg = PClient
|
||||
sentCommands :: TMap CorrId (Request err msg),
|
||||
sndQ :: TBQueue ByteString,
|
||||
rcvQ :: TBQueue (NonEmpty (SignedTransmission err msg)),
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission msg))
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission v msg))
|
||||
}
|
||||
|
||||
clientStub :: TVar ChaChaDRG -> ByteString -> Version -> Maybe THandleAuth -> STM (ProtocolClient err msg)
|
||||
clientStub g sessionId thVersion thAuth = do
|
||||
smpClientStub :: TVar ChaChaDRG -> ByteString -> VersionSMP -> Maybe THandleAuth -> STM (ProtocolClient SMPVersion err msg)
|
||||
smpClientStub g sessionId thVersion thAuth = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- C.newRandomDRG g
|
||||
sentCommands <- TM.empty
|
||||
@@ -174,13 +174,13 @@ clientStub g sessionId thVersion thAuth = do
|
||||
}
|
||||
}
|
||||
|
||||
type SMPClient = ProtocolClient ErrorType BrokerMsg
|
||||
type SMPClient = ProtocolClient SMPVersion ErrorType BrokerMsg
|
||||
|
||||
-- | Type for client command data
|
||||
type ClientCommand msg = (Maybe C.APrivateAuthKey, EntityId, ProtoCommand msg)
|
||||
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type ServerTransmission msg = (TransportSession msg, Version, SessionId, EntityId, msg)
|
||||
type ServerTransmission v msg = (TransportSession msg, Version v, SessionId, EntityId, msg)
|
||||
|
||||
data HostMode
|
||||
= -- | prefer (or require) onion hosts when connecting via SOCKS proxy
|
||||
@@ -241,7 +241,7 @@ transportClientConfig NetworkConfig {socksProxy, tcpKeepAlive, logTLSErrors} =
|
||||
TransportClientConfig {socksProxy, tcpKeepAlive, logTLSErrors, clientCredentials = Nothing}
|
||||
|
||||
-- | protocol client configuration.
|
||||
data ProtocolClientConfig = ProtocolClientConfig
|
||||
data ProtocolClientConfig v = ProtocolClientConfig
|
||||
{ -- | size of TBQueue to use for server commands and responses
|
||||
qSize :: Natural,
|
||||
-- | default server port if port is not specified in ProtocolServer
|
||||
@@ -249,13 +249,13 @@ data ProtocolClientConfig = ProtocolClientConfig
|
||||
-- | network configuration
|
||||
networkConfig :: NetworkConfig,
|
||||
-- | client-server protocol version range
|
||||
serverVRange :: VersionRange,
|
||||
serverVRange :: VersionRange v,
|
||||
-- | delay between sending batches of commands (microseconds)
|
||||
batchDelay :: Maybe Int
|
||||
}
|
||||
|
||||
-- | Default protocol client configuration.
|
||||
defaultClientConfig :: VersionRange -> ProtocolClientConfig
|
||||
defaultClientConfig :: VersionRange v -> ProtocolClientConfig v
|
||||
defaultClientConfig serverVRange =
|
||||
ProtocolClientConfig
|
||||
{ qSize = 64,
|
||||
@@ -265,7 +265,7 @@ defaultClientConfig serverVRange =
|
||||
batchDelay = Nothing
|
||||
}
|
||||
|
||||
defaultSMPClientConfig :: ProtocolClientConfig
|
||||
defaultSMPClientConfig :: ProtocolClientConfig SMPVersion
|
||||
defaultSMPClientConfig = defaultClientConfig supportedClientSMPRelayVRange
|
||||
|
||||
data Request err msg = Request
|
||||
@@ -292,15 +292,15 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
|
||||
onionHost = find isOnionHost hosts
|
||||
publicHost = find (not . isOnionHost) hosts
|
||||
|
||||
protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient err msg -> String
|
||||
protocolClientServer :: ProtocolTypeI (ProtoType msg) => ProtocolClient v err msg -> String
|
||||
protocolClientServer = B.unpack . strEncode . snd3 . transportSession . client_
|
||||
where
|
||||
snd3 (_, s, _) = s
|
||||
|
||||
transportHost' :: ProtocolClient err msg -> TransportHost
|
||||
transportHost' :: ProtocolClient v err msg -> TransportHost
|
||||
transportHost' = transportHost . client_
|
||||
|
||||
transportSession' :: ProtocolClient err msg -> TransportSession msg
|
||||
transportSession' :: ProtocolClient v err msg -> TransportSession msg
|
||||
transportSession' = transportSession . client_
|
||||
|
||||
type UserId = Int64
|
||||
@@ -313,7 +313,7 @@ type TransportSession msg = (UserId, ProtoServer msg, Maybe EntityId)
|
||||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getProtocolClient :: forall err msg. Protocol err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
|
||||
getProtocolClient :: forall v err msg. Protocol v err msg => TVar ChaChaDRG -> TransportSession msg -> ProtocolClientConfig v -> Maybe (TBQueue (ServerTransmission v msg)) -> (ProtocolClient v err msg -> IO ()) -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg))
|
||||
getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, networkConfig, serverVRange, batchDelay} msgQ disconnected = do
|
||||
case chooseTransportHost networkConfig (host srv) of
|
||||
Right useHost ->
|
||||
@@ -322,7 +322,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
Left e -> pure $ Left e
|
||||
where
|
||||
NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig
|
||||
mkProtocolClient :: TransportHost -> STM (PClient err msg)
|
||||
mkProtocolClient :: TransportHost -> STM (PClient v err msg)
|
||||
mkProtocolClient transportHost = do
|
||||
connected <- newTVar False
|
||||
pingErrorCount <- newTVar 0
|
||||
@@ -345,7 +345,7 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
msgQ
|
||||
}
|
||||
|
||||
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient err msg -> IO (Either (ProtocolClientError err) (ProtocolClient err msg))
|
||||
runClient :: (ServiceName, ATransport) -> TransportHost -> PClient v err msg -> IO (Either (ProtocolClientError err) (ProtocolClient v err msg))
|
||||
runClient (port', ATransport t) useHost c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
let tcConfig = transportClientConfig networkConfig
|
||||
@@ -366,10 +366,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
"80" -> ("80", transport @WS)
|
||||
p -> (p, transport @TLS)
|
||||
|
||||
client :: forall c. Transport c => TProxy c -> PClient err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient err msg)) -> c -> IO ()
|
||||
client :: forall c. Transport c => TProxy c -> PClient v err msg -> TMVar (Either (ProtocolClientError err) (ProtocolClient v err msg)) -> c -> IO ()
|
||||
client _ c cVar h = do
|
||||
ks <- atomically $ C.generateKeyPair g
|
||||
runExceptT (protocolClientHandshake @err @msg h ks (keyHash srv) serverVRange) >>= \case
|
||||
runExceptT (protocolClientHandshake @v @err @msg h ks (keyHash srv) serverVRange) >>= \case
|
||||
Left e -> atomically . putTMVar cVar . Left $ PCETransportError e
|
||||
Right th@THandle {params} -> do
|
||||
sessionTs <- getCurrentTime
|
||||
@@ -380,16 +380,16 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
raceAny_ ([send c' th, process c', receive c' th] <> [ping c' | smpPingInterval > 0])
|
||||
`finally` disconnected c'
|
||||
|
||||
send :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
|
||||
send :: Transport c => ProtocolClient v err msg -> THandle v c -> IO ()
|
||||
send ProtocolClient {client_ = PClient {sndQ}} h = forever $ atomically (readTBQueue sndQ) >>= tPutLog h
|
||||
|
||||
receive :: Transport c => ProtocolClient err msg -> THandle c -> IO ()
|
||||
receive :: Transport c => ProtocolClient v err msg -> THandle v c -> IO ()
|
||||
receive ProtocolClient {client_ = PClient {rcvQ}} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: ProtocolClient err msg -> IO ()
|
||||
ping :: ProtocolClient v err msg -> IO ()
|
||||
ping c@ProtocolClient {client_ = PClient {pingErrorCount}} = do
|
||||
threadDelay' smpPingInterval
|
||||
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @err @msg) >>= \case
|
||||
runExceptT (sendProtocolCommand c Nothing "" $ protocolPing @v @err @msg) >>= \case
|
||||
Left PCEResponseTimeout -> do
|
||||
cnt <- atomically $ stateTVar pingErrorCount $ \cnt -> (cnt + 1, cnt + 1)
|
||||
when (maxCnt == 0 || cnt < maxCnt) $ ping c
|
||||
@@ -397,10 +397,10 @@ getProtocolClient g transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize
|
||||
where
|
||||
maxCnt = smpPingCount networkConfig
|
||||
|
||||
process :: ProtocolClient err msg -> IO ()
|
||||
process :: ProtocolClient v err msg -> IO ()
|
||||
process c = forever $ atomically (readTBQueue $ rcvQ $ client_ c) >>= mapM_ (processMsg c)
|
||||
|
||||
processMsg :: ProtocolClient err msg -> SignedTransmission err msg -> IO ()
|
||||
processMsg :: ProtocolClient v err msg -> SignedTransmission err msg -> IO ()
|
||||
processMsg c@ProtocolClient {client_ = PClient {sentCommands}} (_, _, (corrId, entId, respOrErr)) =
|
||||
if B.null $ bs corrId
|
||||
then sendMsg respOrErr
|
||||
@@ -428,7 +428,7 @@ proxyUsername :: TransportSession msg -> ByteString
|
||||
proxyUsername (userId, _, entityId_) = C.sha256Hash $ bshow userId <> maybe "" (":" <>) entityId_
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeProtocolClient :: ProtocolClient err msg -> IO ()
|
||||
closeProtocolClient :: ProtocolClient v err msg -> IO ()
|
||||
closeProtocolClient = mapM_ uninterruptibleCancel . action
|
||||
|
||||
-- | SMP client error type.
|
||||
@@ -517,7 +517,7 @@ processSUBResponse c (Response rId r) = case r of
|
||||
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
|
||||
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ $ client_ c)
|
||||
|
||||
serverTransmission :: ProtocolClient err msg -> RecipientId -> msg -> ServerTransmission msg
|
||||
serverTransmission :: ProtocolClient v err msg -> RecipientId -> msg -> ServerTransmission v msg
|
||||
serverTransmission ProtocolClient {thParams = THandleParams {thVersion, sessionId}, client_ = PClient {transportSession}} entityId message =
|
||||
(transportSession, thVersion, sessionId, entityId, message)
|
||||
|
||||
@@ -635,7 +635,7 @@ sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
|
||||
type PCTransmission err msg = (Either TransportError SentRawTransmission, Request err msg)
|
||||
|
||||
-- | Send multiple commands with batching and collect responses
|
||||
sendProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
|
||||
sendProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Response err msg))
|
||||
sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs = do
|
||||
bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs
|
||||
validate . concat =<< mapM (sendBatch c) bs
|
||||
@@ -652,12 +652,12 @@ sendProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSiz
|
||||
where
|
||||
diff = L.length cs - length rs
|
||||
|
||||
streamProtocolCommands :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
|
||||
streamProtocolCommands :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> NonEmpty (ClientCommand msg) -> ([Response err msg] -> IO ()) -> IO ()
|
||||
streamProtocolCommands c@ProtocolClient {thParams = THandleParams {batch, blockSize}} cs cb = do
|
||||
bs <- batchTransmissions' batch blockSize <$> mapM (mkTransmission c) cs
|
||||
mapM_ (cb <=< sendBatch c) bs
|
||||
|
||||
sendBatch :: ProtocolClient err msg -> TransportBatch (Request err msg) -> IO [Response err msg]
|
||||
sendBatch :: ProtocolClient v err msg -> TransportBatch (Request err msg) -> IO [Response err msg]
|
||||
sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
|
||||
case b of
|
||||
TBError e Request {entityId} -> do
|
||||
@@ -673,7 +673,7 @@ sendBatch c@ProtocolClient {client_ = PClient {sndQ}} b = do
|
||||
(: []) <$> getResponse c r
|
||||
|
||||
-- | Send Protocol command
|
||||
sendProtocolCommand :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
|
||||
sendProtocolCommand :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> Maybe C.APrivateAuthKey -> EntityId -> ProtoCommand msg -> ExceptT (ProtocolClientError err) IO msg
|
||||
sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THandleParams {batch, blockSize}} pKey entId cmd =
|
||||
ExceptT $ uncurry sendRecv =<< mkTransmission c (pKey, entId, cmd)
|
||||
where
|
||||
@@ -690,7 +690,7 @@ sendProtocolCommand c@ProtocolClient {client_ = PClient {sndQ}, thParams = THand
|
||||
| otherwise = tEncode t
|
||||
|
||||
-- TODO switch to timeout or TimeManager that supports Int64
|
||||
getResponse :: ProtocolClient err msg -> Request err msg -> IO (Response err msg)
|
||||
getResponse :: ProtocolClient v err msg -> Request err msg -> IO (Response err msg)
|
||||
getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Request {entityId, responseVar} = do
|
||||
response <-
|
||||
timeout tcpTimeout (atomically (takeTMVar responseVar)) >>= \case
|
||||
@@ -698,7 +698,7 @@ getResponse ProtocolClient {client_ = PClient {tcpTimeout, pingErrorCount}} Requ
|
||||
Nothing -> pure $ Left PCEResponseTimeout
|
||||
pure Response {entityId, response}
|
||||
|
||||
mkTransmission :: forall err msg. ProtocolEncoding err (ProtoCommand msg) => ProtocolClient err msg -> ClientCommand msg -> IO (PCTransmission err msg)
|
||||
mkTransmission :: forall v err msg. ProtocolEncoding v err (ProtoCommand msg) => ProtocolClient v err msg -> ClientCommand msg -> IO (PCTransmission err msg)
|
||||
mkTransmission ProtocolClient {thParams, client_ = PClient {clientCorrId, sentCommands}} (pKey_, entId, cmd) = do
|
||||
corrId <- atomically getNextCorrId
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, entId, cmd)
|
||||
|
||||
@@ -65,7 +65,7 @@ type SMPSub = (SMPSubParty, QueueId)
|
||||
-- type SMPServerSub = (SMPServer, SMPSub)
|
||||
|
||||
data SMPClientAgentConfig = SMPClientAgentConfig
|
||||
{ smpCfg :: ProtocolClientConfig,
|
||||
{ smpCfg :: ProtocolClientConfig SMPVersion,
|
||||
reconnectInterval :: RetryInterval,
|
||||
msgQSize :: Natural,
|
||||
agentQSize :: Natural,
|
||||
@@ -91,7 +91,7 @@ defaultSMPClientAgentConfig =
|
||||
|
||||
data SMPClientAgent = SMPClientAgent
|
||||
{ agentCfg :: SMPClientAgentConfig,
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
msgQ :: TBQueue (ServerTransmission SMPVersion BrokerMsg),
|
||||
agentQ :: TBQueue SMPClientAgentEvent,
|
||||
randomDrg :: TVar ChaChaDRG,
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Compression where
|
||||
|
||||
import qualified Codec.Compression.Zstd.FFI as Z
|
||||
import Control.Monad (forM)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Unsafe as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Foreign
|
||||
import Foreign.C.Types
|
||||
import GHC.IO (unsafePerformIO)
|
||||
import Simplex.Messaging.Encoding
|
||||
import UnliftIO.Exception (bracket)
|
||||
|
||||
data Compressed
|
||||
= -- | Short messages are left intact to skip copying and FFI festivities.
|
||||
Passthrough ByteString
|
||||
| -- | Generic compression using no extra context.
|
||||
Compressed Large
|
||||
|
||||
-- | Messages below this length are not encoded to avoid compression overhead.
|
||||
maxLengthPassthrough :: Int
|
||||
maxLengthPassthrough = 180 -- Sampled from real client data. Messages with length > 180 rapidly gain compression ratio.
|
||||
|
||||
instance Encoding Compressed where
|
||||
smpEncode = \case
|
||||
Passthrough bytes -> "0" <> smpEncode bytes
|
||||
Compressed bytes -> "1" <> smpEncode bytes
|
||||
smpP =
|
||||
smpP >>= \case
|
||||
'0' -> Passthrough <$> smpP
|
||||
'1' -> Compressed <$> smpP
|
||||
x -> fail $ "unknown Compressed tag: " <> show x
|
||||
|
||||
type CompressCtx = (Ptr Z.CCtx, Ptr CChar, CSize)
|
||||
|
||||
withCompressCtx :: CSize -> (CompressCtx -> IO a) -> IO a
|
||||
withCompressCtx scratchSize action =
|
||||
bracket Z.createCCtx Z.freeCCtx $ \cctx ->
|
||||
allocaBytes (fromIntegral scratchSize) $ \scratchPtr ->
|
||||
action (cctx, scratchPtr, scratchSize)
|
||||
|
||||
-- | Compress bytes, falling back to Passthrough in case of some internal error.
|
||||
compress :: CompressCtx -> ByteString -> IO Compressed
|
||||
compress ctx bs = fromRight (Passthrough bs) <$> compress_ ctx bs
|
||||
|
||||
compress_ :: CompressCtx -> ByteString -> IO (Either String Compressed)
|
||||
compress_ (cctx, scratchPtr, scratchSize) bs
|
||||
| B.length bs <= maxLengthPassthrough = pure . Right $ Passthrough bs
|
||||
| otherwise =
|
||||
B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> runExceptT $ do
|
||||
-- should not fail, unless input buffer is too short
|
||||
dstSize <- ExceptT $ Z.checkError $ Z.compressCCtx cctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize) 3
|
||||
liftIO $ Compressed . Large <$> B.packCStringLen (scratchPtr, fromIntegral dstSize)
|
||||
|
||||
type DecompressCtx = (Ptr Z.DCtx, Ptr CChar, CSize)
|
||||
|
||||
withDecompressCtx :: Int -> (DecompressCtx -> IO a) -> IO a
|
||||
withDecompressCtx maxUnpackedSize action =
|
||||
bracket Z.createDCtx Z.freeDCtx $ \dctx ->
|
||||
allocaBytes maxUnpackedSize $ \scratchPtr ->
|
||||
action (dctx, scratchPtr, fromIntegral maxUnpackedSize)
|
||||
|
||||
decompress :: DecompressCtx -> Compressed -> IO (Either String ByteString)
|
||||
decompress (dctx, scratchPtr, scratchSize) = \case
|
||||
Passthrough bs -> pure $ Right bs
|
||||
Compressed (Large bs) ->
|
||||
B.unsafeUseAsCStringLen bs $ \(sourcePtr, sourceSize) -> do
|
||||
res <- Z.checkError $ Z.decompressDCtx dctx scratchPtr scratchSize sourcePtr (fromIntegral sourceSize)
|
||||
forM res $ \dstSize -> B.packCStringLen (scratchPtr, fromIntegral dstSize)
|
||||
|
||||
decompressBatch :: Int -> NonEmpty Compressed -> NonEmpty (Either String ByteString)
|
||||
decompressBatch maxUnpackedSize items = unsafePerformIO $ withDecompressCtx maxUnpackedSize $ forM items . decompress
|
||||
{-# NOINLINE decompressBatch #-} -- prevent double-evaluation under unsafePerformIO
|
||||
@@ -101,6 +101,7 @@ module Simplex.Messaging.Crypto
|
||||
verify,
|
||||
verify',
|
||||
validSignatureSize,
|
||||
checkAlgorithm,
|
||||
|
||||
-- * crypto_box authenticator, as discussed in https://groups.google.com/g/sci.crypt/c/73yb5a9pz2Y/m/LNgRO7IYXOwJ
|
||||
CbAuthenticator (..),
|
||||
@@ -243,8 +244,6 @@ data SAlgorithm :: Algorithm -> Type where
|
||||
SX25519 :: SAlgorithm X25519
|
||||
SX448 :: SAlgorithm X448
|
||||
|
||||
deriving instance Eq (SAlgorithm a)
|
||||
|
||||
deriving instance Show (SAlgorithm a)
|
||||
|
||||
data Alg = forall a. AlgorithmI a => Alg (SAlgorithm a)
|
||||
@@ -297,11 +296,6 @@ data APublicKey
|
||||
AlgorithmI a =>
|
||||
APublicKey (SAlgorithm a) (PublicKey a)
|
||||
|
||||
instance Eq APublicKey where
|
||||
APublicKey a k == APublicKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
instance Encoding APublicKey where
|
||||
smpEncode = smpEncode . encodePubKey
|
||||
{-# INLINE smpEncode #-}
|
||||
@@ -342,11 +336,6 @@ data APrivateKey
|
||||
AlgorithmI a =>
|
||||
APrivateKey (SAlgorithm a) (PrivateKey a)
|
||||
|
||||
instance Eq APrivateKey where
|
||||
APrivateKey a k == APrivateKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
deriving instance Show APrivateKey
|
||||
|
||||
type PrivateKeyEd25519 = PrivateKey Ed25519
|
||||
@@ -372,11 +361,6 @@ data APrivateSignKey
|
||||
(AlgorithmI a, SignatureAlgorithm a) =>
|
||||
APrivateSignKey (SAlgorithm a) (PrivateKey a)
|
||||
|
||||
instance Eq APrivateSignKey where
|
||||
APrivateSignKey a k == APrivateSignKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
deriving instance Show APrivateSignKey
|
||||
|
||||
instance Encoding APrivateSignKey where
|
||||
@@ -396,11 +380,6 @@ data APublicVerifyKey
|
||||
(AlgorithmI a, SignatureAlgorithm a) =>
|
||||
APublicVerifyKey (SAlgorithm a) (PublicKey a)
|
||||
|
||||
instance Eq APublicVerifyKey where
|
||||
APublicVerifyKey a k == APublicVerifyKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
deriving instance Show APublicVerifyKey
|
||||
|
||||
data APrivateDhKey
|
||||
@@ -408,11 +387,6 @@ data APrivateDhKey
|
||||
(AlgorithmI a, DhAlgorithm a) =>
|
||||
APrivateDhKey (SAlgorithm a) (PrivateKey a)
|
||||
|
||||
instance Eq APrivateDhKey where
|
||||
APrivateDhKey a k == APrivateDhKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
deriving instance Show APrivateDhKey
|
||||
|
||||
data APublicDhKey
|
||||
@@ -420,11 +394,6 @@ data APublicDhKey
|
||||
(AlgorithmI a, DhAlgorithm a) =>
|
||||
APublicDhKey (SAlgorithm a) (PublicKey a)
|
||||
|
||||
instance Eq APublicDhKey where
|
||||
APublicDhKey a k == APublicDhKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
deriving instance Show APublicDhKey
|
||||
|
||||
data DhSecret (a :: Algorithm) where
|
||||
@@ -787,8 +756,6 @@ data Signature (a :: Algorithm) where
|
||||
SignatureEd25519 :: Ed25519.Signature -> Signature Ed25519
|
||||
SignatureEd448 :: Ed448.Signature -> Signature Ed448
|
||||
|
||||
deriving instance Eq (Signature a)
|
||||
|
||||
deriving instance Show (Signature a)
|
||||
|
||||
data ASignature
|
||||
@@ -796,11 +763,6 @@ data ASignature
|
||||
(AlgorithmI a, SignatureAlgorithm a) =>
|
||||
ASignature (SAlgorithm a) (Signature a)
|
||||
|
||||
instance Eq ASignature where
|
||||
ASignature a s == ASignature a' s' = case testEquality a a' of
|
||||
Just Refl -> s == s'
|
||||
_ -> False
|
||||
|
||||
deriving instance Show ASignature
|
||||
|
||||
class CryptoSignature s where
|
||||
@@ -885,6 +847,8 @@ data CryptoError
|
||||
CryptoHeaderError String
|
||||
| -- | no sending chain key in ratchet state
|
||||
CERatchetState
|
||||
| -- | no decapsulation key in ratchet state
|
||||
CERatchetKEMState
|
||||
| -- | header decryption error (could indicate that another key should be tried)
|
||||
CERatchetHeader
|
||||
| -- | too many skipped messages
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,16 +19,20 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
|
||||
newtype KEMPublicKey = KEMPublicKey ByteString
|
||||
deriving (Show)
|
||||
deriving (Eq, Show)
|
||||
|
||||
newtype KEMSecretKey = KEMSecretKey ScrubbedBytes
|
||||
deriving (Show)
|
||||
deriving (Eq, Show)
|
||||
|
||||
newtype KEMCiphertext = KEMCiphertext ByteString
|
||||
deriving (Show)
|
||||
deriving (Eq, Show)
|
||||
|
||||
newtype KEMSharedKey = KEMSharedKey ScrubbedBytes
|
||||
deriving (Show)
|
||||
deriving (Eq, Show)
|
||||
|
||||
unsafeRevealKEMSharedKey :: KEMSharedKey -> String
|
||||
unsafeRevealKEMSharedKey (KEMSharedKey scrubbed) = show (BA.convert scrubbed :: ByteString)
|
||||
{-# DEPRECATED unsafeRevealKEMSharedKey "unsafeRevealKEMSharedKey left in code" #-}
|
||||
|
||||
type KEMKeyPair = (KEMPublicKey, KEMSecretKey)
|
||||
|
||||
@@ -60,6 +64,18 @@ sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) =
|
||||
KEMSharedKey
|
||||
<$> BA.alloc c_SNTRUP761_SIZE (\kPtr -> c_sntrup761_dec kPtr cPtr skPtr)
|
||||
|
||||
instance Encoding KEMSecretKey where
|
||||
smpEncode (KEMSecretKey c) = smpEncode . Large $ BA.convert c
|
||||
smpP = KEMSecretKey . BA.convert . unLarge <$> smpP
|
||||
|
||||
instance StrEncoding KEMSecretKey where
|
||||
strEncode (KEMSecretKey pk) = strEncode (BA.convert pk :: ByteString)
|
||||
strP = KEMSecretKey . BA.convert <$> strP @ByteString
|
||||
|
||||
instance Encoding KEMPublicKey where
|
||||
smpEncode (KEMPublicKey pk) = smpEncode . Large $ BA.convert pk
|
||||
smpP = KEMPublicKey . BA.convert . unLarge <$> smpP
|
||||
|
||||
instance StrEncoding KEMPublicKey where
|
||||
strEncode (KEMPublicKey pk) = strEncode (BA.convert pk :: ByteString)
|
||||
strP = KEMPublicKey . BA.convert <$> strP @ByteString
|
||||
@@ -68,6 +84,25 @@ instance Encoding KEMCiphertext where
|
||||
smpEncode (KEMCiphertext c) = smpEncode . Large $ BA.convert c
|
||||
smpP = KEMCiphertext . BA.convert . unLarge <$> smpP
|
||||
|
||||
instance Encoding KEMSharedKey where
|
||||
smpEncode (KEMSharedKey c) = smpEncode (BA.convert c :: ByteString)
|
||||
smpP = KEMSharedKey . BA.convert <$> smpP @ByteString
|
||||
|
||||
instance StrEncoding KEMCiphertext where
|
||||
strEncode (KEMCiphertext pk) = strEncode (BA.convert pk :: ByteString)
|
||||
strP = KEMCiphertext . BA.convert <$> strP @ByteString
|
||||
|
||||
instance StrEncoding KEMSharedKey where
|
||||
strEncode (KEMSharedKey pk) = strEncode (BA.convert pk :: ByteString)
|
||||
strP = KEMSharedKey . BA.convert <$> strP @ByteString
|
||||
|
||||
instance ToJSON KEMSecretKey where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
instance FromJSON KEMSecretKey where
|
||||
parseJSON = strParseJSON "KEMSecretKey"
|
||||
|
||||
instance ToJSON KEMPublicKey where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
@@ -75,8 +110,22 @@ instance ToJSON KEMPublicKey where
|
||||
instance FromJSON KEMPublicKey where
|
||||
parseJSON = strParseJSON "KEMPublicKey"
|
||||
|
||||
instance ToJSON KEMCiphertext where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
instance FromJSON KEMCiphertext where
|
||||
parseJSON = strParseJSON "KEMCiphertext"
|
||||
|
||||
instance ToField KEMSharedKey where
|
||||
toField (KEMSharedKey k) = toField (BA.convert k :: ByteString)
|
||||
|
||||
instance FromField KEMSharedKey where
|
||||
fromField f = KEMSharedKey . BA.convert @ByteString <$> fromField f
|
||||
|
||||
instance ToJSON KEMSharedKey where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
instance FromJSON KEMSharedKey where
|
||||
parseJSON = strParseJSON "KEMSharedKey"
|
||||
|
||||
@@ -179,6 +179,12 @@ instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncodin
|
||||
strP = (,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncoding e, StrEncoding f) => StrEncoding (a, b, c, d, e, f) where
|
||||
strEncode (a, b, c, d, e, f) = B.unwords [strEncode a, strEncode b, strEncode c, strEncode d, strEncode e, strEncode f]
|
||||
{-# INLINE strEncode #-}
|
||||
strP = (,,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
strP_ :: StrEncoding a => Parser a
|
||||
strP_ = strP <* A.space
|
||||
|
||||
|
||||
@@ -10,15 +10,15 @@ import Data.Word (Word16)
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Transport (supportedClientNTFVRange)
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion, supportedClientNTFVRange)
|
||||
import Simplex.Messaging.Protocol (ErrorType)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
|
||||
type NtfClient = ProtocolClient ErrorType NtfResponse
|
||||
type NtfClient = ProtocolClient NTFVersion ErrorType NtfResponse
|
||||
|
||||
type NtfClientError = ProtocolClientError ErrorType
|
||||
|
||||
defaultNTFClientConfig :: ProtocolClientConfig
|
||||
defaultNTFClientConfig :: ProtocolClientConfig NTFVersion
|
||||
defaultNTFClientConfig = defaultClientConfig supportedClientNTFVRange
|
||||
|
||||
ntfRegisterToken :: NtfClient -> C.APrivateAuthKey -> NewNtfEntity 'Token -> ExceptT NtfClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
|
||||
@@ -28,7 +28,7 @@ import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake)
|
||||
import Simplex.Messaging.Parsers (fromTextField_)
|
||||
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
|
||||
import Simplex.Messaging.Util (eitherToMaybe, (<$?>))
|
||||
@@ -147,7 +147,7 @@ instance Encoding ANewNtfEntity where
|
||||
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP)
|
||||
_ -> fail "bad ANewNtfEntity"
|
||||
|
||||
instance Protocol ErrorType NtfResponse where
|
||||
instance Protocol NTFVersion ErrorType NtfResponse where
|
||||
type ProtoCommand NtfResponse = NtfCmd
|
||||
type ProtoType NtfResponse = 'PNTF
|
||||
protocolClientHandshake = ntfClientHandshake
|
||||
@@ -184,7 +184,7 @@ data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
|
||||
|
||||
deriving instance Show NtfCmd
|
||||
|
||||
instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
|
||||
instance NtfEntityI e => ProtocolEncoding NTFVersion ErrorType (NtfCommand e) where
|
||||
type Tag (NtfCommand e) = NtfCommandTag e
|
||||
encodeProtocol _v = \case
|
||||
TNEW newTkn -> e (TNEW_, ' ', newTkn)
|
||||
@@ -203,7 +203,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
|
||||
|
||||
protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag)
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
|
||||
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (auth, _, entityId, _) cmd = case cmd of
|
||||
@@ -223,7 +223,7 @@ instance NtfEntityI e => ProtocolEncoding ErrorType (NtfCommand e) where
|
||||
| not (B.null entityId) = Left $ CMD HAS_AUTH
|
||||
| otherwise = Right cmd
|
||||
|
||||
instance ProtocolEncoding ErrorType NtfCmd where
|
||||
instance ProtocolEncoding NTFVersion ErrorType NtfCmd where
|
||||
type Tag NtfCmd = NtfCmdTag
|
||||
encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c
|
||||
|
||||
@@ -243,7 +243,7 @@ instance ProtocolEncoding ErrorType NtfCmd where
|
||||
SDEL_ -> pure SDEL
|
||||
PING_ -> pure PING
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @NtfResponse
|
||||
fromProtocolError = fromProtocolError @NTFVersion @ErrorType @NtfResponse
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
|
||||
@@ -290,7 +290,7 @@ data NtfResponse
|
||||
| NRPong
|
||||
deriving (Show)
|
||||
|
||||
instance ProtocolEncoding ErrorType NtfResponse where
|
||||
instance ProtocolEncoding NTFVersion ErrorType NtfResponse where
|
||||
type Tag NtfResponse = NtfResponseTag
|
||||
encodeProtocol _v = \case
|
||||
NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey)
|
||||
|
||||
@@ -338,7 +338,7 @@ updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do
|
||||
old <- atomically $ stateTVar tknStatus (,status)
|
||||
when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status
|
||||
|
||||
runNtfClientTransport :: Transport c => THandle c -> M ()
|
||||
runNtfClientTransport :: Transport c => THandleNTF c -> M ()
|
||||
runNtfClientTransport th@THandle {params} = do
|
||||
qSize <- asks $ clientQSize . config
|
||||
ts <- liftIO getSystemTime
|
||||
@@ -355,7 +355,7 @@ runNtfClientTransport th@THandle {params} = do
|
||||
clientDisconnected :: NtfServerClient -> IO ()
|
||||
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False
|
||||
|
||||
receive :: Transport c => THandle c -> NtfServerClient -> M ()
|
||||
receive :: Transport c => THandleNTF c -> NtfServerClient -> M ()
|
||||
receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ, rcvActiveAt} = forever $ do
|
||||
ts <- liftIO $ tGet th
|
||||
forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do
|
||||
@@ -370,7 +370,7 @@ receive th@THandle {params = THandleParams {thAuth}} NtfServerClient {rcvQ, sndQ
|
||||
where
|
||||
write q t = atomically $ writeTBQueue q t
|
||||
|
||||
send :: Transport c => THandle c -> NtfServerClient -> IO ()
|
||||
send :: Transport c => THandleNTF c -> NtfServerClient -> IO ()
|
||||
send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
void . liftIO $ tPut h [Right (Nothing, encodeTransmission params t)]
|
||||
|
||||
@@ -24,6 +24,7 @@ import Numeric.Natural
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF)
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Server.Stats
|
||||
import Simplex.Messaging.Notifications.Server.Store
|
||||
@@ -34,7 +35,6 @@ import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport, THandleParams)
|
||||
import Simplex.Messaging.Transport.Server (TransportServerConfig, loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Version (VersionRange)
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
@@ -60,7 +60,7 @@ data NtfServerConfig = NtfServerConfig
|
||||
logStatsStartTime :: Int64,
|
||||
serverStatsLogFile :: FilePath,
|
||||
serverStatsBackupFile :: Maybe FilePath,
|
||||
ntfServerVRange :: VersionRange,
|
||||
ntfServerVRange :: VersionRangeNTF,
|
||||
transportConfig :: TransportServerConfig
|
||||
}
|
||||
|
||||
@@ -161,13 +161,13 @@ data NtfRequest
|
||||
data NtfServerClient = NtfServerClient
|
||||
{ rcvQ :: TBQueue NtfRequest,
|
||||
sndQ :: TBQueue (Transmission NtfResponse),
|
||||
ntfThParams :: THandleParams,
|
||||
ntfThParams :: THandleParams NTFVersion,
|
||||
connected :: TVar Bool,
|
||||
rcvActiveAt :: TVar SystemTime,
|
||||
sndActiveAt :: TVar SystemTime
|
||||
}
|
||||
|
||||
newNtfServerClient :: Natural -> THandleParams -> SystemTime -> STM NtfServerClient
|
||||
newNtfServerClient :: Natural -> THandleParams NTFVersion -> SystemTime -> STM NtfServerClient
|
||||
newNtfServerClient qSize ntfThParams ts = do
|
||||
rcvQ <- newTBQueue qSize
|
||||
sndQ <- newTBQueue qSize
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Transport where
|
||||
@@ -12,33 +13,51 @@ import Control.Monad.Except
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Word (Word16)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import Simplex.Messaging.Util (liftEitherWith)
|
||||
|
||||
ntfBlockSize :: Int
|
||||
ntfBlockSize = 512
|
||||
|
||||
authBatchCmdsNTFVersion :: Version
|
||||
authBatchCmdsNTFVersion = 2
|
||||
data NTFVersion
|
||||
|
||||
currentClientNTFVersion :: Version
|
||||
currentClientNTFVersion = 1
|
||||
instance VersionScope NTFVersion
|
||||
|
||||
currentServerNTFVersion :: Version
|
||||
currentServerNTFVersion = 1
|
||||
type VersionNTF = Version NTFVersion
|
||||
|
||||
supportedClientNTFVRange :: VersionRange
|
||||
supportedClientNTFVRange = mkVersionRange 1 currentClientNTFVersion
|
||||
type VersionRangeNTF = VersionRange NTFVersion
|
||||
|
||||
supportedServerNTFVRange :: VersionRange
|
||||
supportedServerNTFVRange = mkVersionRange 1 currentServerNTFVersion
|
||||
pattern VersionNTF :: Word16 -> VersionNTF
|
||||
pattern VersionNTF v = Version v
|
||||
|
||||
initialNTFVersion :: VersionNTF
|
||||
initialNTFVersion = VersionNTF 1
|
||||
|
||||
authBatchCmdsNTFVersion :: VersionNTF
|
||||
authBatchCmdsNTFVersion = VersionNTF 2
|
||||
|
||||
currentClientNTFVersion :: VersionNTF
|
||||
currentClientNTFVersion = VersionNTF 1
|
||||
|
||||
currentServerNTFVersion :: VersionNTF
|
||||
currentServerNTFVersion = VersionNTF 1
|
||||
|
||||
supportedClientNTFVRange :: VersionRangeNTF
|
||||
supportedClientNTFVRange = mkVersionRange initialNTFVersion currentClientNTFVersion
|
||||
|
||||
supportedServerNTFVRange :: VersionRangeNTF
|
||||
supportedServerNTFVRange = mkVersionRange initialNTFVersion currentServerNTFVersion
|
||||
|
||||
type THandleNTF c = THandle NTFVersion c
|
||||
|
||||
data NtfServerHandshake = NtfServerHandshake
|
||||
{ ntfVersionRange :: VersionRange,
|
||||
{ ntfVersionRange :: VersionRangeNTF,
|
||||
sessionId :: SessionId,
|
||||
-- pub key to agree shared secrets for command authorization and entity ID encryption.
|
||||
authPubKey :: Maybe (X.SignedExact X.PubKey)
|
||||
@@ -46,7 +65,7 @@ data NtfServerHandshake = NtfServerHandshake
|
||||
|
||||
data NtfClientHandshake = NtfClientHandshake
|
||||
{ -- | agreed SMP notifications server protocol version
|
||||
ntfVersion :: Version,
|
||||
ntfVersion :: VersionNTF,
|
||||
-- | server identity - CA certificate fingerprint
|
||||
keyHash :: C.KeyHash,
|
||||
-- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
|
||||
@@ -66,12 +85,12 @@ instance Encoding NtfServerHandshake where
|
||||
authPubKey <- authEncryptCmdsP (maxVersion ntfVersionRange) $ C.getSignedExact <$> smpP
|
||||
pure NtfServerHandshake {ntfVersionRange, sessionId, authPubKey}
|
||||
|
||||
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
|
||||
encodeAuthEncryptCmds :: Encoding a => VersionNTF -> Maybe a -> ByteString
|
||||
encodeAuthEncryptCmds v k
|
||||
| v >= authBatchCmdsNTFVersion = maybe "" smpEncode k
|
||||
| otherwise = ""
|
||||
|
||||
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
|
||||
authEncryptCmdsP :: VersionNTF -> Parser a -> Parser (Maybe a)
|
||||
authEncryptCmdsP v p = if v >= authBatchCmdsNTFVersion then Just <$> p else pure Nothing
|
||||
|
||||
instance Encoding NtfClientHandshake where
|
||||
@@ -83,16 +102,16 @@ instance Encoding NtfClientHandshake where
|
||||
authPubKey <- ntfAuthPubKeyP ntfVersion
|
||||
pure NtfClientHandshake {ntfVersion, keyHash, authPubKey}
|
||||
|
||||
ntfAuthPubKeyP :: Version -> Parser (Maybe C.PublicKeyX25519)
|
||||
ntfAuthPubKeyP :: VersionNTF -> Parser (Maybe C.PublicKeyX25519)
|
||||
ntfAuthPubKeyP v = if v >= authBatchCmdsNTFVersion then Just <$> smpP else pure Nothing
|
||||
|
||||
encodeNtfAuthPubKey :: Version -> Maybe C.PublicKeyX25519 -> ByteString
|
||||
encodeNtfAuthPubKey :: VersionNTF -> Maybe C.PublicKeyX25519 -> ByteString
|
||||
encodeNtfAuthPubKey v k
|
||||
| v >= authBatchCmdsNTFVersion = maybe "" smpEncode k
|
||||
| otherwise = ""
|
||||
|
||||
-- | Notifcations server transport handshake.
|
||||
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c)
|
||||
ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
|
||||
let sk = C.signX509 serverSignKey $ C.publicToX509 k
|
||||
@@ -106,7 +125,7 @@ ntfServerHandshake serverSignKey c (k, pk) kh ntfVRange = do
|
||||
| otherwise -> throwError $ TEHandshake VERSION
|
||||
|
||||
-- | Notifcations server client transport handshake.
|
||||
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeNTF -> ExceptT TransportError IO (THandleNTF c)
|
||||
ntfClientHandshake c (k, pk) keyHash ntfVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = ntfTHandle c
|
||||
NtfServerHandshake {sessionId = sessId, ntfVersionRange, authPubKey = sk'} <- getHandshake th
|
||||
@@ -122,15 +141,15 @@ ntfClientHandshake c (k, pk) keyHash ntfVRange = do
|
||||
pure $ ntfThHandle th v pk sk_
|
||||
Nothing -> throwError $ TEHandshake VERSION
|
||||
|
||||
ntfThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
|
||||
ntfThHandle :: forall c. THandleNTF c -> VersionNTF -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleNTF c
|
||||
ntfThHandle th@THandle {params} v privKey k_ =
|
||||
-- TODO drop SMP v6: make thAuth non-optional
|
||||
let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_
|
||||
v3 = v >= authBatchCmdsNTFVersion
|
||||
params' = params {thVersion = v, thAuth, implySessId = v3, batch = v3}
|
||||
in (th :: THandle c) {params = params'}
|
||||
in (th :: THandleNTF c) {params = params'}
|
||||
|
||||
ntfTHandle :: Transport c => c -> THandle c
|
||||
ntfTHandle :: Transport c => c -> THandleNTF c
|
||||
ntfTHandle c = THandle {connection = c, params}
|
||||
where
|
||||
params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = False}
|
||||
params = THandleParams {sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = VersionNTF 0, thAuth = Nothing, implySessId = False, batch = False}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
@@ -46,6 +47,10 @@ module Simplex.Messaging.Protocol
|
||||
e2eEncMessageLength,
|
||||
|
||||
-- * SMP protocol types
|
||||
SMPClientVersion,
|
||||
VersionSMPC,
|
||||
VersionRangeSMPC,
|
||||
pattern VersionSMPC,
|
||||
ProtocolEncoding (..),
|
||||
Command (..),
|
||||
SubscriptionMode (..),
|
||||
@@ -117,6 +122,7 @@ module Simplex.Messaging.Protocol
|
||||
SMPMsgMeta (..),
|
||||
NMsgMeta (..),
|
||||
MsgFlags (..),
|
||||
initialSMPClientVersion,
|
||||
userProtocol,
|
||||
rcvMessageMeta,
|
||||
noMsgFlags,
|
||||
@@ -152,6 +158,7 @@ module Simplex.Messaging.Protocol
|
||||
tEncodeBatch1,
|
||||
batchTransmissions,
|
||||
batchTransmissions',
|
||||
batchTransmissions_,
|
||||
|
||||
-- * exports for tests
|
||||
CommandTag (..),
|
||||
@@ -167,6 +174,7 @@ import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson.TH as J
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser, (<?>))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -180,6 +188,7 @@ import Data.Maybe (isJust, isNothing)
|
||||
import Data.String
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import GHC.TypeLits (ErrorMessage (..), TypeError, type (+))
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -191,19 +200,34 @@ import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client (TransportHost, TransportHosts (..))
|
||||
import Simplex.Messaging.Util (bshow, eitherToMaybe, (<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
|
||||
-- SMP client protocol version history:
|
||||
-- 1 - binary protocol encoding (1/1/2022)
|
||||
-- 2 - multiple server hostnames and versioned queue addresses (8/12/2022)
|
||||
|
||||
srvHostnamesSMPClientVersion :: Version
|
||||
srvHostnamesSMPClientVersion = 2
|
||||
data SMPClientVersion
|
||||
|
||||
currentSMPClientVersion :: Version
|
||||
currentSMPClientVersion = 2
|
||||
instance VersionScope SMPClientVersion
|
||||
|
||||
supportedSMPClientVRange :: VersionRange
|
||||
supportedSMPClientVRange = mkVersionRange 1 currentSMPClientVersion
|
||||
type VersionSMPC = Version SMPClientVersion
|
||||
|
||||
type VersionRangeSMPC = VersionRange SMPClientVersion
|
||||
|
||||
pattern VersionSMPC :: Word16 -> VersionSMPC
|
||||
pattern VersionSMPC v = Version v
|
||||
|
||||
initialSMPClientVersion :: VersionSMPC
|
||||
initialSMPClientVersion = VersionSMPC 1
|
||||
|
||||
srvHostnamesSMPClientVersion :: VersionSMPC
|
||||
srvHostnamesSMPClientVersion = VersionSMPC 2
|
||||
|
||||
currentSMPClientVersion :: VersionSMPC
|
||||
currentSMPClientVersion = VersionSMPC 2
|
||||
|
||||
supportedSMPClientVRange :: VersionRangeSMPC
|
||||
supportedSMPClientVRange = mkVersionRange initialSMPClientVersion currentSMPClientVersion
|
||||
|
||||
maxMessageLength :: Int
|
||||
maxMessageLength = 16088
|
||||
@@ -273,7 +297,7 @@ data RawTransmission = RawTransmission
|
||||
data TransmissionAuth
|
||||
= TASignature C.ASignature
|
||||
| TAAuthenticator C.CbAuthenticator
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
-- this encoding is backwards compatible with v6 that used Maybe C.ASignature instead of TAuthorization
|
||||
tAuthBytes :: Maybe TransmissionAuth -> ByteString
|
||||
@@ -339,8 +363,6 @@ data Command (p :: Party) where
|
||||
|
||||
deriving instance Show (Command p)
|
||||
|
||||
deriving instance Eq (Command p)
|
||||
|
||||
data SubscriptionMode = SMSubscribe | SMOnlyCreate
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -645,7 +667,7 @@ data ClientMsgEnvelope = ClientMsgEnvelope
|
||||
deriving (Show)
|
||||
|
||||
data PubHeader = PubHeader
|
||||
{ phVersion :: Version,
|
||||
{ phVersion :: VersionSMPC,
|
||||
phE2ePubDhKey :: Maybe C.PublicKeyX25519
|
||||
}
|
||||
deriving (Show)
|
||||
@@ -747,11 +769,11 @@ instance NFData (SProtocolType p) where rnf spt = spt `seq` ()
|
||||
|
||||
data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p)
|
||||
|
||||
deriving instance Show AProtocolType
|
||||
|
||||
instance Eq AProtocolType where
|
||||
AProtocolType p == AProtocolType p' = isJust $ testEquality p p'
|
||||
|
||||
deriving instance Show AProtocolType
|
||||
|
||||
instance TestEquality SProtocolType where
|
||||
testEquality SPSMP SPSMP = Just Refl
|
||||
testEquality SPNTF SPNTF = Just Refl
|
||||
@@ -1058,7 +1080,7 @@ data CommandError
|
||||
deriving (Eq, Read, Show)
|
||||
|
||||
-- | SMP transmission parser.
|
||||
transmissionP :: THandleParams -> Parser RawTransmission
|
||||
transmissionP :: THandleParams v -> Parser RawTransmission
|
||||
transmissionP THandleParams {sessionId, implySessId} = do
|
||||
authenticator <- smpP
|
||||
authorized <- A.takeByteString
|
||||
@@ -1072,16 +1094,16 @@ transmissionP THandleParams {sessionId, implySessId} = do
|
||||
command <- A.takeByteString
|
||||
pure RawTransmission {authenticator, authorized = authorized', sessId, corrId, entityId, command}
|
||||
|
||||
class (ProtocolEncoding err msg, ProtocolEncoding err (ProtoCommand msg), Show err, Show msg) => Protocol err msg | msg -> err where
|
||||
class (ProtocolEncoding v err msg, ProtocolEncoding v err (ProtoCommand msg), Show err, Show msg) => Protocol v err msg | msg -> v, msg -> err where
|
||||
type ProtoCommand msg = cmd | cmd -> msg
|
||||
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
|
||||
protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
protocolClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange v -> ExceptT TransportError IO (THandle v c)
|
||||
protocolPing :: ProtoCommand msg
|
||||
protocolError :: msg -> Maybe err
|
||||
|
||||
type ProtoServer msg = ProtocolServer (ProtoType msg)
|
||||
|
||||
instance Protocol ErrorType BrokerMsg where
|
||||
instance Protocol SMPVersion ErrorType BrokerMsg where
|
||||
type ProtoCommand BrokerMsg = Cmd
|
||||
type ProtoType BrokerMsg = 'PSMP
|
||||
protocolClientHandshake = smpClientHandshake
|
||||
@@ -1090,14 +1112,14 @@ instance Protocol ErrorType BrokerMsg where
|
||||
ERR e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding err msg | msg -> err where
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding v err msg | msg -> err, msg -> v where
|
||||
type Tag msg
|
||||
encodeProtocol :: Version -> msg -> ByteString
|
||||
protocolP :: Version -> Tag msg -> Parser msg
|
||||
encodeProtocol :: Version v -> msg -> ByteString
|
||||
protocolP :: Version v -> Tag msg -> Parser msg
|
||||
fromProtocolError :: ProtocolErrorType -> err
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either err msg
|
||||
|
||||
instance PartyI p => ProtocolEncoding ErrorType (Command p) where
|
||||
instance PartyI p => ProtocolEncoding SMPVersion ErrorType (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
encodeProtocol v = \case
|
||||
NEW rKey dhKey auth_ subMode
|
||||
@@ -1124,7 +1146,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where
|
||||
|
||||
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
|
||||
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials (auth, _, queueId, _) cmd = case cmd of
|
||||
@@ -1146,7 +1168,7 @@ instance PartyI p => ProtocolEncoding ErrorType (Command p) where
|
||||
| isNothing auth || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance ProtocolEncoding ErrorType Cmd where
|
||||
instance ProtocolEncoding SMPVersion ErrorType Cmd where
|
||||
type Tag Cmd = CmdTag
|
||||
encodeProtocol v (Cmd _ c) = encodeProtocol v c
|
||||
|
||||
@@ -1174,12 +1196,12 @@ instance ProtocolEncoding ErrorType Cmd where
|
||||
PING_ -> pure PING
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
|
||||
fromProtocolError = fromProtocolError @ErrorType @BrokerMsg
|
||||
fromProtocolError = fromProtocolError @SMPVersion @ErrorType @BrokerMsg
|
||||
{-# INLINE fromProtocolError #-}
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
|
||||
instance ProtocolEncoding ErrorType BrokerMsg where
|
||||
instance ProtocolEncoding SMPVersion ErrorType BrokerMsg where
|
||||
type Tag BrokerMsg = BrokerMsgTag
|
||||
encodeProtocol _v = \case
|
||||
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
|
||||
@@ -1231,12 +1253,12 @@ instance ProtocolEncoding ErrorType BrokerMsg where
|
||||
| otherwise -> Right cmd
|
||||
|
||||
-- | Parse SMP protocol commands and broker messages
|
||||
parseProtocol :: forall err msg. ProtocolEncoding err msg => Version -> ByteString -> Either err msg
|
||||
parseProtocol :: forall v err msg. ProtocolEncoding v err msg => Version v -> ByteString -> Either err msg
|
||||
parseProtocol v s =
|
||||
let (tag, params) = B.break (== ' ') s
|
||||
in case decodeTag tag of
|
||||
Just cmd -> parse (protocolP v cmd) (fromProtocolError @err @msg $ PECmdSyntax) params
|
||||
Nothing -> Left $ fromProtocolError @err @msg $ PECmdUnknown
|
||||
Just cmd -> parse (protocolP v cmd) (fromProtocolError @v @err @msg $ PECmdSyntax) params
|
||||
Nothing -> Left $ fromProtocolError @v @err @msg $ PECmdUnknown
|
||||
|
||||
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
|
||||
checkParty c = case testEquality (sParty @p) (sParty @p') of
|
||||
@@ -1291,7 +1313,7 @@ instance Encoding CommandError where
|
||||
_ -> fail "bad command error type"
|
||||
|
||||
-- | Send signed SMP transmission to TCP transport.
|
||||
tPut :: Transport c => THandle c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
|
||||
tPut :: Transport c => THandle v c -> NonEmpty (Either TransportError SentRawTransmission) -> IO [Either TransportError ()]
|
||||
tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (batch params) (blockSize params)
|
||||
where
|
||||
tPutBatch :: TransportBatch () -> IO [Either TransportError ()]
|
||||
@@ -1300,7 +1322,7 @@ tPut th@THandle {params} = fmap concat . mapM tPutBatch . batchTransmissions (ba
|
||||
TBTransmissions s n _ -> replicate n <$> tPutLog th s
|
||||
TBTransmission s _ -> (: []) <$> tPutLog th s
|
||||
|
||||
tPutLog :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
|
||||
tPutLog :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ())
|
||||
tPutLog th s = do
|
||||
r <- tPutBlock th s
|
||||
case r of
|
||||
@@ -1314,11 +1336,11 @@ data TransportBatch r = TBTransmissions ByteString Int [r] | TBTransmission Byte
|
||||
batchTransmissions :: Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission) -> [TransportBatch ()]
|
||||
batchTransmissions batch bSize = batchTransmissions' batch bSize . L.map (,())
|
||||
|
||||
-- | encodes and batches transmissions into blocks,
|
||||
-- | encodes and batches transmissions into blocks
|
||||
batchTransmissions' :: forall r. Bool -> Int -> NonEmpty (Either TransportError SentRawTransmission, r) -> [TransportBatch r]
|
||||
batchTransmissions' batch bSize
|
||||
| batch = addBatch . foldr addTransmission ([], 0, 0, [], [])
|
||||
| otherwise = map mkBatch1 . L.toList
|
||||
batchTransmissions' batch bSize ts
|
||||
| batch = batchTransmissions_ bSize $ L.map (first $ fmap tEncodeForBatch) ts
|
||||
| otherwise = map mkBatch1 $ L.toList ts
|
||||
where
|
||||
mkBatch1 :: (Either TransportError SentRawTransmission, r) -> TransportBatch r
|
||||
mkBatch1 (t_, r) = case t_ of
|
||||
@@ -1329,17 +1351,21 @@ batchTransmissions' batch bSize
|
||||
| otherwise -> TBError TELargeMsg r
|
||||
where
|
||||
s = tEncode t
|
||||
|
||||
-- | Pack encoded transmissions into batches
|
||||
batchTransmissions_ :: Int -> NonEmpty (Either TransportError ByteString, r) -> [TransportBatch r]
|
||||
batchTransmissions_ bSize = addBatch . foldr addTransmission ([], 0, 0, [], [])
|
||||
where
|
||||
-- 3 = 2 bytes reserved for pad size + 1 for transmission count
|
||||
bSize' = bSize - 3
|
||||
addTransmission :: (Either TransportError SentRawTransmission, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r])
|
||||
addTransmission (t_, r) acc@(bs, len, n, ss, rs) = case t_ of
|
||||
addTransmission :: (Either TransportError ByteString, r) -> ([TransportBatch r], Int, Int, [ByteString], [r]) -> ([TransportBatch r], Int, Int, [ByteString], [r])
|
||||
addTransmission (t_, r) acc@(bs, !len, !n, ss, rs) = case t_ of
|
||||
Left e -> (TBError e r : addBatch acc, 0, 0, [], [])
|
||||
Right t
|
||||
Right s
|
||||
| len' <= bSize' && n < 255 -> (bs, len', 1 + n, s : ss, r : rs)
|
||||
| sLen <= bSize' -> (addBatch acc, sLen, 1, [s], [r])
|
||||
| otherwise -> (TBError TELargeMsg r : addBatch acc, 0, 0, [], [])
|
||||
where
|
||||
s = tEncodeForBatch t
|
||||
sLen = B.length s
|
||||
len' = len + sLen
|
||||
addBatch :: ([TransportBatch r], Int, Int, [ByteString], [r]) -> [TransportBatch r]
|
||||
@@ -1362,7 +1388,7 @@ tEncodeBatch1 t = lenEncode 1 `B.cons` tEncodeForBatch t
|
||||
-- tForAuth is lazy to avoid computing it when there is no key to sign
|
||||
data TransmissionForAuth = TransmissionForAuth {tForAuth :: ~ByteString, tToSend :: ByteString}
|
||||
|
||||
encodeTransmissionForAuth :: ProtocolEncoding e c => THandleParams -> Transmission c -> TransmissionForAuth
|
||||
encodeTransmissionForAuth :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> TransmissionForAuth
|
||||
encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId} t =
|
||||
TransmissionForAuth {tForAuth, tToSend = if implySessId then t' else tForAuth}
|
||||
where
|
||||
@@ -1370,24 +1396,24 @@ encodeTransmissionForAuth THandleParams {thVersion = v, sessionId, implySessId}
|
||||
t' = encodeTransmission_ v t
|
||||
{-# INLINE encodeTransmissionForAuth #-}
|
||||
|
||||
encodeTransmission :: ProtocolEncoding e c => THandleParams -> Transmission c -> ByteString
|
||||
encodeTransmission :: ProtocolEncoding v e c => THandleParams v -> Transmission c -> ByteString
|
||||
encodeTransmission THandleParams {thVersion = v, sessionId, implySessId} t =
|
||||
if implySessId then t' else smpEncode sessionId <> t'
|
||||
where
|
||||
t' = encodeTransmission_ v t
|
||||
{-# INLINE encodeTransmission #-}
|
||||
|
||||
encodeTransmission_ :: ProtocolEncoding e c => Version -> Transmission c -> ByteString
|
||||
encodeTransmission_ :: ProtocolEncoding v e c => Version v -> Transmission c -> ByteString
|
||||
encodeTransmission_ v (CorrId corrId, queueId, command) =
|
||||
smpEncode (corrId, queueId) <> encodeProtocol v command
|
||||
{-# INLINE encodeTransmission_ #-}
|
||||
|
||||
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
|
||||
tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission))
|
||||
tGetParse :: Transport c => THandle v c -> IO (NonEmpty (Either TransportError RawTransmission))
|
||||
tGetParse th@THandle {params} = eitherList (tParse params) <$> tGetBlock th
|
||||
{-# INLINE tGetParse #-}
|
||||
|
||||
tParse :: THandleParams -> ByteString -> NonEmpty (Either TransportError RawTransmission)
|
||||
tParse :: THandleParams v -> ByteString -> NonEmpty (Either TransportError RawTransmission)
|
||||
tParse thParams@THandleParams {batch} s
|
||||
| batch = eitherList (L.map (\(Large t) -> tParse1 t)) ts
|
||||
| otherwise = [tParse1 s]
|
||||
@@ -1399,24 +1425,24 @@ eitherList :: (a -> NonEmpty (Either e b)) -> Either e a -> NonEmpty (Either e b
|
||||
eitherList = either (\e -> [Left e])
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet :: forall err cmd c. (ProtocolEncoding err cmd, Transport c) => THandle c -> IO (NonEmpty (SignedTransmission err cmd))
|
||||
tGet :: forall v err cmd c. (ProtocolEncoding v err cmd, Transport c) => THandle v c -> IO (NonEmpty (SignedTransmission err cmd))
|
||||
tGet th@THandle {params} = L.map (tDecodeParseValidate params) <$> tGetParse th
|
||||
|
||||
tDecodeParseValidate :: forall err cmd. ProtocolEncoding err cmd => THandleParams -> Either TransportError RawTransmission -> SignedTransmission err cmd
|
||||
tDecodeParseValidate :: forall v err cmd. ProtocolEncoding v err cmd => THandleParams v -> Either TransportError RawTransmission -> SignedTransmission err cmd
|
||||
tDecodeParseValidate THandleParams {sessionId, thVersion = v, implySessId} = \case
|
||||
Right RawTransmission {authenticator, authorized, sessId, corrId, entityId, command}
|
||||
| implySessId || sessId == sessionId ->
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> decodeTAuthBytes authenticator
|
||||
in either (const $ tError corrId) (tParseValidate authorized) decodedTransmission
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PESession))
|
||||
| otherwise -> (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PESession))
|
||||
Left _ -> tError ""
|
||||
where
|
||||
tError :: ByteString -> SignedTransmission err cmd
|
||||
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @err @cmd PEBlock))
|
||||
tError corrId = (Nothing, "", (CorrId corrId, "", Left $ fromProtocolError @v @err @cmd PEBlock))
|
||||
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> SignedTransmission err cmd
|
||||
tParseValidate signed t@(sig, corrId, entityId, command) =
|
||||
let cmd = parseProtocol @err @cmd v command >>= checkCredentials t
|
||||
let cmd = parseProtocol @v @err @cmd v command >>= checkCredentials t
|
||||
in (sig, signed, (CorrId corrId, entityId, cmd))
|
||||
|
||||
$(J.deriveJSON defaultJSON ''MsgFlags)
|
||||
|
||||
@@ -380,7 +380,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do
|
||||
CPQuit -> pure ()
|
||||
CPSkip -> pure ()
|
||||
|
||||
runClientTransport :: Transport c => THandle c -> M ()
|
||||
runClientTransport :: Transport c => THandleSMP c -> M ()
|
||||
runClientTransport th@THandle {params = THandleParams {thVersion, sessionId}} = do
|
||||
q <- asks $ tbqSize . config
|
||||
ts <- liftIO getSystemTime
|
||||
@@ -428,7 +428,7 @@ cancelSub sub =
|
||||
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
|
||||
_ -> return ()
|
||||
|
||||
receive :: Transport c => THandle c -> Client -> M ()
|
||||
receive :: Transport c => THandleSMP c -> Client -> M ()
|
||||
receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActiveAt, sessionId} = do
|
||||
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " receive"
|
||||
forever $ do
|
||||
@@ -449,7 +449,7 @@ receive th@THandle {params = THandleParams {thAuth}} Client {rcvQ, sndQ, rcvActi
|
||||
VRFailed -> Left (corrId, queueId, ERR AUTH)
|
||||
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
|
||||
|
||||
send :: Transport c => THandle c -> Client -> IO ()
|
||||
send :: Transport c => THandleSMP c -> Client -> IO ()
|
||||
send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do
|
||||
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " send"
|
||||
forever $ do
|
||||
@@ -464,7 +464,7 @@ send h@THandle {params} Client {sndQ, sessionId, sndActiveAt} = do
|
||||
NMSG {} -> 0
|
||||
_ -> 1
|
||||
|
||||
disconnectTransport :: Transport c => THandle c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
|
||||
disconnectTransport :: Transport c => THandle v c -> TVar SystemTime -> TVar SystemTime -> ExpirationConfig -> IO Bool -> IO ()
|
||||
disconnectTransport THandle {connection, params = THandleParams {sessionId}} rcvActiveAt sndActiveAt expCfg noSubscriptions = do
|
||||
labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disconnectTransport"
|
||||
loop
|
||||
|
||||
@@ -33,9 +33,8 @@ import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport (ATransport, VersionSMP, VersionRangeSMP)
|
||||
import Simplex.Messaging.Transport.Server (SocketState, TransportServerConfig, loadFingerprint, loadTLSServerParams, newSocketState)
|
||||
import Simplex.Messaging.Version
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
@@ -73,7 +72,7 @@ data ServerConfig = ServerConfig
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath,
|
||||
-- | SMP client-server protocol version range
|
||||
smpServerVRange :: VersionRange,
|
||||
smpServerVRange :: VersionRangeSMP,
|
||||
-- | TCP transport config
|
||||
transportConfig :: TransportServerConfig,
|
||||
-- | run listener on control port
|
||||
@@ -128,7 +127,7 @@ data Client = Client
|
||||
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
|
||||
endThreads :: TVar (IntMap (Weak ThreadId)),
|
||||
endThreadSeq :: TVar Int,
|
||||
thVersion :: Version,
|
||||
thVersion :: VersionSMP,
|
||||
sessionId :: ByteString,
|
||||
connected :: TVar Bool,
|
||||
createdAt :: SystemTime,
|
||||
@@ -152,7 +151,7 @@ newServer = do
|
||||
savingLock <- createLock
|
||||
return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers, savingLock}
|
||||
|
||||
newClient :: TVar Int -> Natural -> Version -> ByteString -> SystemTime -> STM Client
|
||||
newClient :: TVar Int -> Natural -> VersionSMP -> ByteString -> SystemTime -> STM Client
|
||||
newClient nextClientId qSize thVersion sessionId createdAt = do
|
||||
clientId <- stateTVar nextClientId $ \next -> (next, next + 1)
|
||||
subscriptions <- TM.empty
|
||||
|
||||
@@ -17,14 +17,14 @@ data QueueRec = QueueRec
|
||||
notifier :: !(Maybe NtfCreds),
|
||||
status :: !ServerQueueStatus
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
data NtfCreds = NtfCreds
|
||||
{ notifierId :: !NotifierId,
|
||||
notifierKey :: !NtfPublicAuthKey,
|
||||
rcvNtfDhSecret :: !RcvNtfDhSecret
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
deriving (Show)
|
||||
|
||||
instance StrEncoding NtfCreds where
|
||||
strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} = strEncode (notifierId, notifierKey, rcvNtfDhSecret)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
@@ -27,10 +28,15 @@
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
module Simplex.Messaging.Transport
|
||||
( -- * SMP transport parameters
|
||||
SMPVersion,
|
||||
VersionSMP,
|
||||
VersionRangeSMP,
|
||||
THandleSMP,
|
||||
supportedClientSMPRelayVRange,
|
||||
supportedServerSMPRelayVRange,
|
||||
currentClientSMPRelayVersion,
|
||||
currentServerSMPRelayVersion,
|
||||
batchCmdsSMPVersion,
|
||||
basicAuthSMPVersion,
|
||||
subModeSMPVersion,
|
||||
authCmdsSMPVersion,
|
||||
@@ -85,6 +91,7 @@ import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Default (def)
|
||||
import Data.Functor (($>))
|
||||
import Data.Version (showVersion)
|
||||
import Data.Word (Word16)
|
||||
import qualified Data.X509 as X
|
||||
import qualified Data.X509.Validation as XV
|
||||
import GHC.IO.Handle.Internals (ioe_EOF)
|
||||
@@ -98,6 +105,7 @@ import Simplex.Messaging.Parsers (dropPrefix, parseRead1, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport.Buffer
|
||||
import Simplex.Messaging.Util (bshow, catchAll, catchAll_, liftEitherWith)
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
@@ -116,30 +124,41 @@ smpBlockSize = 16384
|
||||
-- 6 - allow creating queues without subscribing (9/10/2023)
|
||||
-- 7 - support authenticated encryption to verify senders' commands, imply but do NOT send session ID in signed part (2/3/2024)
|
||||
|
||||
batchCmdsSMPVersion :: Version
|
||||
batchCmdsSMPVersion = 4
|
||||
data SMPVersion
|
||||
|
||||
basicAuthSMPVersion :: Version
|
||||
basicAuthSMPVersion = 5
|
||||
instance VersionScope SMPVersion
|
||||
|
||||
subModeSMPVersion :: Version
|
||||
subModeSMPVersion = 6
|
||||
type VersionSMP = Version SMPVersion
|
||||
|
||||
authCmdsSMPVersion :: Version
|
||||
authCmdsSMPVersion = 7
|
||||
type VersionRangeSMP = VersionRange SMPVersion
|
||||
|
||||
currentClientSMPRelayVersion :: Version
|
||||
currentClientSMPRelayVersion = 6
|
||||
pattern VersionSMP :: Word16 -> VersionSMP
|
||||
pattern VersionSMP v = Version v
|
||||
|
||||
currentServerSMPRelayVersion :: Version
|
||||
currentServerSMPRelayVersion = 6
|
||||
batchCmdsSMPVersion :: VersionSMP
|
||||
batchCmdsSMPVersion = VersionSMP 4
|
||||
|
||||
basicAuthSMPVersion :: VersionSMP
|
||||
basicAuthSMPVersion = VersionSMP 5
|
||||
|
||||
subModeSMPVersion :: VersionSMP
|
||||
subModeSMPVersion = VersionSMP 6
|
||||
|
||||
authCmdsSMPVersion :: VersionSMP
|
||||
authCmdsSMPVersion = VersionSMP 7
|
||||
|
||||
currentClientSMPRelayVersion :: VersionSMP
|
||||
currentClientSMPRelayVersion = VersionSMP 6
|
||||
|
||||
currentServerSMPRelayVersion :: VersionSMP
|
||||
currentServerSMPRelayVersion = VersionSMP 6
|
||||
|
||||
-- minimal supported protocol version is 4
|
||||
-- TODO remove code that supports sending commands without batching
|
||||
supportedClientSMPRelayVRange :: VersionRange
|
||||
supportedClientSMPRelayVRange :: VersionRangeSMP
|
||||
supportedClientSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentClientSMPRelayVersion
|
||||
|
||||
supportedServerSMPRelayVRange :: VersionRange
|
||||
supportedServerSMPRelayVRange :: VersionRangeSMP
|
||||
supportedServerSMPRelayVRange = mkVersionRange batchCmdsSMPVersion currentServerSMPRelayVersion
|
||||
|
||||
simplexMQVersion :: String
|
||||
@@ -287,16 +306,18 @@ instance Transport TLS where
|
||||
-- * SMP transport
|
||||
|
||||
-- | The handle for SMP encrypted transport connection over Transport.
|
||||
data THandle c = THandle
|
||||
data THandle v c = THandle
|
||||
{ connection :: c,
|
||||
params :: THandleParams
|
||||
params :: THandleParams v
|
||||
}
|
||||
|
||||
data THandleParams = THandleParams
|
||||
type THandleSMP c = THandle SMPVersion c
|
||||
|
||||
data THandleParams v = THandleParams
|
||||
{ sessionId :: SessionId,
|
||||
blockSize :: Int,
|
||||
-- | agreed server protocol version
|
||||
thVersion :: Version,
|
||||
thVersion :: Version v,
|
||||
-- | peer public key for command authorization and shared secrets for entity ID encryption
|
||||
thAuth :: Maybe THandleAuth,
|
||||
-- | do NOT send session ID in transmission, but include it into signed message
|
||||
@@ -316,7 +337,7 @@ data THandleAuth = THandleAuth
|
||||
type SessionId = ByteString
|
||||
|
||||
data ServerHandshake = ServerHandshake
|
||||
{ smpVersionRange :: VersionRange,
|
||||
{ smpVersionRange :: VersionRangeSMP,
|
||||
sessionId :: SessionId,
|
||||
-- pub key to agree shared secrets for command authorization and entity ID encryption.
|
||||
authPubKey :: Maybe (X.CertificateChain, X.SignedExact X.PubKey)
|
||||
@@ -324,7 +345,7 @@ data ServerHandshake = ServerHandshake
|
||||
|
||||
data ClientHandshake = ClientHandshake
|
||||
{ -- | agreed SMP server protocol version
|
||||
smpVersion :: Version,
|
||||
smpVersion :: VersionSMP,
|
||||
-- | server identity - CA certificate fingerprint
|
||||
keyHash :: C.KeyHash,
|
||||
-- pub key to agree shared secret for entity ID encryption, shared secret for command authorization is agreed using per-queue keys.
|
||||
@@ -358,12 +379,12 @@ instance Encoding ServerHandshake where
|
||||
C.SignedObject key <- smpP
|
||||
pure (cert, key)
|
||||
|
||||
encodeAuthEncryptCmds :: Encoding a => Version -> Maybe a -> ByteString
|
||||
encodeAuthEncryptCmds :: Encoding a => VersionSMP -> Maybe a -> ByteString
|
||||
encodeAuthEncryptCmds v k
|
||||
| v >= authCmdsSMPVersion = maybe "" smpEncode k
|
||||
| otherwise = ""
|
||||
|
||||
authEncryptCmdsP :: Version -> Parser a -> Parser (Maybe a)
|
||||
authEncryptCmdsP :: VersionSMP -> Parser a -> Parser (Maybe a)
|
||||
authEncryptCmdsP v p = if v >= authCmdsSMPVersion then Just <$> p else pure Nothing
|
||||
|
||||
-- | Error of SMP encrypted transport over TCP.
|
||||
@@ -412,13 +433,13 @@ serializeTransportError = \case
|
||||
TEHandshake e -> "HANDSHAKE " <> bshow e
|
||||
|
||||
-- | Pad and send block to SMP transport.
|
||||
tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
|
||||
tPutBlock :: Transport c => THandle v c -> ByteString -> IO (Either TransportError ())
|
||||
tPutBlock THandle {connection = c, params = THandleParams {blockSize}} block =
|
||||
bimapM (const $ pure TELargeMsg) (cPut c) $
|
||||
C.pad block blockSize
|
||||
|
||||
-- | Receive block from SMP transport.
|
||||
tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString)
|
||||
tGetBlock :: Transport c => THandle v c -> IO (Either TransportError ByteString)
|
||||
tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
|
||||
msg <- cGet c blockSize
|
||||
if B.length msg == blockSize
|
||||
@@ -428,7 +449,7 @@ tGetBlock THandle {connection = c, params = THandleParams {blockSize}} = do
|
||||
-- | Server SMP transport handshake.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpServerHandshake :: forall c. Transport c => C.APrivateSignKey -> c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c)
|
||||
smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
|
||||
sk = C.signX509 serverSignKey $ C.publicToX509 k
|
||||
@@ -445,7 +466,7 @@ smpServerHandshake serverSignKey c (k, pk) kh smpVRange = do
|
||||
-- | Client SMP transport handshake.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpClientHandshake :: forall c. Transport c => c -> C.KeyPairX25519 -> C.KeyHash -> VersionRangeSMP -> ExceptT TransportError IO (THandleSMP c)
|
||||
smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
|
||||
let th@THandle {params = THandleParams {sessionId}} = smpTHandle c
|
||||
ServerHandshake {sessionId = sessId, smpVersionRange, authPubKey} <- getHandshake th
|
||||
@@ -465,24 +486,24 @@ smpClientHandshake c (k, pk) keyHash@(C.KeyHash kh) smpVRange = do
|
||||
pure $ smpThHandle th v pk sk_
|
||||
Nothing -> throwE $ TEHandshake VERSION
|
||||
|
||||
smpThHandle :: forall c. THandle c -> Version -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandle c
|
||||
smpThHandle :: forall c. THandleSMP c -> VersionSMP -> C.PrivateKeyX25519 -> Maybe C.PublicKeyX25519 -> THandleSMP c
|
||||
smpThHandle th@THandle {params} v privKey k_ =
|
||||
-- TODO drop SMP v6: make thAuth non-optional
|
||||
let thAuth = (\k -> THandleAuth {peerPubKey = k, privKey}) <$> k_
|
||||
params' = params {thVersion = v, thAuth, implySessId = v >= authCmdsSMPVersion}
|
||||
in (th :: THandle c) {params = params'}
|
||||
in (th :: THandleSMP c) {params = params'}
|
||||
|
||||
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
|
||||
sendHandshake :: (Transport c, Encoding smp) => THandle v c -> smp -> ExceptT TransportError IO ()
|
||||
sendHandshake th = ExceptT . tPutBlock th . smpEncode
|
||||
|
||||
-- ignores tail bytes to allow future extensions
|
||||
getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp
|
||||
getHandshake :: (Transport c, Encoding smp) => THandle v c -> ExceptT TransportError IO smp
|
||||
getHandshake th = ExceptT $ (first (\_ -> TEHandshake PARSE) . A.parseOnly smpP =<<) <$> tGetBlock th
|
||||
|
||||
smpTHandle :: Transport c => c -> THandle c
|
||||
smpTHandle :: Transport c => c -> THandleSMP c
|
||||
smpTHandle c = THandle {connection = c, params}
|
||||
where
|
||||
params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, thAuth = Nothing, implySessId = False, batch = True}
|
||||
params = THandleParams {sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = VersionSMP 0, thAuth = Nothing, implySessId = False, batch = True}
|
||||
|
||||
$(J.deriveJSON (sumTypeJSON id) ''HandshakeError)
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
{-# LANGUAGE ConstrainedClassMethods #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeSynonymInstances #-}
|
||||
|
||||
module Simplex.Messaging.Version
|
||||
( Version,
|
||||
VersionRange (minVersion, maxVersion),
|
||||
VersionScope,
|
||||
pattern VersionRange,
|
||||
VersionI (..),
|
||||
VersionRangeI (..),
|
||||
@@ -24,47 +27,61 @@ module Simplex.Messaging.Version
|
||||
where
|
||||
|
||||
import Control.Applicative (optional)
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import Data.Aeson.Types ((.:), (.=))
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Simplex.Messaging.Version.Internal (Version (..))
|
||||
|
||||
pattern VersionRange :: Word16 -> Word16 -> VersionRange
|
||||
pattern VersionRange :: Version v -> Version v -> VersionRange v
|
||||
pattern VersionRange v1 v2 <- VRange v1 v2
|
||||
|
||||
{-# COMPLETE VersionRange #-}
|
||||
|
||||
type Version = Word16
|
||||
|
||||
data VersionRange = VRange
|
||||
{ minVersion :: Version,
|
||||
maxVersion :: Version
|
||||
data VersionRange v = VRange
|
||||
{ minVersion :: Version v,
|
||||
maxVersion :: Version v
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance J.FromJSON (VersionRange v) where
|
||||
parseJSON (J.Object v) = do
|
||||
minVersion <- v .: "minVersion"
|
||||
maxVersion <- v .: "maxVersion"
|
||||
pure VRange {minVersion, maxVersion}
|
||||
parseJSON invalid =
|
||||
JT.prependFailure "bad VersionRange, " (JT.typeMismatch "Object" invalid)
|
||||
|
||||
instance J.ToJSON (VersionRange v) where
|
||||
toEncoding VRange {minVersion, maxVersion} = JE.pairs $ ("minVersion" .= minVersion) <> ("maxVersion" .= maxVersion)
|
||||
toJSON VRange {minVersion, maxVersion} = J.object ["minVersion" .= minVersion, "maxVersion" .= maxVersion]
|
||||
|
||||
class VersionScope v
|
||||
|
||||
-- | construct valid version range, to be used in constants
|
||||
mkVersionRange :: Version -> Version -> VersionRange
|
||||
mkVersionRange :: Version v -> Version v -> VersionRange v
|
||||
mkVersionRange v1 v2
|
||||
| v1 <= v2 = VRange v1 v2
|
||||
| otherwise = error "invalid version range"
|
||||
|
||||
safeVersionRange :: Version -> Version -> Maybe VersionRange
|
||||
safeVersionRange :: Version v -> Version v -> Maybe (VersionRange v)
|
||||
safeVersionRange v1 v2
|
||||
| v1 <= v2 = Just $ VRange v1 v2
|
||||
| otherwise = Nothing
|
||||
|
||||
versionToRange :: Version -> VersionRange
|
||||
versionToRange :: Version v -> VersionRange v
|
||||
versionToRange v = VRange v v
|
||||
|
||||
instance Encoding VersionRange where
|
||||
instance VersionScope v => Encoding (VersionRange v) where
|
||||
smpEncode (VRange v1 v2) = smpEncode (v1, v2)
|
||||
smpP =
|
||||
maybe (fail "invalid version range") pure
|
||||
=<< safeVersionRange <$> smpP <*> smpP
|
||||
|
||||
instance StrEncoding VersionRange where
|
||||
instance VersionScope v => StrEncoding (VersionRange v) where
|
||||
strEncode (VRange v1 v2)
|
||||
| v1 == v2 = strEncode v1
|
||||
| otherwise = strEncode v1 <> "-" <> strEncode v2
|
||||
@@ -73,32 +90,23 @@ instance StrEncoding VersionRange where
|
||||
v2 <- maybe (pure v1) (const strP) =<< optional (A.char '-')
|
||||
maybe (fail "invalid version range") pure $ safeVersionRange v1 v2
|
||||
|
||||
instance ToJSON VersionRange where
|
||||
toJSON (VRange v1 v2) = toJSON (v1, v2)
|
||||
toEncoding (VRange v1 v2) = toEncoding (v1, v2)
|
||||
class VersionScope v => VersionI v a | a -> v where
|
||||
type VersionRangeT v a
|
||||
version :: a -> Version v
|
||||
toVersionRangeT :: a -> VersionRange v -> VersionRangeT v a
|
||||
|
||||
instance FromJSON VersionRange where
|
||||
parseJSON v =
|
||||
(\(v1, v2) -> maybe (Left "bad VersionRange") Right $ safeVersionRange v1 v2)
|
||||
<$?> parseJSON v
|
||||
class VersionScope v => VersionRangeI v a | a -> v where
|
||||
type VersionT v a
|
||||
versionRange :: a -> VersionRange v
|
||||
toVersionT :: a -> Version v -> VersionT v a
|
||||
|
||||
class VersionI a where
|
||||
type VersionRangeT a
|
||||
version :: a -> Version
|
||||
toVersionRangeT :: a -> VersionRange -> VersionRangeT a
|
||||
|
||||
class VersionRangeI a where
|
||||
type VersionT a
|
||||
versionRange :: a -> VersionRange
|
||||
toVersionT :: a -> Version -> VersionT a
|
||||
|
||||
instance VersionI Version where
|
||||
type VersionRangeT Version = VersionRange
|
||||
instance VersionScope v => VersionI v (Version v) where
|
||||
type VersionRangeT v (Version v) = VersionRange v
|
||||
version = id
|
||||
toVersionRangeT _ vr = vr
|
||||
|
||||
instance VersionRangeI VersionRange where
|
||||
type VersionT VersionRange = Version
|
||||
instance VersionScope v => VersionRangeI v (VersionRange v) where
|
||||
type VersionT v (VersionRange v) = Version v
|
||||
versionRange = id
|
||||
toVersionT _ v = v
|
||||
|
||||
@@ -109,18 +117,18 @@ pattern Compatible a <- Compatible_ a
|
||||
|
||||
{-# COMPLETE Compatible #-}
|
||||
|
||||
isCompatible :: VersionI a => a -> VersionRange -> Bool
|
||||
isCompatible :: VersionI v a => a -> VersionRange v -> Bool
|
||||
isCompatible x (VRange v1 v2) = let v = version x in v1 <= v && v <= v2
|
||||
|
||||
isCompatibleRange :: VersionRangeI a => a -> VersionRange -> Bool
|
||||
isCompatibleRange :: VersionRangeI v a => a -> VersionRange v -> Bool
|
||||
isCompatibleRange x (VRange min2 max2) = min1 <= max2 && min2 <= max1
|
||||
where
|
||||
VRange min1 max1 = versionRange x
|
||||
|
||||
proveCompatible :: VersionI a => a -> VersionRange -> Maybe (Compatible a)
|
||||
proveCompatible :: VersionI v a => a -> VersionRange v -> Maybe (Compatible a)
|
||||
proveCompatible x vr = x `mkCompatibleIf` (x `isCompatible` vr)
|
||||
|
||||
compatibleVersion :: VersionRangeI a => a -> VersionRange -> Maybe (Compatible (VersionT a))
|
||||
compatibleVersion :: VersionRangeI v a => a -> VersionRange v -> Maybe (Compatible (VersionT v a))
|
||||
compatibleVersion x vr =
|
||||
toVersionT x (min max1 max2) `mkCompatibleIf` isCompatibleRange x vr
|
||||
where
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
module Simplex.Messaging.Version.Internal where
|
||||
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
|
||||
-- Do not use constructor of this type directry
|
||||
newtype Version v = Version Word16
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding (Version v) where
|
||||
smpEncode (Version v) = smpEncode v
|
||||
smpP = Version <$> smpP
|
||||
|
||||
instance StrEncoding (Version v) where
|
||||
strEncode (Version v) = strEncode v
|
||||
strP = Version <$> strP
|
||||
|
||||
instance ToJSON (Version v) where
|
||||
toEncoding (Version v) = toEncoding v
|
||||
toJSON (Version v) = toJSON v
|
||||
|
||||
instance FromJSON (Version v) where
|
||||
parseJSON v = Version <$> parseJSON v
|
||||
@@ -68,12 +68,6 @@ import Simplex.RemoteControl.Types
|
||||
import UnliftIO
|
||||
import UnliftIO.Concurrent
|
||||
|
||||
currentRCVersion :: Version
|
||||
currentRCVersion = 1
|
||||
|
||||
supportedRCVRange :: VersionRange
|
||||
supportedRCVRange = mkVersionRange 1 currentRCVersion
|
||||
|
||||
xrcpBlockSize :: Int
|
||||
xrcpBlockSize = 16384
|
||||
|
||||
@@ -181,7 +175,7 @@ connectRCHost drg pairing@RCHostPairing {caKey, caCert, idPrivKey, knownHost} ct
|
||||
{ ca = certFingerprint caCert,
|
||||
host,
|
||||
port = fromIntegral portNum,
|
||||
v = supportedRCVRange,
|
||||
v = supportedRCPVRange,
|
||||
app = ctrlAppInfo,
|
||||
ts,
|
||||
skey = fst sessKeys,
|
||||
@@ -220,7 +214,7 @@ prepareHostSession
|
||||
unless (ca == tlsHostFingerprint) $ throwError RCEIdentity
|
||||
(kemCiphertext, kemSharedKey) <- liftIO $ sntrup761Enc drg kemPubKey
|
||||
let hybridKey = kemHybridSecret dhPubKey dhPrivKey kemSharedKey
|
||||
unless (isCompatible v supportedRCVRange) $ throwError RCEVersion
|
||||
unless (isCompatible v supportedRCPVRange) $ throwError RCEVersion
|
||||
let keys = HostSessKeys {hybridKey, idPrivKey, sessPrivKey}
|
||||
knownHost' <- updateKnownHost ca dhPubKey
|
||||
let ctrlHello = RCCtrlHello {}
|
||||
@@ -334,7 +328,7 @@ prepareHostHello
|
||||
RCInvitation {v, dh = dhPubKey}
|
||||
hostAppInfo = do
|
||||
logDebug "Preparing session"
|
||||
case compatibleVersion v supportedRCVRange of
|
||||
case compatibleVersion v supportedRCPVRange of
|
||||
Nothing -> throwError RCEVersion
|
||||
Just (Compatible v') -> do
|
||||
nonce <- liftIO . atomically $ C.randomCbNonce drg
|
||||
|
||||
@@ -27,7 +27,7 @@ import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Version (VersionRange)
|
||||
import Simplex.RemoteControl.Types (VersionRangeRCP)
|
||||
|
||||
data RCInvitation = RCInvitation
|
||||
{ -- | CA TLS certificate fingerprint of the controller.
|
||||
@@ -37,7 +37,7 @@ data RCInvitation = RCInvitation
|
||||
host :: TransportHost,
|
||||
port :: Word16,
|
||||
-- | Supported version range for remote control protocol
|
||||
v :: VersionRange,
|
||||
v :: VersionRangeRCP,
|
||||
-- | Application information
|
||||
app :: J.Value,
|
||||
-- | Session start time in seconds since epoch
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
@@ -17,6 +18,7 @@ import Data.ByteString (ByteString)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Word (Word16)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.SNTRUP761
|
||||
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
|
||||
@@ -26,7 +28,8 @@ import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON)
|
||||
import Simplex.Messaging.Transport (TLS)
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util (safeDecodeUtf8)
|
||||
import Simplex.Messaging.Version (Version, VersionRange, mkVersionRange)
|
||||
import Simplex.Messaging.Version (VersionRange, VersionScope, mkVersionRange)
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import UnliftIO
|
||||
|
||||
data RCErrorType
|
||||
@@ -92,24 +95,37 @@ instance StrEncoding RCErrorType where
|
||||
|
||||
-- * Discovery
|
||||
|
||||
ipProbeVersionRange :: VersionRange
|
||||
ipProbeVersionRange = mkVersionRange 1 1
|
||||
data RCPVersion
|
||||
|
||||
instance VersionScope RCPVersion
|
||||
|
||||
type VersionRCP = Version RCPVersion
|
||||
|
||||
type VersionRangeRCP = VersionRange RCPVersion
|
||||
|
||||
pattern VersionRCP :: Word16 -> VersionRCP
|
||||
pattern VersionRCP v = Version v
|
||||
|
||||
currentRCPVersion :: VersionRCP
|
||||
currentRCPVersion = VersionRCP 1
|
||||
|
||||
supportedRCPVRange :: VersionRangeRCP
|
||||
supportedRCPVRange = mkVersionRange (VersionRCP 1) currentRCPVersion
|
||||
|
||||
data IpProbe = IpProbe
|
||||
{ versionRange :: VersionRange,
|
||||
{ versionRange :: VersionRangeRCP,
|
||||
randomNonce :: ByteString
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding IpProbe where
|
||||
smpEncode IpProbe {versionRange, randomNonce} = smpEncode (versionRange, 'I', randomNonce)
|
||||
|
||||
smpP = IpProbe <$> (smpP <* "I") *> smpP
|
||||
|
||||
-- * Session
|
||||
|
||||
data RCHostHello = RCHostHello
|
||||
{ v :: Version,
|
||||
{ v :: VersionRCP,
|
||||
ca :: C.KeyHash,
|
||||
app :: J.Value,
|
||||
kem :: KEMPublicKey
|
||||
|
||||
+191
-107
@@ -4,6 +4,7 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PostfixOperators #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
@@ -12,12 +13,12 @@ module AgentTests (agentTests) where
|
||||
|
||||
import AgentTests.ConnectionRequestTests
|
||||
import AgentTests.DoubleRatchetTests (doubleRatchetTests)
|
||||
import AgentTests.FunctionalAPITests (functionalAPITests)
|
||||
import AgentTests.FunctionalAPITests (functionalAPITests, pattern Msg, pattern Msg')
|
||||
import AgentTests.MigrationTests (migrationTests)
|
||||
import AgentTests.NotificationTests (notificationTests)
|
||||
import AgentTests.SQLiteTests (storeTests)
|
||||
import Control.Concurrent
|
||||
import Control.Monad (forM_)
|
||||
import Control.Monad (forM_, when)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Maybe (fromJust)
|
||||
@@ -26,15 +27,18 @@ import GHC.Stack (withFrozenCallStack)
|
||||
import Network.HTTP.Types (urlEncode)
|
||||
import SMPAgentClient
|
||||
import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Protocol hiding (MID, CONF, INFO, REQ)
|
||||
import qualified Simplex.Messaging.Agent.Protocol as A
|
||||
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOn, pattern IKPQOff, pattern PQEncOn, pattern PQSupportOn, pattern PQSupportOff)
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import Simplex.Messaging.Protocol (ErrorType (..))
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..))
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import System.Directory (removeFile)
|
||||
import System.Timeout
|
||||
import Test.Hspec
|
||||
import Util
|
||||
|
||||
agentTests :: ATransport -> Spec
|
||||
agentTests (ATransport t) = do
|
||||
@@ -46,24 +50,25 @@ agentTests (ATransport t) = do
|
||||
describe "Migration tests" migrationTests
|
||||
describe "SMP agent protocol syntax" $ syntaxTests t
|
||||
describe "Establishing duplex connection (via agent protocol)" $ do
|
||||
-- These tests are disabled because the agent does not work correctly with multiple connected TCP clients
|
||||
xit "should connect via one server and one agent" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnection t
|
||||
xit "should connect via one server and one agent (random IDs)" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
|
||||
skip "These tests are disabled because the agent does not work correctly with multiple connected TCP clients" $
|
||||
describe "one agent" $ do
|
||||
it "should connect via one server and one agent" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnection t
|
||||
it "should connect via one server and one agent (random IDs)" $ do
|
||||
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via one server and 2 agents" $ do
|
||||
smpAgentTest2_2_1 $ testDuplexConnection t
|
||||
it "should connect via one server and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via 2 servers and 2 agents" $ do
|
||||
smpAgentTest2_2_2 $ testDuplexConnection t
|
||||
it "should connect via 2 servers and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
|
||||
describe "should connect via 2 servers and 2 agents" $ do
|
||||
pqMatrix2 t smpAgentTest2_2_2 testDuplexConnection'
|
||||
describe "should connect via 2 servers and 2 agents (random IDs)" $ do
|
||||
pqMatrix2 t smpAgentTest2_2_2 testDuplexConnRandomIds'
|
||||
describe "Establishing connections via `contact connection`" $ do
|
||||
it "should connect via contact connection with one server and 3 agents" $ do
|
||||
smpAgentTest3 $ testContactConnection t
|
||||
it "should connect via contact connection with one server and 2 agents (random IDs)" $ do
|
||||
smpAgentTest2_2_1 $ testContactConnRandomIds t
|
||||
describe "should connect via contact connection with one server and 3 agents" $ do
|
||||
pqMatrix3 t smpAgentTest3 testContactConnection
|
||||
describe "should connect via contact connection with one server and 2 agents (random IDs)" $ do
|
||||
pqMatrix2NoInv t smpAgentTest2_2_1 testContactConnRandomIds
|
||||
it "should support rejecting contact request" $ do
|
||||
smpAgentTest2_2_1 $ testRejectContactRequest t
|
||||
describe "Connection subscriptions" $ do
|
||||
@@ -72,8 +77,8 @@ agentTests (ATransport t) = do
|
||||
it "should send notifications to client when server disconnects" $ do
|
||||
smpAgentServerTest $ testSubscrNotification t
|
||||
describe "Message delivery and server reconnection" $ do
|
||||
it "should deliver messages after losing server connection and re-connecting" $ do
|
||||
smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t
|
||||
describe "should deliver messages after losing server connection and re-connecting" $
|
||||
pqMatrix2 t smpAgentTest2_2_2_needs_server testMsgDeliveryServerRestart
|
||||
it "should connect to the server when server goes up if it initially was down" $ do
|
||||
smpAgentTestN [] $ testServerConnectionAfterError t
|
||||
it "should deliver pending messages after agent restarting" $ do
|
||||
@@ -133,6 +138,9 @@ action #> (corrId, connId, cmd) = withFrozenCallStack $ action `shouldReturn` (c
|
||||
(=#>) :: IO (AEntityTransmissionOrError 'Agent 'AEConn) -> (AEntityTransmission 'Agent 'AEConn -> Bool) -> Expectation
|
||||
action =#> p = withFrozenCallStack $ action >>= (`shouldSatisfy` p . correctTransmission)
|
||||
|
||||
pattern MID :: AgentMsgId -> ACommand 'Agent 'AEConn
|
||||
pattern MID msgId = A.MID msgId PQEncOn
|
||||
|
||||
correctTransmission :: (ACorrId, ConnId, Either AgentErrorType cmd) -> (ACorrId, ConnId, cmd)
|
||||
correctTransmission (corrId, connId, cmdOrErr) = case cmdOrErr of
|
||||
Right cmd -> (corrId, connId, cmd)
|
||||
@@ -161,130 +169,188 @@ h #:# err = tryGet `shouldReturn` ()
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent e
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
|
||||
type PQMatrix2 c =
|
||||
HasCallStack =>
|
||||
TProxy c ->
|
||||
(HasCallStack => (c -> c -> IO ()) -> Expectation) ->
|
||||
(HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> IO ()) ->
|
||||
Spec
|
||||
|
||||
pattern Msg' :: AgentMsgId -> MsgBody -> ACommand 'Agent e
|
||||
pattern Msg' aMsgId msgBody <- MSG MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} _ msgBody
|
||||
pqMatrix2 :: PQMatrix2 c
|
||||
pqMatrix2 = pqMatrix2_ True
|
||||
|
||||
pqMatrix2NoInv :: PQMatrix2 c
|
||||
pqMatrix2NoInv = pqMatrix2_ False
|
||||
|
||||
pqMatrix2_ :: Bool -> PQMatrix2 c
|
||||
pqMatrix2_ pqInv _ smpTest test = do
|
||||
it "dh/dh handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOff)
|
||||
it "dh/pq handshake" $ smpTest $ \a b -> test (a, IKPQOff) (b, PQSupportOn)
|
||||
it "pq/dh handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOff)
|
||||
it "pq/pq handshake" $ smpTest $ \a b -> test (a, IKPQOn) (b, PQSupportOn)
|
||||
when pqInv $ do
|
||||
it "pq-inv/dh handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOff)
|
||||
it "pq-inv/pq handshake" $ smpTest $ \a b -> test (a, IKUsePQ) (b, PQSupportOn)
|
||||
|
||||
pqMatrix3 ::
|
||||
HasCallStack =>
|
||||
TProxy c ->
|
||||
(HasCallStack => (c -> c -> c -> IO ()) -> Expectation) ->
|
||||
(HasCallStack => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()) ->
|
||||
Spec
|
||||
pqMatrix3 _ smpTest test = do
|
||||
it "dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOff)
|
||||
it "dh/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOff) (c, PQSupportOn)
|
||||
it "dh/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOff)
|
||||
it "dh/pq/pq" $ smpTest $ \a b c -> test (a, IKPQOff) (b, PQSupportOn) (c, PQSupportOn)
|
||||
it "pq/dh/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOff)
|
||||
it "pq/dh/pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOff) (c, PQSupportOn)
|
||||
it "pq/pq/dh" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOff)
|
||||
it "pq" $ smpTest $ \a b c -> test (a, IKPQOn) (b, PQSupportOn) (c, PQSupportOn)
|
||||
|
||||
testDuplexConnection :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnection _ alice bob = do
|
||||
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV subscribe")
|
||||
testDuplexConnection _ alice bob = testDuplexConnection' (alice, IKPQOn) (bob, PQSupportOn)
|
||||
|
||||
testDuplexConnection' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testDuplexConnection' (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
pqSup = CR.pqEncToSupport pq
|
||||
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "bob", Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:)
|
||||
pqSup' `shouldBe` pqSup
|
||||
alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
bob <# ("", "alice", INFO "alice's connInfo")
|
||||
bob <# ("", "alice", CON)
|
||||
alice <# ("", "bob", CON)
|
||||
bob <# ("", "alice", A.INFO pqSup "alice's connInfo")
|
||||
bob <# ("", "alice", CON pq)
|
||||
alice <# ("", "bob", CON pq)
|
||||
-- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4
|
||||
alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", MID 4)
|
||||
alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", A.MID 4 pq)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg' 4 "hello") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg' 4 pq' "hello") -> pq == pq'; _ -> False
|
||||
bob #: ("12", "alice", "ACK 4") #> ("12", "alice", OK)
|
||||
alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", MID 5)
|
||||
alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", A.MID 5 pq)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
bob <#= \case ("", "alice", Msg' 5 "how are you?") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg' 5 pq' "how are you?") -> pq == pq'; _ -> False
|
||||
bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK)
|
||||
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 6)
|
||||
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", A.MID 6 pq)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice <#= \case ("", "bob", Msg' 6 "hello too") -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg' 6 pq' "hello too") -> pq == pq'; _ -> False
|
||||
alice #: ("3a", "bob", "ACK 6") #> ("3a", "bob", OK)
|
||||
bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", MID 7)
|
||||
bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", A.MID 7 pq)
|
||||
bob <# ("", "alice", SENT 7)
|
||||
alice <#= \case ("", "bob", Msg' 7 "message 1") -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg' 7 pq' "message 1") -> pq == pq'; _ -> False
|
||||
alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK)
|
||||
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
|
||||
bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", MID 8)
|
||||
bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", A.MID 8 pq)
|
||||
bob <# ("", "alice", MERR 8 (SMP AUTH))
|
||||
alice #: ("6", "bob", "DEL") #> ("6", "bob", OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnRandomIds _ alice bob = do
|
||||
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV subscribe")
|
||||
testDuplexConnRandomIds :: (HasCallStack, Transport c) => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnRandomIds _ alice bob = testDuplexConnRandomIds' (alice, IKPQOn) (bob, PQSupportOn)
|
||||
|
||||
testDuplexConnRandomIds' :: (HasCallStack, Transport c) => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testDuplexConnRandomIds' (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
pqSup = CR.pqEncToSupport pq
|
||||
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo")
|
||||
("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo")
|
||||
("", bobConn', Right (A.CONF confId pqSup' _ "bob's connInfo")) <- (alice <#:)
|
||||
pqSup' `shouldBe` pqSup
|
||||
bobConn' `shouldBe` bobConn
|
||||
alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False
|
||||
bob <# ("", aliceConn, INFO "alice's connInfo")
|
||||
bob <# ("", aliceConn, CON)
|
||||
alice <# ("", bobConn, CON)
|
||||
alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, MID 4)
|
||||
bob <# ("", aliceConn, A.INFO pqSup "alice's connInfo")
|
||||
bob <# ("", aliceConn, CON pq)
|
||||
alice <# ("", bobConn, CON pq)
|
||||
alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, A.MID 4 pq)
|
||||
alice <# ("", bobConn, SENT 4)
|
||||
bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False
|
||||
bob <#= \case ("", c, Msg' 4 pq' "hello") -> c == aliceConn && pq == pq'; _ -> False
|
||||
bob #: ("12", aliceConn, "ACK 4") #> ("12", aliceConn, OK)
|
||||
alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, MID 5)
|
||||
alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, A.MID 5 pq)
|
||||
alice <# ("", bobConn, SENT 5)
|
||||
bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False
|
||||
bob <#= \case ("", c, Msg' 5 pq' "how are you?") -> c == aliceConn && pq == pq'; _ -> False
|
||||
bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK)
|
||||
bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, MID 6)
|
||||
bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, A.MID 6 pq)
|
||||
bob <# ("", aliceConn, SENT 6)
|
||||
alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False
|
||||
alice <#= \case ("", c, Msg' 6 pq' "hello too") -> c == bobConn && pq == pq'; _ -> False
|
||||
alice #: ("3a", bobConn, "ACK 6") #> ("3a", bobConn, OK)
|
||||
bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, MID 7)
|
||||
bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, A.MID 7 pq)
|
||||
bob <# ("", aliceConn, SENT 7)
|
||||
alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False
|
||||
alice <#= \case ("", c, Msg' 7 pq' "message 1") -> c == bobConn && pq == pq'; _ -> False
|
||||
alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK)
|
||||
alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK)
|
||||
bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, MID 8)
|
||||
bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, A.MID 8 pq)
|
||||
bob <# ("", aliceConn, MERR 8 (SMP AUTH))
|
||||
alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testContactConnection :: Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testContactConnection _ alice bob tom = do
|
||||
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON subscribe")
|
||||
testContactConnection :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> (c, PQSupport) -> IO ()
|
||||
testContactConnection (alice, aPQ) (bob, bPQ) (tom, tPQ) = do
|
||||
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
abPQ = pqConnectionMode aPQ bPQ
|
||||
abPQSup = CR.pqEncToSupport abPQ
|
||||
aPQMode = CR.connPQEncryption aPQ
|
||||
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
|
||||
alice #: ("2", "bob", "ACPT " <> aInvId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:)
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "alice_contact", Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:)
|
||||
pqSup' `shouldBe` bPQ
|
||||
alice #: ("2", "bob", "ACPT " <> aInvId <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
("", "alice", Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:)
|
||||
pqSup'' `shouldBe` abPQSup
|
||||
bob #: ("12", "alice", "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", "alice", OK)
|
||||
alice <# ("", "bob", INFO "bob's connInfo 2")
|
||||
alice <# ("", "bob", CON)
|
||||
bob <# ("", "alice", CON)
|
||||
alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", MID 4)
|
||||
alice <# ("", "bob", A.INFO abPQSup "bob's connInfo 2")
|
||||
alice <# ("", "bob", CON abPQ)
|
||||
bob <# ("", "alice", CON abPQ)
|
||||
alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", A.MID 4 abPQ)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg' 4 pq' "hi") -> pq' == abPQ; _ -> False
|
||||
bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK)
|
||||
|
||||
tom #: ("21", "alice", "JOIN T " <> cReq' <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK)
|
||||
("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:)
|
||||
alice #: ("4", "tom", "ACPT " <> aInvId' <> " 16\nalice's connInfo") #> ("4", "tom", OK)
|
||||
("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:)
|
||||
let atPQ = pqConnectionMode aPQ tPQ
|
||||
atPQSup = CR.pqEncToSupport atPQ
|
||||
tom #: ("21", "alice", "JOIN T " <> cReq' <> enableKEMStr tPQ <> " subscribe 14\ntom's connInfo") #> ("21", "alice", OK)
|
||||
("", "alice_contact", Right (A.REQ aInvId' pqSup3 _ "tom's connInfo")) <- (alice <#:)
|
||||
pqSup3 `shouldBe` tPQ
|
||||
alice #: ("4", "tom", "ACPT " <> aInvId' <> enableKEMStr aPQMode <> " 16\nalice's connInfo") #> ("4", "tom", OK)
|
||||
("", "alice", Right (A.CONF tConfId pqSup4 _ "alice's connInfo")) <- (tom <#:)
|
||||
pqSup4 `shouldBe` atPQSup
|
||||
tom #: ("22", "alice", "LET " <> tConfId <> " 16\ntom's connInfo 2") #> ("22", "alice", OK)
|
||||
alice <# ("", "tom", INFO "tom's connInfo 2")
|
||||
alice <# ("", "tom", CON)
|
||||
tom <# ("", "alice", CON)
|
||||
alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", MID 4)
|
||||
alice <# ("", "tom", A.INFO atPQSup "tom's connInfo 2")
|
||||
alice <# ("", "tom", CON atPQ)
|
||||
tom <# ("", "alice", CON atPQ)
|
||||
alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", A.MID 4 atPQ)
|
||||
alice <# ("", "tom", SENT 4)
|
||||
tom <#= \case ("", "alice", Msg "hi there") -> True; _ -> False
|
||||
tom <#= \case ("", "alice", Msg' 4 pq' "hi there") -> pq' == atPQ; _ -> False
|
||||
tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK)
|
||||
|
||||
testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testContactConnRandomIds _ alice bob = do
|
||||
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON subscribe")
|
||||
testContactConnRandomIds :: Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testContactConnRandomIds (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
pqSup = CR.pqEncToSupport pq
|
||||
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON" <> pqConnModeStr aPQ <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo")
|
||||
("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> enableKEMStr bPQ <> " subscribe 14\nbob's connInfo")
|
||||
("", aliceContact', Right (A.REQ aInvId pqSup' _ "bob's connInfo")) <- (alice <#:)
|
||||
pqSup' `shouldBe` bPQ
|
||||
aliceContact' `shouldBe` aliceContact
|
||||
|
||||
("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> " 16\nalice's connInfo")
|
||||
("", aliceConn', Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:)
|
||||
("2", bobConn, Right OK) <- alice #: ("2", "", "ACPT " <> aInvId <> enableKEMStr (CR.connPQEncryption aPQ) <> " 16\nalice's connInfo")
|
||||
("", aliceConn', Right (A.CONF bConfId pqSup'' _ "alice's connInfo")) <- (bob <#:)
|
||||
pqSup'' `shouldBe` pqSup
|
||||
aliceConn' `shouldBe` aliceConn
|
||||
|
||||
bob #: ("12", aliceConn, "LET " <> bConfId <> " 16\nbob's connInfo 2") #> ("12", aliceConn, OK)
|
||||
alice <# ("", bobConn, INFO "bob's connInfo 2")
|
||||
alice <# ("", bobConn, CON)
|
||||
bob <# ("", aliceConn, CON)
|
||||
alice <# ("", bobConn, A.INFO pqSup "bob's connInfo 2")
|
||||
alice <# ("", bobConn, CON pq)
|
||||
bob <# ("", aliceConn, CON pq)
|
||||
|
||||
alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, MID 4)
|
||||
alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, A.MID 4 pq)
|
||||
alice <# ("", bobConn, SENT 4)
|
||||
bob <#= \case ("", c, Msg "hi") -> c == aliceConn; _ -> False
|
||||
bob <#= \case ("", c, Msg' 4 pq' "hi") -> c == aliceConn && pq == pq'; _ -> False
|
||||
bob #: ("13", aliceConn, "ACK 4") #> ("13", aliceConn, OK)
|
||||
|
||||
testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
@@ -292,7 +358,7 @@ testRejectContactRequest _ alice bob = do
|
||||
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " subscribe 10\nbob's info") #> ("11", "alice", OK)
|
||||
("", "a_contact", Right (REQ aInvId _ "bob's info")) <- (alice <#:)
|
||||
("", "a_contact", Right (A.REQ aInvId PQSupportOff _ "bob's info")) <- (alice <#:)
|
||||
-- RJCT must use correct contact connection
|
||||
alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND)
|
||||
alice #: ("2b", "a_contact", "RJCT " <> aInvId) #> ("2b", "a_contact", OK)
|
||||
@@ -327,31 +393,32 @@ testSubscrNotification t (server, _) client = do
|
||||
withSmpServer (ATransport t) $
|
||||
client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue
|
||||
|
||||
testMsgDeliveryServerRestart :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testMsgDeliveryServerRestart t alice bob = do
|
||||
testMsgDeliveryServerRestart :: forall c. Transport c => (c, InitialKeys) -> (c, PQSupport) -> IO ()
|
||||
testMsgDeliveryServerRestart (alice, aPQ) (bob, bPQ) = do
|
||||
let pq = pqConnectionMode aPQ bPQ
|
||||
withServer $ do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", MID 4)
|
||||
connect' (alice, "alice", aPQ) (bob, "bob", bPQ)
|
||||
bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", A.MID 4 pq)
|
||||
bob <# ("", "alice", SENT 4)
|
||||
alice <#= \case ("", "bob", Msg "hi") -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg' _ pq' "hi") -> pq == pq'; _ -> False
|
||||
alice #: ("11", "bob", "ACK 4") #> ("11", "bob", OK)
|
||||
alice #:# "nothing else delivered before the server is killed"
|
||||
|
||||
let server = SMPServer "localhost" testPort2 testKeyHash
|
||||
alice <#. ("", "", DOWN server ["bob"])
|
||||
bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", MID 5)
|
||||
bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", A.MID 5 pq)
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
withServer $ do
|
||||
bob <# ("", "alice", SENT 5)
|
||||
alice <#. ("", "", UP server ["bob"])
|
||||
alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg' _ pq' "hello again") -> pq == pq'; _ -> False
|
||||
alice #: ("12", "bob", "ACK 5") #> ("12", "bob", OK)
|
||||
|
||||
removeFile testStoreLogFile
|
||||
where
|
||||
withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` ()
|
||||
withServer test' = withSmpServerStoreLogOn (transport @c) testPort2 (const test') `shouldReturn` ()
|
||||
|
||||
testServerConnectionAfterError :: forall c. Transport c => TProxy c -> [c] -> IO ()
|
||||
testServerConnectionAfterError t _ = do
|
||||
@@ -432,7 +499,7 @@ testConcurrentMsgDelivery _ alice bob = do
|
||||
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice2", "JOIN T " <> cReq' <> " subscribe 14\nbob's connInfo") #> ("11", "alice2", OK)
|
||||
("", "bob2", Right (CONF _confId _ "bob's connInfo")) <- (alice <#:)
|
||||
("", "bob2", Right (A.CONF _confId PQSupportOff _ "bob's connInfo")) <- (alice <#:)
|
||||
-- below commands would be needed to accept bob's connection, but alice does not
|
||||
-- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
-- bob <# ("", "alice", INFO "alice's connInfo")
|
||||
@@ -492,16 +559,33 @@ testResumeDeliveryQuotaExceeded _ alice bob = do
|
||||
-- message 8 is skipped because of alice agent sending "QCONT" message
|
||||
bob #: ("5", "alice", "ACK 9") #> ("5", "alice", OK)
|
||||
|
||||
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV subscribe")
|
||||
connect :: Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = connect' (h1, name1, IKPQOn) (h2, name2, PQSupportOn)
|
||||
|
||||
connect' :: forall c. Transport c => (c, ByteString, InitialKeys) -> (c, ByteString, PQSupport) -> IO ()
|
||||
connect' (h1, name1, pqMode1) (h2, name2, pqMode2) = do
|
||||
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV" <> pqConnModeStr pqMode1 <> " subscribe")
|
||||
let cReq' = strEncode cReq
|
||||
h2 #: ("c2", name1, "JOIN T " <> cReq' <> " subscribe 5\ninfo2") #> ("c2", name1, OK)
|
||||
("", _, Right (CONF connId _ "info2")) <- (h1 <#:)
|
||||
pq = pqConnectionMode pqMode1 pqMode2
|
||||
pqSup = CR.pqEncToSupport pq
|
||||
h2 #: ("c2", name1, "JOIN T " <> cReq' <> enableKEMStr pqMode2 <> " subscribe 5\ninfo2") #> ("c2", name1, OK)
|
||||
("", _, Right (A.CONF connId pqSup' _ "info2")) <- (h1 <#:)
|
||||
pqSup' `shouldBe` pqSup
|
||||
h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK)
|
||||
h2 <# ("", name1, INFO "info1")
|
||||
h2 <# ("", name1, CON)
|
||||
h1 <# ("", name2, CON)
|
||||
h2 <# ("", name1, A.INFO pqSup "info1")
|
||||
h2 <# ("", name1, CON pq)
|
||||
h1 <# ("", name2, CON pq)
|
||||
|
||||
pqConnectionMode :: InitialKeys -> PQSupport -> PQEncryption
|
||||
pqConnectionMode pqMode1 pqMode2 = PQEncryption $ supportPQ (CR.connPQEncryption pqMode1) && supportPQ pqMode2
|
||||
|
||||
enableKEMStr :: PQSupport -> ByteString
|
||||
enableKEMStr PQSupportOn = " " <> strEncode PQSupportOn
|
||||
enableKEMStr _ = ""
|
||||
|
||||
pqConnModeStr :: InitialKeys -> ByteString
|
||||
pqConnModeStr (IKNoPQ PQSupportOff) = ""
|
||||
pqConnModeStr pq = " " <> strEncode pq
|
||||
|
||||
sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO ()
|
||||
sendMessage (h1, name1) (h2, name2) msg = do
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module AgentTests.ConnectionRequestTests where
|
||||
@@ -12,7 +15,7 @@ import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (ProtocolServer (..), supportedSMPClientVRange)
|
||||
import Simplex.Messaging.Protocol (ProtocolServer (..), pattern VersionSMPC, supportedSMPClientVRange)
|
||||
import Simplex.Messaging.ServiceScheme (ServiceScheme (..))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.Hspec
|
||||
@@ -38,7 +41,7 @@ queue :: SMPQueueUri
|
||||
queue = SMPQueueUri supportedSMPClientVRange queueAddr
|
||||
|
||||
queueV1 :: SMPQueueUri
|
||||
queueV1 = SMPQueueUri (mkVersionRange 1 1) queueAddr
|
||||
queueV1 = SMPQueueUri (mkVersionRange (VersionSMPC 1) (VersionSMPC 1)) queueAddr
|
||||
|
||||
testDhKey :: C.PublicKeyX25519
|
||||
testDhKey = "MCowBQYDK2VuAyEAjiswwI3O/NlS8Fk3HJUW870EY2bAwmttMBsvRB9eV3o="
|
||||
@@ -53,7 +56,7 @@ connReqData :: ConnReqUriData
|
||||
connReqData =
|
||||
ConnReqUriData
|
||||
{ crScheme = SSSimplex,
|
||||
crAgentVRange = mkVersionRange 2 2,
|
||||
crAgentVRange = mkVersionRange (VersionSMPA 2) (VersionSMPA 2),
|
||||
crSmpQueues = [queueV1],
|
||||
crClientData = Nothing
|
||||
}
|
||||
@@ -61,11 +64,11 @@ connReqData =
|
||||
testDhPubKey :: C.PublicKeyX448
|
||||
testDhPubKey = "MEIwBQYDK2VvAzkAmKuSYeQ/m0SixPDS8Wq8VBaTS1cW+Lp0n0h4Diu+kUpR+qXx4SDJ32YGEFoGFGSbGPry5Ychr6U="
|
||||
|
||||
testE2ERatchetParams :: E2ERatchetParamsUri 'C.X448
|
||||
testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange 1 1) testDhPubKey testDhPubKey
|
||||
testE2ERatchetParams :: RcvE2ERatchetParamsUri 'C.X448
|
||||
testE2ERatchetParams = E2ERatchetParamsUri (mkVersionRange (VersionE2E 1) (VersionE2E 1)) testDhPubKey testDhPubKey Nothing
|
||||
|
||||
testE2ERatchetParams12 :: E2ERatchetParamsUri 'C.X448
|
||||
testE2ERatchetParams12 = E2ERatchetParamsUri supportedE2EEncryptVRange testDhPubKey testDhPubKey
|
||||
testE2ERatchetParams12 :: RcvE2ERatchetParamsUri 'C.X448
|
||||
testE2ERatchetParams12 = E2ERatchetParamsUri (supportedE2EEncryptVRange PQSupportOn) testDhPubKey testDhPubKey Nothing
|
||||
|
||||
connectionRequest :: AConnectionRequestUri
|
||||
connectionRequest =
|
||||
@@ -79,7 +82,7 @@ connectionRequestCurrentRange :: AConnectionRequestUri
|
||||
connectionRequestCurrentRange =
|
||||
ACR SCMInvitation $
|
||||
CRInvitationUri
|
||||
connReqData {crAgentVRange = supportedSMPAgentVRange, crSmpQueues = [queueV1, queueV1]}
|
||||
connReqData {crAgentVRange = supportedSMPAgentVRange PQSupportOn, crSmpQueues = [queueV1, queueV1]}
|
||||
testE2ERatchetParams12
|
||||
|
||||
connectionRequestClientDataEmpty :: AConnectionRequestUri
|
||||
@@ -98,7 +101,7 @@ connectionRequestTests =
|
||||
it "should serialize SMP queue URIs" $ do
|
||||
strEncode (queue :: SMPQueueUri) {queueAddress = queueAddrNoPort}
|
||||
`shouldBe` "smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri
|
||||
strEncode queue {clientVRange = mkVersionRange 1 2}
|
||||
strEncode queue {clientVRange = mkVersionRange (VersionSMPC 1) (VersionSMPC 2)}
|
||||
`shouldBe` "smp://1234-w==@smp.simplex.im:5223/3456-w==#/?v=1-2&dh=" <> testDhKeyStrUri
|
||||
it "should parse SMP queue URIs" $ do
|
||||
strDecode ("smp://1234-w==@smp.simplex.im/3456-w==#/?v=1-2&dh=" <> testDhKeyStr)
|
||||
@@ -119,11 +122,11 @@ connectionRequestTests =
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D1%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
strEncode connectionRequestCurrentRange
|
||||
`shouldBe` "simplex:/invitation#/?v=2-4&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
`shouldBe` "simplex:/invitation#/?v=2-5&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
<> "&e2e=v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&e2e=v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
strEncode connectionRequestClientDataEmpty
|
||||
`shouldBe` "simplex:/invitation#/?v=2&smp=smp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> urlEncode True testDhKeyStrUri
|
||||
@@ -167,9 +170,9 @@ connectionRequestTests =
|
||||
<> testDhKeyStrUri
|
||||
<> "%2Csmp%3A%2F%2F1234-w%3D%3D%40smp.simplex.im%3A5223%2F3456-w%3D%3D%23%2F%3Fv%3D1%26dh%3D"
|
||||
<> testDhKeyStrUri
|
||||
<> "&e2e=extra_key%3Dnew%26v%3D2%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&e2e=extra_key%3Dnew%26v%3D2-3%26x3dh%3DMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D%2CMEIwBQYDK2VvAzkAmKuSYeQ_m0SixPDS8Wq8VBaTS1cW-Lp0n0h4Diu-kUpR-qXx4SDJ32YGEFoGFGSbGPry5Ychr6U%3D"
|
||||
<> "&some_new_param=abc"
|
||||
<> "&v=2-4"
|
||||
<> "&v=2-5"
|
||||
)
|
||||
`shouldBe` Right connectionRequestCurrentRange
|
||||
strDecode
|
||||
|
||||
@@ -1,75 +1,178 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module AgentTests.DoubleRatchetTests where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Class
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import Data.Aeson (FromJSON, ToJSON, (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Crypto (Algorithm (..), AlgorithmI, CryptoError, DhAlgorithm)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
|
||||
import Simplex.Messaging.Crypto.Ratchet
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util ((<$$>))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.Hspec
|
||||
|
||||
doubleRatchetTests :: Spec
|
||||
doubleRatchetTests = do
|
||||
describe "double-ratchet encryption/decryption" $ do
|
||||
it "should serialize and parse message header" testMessageHeader
|
||||
it "should encrypt and decrypt messages" $ do
|
||||
withRatchets @X25519 testEncryptDecrypt
|
||||
withRatchets @X448 testEncryptDecrypt
|
||||
it "should encrypt and decrypt skipped messages" $ do
|
||||
withRatchets @X25519 testSkippedMessages
|
||||
withRatchets @X448 testSkippedMessages
|
||||
it "should encrypt and decrypt many messages" $ do
|
||||
withRatchets @X25519 testManyMessages
|
||||
it "should allow skipped after ratchet advance" $ do
|
||||
withRatchets @X25519 testSkippedAfterRatchetAdvance
|
||||
it "should serialize and parse message header" $ do
|
||||
testAlgs $ testMessageHeader kdfX3DHE2EEncryptVersion
|
||||
testAlgs $ testMessageHeader $ max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
describe "message tests" $ runMessageTests initRatchets False
|
||||
it "should encode/decode ratchet as JSON" $ do
|
||||
testKeyJSON C.SX25519
|
||||
testKeyJSON C.SX448
|
||||
testRatchetJSON C.SX25519
|
||||
testRatchetJSON C.SX448
|
||||
it "should agree the same ratchet parameters" $ do
|
||||
testX3dh C.SX25519
|
||||
testX3dh C.SX448
|
||||
it "should agree the same ratchet parameters with version 1" $ do
|
||||
testX3dhV1 C.SX25519
|
||||
testX3dhV1 C.SX448
|
||||
testAlgs testKeyJSON
|
||||
testAlgs testRatchetJSON
|
||||
testVersionJSON
|
||||
it "should decode v2 Ratchet with default field values" $ testDecodeV2RatchetJSON
|
||||
it "should agree the same ratchet parameters" $ testAlgs testX3dh
|
||||
it "should agree the same ratchet parameters with version 1" $ testAlgs testX3dhV1
|
||||
describe "post-quantum hybrid KEM double-ratchet algorithm" $ do
|
||||
describe "hybrid KEM key agreement" $ do
|
||||
it "should propose KEM during agreement, but no shared secret" $ testAlgs testPqX3dhProposeInReply
|
||||
it "should agree shared secret using KEM" $ testAlgs testPqX3dhProposeAccept
|
||||
it "should reject proposed KEM in reply" $ testAlgs testPqX3dhProposeReject
|
||||
it "should allow second proposal in reply" $ testAlgs testPqX3dhProposeAgain
|
||||
describe "hybrid KEM key agreement errors" $ do
|
||||
it "should fail if reply contains acceptance without proposal" $ testAlgs testPqX3dhAcceptWithoutProposalError
|
||||
describe "ratchet encryption/decryption" $ do
|
||||
it "should serialize and parse public KEM params" testKEMParams
|
||||
it "should serialize and parse message header" $ testAlgs testMessageHeaderKEM
|
||||
describe "message tests, KEM proposed" $ runMessageTests initRatchetsKEMProposed True
|
||||
describe "message tests, KEM accepted" $ runMessageTests initRatchetsKEMAccepted False
|
||||
describe "message tests, KEM proposed again in reply" $ runMessageTests initRatchetsKEMProposedAgain True
|
||||
it "should disable and re-enable KEM" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEM
|
||||
it "should disable and re-enable KEM (always set PQEncryption)" $ withRatchets_ @X25519 initRatchetsKEMAccepted testDisableEnableKEMStrict
|
||||
it "should enable KEM when it was not enabled in handshake" $ withRatchets_ @X25519 initRatchets testEnableKEM
|
||||
it "should enable KEM when it was not enabled in handshake (always set PQEncryption)" $ withRatchets_ @X25519 initRatchets testEnableKEMStrict
|
||||
|
||||
runMessageTests ::
|
||||
(forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)) ->
|
||||
Bool ->
|
||||
Spec
|
||||
runMessageTests initRatchets_ agreeRatchetKEMs = do
|
||||
it "should encrypt and decrypt messages" $ run $ testEncryptDecrypt agreeRatchetKEMs
|
||||
it "should encrypt and decrypt skipped messages" $ run $ testSkippedMessages agreeRatchetKEMs
|
||||
it "should encrypt and decrypt many messages" $ run $ testManyMessages agreeRatchetKEMs
|
||||
it "should allow skipped after ratchet advance" $ run $ testSkippedAfterRatchetAdvance agreeRatchetKEMs
|
||||
where
|
||||
run :: (forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a) -> IO ()
|
||||
run test = do
|
||||
withRatchets_ @X25519 initRatchets_ test
|
||||
withRatchets_ @X448 initRatchets_ test
|
||||
|
||||
|
||||
testAlgs :: (forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()) -> IO ()
|
||||
testAlgs test = test C.SX25519 >> test C.SX448
|
||||
|
||||
paddedMsgLen :: Int
|
||||
paddedMsgLen = 100
|
||||
|
||||
fullMsgLen :: Int
|
||||
fullMsgLen = 1 + fullHeaderLen + C.authTagSize + paddedMsgLen
|
||||
fullMsgLen :: Ratchet a -> Int
|
||||
fullMsgLen Ratchet {rcSupportKEM, rcVersion} = headerLenLength + fullHeaderLen v rcSupportKEM + C.authTagSize + paddedMsgLen
|
||||
where
|
||||
v = current rcVersion
|
||||
headerLenLength = case rcSupportKEM of
|
||||
PQSupportOn | v >= pqRatchetE2EEncryptVersion -> 3 -- two bytes are added because of two Large used in new encoding
|
||||
_ -> 1
|
||||
|
||||
testMessageHeader :: Expectation
|
||||
testMessageHeader = do
|
||||
(k, _) <- atomically . C.generateKeyPair @X25519 =<< C.newRandom
|
||||
let hdr = MsgHeader {msgMaxVersion = currentE2EEncryptVersion, msgDHRs = k, msgPN = 0, msgNs = 0}
|
||||
parseAll (smpP @(MsgHeader 'X25519)) (smpEncode hdr) `shouldBe` Right hdr
|
||||
testMessageHeader :: forall a. AlgorithmI a => VersionE2E -> C.SAlgorithm a -> Expectation
|
||||
testMessageHeader v _ = do
|
||||
(k, _) <- atomically . C.generateKeyPair @a =<< C.newRandom
|
||||
let hdr = MsgHeader {msgMaxVersion = v, msgDHRs = k, msgKEM = Nothing, msgPN = 0, msgNs = 0}
|
||||
parseAll (msgHeaderP v) (encodeMsgHeader v hdr) `shouldBe` Right hdr
|
||||
|
||||
testKEMParams :: Expectation
|
||||
testKEMParams = do
|
||||
g <- C.newRandom
|
||||
(kem, _) <- sntrup761Keypair g
|
||||
let kemParams = ARKP SRKSProposed $ RKParamsProposed kem
|
||||
parseAll (smpP @ARKEMParams) (smpEncode kemParams) `shouldBe` Right kemParams
|
||||
(kem', _) <- sntrup761Keypair g
|
||||
(ct, _) <- sntrup761Enc g kem
|
||||
let kemParams' = ARKP SRKSAccepted $ RKParamsAccepted ct kem'
|
||||
parseAll (smpP @ARKEMParams) (smpEncode kemParams') `shouldBe` Right kemParams'
|
||||
|
||||
testMessageHeaderKEM :: forall a. AlgorithmI a => C.SAlgorithm a -> Expectation
|
||||
testMessageHeaderKEM _ = do
|
||||
g <- C.newRandom
|
||||
(k, _) <- atomically $ C.generateKeyPair @a g
|
||||
(kem, _) <- sntrup761Keypair g
|
||||
let msgMaxVersion = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
msgKEM = Just . ARKP SRKSProposed $ RKParamsProposed kem
|
||||
hdr = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM, msgPN = 0, msgNs = 0}
|
||||
parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr) `shouldBe` Right hdr
|
||||
(kem', _) <- sntrup761Keypair g
|
||||
(ct, _) <- sntrup761Enc g kem
|
||||
let msgKEM' = Just . ARKP SRKSAccepted $ RKParamsAccepted ct kem'
|
||||
hdr' = MsgHeader {msgMaxVersion, msgDHRs = k, msgKEM = msgKEM', msgPN = 0, msgNs = 0}
|
||||
parseAll (msgHeaderP msgMaxVersion) (encodeMsgHeader msgMaxVersion hdr') `shouldBe` Right hdr'
|
||||
|
||||
pattern Decrypted :: ByteString -> Either CryptoError (Either CryptoError ByteString)
|
||||
pattern Decrypted msg <- Right (Right msg)
|
||||
|
||||
type TestRatchets a = (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()
|
||||
type Encrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString)
|
||||
|
||||
testEncryptDecrypt :: TestRatchets a
|
||||
testEncryptDecrypt alice bob = do
|
||||
type Decrypt a = TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString))
|
||||
|
||||
type EncryptDecryptSpec a = (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation
|
||||
|
||||
type TestRatchets a =
|
||||
TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) ->
|
||||
TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) ->
|
||||
Encrypt a ->
|
||||
Decrypt a ->
|
||||
EncryptDecryptSpec a ->
|
||||
IO ()
|
||||
|
||||
deriving instance Eq (Ratchet a)
|
||||
|
||||
deriving instance Eq (SndRatchet a)
|
||||
|
||||
deriving instance Eq RcvRatchet
|
||||
|
||||
deriving instance Eq RatchetKEM
|
||||
|
||||
deriving instance Eq RatchetKEMAccepted
|
||||
|
||||
deriving instance Eq RatchetInitParams
|
||||
|
||||
deriving instance Eq RatchetKey
|
||||
|
||||
instance Eq ARKEMParams where
|
||||
(ARKP s ps) == (ARKP s' ps') = case testEquality s s' of
|
||||
Just Refl -> ps == ps'
|
||||
Nothing -> False
|
||||
|
||||
deriving instance Eq (MsgHeader a)
|
||||
|
||||
initRatchetKEM :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()
|
||||
initRatchetKEM s r = encryptDecrypt (Just $ PQEncOn) (const ()) (const ()) (s, "initialising ratchet") r
|
||||
|
||||
testEncryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
|
||||
testEncryptDecrypt agreeRatchetKEMs alice bob encrypt decrypt (#>) = do
|
||||
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
|
||||
(bob, "hello alice") #> alice
|
||||
(alice, "hello bob") #> bob
|
||||
Right b1 <- encrypt bob "how are you, alice?"
|
||||
@@ -88,8 +191,9 @@ testEncryptDecrypt alice bob = do
|
||||
(alice, "I'm here too, same") #> bob
|
||||
pure ()
|
||||
|
||||
testSkippedMessages :: TestRatchets a
|
||||
testSkippedMessages alice bob = do
|
||||
testSkippedMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
|
||||
testSkippedMessages agreeRatchetKEMs alice bob encrypt decrypt _ = do
|
||||
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
|
||||
Right msg1 <- encrypt bob "hello alice"
|
||||
Right msg2 <- encrypt bob "hello there again"
|
||||
Right msg3 <- encrypt bob "are you there?"
|
||||
@@ -99,8 +203,9 @@ testSkippedMessages alice bob = do
|
||||
Decrypted "hello alice" <- decrypt alice msg1
|
||||
pure ()
|
||||
|
||||
testManyMessages :: TestRatchets a
|
||||
testManyMessages alice bob = do
|
||||
testManyMessages :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
|
||||
testManyMessages agreeRatchetKEMs alice bob _ _ (#>) = do
|
||||
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
|
||||
(bob, "b1") #> alice
|
||||
(bob, "b2") #> alice
|
||||
(bob, "b3") #> alice
|
||||
@@ -117,8 +222,9 @@ testManyMessages alice bob = do
|
||||
(bob, "b15") #> alice
|
||||
(bob, "b16") #> alice
|
||||
|
||||
testSkippedAfterRatchetAdvance :: TestRatchets a
|
||||
testSkippedAfterRatchetAdvance alice bob = do
|
||||
testSkippedAfterRatchetAdvance :: (AlgorithmI a, DhAlgorithm a) => Bool -> TestRatchets a
|
||||
testSkippedAfterRatchetAdvance agreeRatchetKEMs alice bob encrypt decrypt (#>) = do
|
||||
when agreeRatchetKEMs $ initRatchetKEM bob alice >> initRatchetKEM alice bob
|
||||
(bob, "b1") #> alice
|
||||
Right b2 <- encrypt bob "b2"
|
||||
Right b3 <- encrypt bob "b3"
|
||||
@@ -152,6 +258,84 @@ testSkippedAfterRatchetAdvance alice bob = do
|
||||
Decrypted "b11" <- decrypt alice b11
|
||||
pure ()
|
||||
|
||||
testDisableEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
|
||||
testDisableEnableKEM alice bob _ _ _ = do
|
||||
(bob, "hello alice") !#> alice
|
||||
(alice, "hello bob") !#> bob
|
||||
(bob, "disabling KEM") !#>\ alice
|
||||
(alice, "still disabling KEM") !#> bob
|
||||
(bob, "now KEM is disabled") \#> alice
|
||||
(alice, "KEM is disabled for both sides") \#> bob
|
||||
(bob, "trying to enable KEM") \#>! alice
|
||||
(alice, "but unless alice enables it too it won't enable") \#> bob
|
||||
(bob, "KEM is disabled") \#> alice
|
||||
(alice, "KEM is disabled for both sides") \#> bob
|
||||
(bob, "enabling KEM again") \#>! alice
|
||||
(alice, "and alice accepts it this time") \#>! bob
|
||||
(bob, "still enabling KEM") \#>! alice
|
||||
(alice, "now KEM is enabled") !#> bob
|
||||
(bob, "KEM is enabled for both sides") !#> alice
|
||||
|
||||
testDisableEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
|
||||
testDisableEnableKEMStrict alice bob _ _ _ = do
|
||||
(bob, "hello alice") !#>! alice
|
||||
(alice, "hello bob") !#>! bob
|
||||
(bob, "disabling KEM") !#>\ alice
|
||||
(alice, "still disabling KEM") !#>! bob
|
||||
(bob, "now KEM is disabled") \#>\ alice
|
||||
(alice, "KEM is disabled for both sides") \#>\ bob
|
||||
(bob, "trying to enable KEM") \#>! alice
|
||||
(alice, "but unless alice enables it too it won't enable") \#>\ bob
|
||||
(bob, "KEM is disabled") \#>! alice
|
||||
(alice, "KEM is disabled for both sides") \#>\ bob
|
||||
(bob, "enabling KEM again") \#>! alice
|
||||
(alice, "and alice accepts it this time") \#>! bob
|
||||
(bob, "still enabling KEM") \#>! alice
|
||||
(alice, "now KEM is enabled") !#>! bob
|
||||
(bob, "KEM is enabled for both sides") !#>! alice
|
||||
|
||||
testEnableKEM :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
|
||||
testEnableKEM alice bob _ _ _ = do
|
||||
(bob, "hello alice") \#> alice
|
||||
(alice, "hello bob") \#> bob
|
||||
(bob, "enabling KEM") \#>! alice
|
||||
(bob, "KEM not enabled yet") \#>! alice
|
||||
(alice, "accepting KEM") \#>! bob
|
||||
(alice, "KEM not enabled yet here too") \#>! bob
|
||||
(bob, "KEM is still not enabled") \#>! alice
|
||||
(alice, "now KEM is enabled") !#>! bob
|
||||
(bob, "now KEM is enabled for both sides") !#> alice
|
||||
(alice, "still enabled for both sides") !#> bob
|
||||
(bob, "still enabled for both sides 2") !#> alice
|
||||
(alice, "disabling KEM") !#>\ bob
|
||||
(bob, "KEM not disabled yet") !#> alice
|
||||
(alice, "KEM disabled") \#> bob
|
||||
(bob, "KEM disabled on both sides") \#> alice
|
||||
|
||||
testEnableKEMStrict :: forall a. (AlgorithmI a, DhAlgorithm a) => TestRatchets a
|
||||
testEnableKEMStrict alice bob _ _ _ = do
|
||||
(bob, "hello alice") \#>\ alice
|
||||
(alice, "hello bob") \#>\ bob
|
||||
(bob, "enabling KEM") \#>! alice
|
||||
(bob, "KEM not enabled yet") \#>! alice
|
||||
(alice, "accepting KEM") \#>! bob
|
||||
(alice, "KEM not enabled yet here too") \#>! bob
|
||||
(bob, "KEM is still not enabled") \#>! alice
|
||||
(alice, "now KEM is enabled") !#>! bob
|
||||
(bob, "now KEM is enabled for both sides") !#>! alice
|
||||
(alice, "still enabled for both sides") !#>! bob
|
||||
(bob, "still enabled for both sides 2") !#>! alice
|
||||
(alice, "disabling KEM") !#>\ bob
|
||||
(bob, "KEM not disabled yet") !#>! alice
|
||||
(alice, "KEM disabled") \#>\ bob
|
||||
(bob, "KEM disabled on both sides") \#>! alice
|
||||
(alice, "KEM still disabled 1") \#>\ bob
|
||||
(bob, "KEM still disabled 2") \#>! alice
|
||||
(alice, "KEM still disabled 3") \#>\ bob
|
||||
(bob, "KEM still disabled 4") \#>! alice
|
||||
(alice, "KEM still disabled 5") \#>\ bob
|
||||
(bob, "KEM still disabled 6") \#>! alice
|
||||
|
||||
testKeyJSON :: forall a. AlgorithmI a => C.SAlgorithm a -> IO ()
|
||||
testKeyJSON _ = do
|
||||
(k, pk) <- atomically . C.generateKeyPair @a =<< C.newRandom
|
||||
@@ -160,10 +344,33 @@ testKeyJSON _ = do
|
||||
|
||||
testRatchetJSON :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testRatchetJSON _ = do
|
||||
(alice, bob) <- initRatchets @a
|
||||
(alice, bob, _, _, _) <- initRatchets @a
|
||||
testEncodeDecode alice
|
||||
testEncodeDecode bob
|
||||
|
||||
testVersionJSON :: IO ()
|
||||
testVersionJSON = do
|
||||
testEncodeDecode $ rv 1 1
|
||||
testEncodeDecode $ rv 1 2
|
||||
-- let bad = RVersions 2 1
|
||||
-- Left err <- pure $ J.eitherDecode' @RatchetVersions (J.encode bad)
|
||||
-- err `shouldContain` "bad version range"
|
||||
testDecodeRV $ (1 :: Int, 2 :: Int)
|
||||
testDecodeRV $ J.object ["current" .= (1 :: Int), "maxSupported" .= (2 :: Int)]
|
||||
where
|
||||
rv v1 v2 = RatchetVersions (VersionE2E v1) (VersionE2E v2)
|
||||
testDecodeRV :: ToJSON a => a -> Expectation
|
||||
testDecodeRV a = J.eitherDecode' (J.encode a) `shouldBe` Right (rv 1 2)
|
||||
|
||||
testDecodeV2RatchetJSON :: IO ()
|
||||
testDecodeV2RatchetJSON = do
|
||||
let v2RatchetJSON = "{\"rcVersion\":[2,2],\"rcAD\":\"2GEJrq48TmQse6NR16I-hrI0tSySZQ57E_g46nDceAPRAiF6j0drq26RTE7be6X7uiB4RaGJGf4QRXzcYuVtWw==\",\"rcDHRs\":\"TUM0Q0FRQXdCUVlESzJWdUJDSUVJRkNYbUxtSHQ3SUNfeHpGTi1Qb3ZqTVQ3S2p6XzZlZlBjOG9fRFY2RWxKOQ==\",\"rcRK\":\"BOX2X7YW5qDSp2XknY_lqacSrtDqQNPvS6iJlZIs3G0=\",\"rcNs\":0,\"rcNr\":0,\"rcPN\":0,\"rcNHKs\":\"IMouSkXUvzT_mo0WM-pqEUK09-HTLk9WOTCFQglyQxU=\",\"rcNHKr\":\"g-tus1clYPV0rGlzkf5a959tUqDYQVZ1FpcPeXdKwxI=\"}"
|
||||
Right (r :: Ratchet X25519) <- pure $ J.eitherDecodeStrict' v2RatchetJSON
|
||||
rcSupportKEM r `shouldBe` PQSupportOff
|
||||
rcEnableKEM r `shouldBe` PQEncOff
|
||||
rcSndKEM r `shouldBe` PQEncOff
|
||||
rcRcvKEM r `shouldBe` PQEncOff
|
||||
|
||||
testEncodeDecode :: (Eq a, Show a, ToJSON a, FromJSON a) => a -> Expectation
|
||||
testEncodeDecode x = do
|
||||
let j = J.encode x
|
||||
@@ -173,77 +380,255 @@ testEncodeDecode x = do
|
||||
testX3dh :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testX3dh _ = do
|
||||
g <- C.newRandom
|
||||
(pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion
|
||||
(pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g currentE2EEncryptVersion
|
||||
let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice
|
||||
paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
(pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing
|
||||
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff
|
||||
let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
|
||||
paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
|
||||
paramsAlice `shouldBe` paramsBob
|
||||
|
||||
testX3dhV1 :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testX3dhV1 _ = do
|
||||
g <- C.newRandom
|
||||
(pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams @a g 1
|
||||
(pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams @a g 1
|
||||
let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice
|
||||
paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob
|
||||
(pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g (VersionE2E 1) Nothing
|
||||
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g (VersionE2E 1) PQSupportOff
|
||||
let paramsBob = pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
|
||||
paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
|
||||
paramsAlice `shouldBe` paramsBob
|
||||
|
||||
(#>) :: (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys), ByteString) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> Expectation
|
||||
(alice, msg) #> bob = do
|
||||
Right msg' <- encrypt alice msg
|
||||
Decrypted msg'' <- decrypt bob msg'
|
||||
testPqX3dhProposeInReply :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testPqX3dhProposeInReply _ = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (no KEM)
|
||||
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff
|
||||
-- propose KEM in reply
|
||||
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM)
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
|
||||
paramsAlice `compatibleRatchets` paramsBob
|
||||
|
||||
testPqX3dhProposeAccept :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testPqX3dhProposeAccept _ = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (propose KEM)
|
||||
(pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn
|
||||
E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice
|
||||
-- accept KEM
|
||||
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM aliceKem)
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob
|
||||
paramsAlice `compatibleRatchets` paramsBob
|
||||
|
||||
testPqX3dhProposeReject :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testPqX3dhProposeReject _ = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (propose KEM)
|
||||
(pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn
|
||||
E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice
|
||||
-- reject KEM
|
||||
(pkBob1, pkBob2, Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v Nothing
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob
|
||||
paramsAlice `compatibleRatchets` paramsBob
|
||||
|
||||
testPqX3dhAcceptWithoutProposalError :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testPqX3dhAcceptWithoutProposalError _ = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (no KEM)
|
||||
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOff
|
||||
E2ERatchetParams _ _ _ Nothing <- pure e2eAlice
|
||||
-- incorrectly accept KEM
|
||||
-- we don't have key in proposal, so we just generate it
|
||||
(k, _) <- sntrup761Keypair g
|
||||
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSAccepted $ AcceptKEM k)
|
||||
pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice `shouldBe` Left C.CERatchetKEMState
|
||||
runExceptT (pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob) `shouldReturn` Left C.CERatchetKEMState
|
||||
|
||||
testPqX3dhProposeAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => C.SAlgorithm a -> IO ()
|
||||
testPqX3dhProposeAgain _ = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (propose KEM)
|
||||
(pkAlice1, pkAlice2, pKemAlice_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams @a g v PQSupportOn
|
||||
E2ERatchetParams _ _ _ (Just (RKParamsProposed _)) <- pure e2eAlice
|
||||
-- propose KEM again in reply - this is not an error
|
||||
(pkBob1, pkBob2, pKemBob_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams @a g v (Just $ AUseKEM SRKSProposed ProposeKEM)
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemBob_ e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKemAlice_ e2eBob
|
||||
paramsAlice `compatibleRatchets` paramsBob
|
||||
|
||||
compatibleRatchets :: (RatchetInitParams, x) -> (RatchetInitParams, x) -> Expectation
|
||||
compatibleRatchets
|
||||
(RatchetInitParams {assocData, ratchetKey, sndHK, rcvNextHK, kemAccepted}, _)
|
||||
(RatchetInitParams {assocData = ad, ratchetKey = rk, sndHK = shk, rcvNextHK = rnhk, kemAccepted = ka}, _) = do
|
||||
assocData == ad && ratchetKey == rk && sndHK == shk && rcvNextHK == rnhk `shouldBe` True
|
||||
case (kemAccepted, ka) of
|
||||
(Just RatchetKEMAccepted {rcPQRr, rcPQRss, rcPQRct}, Just RatchetKEMAccepted {rcPQRr = pqk, rcPQRss = pqss, rcPQRct = pqct}) ->
|
||||
pqk /= rcPQRr && pqss == rcPQRss && pqct == rcPQRct `shouldBe` True
|
||||
(Nothing, Nothing) -> pure ()
|
||||
_ -> expectationFailure "RatchetInitParams params are not compatible"
|
||||
|
||||
encryptDecrypt :: (AlgorithmI a, DhAlgorithm a) => Maybe PQEncryption -> (Ratchet a -> ()) -> (Ratchet a -> ()) -> EncryptDecryptSpec a
|
||||
encryptDecrypt pqEnc validSnd validRcv (alice, msg) bob = do
|
||||
Right msg' <- withTVar (encrypt_ pqEnc) validSnd alice msg
|
||||
Decrypted msg'' <- decrypt' validRcv bob msg'
|
||||
msg'' `shouldBe` msg
|
||||
|
||||
withRatchets :: forall a. (AlgorithmI a, DhAlgorithm a) => (TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> IO ()) -> Expectation
|
||||
withRatchets test = do
|
||||
-- enable KEM (currently disabled)
|
||||
(\#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
|
||||
(s, msg) \#>! r = encryptDecrypt (Just PQEncOn) noSndKEM noRcvKEM (s, msg) r
|
||||
|
||||
-- enable KEM (currently enabled)
|
||||
(!#>!) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
|
||||
(s, msg) !#>! r = encryptDecrypt (Just PQEncOn) hasSndKEM hasRcvKEM (s, msg) r
|
||||
|
||||
-- KEM enabled (no user preference)
|
||||
(!#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
|
||||
(s, msg) !#> r = encryptDecrypt Nothing hasSndKEM hasRcvKEM (s, msg) r
|
||||
|
||||
-- disable KEM (currently enabled)
|
||||
(!#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
|
||||
(s, msg) !#>\ r = encryptDecrypt (Just PQEncOff) hasSndKEM hasRcvKEM (s, msg) r
|
||||
|
||||
-- disable KEM (currently disabled)
|
||||
(\#>\) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
|
||||
(s, msg) \#>\ r = encryptDecrypt (Just PQEncOff) noSndKEM noSndKEM (s, msg) r
|
||||
|
||||
-- KEM disabled (no user preference)
|
||||
(\#>) :: (AlgorithmI a, DhAlgorithm a) => EncryptDecryptSpec a
|
||||
(s, msg) \#> r = encryptDecrypt Nothing noSndKEM noSndKEM (s, msg) r
|
||||
|
||||
withRatchets_ :: IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a) -> TestRatchets a -> Expectation
|
||||
withRatchets_ initRatchets_ test = do
|
||||
ga <- C.newRandom
|
||||
gb <- C.newRandom
|
||||
(a, b) <- initRatchets @a
|
||||
(a, b, encrypt, decrypt, (#>)) <- initRatchets_
|
||||
alice <- newTVarIO (ga, a, M.empty)
|
||||
bob <- newTVarIO (gb, b, M.empty)
|
||||
test alice bob `shouldReturn` ()
|
||||
test alice bob encrypt decrypt (#>) `shouldReturn` ()
|
||||
|
||||
initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a)
|
||||
initRatchets :: (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
|
||||
initRatchets = do
|
||||
g <- C.newRandom
|
||||
(pkBob1, pkBob2, e2eBob) <- atomically $ generateE2EParams g currentE2EEncryptVersion
|
||||
(pkAlice1, pkAlice2, e2eAlice) <- atomically $ generateE2EParams g currentE2EEncryptVersion
|
||||
let paramsBob = x3dhSnd pkBob1 pkBob2 e2eAlice
|
||||
paramsAlice = x3dhRcv pkAlice1 pkAlice2 e2eBob
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
(pkBob1, pkBob2, _pKemParams@Nothing, AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v Nothing
|
||||
(pkAlice1, pkAlice2, _pKem@Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 Nothing e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
|
||||
(_, pkBob3) <- atomically $ C.generateKeyPair g
|
||||
let bob = initSndRatchet supportedE2EEncryptVRange (C.publicKey pkAlice2) pkBob3 paramsBob
|
||||
alice = initRcvRatchet supportedE2EEncryptVRange pkAlice2 paramsAlice
|
||||
pure (alice, bob)
|
||||
let vs = testRatchetVersions PQSupportOff
|
||||
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
|
||||
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOff
|
||||
pure (alice, bob, encrypt' noSndKEM, decrypt' noRcvKEM, (\#>))
|
||||
|
||||
encrypt_ :: AlgorithmI a => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff))
|
||||
encrypt_ (_, rc, _) msg =
|
||||
runExceptT (rcEncrypt rc paddedMsgLen msg)
|
||||
initRatchetsKEMProposed :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
|
||||
initRatchetsKEMProposed = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (no KEM)
|
||||
(pkAlice1, pkAlice2, Nothing, e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOff
|
||||
-- propose KEM in reply
|
||||
let useKem = AUseKEM SRKSProposed ProposeKEM
|
||||
(pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem)
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 Nothing e2eBob
|
||||
(_, pkBob3) <- atomically $ C.generateKeyPair g
|
||||
let vs = testRatchetVersions PQSupportOn
|
||||
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
|
||||
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn
|
||||
pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>))
|
||||
|
||||
initRatchetsKEMAccepted :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
|
||||
initRatchetsKEMAccepted = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (propose)
|
||||
(pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn
|
||||
E2ERatchetParams _ _ _ (Just (RKParamsProposed aliceKem)) <- pure e2eAlice
|
||||
-- accept
|
||||
let useKem = AUseKEM SRKSAccepted (AcceptKEM aliceKem)
|
||||
(pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem)
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob
|
||||
(_, pkBob3) <- atomically $ C.generateKeyPair g
|
||||
let vs = testRatchetVersions PQSupportOn
|
||||
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
|
||||
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn
|
||||
pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>))
|
||||
|
||||
initRatchetsKEMProposedAgain :: forall a. (AlgorithmI a, DhAlgorithm a) => IO (Ratchet a, Ratchet a, Encrypt a, Decrypt a, EncryptDecryptSpec a)
|
||||
initRatchetsKEMProposedAgain = do
|
||||
g <- C.newRandom
|
||||
let v = max pqRatchetE2EEncryptVersion currentE2EEncryptVersion
|
||||
-- initiate (propose KEM)
|
||||
(pkAlice1, pkAlice2, pKem_@(Just _), e2eAlice) <- liftIO $ generateRcvE2EParams g v PQSupportOn
|
||||
-- propose KEM again in reply
|
||||
let useKem = AUseKEM SRKSProposed ProposeKEM
|
||||
(pkBob1, pkBob2, pKemParams_@(Just _), AE2ERatchetParams _ e2eBob) <- liftIO $ generateSndE2EParams g v (Just useKem)
|
||||
Right paramsBob <- pure $ pqX3dhSnd pkBob1 pkBob2 pKemParams_ e2eAlice
|
||||
Right paramsAlice <- runExceptT $ pqX3dhRcv pkAlice1 pkAlice2 pKem_ e2eBob
|
||||
(_, pkBob3) <- atomically $ C.generateKeyPair g
|
||||
let vs = testRatchetVersions PQSupportOn
|
||||
bob = initSndRatchet vs (C.publicKey pkAlice2) pkBob3 paramsBob
|
||||
alice = initRcvRatchet vs pkAlice2 paramsAlice PQSupportOn
|
||||
pure (alice, bob, encrypt' hasSndKEM, decrypt' hasRcvKEM, (!#>))
|
||||
|
||||
testRatchetVersions :: PQSupport -> RatchetVersions
|
||||
testRatchetVersions pq =
|
||||
let v = maxVersion $ supportedE2EEncryptVRange pq
|
||||
in RatchetVersions v v
|
||||
|
||||
encrypt_ :: AlgorithmI a => Maybe PQEncryption -> (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (ByteString, Ratchet a, SkippedMsgDiff))
|
||||
encrypt_ pqEnc_ (_, rc, _) msg =
|
||||
-- print msg >>
|
||||
runExceptT (rcEncrypt rc paddedMsgLen msg pqEnc_ currentE2EEncryptVersion)
|
||||
>>= either (pure . Left) checkLength
|
||||
where
|
||||
checkLength (msg', rc') = do
|
||||
B.length msg' `shouldBe` fullMsgLen
|
||||
B.length msg' `shouldBe` fullMsgLen rc'
|
||||
pure $ Right (msg', rc', SMDNoChange)
|
||||
|
||||
decrypt_ :: (AlgorithmI a, DhAlgorithm a) => (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString, Ratchet a, SkippedMsgDiff))
|
||||
decrypt_ (g, rc, smks) msg = runExceptT $ rcDecrypt g rc smks msg
|
||||
|
||||
encrypt :: AlgorithmI a => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError ByteString)
|
||||
encrypt = withTVar encrypt_
|
||||
encrypt' :: AlgorithmI a => (Ratchet a -> ()) -> Encrypt a
|
||||
encrypt' = withTVar $ encrypt_ Nothing
|
||||
|
||||
decrypt :: (AlgorithmI a, DhAlgorithm a) => TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either CryptoError (Either CryptoError ByteString))
|
||||
decrypt = withTVar decrypt_
|
||||
decrypt' :: (AlgorithmI a, DhAlgorithm a) => (Ratchet a -> ()) -> Decrypt a
|
||||
decrypt' = withTVar decrypt_
|
||||
|
||||
noSndKEM :: Ratchet a -> ()
|
||||
noSndKEM Ratchet {rcSndKEM = PQEncOn} = error "snd ratchet has KEM"
|
||||
noSndKEM _ = ()
|
||||
|
||||
noRcvKEM :: Ratchet a -> ()
|
||||
noRcvKEM Ratchet {rcRcvKEM = PQEncOn} = error "rcv ratchet has KEM"
|
||||
noRcvKEM _ = ()
|
||||
|
||||
hasSndKEM :: Ratchet a -> ()
|
||||
hasSndKEM Ratchet {rcSndKEM = PQEncOn} = ()
|
||||
hasSndKEM _ = error "snd ratchet has no KEM"
|
||||
|
||||
hasRcvKEM :: Ratchet a -> ()
|
||||
hasRcvKEM Ratchet {rcRcvKEM = PQEncOn} = ()
|
||||
hasRcvKEM _ = error "rcv ratchet has no KEM"
|
||||
|
||||
withTVar ::
|
||||
AlgorithmI a =>
|
||||
((TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) -> ByteString -> IO (Either e (r, Ratchet a, SkippedMsgDiff))) ->
|
||||
(Ratchet a -> ()) ->
|
||||
TVar (TVar ChaChaDRG, Ratchet a, SkippedMsgKeys) ->
|
||||
ByteString ->
|
||||
IO (Either e r)
|
||||
withTVar op rcVar msg = do
|
||||
withTVar op valid rcVar msg = do
|
||||
(g, rc, smks) <- readTVarIO rcVar
|
||||
applyDiff smks <$$> (testEncodeDecode rc >> op (g, rc, smks) msg)
|
||||
>>= \case
|
||||
Right (res, rc', smks') -> atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res)
|
||||
Right (res, rc', smks') -> valid rc' `seq` atomically (writeTVar rcVar (g, rc', smks')) >> pure (Right res)
|
||||
Left e -> pure $ Left e
|
||||
where
|
||||
applyDiff smks (res, rc', smDiff) = (res, rc', applySMDiff smks smDiff)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module AgentTests.EqInstances where
|
||||
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Store
|
||||
|
||||
instance Eq SomeConn where
|
||||
SomeConn d c == SomeConn d' c' = case testEquality d d' of
|
||||
Just Refl -> c == c'
|
||||
_ -> False
|
||||
|
||||
deriving instance Eq (Connection d)
|
||||
|
||||
deriving instance Eq (SConnType d)
|
||||
|
||||
deriving instance Eq (StoredRcvQueue q)
|
||||
|
||||
deriving instance Eq (StoredSndQueue q)
|
||||
|
||||
deriving instance Eq (DBQueueId q)
|
||||
|
||||
deriving instance Eq ClientNtfCreds
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,28 @@
|
||||
module AgentTests.NotificationTests where
|
||||
|
||||
-- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging)
|
||||
import AgentTests.FunctionalAPITests (agentCfgV7, exchangeGreetingsMsgId, get, getSMPAgentClient', makeConnection, nGet, runRight, runRight_, switchComplete, testServerMatrix2, withAgentClientsCfg2, (##>), (=##>), pattern Msg)
|
||||
import AgentTests.FunctionalAPITests
|
||||
( agentCfgV7,
|
||||
createConnection,
|
||||
exchangeGreetingsMsgId,
|
||||
get,
|
||||
getSMPAgentClient',
|
||||
joinConnection,
|
||||
makeConnection,
|
||||
nGet,
|
||||
runRight,
|
||||
runRight_,
|
||||
sendMessage,
|
||||
switchComplete,
|
||||
testServerMatrix2,
|
||||
withAgentClientsCfg2,
|
||||
(##>),
|
||||
(=##>),
|
||||
pattern CON,
|
||||
pattern CONF,
|
||||
pattern INFO,
|
||||
pattern Msg,
|
||||
)
|
||||
import Control.Concurrent (ThreadId, killThread, threadDelay)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
@@ -28,10 +49,10 @@ import Data.Text.Encoding (encodeUtf8)
|
||||
import NtfClient
|
||||
import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testDB3, testNtfServer, testNtfServer2)
|
||||
import SMPClient (cfg, cfgV7, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage)
|
||||
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore')
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO)
|
||||
import Simplex.Messaging.Agent.Store.SQLite (getSavedNtfToken)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
@@ -46,6 +67,7 @@ import Simplex.Messaging.Transport (ATransport)
|
||||
import System.Directory (doesFileExist, removeFile)
|
||||
import Test.Hspec
|
||||
import UnliftIO
|
||||
import Util
|
||||
|
||||
removeFileIfExists :: FilePath -> IO ()
|
||||
removeFileIfExists filePath = do
|
||||
@@ -125,8 +147,8 @@ testNtfMatrix t runTest = do
|
||||
it "next servers: SMP v7, NTF v2; next clients: v7/v2" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfgV7 agentCfgV7 runTest
|
||||
it "next servers: SMP v7, NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfgV2 agentCfg agentCfg runTest
|
||||
it "curr servers: SMP v6, NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfg agentCfg agentCfg runTest
|
||||
-- this case will cannot be supported - see RFC
|
||||
xit "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest
|
||||
skip "this case cannot be supported - see RFC" $
|
||||
it "servers: SMP v6, NTF v1; clients: v7/v2 (not supported)" $ runNtfTestCfg t cfg ntfServerCfg agentCfgV7 agentCfgV7 runTest
|
||||
-- servers can be migrated in any order
|
||||
it "servers: next SMP v7, curr NTF v1; curr clients: v6/v1" $ runNtfTestCfg t cfgV7 ntfServerCfg agentCfg agentCfg runTest
|
||||
it "servers: curr SMP v6, next NTF v2; curr clients: v6/v1" $ runNtfTestCfg t cfg ntfServerCfgV2 agentCfg agentCfg runTest
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
|
||||
|
||||
module AgentTests.SQLiteTests (storeTests) where
|
||||
|
||||
import AgentTests.EqInstances ()
|
||||
import Control.Concurrent.Async (concurrently_)
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (SomeException)
|
||||
import Control.Monad (replicateM_)
|
||||
import Control.Monad.Trans.Except
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.ByteArray (ScrubbedBytes)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.List (isInfixOf)
|
||||
@@ -38,9 +45,11 @@ import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction')
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQSupportOn)
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Crypto.File (CryptoFile (..))
|
||||
import Simplex.Messaging.Encoding.String (StrEncoding (..))
|
||||
import Simplex.Messaging.Protocol (SubscriptionMode (..))
|
||||
import Simplex.Messaging.Protocol (SubscriptionMode (..), pattern VersionSMPC)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import System.Random
|
||||
import Test.Hspec
|
||||
@@ -91,7 +100,7 @@ storeTests = do
|
||||
testForeignKeysEnabled
|
||||
describe "db methods" $ do
|
||||
describe "Queue and Connection management" $ do
|
||||
describe "createRcvConn" $ do
|
||||
describe "create Rcv connection" $ do
|
||||
testCreateRcvConn
|
||||
testCreateRcvConnRandomId
|
||||
testCreateRcvConnDuplicate
|
||||
@@ -172,7 +181,17 @@ testForeignKeysEnabled =
|
||||
`shouldThrow` (\e -> SQL.sqlError e == SQL.ErrorConstraint)
|
||||
|
||||
cData1 :: ConnData
|
||||
cData1 = ConnData {userId = 1, connId = "conn1", connAgentVersion = 1, enableNtfs = True, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
|
||||
cData1 =
|
||||
ConnData
|
||||
{ userId = 1,
|
||||
connId = "conn1",
|
||||
connAgentVersion = VersionSMPA 1,
|
||||
enableNtfs = True,
|
||||
lastExternalSndId = 0,
|
||||
deleted = False,
|
||||
ratchetSyncState = RSOk,
|
||||
pqSupport = CR.PQSupportOn
|
||||
}
|
||||
|
||||
testPrivateAuthKey :: C.APrivateAuthKey
|
||||
testPrivateAuthKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
|
||||
@@ -203,7 +222,7 @@ rcvQueue1 =
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
rcvSwchStatus = Nothing,
|
||||
smpClientVersion = 1,
|
||||
smpClientVersion = VersionSMPC 1,
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
@@ -224,9 +243,15 @@ sndQueue1 =
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
sndSwchStatus = Nothing,
|
||||
smpClientVersion = 1
|
||||
smpClientVersion = VersionSMPC 1
|
||||
}
|
||||
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewRcvQueue -> SConnectionMode c -> IO (Either StoreError (ConnId, RcvQueue))
|
||||
createRcvConn db g cData rq cMode = runExceptT $ do
|
||||
connId <- ExceptT $ createNewConn db g cData cMode
|
||||
rq' <- ExceptT $ updateNewConnRcv db connId rq
|
||||
pure (connId, rq')
|
||||
|
||||
testCreateRcvConn :: SpecWith SQLiteStore
|
||||
testCreateRcvConn =
|
||||
it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do
|
||||
@@ -312,8 +337,8 @@ testDeleteRcvConn =
|
||||
Right (_, rq) <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
deleteConn db Nothing "conn1"
|
||||
`shouldReturn` Just "conn1"
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left SEConnNotFound
|
||||
|
||||
@@ -324,8 +349,8 @@ testDeleteSndConn =
|
||||
Right (_, sq) <- createSndConn db g cData1 sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
deleteConn db Nothing "conn1"
|
||||
`shouldReturn` Just "conn1"
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left SEConnNotFound
|
||||
|
||||
@@ -337,8 +362,8 @@ testDeleteDuplexConn =
|
||||
Right sq <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq]))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
deleteConn db Nothing "conn1"
|
||||
`shouldReturn` Just "conn1"
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left SEConnNotFound
|
||||
|
||||
@@ -362,7 +387,7 @@ testUpgradeRcvConnToDuplex =
|
||||
sndSwchStatus = Nothing,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1
|
||||
smpClientVersion = VersionSMPC 1
|
||||
}
|
||||
upgradeRcvConnToDuplex db "conn1" anotherSndQueue
|
||||
`shouldReturn` Left (SEBadConnType CSnd)
|
||||
@@ -391,7 +416,7 @@ testUpgradeSndConnToDuplex =
|
||||
rcvSwchStatus = Nothing,
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
smpClientVersion = 1,
|
||||
smpClientVersion = VersionSMPC 1,
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
@@ -459,7 +484,8 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
|
||||
{ integrity = MsgOk,
|
||||
recipient = (unId internalId, ts),
|
||||
sndMsgId = externalSndId,
|
||||
broker = (brokerId, ts)
|
||||
broker = (brokerId, ts),
|
||||
pqEncryption = CR.PQEncOn
|
||||
},
|
||||
msgType = AM_A_MSG_,
|
||||
msgFlags = SMP.noMsgFlags,
|
||||
@@ -497,6 +523,7 @@ mkSndMsgData internalId internalSndId internalHash =
|
||||
msgType = AM_A_MSG_,
|
||||
msgFlags = SMP.noMsgFlags,
|
||||
msgBody = hw,
|
||||
pqEncryption = CR.PQEncOn,
|
||||
internalHash,
|
||||
prevMsgHash = internalHash
|
||||
}
|
||||
@@ -635,7 +662,7 @@ testGetPendingServerCommand st = do
|
||||
Right (Just PendingCommand {corrId = corrId'}) <- getPendingServerCommand db (Just smpServer1)
|
||||
corrId' `shouldBe` "4"
|
||||
where
|
||||
command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) SMSubscribe
|
||||
command = AClientCommand $ APC SAEConn $ NEW True (ACM SCMInvitation) (IKNoPQ PQSupportOn) SMSubscribe
|
||||
corruptCmd :: DB.Connection -> ByteString -> ConnId -> IO ()
|
||||
corruptCmd db corrId connId = DB.execute db "UPDATE commands SET command = cast('bad' as blob) WHERE conn_id = ? AND corr_id = ?" (connId, corrId)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Version (Version)
|
||||
import Test.Hspec
|
||||
|
||||
batchingTests :: Spec
|
||||
@@ -253,27 +252,27 @@ testClientBatchWithLargeMessageV7 = do
|
||||
(length rs1', length rs2') `shouldBe` (74, 136)
|
||||
all lenOk [s1', s2'] `shouldBe` True
|
||||
|
||||
testClientStub :: IO (ProtocolClient ErrorType BrokerMsg)
|
||||
testClientStub :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg)
|
||||
testClientStub = do
|
||||
g <- C.newRandom
|
||||
sessId <- atomically $ C.randomBytes 32 g
|
||||
atomically $ clientStub g sessId (authCmdsSMPVersion - 1) Nothing
|
||||
atomically $ smpClientStub g sessId subModeSMPVersion Nothing
|
||||
|
||||
clientStubV7 :: IO (ProtocolClient ErrorType BrokerMsg)
|
||||
clientStubV7 :: IO (ProtocolClient SMPVersion ErrorType BrokerMsg)
|
||||
clientStubV7 = do
|
||||
g <- C.newRandom
|
||||
sessId <- atomically $ C.randomBytes 32 g
|
||||
(rKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g
|
||||
thAuth_ <- testTHandleAuth authCmdsSMPVersion g rKey
|
||||
atomically $ clientStub g sessId authCmdsSMPVersion thAuth_
|
||||
atomically $ smpClientStub g sessId authCmdsSMPVersion thAuth_
|
||||
|
||||
randomSUB :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSUB = randomSUB_ C.SEd25519 (authCmdsSMPVersion - 1)
|
||||
randomSUB = randomSUB_ C.SEd25519 subModeSMPVersion
|
||||
|
||||
randomSUBv7 :: ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSUBv7 = randomSUB_ C.SEd25519 authCmdsSMPVersion
|
||||
|
||||
randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSUB_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSUB_ a v sessId = do
|
||||
g <- C.newRandom
|
||||
rId <- atomically $ C.randomBytes 24 g
|
||||
@@ -284,13 +283,13 @@ randomSUB_ a v sessId = do
|
||||
TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, rId, Cmd SRecipient SUB)
|
||||
pure $ (,tToSend) <$> authTransmission thAuth_ (Just rpKey) corrId tForAuth
|
||||
|
||||
randomSUBCmd :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSUBCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSUBCmd = randomSUBCmd_ C.SEd25519
|
||||
|
||||
randomSUBCmdV7 :: ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSUBCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSUBCmdV7 = randomSUBCmd_ C.SEd25519 -- same as v6
|
||||
|
||||
randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSUBCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSUBCmd_ a c = do
|
||||
g <- C.newRandom
|
||||
rId <- atomically $ C.randomBytes 24 g
|
||||
@@ -298,12 +297,12 @@ randomSUBCmd_ a c = do
|
||||
mkTransmission c (Just rpKey, rId, Cmd SRecipient SUB)
|
||||
|
||||
randomSEND :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSEND = randomSEND_ C.SEd25519 (authCmdsSMPVersion - 1)
|
||||
randomSEND = randomSEND_ C.SEd25519 subModeSMPVersion
|
||||
|
||||
randomSENDv7 :: ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSENDv7 = randomSEND_ C.SX25519 authCmdsSMPVersion
|
||||
|
||||
randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> Version -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSEND_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> VersionSMP -> ByteString -> Int -> IO (Either TransportError (Maybe TransmissionAuth, ByteString))
|
||||
randomSEND_ a v sessId len = do
|
||||
g <- C.newRandom
|
||||
sId <- atomically $ C.randomBytes 24 g
|
||||
@@ -315,7 +314,7 @@ randomSEND_ a v sessId len = do
|
||||
TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth thParams (corrId, sId, Cmd SSender $ SEND noMsgFlags msg)
|
||||
pure $ (,tToSend) <$> authTransmission thAuth_ (Just spKey) corrId tForAuth
|
||||
|
||||
testTHandleParams :: Version -> ByteString -> THandleParams
|
||||
testTHandleParams :: VersionSMP -> ByteString -> THandleParams SMPVersion
|
||||
testTHandleParams v sessionId =
|
||||
THandleParams
|
||||
{ sessionId,
|
||||
@@ -326,20 +325,20 @@ testTHandleParams v sessionId =
|
||||
batch = True
|
||||
}
|
||||
|
||||
testTHandleAuth :: Version -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth)
|
||||
testTHandleAuth :: VersionSMP -> TVar ChaChaDRG -> C.APublicAuthKey -> IO (Maybe THandleAuth)
|
||||
testTHandleAuth v g (C.APublicAuthKey a k) = case a of
|
||||
C.SX25519 | v >= authCmdsSMPVersion -> do
|
||||
(_, privKey) <- atomically $ C.generateKeyPair g
|
||||
pure $ Just THandleAuth {peerPubKey = k, privKey}
|
||||
_ -> pure Nothing
|
||||
|
||||
randomSENDCmd :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSENDCmd :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSENDCmd = randomSENDCmd_ C.SEd25519
|
||||
|
||||
randomSENDCmdV7 :: ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSENDCmdV7 :: ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSENDCmdV7 = randomSENDCmd_ C.SX25519
|
||||
|
||||
randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSENDCmd_ :: (C.AlgorithmI a, C.AuthAlgorithm a) => C.SAlgorithm a -> ProtocolClient SMPVersion ErrorType BrokerMsg -> Int -> IO (PCTransmission ErrorType BrokerMsg)
|
||||
randomSENDCmd_ a c len = do
|
||||
g <- C.newRandom
|
||||
sId <- atomically $ C.randomBytes 24 g
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module CoreTests.CryptoTests (cryptoTests) where
|
||||
|
||||
@@ -13,6 +15,7 @@ import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import qualified Data.Text.Lazy as LT
|
||||
import qualified Data.Text.Lazy.Encoding as LE
|
||||
import Data.Type.Equality
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
|
||||
@@ -91,6 +94,16 @@ cryptoTests = do
|
||||
describe "sntrup761" $
|
||||
it "should enc/dec key" testSNTRUP761
|
||||
|
||||
instance Eq C.APublicKey where
|
||||
C.APublicKey a k == C.APublicKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
instance Eq C.APrivateKey where
|
||||
C.APrivateKey a k == C.APrivateKey a' k' = case testEquality a a' of
|
||||
Just Refl -> k == k'
|
||||
Nothing -> False
|
||||
|
||||
testPadUnpadFile :: IO ()
|
||||
testPadUnpadFile = do
|
||||
let f = "tests/tmp/testpad"
|
||||
|
||||
@@ -12,7 +12,7 @@ import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Simplex.FileTransfer.Protocol (XFTPErrorType (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module CoreTests.TRcvQueuesTests where
|
||||
|
||||
import AgentTests.EqInstances ()
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
@@ -11,7 +13,7 @@ import Simplex.Messaging.Agent.Protocol (ConnId, QueueStatus (..), UserId)
|
||||
import Simplex.Messaging.Agent.Store (DBQueueId (..), RcvQueue, StoredRcvQueue (..))
|
||||
import qualified Simplex.Messaging.Agent.TRcvQueues as RQ
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (SMPServer)
|
||||
import Simplex.Messaging.Protocol (SMPServer, pattern VersionSMPC)
|
||||
import Test.Hspec
|
||||
import UnliftIO
|
||||
|
||||
@@ -136,7 +138,7 @@ dummyRQ userId server connId =
|
||||
primary = True,
|
||||
dbReplaceQueueId = Nothing,
|
||||
rcvSwchStatus = Nothing,
|
||||
smpClientVersion = 123,
|
||||
smpClientVersion = VersionSMPC 123,
|
||||
clientNtfCreds = Nothing,
|
||||
deleteErrors = 0
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
module CoreTests.VersionRangeTests where
|
||||
@@ -8,6 +9,7 @@ module CoreTests.VersionRangeTests where
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Simplex.Messaging.Version
|
||||
import Simplex.Messaging.Version.Internal
|
||||
import Test.Hspec
|
||||
import Test.Hspec.QuickCheck (modifyMaxSuccess)
|
||||
import Test.QuickCheck
|
||||
@@ -16,6 +18,10 @@ data V = V1 | V2 | V3 | V4 | V5 deriving (Eq, Enum, Ord, Generic, Show)
|
||||
|
||||
instance Arbitrary V where arbitrary = genericArbitraryU
|
||||
|
||||
data T
|
||||
|
||||
instance VersionScope T
|
||||
|
||||
versionRangeTests :: Spec
|
||||
versionRangeTests = modifyMaxSuccess (const 1000) $ do
|
||||
describe "VersionRange construction" $ do
|
||||
@@ -25,31 +31,31 @@ versionRangeTests = modifyMaxSuccess (const 1000) $ do
|
||||
(pure $! vr 2 1) `shouldThrow` anyErrorCall
|
||||
describe "compatible version" $ do
|
||||
it "should choose mutually compatible max version" $ do
|
||||
(vr 1 1, vr 1 1) `compatible` Just 1
|
||||
(vr 1 1, vr 1 2) `compatible` Just 1
|
||||
(vr 1 2, vr 1 2) `compatible` Just 2
|
||||
(vr 1 2, vr 2 3) `compatible` Just 2
|
||||
(vr 1 3, vr 2 3) `compatible` Just 3
|
||||
(vr 1 3, vr 2 4) `compatible` Just 3
|
||||
(vr 1 1, vr 1 1) `compatible` Just (Version 1)
|
||||
(vr 1 1, vr 1 2) `compatible` Just (Version 1)
|
||||
(vr 1 2, vr 1 2) `compatible` Just (Version 2)
|
||||
(vr 1 2, vr 2 3) `compatible` Just (Version 2)
|
||||
(vr 1 3, vr 2 3) `compatible` Just (Version 3)
|
||||
(vr 1 3, vr 2 4) `compatible` Just (Version 3)
|
||||
(vr 1 2, vr 3 4) `compatible` Nothing
|
||||
it "should check if version is compatible" $ do
|
||||
isCompatible (1 :: Version) (vr 1 2) `shouldBe` True
|
||||
isCompatible (2 :: Version) (vr 1 2) `shouldBe` True
|
||||
isCompatible (2 :: Version) (vr 1 1) `shouldBe` False
|
||||
isCompatible (1 :: Version) (vr 2 2) `shouldBe` False
|
||||
isCompatible @T (Version 1) (vr 1 2) `shouldBe` True
|
||||
isCompatible @T (Version 2) (vr 1 2) `shouldBe` True
|
||||
isCompatible @T (Version 2) (vr 1 1) `shouldBe` False
|
||||
isCompatible @T (Version 1) (vr 2 2) `shouldBe` False
|
||||
it "compatibleVersion should pass isCompatible check" . property $
|
||||
\((min1, max1) :: (V, V)) ((min2, max2) :: (V, V)) ->
|
||||
min1 > max1
|
||||
|| min2 > max2 -- one of ranges is invalid, skip testing it
|
||||
|| let w = fromIntegral . fromEnum
|
||||
vr1 = mkVersionRange (w min1) (w max1) :: VersionRange
|
||||
vr2 = mkVersionRange (w min2) (w max2) :: VersionRange
|
||||
|| let w = Version . fromIntegral . fromEnum
|
||||
vr1 = mkVersionRange (w min1) (w max1) :: VersionRange T
|
||||
vr2 = mkVersionRange (w min2) (w max2) :: VersionRange T
|
||||
in case compatibleVersion vr1 vr2 of
|
||||
Just (Compatible v) -> v `isCompatible` vr1 && v `isCompatible` vr2
|
||||
_ -> True
|
||||
where
|
||||
vr = mkVersionRange
|
||||
compatible :: (VersionRange, VersionRange) -> Maybe Version -> Expectation
|
||||
vr v1 v2 = mkVersionRange (Version v1) (Version v2)
|
||||
compatible :: (VersionRange T, VersionRange T) -> Maybe (Version T) -> Expectation
|
||||
(vr1, vr2) `compatible` v = do
|
||||
(vr1, vr2) `checkCompatible` v
|
||||
(vr2, vr1) `checkCompatible` v
|
||||
|
||||
+8
-7
@@ -34,6 +34,7 @@ import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost,
|
||||
import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Protocol (NtfResponse)
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
@@ -70,7 +71,7 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
ntfTestStoreLogFile :: FilePath
|
||||
ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
|
||||
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleNTF c -> m a) -> m a
|
||||
testNtfClient client = do
|
||||
Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig Nothing host ntfTestPort (Just testKeyHash) $ \h -> do
|
||||
@@ -114,8 +115,8 @@ ntfServerCfg =
|
||||
ntfServerCfgV2 :: NtfServerConfig
|
||||
ntfServerCfgV2 =
|
||||
ntfServerCfg
|
||||
{ ntfServerVRange = mkVersionRange 1 authBatchCmdsNTFVersion,
|
||||
smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange 4 authCmdsSMPVersion}}
|
||||
{ ntfServerVRange = mkVersionRange initialNTFVersion authBatchCmdsNTFVersion,
|
||||
smpAgentCfg = defaultSMPClientAgentConfig {smpCfg = (smpCfg defaultSMPClientAgentConfig) {serverVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}}
|
||||
}
|
||||
|
||||
withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a
|
||||
@@ -139,7 +140,7 @@ withNtfServerOn t port' = withNtfServerThreadOn t port' . const
|
||||
withNtfServer :: ATransport -> IO a -> IO a
|
||||
withNtfServer t = withNtfServerOn t ntfTestPort
|
||||
|
||||
runNtfTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a
|
||||
runNtfTest :: forall c a. Transport c => (THandleNTF c -> IO a) -> IO a
|
||||
runNtfTest test = withNtfServer (transport @c) $ testNtfClient test
|
||||
|
||||
ntfServerTest ::
|
||||
@@ -147,10 +148,10 @@ ntfServerTest ::
|
||||
(Transport c, Encoding smp) =>
|
||||
TProxy c ->
|
||||
(Maybe TransmissionAuth, ByteString, ByteString, smp) ->
|
||||
IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg)
|
||||
IO (Maybe TransmissionAuth, ByteString, ByteString, NtfResponse)
|
||||
ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
||||
where
|
||||
tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
|
||||
tPut' :: THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
|
||||
tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do
|
||||
let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp)
|
||||
[Right ()] <- tPut h [Right (sig, t')]
|
||||
@@ -159,7 +160,7 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
||||
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
ntfTest :: Transport c => TProxy c -> (THandleNTF c -> IO ()) -> Expectation
|
||||
ntfTest _ test' = runNtfTest test' `shouldReturn` ()
|
||||
|
||||
data APNSMockRequest = APNSMockRequest
|
||||
|
||||
+13
-8
@@ -5,7 +5,9 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module NtfServerTests where
|
||||
|
||||
@@ -37,6 +39,7 @@ import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import qualified Simplex.Messaging.Notifications.Server.Push.APNS as APNS
|
||||
import Simplex.Messaging.Notifications.Transport (THandleNTF)
|
||||
import Simplex.Messaging.Parsers (parse, parseAll)
|
||||
import Simplex.Messaging.Protocol hiding (notification)
|
||||
import Simplex.Messaging.Transport
|
||||
@@ -50,30 +53,32 @@ ntfServerTests t = do
|
||||
|
||||
ntfSyntaxTests :: ATransport -> Spec
|
||||
ntfSyntaxTests (ATransport t) = do
|
||||
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
|
||||
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", NRErr $ CMD UNKNOWN)
|
||||
describe "NEW" $ do
|
||||
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX)
|
||||
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX)
|
||||
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH)
|
||||
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH)
|
||||
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", NRErr $ CMD SYNTAX)
|
||||
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", NRErr $ CMD SYNTAX)
|
||||
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", NRErr $ CMD NO_AUTH)
|
||||
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", NRErr $ CMD HAS_AUTH)
|
||||
where
|
||||
(>#>) ::
|
||||
Encoding smp =>
|
||||
(Maybe TransmissionAuth, ByteString, ByteString, smp) ->
|
||||
(Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg) ->
|
||||
(Maybe TransmissionAuth, ByteString, ByteString, NtfResponse) ->
|
||||
Expectation
|
||||
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
|
||||
|
||||
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission ErrorType NtfResponse
|
||||
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
||||
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
deriving instance Eq NtfResponse
|
||||
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> (Maybe TransmissionAuth, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
sendRecvNtf h@THandle {params} (sgn, corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (sgn, tToSend)
|
||||
tGet1 h
|
||||
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandleNTF c -> C.APrivateAuthKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission ErrorType NtfResponse)
|
||||
signSendRecvNtf h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (authorize tForAuth, tToSend)
|
||||
|
||||
+21
-20
@@ -26,7 +26,7 @@ import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.Server
|
||||
import Simplex.Messaging.Version (VersionRange, mkVersionRange)
|
||||
import Simplex.Messaging.Version (mkVersionRange)
|
||||
import System.Environment (lookupEnv)
|
||||
import System.Info (os)
|
||||
import Test.Hspec
|
||||
@@ -34,6 +34,7 @@ import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
|
||||
import UnliftIO.Timeout (timeout)
|
||||
import Util
|
||||
|
||||
testHost :: NonEmpty TransportHost
|
||||
testHost = "localhost"
|
||||
@@ -60,17 +61,17 @@ testServerStatsBackupFile :: FilePath
|
||||
testServerStatsBackupFile = "tests/tmp/smp-server-stats.log"
|
||||
|
||||
xit' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a)
|
||||
xit' = if os == "linux" then xit else it
|
||||
xit' d = if os == "linux" then skip "skipped on Linux" . it d else it d
|
||||
|
||||
xit'' :: (HasCallStack, Example a) => String -> a -> SpecWith (Arg a)
|
||||
xit'' d t = do
|
||||
ci <- runIO $ lookupEnv "CI"
|
||||
(if ci == Just "true" then xit else it) d t
|
||||
(if ci == Just "true" then skip "skipped on CI" . it d else it d) t
|
||||
|
||||
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandle c -> m a) -> m a
|
||||
testSMPClient :: (Transport c, MonadUnliftIO m, MonadFail m) => (THandleSMP c -> m a) -> m a
|
||||
testSMPClient = testSMPClientVR supportedClientSMPRelayVRange
|
||||
|
||||
testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRange -> (THandle c -> m a) -> m a
|
||||
testSMPClientVR :: (Transport c, MonadUnliftIO m, MonadFail m) => VersionRangeSMP -> (THandleSMP c -> m a) -> m a
|
||||
testSMPClientVR vr client = do
|
||||
Right useHost <- pure $ chooseTransportHost defaultNetworkConfig testHost
|
||||
runTransportClient defaultTransportClientConfig Nothing useHost testPort (Just testKeyHash) $ \h -> do
|
||||
@@ -109,7 +110,7 @@ cfg =
|
||||
}
|
||||
|
||||
cfgV7 :: ServerConfig
|
||||
cfgV7 = cfg {smpServerVRange = mkVersionRange 4 authCmdsSMPVersion}
|
||||
cfgV7 = cfg {smpServerVRange = mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion}
|
||||
|
||||
withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
|
||||
withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile}
|
||||
@@ -148,16 +149,16 @@ withSmpServer t = withSmpServerOn t testPort
|
||||
withSmpServerV7 :: HasCallStack => ATransport -> IO a -> IO a
|
||||
withSmpServerV7 t = withSmpServerConfigOn t cfgV7 testPort . const
|
||||
|
||||
runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a
|
||||
runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandleSMP c -> IO a) -> IO a
|
||||
runSmpTest test = withSmpServer (transport @c) $ testSMPClient test
|
||||
|
||||
runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a
|
||||
runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a
|
||||
runSmpTestN = runSmpTestNCfg cfg supportedClientSMPRelayVRange
|
||||
|
||||
runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> Int -> (HasCallStack => [THandle c] -> IO a) -> IO a
|
||||
runSmpTestNCfg :: forall c a. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> Int -> (HasCallStack => [THandleSMP c] -> IO a) -> IO a
|
||||
runSmpTestNCfg srvCfg clntVR nClients test = withSmpServerConfigOn (transport @c) srvCfg testPort $ \_ -> run nClients []
|
||||
where
|
||||
run :: Int -> [THandle c] -> IO a
|
||||
run :: Int -> [THandleSMP c] -> IO a
|
||||
run 0 hs = test hs
|
||||
run n hs = testSMPClientVR clntVR $ \h -> run (n - 1) (h : hs)
|
||||
|
||||
@@ -169,7 +170,7 @@ smpServerTest ::
|
||||
IO (Maybe TransmissionAuth, ByteString, ByteString, BrokerMsg)
|
||||
smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h
|
||||
where
|
||||
tPut' :: THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
|
||||
tPut' :: THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, smp) -> IO ()
|
||||
tPut' h@THandle {params = THandleParams {sessionId, implySessId}} (sig, corrId, queueId, smp) = do
|
||||
let t' = if implySessId then smpEncode (corrId, queueId, smp) else smpEncode (sessionId, corrId, queueId, smp)
|
||||
[Right ()] <- tPut h [Right (sig, t')]
|
||||
@@ -178,33 +179,33 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h
|
||||
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation
|
||||
smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> IO ()) -> Expectation
|
||||
smpTest _ test' = runSmpTest test' `shouldReturn` ()
|
||||
|
||||
smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation
|
||||
smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandleSMP c] -> IO ()) -> Expectation
|
||||
smpTestN n test' = runSmpTestN n test' `shouldReturn` ()
|
||||
|
||||
smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation
|
||||
smpTest2 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
|
||||
smpTest2 = smpTest2Cfg cfg supportedClientSMPRelayVRange
|
||||
|
||||
smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRange -> TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation
|
||||
smpTest2Cfg :: forall c. (HasCallStack, Transport c) => ServerConfig -> VersionRangeSMP -> TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
|
||||
smpTest2Cfg srvCfg clntVR _ test' = runSmpTestNCfg srvCfg clntVR 2 _test `shouldReturn` ()
|
||||
where
|
||||
_test :: HasCallStack => [THandle c] -> IO ()
|
||||
_test :: HasCallStack => [THandleSMP c] -> IO ()
|
||||
_test [h1, h2] = test' h1 h2
|
||||
_test _ = error "expected 2 handles"
|
||||
|
||||
smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
|
||||
smpTest3 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
|
||||
smpTest3 _ test' = smpTestN 3 _test
|
||||
where
|
||||
_test :: HasCallStack => [THandle c] -> IO ()
|
||||
_test :: HasCallStack => [THandleSMP c] -> IO ()
|
||||
_test [h1, h2, h3] = test' h1 h2 h3
|
||||
_test _ = error "expected 3 handles"
|
||||
|
||||
smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
|
||||
smpTest4 :: forall c. (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandleSMP c -> THandleSMP c -> THandleSMP c -> THandleSMP c -> IO ()) -> Expectation
|
||||
smpTest4 _ test' = smpTestN 4 _test
|
||||
where
|
||||
_test :: HasCallStack => [THandle c] -> IO ()
|
||||
_test :: HasCallStack => [THandleSMP c] -> IO ()
|
||||
_test [h1, h2, h3, h4] = test' h1 h2 h3 h4
|
||||
_test _ = error "expected 4 handles"
|
||||
|
||||
|
||||
+33
-20
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -8,7 +9,9 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -Wno-orphans #-}
|
||||
|
||||
module ServerTests where
|
||||
|
||||
@@ -23,6 +26,7 @@ import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Set as S
|
||||
import Data.Type.Equality
|
||||
import GHC.Stack (withFrozenCallStack)
|
||||
import SMPClient
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -74,13 +78,13 @@ pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh)
|
||||
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
|
||||
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
|
||||
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> (Maybe TransmissionAuth, ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
sendRecv h@THandle {params} (sgn, corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (sgn, tToSend)
|
||||
tGet1 h
|
||||
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandleSMP c -> C.APrivateAuthKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission ErrorType BrokerMsg)
|
||||
signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
|
||||
let TransmissionForAuth {tForAuth, tToSend} = encodeTransmissionForAuth params (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut1 h (authorize tForAuth, tToSend)
|
||||
@@ -94,12 +98,12 @@ signSendRecv h@THandle {params} (C.APrivateAuthKey a pk) (corrId, qId, cmd) = do
|
||||
_sx448 -> undefined -- ghc8107 fails to the branch excluded by types
|
||||
#endif
|
||||
|
||||
tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut1 :: Transport c => THandle v c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut1 h t = do
|
||||
[r] <- tPut h [Right t]
|
||||
pure r
|
||||
|
||||
tGet1 :: (ProtocolEncoding err cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission err cmd)
|
||||
tGet1 :: (ProtocolEncoding v err cmd, Transport c, MonadIO m, MonadFail m) => THandle v c -> m (SignedTransmission err cmd)
|
||||
tGet1 h = do
|
||||
[r] <- liftIO $ tGet h
|
||||
pure r
|
||||
@@ -380,7 +384,7 @@ testSwitchSub (ATransport t) =
|
||||
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3)
|
||||
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
|
||||
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh1 >>= \case
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh1 >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else is delivered to the 1st TCP connection"
|
||||
|
||||
@@ -551,12 +555,12 @@ testWithStoreLog at@(ATransport t) =
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
removeFile testStoreLogFile
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
killThread server
|
||||
|
||||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
logSize :: FilePath -> IO Int
|
||||
@@ -649,12 +653,12 @@ testRestoreMessages at@(ATransport t) =
|
||||
removeFile testStoreMsgsFile
|
||||
removeFile testServerStatsBackupFile
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
killThread server
|
||||
|
||||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
checkStats :: ServerStatsData -> [RecipientId] -> Int -> Int -> Expectation
|
||||
@@ -723,15 +727,15 @@ testRestoreExpireMessages at@(ATransport t) =
|
||||
Right ServerStatsData {_msgExpired} <- strDecode <$> B.readFile testServerStatsBackupFile
|
||||
_msgExpired `shouldBe` 2
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
killThread server
|
||||
|
||||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient :: Transport c => TProxy c -> (THandleSMP c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
createAndSecureQueue :: Transport c => THandle c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret)
|
||||
createAndSecureQueue :: Transport c => THandleSMP c -> SndPublicAuthKey -> IO (SenderId, RecipientId, RcvPrivateAuthKey, RcvDhSecret)
|
||||
createAndSecureQueue h sPub = do
|
||||
g <- C.newRandom
|
||||
(rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g
|
||||
@@ -747,7 +751,7 @@ testTiming (ATransport t) =
|
||||
describe "should have similar time for auth error, whether queue exists or not, for all key types" $
|
||||
forM_ timingTests $ \tst ->
|
||||
it (testName tst) $
|
||||
smpTest2Cfg cfgV7 (mkVersionRange 4 authCmdsSMPVersion) t $ \rh sh ->
|
||||
smpTest2Cfg cfgV7 (mkVersionRange batchCmdsSMPVersion authCmdsSMPVersion) t $ \rh sh ->
|
||||
testSameTiming rh sh tst
|
||||
where
|
||||
testName :: (C.AuthAlg, C.AuthAlg, Int) -> String
|
||||
@@ -766,7 +770,7 @@ testTiming (ATransport t) =
|
||||
]
|
||||
timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const
|
||||
similarTime t1 t2 = abs (t2 / t1 - 1) < 0.15 -- normally the difference between "no queue" and "wrong key" is less than 5%
|
||||
testSameTiming :: forall c. Transport c => THandle c -> THandle c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation
|
||||
testSameTiming :: forall c. Transport c => THandleSMP c -> THandleSMP c -> (C.AuthAlg, C.AuthAlg, Int) -> Expectation
|
||||
testSameTiming rh sh (C.AuthAlg goodKeyAlg, C.AuthAlg badKeyAlg, n) = do
|
||||
g <- C.newRandom
|
||||
(rPub, rKey) <- atomically $ C.generateAuthKeyPair goodKeyAlg g
|
||||
@@ -787,7 +791,7 @@ testTiming (ATransport t) =
|
||||
|
||||
runTimingTest sh badKey sId $ _SEND "hello"
|
||||
where
|
||||
runTimingTest :: PartyI p => THandle c -> C.APrivateAuthKey -> ByteString -> Command p -> IO ()
|
||||
runTimingTest :: PartyI p => THandleSMP c -> C.APrivateAuthKey -> ByteString -> Command p -> IO ()
|
||||
runTimingTest h badKey qId cmd = do
|
||||
threadDelay 100000
|
||||
_ <- timeRepeat n $ do -- "warm up" the server
|
||||
@@ -837,14 +841,14 @@ testMessageNotifications (ATransport t) =
|
||||
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
|
||||
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh2
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg nh1 >>= \case
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh1 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
|
||||
Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there")
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
|
||||
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg nh2 >>= \case
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg nh2 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
|
||||
|
||||
@@ -864,7 +868,7 @@ testMsgExpireOnSend t =
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else should be delivered"
|
||||
|
||||
@@ -884,7 +888,7 @@ testMsgExpireOnInterval t =
|
||||
signSendRecv rh rKey ("2", rId, SUB) >>= \case
|
||||
Resp "2" _ OK -> pure ()
|
||||
r -> unexpected r
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing should be delivered"
|
||||
|
||||
@@ -903,7 +907,7 @@ testMsgNOTExpireOnInterval t =
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @ErrorType @BrokerMsg rh >>= \case
|
||||
1000 `timeout` tGet @SMPVersion @ErrorType @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else should be delivered"
|
||||
|
||||
@@ -919,6 +923,15 @@ sampleSig = Just $ TASignature "e8JK+8V3fq6kOLqco/SaKlpNaQ7i1gfOrXoqekEl42u4mF8B
|
||||
noAuth :: (Char, Maybe BasicAuth)
|
||||
noAuth = ('A', Nothing)
|
||||
|
||||
deriving instance Eq TransmissionAuth
|
||||
|
||||
instance Eq C.ASignature where
|
||||
C.ASignature a s == C.ASignature a' s' = case testEquality a a' of
|
||||
Just Refl -> s == s'
|
||||
_ -> False
|
||||
|
||||
deriving instance Eq (C.Signature a)
|
||||
|
||||
syntaxTests :: ATransport -> Spec
|
||||
syntaxTests (ATransport t) = do
|
||||
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
module Util where
|
||||
|
||||
import Test.Hspec
|
||||
|
||||
skip :: String -> SpecWith a -> SpecWith a
|
||||
skip = before_ . pendingWith
|
||||
+2
-1
@@ -21,7 +21,8 @@ import Data.List (find, isSuffixOf)
|
||||
import Data.Maybe (fromJust)
|
||||
import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2, testDB3)
|
||||
import Simplex.FileTransfer.Description (FileDescription (..), FileDescriptionURI (..), ValidFileDescription, fileDescriptionURI, mb, qrSizeLimit, pattern ValidFileDescription)
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..), XFTPErrorType (AUTH))
|
||||
import Simplex.FileTransfer.Protocol (FileParty (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPErrorType (AUTH))
|
||||
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..))
|
||||
import Simplex.Messaging.Agent (AgentClient, disconnectAgentClient, testProtocolServer, xftpDeleteRcvFile, xftpDeleteSndFileInternal, xftpDeleteSndFileRemote, xftpReceiveFile, xftpSendDescription, xftpSendFile, xftpStartWorkers)
|
||||
import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..))
|
||||
|
||||
@@ -20,9 +20,9 @@ import Data.List (isInfixOf)
|
||||
import ServerTests (logSize)
|
||||
import Simplex.FileTransfer.Client
|
||||
import Simplex.FileTransfer.Description (kb)
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..), XFTPErrorType (..))
|
||||
import Simplex.FileTransfer.Protocol (FileInfo (..))
|
||||
import Simplex.FileTransfer.Server.Env (XFTPServerConfig (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..))
|
||||
import Simplex.FileTransfer.Transport (XFTPRcvChunkSpec (..), XFTPErrorType (..))
|
||||
import Simplex.Messaging.Client (ProtocolClientError (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Lazy as LC
|
||||
|
||||
Reference in New Issue
Block a user