Merge pull request #82 from simplex-chat/v2

This commit is contained in:
Evgeny Poberezkin
2021-05-02 11:34:24 +01:00
committed by GitHub
33 changed files with 2334 additions and 1026 deletions

View File

@@ -1,11 +1,12 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module ChatOptions (getChatOpts, ChatOpts (..)) where
import qualified Data.Attoparsec.ByteString.Char8 as A
import qualified Data.ByteString.Char8 as B
import Options.Applicative
import Simplex.Messaging.Agent.Transmission (SMPServer (..), smpServerP)
import Simplex.Messaging.Parsers (parseAll)
import System.FilePath (combine)
import Types
@@ -30,8 +31,8 @@ chatOpts appDir =
( long "server"
<> short 's'
<> metavar "SERVER"
<> help "SMP server to use (smp.simplex.im:5223)"
<> value (SMPServer "smp.simplex.im" (Just "5223") Nothing)
<> help "SMP server to use (smp1.simplex.im:5223#pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA=)"
<> value (SMPServer "smp1.simplex.im" (Just "5223") (Just "pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA="))
)
<*> option
parseTermMode
@@ -45,7 +46,7 @@ chatOpts appDir =
defaultDbFilePath = combine appDir "smp-chat.db"
parseSMPServer :: ReadM SMPServer
parseSMPServer = eitherReader $ A.parseOnly (smpServerP <* A.endOfInput) . B.pack
parseSMPServer = eitherReader $ parseAll smpServerP . B.pack
parseTermMode :: ReadM TermMode
parseTermMode = maybeReader $ \case

View File

@@ -31,6 +31,7 @@ import Simplex.Messaging.Agent.Client (AgentClient (..))
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Client (smpDefaultConfig)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Util (raceAny_)
import Styled
import System.Console.ANSI.Types
@@ -89,6 +90,7 @@ data ChatResponse
| ReceivedMessage Contact ByteString
| Disconnected Contact
| YesYes
| ContactError ConnectionErrorType Contact
| ErrorInput ByteString
| ChatError AgentErrorType
| NoChatResponse
@@ -107,8 +109,13 @@ serializeChatResponse = \case
Connected c -> [ttyContact c <> " connected"]
Confirmation c -> [ttyContact c <> " ok"]
ReceivedMessage c t -> prependFirst (ttyFromContact c) $ msgPlain t
-- TODO either add command to re-connect or update message below
Disconnected c -> ["disconnected from " <> ttyContact c <> " - try \"/chat " <> bPlain (toBs c) <> "\""]
YesYes -> ["you got it!"]
ContactError e c -> case e of
UNKNOWN -> ["no contact " <> ttyContact c]
DUPLICATE -> ["contact " <> ttyContact c <> " already exists"]
SIMPLEX -> ["contact " <> ttyContact c <> " did not accept invitation yet"]
ErrorInput t -> ["invalid input: " <> bPlain t]
ChatError e -> ["chat error: " <> plain (show e)]
NoChatResponse -> [""]
@@ -172,7 +179,7 @@ main = do
t <- getChatClient smpServer
ct <- newChatTerminal (tbqSize cfg) termMode
-- setLogLevel LogInfo -- LogError
-- withGlobalLogging logCfg $
-- withGlobalLogging logCfg $ do
env <- newSMPAgentEnv cfg {dbFile = dbFileName}
dogFoodChat t ct env
@@ -209,7 +216,7 @@ newChatClient qSize smpServer = do
receiveFromChatTerm :: ChatClient -> ChatTerminal -> IO ()
receiveFromChatTerm t ct = forever $ do
atomically (readTBQueue $ inputQ ct)
>>= processOrError . A.parseOnly (chatCommandP <* A.endOfInput) . encodeUtf8 . T.pack
>>= processOrError . parseAll chatCommandP . encodeUtf8 . T.pack
where
processOrError = \case
Left err -> writeOutQ . ErrorInput $ B.pack err
@@ -259,9 +266,10 @@ receiveFromAgent t ct c = forever . atomically $ do
INV qInfo -> Invitation qInfo
CON -> Connected contact
END -> Disconnected contact
MSG {m_body} -> ReceivedMessage contact m_body
MSG {msgBody} -> ReceivedMessage contact msgBody
SENT _ -> NoChatResponse
OK -> Confirmation contact
ERR (CONN e) -> ContactError e contact
ERR e -> ChatError e
where
contact = Contact a

View File

@@ -1,7 +1,28 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
module Main where
import Control.Monad (unless, when)
import qualified Crypto.Store.PKCS8 as S
import qualified Data.ByteString.Char8 as B
import Data.Char (toLower)
import Data.Functor (($>))
import Data.Ini (lookupValue, readIniFile)
import qualified Data.Text as T
import Data.X509 (PrivKey (PrivKeyRSA))
import Options.Applicative
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Server (runSMPServer)
import Simplex.Messaging.Server.Env.STM
import Simplex.Messaging.Server.StoreLog (StoreLog, openReadStoreLog)
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.Exit (exitFailure)
import System.FilePath (combine)
import System.IO (IOMode (..), hFlush, stdout)
cfg :: ServerConfig
cfg =
@@ -9,10 +30,137 @@ cfg =
{ tcpPort = "5223",
tbqSize = 16,
queueIdBytes = 12,
msgIdBytes = 6
msgIdBytes = 6,
storeLog = Nothing,
-- key is loaded from the file server_key in /etc/opt/simplex directory
serverPrivateKey = undefined
}
newKeySize :: Int
newKeySize = 2048 `div` 8
cfgDir :: FilePath
cfgDir = "/etc/opt/simplex"
logDir :: FilePath
logDir = "/var/opt/simplex"
defaultStoreLogFile :: FilePath
defaultStoreLogFile = combine logDir "smp-server-store.log"
main :: IO ()
main = do
putStrLn $ "Listening on port " ++ tcpPort cfg
runSMPServer cfg
opts <- getServerOpts
putStrLn "SMP Server (-h for help)"
ini <- readCreateIni opts
storeLog <- openStoreLog ini
pk <- readCreateKey
B.putStrLn $ "transport key hash: " <> publicKeyHash (C.publicKey pk)
putStrLn $ "listening on port " <> tcpPort cfg
runSMPServer cfg {serverPrivateKey = pk, storeLog}
data IniOpts = IniOpts
{ enableStoreLog :: Bool,
storeLogFile :: FilePath
}
readCreateIni :: ServerOpts -> IO IniOpts
readCreateIni ServerOpts {configFile} = do
createDirectoryIfMissing True cfgDir
doesFileExist configFile >>= (`unless` createIni)
readIni
where
readIni :: IO IniOpts
readIni = do
ini <- either exitError pure =<< readIniFile configFile
let enableStoreLog = (== Right "on") $ lookupValue "STORE_LOG" "enable" ini
storeLogFile = either (const defaultStoreLogFile) T.unpack $ lookupValue "STORE_LOG" "file" ini
pure IniOpts {enableStoreLog, storeLogFile}
exitError e = do
putStrLn $ "error reading config file " <> configFile <> ": " <> e
exitFailure
createIni :: IO ()
createIni = do
confirm $ "Save default ini file to " <> configFile
writeFile
configFile
"[STORE_LOG]\n\
\# The server uses STM memory to store SMP queues and messages,\n\
\# that will be lost on restart (e.g., as with redis).\n\
\# This option enables saving SMP queues to append only log,\n\
\# and restoring them when the server is started.\n\
\# Log is compacted on start (deleted queues are removed).\n\
\# The messages in the queues are not logged.\n\
\\n\
\# enable: on\n\
\# file: /var/opt/simplex/smp-server-store.log\n"
readCreateKey :: IO C.FullPrivateKey
readCreateKey = do
createDirectoryIfMissing True cfgDir
let path = combine cfgDir "server_key"
hasKey <- doesFileExist path
(if hasKey then readKey else createKey) path
where
createKey :: FilePath -> IO C.FullPrivateKey
createKey path = do
confirm "Generate new server key pair"
(_, pk) <- C.generateKeyPair newKeySize
S.writeKeyFile S.TraditionalFormat path [PrivKeyRSA $ C.rsaPrivateKey pk]
pure pk
readKey :: FilePath -> IO C.FullPrivateKey
readKey path = do
S.readKeyFile path >>= \case
[S.Unprotected (PrivKeyRSA pk)] -> pure $ C.FullPrivateKey pk
[_] -> errorExit "not RSA key"
[] -> errorExit "invalid key file format"
_ -> errorExit "more than one key"
where
errorExit :: String -> IO b
errorExit e = putStrLn (e <> ": " <> path) >> exitFailure
confirm :: String -> IO ()
confirm msg = do
putStr $ msg <> " (y/N): "
hFlush stdout
ok <- getLine
when (map toLower ok /= "y") exitFailure
publicKeyHash :: C.PublicKey -> B.ByteString
publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.encodePubKey
openStoreLog :: IniOpts -> IO (Maybe (StoreLog 'ReadMode))
openStoreLog IniOpts {enableStoreLog, storeLogFile = f}
| enableStoreLog = do
createDirectoryIfMissing True logDir
putStrLn ("store log: " <> f)
Just <$> openReadStoreLog f
| otherwise = putStrLn "store log disabled" $> Nothing
newtype ServerOpts = ServerOpts
{ configFile :: FilePath
}
serverOpts :: Parser ServerOpts
serverOpts =
ServerOpts
<$> strOption
( long "config"
<> short 'c'
<> metavar "INI_FILE"
<> help ("config file (" <> defaultIniFile <> ")")
<> value defaultIniFile
)
where
defaultIniFile = combine cfgDir "smp-server.ini"
getServerOpts :: IO ServerOpts
getServerOpts = execParser opts
where
opts =
info
(serverOpts <**> helper)
( fullDesc
<> header "Simplex Messaging Protocol (SMP) Server"
<> progDesc "Start server with INI_FILE (created on first run)"
)

View File

@@ -13,6 +13,8 @@ extra-source-files:
dependencies:
- ansi-terminal == 0.10.*
- asn1-encoding == 0.9.*
- asn1-types == 0.3.*
- async == 2.2.*
- attoparsec == 0.13.*
- base >= 4.7 && < 5
@@ -22,11 +24,13 @@ dependencies:
- cryptonite == 0.26.*
- directory == 1.3.*
- filepath == 1.4.*
- generic-random == 1.3.*
- iso8601-time == 0.1.*
- memory == 0.15.*
- mtl
- network == 3.1.*
- network-transport == 0.5.*
- QuickCheck == 2.13.*
- simple-logger == 0.1.*
- sqlite-simple == 0.4.*
- stm
@@ -36,6 +40,7 @@ dependencies:
- transformers == 0.5.*
- unliftio == 0.2.*
- unliftio-core == 0.1.*
- x509 == 1.7.*
library:
source-dirs: src
@@ -45,6 +50,9 @@ executables:
source-dirs: apps/smp-server
main: Main.hs
dependencies:
- cryptostore == 0.2.*
- ini == 0.4.*
- optparse-applicative == 0.15.*
- simplex-messaging
ghc-options:
- -threaded
@@ -78,6 +86,8 @@ tests:
- hspec-core == 2.7.*
- HUnit == 1.6.*
- random == 1.1.*
- QuickCheck == 2.13.*
- timeit == 2.0.*
ghc-options:
# - -haddock

View File

@@ -16,49 +16,60 @@ For initial implementation I propose approach to be as simple as possible as lon
One of the consideration is to use [noise protocol framework](https://noiseprotocol.org/noise.html), this section describes ad hoc protocol though.
During TCP session both client and server should use symmetric AES 256 bit encryption using the session key that will be established during the handshake.
During TCP session both client and server should use symmetric AES 256 bit encryption using two session keys and two base IVs that will be agreed during the handshake. Both client and the server should maintain two 32-bit word counters, one for sent and one for the received messages. The IV for each message should be computed by xor-ing the sequential message counter, starting from 0, with the first 32 bits of agreed base IV. TODO - explain it in a more formal way, also document how 32-bit word is encoded - with the most or least significant byte first (currently encodeWord32 from Network.Transport.Internal is used)
To establish the session key, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) and additional server ID (256 bits) in advance in order to be able to establish connection.
To establish the session keys and base IVs, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) in advance in order to be able to establish connection.
The handshake sequence could be the following:
The handshake sequence is the following:
1. Once the connection is established, the server sends its public key to the client
2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection.
3. If the hash is the same, the client should generate a random symmetric AES key and IV that will be used as a session key both by the client and the server.
4. The client then should encrypt this symmetric key with the public key that the server sent and send back to the server the result and the server ID also shared with the client in advance: `rsa-encrypt(aes-key, iv, server-id)`.
5. The server should decrypt the received key, IV and server id with its private key.
6. The server should compare the `server-id` sent by the client and if it does not match its ID terminate the connection.
7. In case of successful decryption and matching server ID, the server should send encrypted welcome header.
1. Once the connection is established, the server sends server_header and its public RSA key encoded in X509 binary format to the client.
2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. TODO as the hash is optional in server syntax at the moment, hash comparison will be optional as well. Probably it should become required.
3. If the hash is the same, the client should generate random symmetric AES keys and base IVs that will be used as session keys/IVs by the client and the server.
4. The client then should construct client_handshake block and send it to the server encrypted with the server public key: `rsa-encrypt(client_handshake)`. `snd_aes_key` and `snd_base_iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv_aes_key` and `rcv_base_iv` will be used by the client to decrypt **received** messages and by the server to encrypt them.
5. The server should decrypt the received keys and base IVs with its private key.
6. In case of successful decryption, the server should send encrypted welcome block (encrypted_welcome_block) that contains SMP protocol version.
```abnf
aes_welcome_header = aes_header_auth_tag aes_encrypted_header
welcome_header = smp_version ["," smp_mode] *SP ; decrypt(aes_encrypted_header) - 32 bytes
smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format
; for example: v123.456.789-alpha.7
smp_mode = smp_public / smp_authenticated
smp_public = %s"pub" ; public (default) - no auth to create and manage queues
smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues
aes_header_auth_tag = aes_auth_tag
aes_auth_tag = 16*16(OCTET)
```
No payload should follow this header, it is only used to confirm successful handshake and send the SMP protocol version that the server supports.
All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES key and IV sent by the client during the handshake.
All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES keys and base IVs (incremented by counters on both sides) sent by the client during the handshake.
Each transport block sent by the client and the server has this syntax:
```abnf
transport_block = aes_header_auth_tag aes_encrypted_header aes_body_auth_tag aes_encrypted_body
aes_encrypted_header = 32*32(OCTET)
header = padded_body_size payload_size reserved ; decrypt(aes_encrypted_header) - 32 bytes
server_header = block_size protocol key_size
block_size = 4*4(OCTET) ; 4-byte block size sent by the server, currently the client rejects if > 65536 bytes
protocol = 2*2(OCTET) ; currently it is 0, that means binary RSA key
key_size = 2*2(OCTET) ; the encoded key size in bytes (binary encoded in X509 standard)
client_handshake = client_block_size client_protocol snd_aes_key snd_base_iv rcv_aes_key rcv_base_iv
client_block_size = 4*4(OCTET) ; 4-byte block size sent by the client, currently it is ignored by the server - reserved
client_protocol = 2*2(OCTET) ; currently it is 0 - reserved
snd_aes_key = 32*32(OCTET)
snd_base_iv = 16*16(OCTET)
rcv_aes_key = 32*32(OCTET)
rcv_base_iv = 16*16(OCTET)
transport_block = aes_body_auth_tag aes_encrypted_body
; size is sent by server during handshake, usually 8192 bytes
aes_encrypted_body = 1*OCTET
body = payload pad
padded_body_size = size ; body size in bytes
payload_size = size ; payload_size in bytes
size = 4*4(OCTET)
reserved = 24*24(OCTET)
aes_body_auth_tag = aes_auth_tag
aes_body_auth_tag = 16*16(OCTET)
encrypted_welcome_block = transport_block
welcome_block = smp_version SP pad ; decrypt(encrypted_welcome_block)
smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format
; for example: v123.456.789-alpha.7
pad = 1*OCTET
```
## Possible future improvements/changes
- server id (256 bits), so that only the users that have it can connect to the server. This ID will have to be passed to the server during the handshake
- block size agreed during handshake
- transport encryption protocol agreed during handshake
- welcome block containing SMP mode (smp_mode)
```abnf
smp_mode = smp_public / smp_authenticated
smp_public = %s"pub" ; public (default) - no auth to create and manage queues
smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues
```
## Initial handshake
@@ -105,7 +116,9 @@ Symmetric keys are generated per message and encrypted with receiver's public ke
The syntax of each encrypted message body is the following:
```abnf
encrypted_message_body = rsa_encrypted_header aes_encrypted_body
encrypted_message_body = rsa_signature encrypted_body
encrypted_body = rsa_encrypted_header aes_encrypted_body
rsa_signature = 256*256(OCTET) ; sign(encrypted_body) - assuming 2048 bit key size
rsa_encrypted_header = 256*256(OCTET) ; encrypt(header) - assuming 2048 bit key size
aes_encrypted_body = 1*OCTET ; encrypt(body)
@@ -120,7 +133,6 @@ body = payload pad
Future considerations:
- Generation of symmetric keys per session and session rotation;
- Signature and verification of messages.
## E2E implementation

25
rfcs/2021-03-18-groups.md Normal file
View File

@@ -0,0 +1,25 @@
# SMP agent groups
## Problems
- device/user profile synchronisation
- chat group communication
Both problems would require message broadcast between a group of SMP agents.
## Solution: basic symmetric groups via SMP agent protocol
Additional commands and message envelopes to SMP agent protocol to provide an abstraction layer for device synchronisation and chat groups.
The groups are fully symmetric, all agent who are members of the group have equal rights and can join and leave group at any time.
All the information about the groups is stored only in agents, the commands are used to synchronise the group state between the agents.
```abnf
group_command = create_group / add_to_group / remove_from_group / leave_group
group_response = group_created / added_to_group / removed_from_group
group_notification = added_to_group_by / removed_from_group_by / left_group
create_group = %s"GNEW " group_name ; cAlias must be empty
add_to_group = %s"GADD " group_name ; cAlias is the connection to add to the group
added_to_group = %s"GADDED " name ; cAlias is the connection added to the group
```

View File

@@ -26,6 +26,7 @@ import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import Data.Time.Clock
import Database.SQLite.Simple (SQLError)
import Simplex.Messaging.Agent.Client
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Store
@@ -36,10 +37,9 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (CorrId (..), MsgBody, SenderPublicKey)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport (putLn, runTCPServer)
import Simplex.Messaging.Util (liftError)
import Simplex.Messaging.Util (bshow)
import System.IO (Handle)
import UnliftIO.Async (race_)
import UnliftIO.Exception (SomeException)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@@ -98,7 +98,7 @@ send h c@AgentClient {sndQ} = forever $ do
logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m ()
logClient AgentClient {clientId} dir (CorrId corrId, cAlias, cmd) = do
logInfo . decodeUtf8 $ B.unwords [B.pack $ show clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd]
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd]
client :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
client c@AgentClient {rcvQ, sndQ} st = forever $ do
@@ -114,10 +114,15 @@ withStore ::
withStore action = do
runExceptT (action `E.catch` handleInternal) >>= \case
Right c -> return c
Left _ -> throwError STORE
Left e -> throwError $ storeError e
where
handleInternal :: (MonadError StoreError m') => SomeException -> m' a
handleInternal _ = throwError SEInternal
handleInternal :: (MonadError StoreError m') => SQLError -> m' a
handleInternal e = throwError . SEInternal $ bshow e
storeError :: StoreError -> AgentErrorType
storeError = \case
SEConnNotFound -> CONN UNKNOWN
SEConnDuplicate -> CONN DUPLICATE
e -> INTERNAL $ show e
processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m ()
processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
@@ -156,9 +161,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
withStore (getConn st cAlias) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> subscribe rq
SomeConn _ (RcvConnection _ rq) -> subscribe rq
-- TODO possibly there should be a separate error type trying
-- TODO to send the message to the connection without RcvQueue
_ -> throwError PROHIBITED
_ -> throwError $ CONN SIMPLEX
where
subscribe rq = subscribeQueue c rq cAlias >> respond' cAlias OK
@@ -171,22 +174,32 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
withStore (getConn st connAlias) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
SomeConn _ (SndConnection _ sq) -> sendMsg sq
-- TODO possibly there should be a separate error type trying
-- TODO to send the message to the connection without SndQueue
_ -> throwError PROHIBITED -- NOT_READY ?
_ -> throwError $ CONN SIMPLEX
where
sendMsg sq = do
senderTs <- liftIO getCurrentTime
senderId <- withStore $ createSndMsg st connAlias msgBody senderTs
sendAgentMessage c sq senderTs $ A_MSG msgBody
respond $ SENT (unId senderId)
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
let msgStr =
serializeSMPMessage
SMPMessage
{ senderMsgId = unSndId internalSndId,
senderTimestamp = internalTs,
previousMsgHash,
agentMessage = A_MSG msgBody
}
msgHash = C.sha256Hash msgStr
withStore $
createSndMsg st sq $
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
sendAgentMessage c sq msgStr
respond $ SENT (unId internalId)
suspendConnection :: m ()
suspendConnection =
withStore (getConn st connAlias) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> suspend rq
SomeConn _ (RcvConnection _ rq) -> suspend rq
_ -> throwError PROHIBITED
_ -> throwError $ CONN SIMPLEX
where
suspend rq = suspendQueue c rq >> respond OK
@@ -195,20 +208,26 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
withStore (getConn st connAlias) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> delete rq
SomeConn _ (RcvConnection _ rq) -> delete rq
_ -> throwError PROHIBITED
_ -> delConn
where
delConn = withStore (deleteConn st connAlias) >> respond OK
delete rq = do
deleteQueue c rq
removeSubscription c connAlias
withStore (deleteConn st connAlias)
respond OK
delConn
sendReplyQInfo :: SMPServer -> SndQueue -> m ()
sendReplyQInfo srv sq = do
(rq, qInfo) <- newReceiveQueue c srv connAlias
withStore $ upgradeSndConnToDuplex st connAlias rq
senderTs <- liftIO getCurrentTime
sendAgentMessage c sq senderTs $ REPLY qInfo
senderTimestamp <- liftIO getCurrentTime
sendAgentMessage c sq . serializeSMPMessage $
SMPMessage
{ senderMsgId = 0,
senderTimestamp,
previousMsgHash = "",
agentMessage = REPLY qInfo
}
respond :: ACommand 'Agent -> m ()
respond = respond' connAlias
@@ -226,11 +245,13 @@ subscriber c@AgentClient {msgQ} st = forever $ do
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
rq@RcvQueue {connAlias, decryptKey, status} <- withStore $ getRcvQueue st srv rId
rq@RcvQueue {connAlias, status} <- withStore $ getRcvQueue st srv rId
case cmd of
SMP.MSG srvMsgId srvTs msgBody -> do
-- TODO deduplicate with previously received
agentMsg <- liftEither . parseSMPMessage =<< decryptMessage decryptKey msgBody
msg <- decryptAndVerify rq msgBody
let msgHash = C.sha256Hash msg
agentMsg <- liftEither $ parseSMPMessage msg
case agentMsg of
SMPConfirmation senderKey -> do
logServer "<--" c srv rId "MSG <KEY>"
@@ -242,50 +263,74 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
-- TODO update sender key in the store?
secureQueue c rq senderKey
withStore $ setRcvQueueStatus st rq Secured
sendAck c rq
s ->
-- TODO maybe send notification to the user
liftIO . putStrLn $ "unexpected SMP confirmation, queue status " <> show s
SMPMessage {agentMessage, senderMsgId, senderTimestamp} ->
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
case agentMessage of
HELLO _verifyKey _ -> do
HELLO verifyKey _ -> do
logServer "<--" c srv rId "MSG <HELLO>"
-- TODO send status update to the user?
withStore $ setRcvQueueStatus st rq Active
sendAck c rq
case status of
Active -> notify connAlias . ERR $ AGENT A_PROHIBITED
_ -> do
void $ verifyMessage (Just verifyKey) msgBody
withStore $ setRcvQueueActive st rq verifyKey
REPLY qInfo -> do
logServer "<--" c srv rId "MSG <REPLY>"
-- TODO move senderKey inside SndQueue
(sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias
withStore $ upgradeRcvConnToDuplex st connAlias sq
connectToSendQueue c st sq senderKey verifyKey
notify connAlias CON
sendAck c rq
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
-- TODO check message status
recipientTs <- liftIO getCurrentTime
let m_sender = (senderMsgId, senderTimestamp)
let m_broker = (srvMsgId, srvTs)
recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker
notify connAlias $
MSG
{ m_status = MsgOk,
m_recipient = (unId recipientId, recipientTs),
m_sender,
m_broker,
m_body = body
}
sendAck c rq
A_MSG body -> agentClientMsg rq previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
sendAck c rq
return ()
SMP.END -> do
removeSubscription c connAlias
logServer "<--" c srv rId "END"
notify connAlias END
_ -> logServer "<--" c srv rId $ "unexpected:" <> (B.pack . show) cmd
_ -> do
logServer "<--" c srv rId $ "unexpected: " <> bshow cmd
notify connAlias . ERR $ BROKER UNEXPECTED
where
notify :: ConnAlias -> ACommand 'Agent -> m ()
notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg)
agentClientMsg :: RcvQueue -> PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
agentClientMsg rq@RcvQueue {connAlias, status} receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
logServer "<--" c srv rId "MSG <MSG>"
case status of
Active -> do
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash
withStore $
createRcvMsg st rq $
RcvMsgData
{ internalId,
internalRcvId,
internalTs,
senderMeta,
brokerMeta,
msgBody,
internalHash = msgHash,
externalPrevSndHash = receivedPrevMsgHash,
msgIntegrity
}
notify connAlias $
MSG
{ recipientMeta = (unId internalId, internalTs),
senderMeta,
brokerMeta,
msgBody,
msgIntegrity
}
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
where
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> MsgIntegrity
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash
| extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk
| extSndId < prevExtSndId = MsgError $ MsgBadId extSndId
| extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate
| extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1)
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
| otherwise = MsgError MsgDuplicate -- this case is not possible
connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m ()
connectToSendQueue c st sq senderKey verifyKey = do
@@ -294,9 +339,6 @@ connectToSendQueue c st sq senderKey verifyKey = do
sendHello c sq verifyKey
withStore $ setSndQueueStatus st sq Active
decryptMessage :: (MonadUnliftIO m, MonadError AgentErrorType m) => DecryptionKey -> ByteString -> m ByteString
decryptMessage decryptKey msg = liftError CRYPTO $ C.decrypt decryptKey msg
newSendQueue ::
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey)
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do

View File

@@ -19,11 +19,14 @@ module Simplex.Messaging.Agent.Client
sendHello,
secureQueue,
sendAgentMessage,
decryptAndVerify,
verifyMessage,
sendAck,
suspendQueue,
deleteQueue,
logServer,
removeSubscription,
cryptoError,
)
where
@@ -47,7 +50,7 @@ import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey)
import Simplex.Messaging.Util (liftError)
import Simplex.Messaging.Util (bshow, liftEitherError, liftError)
import UnliftIO.Concurrent
import UnliftIO.Exception (IOException)
import qualified UnliftIO.Exception as E
@@ -86,15 +89,17 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
newSMPClient = do
smp <- connectClient
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
-- TODO how can agent know client lost the connection?
atomically . modifyTVar smpClients $ M.insert srv smp
return smp
connectClient :: m SMPClient
connectClient = do
cfg <- asks $ smpCfg . config
liftIO (getSMPClient srv cfg msgQ clientDisconnected)
`E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection)
liftEitherError smpClientError (getSMPClient srv cfg msgQ clientDisconnected)
`E.catch` internalError
where
internalError :: IOException -> m SMPClient
internalError = throwError . INTERNAL . show
clientDisconnected :: IO ()
clientDisconnected = do
@@ -118,31 +123,41 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient
withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMP c srv action =
(getSMPServerClient c srv >>= runAction) `catchError` logServerError
withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a
withSMP_ c srv action =
(getSMPServerClient c srv >>= action) `catchError` logServerError
where
runAction :: SMPClient -> m a
runAction smp = liftError smpClientError $ action smp
smpClientError :: SMPClientError -> AgentErrorType
smpClientError = \case
SMPServerError e -> SMP e
-- TODO handle other errors
_ -> INTERNAL
logServerError :: AgentErrorType -> m a
logServerError e = do
logServer "<--" c srv "" $ (B.pack . show) e
logServer "<--" c srv "" $ bshow e
throwError e
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withLogSMP c srv qId cmdStr action = do
withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a
withLogSMP_ c srv qId cmdStr action = do
logServer "-->" c srv qId cmdStr
res <- withSMP c srv action
res <- withSMP_ c srv action
logServer "<--" c srv qId "OK"
return res
withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withSMP c srv action = withSMP_ c srv $ liftSMP . action
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action
liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a
liftSMP = liftError smpClientError
smpClientError :: SMPClientError -> AgentErrorType
smpClientError = \case
SMPServerError e -> SMP e
SMPResponseError e -> BROKER $ RESPONSE e
SMPUnexpectedResponse -> BROKER UNEXPECTED
SMPResponseTimeout -> BROKER TIMEOUT
SMPNetworkError -> BROKER NETWORK
SMPTransportError e -> BROKER $ TRANSPORT e
e -> INTERNAL $ show e
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo)
newReceiveQueue c srv connAlias = do
size <- asks $ rsaKeySize . config
@@ -196,7 +211,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
logServer dir AgentClient {clientId} srv qId cmdStr =
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
showServer :: SMPServer -> ByteString
showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv)
@@ -205,35 +220,38 @@ logSecret :: ByteString -> ByteString
logSecret bs = encode $ B.take 3 bs
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> m ()
sendConfirmation c SndQueue {server, sndId, encryptKey} senderKey = do
msg <- mkConfirmation
withLogSMP c server sndId "SEND <KEY>" $ \smp ->
sendSMPMessage smp Nothing sndId msg
sendConfirmation c sq@SndQueue {server, sndId} senderKey =
withLogSMP_ c server sndId "SEND <KEY>" $ \smp -> do
msg <- mkConfirmation smp
liftSMP $ sendSMPMessage smp Nothing sndId msg
where
mkConfirmation :: m MsgBody
mkConfirmation = do
let msg = serializeSMPMessage $ SMPConfirmation senderKey
paddedSize <- asks paddedMsgSize
liftError CRYPTO $ C.encrypt encryptKey paddedSize msg
mkConfirmation :: SMPClient -> m MsgBody
mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey
sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m ()
sendHello c SndQueue {server, sndId, sndPrivateKey, encryptKey} verifyKey = do
msg <- mkHello $ AckMode On
withLogSMP c server sndId "SEND <HELLO> (retrying)" $
send 20 msg
sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey =
withLogSMP_ c server sndId "SEND <HELLO> (retrying)" $ \smp -> do
msg <- mkHello smp $ AckMode On
liftSMP $ send 8 100000 msg smp
where
mkHello :: AckMode -> m ByteString
mkHello ackMode = do
senderTs <- liftIO getCurrentTime
mkAgentMessage encryptKey senderTs $ HELLO verifyKey ackMode
mkHello :: SMPClient -> AckMode -> m ByteString
mkHello smp ackMode = do
senderTimestamp <- liftIO getCurrentTime
encryptAndSign smp sq . serializeSMPMessage $
SMPMessage
{ senderMsgId = 0,
senderTimestamp,
previousMsgHash = "",
agentMessage = HELLO verifyKey ackMode
}
send :: Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO ()
send 0 _ _ = throwE SMPResponseTimeout -- TODO different error
send retry msg smp =
send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO ()
send 0 _ _ _ = throwE $ SMPServerError AUTH
send retry delay msg smp =
sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case
SMPServerError AUTH -> do
threadDelay 100000
send (retry - 1) msg smp
threadDelay delay
send (retry - 1) (delay * 3 `div` 2) msg smp
e -> throwE e
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m ()
@@ -256,21 +274,39 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
withLogSMP c server rcvId "DEL" $ \smp ->
deleteSMPQueue smp rcvPrivateKey rcvId
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m ()
sendAgentMessage c SndQueue {server, sndId, sndPrivateKey, encryptKey} senderTs agentMsg = do
msg <- mkAgentMessage encryptKey senderTs agentMsg
withLogSMP c server sndId "SEND <message>" $ \smp ->
sendSMPMessage smp (Just sndPrivateKey) sndId msg
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msg =
withLogSMP_ c server sndId "SEND <message>" $ \smp -> do
msg' <- encryptAndSign smp sq msg
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg'
mkAgentMessage :: AgentMonad m => EncryptionKey -> SenderTimestamp -> AMessage -> m ByteString
mkAgentMessage encKey senderTs agentMessage = do
let msg =
serializeSMPMessage
SMPMessage
{ senderMsgId = 0,
senderTimestamp = senderTs,
previousMsgHash = "1234", -- TODO hash of the previous message
agentMessage
}
paddedSize <- asks paddedMsgSize
liftError CRYPTO $ C.encrypt encKey paddedSize msg
encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> ByteString -> m ByteString
encryptAndSign smp SndQueue {encryptKey, signKey} msg = do
paddedSize <- asks $ (blockSize smp -) . reservedMsgSize
liftError cryptoError $ do
enc <- C.encrypt encryptKey paddedSize msg
C.Signature sig <- C.sign signKey enc
pure $ sig <> enc
decryptAndVerify :: AgentMonad m => RcvQueue -> ByteString -> m ByteString
decryptAndVerify RcvQueue {decryptKey, verifyKey} msg =
verifyMessage verifyKey msg
>>= liftError cryptoError . C.decrypt decryptKey
verifyMessage :: AgentMonad m => Maybe VerificationKey -> ByteString -> m ByteString
verifyMessage verifyKey msg = do
size <- asks $ rsaKeySize . config
let (sig, enc) = B.splitAt size msg
case verifyKey of
Nothing -> pure enc
Just k
| C.verify k (C.Signature sig) enc -> pure enc
| otherwise -> throwError $ AGENT A_SIGNATURE
cryptoError :: C.CryptoError -> AgentErrorType
cryptoError = \case
C.CryptoLargeMsgError -> CMD LARGE
C.RSADecryptError _ -> AGENT A_ENCRYPTION
C.CryptoHeaderError _ -> AGENT A_ENCRYPTION
C.AESDecryptError -> AGENT A_ENCRYPTION
e -> INTERNAL $ show e

View File

@@ -26,7 +26,7 @@ data Env = Env
{ config :: AgentConfig,
idsDrg :: TVar ChaChaDRG,
clientCounter :: TVar Int,
paddedMsgSize :: Int
reservedMsgSize :: Int
}
newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
@@ -34,10 +34,10 @@ newSMPAgentEnv config = do
idsDrg <- drgNew >>= newTVarIO
_ <- createSQLiteStore $ dbFile config
clientCounter <- newTVarIO 0
return Env {config, idsDrg, clientCounter, paddedMsgSize}
return Env {config, idsDrg, clientCounter, reservedMsgSize}
where
-- one rsaKeySize is used by the RSA signature in each command,
-- another - by encrypted message body header
-- 1st rsaKeySize is used by the RSA signature in each command,
-- 2nd - by encrypted message body header
-- 3rd - by message signature
-- smpCommandSize - is the estimated max size for SMP command, queueId, corrId
paddedMsgSize = blockSize smp - 2 * rsaKeySize config - smpCommandSize smp
smp = smpCfg config
reservedMsgSize = 3 * rsaKeySize config + smpCommandSize (smpCfg config)

View File

@@ -10,6 +10,7 @@
module Simplex.Messaging.Agent.Store where
import Control.Exception (Exception)
import Data.ByteString.Char8 (ByteString)
import Data.Int (Int64)
import Data.Kind (Type)
import Data.Time (UTCTime)
@@ -38,11 +39,16 @@ class Monad m => MonadAgentStore s m where
upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m ()
upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m ()
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m ()
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
-- Msg management
createRcvMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId
createSndMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> m InternalId
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
getMsg :: s -> ConnAlias -> InternalId -> m Msg
-- * Queue types
@@ -102,6 +108,11 @@ data SConnType :: ConnType -> Type where
SCSnd :: SConnType CSnd
SCDuplex :: SConnType CDuplex
connType :: SConnType c -> ConnType
connType SCRcv = CRcv
connType SCSnd = CSnd
connType SCDuplex = CDuplex
deriving instance Eq (SConnType d)
deriving instance Show (SConnType d)
@@ -123,6 +134,42 @@ instance Eq SomeConn where
deriving instance Show SomeConn
-- * Message integrity validation types
type MsgHash = ByteString
-- | Corresponds to `last_external_snd_msg_id` in `connections` table
type PrevExternalSndId = Int64
-- | Corresponds to `last_rcv_msg_hash` in `connections` table
type PrevRcvMsgHash = MsgHash
-- | Corresponds to `last_snd_msg_hash` in `connections` table
type PrevSndMsgHash = MsgHash
-- ? merge/replace these with RcvMsg and SndMsg
-- * Message data containers - used on Msg creation to reduce number of parameters
data RcvMsgData = RcvMsgData
{ internalId :: InternalId,
internalRcvId :: InternalRcvId,
internalTs :: InternalTs,
senderMeta :: (ExternalSndId, ExternalSndTs),
brokerMeta :: (BrokerId, BrokerTs),
msgBody :: MsgBody,
internalHash :: MsgHash,
externalPrevSndHash :: MsgHash,
msgIntegrity :: MsgIntegrity
}
data SndMsgData = SndMsgData
{ internalId :: InternalId,
internalSndId :: InternalSndId,
internalTs :: InternalTs,
msgBody :: MsgBody,
internalHash :: MsgHash
}
-- * Message types
-- | A message in either direction that is stored by the agent.
@@ -149,7 +196,10 @@ data RcvMsg = RcvMsg
-- | Timestamp of acknowledgement to sender, corresponds to `AcknowledgedToSender` status.
-- Do not mix up with `externalSndTs` - timestamp created at sender before sending,
-- which in its turn corresponds to `internalTs` in sending agent.
ackSenderTs :: AckSenderTs
ackSenderTs :: AckSenderTs,
-- | Hash of previous message as received from sender - stored for integrity forensics.
externalPrevSndHash :: MsgHash,
msgIntegrity :: MsgIntegrity
}
deriving (Eq, Show)
@@ -209,7 +259,9 @@ data MsgBase = MsgBase
-- due to a possibility of implementation errors in different agents.
internalId :: InternalId,
internalTs :: InternalTs,
msgBody :: MsgBody
msgBody :: MsgBody,
-- | Hash of the message as computed by agent.
internalHash :: MsgHash
}
deriving (Eq, Show)
@@ -219,13 +271,11 @@ type InternalTs = UTCTime
-- * Store errors
-- TODO revise
data StoreError
= SEInternal
| SENotFound
| SEBadConn
= SEInternal ByteString
| SEConnNotFound
| SEConnDuplicate
| SEBadConnType ConnType
| SEBadQueueStatus
| SEBadQueueDirection
| SEBadQueueStatus -- not used, planned to check strictly
| SENotImplemented -- TODO remove
deriving (Eq, Show, Exception)

View File

@@ -1,5 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@@ -21,24 +23,26 @@ where
import Control.Monad (when)
import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Data.Bifunctor (first)
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Text (isPrefixOf)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import Database.SQLite.Simple as DB
import Database.SQLite.Simple (FromRow, NamedParam (..), SQLData (..), SQLError, field)
import qualified Database.SQLite.Simple as DB
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok (Ok (Ok))
import Database.SQLite.Simple.QQ (sql)
import Database.SQLite.Simple.ToField (ToField (..))
import Network.Socket (HostName, ServiceName)
import Network.Socket (ServiceName)
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema)
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Protocol (MsgBody)
import Simplex.Messaging.Parsers (parseAll)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Util (liftIOEither)
import Simplex.Messaging.Util (bshow, liftIOEither)
import System.Exit (ExitCode (ExitFailure), exitWith)
import System.FilePath (takeDirectory)
import Text.Read (readMaybe)
@@ -75,76 +79,170 @@ connectSQLiteStore dbFilePath = do
liftIO $ DB.execute_ dbConn "PRAGMA foreign_keys = ON;"
return SQLiteStore {dbFilePath, dbConn}
checkDuplicate :: (MonadUnliftIO m, MonadError StoreError m) => IO () -> m ()
checkDuplicate action = liftIOEither $ first handleError <$> E.try action
where
handleError :: SQLError -> StoreError
handleError e
| DB.sqlError e == DB.ErrorConstraint = SEConnDuplicate
| otherwise = SEInternal $ bshow e
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
createRcvConn SQLiteStore {dbConn} rcvQueue =
liftIO $
createRcvQueueAndConn dbConn rcvQueue
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
checkDuplicate $
DB.withTransaction dbConn $ do
upsertServer_ dbConn server
insertRcvQueue_ dbConn q
insertRcvConnection_ dbConn q
createSndConn :: SQLiteStore -> SndQueue -> m ()
createSndConn SQLiteStore {dbConn} sndQueue =
liftIO $
createSndQueueAndConn dbConn sndQueue
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
checkDuplicate $
DB.withTransaction dbConn $ do
upsertServer_ dbConn server
insertSndQueue_ dbConn q
insertSndConnection_ dbConn q
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
getConn SQLiteStore {dbConn} connAlias = do
queues <-
liftIO $
retrieveConnQueues dbConn connAlias
case queues of
(Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
(Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
(Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ)
_ -> throwError SEBadConn
getConn SQLiteStore {dbConn} connAlias =
liftIOEither . DB.withTransaction dbConn $
getConn_ dbConn connAlias
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
getAllConnAliases SQLiteStore {dbConn} =
liftIO $
retrieveAllConnAliases dbConn
liftIO $ do
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
return (concat r)
getRcvQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m RcvQueue
getRcvQueue SQLiteStore {dbConn} SMPServer {host, port} rcvId = do
rcvQueue <-
r <-
liftIO $
retrieveRcvQueue dbConn host port rcvId
case rcvQueue of
Just rcvQ -> return rcvQ
_ -> throwError SENotFound
DB.queryNamed
dbConn
[sql|
SELECT
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id;
|]
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
case r of
[(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] ->
let srv = SMPServer hst (deserializePort_ prt) keyHash
in pure $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
_ -> throwError SEConnNotFound
deleteConn :: SQLiteStore -> ConnAlias -> m ()
deleteConn SQLiteStore {dbConn} connAlias =
liftIO $
deleteConnCascade dbConn connAlias
DB.executeNamed
dbConn
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
[":conn_alias" := connAlias]
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sndQueue =
liftIOEither $
updateRcvConnWithSndQueue dbConn connAlias sndQueue
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
liftIOEither . DB.withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
Right (SomeConn SCRcv (RcvConnection _ _)) -> do
upsertServer_ dbConn server
insertSndQueue_ dbConn sq
updateConnWithSndQueue_ dbConn connAlias sq
pure $ Right ()
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rcvQueue =
liftIOEither $
updateSndConnWithRcvQueue dbConn connAlias rcvQueue
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
liftIOEither . DB.withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
Right (SomeConn SCSnd (SndConnection _ _)) -> do
upsertServer_ dbConn server
insertRcvQueue_ dbConn rq
updateConnWithRcvQueue_ dbConn connAlias rq
pure $ Right ()
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m ()
setRcvQueueStatus SQLiteStore {dbConn} rcvQueue status =
setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status =
-- ? throw error if queue doesn't exist?
liftIO $
updateRcvQueueStatus dbConn rcvQueue status
DB.executeNamed
dbConn
[sql|
UPDATE rcv_queues
SET status = :status
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|]
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m ()
setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey =
-- ? throw error if queue doesn't exist?
liftIO $
DB.executeNamed
dbConn
[sql|
UPDATE rcv_queues
SET verify_key = :verify_key, status = :status
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|]
[ ":verify_key" := Just verifyKey,
":status" := Active,
":host" := host,
":port" := serializePort_ port,
":rcv_id" := rcvId
]
setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m ()
setSndQueueStatus SQLiteStore {dbConn} sndQueue status =
setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status =
-- ? throw error if queue doesn't exist?
liftIO $
updateSndQueueStatus dbConn sndQueue status
DB.executeNamed
dbConn
[sql|
UPDATE snd_queues
SET status = :status
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|]
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
createRcvMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId
createRcvMsg SQLiteStore {dbConn} connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) =
liftIOEither $
insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs)
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
liftIO . DB.withTransaction dbConn $ do
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
let internalId = InternalId $ unId lastInternalId + 1
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
createSndMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> m InternalId
createSndMsg SQLiteStore {dbConn} connAlias msgBody internalTs =
liftIOEither $
insertSndMsg dbConn connAlias msgBody internalTs
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
liftIO . DB.withTransaction dbConn $ do
insertRcvMsgBase_ dbConn connAlias rcvMsgData
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
updateHashRcv_ dbConn connAlias rcvMsgData
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
liftIO . DB.withTransaction dbConn $ do
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
let internalId = InternalId $ unId lastInternalId + 1
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
pure (internalId, internalSndId, prevSndHash)
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
liftIO . DB.withTransaction dbConn $ do
insertSndMsgBase_ dbConn connAlias sndMsgData
insertSndMsgDetails_ dbConn connAlias sndMsgData
updateHashSnd_ dbConn connAlias sndMsgData
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
getMsg _st _connAlias _id = throwError SENotImplemented
@@ -179,6 +277,16 @@ instance ToField RcvMsgStatus where toField = toField . show
instance ToField SndMsgStatus where toField = toField . show
instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity
instance FromField MsgIntegrity where
fromField = \case
f@(Field (SQLBlob b) _) ->
case parseAll msgIntegrityP b of
Right k -> Ok k
Left e -> returnError ConversionFailed f ("can't parse msg integrity field: " ++ e)
f -> returnError ConversionFailed f "expecting SQLBlob column type"
fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a
fromFieldToReadable_ = \case
f@(Field (SQLText t) _) ->
@@ -217,13 +325,6 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
-- * createRcvConn helpers
createRcvQueueAndConn :: DB.Connection -> RcvQueue -> IO ()
createRcvQueueAndConn dbConn rcvQueue =
DB.withTransaction dbConn $ do
upsertServer_ dbConn (server (rcvQueue :: RcvQueue))
insertRcvQueue_ dbConn rcvQueue
insertRcvConnection_ dbConn rcvQueue
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
insertRcvQueue_ dbConn RcvQueue {..} = do
let port_ = serializePort_ $ port server
@@ -255,22 +356,16 @@ insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
[sql|
INSERT INTO connections
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id)
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
VALUES
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
0, 0, 0);
0, 0, 0, 0, x'', x'');
|]
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
-- * createSndConn helpers
createSndQueueAndConn :: DB.Connection -> SndQueue -> IO ()
createSndQueueAndConn dbConn sndQueue =
DB.withTransaction dbConn $ do
upsertServer_ dbConn (server (sndQueue :: SndQueue))
insertSndQueue_ dbConn sndQueue
insertSndConnection_ dbConn sndQueue
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
insertSndQueue_ dbConn SndQueue {..} = do
let port_ = serializePort_ $ port server
@@ -300,28 +395,25 @@ insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
[sql|
INSERT INTO connections
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id)
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
VALUES
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
0, 0, 0);
0, 0, 0, 0, x'', x'');
|]
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
-- * getConn helpers
retrieveConnQueues :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue)
retrieveConnQueues dbConn connAlias =
DB.withTransaction -- Avoid inconsistent state between queue reads
dbConn
$ retrieveConnQueues_ dbConn connAlias
-- Separate transactionless version of retrieveConnQueues to be reused in other functions that already wrap
-- multiple statements in transaction - otherwise they'd be attempting to start a transaction within a transaction
retrieveConnQueues_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue)
retrieveConnQueues_ dbConn connAlias = do
rcvQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
sndQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
return (rcvQ, sndQ)
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
getConn_ dbConn connAlias = do
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
pure $ case (rQ, sQ) of
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
_ -> Left SEConnNotFound
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
@@ -363,60 +455,8 @@ retrieveSndQueueByConnAlias_ dbConn connAlias = do
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
_ -> return Nothing
-- * getAllConnAliases helper
retrieveAllConnAliases :: DB.Connection -> IO [ConnAlias]
retrieveAllConnAliases dbConn = do
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
return (concat r)
-- * getRcvQueue helper
retrieveRcvQueue :: DB.Connection -> HostName -> Maybe ServiceName -> SMP.RecipientId -> IO (Maybe RcvQueue)
retrieveRcvQueue dbConn host port rcvId = do
r <-
DB.queryNamed
dbConn
[sql|
SELECT
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id;
|]
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
case r of
[(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
let srv = SMPServer hst (deserializePort_ prt) keyHash
return . Just $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
_ -> return Nothing
-- * deleteConn helper
deleteConnCascade :: DB.Connection -> ConnAlias -> IO ()
deleteConnCascade dbConn connAlias =
DB.executeNamed
dbConn
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
[":conn_alias" := connAlias]
-- * upgradeRcvConnToDuplex helpers
updateRcvConnWithSndQueue :: DB.Connection -> ConnAlias -> SndQueue -> IO (Either StoreError ())
updateRcvConnWithSndQueue dbConn connAlias sndQueue =
DB.withTransaction dbConn $ do
queues <- retrieveConnQueues_ dbConn connAlias
case queues of
(Just _rcvQ, Nothing) -> do
upsertServer_ dbConn (server (sndQueue :: SndQueue))
insertSndQueue_ dbConn sndQueue
updateConnWithSndQueue_ dbConn connAlias sndQueue
return $ Right ()
(Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd)
(Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex)
_ -> return $ Left SEBadConn
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
let port_ = serializePort_ $ port server
@@ -431,20 +471,6 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
-- * upgradeSndConnToDuplex helpers
updateSndConnWithRcvQueue :: DB.Connection -> ConnAlias -> RcvQueue -> IO (Either StoreError ())
updateSndConnWithRcvQueue dbConn connAlias rcvQueue =
DB.withTransaction dbConn $ do
queues <- retrieveConnQueues_ dbConn connAlias
case queues of
(Nothing, Just _sndQ) -> do
upsertServer_ dbConn (server (rcvQueue :: RcvQueue))
insertRcvQueue_ dbConn rcvQueue
updateConnWithRcvQueue_ dbConn connAlias rcvQueue
return $ Right ()
(Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv)
(Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex)
_ -> return $ Left SEBadConn
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
let port_ = serializePort_ $ port server
@@ -457,74 +483,40 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|]
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
-- * setRcvQueueStatus helper
-- * updateRcvIds helpers
-- ? throw error if queue doesn't exist?
updateRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} status =
DB.executeNamed
dbConn
[sql|
UPDATE rcv_queues
SET status = :status
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|]
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
-- * setSndQueueStatus helper
-- ? throw error if queue doesn't exist?
updateSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO ()
updateSndQueueStatus dbConn SndQueue {sndId, server = SMPServer {host, port}} status =
DB.executeNamed
dbConn
[sql|
UPDATE snd_queues
SET status = :status
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|]
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
-- * createRcvMsg helpers
insertRcvMsg ::
DB.Connection ->
ConnAlias ->
MsgBody ->
InternalTs ->
(ExternalSndId, ExternalSndTs) ->
(BrokerId, BrokerTs) ->
IO (Either StoreError InternalId)
insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) =
DB.withTransaction dbConn $ do
queues <- retrieveConnQueues_ dbConn connAlias
case queues of
(Just _rcvQ, _) -> do
(lastInternalId, lastInternalRcvId) <- retrieveLastInternalIdsRcv_ dbConn connAlias
let internalId = InternalId $ unId lastInternalId + 1
let internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody
insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs)
updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId
return $ Right internalId
(Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd)
_ -> return $ Left SEBadConn
retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId)
retrieveLastInternalIdsRcv_ dbConn connAlias = do
[(lastInternalId, lastInternalRcvId)] <-
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
DB.queryNamed
dbConn
[sql|
SELECT last_internal_msg_id, last_internal_rcv_msg_id
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
return (lastInternalId, lastInternalRcvId)
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalRcvId -> MsgBody -> IO ()
insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = do
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
DB.executeNamed
dbConn
[sql|
UPDATE connections
SET last_internal_msg_id = :last_internal_msg_id,
last_internal_rcv_msg_id = :last_internal_rcv_msg_id
WHERE conn_alias = :conn_alias;
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_rcv_msg_id" := newInternalRcvId,
":conn_alias" := connAlias
]
-- * createRcvMsg helpers
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
DB.executeNamed
dbConn
[sql|
@@ -540,82 +532,85 @@ insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody =
":body" := decodeUtf8 msgBody
]
insertRcvMsgDetails_ ::
DB.Connection ->
ConnAlias ->
InternalRcvId ->
InternalId ->
(ExternalSndId, ExternalSndTs) ->
(BrokerId, BrokerTs) ->
IO ()
insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) =
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
DB.executeNamed
dbConn
[sql|
INSERT INTO rcv_messages
( conn_alias, internal_rcv_id, internal_id, external_snd_id, external_snd_ts,
broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts)
broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts,
internal_hash, external_prev_snd_hash, integrity)
VALUES
(:conn_alias,:internal_rcv_id,:internal_id,:external_snd_id,:external_snd_ts,
:broker_id,:broker_ts,:rcv_status, NULL, NULL);
:broker_id,:broker_ts,:rcv_status, NULL, NULL,
:internal_hash,:external_prev_snd_hash,:integrity);
|]
[ ":conn_alias" := connAlias,
":internal_rcv_id" := internalRcvId,
":internal_id" := internalId,
":external_snd_id" := externalSndId,
":external_snd_ts" := externalSndTs,
":broker_id" := brokerId,
":broker_ts" := brokerTs,
":rcv_status" := Received
":external_snd_id" := fst senderMeta,
":external_snd_ts" := snd senderMeta,
":broker_id" := fst brokerMeta,
":broker_ts" := snd brokerMeta,
":rcv_status" := Received,
":internal_hash" := internalHash,
":external_prev_snd_hash" := externalPrevSndHash,
":integrity" := msgIntegrity
]
updateLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
updateLastInternalIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
DB.executeNamed
dbConn
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
[sql|
UPDATE connections
SET last_external_snd_msg_id = :last_external_snd_msg_id,
last_rcv_msg_hash = :last_rcv_msg_hash
WHERE conn_alias = :conn_alias
AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id;
|]
[ ":last_external_snd_msg_id" := fst senderMeta,
":last_rcv_msg_hash" := internalHash,
":conn_alias" := connAlias,
":last_internal_rcv_msg_id" := internalRcvId
]
-- * updateSndIds helpers
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
DB.queryNamed
dbConn
[sql|
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
return (lastInternalId, lastInternalSndId, lastSndHash)
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
DB.executeNamed
dbConn
[sql|
UPDATE connections
SET last_internal_msg_id = :last_internal_msg_id, last_internal_rcv_msg_id = :last_internal_rcv_msg_id
SET last_internal_msg_id = :last_internal_msg_id,
last_internal_snd_msg_id = :last_internal_snd_msg_id
WHERE conn_alias = :conn_alias;
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_rcv_msg_id" := newInternalRcvId,
":last_internal_snd_msg_id" := newInternalSndId,
":conn_alias" := connAlias
]
-- * createSndMsg helpers
insertSndMsg :: DB.Connection -> ConnAlias -> MsgBody -> InternalTs -> IO (Either StoreError InternalId)
insertSndMsg dbConn connAlias msgBody internalTs =
DB.withTransaction dbConn $ do
queues <- retrieveConnQueues_ dbConn connAlias
case queues of
(_, Just _sndQ) -> do
(lastInternalId, lastInternalSndId) <- retrieveLastInternalIdsSnd_ dbConn connAlias
let internalId = InternalId $ unId lastInternalId + 1
let internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody
insertSndMsgDetails_ dbConn connAlias internalSndId internalId
updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId
return $ Right internalId
(Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv)
_ -> return $ Left SEBadConn
retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId)
retrieveLastInternalIdsSnd_ dbConn connAlias = do
[(lastInternalId, lastInternalSndId)] <-
DB.queryNamed
dbConn
[sql|
SELECT last_internal_msg_id, last_internal_snd_msg_id
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
return (lastInternalId, lastInternalSndId)
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalSndId -> MsgBody -> IO ()
insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = do
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
DB.executeNamed
dbConn
[sql|
@@ -631,32 +626,35 @@ insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody =
":body" := decodeUtf8 msgBody
]
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> InternalSndId -> InternalId -> IO ()
insertSndMsgDetails_ dbConn connAlias internalSndId internalId =
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
DB.executeNamed
dbConn
[sql|
INSERT INTO snd_messages
( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts)
( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts, internal_hash)
VALUES
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL);
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash);
|]
[ ":conn_alias" := connAlias,
":internal_snd_id" := internalSndId,
":internal_id" := internalId,
":snd_status" := Created
":snd_status" := Created,
":internal_hash" := internalHash
]
updateLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
updateLastInternalIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
updateHashSnd_ dbConn connAlias SndMsgData {..} =
DB.executeNamed
dbConn
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
[sql|
UPDATE connections
SET last_internal_msg_id = :last_internal_msg_id, last_internal_snd_msg_id = :last_internal_snd_msg_id
WHERE conn_alias = :conn_alias;
SET last_snd_msg_hash = :last_snd_msg_hash
WHERE conn_alias = :conn_alias
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_snd_msg_id" := newInternalSndId,
":conn_alias" := connAlias
[ ":last_snd_msg_hash" := internalHash,
":conn_alias" := connAlias,
":last_internal_snd_msg_id" := internalSndId
]

View File

@@ -90,6 +90,9 @@ connections =
last_internal_msg_id INTEGER NOT NULL,
last_internal_rcv_msg_id INTEGER NOT NULL,
last_internal_snd_msg_id INTEGER NOT NULL,
last_external_snd_msg_id INTEGER NOT NULL,
last_rcv_msg_hash BLOB NOT NULL,
last_snd_msg_hash BLOB NOT NULL,
PRIMARY KEY (conn_alias),
FOREIGN KEY (rcv_host, rcv_port, rcv_id) REFERENCES rcv_queues (host, port, rcv_id),
FOREIGN KEY (snd_host, snd_port, snd_id) REFERENCES snd_queues (host, port, snd_id)
@@ -135,6 +138,9 @@ rcvMessages =
rcv_status TEXT NOT NULL,
ack_brocker_ts TEXT,
ack_sender_ts TEXT,
internal_hash BLOB NOT NULL,
external_prev_snd_hash BLOB NOT NULL,
integrity BLOB NOT NULL,
PRIMARY KEY (conn_alias, internal_rcv_id),
FOREIGN KEY (conn_alias, internal_id)
REFERENCES messages (conn_alias, internal_id)
@@ -152,6 +158,7 @@ sndMessages =
snd_status TEXT NOT NULL,
sent_ts TEXT,
delivered_ts TEXT,
internal_hash BLOB NOT NULL,
PRIMARY KEY (conn_alias, internal_snd_id),
FOREIGN KEY (conn_alias, internal_id)
REFERENCES messages (conn_alias, internal_id)

View File

@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -12,7 +13,7 @@
module Simplex.Messaging.Agent.Transmission where
import Control.Applicative ((<|>))
import Control.Applicative (optional, (<|>))
import Control.Monad.IO.Class
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
@@ -21,13 +22,14 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Int (Int64)
import Data.Kind
import Data.Kind (Type)
import Data.Time.Clock (UTCTime)
import Data.Time.ISO8601
import Data.Type.Equality
import Data.Typeable ()
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import Network.Socket
import Numeric.Natural
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol
@@ -37,12 +39,12 @@ import Simplex.Messaging.Protocol
MsgBody,
MsgId,
SenderPublicKey,
errMessageBody,
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Transport
import Simplex.Messaging.Util
import System.IO
import Test.QuickCheck (Arbitrary (..))
import Text.Read
import UnliftIO.Exception
@@ -88,11 +90,11 @@ data ACommand (p :: AParty) where
SEND :: MsgBody -> ACommand Client
SENT :: AgentMsgId -> ACommand Agent
MSG ::
{ m_recipient :: (AgentMsgId, UTCTime),
m_broker :: (MsgId, UTCTime),
m_sender :: (AgentMsgId, UTCTime),
m_status :: MsgStatus,
m_body :: MsgBody
{ recipientMeta :: (AgentMsgId, UTCTime),
brokerMeta :: (MsgId, UTCTime),
senderMeta :: (AgentMsgId, UTCTime),
msgIntegrity :: MsgIntegrity,
msgBody :: MsgBody
} ->
ACommand Agent
-- ACK :: AgentMsgId -> ACommand Client
@@ -125,7 +127,7 @@ data AMessage where
deriving (Show)
parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage
parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage
parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE
where
smpMessageP :: Parser SMPMessage
smpMessageP =
@@ -140,7 +142,9 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage
SMPMessage
<$> A.decimal <* A.space
<*> tsISO8601P <* A.space
<*> base64P <* A.endOfLine
-- TODO previous message hash should become mandatory when we support HELLO and REPLY
-- (for HELLO it would be the hash of SMPConfirmation)
<*> (base64P <|> pure "") <* A.endOfLine
<*> agentMessageP
serializeSMPMessage :: SMPMessage -> ByteString
@@ -152,7 +156,7 @@ serializeSMPMessage = \case
in smpMessage "" header body
where
messageHeader msgId ts prevMsgHash =
B.unwords [B.pack $ show msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash]
B.unwords [bshow msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash]
smpMessage smpHeader aHeader aBody = B.intercalate "\n" [smpHeader, aHeader, aBody, ""]
agentMessageP :: Parser AMessage
@@ -164,8 +168,8 @@ agentMessageP =
hello = HELLO <$> C.pubKeyP <*> ackMode
reply = REPLY <$> smpQueueInfoP
a_msg = do
size :: Int <- A.decimal
A_MSG <$> (A.endOfLine *> A.take size <* A.endOfLine)
size :: Int <- A.decimal <* A.endOfLine
A_MSG <$> A.take size <* A.endOfLine
ackMode = " NO_ACK" $> AckMode Off <|> pure (AckMode On)
smpQueueInfoP :: Parser SMPQueueInfo
@@ -173,14 +177,14 @@ smpQueueInfoP =
"smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP)
smpServerP :: Parser SMPServer
smpServerP = SMPServer <$> server <*> port <*> msgHash
smpServerP = SMPServer <$> server <*> optional port <*> optional kHash
where
server = B.unpack <$> A.takeTill (A.inClass ":# ")
port = A.char ':' *> (Just . show <$> (A.decimal :: Parser Int)) <|> pure Nothing
msgHash = A.char '#' *> (Just <$> base64P) <|> pure Nothing
port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit)
kHash = A.char '#' *> C.keyHashP
parseAgentMessage :: ByteString -> Either AgentErrorType AMessage
parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage
parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE
serializeAgentMessage :: AMessage -> ByteString
serializeAgentMessage = \case
@@ -194,17 +198,15 @@ serializeSmpQueueInfo (SMPQueueInfo srv qId ek) =
serializeServer :: SMPServer -> ByteString
serializeServer SMPServer {host, port, keyHash} =
B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash
B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack . C.serializeKeyHash) keyHash
data SMPServer = SMPServer
{ host :: HostName,
port :: Maybe ServiceName,
keyHash :: Maybe KeyHash
keyHash :: Maybe C.KeyHash
}
deriving (Eq, Ord, Show)
type KeyHash = Encoded
type ConnAlias = ByteString
type OtherPartyId = Encoded
@@ -220,9 +222,9 @@ data ReplyMode = ReplyOff | ReplyOn | ReplyVia SMPServer deriving (Eq, Show)
type EncryptionKey = C.PublicKey
type DecryptionKey = C.PrivateKey
type DecryptionKey = C.SafePrivateKey
type SignatureKey = C.PrivateKey
type SignatureKey = C.SafePrivateKey
type VerificationKey = C.PublicKey
@@ -235,56 +237,60 @@ type AgentMsgId = Int64
type SenderTimestamp = UTCTime
data MsgStatus = MsgOk | MsgError MsgErrorType
data MsgIntegrity = MsgOk | MsgError MsgErrorType
deriving (Eq, Show)
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash | MsgDuplicate
deriving (Eq, Show)
-- | error type used in errors sent to agent clients
data AgentErrorType
= UNKNOWN
| PROHIBITED
| SYNTAX Int
| BROKER Natural
| SMP ErrorType
| CRYPTO C.CryptoError
| SIZE
| STORE
| INTERNAL -- etc. TODO SYNTAX Natural
deriving (Eq, Show, Exception)
= CMD CommandErrorType -- command errors
| CONN ConnectionErrorType -- connection state errors
| SMP ErrorType -- SMP protocol errors forwarded to agent clients
| BROKER BrokerErrorType -- SMP server errors
| AGENT SMPAgentError -- errors of other agents
| INTERNAL String -- agent implementation errors
deriving (Eq, Generic, Read, Show, Exception)
data AckStatus = AckOk | AckError AckErrorType
deriving (Show)
data CommandErrorType
= PROHIBITED -- command is prohibited
| SYNTAX -- command syntax is invalid
| NO_CONN -- connection alias is required with this command
| SIZE -- message size is not correct (no terminating space)
| LARGE -- message does not fit SMP block
deriving (Eq, Generic, Read, Show, Exception)
data AckErrorType = AckUnknown | AckProhibited | AckSyntax Int -- etc.
deriving (Show)
data ConnectionErrorType
= UNKNOWN -- connection alias not in database
| DUPLICATE -- connection alias already exists
| SIMPLEX -- connection is simplex, but operation requires another queue
deriving (Eq, Generic, Read, Show, Exception)
errBadEncoding :: Int
errBadEncoding = 10
data BrokerErrorType
= RESPONSE ErrorType -- invalid server response (failed to parse)
| UNEXPECTED -- unexpected response
| NETWORK -- network error
| TRANSPORT TransportError -- handshake or other transport error
| TIMEOUT -- command response timeout
deriving (Eq, Generic, Read, Show, Exception)
errBadCommand :: Int
errBadCommand = 11
data SMPAgentError
= A_MESSAGE -- possibly should include bytestring that failed to parse
| A_PROHIBITED -- possibly should include the prohibited SMP/agent message
| A_ENCRYPTION -- cannot RSA/AES-decrypt or parse decrypted header
| A_SIGNATURE -- invalid RSA signature
deriving (Eq, Generic, Read, Show, Exception)
errBadInvitation :: Int
errBadInvitation = 12
instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU
errNoConnAlias :: Int
errNoConnAlias = 13
instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU
errBadMessage :: Int
errBadMessage = 14
instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU
errBadServer :: Int
errBadServer = 15
instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
smpErrTCPConnection :: Natural
smpErrTCPConnection = 1
smpErrCorrelationId :: Natural
smpErrCorrelationId = 2
smpUnexpectedResponse :: Natural
smpUnexpectedResponse = 3
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
commandP :: Parser ACmd
commandP =
@@ -309,28 +315,30 @@ commandP =
sendCmd = ACmd SClient . SEND <$> A.takeByteString
sentResp = ACmd SAgent . SENT <$> A.decimal
message = do
m_status <- status <* A.space
m_recipient <- "R=" *> partyMeta A.decimal
m_broker <- "B=" *> partyMeta base64P
m_sender <- "S=" *> partyMeta A.decimal
m_body <- A.takeByteString
return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body}
-- TODO other error types
agentError = ACmd SAgent . ERR <$> ("SMP " *> smpErrorType)
smpErrorType = "AUTH" $> SMP SMP.AUTH
msgIntegrity <- msgIntegrityP <* A.space
recipientMeta <- "R=" *> partyMeta A.decimal
brokerMeta <- "B=" *> partyMeta base64P
senderMeta <- "S=" *> partyMeta A.decimal
msgBody <- A.takeByteString
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
replyMode =
" NO_REPLY" $> ReplyOff
<|> A.space *> (ReplyVia <$> smpServerP)
<|> pure ReplyOn
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType)
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
msgIntegrityP :: Parser MsgIntegrity
msgIntegrityP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType)
where
msgErrorType =
"ID " *> (MsgBadId <$> A.decimal)
<|> "NO_ID " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal)
<|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal)
<|> "HASH" $> MsgBadHash
<|> "DUPLICATE" $> MsgDuplicate
parseCommand :: ByteString -> Either AgentErrorType ACmd
parseCommand = parse commandP $ SYNTAX errBadCommand
parseCommand = parse commandP $ CMD SYNTAX
serializeCommand :: ACommand p -> ByteString
serializeCommand = \case
@@ -342,19 +350,19 @@ serializeCommand = \case
END -> "END"
SEND msgBody -> "SEND " <> serializeMsg msgBody
SENT mId -> "SENT " <> bshow mId
MSG {m_recipient = (rmId, rTs), m_broker = (bmId, bTs), m_sender = (smId, sTs), m_status, m_body} ->
MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} ->
B.unwords
[ "MSG",
msgStatus m_status,
serializeMsgIntegrity msgIntegrity,
"R=" <> bshow rmId <> "," <> showTs rTs,
"B=" <> encode bmId <> "," <> showTs bTs,
"S=" <> bshow smId <> "," <> showTs sTs,
serializeMsg m_body
serializeMsg msgBody
]
OFF -> "OFF"
DEL -> "DEL"
CON -> "CON"
ERR e -> "ERR " <> B.pack (show e)
ERR e -> "ERR " <> serializeAgentError e
OK -> "OK"
where
replyMode :: ReplyMode -> ByteString
@@ -364,19 +372,35 @@ serializeCommand = \case
ReplyOn -> ""
showTs :: UTCTime -> ByteString
showTs = B.pack . formatISO8601Millis
msgStatus :: MsgStatus -> ByteString
msgStatus = \case
MsgOk -> "OK"
MsgError e ->
"ERR" <> case e of
MsgSkipped fromMsgId toMsgId ->
B.unwords ["NO_ID", B.pack $ show fromMsgId, B.pack $ show toMsgId]
MsgBadId aMsgId -> "ID " <> B.pack (show aMsgId)
MsgBadHash -> "HASH"
-- TODO - save function as in the server Transmission - re-use?
serializeMsgIntegrity :: MsgIntegrity -> ByteString
serializeMsgIntegrity = \case
MsgOk -> "OK"
MsgError e ->
"ERR " <> case e of
MsgSkipped fromMsgId toMsgId ->
B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId]
MsgBadId aMsgId -> "ID " <> bshow aMsgId
MsgBadHash -> "HASH"
MsgDuplicate -> "DUPLICATE"
agentErrorTypeP :: Parser AgentErrorType
agentErrorTypeP =
"SMP " *> (SMP <$> SMP.errorTypeP)
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP)
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
<|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString)
<|> parseRead2
serializeAgentError :: AgentErrorType -> ByteString
serializeAgentError = \case
SMP e -> "SMP " <> SMP.serializeErrorType e
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
e -> bshow e
serializeMsg :: ByteString -> ByteString
serializeMsg body = B.pack (show $ B.length body) <> "\n" <> body
serializeMsg body = bshow (B.length body) <> "\n" <> body
tPutRaw :: Handle -> ARawTransmission -> IO ()
tPutRaw h (corrId, connAlias, command) = do
@@ -404,7 +428,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
fromParty :: ACmd -> Either AgentErrorType (ACommand p)
fromParty (ACmd (p :: p1) cmd) = case testEquality party p of
Just Refl -> Right cmd
_ -> Left PROHIBITED
_ -> Left $ CMD PROHIBITED
tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p)
tConnAlias (_, connAlias, _) cmd = case cmd of
@@ -415,13 +439,13 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
ERR _ -> Right cmd
-- other responses must have connAlias
_
| B.null connAlias -> Left $ SYNTAX errNoConnAlias
| B.null connAlias -> Left $ CMD NO_CONN
| otherwise -> Right cmd
cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p))
cmdWithMsgBody = \case
SEND body -> SEND <$$> getMsgBody body
MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
cmd -> return $ Right cmd
-- TODO refactor with server
@@ -433,5 +457,5 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
Just size -> liftIO $ do
body <- B.hGet h size
s <- getLn h
return $ if B.null s then Right body else Left SIZE
Nothing -> return . Left $ SYNTAX errMessageBody
return $ if B.null s then Right body else Left $ CMD SIZE
Nothing -> return . Left $ CMD SYNTAX

View File

@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
@@ -8,7 +9,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Client
( SMPClient,
( SMPClient (blockSize),
getSMPClient,
closeSMPClient,
createSMPQueue,
@@ -33,33 +34,31 @@ import Control.Exception
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import qualified Crypto.PubKey.RSA.Types as RSA
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe
import GHC.IO.Exception (IOErrorType (..))
import Network.Socket (ServiceName)
import Numeric.Natural
import Simplex.Messaging.Agent.Transmission (SMPServer (..))
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport
import Simplex.Messaging.Util (liftEitherError, raceAny_)
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
import System.IO
import System.IO.Error
import System.Timeout
data SMPClient = SMPClient
{ action :: Async (),
connected :: TVar Bool,
smpServer :: SMPServer,
tcpTimeout :: Int,
clientCorrId :: TVar Natural,
sentCommands :: TVar (Map CorrId Request),
sndQ :: TBQueue SignedRawTransmission,
rcvQ :: TBQueue SignedTransmissionOrError,
msgQ :: TBQueue SMPServerTransmission
msgQ :: TBQueue SMPServerTransmission,
blockSize :: Int
}
type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker)
@@ -69,7 +68,6 @@ data SMPClientConfig = SMPClientConfig
defaultPort :: ServiceName,
tcpTimeout :: Int,
smpPing :: Int,
blockSize :: Int,
smpCommandSize :: Int
}
@@ -78,36 +76,36 @@ smpDefaultConfig =
SMPClientConfig
{ qSize = 16,
defaultPort = "5223",
tcpTimeout = 2_000_000,
tcpTimeout = 4_000_000,
smpPing = 30_000_000,
blockSize = 8_192, -- 16_384,
smpCommandSize = 256
}
data Request = Request
{ queueId :: QueueId,
responseVar :: TMVar (Either SMPClientError Cmd)
responseVar :: TMVar Response
}
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient
type Response = Either SMPClientError Cmd
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient)
getSMPClient
smpServer@SMPServer {host, port}
smpServer@SMPServer {host, port, keyHash}
SMPClientConfig {qSize, defaultPort, tcpTimeout, smpPing}
msgQ
disconnected = do
c <- atomically mkSMPClient
started <- newEmptyTMVarIO
thVar <- newEmptyTMVarIO
action <-
async $
runTCPClient host (fromMaybe defaultPort port) (client c started)
`finally` atomically (putTMVar started False)
tcpTimeout `timeout` atomically (takeTMVar started) >>= \case
Just True -> return c {action}
_ -> throwIO err
runTCPClient host (fromMaybe defaultPort port) (client c thVar)
`finally` atomically (putTMVar thVar $ Left SMPNetworkError)
tHandle <- tcpTimeout `timeout` atomically (takeTMVar thVar)
pure $ case tHandle of
Just (Right THandle {blockSize}) -> Right c {action, blockSize}
Just (Left e) -> Left e
Nothing -> Left SMPNetworkError
where
err :: IOException
err = mkIOError TimeExpired "connection timeout" Nothing Nothing
mkSMPClient :: STM SMPClient
mkSMPClient = do
connected <- newTVar False
@@ -118,8 +116,10 @@ getSMPClient
return
SMPClient
{ action = undefined,
blockSize = undefined,
connected,
smpServer,
tcpTimeout,
clientCorrId,
sentCommands,
sndQ,
@@ -127,19 +127,24 @@ getSMPClient
msgQ
}
client :: SMPClient -> TMVar Bool -> Handle -> IO ()
client c started h = do
_ <- getLn h -- "Welcome to SMP"
client :: SMPClient -> TMVar (Either SMPClientError THandle) -> Handle -> IO ()
client c thVar h =
runExceptT (clientHandshake h keyHash) >>= \case
Right th -> clientTransport c thVar th
Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e
clientTransport :: SMPClient -> TMVar (Either SMPClientError THandle) -> THandle -> IO ()
clientTransport c thVar th = do
atomically $ do
modifyTVar (connected c) (const True)
putTMVar started True
raceAny_ [send c h, process c, receive c h, ping c]
writeTVar (connected c) True
putTMVar thVar $ Right th
raceAny_ [send c th, process c, receive c th, ping c]
`finally` disconnected
send :: SMPClient -> Handle -> IO ()
send :: SMPClient -> THandle -> IO ()
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
receive :: SMPClient -> Handle -> IO ()
receive :: SMPClient -> THandle -> IO ()
receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ
ping :: SMPClient -> IO ()
@@ -165,7 +170,7 @@ getSMPClient
Left e -> Left $ SMPResponseError e
Right (Cmd _ (ERR e)) -> Left $ SMPServerError e
Right r -> Right r
else Left SMPQueueIdError
else Left SMPUnexpectedResponse
closeSMPClient :: SMPClient -> IO ()
closeSMPClient = uninterruptibleCancel . action
@@ -173,11 +178,11 @@ closeSMPClient = uninterruptibleCancel . action
data SMPClientError
= SMPServerError ErrorType
| SMPResponseError ErrorType
| SMPQueueIdError
| SMPUnexpectedResponse
| SMPResponseTimeout
| SMPCryptoError RSA.Error
| SMPClientError
| SMPNetworkError
| SMPTransportError TransportError
| SMPSignatureError C.CryptoError
deriving (Eq, Show, Exception)
createSMPQueue ::
@@ -222,14 +227,14 @@ suspendSMPQueue = okSMPCommand $ Cmd SRecipient OFF
deleteSMPQueue :: SMPClient -> RecipientPrivateKey -> QueueId -> ExceptT SMPClientError IO ()
deleteSMPQueue = okSMPCommand $ Cmd SRecipient DEL
okSMPCommand :: Cmd -> SMPClient -> C.PrivateKey -> QueueId -> ExceptT SMPClientError IO ()
okSMPCommand :: Cmd -> SMPClient -> C.SafePrivateKey -> QueueId -> ExceptT SMPClientError IO ()
okSMPCommand cmd c pKey qId =
sendSMPCommand c (Just pKey) qId cmd >>= \case
Cmd _ OK -> return ()
_ -> throwE SMPUnexpectedResponse
sendSMPCommand :: SMPClient -> Maybe C.PrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do
sendSMPCommand :: SMPClient -> Maybe C.SafePrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId cmd = do
corrId <- lift_ getNextCorrId
t <- signTransmission $ serializeTransmission (corrId, qId, cmd)
ExceptT $ sendRecv corrId t
@@ -241,20 +246,22 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do
getNextCorrId = do
i <- (+ 1) <$> readTVar clientCorrId
writeTVar clientCorrId i
return . CorrId . B.pack $ show i
return . CorrId $ bshow i
signTransmission :: ByteString -> ExceptT SMPClientError IO SignedRawTransmission
signTransmission t = case pKey of
Nothing -> return ("", t)
Just pk -> do
sig <- liftEitherError SMPCryptoError $ C.sign pk t
sig <- liftError SMPSignatureError $ C.sign pk t
return (sig, t)
-- two separate "atomically" needed to avoid blocking
sendRecv :: CorrId -> SignedRawTransmission -> IO (Either SMPClientError Cmd)
sendRecv corrId t = atomically (send corrId t) >>= atomically . takeTMVar
sendRecv :: CorrId -> SignedRawTransmission -> IO Response
sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar
where
withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a
send :: CorrId -> SignedRawTransmission -> STM (TMVar (Either SMPClientError Cmd))
send :: CorrId -> SignedRawTransmission -> STM (TMVar Response)
send corrId t = do
r <- newEmptyTMVar
modifyTVar sentCommands . M.insert corrId $ Request qId r

View File

@@ -1,26 +1,53 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Simplex.Messaging.Crypto
( PrivateKey (..),
( PrivateKey (rsaPrivateKey),
SafePrivateKey, -- constructor is not exported
FullPrivateKey (..),
PublicKey (..),
Signature (..),
CryptoError (..),
SafeKeyPair,
FullKeyPair,
Key (..),
IV (..),
KeyHash (..),
generateKeyPair,
publicKey,
publicKeySize,
safePrivateKey,
sign,
verify,
encrypt,
decrypt,
encryptOAEP,
decryptOAEP,
encryptAES,
decryptAES,
serializePrivKey,
serializePubKey,
parsePrivKey,
parsePubKey,
encodePubKey,
serializeKeyHash,
getKeyHash,
sha256Hash,
privKeyP,
pubKeyP,
binaryPubKeyP,
keyHashP,
authTagSize,
authTagToBS,
bsToAuthTag,
randomAesKey,
randomIV,
aesKeyP,
ivP,
)
where
@@ -30,60 +57,89 @@ import Control.Monad.Trans.Except
import Crypto.Cipher.AES (AES256)
import qualified Crypto.Cipher.Types as AES
import qualified Crypto.Error as CE
import Crypto.Hash.Algorithms (SHA256 (..))
import Crypto.Hash (Digest, SHA256 (..), digestFromByteString, hash)
import Crypto.Number.Generate (generateMax)
import Crypto.Number.Prime (findPrimeFrom)
import Crypto.Number.Serialize (i2osp, os2ip)
import qualified Crypto.PubKey.RSA as R
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import qualified Crypto.PubKey.RSA.PSS as PSS
import Crypto.Random (getRandomBytes)
import Data.ASN1.BinaryEncoding
import Data.ASN1.Encoding
import Data.ASN1.Types
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.Bifunctor (bimap, first)
import qualified Data.ByteArray as BA
import Data.ByteString.Base64
import Data.ByteString.Base64 (decode, encode)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.ByteString.Internal (c2w, w2c)
import Data.ByteString.Lazy (fromStrict, toStrict)
import Data.String
import Database.SQLite.Simple as DB
import Database.SQLite.Simple.FromField
import Data.Typeable (Typeable)
import Data.X509
import Database.SQLite.Simple (ResultError (..), SQLData (..))
import Database.SQLite.Simple.FromField (FieldParser, FromField (..), returnError)
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok (Ok (Ok))
import Database.SQLite.Simple.ToField (ToField (..))
import Network.Transport.Internal (decodeWord32, encodeWord32)
import Simplex.Messaging.Parsers (base64P)
import Simplex.Messaging.Util (bshow, liftEitherError, (<$$>))
import Simplex.Messaging.Parsers (base64P, parseAll)
import Simplex.Messaging.Util (liftEitherError, (<$?>))
newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show)
data PrivateKey = PrivateKey
{ private_size :: Int,
private_n :: Integer,
private_d :: Integer
}
deriving (Eq, Show)
newtype SafePrivateKey = SafePrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show)
instance ToField PrivateKey where toField = toField . serializePrivKey
newtype FullPrivateKey = FullPrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show)
instance ToField PublicKey where toField = toField . serializePubKey
class PrivateKey k where
rsaPrivateKey :: k -> R.PrivateKey
_privateKey :: R.PrivateKey -> k
mkPrivateKey :: R.PrivateKey -> k
instance FromField PrivateKey where
fromField f@(Field (SQLBlob b) _) =
case parsePrivKey b of
instance PrivateKey SafePrivateKey where
rsaPrivateKey = unPrivateKey
_privateKey = SafePrivateKey
mkPrivateKey R.PrivateKey {private_pub = k, private_d} =
safePrivateKey (R.public_size k, R.public_n k, private_d)
instance PrivateKey FullPrivateKey where
rsaPrivateKey = unPrivateKey
_privateKey = FullPrivateKey
mkPrivateKey = FullPrivateKey
instance IsString FullPrivateKey where
fromString = parseString (decode >=> decodePrivKey)
instance IsString PublicKey where
fromString = parseString (decode >=> decodePubKey)
parseString :: (ByteString -> Either String a) -> (String -> a)
parseString parse = either error id . parse . B.pack
instance ToField SafePrivateKey where toField = toField . encodePrivKey
instance ToField PublicKey where toField = toField . encodePubKey
instance FromField SafePrivateKey where fromField = keyFromField binaryPrivKeyP
instance FromField PublicKey where fromField = keyFromField binaryPubKeyP
keyFromField :: Typeable k => Parser k -> FieldParser k
keyFromField p = \case
f@(Field (SQLBlob b) _) ->
case parseAll p b of
Right k -> Ok k
Left e -> returnError ConversionFailed f ("couldn't parse PrivateKey field: " ++ e)
fromField f = returnError ConversionFailed f "expecting SQLBlob column type"
Left e -> returnError ConversionFailed f ("couldn't parse key field: " ++ e)
f -> returnError ConversionFailed f "expecting SQLBlob column type"
instance FromField PublicKey where
fromField f@(Field (SQLBlob b) _) =
case parsePubKey b of
Right k -> Ok k
Left e -> returnError ConversionFailed f ("couldn't parse PublicKey field: " ++ e)
fromField f = returnError ConversionFailed f "expecting SQLBlob column type"
type KeyPair k = (PublicKey, k)
type KeyPair = (PublicKey, PrivateKey)
type SafeKeyPair = (PublicKey, SafePrivateKey)
type FullKeyPair = (PublicKey, FullPrivateKey)
newtype Signature = Signature {unSignature :: ByteString} deriving (Eq, Show)
@@ -93,10 +149,12 @@ instance IsString Signature where
newtype Verified = Verified ByteString deriving (Show)
data CryptoError
= CryptoRSAError R.Error
| CryptoCipherError CE.CryptoError
= RSAEncryptError R.Error
| RSADecryptError R.Error
| RSASignError R.Error
| AESCipherError CE.CryptoError
| CryptoIVError
| CryptoDecryptError
| AESDecryptError
| CryptoLargeMsgError
| CryptoHeaderError String
deriving (Eq, Show, Exception)
@@ -110,19 +168,26 @@ aesKeySize = 256 `div` 8
authTagSize :: Int
authTagSize = 128 `div` 8
generateKeyPair :: Int -> IO KeyPair
generateKeyPair :: PrivateKey k => Int -> IO (KeyPair k)
generateKeyPair size = loop
where
publicExponent = findPrimeFrom . (+ 3) <$> generateMax pubExpRange
privateKey s n d = PrivateKey {private_size = s, private_n = n, private_d = d}
loop = do
(pub, priv) <- R.generate size =<< publicExponent
let s = R.public_size pub
n = R.public_n pub
d = R.private_d priv
in if d * d < n
then loop
else return (PublicKey pub, privateKey s n d)
(k, pk) <- R.generate size =<< publicExponent
let n = R.public_n k
d = R.private_d pk
if d * d < n
then loop
else pure (PublicKey k, mkPrivateKey pk)
privateKeySize :: PrivateKey k => k -> Int
privateKeySize = R.public_size . R.private_pub . rsaPrivateKey
publicKey :: FullPrivateKey -> PublicKey
publicKey = PublicKey . R.private_pub . rsaPrivateKey
publicKeySize :: PublicKey -> Int
publicKeySize = R.public_size . rsaPublicKey
data Header = Header
{ aesKey :: Key,
@@ -135,52 +200,89 @@ newtype Key = Key {unKey :: ByteString}
newtype IV = IV {unIV :: ByteString}
newtype KeyHash = KeyHash {unKeyHash :: Digest SHA256} deriving (Eq, Ord, Show)
instance IsString KeyHash where
fromString = parseString $ parseAll keyHashP
instance ToField KeyHash where toField = toField . serializeKeyHash
instance FromField KeyHash where
fromField f@(Field (SQLBlob b) _) =
case parseAll keyHashP b of
Right k -> Ok k
Left e -> returnError ConversionFailed f ("couldn't parse KeyHash field: " ++ e)
fromField f = returnError ConversionFailed f "expecting SQLBlob column type"
serializeKeyHash :: KeyHash -> ByteString
serializeKeyHash = encode . BA.convert . unKeyHash
keyHashP :: Parser KeyHash
keyHashP = do
bs <- base64P
case digestFromByteString bs of
Just d -> pure $ KeyHash d
_ -> fail "invalid digest"
getKeyHash :: ByteString -> KeyHash
getKeyHash = KeyHash . hash
sha256Hash :: ByteString -> ByteString
sha256Hash = BA.convert . (hash :: ByteString -> Digest SHA256)
serializeHeader :: Header -> ByteString
serializeHeader Header {aesKey, ivBytes, authTag, msgSize} =
unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize
headerP :: Parser Header
headerP = do
aesKey <- Key <$> A.take aesKeySize
ivBytes <- IV <$> A.take (ivSize @AES256)
aesKey <- aesKeyP
ivBytes <- ivP
authTag <- bsToAuthTag <$> A.take authTagSize
msgSize <- fromIntegral . decodeWord32 <$> A.take 4
return Header {aesKey, ivBytes, authTag, msgSize}
aesKeyP :: Parser Key
aesKeyP = Key <$> A.take aesKeySize
ivP :: Parser IV
ivP = IV <$> A.take (ivSize @AES256)
parseHeader :: ByteString -> Either CryptoError Header
parseHeader = first CryptoHeaderError . A.parseOnly (headerP <* A.endOfInput)
parseHeader = first CryptoHeaderError . parseAll headerP
encrypt :: PublicKey -> Int -> ByteString -> ExceptT CryptoError IO ByteString
encrypt k paddedSize msg = do
aesKey <- Key <$> randomBytes aesKeySize
ivBytes <- IV <$> randomBytes (ivSize @AES256)
aesKey <- liftIO randomAesKey
ivBytes <- liftIO randomIV
(authTag, msg') <- encryptAES aesKey ivBytes paddedSize msg
let header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg}
encHeader <- encryptOAEP k $ serializeHeader header
return $ encHeader <> msg'
decrypt :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString
decrypt pk msg'' = do
let (encHeader, msg') = B.splitAt (privateKeySize pk) msg''
header <- decryptOAEP pk encHeader
Header {aesKey, ivBytes, authTag, msgSize} <- except $ parseHeader header
msg <- decryptAES aesKey ivBytes msg' authTag
return $ B.take msgSize msg
encryptAES :: Key -> IV -> Int -> ByteString -> ExceptT CryptoError IO (AES.AuthTag, ByteString)
encryptAES aesKey ivBytes paddedSize msg = do
aead <- initAEAD @AES256 aesKey ivBytes
msg' <- paddedMsg
let (authTag, msg'') = encryptAES aead msg'
header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg}
encHeader <- encryptOAEP k $ serializeHeader header
return $ encHeader <> msg''
return $ AES.aeadSimpleEncrypt aead B.empty msg' authTagSize
where
len = B.length msg
paddedMsg
| len >= paddedSize = throwE CryptoLargeMsgError
| otherwise = return (msg <> B.replicate (paddedSize - len) '#')
decrypt :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString
decrypt pk msg'' = do
let (encHeader, msg') = B.splitAt (private_size pk) msg''
header <- decryptOAEP pk encHeader
Header {aesKey, ivBytes, authTag, msgSize} <- ExceptT . return $ parseHeader header
decryptAES :: Key -> IV -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString
decryptAES aesKey ivBytes msg authTag = do
aead <- initAEAD @AES256 aesKey ivBytes
msg <- decryptAES aead msg' authTag
return $ B.take msgSize msg
encryptAES :: AES.AEAD AES256 -> ByteString -> (AES.AuthTag, ByteString)
encryptAES aead plaintext = AES.aeadSimpleEncrypt aead B.empty plaintext authTagSize
decryptAES :: AES.AEAD AES256 -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString
decryptAES aead ciphertext authTag =
maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty ciphertext authTag
maybeError AESDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag
initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c)
initAEAD (Key aesKey) (IV ivBytes) = do
@@ -189,15 +291,18 @@ initAEAD (Key aesKey) (IV ivBytes) = do
cipher <- AES.cipherInit aesKey
AES.aeadInit AES.AEAD_GCM cipher iv
randomAesKey :: IO Key
randomAesKey = Key <$> getRandomBytes aesKeySize
randomIV :: IO IV
randomIV = IV <$> getRandomBytes (ivSize @AES256)
ivSize :: forall c. AES.BlockCipher c => Int
ivSize = AES.blockSize (undefined :: c)
makeIV :: AES.BlockCipher c => ByteString -> ExceptT CryptoError IO (AES.IV c)
makeIV bs = maybeError CryptoIVError $ AES.makeIV bs
randomBytes :: Int -> ExceptT CryptoError IO ByteString
randomBytes n = ExceptT $ Right <$> getRandomBytes n
maybeError :: CryptoError -> Maybe a -> ExceptT CryptoError IO a
maybeError e = maybe (throwE e) return
@@ -208,75 +313,93 @@ bsToAuthTag :: ByteString -> AES.AuthTag
bsToAuthTag = AES.AuthTag . BA.pack . map c2w . B.unpack
cryptoFailable :: CE.CryptoFailable a -> ExceptT CryptoError IO a
cryptoFailable = liftEither . first CryptoCipherError . CE.eitherCryptoError
cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError
oaepParams :: OAEP.OAEPParams SHA256 ByteString ByteString
oaepParams = OAEP.defaultOAEPParams SHA256
encryptOAEP :: PublicKey -> ByteString -> ExceptT CryptoError IO ByteString
encryptOAEP (PublicKey k) aesKey =
liftEitherError CryptoRSAError $
liftEitherError RSAEncryptError $
OAEP.encrypt oaepParams k aesKey
decryptOAEP :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString
decryptOAEP :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString
decryptOAEP pk encKey =
liftEitherError CryptoRSAError $
liftEitherError RSADecryptError $
OAEP.decryptSafer oaepParams (rsaPrivateKey pk) encKey
pssParams :: PSS.PSSParams SHA256 ByteString ByteString
pssParams = PSS.defaultPSSParams SHA256
sign :: PrivateKey -> ByteString -> IO (Either R.Error Signature)
sign pk msg = Signature <$$> PSS.signSafer pssParams (rsaPrivateKey pk) msg
sign :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO Signature
sign pk msg = ExceptT $ bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg
verify :: PublicKey -> Signature -> ByteString -> Bool
verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig
serializePubKey :: PublicKey -> ByteString
serializePubKey (PublicKey k) = serializeKey_ (R.public_size k, R.public_n k, R.public_e k)
serializePubKey = ("rsa:" <>) . encode . encodePubKey
serializePrivKey :: PrivateKey -> ByteString
serializePrivKey pk = serializeKey_ (private_size pk, private_n pk, private_d pk)
serializeKey_ :: (Int, Integer, Integer) -> ByteString
serializeKey_ (size, n, ex) = bshow size <> "," <> encInt n <> "," <> encInt ex
where
encInt = encode . i2osp
serializePrivKey :: PrivateKey k => k -> ByteString
serializePrivKey = ("rsa:" <>) . encode . encodePrivKey
pubKeyP :: Parser PublicKey
pubKeyP = do
(public_size, public_n, public_e) <- keyParser_
return . PublicKey $ R.PublicKey {R.public_size, R.public_n, R.public_e}
pubKeyP = decodePubKey <$?> ("rsa:" *> base64P)
privKeyP :: Parser PrivateKey
privKeyP = do
(private_size, private_n, private_d) <- keyParser_
return PrivateKey {private_size, private_n, private_d}
binaryPubKeyP :: Parser PublicKey
binaryPubKeyP = decodePubKey <$?> A.takeByteString
parsePubKey :: ByteString -> Either String PublicKey
parsePubKey = A.parseOnly (pubKeyP <* A.endOfInput)
privKeyP :: PrivateKey k => Parser k
privKeyP = decodePrivKey <$?> ("rsa:" *> base64P)
parsePrivKey :: ByteString -> Either String PrivateKey
parsePrivKey = A.parseOnly (privKeyP <* A.endOfInput)
binaryPrivKeyP :: PrivateKey k => Parser k
binaryPrivKeyP = decodePrivKey <$?> A.takeByteString
keyParser_ :: Parser (Int, Integer, Integer)
keyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP
where
intP = os2ip <$> base64P
safePrivateKey :: (Int, Integer, Integer) -> SafePrivateKey
safePrivateKey = SafePrivateKey . safeRsaPrivateKey
rsaPrivateKey :: PrivateKey -> R.PrivateKey
rsaPrivateKey pk =
safeRsaPrivateKey :: (Int, Integer, Integer) -> R.PrivateKey
safeRsaPrivateKey (size, n, d) =
R.PrivateKey
{ R.private_pub =
{ private_pub =
R.PublicKey
{ R.public_size = private_size pk,
R.public_n = private_n pk,
R.public_e = undefined
{ public_size = size,
public_n = n,
public_e = 0
},
R.private_d = private_d pk,
R.private_p = 0,
R.private_q = 0,
R.private_dP = undefined,
R.private_dQ = undefined,
R.private_qinv = undefined
private_d = d,
private_p = 0,
private_q = 0,
private_dP = 0,
private_dQ = 0,
private_qinv = 0
}
encodePubKey :: PublicKey -> ByteString
encodePubKey = encodeKey . PubKeyRSA . rsaPublicKey
encodePrivKey :: PrivateKey k => k -> ByteString
encodePrivKey = encodeKey . PrivKeyRSA . rsaPrivateKey
encodeKey :: ASN1Object a => a -> ByteString
encodeKey k = toStrict . encodeASN1 DER $ toASN1 k []
decodePubKey :: ByteString -> Either String PublicKey
decodePubKey =
decodeKey >=> \case
(PubKeyRSA k, []) -> Right $ PublicKey k
r -> keyError r
decodePrivKey :: PrivateKey k => ByteString -> Either String k
decodePrivKey =
decodeKey >=> \case
(PrivKeyRSA pk, []) -> Right $ mkPrivateKey pk
r -> keyError r
decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1])
decodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict
keyError :: (a, [ASN1]) -> Either String b
keyError = \case
(_, []) -> Left "not RSA key"
_ -> Left "more than one key"

View File

@@ -1,3 +1,5 @@
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Parsers where
import Data.Attoparsec.ByteString.Char8 (Parser)
@@ -9,15 +11,35 @@ import qualified Data.ByteString.Char8 as B
import Data.Char (isAlphaNum)
import Data.Time.Clock (UTCTime)
import Data.Time.ISO8601 (parseISO8601)
import Simplex.Messaging.Util ((<$?>))
import Text.Read (readMaybe)
base64P :: Parser ByteString
base64P = do
base64P = decode <$?> base64StringP
base64StringP :: Parser ByteString
base64StringP = do
str <- A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/')
pad <- A.takeWhile (== '=')
either fail pure $ decode (str <> pad)
pure $ str <> pad
tsISO8601P :: Parser UTCTime
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ')
parse :: Parser a -> e -> (ByteString -> Either e a)
parse parser err = first (const err) . A.parseOnly (parser <* A.endOfInput)
parse parser err = first (const err) . parseAll parser
parseAll :: Parser a -> (ByteString -> Either String a)
parseAll parser = A.parseOnly (parser <* A.endOfInput)
parseRead :: Read a => Parser ByteString -> Parser a
parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack)
parseRead1 :: Read a => Parser a
parseRead1 = parseRead $ A.takeTill (== ' ')
parseRead2 :: Read a => Parser a
parseRead2 = parseRead $ do
w1 <- A.takeTill (== ' ') <* A.char ' '
w2 <- A.takeTill (== ' ')
pure $ w1 <> " " <> w2

View File

@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
@@ -13,7 +14,7 @@ module Simplex.Messaging.Protocol where
import Control.Applicative ((<|>))
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Except
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Base64
@@ -24,12 +25,13 @@ import Data.Kind
import Data.String
import Data.Time.Clock
import Data.Time.ISO8601
import GHC.Generics (Generic)
import Generic.Random (genericArbitraryU)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Parsers
import Simplex.Messaging.Transport
import Simplex.Messaging.Util
import System.IO
import Text.Read
import Test.QuickCheck (Arbitrary (..))
data Party = Broker | Recipient | Sender
deriving (Show)
@@ -95,12 +97,12 @@ instance IsString CorrId where
fromString = CorrId . fromString
-- only used by Agent, kept here so its definition is close to respective public key
type RecipientPrivateKey = C.PrivateKey
type RecipientPrivateKey = C.SafePrivateKey
type RecipientPublicKey = C.PublicKey
-- only used by Agent, kept here so its definition is close to respective public key
type SenderPrivateKey = C.PrivateKey
type SenderPrivateKey = C.SafePrivateKey
type SenderPublicKey = C.PublicKey
@@ -108,25 +110,36 @@ type MsgId = Encoded
type MsgBody = ByteString
data ErrorType = PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq)
data ErrorType
= BLOCK
| CMD CommandError
| AUTH
| NO_MSG
| INTERNAL
| DUPLICATE_ -- TODO remove, not part of SMP protocol
deriving (Eq, Generic, Read, Show)
errBadTransmission :: Int
errBadTransmission = 1
data CommandError
= PROHIBITED
| SYNTAX
| NO_AUTH
| HAS_AUTH
| NO_QUEUE
deriving (Eq, Generic, Read, Show)
errBadSMPCommand :: Int
errBadSMPCommand = 2
instance Arbitrary ErrorType where arbitrary = genericArbitraryU
errNoCredentials :: Int
errNoCredentials = 3
instance Arbitrary CommandError where arbitrary = genericArbitraryU
errHasCredentials :: Int
errHasCredentials = 4
errNoQueueId :: Int
errNoQueueId = 5
errMessageBody :: Int
errMessageBody = 6
transmissionP :: Parser RawTransmission
transmissionP = do
signature <- segment
corrId <- segment
queueId <- segment
command <- A.takeByteString
return (signature, corrId, queueId, command)
where
segment = A.takeTill (== ' ') <* " "
commandP :: Parser Cmd
commandP =
@@ -148,132 +161,110 @@ commandP =
newCmd = Cmd SRecipient . NEW <$> C.pubKeyP
idsResp = Cmd SBroker <$> (IDS <$> (base64P <* A.space) <*> base64P)
keyCmd = Cmd SRecipient . KEY <$> C.pubKeyP
sendCmd = Cmd SSender . SEND <$> A.takeWhile A.isDigit
sendCmd = do
size <- A.decimal <* A.space
Cmd SSender . SEND <$> A.take size <* A.space
message = do
msgId <- base64P <* A.space
ts <- tsISO8601P <* A.space
Cmd SBroker . MSG msgId ts <$> A.takeWhile A.isDigit
serverError = Cmd SBroker . ERR <$> errorType
errorType =
"PROHIBITED" $> PROHIBITED
<|> "SYNTAX " *> (SYNTAX <$> A.decimal)
<|> "SIZE" $> SIZE
<|> "AUTH" $> AUTH
<|> "INTERNAL" $> INTERNAL
size <- A.decimal <* A.space
Cmd SBroker . MSG msgId ts <$> A.take size <* A.space
serverError = Cmd SBroker . ERR <$> errorTypeP
-- TODO ignore the end of block, no need to parse it
parseCommand :: ByteString -> Either ErrorType Cmd
parseCommand = parse commandP $ SYNTAX errBadSMPCommand
parseCommand = parse (commandP <* " " <* A.takeByteString) $ CMD SYNTAX
serializeCommand :: Cmd -> ByteString
serializeCommand = \case
Cmd SRecipient (NEW rKey) -> "NEW " <> C.serializePubKey rKey
Cmd SRecipient (KEY sKey) -> "KEY " <> C.serializePubKey sKey
Cmd SRecipient cmd -> B.pack $ show cmd
Cmd SSender (SEND msgBody) -> "SEND" <> serializeMsg msgBody
Cmd SRecipient cmd -> bshow cmd
Cmd SSender (SEND msgBody) -> "SEND " <> serializeMsg msgBody
Cmd SSender PING -> "PING"
Cmd SBroker (MSG msgId ts msgBody) ->
B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts] <> serializeMsg msgBody
B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody]
Cmd SBroker (IDS rId sId) -> B.unwords ["IDS", encode rId, encode sId]
Cmd SBroker (ERR err) -> "ERR " <> B.pack (show err)
Cmd SBroker resp -> B.pack $ show resp
Cmd SBroker (ERR err) -> "ERR " <> serializeErrorType err
Cmd SBroker resp -> bshow resp
where
serializeMsg msgBody = " " <> B.pack (show $ B.length msgBody) <> "\r\n" <> msgBody
serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " "
tPutRaw :: Handle -> RawTransmission -> IO ()
tPutRaw h (signature, corrId, queueId, command) = do
putLn h signature
putLn h corrId
putLn h queueId
putLn h command
errorTypeP :: Parser ErrorType
errorTypeP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1
tGetRaw :: Handle -> IO RawTransmission
tGetRaw h = do
signature <- getLn h
corrId <- getLn h
queueId <- getLn h
command <- getLn h
return (signature, corrId, queueId, command)
serializeErrorType :: ErrorType -> ByteString
serializeErrorType = bshow
tPut :: Handle -> SignedRawTransmission -> IO ()
tPut h (C.Signature sig, t) = do
putLn h $ encode sig
putLn h t
tPut :: THandle -> SignedRawTransmission -> IO (Either TransportError ())
tPut th (C.Signature sig, t) =
tPutEncrypted th $ encode sig <> " " <> t <> " "
serializeTransmission :: Transmission -> ByteString
serializeTransmission (CorrId corrId, queueId, command) =
B.intercalate "\r\n" [corrId, encode queueId, serializeCommand command]
B.intercalate " " [corrId, encode queueId, serializeCommand command]
fromClient :: Cmd -> Either ErrorType Cmd
fromClient = \case
Cmd SBroker _ -> Left PROHIBITED
Cmd SBroker _ -> Left $ CMD PROHIBITED
cmd -> Right cmd
fromServer :: Cmd -> Either ErrorType Cmd
fromServer = \case
cmd@(Cmd SBroker _) -> Right cmd
_ -> Left PROHIBITED
_ -> Left $ CMD PROHIBITED
tGetParse :: THandle -> IO (Either TransportError RawTransmission)
tGetParse th = (>>= parse transmissionP TEBadBlock) <$> tGetEncrypted th
-- | get client and server transmissions
-- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used
tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> Handle -> m SignedTransmissionOrError
tGet fromParty h = do
(signature, corrId, queueId, command) <- liftIO $ tGetRaw h
let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId)
either (const $ tError corrId) tParseLoadBody decodedTransmission
tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> THandle -> m SignedTransmissionOrError
tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate
where
tError :: ByteString -> m SignedTransmissionOrError
tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left $ SYNTAX errBadTransmission))
decodeParseValidate :: Either TransportError RawTransmission -> m SignedTransmissionOrError
decodeParseValidate = \case
Right (signature, corrId, queueId, command) ->
let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId)
in either (const $ tError corrId) tParseValidate decodedTransmission
Left _ -> tError ""
tParseLoadBody :: RawTransmission -> m SignedTransmissionOrError
tParseLoadBody t@(sig, corrId, queueId, command) = do
tError :: ByteString -> m SignedTransmissionOrError
tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left BLOCK))
tParseValidate :: RawTransmission -> m SignedTransmissionOrError
tParseValidate t@(sig, corrId, queueId, command) = do
let cmd = parseCommand command >>= fromParty >>= tCredentials t
fullCmd <- either (return . Left) cmdWithMsgBody cmd
return (C.Signature sig, (CorrId corrId, queueId, fullCmd))
return (C.Signature sig, (CorrId corrId, queueId, cmd))
tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd
tCredentials (signature, _, queueId, _) cmd = case cmd of
-- IDS response should not have queue ID
-- IDS response must not have queue ID
Cmd SBroker (IDS _ _) -> Right cmd
-- ERR response does not always have queue ID
Cmd SBroker (ERR _) -> Right cmd
-- PONG response should not have queue ID
-- PONG response must not have queue ID
Cmd SBroker PONG
| B.null queueId -> Right cmd
| otherwise -> Left $ SYNTAX errHasCredentials
| otherwise -> Left $ CMD HAS_AUTH
-- other responses must have queue ID
Cmd SBroker _
| B.null queueId -> Left $ SYNTAX errNoQueueId
| B.null queueId -> Left $ CMD NO_QUEUE
| otherwise -> Right cmd
-- NEW must NOT have signature or queue ID
-- NEW must have signature but NOT queue ID
Cmd SRecipient (NEW _)
| B.null signature -> Left $ SYNTAX errNoCredentials
| not (B.null queueId) -> Left $ SYNTAX errHasCredentials
| B.null signature -> Left $ CMD NO_AUTH
| not (B.null queueId) -> Left $ CMD HAS_AUTH
| otherwise -> Right cmd
-- SEND must have queue ID, signature is not always required
Cmd SSender (SEND _)
| B.null queueId -> Left $ SYNTAX errNoQueueId
| B.null queueId -> Left $ CMD NO_QUEUE
| otherwise -> Right cmd
-- PING must not have queue ID or signature
Cmd SSender PING
| B.null queueId && B.null signature -> Right cmd
| otherwise -> Left $ SYNTAX errHasCredentials
| otherwise -> Left $ CMD HAS_AUTH
-- other client commands must have both signature and queue ID
Cmd SRecipient _
| B.null signature || B.null queueId -> Left $ SYNTAX errNoCredentials
| B.null signature || B.null queueId -> Left $ CMD NO_AUTH
| otherwise -> Right cmd
cmdWithMsgBody :: Cmd -> m (Either ErrorType Cmd)
cmdWithMsgBody = \case
Cmd SSender (SEND sizeStr) ->
Cmd SSender . SEND <$$> getMsgBody sizeStr
Cmd SBroker (MSG msgId ts sizeStr) ->
Cmd SBroker . MSG msgId ts <$$> getMsgBody sizeStr
cmd -> return $ Right cmd
getMsgBody :: MsgBody -> m (Either ErrorType MsgBody)
getMsgBody sizeStr = case readMaybe (B.unpack sizeStr) :: Maybe Int of
Just size -> liftIO $ do
body <- B.hGet h size
s <- getLn h
return $ if B.null s then Right body else Left SIZE
Nothing -> return $ Left INTERNAL

View File

@@ -15,6 +15,7 @@ module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, randomBytes
import Control.Concurrent.STM (stateTVar)
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Crypto.Random
@@ -30,6 +31,7 @@ import Simplex.Messaging.Server.MsgStore
import Simplex.Messaging.Server.MsgStore.STM (MsgQueue)
import Simplex.Messaging.Server.QueueStore
import Simplex.Messaging.Server.QueueStore.STM (QueueStore)
import Simplex.Messaging.Server.StoreLog
import Simplex.Messaging.Transport
import Simplex.Messaging.Util
import UnliftIO.Async
@@ -50,6 +52,7 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do
smpServer = do
s <- asks server
race_ (runTCPServer started tcpPort runClient) (serverThread s)
`finally` withLog closeStoreLog
serverThread :: MonadUnliftIO m => Server -> m ()
serverThread Server {subscribedQ, subscribers} = forever . atomically $ do
@@ -62,11 +65,17 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do
runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m ()
runClient h = do
liftIO $ putLn h "Welcome to SMP v0.2.0"
keyPair <- asks serverKeyPair
liftIO (runExceptT $ serverHandshake h keyPair) >>= \case
Right th -> runClientTransport th
Left _ -> pure ()
runClientTransport :: (MonadUnliftIO m, MonadReader Env m) => THandle -> m ()
runClientTransport th = do
q <- asks $ tbqSize . config
c <- atomically $ newClient q
s <- asks server
raceAny_ [send h c, client c s, receive h c]
raceAny_ [send th c, client c s, receive th c]
`finally` cancelSubscribers c
cancelSubscribers :: MonadUnliftIO m => Client -> m ()
@@ -78,7 +87,7 @@ cancelSub = \case
Sub {subThread = SubThread t} -> killThread t
_ -> return ()
receive :: (MonadUnliftIO m, MonadReader Env m) => Handle -> Client -> m ()
receive :: (MonadUnliftIO m, MonadReader Env m) => THandle -> Client -> m ()
receive h Client {rcvQ} = forever $ do
(signature, (corrId, queueId, cmdOrError)) <- tGet fromClient h
t <- case cmdOrError of
@@ -86,7 +95,7 @@ receive h Client {rcvQ} = forever $ do
Right cmd -> verifyTransmission (signature, (corrId, queueId, cmd))
atomically $ writeTBQueue rcvQ t
send :: MonadUnliftIO m => Handle -> Client -> m ()
send :: MonadUnliftIO m => THandle -> Client -> m ()
send h Client {sndQ} = forever $ do
t <- atomically $ readTBQueue sndQ
liftIO $ tPut h ("", serializeTransmission t)
@@ -99,25 +108,27 @@ verifyTransmission (sig, t@(corrId, queueId, cmd)) = do
(corrId,queueId,) <$> case cmd of
Cmd SBroker _ -> return $ smpErr INTERNAL -- it can only be client command, because `fromClient` was used
Cmd SRecipient (NEW k) -> return $ verifySignature k
Cmd SRecipient _ -> withQueueRec SRecipient $ verifySignature . recipientKey
Cmd SSender (SEND _) -> withQueueRec SSender $ verifySend sig . senderKey
Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey
Cmd SSender (SEND _) -> verifyCmd SSender $ verifySend sig . senderKey
Cmd SSender PING -> return cmd
where
withQueueRec :: SParty (p :: Party) -> (QueueRec -> Cmd) -> m Cmd
withQueueRec party f = do
verifyCmd :: SParty p -> (QueueRec -> Cmd) -> m Cmd
verifyCmd party f = do
(aKey, _) <- asks serverKeyPair -- any public key can be used to mitigate timing attack
st <- asks queueStore
qr <- atomically $ getQueue st party queueId
return $ either smpErr f qr
q <- atomically $ getQueue st party queueId
pure $ either (const $ fakeVerify aKey) f q
fakeVerify :: C.PublicKey -> Cmd
fakeVerify aKey = if verify aKey then authErr else authErr
verifySend :: C.Signature -> Maybe SenderPublicKey -> Cmd
verifySend "" = maybe cmd (const authErr)
verifySend _ = maybe authErr verifySignature
verifySignature :: C.PublicKey -> Cmd
verifySignature key =
if C.verify key sig (serializeTransmission t)
then cmd
else authErr
smpErr e = Cmd SBroker $ ERR e
verifySignature key = if verify key then cmd else authErr
verify :: C.PublicKey -> Bool
verify key = C.verify key sig (serializeTransmission t)
smpErr :: ErrorType -> Cmd
smpErr = Cmd SBroker . ERR
authErr = smpErr AUTH
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m ()
@@ -140,8 +151,8 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
NEW rKey -> createQueue st rKey
SUB -> subscribeQueue queueId
ACK -> acknowledgeMsg
KEY sKey -> okResp <$> atomically (secureQueue st queueId sKey)
OFF -> okResp <$> atomically (suspendQueue st queueId)
KEY sKey -> secureQueue_ st sKey
OFF -> suspendQueue_ st
DEL -> delQueueAndMsgs st
where
createQueue :: QueueStore -> RecipientPublicKey -> m Transmission
@@ -151,22 +162,40 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
addSubscribe =
addQueueRetry 3 >>= \case
Left e -> return $ ERR e
Right (rId, sId) -> subscribeQueue rId $> IDS rId sId
Right (rId, sId) -> do
withLog (`logCreateById` rId)
subscribeQueue rId $> IDS rId sId
addQueueRetry :: Int -> m (Either ErrorType (RecipientId, SenderId))
addQueueRetry 0 = return $ Left INTERNAL
addQueueRetry n = do
ids <- getIds
atomically (addQueue st rKey ids) >>= \case
Left DUPLICATE -> addQueueRetry $ n - 1
Left DUPLICATE_ -> addQueueRetry $ n - 1
Left e -> return $ Left e
Right _ -> return $ Right ids
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
logCreateById s rId =
atomically (getQueue st SRecipient rId) >>= \case
Right q -> logCreateQueue s q
_ -> pure ()
getIds :: m (RecipientId, SenderId)
getIds = do
n <- asks $ queueIdBytes . config
liftM2 (,) (randomId n) (randomId n)
secureQueue_ :: QueueStore -> SenderPublicKey -> m Transmission
secureQueue_ st sKey = do
withLog $ \s -> logSecureQueue s queueId sKey
okResp <$> atomically (secureQueue st queueId sKey)
suspendQueue_ :: QueueStore -> m Transmission
suspendQueue_ st = do
withLog (`logDeleteQueue` queueId)
okResp <$> atomically (suspendQueue st queueId)
subscribeQueue :: RecipientId -> m Transmission
subscribeQueue rId =
atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId
@@ -193,7 +222,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s))
>>= \case
Just (Just s) -> deliverMessage tryDelPeekMsg queueId s
_ -> return $ err PROHIBITED
_ -> return $ err NO_MSG
withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a)
withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId
@@ -253,6 +282,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
delQueueAndMsgs :: QueueStore -> m Transmission
delQueueAndMsgs st = do
withLog (`logDeleteQueue` queueId)
ms <- asks msgStore
atomically $
deleteQueue st queueId >>= \case
@@ -271,6 +301,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} =
msgCmd :: Message -> Command 'Broker
msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody
withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m ()
withLog action = do
env <- ask
liftIO . mapM_ action $ storeLog (env :: Env)
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded
randomId n = do
gVar <- asks idsDrg

View File

@@ -1,5 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Server.Env.STM where
@@ -10,16 +12,23 @@ import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Network.Socket (ServiceName)
import Numeric.Natural
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.MsgStore.STM
import Simplex.Messaging.Server.QueueStore (QueueRec (..))
import Simplex.Messaging.Server.QueueStore.STM
import Simplex.Messaging.Server.StoreLog
import System.IO (IOMode (..))
import UnliftIO.STM
data ServerConfig = ServerConfig
{ tcpPort :: ServiceName,
tbqSize :: Natural,
queueIdBytes :: Int,
msgIdBytes :: Int
msgIdBytes :: Int,
storeLog :: Maybe (StoreLog 'ReadMode),
serverPrivateKey :: C.FullPrivateKey
-- serverId :: ByteString
}
data Env = Env
@@ -27,7 +36,9 @@ data Env = Env
server :: Server,
queueStore :: QueueStore,
msgStore :: STMMsgStore,
idsDrg :: TVar ChaChaDRG
idsDrg :: TVar ChaChaDRG,
serverKeyPair :: C.FullKeyPair,
storeLog :: Maybe (StoreLog 'WriteMode)
}
data Server = Server
@@ -66,10 +77,21 @@ newSubscription = do
delivered <- newEmptyTMVar
return Sub {subThread = NoSub, delivered}
newEnv :: (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
newEnv config = do
server <- atomically $ newServer (tbqSize config)
queueStore <- atomically newQueueStore
msgStore <- atomically newMsgStore
idsDrg <- drgNew >>= newTVarIO
return Env {config, server, queueStore, msgStore, idsDrg}
s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig)
let pk = serverPrivateKey config
serverKeyPair = (C.publicKey pk, pk)
return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s'}
where
restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode)
restoreQueues queueStore s = do
(queues, s') <- liftIO $ readWriteStoreLog s
atomically $ modifyTVar queueStore $ \d -> d {queues, senders = M.foldr' addSender M.empty queues}
pure s'
addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId
addSender q = M.insert (senderId q) (recipientId q)

View File

@@ -15,7 +15,7 @@ data QueueRec = QueueRec
status :: QueueStatus
}
data QueueStatus = QueueActive | QueueOff
data QueueStatus = QueueActive | QueueOff deriving (Eq)
class MonadQueueStore s m where
addQueue :: s -> RecipientPublicKey -> (RecipientId, SenderId) -> m (Either ErrorType ())

View File

@@ -32,7 +32,7 @@ instance MonadQueueStore QueueStore STM where
addQueue store rKey ids@(rId, sId) = do
cs@QueueStoreData {queues, senders} <- readTVar store
if M.member rId queues || M.member sId senders
then return $ Left DUPLICATE
then return $ Left DUPLICATE_
else do
writeTVar store $
cs

View File

@@ -0,0 +1,140 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Server.StoreLog
( StoreLog, -- constructors are not exported
openWriteStoreLog,
openReadStoreLog,
closeStoreLog,
logCreateQueue,
logSecureQueue,
logDeleteQueue,
readWriteStoreLog,
)
where
import Control.Applicative (optional, (<|>))
import Control.Monad (unless)
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first, second)
import Data.ByteString.Base64 (encode)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Either (partitionEithers)
import Data.Functor (($>))
import Data.List (foldl')
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Parsers (base64P, parseAll)
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.QueueStore (QueueRec (..), QueueStatus (..))
import Simplex.Messaging.Transport (trimCR)
import System.Directory (doesFileExist)
import System.IO
-- | opaque container for file handle with a type-safe IOMode
-- constructors are not exported, openWriteStoreLog and openReadStoreLog should be used instead
data StoreLog (a :: IOMode) where
ReadStoreLog :: FilePath -> Handle -> StoreLog 'ReadMode
WriteStoreLog :: FilePath -> Handle -> StoreLog 'WriteMode
data StoreLogRecord
= CreateQueue QueueRec
| SecureQueue QueueId SenderPublicKey
| DeleteQueue QueueId
storeLogRecordP :: Parser StoreLogRecord
storeLogRecordP =
"CREATE " *> createQueueP
<|> "SECURE " *> secureQueueP
<|> "DELETE " *> (DeleteQueue <$> base64P)
where
createQueueP = CreateQueue <$> queueRecP
secureQueueP = SecureQueue <$> base64P <* A.space <*> C.pubKeyP
queueRecP = do
recipientId <- "rid=" *> base64P <* A.space
senderId <- "sid=" *> base64P <* A.space
recipientKey <- "rk=" *> C.pubKeyP <* A.space
senderKey <- "sk=" *> optional C.pubKeyP
pure QueueRec {recipientId, senderId, recipientKey, senderKey, status = QueueActive}
serializeStoreLogRecord :: StoreLogRecord -> ByteString
serializeStoreLogRecord = \case
CreateQueue q -> "CREATE " <> serializeQueue q
SecureQueue rId sKey -> "SECURE " <> encode rId <> " " <> C.serializePubKey sKey
DeleteQueue rId -> "DELETE " <> encode rId
where
serializeQueue QueueRec {recipientId, senderId, recipientKey, senderKey} =
B.unwords
[ "rid=" <> encode recipientId,
"sid=" <> encode senderId,
"rk=" <> C.serializePubKey recipientKey,
"sk=" <> maybe "" C.serializePubKey senderKey
]
openWriteStoreLog :: FilePath -> IO (StoreLog 'WriteMode)
openWriteStoreLog f = WriteStoreLog f <$> openFile f WriteMode
openReadStoreLog :: FilePath -> IO (StoreLog 'ReadMode)
openReadStoreLog f = do
doesFileExist f >>= (`unless` writeFile f "")
ReadStoreLog f <$> openFile f ReadMode
closeStoreLog :: StoreLog a -> IO ()
closeStoreLog = \case
WriteStoreLog _ h -> hClose h
ReadStoreLog _ h -> hClose h
writeStoreLogRecord :: StoreLog 'WriteMode -> StoreLogRecord -> IO ()
writeStoreLogRecord (WriteStoreLog _ h) r = do
B.hPutStrLn h $ serializeStoreLogRecord r
hFlush h
logCreateQueue :: StoreLog 'WriteMode -> QueueRec -> IO ()
logCreateQueue s = writeStoreLogRecord s . CreateQueue
logSecureQueue :: StoreLog 'WriteMode -> QueueId -> SenderPublicKey -> IO ()
logSecureQueue s qId sKey = writeStoreLogRecord s $ SecureQueue qId sKey
logDeleteQueue :: StoreLog 'WriteMode -> QueueId -> IO ()
logDeleteQueue s = writeStoreLogRecord s . DeleteQueue
readWriteStoreLog :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec, StoreLog 'WriteMode)
readWriteStoreLog s@(ReadStoreLog f _) = do
qs <- readQueues s
closeStoreLog s
s' <- openWriteStoreLog f
writeQueues s' qs
pure (qs, s')
writeQueues :: StoreLog 'WriteMode -> Map RecipientId QueueRec -> IO ()
writeQueues s = mapM_ (writeStoreLogRecord s . CreateQueue) . M.filter active
where
active QueueRec {status} = status == QueueActive
type LogParsingError = (String, ByteString)
readQueues :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec)
readQueues (ReadStoreLog _ h) = LB.hGetContents h >>= returnResult . procStoreLog
where
procStoreLog :: LB.ByteString -> ([LogParsingError], Map RecipientId QueueRec)
procStoreLog = second (foldl' procLogRecord M.empty) . partitionEithers . map parseLogRecord . LB.lines
returnResult :: ([LogParsingError], Map RecipientId QueueRec) -> IO (Map RecipientId QueueRec)
returnResult (errs, res) = mapM_ printError errs $> res
parseLogRecord :: LB.ByteString -> Either LogParsingError StoreLogRecord
parseLogRecord = (\s -> first (,s) $ parseAll storeLogRecordP s) . trimCR . LB.toStrict
procLogRecord :: Map RecipientId QueueRec -> StoreLogRecord -> Map RecipientId QueueRec
procLogRecord m = \case
CreateQueue q -> M.insert (recipientId q) q m
SecureQueue qId sKey -> M.adjust (\q -> q {senderKey = Just sKey}) qId m
DeleteQueue qId -> M.delete qId m
printError :: LogParsingError -> IO ()
printError (e, s) = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s

View File

@@ -1,29 +1,52 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Messaging.Transport where
import Control.Monad.IO.Class
import Control.Applicative ((<|>))
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Control.Monad.Trans.Except (throwE)
import Crypto.Cipher.Types (AuthTag)
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteArray (xor)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Set (Set)
import qualified Data.Set as S
import Data.Word (Word32)
import GHC.Generics (Generic)
import GHC.IO.Exception (IOErrorType (..))
import GHC.IO.Handle.Internals (ioe_EOF)
import Generic.Random (genericArbitraryU)
import Network.Socket
import Network.Transport.Internal (decodeNum16, decodeNum32, encodeEnum16, encodeEnum32, encodeWord32)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Parsers (parse, parseAll, parseRead1)
import Simplex.Messaging.Util (bshow, liftError)
import System.IO
import System.IO.Error
import Test.QuickCheck (Arbitrary (..))
import UnliftIO.Concurrent
import UnliftIO.Exception (Exception, IOException)
import qualified UnliftIO.Exception as E
import qualified UnliftIO.IO as IO
import UnliftIO.STM
-- * TCP transport
runTCPServer :: MonadUnliftIO m => TMVar Bool -> ServiceName -> (Handle -> m ()) -> m ()
runTCPServer started port server = do
clients <- newTVarIO S.empty
@@ -33,7 +56,10 @@ runTCPServer started port server = do
atomically . modifyTVar clients $ S.insert tid
where
closeServer :: TVar (Set ThreadId) -> Socket -> IO ()
closeServer clients sock = readTVarIO clients >>= mapM_ killThread >> close sock
closeServer clients sock = do
readTVarIO clients >>= mapM_ killThread
close sock
void . atomically $ tryPutTMVar started False
startTCPServer :: TMVar Bool -> ServiceName -> IO Socket
startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted
@@ -59,9 +85,7 @@ runTCPClient host port client = do
client h `E.finally` IO.hClose h
startTCPClient :: HostName -> ServiceName -> IO Handle
startTCPClient host port =
withSocketsDo $
resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion
startTCPClient host port = withSocketsDo $ resolve >>= tryOpen err
where
err :: IOException
err = mkIOError NoSuchThing "no address" Nothing Nothing
@@ -71,9 +95,10 @@ startTCPClient host port =
let hints = defaultHints {addrSocketType = Stream}
in getAddrInfo (Just hints) (Just host) (Just port)
tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle)
tryOpen (Left _) addr = E.try $ open addr
tryOpen h _ = return h
tryOpen :: IOException -> [AddrInfo] -> IO Handle
tryOpen e [] = E.throwIO e
tryOpen _ (addr : as) =
E.try (open addr) >>= either (`tryOpen` as) pure
open :: AddrInfo -> IO Handle
open addr = do
@@ -93,7 +118,261 @@ putLn :: Handle -> ByteString -> IO ()
putLn h = B.hPut h . (<> "\r\n")
getLn :: Handle -> IO ByteString
getLn h = trim_cr <$> B.hGetLine h
getLn h = trimCR <$> B.hGetLine h
trimCR :: ByteString -> ByteString
trimCR "" = ""
trimCR s = if B.last s == '\r' then B.init s else s
-- * Encrypted transport
data SMPVersion = SMPVersion Int Int Int Int
deriving (Eq, Ord)
major :: SMPVersion -> (Int, Int)
major (SMPVersion a b _ _) = (a, b)
currentSMPVersion :: SMPVersion
currentSMPVersion = SMPVersion 0 2 0 0
serializeSMPVersion :: SMPVersion -> ByteString
serializeSMPVersion (SMPVersion a b c d) = B.intercalate "." [bshow a, bshow b, bshow c, bshow d]
smpVersionP :: Parser SMPVersion
smpVersionP =
let ver = A.decimal <* A.char '.'
in SMPVersion <$> ver <*> ver <*> ver <*> A.decimal
data THandle = THandle
{ handle :: Handle,
sndKey :: SessionKey,
rcvKey :: SessionKey,
blockSize :: Int
}
data SessionKey = SessionKey
{ aesKey :: C.Key,
baseIV :: C.IV,
counter :: TVar Word32
}
data ClientHandshake = ClientHandshake
{ blockSize :: Int,
sndKey :: SessionKey,
rcvKey :: SessionKey
}
data TransportError
= TEBadBlock
| TEEncrypt
| TEDecrypt
| TEHandshake HandshakeError
deriving (Eq, Generic, Read, Show, Exception)
data HandshakeError
= ENCRYPT
| DECRYPT
| VERSION
| RSA_KEY
| HEADER
| AES_KEYS
| BAD_HASH
| MAJOR_VERSION
| TERMINATED
deriving (Eq, Generic, Read, Show, Exception)
instance Arbitrary TransportError where arbitrary = genericArbitraryU
instance Arbitrary HandshakeError where arbitrary = genericArbitraryU
transportErrorP :: Parser TransportError
transportErrorP =
"BLOCK" $> TEBadBlock
<|> "AES_ENCRYPT" $> TEEncrypt
<|> "AES_DECRYPT" $> TEDecrypt
<|> TEHandshake <$> parseRead1
serializeTransportError :: TransportError -> ByteString
serializeTransportError = \case
TEEncrypt -> "AES_ENCRYPT"
TEDecrypt -> "AES_DECRYPT"
TEBadBlock -> "BLOCK"
TEHandshake e -> bshow e
tPutEncrypted :: THandle -> ByteString -> IO (Either TransportError ())
tPutEncrypted THandle {handle = h, sndKey, blockSize} block =
encryptBlock sndKey (blockSize - C.authTagSize) block >>= \case
Left _ -> pure $ Left TEEncrypt
Right (authTag, msg) -> Right <$> B.hPut h (C.authTagToBS authTag <> msg)
tGetEncrypted :: THandle -> IO (Either TransportError ByteString)
tGetEncrypted THandle {handle = h, rcvKey, blockSize} =
B.hGet h blockSize >>= decryptBlock rcvKey >>= \case
Left _ -> pure $ Left TEDecrypt
Right "" -> ioe_EOF
Right msg -> pure $ Right msg
encryptBlock :: SessionKey -> Int -> ByteString -> IO (Either C.CryptoError (AuthTag, ByteString))
encryptBlock k@SessionKey {aesKey} size block = do
ivBytes <- makeNextIV k
runExceptT $ C.encryptAES aesKey ivBytes size block
decryptBlock :: SessionKey -> ByteString -> IO (Either C.CryptoError ByteString)
decryptBlock k@SessionKey {aesKey} block = do
let (authTag, msg') = B.splitAt C.authTagSize block
ivBytes <- makeNextIV k
runExceptT $ C.decryptAES aesKey ivBytes msg' (C.bsToAuthTag authTag)
makeNextIV :: SessionKey -> IO C.IV
makeNextIV SessionKey {baseIV, counter} = atomically $ do
c <- readTVar counter
writeTVar counter $ c + 1
pure $ iv c
where
trim_cr "" = ""
trim_cr s = if B.last s == '\r' then B.init s else s
(start, rest) = B.splitAt 4 $ C.unIV baseIV
iv c = C.IV $ (start `xor` encodeWord32 c) <> rest
-- | implements server transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption
-- The numbers in function names refer to the steps in the document
serverHandshake :: Handle -> C.FullKeyPair -> ExceptT TransportError IO THandle
serverHandshake h (k, pk) = do
liftIO sendHeaderAndPublicKey_1
encryptedKeys <- receiveEncryptedKeys_4
-- TODO server currently ignores blockSize returned by the client
-- this is reserved for future support of streams
ClientHandshake {blockSize = _, sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys
th <- liftIO $ transportHandle h rcvKey sndKey transportBlockSize -- keys are swapped here
sendWelcome_6 th
pure th
where
sendHeaderAndPublicKey_1 :: IO ()
sendHeaderAndPublicKey_1 = do
let sKey = C.encodePubKey k
header = ServerHeader {blockSize = transportBlockSize, keySize = B.length sKey}
B.hPut h $ binaryServerHeader header <> sKey
receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString
receiveEncryptedKeys_4 =
liftIO (B.hGet h $ C.publicKeySize k) >>= \case
"" -> throwE $ TEHandshake TERMINATED
ks -> pure ks
decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO ClientHandshake
decryptParseKeys_5 encKeys =
liftError (const $ TEHandshake DECRYPT) (C.decryptOAEP pk encKeys)
>>= liftEither . parseClientHandshake
sendWelcome_6 :: THandle -> ExceptT TransportError IO ()
sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " "
-- | implements client transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption
-- The numbers in function names refer to the steps in the document
clientHandshake :: Handle -> Maybe C.KeyHash -> ExceptT TransportError IO THandle
clientHandshake h keyHash = do
(k, blkSize) <- getHeaderAndPublicKey_1_2
-- TODO currently client always uses the blkSize returned by the server
keys@ClientHandshake {sndKey, rcvKey} <- liftIO $ generateKeys_3 blkSize
sendEncryptedKeys_4 k keys
th <- liftIO $ transportHandle h sndKey rcvKey blkSize
getWelcome_6 th >>= checkVersion
pure th
where
getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (C.PublicKey, Int)
getHeaderAndPublicKey_1_2 = do
header <- liftIO (B.hGet h serverHeaderSize)
ServerHeader {blockSize, keySize} <- liftEither $ parse serverHeaderP (TEHandshake HEADER) header
when (blockSize < transportBlockSize || blockSize > maxTransportBlockSize) $
throwError $ TEHandshake HEADER
s <- liftIO $ B.hGet h keySize
maybe (pure ()) (validateKeyHash_2 s) keyHash
key <- liftEither $ parseKey s
pure (key, blockSize)
parseKey :: ByteString -> Either TransportError C.PublicKey
parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.binaryPubKeyP
validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO ()
validateKeyHash_2 k kHash
| C.getKeyHash k == kHash = pure ()
| otherwise = throwE $ TEHandshake BAD_HASH
generateKeys_3 :: Int -> IO ClientHandshake
generateKeys_3 blkSize = ClientHandshake blkSize <$> generateKey <*> generateKey
generateKey :: IO SessionKey
generateKey = do
aesKey <- C.randomAesKey
baseIV <- C.randomIV
pure SessionKey {aesKey, baseIV, counter = undefined}
sendEncryptedKeys_4 :: C.PublicKey -> ClientHandshake -> ExceptT TransportError IO ()
sendEncryptedKeys_4 k keys =
liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeClientHandshake keys)
>>= liftIO . B.hPut h
getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion
getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th
parseSMPVersion :: ByteString -> Either TransportError SMPVersion
parseSMPVersion = first (const $ TEHandshake VERSION) . A.parseOnly (smpVersionP <* A.space)
checkVersion :: SMPVersion -> ExceptT TransportError IO ()
checkVersion smpVersion =
when (major smpVersion > major currentSMPVersion) . throwE $
TEHandshake MAJOR_VERSION
data ServerHeader = ServerHeader {blockSize :: Int, keySize :: Int}
deriving (Eq, Show)
binaryRsaTransport :: Int
binaryRsaTransport = 0
transportBlockSize :: Int
transportBlockSize = 4096
maxTransportBlockSize :: Int
maxTransportBlockSize = 65536
serverHeaderSize :: Int
serverHeaderSize = 8
binaryServerHeader :: ServerHeader -> ByteString
binaryServerHeader ServerHeader {blockSize, keySize} =
encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> encodeEnum16 keySize
serverHeaderP :: Parser ServerHeader
serverHeaderP = ServerHeader <$> int32 <* binaryRsaTransportP <*> int16
serializeClientHandshake :: ClientHandshake -> ByteString
serializeClientHandshake ClientHandshake {blockSize, sndKey, rcvKey} =
encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> serializeKey sndKey <> serializeKey rcvKey
where
serializeKey :: SessionKey -> ByteString
serializeKey SessionKey {aesKey, baseIV} = C.unKey aesKey <> C.unIV baseIV
clientHandshakeP :: Parser ClientHandshake
clientHandshakeP = ClientHandshake <$> int32 <* binaryRsaTransportP <*> keyP <*> keyP
where
keyP :: Parser SessionKey
keyP = do
aesKey <- C.aesKeyP
baseIV <- C.ivP
pure SessionKey {aesKey, baseIV, counter = undefined}
int32 :: Parser Int
int32 = decodeNum32 <$> A.take 4
int16 :: Parser Int
int16 = decodeNum16 <$> A.take 2
binaryRsaTransportP :: Parser ()
binaryRsaTransportP = binaryRsa =<< int16
where
binaryRsa :: Int -> Parser ()
binaryRsa n
| n == binaryRsaTransport = pure ()
| otherwise = fail "unknown transport mode"
parseClientHandshake :: ByteString -> Either TransportError ClientHandshake
parseClientHandshake = parse clientHandshakeP $ TEHandshake AES_KEYS
transportHandle :: Handle -> SessionKey -> SessionKey -> Int -> IO THandle
transportHandle h sk rk blockSize = do
sndCounter <- newTVarIO 0
rcvCounter <- newTVarIO 0
pure
THandle
{ handle = h,
sndKey = sk {counter = sndCounter},
rcvKey = rk {counter = rcvCounter},
blockSize
}

View File

@@ -31,19 +31,22 @@ raceAny_ = r []
r as (m : ms) = withAsync m $ \a -> r (a : as) ms
r as [] = void $ waitAnyCancel as
infixl 4 <$$>
infixl 4 <$$>, <$?>
(<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b)
(<$$>) = fmap . fmap
(<$?>) :: MonadFail m => (a -> Either String b) -> m a -> m b
f <$?> m = m >>= either fail pure . f
bshow :: Show a => a -> ByteString
bshow = B.pack . show
liftIOEither :: (MonadUnliftIO m, MonadError e m) => IO (Either e a) -> m a
liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a
liftIOEither a = liftIO a >>= liftEither
liftError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a
liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a
liftError f = liftEitherError f . runExceptT
liftEitherError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a
liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a
liftEitherError f a = liftIOEither (first f <$> a)

View File

@@ -0,0 +1,98 @@
# Errors
## Problems
- using numbers and strings to indicate errors (in protocol and in code) - ErrorType, AgentErrorType, TransportError
- re-using the same type in multiple contexts (with some constructors not applicable to all contexts) - ErrorType
## Error types
### ErrorType (Protocol.hs)
- BLOCK - incorrect block format or encoding
- CMD error - command is unknown or has invalid syntax, where `error` can be:
- PROHIBITED - server response sent from client or vice versa
- SYNTAX - error parsing command
- NO_AUTH - transmission has no required credentials (signature or queue ID)
- HAS_AUTH - transmission has not allowed credentials
- NO_QUEUE - transmission has not queue ID
- AUTH - command is not authorised (queue does not exist or signature verification failed).
- NO_MSG - acknowledging (ACK) the message without message
- INTERNAL - internal server error.
- DUPLICATE_ - it is used internally to signal that the queue ID is already used. This is NOT used in the protocol, instead INTERNAL is sent to the client. It has to be removed.
### AgentErrorType (Agent/Transmission.hs)
Some of these errors are not correctly serialized/parsed - see line 322 in Agent/Transmission.hs
- CMD e - command or response error
- PROHIBITED - server response sent as client command (and vice versa)
- SYNTAX - command is unknown or has invalid syntax.
- NO_CONN - connection is required in the command (and absent)
- SIZE - incorrect message size of messages (when parsing SEND and MSG)
- LARGE -- message does not fit SMP block
- CONN e - connection errors
- UNKNOWN - connection alias not in database
- DUPLICATE - connection alias already exists
- SIMPLEX - connection is simplex, but operation requires another queue
- SMP ErrorType - forwarding SMP errors (SMPServerError) to the agent client
- BROKER e - SMP server errors
- RESPONSE ErrorType - invalid SMP server response
- UNEXPECTED - unexpected response
- NETWORK - network TCP connection error
- TRANSPORT TransportError -- handshake or other transport error
- TIMEOUT - command response timeout
- AGENT e - errors of other agents
- A_MESSAGE - SMP message failed to parse
- A_PROHIBITED - SMP message is prohibited with the current queue status
- A_ENCRYPTION - cannot RSA/AES-decrypt or parse decrypted header
- A_SIGNATURE - invalid RSA signature
- INTERNAL ByteString - agent implementation or dependency error
### SMPClientError (Client.hs)
- SMPServerError ErrorType - this is correctly parsed server ERR response. This error is forwarded to the agent client as `ERR SMP err`
- SMPResponseError ErrorType - this is invalid server response that failed to parse - forwarded to the client as `ERR BROKER RESPONSE`.
- SMPUnexpectedResponse - different response from what is expected to a given command, e.g. server should respond `IDS` or `ERR` to `NEW` command, other responses would result in this error - forwarded to the client as `ERR BROKER UNEXPECTED`.
- SMPResponseTimeout - used for TCP connection and command response timeouts -> `ERR BROKER TIMEOUT`.
- SMPNetworkError - fails to establish TCP connection -> `ERR BROKER NETWORK`
- SMPTransportError e - fails connection handshake or some other transport error -> `ERR BROKER TRANSPORT e`
- SMPSignatureError C.CryptoError - error when cryptographically "signing" the command.
### StoreError (Agent/Store.hs)
- SEInternal ByteString - signals exceptions in store actions.
- SEConnNotFound - connection alias not found (or both queues absent).
- SEConnDuplicate - connection alias already used.
- SEBadConnType ConnType - wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. `updateRcvConnWithSndQueue` and `updateSndConnWithRcvQueue` do not allow duplex connections - they would also return this error.
- SEBadQueueStatus - the intention was to pass current expected queue status in methods, as we always know what it should be at any stage of the protocol, and in case it does not match use this error. **Currently not used**.
- SENotImplemented - used in `getMsg` that is not implemented/used.
### CryptoError (Crypto.hs)
- RSAEncryptError R.Error - RSA encryption error
- RSADecryptError R.Error - RSA decryption error
- RSASignError R.Error - RSA signature error
- AESCipherError CE.CryptoError - AES initialization error
- CryptoIVError - IV generation error
- AESDecryptError - AES decryption error
- CryptoLargeMsgError - message does not fit in SMP block
- CryptoHeaderError String - failure parsing RSA-encrypted message header
### TransportError (Transport.hs)
- TEBadBlock - error parsing block
- TEEncrypt - block encryption error
- TEDecrypt - block decryption error
- TEHandshake HandshakeError
### HandshakeError (Transport.hs)
- ENCRYPT - encryption error
- DECRYPT - decryption error
- VERSION - error parsing protocol version
- RSA_KEY - error parsing RSA key
- AES_KEYS - error parsing AES keys
- BAD_HASH - not matching RSA key hash
- MAJOR_VERSION - lower agent version than protocol version
- TERMINATED - transport terminated

View File

@@ -35,9 +35,10 @@ packages:
# forks / in-progress versions pinned to a git hash. For example:
#
extra-deps:
- sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002
- cryptostore-0.2.1.0@sha256:9896e2984f36a1c8790f057fd5ce3da4cbcaf8aa73eb2d9277916886978c5b19,3881
- direct-sqlite-2.3.26@sha256:04e835402f1508abca383182023e4e2b9b86297b8533afbd4e57d1a5652e0c23,3718
- simple-logger-0.1.0@sha256:be8ede4bd251a9cac776533bae7fb643369ebd826eb948a9a18df1a8dd252ff8,1079
- sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002
- terminal-0.2.0.0@sha256:de6770ecaae3197c66ac1f0db5a80cf5a5b1d3b64a66a05b50f442de5ad39570,2977
# - network-run-0.2.4@sha256:7dbb06def522dab413bce4a46af476820bffdff2071974736b06f52f4ab57c96,885
# - git: https://github.com/commercialhaskell/stack.git

View File

@@ -13,6 +13,7 @@ import Control.Concurrent
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import SMPAgentClient
import SMPClient (testKeyHashStr)
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
import System.IO (Handle)
@@ -79,7 +80,7 @@ h #:# err = tryGet `shouldReturn` ()
_ -> return ()
pattern Msg :: MsgBody -> ACommand 'Agent
pattern Msg m_body <- MSG {m_body}
pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk}
testDuplexConnection :: Handle -> Handle -> IO ()
testDuplexConnection alice bob = do
@@ -87,13 +88,13 @@ testDuplexConnection alice bob = do
let qInfo' = serializeSmpQueueInfo qInfo
bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", CON)
alice <# ("", "bob", CON)
alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT _) -> True; _ -> False
alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT _) -> True; _ -> False
alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT 1) -> True; _ -> False
alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT 2) -> True; _ -> False
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False
bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT _) -> True; _ -> False
bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT 3) -> True; _ -> False
alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False
bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT _) -> True; _ -> False
bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT 4) -> True; _ -> False
alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH))
@@ -127,24 +128,24 @@ testSubscrNotification (server, _) client = do
client <# ("", "conn1", END)
samplePublicKey :: ByteString
samplePublicKey = "256,ppr3DCweAD3RTVFhU2j0u+DnYdqJl1qCdKLHIKsPl1xBzfmnzK0o9GEDlaIClbK39KzPJMljcpnYb2KlSoZ51AhwF5PH2CS+FStc3QzajiqfdOQPet23Hd9YC6pqyTQ7idntqgPrE7yKJF44lUhKlq8QS9KQcbK7W6t7F9uQFw44ceWd2eVf81UV04kQdKWJvC5Sz6jtSZNEfs9mVI8H0wi1amUvS6+7EDJbxikhcCRnFShFO9dUKRYXj6L2JVqXqO5cZgY9BScyneWIg6mhhsTcdDbITM6COlL+pF1f3TjDN+slyV+IzE+ap/9NkpsrCcI8KwwDpqEDmUUV/JQfmQ==,gj2UAiWzSj7iun0iXvI5iz5WEjaqngmB3SzQ5+iarixbaG15LFDtYs3pijG3eGfB1wIFgoP4D2z97vIWn8olT4uCTUClf29zGDDve07h/B3QG/4i0IDnio7MX3AbE8O6PKouqy/GLTfT4WxFUn423g80rpsVYd5oj+SCL2eaxIc="
samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"
syntaxTests :: Spec
syntaxTests = do
it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR SYNTAX 11")
it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX")
describe "NEW" do
describe "valid" do
-- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided)
-- TODO: add tests with defined connection alias
xit "only server" $ ("211", "", "NEW localhost") >#>= \case ("211", "", "INV" : _) -> True; _ -> False
it "with port" $ ("212", "", "NEW localhost:5000") >#>= \case ("212", "", "INV" : _) -> True; _ -> False
xit "with keyHash" $ ("213", "", "NEW localhost#1234") >#>= \case ("213", "", "INV" : _) -> True; _ -> False
it "with port and keyHash" $ ("214", "", "NEW localhost:5000#1234") >#>= \case ("214", "", "INV" : _) -> True; _ -> False
xit "with keyHash" $ ("213", "", "NEW localhost#" <> testKeyHashStr) >#>= \case ("213", "", "INV" : _) -> True; _ -> False
it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> testKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False
describe "invalid" do
-- TODO: add tests with defined connection alias
it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR SYNTAX 11")
it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR SYNTAX 11")
it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR SYNTAX 11")
it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR CMD SYNTAX")
it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR CMD SYNTAX")
it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR CMD SYNTAX")
describe "JOIN" do
describe "valid" do
@@ -154,4 +155,4 @@ syntaxTests = do
("311", "", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "", "ERR SMP AUTH")
describe "invalid" do
-- TODO: JOIN is not merged yet - to be added
it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR SYNTAX 11")
it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX")

View File

@@ -1,18 +1,22 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
module AgentTests.SQLiteTests (storeTests) where
import Control.Monad.Except (ExceptT, runExceptT)
import qualified Crypto.PubKey.RSA as R
import Data.ByteString.Char8 (ByteString)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.Time
import Data.Word (Word32)
import qualified Database.SQLite.Simple as DB
import Database.SQLite.Simple.QQ (sql)
import SMPClient (testKeyHash)
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite
import Simplex.Messaging.Agent.Transmission
@@ -48,41 +52,50 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e
-- TODO add null port tests
storeTests :: Spec
storeTests = withStore do
describe "compiled as threadsafe" testCompiledThreadsafe
describe "foreign keys enabled" testForeignKeysEnabled
describe "store setup" do
testCompiledThreadsafe
testForeignKeysEnabled
describe "store methods" do
describe "createRcvConn" testCreateRcvConn
describe "createSndConn" testCreateSndConn
describe "getAllConnAliases" testGetAllConnAliases
describe "getRcvQueue" testGetRcvQueue
describe "deleteConn" do
describe "RcvConnection" testDeleteRcvConn
describe "SndConnection" testDeleteSndConn
describe "DuplexConnection" testDeleteDuplexConn
describe "upgradeRcvConnToDuplex" testUpgradeRcvConnToDuplex
describe "upgradeSndConnToDuplex" testUpgradeSndConnToDuplex
describe "set queue status" do
describe "setRcvQueueStatus" testSetRcvQueueStatus
describe "setSndQueueStatus" testSetSndQueueStatus
describe "DuplexConnection" testSetQueueStatusDuplex
xdescribe "RcvQueue doesn't exist" testSetRcvQueueStatusNoQueue
xdescribe "SndQueue doesn't exist" testSetSndQueueStatusNoQueue
describe "createRcvMsg" do
describe "RcvQueue exists" testCreateRcvMsg
describe "RcvQueue doesn't exist" testCreateRcvMsgNoQueue
describe "createSndMsg" do
describe "SndQueue exists" testCreateSndMsg
describe "SndQueue doesn't exist" testCreateSndMsgNoQueue
describe "Queue and Connection management" do
describe "createRcvConn" do
testCreateRcvConn
testCreateRcvConnDuplicate
describe "createSndConn" do
testCreateSndConn
testCreateSndConnDuplicate
describe "getAllConnAliases" testGetAllConnAliases
describe "getRcvQueue" testGetRcvQueue
describe "deleteConn" do
testDeleteRcvConn
testDeleteSndConn
testDeleteDuplexConn
describe "upgradeRcvConnToDuplex" do
testUpgradeRcvConnToDuplex
describe "upgradeSndConnToDuplex" do
testUpgradeSndConnToDuplex
describe "set Queue status" do
describe "setRcvQueueStatus" do
testSetRcvQueueStatus
testSetRcvQueueStatusNoQueue
describe "setSndQueueStatus" do
testSetSndQueueStatus
testSetSndQueueStatusNoQueue
testSetQueueStatusDuplex
describe "Msg management" do
describe "create Msg" do
testCreateRcvMsg
testCreateSndMsg
testCreateRcvAndSndMsgs
testCompiledThreadsafe :: SpecWith SQLiteStore
testCompiledThreadsafe = do
it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do
it "compiled sqlite library should be threadsafe" $ \store -> do
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
compileOptions `shouldNotContain` [["THREADSAFE=0"]]
testForeignKeysEnabled :: SpecWith SQLiteStore
testForeignKeysEnabled = do
it "should throw error if foreign keys are enabled" $ \store -> do
it "foreign keys should be enabled" $ \store -> do
let inconsistentQuery =
[sql|
INSERT INTO connections
@@ -96,13 +109,13 @@ testForeignKeysEnabled = do
rcvQueue1 :: RcvQueue
rcvQueue1 =
RcvQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
rcvId = "1234",
connAlias = "conn1",
rcvPrivateKey = C.PrivateKey 1 2 3,
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
sndId = Just "2345",
sndKey = Nothing,
decryptKey = C.PrivateKey 1 2 3,
decryptKey = C.safePrivateKey (1, 2, 3),
verifyKey = Nothing,
status = New
}
@@ -110,12 +123,12 @@ rcvQueue1 =
sndQueue1 :: SndQueue
sndQueue1 =
SndQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
sndId = "3456",
connAlias = "conn1",
sndPrivateKey = C.PrivateKey 1 2 3,
sndPrivateKey = C.safePrivateKey (1, 2, 3),
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
signKey = C.PrivateKey 1 2 3,
signKey = C.safePrivateKey (1, 2, 3),
status = New
}
@@ -131,6 +144,13 @@ testCreateRcvConn = do
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
testCreateRcvConnDuplicate = do
it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
createRcvConn store rcvQueue1
`throwsError` SEConnDuplicate
testCreateSndConn :: SpecWith SQLiteStore
testCreateSndConn = do
it "should create SndConnection and add RcvQueue" $ \store -> do
@@ -143,118 +163,113 @@ testCreateSndConn = do
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
testCreateSndConnDuplicate :: SpecWith SQLiteStore
testCreateSndConnDuplicate = do
it "should throw error on attempt to create duplicate SndConnection" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
createSndConn store sndQueue1
`throwsError` SEConnDuplicate
testGetAllConnAliases :: SpecWith SQLiteStore
testGetAllConnAliases = do
it "should get all conn aliases" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
createSndConn store sndQueue1 {connAlias = "conn2"}
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
_ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"}
getAllConnAliases store
`returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias]
testGetRcvQueue :: SpecWith SQLiteStore
testGetRcvQueue = do
it "should get RcvQueue" $ \store -> do
let smpServer = SMPServer "smp.simplex.im" (Just "5223") (Just "1234")
let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash
let recipientId = "1234"
createRcvConn store rcvQueue1
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
getRcvQueue store smpServer recipientId
`returnsResult` rcvQueue1
testDeleteRcvConn :: SpecWith SQLiteStore
testDeleteRcvConn = do
it "should create RcvConnection and delete it" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
getConn store "conn1"
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
deleteConn store "conn1"
`returnsResult` ()
-- TODO check queues are deleted as well
getConn store "conn1"
`throwsError` SEBadConn
`throwsError` SEConnNotFound
testDeleteSndConn :: SpecWith SQLiteStore
testDeleteSndConn = do
it "should create SndConnection and delete it" $ \store -> do
createSndConn store sndQueue1
`returnsResult` ()
_ <- runExceptT $ createSndConn store sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
deleteConn store "conn1"
`returnsResult` ()
-- TODO check queues are deleted as well
getConn store "conn1"
`throwsError` SEBadConn
`throwsError` SEConnNotFound
testDeleteDuplexConn :: SpecWith SQLiteStore
testDeleteDuplexConn = do
it "should create DuplexConnection and delete it" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
upgradeRcvConnToDuplex store "conn1" sndQueue1
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
deleteConn store "conn1"
`returnsResult` ()
-- TODO check queues are deleted as well
getConn store "conn1"
`throwsError` SEBadConn
`throwsError` SEConnNotFound
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
testUpgradeRcvConnToDuplex = do
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do
createSndConn store sndQueue1
`returnsResult` ()
_ <- runExceptT $ createSndConn store sndQueue1
let anotherSndQueue =
SndQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
sndId = "2345",
connAlias = "conn1",
sndPrivateKey = C.PrivateKey 1 2 3,
sndPrivateKey = C.safePrivateKey (1, 2, 3),
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
signKey = C.PrivateKey 1 2 3,
signKey = C.safePrivateKey (1, 2, 3),
status = New
}
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
`throwsError` SEBadConnType CSnd
upgradeSndConnToDuplex store "conn1" rcvQueue1
`returnsResult` ()
_ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
`throwsError` SEBadConnType CDuplex
testUpgradeSndConnToDuplex :: SpecWith SQLiteStore
testUpgradeSndConnToDuplex = do
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
let anotherRcvQueue =
RcvQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"),
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
rcvId = "3456",
connAlias = "conn1",
rcvPrivateKey = C.PrivateKey 1 2 3,
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
sndId = Just "4567",
sndKey = Nothing,
decryptKey = C.PrivateKey 1 2 3,
decryptKey = C.safePrivateKey (1, 2, 3),
verifyKey = Nothing,
status = New
}
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
`throwsError` SEBadConnType CRcv
upgradeRcvConnToDuplex store "conn1" sndQueue1
`returnsResult` ()
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
`throwsError` SEBadConnType CDuplex
testSetRcvQueueStatus :: SpecWith SQLiteStore
testSetRcvQueueStatus = do
it "should update status of RcvQueue" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
getConn store "conn1"
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
setRcvQueueStatus store rcvQueue1 Confirmed
@@ -265,8 +280,7 @@ testSetRcvQueueStatus = do
testSetSndQueueStatus :: SpecWith SQLiteStore
testSetSndQueueStatus = do
it "should update status of SndQueue" $ \store -> do
createSndConn store sndQueue1
`returnsResult` ()
_ <- runExceptT $ createSndConn store sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
setSndQueueStatus store sndQueue1 Confirmed
@@ -277,10 +291,8 @@ testSetSndQueueStatus = do
testSetQueueStatusDuplex :: SpecWith SQLiteStore
testSetQueueStatusDuplex = do
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
upgradeRcvConnToDuplex store "conn1" sndQueue1
`returnsResult` ()
_ <- runExceptT $ createRcvConn store rcvQueue1
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
setRcvQueueStatus store rcvQueue1 Secured
@@ -290,61 +302,88 @@ testSetQueueStatusDuplex = do
setSndQueueStatus store sndQueue1 Confirmed
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn
SCDuplex
( DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}
)
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore
testSetRcvQueueStatusNoQueue = do
it "should throw error on attempt to update status of nonexistent RcvQueue" $ \store -> do
xit "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do
setRcvQueueStatus store rcvQueue1 Confirmed
`throwsError` SEInternal
`throwsError` SEConnNotFound
testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore
testSetSndQueueStatusNoQueue = do
it "should throw error on attempt to update status of nonexistent SndQueue" $ \store -> do
xit "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do
setSndQueueStatus store sndQueue1 Confirmed
`throwsError` SEInternal
`throwsError` SEConnNotFound
hw :: ByteString
hw = encodeUtf8 "Hello world!"
ts :: UTCTime
ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData
mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
RcvMsgData
{ internalId,
internalRcvId,
internalTs = ts,
senderMeta = (externalSndId, ts),
brokerMeta = (brokerId, ts),
msgBody = hw,
internalHash,
externalPrevSndHash = "hash_from_sender",
msgIntegrity = MsgOk
}
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation
testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do
updateRcvIds store rcvQueue
`returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
createRcvMsg store rcvQueue rcvMsgData
`returnsResult` ()
testCreateRcvMsg :: SpecWith SQLiteStore
testCreateRcvMsg = do
it "should create a RcvMsg and return InternalId" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
it "should reserve internal ids and create a RcvMsg" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
-- TODO getMsg to check message
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
`returnsResult` InternalId 1
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
testCreateRcvMsgNoQueue :: SpecWith SQLiteStore
testCreateRcvMsgNoQueue = do
it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
`throwsError` SEBadConn
createSndConn store sndQueue1
`returnsResult` ()
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
`throwsError` SEBadConnType CSnd
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
mkSndMsgData internalId internalSndId internalHash =
SndMsgData
{ internalId,
internalSndId,
internalTs = ts,
msgBody = hw,
internalHash
}
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation
testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do
updateSndIds store sndQueue
`returnsResult` (internalId, internalSndId, expectedPrevHash)
createSndMsg store sndQueue sndMsgData
`returnsResult` ()
testCreateSndMsg :: SpecWith SQLiteStore
testCreateSndMsg = do
it "should create a SndMsg and return InternalId" $ \store -> do
createSndConn store sndQueue1
`returnsResult` ()
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
-- TODO getMsg to check message
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
`returnsResult` InternalId 1
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
testCreateSndMsgNoQueue :: SpecWith SQLiteStore
testCreateSndMsgNoQueue = do
it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
`throwsError` SEBadConn
createRcvConn store rcvQueue1
`returnsResult` ()
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
`throwsError` SEBadConnType CRcv
testCreateRcvAndSndMsgs :: SpecWith SQLiteStore
testCreateRcvAndSndMsgs = do
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"

View File

@@ -0,0 +1,18 @@
module ProtocolErrorTests where
import Simplex.Messaging.Agent.Transmission (AgentErrorType, agentErrorTypeP, serializeAgentError)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Protocol (ErrorType, errorTypeP, serializeErrorType)
import Test.Hspec
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck
protocolErrorTests :: Spec
protocolErrorTests = modifyMaxSuccess (const 1000) $ do
describe "errors parsing / serializing" $ do
it "should parse SMP protocol errors" . property $ \err ->
parseAll errorTypeP (serializeErrorType err)
== Right (err :: ErrorType)
it "should parse SMP agent errors" . property $ \err ->
parseAll agentErrorTypeP (serializeAgentError err)
== Right (err :: AgentErrorType)

View File

@@ -6,23 +6,19 @@
module SMPAgentClient where
import Control.Monad
import Control.Monad.IO.Unlift
import Crypto.Random
import Network.Socket (HostName, ServiceName)
import SMPClient (testPort, withSmpServer, withSmpServerThreadOn)
import SMPClient (serverBracket, testPort, withSmpServer, withSmpServerThreadOn)
import Simplex.Messaging.Agent (runSMPAgentBlocking)
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Transmission
import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig)
import Simplex.Messaging.Transport
import System.Timeout (timeout)
import Test.Hspec
import UnliftIO.Concurrent
import UnliftIO.Directory
import qualified UnliftIO.Exception as E
import UnliftIO.IO
import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar)
agentTestHost :: HostName
agentTestHost = "localhost"
@@ -125,12 +121,10 @@ cfg =
}
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn (port', db') f = do
started <- newEmptyTMVarIO
E.bracket
(forkIOWithUnmask ($ runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'}))
(liftIO . killThread >=> const (removeFile db'))
\x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x
withSmpAgentThreadOn (port', db') =
serverBracket
(\started -> runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'})
(removeFile db')
withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a
withSmpAgentOn (port', db') = withSmpAgentThreadOn (port', db') . const

View File

@@ -1,23 +1,30 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module SMPClient where
import Control.Monad (void)
import Control.Monad.Except (runExceptT)
import Control.Monad.IO.Unlift
import Crypto.Random
import Data.ByteString.Base64 (encode)
import qualified Data.ByteString.Char8 as B
import Network.Socket
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server (runSMPServerBlocking)
import Simplex.Messaging.Server.Env.STM
import Simplex.Messaging.Server.StoreLog (openReadStoreLog)
import Simplex.Messaging.Transport
import System.Timeout (timeout)
import Test.Hspec
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import UnliftIO.IO
import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar)
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
import UnliftIO.Timeout (timeout)
testHost :: HostName
testHost = "localhost"
@@ -25,13 +32,21 @@ testHost = "localhost"
testPort :: ServiceName
testPort = "5000"
testSMPClient :: MonadUnliftIO m => (Handle -> m a) -> m a
testSMPClient client = do
runTCPClient testHost testPort $ \h -> do
line <- liftIO $ getLn h
if line == "Welcome to SMP v0.2.0"
then client h
else error "not connected"
testKeyHashStr :: B.ByteString
testKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="
testKeyHash :: Maybe C.KeyHash
testKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="
testStoreLogFile :: FilePath
testStoreLogFile = "tests/tmp/smp-server-store.log"
testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a
testSMPClient client =
runTCPClient testHost testPort $ \h ->
liftIO (runExceptT $ clientHandshake h testKeyHash) >>= \case
Right th -> client th
Left e -> error $ show e
cfg :: ServerConfig
cfg =
@@ -39,16 +54,66 @@ cfg =
{ tcpPort = testPort,
tbqSize = 1,
queueIdBytes = 12,
msgIdBytes = 6
msgIdBytes = 6,
storeLog = Nothing,
serverPrivateKey =
-- full RSA private key (only for tests)
"MIIFIwIBAAKCAQEArZyrri/NAwt5buvYjwu+B/MQeJUszDBpRgVqNddlI9kNwDXu\
\kaJ8chEhrtaUgXeSWGooWwqjXEUQE6RVbCC6QVo9VEBSP4xFwVVd9Fj7OsgfcXXh\
\AqWxfctDcBZQ5jTUiJpdBc+Vz2ZkumVNl0W+j9kWm9nfkMLQj8c0cVSDxz4OKpZb\
\qFuj0uzHkis7e7wsrKSKWLPg3M5ZXPZM1m9qn7SfJzDRDfJifamxWI7uz9XK2+Dp\
\NkUQlGQgFJEv1cKN88JAwIqZ1s+TAQMQiB+4QZ2aNfSqGEzRJN7FMCKRK7pM0A9A\
\PCnijyuImvKFxTdk8Bx1q+XNJzsY6fBrLWJZ+QKBgQCySG4tzlcEm+tOVWRcwrWh\
\6zsczGZp9mbf9c8itRx6dlldSYuDG1qnddL70wuAZF2AgS1JZgvcRZECoZRoWP5q\
\Kq2wvpTIYjFPpC39lxgUoA/DXKVKZZdan+gwaVPAPT54my1CS32VrOiAY4gVJ3LJ\
\Mn1/FqZXUFQA326pau3loQKCAQEAoljmJMp88EZoy3HlHUbOjl5UEhzzVsU1TnQi\
\QmPm+aWRe2qelhjW4aTvSVE5mAUJsN6UWTeMf4uvM69Z9I5pfw2pEm8x4+GxRibY\
\iiwF2QNaLxxmzEHm1zQQPTgb39o8mgklhzFPill0JsnL3f6IkVwjFJofWSmpqEGs\
\dFSMRSXUTVXh1p/o7QZrhpwO/475iWKVS7o48N/0Xp513re3aXw+DRNuVnFEaBIe\
\TLvWM9Czn16ndAu1HYiTBuMvtRbAWnGZxU8ewzF4wlWK5tdIL5PTJDd1VhZJAKtB\
\npDvJpwxzKmjAhcTmjx0ckMIWtdVaOVm/2gWCXDty2FEdg7koQKBgQDOUUguJ/i7\
\q0jldWYRnVkotKnpInPdcEaodrehfOqYEHnvro9xlS6OeAS4Vz5AdH45zQ/4J3bV\
\2cH66tNr18ebM9nL//t5G69i89R9W7szyUxCI3LmAIdi3oSEbmz5GQBaw4l6h9Wi\
\n4FmFQaAXZrjQfO2qJcAHvWRsMp2pmqAGwKBgQDXaza0DRsKWywWznsHcmHa0cx8\
\I4jxqGaQmLO7wBJRP1NSFrywy1QfYrVX9CTLBK4V3F0PCgZ01Qv94751CzN43TgF\
\ebd/O9r5NjNTnOXzdWqETbCffLGd6kLgCMwPQWpM9ySVjXHWCGZsRAnF2F6M1O32\
\43StIifvwJQFqSM3ewKBgCaW6y7sRY90Ua7283RErezd9EyT22BWlDlACrPu3FNC\
\LtBf1j43uxBWBQrMLsHe2GtTV0xt9m0MfwZsm2gSsXcm4Xi4DJgfN+Z7rIlyy9UY\
\PCDSdZiU1qSr+NrffDrXlfiAM1cUmCdUX7eKjp/ltkUHNaOGfSn5Pdr3MkAiD/Hf\
\AoGBAKIdKCuOwuYlwjS9J+IRGuSSM4o+OxQdwGmcJDTCpyWb5dEk68e7xKIna3zf\
\jc+H+QdMXv1nkRK9bZgYheXczsXaNZUSTwpxaEldzVD3hNvsXSgJRy9fqHwA4PBq\
\vqiBHoO3RNbqg+2rmTMfDuXreME3S955ZiPZm4Z+T8Hj52mPAoGAQm5QH/gLFtY5\
\+znqU/0G8V6BKISCQMxbbmTQVcTgGySrP2gVd+e4MWvUttaZykhWqs8rpr7mgpIY\
\hul7Swx0SHFN3WpXu8uj+B6MLpRcCbDHO65qU4kQLs+IaXXsuuTjMvJ5LwjkZVrQ\
\TmKzSAw7iVWwEUZR/PeiEKazqrpp9VU="
}
withSmpServerStoreLogOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a
withSmpServerStoreLogOn port client = do
s <- liftIO $ openReadStoreLog testStoreLogFile
serverBracket
(\started -> runSMPServerBlocking started cfg {tcpPort = port, storeLog = Just s})
(pure ())
client
withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a
withSmpServerThreadOn port f = do
withSmpServerThreadOn port =
serverBracket
(\started -> runSMPServerBlocking started cfg {tcpPort = port})
(pure ())
serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
serverBracket process afterProcess f = do
started <- newEmptyTMVarIO
E.bracket
(forkIOWithUnmask ($ runSMPServerBlocking started cfg {tcpPort = port}))
(liftIO . killThread)
\x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x
(forkIOWithUnmask ($ process started))
(\t -> killThread t >> afterProcess >> waitFor started "stop")
(\t -> waitFor started "start" >> f t)
where
waitFor started s =
5_000_000 `timeout` atomically (takeTMVar started) >>= \case
Nothing -> error $ "server did not " <> s
_ -> pure ()
withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a
withSmpServerOn port = withSmpServerThreadOn port . const
@@ -56,33 +121,43 @@ withSmpServerOn port = withSmpServerThreadOn port . const
withSmpServer :: (MonadUnliftIO m, MonadRandom m) => m a -> m a
withSmpServer = withSmpServerOn testPort
runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a
runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (THandle -> m a) -> m a
runSmpTest test = withSmpServer $ testSMPClient test
runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([Handle] -> m a) -> m a
runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([THandle] -> m a) -> m a
runSmpTestN nClients test = withSmpServer $ run nClients []
where
run :: Int -> [Handle] -> m a
run :: Int -> [THandle] -> m a
run 0 hs = test hs
run n hs = testSMPClient $ \h -> run (n - 1) (h : hs)
smpServerTest :: RawTransmission -> IO RawTransmission
smpServerTest cmd = runSmpTest $ \h -> tPutRaw h cmd >> tGetRaw h
smpTest :: (Handle -> IO ()) -> Expectation
smpTest :: (THandle -> IO ()) -> Expectation
smpTest test' = runSmpTest test' `shouldReturn` ()
smpTestN :: Int -> ([Handle] -> IO ()) -> Expectation
smpTestN :: Int -> ([THandle] -> IO ()) -> Expectation
smpTestN n test' = runSmpTestN n test' `shouldReturn` ()
smpTest2 :: (Handle -> Handle -> IO ()) -> Expectation
smpTest2 :: (THandle -> THandle -> IO ()) -> Expectation
smpTest2 test' = smpTestN 2 _test
where
_test [h1, h2] = test' h1 h2
_test _ = error "expected 2 handles"
smpTest3 :: (Handle -> Handle -> Handle -> IO ()) -> Expectation
smpTest3 :: (THandle -> THandle -> THandle -> IO ()) -> Expectation
smpTest3 test' = smpTestN 3 _test
where
_test [h1, h2, h3] = test' h1 h2 h3
_test _ = error "expected 3 handles"
tPutRaw :: THandle -> RawTransmission -> IO ()
tPutRaw h (sig, corrId, queueId, command) = do
let t = B.intercalate " " [corrId, queueId, command]
void $ tPut h (C.Signature sig, t)
tGetRaw :: THandle -> IO RawTransmission
tGetRaw h = do
("", (CorrId corrId, qId, Right cmd)) <- tGet fromServer h
pure ("", corrId, encode qId, serializeCommand cmd)

View File

@@ -4,22 +4,29 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
module ServerTests where
import Control.Concurrent (ThreadId, killThread)
import Control.Concurrent.STM
import Control.Exception (SomeException, try)
import Control.Monad.Except (forM_, runExceptT)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import SMPClient
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol
import System.IO (Handle)
import Simplex.Messaging.Transport
import System.Directory (removeFile)
import System.TimeIt (timeItT)
import System.Timeout
import Test.HUnit
import Test.Hspec
rsaKeySize :: Int
rsaKeySize = 1024 `div` 8
rsaKeySize = 2048 `div` 8
serverTests :: Spec
serverTests = do
@@ -30,18 +37,20 @@ serverTests = do
describe "SMP messages" do
describe "duplex communication over 2 SMP connections" testDuplex
describe "switch subscription to another SMP queue" testSwitchSub
describe "Store log" testWithStoreLog
describe "Timing of AUTH error" testTiming
pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError
pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command)))
sendRecv :: Handle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
sendRecv :: THandle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> tGet fromServer h
signSendRecv :: Handle -> C.PrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
signSendRecv :: THandle -> C.SafePrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError
signSendRecv h pk (corrId, qId, cmd) = do
let t = B.intercalate "\r\n" [corrId, encode qId, cmd]
Right sig <- C.sign pk t
tPut h (sig, t)
let t = B.intercalate " " [corrId, encode qId, cmd]
Right sig <- runExceptT $ C.sign pk t
_ <- tPut h (sig, t)
tGet fromServer h
cmdSEND :: ByteString -> ByteString
@@ -61,7 +70,7 @@ testCreateSecure =
Resp "abcd" rId1 (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
(rId1, "") #== "creates queue"
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5\r\nhello")
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5 hello ")
(ok1, OK) #== "accepts unsigned SEND"
(sId1, sId) #== "same queue ID in response 1"
@@ -72,10 +81,10 @@ testCreateSecure =
(ok4, OK) #== "replies OK when message acknowledged if no more messages"
Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, "ACK")
(err6, ERR PROHIBITED) #== "replies ERR when message acknowledged without messages"
(err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages"
(sPub, sKey) <- C.generateKeyPair rsaKeySize
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5\r\nhello")
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5 hello ")
(err1, ERR AUTH) #== "rejects signed SEND"
(sId2, sId) #== "same queue ID in response 2"
@@ -93,7 +102,7 @@ testCreateSecure =
Resp "abcd" _ err4 <- signSendRecv h rKey ("abcd", rId, keyCmd)
(err4, ERR AUTH) #== "rejects KEY if already secured"
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11\r\nhello again")
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11 hello again ")
(ok3, OK) #== "accepts signed SEND"
Resp "" _ (MSG _ _ msg) <- tGet fromServer h
@@ -102,7 +111,7 @@ testCreateSecure =
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, "ACK")
(ok5, OK) #== "replies OK when message acknowledged 2"
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5\r\nhello")
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5 hello ")
(err5, ERR AUTH) #== "rejects unsigned SEND"
testCreateDelete :: Spec
@@ -117,10 +126,10 @@ testCreateDelete =
Resp "bcda" _ ok1 <- signSendRecv rh rKey ("bcda", rId, "KEY " <> C.serializePubKey sPub)
(ok1, OK) #== "secures queue"
Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello")
Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ")
(ok2, OK) #== "accepts signed SEND"
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7\r\nhello 2")
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7 hello 2 ")
(ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed"
Resp "" _ (MSG _ _ msg1) <- tGet fromServer rh
@@ -136,10 +145,10 @@ testCreateDelete =
(ok3, OK) #== "suspends queue"
(rId2, rId) #== "same queue ID in response 2"
Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5\r\nhello")
Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5 hello ")
(err3, ERR AUTH) #== "rejects signed SEND"
Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5\r\nhello")
Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5 hello ")
(err4, ERR AUTH) #== "reject unsigned SEND too"
Resp "bcda" _ ok4 <- signSendRecv rh rKey ("bcda", rId, "OFF")
@@ -158,10 +167,10 @@ testCreateDelete =
(ok6, OK) #== "deletes queue"
(rId3, rId) #== "same queue ID in response 3"
Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello")
Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ")
(err7, ERR AUTH) #== "rejects signed SEND when deleted"
Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5\r\nhello")
Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5 hello ")
(err8, ERR AUTH) #== "rejects unsigned SEND too when deleted"
Resp "abcd" _ err11 <- signSendRecv rh rKey ("abcd", rId, "ACK")
@@ -211,7 +220,7 @@ testDuplex =
(aliceKey, C.serializePubKey asPub) #== "key received from Alice"
Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, "KEY " <> aliceKey)
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8\r\nhi alice")
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8 hi alice ")
Resp "" _ (MSG _ _ msg4) <- tGet fromServer alice
Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, "ACK")
@@ -229,7 +238,7 @@ testSwitchSub =
smpTest3 \rh1 rh2 sh -> do
(rPub, rKey) <- C.generateKeyPair rsaKeySize
Resp "abcd" _ (IDS rId sId) <- signSendRecv rh1 rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5\r\ntest1")
Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5 test1 ")
(ok1, OK) #== "sent test message 1"
Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, cmdSEND "test2, no ACK")
(ok2, OK) #== "sent test message 2"
@@ -246,13 +255,13 @@ testSwitchSub =
Resp "" _ end <- tGet fromServer rh1
(end, END) #== "unsubscribed the 1st TCP connection"
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5\r\ntest3")
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5 test3 ")
Resp "" _ (MSG _ _ msg3) <- tGet fromServer rh2
(msg3, "test3") #== "delivered to the 2nd TCP connection"
Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, "ACK")
(err, ERR PROHIBITED) #== "rejects ACK from the 1st TCP connection"
(err, ERR NO_MSG) #== "rejects ACK from the 1st TCP connection"
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, "ACK")
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
@@ -261,40 +270,128 @@ testSwitchSub =
Nothing -> return ()
Just _ -> error "nothing else is delivered to the 1st TCP connection"
testWithStoreLog :: Spec
testWithStoreLog =
it "should store simplex queues to log and restore them after server restart" $ do
(sPub1, sKey1) <- C.generateKeyPair rsaKeySize
(sPub2, sKey2) <- C.generateKeyPair rsaKeySize
senderId1 <- newTVarIO ""
senderId2 <- newTVarIO ""
withSmpServerStoreLogOn testPort . runTest $ \h -> do
(sId1, _, _) <- createAndSecureQueue h sPub1
atomically $ writeTVar senderId1 sId1
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ")
Resp "" _ (MSG _ _ "hello") <- tGet fromServer h
(sId2, rId2, rKey2) <- createAndSecureQueue h sPub2
atomically $ writeTVar senderId2 sId2
Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ")
Resp "" _ (MSG _ _ "hello too") <- tGet fromServer h
Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, "DEL")
pure ()
logSize `shouldReturn` 5
withSmpServerThreadOn testPort . runTest $ \h -> do
sId1 <- readTVarIO senderId1
-- fails if store log is disabled
Resp "bcda" _ (ERR AUTH) <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ")
pure ()
withSmpServerStoreLogOn testPort . runTest $ \h -> do
-- this queue is restored
sId1 <- readTVarIO senderId1
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ")
-- this queue is removed - not restored
sId2 <- readTVarIO senderId2
Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ")
pure ()
logSize `shouldReturn` 1
removeFile testStoreLogFile
where
createAndSecureQueue :: THandle -> SenderPublicKey -> IO (SenderId, RecipientId, C.SafePrivateKey)
createAndSecureQueue h sPub = do
(rPub, rKey) <- C.generateKeyPair rsaKeySize
Resp "abcd" "" (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
let keyCmd = "KEY " <> C.serializePubKey sPub
Resp "dabc" rId' OK <- signSendRecv h rKey ("dabc", rId, keyCmd)
(rId', rId) #== "same queue ID"
pure (sId, rId, rKey)
runTest :: (THandle -> IO ()) -> ThreadId -> Expectation
runTest test' server = do
testSMPClient test' `shouldReturn` ()
killThread server
logSize :: IO Int
logSize =
try (length . B.lines <$> B.readFile testStoreLogFile) >>= \case
Right l -> pure l
Left (_ :: SomeException) -> logSize
testTiming :: Spec
testTiming =
it "should have similar time for auth error whether queue exists or not" $
smpTest2 \rh sh -> do
(rPub, rKey) <- C.generateKeyPair rsaKeySize
Resp "abcd" "" (IDS rId sId) <- signSendRecv rh rKey ("abcd", "", "NEW " <> C.serializePubKey rPub)
(sPub, sKey) <- C.generateKeyPair rsaKeySize
let keyCmd = "KEY " <> C.serializePubKey sPub
Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, keyCmd)
Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, "SEND 5 hello ")
timeNoQueue <- timeRepeat 25 $ do
Resp "dabc" _ (ERR AUTH) <- signSendRecv sh sKey ("dabc", rId, "SEND 5 hello ")
return ()
timeWrongKey <- timeRepeat 25 $ do
Resp "cdab" _ (ERR AUTH) <- signSendRecv sh rKey ("cdab", sId, "SEND 5 hello ")
return ()
abs (timeNoQueue - timeWrongKey) / timeNoQueue < 0.15 `shouldBe` True
where
timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const
samplePubKey :: ByteString
samplePubKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"
syntaxTests :: Spec
syntaxTests = do
it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR SYNTAX 2")
it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX")
describe "NEW" do
it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR SYNTAX 2")
it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR SYNTAX 2")
it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR SYNTAX 3")
it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR SYNTAX 4")
it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX")
it "many parameters" $ ("1234", "cdab", "", "NEW 1 " <> samplePubKey) >#> ("", "cdab", "", "ERR CMD SYNTAX")
it "no signature" $ ("", "dabc", "", "NEW " <> samplePubKey) >#> ("", "dabc", "", "ERR CMD NO_AUTH")
it "queue ID" $ ("1234", "abcd", "12345678", "NEW " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH")
describe "KEY" do
it "valid syntax" $ ("1234", "bcda", "12345678", "KEY 3,4567,4567") >#> ("", "bcda", "12345678", "ERR AUTH")
it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR SYNTAX 2")
it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR SYNTAX 2")
it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR SYNTAX 3")
it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR SYNTAX 3")
it "valid syntax" $ ("1234", "bcda", "12345678", "KEY " <> samplePubKey) >#> ("", "bcda", "12345678", "ERR AUTH")
it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX")
it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 " <> samplePubKey) >#> ("", "dabc", "12345678", "ERR CMD SYNTAX")
it "no signature" $ ("", "abcd", "12345678", "KEY " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH")
it "no queue ID" $ ("1234", "bcda", "", "KEY " <> samplePubKey) >#> ("", "bcda", "", "ERR CMD NO_AUTH")
noParamsSyntaxTest "SUB"
noParamsSyntaxTest "ACK"
noParamsSyntaxTest "OFF"
noParamsSyntaxTest "DEL"
describe "SEND" do
it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5\r\nhello") >#> ("", "cdab", "12345678", "ERR AUTH")
it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11\r\nhello there") >#> ("", "dabc", "12345678", "ERR AUTH")
it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR SYNTAX 2")
it "no queue ID" $ ("1234", "bcda", "", "SEND 5\r\nhello") >#> ("", "bcda", "", "ERR SYNTAX 5")
it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello") >#> ("", "cdab", "12345678", "ERR SYNTAX 2")
it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello") >#> ("", "dabc", "12345678", "ERR SYNTAX 2")
it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4\r\nhello") >#> ("", "abcd", "12345678", "ERR SIZE")
it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH")
it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH")
it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX")
it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR CMD NO_QUEUE")
it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX")
it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX")
it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX")
describe "PING" do
it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG")
describe "broker response not allowed" do
it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR PROHIBITED")
it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR CMD PROHIBITED")
where
noParamsSyntaxTest :: ByteString -> Spec
noParamsSyntaxTest cmd = describe (B.unpack cmd) do
it "valid syntax" $ ("1234", "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH")
it "parameters" $ ("1234", "bcda", "12345678", cmd <> " 1") >#> ("", "bcda", "12345678", "ERR SYNTAX 2")
it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR SYNTAX 3")
it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR SYNTAX 3")
it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR CMD SYNTAX")
it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR CMD NO_AUTH")
it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR CMD NO_AUTH")

View File

@@ -1,5 +1,6 @@
import AgentTests
import MarkdownTests
import ProtocolErrorTests
import ServerTests
import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
import Test.Hspec
@@ -9,6 +10,7 @@ main = do
createDirectoryIfMissing False "tests/tmp"
hspec $ do
describe "SimpleX markdown" markdownTests
describe "Protocol errors" protocolErrorTests
describe "SMP server" serverTests
describe "SMP client agent" agentTests
removeDirectoryRecursive "tests/tmp"