mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-25 12:02:18 +00:00
Merge pull request #82 from simplex-chat/v2
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module ChatOptions (getChatOpts, ChatOpts (..)) where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Options.Applicative
|
||||
import Simplex.Messaging.Agent.Transmission (SMPServer (..), smpServerP)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import System.FilePath (combine)
|
||||
import Types
|
||||
|
||||
@@ -30,8 +31,8 @@ chatOpts appDir =
|
||||
( long "server"
|
||||
<> short 's'
|
||||
<> metavar "SERVER"
|
||||
<> help "SMP server to use (smp.simplex.im:5223)"
|
||||
<> value (SMPServer "smp.simplex.im" (Just "5223") Nothing)
|
||||
<> help "SMP server to use (smp1.simplex.im:5223#pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA=)"
|
||||
<> value (SMPServer "smp1.simplex.im" (Just "5223") (Just "pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA="))
|
||||
)
|
||||
<*> option
|
||||
parseTermMode
|
||||
@@ -45,7 +46,7 @@ chatOpts appDir =
|
||||
defaultDbFilePath = combine appDir "smp-chat.db"
|
||||
|
||||
parseSMPServer :: ReadM SMPServer
|
||||
parseSMPServer = eitherReader $ A.parseOnly (smpServerP <* A.endOfInput) . B.pack
|
||||
parseSMPServer = eitherReader $ parseAll smpServerP . B.pack
|
||||
|
||||
parseTermMode :: ReadM TermMode
|
||||
parseTermMode = maybeReader $ \case
|
||||
|
||||
@@ -31,6 +31,7 @@ import Simplex.Messaging.Agent.Client (AgentClient (..))
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client (smpDefaultConfig)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util (raceAny_)
|
||||
import Styled
|
||||
import System.Console.ANSI.Types
|
||||
@@ -89,6 +90,7 @@ data ChatResponse
|
||||
| ReceivedMessage Contact ByteString
|
||||
| Disconnected Contact
|
||||
| YesYes
|
||||
| ContactError ConnectionErrorType Contact
|
||||
| ErrorInput ByteString
|
||||
| ChatError AgentErrorType
|
||||
| NoChatResponse
|
||||
@@ -107,8 +109,13 @@ serializeChatResponse = \case
|
||||
Connected c -> [ttyContact c <> " connected"]
|
||||
Confirmation c -> [ttyContact c <> " ok"]
|
||||
ReceivedMessage c t -> prependFirst (ttyFromContact c) $ msgPlain t
|
||||
-- TODO either add command to re-connect or update message below
|
||||
Disconnected c -> ["disconnected from " <> ttyContact c <> " - try \"/chat " <> bPlain (toBs c) <> "\""]
|
||||
YesYes -> ["you got it!"]
|
||||
ContactError e c -> case e of
|
||||
UNKNOWN -> ["no contact " <> ttyContact c]
|
||||
DUPLICATE -> ["contact " <> ttyContact c <> " already exists"]
|
||||
SIMPLEX -> ["contact " <> ttyContact c <> " did not accept invitation yet"]
|
||||
ErrorInput t -> ["invalid input: " <> bPlain t]
|
||||
ChatError e -> ["chat error: " <> plain (show e)]
|
||||
NoChatResponse -> [""]
|
||||
@@ -172,7 +179,7 @@ main = do
|
||||
t <- getChatClient smpServer
|
||||
ct <- newChatTerminal (tbqSize cfg) termMode
|
||||
-- setLogLevel LogInfo -- LogError
|
||||
-- withGlobalLogging logCfg $
|
||||
-- withGlobalLogging logCfg $ do
|
||||
env <- newSMPAgentEnv cfg {dbFile = dbFileName}
|
||||
dogFoodChat t ct env
|
||||
|
||||
@@ -209,7 +216,7 @@ newChatClient qSize smpServer = do
|
||||
receiveFromChatTerm :: ChatClient -> ChatTerminal -> IO ()
|
||||
receiveFromChatTerm t ct = forever $ do
|
||||
atomically (readTBQueue $ inputQ ct)
|
||||
>>= processOrError . A.parseOnly (chatCommandP <* A.endOfInput) . encodeUtf8 . T.pack
|
||||
>>= processOrError . parseAll chatCommandP . encodeUtf8 . T.pack
|
||||
where
|
||||
processOrError = \case
|
||||
Left err -> writeOutQ . ErrorInput $ B.pack err
|
||||
@@ -259,9 +266,10 @@ receiveFromAgent t ct c = forever . atomically $ do
|
||||
INV qInfo -> Invitation qInfo
|
||||
CON -> Connected contact
|
||||
END -> Disconnected contact
|
||||
MSG {m_body} -> ReceivedMessage contact m_body
|
||||
MSG {msgBody} -> ReceivedMessage contact msgBody
|
||||
SENT _ -> NoChatResponse
|
||||
OK -> Confirmation contact
|
||||
ERR (CONN e) -> ContactError e contact
|
||||
ERR e -> ChatError e
|
||||
where
|
||||
contact = Contact a
|
||||
|
||||
@@ -1,7 +1,28 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad (unless, when)
|
||||
import qualified Crypto.Store.PKCS8 as S
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
import Data.Ini (lookupValue, readIniFile)
|
||||
import qualified Data.Text as T
|
||||
import Data.X509 (PrivKey (PrivKeyRSA))
|
||||
import Options.Applicative
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Server (runSMPServer)
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.StoreLog (StoreLog, openReadStoreLog)
|
||||
import System.Directory (createDirectoryIfMissing, doesFileExist)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath (combine)
|
||||
import System.IO (IOMode (..), hFlush, stdout)
|
||||
|
||||
cfg :: ServerConfig
|
||||
cfg =
|
||||
@@ -9,10 +30,137 @@ cfg =
|
||||
{ tcpPort = "5223",
|
||||
tbqSize = 16,
|
||||
queueIdBytes = 12,
|
||||
msgIdBytes = 6
|
||||
msgIdBytes = 6,
|
||||
storeLog = Nothing,
|
||||
-- key is loaded from the file server_key in /etc/opt/simplex directory
|
||||
serverPrivateKey = undefined
|
||||
}
|
||||
|
||||
newKeySize :: Int
|
||||
newKeySize = 2048 `div` 8
|
||||
|
||||
cfgDir :: FilePath
|
||||
cfgDir = "/etc/opt/simplex"
|
||||
|
||||
logDir :: FilePath
|
||||
logDir = "/var/opt/simplex"
|
||||
|
||||
defaultStoreLogFile :: FilePath
|
||||
defaultStoreLogFile = combine logDir "smp-server-store.log"
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
putStrLn $ "Listening on port " ++ tcpPort cfg
|
||||
runSMPServer cfg
|
||||
opts <- getServerOpts
|
||||
putStrLn "SMP Server (-h for help)"
|
||||
ini <- readCreateIni opts
|
||||
storeLog <- openStoreLog ini
|
||||
pk <- readCreateKey
|
||||
B.putStrLn $ "transport key hash: " <> publicKeyHash (C.publicKey pk)
|
||||
putStrLn $ "listening on port " <> tcpPort cfg
|
||||
runSMPServer cfg {serverPrivateKey = pk, storeLog}
|
||||
|
||||
data IniOpts = IniOpts
|
||||
{ enableStoreLog :: Bool,
|
||||
storeLogFile :: FilePath
|
||||
}
|
||||
|
||||
readCreateIni :: ServerOpts -> IO IniOpts
|
||||
readCreateIni ServerOpts {configFile} = do
|
||||
createDirectoryIfMissing True cfgDir
|
||||
doesFileExist configFile >>= (`unless` createIni)
|
||||
readIni
|
||||
where
|
||||
readIni :: IO IniOpts
|
||||
readIni = do
|
||||
ini <- either exitError pure =<< readIniFile configFile
|
||||
let enableStoreLog = (== Right "on") $ lookupValue "STORE_LOG" "enable" ini
|
||||
storeLogFile = either (const defaultStoreLogFile) T.unpack $ lookupValue "STORE_LOG" "file" ini
|
||||
pure IniOpts {enableStoreLog, storeLogFile}
|
||||
exitError e = do
|
||||
putStrLn $ "error reading config file " <> configFile <> ": " <> e
|
||||
exitFailure
|
||||
createIni :: IO ()
|
||||
createIni = do
|
||||
confirm $ "Save default ini file to " <> configFile
|
||||
writeFile
|
||||
configFile
|
||||
"[STORE_LOG]\n\
|
||||
\# The server uses STM memory to store SMP queues and messages,\n\
|
||||
\# that will be lost on restart (e.g., as with redis).\n\
|
||||
\# This option enables saving SMP queues to append only log,\n\
|
||||
\# and restoring them when the server is started.\n\
|
||||
\# Log is compacted on start (deleted queues are removed).\n\
|
||||
\# The messages in the queues are not logged.\n\
|
||||
\\n\
|
||||
\# enable: on\n\
|
||||
\# file: /var/opt/simplex/smp-server-store.log\n"
|
||||
|
||||
readCreateKey :: IO C.FullPrivateKey
|
||||
readCreateKey = do
|
||||
createDirectoryIfMissing True cfgDir
|
||||
let path = combine cfgDir "server_key"
|
||||
hasKey <- doesFileExist path
|
||||
(if hasKey then readKey else createKey) path
|
||||
where
|
||||
createKey :: FilePath -> IO C.FullPrivateKey
|
||||
createKey path = do
|
||||
confirm "Generate new server key pair"
|
||||
(_, pk) <- C.generateKeyPair newKeySize
|
||||
S.writeKeyFile S.TraditionalFormat path [PrivKeyRSA $ C.rsaPrivateKey pk]
|
||||
pure pk
|
||||
readKey :: FilePath -> IO C.FullPrivateKey
|
||||
readKey path = do
|
||||
S.readKeyFile path >>= \case
|
||||
[S.Unprotected (PrivKeyRSA pk)] -> pure $ C.FullPrivateKey pk
|
||||
[_] -> errorExit "not RSA key"
|
||||
[] -> errorExit "invalid key file format"
|
||||
_ -> errorExit "more than one key"
|
||||
where
|
||||
errorExit :: String -> IO b
|
||||
errorExit e = putStrLn (e <> ": " <> path) >> exitFailure
|
||||
|
||||
confirm :: String -> IO ()
|
||||
confirm msg = do
|
||||
putStr $ msg <> " (y/N): "
|
||||
hFlush stdout
|
||||
ok <- getLine
|
||||
when (map toLower ok /= "y") exitFailure
|
||||
|
||||
publicKeyHash :: C.PublicKey -> B.ByteString
|
||||
publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.encodePubKey
|
||||
|
||||
openStoreLog :: IniOpts -> IO (Maybe (StoreLog 'ReadMode))
|
||||
openStoreLog IniOpts {enableStoreLog, storeLogFile = f}
|
||||
| enableStoreLog = do
|
||||
createDirectoryIfMissing True logDir
|
||||
putStrLn ("store log: " <> f)
|
||||
Just <$> openReadStoreLog f
|
||||
| otherwise = putStrLn "store log disabled" $> Nothing
|
||||
|
||||
newtype ServerOpts = ServerOpts
|
||||
{ configFile :: FilePath
|
||||
}
|
||||
|
||||
serverOpts :: Parser ServerOpts
|
||||
serverOpts =
|
||||
ServerOpts
|
||||
<$> strOption
|
||||
( long "config"
|
||||
<> short 'c'
|
||||
<> metavar "INI_FILE"
|
||||
<> help ("config file (" <> defaultIniFile <> ")")
|
||||
<> value defaultIniFile
|
||||
)
|
||||
where
|
||||
defaultIniFile = combine cfgDir "smp-server.ini"
|
||||
|
||||
getServerOpts :: IO ServerOpts
|
||||
getServerOpts = execParser opts
|
||||
where
|
||||
opts =
|
||||
info
|
||||
(serverOpts <**> helper)
|
||||
( fullDesc
|
||||
<> header "Simplex Messaging Protocol (SMP) Server"
|
||||
<> progDesc "Start server with INI_FILE (created on first run)"
|
||||
)
|
||||
|
||||
10
package.yaml
10
package.yaml
@@ -13,6 +13,8 @@ extra-source-files:
|
||||
|
||||
dependencies:
|
||||
- ansi-terminal == 0.10.*
|
||||
- asn1-encoding == 0.9.*
|
||||
- asn1-types == 0.3.*
|
||||
- async == 2.2.*
|
||||
- attoparsec == 0.13.*
|
||||
- base >= 4.7 && < 5
|
||||
@@ -22,11 +24,13 @@ dependencies:
|
||||
- cryptonite == 0.26.*
|
||||
- directory == 1.3.*
|
||||
- filepath == 1.4.*
|
||||
- generic-random == 1.3.*
|
||||
- iso8601-time == 0.1.*
|
||||
- memory == 0.15.*
|
||||
- mtl
|
||||
- network == 3.1.*
|
||||
- network-transport == 0.5.*
|
||||
- QuickCheck == 2.13.*
|
||||
- simple-logger == 0.1.*
|
||||
- sqlite-simple == 0.4.*
|
||||
- stm
|
||||
@@ -36,6 +40,7 @@ dependencies:
|
||||
- transformers == 0.5.*
|
||||
- unliftio == 0.2.*
|
||||
- unliftio-core == 0.1.*
|
||||
- x509 == 1.7.*
|
||||
|
||||
library:
|
||||
source-dirs: src
|
||||
@@ -45,6 +50,9 @@ executables:
|
||||
source-dirs: apps/smp-server
|
||||
main: Main.hs
|
||||
dependencies:
|
||||
- cryptostore == 0.2.*
|
||||
- ini == 0.4.*
|
||||
- optparse-applicative == 0.15.*
|
||||
- simplex-messaging
|
||||
ghc-options:
|
||||
- -threaded
|
||||
@@ -78,6 +86,8 @@ tests:
|
||||
- hspec-core == 2.7.*
|
||||
- HUnit == 1.6.*
|
||||
- random == 1.1.*
|
||||
- QuickCheck == 2.13.*
|
||||
- timeit == 2.0.*
|
||||
|
||||
ghc-options:
|
||||
# - -haddock
|
||||
|
||||
@@ -16,49 +16,60 @@ For initial implementation I propose approach to be as simple as possible as lon
|
||||
|
||||
One of the consideration is to use [noise protocol framework](https://noiseprotocol.org/noise.html), this section describes ad hoc protocol though.
|
||||
|
||||
During TCP session both client and server should use symmetric AES 256 bit encryption using the session key that will be established during the handshake.
|
||||
During TCP session both client and server should use symmetric AES 256 bit encryption using two session keys and two base IVs that will be agreed during the handshake. Both client and the server should maintain two 32-bit word counters, one for sent and one for the received messages. The IV for each message should be computed by xor-ing the sequential message counter, starting from 0, with the first 32 bits of agreed base IV. TODO - explain it in a more formal way, also document how 32-bit word is encoded - with the most or least significant byte first (currently encodeWord32 from Network.Transport.Internal is used)
|
||||
|
||||
To establish the session key, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) and additional server ID (256 bits) in advance in order to be able to establish connection.
|
||||
To establish the session keys and base IVs, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) in advance in order to be able to establish connection.
|
||||
|
||||
The handshake sequence could be the following:
|
||||
The handshake sequence is the following:
|
||||
|
||||
1. Once the connection is established, the server sends its public key to the client
|
||||
2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection.
|
||||
3. If the hash is the same, the client should generate a random symmetric AES key and IV that will be used as a session key both by the client and the server.
|
||||
4. The client then should encrypt this symmetric key with the public key that the server sent and send back to the server the result and the server ID also shared with the client in advance: `rsa-encrypt(aes-key, iv, server-id)`.
|
||||
5. The server should decrypt the received key, IV and server id with its private key.
|
||||
6. The server should compare the `server-id` sent by the client and if it does not match its ID terminate the connection.
|
||||
7. In case of successful decryption and matching server ID, the server should send encrypted welcome header.
|
||||
1. Once the connection is established, the server sends server_header and its public RSA key encoded in X509 binary format to the client.
|
||||
2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. TODO as the hash is optional in server syntax at the moment, hash comparison will be optional as well. Probably it should become required.
|
||||
3. If the hash is the same, the client should generate random symmetric AES keys and base IVs that will be used as session keys/IVs by the client and the server.
|
||||
4. The client then should construct client_handshake block and send it to the server encrypted with the server public key: `rsa-encrypt(client_handshake)`. `snd_aes_key` and `snd_base_iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv_aes_key` and `rcv_base_iv` will be used by the client to decrypt **received** messages and by the server to encrypt them.
|
||||
5. The server should decrypt the received keys and base IVs with its private key.
|
||||
6. In case of successful decryption, the server should send encrypted welcome block (encrypted_welcome_block) that contains SMP protocol version.
|
||||
|
||||
```abnf
|
||||
aes_welcome_header = aes_header_auth_tag aes_encrypted_header
|
||||
welcome_header = smp_version ["," smp_mode] *SP ; decrypt(aes_encrypted_header) - 32 bytes
|
||||
smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format
|
||||
; for example: v123.456.789-alpha.7
|
||||
smp_mode = smp_public / smp_authenticated
|
||||
smp_public = %s"pub" ; public (default) - no auth to create and manage queues
|
||||
smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues
|
||||
aes_header_auth_tag = aes_auth_tag
|
||||
aes_auth_tag = 16*16(OCTET)
|
||||
```
|
||||
|
||||
No payload should follow this header, it is only used to confirm successful handshake and send the SMP protocol version that the server supports.
|
||||
|
||||
All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES key and IV sent by the client during the handshake.
|
||||
All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES keys and base IVs (incremented by counters on both sides) sent by the client during the handshake.
|
||||
|
||||
Each transport block sent by the client and the server has this syntax:
|
||||
|
||||
```abnf
|
||||
transport_block = aes_header_auth_tag aes_encrypted_header aes_body_auth_tag aes_encrypted_body
|
||||
aes_encrypted_header = 32*32(OCTET)
|
||||
header = padded_body_size payload_size reserved ; decrypt(aes_encrypted_header) - 32 bytes
|
||||
server_header = block_size protocol key_size
|
||||
block_size = 4*4(OCTET) ; 4-byte block size sent by the server, currently the client rejects if > 65536 bytes
|
||||
protocol = 2*2(OCTET) ; currently it is 0, that means binary RSA key
|
||||
key_size = 2*2(OCTET) ; the encoded key size in bytes (binary encoded in X509 standard)
|
||||
|
||||
client_handshake = client_block_size client_protocol snd_aes_key snd_base_iv rcv_aes_key rcv_base_iv
|
||||
client_block_size = 4*4(OCTET) ; 4-byte block size sent by the client, currently it is ignored by the server - reserved
|
||||
client_protocol = 2*2(OCTET) ; currently it is 0 - reserved
|
||||
snd_aes_key = 32*32(OCTET)
|
||||
snd_base_iv = 16*16(OCTET)
|
||||
rcv_aes_key = 32*32(OCTET)
|
||||
rcv_base_iv = 16*16(OCTET)
|
||||
|
||||
transport_block = aes_body_auth_tag aes_encrypted_body
|
||||
; size is sent by server during handshake, usually 8192 bytes
|
||||
aes_encrypted_body = 1*OCTET
|
||||
body = payload pad
|
||||
padded_body_size = size ; body size in bytes
|
||||
payload_size = size ; payload_size in bytes
|
||||
size = 4*4(OCTET)
|
||||
reserved = 24*24(OCTET)
|
||||
aes_body_auth_tag = aes_auth_tag
|
||||
aes_body_auth_tag = 16*16(OCTET)
|
||||
|
||||
encrypted_welcome_block = transport_block
|
||||
welcome_block = smp_version SP pad ; decrypt(encrypted_welcome_block)
|
||||
smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format
|
||||
; for example: v123.456.789-alpha.7
|
||||
pad = 1*OCTET
|
||||
```
|
||||
|
||||
## Possible future improvements/changes
|
||||
|
||||
- server id (256 bits), so that only the users that have it can connect to the server. This ID will have to be passed to the server during the handshake
|
||||
- block size agreed during handshake
|
||||
- transport encryption protocol agreed during handshake
|
||||
- welcome block containing SMP mode (smp_mode)
|
||||
|
||||
```abnf
|
||||
smp_mode = smp_public / smp_authenticated
|
||||
smp_public = %s"pub" ; public (default) - no auth to create and manage queues
|
||||
smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues
|
||||
```
|
||||
|
||||
## Initial handshake
|
||||
@@ -105,7 +116,9 @@ Symmetric keys are generated per message and encrypted with receiver's public ke
|
||||
The syntax of each encrypted message body is the following:
|
||||
|
||||
```abnf
|
||||
encrypted_message_body = rsa_encrypted_header aes_encrypted_body
|
||||
encrypted_message_body = rsa_signature encrypted_body
|
||||
encrypted_body = rsa_encrypted_header aes_encrypted_body
|
||||
rsa_signature = 256*256(OCTET) ; sign(encrypted_body) - assuming 2048 bit key size
|
||||
rsa_encrypted_header = 256*256(OCTET) ; encrypt(header) - assuming 2048 bit key size
|
||||
aes_encrypted_body = 1*OCTET ; encrypt(body)
|
||||
|
||||
@@ -120,7 +133,6 @@ body = payload pad
|
||||
|
||||
Future considerations:
|
||||
- Generation of symmetric keys per session and session rotation;
|
||||
- Signature and verification of messages.
|
||||
|
||||
## E2E implementation
|
||||
|
||||
|
||||
25
rfcs/2021-03-18-groups.md
Normal file
25
rfcs/2021-03-18-groups.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# SMP agent groups
|
||||
|
||||
## Problems
|
||||
|
||||
- device/user profile synchronisation
|
||||
- chat group communication
|
||||
|
||||
Both problems would require message broadcast between a group of SMP agents.
|
||||
|
||||
## Solution: basic symmetric groups via SMP agent protocol
|
||||
|
||||
Additional commands and message envelopes to SMP agent protocol to provide an abstraction layer for device synchronisation and chat groups.
|
||||
|
||||
The groups are fully symmetric, all agent who are members of the group have equal rights and can join and leave group at any time.
|
||||
|
||||
All the information about the groups is stored only in agents, the commands are used to synchronise the group state between the agents.
|
||||
|
||||
```abnf
|
||||
group_command = create_group / add_to_group / remove_from_group / leave_group
|
||||
group_response = group_created / added_to_group / removed_from_group
|
||||
group_notification = added_to_group_by / removed_from_group_by / left_group
|
||||
create_group = %s"GNEW " group_name ; cAlias must be empty
|
||||
add_to_group = %s"GADD " group_name ; cAlias is the connection to add to the group
|
||||
added_to_group = %s"GADDED " name ; cAlias is the connection added to the group
|
||||
```
|
||||
@@ -26,6 +26,7 @@ import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Data.Time.Clock
|
||||
import Database.SQLite.Simple (SQLError)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Store
|
||||
@@ -36,10 +37,9 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (CorrId (..), MsgBody, SenderPublicKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (putLn, runTCPServer)
|
||||
import Simplex.Messaging.Util (liftError)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import System.IO (Handle)
|
||||
import UnliftIO.Async (race_)
|
||||
import UnliftIO.Exception (SomeException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -98,7 +98,7 @@ send h c@AgentClient {sndQ} = forever $ do
|
||||
|
||||
logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m ()
|
||||
logClient AgentClient {clientId} dir (CorrId corrId, cAlias, cmd) = do
|
||||
logInfo . decodeUtf8 $ B.unwords [B.pack $ show clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
|
||||
client :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
client c@AgentClient {rcvQ, sndQ} st = forever $ do
|
||||
@@ -114,10 +114,15 @@ withStore ::
|
||||
withStore action = do
|
||||
runExceptT (action `E.catch` handleInternal) >>= \case
|
||||
Right c -> return c
|
||||
Left _ -> throwError STORE
|
||||
Left e -> throwError $ storeError e
|
||||
where
|
||||
handleInternal :: (MonadError StoreError m') => SomeException -> m' a
|
||||
handleInternal _ = throwError SEInternal
|
||||
handleInternal :: (MonadError StoreError m') => SQLError -> m' a
|
||||
handleInternal e = throwError . SEInternal $ bshow e
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
storeError = \case
|
||||
SEConnNotFound -> CONN UNKNOWN
|
||||
SEConnDuplicate -> CONN DUPLICATE
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m ()
|
||||
processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
|
||||
@@ -156,9 +161,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
|
||||
withStore (getConn st cAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> subscribe rq
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq
|
||||
-- TODO possibly there should be a separate error type trying
|
||||
-- TODO to send the message to the connection without RcvQueue
|
||||
_ -> throwError PROHIBITED
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
subscribe rq = subscribeQueue c rq cAlias >> respond' cAlias OK
|
||||
|
||||
@@ -171,22 +174,32 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
|
||||
withStore (getConn st connAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
|
||||
SomeConn _ (SndConnection _ sq) -> sendMsg sq
|
||||
-- TODO possibly there should be a separate error type trying
|
||||
-- TODO to send the message to the connection without SndQueue
|
||||
_ -> throwError PROHIBITED -- NOT_READY ?
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
sendMsg sq = do
|
||||
senderTs <- liftIO getCurrentTime
|
||||
senderId <- withStore $ createSndMsg st connAlias msgBody senderTs
|
||||
sendAgentMessage c sq senderTs $ A_MSG msgBody
|
||||
respond $ SENT (unId senderId)
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
|
||||
let msgStr =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
{ senderMsgId = unSndId internalSndId,
|
||||
senderTimestamp = internalTs,
|
||||
previousMsgHash,
|
||||
agentMessage = A_MSG msgBody
|
||||
}
|
||||
msgHash = C.sha256Hash msgStr
|
||||
withStore $
|
||||
createSndMsg st sq $
|
||||
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
|
||||
sendAgentMessage c sq msgStr
|
||||
respond $ SENT (unId internalId)
|
||||
|
||||
suspendConnection :: m ()
|
||||
suspendConnection =
|
||||
withStore (getConn st connAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> suspend rq
|
||||
SomeConn _ (RcvConnection _ rq) -> suspend rq
|
||||
_ -> throwError PROHIBITED
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
suspend rq = suspendQueue c rq >> respond OK
|
||||
|
||||
@@ -195,20 +208,26 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
|
||||
withStore (getConn st connAlias) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> delete rq
|
||||
SomeConn _ (RcvConnection _ rq) -> delete rq
|
||||
_ -> throwError PROHIBITED
|
||||
_ -> delConn
|
||||
where
|
||||
delConn = withStore (deleteConn st connAlias) >> respond OK
|
||||
delete rq = do
|
||||
deleteQueue c rq
|
||||
removeSubscription c connAlias
|
||||
withStore (deleteConn st connAlias)
|
||||
respond OK
|
||||
delConn
|
||||
|
||||
sendReplyQInfo :: SMPServer -> SndQueue -> m ()
|
||||
sendReplyQInfo srv sq = do
|
||||
(rq, qInfo) <- newReceiveQueue c srv connAlias
|
||||
withStore $ upgradeSndConnToDuplex st connAlias rq
|
||||
senderTs <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq senderTs $ REPLY qInfo
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage = REPLY qInfo
|
||||
}
|
||||
|
||||
respond :: ACommand 'Agent -> m ()
|
||||
respond = respond' connAlias
|
||||
@@ -226,11 +245,13 @@ subscriber c@AgentClient {msgQ} st = forever $ do
|
||||
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
|
||||
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
rq@RcvQueue {connAlias, decryptKey, status} <- withStore $ getRcvQueue st srv rId
|
||||
rq@RcvQueue {connAlias, status} <- withStore $ getRcvQueue st srv rId
|
||||
case cmd of
|
||||
SMP.MSG srvMsgId srvTs msgBody -> do
|
||||
-- TODO deduplicate with previously received
|
||||
agentMsg <- liftEither . parseSMPMessage =<< decryptMessage decryptKey msgBody
|
||||
msg <- decryptAndVerify rq msgBody
|
||||
let msgHash = C.sha256Hash msg
|
||||
agentMsg <- liftEither $ parseSMPMessage msg
|
||||
case agentMsg of
|
||||
SMPConfirmation senderKey -> do
|
||||
logServer "<--" c srv rId "MSG <KEY>"
|
||||
@@ -242,50 +263,74 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
-- TODO update sender key in the store?
|
||||
secureQueue c rq senderKey
|
||||
withStore $ setRcvQueueStatus st rq Secured
|
||||
sendAck c rq
|
||||
s ->
|
||||
-- TODO maybe send notification to the user
|
||||
liftIO . putStrLn $ "unexpected SMP confirmation, queue status " <> show s
|
||||
SMPMessage {agentMessage, senderMsgId, senderTimestamp} ->
|
||||
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
|
||||
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
|
||||
case agentMessage of
|
||||
HELLO _verifyKey _ -> do
|
||||
HELLO verifyKey _ -> do
|
||||
logServer "<--" c srv rId "MSG <HELLO>"
|
||||
-- TODO send status update to the user?
|
||||
withStore $ setRcvQueueStatus st rq Active
|
||||
sendAck c rq
|
||||
case status of
|
||||
Active -> notify connAlias . ERR $ AGENT A_PROHIBITED
|
||||
_ -> do
|
||||
void $ verifyMessage (Just verifyKey) msgBody
|
||||
withStore $ setRcvQueueActive st rq verifyKey
|
||||
REPLY qInfo -> do
|
||||
logServer "<--" c srv rId "MSG <REPLY>"
|
||||
-- TODO move senderKey inside SndQueue
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias
|
||||
withStore $ upgradeRcvConnToDuplex st connAlias sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
notify connAlias CON
|
||||
sendAck c rq
|
||||
A_MSG body -> do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
-- TODO check message status
|
||||
recipientTs <- liftIO getCurrentTime
|
||||
let m_sender = (senderMsgId, senderTimestamp)
|
||||
let m_broker = (srvMsgId, srvTs)
|
||||
recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker
|
||||
notify connAlias $
|
||||
MSG
|
||||
{ m_status = MsgOk,
|
||||
m_recipient = (unId recipientId, recipientTs),
|
||||
m_sender,
|
||||
m_broker,
|
||||
m_body = body
|
||||
}
|
||||
sendAck c rq
|
||||
A_MSG body -> agentClientMsg rq previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
|
||||
sendAck c rq
|
||||
return ()
|
||||
SMP.END -> do
|
||||
removeSubscription c connAlias
|
||||
logServer "<--" c srv rId "END"
|
||||
notify connAlias END
|
||||
_ -> logServer "<--" c srv rId $ "unexpected:" <> (B.pack . show) cmd
|
||||
_ -> do
|
||||
logServer "<--" c srv rId $ "unexpected: " <> bshow cmd
|
||||
notify connAlias . ERR $ BROKER UNEXPECTED
|
||||
where
|
||||
notify :: ConnAlias -> ACommand 'Agent -> m ()
|
||||
notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg)
|
||||
agentClientMsg :: RcvQueue -> PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
|
||||
agentClientMsg rq@RcvQueue {connAlias, status} receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
case status of
|
||||
Active -> do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
|
||||
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash
|
||||
withStore $
|
||||
createRcvMsg st rq $
|
||||
RcvMsgData
|
||||
{ internalId,
|
||||
internalRcvId,
|
||||
internalTs,
|
||||
senderMeta,
|
||||
brokerMeta,
|
||||
msgBody,
|
||||
internalHash = msgHash,
|
||||
externalPrevSndHash = receivedPrevMsgHash,
|
||||
msgIntegrity
|
||||
}
|
||||
notify connAlias $
|
||||
MSG
|
||||
{ recipientMeta = (unId internalId, internalTs),
|
||||
senderMeta,
|
||||
brokerMeta,
|
||||
msgBody,
|
||||
msgIntegrity
|
||||
}
|
||||
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
|
||||
where
|
||||
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> MsgIntegrity
|
||||
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash
|
||||
| extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk
|
||||
| extSndId < prevExtSndId = MsgError $ MsgBadId extSndId
|
||||
| extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate
|
||||
| extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1)
|
||||
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
|
||||
| otherwise = MsgError MsgDuplicate -- this case is not possible
|
||||
|
||||
connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m ()
|
||||
connectToSendQueue c st sq senderKey verifyKey = do
|
||||
@@ -294,9 +339,6 @@ connectToSendQueue c st sq senderKey verifyKey = do
|
||||
sendHello c sq verifyKey
|
||||
withStore $ setSndQueueStatus st sq Active
|
||||
|
||||
decryptMessage :: (MonadUnliftIO m, MonadError AgentErrorType m) => DecryptionKey -> ByteString -> m ByteString
|
||||
decryptMessage decryptKey msg = liftError CRYPTO $ C.decrypt decryptKey msg
|
||||
|
||||
newSendQueue ::
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey)
|
||||
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
|
||||
|
||||
@@ -19,11 +19,14 @@ module Simplex.Messaging.Agent.Client
|
||||
sendHello,
|
||||
secureQueue,
|
||||
sendAgentMessage,
|
||||
decryptAndVerify,
|
||||
verifyMessage,
|
||||
sendAck,
|
||||
suspendQueue,
|
||||
deleteQueue,
|
||||
logServer,
|
||||
removeSubscription,
|
||||
cryptoError,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -47,7 +50,7 @@ import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey)
|
||||
import Simplex.Messaging.Util (liftError)
|
||||
import Simplex.Messaging.Util (bshow, liftEitherError, liftError)
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Exception (IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
@@ -86,15 +89,17 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
newSMPClient = do
|
||||
smp <- connectClient
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
-- TODO how can agent know client lost the connection?
|
||||
atomically . modifyTVar smpClients $ M.insert srv smp
|
||||
return smp
|
||||
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- asks $ smpCfg . config
|
||||
liftIO (getSMPClient srv cfg msgQ clientDisconnected)
|
||||
`E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection)
|
||||
liftEitherError smpClientError (getSMPClient srv cfg msgQ clientDisconnected)
|
||||
`E.catch` internalError
|
||||
where
|
||||
internalError :: IOException -> m SMPClient
|
||||
internalError = throwError . INTERNAL . show
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
@@ -118,31 +123,41 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
|
||||
closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient
|
||||
|
||||
withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMP c srv action =
|
||||
(getSMPServerClient c srv >>= runAction) `catchError` logServerError
|
||||
withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a
|
||||
withSMP_ c srv action =
|
||||
(getSMPServerClient c srv >>= action) `catchError` logServerError
|
||||
where
|
||||
runAction :: SMPClient -> m a
|
||||
runAction smp = liftError smpClientError $ action smp
|
||||
|
||||
smpClientError :: SMPClientError -> AgentErrorType
|
||||
smpClientError = \case
|
||||
SMPServerError e -> SMP e
|
||||
-- TODO handle other errors
|
||||
_ -> INTERNAL
|
||||
|
||||
logServerError :: AgentErrorType -> m a
|
||||
logServerError e = do
|
||||
logServer "<--" c srv "" $ (B.pack . show) e
|
||||
logServer "<--" c srv "" $ bshow e
|
||||
throwError e
|
||||
|
||||
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withLogSMP c srv qId cmdStr action = do
|
||||
withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a
|
||||
withLogSMP_ c srv qId cmdStr action = do
|
||||
logServer "-->" c srv qId cmdStr
|
||||
res <- withSMP c srv action
|
||||
res <- withSMP_ c srv action
|
||||
logServer "<--" c srv qId "OK"
|
||||
return res
|
||||
|
||||
withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMP c srv action = withSMP_ c srv $ liftSMP . action
|
||||
|
||||
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action
|
||||
|
||||
liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a
|
||||
liftSMP = liftError smpClientError
|
||||
|
||||
smpClientError :: SMPClientError -> AgentErrorType
|
||||
smpClientError = \case
|
||||
SMPServerError e -> SMP e
|
||||
SMPResponseError e -> BROKER $ RESPONSE e
|
||||
SMPUnexpectedResponse -> BROKER UNEXPECTED
|
||||
SMPResponseTimeout -> BROKER TIMEOUT
|
||||
SMPNetworkError -> BROKER NETWORK
|
||||
SMPTransportError e -> BROKER $ TRANSPORT e
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv connAlias = do
|
||||
size <- asks $ rsaKeySize . config
|
||||
@@ -196,7 +211,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically
|
||||
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} srv qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv)
|
||||
@@ -205,35 +220,38 @@ logSecret :: ByteString -> ByteString
|
||||
logSecret bs = encode $ B.take 3 bs
|
||||
|
||||
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> m ()
|
||||
sendConfirmation c SndQueue {server, sndId, encryptKey} senderKey = do
|
||||
msg <- mkConfirmation
|
||||
withLogSMP c server sndId "SEND <KEY>" $ \smp ->
|
||||
sendSMPMessage smp Nothing sndId msg
|
||||
sendConfirmation c sq@SndQueue {server, sndId} senderKey =
|
||||
withLogSMP_ c server sndId "SEND <KEY>" $ \smp -> do
|
||||
msg <- mkConfirmation smp
|
||||
liftSMP $ sendSMPMessage smp Nothing sndId msg
|
||||
where
|
||||
mkConfirmation :: m MsgBody
|
||||
mkConfirmation = do
|
||||
let msg = serializeSMPMessage $ SMPConfirmation senderKey
|
||||
paddedSize <- asks paddedMsgSize
|
||||
liftError CRYPTO $ C.encrypt encryptKey paddedSize msg
|
||||
mkConfirmation :: SMPClient -> m MsgBody
|
||||
mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey
|
||||
|
||||
sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m ()
|
||||
sendHello c SndQueue {server, sndId, sndPrivateKey, encryptKey} verifyKey = do
|
||||
msg <- mkHello $ AckMode On
|
||||
withLogSMP c server sndId "SEND <HELLO> (retrying)" $
|
||||
send 20 msg
|
||||
sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey =
|
||||
withLogSMP_ c server sndId "SEND <HELLO> (retrying)" $ \smp -> do
|
||||
msg <- mkHello smp $ AckMode On
|
||||
liftSMP $ send 8 100000 msg smp
|
||||
where
|
||||
mkHello :: AckMode -> m ByteString
|
||||
mkHello ackMode = do
|
||||
senderTs <- liftIO getCurrentTime
|
||||
mkAgentMessage encryptKey senderTs $ HELLO verifyKey ackMode
|
||||
mkHello :: SMPClient -> AckMode -> m ByteString
|
||||
mkHello smp ackMode = do
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
encryptAndSign smp sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage = HELLO verifyKey ackMode
|
||||
}
|
||||
|
||||
send :: Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO ()
|
||||
send 0 _ _ = throwE SMPResponseTimeout -- TODO different error
|
||||
send retry msg smp =
|
||||
send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO ()
|
||||
send 0 _ _ _ = throwE $ SMPServerError AUTH
|
||||
send retry delay msg smp =
|
||||
sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case
|
||||
SMPServerError AUTH -> do
|
||||
threadDelay 100000
|
||||
send (retry - 1) msg smp
|
||||
threadDelay delay
|
||||
send (retry - 1) (delay * 3 `div` 2) msg smp
|
||||
e -> throwE e
|
||||
|
||||
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m ()
|
||||
@@ -256,21 +274,39 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m ()
|
||||
sendAgentMessage c SndQueue {server, sndId, sndPrivateKey, encryptKey} senderTs agentMsg = do
|
||||
msg <- mkAgentMessage encryptKey senderTs agentMsg
|
||||
withLogSMP c server sndId "SEND <message>" $ \smp ->
|
||||
sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msg =
|
||||
withLogSMP_ c server sndId "SEND <message>" $ \smp -> do
|
||||
msg' <- encryptAndSign smp sq msg
|
||||
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg'
|
||||
|
||||
mkAgentMessage :: AgentMonad m => EncryptionKey -> SenderTimestamp -> AMessage -> m ByteString
|
||||
mkAgentMessage encKey senderTs agentMessage = do
|
||||
let msg =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp = senderTs,
|
||||
previousMsgHash = "1234", -- TODO hash of the previous message
|
||||
agentMessage
|
||||
}
|
||||
paddedSize <- asks paddedMsgSize
|
||||
liftError CRYPTO $ C.encrypt encKey paddedSize msg
|
||||
encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> ByteString -> m ByteString
|
||||
encryptAndSign smp SndQueue {encryptKey, signKey} msg = do
|
||||
paddedSize <- asks $ (blockSize smp -) . reservedMsgSize
|
||||
liftError cryptoError $ do
|
||||
enc <- C.encrypt encryptKey paddedSize msg
|
||||
C.Signature sig <- C.sign signKey enc
|
||||
pure $ sig <> enc
|
||||
|
||||
decryptAndVerify :: AgentMonad m => RcvQueue -> ByteString -> m ByteString
|
||||
decryptAndVerify RcvQueue {decryptKey, verifyKey} msg =
|
||||
verifyMessage verifyKey msg
|
||||
>>= liftError cryptoError . C.decrypt decryptKey
|
||||
|
||||
verifyMessage :: AgentMonad m => Maybe VerificationKey -> ByteString -> m ByteString
|
||||
verifyMessage verifyKey msg = do
|
||||
size <- asks $ rsaKeySize . config
|
||||
let (sig, enc) = B.splitAt size msg
|
||||
case verifyKey of
|
||||
Nothing -> pure enc
|
||||
Just k
|
||||
| C.verify k (C.Signature sig) enc -> pure enc
|
||||
| otherwise -> throwError $ AGENT A_SIGNATURE
|
||||
|
||||
cryptoError :: C.CryptoError -> AgentErrorType
|
||||
cryptoError = \case
|
||||
C.CryptoLargeMsgError -> CMD LARGE
|
||||
C.RSADecryptError _ -> AGENT A_ENCRYPTION
|
||||
C.CryptoHeaderError _ -> AGENT A_ENCRYPTION
|
||||
C.AESDecryptError -> AGENT A_ENCRYPTION
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
@@ -26,7 +26,7 @@ data Env = Env
|
||||
{ config :: AgentConfig,
|
||||
idsDrg :: TVar ChaChaDRG,
|
||||
clientCounter :: TVar Int,
|
||||
paddedMsgSize :: Int
|
||||
reservedMsgSize :: Int
|
||||
}
|
||||
|
||||
newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
|
||||
@@ -34,10 +34,10 @@ newSMPAgentEnv config = do
|
||||
idsDrg <- drgNew >>= newTVarIO
|
||||
_ <- createSQLiteStore $ dbFile config
|
||||
clientCounter <- newTVarIO 0
|
||||
return Env {config, idsDrg, clientCounter, paddedMsgSize}
|
||||
return Env {config, idsDrg, clientCounter, reservedMsgSize}
|
||||
where
|
||||
-- one rsaKeySize is used by the RSA signature in each command,
|
||||
-- another - by encrypted message body header
|
||||
-- 1st rsaKeySize is used by the RSA signature in each command,
|
||||
-- 2nd - by encrypted message body header
|
||||
-- 3rd - by message signature
|
||||
-- smpCommandSize - is the estimated max size for SMP command, queueId, corrId
|
||||
paddedMsgSize = blockSize smp - 2 * rsaKeySize config - smpCommandSize smp
|
||||
smp = smpCfg config
|
||||
reservedMsgSize = 3 * rsaKeySize config + smpCommandSize (smpCfg config)
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
module Simplex.Messaging.Agent.Store where
|
||||
|
||||
import Control.Exception (Exception)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
import Data.Time (UTCTime)
|
||||
@@ -38,11 +39,16 @@ class Monad m => MonadAgentStore s m where
|
||||
upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m ()
|
||||
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m ()
|
||||
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
|
||||
|
||||
-- Msg management
|
||||
createRcvMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId
|
||||
createSndMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> m InternalId
|
||||
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
|
||||
|
||||
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
|
||||
|
||||
getMsg :: s -> ConnAlias -> InternalId -> m Msg
|
||||
|
||||
-- * Queue types
|
||||
@@ -102,6 +108,11 @@ data SConnType :: ConnType -> Type where
|
||||
SCSnd :: SConnType CSnd
|
||||
SCDuplex :: SConnType CDuplex
|
||||
|
||||
connType :: SConnType c -> ConnType
|
||||
connType SCRcv = CRcv
|
||||
connType SCSnd = CSnd
|
||||
connType SCDuplex = CDuplex
|
||||
|
||||
deriving instance Eq (SConnType d)
|
||||
|
||||
deriving instance Show (SConnType d)
|
||||
@@ -123,6 +134,42 @@ instance Eq SomeConn where
|
||||
|
||||
deriving instance Show SomeConn
|
||||
|
||||
-- * Message integrity validation types
|
||||
|
||||
type MsgHash = ByteString
|
||||
|
||||
-- | Corresponds to `last_external_snd_msg_id` in `connections` table
|
||||
type PrevExternalSndId = Int64
|
||||
|
||||
-- | Corresponds to `last_rcv_msg_hash` in `connections` table
|
||||
type PrevRcvMsgHash = MsgHash
|
||||
|
||||
-- | Corresponds to `last_snd_msg_hash` in `connections` table
|
||||
type PrevSndMsgHash = MsgHash
|
||||
|
||||
-- ? merge/replace these with RcvMsg and SndMsg
|
||||
-- * Message data containers - used on Msg creation to reduce number of parameters
|
||||
|
||||
data RcvMsgData = RcvMsgData
|
||||
{ internalId :: InternalId,
|
||||
internalRcvId :: InternalRcvId,
|
||||
internalTs :: InternalTs,
|
||||
senderMeta :: (ExternalSndId, ExternalSndTs),
|
||||
brokerMeta :: (BrokerId, BrokerTs),
|
||||
msgBody :: MsgBody,
|
||||
internalHash :: MsgHash,
|
||||
externalPrevSndHash :: MsgHash,
|
||||
msgIntegrity :: MsgIntegrity
|
||||
}
|
||||
|
||||
data SndMsgData = SndMsgData
|
||||
{ internalId :: InternalId,
|
||||
internalSndId :: InternalSndId,
|
||||
internalTs :: InternalTs,
|
||||
msgBody :: MsgBody,
|
||||
internalHash :: MsgHash
|
||||
}
|
||||
|
||||
-- * Message types
|
||||
|
||||
-- | A message in either direction that is stored by the agent.
|
||||
@@ -149,7 +196,10 @@ data RcvMsg = RcvMsg
|
||||
-- | Timestamp of acknowledgement to sender, corresponds to `AcknowledgedToSender` status.
|
||||
-- Do not mix up with `externalSndTs` - timestamp created at sender before sending,
|
||||
-- which in its turn corresponds to `internalTs` in sending agent.
|
||||
ackSenderTs :: AckSenderTs
|
||||
ackSenderTs :: AckSenderTs,
|
||||
-- | Hash of previous message as received from sender - stored for integrity forensics.
|
||||
externalPrevSndHash :: MsgHash,
|
||||
msgIntegrity :: MsgIntegrity
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -209,7 +259,9 @@ data MsgBase = MsgBase
|
||||
-- due to a possibility of implementation errors in different agents.
|
||||
internalId :: InternalId,
|
||||
internalTs :: InternalTs,
|
||||
msgBody :: MsgBody
|
||||
msgBody :: MsgBody,
|
||||
-- | Hash of the message as computed by agent.
|
||||
internalHash :: MsgHash
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
@@ -219,13 +271,11 @@ type InternalTs = UTCTime
|
||||
|
||||
-- * Store errors
|
||||
|
||||
-- TODO revise
|
||||
data StoreError
|
||||
= SEInternal
|
||||
| SENotFound
|
||||
| SEBadConn
|
||||
= SEInternal ByteString
|
||||
| SEConnNotFound
|
||||
| SEConnDuplicate
|
||||
| SEBadConnType ConnType
|
||||
| SEBadQueueStatus
|
||||
| SEBadQueueDirection
|
||||
| SEBadQueueStatus -- not used, planned to check strictly
|
||||
| SENotImplemented -- TODO remove
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@@ -21,24 +23,26 @@ where
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO))
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.List (find)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Text (isPrefixOf)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple (FromRow, NamedParam (..), SQLData (..), SQLError, field)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema)
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Protocol (MsgBody)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (liftIOEither)
|
||||
import Simplex.Messaging.Util (bshow, liftIOEither)
|
||||
import System.Exit (ExitCode (ExitFailure), exitWith)
|
||||
import System.FilePath (takeDirectory)
|
||||
import Text.Read (readMaybe)
|
||||
@@ -75,76 +79,170 @@ connectSQLiteStore dbFilePath = do
|
||||
liftIO $ DB.execute_ dbConn "PRAGMA foreign_keys = ON;"
|
||||
return SQLiteStore {dbFilePath, dbConn}
|
||||
|
||||
checkDuplicate :: (MonadUnliftIO m, MonadError StoreError m) => IO () -> m ()
|
||||
checkDuplicate action = liftIOEither $ first handleError <$> E.try action
|
||||
where
|
||||
handleError :: SQLError -> StoreError
|
||||
handleError e
|
||||
| DB.sqlError e == DB.ErrorConstraint = SEConnDuplicate
|
||||
| otherwise = SEInternal $ bshow e
|
||||
|
||||
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
|
||||
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
|
||||
createRcvConn SQLiteStore {dbConn} rcvQueue =
|
||||
liftIO $
|
||||
createRcvQueueAndConn dbConn rcvQueue
|
||||
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
|
||||
checkDuplicate $
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn q
|
||||
insertRcvConnection_ dbConn q
|
||||
|
||||
createSndConn :: SQLiteStore -> SndQueue -> m ()
|
||||
createSndConn SQLiteStore {dbConn} sndQueue =
|
||||
liftIO $
|
||||
createSndQueueAndConn dbConn sndQueue
|
||||
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
|
||||
checkDuplicate $
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn q
|
||||
insertSndConnection_ dbConn q
|
||||
|
||||
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
|
||||
getConn SQLiteStore {dbConn} connAlias = do
|
||||
queues <-
|
||||
liftIO $
|
||||
retrieveConnQueues dbConn connAlias
|
||||
case queues of
|
||||
(Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
|
||||
(Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ)
|
||||
_ -> throwError SEBadConn
|
||||
getConn SQLiteStore {dbConn} connAlias =
|
||||
liftIOEither . DB.withTransaction dbConn $
|
||||
getConn_ dbConn connAlias
|
||||
|
||||
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
|
||||
getAllConnAliases SQLiteStore {dbConn} =
|
||||
liftIO $
|
||||
retrieveAllConnAliases dbConn
|
||||
liftIO $ do
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
|
||||
return (concat r)
|
||||
|
||||
getRcvQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m RcvQueue
|
||||
getRcvQueue SQLiteStore {dbConn} SMPServer {host, port} rcvId = do
|
||||
rcvQueue <-
|
||||
r <-
|
||||
liftIO $
|
||||
retrieveRcvQueue dbConn host port rcvId
|
||||
case rcvQueue of
|
||||
Just rcvQ -> return rcvQ
|
||||
_ -> throwError SENotFound
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
|
||||
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id;
|
||||
|]
|
||||
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
case r of
|
||||
[(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] ->
|
||||
let srv = SMPServer hst (deserializePort_ prt) keyHash
|
||||
in pure $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
_ -> throwError SEConnNotFound
|
||||
|
||||
deleteConn :: SQLiteStore -> ConnAlias -> m ()
|
||||
deleteConn SQLiteStore {dbConn} connAlias =
|
||||
liftIO $
|
||||
deleteConnCascade dbConn connAlias
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
|
||||
[":conn_alias" := connAlias]
|
||||
|
||||
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sndQueue =
|
||||
liftIOEither $
|
||||
updateRcvConnWithSndQueue dbConn connAlias sndQueue
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
|
||||
liftIOEither . DB.withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
Right (SomeConn SCRcv (RcvConnection _ _)) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn sq
|
||||
updateConnWithSndQueue_ dbConn connAlias sq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rcvQueue =
|
||||
liftIOEither $
|
||||
updateSndConnWithRcvQueue dbConn connAlias rcvQueue
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
|
||||
liftIOEither . DB.withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
Right (SomeConn SCSnd (SndConnection _ _)) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn rq
|
||||
updateConnWithRcvQueue_ dbConn connAlias rq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueStatus SQLiteStore {dbConn} rcvQueue status =
|
||||
setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status =
|
||||
-- ? throw error if queue doesn't exist?
|
||||
liftIO $
|
||||
updateRcvQueueStatus dbConn rcvQueue status
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
|
||||
setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m ()
|
||||
setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey =
|
||||
-- ? throw error if queue doesn't exist?
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET verify_key = :verify_key, status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[ ":verify_key" := Just verifyKey,
|
||||
":status" := Active,
|
||||
":host" := host,
|
||||
":port" := serializePort_ port,
|
||||
":rcv_id" := rcvId
|
||||
]
|
||||
|
||||
setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m ()
|
||||
setSndQueueStatus SQLiteStore {dbConn} sndQueue status =
|
||||
setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status =
|
||||
-- ? throw error if queue doesn't exist?
|
||||
liftIO $
|
||||
updateSndQueueStatus dbConn sndQueue status
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE snd_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
|
||||
|
||||
createRcvMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId
|
||||
createRcvMsg SQLiteStore {dbConn} connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) =
|
||||
liftIOEither $
|
||||
insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs)
|
||||
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
||||
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
|
||||
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
createSndMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> m InternalId
|
||||
createSndMsg SQLiteStore {dbConn} connAlias msgBody internalTs =
|
||||
liftIOEither $
|
||||
insertSndMsg dbConn connAlias msgBody internalTs
|
||||
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
|
||||
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
insertRcvMsgBase_ dbConn connAlias rcvMsgData
|
||||
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
|
||||
updateHashRcv_ dbConn connAlias rcvMsgData
|
||||
|
||||
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
||||
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
|
||||
pure (internalId, internalSndId, prevSndHash)
|
||||
|
||||
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
|
||||
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
insertSndMsgBase_ dbConn connAlias sndMsgData
|
||||
insertSndMsgDetails_ dbConn connAlias sndMsgData
|
||||
updateHashSnd_ dbConn connAlias sndMsgData
|
||||
|
||||
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
|
||||
getMsg _st _connAlias _id = throwError SENotImplemented
|
||||
@@ -179,6 +277,16 @@ instance ToField RcvMsgStatus where toField = toField . show
|
||||
|
||||
instance ToField SndMsgStatus where toField = toField . show
|
||||
|
||||
instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity
|
||||
|
||||
instance FromField MsgIntegrity where
|
||||
fromField = \case
|
||||
f@(Field (SQLBlob b) _) ->
|
||||
case parseAll msgIntegrityP b of
|
||||
Right k -> Ok k
|
||||
Left e -> returnError ConversionFailed f ("can't parse msg integrity field: " ++ e)
|
||||
f -> returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
|
||||
fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a
|
||||
fromFieldToReadable_ = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
@@ -217,13 +325,6 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
createRcvQueueAndConn :: DB.Connection -> RcvQueue -> IO ()
|
||||
createRcvQueueAndConn dbConn rcvQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn (server (rcvQueue :: RcvQueue))
|
||||
insertRcvQueue_ dbConn rcvQueue
|
||||
insertRcvConnection_ dbConn rcvQueue
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -255,22 +356,16 @@ insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id)
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
|
||||
0, 0, 0);
|
||||
0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
createSndQueueAndConn :: DB.Connection -> SndQueue -> IO ()
|
||||
createSndQueueAndConn dbConn sndQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn (server (sndQueue :: SndQueue))
|
||||
insertSndQueue_ dbConn sndQueue
|
||||
insertSndConnection_ dbConn sndQueue
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn SndQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -300,28 +395,25 @@ insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id)
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
|
||||
0, 0, 0);
|
||||
0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
retrieveConnQueues :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue)
|
||||
retrieveConnQueues dbConn connAlias =
|
||||
DB.withTransaction -- Avoid inconsistent state between queue reads
|
||||
dbConn
|
||||
$ retrieveConnQueues_ dbConn connAlias
|
||||
|
||||
-- Separate transactionless version of retrieveConnQueues to be reused in other functions that already wrap
|
||||
-- multiple statements in transaction - otherwise they'd be attempting to start a transaction within a transaction
|
||||
retrieveConnQueues_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue)
|
||||
retrieveConnQueues_ dbConn connAlias = do
|
||||
rcvQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
|
||||
sndQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
|
||||
return (rcvQ, sndQ)
|
||||
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
|
||||
getConn_ dbConn connAlias = do
|
||||
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
|
||||
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
|
||||
pure $ case (rQ, sQ) of
|
||||
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
|
||||
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
|
||||
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
|
||||
@@ -363,60 +455,8 @@ retrieveSndQueueByConnAlias_ dbConn connAlias = do
|
||||
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
|
||||
_ -> return Nothing
|
||||
|
||||
-- * getAllConnAliases helper
|
||||
|
||||
retrieveAllConnAliases :: DB.Connection -> IO [ConnAlias]
|
||||
retrieveAllConnAliases dbConn = do
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
|
||||
return (concat r)
|
||||
|
||||
-- * getRcvQueue helper
|
||||
|
||||
retrieveRcvQueue :: DB.Connection -> HostName -> Maybe ServiceName -> SMP.RecipientId -> IO (Maybe RcvQueue)
|
||||
retrieveRcvQueue dbConn host port rcvId = do
|
||||
r <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
|
||||
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id;
|
||||
|]
|
||||
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
case r of
|
||||
[(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
|
||||
let srv = SMPServer hst (deserializePort_ prt) keyHash
|
||||
return . Just $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
_ -> return Nothing
|
||||
|
||||
-- * deleteConn helper
|
||||
|
||||
deleteConnCascade :: DB.Connection -> ConnAlias -> IO ()
|
||||
deleteConnCascade dbConn connAlias =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
|
||||
[":conn_alias" := connAlias]
|
||||
|
||||
-- * upgradeRcvConnToDuplex helpers
|
||||
|
||||
updateRcvConnWithSndQueue :: DB.Connection -> ConnAlias -> SndQueue -> IO (Either StoreError ())
|
||||
updateRcvConnWithSndQueue dbConn connAlias sndQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(Just _rcvQ, Nothing) -> do
|
||||
upsertServer_ dbConn (server (sndQueue :: SndQueue))
|
||||
insertSndQueue_ dbConn sndQueue
|
||||
updateConnWithSndQueue_ dbConn connAlias sndQueue
|
||||
return $ Right ()
|
||||
(Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd)
|
||||
(Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex)
|
||||
_ -> return $ Left SEBadConn
|
||||
|
||||
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
|
||||
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -431,20 +471,6 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
|
||||
-- * upgradeSndConnToDuplex helpers
|
||||
|
||||
updateSndConnWithRcvQueue :: DB.Connection -> ConnAlias -> RcvQueue -> IO (Either StoreError ())
|
||||
updateSndConnWithRcvQueue dbConn connAlias rcvQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(Nothing, Just _sndQ) -> do
|
||||
upsertServer_ dbConn (server (rcvQueue :: RcvQueue))
|
||||
insertRcvQueue_ dbConn rcvQueue
|
||||
updateConnWithRcvQueue_ dbConn connAlias rcvQueue
|
||||
return $ Right ()
|
||||
(Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv)
|
||||
(Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex)
|
||||
_ -> return $ Left SEBadConn
|
||||
|
||||
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
|
||||
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -457,74 +483,40 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
|]
|
||||
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
|
||||
|
||||
-- * setRcvQueueStatus helper
|
||||
-- * updateRcvIds helpers
|
||||
|
||||
-- ? throw error if queue doesn't exist?
|
||||
updateRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
|
||||
updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} status =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
|
||||
-- * setSndQueueStatus helper
|
||||
|
||||
-- ? throw error if queue doesn't exist?
|
||||
updateSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO ()
|
||||
updateSndQueueStatus dbConn SndQueue {sndId, server = SMPServer {host, port}} status =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE snd_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
|
||||
|
||||
-- * createRcvMsg helpers
|
||||
|
||||
insertRcvMsg ::
|
||||
DB.Connection ->
|
||||
ConnAlias ->
|
||||
MsgBody ->
|
||||
InternalTs ->
|
||||
(ExternalSndId, ExternalSndTs) ->
|
||||
(BrokerId, BrokerTs) ->
|
||||
IO (Either StoreError InternalId)
|
||||
insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(Just _rcvQ, _) -> do
|
||||
(lastInternalId, lastInternalRcvId) <- retrieveLastInternalIdsRcv_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
let internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
||||
insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody
|
||||
insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs)
|
||||
updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId
|
||||
return $ Right internalId
|
||||
(Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd)
|
||||
_ -> return $ Left SEBadConn
|
||||
|
||||
retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId)
|
||||
retrieveLastInternalIdsRcv_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalRcvId)] <-
|
||||
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_rcv_msg_id
|
||||
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
return (lastInternalId, lastInternalRcvId)
|
||||
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalRcvId -> MsgBody -> IO ()
|
||||
insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = do
|
||||
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_internal_msg_id = :last_internal_msg_id,
|
||||
last_internal_rcv_msg_id = :last_internal_rcv_msg_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_rcv_msg_id" := newInternalRcvId,
|
||||
":conn_alias" := connAlias
|
||||
]
|
||||
|
||||
-- * createRcvMsg helpers
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -540,82 +532,85 @@ insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody =
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertRcvMsgDetails_ ::
|
||||
DB.Connection ->
|
||||
ConnAlias ->
|
||||
InternalRcvId ->
|
||||
InternalId ->
|
||||
(ExternalSndId, ExternalSndTs) ->
|
||||
(BrokerId, BrokerTs) ->
|
||||
IO ()
|
||||
insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) =
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO rcv_messages
|
||||
( conn_alias, internal_rcv_id, internal_id, external_snd_id, external_snd_ts,
|
||||
broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts)
|
||||
broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts,
|
||||
internal_hash, external_prev_snd_hash, integrity)
|
||||
VALUES
|
||||
(:conn_alias,:internal_rcv_id,:internal_id,:external_snd_id,:external_snd_ts,
|
||||
:broker_id,:broker_ts,:rcv_status, NULL, NULL);
|
||||
:broker_id,:broker_ts,:rcv_status, NULL, NULL,
|
||||
:internal_hash,:external_prev_snd_hash,:integrity);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":internal_id" := internalId,
|
||||
":external_snd_id" := externalSndId,
|
||||
":external_snd_ts" := externalSndTs,
|
||||
":broker_id" := brokerId,
|
||||
":broker_ts" := brokerTs,
|
||||
":rcv_status" := Received
|
||||
":external_snd_id" := fst senderMeta,
|
||||
":external_snd_ts" := snd senderMeta,
|
||||
":broker_id" := fst brokerMeta,
|
||||
":broker_ts" := snd brokerMeta,
|
||||
":rcv_status" := Received,
|
||||
":internal_hash" := internalHash,
|
||||
":external_prev_snd_hash" := externalPrevSndHash,
|
||||
":integrity" := msgIntegrity
|
||||
]
|
||||
|
||||
updateLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastInternalIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_external_snd_msg_id = :last_external_snd_msg_id,
|
||||
last_rcv_msg_hash = :last_rcv_msg_hash
|
||||
WHERE conn_alias = :conn_alias
|
||||
AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id;
|
||||
|]
|
||||
[ ":last_external_snd_msg_id" := fst senderMeta,
|
||||
":last_rcv_msg_hash" := internalHash,
|
||||
":conn_alias" := connAlias,
|
||||
":last_internal_rcv_msg_id" := internalRcvId
|
||||
]
|
||||
|
||||
-- * updateSndIds helpers
|
||||
|
||||
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
return (lastInternalId, lastInternalSndId, lastSndHash)
|
||||
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_internal_msg_id = :last_internal_msg_id, last_internal_rcv_msg_id = :last_internal_rcv_msg_id
|
||||
SET last_internal_msg_id = :last_internal_msg_id,
|
||||
last_internal_snd_msg_id = :last_internal_snd_msg_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_rcv_msg_id" := newInternalRcvId,
|
||||
":last_internal_snd_msg_id" := newInternalSndId,
|
||||
":conn_alias" := connAlias
|
||||
]
|
||||
|
||||
-- * createSndMsg helpers
|
||||
|
||||
insertSndMsg :: DB.Connection -> ConnAlias -> MsgBody -> InternalTs -> IO (Either StoreError InternalId)
|
||||
insertSndMsg dbConn connAlias msgBody internalTs =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(_, Just _sndQ) -> do
|
||||
(lastInternalId, lastInternalSndId) <- retrieveLastInternalIdsSnd_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
let internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
||||
insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody
|
||||
insertSndMsgDetails_ dbConn connAlias internalSndId internalId
|
||||
updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId
|
||||
return $ Right internalId
|
||||
(Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv)
|
||||
_ -> return $ Left SEBadConn
|
||||
|
||||
retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId)
|
||||
retrieveLastInternalIdsSnd_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalSndId)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_snd_msg_id
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
return (lastInternalId, lastInternalSndId)
|
||||
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalSndId -> MsgBody -> IO ()
|
||||
insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = do
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -631,32 +626,35 @@ insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody =
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> InternalSndId -> InternalId -> IO ()
|
||||
insertSndMsgDetails_ dbConn connAlias internalSndId internalId =
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO snd_messages
|
||||
( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts)
|
||||
( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts, internal_hash)
|
||||
VALUES
|
||||
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL);
|
||||
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
":internal_snd_id" := internalSndId,
|
||||
":internal_id" := internalId,
|
||||
":snd_status" := Created
|
||||
":snd_status" := Created,
|
||||
":internal_hash" := internalHash
|
||||
]
|
||||
|
||||
updateLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastInternalIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
updateHashSnd_ dbConn connAlias SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_internal_msg_id = :last_internal_msg_id, last_internal_snd_msg_id = :last_internal_snd_msg_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
SET last_snd_msg_hash = :last_snd_msg_hash
|
||||
WHERE conn_alias = :conn_alias
|
||||
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_snd_msg_id" := newInternalSndId,
|
||||
":conn_alias" := connAlias
|
||||
[ ":last_snd_msg_hash" := internalHash,
|
||||
":conn_alias" := connAlias,
|
||||
":last_internal_snd_msg_id" := internalSndId
|
||||
]
|
||||
|
||||
@@ -90,6 +90,9 @@ connections =
|
||||
last_internal_msg_id INTEGER NOT NULL,
|
||||
last_internal_rcv_msg_id INTEGER NOT NULL,
|
||||
last_internal_snd_msg_id INTEGER NOT NULL,
|
||||
last_external_snd_msg_id INTEGER NOT NULL,
|
||||
last_rcv_msg_hash BLOB NOT NULL,
|
||||
last_snd_msg_hash BLOB NOT NULL,
|
||||
PRIMARY KEY (conn_alias),
|
||||
FOREIGN KEY (rcv_host, rcv_port, rcv_id) REFERENCES rcv_queues (host, port, rcv_id),
|
||||
FOREIGN KEY (snd_host, snd_port, snd_id) REFERENCES snd_queues (host, port, snd_id)
|
||||
@@ -135,6 +138,9 @@ rcvMessages =
|
||||
rcv_status TEXT NOT NULL,
|
||||
ack_brocker_ts TEXT,
|
||||
ack_sender_ts TEXT,
|
||||
internal_hash BLOB NOT NULL,
|
||||
external_prev_snd_hash BLOB NOT NULL,
|
||||
integrity BLOB NOT NULL,
|
||||
PRIMARY KEY (conn_alias, internal_rcv_id),
|
||||
FOREIGN KEY (conn_alias, internal_id)
|
||||
REFERENCES messages (conn_alias, internal_id)
|
||||
@@ -152,6 +158,7 @@ sndMessages =
|
||||
snd_status TEXT NOT NULL,
|
||||
sent_ts TEXT,
|
||||
delivered_ts TEXT,
|
||||
internal_hash BLOB NOT NULL,
|
||||
PRIMARY KEY (conn_alias, internal_snd_id),
|
||||
FOREIGN KEY (conn_alias, internal_id)
|
||||
REFERENCES messages (conn_alias, internal_id)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -12,7 +13,7 @@
|
||||
|
||||
module Simplex.Messaging.Agent.Transmission where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
@@ -21,13 +22,14 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind
|
||||
import Data.Kind (Type)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.ISO8601
|
||||
import Data.Type.Equality
|
||||
import Data.Typeable ()
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol
|
||||
@@ -37,12 +39,12 @@ import Simplex.Messaging.Protocol
|
||||
MsgBody,
|
||||
MsgId,
|
||||
SenderPublicKey,
|
||||
errMessageBody,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
import Text.Read
|
||||
import UnliftIO.Exception
|
||||
|
||||
@@ -88,11 +90,11 @@ data ACommand (p :: AParty) where
|
||||
SEND :: MsgBody -> ACommand Client
|
||||
SENT :: AgentMsgId -> ACommand Agent
|
||||
MSG ::
|
||||
{ m_recipient :: (AgentMsgId, UTCTime),
|
||||
m_broker :: (MsgId, UTCTime),
|
||||
m_sender :: (AgentMsgId, UTCTime),
|
||||
m_status :: MsgStatus,
|
||||
m_body :: MsgBody
|
||||
{ recipientMeta :: (AgentMsgId, UTCTime),
|
||||
brokerMeta :: (MsgId, UTCTime),
|
||||
senderMeta :: (AgentMsgId, UTCTime),
|
||||
msgIntegrity :: MsgIntegrity,
|
||||
msgBody :: MsgBody
|
||||
} ->
|
||||
ACommand Agent
|
||||
-- ACK :: AgentMsgId -> ACommand Client
|
||||
@@ -125,7 +127,7 @@ data AMessage where
|
||||
deriving (Show)
|
||||
|
||||
parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage
|
||||
parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage
|
||||
parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE
|
||||
where
|
||||
smpMessageP :: Parser SMPMessage
|
||||
smpMessageP =
|
||||
@@ -140,7 +142,9 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage
|
||||
SMPMessage
|
||||
<$> A.decimal <* A.space
|
||||
<*> tsISO8601P <* A.space
|
||||
<*> base64P <* A.endOfLine
|
||||
-- TODO previous message hash should become mandatory when we support HELLO and REPLY
|
||||
-- (for HELLO it would be the hash of SMPConfirmation)
|
||||
<*> (base64P <|> pure "") <* A.endOfLine
|
||||
<*> agentMessageP
|
||||
|
||||
serializeSMPMessage :: SMPMessage -> ByteString
|
||||
@@ -152,7 +156,7 @@ serializeSMPMessage = \case
|
||||
in smpMessage "" header body
|
||||
where
|
||||
messageHeader msgId ts prevMsgHash =
|
||||
B.unwords [B.pack $ show msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash]
|
||||
B.unwords [bshow msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash]
|
||||
smpMessage smpHeader aHeader aBody = B.intercalate "\n" [smpHeader, aHeader, aBody, ""]
|
||||
|
||||
agentMessageP :: Parser AMessage
|
||||
@@ -164,8 +168,8 @@ agentMessageP =
|
||||
hello = HELLO <$> C.pubKeyP <*> ackMode
|
||||
reply = REPLY <$> smpQueueInfoP
|
||||
a_msg = do
|
||||
size :: Int <- A.decimal
|
||||
A_MSG <$> (A.endOfLine *> A.take size <* A.endOfLine)
|
||||
size :: Int <- A.decimal <* A.endOfLine
|
||||
A_MSG <$> A.take size <* A.endOfLine
|
||||
ackMode = " NO_ACK" $> AckMode Off <|> pure (AckMode On)
|
||||
|
||||
smpQueueInfoP :: Parser SMPQueueInfo
|
||||
@@ -173,14 +177,14 @@ smpQueueInfoP =
|
||||
"smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP)
|
||||
|
||||
smpServerP :: Parser SMPServer
|
||||
smpServerP = SMPServer <$> server <*> port <*> msgHash
|
||||
smpServerP = SMPServer <$> server <*> optional port <*> optional kHash
|
||||
where
|
||||
server = B.unpack <$> A.takeTill (A.inClass ":# ")
|
||||
port = A.char ':' *> (Just . show <$> (A.decimal :: Parser Int)) <|> pure Nothing
|
||||
msgHash = A.char '#' *> (Just <$> base64P) <|> pure Nothing
|
||||
port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit)
|
||||
kHash = A.char '#' *> C.keyHashP
|
||||
|
||||
parseAgentMessage :: ByteString -> Either AgentErrorType AMessage
|
||||
parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage
|
||||
parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE
|
||||
|
||||
serializeAgentMessage :: AMessage -> ByteString
|
||||
serializeAgentMessage = \case
|
||||
@@ -194,17 +198,15 @@ serializeSmpQueueInfo (SMPQueueInfo srv qId ek) =
|
||||
|
||||
serializeServer :: SMPServer -> ByteString
|
||||
serializeServer SMPServer {host, port, keyHash} =
|
||||
B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash
|
||||
B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack . C.serializeKeyHash) keyHash
|
||||
|
||||
data SMPServer = SMPServer
|
||||
{ host :: HostName,
|
||||
port :: Maybe ServiceName,
|
||||
keyHash :: Maybe KeyHash
|
||||
keyHash :: Maybe C.KeyHash
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
type KeyHash = Encoded
|
||||
|
||||
type ConnAlias = ByteString
|
||||
|
||||
type OtherPartyId = Encoded
|
||||
@@ -220,9 +222,9 @@ data ReplyMode = ReplyOff | ReplyOn | ReplyVia SMPServer deriving (Eq, Show)
|
||||
|
||||
type EncryptionKey = C.PublicKey
|
||||
|
||||
type DecryptionKey = C.PrivateKey
|
||||
type DecryptionKey = C.SafePrivateKey
|
||||
|
||||
type SignatureKey = C.PrivateKey
|
||||
type SignatureKey = C.SafePrivateKey
|
||||
|
||||
type VerificationKey = C.PublicKey
|
||||
|
||||
@@ -235,56 +237,60 @@ type AgentMsgId = Int64
|
||||
|
||||
type SenderTimestamp = UTCTime
|
||||
|
||||
data MsgStatus = MsgOk | MsgError MsgErrorType
|
||||
data MsgIntegrity = MsgOk | MsgError MsgErrorType
|
||||
deriving (Eq, Show)
|
||||
|
||||
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash
|
||||
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash | MsgDuplicate
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | error type used in errors sent to agent clients
|
||||
data AgentErrorType
|
||||
= UNKNOWN
|
||||
| PROHIBITED
|
||||
| SYNTAX Int
|
||||
| BROKER Natural
|
||||
| SMP ErrorType
|
||||
| CRYPTO C.CryptoError
|
||||
| SIZE
|
||||
| STORE
|
||||
| INTERNAL -- etc. TODO SYNTAX Natural
|
||||
deriving (Eq, Show, Exception)
|
||||
= CMD CommandErrorType -- command errors
|
||||
| CONN ConnectionErrorType -- connection state errors
|
||||
| SMP ErrorType -- SMP protocol errors forwarded to agent clients
|
||||
| BROKER BrokerErrorType -- SMP server errors
|
||||
| AGENT SMPAgentError -- errors of other agents
|
||||
| INTERNAL String -- agent implementation errors
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
data AckStatus = AckOk | AckError AckErrorType
|
||||
deriving (Show)
|
||||
data CommandErrorType
|
||||
= PROHIBITED -- command is prohibited
|
||||
| SYNTAX -- command syntax is invalid
|
||||
| NO_CONN -- connection alias is required with this command
|
||||
| SIZE -- message size is not correct (no terminating space)
|
||||
| LARGE -- message does not fit SMP block
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
data AckErrorType = AckUnknown | AckProhibited | AckSyntax Int -- etc.
|
||||
deriving (Show)
|
||||
data ConnectionErrorType
|
||||
= UNKNOWN -- connection alias not in database
|
||||
| DUPLICATE -- connection alias already exists
|
||||
| SIMPLEX -- connection is simplex, but operation requires another queue
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
errBadEncoding :: Int
|
||||
errBadEncoding = 10
|
||||
data BrokerErrorType
|
||||
= RESPONSE ErrorType -- invalid server response (failed to parse)
|
||||
| UNEXPECTED -- unexpected response
|
||||
| NETWORK -- network error
|
||||
| TRANSPORT TransportError -- handshake or other transport error
|
||||
| TIMEOUT -- command response timeout
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
errBadCommand :: Int
|
||||
errBadCommand = 11
|
||||
data SMPAgentError
|
||||
= A_MESSAGE -- possibly should include bytestring that failed to parse
|
||||
| A_PROHIBITED -- possibly should include the prohibited SMP/agent message
|
||||
| A_ENCRYPTION -- cannot RSA/AES-decrypt or parse decrypted header
|
||||
| A_SIGNATURE -- invalid RSA signature
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
errBadInvitation :: Int
|
||||
errBadInvitation = 12
|
||||
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
errNoConnAlias :: Int
|
||||
errNoConnAlias = 13
|
||||
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
errBadMessage :: Int
|
||||
errBadMessage = 14
|
||||
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
errBadServer :: Int
|
||||
errBadServer = 15
|
||||
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
smpErrTCPConnection :: Natural
|
||||
smpErrTCPConnection = 1
|
||||
|
||||
smpErrCorrelationId :: Natural
|
||||
smpErrCorrelationId = 2
|
||||
|
||||
smpUnexpectedResponse :: Natural
|
||||
smpUnexpectedResponse = 3
|
||||
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
|
||||
|
||||
commandP :: Parser ACmd
|
||||
commandP =
|
||||
@@ -309,28 +315,30 @@ commandP =
|
||||
sendCmd = ACmd SClient . SEND <$> A.takeByteString
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
message = do
|
||||
m_status <- status <* A.space
|
||||
m_recipient <- "R=" *> partyMeta A.decimal
|
||||
m_broker <- "B=" *> partyMeta base64P
|
||||
m_sender <- "S=" *> partyMeta A.decimal
|
||||
m_body <- A.takeByteString
|
||||
return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body}
|
||||
-- TODO other error types
|
||||
agentError = ACmd SAgent . ERR <$> ("SMP " *> smpErrorType)
|
||||
smpErrorType = "AUTH" $> SMP SMP.AUTH
|
||||
msgIntegrity <- msgIntegrityP <* A.space
|
||||
recipientMeta <- "R=" *> partyMeta A.decimal
|
||||
brokerMeta <- "B=" *> partyMeta base64P
|
||||
senderMeta <- "S=" *> partyMeta A.decimal
|
||||
msgBody <- A.takeByteString
|
||||
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
|
||||
replyMode =
|
||||
" NO_REPLY" $> ReplyOff
|
||||
<|> A.space *> (ReplyVia <$> smpServerP)
|
||||
<|> pure ReplyOn
|
||||
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
|
||||
status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType)
|
||||
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
|
||||
|
||||
msgIntegrityP :: Parser MsgIntegrity
|
||||
msgIntegrityP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType)
|
||||
where
|
||||
msgErrorType =
|
||||
"ID " *> (MsgBadId <$> A.decimal)
|
||||
<|> "NO_ID " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal)
|
||||
<|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal)
|
||||
<|> "HASH" $> MsgBadHash
|
||||
<|> "DUPLICATE" $> MsgDuplicate
|
||||
|
||||
parseCommand :: ByteString -> Either AgentErrorType ACmd
|
||||
parseCommand = parse commandP $ SYNTAX errBadCommand
|
||||
parseCommand = parse commandP $ CMD SYNTAX
|
||||
|
||||
serializeCommand :: ACommand p -> ByteString
|
||||
serializeCommand = \case
|
||||
@@ -342,19 +350,19 @@ serializeCommand = \case
|
||||
END -> "END"
|
||||
SEND msgBody -> "SEND " <> serializeMsg msgBody
|
||||
SENT mId -> "SENT " <> bshow mId
|
||||
MSG {m_recipient = (rmId, rTs), m_broker = (bmId, bTs), m_sender = (smId, sTs), m_status, m_body} ->
|
||||
MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} ->
|
||||
B.unwords
|
||||
[ "MSG",
|
||||
msgStatus m_status,
|
||||
serializeMsgIntegrity msgIntegrity,
|
||||
"R=" <> bshow rmId <> "," <> showTs rTs,
|
||||
"B=" <> encode bmId <> "," <> showTs bTs,
|
||||
"S=" <> bshow smId <> "," <> showTs sTs,
|
||||
serializeMsg m_body
|
||||
serializeMsg msgBody
|
||||
]
|
||||
OFF -> "OFF"
|
||||
DEL -> "DEL"
|
||||
CON -> "CON"
|
||||
ERR e -> "ERR " <> B.pack (show e)
|
||||
ERR e -> "ERR " <> serializeAgentError e
|
||||
OK -> "OK"
|
||||
where
|
||||
replyMode :: ReplyMode -> ByteString
|
||||
@@ -364,19 +372,35 @@ serializeCommand = \case
|
||||
ReplyOn -> ""
|
||||
showTs :: UTCTime -> ByteString
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
msgStatus :: MsgStatus -> ByteString
|
||||
msgStatus = \case
|
||||
MsgOk -> "OK"
|
||||
MsgError e ->
|
||||
"ERR" <> case e of
|
||||
MsgSkipped fromMsgId toMsgId ->
|
||||
B.unwords ["NO_ID", B.pack $ show fromMsgId, B.pack $ show toMsgId]
|
||||
MsgBadId aMsgId -> "ID " <> B.pack (show aMsgId)
|
||||
MsgBadHash -> "HASH"
|
||||
|
||||
-- TODO - save function as in the server Transmission - re-use?
|
||||
serializeMsgIntegrity :: MsgIntegrity -> ByteString
|
||||
serializeMsgIntegrity = \case
|
||||
MsgOk -> "OK"
|
||||
MsgError e ->
|
||||
"ERR " <> case e of
|
||||
MsgSkipped fromMsgId toMsgId ->
|
||||
B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId]
|
||||
MsgBadId aMsgId -> "ID " <> bshow aMsgId
|
||||
MsgBadHash -> "HASH"
|
||||
MsgDuplicate -> "DUPLICATE"
|
||||
|
||||
agentErrorTypeP :: Parser AgentErrorType
|
||||
agentErrorTypeP =
|
||||
"SMP " *> (SMP <$> SMP.errorTypeP)
|
||||
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP)
|
||||
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
|
||||
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
|
||||
<|> parseRead2
|
||||
|
||||
serializeAgentError :: AgentErrorType -> ByteString
|
||||
serializeAgentError = \case
|
||||
SMP e -> "SMP " <> SMP.serializeErrorType e
|
||||
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e
|
||||
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
|
||||
e -> bshow e
|
||||
|
||||
serializeMsg :: ByteString -> ByteString
|
||||
serializeMsg body = B.pack (show $ B.length body) <> "\n" <> body
|
||||
serializeMsg body = bshow (B.length body) <> "\n" <> body
|
||||
|
||||
tPutRaw :: Handle -> ARawTransmission -> IO ()
|
||||
tPutRaw h (corrId, connAlias, command) = do
|
||||
@@ -404,7 +428,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
fromParty :: ACmd -> Either AgentErrorType (ACommand p)
|
||||
fromParty (ACmd (p :: p1) cmd) = case testEquality party p of
|
||||
Just Refl -> Right cmd
|
||||
_ -> Left PROHIBITED
|
||||
_ -> Left $ CMD PROHIBITED
|
||||
|
||||
tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p)
|
||||
tConnAlias (_, connAlias, _) cmd = case cmd of
|
||||
@@ -415,13 +439,13 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
ERR _ -> Right cmd
|
||||
-- other responses must have connAlias
|
||||
_
|
||||
| B.null connAlias -> Left $ SYNTAX errNoConnAlias
|
||||
| B.null connAlias -> Left $ CMD NO_CONN
|
||||
| otherwise -> Right cmd
|
||||
|
||||
cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p))
|
||||
cmdWithMsgBody = \case
|
||||
SEND body -> SEND <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
|
||||
cmd -> return $ Right cmd
|
||||
|
||||
-- TODO refactor with server
|
||||
@@ -433,5 +457,5 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
Just size -> liftIO $ do
|
||||
body <- B.hGet h size
|
||||
s <- getLn h
|
||||
return $ if B.null s then Right body else Left SIZE
|
||||
Nothing -> return . Left $ SYNTAX errMessageBody
|
||||
return $ if B.null s then Right body else Left $ CMD SIZE
|
||||
Nothing -> return . Left $ CMD SYNTAX
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
@@ -8,7 +9,7 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Client
|
||||
( SMPClient,
|
||||
( SMPClient (blockSize),
|
||||
getSMPClient,
|
||||
closeSMPClient,
|
||||
createSMPQueue,
|
||||
@@ -33,33 +34,31 @@ import Control.Exception
|
||||
import Control.Monad
|
||||
import Control.Monad.Trans.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import qualified Crypto.PubKey.RSA.Types as RSA
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe
|
||||
import GHC.IO.Exception (IOErrorType (..))
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Transmission (SMPServer (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util (liftEitherError, raceAny_)
|
||||
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
|
||||
import System.IO
|
||||
import System.IO.Error
|
||||
import System.Timeout
|
||||
|
||||
data SMPClient = SMPClient
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
smpServer :: SMPServer,
|
||||
tcpTimeout :: Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TVar (Map CorrId Request),
|
||||
sndQ :: TBQueue SignedRawTransmission,
|
||||
rcvQ :: TBQueue SignedTransmissionOrError,
|
||||
msgQ :: TBQueue SMPServerTransmission
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
blockSize :: Int
|
||||
}
|
||||
|
||||
type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker)
|
||||
@@ -69,7 +68,6 @@ data SMPClientConfig = SMPClientConfig
|
||||
defaultPort :: ServiceName,
|
||||
tcpTimeout :: Int,
|
||||
smpPing :: Int,
|
||||
blockSize :: Int,
|
||||
smpCommandSize :: Int
|
||||
}
|
||||
|
||||
@@ -78,36 +76,36 @@ smpDefaultConfig =
|
||||
SMPClientConfig
|
||||
{ qSize = 16,
|
||||
defaultPort = "5223",
|
||||
tcpTimeout = 2_000_000,
|
||||
tcpTimeout = 4_000_000,
|
||||
smpPing = 30_000_000,
|
||||
blockSize = 8_192, -- 16_384,
|
||||
smpCommandSize = 256
|
||||
}
|
||||
|
||||
data Request = Request
|
||||
{ queueId :: QueueId,
|
||||
responseVar :: TMVar (Either SMPClientError Cmd)
|
||||
responseVar :: TMVar Response
|
||||
}
|
||||
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient
|
||||
type Response = Either SMPClientError Cmd
|
||||
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient)
|
||||
getSMPClient
|
||||
smpServer@SMPServer {host, port}
|
||||
smpServer@SMPServer {host, port, keyHash}
|
||||
SMPClientConfig {qSize, defaultPort, tcpTimeout, smpPing}
|
||||
msgQ
|
||||
disconnected = do
|
||||
c <- atomically mkSMPClient
|
||||
started <- newEmptyTMVarIO
|
||||
thVar <- newEmptyTMVarIO
|
||||
action <-
|
||||
async $
|
||||
runTCPClient host (fromMaybe defaultPort port) (client c started)
|
||||
`finally` atomically (putTMVar started False)
|
||||
tcpTimeout `timeout` atomically (takeTMVar started) >>= \case
|
||||
Just True -> return c {action}
|
||||
_ -> throwIO err
|
||||
runTCPClient host (fromMaybe defaultPort port) (client c thVar)
|
||||
`finally` atomically (putTMVar thVar $ Left SMPNetworkError)
|
||||
tHandle <- tcpTimeout `timeout` atomically (takeTMVar thVar)
|
||||
pure $ case tHandle of
|
||||
Just (Right THandle {blockSize}) -> Right c {action, blockSize}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left SMPNetworkError
|
||||
where
|
||||
err :: IOException
|
||||
err = mkIOError TimeExpired "connection timeout" Nothing Nothing
|
||||
|
||||
mkSMPClient :: STM SMPClient
|
||||
mkSMPClient = do
|
||||
connected <- newTVar False
|
||||
@@ -118,8 +116,10 @@ getSMPClient
|
||||
return
|
||||
SMPClient
|
||||
{ action = undefined,
|
||||
blockSize = undefined,
|
||||
connected,
|
||||
smpServer,
|
||||
tcpTimeout,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
sndQ,
|
||||
@@ -127,19 +127,24 @@ getSMPClient
|
||||
msgQ
|
||||
}
|
||||
|
||||
client :: SMPClient -> TMVar Bool -> Handle -> IO ()
|
||||
client c started h = do
|
||||
_ <- getLn h -- "Welcome to SMP"
|
||||
client :: SMPClient -> TMVar (Either SMPClientError THandle) -> Handle -> IO ()
|
||||
client c thVar h =
|
||||
runExceptT (clientHandshake h keyHash) >>= \case
|
||||
Right th -> clientTransport c thVar th
|
||||
Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e
|
||||
|
||||
clientTransport :: SMPClient -> TMVar (Either SMPClientError THandle) -> THandle -> IO ()
|
||||
clientTransport c thVar th = do
|
||||
atomically $ do
|
||||
modifyTVar (connected c) (const True)
|
||||
putTMVar started True
|
||||
raceAny_ [send c h, process c, receive c h, ping c]
|
||||
writeTVar (connected c) True
|
||||
putTMVar thVar $ Right th
|
||||
raceAny_ [send c th, process c, receive c th, ping c]
|
||||
`finally` disconnected
|
||||
|
||||
send :: SMPClient -> Handle -> IO ()
|
||||
send :: SMPClient -> THandle -> IO ()
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
||||
receive :: SMPClient -> Handle -> IO ()
|
||||
receive :: SMPClient -> THandle -> IO ()
|
||||
receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: SMPClient -> IO ()
|
||||
@@ -165,7 +170,7 @@ getSMPClient
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (Cmd _ (ERR e)) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPQueueIdError
|
||||
else Left SMPUnexpectedResponse
|
||||
|
||||
closeSMPClient :: SMPClient -> IO ()
|
||||
closeSMPClient = uninterruptibleCancel . action
|
||||
@@ -173,11 +178,11 @@ closeSMPClient = uninterruptibleCancel . action
|
||||
data SMPClientError
|
||||
= SMPServerError ErrorType
|
||||
| SMPResponseError ErrorType
|
||||
| SMPQueueIdError
|
||||
| SMPUnexpectedResponse
|
||||
| SMPResponseTimeout
|
||||
| SMPCryptoError RSA.Error
|
||||
| SMPClientError
|
||||
| SMPNetworkError
|
||||
| SMPTransportError TransportError
|
||||
| SMPSignatureError C.CryptoError
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
createSMPQueue ::
|
||||
@@ -222,14 +227,14 @@ suspendSMPQueue = okSMPCommand $ Cmd SRecipient OFF
|
||||
deleteSMPQueue :: SMPClient -> RecipientPrivateKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand $ Cmd SRecipient DEL
|
||||
|
||||
okSMPCommand :: Cmd -> SMPClient -> C.PrivateKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand :: Cmd -> SMPClient -> C.SafePrivateKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
sendSMPCommand c (Just pKey) qId cmd >>= \case
|
||||
Cmd _ OK -> return ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
|
||||
sendSMPCommand :: SMPClient -> Maybe C.PrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd
|
||||
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do
|
||||
sendSMPCommand :: SMPClient -> Maybe C.SafePrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd
|
||||
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId cmd = do
|
||||
corrId <- lift_ getNextCorrId
|
||||
t <- signTransmission $ serializeTransmission (corrId, qId, cmd)
|
||||
ExceptT $ sendRecv corrId t
|
||||
@@ -241,20 +246,22 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do
|
||||
getNextCorrId = do
|
||||
i <- (+ 1) <$> readTVar clientCorrId
|
||||
writeTVar clientCorrId i
|
||||
return . CorrId . B.pack $ show i
|
||||
return . CorrId $ bshow i
|
||||
|
||||
signTransmission :: ByteString -> ExceptT SMPClientError IO SignedRawTransmission
|
||||
signTransmission t = case pKey of
|
||||
Nothing -> return ("", t)
|
||||
Just pk -> do
|
||||
sig <- liftEitherError SMPCryptoError $ C.sign pk t
|
||||
sig <- liftError SMPSignatureError $ C.sign pk t
|
||||
return (sig, t)
|
||||
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: CorrId -> SignedRawTransmission -> IO (Either SMPClientError Cmd)
|
||||
sendRecv corrId t = atomically (send corrId t) >>= atomically . takeTMVar
|
||||
sendRecv :: CorrId -> SignedRawTransmission -> IO Response
|
||||
sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar
|
||||
where
|
||||
withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
send :: CorrId -> SignedRawTransmission -> STM (TMVar (Either SMPClientError Cmd))
|
||||
send :: CorrId -> SignedRawTransmission -> STM (TMVar Response)
|
||||
send corrId t = do
|
||||
r <- newEmptyTMVar
|
||||
modifyTVar sentCommands . M.insert corrId $ Request qId r
|
||||
|
||||
@@ -1,26 +1,53 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Simplex.Messaging.Crypto
|
||||
( PrivateKey (..),
|
||||
( PrivateKey (rsaPrivateKey),
|
||||
SafePrivateKey, -- constructor is not exported
|
||||
FullPrivateKey (..),
|
||||
PublicKey (..),
|
||||
Signature (..),
|
||||
CryptoError (..),
|
||||
SafeKeyPair,
|
||||
FullKeyPair,
|
||||
Key (..),
|
||||
IV (..),
|
||||
KeyHash (..),
|
||||
generateKeyPair,
|
||||
publicKey,
|
||||
publicKeySize,
|
||||
safePrivateKey,
|
||||
sign,
|
||||
verify,
|
||||
encrypt,
|
||||
decrypt,
|
||||
encryptOAEP,
|
||||
decryptOAEP,
|
||||
encryptAES,
|
||||
decryptAES,
|
||||
serializePrivKey,
|
||||
serializePubKey,
|
||||
parsePrivKey,
|
||||
parsePubKey,
|
||||
encodePubKey,
|
||||
serializeKeyHash,
|
||||
getKeyHash,
|
||||
sha256Hash,
|
||||
privKeyP,
|
||||
pubKeyP,
|
||||
binaryPubKeyP,
|
||||
keyHashP,
|
||||
authTagSize,
|
||||
authTagToBS,
|
||||
bsToAuthTag,
|
||||
randomAesKey,
|
||||
randomIV,
|
||||
aesKeyP,
|
||||
ivP,
|
||||
)
|
||||
where
|
||||
|
||||
@@ -30,60 +57,89 @@ import Control.Monad.Trans.Except
|
||||
import Crypto.Cipher.AES (AES256)
|
||||
import qualified Crypto.Cipher.Types as AES
|
||||
import qualified Crypto.Error as CE
|
||||
import Crypto.Hash.Algorithms (SHA256 (..))
|
||||
import Crypto.Hash (Digest, SHA256 (..), digestFromByteString, hash)
|
||||
import Crypto.Number.Generate (generateMax)
|
||||
import Crypto.Number.Prime (findPrimeFrom)
|
||||
import Crypto.Number.Serialize (i2osp, os2ip)
|
||||
import qualified Crypto.PubKey.RSA as R
|
||||
import qualified Crypto.PubKey.RSA.OAEP as OAEP
|
||||
import qualified Crypto.PubKey.RSA.PSS as PSS
|
||||
import Crypto.Random (getRandomBytes)
|
||||
import Data.ASN1.BinaryEncoding
|
||||
import Data.ASN1.Encoding
|
||||
import Data.ASN1.Types
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import qualified Data.ByteArray as BA
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Base64 (decode, encode)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.ByteString.Internal (c2w, w2c)
|
||||
import Data.ByteString.Lazy (fromStrict, toStrict)
|
||||
import Data.String
|
||||
import Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.X509
|
||||
import Database.SQLite.Simple (ResultError (..), SQLData (..))
|
||||
import Database.SQLite.Simple.FromField (FieldParser, FromField (..), returnError)
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Network.Transport.Internal (decodeWord32, encodeWord32)
|
||||
import Simplex.Messaging.Parsers (base64P)
|
||||
import Simplex.Messaging.Util (bshow, liftEitherError, (<$$>))
|
||||
import Simplex.Messaging.Parsers (base64P, parseAll)
|
||||
import Simplex.Messaging.Util (liftEitherError, (<$?>))
|
||||
|
||||
newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show)
|
||||
|
||||
data PrivateKey = PrivateKey
|
||||
{ private_size :: Int,
|
||||
private_n :: Integer,
|
||||
private_d :: Integer
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
newtype SafePrivateKey = SafePrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show)
|
||||
|
||||
instance ToField PrivateKey where toField = toField . serializePrivKey
|
||||
newtype FullPrivateKey = FullPrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show)
|
||||
|
||||
instance ToField PublicKey where toField = toField . serializePubKey
|
||||
class PrivateKey k where
|
||||
rsaPrivateKey :: k -> R.PrivateKey
|
||||
_privateKey :: R.PrivateKey -> k
|
||||
mkPrivateKey :: R.PrivateKey -> k
|
||||
|
||||
instance FromField PrivateKey where
|
||||
fromField f@(Field (SQLBlob b) _) =
|
||||
case parsePrivKey b of
|
||||
instance PrivateKey SafePrivateKey where
|
||||
rsaPrivateKey = unPrivateKey
|
||||
_privateKey = SafePrivateKey
|
||||
mkPrivateKey R.PrivateKey {private_pub = k, private_d} =
|
||||
safePrivateKey (R.public_size k, R.public_n k, private_d)
|
||||
|
||||
instance PrivateKey FullPrivateKey where
|
||||
rsaPrivateKey = unPrivateKey
|
||||
_privateKey = FullPrivateKey
|
||||
mkPrivateKey = FullPrivateKey
|
||||
|
||||
instance IsString FullPrivateKey where
|
||||
fromString = parseString (decode >=> decodePrivKey)
|
||||
|
||||
instance IsString PublicKey where
|
||||
fromString = parseString (decode >=> decodePubKey)
|
||||
|
||||
parseString :: (ByteString -> Either String a) -> (String -> a)
|
||||
parseString parse = either error id . parse . B.pack
|
||||
|
||||
instance ToField SafePrivateKey where toField = toField . encodePrivKey
|
||||
|
||||
instance ToField PublicKey where toField = toField . encodePubKey
|
||||
|
||||
instance FromField SafePrivateKey where fromField = keyFromField binaryPrivKeyP
|
||||
|
||||
instance FromField PublicKey where fromField = keyFromField binaryPubKeyP
|
||||
|
||||
keyFromField :: Typeable k => Parser k -> FieldParser k
|
||||
keyFromField p = \case
|
||||
f@(Field (SQLBlob b) _) ->
|
||||
case parseAll p b of
|
||||
Right k -> Ok k
|
||||
Left e -> returnError ConversionFailed f ("couldn't parse PrivateKey field: " ++ e)
|
||||
fromField f = returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
Left e -> returnError ConversionFailed f ("couldn't parse key field: " ++ e)
|
||||
f -> returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
|
||||
instance FromField PublicKey where
|
||||
fromField f@(Field (SQLBlob b) _) =
|
||||
case parsePubKey b of
|
||||
Right k -> Ok k
|
||||
Left e -> returnError ConversionFailed f ("couldn't parse PublicKey field: " ++ e)
|
||||
fromField f = returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
type KeyPair k = (PublicKey, k)
|
||||
|
||||
type KeyPair = (PublicKey, PrivateKey)
|
||||
type SafeKeyPair = (PublicKey, SafePrivateKey)
|
||||
|
||||
type FullKeyPair = (PublicKey, FullPrivateKey)
|
||||
|
||||
newtype Signature = Signature {unSignature :: ByteString} deriving (Eq, Show)
|
||||
|
||||
@@ -93,10 +149,12 @@ instance IsString Signature where
|
||||
newtype Verified = Verified ByteString deriving (Show)
|
||||
|
||||
data CryptoError
|
||||
= CryptoRSAError R.Error
|
||||
| CryptoCipherError CE.CryptoError
|
||||
= RSAEncryptError R.Error
|
||||
| RSADecryptError R.Error
|
||||
| RSASignError R.Error
|
||||
| AESCipherError CE.CryptoError
|
||||
| CryptoIVError
|
||||
| CryptoDecryptError
|
||||
| AESDecryptError
|
||||
| CryptoLargeMsgError
|
||||
| CryptoHeaderError String
|
||||
deriving (Eq, Show, Exception)
|
||||
@@ -110,19 +168,26 @@ aesKeySize = 256 `div` 8
|
||||
authTagSize :: Int
|
||||
authTagSize = 128 `div` 8
|
||||
|
||||
generateKeyPair :: Int -> IO KeyPair
|
||||
generateKeyPair :: PrivateKey k => Int -> IO (KeyPair k)
|
||||
generateKeyPair size = loop
|
||||
where
|
||||
publicExponent = findPrimeFrom . (+ 3) <$> generateMax pubExpRange
|
||||
privateKey s n d = PrivateKey {private_size = s, private_n = n, private_d = d}
|
||||
loop = do
|
||||
(pub, priv) <- R.generate size =<< publicExponent
|
||||
let s = R.public_size pub
|
||||
n = R.public_n pub
|
||||
d = R.private_d priv
|
||||
in if d * d < n
|
||||
then loop
|
||||
else return (PublicKey pub, privateKey s n d)
|
||||
(k, pk) <- R.generate size =<< publicExponent
|
||||
let n = R.public_n k
|
||||
d = R.private_d pk
|
||||
if d * d < n
|
||||
then loop
|
||||
else pure (PublicKey k, mkPrivateKey pk)
|
||||
|
||||
privateKeySize :: PrivateKey k => k -> Int
|
||||
privateKeySize = R.public_size . R.private_pub . rsaPrivateKey
|
||||
|
||||
publicKey :: FullPrivateKey -> PublicKey
|
||||
publicKey = PublicKey . R.private_pub . rsaPrivateKey
|
||||
|
||||
publicKeySize :: PublicKey -> Int
|
||||
publicKeySize = R.public_size . rsaPublicKey
|
||||
|
||||
data Header = Header
|
||||
{ aesKey :: Key,
|
||||
@@ -135,52 +200,89 @@ newtype Key = Key {unKey :: ByteString}
|
||||
|
||||
newtype IV = IV {unIV :: ByteString}
|
||||
|
||||
newtype KeyHash = KeyHash {unKeyHash :: Digest SHA256} deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString KeyHash where
|
||||
fromString = parseString $ parseAll keyHashP
|
||||
|
||||
instance ToField KeyHash where toField = toField . serializeKeyHash
|
||||
|
||||
instance FromField KeyHash where
|
||||
fromField f@(Field (SQLBlob b) _) =
|
||||
case parseAll keyHashP b of
|
||||
Right k -> Ok k
|
||||
Left e -> returnError ConversionFailed f ("couldn't parse KeyHash field: " ++ e)
|
||||
fromField f = returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
|
||||
serializeKeyHash :: KeyHash -> ByteString
|
||||
serializeKeyHash = encode . BA.convert . unKeyHash
|
||||
|
||||
keyHashP :: Parser KeyHash
|
||||
keyHashP = do
|
||||
bs <- base64P
|
||||
case digestFromByteString bs of
|
||||
Just d -> pure $ KeyHash d
|
||||
_ -> fail "invalid digest"
|
||||
|
||||
getKeyHash :: ByteString -> KeyHash
|
||||
getKeyHash = KeyHash . hash
|
||||
|
||||
sha256Hash :: ByteString -> ByteString
|
||||
sha256Hash = BA.convert . (hash :: ByteString -> Digest SHA256)
|
||||
|
||||
serializeHeader :: Header -> ByteString
|
||||
serializeHeader Header {aesKey, ivBytes, authTag, msgSize} =
|
||||
unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize
|
||||
|
||||
headerP :: Parser Header
|
||||
headerP = do
|
||||
aesKey <- Key <$> A.take aesKeySize
|
||||
ivBytes <- IV <$> A.take (ivSize @AES256)
|
||||
aesKey <- aesKeyP
|
||||
ivBytes <- ivP
|
||||
authTag <- bsToAuthTag <$> A.take authTagSize
|
||||
msgSize <- fromIntegral . decodeWord32 <$> A.take 4
|
||||
return Header {aesKey, ivBytes, authTag, msgSize}
|
||||
|
||||
aesKeyP :: Parser Key
|
||||
aesKeyP = Key <$> A.take aesKeySize
|
||||
|
||||
ivP :: Parser IV
|
||||
ivP = IV <$> A.take (ivSize @AES256)
|
||||
|
||||
parseHeader :: ByteString -> Either CryptoError Header
|
||||
parseHeader = first CryptoHeaderError . A.parseOnly (headerP <* A.endOfInput)
|
||||
parseHeader = first CryptoHeaderError . parseAll headerP
|
||||
|
||||
encrypt :: PublicKey -> Int -> ByteString -> ExceptT CryptoError IO ByteString
|
||||
encrypt k paddedSize msg = do
|
||||
aesKey <- Key <$> randomBytes aesKeySize
|
||||
ivBytes <- IV <$> randomBytes (ivSize @AES256)
|
||||
aesKey <- liftIO randomAesKey
|
||||
ivBytes <- liftIO randomIV
|
||||
(authTag, msg') <- encryptAES aesKey ivBytes paddedSize msg
|
||||
let header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg}
|
||||
encHeader <- encryptOAEP k $ serializeHeader header
|
||||
return $ encHeader <> msg'
|
||||
|
||||
decrypt :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString
|
||||
decrypt pk msg'' = do
|
||||
let (encHeader, msg') = B.splitAt (privateKeySize pk) msg''
|
||||
header <- decryptOAEP pk encHeader
|
||||
Header {aesKey, ivBytes, authTag, msgSize} <- except $ parseHeader header
|
||||
msg <- decryptAES aesKey ivBytes msg' authTag
|
||||
return $ B.take msgSize msg
|
||||
|
||||
encryptAES :: Key -> IV -> Int -> ByteString -> ExceptT CryptoError IO (AES.AuthTag, ByteString)
|
||||
encryptAES aesKey ivBytes paddedSize msg = do
|
||||
aead <- initAEAD @AES256 aesKey ivBytes
|
||||
msg' <- paddedMsg
|
||||
let (authTag, msg'') = encryptAES aead msg'
|
||||
header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg}
|
||||
encHeader <- encryptOAEP k $ serializeHeader header
|
||||
return $ encHeader <> msg''
|
||||
return $ AES.aeadSimpleEncrypt aead B.empty msg' authTagSize
|
||||
where
|
||||
len = B.length msg
|
||||
paddedMsg
|
||||
| len >= paddedSize = throwE CryptoLargeMsgError
|
||||
| otherwise = return (msg <> B.replicate (paddedSize - len) '#')
|
||||
|
||||
decrypt :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString
|
||||
decrypt pk msg'' = do
|
||||
let (encHeader, msg') = B.splitAt (private_size pk) msg''
|
||||
header <- decryptOAEP pk encHeader
|
||||
Header {aesKey, ivBytes, authTag, msgSize} <- ExceptT . return $ parseHeader header
|
||||
decryptAES :: Key -> IV -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString
|
||||
decryptAES aesKey ivBytes msg authTag = do
|
||||
aead <- initAEAD @AES256 aesKey ivBytes
|
||||
msg <- decryptAES aead msg' authTag
|
||||
return $ B.take msgSize msg
|
||||
|
||||
encryptAES :: AES.AEAD AES256 -> ByteString -> (AES.AuthTag, ByteString)
|
||||
encryptAES aead plaintext = AES.aeadSimpleEncrypt aead B.empty plaintext authTagSize
|
||||
|
||||
decryptAES :: AES.AEAD AES256 -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString
|
||||
decryptAES aead ciphertext authTag =
|
||||
maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty ciphertext authTag
|
||||
maybeError AESDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag
|
||||
|
||||
initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c)
|
||||
initAEAD (Key aesKey) (IV ivBytes) = do
|
||||
@@ -189,15 +291,18 @@ initAEAD (Key aesKey) (IV ivBytes) = do
|
||||
cipher <- AES.cipherInit aesKey
|
||||
AES.aeadInit AES.AEAD_GCM cipher iv
|
||||
|
||||
randomAesKey :: IO Key
|
||||
randomAesKey = Key <$> getRandomBytes aesKeySize
|
||||
|
||||
randomIV :: IO IV
|
||||
randomIV = IV <$> getRandomBytes (ivSize @AES256)
|
||||
|
||||
ivSize :: forall c. AES.BlockCipher c => Int
|
||||
ivSize = AES.blockSize (undefined :: c)
|
||||
|
||||
makeIV :: AES.BlockCipher c => ByteString -> ExceptT CryptoError IO (AES.IV c)
|
||||
makeIV bs = maybeError CryptoIVError $ AES.makeIV bs
|
||||
|
||||
randomBytes :: Int -> ExceptT CryptoError IO ByteString
|
||||
randomBytes n = ExceptT $ Right <$> getRandomBytes n
|
||||
|
||||
maybeError :: CryptoError -> Maybe a -> ExceptT CryptoError IO a
|
||||
maybeError e = maybe (throwE e) return
|
||||
|
||||
@@ -208,75 +313,93 @@ bsToAuthTag :: ByteString -> AES.AuthTag
|
||||
bsToAuthTag = AES.AuthTag . BA.pack . map c2w . B.unpack
|
||||
|
||||
cryptoFailable :: CE.CryptoFailable a -> ExceptT CryptoError IO a
|
||||
cryptoFailable = liftEither . first CryptoCipherError . CE.eitherCryptoError
|
||||
cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError
|
||||
|
||||
oaepParams :: OAEP.OAEPParams SHA256 ByteString ByteString
|
||||
oaepParams = OAEP.defaultOAEPParams SHA256
|
||||
|
||||
encryptOAEP :: PublicKey -> ByteString -> ExceptT CryptoError IO ByteString
|
||||
encryptOAEP (PublicKey k) aesKey =
|
||||
liftEitherError CryptoRSAError $
|
||||
liftEitherError RSAEncryptError $
|
||||
OAEP.encrypt oaepParams k aesKey
|
||||
|
||||
decryptOAEP :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString
|
||||
decryptOAEP :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString
|
||||
decryptOAEP pk encKey =
|
||||
liftEitherError CryptoRSAError $
|
||||
liftEitherError RSADecryptError $
|
||||
OAEP.decryptSafer oaepParams (rsaPrivateKey pk) encKey
|
||||
|
||||
pssParams :: PSS.PSSParams SHA256 ByteString ByteString
|
||||
pssParams = PSS.defaultPSSParams SHA256
|
||||
|
||||
sign :: PrivateKey -> ByteString -> IO (Either R.Error Signature)
|
||||
sign pk msg = Signature <$$> PSS.signSafer pssParams (rsaPrivateKey pk) msg
|
||||
sign :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO Signature
|
||||
sign pk msg = ExceptT $ bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg
|
||||
|
||||
verify :: PublicKey -> Signature -> ByteString -> Bool
|
||||
verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig
|
||||
|
||||
serializePubKey :: PublicKey -> ByteString
|
||||
serializePubKey (PublicKey k) = serializeKey_ (R.public_size k, R.public_n k, R.public_e k)
|
||||
serializePubKey = ("rsa:" <>) . encode . encodePubKey
|
||||
|
||||
serializePrivKey :: PrivateKey -> ByteString
|
||||
serializePrivKey pk = serializeKey_ (private_size pk, private_n pk, private_d pk)
|
||||
|
||||
serializeKey_ :: (Int, Integer, Integer) -> ByteString
|
||||
serializeKey_ (size, n, ex) = bshow size <> "," <> encInt n <> "," <> encInt ex
|
||||
where
|
||||
encInt = encode . i2osp
|
||||
serializePrivKey :: PrivateKey k => k -> ByteString
|
||||
serializePrivKey = ("rsa:" <>) . encode . encodePrivKey
|
||||
|
||||
pubKeyP :: Parser PublicKey
|
||||
pubKeyP = do
|
||||
(public_size, public_n, public_e) <- keyParser_
|
||||
return . PublicKey $ R.PublicKey {R.public_size, R.public_n, R.public_e}
|
||||
pubKeyP = decodePubKey <$?> ("rsa:" *> base64P)
|
||||
|
||||
privKeyP :: Parser PrivateKey
|
||||
privKeyP = do
|
||||
(private_size, private_n, private_d) <- keyParser_
|
||||
return PrivateKey {private_size, private_n, private_d}
|
||||
binaryPubKeyP :: Parser PublicKey
|
||||
binaryPubKeyP = decodePubKey <$?> A.takeByteString
|
||||
|
||||
parsePubKey :: ByteString -> Either String PublicKey
|
||||
parsePubKey = A.parseOnly (pubKeyP <* A.endOfInput)
|
||||
privKeyP :: PrivateKey k => Parser k
|
||||
privKeyP = decodePrivKey <$?> ("rsa:" *> base64P)
|
||||
|
||||
parsePrivKey :: ByteString -> Either String PrivateKey
|
||||
parsePrivKey = A.parseOnly (privKeyP <* A.endOfInput)
|
||||
binaryPrivKeyP :: PrivateKey k => Parser k
|
||||
binaryPrivKeyP = decodePrivKey <$?> A.takeByteString
|
||||
|
||||
keyParser_ :: Parser (Int, Integer, Integer)
|
||||
keyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP
|
||||
where
|
||||
intP = os2ip <$> base64P
|
||||
safePrivateKey :: (Int, Integer, Integer) -> SafePrivateKey
|
||||
safePrivateKey = SafePrivateKey . safeRsaPrivateKey
|
||||
|
||||
rsaPrivateKey :: PrivateKey -> R.PrivateKey
|
||||
rsaPrivateKey pk =
|
||||
safeRsaPrivateKey :: (Int, Integer, Integer) -> R.PrivateKey
|
||||
safeRsaPrivateKey (size, n, d) =
|
||||
R.PrivateKey
|
||||
{ R.private_pub =
|
||||
{ private_pub =
|
||||
R.PublicKey
|
||||
{ R.public_size = private_size pk,
|
||||
R.public_n = private_n pk,
|
||||
R.public_e = undefined
|
||||
{ public_size = size,
|
||||
public_n = n,
|
||||
public_e = 0
|
||||
},
|
||||
R.private_d = private_d pk,
|
||||
R.private_p = 0,
|
||||
R.private_q = 0,
|
||||
R.private_dP = undefined,
|
||||
R.private_dQ = undefined,
|
||||
R.private_qinv = undefined
|
||||
private_d = d,
|
||||
private_p = 0,
|
||||
private_q = 0,
|
||||
private_dP = 0,
|
||||
private_dQ = 0,
|
||||
private_qinv = 0
|
||||
}
|
||||
|
||||
encodePubKey :: PublicKey -> ByteString
|
||||
encodePubKey = encodeKey . PubKeyRSA . rsaPublicKey
|
||||
|
||||
encodePrivKey :: PrivateKey k => k -> ByteString
|
||||
encodePrivKey = encodeKey . PrivKeyRSA . rsaPrivateKey
|
||||
|
||||
encodeKey :: ASN1Object a => a -> ByteString
|
||||
encodeKey k = toStrict . encodeASN1 DER $ toASN1 k []
|
||||
|
||||
decodePubKey :: ByteString -> Either String PublicKey
|
||||
decodePubKey =
|
||||
decodeKey >=> \case
|
||||
(PubKeyRSA k, []) -> Right $ PublicKey k
|
||||
r -> keyError r
|
||||
|
||||
decodePrivKey :: PrivateKey k => ByteString -> Either String k
|
||||
decodePrivKey =
|
||||
decodeKey >=> \case
|
||||
(PrivKeyRSA pk, []) -> Right $ mkPrivateKey pk
|
||||
r -> keyError r
|
||||
|
||||
decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1])
|
||||
decodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict
|
||||
|
||||
keyError :: (a, [ASN1]) -> Either String b
|
||||
keyError = \case
|
||||
(_, []) -> Left "not RSA key"
|
||||
_ -> Left "more than one key"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Parsers where
|
||||
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
@@ -9,15 +11,35 @@ import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.ISO8601 (parseISO8601)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
base64P :: Parser ByteString
|
||||
base64P = do
|
||||
base64P = decode <$?> base64StringP
|
||||
|
||||
base64StringP :: Parser ByteString
|
||||
base64StringP = do
|
||||
str <- A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/')
|
||||
pad <- A.takeWhile (== '=')
|
||||
either fail pure $ decode (str <> pad)
|
||||
pure $ str <> pad
|
||||
|
||||
tsISO8601P :: Parser UTCTime
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ')
|
||||
|
||||
parse :: Parser a -> e -> (ByteString -> Either e a)
|
||||
parse parser err = first (const err) . A.parseOnly (parser <* A.endOfInput)
|
||||
parse parser err = first (const err) . parseAll parser
|
||||
|
||||
parseAll :: Parser a -> (ByteString -> Either String a)
|
||||
parseAll parser = A.parseOnly (parser <* A.endOfInput)
|
||||
|
||||
parseRead :: Read a => Parser ByteString -> Parser a
|
||||
parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack)
|
||||
|
||||
parseRead1 :: Read a => Parser a
|
||||
parseRead1 = parseRead $ A.takeTill (== ' ')
|
||||
|
||||
parseRead2 :: Read a => Parser a
|
||||
parseRead2 = parseRead $ do
|
||||
w1 <- A.takeTill (== ' ') <* A.char ' '
|
||||
w2 <- A.takeTill (== ' ')
|
||||
pure $ w1 <> " " <> w2
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
@@ -13,7 +14,7 @@ module Simplex.Messaging.Protocol where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Monad.Except
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Base64
|
||||
@@ -24,12 +25,13 @@ import Data.Kind
|
||||
import Data.String
|
||||
import Data.Time.Clock
|
||||
import Data.Time.ISO8601
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util
|
||||
import System.IO
|
||||
import Text.Read
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
|
||||
data Party = Broker | Recipient | Sender
|
||||
deriving (Show)
|
||||
@@ -95,12 +97,12 @@ instance IsString CorrId where
|
||||
fromString = CorrId . fromString
|
||||
|
||||
-- only used by Agent, kept here so its definition is close to respective public key
|
||||
type RecipientPrivateKey = C.PrivateKey
|
||||
type RecipientPrivateKey = C.SafePrivateKey
|
||||
|
||||
type RecipientPublicKey = C.PublicKey
|
||||
|
||||
-- only used by Agent, kept here so its definition is close to respective public key
|
||||
type SenderPrivateKey = C.PrivateKey
|
||||
type SenderPrivateKey = C.SafePrivateKey
|
||||
|
||||
type SenderPublicKey = C.PublicKey
|
||||
|
||||
@@ -108,25 +110,36 @@ type MsgId = Encoded
|
||||
|
||||
type MsgBody = ByteString
|
||||
|
||||
data ErrorType = PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq)
|
||||
data ErrorType
|
||||
= BLOCK
|
||||
| CMD CommandError
|
||||
| AUTH
|
||||
| NO_MSG
|
||||
| INTERNAL
|
||||
| DUPLICATE_ -- TODO remove, not part of SMP protocol
|
||||
deriving (Eq, Generic, Read, Show)
|
||||
|
||||
errBadTransmission :: Int
|
||||
errBadTransmission = 1
|
||||
data CommandError
|
||||
= PROHIBITED
|
||||
| SYNTAX
|
||||
| NO_AUTH
|
||||
| HAS_AUTH
|
||||
| NO_QUEUE
|
||||
deriving (Eq, Generic, Read, Show)
|
||||
|
||||
errBadSMPCommand :: Int
|
||||
errBadSMPCommand = 2
|
||||
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
errNoCredentials :: Int
|
||||
errNoCredentials = 3
|
||||
instance Arbitrary CommandError where arbitrary = genericArbitraryU
|
||||
|
||||
errHasCredentials :: Int
|
||||
errHasCredentials = 4
|
||||
|
||||
errNoQueueId :: Int
|
||||
errNoQueueId = 5
|
||||
|
||||
errMessageBody :: Int
|
||||
errMessageBody = 6
|
||||
transmissionP :: Parser RawTransmission
|
||||
transmissionP = do
|
||||
signature <- segment
|
||||
corrId <- segment
|
||||
queueId <- segment
|
||||
command <- A.takeByteString
|
||||
return (signature, corrId, queueId, command)
|
||||
where
|
||||
segment = A.takeTill (== ' ') <* " "
|
||||
|
||||
commandP :: Parser Cmd
|
||||
commandP =
|
||||
@@ -148,132 +161,110 @@ commandP =
|
||||
newCmd = Cmd SRecipient . NEW <$> C.pubKeyP
|
||||
idsResp = Cmd SBroker <$> (IDS <$> (base64P <* A.space) <*> base64P)
|
||||
keyCmd = Cmd SRecipient . KEY <$> C.pubKeyP
|
||||
sendCmd = Cmd SSender . SEND <$> A.takeWhile A.isDigit
|
||||
sendCmd = do
|
||||
size <- A.decimal <* A.space
|
||||
Cmd SSender . SEND <$> A.take size <* A.space
|
||||
message = do
|
||||
msgId <- base64P <* A.space
|
||||
ts <- tsISO8601P <* A.space
|
||||
Cmd SBroker . MSG msgId ts <$> A.takeWhile A.isDigit
|
||||
serverError = Cmd SBroker . ERR <$> errorType
|
||||
errorType =
|
||||
"PROHIBITED" $> PROHIBITED
|
||||
<|> "SYNTAX " *> (SYNTAX <$> A.decimal)
|
||||
<|> "SIZE" $> SIZE
|
||||
<|> "AUTH" $> AUTH
|
||||
<|> "INTERNAL" $> INTERNAL
|
||||
size <- A.decimal <* A.space
|
||||
Cmd SBroker . MSG msgId ts <$> A.take size <* A.space
|
||||
serverError = Cmd SBroker . ERR <$> errorTypeP
|
||||
|
||||
-- TODO ignore the end of block, no need to parse it
|
||||
parseCommand :: ByteString -> Either ErrorType Cmd
|
||||
parseCommand = parse commandP $ SYNTAX errBadSMPCommand
|
||||
parseCommand = parse (commandP <* " " <* A.takeByteString) $ CMD SYNTAX
|
||||
|
||||
serializeCommand :: Cmd -> ByteString
|
||||
serializeCommand = \case
|
||||
Cmd SRecipient (NEW rKey) -> "NEW " <> C.serializePubKey rKey
|
||||
Cmd SRecipient (KEY sKey) -> "KEY " <> C.serializePubKey sKey
|
||||
Cmd SRecipient cmd -> B.pack $ show cmd
|
||||
Cmd SSender (SEND msgBody) -> "SEND" <> serializeMsg msgBody
|
||||
Cmd SRecipient cmd -> bshow cmd
|
||||
Cmd SSender (SEND msgBody) -> "SEND " <> serializeMsg msgBody
|
||||
Cmd SSender PING -> "PING"
|
||||
Cmd SBroker (MSG msgId ts msgBody) ->
|
||||
B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts] <> serializeMsg msgBody
|
||||
B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody]
|
||||
Cmd SBroker (IDS rId sId) -> B.unwords ["IDS", encode rId, encode sId]
|
||||
Cmd SBroker (ERR err) -> "ERR " <> B.pack (show err)
|
||||
Cmd SBroker resp -> B.pack $ show resp
|
||||
Cmd SBroker (ERR err) -> "ERR " <> serializeErrorType err
|
||||
Cmd SBroker resp -> bshow resp
|
||||
where
|
||||
serializeMsg msgBody = " " <> B.pack (show $ B.length msgBody) <> "\r\n" <> msgBody
|
||||
serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " "
|
||||
|
||||
tPutRaw :: Handle -> RawTransmission -> IO ()
|
||||
tPutRaw h (signature, corrId, queueId, command) = do
|
||||
putLn h signature
|
||||
putLn h corrId
|
||||
putLn h queueId
|
||||
putLn h command
|
||||
errorTypeP :: Parser ErrorType
|
||||
errorTypeP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1
|
||||
|
||||
tGetRaw :: Handle -> IO RawTransmission
|
||||
tGetRaw h = do
|
||||
signature <- getLn h
|
||||
corrId <- getLn h
|
||||
queueId <- getLn h
|
||||
command <- getLn h
|
||||
return (signature, corrId, queueId, command)
|
||||
serializeErrorType :: ErrorType -> ByteString
|
||||
serializeErrorType = bshow
|
||||
|
||||
tPut :: Handle -> SignedRawTransmission -> IO ()
|
||||
tPut h (C.Signature sig, t) = do
|
||||
putLn h $ encode sig
|
||||
putLn h t
|
||||
tPut :: THandle -> SignedRawTransmission -> IO (Either TransportError ())
|
||||
tPut th (C.Signature sig, t) =
|
||||
tPutEncrypted th $ encode sig <> " " <> t <> " "
|
||||
|
||||
serializeTransmission :: Transmission -> ByteString
|
||||
serializeTransmission (CorrId corrId, queueId, command) =
|
||||
B.intercalate "\r\n" [corrId, encode queueId, serializeCommand command]
|
||||
B.intercalate " " [corrId, encode queueId, serializeCommand command]
|
||||
|
||||
fromClient :: Cmd -> Either ErrorType Cmd
|
||||
fromClient = \case
|
||||
Cmd SBroker _ -> Left PROHIBITED
|
||||
Cmd SBroker _ -> Left $ CMD PROHIBITED
|
||||
cmd -> Right cmd
|
||||
|
||||
fromServer :: Cmd -> Either ErrorType Cmd
|
||||
fromServer = \case
|
||||
cmd@(Cmd SBroker _) -> Right cmd
|
||||
_ -> Left PROHIBITED
|
||||
_ -> Left $ CMD PROHIBITED
|
||||
|
||||
tGetParse :: THandle -> IO (Either TransportError RawTransmission)
|
||||
tGetParse th = (>>= parse transmissionP TEBadBlock) <$> tGetEncrypted th
|
||||
|
||||
-- | get client and server transmissions
|
||||
-- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used
|
||||
tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> Handle -> m SignedTransmissionOrError
|
||||
tGet fromParty h = do
|
||||
(signature, corrId, queueId, command) <- liftIO $ tGetRaw h
|
||||
let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId)
|
||||
either (const $ tError corrId) tParseLoadBody decodedTransmission
|
||||
tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> THandle -> m SignedTransmissionOrError
|
||||
tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
where
|
||||
tError :: ByteString -> m SignedTransmissionOrError
|
||||
tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left $ SYNTAX errBadTransmission))
|
||||
decodeParseValidate :: Either TransportError RawTransmission -> m SignedTransmissionOrError
|
||||
decodeParseValidate = \case
|
||||
Right (signature, corrId, queueId, command) ->
|
||||
let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId)
|
||||
in either (const $ tError corrId) tParseValidate decodedTransmission
|
||||
Left _ -> tError ""
|
||||
|
||||
tParseLoadBody :: RawTransmission -> m SignedTransmissionOrError
|
||||
tParseLoadBody t@(sig, corrId, queueId, command) = do
|
||||
tError :: ByteString -> m SignedTransmissionOrError
|
||||
tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left BLOCK))
|
||||
|
||||
tParseValidate :: RawTransmission -> m SignedTransmissionOrError
|
||||
tParseValidate t@(sig, corrId, queueId, command) = do
|
||||
let cmd = parseCommand command >>= fromParty >>= tCredentials t
|
||||
fullCmd <- either (return . Left) cmdWithMsgBody cmd
|
||||
return (C.Signature sig, (CorrId corrId, queueId, fullCmd))
|
||||
return (C.Signature sig, (CorrId corrId, queueId, cmd))
|
||||
|
||||
tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd
|
||||
tCredentials (signature, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response should not have queue ID
|
||||
-- IDS response must not have queue ID
|
||||
Cmd SBroker (IDS _ _) -> Right cmd
|
||||
-- ERR response does not always have queue ID
|
||||
Cmd SBroker (ERR _) -> Right cmd
|
||||
-- PONG response should not have queue ID
|
||||
-- PONG response must not have queue ID
|
||||
Cmd SBroker PONG
|
||||
| B.null queueId -> Right cmd
|
||||
| otherwise -> Left $ SYNTAX errHasCredentials
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other responses must have queue ID
|
||||
Cmd SBroker _
|
||||
| B.null queueId -> Left $ SYNTAX errNoQueueId
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| otherwise -> Right cmd
|
||||
-- NEW must NOT have signature or queue ID
|
||||
-- NEW must have signature but NOT queue ID
|
||||
Cmd SRecipient (NEW _)
|
||||
| B.null signature -> Left $ SYNTAX errNoCredentials
|
||||
| not (B.null queueId) -> Left $ SYNTAX errHasCredentials
|
||||
| B.null signature -> Left $ CMD NO_AUTH
|
||||
| not (B.null queueId) -> Left $ CMD HAS_AUTH
|
||||
| otherwise -> Right cmd
|
||||
-- SEND must have queue ID, signature is not always required
|
||||
Cmd SSender (SEND _)
|
||||
| B.null queueId -> Left $ SYNTAX errNoQueueId
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| otherwise -> Right cmd
|
||||
-- PING must not have queue ID or signature
|
||||
Cmd SSender PING
|
||||
| B.null queueId && B.null signature -> Right cmd
|
||||
| otherwise -> Left $ SYNTAX errHasCredentials
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other client commands must have both signature and queue ID
|
||||
Cmd SRecipient _
|
||||
| B.null signature || B.null queueId -> Left $ SYNTAX errNoCredentials
|
||||
| B.null signature || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
cmdWithMsgBody :: Cmd -> m (Either ErrorType Cmd)
|
||||
cmdWithMsgBody = \case
|
||||
Cmd SSender (SEND sizeStr) ->
|
||||
Cmd SSender . SEND <$$> getMsgBody sizeStr
|
||||
Cmd SBroker (MSG msgId ts sizeStr) ->
|
||||
Cmd SBroker . MSG msgId ts <$$> getMsgBody sizeStr
|
||||
cmd -> return $ Right cmd
|
||||
|
||||
getMsgBody :: MsgBody -> m (Either ErrorType MsgBody)
|
||||
getMsgBody sizeStr = case readMaybe (B.unpack sizeStr) :: Maybe Int of
|
||||
Just size -> liftIO $ do
|
||||
body <- B.hGet h size
|
||||
s <- getLn h
|
||||
return $ if B.null s then Right body else Left SIZE
|
||||
Nothing -> return $ Left INTERNAL
|
||||
|
||||
@@ -15,6 +15,7 @@ module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, randomBytes
|
||||
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
@@ -30,6 +31,7 @@ import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.MsgStore.STM (MsgQueue)
|
||||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.Server.QueueStore.STM (QueueStore)
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util
|
||||
import UnliftIO.Async
|
||||
@@ -50,6 +52,7 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do
|
||||
smpServer = do
|
||||
s <- asks server
|
||||
race_ (runTCPServer started tcpPort runClient) (serverThread s)
|
||||
`finally` withLog closeStoreLog
|
||||
|
||||
serverThread :: MonadUnliftIO m => Server -> m ()
|
||||
serverThread Server {subscribedQ, subscribers} = forever . atomically $ do
|
||||
@@ -62,11 +65,17 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do
|
||||
|
||||
runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m ()
|
||||
runClient h = do
|
||||
liftIO $ putLn h "Welcome to SMP v0.2.0"
|
||||
keyPair <- asks serverKeyPair
|
||||
liftIO (runExceptT $ serverHandshake h keyPair) >>= \case
|
||||
Right th -> runClientTransport th
|
||||
Left _ -> pure ()
|
||||
|
||||
runClientTransport :: (MonadUnliftIO m, MonadReader Env m) => THandle -> m ()
|
||||
runClientTransport th = do
|
||||
q <- asks $ tbqSize . config
|
||||
c <- atomically $ newClient q
|
||||
s <- asks server
|
||||
raceAny_ [send h c, client c s, receive h c]
|
||||
raceAny_ [send th c, client c s, receive th c]
|
||||
`finally` cancelSubscribers c
|
||||
|
||||
cancelSubscribers :: MonadUnliftIO m => Client -> m ()
|
||||
@@ -78,7 +87,7 @@ cancelSub = \case
|
||||
Sub {subThread = SubThread t} -> killThread t
|
||||
_ -> return ()
|
||||
|
||||
receive :: (MonadUnliftIO m, MonadReader Env m) => Handle -> Client -> m ()
|
||||
receive :: (MonadUnliftIO m, MonadReader Env m) => THandle -> Client -> m ()
|
||||
receive h Client {rcvQ} = forever $ do
|
||||
(signature, (corrId, queueId, cmdOrError)) <- tGet fromClient h
|
||||
t <- case cmdOrError of
|
||||
@@ -86,7 +95,7 @@ receive h Client {rcvQ} = forever $ do
|
||||
Right cmd -> verifyTransmission (signature, (corrId, queueId, cmd))
|
||||
atomically $ writeTBQueue rcvQ t
|
||||
|
||||
send :: MonadUnliftIO m => Handle -> Client -> m ()
|
||||
send :: MonadUnliftIO m => THandle -> Client -> m ()
|
||||
send h Client {sndQ} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
liftIO $ tPut h ("", serializeTransmission t)
|
||||
@@ -99,25 +108,27 @@ verifyTransmission (sig, t@(corrId, queueId, cmd)) = do
|
||||
(corrId,queueId,) <$> case cmd of
|
||||
Cmd SBroker _ -> return $ smpErr INTERNAL -- it can only be client command, because `fromClient` was used
|
||||
Cmd SRecipient (NEW k) -> return $ verifySignature k
|
||||
Cmd SRecipient _ -> withQueueRec SRecipient $ verifySignature . recipientKey
|
||||
Cmd SSender (SEND _) -> withQueueRec SSender $ verifySend sig . senderKey
|
||||
Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey
|
||||
Cmd SSender (SEND _) -> verifyCmd SSender $ verifySend sig . senderKey
|
||||
Cmd SSender PING -> return cmd
|
||||
where
|
||||
withQueueRec :: SParty (p :: Party) -> (QueueRec -> Cmd) -> m Cmd
|
||||
withQueueRec party f = do
|
||||
verifyCmd :: SParty p -> (QueueRec -> Cmd) -> m Cmd
|
||||
verifyCmd party f = do
|
||||
(aKey, _) <- asks serverKeyPair -- any public key can be used to mitigate timing attack
|
||||
st <- asks queueStore
|
||||
qr <- atomically $ getQueue st party queueId
|
||||
return $ either smpErr f qr
|
||||
q <- atomically $ getQueue st party queueId
|
||||
pure $ either (const $ fakeVerify aKey) f q
|
||||
fakeVerify :: C.PublicKey -> Cmd
|
||||
fakeVerify aKey = if verify aKey then authErr else authErr
|
||||
verifySend :: C.Signature -> Maybe SenderPublicKey -> Cmd
|
||||
verifySend "" = maybe cmd (const authErr)
|
||||
verifySend _ = maybe authErr verifySignature
|
||||
verifySignature :: C.PublicKey -> Cmd
|
||||
verifySignature key =
|
||||
if C.verify key sig (serializeTransmission t)
|
||||
then cmd
|
||||
else authErr
|
||||
|
||||
smpErr e = Cmd SBroker $ ERR e
|
||||
verifySignature key = if verify key then cmd else authErr
|
||||
verify :: C.PublicKey -> Bool
|
||||
verify key = C.verify key sig (serializeTransmission t)
|
||||
smpErr :: ErrorType -> Cmd
|
||||
smpErr = Cmd SBroker . ERR
|
||||
authErr = smpErr AUTH
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m ()
|
||||
@@ -140,8 +151,8 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
NEW rKey -> createQueue st rKey
|
||||
SUB -> subscribeQueue queueId
|
||||
ACK -> acknowledgeMsg
|
||||
KEY sKey -> okResp <$> atomically (secureQueue st queueId sKey)
|
||||
OFF -> okResp <$> atomically (suspendQueue st queueId)
|
||||
KEY sKey -> secureQueue_ st sKey
|
||||
OFF -> suspendQueue_ st
|
||||
DEL -> delQueueAndMsgs st
|
||||
where
|
||||
createQueue :: QueueStore -> RecipientPublicKey -> m Transmission
|
||||
@@ -151,22 +162,40 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
addSubscribe =
|
||||
addQueueRetry 3 >>= \case
|
||||
Left e -> return $ ERR e
|
||||
Right (rId, sId) -> subscribeQueue rId $> IDS rId sId
|
||||
Right (rId, sId) -> do
|
||||
withLog (`logCreateById` rId)
|
||||
subscribeQueue rId $> IDS rId sId
|
||||
|
||||
addQueueRetry :: Int -> m (Either ErrorType (RecipientId, SenderId))
|
||||
addQueueRetry 0 = return $ Left INTERNAL
|
||||
addQueueRetry n = do
|
||||
ids <- getIds
|
||||
atomically (addQueue st rKey ids) >>= \case
|
||||
Left DUPLICATE -> addQueueRetry $ n - 1
|
||||
Left DUPLICATE_ -> addQueueRetry $ n - 1
|
||||
Left e -> return $ Left e
|
||||
Right _ -> return $ Right ids
|
||||
|
||||
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
|
||||
logCreateById s rId =
|
||||
atomically (getQueue st SRecipient rId) >>= \case
|
||||
Right q -> logCreateQueue s q
|
||||
_ -> pure ()
|
||||
|
||||
getIds :: m (RecipientId, SenderId)
|
||||
getIds = do
|
||||
n <- asks $ queueIdBytes . config
|
||||
liftM2 (,) (randomId n) (randomId n)
|
||||
|
||||
secureQueue_ :: QueueStore -> SenderPublicKey -> m Transmission
|
||||
secureQueue_ st sKey = do
|
||||
withLog $ \s -> logSecureQueue s queueId sKey
|
||||
okResp <$> atomically (secureQueue st queueId sKey)
|
||||
|
||||
suspendQueue_ :: QueueStore -> m Transmission
|
||||
suspendQueue_ st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
okResp <$> atomically (suspendQueue st queueId)
|
||||
|
||||
subscribeQueue :: RecipientId -> m Transmission
|
||||
subscribeQueue rId =
|
||||
atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId
|
||||
@@ -193,7 +222,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s))
|
||||
>>= \case
|
||||
Just (Just s) -> deliverMessage tryDelPeekMsg queueId s
|
||||
_ -> return $ err PROHIBITED
|
||||
_ -> return $ err NO_MSG
|
||||
|
||||
withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a)
|
||||
withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId
|
||||
@@ -253,6 +282,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
|
||||
delQueueAndMsgs :: QueueStore -> m Transmission
|
||||
delQueueAndMsgs st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
ms <- asks msgStore
|
||||
atomically $
|
||||
deleteQueue st queueId >>= \case
|
||||
@@ -271,6 +301,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
|
||||
msgCmd :: Message -> Command 'Broker
|
||||
msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody
|
||||
|
||||
withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m ()
|
||||
withLog action = do
|
||||
env <- ask
|
||||
liftIO . mapM_ action $ storeLog (env :: Env)
|
||||
|
||||
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded
|
||||
randomId n = do
|
||||
gVar <- asks idsDrg
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Server.Env.STM where
|
||||
|
||||
@@ -10,16 +12,23 @@ import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
import Simplex.Messaging.Server.QueueStore (QueueRec (..))
|
||||
import Simplex.Messaging.Server.QueueStore.STM
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import System.IO (IOMode (..))
|
||||
import UnliftIO.STM
|
||||
|
||||
data ServerConfig = ServerConfig
|
||||
{ tcpPort :: ServiceName,
|
||||
tbqSize :: Natural,
|
||||
queueIdBytes :: Int,
|
||||
msgIdBytes :: Int
|
||||
msgIdBytes :: Int,
|
||||
storeLog :: Maybe (StoreLog 'ReadMode),
|
||||
serverPrivateKey :: C.FullPrivateKey
|
||||
-- serverId :: ByteString
|
||||
}
|
||||
|
||||
data Env = Env
|
||||
@@ -27,7 +36,9 @@ data Env = Env
|
||||
server :: Server,
|
||||
queueStore :: QueueStore,
|
||||
msgStore :: STMMsgStore,
|
||||
idsDrg :: TVar ChaChaDRG
|
||||
idsDrg :: TVar ChaChaDRG,
|
||||
serverKeyPair :: C.FullKeyPair,
|
||||
storeLog :: Maybe (StoreLog 'WriteMode)
|
||||
}
|
||||
|
||||
data Server = Server
|
||||
@@ -66,10 +77,21 @@ newSubscription = do
|
||||
delivered <- newEmptyTMVar
|
||||
return Sub {subThread = NoSub, delivered}
|
||||
|
||||
newEnv :: (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
|
||||
newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
|
||||
newEnv config = do
|
||||
server <- atomically $ newServer (tbqSize config)
|
||||
queueStore <- atomically newQueueStore
|
||||
msgStore <- atomically newMsgStore
|
||||
idsDrg <- drgNew >>= newTVarIO
|
||||
return Env {config, server, queueStore, msgStore, idsDrg}
|
||||
s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig)
|
||||
let pk = serverPrivateKey config
|
||||
serverKeyPair = (C.publicKey pk, pk)
|
||||
return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s'}
|
||||
where
|
||||
restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode)
|
||||
restoreQueues queueStore s = do
|
||||
(queues, s') <- liftIO $ readWriteStoreLog s
|
||||
atomically $ modifyTVar queueStore $ \d -> d {queues, senders = M.foldr' addSender M.empty queues}
|
||||
pure s'
|
||||
addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId
|
||||
addSender q = M.insert (senderId q) (recipientId q)
|
||||
|
||||
@@ -15,7 +15,7 @@ data QueueRec = QueueRec
|
||||
status :: QueueStatus
|
||||
}
|
||||
|
||||
data QueueStatus = QueueActive | QueueOff
|
||||
data QueueStatus = QueueActive | QueueOff deriving (Eq)
|
||||
|
||||
class MonadQueueStore s m where
|
||||
addQueue :: s -> RecipientPublicKey -> (RecipientId, SenderId) -> m (Either ErrorType ())
|
||||
|
||||
@@ -32,7 +32,7 @@ instance MonadQueueStore QueueStore STM where
|
||||
addQueue store rKey ids@(rId, sId) = do
|
||||
cs@QueueStoreData {queues, senders} <- readTVar store
|
||||
if M.member rId queues || M.member sId senders
|
||||
then return $ Left DUPLICATE
|
||||
then return $ Left DUPLICATE_
|
||||
else do
|
||||
writeTVar store $
|
||||
cs
|
||||
|
||||
140
src/Simplex/Messaging/Server/StoreLog.hs
Normal file
140
src/Simplex/Messaging/Server/StoreLog.hs
Normal file
@@ -0,0 +1,140 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Server.StoreLog
|
||||
( StoreLog, -- constructors are not exported
|
||||
openWriteStoreLog,
|
||||
openReadStoreLog,
|
||||
closeStoreLog,
|
||||
logCreateQueue,
|
||||
logSecureQueue,
|
||||
logDeleteQueue,
|
||||
readWriteStoreLog,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Monad (unless)
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first, second)
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import Data.Either (partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (foldl')
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Parsers (base64P, parseAll)
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.QueueStore (QueueRec (..), QueueStatus (..))
|
||||
import Simplex.Messaging.Transport (trimCR)
|
||||
import System.Directory (doesFileExist)
|
||||
import System.IO
|
||||
|
||||
-- | opaque container for file handle with a type-safe IOMode
|
||||
-- constructors are not exported, openWriteStoreLog and openReadStoreLog should be used instead
|
||||
data StoreLog (a :: IOMode) where
|
||||
ReadStoreLog :: FilePath -> Handle -> StoreLog 'ReadMode
|
||||
WriteStoreLog :: FilePath -> Handle -> StoreLog 'WriteMode
|
||||
|
||||
data StoreLogRecord
|
||||
= CreateQueue QueueRec
|
||||
| SecureQueue QueueId SenderPublicKey
|
||||
| DeleteQueue QueueId
|
||||
|
||||
storeLogRecordP :: Parser StoreLogRecord
|
||||
storeLogRecordP =
|
||||
"CREATE " *> createQueueP
|
||||
<|> "SECURE " *> secureQueueP
|
||||
<|> "DELETE " *> (DeleteQueue <$> base64P)
|
||||
where
|
||||
createQueueP = CreateQueue <$> queueRecP
|
||||
secureQueueP = SecureQueue <$> base64P <* A.space <*> C.pubKeyP
|
||||
queueRecP = do
|
||||
recipientId <- "rid=" *> base64P <* A.space
|
||||
senderId <- "sid=" *> base64P <* A.space
|
||||
recipientKey <- "rk=" *> C.pubKeyP <* A.space
|
||||
senderKey <- "sk=" *> optional C.pubKeyP
|
||||
pure QueueRec {recipientId, senderId, recipientKey, senderKey, status = QueueActive}
|
||||
|
||||
serializeStoreLogRecord :: StoreLogRecord -> ByteString
|
||||
serializeStoreLogRecord = \case
|
||||
CreateQueue q -> "CREATE " <> serializeQueue q
|
||||
SecureQueue rId sKey -> "SECURE " <> encode rId <> " " <> C.serializePubKey sKey
|
||||
DeleteQueue rId -> "DELETE " <> encode rId
|
||||
where
|
||||
serializeQueue QueueRec {recipientId, senderId, recipientKey, senderKey} =
|
||||
B.unwords
|
||||
[ "rid=" <> encode recipientId,
|
||||
"sid=" <> encode senderId,
|
||||
"rk=" <> C.serializePubKey recipientKey,
|
||||
"sk=" <> maybe "" C.serializePubKey senderKey
|
||||
]
|
||||
|
||||
openWriteStoreLog :: FilePath -> IO (StoreLog 'WriteMode)
|
||||
openWriteStoreLog f = WriteStoreLog f <$> openFile f WriteMode
|
||||
|
||||
openReadStoreLog :: FilePath -> IO (StoreLog 'ReadMode)
|
||||
openReadStoreLog f = do
|
||||
doesFileExist f >>= (`unless` writeFile f "")
|
||||
ReadStoreLog f <$> openFile f ReadMode
|
||||
|
||||
closeStoreLog :: StoreLog a -> IO ()
|
||||
closeStoreLog = \case
|
||||
WriteStoreLog _ h -> hClose h
|
||||
ReadStoreLog _ h -> hClose h
|
||||
|
||||
writeStoreLogRecord :: StoreLog 'WriteMode -> StoreLogRecord -> IO ()
|
||||
writeStoreLogRecord (WriteStoreLog _ h) r = do
|
||||
B.hPutStrLn h $ serializeStoreLogRecord r
|
||||
hFlush h
|
||||
|
||||
logCreateQueue :: StoreLog 'WriteMode -> QueueRec -> IO ()
|
||||
logCreateQueue s = writeStoreLogRecord s . CreateQueue
|
||||
|
||||
logSecureQueue :: StoreLog 'WriteMode -> QueueId -> SenderPublicKey -> IO ()
|
||||
logSecureQueue s qId sKey = writeStoreLogRecord s $ SecureQueue qId sKey
|
||||
|
||||
logDeleteQueue :: StoreLog 'WriteMode -> QueueId -> IO ()
|
||||
logDeleteQueue s = writeStoreLogRecord s . DeleteQueue
|
||||
|
||||
readWriteStoreLog :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec, StoreLog 'WriteMode)
|
||||
readWriteStoreLog s@(ReadStoreLog f _) = do
|
||||
qs <- readQueues s
|
||||
closeStoreLog s
|
||||
s' <- openWriteStoreLog f
|
||||
writeQueues s' qs
|
||||
pure (qs, s')
|
||||
|
||||
writeQueues :: StoreLog 'WriteMode -> Map RecipientId QueueRec -> IO ()
|
||||
writeQueues s = mapM_ (writeStoreLogRecord s . CreateQueue) . M.filter active
|
||||
where
|
||||
active QueueRec {status} = status == QueueActive
|
||||
|
||||
type LogParsingError = (String, ByteString)
|
||||
|
||||
readQueues :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec)
|
||||
readQueues (ReadStoreLog _ h) = LB.hGetContents h >>= returnResult . procStoreLog
|
||||
where
|
||||
procStoreLog :: LB.ByteString -> ([LogParsingError], Map RecipientId QueueRec)
|
||||
procStoreLog = second (foldl' procLogRecord M.empty) . partitionEithers . map parseLogRecord . LB.lines
|
||||
returnResult :: ([LogParsingError], Map RecipientId QueueRec) -> IO (Map RecipientId QueueRec)
|
||||
returnResult (errs, res) = mapM_ printError errs $> res
|
||||
parseLogRecord :: LB.ByteString -> Either LogParsingError StoreLogRecord
|
||||
parseLogRecord = (\s -> first (,s) $ parseAll storeLogRecordP s) . trimCR . LB.toStrict
|
||||
procLogRecord :: Map RecipientId QueueRec -> StoreLogRecord -> Map RecipientId QueueRec
|
||||
procLogRecord m = \case
|
||||
CreateQueue q -> M.insert (recipientId q) q m
|
||||
SecureQueue qId sKey -> M.adjust (\q -> q {senderKey = Just sKey}) qId m
|
||||
DeleteQueue qId -> M.delete qId m
|
||||
printError :: LogParsingError -> IO ()
|
||||
printError (e, s) = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s
|
||||
@@ -1,29 +1,52 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Transport where
|
||||
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.Trans.Except (throwE)
|
||||
import Crypto.Cipher.Types (AuthTag)
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteArray (xor)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Word (Word32)
|
||||
import GHC.Generics (Generic)
|
||||
import GHC.IO.Exception (IOErrorType (..))
|
||||
import GHC.IO.Handle.Internals (ioe_EOF)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Network.Socket
|
||||
import Network.Transport.Internal (decodeNum16, decodeNum32, encodeEnum16, encodeEnum32, encodeWord32)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Parsers (parse, parseAll, parseRead1)
|
||||
import Simplex.Messaging.Util (bshow, liftError)
|
||||
import System.IO
|
||||
import System.IO.Error
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Exception (Exception, IOException)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import qualified UnliftIO.IO as IO
|
||||
import UnliftIO.STM
|
||||
|
||||
-- * TCP transport
|
||||
|
||||
runTCPServer :: MonadUnliftIO m => TMVar Bool -> ServiceName -> (Handle -> m ()) -> m ()
|
||||
runTCPServer started port server = do
|
||||
clients <- newTVarIO S.empty
|
||||
@@ -33,7 +56,10 @@ runTCPServer started port server = do
|
||||
atomically . modifyTVar clients $ S.insert tid
|
||||
where
|
||||
closeServer :: TVar (Set ThreadId) -> Socket -> IO ()
|
||||
closeServer clients sock = readTVarIO clients >>= mapM_ killThread >> close sock
|
||||
closeServer clients sock = do
|
||||
readTVarIO clients >>= mapM_ killThread
|
||||
close sock
|
||||
void . atomically $ tryPutTMVar started False
|
||||
|
||||
startTCPServer :: TMVar Bool -> ServiceName -> IO Socket
|
||||
startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted
|
||||
@@ -59,9 +85,7 @@ runTCPClient host port client = do
|
||||
client h `E.finally` IO.hClose h
|
||||
|
||||
startTCPClient :: HostName -> ServiceName -> IO Handle
|
||||
startTCPClient host port =
|
||||
withSocketsDo $
|
||||
resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion
|
||||
startTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
|
||||
where
|
||||
err :: IOException
|
||||
err = mkIOError NoSuchThing "no address" Nothing Nothing
|
||||
@@ -71,9 +95,10 @@ startTCPClient host port =
|
||||
let hints = defaultHints {addrSocketType = Stream}
|
||||
in getAddrInfo (Just hints) (Just host) (Just port)
|
||||
|
||||
tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle)
|
||||
tryOpen (Left _) addr = E.try $ open addr
|
||||
tryOpen h _ = return h
|
||||
tryOpen :: IOException -> [AddrInfo] -> IO Handle
|
||||
tryOpen e [] = E.throwIO e
|
||||
tryOpen _ (addr : as) =
|
||||
E.try (open addr) >>= either (`tryOpen` as) pure
|
||||
|
||||
open :: AddrInfo -> IO Handle
|
||||
open addr = do
|
||||
@@ -93,7 +118,261 @@ putLn :: Handle -> ByteString -> IO ()
|
||||
putLn h = B.hPut h . (<> "\r\n")
|
||||
|
||||
getLn :: Handle -> IO ByteString
|
||||
getLn h = trim_cr <$> B.hGetLine h
|
||||
getLn h = trimCR <$> B.hGetLine h
|
||||
|
||||
trimCR :: ByteString -> ByteString
|
||||
trimCR "" = ""
|
||||
trimCR s = if B.last s == '\r' then B.init s else s
|
||||
|
||||
-- * Encrypted transport
|
||||
|
||||
data SMPVersion = SMPVersion Int Int Int Int
|
||||
deriving (Eq, Ord)
|
||||
|
||||
major :: SMPVersion -> (Int, Int)
|
||||
major (SMPVersion a b _ _) = (a, b)
|
||||
|
||||
currentSMPVersion :: SMPVersion
|
||||
currentSMPVersion = SMPVersion 0 2 0 0
|
||||
|
||||
serializeSMPVersion :: SMPVersion -> ByteString
|
||||
serializeSMPVersion (SMPVersion a b c d) = B.intercalate "." [bshow a, bshow b, bshow c, bshow d]
|
||||
|
||||
smpVersionP :: Parser SMPVersion
|
||||
smpVersionP =
|
||||
let ver = A.decimal <* A.char '.'
|
||||
in SMPVersion <$> ver <*> ver <*> ver <*> A.decimal
|
||||
|
||||
data THandle = THandle
|
||||
{ handle :: Handle,
|
||||
sndKey :: SessionKey,
|
||||
rcvKey :: SessionKey,
|
||||
blockSize :: Int
|
||||
}
|
||||
|
||||
data SessionKey = SessionKey
|
||||
{ aesKey :: C.Key,
|
||||
baseIV :: C.IV,
|
||||
counter :: TVar Word32
|
||||
}
|
||||
|
||||
data ClientHandshake = ClientHandshake
|
||||
{ blockSize :: Int,
|
||||
sndKey :: SessionKey,
|
||||
rcvKey :: SessionKey
|
||||
}
|
||||
|
||||
data TransportError
|
||||
= TEBadBlock
|
||||
| TEEncrypt
|
||||
| TEDecrypt
|
||||
| TEHandshake HandshakeError
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
data HandshakeError
|
||||
= ENCRYPT
|
||||
| DECRYPT
|
||||
| VERSION
|
||||
| RSA_KEY
|
||||
| HEADER
|
||||
| AES_KEYS
|
||||
| BAD_HASH
|
||||
| MAJOR_VERSION
|
||||
| TERMINATED
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance Arbitrary TransportError where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
|
||||
|
||||
transportErrorP :: Parser TransportError
|
||||
transportErrorP =
|
||||
"BLOCK" $> TEBadBlock
|
||||
<|> "AES_ENCRYPT" $> TEEncrypt
|
||||
<|> "AES_DECRYPT" $> TEDecrypt
|
||||
<|> TEHandshake <$> parseRead1
|
||||
|
||||
serializeTransportError :: TransportError -> ByteString
|
||||
serializeTransportError = \case
|
||||
TEEncrypt -> "AES_ENCRYPT"
|
||||
TEDecrypt -> "AES_DECRYPT"
|
||||
TEBadBlock -> "BLOCK"
|
||||
TEHandshake e -> bshow e
|
||||
|
||||
tPutEncrypted :: THandle -> ByteString -> IO (Either TransportError ())
|
||||
tPutEncrypted THandle {handle = h, sndKey, blockSize} block =
|
||||
encryptBlock sndKey (blockSize - C.authTagSize) block >>= \case
|
||||
Left _ -> pure $ Left TEEncrypt
|
||||
Right (authTag, msg) -> Right <$> B.hPut h (C.authTagToBS authTag <> msg)
|
||||
|
||||
tGetEncrypted :: THandle -> IO (Either TransportError ByteString)
|
||||
tGetEncrypted THandle {handle = h, rcvKey, blockSize} =
|
||||
B.hGet h blockSize >>= decryptBlock rcvKey >>= \case
|
||||
Left _ -> pure $ Left TEDecrypt
|
||||
Right "" -> ioe_EOF
|
||||
Right msg -> pure $ Right msg
|
||||
|
||||
encryptBlock :: SessionKey -> Int -> ByteString -> IO (Either C.CryptoError (AuthTag, ByteString))
|
||||
encryptBlock k@SessionKey {aesKey} size block = do
|
||||
ivBytes <- makeNextIV k
|
||||
runExceptT $ C.encryptAES aesKey ivBytes size block
|
||||
|
||||
decryptBlock :: SessionKey -> ByteString -> IO (Either C.CryptoError ByteString)
|
||||
decryptBlock k@SessionKey {aesKey} block = do
|
||||
let (authTag, msg') = B.splitAt C.authTagSize block
|
||||
ivBytes <- makeNextIV k
|
||||
runExceptT $ C.decryptAES aesKey ivBytes msg' (C.bsToAuthTag authTag)
|
||||
|
||||
makeNextIV :: SessionKey -> IO C.IV
|
||||
makeNextIV SessionKey {baseIV, counter} = atomically $ do
|
||||
c <- readTVar counter
|
||||
writeTVar counter $ c + 1
|
||||
pure $ iv c
|
||||
where
|
||||
trim_cr "" = ""
|
||||
trim_cr s = if B.last s == '\r' then B.init s else s
|
||||
(start, rest) = B.splitAt 4 $ C.unIV baseIV
|
||||
iv c = C.IV $ (start `xor` encodeWord32 c) <> rest
|
||||
|
||||
-- | implements server transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption
|
||||
-- The numbers in function names refer to the steps in the document
|
||||
serverHandshake :: Handle -> C.FullKeyPair -> ExceptT TransportError IO THandle
|
||||
serverHandshake h (k, pk) = do
|
||||
liftIO sendHeaderAndPublicKey_1
|
||||
encryptedKeys <- receiveEncryptedKeys_4
|
||||
-- TODO server currently ignores blockSize returned by the client
|
||||
-- this is reserved for future support of streams
|
||||
ClientHandshake {blockSize = _, sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys
|
||||
th <- liftIO $ transportHandle h rcvKey sndKey transportBlockSize -- keys are swapped here
|
||||
sendWelcome_6 th
|
||||
pure th
|
||||
where
|
||||
sendHeaderAndPublicKey_1 :: IO ()
|
||||
sendHeaderAndPublicKey_1 = do
|
||||
let sKey = C.encodePubKey k
|
||||
header = ServerHeader {blockSize = transportBlockSize, keySize = B.length sKey}
|
||||
B.hPut h $ binaryServerHeader header <> sKey
|
||||
receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString
|
||||
receiveEncryptedKeys_4 =
|
||||
liftIO (B.hGet h $ C.publicKeySize k) >>= \case
|
||||
"" -> throwE $ TEHandshake TERMINATED
|
||||
ks -> pure ks
|
||||
decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO ClientHandshake
|
||||
decryptParseKeys_5 encKeys =
|
||||
liftError (const $ TEHandshake DECRYPT) (C.decryptOAEP pk encKeys)
|
||||
>>= liftEither . parseClientHandshake
|
||||
sendWelcome_6 :: THandle -> ExceptT TransportError IO ()
|
||||
sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " "
|
||||
|
||||
-- | implements client transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption
|
||||
-- The numbers in function names refer to the steps in the document
|
||||
clientHandshake :: Handle -> Maybe C.KeyHash -> ExceptT TransportError IO THandle
|
||||
clientHandshake h keyHash = do
|
||||
(k, blkSize) <- getHeaderAndPublicKey_1_2
|
||||
-- TODO currently client always uses the blkSize returned by the server
|
||||
keys@ClientHandshake {sndKey, rcvKey} <- liftIO $ generateKeys_3 blkSize
|
||||
sendEncryptedKeys_4 k keys
|
||||
th <- liftIO $ transportHandle h sndKey rcvKey blkSize
|
||||
getWelcome_6 th >>= checkVersion
|
||||
pure th
|
||||
where
|
||||
getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (C.PublicKey, Int)
|
||||
getHeaderAndPublicKey_1_2 = do
|
||||
header <- liftIO (B.hGet h serverHeaderSize)
|
||||
ServerHeader {blockSize, keySize} <- liftEither $ parse serverHeaderP (TEHandshake HEADER) header
|
||||
when (blockSize < transportBlockSize || blockSize > maxTransportBlockSize) $
|
||||
throwError $ TEHandshake HEADER
|
||||
s <- liftIO $ B.hGet h keySize
|
||||
maybe (pure ()) (validateKeyHash_2 s) keyHash
|
||||
key <- liftEither $ parseKey s
|
||||
pure (key, blockSize)
|
||||
parseKey :: ByteString -> Either TransportError C.PublicKey
|
||||
parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.binaryPubKeyP
|
||||
validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO ()
|
||||
validateKeyHash_2 k kHash
|
||||
| C.getKeyHash k == kHash = pure ()
|
||||
| otherwise = throwE $ TEHandshake BAD_HASH
|
||||
generateKeys_3 :: Int -> IO ClientHandshake
|
||||
generateKeys_3 blkSize = ClientHandshake blkSize <$> generateKey <*> generateKey
|
||||
generateKey :: IO SessionKey
|
||||
generateKey = do
|
||||
aesKey <- C.randomAesKey
|
||||
baseIV <- C.randomIV
|
||||
pure SessionKey {aesKey, baseIV, counter = undefined}
|
||||
sendEncryptedKeys_4 :: C.PublicKey -> ClientHandshake -> ExceptT TransportError IO ()
|
||||
sendEncryptedKeys_4 k keys =
|
||||
liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeClientHandshake keys)
|
||||
>>= liftIO . B.hPut h
|
||||
getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion
|
||||
getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th
|
||||
parseSMPVersion :: ByteString -> Either TransportError SMPVersion
|
||||
parseSMPVersion = first (const $ TEHandshake VERSION) . A.parseOnly (smpVersionP <* A.space)
|
||||
checkVersion :: SMPVersion -> ExceptT TransportError IO ()
|
||||
checkVersion smpVersion =
|
||||
when (major smpVersion > major currentSMPVersion) . throwE $
|
||||
TEHandshake MAJOR_VERSION
|
||||
|
||||
data ServerHeader = ServerHeader {blockSize :: Int, keySize :: Int}
|
||||
deriving (Eq, Show)
|
||||
|
||||
binaryRsaTransport :: Int
|
||||
binaryRsaTransport = 0
|
||||
|
||||
transportBlockSize :: Int
|
||||
transportBlockSize = 4096
|
||||
|
||||
maxTransportBlockSize :: Int
|
||||
maxTransportBlockSize = 65536
|
||||
|
||||
serverHeaderSize :: Int
|
||||
serverHeaderSize = 8
|
||||
|
||||
binaryServerHeader :: ServerHeader -> ByteString
|
||||
binaryServerHeader ServerHeader {blockSize, keySize} =
|
||||
encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> encodeEnum16 keySize
|
||||
|
||||
serverHeaderP :: Parser ServerHeader
|
||||
serverHeaderP = ServerHeader <$> int32 <* binaryRsaTransportP <*> int16
|
||||
|
||||
serializeClientHandshake :: ClientHandshake -> ByteString
|
||||
serializeClientHandshake ClientHandshake {blockSize, sndKey, rcvKey} =
|
||||
encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> serializeKey sndKey <> serializeKey rcvKey
|
||||
where
|
||||
serializeKey :: SessionKey -> ByteString
|
||||
serializeKey SessionKey {aesKey, baseIV} = C.unKey aesKey <> C.unIV baseIV
|
||||
|
||||
clientHandshakeP :: Parser ClientHandshake
|
||||
clientHandshakeP = ClientHandshake <$> int32 <* binaryRsaTransportP <*> keyP <*> keyP
|
||||
where
|
||||
keyP :: Parser SessionKey
|
||||
keyP = do
|
||||
aesKey <- C.aesKeyP
|
||||
baseIV <- C.ivP
|
||||
pure SessionKey {aesKey, baseIV, counter = undefined}
|
||||
|
||||
int32 :: Parser Int
|
||||
int32 = decodeNum32 <$> A.take 4
|
||||
|
||||
int16 :: Parser Int
|
||||
int16 = decodeNum16 <$> A.take 2
|
||||
|
||||
binaryRsaTransportP :: Parser ()
|
||||
binaryRsaTransportP = binaryRsa =<< int16
|
||||
where
|
||||
binaryRsa :: Int -> Parser ()
|
||||
binaryRsa n
|
||||
| n == binaryRsaTransport = pure ()
|
||||
| otherwise = fail "unknown transport mode"
|
||||
|
||||
parseClientHandshake :: ByteString -> Either TransportError ClientHandshake
|
||||
parseClientHandshake = parse clientHandshakeP $ TEHandshake AES_KEYS
|
||||
|
||||
transportHandle :: Handle -> SessionKey -> SessionKey -> Int -> IO THandle
|
||||
transportHandle h sk rk blockSize = do
|
||||
sndCounter <- newTVarIO 0
|
||||
rcvCounter <- newTVarIO 0
|
||||
pure
|
||||
THandle
|
||||
{ handle = h,
|
||||
sndKey = sk {counter = sndCounter},
|
||||
rcvKey = rk {counter = rcvCounter},
|
||||
blockSize
|
||||
}
|
||||
|
||||
@@ -31,19 +31,22 @@ raceAny_ = r []
|
||||
r as (m : ms) = withAsync m $ \a -> r (a : as) ms
|
||||
r as [] = void $ waitAnyCancel as
|
||||
|
||||
infixl 4 <$$>
|
||||
infixl 4 <$$>, <$?>
|
||||
|
||||
(<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b)
|
||||
(<$$>) = fmap . fmap
|
||||
|
||||
(<$?>) :: MonadFail m => (a -> Either String b) -> m a -> m b
|
||||
f <$?> m = m >>= either fail pure . f
|
||||
|
||||
bshow :: Show a => a -> ByteString
|
||||
bshow = B.pack . show
|
||||
|
||||
liftIOEither :: (MonadUnliftIO m, MonadError e m) => IO (Either e a) -> m a
|
||||
liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a
|
||||
liftIOEither a = liftIO a >>= liftEither
|
||||
|
||||
liftError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a
|
||||
liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a
|
||||
liftError f = liftEitherError f . runExceptT
|
||||
|
||||
liftEitherError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a
|
||||
liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a
|
||||
liftEitherError f a = liftIOEither (first f <$> a)
|
||||
|
||||
98
src/Simplex/Messaging/errors.md
Normal file
98
src/Simplex/Messaging/errors.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# Errors
|
||||
|
||||
## Problems
|
||||
|
||||
- using numbers and strings to indicate errors (in protocol and in code) - ErrorType, AgentErrorType, TransportError
|
||||
- re-using the same type in multiple contexts (with some constructors not applicable to all contexts) - ErrorType
|
||||
|
||||
## Error types
|
||||
|
||||
### ErrorType (Protocol.hs)
|
||||
|
||||
- BLOCK - incorrect block format or encoding
|
||||
- CMD error - command is unknown or has invalid syntax, where `error` can be:
|
||||
- PROHIBITED - server response sent from client or vice versa
|
||||
- SYNTAX - error parsing command
|
||||
- NO_AUTH - transmission has no required credentials (signature or queue ID)
|
||||
- HAS_AUTH - transmission has not allowed credentials
|
||||
- NO_QUEUE - transmission has not queue ID
|
||||
- AUTH - command is not authorised (queue does not exist or signature verification failed).
|
||||
- NO_MSG - acknowledging (ACK) the message without message
|
||||
- INTERNAL - internal server error.
|
||||
- DUPLICATE_ - it is used internally to signal that the queue ID is already used. This is NOT used in the protocol, instead INTERNAL is sent to the client. It has to be removed.
|
||||
|
||||
### AgentErrorType (Agent/Transmission.hs)
|
||||
|
||||
Some of these errors are not correctly serialized/parsed - see line 322 in Agent/Transmission.hs
|
||||
|
||||
- CMD e - command or response error
|
||||
- PROHIBITED - server response sent as client command (and vice versa)
|
||||
- SYNTAX - command is unknown or has invalid syntax.
|
||||
- NO_CONN - connection is required in the command (and absent)
|
||||
- SIZE - incorrect message size of messages (when parsing SEND and MSG)
|
||||
- LARGE -- message does not fit SMP block
|
||||
- CONN e - connection errors
|
||||
- UNKNOWN - connection alias not in database
|
||||
- DUPLICATE - connection alias already exists
|
||||
- SIMPLEX - connection is simplex, but operation requires another queue
|
||||
- SMP ErrorType - forwarding SMP errors (SMPServerError) to the agent client
|
||||
- BROKER e - SMP server errors
|
||||
- RESPONSE ErrorType - invalid SMP server response
|
||||
- UNEXPECTED - unexpected response
|
||||
- NETWORK - network TCP connection error
|
||||
- TRANSPORT TransportError -- handshake or other transport error
|
||||
- TIMEOUT - command response timeout
|
||||
- AGENT e - errors of other agents
|
||||
- A_MESSAGE - SMP message failed to parse
|
||||
- A_PROHIBITED - SMP message is prohibited with the current queue status
|
||||
- A_ENCRYPTION - cannot RSA/AES-decrypt or parse decrypted header
|
||||
- A_SIGNATURE - invalid RSA signature
|
||||
- INTERNAL ByteString - agent implementation or dependency error
|
||||
|
||||
### SMPClientError (Client.hs)
|
||||
|
||||
- SMPServerError ErrorType - this is correctly parsed server ERR response. This error is forwarded to the agent client as `ERR SMP err`
|
||||
- SMPResponseError ErrorType - this is invalid server response that failed to parse - forwarded to the client as `ERR BROKER RESPONSE`.
|
||||
- SMPUnexpectedResponse - different response from what is expected to a given command, e.g. server should respond `IDS` or `ERR` to `NEW` command, other responses would result in this error - forwarded to the client as `ERR BROKER UNEXPECTED`.
|
||||
- SMPResponseTimeout - used for TCP connection and command response timeouts -> `ERR BROKER TIMEOUT`.
|
||||
- SMPNetworkError - fails to establish TCP connection -> `ERR BROKER NETWORK`
|
||||
- SMPTransportError e - fails connection handshake or some other transport error -> `ERR BROKER TRANSPORT e`
|
||||
- SMPSignatureError C.CryptoError - error when cryptographically "signing" the command.
|
||||
|
||||
### StoreError (Agent/Store.hs)
|
||||
|
||||
- SEInternal ByteString - signals exceptions in store actions.
|
||||
- SEConnNotFound - connection alias not found (or both queues absent).
|
||||
- SEConnDuplicate - connection alias already used.
|
||||
- SEBadConnType ConnType - wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. `updateRcvConnWithSndQueue` and `updateSndConnWithRcvQueue` do not allow duplex connections - they would also return this error.
|
||||
- SEBadQueueStatus - the intention was to pass current expected queue status in methods, as we always know what it should be at any stage of the protocol, and in case it does not match use this error. **Currently not used**.
|
||||
- SENotImplemented - used in `getMsg` that is not implemented/used.
|
||||
|
||||
### CryptoError (Crypto.hs)
|
||||
|
||||
- RSAEncryptError R.Error - RSA encryption error
|
||||
- RSADecryptError R.Error - RSA decryption error
|
||||
- RSASignError R.Error - RSA signature error
|
||||
- AESCipherError CE.CryptoError - AES initialization error
|
||||
- CryptoIVError - IV generation error
|
||||
- AESDecryptError - AES decryption error
|
||||
- CryptoLargeMsgError - message does not fit in SMP block
|
||||
- CryptoHeaderError String - failure parsing RSA-encrypted message header
|
||||
|
||||
### TransportError (Transport.hs)
|
||||
|
||||
- TEBadBlock - error parsing block
|
||||
- TEEncrypt - block encryption error
|
||||
- TEDecrypt - block decryption error
|
||||
- TEHandshake HandshakeError
|
||||
|
||||
### HandshakeError (Transport.hs)
|
||||
|
||||
- ENCRYPT - encryption error
|
||||
- DECRYPT - decryption error
|
||||
- VERSION - error parsing protocol version
|
||||
- RSA_KEY - error parsing RSA key
|
||||
- AES_KEYS - error parsing AES keys
|
||||
- BAD_HASH - not matching RSA key hash
|
||||
- MAJOR_VERSION - lower agent version than protocol version
|
||||
- TERMINATED - transport terminated
|
||||
@@ -35,9 +35,10 @@ packages:
|
||||
# forks / in-progress versions pinned to a git hash. For example:
|
||||
#
|
||||
extra-deps:
|
||||
- sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002
|
||||
- cryptostore-0.2.1.0@sha256:9896e2984f36a1c8790f057fd5ce3da4cbcaf8aa73eb2d9277916886978c5b19,3881
|
||||
- direct-sqlite-2.3.26@sha256:04e835402f1508abca383182023e4e2b9b86297b8533afbd4e57d1a5652e0c23,3718
|
||||
- simple-logger-0.1.0@sha256:be8ede4bd251a9cac776533bae7fb643369ebd826eb948a9a18df1a8dd252ff8,1079
|
||||
- sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002
|
||||
- terminal-0.2.0.0@sha256:de6770ecaae3197c66ac1f0db5a80cf5a5b1d3b64a66a05b50f442de5ad39570,2977
|
||||
# - network-run-0.2.4@sha256:7dbb06def522dab413bce4a46af476820bffdff2071974736b06f52f4ab57c96,885
|
||||
# - git: https://github.com/commercialhaskell/stack.git
|
||||
|
||||
@@ -13,6 +13,7 @@ import Control.Concurrent
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import SMPAgentClient
|
||||
import SMPClient (testKeyHashStr)
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import System.IO (Handle)
|
||||
@@ -79,7 +80,7 @@ h #:# err = tryGet `shouldReturn` ()
|
||||
_ -> return ()
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent
|
||||
pattern Msg m_body <- MSG {m_body}
|
||||
pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk}
|
||||
|
||||
testDuplexConnection :: Handle -> Handle -> IO ()
|
||||
testDuplexConnection alice bob = do
|
||||
@@ -87,13 +88,13 @@ testDuplexConnection alice bob = do
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", CON)
|
||||
alice <# ("", "bob", CON)
|
||||
alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT _) -> True; _ -> False
|
||||
alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT _) -> True; _ -> False
|
||||
alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT 1) -> True; _ -> False
|
||||
alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT 2) -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False
|
||||
bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT _) -> True; _ -> False
|
||||
bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT 3) -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False
|
||||
bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT _) -> True; _ -> False
|
||||
bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT 4) -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False
|
||||
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
|
||||
bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH))
|
||||
@@ -127,24 +128,24 @@ testSubscrNotification (server, _) client = do
|
||||
client <# ("", "conn1", END)
|
||||
|
||||
samplePublicKey :: ByteString
|
||||
samplePublicKey = "256,ppr3DCweAD3RTVFhU2j0u+DnYdqJl1qCdKLHIKsPl1xBzfmnzK0o9GEDlaIClbK39KzPJMljcpnYb2KlSoZ51AhwF5PH2CS+FStc3QzajiqfdOQPet23Hd9YC6pqyTQ7idntqgPrE7yKJF44lUhKlq8QS9KQcbK7W6t7F9uQFw44ceWd2eVf81UV04kQdKWJvC5Sz6jtSZNEfs9mVI8H0wi1amUvS6+7EDJbxikhcCRnFShFO9dUKRYXj6L2JVqXqO5cZgY9BScyneWIg6mhhsTcdDbITM6COlL+pF1f3TjDN+slyV+IzE+ap/9NkpsrCcI8KwwDpqEDmUUV/JQfmQ==,gj2UAiWzSj7iun0iXvI5iz5WEjaqngmB3SzQ5+iarixbaG15LFDtYs3pijG3eGfB1wIFgoP4D2z97vIWn8olT4uCTUClf29zGDDve07h/B3QG/4i0IDnio7MX3AbE8O6PKouqy/GLTfT4WxFUn423g80rpsVYd5oj+SCL2eaxIc="
|
||||
samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"
|
||||
|
||||
syntaxTests :: Spec
|
||||
syntaxTests = do
|
||||
it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR SYNTAX 11")
|
||||
it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX")
|
||||
describe "NEW" do
|
||||
describe "valid" do
|
||||
-- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided)
|
||||
-- TODO: add tests with defined connection alias
|
||||
xit "only server" $ ("211", "", "NEW localhost") >#>= \case ("211", "", "INV" : _) -> True; _ -> False
|
||||
it "with port" $ ("212", "", "NEW localhost:5000") >#>= \case ("212", "", "INV" : _) -> True; _ -> False
|
||||
xit "with keyHash" $ ("213", "", "NEW localhost#1234") >#>= \case ("213", "", "INV" : _) -> True; _ -> False
|
||||
it "with port and keyHash" $ ("214", "", "NEW localhost:5000#1234") >#>= \case ("214", "", "INV" : _) -> True; _ -> False
|
||||
xit "with keyHash" $ ("213", "", "NEW localhost#" <> testKeyHashStr) >#>= \case ("213", "", "INV" : _) -> True; _ -> False
|
||||
it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> testKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False
|
||||
describe "invalid" do
|
||||
-- TODO: add tests with defined connection alias
|
||||
it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR SYNTAX 11")
|
||||
it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR SYNTAX 11")
|
||||
it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR SYNTAX 11")
|
||||
it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR CMD SYNTAX")
|
||||
it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR CMD SYNTAX")
|
||||
it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR CMD SYNTAX")
|
||||
|
||||
describe "JOIN" do
|
||||
describe "valid" do
|
||||
@@ -154,4 +155,4 @@ syntaxTests = do
|
||||
("311", "", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "", "ERR SMP AUTH")
|
||||
describe "invalid" do
|
||||
-- TODO: JOIN is not merged yet - to be added
|
||||
it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR SYNTAX 11")
|
||||
it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX")
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
|
||||
module AgentTests.SQLiteTests (storeTests) where
|
||||
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import qualified Crypto.PubKey.RSA as R
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Time
|
||||
import Data.Word (Word32)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import SMPClient (testKeyHash)
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
@@ -48,41 +52,50 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e
|
||||
-- TODO add null port tests
|
||||
storeTests :: Spec
|
||||
storeTests = withStore do
|
||||
describe "compiled as threadsafe" testCompiledThreadsafe
|
||||
describe "foreign keys enabled" testForeignKeysEnabled
|
||||
describe "store setup" do
|
||||
testCompiledThreadsafe
|
||||
testForeignKeysEnabled
|
||||
describe "store methods" do
|
||||
describe "createRcvConn" testCreateRcvConn
|
||||
describe "createSndConn" testCreateSndConn
|
||||
describe "getAllConnAliases" testGetAllConnAliases
|
||||
describe "getRcvQueue" testGetRcvQueue
|
||||
describe "deleteConn" do
|
||||
describe "RcvConnection" testDeleteRcvConn
|
||||
describe "SndConnection" testDeleteSndConn
|
||||
describe "DuplexConnection" testDeleteDuplexConn
|
||||
describe "upgradeRcvConnToDuplex" testUpgradeRcvConnToDuplex
|
||||
describe "upgradeSndConnToDuplex" testUpgradeSndConnToDuplex
|
||||
describe "set queue status" do
|
||||
describe "setRcvQueueStatus" testSetRcvQueueStatus
|
||||
describe "setSndQueueStatus" testSetSndQueueStatus
|
||||
describe "DuplexConnection" testSetQueueStatusDuplex
|
||||
xdescribe "RcvQueue doesn't exist" testSetRcvQueueStatusNoQueue
|
||||
xdescribe "SndQueue doesn't exist" testSetSndQueueStatusNoQueue
|
||||
describe "createRcvMsg" do
|
||||
describe "RcvQueue exists" testCreateRcvMsg
|
||||
describe "RcvQueue doesn't exist" testCreateRcvMsgNoQueue
|
||||
describe "createSndMsg" do
|
||||
describe "SndQueue exists" testCreateSndMsg
|
||||
describe "SndQueue doesn't exist" testCreateSndMsgNoQueue
|
||||
describe "Queue and Connection management" do
|
||||
describe "createRcvConn" do
|
||||
testCreateRcvConn
|
||||
testCreateRcvConnDuplicate
|
||||
describe "createSndConn" do
|
||||
testCreateSndConn
|
||||
testCreateSndConnDuplicate
|
||||
describe "getAllConnAliases" testGetAllConnAliases
|
||||
describe "getRcvQueue" testGetRcvQueue
|
||||
describe "deleteConn" do
|
||||
testDeleteRcvConn
|
||||
testDeleteSndConn
|
||||
testDeleteDuplexConn
|
||||
describe "upgradeRcvConnToDuplex" do
|
||||
testUpgradeRcvConnToDuplex
|
||||
describe "upgradeSndConnToDuplex" do
|
||||
testUpgradeSndConnToDuplex
|
||||
describe "set Queue status" do
|
||||
describe "setRcvQueueStatus" do
|
||||
testSetRcvQueueStatus
|
||||
testSetRcvQueueStatusNoQueue
|
||||
describe "setSndQueueStatus" do
|
||||
testSetSndQueueStatus
|
||||
testSetSndQueueStatusNoQueue
|
||||
testSetQueueStatusDuplex
|
||||
describe "Msg management" do
|
||||
describe "create Msg" do
|
||||
testCreateRcvMsg
|
||||
testCreateSndMsg
|
||||
testCreateRcvAndSndMsgs
|
||||
|
||||
testCompiledThreadsafe :: SpecWith SQLiteStore
|
||||
testCompiledThreadsafe = do
|
||||
it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do
|
||||
it "compiled sqlite library should be threadsafe" $ \store -> do
|
||||
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
|
||||
compileOptions `shouldNotContain` [["THREADSAFE=0"]]
|
||||
|
||||
testForeignKeysEnabled :: SpecWith SQLiteStore
|
||||
testForeignKeysEnabled = do
|
||||
it "should throw error if foreign keys are enabled" $ \store -> do
|
||||
it "foreign keys should be enabled" $ \store -> do
|
||||
let inconsistentQuery =
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
@@ -96,13 +109,13 @@ testForeignKeysEnabled = do
|
||||
rcvQueue1 :: RcvQueue
|
||||
rcvQueue1 =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
rcvId = "1234",
|
||||
connAlias = "conn1",
|
||||
rcvPrivateKey = C.PrivateKey 1 2 3,
|
||||
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
sndId = Just "2345",
|
||||
sndKey = Nothing,
|
||||
decryptKey = C.PrivateKey 1 2 3,
|
||||
decryptKey = C.safePrivateKey (1, 2, 3),
|
||||
verifyKey = Nothing,
|
||||
status = New
|
||||
}
|
||||
@@ -110,12 +123,12 @@ rcvQueue1 =
|
||||
sndQueue1 :: SndQueue
|
||||
sndQueue1 =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
sndId = "3456",
|
||||
connAlias = "conn1",
|
||||
sndPrivateKey = C.PrivateKey 1 2 3,
|
||||
sndPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
|
||||
signKey = C.PrivateKey 1 2 3,
|
||||
signKey = C.safePrivateKey (1, 2, 3),
|
||||
status = New
|
||||
}
|
||||
|
||||
@@ -131,6 +144,13 @@ testCreateRcvConn = do
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
|
||||
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateRcvConnDuplicate = do
|
||||
it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
createRcvConn store rcvQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
|
||||
testCreateSndConn :: SpecWith SQLiteStore
|
||||
testCreateSndConn = do
|
||||
it "should create SndConnection and add RcvQueue" $ \store -> do
|
||||
@@ -143,118 +163,113 @@ testCreateSndConn = do
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
|
||||
testCreateSndConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateSndConnDuplicate = do
|
||||
it "should throw error on attempt to create duplicate SndConnection" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
createSndConn store sndQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
|
||||
testGetAllConnAliases :: SpecWith SQLiteStore
|
||||
testGetAllConnAliases = do
|
||||
it "should get all conn aliases" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
createSndConn store sndQueue1 {connAlias = "conn2"}
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"}
|
||||
getAllConnAliases store
|
||||
`returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias]
|
||||
|
||||
testGetRcvQueue :: SpecWith SQLiteStore
|
||||
testGetRcvQueue = do
|
||||
it "should get RcvQueue" $ \store -> do
|
||||
let smpServer = SMPServer "smp.simplex.im" (Just "5223") (Just "1234")
|
||||
let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash
|
||||
let recipientId = "1234"
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
getRcvQueue store smpServer recipientId
|
||||
`returnsResult` rcvQueue1
|
||||
|
||||
testDeleteRcvConn :: SpecWith SQLiteStore
|
||||
testDeleteRcvConn = do
|
||||
it "should create RcvConnection and delete it" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn store "conn1"
|
||||
`throwsError` SEBadConn
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
testDeleteSndConn :: SpecWith SQLiteStore
|
||||
testDeleteSndConn = do
|
||||
it "should create SndConnection and delete it" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn store "conn1"
|
||||
`throwsError` SEBadConn
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
testDeleteDuplexConn :: SpecWith SQLiteStore
|
||||
testDeleteDuplexConn = do
|
||||
it "should create DuplexConnection and delete it" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn store "conn1"
|
||||
`throwsError` SEBadConn
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeRcvConnToDuplex = do
|
||||
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
sndId = "2345",
|
||||
connAlias = "conn1",
|
||||
sndPrivateKey = C.PrivateKey 1 2 3,
|
||||
sndPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
|
||||
signKey = C.PrivateKey 1 2 3,
|
||||
signKey = C.safePrivateKey (1, 2, 3),
|
||||
status = New
|
||||
}
|
||||
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
|
||||
`throwsError` SEBadConnType CSnd
|
||||
upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
|
||||
`throwsError` SEBadConnType CDuplex
|
||||
|
||||
testUpgradeSndConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeSndConnToDuplex = do
|
||||
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
rcvId = "3456",
|
||||
connAlias = "conn1",
|
||||
rcvPrivateKey = C.PrivateKey 1 2 3,
|
||||
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
sndId = Just "4567",
|
||||
sndKey = Nothing,
|
||||
decryptKey = C.PrivateKey 1 2 3,
|
||||
decryptKey = C.safePrivateKey (1, 2, 3),
|
||||
verifyKey = Nothing,
|
||||
status = New
|
||||
}
|
||||
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
|
||||
`throwsError` SEBadConnType CRcv
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
|
||||
`throwsError` SEBadConnType CDuplex
|
||||
|
||||
testSetRcvQueueStatus :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatus = do
|
||||
it "should update status of RcvQueue" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Confirmed
|
||||
@@ -265,8 +280,7 @@ testSetRcvQueueStatus = do
|
||||
testSetSndQueueStatus :: SpecWith SQLiteStore
|
||||
testSetSndQueueStatus = do
|
||||
it "should update status of SndQueue" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
@@ -277,10 +291,8 @@ testSetSndQueueStatus = do
|
||||
testSetQueueStatusDuplex :: SpecWith SQLiteStore
|
||||
testSetQueueStatusDuplex = do
|
||||
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Secured
|
||||
@@ -290,61 +302,88 @@ testSetQueueStatusDuplex = do
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn
|
||||
SCDuplex
|
||||
( DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}
|
||||
)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
|
||||
|
||||
testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatusNoQueue = do
|
||||
it "should throw error on attempt to update status of nonexistent RcvQueue" $ \store -> do
|
||||
xit "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do
|
||||
setRcvQueueStatus store rcvQueue1 Confirmed
|
||||
`throwsError` SEInternal
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore
|
||||
testSetSndQueueStatusNoQueue = do
|
||||
it "should throw error on attempt to update status of nonexistent SndQueue" $ \store -> do
|
||||
xit "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`throwsError` SEInternal
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
hw :: ByteString
|
||||
hw = encodeUtf8 "Hello world!"
|
||||
|
||||
ts :: UTCTime
|
||||
ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
|
||||
mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData
|
||||
mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
|
||||
RcvMsgData
|
||||
{ internalId,
|
||||
internalRcvId,
|
||||
internalTs = ts,
|
||||
senderMeta = (externalSndId, ts),
|
||||
brokerMeta = (brokerId, ts),
|
||||
msgBody = hw,
|
||||
internalHash,
|
||||
externalPrevSndHash = "hash_from_sender",
|
||||
msgIntegrity = MsgOk
|
||||
}
|
||||
|
||||
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do
|
||||
updateRcvIds store rcvQueue
|
||||
`returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
|
||||
createRcvMsg store rcvQueue rcvMsgData
|
||||
`returnsResult` ()
|
||||
|
||||
testCreateRcvMsg :: SpecWith SQLiteStore
|
||||
testCreateRcvMsg = do
|
||||
it "should create a RcvMsg and return InternalId" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
it "should reserve internal ids and create a RcvMsg" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
-- TODO getMsg to check message
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
|
||||
`returnsResult` InternalId 1
|
||||
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
|
||||
testCreateRcvMsgNoQueue :: SpecWith SQLiteStore
|
||||
testCreateRcvMsgNoQueue = do
|
||||
it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
|
||||
`throwsError` SEBadConn
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
|
||||
`throwsError` SEBadConnType CSnd
|
||||
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
|
||||
mkSndMsgData internalId internalSndId internalHash =
|
||||
SndMsgData
|
||||
{ internalId,
|
||||
internalSndId,
|
||||
internalTs = ts,
|
||||
msgBody = hw,
|
||||
internalHash
|
||||
}
|
||||
|
||||
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation
|
||||
testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do
|
||||
updateSndIds store sndQueue
|
||||
`returnsResult` (internalId, internalSndId, expectedPrevHash)
|
||||
createSndMsg store sndQueue sndMsgData
|
||||
`returnsResult` ()
|
||||
|
||||
testCreateSndMsg :: SpecWith SQLiteStore
|
||||
testCreateSndMsg = do
|
||||
it "should create a SndMsg and return InternalId" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
-- TODO getMsg to check message
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
|
||||
`returnsResult` InternalId 1
|
||||
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
|
||||
testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
|
||||
|
||||
testCreateSndMsgNoQueue :: SpecWith SQLiteStore
|
||||
testCreateSndMsgNoQueue = do
|
||||
it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
|
||||
`throwsError` SEBadConn
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
|
||||
`throwsError` SEBadConnType CRcv
|
||||
testCreateRcvAndSndMsgs :: SpecWith SQLiteStore
|
||||
testCreateRcvAndSndMsgs = do
|
||||
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
|
||||
18
tests/ProtocolErrorTests.hs
Normal file
18
tests/ProtocolErrorTests.hs
Normal file
@@ -0,0 +1,18 @@
|
||||
module ProtocolErrorTests where
|
||||
|
||||
import Simplex.Messaging.Agent.Transmission (AgentErrorType, agentErrorTypeP, serializeAgentError)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol (ErrorType, errorTypeP, serializeErrorType)
|
||||
import Test.Hspec
|
||||
import Test.Hspec.QuickCheck (modifyMaxSuccess)
|
||||
import Test.QuickCheck
|
||||
|
||||
protocolErrorTests :: Spec
|
||||
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
|
||||
describe "errors parsing / serializing" $ do
|
||||
it "should parse SMP protocol errors" . property $ \err ->
|
||||
parseAll errorTypeP (serializeErrorType err)
|
||||
== Right (err :: ErrorType)
|
||||
it "should parse SMP agent errors" . property $ \err ->
|
||||
parseAll agentErrorTypeP (serializeAgentError err)
|
||||
== Right (err :: AgentErrorType)
|
||||
@@ -6,23 +6,19 @@
|
||||
|
||||
module SMPAgentClient where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import SMPClient (testPort, withSmpServer, withSmpServerThreadOn)
|
||||
import SMPClient (serverBracket, testPort, withSmpServer, withSmpServerThreadOn)
|
||||
import Simplex.Messaging.Agent (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig)
|
||||
import Simplex.Messaging.Transport
|
||||
import System.Timeout (timeout)
|
||||
import Test.Hspec
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Directory
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar)
|
||||
|
||||
agentTestHost :: HostName
|
||||
agentTestHost = "localhost"
|
||||
@@ -125,12 +121,10 @@ cfg =
|
||||
}
|
||||
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn (port', db') f = do
|
||||
started <- newEmptyTMVarIO
|
||||
E.bracket
|
||||
(forkIOWithUnmask ($ runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'}))
|
||||
(liftIO . killThread >=> const (removeFile db'))
|
||||
\x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x
|
||||
withSmpAgentThreadOn (port', db') =
|
||||
serverBracket
|
||||
(\started -> runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'})
|
||||
(removeFile db')
|
||||
|
||||
withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a
|
||||
withSmpAgentOn (port', db') = withSmpAgentThreadOn (port', db') . const
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module SMPClient where
|
||||
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Network.Socket
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server (runSMPServerBlocking)
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.StoreLog (openReadStoreLog)
|
||||
import Simplex.Messaging.Transport
|
||||
import System.Timeout (timeout)
|
||||
import Test.Hspec
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar)
|
||||
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
|
||||
import UnliftIO.Timeout (timeout)
|
||||
|
||||
testHost :: HostName
|
||||
testHost = "localhost"
|
||||
@@ -25,13 +32,21 @@ testHost = "localhost"
|
||||
testPort :: ServiceName
|
||||
testPort = "5000"
|
||||
|
||||
testSMPClient :: MonadUnliftIO m => (Handle -> m a) -> m a
|
||||
testSMPClient client = do
|
||||
runTCPClient testHost testPort $ \h -> do
|
||||
line <- liftIO $ getLn h
|
||||
if line == "Welcome to SMP v0.2.0"
|
||||
then client h
|
||||
else error "not connected"
|
||||
testKeyHashStr :: B.ByteString
|
||||
testKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="
|
||||
|
||||
testKeyHash :: Maybe C.KeyHash
|
||||
testKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="
|
||||
|
||||
testStoreLogFile :: FilePath
|
||||
testStoreLogFile = "tests/tmp/smp-server-store.log"
|
||||
|
||||
testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a
|
||||
testSMPClient client =
|
||||
runTCPClient testHost testPort $ \h ->
|
||||
liftIO (runExceptT $ clientHandshake h testKeyHash) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
cfg :: ServerConfig
|
||||
cfg =
|
||||
@@ -39,16 +54,66 @@ cfg =
|
||||
{ tcpPort = testPort,
|
||||
tbqSize = 1,
|
||||
queueIdBytes = 12,
|
||||
msgIdBytes = 6
|
||||
msgIdBytes = 6,
|
||||
storeLog = Nothing,
|
||||
serverPrivateKey =
|
||||
-- full RSA private key (only for tests)
|
||||
"MIIFIwIBAAKCAQEArZyrri/NAwt5buvYjwu+B/MQeJUszDBpRgVqNddlI9kNwDXu\
|
||||
\kaJ8chEhrtaUgXeSWGooWwqjXEUQE6RVbCC6QVo9VEBSP4xFwVVd9Fj7OsgfcXXh\
|
||||
\AqWxfctDcBZQ5jTUiJpdBc+Vz2ZkumVNl0W+j9kWm9nfkMLQj8c0cVSDxz4OKpZb\
|
||||
\qFuj0uzHkis7e7wsrKSKWLPg3M5ZXPZM1m9qn7SfJzDRDfJifamxWI7uz9XK2+Dp\
|
||||
\NkUQlGQgFJEv1cKN88JAwIqZ1s+TAQMQiB+4QZ2aNfSqGEzRJN7FMCKRK7pM0A9A\
|
||||
\PCnijyuImvKFxTdk8Bx1q+XNJzsY6fBrLWJZ+QKBgQCySG4tzlcEm+tOVWRcwrWh\
|
||||
\6zsczGZp9mbf9c8itRx6dlldSYuDG1qnddL70wuAZF2AgS1JZgvcRZECoZRoWP5q\
|
||||
\Kq2wvpTIYjFPpC39lxgUoA/DXKVKZZdan+gwaVPAPT54my1CS32VrOiAY4gVJ3LJ\
|
||||
\Mn1/FqZXUFQA326pau3loQKCAQEAoljmJMp88EZoy3HlHUbOjl5UEhzzVsU1TnQi\
|
||||
\QmPm+aWRe2qelhjW4aTvSVE5mAUJsN6UWTeMf4uvM69Z9I5pfw2pEm8x4+GxRibY\
|
||||
\iiwF2QNaLxxmzEHm1zQQPTgb39o8mgklhzFPill0JsnL3f6IkVwjFJofWSmpqEGs\
|
||||
\dFSMRSXUTVXh1p/o7QZrhpwO/475iWKVS7o48N/0Xp513re3aXw+DRNuVnFEaBIe\
|
||||
\TLvWM9Czn16ndAu1HYiTBuMvtRbAWnGZxU8ewzF4wlWK5tdIL5PTJDd1VhZJAKtB\
|
||||
\npDvJpwxzKmjAhcTmjx0ckMIWtdVaOVm/2gWCXDty2FEdg7koQKBgQDOUUguJ/i7\
|
||||
\q0jldWYRnVkotKnpInPdcEaodrehfOqYEHnvro9xlS6OeAS4Vz5AdH45zQ/4J3bV\
|
||||
\2cH66tNr18ebM9nL//t5G69i89R9W7szyUxCI3LmAIdi3oSEbmz5GQBaw4l6h9Wi\
|
||||
\n4FmFQaAXZrjQfO2qJcAHvWRsMp2pmqAGwKBgQDXaza0DRsKWywWznsHcmHa0cx8\
|
||||
\I4jxqGaQmLO7wBJRP1NSFrywy1QfYrVX9CTLBK4V3F0PCgZ01Qv94751CzN43TgF\
|
||||
\ebd/O9r5NjNTnOXzdWqETbCffLGd6kLgCMwPQWpM9ySVjXHWCGZsRAnF2F6M1O32\
|
||||
\43StIifvwJQFqSM3ewKBgCaW6y7sRY90Ua7283RErezd9EyT22BWlDlACrPu3FNC\
|
||||
\LtBf1j43uxBWBQrMLsHe2GtTV0xt9m0MfwZsm2gSsXcm4Xi4DJgfN+Z7rIlyy9UY\
|
||||
\PCDSdZiU1qSr+NrffDrXlfiAM1cUmCdUX7eKjp/ltkUHNaOGfSn5Pdr3MkAiD/Hf\
|
||||
\AoGBAKIdKCuOwuYlwjS9J+IRGuSSM4o+OxQdwGmcJDTCpyWb5dEk68e7xKIna3zf\
|
||||
\jc+H+QdMXv1nkRK9bZgYheXczsXaNZUSTwpxaEldzVD3hNvsXSgJRy9fqHwA4PBq\
|
||||
\vqiBHoO3RNbqg+2rmTMfDuXreME3S955ZiPZm4Z+T8Hj52mPAoGAQm5QH/gLFtY5\
|
||||
\+znqU/0G8V6BKISCQMxbbmTQVcTgGySrP2gVd+e4MWvUttaZykhWqs8rpr7mgpIY\
|
||||
\hul7Swx0SHFN3WpXu8uj+B6MLpRcCbDHO65qU4kQLs+IaXXsuuTjMvJ5LwjkZVrQ\
|
||||
\TmKzSAw7iVWwEUZR/PeiEKazqrpp9VU="
|
||||
}
|
||||
|
||||
withSmpServerStoreLogOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a
|
||||
withSmpServerStoreLogOn port client = do
|
||||
s <- liftIO $ openReadStoreLog testStoreLogFile
|
||||
serverBracket
|
||||
(\started -> runSMPServerBlocking started cfg {tcpPort = port, storeLog = Just s})
|
||||
(pure ())
|
||||
client
|
||||
|
||||
withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a
|
||||
withSmpServerThreadOn port f = do
|
||||
withSmpServerThreadOn port =
|
||||
serverBracket
|
||||
(\started -> runSMPServerBlocking started cfg {tcpPort = port})
|
||||
(pure ())
|
||||
|
||||
serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
|
||||
serverBracket process afterProcess f = do
|
||||
started <- newEmptyTMVarIO
|
||||
E.bracket
|
||||
(forkIOWithUnmask ($ runSMPServerBlocking started cfg {tcpPort = port}))
|
||||
(liftIO . killThread)
|
||||
\x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x
|
||||
(forkIOWithUnmask ($ process started))
|
||||
(\t -> killThread t >> afterProcess >> waitFor started "stop")
|
||||
(\t -> waitFor started "start" >> f t)
|
||||
where
|
||||
waitFor started s =
|
||||
5_000_000 `timeout` atomically (takeTMVar started) >>= \case
|
||||
Nothing -> error $ "server did not " <> s
|
||||
_ -> pure ()
|
||||
|
||||
withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a
|
||||
withSmpServerOn port = withSmpServerThreadOn port . const
|
||||
@@ -56,33 +121,43 @@ withSmpServerOn port = withSmpServerThreadOn port . const
|
||||
withSmpServer :: (MonadUnliftIO m, MonadRandom m) => m a -> m a
|
||||
withSmpServer = withSmpServerOn testPort
|
||||
|
||||
runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a
|
||||
runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (THandle -> m a) -> m a
|
||||
runSmpTest test = withSmpServer $ testSMPClient test
|
||||
|
||||
runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([Handle] -> m a) -> m a
|
||||
runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([THandle] -> m a) -> m a
|
||||
runSmpTestN nClients test = withSmpServer $ run nClients []
|
||||
where
|
||||
run :: Int -> [Handle] -> m a
|
||||
run :: Int -> [THandle] -> m a
|
||||
run 0 hs = test hs
|
||||
run n hs = testSMPClient $ \h -> run (n - 1) (h : hs)
|
||||
|
||||
smpServerTest :: RawTransmission -> IO RawTransmission
|
||||
smpServerTest cmd = runSmpTest $ \h -> tPutRaw h cmd >> tGetRaw h
|
||||
|
||||
smpTest :: (Handle -> IO ()) -> Expectation
|
||||
smpTest :: (THandle -> IO ()) -> Expectation
|
||||
smpTest test' = runSmpTest test' `shouldReturn` ()
|
||||
|
||||
smpTestN :: Int -> ([Handle] -> IO ()) -> Expectation
|
||||
smpTestN :: Int -> ([THandle] -> IO ()) -> Expectation
|
||||
smpTestN n test' = runSmpTestN n test' `shouldReturn` ()
|
||||
|
||||
smpTest2 :: (Handle -> Handle -> IO ()) -> Expectation
|
||||
smpTest2 :: (THandle -> THandle -> IO ()) -> Expectation
|
||||
smpTest2 test' = smpTestN 2 _test
|
||||
where
|
||||
_test [h1, h2] = test' h1 h2
|
||||
_test _ = error "expected 2 handles"
|
||||
|
||||
smpTest3 :: (Handle -> Handle -> Handle -> IO ()) -> Expectation
|
||||
smpTest3 :: (THandle -> THandle -> THandle -> IO ()) -> Expectation
|
||||
smpTest3 test' = smpTestN 3 _test
|
||||
where
|
||||
_test [h1, h2, h3] = test' h1 h2 h3
|
||||
_test _ = error "expected 3 handles"
|
||||
|
||||
tPutRaw :: THandle -> RawTransmission -> IO ()
|
||||
tPutRaw h (sig, corrId, queueId, command) = do
|
||||
let t = B.intercalate " " [corrId, queueId, command]
|
||||
void $ tPut h (C.Signature sig, t)
|
||||
|
||||
tGetRaw :: THandle -> IO RawTransmission
|
||||
tGetRaw h = do
|
||||
("", (CorrId corrId, qId, Right cmd)) <- tGet fromServer h
|
||||
pure ("", corrId, encode qId, serializeCommand cmd)
|
||||
|
||||
@@ -4,22 +4,29 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module ServerTests where
|
||||
|
||||
import Control.Concurrent (ThreadId, killThread)
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (SomeException, try)
|
||||
import Control.Monad.Except (forM_, runExceptT)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import SMPClient
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import System.IO (Handle)
|
||||
import Simplex.Messaging.Transport
|
||||
import System.Directory (removeFile)
|
||||
import System.TimeIt (timeItT)
|
||||
import System.Timeout
|
||||
import Test.HUnit
|
||||
import Test.Hspec
|
||||
|
||||
rsaKeySize :: Int
|
||||
rsaKeySize = 1024 `div` 8
|
||||
rsaKeySize = 2048 `div` 8
|
||||
|
||||
serverTests :: Spec
|
||||
serverTests = do
|
||||
@@ -30,18 +37,20 @@ serverTests = do
|
||||
describe "SMP messages" do
|
||||
describe "duplex communication over 2 SMP connections" testDuplex
|
||||
describe "switch subscription to another SMP queue" testSwitchSub
|
||||
describe "Store log" testWithStoreLog
|
||||
describe "Timing of AUTH error" testTiming
|
||||
|
||||
pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError
|
||||
pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command)))
|
||||
|
||||
sendRecv :: Handle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
|
||||
sendRecv :: THandle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
|
||||
sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> tGet fromServer h
|
||||
|
||||
signSendRecv :: Handle -> C.PrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
|
||||
signSendRecv :: THandle -> C.SafePrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
|
||||
signSendRecv h pk (corrId, qId, cmd) = do
|
||||
let t = B.intercalate "\r\n" [corrId, encode qId, cmd]
|
||||
Right sig <- C.sign pk t
|
||||
tPut h (sig, t)
|
||||
let t = B.intercalate " " [corrId, encode qId, cmd]
|
||||
Right sig <- runExceptT $ C.sign pk t
|
||||
_ <- tPut h (sig, t)
|
||||
tGet fromServer h
|
||||
|
||||
cmdSEND :: ByteString -> ByteString
|
||||
@@ -61,7 +70,7 @@ testCreateSecure =
|
||||
Resp "abcd" rId1 (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
|
||||
(rId1, "") #== "creates queue"
|
||||
|
||||
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5\r\nhello")
|
||||
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5 hello ")
|
||||
(ok1, OK) #== "accepts unsigned SEND"
|
||||
(sId1, sId) #== "same queue ID in response 1"
|
||||
|
||||
@@ -72,10 +81,10 @@ testCreateSecure =
|
||||
(ok4, OK) #== "replies OK when message acknowledged if no more messages"
|
||||
|
||||
Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, "ACK")
|
||||
(err6, ERR PROHIBITED) #== "replies ERR when message acknowledged without messages"
|
||||
(err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages"
|
||||
|
||||
(sPub, sKey) <- C.generateKeyPair rsaKeySize
|
||||
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5\r\nhello")
|
||||
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5 hello ")
|
||||
(err1, ERR AUTH) #== "rejects signed SEND"
|
||||
(sId2, sId) #== "same queue ID in response 2"
|
||||
|
||||
@@ -93,7 +102,7 @@ testCreateSecure =
|
||||
Resp "abcd" _ err4 <- signSendRecv h rKey ("abcd", rId, keyCmd)
|
||||
(err4, ERR AUTH) #== "rejects KEY if already secured"
|
||||
|
||||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11\r\nhello again")
|
||||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11 hello again ")
|
||||
(ok3, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "" _ (MSG _ _ msg) <- tGet fromServer h
|
||||
@@ -102,7 +111,7 @@ testCreateSecure =
|
||||
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, "ACK")
|
||||
(ok5, OK) #== "replies OK when message acknowledged 2"
|
||||
|
||||
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5\r\nhello")
|
||||
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5 hello ")
|
||||
(err5, ERR AUTH) #== "rejects unsigned SEND"
|
||||
|
||||
testCreateDelete :: Spec
|
||||
@@ -117,10 +126,10 @@ testCreateDelete =
|
||||
Resp "bcda" _ ok1 <- signSendRecv rh rKey ("bcda", rId, "KEY " <> C.serializePubKey sPub)
|
||||
(ok1, OK) #== "secures queue"
|
||||
|
||||
Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello")
|
||||
Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ")
|
||||
(ok2, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7\r\nhello 2")
|
||||
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7 hello 2 ")
|
||||
(ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed"
|
||||
|
||||
Resp "" _ (MSG _ _ msg1) <- tGet fromServer rh
|
||||
@@ -136,10 +145,10 @@ testCreateDelete =
|
||||
(ok3, OK) #== "suspends queue"
|
||||
(rId2, rId) #== "same queue ID in response 2"
|
||||
|
||||
Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5\r\nhello")
|
||||
Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5 hello ")
|
||||
(err3, ERR AUTH) #== "rejects signed SEND"
|
||||
|
||||
Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5\r\nhello")
|
||||
Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5 hello ")
|
||||
(err4, ERR AUTH) #== "reject unsigned SEND too"
|
||||
|
||||
Resp "bcda" _ ok4 <- signSendRecv rh rKey ("bcda", rId, "OFF")
|
||||
@@ -158,10 +167,10 @@ testCreateDelete =
|
||||
(ok6, OK) #== "deletes queue"
|
||||
(rId3, rId) #== "same queue ID in response 3"
|
||||
|
||||
Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello")
|
||||
Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ")
|
||||
(err7, ERR AUTH) #== "rejects signed SEND when deleted"
|
||||
|
||||
Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5\r\nhello")
|
||||
Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5 hello ")
|
||||
(err8, ERR AUTH) #== "rejects unsigned SEND too when deleted"
|
||||
|
||||
Resp "abcd" _ err11 <- signSendRecv rh rKey ("abcd", rId, "ACK")
|
||||
@@ -211,7 +220,7 @@ testDuplex =
|
||||
(aliceKey, C.serializePubKey asPub) #== "key received from Alice"
|
||||
Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, "KEY " <> aliceKey)
|
||||
|
||||
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8\r\nhi alice")
|
||||
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8 hi alice ")
|
||||
|
||||
Resp "" _ (MSG _ _ msg4) <- tGet fromServer alice
|
||||
Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, "ACK")
|
||||
@@ -229,7 +238,7 @@ testSwitchSub =
|
||||
smpTest3 \rh1 rh2 sh -> do
|
||||
(rPub, rKey) <- C.generateKeyPair rsaKeySize
|
||||
Resp "abcd" _ (IDS rId sId) <- signSendRecv rh1 rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
|
||||
Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5\r\ntest1")
|
||||
Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5 test1 ")
|
||||
(ok1, OK) #== "sent test message 1"
|
||||
Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, cmdSEND "test2, no ACK")
|
||||
(ok2, OK) #== "sent test message 2"
|
||||
@@ -246,13 +255,13 @@ testSwitchSub =
|
||||
Resp "" _ end <- tGet fromServer rh1
|
||||
(end, END) #== "unsubscribed the 1st TCP connection"
|
||||
|
||||
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5\r\ntest3")
|
||||
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5 test3 ")
|
||||
|
||||
Resp "" _ (MSG _ _ msg3) <- tGet fromServer rh2
|
||||
(msg3, "test3") #== "delivered to the 2nd TCP connection"
|
||||
|
||||
Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, "ACK")
|
||||
(err, ERR PROHIBITED) #== "rejects ACK from the 1st TCP connection"
|
||||
(err, ERR NO_MSG) #== "rejects ACK from the 1st TCP connection"
|
||||
|
||||
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, "ACK")
|
||||
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
|
||||
@@ -261,40 +270,128 @@ testSwitchSub =
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else is delivered to the 1st TCP connection"
|
||||
|
||||
testWithStoreLog :: Spec
|
||||
testWithStoreLog =
|
||||
it "should store simplex queues to log and restore them after server restart" $ do
|
||||
(sPub1, sKey1) <- C.generateKeyPair rsaKeySize
|
||||
(sPub2, sKey2) <- C.generateKeyPair rsaKeySize
|
||||
senderId1 <- newTVarIO ""
|
||||
senderId2 <- newTVarIO ""
|
||||
|
||||
withSmpServerStoreLogOn testPort . runTest $ \h -> do
|
||||
(sId1, _, _) <- createAndSecureQueue h sPub1
|
||||
atomically $ writeTVar senderId1 sId1
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ")
|
||||
Resp "" _ (MSG _ _ "hello") <- tGet fromServer h
|
||||
|
||||
(sId2, rId2, rKey2) <- createAndSecureQueue h sPub2
|
||||
atomically $ writeTVar senderId2 sId2
|
||||
Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ")
|
||||
Resp "" _ (MSG _ _ "hello too") <- tGet fromServer h
|
||||
|
||||
Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, "DEL")
|
||||
pure ()
|
||||
|
||||
logSize `shouldReturn` 5
|
||||
|
||||
withSmpServerThreadOn testPort . runTest $ \h -> do
|
||||
sId1 <- readTVarIO senderId1
|
||||
-- fails if store log is disabled
|
||||
Resp "bcda" _ (ERR AUTH) <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ")
|
||||
pure ()
|
||||
|
||||
withSmpServerStoreLogOn testPort . runTest $ \h -> do
|
||||
-- this queue is restored
|
||||
sId1 <- readTVarIO senderId1
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ")
|
||||
-- this queue is removed - not restored
|
||||
sId2 <- readTVarIO senderId2
|
||||
Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ")
|
||||
pure ()
|
||||
|
||||
logSize `shouldReturn` 1
|
||||
removeFile testStoreLogFile
|
||||
where
|
||||
createAndSecureQueue :: THandle -> SenderPublicKey -> IO (SenderId, RecipientId, C.SafePrivateKey)
|
||||
createAndSecureQueue h sPub = do
|
||||
(rPub, rKey) <- C.generateKeyPair rsaKeySize
|
||||
Resp "abcd" "" (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
|
||||
let keyCmd = "KEY " <> C.serializePubKey sPub
|
||||
Resp "dabc" rId' OK <- signSendRecv h rKey ("dabc", rId, keyCmd)
|
||||
(rId', rId) #== "same queue ID"
|
||||
pure (sId, rId, rKey)
|
||||
|
||||
runTest :: (THandle -> IO ()) -> ThreadId -> Expectation
|
||||
runTest test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
killThread server
|
||||
|
||||
logSize :: IO Int
|
||||
logSize =
|
||||
try (length . B.lines <$> B.readFile testStoreLogFile) >>= \case
|
||||
Right l -> pure l
|
||||
Left (_ :: SomeException) -> logSize
|
||||
|
||||
testTiming :: Spec
|
||||
testTiming =
|
||||
it "should have similar time for auth error whether queue exists or not" $
|
||||
smpTest2 \rh sh -> do
|
||||
(rPub, rKey) <- C.generateKeyPair rsaKeySize
|
||||
Resp "abcd" "" (IDS rId sId) <- signSendRecv rh rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
|
||||
|
||||
(sPub, sKey) <- C.generateKeyPair rsaKeySize
|
||||
let keyCmd = "KEY " <> C.serializePubKey sPub
|
||||
Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, keyCmd)
|
||||
|
||||
Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, "SEND 5 hello ")
|
||||
|
||||
timeNoQueue <- timeRepeat 25 $ do
|
||||
Resp "dabc" _ (ERR AUTH) <- signSendRecv sh sKey ("dabc", rId, "SEND 5 hello ")
|
||||
return ()
|
||||
timeWrongKey <- timeRepeat 25 $ do
|
||||
Resp "cdab" _ (ERR AUTH) <- signSendRecv sh rKey ("cdab", sId, "SEND 5 hello ")
|
||||
return ()
|
||||
abs (timeNoQueue - timeWrongKey) / timeNoQueue < 0.15 `shouldBe` True
|
||||
where
|
||||
timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const
|
||||
|
||||
samplePubKey :: ByteString
|
||||
samplePubKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"
|
||||
|
||||
syntaxTests :: Spec
|
||||
syntaxTests = do
|
||||
it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR SYNTAX 2")
|
||||
it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX")
|
||||
describe "NEW" do
|
||||
it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR SYNTAX 2")
|
||||
it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR SYNTAX 2")
|
||||
it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR SYNTAX 3")
|
||||
it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR SYNTAX 4")
|
||||
it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX")
|
||||
it "many parameters" $ ("1234", "cdab", "", "NEW 1 " <> samplePubKey) >#> ("", "cdab", "", "ERR CMD SYNTAX")
|
||||
it "no signature" $ ("", "dabc", "", "NEW " <> samplePubKey) >#> ("", "dabc", "", "ERR CMD NO_AUTH")
|
||||
it "queue ID" $ ("1234", "abcd", "12345678", "NEW " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH")
|
||||
describe "KEY" do
|
||||
it "valid syntax" $ ("1234", "bcda", "12345678", "KEY 3,4567,4567") >#> ("", "bcda", "12345678", "ERR AUTH")
|
||||
it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR SYNTAX 2")
|
||||
it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR SYNTAX 2")
|
||||
it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR SYNTAX 3")
|
||||
it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR SYNTAX 3")
|
||||
it "valid syntax" $ ("1234", "bcda", "12345678", "KEY " <> samplePubKey) >#> ("", "bcda", "12345678", "ERR AUTH")
|
||||
it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX")
|
||||
it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 " <> samplePubKey) >#> ("", "dabc", "12345678", "ERR CMD SYNTAX")
|
||||
it "no signature" $ ("", "abcd", "12345678", "KEY " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH")
|
||||
it "no queue ID" $ ("1234", "bcda", "", "KEY " <> samplePubKey) >#> ("", "bcda", "", "ERR CMD NO_AUTH")
|
||||
noParamsSyntaxTest "SUB"
|
||||
noParamsSyntaxTest "ACK"
|
||||
noParamsSyntaxTest "OFF"
|
||||
noParamsSyntaxTest "DEL"
|
||||
describe "SEND" do
|
||||
it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5\r\nhello") >#> ("", "cdab", "12345678", "ERR AUTH")
|
||||
it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11\r\nhello there") >#> ("", "dabc", "12345678", "ERR AUTH")
|
||||
it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR SYNTAX 2")
|
||||
it "no queue ID" $ ("1234", "bcda", "", "SEND 5\r\nhello") >#> ("", "bcda", "", "ERR SYNTAX 5")
|
||||
it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello") >#> ("", "cdab", "12345678", "ERR SYNTAX 2")
|
||||
it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello") >#> ("", "dabc", "12345678", "ERR SYNTAX 2")
|
||||
it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4\r\nhello") >#> ("", "abcd", "12345678", "ERR SIZE")
|
||||
it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH")
|
||||
it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH")
|
||||
it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX")
|
||||
it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR CMD NO_QUEUE")
|
||||
it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX")
|
||||
it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX")
|
||||
it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX")
|
||||
describe "PING" do
|
||||
it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG")
|
||||
describe "broker response not allowed" do
|
||||
it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR PROHIBITED")
|
||||
it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR CMD PROHIBITED")
|
||||
where
|
||||
noParamsSyntaxTest :: ByteString -> Spec
|
||||
noParamsSyntaxTest cmd = describe (B.unpack cmd) do
|
||||
it "valid syntax" $ ("1234", "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH")
|
||||
it "parameters" $ ("1234", "bcda", "12345678", cmd <> " 1") >#> ("", "bcda", "12345678", "ERR SYNTAX 2")
|
||||
it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR SYNTAX 3")
|
||||
it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR SYNTAX 3")
|
||||
it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR CMD SYNTAX")
|
||||
it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR CMD NO_AUTH")
|
||||
it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR CMD NO_AUTH")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import AgentTests
|
||||
import MarkdownTests
|
||||
import ProtocolErrorTests
|
||||
import ServerTests
|
||||
import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
|
||||
import Test.Hspec
|
||||
@@ -9,6 +10,7 @@ main = do
|
||||
createDirectoryIfMissing False "tests/tmp"
|
||||
hspec $ do
|
||||
describe "SimpleX markdown" markdownTests
|
||||
describe "Protocol errors" protocolErrorTests
|
||||
describe "SMP server" serverTests
|
||||
describe "SMP client agent" agentTests
|
||||
removeDirectoryRecursive "tests/tmp"
|
||||
|
||||
Reference in New Issue
Block a user