diff --git a/apps/dog-food/ChatOptions.hs b/apps/dog-food/ChatOptions.hs index 8da33f816..8d0a0560f 100644 --- a/apps/dog-food/ChatOptions.hs +++ b/apps/dog-food/ChatOptions.hs @@ -1,11 +1,12 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} module ChatOptions (getChatOpts, ChatOpts (..)) where -import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteString.Char8 as B import Options.Applicative import Simplex.Messaging.Agent.Transmission (SMPServer (..), smpServerP) +import Simplex.Messaging.Parsers (parseAll) import System.FilePath (combine) import Types @@ -30,8 +31,8 @@ chatOpts appDir = ( long "server" <> short 's' <> metavar "SERVER" - <> help "SMP server to use (smp.simplex.im:5223)" - <> value (SMPServer "smp.simplex.im" (Just "5223") Nothing) + <> help "SMP server to use (smp1.simplex.im:5223#pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA=)" + <> value (SMPServer "smp1.simplex.im" (Just "5223") (Just "pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA=")) ) <*> option parseTermMode @@ -45,7 +46,7 @@ chatOpts appDir = defaultDbFilePath = combine appDir "smp-chat.db" parseSMPServer :: ReadM SMPServer -parseSMPServer = eitherReader $ A.parseOnly (smpServerP <* A.endOfInput) . B.pack +parseSMPServer = eitherReader $ parseAll smpServerP . B.pack parseTermMode :: ReadM TermMode parseTermMode = maybeReader $ \case diff --git a/apps/dog-food/Main.hs b/apps/dog-food/Main.hs index 7204971cd..e96d16d8d 100644 --- a/apps/dog-food/Main.hs +++ b/apps/dog-food/Main.hs @@ -31,6 +31,7 @@ import Simplex.Messaging.Agent.Client (AgentClient (..)) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (smpDefaultConfig) +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (raceAny_) import Styled import System.Console.ANSI.Types @@ -89,6 +90,7 @@ data ChatResponse | ReceivedMessage Contact ByteString | Disconnected Contact | YesYes + | ContactError ConnectionErrorType Contact | ErrorInput ByteString | ChatError AgentErrorType | NoChatResponse @@ -107,8 +109,13 @@ serializeChatResponse = \case Connected c -> [ttyContact c <> " connected"] Confirmation c -> [ttyContact c <> " ok"] ReceivedMessage c t -> prependFirst (ttyFromContact c) $ msgPlain t + -- TODO either add command to re-connect or update message below Disconnected c -> ["disconnected from " <> ttyContact c <> " - try \"/chat " <> bPlain (toBs c) <> "\""] YesYes -> ["you got it!"] + ContactError e c -> case e of + UNKNOWN -> ["no contact " <> ttyContact c] + DUPLICATE -> ["contact " <> ttyContact c <> " already exists"] + SIMPLEX -> ["contact " <> ttyContact c <> " did not accept invitation yet"] ErrorInput t -> ["invalid input: " <> bPlain t] ChatError e -> ["chat error: " <> plain (show e)] NoChatResponse -> [""] @@ -172,7 +179,7 @@ main = do t <- getChatClient smpServer ct <- newChatTerminal (tbqSize cfg) termMode -- setLogLevel LogInfo -- LogError - -- withGlobalLogging logCfg $ + -- withGlobalLogging logCfg $ do env <- newSMPAgentEnv cfg {dbFile = dbFileName} dogFoodChat t ct env @@ -209,7 +216,7 @@ newChatClient qSize smpServer = do receiveFromChatTerm :: ChatClient -> ChatTerminal -> IO () receiveFromChatTerm t ct = forever $ do atomically (readTBQueue $ inputQ ct) - >>= processOrError . A.parseOnly (chatCommandP <* A.endOfInput) . encodeUtf8 . T.pack + >>= processOrError . parseAll chatCommandP . encodeUtf8 . T.pack where processOrError = \case Left err -> writeOutQ . ErrorInput $ B.pack err @@ -259,9 +266,10 @@ receiveFromAgent t ct c = forever . atomically $ do INV qInfo -> Invitation qInfo CON -> Connected contact END -> Disconnected contact - MSG {m_body} -> ReceivedMessage contact m_body + MSG {msgBody} -> ReceivedMessage contact msgBody SENT _ -> NoChatResponse OK -> Confirmation contact + ERR (CONN e) -> ContactError e contact ERR e -> ChatError e where contact = Contact a diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index d43c97193..2de39e47c 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -1,7 +1,28 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} + module Main where +import Control.Monad (unless, when) +import qualified Crypto.Store.PKCS8 as S +import qualified Data.ByteString.Char8 as B +import Data.Char (toLower) +import Data.Functor (($>)) +import Data.Ini (lookupValue, readIniFile) +import qualified Data.Text as T +import Data.X509 (PrivKey (PrivKeyRSA)) +import Options.Applicative +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.Env.STM +import Simplex.Messaging.Server.StoreLog (StoreLog, openReadStoreLog) +import System.Directory (createDirectoryIfMissing, doesFileExist) +import System.Exit (exitFailure) +import System.FilePath (combine) +import System.IO (IOMode (..), hFlush, stdout) cfg :: ServerConfig cfg = @@ -9,10 +30,137 @@ cfg = { tcpPort = "5223", tbqSize = 16, queueIdBytes = 12, - msgIdBytes = 6 + msgIdBytes = 6, + storeLog = Nothing, + -- key is loaded from the file server_key in /etc/opt/simplex directory + serverPrivateKey = undefined } +newKeySize :: Int +newKeySize = 2048 `div` 8 + +cfgDir :: FilePath +cfgDir = "/etc/opt/simplex" + +logDir :: FilePath +logDir = "/var/opt/simplex" + +defaultStoreLogFile :: FilePath +defaultStoreLogFile = combine logDir "smp-server-store.log" + main :: IO () main = do - putStrLn $ "Listening on port " ++ tcpPort cfg - runSMPServer cfg + opts <- getServerOpts + putStrLn "SMP Server (-h for help)" + ini <- readCreateIni opts + storeLog <- openStoreLog ini + pk <- readCreateKey + B.putStrLn $ "transport key hash: " <> publicKeyHash (C.publicKey pk) + putStrLn $ "listening on port " <> tcpPort cfg + runSMPServer cfg {serverPrivateKey = pk, storeLog} + +data IniOpts = IniOpts + { enableStoreLog :: Bool, + storeLogFile :: FilePath + } + +readCreateIni :: ServerOpts -> IO IniOpts +readCreateIni ServerOpts {configFile} = do + createDirectoryIfMissing True cfgDir + doesFileExist configFile >>= (`unless` createIni) + readIni + where + readIni :: IO IniOpts + readIni = do + ini <- either exitError pure =<< readIniFile configFile + let enableStoreLog = (== Right "on") $ lookupValue "STORE_LOG" "enable" ini + storeLogFile = either (const defaultStoreLogFile) T.unpack $ lookupValue "STORE_LOG" "file" ini + pure IniOpts {enableStoreLog, storeLogFile} + exitError e = do + putStrLn $ "error reading config file " <> configFile <> ": " <> e + exitFailure + createIni :: IO () + createIni = do + confirm $ "Save default ini file to " <> configFile + writeFile + configFile + "[STORE_LOG]\n\ + \# The server uses STM memory to store SMP queues and messages,\n\ + \# that will be lost on restart (e.g., as with redis).\n\ + \# This option enables saving SMP queues to append only log,\n\ + \# and restoring them when the server is started.\n\ + \# Log is compacted on start (deleted queues are removed).\n\ + \# The messages in the queues are not logged.\n\ + \\n\ + \# enable: on\n\ + \# file: /var/opt/simplex/smp-server-store.log\n" + +readCreateKey :: IO C.FullPrivateKey +readCreateKey = do + createDirectoryIfMissing True cfgDir + let path = combine cfgDir "server_key" + hasKey <- doesFileExist path + (if hasKey then readKey else createKey) path + where + createKey :: FilePath -> IO C.FullPrivateKey + createKey path = do + confirm "Generate new server key pair" + (_, pk) <- C.generateKeyPair newKeySize + S.writeKeyFile S.TraditionalFormat path [PrivKeyRSA $ C.rsaPrivateKey pk] + pure pk + readKey :: FilePath -> IO C.FullPrivateKey + readKey path = do + S.readKeyFile path >>= \case + [S.Unprotected (PrivKeyRSA pk)] -> pure $ C.FullPrivateKey pk + [_] -> errorExit "not RSA key" + [] -> errorExit "invalid key file format" + _ -> errorExit "more than one key" + where + errorExit :: String -> IO b + errorExit e = putStrLn (e <> ": " <> path) >> exitFailure + +confirm :: String -> IO () +confirm msg = do + putStr $ msg <> " (y/N): " + hFlush stdout + ok <- getLine + when (map toLower ok /= "y") exitFailure + +publicKeyHash :: C.PublicKey -> B.ByteString +publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.encodePubKey + +openStoreLog :: IniOpts -> IO (Maybe (StoreLog 'ReadMode)) +openStoreLog IniOpts {enableStoreLog, storeLogFile = f} + | enableStoreLog = do + createDirectoryIfMissing True logDir + putStrLn ("store log: " <> f) + Just <$> openReadStoreLog f + | otherwise = putStrLn "store log disabled" $> Nothing + +newtype ServerOpts = ServerOpts + { configFile :: FilePath + } + +serverOpts :: Parser ServerOpts +serverOpts = + ServerOpts + <$> strOption + ( long "config" + <> short 'c' + <> metavar "INI_FILE" + <> help ("config file (" <> defaultIniFile <> ")") + <> value defaultIniFile + ) + where + defaultIniFile = combine cfgDir "smp-server.ini" + +getServerOpts :: IO ServerOpts +getServerOpts = execParser opts + where + opts = + info + (serverOpts <**> helper) + ( fullDesc + <> header "Simplex Messaging Protocol (SMP) Server" + <> progDesc "Start server with INI_FILE (created on first run)" + ) diff --git a/package.yaml b/package.yaml index be4272382..d89e592ca 100644 --- a/package.yaml +++ b/package.yaml @@ -13,6 +13,8 @@ extra-source-files: dependencies: - ansi-terminal == 0.10.* + - asn1-encoding == 0.9.* + - asn1-types == 0.3.* - async == 2.2.* - attoparsec == 0.13.* - base >= 4.7 && < 5 @@ -22,11 +24,13 @@ dependencies: - cryptonite == 0.26.* - directory == 1.3.* - filepath == 1.4.* + - generic-random == 1.3.* - iso8601-time == 0.1.* - memory == 0.15.* - mtl - network == 3.1.* - network-transport == 0.5.* + - QuickCheck == 2.13.* - simple-logger == 0.1.* - sqlite-simple == 0.4.* - stm @@ -36,6 +40,7 @@ dependencies: - transformers == 0.5.* - unliftio == 0.2.* - unliftio-core == 0.1.* + - x509 == 1.7.* library: source-dirs: src @@ -45,6 +50,9 @@ executables: source-dirs: apps/smp-server main: Main.hs dependencies: + - cryptostore == 0.2.* + - ini == 0.4.* + - optparse-applicative == 0.15.* - simplex-messaging ghc-options: - -threaded @@ -78,6 +86,8 @@ tests: - hspec-core == 2.7.* - HUnit == 1.6.* - random == 1.1.* + - QuickCheck == 2.13.* + - timeit == 2.0.* ghc-options: # - -haddock diff --git a/rfcs/2021-01-26-crypto.md b/rfcs/2021-01-26-crypto.md index e9ea8aa49..6e397da18 100644 --- a/rfcs/2021-01-26-crypto.md +++ b/rfcs/2021-01-26-crypto.md @@ -16,49 +16,60 @@ For initial implementation I propose approach to be as simple as possible as lon One of the consideration is to use [noise protocol framework](https://noiseprotocol.org/noise.html), this section describes ad hoc protocol though. -During TCP session both client and server should use symmetric AES 256 bit encryption using the session key that will be established during the handshake. +During TCP session both client and server should use symmetric AES 256 bit encryption using two session keys and two base IVs that will be agreed during the handshake. Both client and the server should maintain two 32-bit word counters, one for sent and one for the received messages. The IV for each message should be computed by xor-ing the sequential message counter, starting from 0, with the first 32 bits of agreed base IV. TODO - explain it in a more formal way, also document how 32-bit word is encoded - with the most or least significant byte first (currently encodeWord32 from Network.Transport.Internal is used) -To establish the session key, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) and additional server ID (256 bits) in advance in order to be able to establish connection. +To establish the session keys and base IVs, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) in advance in order to be able to establish connection. -The handshake sequence could be the following: +The handshake sequence is the following: -1. Once the connection is established, the server sends its public key to the client -2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. -3. If the hash is the same, the client should generate a random symmetric AES key and IV that will be used as a session key both by the client and the server. -4. The client then should encrypt this symmetric key with the public key that the server sent and send back to the server the result and the server ID also shared with the client in advance: `rsa-encrypt(aes-key, iv, server-id)`. -5. The server should decrypt the received key, IV and server id with its private key. -6. The server should compare the `server-id` sent by the client and if it does not match its ID terminate the connection. -7. In case of successful decryption and matching server ID, the server should send encrypted welcome header. +1. Once the connection is established, the server sends server_header and its public RSA key encoded in X509 binary format to the client. +2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. TODO as the hash is optional in server syntax at the moment, hash comparison will be optional as well. Probably it should become required. +3. If the hash is the same, the client should generate random symmetric AES keys and base IVs that will be used as session keys/IVs by the client and the server. +4. The client then should construct client_handshake block and send it to the server encrypted with the server public key: `rsa-encrypt(client_handshake)`. `snd_aes_key` and `snd_base_iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv_aes_key` and `rcv_base_iv` will be used by the client to decrypt **received** messages and by the server to encrypt them. +5. The server should decrypt the received keys and base IVs with its private key. +6. In case of successful decryption, the server should send encrypted welcome block (encrypted_welcome_block) that contains SMP protocol version. -```abnf -aes_welcome_header = aes_header_auth_tag aes_encrypted_header -welcome_header = smp_version ["," smp_mode] *SP ; decrypt(aes_encrypted_header) - 32 bytes -smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format - ; for example: v123.456.789-alpha.7 -smp_mode = smp_public / smp_authenticated -smp_public = %s"pub" ; public (default) - no auth to create and manage queues -smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues -aes_header_auth_tag = aes_auth_tag -aes_auth_tag = 16*16(OCTET) -``` - -No payload should follow this header, it is only used to confirm successful handshake and send the SMP protocol version that the server supports. - -All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES key and IV sent by the client during the handshake. +All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES keys and base IVs (incremented by counters on both sides) sent by the client during the handshake. Each transport block sent by the client and the server has this syntax: ```abnf -transport_block = aes_header_auth_tag aes_encrypted_header aes_body_auth_tag aes_encrypted_body -aes_encrypted_header = 32*32(OCTET) -header = padded_body_size payload_size reserved ; decrypt(aes_encrypted_header) - 32 bytes +server_header = block_size protocol key_size +block_size = 4*4(OCTET) ; 4-byte block size sent by the server, currently the client rejects if > 65536 bytes +protocol = 2*2(OCTET) ; currently it is 0, that means binary RSA key +key_size = 2*2(OCTET) ; the encoded key size in bytes (binary encoded in X509 standard) + +client_handshake = client_block_size client_protocol snd_aes_key snd_base_iv rcv_aes_key rcv_base_iv +client_block_size = 4*4(OCTET) ; 4-byte block size sent by the client, currently it is ignored by the server - reserved +client_protocol = 2*2(OCTET) ; currently it is 0 - reserved +snd_aes_key = 32*32(OCTET) +snd_base_iv = 16*16(OCTET) +rcv_aes_key = 32*32(OCTET) +rcv_base_iv = 16*16(OCTET) + +transport_block = aes_body_auth_tag aes_encrypted_body +; size is sent by server during handshake, usually 8192 bytes aes_encrypted_body = 1*OCTET -body = payload pad -padded_body_size = size ; body size in bytes -payload_size = size ; payload_size in bytes -size = 4*4(OCTET) -reserved = 24*24(OCTET) -aes_body_auth_tag = aes_auth_tag +aes_body_auth_tag = 16*16(OCTET) + +encrypted_welcome_block = transport_block +welcome_block = smp_version SP pad ; decrypt(encrypted_welcome_block) +smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format + ; for example: v123.456.789-alpha.7 +pad = 1*OCTET +``` + +## Possible future improvements/changes + +- server id (256 bits), so that only the users that have it can connect to the server. This ID will have to be passed to the server during the handshake +- block size agreed during handshake +- transport encryption protocol agreed during handshake +- welcome block containing SMP mode (smp_mode) + +```abnf +smp_mode = smp_public / smp_authenticated +smp_public = %s"pub" ; public (default) - no auth to create and manage queues +smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues ``` ## Initial handshake @@ -105,7 +116,9 @@ Symmetric keys are generated per message and encrypted with receiver's public ke The syntax of each encrypted message body is the following: ```abnf -encrypted_message_body = rsa_encrypted_header aes_encrypted_body +encrypted_message_body = rsa_signature encrypted_body +encrypted_body = rsa_encrypted_header aes_encrypted_body +rsa_signature = 256*256(OCTET) ; sign(encrypted_body) - assuming 2048 bit key size rsa_encrypted_header = 256*256(OCTET) ; encrypt(header) - assuming 2048 bit key size aes_encrypted_body = 1*OCTET ; encrypt(body) @@ -120,7 +133,6 @@ body = payload pad Future considerations: - Generation of symmetric keys per session and session rotation; -- Signature and verification of messages. ## E2E implementation diff --git a/rfcs/2021-03-18-groups.md b/rfcs/2021-03-18-groups.md new file mode 100644 index 000000000..6da1fba69 --- /dev/null +++ b/rfcs/2021-03-18-groups.md @@ -0,0 +1,25 @@ +# SMP agent groups + +## Problems + +- device/user profile synchronisation +- chat group communication + +Both problems would require message broadcast between a group of SMP agents. + +## Solution: basic symmetric groups via SMP agent protocol + +Additional commands and message envelopes to SMP agent protocol to provide an abstraction layer for device synchronisation and chat groups. + +The groups are fully symmetric, all agent who are members of the group have equal rights and can join and leave group at any time. + +All the information about the groups is stored only in agents, the commands are used to synchronise the group state between the agents. + +```abnf +group_command = create_group / add_to_group / remove_from_group / leave_group +group_response = group_created / added_to_group / removed_from_group +group_notification = added_to_group_by / removed_from_group_by / left_group +create_group = %s"GNEW " group_name ; cAlias must be empty +add_to_group = %s"GADD " group_name ; cAlias is the connection to add to the group +added_to_group = %s"GADDED " name ; cAlias is the connection added to the group +``` \ No newline at end of file diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index e11d91f8d..e5b214033 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -26,6 +26,7 @@ import qualified Data.ByteString.Char8 as B import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) import Data.Time.Clock +import Database.SQLite.Simple (SQLError) import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Store @@ -36,10 +37,9 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (CorrId (..), MsgBody, SenderPublicKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport (putLn, runTCPServer) -import Simplex.Messaging.Util (liftError) +import Simplex.Messaging.Util (bshow) import System.IO (Handle) import UnliftIO.Async (race_) -import UnliftIO.Exception (SomeException) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -98,7 +98,7 @@ send h c@AgentClient {sndQ} = forever $ do logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m () logClient AgentClient {clientId} dir (CorrId corrId, cAlias, cmd) = do - logInfo . decodeUtf8 $ B.unwords [B.pack $ show clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd] + logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd] client :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () client c@AgentClient {rcvQ, sndQ} st = forever $ do @@ -114,10 +114,15 @@ withStore :: withStore action = do runExceptT (action `E.catch` handleInternal) >>= \case Right c -> return c - Left _ -> throwError STORE + Left e -> throwError $ storeError e where - handleInternal :: (MonadError StoreError m') => SomeException -> m' a - handleInternal _ = throwError SEInternal + handleInternal :: (MonadError StoreError m') => SQLError -> m' a + handleInternal e = throwError . SEInternal $ bshow e + storeError :: StoreError -> AgentErrorType + storeError = \case + SEConnNotFound -> CONN UNKNOWN + SEConnDuplicate -> CONN DUPLICATE + e -> INTERNAL $ show e processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m () processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = @@ -156,9 +161,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st cAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> subscribe rq SomeConn _ (RcvConnection _ rq) -> subscribe rq - -- TODO possibly there should be a separate error type trying - -- TODO to send the message to the connection without RcvQueue - _ -> throwError PROHIBITED + _ -> throwError $ CONN SIMPLEX where subscribe rq = subscribeQueue c rq cAlias >> respond' cAlias OK @@ -171,22 +174,32 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq SomeConn _ (SndConnection _ sq) -> sendMsg sq - -- TODO possibly there should be a separate error type trying - -- TODO to send the message to the connection without SndQueue - _ -> throwError PROHIBITED -- NOT_READY ? + _ -> throwError $ CONN SIMPLEX where sendMsg sq = do - senderTs <- liftIO getCurrentTime - senderId <- withStore $ createSndMsg st connAlias msgBody senderTs - sendAgentMessage c sq senderTs $ A_MSG msgBody - respond $ SENT (unId senderId) + internalTs <- liftIO getCurrentTime + (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq + let msgStr = + serializeSMPMessage + SMPMessage + { senderMsgId = unSndId internalSndId, + senderTimestamp = internalTs, + previousMsgHash, + agentMessage = A_MSG msgBody + } + msgHash = C.sha256Hash msgStr + withStore $ + createSndMsg st sq $ + SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash} + sendAgentMessage c sq msgStr + respond $ SENT (unId internalId) suspendConnection :: m () suspendConnection = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> suspend rq SomeConn _ (RcvConnection _ rq) -> suspend rq - _ -> throwError PROHIBITED + _ -> throwError $ CONN SIMPLEX where suspend rq = suspendQueue c rq >> respond OK @@ -195,20 +208,26 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> delete rq SomeConn _ (RcvConnection _ rq) -> delete rq - _ -> throwError PROHIBITED + _ -> delConn where + delConn = withStore (deleteConn st connAlias) >> respond OK delete rq = do deleteQueue c rq removeSubscription c connAlias - withStore (deleteConn st connAlias) - respond OK + delConn sendReplyQInfo :: SMPServer -> SndQueue -> m () sendReplyQInfo srv sq = do (rq, qInfo) <- newReceiveQueue c srv connAlias withStore $ upgradeSndConnToDuplex st connAlias rq - senderTs <- liftIO getCurrentTime - sendAgentMessage c sq senderTs $ REPLY qInfo + senderTimestamp <- liftIO getCurrentTime + sendAgentMessage c sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage = REPLY qInfo + } respond :: ACommand 'Agent -> m () respond = respond' connAlias @@ -226,11 +245,13 @@ subscriber c@AgentClient {msgQ} st = forever $ do processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m () processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do - rq@RcvQueue {connAlias, decryptKey, status} <- withStore $ getRcvQueue st srv rId + rq@RcvQueue {connAlias, status} <- withStore $ getRcvQueue st srv rId case cmd of SMP.MSG srvMsgId srvTs msgBody -> do -- TODO deduplicate with previously received - agentMsg <- liftEither . parseSMPMessage =<< decryptMessage decryptKey msgBody + msg <- decryptAndVerify rq msgBody + let msgHash = C.sha256Hash msg + agentMsg <- liftEither $ parseSMPMessage msg case agentMsg of SMPConfirmation senderKey -> do logServer "<--" c srv rId "MSG " @@ -242,50 +263,74 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do -- TODO update sender key in the store? secureQueue c rq senderKey withStore $ setRcvQueueStatus st rq Secured - sendAck c rq - s -> - -- TODO maybe send notification to the user - liftIO . putStrLn $ "unexpected SMP confirmation, queue status " <> show s - SMPMessage {agentMessage, senderMsgId, senderTimestamp} -> + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> case agentMessage of - HELLO _verifyKey _ -> do + HELLO verifyKey _ -> do logServer "<--" c srv rId "MSG " - -- TODO send status update to the user? - withStore $ setRcvQueueStatus st rq Active - sendAck c rq + case status of + Active -> notify connAlias . ERR $ AGENT A_PROHIBITED + _ -> do + void $ verifyMessage (Just verifyKey) msgBody + withStore $ setRcvQueueActive st rq verifyKey REPLY qInfo -> do logServer "<--" c srv rId "MSG " - -- TODO move senderKey inside SndQueue (sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias withStore $ upgradeRcvConnToDuplex st connAlias sq connectToSendQueue c st sq senderKey verifyKey notify connAlias CON - sendAck c rq - A_MSG body -> do - logServer "<--" c srv rId "MSG " - -- TODO check message status - recipientTs <- liftIO getCurrentTime - let m_sender = (senderMsgId, senderTimestamp) - let m_broker = (srvMsgId, srvTs) - recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker - notify connAlias $ - MSG - { m_status = MsgOk, - m_recipient = (unId recipientId, recipientTs), - m_sender, - m_broker, - m_body = body - } - sendAck c rq + A_MSG body -> agentClientMsg rq previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash + sendAck c rq return () SMP.END -> do removeSubscription c connAlias logServer "<--" c srv rId "END" notify connAlias END - _ -> logServer "<--" c srv rId $ "unexpected:" <> (B.pack . show) cmd + _ -> do + logServer "<--" c srv rId $ "unexpected: " <> bshow cmd + notify connAlias . ERR $ BROKER UNEXPECTED where notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) + agentClientMsg :: RcvQueue -> PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m () + agentClientMsg rq@RcvQueue {connAlias, status} receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do + logServer "<--" c srv rId "MSG " + case status of + Active -> do + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq + let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash + withStore $ + createRcvMsg st rq $ + RcvMsgData + { internalId, + internalRcvId, + internalTs, + senderMeta, + brokerMeta, + msgBody, + internalHash = msgHash, + externalPrevSndHash = receivedPrevMsgHash, + msgIntegrity + } + notify connAlias $ + MSG + { recipientMeta = (unId internalId, internalTs), + senderMeta, + brokerMeta, + msgBody, + msgIntegrity + } + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + where + checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> MsgIntegrity + checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash + | extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk + | extSndId < prevExtSndId = MsgError $ MsgBadId extSndId + | extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate + | extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1) + | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash + | otherwise = MsgError MsgDuplicate -- this case is not possible connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m () connectToSendQueue c st sq senderKey verifyKey = do @@ -294,9 +339,6 @@ connectToSendQueue c st sq senderKey verifyKey = do sendHello c sq verifyKey withStore $ setSndQueueStatus st sq Active -decryptMessage :: (MonadUnliftIO m, MonadError AgentErrorType m) => DecryptionKey -> ByteString -> m ByteString -decryptMessage decryptKey msg = liftError CRYPTO $ C.decrypt decryptKey msg - newSendQueue :: (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey) newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 1226c36be..30c510b28 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -19,11 +19,14 @@ module Simplex.Messaging.Agent.Client sendHello, secureQueue, sendAgentMessage, + decryptAndVerify, + verifyMessage, sendAck, suspendQueue, deleteQueue, logServer, removeSubscription, + cryptoError, ) where @@ -47,7 +50,7 @@ import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey) -import Simplex.Messaging.Util (liftError) +import Simplex.Messaging.Util (bshow, liftEitherError, liftError) import UnliftIO.Concurrent import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E @@ -86,15 +89,17 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = newSMPClient = do smp <- connectClient logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - -- TODO how can agent know client lost the connection? atomically . modifyTVar smpClients $ M.insert srv smp return smp connectClient :: m SMPClient connectClient = do cfg <- asks $ smpCfg . config - liftIO (getSMPClient srv cfg msgQ clientDisconnected) - `E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection) + liftEitherError smpClientError (getSMPClient srv cfg msgQ clientDisconnected) + `E.catch` internalError + where + internalError :: IOException -> m SMPClient + internalError = throwError . INTERNAL . show clientDisconnected :: IO () clientDisconnected = do @@ -118,31 +123,41 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m () closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient -withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a -withSMP c srv action = - (getSMPServerClient c srv >>= runAction) `catchError` logServerError +withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a +withSMP_ c srv action = + (getSMPServerClient c srv >>= action) `catchError` logServerError where - runAction :: SMPClient -> m a - runAction smp = liftError smpClientError $ action smp - - smpClientError :: SMPClientError -> AgentErrorType - smpClientError = \case - SMPServerError e -> SMP e - -- TODO handle other errors - _ -> INTERNAL - logServerError :: AgentErrorType -> m a logServerError e = do - logServer "<--" c srv "" $ (B.pack . show) e + logServer "<--" c srv "" $ bshow e throwError e -withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a -withLogSMP c srv qId cmdStr action = do +withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a +withLogSMP_ c srv qId cmdStr action = do logServer "-->" c srv qId cmdStr - res <- withSMP c srv action + res <- withSMP_ c srv action logServer "<--" c srv qId "OK" return res +withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a +withSMP c srv action = withSMP_ c srv $ liftSMP . action + +withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a +withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action + +liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a +liftSMP = liftError smpClientError + +smpClientError :: SMPClientError -> AgentErrorType +smpClientError = \case + SMPServerError e -> SMP e + SMPResponseError e -> BROKER $ RESPONSE e + SMPUnexpectedResponse -> BROKER UNEXPECTED + SMPResponseTimeout -> BROKER TIMEOUT + SMPNetworkError -> BROKER NETWORK + SMPTransportError e -> BROKER $ TRANSPORT e + e -> INTERNAL $ show e + newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo) newReceiveQueue c srv connAlias = do size <- asks $ rsaKeySize . config @@ -196,7 +211,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = - logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] + logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] showServer :: SMPServer -> ByteString showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv) @@ -205,35 +220,38 @@ logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> m () -sendConfirmation c SndQueue {server, sndId, encryptKey} senderKey = do - msg <- mkConfirmation - withLogSMP c server sndId "SEND " $ \smp -> - sendSMPMessage smp Nothing sndId msg +sendConfirmation c sq@SndQueue {server, sndId} senderKey = + withLogSMP_ c server sndId "SEND " $ \smp -> do + msg <- mkConfirmation smp + liftSMP $ sendSMPMessage smp Nothing sndId msg where - mkConfirmation :: m MsgBody - mkConfirmation = do - let msg = serializeSMPMessage $ SMPConfirmation senderKey - paddedSize <- asks paddedMsgSize - liftError CRYPTO $ C.encrypt encryptKey paddedSize msg + mkConfirmation :: SMPClient -> m MsgBody + mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () -sendHello c SndQueue {server, sndId, sndPrivateKey, encryptKey} verifyKey = do - msg <- mkHello $ AckMode On - withLogSMP c server sndId "SEND (retrying)" $ - send 20 msg +sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = + withLogSMP_ c server sndId "SEND (retrying)" $ \smp -> do + msg <- mkHello smp $ AckMode On + liftSMP $ send 8 100000 msg smp where - mkHello :: AckMode -> m ByteString - mkHello ackMode = do - senderTs <- liftIO getCurrentTime - mkAgentMessage encryptKey senderTs $ HELLO verifyKey ackMode + mkHello :: SMPClient -> AckMode -> m ByteString + mkHello smp ackMode = do + senderTimestamp <- liftIO getCurrentTime + encryptAndSign smp sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage = HELLO verifyKey ackMode + } - send :: Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () - send 0 _ _ = throwE SMPResponseTimeout -- TODO different error - send retry msg smp = + send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () + send 0 _ _ _ = throwE $ SMPServerError AUTH + send retry delay msg smp = sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case SMPServerError AUTH -> do - threadDelay 100000 - send (retry - 1) msg smp + threadDelay delay + send (retry - 1) (delay * 3 `div` 2) msg smp e -> throwE e secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m () @@ -256,21 +274,39 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} = withLogSMP c server rcvId "DEL" $ \smp -> deleteSMPQueue smp rcvPrivateKey rcvId -sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m () -sendAgentMessage c SndQueue {server, sndId, sndPrivateKey, encryptKey} senderTs agentMsg = do - msg <- mkAgentMessage encryptKey senderTs agentMsg - withLogSMP c server sndId "SEND " $ \smp -> - sendSMPMessage smp (Just sndPrivateKey) sndId msg +sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () +sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msg = + withLogSMP_ c server sndId "SEND " $ \smp -> do + msg' <- encryptAndSign smp sq msg + liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg' -mkAgentMessage :: AgentMonad m => EncryptionKey -> SenderTimestamp -> AMessage -> m ByteString -mkAgentMessage encKey senderTs agentMessage = do - let msg = - serializeSMPMessage - SMPMessage - { senderMsgId = 0, - senderTimestamp = senderTs, - previousMsgHash = "1234", -- TODO hash of the previous message - agentMessage - } - paddedSize <- asks paddedMsgSize - liftError CRYPTO $ C.encrypt encKey paddedSize msg +encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> ByteString -> m ByteString +encryptAndSign smp SndQueue {encryptKey, signKey} msg = do + paddedSize <- asks $ (blockSize smp -) . reservedMsgSize + liftError cryptoError $ do + enc <- C.encrypt encryptKey paddedSize msg + C.Signature sig <- C.sign signKey enc + pure $ sig <> enc + +decryptAndVerify :: AgentMonad m => RcvQueue -> ByteString -> m ByteString +decryptAndVerify RcvQueue {decryptKey, verifyKey} msg = + verifyMessage verifyKey msg + >>= liftError cryptoError . C.decrypt decryptKey + +verifyMessage :: AgentMonad m => Maybe VerificationKey -> ByteString -> m ByteString +verifyMessage verifyKey msg = do + size <- asks $ rsaKeySize . config + let (sig, enc) = B.splitAt size msg + case verifyKey of + Nothing -> pure enc + Just k + | C.verify k (C.Signature sig) enc -> pure enc + | otherwise -> throwError $ AGENT A_SIGNATURE + +cryptoError :: C.CryptoError -> AgentErrorType +cryptoError = \case + C.CryptoLargeMsgError -> CMD LARGE + C.RSADecryptError _ -> AGENT A_ENCRYPTION + C.CryptoHeaderError _ -> AGENT A_ENCRYPTION + C.AESDecryptError -> AGENT A_ENCRYPTION + e -> INTERNAL $ show e diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 6ecd85ea2..b14bfb4a6 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -26,7 +26,7 @@ data Env = Env { config :: AgentConfig, idsDrg :: TVar ChaChaDRG, clientCounter :: TVar Int, - paddedMsgSize :: Int + reservedMsgSize :: Int } newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env @@ -34,10 +34,10 @@ newSMPAgentEnv config = do idsDrg <- drgNew >>= newTVarIO _ <- createSQLiteStore $ dbFile config clientCounter <- newTVarIO 0 - return Env {config, idsDrg, clientCounter, paddedMsgSize} + return Env {config, idsDrg, clientCounter, reservedMsgSize} where - -- one rsaKeySize is used by the RSA signature in each command, - -- another - by encrypted message body header + -- 1st rsaKeySize is used by the RSA signature in each command, + -- 2nd - by encrypted message body header + -- 3rd - by message signature -- smpCommandSize - is the estimated max size for SMP command, queueId, corrId - paddedMsgSize = blockSize smp - 2 * rsaKeySize config - smpCommandSize smp - smp = smpCfg config + reservedMsgSize = 3 * rsaKeySize config + smpCommandSize (smpCfg config) diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 6708bb13d..778b6b6be 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -10,6 +10,7 @@ module Simplex.Messaging.Agent.Store where import Control.Exception (Exception) +import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Kind (Type) import Data.Time (UTCTime) @@ -38,11 +39,16 @@ class Monad m => MonadAgentStore s m where upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m () upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m () setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m () + setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m () setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m () -- Msg management - createRcvMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId - createSndMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> m InternalId + updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m () + + updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) + createSndMsg :: s -> SndQueue -> SndMsgData -> m () + getMsg :: s -> ConnAlias -> InternalId -> m Msg -- * Queue types @@ -102,6 +108,11 @@ data SConnType :: ConnType -> Type where SCSnd :: SConnType CSnd SCDuplex :: SConnType CDuplex +connType :: SConnType c -> ConnType +connType SCRcv = CRcv +connType SCSnd = CSnd +connType SCDuplex = CDuplex + deriving instance Eq (SConnType d) deriving instance Show (SConnType d) @@ -123,6 +134,42 @@ instance Eq SomeConn where deriving instance Show SomeConn +-- * Message integrity validation types + +type MsgHash = ByteString + +-- | Corresponds to `last_external_snd_msg_id` in `connections` table +type PrevExternalSndId = Int64 + +-- | Corresponds to `last_rcv_msg_hash` in `connections` table +type PrevRcvMsgHash = MsgHash + +-- | Corresponds to `last_snd_msg_hash` in `connections` table +type PrevSndMsgHash = MsgHash + +-- ? merge/replace these with RcvMsg and SndMsg +-- * Message data containers - used on Msg creation to reduce number of parameters + +data RcvMsgData = RcvMsgData + { internalId :: InternalId, + internalRcvId :: InternalRcvId, + internalTs :: InternalTs, + senderMeta :: (ExternalSndId, ExternalSndTs), + brokerMeta :: (BrokerId, BrokerTs), + msgBody :: MsgBody, + internalHash :: MsgHash, + externalPrevSndHash :: MsgHash, + msgIntegrity :: MsgIntegrity + } + +data SndMsgData = SndMsgData + { internalId :: InternalId, + internalSndId :: InternalSndId, + internalTs :: InternalTs, + msgBody :: MsgBody, + internalHash :: MsgHash + } + -- * Message types -- | A message in either direction that is stored by the agent. @@ -149,7 +196,10 @@ data RcvMsg = RcvMsg -- | Timestamp of acknowledgement to sender, corresponds to `AcknowledgedToSender` status. -- Do not mix up with `externalSndTs` - timestamp created at sender before sending, -- which in its turn corresponds to `internalTs` in sending agent. - ackSenderTs :: AckSenderTs + ackSenderTs :: AckSenderTs, + -- | Hash of previous message as received from sender - stored for integrity forensics. + externalPrevSndHash :: MsgHash, + msgIntegrity :: MsgIntegrity } deriving (Eq, Show) @@ -209,7 +259,9 @@ data MsgBase = MsgBase -- due to a possibility of implementation errors in different agents. internalId :: InternalId, internalTs :: InternalTs, - msgBody :: MsgBody + msgBody :: MsgBody, + -- | Hash of the message as computed by agent. + internalHash :: MsgHash } deriving (Eq, Show) @@ -219,13 +271,11 @@ type InternalTs = UTCTime -- * Store errors --- TODO revise data StoreError - = SEInternal - | SENotFound - | SEBadConn + = SEInternal ByteString + | SEConnNotFound + | SEConnDuplicate | SEBadConnType ConnType - | SEBadQueueStatus - | SEBadQueueDirection + | SEBadQueueStatus -- not used, planned to check strictly | SENotImplemented -- TODO remove deriving (Eq, Show, Exception) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 0b83bbd77..ae9473e7d 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -21,24 +23,26 @@ where import Control.Monad (when) import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO)) import Control.Monad.IO.Unlift (MonadUnliftIO) +import Data.Bifunctor (first) import Data.List (find) import Data.Maybe (fromMaybe) import Data.Text (isPrefixOf) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) -import Database.SQLite.Simple as DB +import Database.SQLite.Simple (FromRow, NamedParam (..), SQLData (..), SQLError, field) +import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.QQ (sql) import Database.SQLite.Simple.ToField (ToField (..)) -import Network.Socket (HostName, ServiceName) +import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema) import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Protocol (MsgBody) +import Simplex.Messaging.Parsers (parseAll) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Util (liftIOEither) +import Simplex.Messaging.Util (bshow, liftIOEither) import System.Exit (ExitCode (ExitFailure), exitWith) import System.FilePath (takeDirectory) import Text.Read (readMaybe) @@ -75,76 +79,170 @@ connectSQLiteStore dbFilePath = do liftIO $ DB.execute_ dbConn "PRAGMA foreign_keys = ON;" return SQLiteStore {dbFilePath, dbConn} +checkDuplicate :: (MonadUnliftIO m, MonadError StoreError m) => IO () -> m () +checkDuplicate action = liftIOEither $ first handleError <$> E.try action + where + handleError :: SQLError -> StoreError + handleError e + | DB.sqlError e == DB.ErrorConstraint = SEConnDuplicate + | otherwise = SEInternal $ bshow e + instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where createRcvConn :: SQLiteStore -> RcvQueue -> m () - createRcvConn SQLiteStore {dbConn} rcvQueue = - liftIO $ - createRcvQueueAndConn dbConn rcvQueue + createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} = + checkDuplicate $ + DB.withTransaction dbConn $ do + upsertServer_ dbConn server + insertRcvQueue_ dbConn q + insertRcvConnection_ dbConn q createSndConn :: SQLiteStore -> SndQueue -> m () - createSndConn SQLiteStore {dbConn} sndQueue = - liftIO $ - createSndQueueAndConn dbConn sndQueue + createSndConn SQLiteStore {dbConn} q@SndQueue {server} = + checkDuplicate $ + DB.withTransaction dbConn $ do + upsertServer_ dbConn server + insertSndQueue_ dbConn q + insertSndConnection_ dbConn q getConn :: SQLiteStore -> ConnAlias -> m SomeConn - getConn SQLiteStore {dbConn} connAlias = do - queues <- - liftIO $ - retrieveConnQueues dbConn connAlias - case queues of - (Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) - (Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ) - (Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> throwError SEBadConn + getConn SQLiteStore {dbConn} connAlias = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias getAllConnAliases :: SQLiteStore -> m [ConnAlias] getAllConnAliases SQLiteStore {dbConn} = - liftIO $ - retrieveAllConnAliases dbConn + liftIO $ do + r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] + return (concat r) getRcvQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m RcvQueue getRcvQueue SQLiteStore {dbConn} SMPServer {host, port} rcvId = do - rcvQueue <- + r <- liftIO $ - retrieveRcvQueue dbConn host port rcvId - case rcvQueue of - Just rcvQ -> return rcvQ - _ -> throwError SENotFound + DB.queryNamed + dbConn + [sql| + SELECT + s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, + q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status + FROM rcv_queues q + INNER JOIN servers s ON q.host = s.host AND q.port = s.port + WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; + |] + [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] + case r of + [(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> + let srv = SMPServer hst (deserializePort_ prt) keyHash + in pure $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status + _ -> throwError SEConnNotFound deleteConn :: SQLiteStore -> ConnAlias -> m () deleteConn SQLiteStore {dbConn} connAlias = liftIO $ - deleteConnCascade dbConn connAlias + DB.executeNamed + dbConn + "DELETE FROM connections WHERE conn_alias = :conn_alias;" + [":conn_alias" := connAlias] upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m () - upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sndQueue = - liftIOEither $ - updateRcvConnWithSndQueue dbConn connAlias sndQueue + upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias >>= \case + Right (SomeConn SCRcv (RcvConnection _ _)) -> do + upsertServer_ dbConn server + insertSndQueue_ dbConn sq + updateConnWithSndQueue_ dbConn connAlias sq + pure $ Right () + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m () - upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rcvQueue = - liftIOEither $ - updateSndConnWithRcvQueue dbConn connAlias rcvQueue + upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias >>= \case + Right (SomeConn SCSnd (SndConnection _ _)) -> do + upsertServer_ dbConn server + insertRcvQueue_ dbConn rq + updateConnWithRcvQueue_ dbConn connAlias rq + pure $ Right () + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m () - setRcvQueueStatus SQLiteStore {dbConn} rcvQueue status = + setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status = + -- ? throw error if queue doesn't exist? liftIO $ - updateRcvQueueStatus dbConn rcvQueue status + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] + + setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m () + setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = + -- ? throw error if queue doesn't exist? + liftIO $ + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET verify_key = :verify_key, status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [ ":verify_key" := Just verifyKey, + ":status" := Active, + ":host" := host, + ":port" := serializePort_ port, + ":rcv_id" := rcvId + ] setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () - setSndQueueStatus SQLiteStore {dbConn} sndQueue status = + setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status = + -- ? throw error if queue doesn't exist? liftIO $ - updateSndQueueStatus dbConn sndQueue status + DB.executeNamed + dbConn + [sql| + UPDATE snd_queues + SET status = :status + WHERE host = :host AND port = :port AND snd_id = :snd_id; + |] + [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - createRcvMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId - createRcvMsg SQLiteStore {dbConn} connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) = - liftIOEither $ - insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) + updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} = + liftIO . DB.withTransaction dbConn $ do + (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias + let internalId = InternalId $ unId lastInternalId + 1 + internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 + updateLastIdsRcv_ dbConn connAlias internalId internalRcvId + pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) - createSndMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> m InternalId - createSndMsg SQLiteStore {dbConn} connAlias msgBody internalTs = - liftIOEither $ - insertSndMsg dbConn connAlias msgBody internalTs + createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m () + createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData = + liftIO . DB.withTransaction dbConn $ do + insertRcvMsgBase_ dbConn connAlias rcvMsgData + insertRcvMsgDetails_ dbConn connAlias rcvMsgData + updateHashRcv_ dbConn connAlias rcvMsgData + + updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) + updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} = + liftIO . DB.withTransaction dbConn $ do + (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias + let internalId = InternalId $ unId lastInternalId + 1 + internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 + updateLastIdsSnd_ dbConn connAlias internalId internalSndId + pure (internalId, internalSndId, prevSndHash) + + createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m () + createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData = + liftIO . DB.withTransaction dbConn $ do + insertSndMsgBase_ dbConn connAlias sndMsgData + insertSndMsgDetails_ dbConn connAlias sndMsgData + updateHashSnd_ dbConn connAlias sndMsgData getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg getMsg _st _connAlias _id = throwError SENotImplemented @@ -179,6 +277,16 @@ instance ToField RcvMsgStatus where toField = toField . show instance ToField SndMsgStatus where toField = toField . show +instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity + +instance FromField MsgIntegrity where + fromField = \case + f@(Field (SQLBlob b) _) -> + case parseAll msgIntegrityP b of + Right k -> Ok k + Left e -> returnError ConversionFailed f ("can't parse msg integrity field: " ++ e) + f -> returnError ConversionFailed f "expecting SQLBlob column type" + fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a fromFieldToReadable_ = \case f@(Field (SQLText t) _) -> @@ -217,13 +325,6 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do -- * createRcvConn helpers -createRcvQueueAndConn :: DB.Connection -> RcvQueue -> IO () -createRcvQueueAndConn dbConn rcvQueue = - DB.withTransaction dbConn $ do - upsertServer_ dbConn (server (rcvQueue :: RcvQueue)) - insertRcvQueue_ dbConn rcvQueue - insertRcvConnection_ dbConn rcvQueue - insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO () insertRcvQueue_ dbConn RcvQueue {..} = do let port_ = serializePort_ $ port server @@ -255,22 +356,16 @@ insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do [sql| INSERT INTO connections ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id) + last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, + last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, - 0, 0, 0); + 0, 0, 0, 0, x'', x''); |] [":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId] -- * createSndConn helpers -createSndQueueAndConn :: DB.Connection -> SndQueue -> IO () -createSndQueueAndConn dbConn sndQueue = - DB.withTransaction dbConn $ do - upsertServer_ dbConn (server (sndQueue :: SndQueue)) - insertSndQueue_ dbConn sndQueue - insertSndConnection_ dbConn sndQueue - insertSndQueue_ :: DB.Connection -> SndQueue -> IO () insertSndQueue_ dbConn SndQueue {..} = do let port_ = serializePort_ $ port server @@ -300,28 +395,25 @@ insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do [sql| INSERT INTO connections ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id) + last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, + last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES (:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id, - 0, 0, 0); + 0, 0, 0, 0, x'', x''); |] [":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId] -- * getConn helpers -retrieveConnQueues :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue) -retrieveConnQueues dbConn connAlias = - DB.withTransaction -- Avoid inconsistent state between queue reads - dbConn - $ retrieveConnQueues_ dbConn connAlias - --- Separate transactionless version of retrieveConnQueues to be reused in other functions that already wrap --- multiple statements in transaction - otherwise they'd be attempting to start a transaction within a transaction -retrieveConnQueues_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue) -retrieveConnQueues_ dbConn connAlias = do - rcvQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias - sndQ <- retrieveSndQueueByConnAlias_ dbConn connAlias - return (rcvQ, sndQ) +getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn) +getConn_ dbConn connAlias = do + rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias + sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias + pure $ case (rQ, sQ) of + (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) + (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ) + (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ) + _ -> Left SEConnNotFound retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue) retrieveRcvQueueByConnAlias_ dbConn connAlias = do @@ -363,60 +455,8 @@ retrieveSndQueueByConnAlias_ dbConn connAlias = do return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status _ -> return Nothing --- * getAllConnAliases helper - -retrieveAllConnAliases :: DB.Connection -> IO [ConnAlias] -retrieveAllConnAliases dbConn = do - r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] - return (concat r) - --- * getRcvQueue helper - -retrieveRcvQueue :: DB.Connection -> HostName -> Maybe ServiceName -> SMP.RecipientId -> IO (Maybe RcvQueue) -retrieveRcvQueue dbConn host port rcvId = do - r <- - DB.queryNamed - dbConn - [sql| - SELECT - s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, - q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status - FROM rcv_queues q - INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; - |] - [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] - case r of - [(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do - let srv = SMPServer hst (deserializePort_ prt) keyHash - return . Just $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status - _ -> return Nothing - --- * deleteConn helper - -deleteConnCascade :: DB.Connection -> ConnAlias -> IO () -deleteConnCascade dbConn connAlias = - DB.executeNamed - dbConn - "DELETE FROM connections WHERE conn_alias = :conn_alias;" - [":conn_alias" := connAlias] - -- * upgradeRcvConnToDuplex helpers -updateRcvConnWithSndQueue :: DB.Connection -> ConnAlias -> SndQueue -> IO (Either StoreError ()) -updateRcvConnWithSndQueue dbConn connAlias sndQueue = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Just _rcvQ, Nothing) -> do - upsertServer_ dbConn (server (sndQueue :: SndQueue)) - insertSndQueue_ dbConn sndQueue - updateConnWithSndQueue_ dbConn connAlias sndQueue - return $ Right () - (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEBadConn - updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do let port_ = serializePort_ $ port server @@ -431,20 +471,6 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do -- * upgradeSndConnToDuplex helpers -updateSndConnWithRcvQueue :: DB.Connection -> ConnAlias -> RcvQueue -> IO (Either StoreError ()) -updateSndConnWithRcvQueue dbConn connAlias rcvQueue = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Nothing, Just _sndQ) -> do - upsertServer_ dbConn (server (rcvQueue :: RcvQueue)) - insertRcvQueue_ dbConn rcvQueue - updateConnWithRcvQueue_ dbConn connAlias rcvQueue - return $ Right () - (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEBadConn - updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server @@ -457,74 +483,40 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do |] [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias] --- * setRcvQueueStatus helper +-- * updateRcvIds helpers --- ? throw error if queue doesn't exist? -updateRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () -updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} status = - DB.executeNamed - dbConn - [sql| - UPDATE rcv_queues - SET status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; - |] - [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] - --- * setSndQueueStatus helper - --- ? throw error if queue doesn't exist? -updateSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO () -updateSndQueueStatus dbConn SndQueue {sndId, server = SMPServer {host, port}} status = - DB.executeNamed - dbConn - [sql| - UPDATE snd_queues - SET status = :status - WHERE host = :host AND port = :port AND snd_id = :snd_id; - |] - [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - --- * createRcvMsg helpers - -insertRcvMsg :: - DB.Connection -> - ConnAlias -> - MsgBody -> - InternalTs -> - (ExternalSndId, ExternalSndTs) -> - (BrokerId, BrokerTs) -> - IO (Either StoreError InternalId) -insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Just _rcvQ, _) -> do - (lastInternalId, lastInternalRcvId) <- retrieveLastInternalIdsRcv_ dbConn connAlias - let internalId = InternalId $ unId lastInternalId + 1 - let internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 - insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody - insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) - updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId - return $ Right internalId - (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - _ -> return $ Left SEBadConn - -retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId) -retrieveLastInternalIdsRcv_ dbConn connAlias = do - [(lastInternalId, lastInternalRcvId)] <- +retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +retrieveLastIdsAndHashRcv_ dbConn connAlias = do + [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- DB.queryNamed dbConn [sql| - SELECT last_internal_msg_id, last_internal_rcv_msg_id + SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash FROM connections WHERE conn_alias = :conn_alias; |] [":conn_alias" := connAlias] - return (lastInternalId, lastInternalRcvId) + return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) -insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalRcvId -> MsgBody -> IO () -insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = do +updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () +updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = + DB.executeNamed + dbConn + [sql| + UPDATE connections + SET last_internal_msg_id = :last_internal_msg_id, + last_internal_rcv_msg_id = :last_internal_rcv_msg_id + WHERE conn_alias = :conn_alias; + |] + [ ":last_internal_msg_id" := newInternalId, + ":last_internal_rcv_msg_id" := newInternalRcvId, + ":conn_alias" := connAlias + ] + +-- * createRcvMsg helpers + +insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do DB.executeNamed dbConn [sql| @@ -540,82 +532,85 @@ insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = ":body" := decodeUtf8 msgBody ] -insertRcvMsgDetails_ :: - DB.Connection -> - ConnAlias -> - InternalRcvId -> - InternalId -> - (ExternalSndId, ExternalSndTs) -> - (BrokerId, BrokerTs) -> - IO () -insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) = +insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = DB.executeNamed dbConn [sql| INSERT INTO rcv_messages ( conn_alias, internal_rcv_id, internal_id, external_snd_id, external_snd_ts, - broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts) + broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts, + internal_hash, external_prev_snd_hash, integrity) VALUES (:conn_alias,:internal_rcv_id,:internal_id,:external_snd_id,:external_snd_ts, - :broker_id,:broker_ts,:rcv_status, NULL, NULL); + :broker_id,:broker_ts,:rcv_status, NULL, NULL, + :internal_hash,:external_prev_snd_hash,:integrity); |] [ ":conn_alias" := connAlias, ":internal_rcv_id" := internalRcvId, ":internal_id" := internalId, - ":external_snd_id" := externalSndId, - ":external_snd_ts" := externalSndTs, - ":broker_id" := brokerId, - ":broker_ts" := brokerTs, - ":rcv_status" := Received + ":external_snd_id" := fst senderMeta, + ":external_snd_ts" := snd senderMeta, + ":broker_id" := fst brokerMeta, + ":broker_ts" := snd brokerMeta, + ":rcv_status" := Received, + ":internal_hash" := internalHash, + ":external_prev_snd_hash" := externalPrevSndHash, + ":integrity" := msgIntegrity ] -updateLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () -updateLastInternalIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = +updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +updateHashRcv_ dbConn connAlias RcvMsgData {..} = + DB.executeNamed + dbConn + -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved + [sql| + UPDATE connections + SET last_external_snd_msg_id = :last_external_snd_msg_id, + last_rcv_msg_hash = :last_rcv_msg_hash + WHERE conn_alias = :conn_alias + AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; + |] + [ ":last_external_snd_msg_id" := fst senderMeta, + ":last_rcv_msg_hash" := internalHash, + ":conn_alias" := connAlias, + ":last_internal_rcv_msg_id" := internalRcvId + ] + +-- * updateSndIds helpers + +retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash) +retrieveLastIdsAndHashSnd_ dbConn connAlias = do + [(lastInternalId, lastInternalSndId, lastSndHash)] <- + DB.queryNamed + dbConn + [sql| + SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash + FROM connections + WHERE conn_alias = :conn_alias; + |] + [":conn_alias" := connAlias] + return (lastInternalId, lastInternalSndId, lastSndHash) + +updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () +updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = DB.executeNamed dbConn [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, last_internal_rcv_msg_id = :last_internal_rcv_msg_id + SET last_internal_msg_id = :last_internal_msg_id, + last_internal_snd_msg_id = :last_internal_snd_msg_id WHERE conn_alias = :conn_alias; |] [ ":last_internal_msg_id" := newInternalId, - ":last_internal_rcv_msg_id" := newInternalRcvId, + ":last_internal_snd_msg_id" := newInternalSndId, ":conn_alias" := connAlias ] -- * createSndMsg helpers -insertSndMsg :: DB.Connection -> ConnAlias -> MsgBody -> InternalTs -> IO (Either StoreError InternalId) -insertSndMsg dbConn connAlias msgBody internalTs = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (_, Just _sndQ) -> do - (lastInternalId, lastInternalSndId) <- retrieveLastInternalIdsSnd_ dbConn connAlias - let internalId = InternalId $ unId lastInternalId + 1 - let internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 - insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody - insertSndMsgDetails_ dbConn connAlias internalSndId internalId - updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId - return $ Right internalId - (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - _ -> return $ Left SEBadConn - -retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId) -retrieveLastInternalIdsSnd_ dbConn connAlias = do - [(lastInternalId, lastInternalSndId)] <- - DB.queryNamed - dbConn - [sql| - SELECT last_internal_msg_id, last_internal_snd_msg_id - FROM connections - WHERE conn_alias = :conn_alias; - |] - [":conn_alias" := connAlias] - return (lastInternalId, lastInternalSndId) - -insertSndMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalSndId -> MsgBody -> IO () -insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = do +insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do DB.executeNamed dbConn [sql| @@ -631,32 +626,35 @@ insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = ":body" := decodeUtf8 msgBody ] -insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> InternalSndId -> InternalId -> IO () -insertSndMsgDetails_ dbConn connAlias internalSndId internalId = +insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = DB.executeNamed dbConn [sql| INSERT INTO snd_messages - ( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts) + ( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts, internal_hash) VALUES - (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL); + (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash); |] [ ":conn_alias" := connAlias, ":internal_snd_id" := internalSndId, ":internal_id" := internalId, - ":snd_status" := Created + ":snd_status" := Created, + ":internal_hash" := internalHash ] -updateLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () -updateLastInternalIdsSnd_ dbConn connAlias newInternalId newInternalSndId = +updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +updateHashSnd_ dbConn connAlias SndMsgData {..} = DB.executeNamed dbConn + -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, last_internal_snd_msg_id = :last_internal_snd_msg_id - WHERE conn_alias = :conn_alias; + SET last_snd_msg_hash = :last_snd_msg_hash + WHERE conn_alias = :conn_alias + AND last_internal_snd_msg_id = :last_internal_snd_msg_id; |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_alias" := connAlias + [ ":last_snd_msg_hash" := internalHash, + ":conn_alias" := connAlias, + ":last_internal_snd_msg_id" := internalSndId ] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs index 12d886a4e..99ba67ccf 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs @@ -90,6 +90,9 @@ connections = last_internal_msg_id INTEGER NOT NULL, last_internal_rcv_msg_id INTEGER NOT NULL, last_internal_snd_msg_id INTEGER NOT NULL, + last_external_snd_msg_id INTEGER NOT NULL, + last_rcv_msg_hash BLOB NOT NULL, + last_snd_msg_hash BLOB NOT NULL, PRIMARY KEY (conn_alias), FOREIGN KEY (rcv_host, rcv_port, rcv_id) REFERENCES rcv_queues (host, port, rcv_id), FOREIGN KEY (snd_host, snd_port, snd_id) REFERENCES snd_queues (host, port, snd_id) @@ -135,6 +138,9 @@ rcvMessages = rcv_status TEXT NOT NULL, ack_brocker_ts TEXT, ack_sender_ts TEXT, + internal_hash BLOB NOT NULL, + external_prev_snd_hash BLOB NOT NULL, + integrity BLOB NOT NULL, PRIMARY KEY (conn_alias, internal_rcv_id), FOREIGN KEY (conn_alias, internal_id) REFERENCES messages (conn_alias, internal_id) @@ -152,6 +158,7 @@ sndMessages = snd_status TEXT NOT NULL, sent_ts TEXT, delivered_ts TEXT, + internal_hash BLOB NOT NULL, PRIMARY KEY (conn_alias, internal_snd_id), FOREIGN KEY (conn_alias, internal_id) REFERENCES messages (conn_alias, internal_id) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 207809276..fbaff38b6 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -12,7 +13,7 @@ module Simplex.Messaging.Agent.Transmission where -import Control.Applicative ((<|>)) +import Control.Applicative (optional, (<|>)) import Control.Monad.IO.Class import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A @@ -21,13 +22,14 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) -import Data.Kind +import Data.Kind (Type) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) import Network.Socket -import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol @@ -37,12 +39,12 @@ import Simplex.Messaging.Protocol MsgBody, MsgId, SenderPublicKey, - errMessageBody, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport import Simplex.Messaging.Util import System.IO +import Test.QuickCheck (Arbitrary (..)) import Text.Read import UnliftIO.Exception @@ -88,11 +90,11 @@ data ACommand (p :: AParty) where SEND :: MsgBody -> ACommand Client SENT :: AgentMsgId -> ACommand Agent MSG :: - { m_recipient :: (AgentMsgId, UTCTime), - m_broker :: (MsgId, UTCTime), - m_sender :: (AgentMsgId, UTCTime), - m_status :: MsgStatus, - m_body :: MsgBody + { recipientMeta :: (AgentMsgId, UTCTime), + brokerMeta :: (MsgId, UTCTime), + senderMeta :: (AgentMsgId, UTCTime), + msgIntegrity :: MsgIntegrity, + msgBody :: MsgBody } -> ACommand Agent -- ACK :: AgentMsgId -> ACommand Client @@ -125,7 +127,7 @@ data AMessage where deriving (Show) parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage -parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage +parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE where smpMessageP :: Parser SMPMessage smpMessageP = @@ -140,7 +142,9 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage SMPMessage <$> A.decimal <* A.space <*> tsISO8601P <* A.space - <*> base64P <* A.endOfLine + -- TODO previous message hash should become mandatory when we support HELLO and REPLY + -- (for HELLO it would be the hash of SMPConfirmation) + <*> (base64P <|> pure "") <* A.endOfLine <*> agentMessageP serializeSMPMessage :: SMPMessage -> ByteString @@ -152,7 +156,7 @@ serializeSMPMessage = \case in smpMessage "" header body where messageHeader msgId ts prevMsgHash = - B.unwords [B.pack $ show msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash] + B.unwords [bshow msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash] smpMessage smpHeader aHeader aBody = B.intercalate "\n" [smpHeader, aHeader, aBody, ""] agentMessageP :: Parser AMessage @@ -164,8 +168,8 @@ agentMessageP = hello = HELLO <$> C.pubKeyP <*> ackMode reply = REPLY <$> smpQueueInfoP a_msg = do - size :: Int <- A.decimal - A_MSG <$> (A.endOfLine *> A.take size <* A.endOfLine) + size :: Int <- A.decimal <* A.endOfLine + A_MSG <$> A.take size <* A.endOfLine ackMode = " NO_ACK" $> AckMode Off <|> pure (AckMode On) smpQueueInfoP :: Parser SMPQueueInfo @@ -173,14 +177,14 @@ smpQueueInfoP = "smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP) smpServerP :: Parser SMPServer -smpServerP = SMPServer <$> server <*> port <*> msgHash +smpServerP = SMPServer <$> server <*> optional port <*> optional kHash where server = B.unpack <$> A.takeTill (A.inClass ":# ") - port = A.char ':' *> (Just . show <$> (A.decimal :: Parser Int)) <|> pure Nothing - msgHash = A.char '#' *> (Just <$> base64P) <|> pure Nothing + port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit) + kHash = A.char '#' *> C.keyHashP parseAgentMessage :: ByteString -> Either AgentErrorType AMessage -parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage +parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE serializeAgentMessage :: AMessage -> ByteString serializeAgentMessage = \case @@ -194,17 +198,15 @@ serializeSmpQueueInfo (SMPQueueInfo srv qId ek) = serializeServer :: SMPServer -> ByteString serializeServer SMPServer {host, port, keyHash} = - B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash + B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack . C.serializeKeyHash) keyHash data SMPServer = SMPServer { host :: HostName, port :: Maybe ServiceName, - keyHash :: Maybe KeyHash + keyHash :: Maybe C.KeyHash } deriving (Eq, Ord, Show) -type KeyHash = Encoded - type ConnAlias = ByteString type OtherPartyId = Encoded @@ -220,9 +222,9 @@ data ReplyMode = ReplyOff | ReplyOn | ReplyVia SMPServer deriving (Eq, Show) type EncryptionKey = C.PublicKey -type DecryptionKey = C.PrivateKey +type DecryptionKey = C.SafePrivateKey -type SignatureKey = C.PrivateKey +type SignatureKey = C.SafePrivateKey type VerificationKey = C.PublicKey @@ -235,56 +237,60 @@ type AgentMsgId = Int64 type SenderTimestamp = UTCTime -data MsgStatus = MsgOk | MsgError MsgErrorType +data MsgIntegrity = MsgOk | MsgError MsgErrorType deriving (Eq, Show) -data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash +data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash | MsgDuplicate deriving (Eq, Show) +-- | error type used in errors sent to agent clients data AgentErrorType - = UNKNOWN - | PROHIBITED - | SYNTAX Int - | BROKER Natural - | SMP ErrorType - | CRYPTO C.CryptoError - | SIZE - | STORE - | INTERNAL -- etc. TODO SYNTAX Natural - deriving (Eq, Show, Exception) + = CMD CommandErrorType -- command errors + | CONN ConnectionErrorType -- connection state errors + | SMP ErrorType -- SMP protocol errors forwarded to agent clients + | BROKER BrokerErrorType -- SMP server errors + | AGENT SMPAgentError -- errors of other agents + | INTERNAL String -- agent implementation errors + deriving (Eq, Generic, Read, Show, Exception) -data AckStatus = AckOk | AckError AckErrorType - deriving (Show) +data CommandErrorType + = PROHIBITED -- command is prohibited + | SYNTAX -- command syntax is invalid + | NO_CONN -- connection alias is required with this command + | SIZE -- message size is not correct (no terminating space) + | LARGE -- message does not fit SMP block + deriving (Eq, Generic, Read, Show, Exception) -data AckErrorType = AckUnknown | AckProhibited | AckSyntax Int -- etc. - deriving (Show) +data ConnectionErrorType + = UNKNOWN -- connection alias not in database + | DUPLICATE -- connection alias already exists + | SIMPLEX -- connection is simplex, but operation requires another queue + deriving (Eq, Generic, Read, Show, Exception) -errBadEncoding :: Int -errBadEncoding = 10 +data BrokerErrorType + = RESPONSE ErrorType -- invalid server response (failed to parse) + | UNEXPECTED -- unexpected response + | NETWORK -- network error + | TRANSPORT TransportError -- handshake or other transport error + | TIMEOUT -- command response timeout + deriving (Eq, Generic, Read, Show, Exception) -errBadCommand :: Int -errBadCommand = 11 +data SMPAgentError + = A_MESSAGE -- possibly should include bytestring that failed to parse + | A_PROHIBITED -- possibly should include the prohibited SMP/agent message + | A_ENCRYPTION -- cannot RSA/AES-decrypt or parse decrypted header + | A_SIGNATURE -- invalid RSA signature + deriving (Eq, Generic, Read, Show, Exception) -errBadInvitation :: Int -errBadInvitation = 12 +instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU -errNoConnAlias :: Int -errNoConnAlias = 13 +instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU -errBadMessage :: Int -errBadMessage = 14 +instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU -errBadServer :: Int -errBadServer = 15 +instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU -smpErrTCPConnection :: Natural -smpErrTCPConnection = 1 - -smpErrCorrelationId :: Natural -smpErrCorrelationId = 2 - -smpUnexpectedResponse :: Natural -smpUnexpectedResponse = 3 +instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU commandP :: Parser ACmd commandP = @@ -309,28 +315,30 @@ commandP = sendCmd = ACmd SClient . SEND <$> A.takeByteString sentResp = ACmd SAgent . SENT <$> A.decimal message = do - m_status <- status <* A.space - m_recipient <- "R=" *> partyMeta A.decimal - m_broker <- "B=" *> partyMeta base64P - m_sender <- "S=" *> partyMeta A.decimal - m_body <- A.takeByteString - return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body} - -- TODO other error types - agentError = ACmd SAgent . ERR <$> ("SMP " *> smpErrorType) - smpErrorType = "AUTH" $> SMP SMP.AUTH + msgIntegrity <- msgIntegrityP <* A.space + recipientMeta <- "R=" *> partyMeta A.decimal + brokerMeta <- "B=" *> partyMeta base64P + senderMeta <- "S=" *> partyMeta A.decimal + msgBody <- A.takeByteString + return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody} replyMode = " NO_REPLY" $> ReplyOff <|> A.space *> (ReplyVia <$> smpServerP) <|> pure ReplyOn partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space - status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) + agentError = ACmd SAgent . ERR <$> agentErrorTypeP + +msgIntegrityP :: Parser MsgIntegrity +msgIntegrityP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) + where msgErrorType = "ID " *> (MsgBadId <$> A.decimal) - <|> "NO_ID " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) + <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash + <|> "DUPLICATE" $> MsgDuplicate parseCommand :: ByteString -> Either AgentErrorType ACmd -parseCommand = parse commandP $ SYNTAX errBadCommand +parseCommand = parse commandP $ CMD SYNTAX serializeCommand :: ACommand p -> ByteString serializeCommand = \case @@ -342,19 +350,19 @@ serializeCommand = \case END -> "END" SEND msgBody -> "SEND " <> serializeMsg msgBody SENT mId -> "SENT " <> bshow mId - MSG {m_recipient = (rmId, rTs), m_broker = (bmId, bTs), m_sender = (smId, sTs), m_status, m_body} -> + MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} -> B.unwords [ "MSG", - msgStatus m_status, + serializeMsgIntegrity msgIntegrity, "R=" <> bshow rmId <> "," <> showTs rTs, "B=" <> encode bmId <> "," <> showTs bTs, "S=" <> bshow smId <> "," <> showTs sTs, - serializeMsg m_body + serializeMsg msgBody ] OFF -> "OFF" DEL -> "DEL" CON -> "CON" - ERR e -> "ERR " <> B.pack (show e) + ERR e -> "ERR " <> serializeAgentError e OK -> "OK" where replyMode :: ReplyMode -> ByteString @@ -364,19 +372,35 @@ serializeCommand = \case ReplyOn -> "" showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis - msgStatus :: MsgStatus -> ByteString - msgStatus = \case - MsgOk -> "OK" - MsgError e -> - "ERR" <> case e of - MsgSkipped fromMsgId toMsgId -> - B.unwords ["NO_ID", B.pack $ show fromMsgId, B.pack $ show toMsgId] - MsgBadId aMsgId -> "ID " <> B.pack (show aMsgId) - MsgBadHash -> "HASH" --- TODO - save function as in the server Transmission - re-use? +serializeMsgIntegrity :: MsgIntegrity -> ByteString +serializeMsgIntegrity = \case + MsgOk -> "OK" + MsgError e -> + "ERR " <> case e of + MsgSkipped fromMsgId toMsgId -> + B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId] + MsgBadId aMsgId -> "ID " <> bshow aMsgId + MsgBadHash -> "HASH" + MsgDuplicate -> "DUPLICATE" + +agentErrorTypeP :: Parser AgentErrorType +agentErrorTypeP = + "SMP " *> (SMP <$> SMP.errorTypeP) + <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP) + <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) + <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) + <|> parseRead2 + +serializeAgentError :: AgentErrorType -> ByteString +serializeAgentError = \case + SMP e -> "SMP " <> SMP.serializeErrorType e + BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e + BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e + e -> bshow e + serializeMsg :: ByteString -> ByteString -serializeMsg body = B.pack (show $ B.length body) <> "\n" <> body +serializeMsg body = bshow (B.length body) <> "\n" <> body tPutRaw :: Handle -> ARawTransmission -> IO () tPutRaw h (corrId, connAlias, command) = do @@ -404,7 +428,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody fromParty :: ACmd -> Either AgentErrorType (ACommand p) fromParty (ACmd (p :: p1) cmd) = case testEquality party p of Just Refl -> Right cmd - _ -> Left PROHIBITED + _ -> Left $ CMD PROHIBITED tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p) tConnAlias (_, connAlias, _) cmd = case cmd of @@ -415,13 +439,13 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody ERR _ -> Right cmd -- other responses must have connAlias _ - | B.null connAlias -> Left $ SYNTAX errNoConnAlias + | B.null connAlias -> Left $ CMD NO_CONN | otherwise -> Right cmd cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) cmdWithMsgBody = \case SEND body -> SEND <$$> getMsgBody body - MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body + MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body cmd -> return $ Right cmd -- TODO refactor with server @@ -433,5 +457,5 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody Just size -> liftIO $ do body <- B.hGet h size s <- getLn h - return $ if B.null s then Right body else Left SIZE - Nothing -> return . Left $ SYNTAX errMessageBody + return $ if B.null s then Right body else Left $ CMD SIZE + Nothing -> return . Left $ CMD SYNTAX diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index a4737e399..8def2ab8a 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -8,7 +9,7 @@ {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Client - ( SMPClient, + ( SMPClient (blockSize), getSMPClient, closeSMPClient, createSMPQueue, @@ -33,33 +34,31 @@ import Control.Exception import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Except -import qualified Crypto.PubKey.RSA.Types as RSA import Data.ByteString.Char8 (ByteString) -import qualified Data.ByteString.Char8 as B import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe -import GHC.IO.Exception (IOErrorType (..)) import Network.Socket (ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.Transmission (SMPServer (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Util (liftEitherError, raceAny_) +import Simplex.Messaging.Util (bshow, liftError, raceAny_) import System.IO -import System.IO.Error import System.Timeout data SMPClient = SMPClient { action :: Async (), connected :: TVar Bool, smpServer :: SMPServer, + tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TVar (Map CorrId Request), sndQ :: TBQueue SignedRawTransmission, rcvQ :: TBQueue SignedTransmissionOrError, - msgQ :: TBQueue SMPServerTransmission + msgQ :: TBQueue SMPServerTransmission, + blockSize :: Int } type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker) @@ -69,7 +68,6 @@ data SMPClientConfig = SMPClientConfig defaultPort :: ServiceName, tcpTimeout :: Int, smpPing :: Int, - blockSize :: Int, smpCommandSize :: Int } @@ -78,36 +76,36 @@ smpDefaultConfig = SMPClientConfig { qSize = 16, defaultPort = "5223", - tcpTimeout = 2_000_000, + tcpTimeout = 4_000_000, smpPing = 30_000_000, - blockSize = 8_192, -- 16_384, smpCommandSize = 256 } data Request = Request { queueId :: QueueId, - responseVar :: TMVar (Either SMPClientError Cmd) + responseVar :: TMVar Response } -getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient +type Response = Either SMPClientError Cmd + +getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient) getSMPClient - smpServer@SMPServer {host, port} + smpServer@SMPServer {host, port, keyHash} SMPClientConfig {qSize, defaultPort, tcpTimeout, smpPing} msgQ disconnected = do c <- atomically mkSMPClient - started <- newEmptyTMVarIO + thVar <- newEmptyTMVarIO action <- async $ - runTCPClient host (fromMaybe defaultPort port) (client c started) - `finally` atomically (putTMVar started False) - tcpTimeout `timeout` atomically (takeTMVar started) >>= \case - Just True -> return c {action} - _ -> throwIO err + runTCPClient host (fromMaybe defaultPort port) (client c thVar) + `finally` atomically (putTMVar thVar $ Left SMPNetworkError) + tHandle <- tcpTimeout `timeout` atomically (takeTMVar thVar) + pure $ case tHandle of + Just (Right THandle {blockSize}) -> Right c {action, blockSize} + Just (Left e) -> Left e + Nothing -> Left SMPNetworkError where - err :: IOException - err = mkIOError TimeExpired "connection timeout" Nothing Nothing - mkSMPClient :: STM SMPClient mkSMPClient = do connected <- newTVar False @@ -118,8 +116,10 @@ getSMPClient return SMPClient { action = undefined, + blockSize = undefined, connected, smpServer, + tcpTimeout, clientCorrId, sentCommands, sndQ, @@ -127,19 +127,24 @@ getSMPClient msgQ } - client :: SMPClient -> TMVar Bool -> Handle -> IO () - client c started h = do - _ <- getLn h -- "Welcome to SMP" + client :: SMPClient -> TMVar (Either SMPClientError THandle) -> Handle -> IO () + client c thVar h = + runExceptT (clientHandshake h keyHash) >>= \case + Right th -> clientTransport c thVar th + Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e + + clientTransport :: SMPClient -> TMVar (Either SMPClientError THandle) -> THandle -> IO () + clientTransport c thVar th = do atomically $ do - modifyTVar (connected c) (const True) - putTMVar started True - raceAny_ [send c h, process c, receive c h, ping c] + writeTVar (connected c) True + putTMVar thVar $ Right th + raceAny_ [send c th, process c, receive c th, ping c] `finally` disconnected - send :: SMPClient -> Handle -> IO () + send :: SMPClient -> THandle -> IO () send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h - receive :: SMPClient -> Handle -> IO () + receive :: SMPClient -> THandle -> IO () receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ ping :: SMPClient -> IO () @@ -165,7 +170,7 @@ getSMPClient Left e -> Left $ SMPResponseError e Right (Cmd _ (ERR e)) -> Left $ SMPServerError e Right r -> Right r - else Left SMPQueueIdError + else Left SMPUnexpectedResponse closeSMPClient :: SMPClient -> IO () closeSMPClient = uninterruptibleCancel . action @@ -173,11 +178,11 @@ closeSMPClient = uninterruptibleCancel . action data SMPClientError = SMPServerError ErrorType | SMPResponseError ErrorType - | SMPQueueIdError | SMPUnexpectedResponse | SMPResponseTimeout - | SMPCryptoError RSA.Error - | SMPClientError + | SMPNetworkError + | SMPTransportError TransportError + | SMPSignatureError C.CryptoError deriving (Eq, Show, Exception) createSMPQueue :: @@ -222,14 +227,14 @@ suspendSMPQueue = okSMPCommand $ Cmd SRecipient OFF deleteSMPQueue :: SMPClient -> RecipientPrivateKey -> QueueId -> ExceptT SMPClientError IO () deleteSMPQueue = okSMPCommand $ Cmd SRecipient DEL -okSMPCommand :: Cmd -> SMPClient -> C.PrivateKey -> QueueId -> ExceptT SMPClientError IO () +okSMPCommand :: Cmd -> SMPClient -> C.SafePrivateKey -> QueueId -> ExceptT SMPClientError IO () okSMPCommand cmd c pKey qId = sendSMPCommand c (Just pKey) qId cmd >>= \case Cmd _ OK -> return () _ -> throwE SMPUnexpectedResponse -sendSMPCommand :: SMPClient -> Maybe C.PrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd -sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do +sendSMPCommand :: SMPClient -> Maybe C.SafePrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd +sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId cmd = do corrId <- lift_ getNextCorrId t <- signTransmission $ serializeTransmission (corrId, qId, cmd) ExceptT $ sendRecv corrId t @@ -241,20 +246,22 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do getNextCorrId = do i <- (+ 1) <$> readTVar clientCorrId writeTVar clientCorrId i - return . CorrId . B.pack $ show i + return . CorrId $ bshow i signTransmission :: ByteString -> ExceptT SMPClientError IO SignedRawTransmission signTransmission t = case pKey of Nothing -> return ("", t) Just pk -> do - sig <- liftEitherError SMPCryptoError $ C.sign pk t + sig <- liftError SMPSignatureError $ C.sign pk t return (sig, t) -- two separate "atomically" needed to avoid blocking - sendRecv :: CorrId -> SignedRawTransmission -> IO (Either SMPClientError Cmd) - sendRecv corrId t = atomically (send corrId t) >>= atomically . takeTMVar + sendRecv :: CorrId -> SignedRawTransmission -> IO Response + sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar + where + withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a - send :: CorrId -> SignedRawTransmission -> STM (TMVar (Either SMPClientError Cmd)) + send :: CorrId -> SignedRawTransmission -> STM (TMVar Response) send corrId t = do r <- newEmptyTMVar modifyTVar sentCommands . M.insert corrId $ Request qId r diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 719056eae..808699fe7 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -1,26 +1,53 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Crypto - ( PrivateKey (..), + ( PrivateKey (rsaPrivateKey), + SafePrivateKey, -- constructor is not exported + FullPrivateKey (..), PublicKey (..), Signature (..), CryptoError (..), + SafeKeyPair, + FullKeyPair, + Key (..), + IV (..), + KeyHash (..), generateKeyPair, + publicKey, + publicKeySize, + safePrivateKey, sign, verify, encrypt, decrypt, + encryptOAEP, + decryptOAEP, + encryptAES, + decryptAES, serializePrivKey, serializePubKey, - parsePrivKey, - parsePubKey, + encodePubKey, + serializeKeyHash, + getKeyHash, + sha256Hash, privKeyP, pubKeyP, + binaryPubKeyP, + keyHashP, + authTagSize, + authTagToBS, + bsToAuthTag, + randomAesKey, + randomIV, + aesKeyP, + ivP, ) where @@ -30,60 +57,89 @@ import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import qualified Crypto.Cipher.Types as AES import qualified Crypto.Error as CE -import Crypto.Hash.Algorithms (SHA256 (..)) +import Crypto.Hash (Digest, SHA256 (..), digestFromByteString, hash) import Crypto.Number.Generate (generateMax) import Crypto.Number.Prime (findPrimeFrom) -import Crypto.Number.Serialize (i2osp, os2ip) import qualified Crypto.PubKey.RSA as R import qualified Crypto.PubKey.RSA.OAEP as OAEP import qualified Crypto.PubKey.RSA.PSS as PSS import Crypto.Random (getRandomBytes) +import Data.ASN1.BinaryEncoding +import Data.ASN1.Encoding +import Data.ASN1.Types import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Bifunctor (first) +import Data.Bifunctor (bimap, first) import qualified Data.ByteArray as BA -import Data.ByteString.Base64 +import Data.ByteString.Base64 (decode, encode) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.ByteString.Internal (c2w, w2c) +import Data.ByteString.Lazy (fromStrict, toStrict) import Data.String -import Database.SQLite.Simple as DB -import Database.SQLite.Simple.FromField +import Data.Typeable (Typeable) +import Data.X509 +import Database.SQLite.Simple (ResultError (..), SQLData (..)) +import Database.SQLite.Simple.FromField (FieldParser, FromField (..), returnError) import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) import Network.Transport.Internal (decodeWord32, encodeWord32) -import Simplex.Messaging.Parsers (base64P) -import Simplex.Messaging.Util (bshow, liftEitherError, (<$$>)) +import Simplex.Messaging.Parsers (base64P, parseAll) +import Simplex.Messaging.Util (liftEitherError, (<$?>)) newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show) -data PrivateKey = PrivateKey - { private_size :: Int, - private_n :: Integer, - private_d :: Integer - } - deriving (Eq, Show) +newtype SafePrivateKey = SafePrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show) -instance ToField PrivateKey where toField = toField . serializePrivKey +newtype FullPrivateKey = FullPrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show) -instance ToField PublicKey where toField = toField . serializePubKey +class PrivateKey k where + rsaPrivateKey :: k -> R.PrivateKey + _privateKey :: R.PrivateKey -> k + mkPrivateKey :: R.PrivateKey -> k -instance FromField PrivateKey where - fromField f@(Field (SQLBlob b) _) = - case parsePrivKey b of +instance PrivateKey SafePrivateKey where + rsaPrivateKey = unPrivateKey + _privateKey = SafePrivateKey + mkPrivateKey R.PrivateKey {private_pub = k, private_d} = + safePrivateKey (R.public_size k, R.public_n k, private_d) + +instance PrivateKey FullPrivateKey where + rsaPrivateKey = unPrivateKey + _privateKey = FullPrivateKey + mkPrivateKey = FullPrivateKey + +instance IsString FullPrivateKey where + fromString = parseString (decode >=> decodePrivKey) + +instance IsString PublicKey where + fromString = parseString (decode >=> decodePubKey) + +parseString :: (ByteString -> Either String a) -> (String -> a) +parseString parse = either error id . parse . B.pack + +instance ToField SafePrivateKey where toField = toField . encodePrivKey + +instance ToField PublicKey where toField = toField . encodePubKey + +instance FromField SafePrivateKey where fromField = keyFromField binaryPrivKeyP + +instance FromField PublicKey where fromField = keyFromField binaryPubKeyP + +keyFromField :: Typeable k => Parser k -> FieldParser k +keyFromField p = \case + f@(Field (SQLBlob b) _) -> + case parseAll p b of Right k -> Ok k - Left e -> returnError ConversionFailed f ("couldn't parse PrivateKey field: " ++ e) - fromField f = returnError ConversionFailed f "expecting SQLBlob column type" + Left e -> returnError ConversionFailed f ("couldn't parse key field: " ++ e) + f -> returnError ConversionFailed f "expecting SQLBlob column type" -instance FromField PublicKey where - fromField f@(Field (SQLBlob b) _) = - case parsePubKey b of - Right k -> Ok k - Left e -> returnError ConversionFailed f ("couldn't parse PublicKey field: " ++ e) - fromField f = returnError ConversionFailed f "expecting SQLBlob column type" +type KeyPair k = (PublicKey, k) -type KeyPair = (PublicKey, PrivateKey) +type SafeKeyPair = (PublicKey, SafePrivateKey) + +type FullKeyPair = (PublicKey, FullPrivateKey) newtype Signature = Signature {unSignature :: ByteString} deriving (Eq, Show) @@ -93,10 +149,12 @@ instance IsString Signature where newtype Verified = Verified ByteString deriving (Show) data CryptoError - = CryptoRSAError R.Error - | CryptoCipherError CE.CryptoError + = RSAEncryptError R.Error + | RSADecryptError R.Error + | RSASignError R.Error + | AESCipherError CE.CryptoError | CryptoIVError - | CryptoDecryptError + | AESDecryptError | CryptoLargeMsgError | CryptoHeaderError String deriving (Eq, Show, Exception) @@ -110,19 +168,26 @@ aesKeySize = 256 `div` 8 authTagSize :: Int authTagSize = 128 `div` 8 -generateKeyPair :: Int -> IO KeyPair +generateKeyPair :: PrivateKey k => Int -> IO (KeyPair k) generateKeyPair size = loop where publicExponent = findPrimeFrom . (+ 3) <$> generateMax pubExpRange - privateKey s n d = PrivateKey {private_size = s, private_n = n, private_d = d} loop = do - (pub, priv) <- R.generate size =<< publicExponent - let s = R.public_size pub - n = R.public_n pub - d = R.private_d priv - in if d * d < n - then loop - else return (PublicKey pub, privateKey s n d) + (k, pk) <- R.generate size =<< publicExponent + let n = R.public_n k + d = R.private_d pk + if d * d < n + then loop + else pure (PublicKey k, mkPrivateKey pk) + +privateKeySize :: PrivateKey k => k -> Int +privateKeySize = R.public_size . R.private_pub . rsaPrivateKey + +publicKey :: FullPrivateKey -> PublicKey +publicKey = PublicKey . R.private_pub . rsaPrivateKey + +publicKeySize :: PublicKey -> Int +publicKeySize = R.public_size . rsaPublicKey data Header = Header { aesKey :: Key, @@ -135,52 +200,89 @@ newtype Key = Key {unKey :: ByteString} newtype IV = IV {unIV :: ByteString} +newtype KeyHash = KeyHash {unKeyHash :: Digest SHA256} deriving (Eq, Ord, Show) + +instance IsString KeyHash where + fromString = parseString $ parseAll keyHashP + +instance ToField KeyHash where toField = toField . serializeKeyHash + +instance FromField KeyHash where + fromField f@(Field (SQLBlob b) _) = + case parseAll keyHashP b of + Right k -> Ok k + Left e -> returnError ConversionFailed f ("couldn't parse KeyHash field: " ++ e) + fromField f = returnError ConversionFailed f "expecting SQLBlob column type" + +serializeKeyHash :: KeyHash -> ByteString +serializeKeyHash = encode . BA.convert . unKeyHash + +keyHashP :: Parser KeyHash +keyHashP = do + bs <- base64P + case digestFromByteString bs of + Just d -> pure $ KeyHash d + _ -> fail "invalid digest" + +getKeyHash :: ByteString -> KeyHash +getKeyHash = KeyHash . hash + +sha256Hash :: ByteString -> ByteString +sha256Hash = BA.convert . (hash :: ByteString -> Digest SHA256) + serializeHeader :: Header -> ByteString serializeHeader Header {aesKey, ivBytes, authTag, msgSize} = unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize headerP :: Parser Header headerP = do - aesKey <- Key <$> A.take aesKeySize - ivBytes <- IV <$> A.take (ivSize @AES256) + aesKey <- aesKeyP + ivBytes <- ivP authTag <- bsToAuthTag <$> A.take authTagSize msgSize <- fromIntegral . decodeWord32 <$> A.take 4 return Header {aesKey, ivBytes, authTag, msgSize} +aesKeyP :: Parser Key +aesKeyP = Key <$> A.take aesKeySize + +ivP :: Parser IV +ivP = IV <$> A.take (ivSize @AES256) + parseHeader :: ByteString -> Either CryptoError Header -parseHeader = first CryptoHeaderError . A.parseOnly (headerP <* A.endOfInput) +parseHeader = first CryptoHeaderError . parseAll headerP encrypt :: PublicKey -> Int -> ByteString -> ExceptT CryptoError IO ByteString encrypt k paddedSize msg = do - aesKey <- Key <$> randomBytes aesKeySize - ivBytes <- IV <$> randomBytes (ivSize @AES256) + aesKey <- liftIO randomAesKey + ivBytes <- liftIO randomIV + (authTag, msg') <- encryptAES aesKey ivBytes paddedSize msg + let header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg} + encHeader <- encryptOAEP k $ serializeHeader header + return $ encHeader <> msg' + +decrypt :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString +decrypt pk msg'' = do + let (encHeader, msg') = B.splitAt (privateKeySize pk) msg'' + header <- decryptOAEP pk encHeader + Header {aesKey, ivBytes, authTag, msgSize} <- except $ parseHeader header + msg <- decryptAES aesKey ivBytes msg' authTag + return $ B.take msgSize msg + +encryptAES :: Key -> IV -> Int -> ByteString -> ExceptT CryptoError IO (AES.AuthTag, ByteString) +encryptAES aesKey ivBytes paddedSize msg = do aead <- initAEAD @AES256 aesKey ivBytes msg' <- paddedMsg - let (authTag, msg'') = encryptAES aead msg' - header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg} - encHeader <- encryptOAEP k $ serializeHeader header - return $ encHeader <> msg'' + return $ AES.aeadSimpleEncrypt aead B.empty msg' authTagSize where len = B.length msg paddedMsg | len >= paddedSize = throwE CryptoLargeMsgError | otherwise = return (msg <> B.replicate (paddedSize - len) '#') -decrypt :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString -decrypt pk msg'' = do - let (encHeader, msg') = B.splitAt (private_size pk) msg'' - header <- decryptOAEP pk encHeader - Header {aesKey, ivBytes, authTag, msgSize} <- ExceptT . return $ parseHeader header +decryptAES :: Key -> IV -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString +decryptAES aesKey ivBytes msg authTag = do aead <- initAEAD @AES256 aesKey ivBytes - msg <- decryptAES aead msg' authTag - return $ B.take msgSize msg - -encryptAES :: AES.AEAD AES256 -> ByteString -> (AES.AuthTag, ByteString) -encryptAES aead plaintext = AES.aeadSimpleEncrypt aead B.empty plaintext authTagSize - -decryptAES :: AES.AEAD AES256 -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString -decryptAES aead ciphertext authTag = - maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty ciphertext authTag + maybeError AESDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c) initAEAD (Key aesKey) (IV ivBytes) = do @@ -189,15 +291,18 @@ initAEAD (Key aesKey) (IV ivBytes) = do cipher <- AES.cipherInit aesKey AES.aeadInit AES.AEAD_GCM cipher iv +randomAesKey :: IO Key +randomAesKey = Key <$> getRandomBytes aesKeySize + +randomIV :: IO IV +randomIV = IV <$> getRandomBytes (ivSize @AES256) + ivSize :: forall c. AES.BlockCipher c => Int ivSize = AES.blockSize (undefined :: c) makeIV :: AES.BlockCipher c => ByteString -> ExceptT CryptoError IO (AES.IV c) makeIV bs = maybeError CryptoIVError $ AES.makeIV bs -randomBytes :: Int -> ExceptT CryptoError IO ByteString -randomBytes n = ExceptT $ Right <$> getRandomBytes n - maybeError :: CryptoError -> Maybe a -> ExceptT CryptoError IO a maybeError e = maybe (throwE e) return @@ -208,75 +313,93 @@ bsToAuthTag :: ByteString -> AES.AuthTag bsToAuthTag = AES.AuthTag . BA.pack . map c2w . B.unpack cryptoFailable :: CE.CryptoFailable a -> ExceptT CryptoError IO a -cryptoFailable = liftEither . first CryptoCipherError . CE.eitherCryptoError +cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError oaepParams :: OAEP.OAEPParams SHA256 ByteString ByteString oaepParams = OAEP.defaultOAEPParams SHA256 encryptOAEP :: PublicKey -> ByteString -> ExceptT CryptoError IO ByteString encryptOAEP (PublicKey k) aesKey = - liftEitherError CryptoRSAError $ + liftEitherError RSAEncryptError $ OAEP.encrypt oaepParams k aesKey -decryptOAEP :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString +decryptOAEP :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString decryptOAEP pk encKey = - liftEitherError CryptoRSAError $ + liftEitherError RSADecryptError $ OAEP.decryptSafer oaepParams (rsaPrivateKey pk) encKey pssParams :: PSS.PSSParams SHA256 ByteString ByteString pssParams = PSS.defaultPSSParams SHA256 -sign :: PrivateKey -> ByteString -> IO (Either R.Error Signature) -sign pk msg = Signature <$$> PSS.signSafer pssParams (rsaPrivateKey pk) msg +sign :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO Signature +sign pk msg = ExceptT $ bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg verify :: PublicKey -> Signature -> ByteString -> Bool verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig serializePubKey :: PublicKey -> ByteString -serializePubKey (PublicKey k) = serializeKey_ (R.public_size k, R.public_n k, R.public_e k) +serializePubKey = ("rsa:" <>) . encode . encodePubKey -serializePrivKey :: PrivateKey -> ByteString -serializePrivKey pk = serializeKey_ (private_size pk, private_n pk, private_d pk) - -serializeKey_ :: (Int, Integer, Integer) -> ByteString -serializeKey_ (size, n, ex) = bshow size <> "," <> encInt n <> "," <> encInt ex - where - encInt = encode . i2osp +serializePrivKey :: PrivateKey k => k -> ByteString +serializePrivKey = ("rsa:" <>) . encode . encodePrivKey pubKeyP :: Parser PublicKey -pubKeyP = do - (public_size, public_n, public_e) <- keyParser_ - return . PublicKey $ R.PublicKey {R.public_size, R.public_n, R.public_e} +pubKeyP = decodePubKey <$?> ("rsa:" *> base64P) -privKeyP :: Parser PrivateKey -privKeyP = do - (private_size, private_n, private_d) <- keyParser_ - return PrivateKey {private_size, private_n, private_d} +binaryPubKeyP :: Parser PublicKey +binaryPubKeyP = decodePubKey <$?> A.takeByteString -parsePubKey :: ByteString -> Either String PublicKey -parsePubKey = A.parseOnly (pubKeyP <* A.endOfInput) +privKeyP :: PrivateKey k => Parser k +privKeyP = decodePrivKey <$?> ("rsa:" *> base64P) -parsePrivKey :: ByteString -> Either String PrivateKey -parsePrivKey = A.parseOnly (privKeyP <* A.endOfInput) +binaryPrivKeyP :: PrivateKey k => Parser k +binaryPrivKeyP = decodePrivKey <$?> A.takeByteString -keyParser_ :: Parser (Int, Integer, Integer) -keyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP - where - intP = os2ip <$> base64P +safePrivateKey :: (Int, Integer, Integer) -> SafePrivateKey +safePrivateKey = SafePrivateKey . safeRsaPrivateKey -rsaPrivateKey :: PrivateKey -> R.PrivateKey -rsaPrivateKey pk = +safeRsaPrivateKey :: (Int, Integer, Integer) -> R.PrivateKey +safeRsaPrivateKey (size, n, d) = R.PrivateKey - { R.private_pub = + { private_pub = R.PublicKey - { R.public_size = private_size pk, - R.public_n = private_n pk, - R.public_e = undefined + { public_size = size, + public_n = n, + public_e = 0 }, - R.private_d = private_d pk, - R.private_p = 0, - R.private_q = 0, - R.private_dP = undefined, - R.private_dQ = undefined, - R.private_qinv = undefined + private_d = d, + private_p = 0, + private_q = 0, + private_dP = 0, + private_dQ = 0, + private_qinv = 0 } + +encodePubKey :: PublicKey -> ByteString +encodePubKey = encodeKey . PubKeyRSA . rsaPublicKey + +encodePrivKey :: PrivateKey k => k -> ByteString +encodePrivKey = encodeKey . PrivKeyRSA . rsaPrivateKey + +encodeKey :: ASN1Object a => a -> ByteString +encodeKey k = toStrict . encodeASN1 DER $ toASN1 k [] + +decodePubKey :: ByteString -> Either String PublicKey +decodePubKey = + decodeKey >=> \case + (PubKeyRSA k, []) -> Right $ PublicKey k + r -> keyError r + +decodePrivKey :: PrivateKey k => ByteString -> Either String k +decodePrivKey = + decodeKey >=> \case + (PrivKeyRSA pk, []) -> Right $ mkPrivateKey pk + r -> keyError r + +decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) +decodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict + +keyError :: (a, [ASN1]) -> Either String b +keyError = \case + (_, []) -> Left "not RSA key" + _ -> Left "more than one key" diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 461c38c5e..25e2f32bb 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE OverloadedStrings #-} + module Simplex.Messaging.Parsers where import Data.Attoparsec.ByteString.Char8 (Parser) @@ -9,15 +11,35 @@ import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) +import Simplex.Messaging.Util ((<$?>)) +import Text.Read (readMaybe) base64P :: Parser ByteString -base64P = do +base64P = decode <$?> base64StringP + +base64StringP :: Parser ByteString +base64StringP = do str <- A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/') pad <- A.takeWhile (== '=') - either fail pure $ decode (str <> pad) + pure $ str <> pad tsISO8601P :: Parser UTCTime tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ') parse :: Parser a -> e -> (ByteString -> Either e a) -parse parser err = first (const err) . A.parseOnly (parser <* A.endOfInput) +parse parser err = first (const err) . parseAll parser + +parseAll :: Parser a -> (ByteString -> Either String a) +parseAll parser = A.parseOnly (parser <* A.endOfInput) + +parseRead :: Read a => Parser ByteString -> Parser a +parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack) + +parseRead1 :: Read a => Parser a +parseRead1 = parseRead $ A.takeTill (== ' ') + +parseRead2 :: Read a => Parser a +parseRead2 = parseRead $ do + w1 <- A.takeTill (== ' ') <* A.char ' ' + w2 <- A.takeTill (== ' ') + pure $ w1 <> " " <> w2 diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 94f49e24e..30d20ec5a 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} @@ -13,7 +14,7 @@ module Simplex.Messaging.Protocol where import Control.Applicative ((<|>)) import Control.Monad -import Control.Monad.IO.Class +import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Base64 @@ -24,12 +25,13 @@ import Data.Kind import Data.String import Data.Time.Clock import Data.Time.ISO8601 +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Transport import Simplex.Messaging.Util -import System.IO -import Text.Read +import Test.QuickCheck (Arbitrary (..)) data Party = Broker | Recipient | Sender deriving (Show) @@ -95,12 +97,12 @@ instance IsString CorrId where fromString = CorrId . fromString -- only used by Agent, kept here so its definition is close to respective public key -type RecipientPrivateKey = C.PrivateKey +type RecipientPrivateKey = C.SafePrivateKey type RecipientPublicKey = C.PublicKey -- only used by Agent, kept here so its definition is close to respective public key -type SenderPrivateKey = C.PrivateKey +type SenderPrivateKey = C.SafePrivateKey type SenderPublicKey = C.PublicKey @@ -108,25 +110,36 @@ type MsgId = Encoded type MsgBody = ByteString -data ErrorType = PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) +data ErrorType + = BLOCK + | CMD CommandError + | AUTH + | NO_MSG + | INTERNAL + | DUPLICATE_ -- TODO remove, not part of SMP protocol + deriving (Eq, Generic, Read, Show) -errBadTransmission :: Int -errBadTransmission = 1 +data CommandError + = PROHIBITED + | SYNTAX + | NO_AUTH + | HAS_AUTH + | NO_QUEUE + deriving (Eq, Generic, Read, Show) -errBadSMPCommand :: Int -errBadSMPCommand = 2 +instance Arbitrary ErrorType where arbitrary = genericArbitraryU -errNoCredentials :: Int -errNoCredentials = 3 +instance Arbitrary CommandError where arbitrary = genericArbitraryU -errHasCredentials :: Int -errHasCredentials = 4 - -errNoQueueId :: Int -errNoQueueId = 5 - -errMessageBody :: Int -errMessageBody = 6 +transmissionP :: Parser RawTransmission +transmissionP = do + signature <- segment + corrId <- segment + queueId <- segment + command <- A.takeByteString + return (signature, corrId, queueId, command) + where + segment = A.takeTill (== ' ') <* " " commandP :: Parser Cmd commandP = @@ -148,132 +161,110 @@ commandP = newCmd = Cmd SRecipient . NEW <$> C.pubKeyP idsResp = Cmd SBroker <$> (IDS <$> (base64P <* A.space) <*> base64P) keyCmd = Cmd SRecipient . KEY <$> C.pubKeyP - sendCmd = Cmd SSender . SEND <$> A.takeWhile A.isDigit + sendCmd = do + size <- A.decimal <* A.space + Cmd SSender . SEND <$> A.take size <* A.space message = do msgId <- base64P <* A.space ts <- tsISO8601P <* A.space - Cmd SBroker . MSG msgId ts <$> A.takeWhile A.isDigit - serverError = Cmd SBroker . ERR <$> errorType - errorType = - "PROHIBITED" $> PROHIBITED - <|> "SYNTAX " *> (SYNTAX <$> A.decimal) - <|> "SIZE" $> SIZE - <|> "AUTH" $> AUTH - <|> "INTERNAL" $> INTERNAL + size <- A.decimal <* A.space + Cmd SBroker . MSG msgId ts <$> A.take size <* A.space + serverError = Cmd SBroker . ERR <$> errorTypeP +-- TODO ignore the end of block, no need to parse it parseCommand :: ByteString -> Either ErrorType Cmd -parseCommand = parse commandP $ SYNTAX errBadSMPCommand +parseCommand = parse (commandP <* " " <* A.takeByteString) $ CMD SYNTAX serializeCommand :: Cmd -> ByteString serializeCommand = \case Cmd SRecipient (NEW rKey) -> "NEW " <> C.serializePubKey rKey Cmd SRecipient (KEY sKey) -> "KEY " <> C.serializePubKey sKey - Cmd SRecipient cmd -> B.pack $ show cmd - Cmd SSender (SEND msgBody) -> "SEND" <> serializeMsg msgBody + Cmd SRecipient cmd -> bshow cmd + Cmd SSender (SEND msgBody) -> "SEND " <> serializeMsg msgBody Cmd SSender PING -> "PING" Cmd SBroker (MSG msgId ts msgBody) -> - B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts] <> serializeMsg msgBody + B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody] Cmd SBroker (IDS rId sId) -> B.unwords ["IDS", encode rId, encode sId] - Cmd SBroker (ERR err) -> "ERR " <> B.pack (show err) - Cmd SBroker resp -> B.pack $ show resp + Cmd SBroker (ERR err) -> "ERR " <> serializeErrorType err + Cmd SBroker resp -> bshow resp where - serializeMsg msgBody = " " <> B.pack (show $ B.length msgBody) <> "\r\n" <> msgBody + serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " " -tPutRaw :: Handle -> RawTransmission -> IO () -tPutRaw h (signature, corrId, queueId, command) = do - putLn h signature - putLn h corrId - putLn h queueId - putLn h command +errorTypeP :: Parser ErrorType +errorTypeP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1 -tGetRaw :: Handle -> IO RawTransmission -tGetRaw h = do - signature <- getLn h - corrId <- getLn h - queueId <- getLn h - command <- getLn h - return (signature, corrId, queueId, command) +serializeErrorType :: ErrorType -> ByteString +serializeErrorType = bshow -tPut :: Handle -> SignedRawTransmission -> IO () -tPut h (C.Signature sig, t) = do - putLn h $ encode sig - putLn h t +tPut :: THandle -> SignedRawTransmission -> IO (Either TransportError ()) +tPut th (C.Signature sig, t) = + tPutEncrypted th $ encode sig <> " " <> t <> " " serializeTransmission :: Transmission -> ByteString serializeTransmission (CorrId corrId, queueId, command) = - B.intercalate "\r\n" [corrId, encode queueId, serializeCommand command] + B.intercalate " " [corrId, encode queueId, serializeCommand command] fromClient :: Cmd -> Either ErrorType Cmd fromClient = \case - Cmd SBroker _ -> Left PROHIBITED + Cmd SBroker _ -> Left $ CMD PROHIBITED cmd -> Right cmd fromServer :: Cmd -> Either ErrorType Cmd fromServer = \case cmd@(Cmd SBroker _) -> Right cmd - _ -> Left PROHIBITED + _ -> Left $ CMD PROHIBITED + +tGetParse :: THandle -> IO (Either TransportError RawTransmission) +tGetParse th = (>>= parse transmissionP TEBadBlock) <$> tGetEncrypted th -- | get client and server transmissions -- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used -tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> Handle -> m SignedTransmissionOrError -tGet fromParty h = do - (signature, corrId, queueId, command) <- liftIO $ tGetRaw h - let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId) - either (const $ tError corrId) tParseLoadBody decodedTransmission +tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> THandle -> m SignedTransmissionOrError +tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate where - tError :: ByteString -> m SignedTransmissionOrError - tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left $ SYNTAX errBadTransmission)) + decodeParseValidate :: Either TransportError RawTransmission -> m SignedTransmissionOrError + decodeParseValidate = \case + Right (signature, corrId, queueId, command) -> + let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId) + in either (const $ tError corrId) tParseValidate decodedTransmission + Left _ -> tError "" - tParseLoadBody :: RawTransmission -> m SignedTransmissionOrError - tParseLoadBody t@(sig, corrId, queueId, command) = do + tError :: ByteString -> m SignedTransmissionOrError + tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left BLOCK)) + + tParseValidate :: RawTransmission -> m SignedTransmissionOrError + tParseValidate t@(sig, corrId, queueId, command) = do let cmd = parseCommand command >>= fromParty >>= tCredentials t - fullCmd <- either (return . Left) cmdWithMsgBody cmd - return (C.Signature sig, (CorrId corrId, queueId, fullCmd)) + return (C.Signature sig, (CorrId corrId, queueId, cmd)) tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd tCredentials (signature, _, queueId, _) cmd = case cmd of - -- IDS response should not have queue ID + -- IDS response must not have queue ID Cmd SBroker (IDS _ _) -> Right cmd -- ERR response does not always have queue ID Cmd SBroker (ERR _) -> Right cmd - -- PONG response should not have queue ID + -- PONG response must not have queue ID Cmd SBroker PONG | B.null queueId -> Right cmd - | otherwise -> Left $ SYNTAX errHasCredentials + | otherwise -> Left $ CMD HAS_AUTH -- other responses must have queue ID Cmd SBroker _ - | B.null queueId -> Left $ SYNTAX errNoQueueId + | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd - -- NEW must NOT have signature or queue ID + -- NEW must have signature but NOT queue ID Cmd SRecipient (NEW _) - | B.null signature -> Left $ SYNTAX errNoCredentials - | not (B.null queueId) -> Left $ SYNTAX errHasCredentials + | B.null signature -> Left $ CMD NO_AUTH + | not (B.null queueId) -> Left $ CMD HAS_AUTH | otherwise -> Right cmd -- SEND must have queue ID, signature is not always required Cmd SSender (SEND _) - | B.null queueId -> Left $ SYNTAX errNoQueueId + | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd -- PING must not have queue ID or signature Cmd SSender PING | B.null queueId && B.null signature -> Right cmd - | otherwise -> Left $ SYNTAX errHasCredentials + | otherwise -> Left $ CMD HAS_AUTH -- other client commands must have both signature and queue ID Cmd SRecipient _ - | B.null signature || B.null queueId -> Left $ SYNTAX errNoCredentials + | B.null signature || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd - - cmdWithMsgBody :: Cmd -> m (Either ErrorType Cmd) - cmdWithMsgBody = \case - Cmd SSender (SEND sizeStr) -> - Cmd SSender . SEND <$$> getMsgBody sizeStr - Cmd SBroker (MSG msgId ts sizeStr) -> - Cmd SBroker . MSG msgId ts <$$> getMsgBody sizeStr - cmd -> return $ Right cmd - - getMsgBody :: MsgBody -> m (Either ErrorType MsgBody) - getMsgBody sizeStr = case readMaybe (B.unpack sizeStr) :: Maybe Int of - Just size -> liftIO $ do - body <- B.hGet h size - s <- getLn h - return $ if B.null s then Right body else Left SIZE - Nothing -> return $ Left INTERNAL diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 034ff8ff8..d0d736673 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -15,6 +15,7 @@ module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, randomBytes import Control.Concurrent.STM (stateTVar) import Control.Monad +import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random @@ -30,6 +31,7 @@ import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (QueueStore) +import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.Transport import Simplex.Messaging.Util import UnliftIO.Async @@ -50,6 +52,7 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do smpServer = do s <- asks server race_ (runTCPServer started tcpPort runClient) (serverThread s) + `finally` withLog closeStoreLog serverThread :: MonadUnliftIO m => Server -> m () serverThread Server {subscribedQ, subscribers} = forever . atomically $ do @@ -62,11 +65,17 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () runClient h = do - liftIO $ putLn h "Welcome to SMP v0.2.0" + keyPair <- asks serverKeyPair + liftIO (runExceptT $ serverHandshake h keyPair) >>= \case + Right th -> runClientTransport th + Left _ -> pure () + +runClientTransport :: (MonadUnliftIO m, MonadReader Env m) => THandle -> m () +runClientTransport th = do q <- asks $ tbqSize . config c <- atomically $ newClient q s <- asks server - raceAny_ [send h c, client c s, receive h c] + raceAny_ [send th c, client c s, receive th c] `finally` cancelSubscribers c cancelSubscribers :: MonadUnliftIO m => Client -> m () @@ -78,7 +87,7 @@ cancelSub = \case Sub {subThread = SubThread t} -> killThread t _ -> return () -receive :: (MonadUnliftIO m, MonadReader Env m) => Handle -> Client -> m () +receive :: (MonadUnliftIO m, MonadReader Env m) => THandle -> Client -> m () receive h Client {rcvQ} = forever $ do (signature, (corrId, queueId, cmdOrError)) <- tGet fromClient h t <- case cmdOrError of @@ -86,7 +95,7 @@ receive h Client {rcvQ} = forever $ do Right cmd -> verifyTransmission (signature, (corrId, queueId, cmd)) atomically $ writeTBQueue rcvQ t -send :: MonadUnliftIO m => Handle -> Client -> m () +send :: MonadUnliftIO m => THandle -> Client -> m () send h Client {sndQ} = forever $ do t <- atomically $ readTBQueue sndQ liftIO $ tPut h ("", serializeTransmission t) @@ -99,25 +108,27 @@ verifyTransmission (sig, t@(corrId, queueId, cmd)) = do (corrId,queueId,) <$> case cmd of Cmd SBroker _ -> return $ smpErr INTERNAL -- it can only be client command, because `fromClient` was used Cmd SRecipient (NEW k) -> return $ verifySignature k - Cmd SRecipient _ -> withQueueRec SRecipient $ verifySignature . recipientKey - Cmd SSender (SEND _) -> withQueueRec SSender $ verifySend sig . senderKey + Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey + Cmd SSender (SEND _) -> verifyCmd SSender $ verifySend sig . senderKey Cmd SSender PING -> return cmd where - withQueueRec :: SParty (p :: Party) -> (QueueRec -> Cmd) -> m Cmd - withQueueRec party f = do + verifyCmd :: SParty p -> (QueueRec -> Cmd) -> m Cmd + verifyCmd party f = do + (aKey, _) <- asks serverKeyPair -- any public key can be used to mitigate timing attack st <- asks queueStore - qr <- atomically $ getQueue st party queueId - return $ either smpErr f qr + q <- atomically $ getQueue st party queueId + pure $ either (const $ fakeVerify aKey) f q + fakeVerify :: C.PublicKey -> Cmd + fakeVerify aKey = if verify aKey then authErr else authErr verifySend :: C.Signature -> Maybe SenderPublicKey -> Cmd verifySend "" = maybe cmd (const authErr) verifySend _ = maybe authErr verifySignature verifySignature :: C.PublicKey -> Cmd - verifySignature key = - if C.verify key sig (serializeTransmission t) - then cmd - else authErr - - smpErr e = Cmd SBroker $ ERR e + verifySignature key = if verify key then cmd else authErr + verify :: C.PublicKey -> Bool + verify key = C.verify key sig (serializeTransmission t) + smpErr :: ErrorType -> Cmd + smpErr = Cmd SBroker . ERR authErr = smpErr AUTH client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m () @@ -140,8 +151,8 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = NEW rKey -> createQueue st rKey SUB -> subscribeQueue queueId ACK -> acknowledgeMsg - KEY sKey -> okResp <$> atomically (secureQueue st queueId sKey) - OFF -> okResp <$> atomically (suspendQueue st queueId) + KEY sKey -> secureQueue_ st sKey + OFF -> suspendQueue_ st DEL -> delQueueAndMsgs st where createQueue :: QueueStore -> RecipientPublicKey -> m Transmission @@ -151,22 +162,40 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = addSubscribe = addQueueRetry 3 >>= \case Left e -> return $ ERR e - Right (rId, sId) -> subscribeQueue rId $> IDS rId sId + Right (rId, sId) -> do + withLog (`logCreateById` rId) + subscribeQueue rId $> IDS rId sId addQueueRetry :: Int -> m (Either ErrorType (RecipientId, SenderId)) addQueueRetry 0 = return $ Left INTERNAL addQueueRetry n = do ids <- getIds atomically (addQueue st rKey ids) >>= \case - Left DUPLICATE -> addQueueRetry $ n - 1 + Left DUPLICATE_ -> addQueueRetry $ n - 1 Left e -> return $ Left e Right _ -> return $ Right ids + logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO () + logCreateById s rId = + atomically (getQueue st SRecipient rId) >>= \case + Right q -> logCreateQueue s q + _ -> pure () + getIds :: m (RecipientId, SenderId) getIds = do n <- asks $ queueIdBytes . config liftM2 (,) (randomId n) (randomId n) + secureQueue_ :: QueueStore -> SenderPublicKey -> m Transmission + secureQueue_ st sKey = do + withLog $ \s -> logSecureQueue s queueId sKey + okResp <$> atomically (secureQueue st queueId sKey) + + suspendQueue_ :: QueueStore -> m Transmission + suspendQueue_ st = do + withLog (`logDeleteQueue` queueId) + okResp <$> atomically (suspendQueue st queueId) + subscribeQueue :: RecipientId -> m Transmission subscribeQueue rId = atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId @@ -193,7 +222,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s)) >>= \case Just (Just s) -> deliverMessage tryDelPeekMsg queueId s - _ -> return $ err PROHIBITED + _ -> return $ err NO_MSG withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a) withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId @@ -253,6 +282,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = delQueueAndMsgs :: QueueStore -> m Transmission delQueueAndMsgs st = do + withLog (`logDeleteQueue` queueId) ms <- asks msgStore atomically $ deleteQueue st queueId >>= \case @@ -271,6 +301,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = msgCmd :: Message -> Command 'Broker msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody +withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m () +withLog action = do + env <- ask + liftIO . mapM_ action $ storeLog (env :: Env) + randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded randomId n = do gVar <- asks idsDrg diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index c2d170e05..9a61e0243 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Server.Env.STM where @@ -10,16 +12,23 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Network.Socket (ServiceName) import Numeric.Natural +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Server.MsgStore.STM +import Simplex.Messaging.Server.QueueStore (QueueRec (..)) import Simplex.Messaging.Server.QueueStore.STM +import Simplex.Messaging.Server.StoreLog +import System.IO (IOMode (..)) import UnliftIO.STM data ServerConfig = ServerConfig { tcpPort :: ServiceName, tbqSize :: Natural, queueIdBytes :: Int, - msgIdBytes :: Int + msgIdBytes :: Int, + storeLog :: Maybe (StoreLog 'ReadMode), + serverPrivateKey :: C.FullPrivateKey + -- serverId :: ByteString } data Env = Env @@ -27,7 +36,9 @@ data Env = Env server :: Server, queueStore :: QueueStore, msgStore :: STMMsgStore, - idsDrg :: TVar ChaChaDRG + idsDrg :: TVar ChaChaDRG, + serverKeyPair :: C.FullKeyPair, + storeLog :: Maybe (StoreLog 'WriteMode) } data Server = Server @@ -66,10 +77,21 @@ newSubscription = do delivered <- newEmptyTMVar return Sub {subThread = NoSub, delivered} -newEnv :: (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env +newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env newEnv config = do server <- atomically $ newServer (tbqSize config) queueStore <- atomically newQueueStore msgStore <- atomically newMsgStore idsDrg <- drgNew >>= newTVarIO - return Env {config, server, queueStore, msgStore, idsDrg} + s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig) + let pk = serverPrivateKey config + serverKeyPair = (C.publicKey pk, pk) + return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s'} + where + restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode) + restoreQueues queueStore s = do + (queues, s') <- liftIO $ readWriteStoreLog s + atomically $ modifyTVar queueStore $ \d -> d {queues, senders = M.foldr' addSender M.empty queues} + pure s' + addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId + addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index fd6783106..79eb2daee 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -15,7 +15,7 @@ data QueueRec = QueueRec status :: QueueStatus } -data QueueStatus = QueueActive | QueueOff +data QueueStatus = QueueActive | QueueOff deriving (Eq) class MonadQueueStore s m where addQueue :: s -> RecipientPublicKey -> (RecipientId, SenderId) -> m (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index dca596e82..86caff78f 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -32,7 +32,7 @@ instance MonadQueueStore QueueStore STM where addQueue store rKey ids@(rId, sId) = do cs@QueueStoreData {queues, senders} <- readTVar store if M.member rId queues || M.member sId senders - then return $ Left DUPLICATE + then return $ Left DUPLICATE_ else do writeTVar store $ cs diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs new file mode 100644 index 000000000..5841b23c5 --- /dev/null +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -0,0 +1,140 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} + +module Simplex.Messaging.Server.StoreLog + ( StoreLog, -- constructors are not exported + openWriteStoreLog, + openReadStoreLog, + closeStoreLog, + logCreateQueue, + logSecureQueue, + logDeleteQueue, + readWriteStoreLog, + ) +where + +import Control.Applicative (optional, (<|>)) +import Control.Monad (unless) +import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first, second) +import Data.ByteString.Base64 (encode) +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Either (partitionEithers) +import Data.Functor (($>)) +import Data.List (foldl') +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Parsers (base64P, parseAll) +import Simplex.Messaging.Protocol +import Simplex.Messaging.Server.QueueStore (QueueRec (..), QueueStatus (..)) +import Simplex.Messaging.Transport (trimCR) +import System.Directory (doesFileExist) +import System.IO + +-- | opaque container for file handle with a type-safe IOMode +-- constructors are not exported, openWriteStoreLog and openReadStoreLog should be used instead +data StoreLog (a :: IOMode) where + ReadStoreLog :: FilePath -> Handle -> StoreLog 'ReadMode + WriteStoreLog :: FilePath -> Handle -> StoreLog 'WriteMode + +data StoreLogRecord + = CreateQueue QueueRec + | SecureQueue QueueId SenderPublicKey + | DeleteQueue QueueId + +storeLogRecordP :: Parser StoreLogRecord +storeLogRecordP = + "CREATE " *> createQueueP + <|> "SECURE " *> secureQueueP + <|> "DELETE " *> (DeleteQueue <$> base64P) + where + createQueueP = CreateQueue <$> queueRecP + secureQueueP = SecureQueue <$> base64P <* A.space <*> C.pubKeyP + queueRecP = do + recipientId <- "rid=" *> base64P <* A.space + senderId <- "sid=" *> base64P <* A.space + recipientKey <- "rk=" *> C.pubKeyP <* A.space + senderKey <- "sk=" *> optional C.pubKeyP + pure QueueRec {recipientId, senderId, recipientKey, senderKey, status = QueueActive} + +serializeStoreLogRecord :: StoreLogRecord -> ByteString +serializeStoreLogRecord = \case + CreateQueue q -> "CREATE " <> serializeQueue q + SecureQueue rId sKey -> "SECURE " <> encode rId <> " " <> C.serializePubKey sKey + DeleteQueue rId -> "DELETE " <> encode rId + where + serializeQueue QueueRec {recipientId, senderId, recipientKey, senderKey} = + B.unwords + [ "rid=" <> encode recipientId, + "sid=" <> encode senderId, + "rk=" <> C.serializePubKey recipientKey, + "sk=" <> maybe "" C.serializePubKey senderKey + ] + +openWriteStoreLog :: FilePath -> IO (StoreLog 'WriteMode) +openWriteStoreLog f = WriteStoreLog f <$> openFile f WriteMode + +openReadStoreLog :: FilePath -> IO (StoreLog 'ReadMode) +openReadStoreLog f = do + doesFileExist f >>= (`unless` writeFile f "") + ReadStoreLog f <$> openFile f ReadMode + +closeStoreLog :: StoreLog a -> IO () +closeStoreLog = \case + WriteStoreLog _ h -> hClose h + ReadStoreLog _ h -> hClose h + +writeStoreLogRecord :: StoreLog 'WriteMode -> StoreLogRecord -> IO () +writeStoreLogRecord (WriteStoreLog _ h) r = do + B.hPutStrLn h $ serializeStoreLogRecord r + hFlush h + +logCreateQueue :: StoreLog 'WriteMode -> QueueRec -> IO () +logCreateQueue s = writeStoreLogRecord s . CreateQueue + +logSecureQueue :: StoreLog 'WriteMode -> QueueId -> SenderPublicKey -> IO () +logSecureQueue s qId sKey = writeStoreLogRecord s $ SecureQueue qId sKey + +logDeleteQueue :: StoreLog 'WriteMode -> QueueId -> IO () +logDeleteQueue s = writeStoreLogRecord s . DeleteQueue + +readWriteStoreLog :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec, StoreLog 'WriteMode) +readWriteStoreLog s@(ReadStoreLog f _) = do + qs <- readQueues s + closeStoreLog s + s' <- openWriteStoreLog f + writeQueues s' qs + pure (qs, s') + +writeQueues :: StoreLog 'WriteMode -> Map RecipientId QueueRec -> IO () +writeQueues s = mapM_ (writeStoreLogRecord s . CreateQueue) . M.filter active + where + active QueueRec {status} = status == QueueActive + +type LogParsingError = (String, ByteString) + +readQueues :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec) +readQueues (ReadStoreLog _ h) = LB.hGetContents h >>= returnResult . procStoreLog + where + procStoreLog :: LB.ByteString -> ([LogParsingError], Map RecipientId QueueRec) + procStoreLog = second (foldl' procLogRecord M.empty) . partitionEithers . map parseLogRecord . LB.lines + returnResult :: ([LogParsingError], Map RecipientId QueueRec) -> IO (Map RecipientId QueueRec) + returnResult (errs, res) = mapM_ printError errs $> res + parseLogRecord :: LB.ByteString -> Either LogParsingError StoreLogRecord + parseLogRecord = (\s -> first (,s) $ parseAll storeLogRecordP s) . trimCR . LB.toStrict + procLogRecord :: Map RecipientId QueueRec -> StoreLogRecord -> Map RecipientId QueueRec + procLogRecord m = \case + CreateQueue q -> M.insert (recipientId q) q m + SecureQueue qId sKey -> M.adjust (\q -> q {senderKey = Just sKey}) qId m + DeleteQueue qId -> M.delete qId m + printError :: LogParsingError -> IO () + printError (e, s) = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 02346fa3d..731f3d4d2 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -1,29 +1,52 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Transport where -import Control.Monad.IO.Class +import Control.Applicative ((<|>)) +import Control.Monad.Except import Control.Monad.IO.Unlift -import Control.Monad.Reader +import Control.Monad.Trans.Except (throwE) +import Crypto.Cipher.Types (AuthTag) +import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first) +import Data.ByteArray (xor) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Functor (($>)) import Data.Set (Set) import qualified Data.Set as S +import Data.Word (Word32) +import GHC.Generics (Generic) import GHC.IO.Exception (IOErrorType (..)) +import GHC.IO.Handle.Internals (ioe_EOF) +import Generic.Random (genericArbitraryU) import Network.Socket +import Network.Transport.Internal (decodeNum16, decodeNum32, encodeEnum16, encodeEnum32, encodeWord32) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Parsers (parse, parseAll, parseRead1) +import Simplex.Messaging.Util (bshow, liftError) import System.IO import System.IO.Error +import Test.QuickCheck (Arbitrary (..)) import UnliftIO.Concurrent import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E import qualified UnliftIO.IO as IO import UnliftIO.STM +-- * TCP transport + runTCPServer :: MonadUnliftIO m => TMVar Bool -> ServiceName -> (Handle -> m ()) -> m () runTCPServer started port server = do clients <- newTVarIO S.empty @@ -33,7 +56,10 @@ runTCPServer started port server = do atomically . modifyTVar clients $ S.insert tid where closeServer :: TVar (Set ThreadId) -> Socket -> IO () - closeServer clients sock = readTVarIO clients >>= mapM_ killThread >> close sock + closeServer clients sock = do + readTVarIO clients >>= mapM_ killThread + close sock + void . atomically $ tryPutTMVar started False startTCPServer :: TMVar Bool -> ServiceName -> IO Socket startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted @@ -59,9 +85,7 @@ runTCPClient host port client = do client h `E.finally` IO.hClose h startTCPClient :: HostName -> ServiceName -> IO Handle -startTCPClient host port = - withSocketsDo $ - resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion +startTCPClient host port = withSocketsDo $ resolve >>= tryOpen err where err :: IOException err = mkIOError NoSuchThing "no address" Nothing Nothing @@ -71,9 +95,10 @@ startTCPClient host port = let hints = defaultHints {addrSocketType = Stream} in getAddrInfo (Just hints) (Just host) (Just port) - tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle) - tryOpen (Left _) addr = E.try $ open addr - tryOpen h _ = return h + tryOpen :: IOException -> [AddrInfo] -> IO Handle + tryOpen e [] = E.throwIO e + tryOpen _ (addr : as) = + E.try (open addr) >>= either (`tryOpen` as) pure open :: AddrInfo -> IO Handle open addr = do @@ -93,7 +118,261 @@ putLn :: Handle -> ByteString -> IO () putLn h = B.hPut h . (<> "\r\n") getLn :: Handle -> IO ByteString -getLn h = trim_cr <$> B.hGetLine h +getLn h = trimCR <$> B.hGetLine h + +trimCR :: ByteString -> ByteString +trimCR "" = "" +trimCR s = if B.last s == '\r' then B.init s else s + +-- * Encrypted transport + +data SMPVersion = SMPVersion Int Int Int Int + deriving (Eq, Ord) + +major :: SMPVersion -> (Int, Int) +major (SMPVersion a b _ _) = (a, b) + +currentSMPVersion :: SMPVersion +currentSMPVersion = SMPVersion 0 2 0 0 + +serializeSMPVersion :: SMPVersion -> ByteString +serializeSMPVersion (SMPVersion a b c d) = B.intercalate "." [bshow a, bshow b, bshow c, bshow d] + +smpVersionP :: Parser SMPVersion +smpVersionP = + let ver = A.decimal <* A.char '.' + in SMPVersion <$> ver <*> ver <*> ver <*> A.decimal + +data THandle = THandle + { handle :: Handle, + sndKey :: SessionKey, + rcvKey :: SessionKey, + blockSize :: Int + } + +data SessionKey = SessionKey + { aesKey :: C.Key, + baseIV :: C.IV, + counter :: TVar Word32 + } + +data ClientHandshake = ClientHandshake + { blockSize :: Int, + sndKey :: SessionKey, + rcvKey :: SessionKey + } + +data TransportError + = TEBadBlock + | TEEncrypt + | TEDecrypt + | TEHandshake HandshakeError + deriving (Eq, Generic, Read, Show, Exception) + +data HandshakeError + = ENCRYPT + | DECRYPT + | VERSION + | RSA_KEY + | HEADER + | AES_KEYS + | BAD_HASH + | MAJOR_VERSION + | TERMINATED + deriving (Eq, Generic, Read, Show, Exception) + +instance Arbitrary TransportError where arbitrary = genericArbitraryU + +instance Arbitrary HandshakeError where arbitrary = genericArbitraryU + +transportErrorP :: Parser TransportError +transportErrorP = + "BLOCK" $> TEBadBlock + <|> "AES_ENCRYPT" $> TEEncrypt + <|> "AES_DECRYPT" $> TEDecrypt + <|> TEHandshake <$> parseRead1 + +serializeTransportError :: TransportError -> ByteString +serializeTransportError = \case + TEEncrypt -> "AES_ENCRYPT" + TEDecrypt -> "AES_DECRYPT" + TEBadBlock -> "BLOCK" + TEHandshake e -> bshow e + +tPutEncrypted :: THandle -> ByteString -> IO (Either TransportError ()) +tPutEncrypted THandle {handle = h, sndKey, blockSize} block = + encryptBlock sndKey (blockSize - C.authTagSize) block >>= \case + Left _ -> pure $ Left TEEncrypt + Right (authTag, msg) -> Right <$> B.hPut h (C.authTagToBS authTag <> msg) + +tGetEncrypted :: THandle -> IO (Either TransportError ByteString) +tGetEncrypted THandle {handle = h, rcvKey, blockSize} = + B.hGet h blockSize >>= decryptBlock rcvKey >>= \case + Left _ -> pure $ Left TEDecrypt + Right "" -> ioe_EOF + Right msg -> pure $ Right msg + +encryptBlock :: SessionKey -> Int -> ByteString -> IO (Either C.CryptoError (AuthTag, ByteString)) +encryptBlock k@SessionKey {aesKey} size block = do + ivBytes <- makeNextIV k + runExceptT $ C.encryptAES aesKey ivBytes size block + +decryptBlock :: SessionKey -> ByteString -> IO (Either C.CryptoError ByteString) +decryptBlock k@SessionKey {aesKey} block = do + let (authTag, msg') = B.splitAt C.authTagSize block + ivBytes <- makeNextIV k + runExceptT $ C.decryptAES aesKey ivBytes msg' (C.bsToAuthTag authTag) + +makeNextIV :: SessionKey -> IO C.IV +makeNextIV SessionKey {baseIV, counter} = atomically $ do + c <- readTVar counter + writeTVar counter $ c + 1 + pure $ iv c where - trim_cr "" = "" - trim_cr s = if B.last s == '\r' then B.init s else s + (start, rest) = B.splitAt 4 $ C.unIV baseIV + iv c = C.IV $ (start `xor` encodeWord32 c) <> rest + +-- | implements server transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption +-- The numbers in function names refer to the steps in the document +serverHandshake :: Handle -> C.FullKeyPair -> ExceptT TransportError IO THandle +serverHandshake h (k, pk) = do + liftIO sendHeaderAndPublicKey_1 + encryptedKeys <- receiveEncryptedKeys_4 + -- TODO server currently ignores blockSize returned by the client + -- this is reserved for future support of streams + ClientHandshake {blockSize = _, sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys + th <- liftIO $ transportHandle h rcvKey sndKey transportBlockSize -- keys are swapped here + sendWelcome_6 th + pure th + where + sendHeaderAndPublicKey_1 :: IO () + sendHeaderAndPublicKey_1 = do + let sKey = C.encodePubKey k + header = ServerHeader {blockSize = transportBlockSize, keySize = B.length sKey} + B.hPut h $ binaryServerHeader header <> sKey + receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString + receiveEncryptedKeys_4 = + liftIO (B.hGet h $ C.publicKeySize k) >>= \case + "" -> throwE $ TEHandshake TERMINATED + ks -> pure ks + decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO ClientHandshake + decryptParseKeys_5 encKeys = + liftError (const $ TEHandshake DECRYPT) (C.decryptOAEP pk encKeys) + >>= liftEither . parseClientHandshake + sendWelcome_6 :: THandle -> ExceptT TransportError IO () + sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " " + +-- | implements client transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption +-- The numbers in function names refer to the steps in the document +clientHandshake :: Handle -> Maybe C.KeyHash -> ExceptT TransportError IO THandle +clientHandshake h keyHash = do + (k, blkSize) <- getHeaderAndPublicKey_1_2 + -- TODO currently client always uses the blkSize returned by the server + keys@ClientHandshake {sndKey, rcvKey} <- liftIO $ generateKeys_3 blkSize + sendEncryptedKeys_4 k keys + th <- liftIO $ transportHandle h sndKey rcvKey blkSize + getWelcome_6 th >>= checkVersion + pure th + where + getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (C.PublicKey, Int) + getHeaderAndPublicKey_1_2 = do + header <- liftIO (B.hGet h serverHeaderSize) + ServerHeader {blockSize, keySize} <- liftEither $ parse serverHeaderP (TEHandshake HEADER) header + when (blockSize < transportBlockSize || blockSize > maxTransportBlockSize) $ + throwError $ TEHandshake HEADER + s <- liftIO $ B.hGet h keySize + maybe (pure ()) (validateKeyHash_2 s) keyHash + key <- liftEither $ parseKey s + pure (key, blockSize) + parseKey :: ByteString -> Either TransportError C.PublicKey + parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.binaryPubKeyP + validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO () + validateKeyHash_2 k kHash + | C.getKeyHash k == kHash = pure () + | otherwise = throwE $ TEHandshake BAD_HASH + generateKeys_3 :: Int -> IO ClientHandshake + generateKeys_3 blkSize = ClientHandshake blkSize <$> generateKey <*> generateKey + generateKey :: IO SessionKey + generateKey = do + aesKey <- C.randomAesKey + baseIV <- C.randomIV + pure SessionKey {aesKey, baseIV, counter = undefined} + sendEncryptedKeys_4 :: C.PublicKey -> ClientHandshake -> ExceptT TransportError IO () + sendEncryptedKeys_4 k keys = + liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeClientHandshake keys) + >>= liftIO . B.hPut h + getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion + getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th + parseSMPVersion :: ByteString -> Either TransportError SMPVersion + parseSMPVersion = first (const $ TEHandshake VERSION) . A.parseOnly (smpVersionP <* A.space) + checkVersion :: SMPVersion -> ExceptT TransportError IO () + checkVersion smpVersion = + when (major smpVersion > major currentSMPVersion) . throwE $ + TEHandshake MAJOR_VERSION + +data ServerHeader = ServerHeader {blockSize :: Int, keySize :: Int} + deriving (Eq, Show) + +binaryRsaTransport :: Int +binaryRsaTransport = 0 + +transportBlockSize :: Int +transportBlockSize = 4096 + +maxTransportBlockSize :: Int +maxTransportBlockSize = 65536 + +serverHeaderSize :: Int +serverHeaderSize = 8 + +binaryServerHeader :: ServerHeader -> ByteString +binaryServerHeader ServerHeader {blockSize, keySize} = + encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> encodeEnum16 keySize + +serverHeaderP :: Parser ServerHeader +serverHeaderP = ServerHeader <$> int32 <* binaryRsaTransportP <*> int16 + +serializeClientHandshake :: ClientHandshake -> ByteString +serializeClientHandshake ClientHandshake {blockSize, sndKey, rcvKey} = + encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> serializeKey sndKey <> serializeKey rcvKey + where + serializeKey :: SessionKey -> ByteString + serializeKey SessionKey {aesKey, baseIV} = C.unKey aesKey <> C.unIV baseIV + +clientHandshakeP :: Parser ClientHandshake +clientHandshakeP = ClientHandshake <$> int32 <* binaryRsaTransportP <*> keyP <*> keyP + where + keyP :: Parser SessionKey + keyP = do + aesKey <- C.aesKeyP + baseIV <- C.ivP + pure SessionKey {aesKey, baseIV, counter = undefined} + +int32 :: Parser Int +int32 = decodeNum32 <$> A.take 4 + +int16 :: Parser Int +int16 = decodeNum16 <$> A.take 2 + +binaryRsaTransportP :: Parser () +binaryRsaTransportP = binaryRsa =<< int16 + where + binaryRsa :: Int -> Parser () + binaryRsa n + | n == binaryRsaTransport = pure () + | otherwise = fail "unknown transport mode" + +parseClientHandshake :: ByteString -> Either TransportError ClientHandshake +parseClientHandshake = parse clientHandshakeP $ TEHandshake AES_KEYS + +transportHandle :: Handle -> SessionKey -> SessionKey -> Int -> IO THandle +transportHandle h sk rk blockSize = do + sndCounter <- newTVarIO 0 + rcvCounter <- newTVarIO 0 + pure + THandle + { handle = h, + sndKey = sk {counter = sndCounter}, + rcvKey = rk {counter = rcvCounter}, + blockSize + } diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index e8397015d..2800e521e 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -31,19 +31,22 @@ raceAny_ = r [] r as (m : ms) = withAsync m $ \a -> r (a : as) ms r as [] = void $ waitAnyCancel as -infixl 4 <$$> +infixl 4 <$$>, <$?> (<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b) (<$$>) = fmap . fmap +(<$?>) :: MonadFail m => (a -> Either String b) -> m a -> m b +f <$?> m = m >>= either fail pure . f + bshow :: Show a => a -> ByteString bshow = B.pack . show -liftIOEither :: (MonadUnliftIO m, MonadError e m) => IO (Either e a) -> m a +liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a liftIOEither a = liftIO a >>= liftEither -liftError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a +liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a liftError f = liftEitherError f . runExceptT -liftEitherError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a +liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a liftEitherError f a = liftIOEither (first f <$> a) diff --git a/src/Simplex/Messaging/errors.md b/src/Simplex/Messaging/errors.md new file mode 100644 index 000000000..6fba8bed4 --- /dev/null +++ b/src/Simplex/Messaging/errors.md @@ -0,0 +1,98 @@ +# Errors + +## Problems + +- using numbers and strings to indicate errors (in protocol and in code) - ErrorType, AgentErrorType, TransportError +- re-using the same type in multiple contexts (with some constructors not applicable to all contexts) - ErrorType + +## Error types + +### ErrorType (Protocol.hs) + +- BLOCK - incorrect block format or encoding +- CMD error - command is unknown or has invalid syntax, where `error` can be: + - PROHIBITED - server response sent from client or vice versa + - SYNTAX - error parsing command + - NO_AUTH - transmission has no required credentials (signature or queue ID) + - HAS_AUTH - transmission has not allowed credentials + - NO_QUEUE - transmission has not queue ID +- AUTH - command is not authorised (queue does not exist or signature verification failed). +- NO_MSG - acknowledging (ACK) the message without message +- INTERNAL - internal server error. +- DUPLICATE_ - it is used internally to signal that the queue ID is already used. This is NOT used in the protocol, instead INTERNAL is sent to the client. It has to be removed. + +### AgentErrorType (Agent/Transmission.hs) + +Some of these errors are not correctly serialized/parsed - see line 322 in Agent/Transmission.hs + +- CMD e - command or response error + - PROHIBITED - server response sent as client command (and vice versa) + - SYNTAX - command is unknown or has invalid syntax. + - NO_CONN - connection is required in the command (and absent) + - SIZE - incorrect message size of messages (when parsing SEND and MSG) + - LARGE -- message does not fit SMP block +- CONN e - connection errors + - UNKNOWN - connection alias not in database + - DUPLICATE - connection alias already exists + - SIMPLEX - connection is simplex, but operation requires another queue +- SMP ErrorType - forwarding SMP errors (SMPServerError) to the agent client +- BROKER e - SMP server errors + - RESPONSE ErrorType - invalid SMP server response + - UNEXPECTED - unexpected response + - NETWORK - network TCP connection error + - TRANSPORT TransportError -- handshake or other transport error + - TIMEOUT - command response timeout +- AGENT e - errors of other agents + - A_MESSAGE - SMP message failed to parse + - A_PROHIBITED - SMP message is prohibited with the current queue status + - A_ENCRYPTION - cannot RSA/AES-decrypt or parse decrypted header + - A_SIGNATURE - invalid RSA signature +- INTERNAL ByteString - agent implementation or dependency error + +### SMPClientError (Client.hs) + +- SMPServerError ErrorType - this is correctly parsed server ERR response. This error is forwarded to the agent client as `ERR SMP err` +- SMPResponseError ErrorType - this is invalid server response that failed to parse - forwarded to the client as `ERR BROKER RESPONSE`. +- SMPUnexpectedResponse - different response from what is expected to a given command, e.g. server should respond `IDS` or `ERR` to `NEW` command, other responses would result in this error - forwarded to the client as `ERR BROKER UNEXPECTED`. +- SMPResponseTimeout - used for TCP connection and command response timeouts -> `ERR BROKER TIMEOUT`. +- SMPNetworkError - fails to establish TCP connection -> `ERR BROKER NETWORK` +- SMPTransportError e - fails connection handshake or some other transport error -> `ERR BROKER TRANSPORT e` +- SMPSignatureError C.CryptoError - error when cryptographically "signing" the command. + +### StoreError (Agent/Store.hs) + +- SEInternal ByteString - signals exceptions in store actions. +- SEConnNotFound - connection alias not found (or both queues absent). +- SEConnDuplicate - connection alias already used. +- SEBadConnType ConnType - wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. `updateRcvConnWithSndQueue` and `updateSndConnWithRcvQueue` do not allow duplex connections - they would also return this error. +- SEBadQueueStatus - the intention was to pass current expected queue status in methods, as we always know what it should be at any stage of the protocol, and in case it does not match use this error. **Currently not used**. +- SENotImplemented - used in `getMsg` that is not implemented/used. + +### CryptoError (Crypto.hs) + +- RSAEncryptError R.Error - RSA encryption error +- RSADecryptError R.Error - RSA decryption error +- RSASignError R.Error - RSA signature error +- AESCipherError CE.CryptoError - AES initialization error +- CryptoIVError - IV generation error +- AESDecryptError - AES decryption error +- CryptoLargeMsgError - message does not fit in SMP block +- CryptoHeaderError String - failure parsing RSA-encrypted message header + +### TransportError (Transport.hs) + + - TEBadBlock - error parsing block + - TEEncrypt - block encryption error + - TEDecrypt - block decryption error + - TEHandshake HandshakeError + +### HandshakeError (Transport.hs) + + - ENCRYPT - encryption error + - DECRYPT - decryption error + - VERSION - error parsing protocol version + - RSA_KEY - error parsing RSA key + - AES_KEYS - error parsing AES keys + - BAD_HASH - not matching RSA key hash + - MAJOR_VERSION - lower agent version than protocol version + - TERMINATED - transport terminated diff --git a/stack.yaml b/stack.yaml index 7819ad877..4dd8b8bbf 100644 --- a/stack.yaml +++ b/stack.yaml @@ -35,9 +35,10 @@ packages: # forks / in-progress versions pinned to a git hash. For example: # extra-deps: - - sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002 + - cryptostore-0.2.1.0@sha256:9896e2984f36a1c8790f057fd5ce3da4cbcaf8aa73eb2d9277916886978c5b19,3881 - direct-sqlite-2.3.26@sha256:04e835402f1508abca383182023e4e2b9b86297b8533afbd4e57d1a5652e0c23,3718 - simple-logger-0.1.0@sha256:be8ede4bd251a9cac776533bae7fb643369ebd826eb948a9a18df1a8dd252ff8,1079 + - sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002 - terminal-0.2.0.0@sha256:de6770ecaae3197c66ac1f0db5a80cf5a5b1d3b64a66a05b50f442de5ad39570,2977 # - network-run-0.2.4@sha256:7dbb06def522dab413bce4a46af476820bffdff2071974736b06f52f4ab57c96,885 # - git: https://github.com/commercialhaskell/stack.git diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 3bcdd1314..6a1754f13 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -13,6 +13,7 @@ import Control.Concurrent import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPAgentClient +import SMPClient (testKeyHashStr) import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import System.IO (Handle) @@ -79,7 +80,7 @@ h #:# err = tryGet `shouldReturn` () _ -> return () pattern Msg :: MsgBody -> ACommand 'Agent -pattern Msg m_body <- MSG {m_body} +pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk} testDuplexConnection :: Handle -> Handle -> IO () testDuplexConnection alice bob = do @@ -87,13 +88,13 @@ testDuplexConnection alice bob = do let qInfo' = serializeSmpQueueInfo qInfo bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", CON) alice <# ("", "bob", CON) - alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT _) -> True; _ -> False - alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT _) -> True; _ -> False + alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT 1) -> True; _ -> False + alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT 2) -> True; _ -> False bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False - bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT _) -> True; _ -> False + bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT 3) -> True; _ -> False alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False - bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT _) -> True; _ -> False + bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT 4) -> True; _ -> False alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH)) @@ -127,24 +128,24 @@ testSubscrNotification (server, _) client = do client <# ("", "conn1", END) samplePublicKey :: ByteString -samplePublicKey = "256,ppr3DCweAD3RTVFhU2j0u+DnYdqJl1qCdKLHIKsPl1xBzfmnzK0o9GEDlaIClbK39KzPJMljcpnYb2KlSoZ51AhwF5PH2CS+FStc3QzajiqfdOQPet23Hd9YC6pqyTQ7idntqgPrE7yKJF44lUhKlq8QS9KQcbK7W6t7F9uQFw44ceWd2eVf81UV04kQdKWJvC5Sz6jtSZNEfs9mVI8H0wi1amUvS6+7EDJbxikhcCRnFShFO9dUKRYXj6L2JVqXqO5cZgY9BScyneWIg6mhhsTcdDbITM6COlL+pF1f3TjDN+slyV+IzE+ap/9NkpsrCcI8KwwDpqEDmUUV/JQfmQ==,gj2UAiWzSj7iun0iXvI5iz5WEjaqngmB3SzQ5+iarixbaG15LFDtYs3pijG3eGfB1wIFgoP4D2z97vIWn8olT4uCTUClf29zGDDve07h/B3QG/4i0IDnio7MX3AbE8O6PKouqy/GLTfT4WxFUn423g80rpsVYd5oj+SCL2eaxIc=" +samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" syntaxTests :: Spec syntaxTests = do - it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR SYNTAX 11") + it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX") describe "NEW" do describe "valid" do -- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided) -- TODO: add tests with defined connection alias xit "only server" $ ("211", "", "NEW localhost") >#>= \case ("211", "", "INV" : _) -> True; _ -> False it "with port" $ ("212", "", "NEW localhost:5000") >#>= \case ("212", "", "INV" : _) -> True; _ -> False - xit "with keyHash" $ ("213", "", "NEW localhost#1234") >#>= \case ("213", "", "INV" : _) -> True; _ -> False - it "with port and keyHash" $ ("214", "", "NEW localhost:5000#1234") >#>= \case ("214", "", "INV" : _) -> True; _ -> False + xit "with keyHash" $ ("213", "", "NEW localhost#" <> testKeyHashStr) >#>= \case ("213", "", "INV" : _) -> True; _ -> False + it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> testKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False describe "invalid" do -- TODO: add tests with defined connection alias - it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR SYNTAX 11") - it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR SYNTAX 11") - it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR SYNTAX 11") + it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR CMD SYNTAX") + it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR CMD SYNTAX") + it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR CMD SYNTAX") describe "JOIN" do describe "valid" do @@ -154,4 +155,4 @@ syntaxTests = do ("311", "", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "", "ERR SMP AUTH") describe "invalid" do -- TODO: JOIN is not merged yet - to be added - it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR SYNTAX 11") + it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 9bf9b6a8b..5442705f0 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -1,18 +1,22 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} module AgentTests.SQLiteTests (storeTests) where import Control.Monad.Except (ExceptT, runExceptT) import qualified Crypto.PubKey.RSA as R +import Data.ByteString.Char8 (ByteString) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Time import Data.Word (Word32) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.QQ (sql) +import SMPClient (testKeyHash) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Transmission @@ -48,41 +52,50 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e -- TODO add null port tests storeTests :: Spec storeTests = withStore do - describe "compiled as threadsafe" testCompiledThreadsafe - describe "foreign keys enabled" testForeignKeysEnabled + describe "store setup" do + testCompiledThreadsafe + testForeignKeysEnabled describe "store methods" do - describe "createRcvConn" testCreateRcvConn - describe "createSndConn" testCreateSndConn - describe "getAllConnAliases" testGetAllConnAliases - describe "getRcvQueue" testGetRcvQueue - describe "deleteConn" do - describe "RcvConnection" testDeleteRcvConn - describe "SndConnection" testDeleteSndConn - describe "DuplexConnection" testDeleteDuplexConn - describe "upgradeRcvConnToDuplex" testUpgradeRcvConnToDuplex - describe "upgradeSndConnToDuplex" testUpgradeSndConnToDuplex - describe "set queue status" do - describe "setRcvQueueStatus" testSetRcvQueueStatus - describe "setSndQueueStatus" testSetSndQueueStatus - describe "DuplexConnection" testSetQueueStatusDuplex - xdescribe "RcvQueue doesn't exist" testSetRcvQueueStatusNoQueue - xdescribe "SndQueue doesn't exist" testSetSndQueueStatusNoQueue - describe "createRcvMsg" do - describe "RcvQueue exists" testCreateRcvMsg - describe "RcvQueue doesn't exist" testCreateRcvMsgNoQueue - describe "createSndMsg" do - describe "SndQueue exists" testCreateSndMsg - describe "SndQueue doesn't exist" testCreateSndMsgNoQueue + describe "Queue and Connection management" do + describe "createRcvConn" do + testCreateRcvConn + testCreateRcvConnDuplicate + describe "createSndConn" do + testCreateSndConn + testCreateSndConnDuplicate + describe "getAllConnAliases" testGetAllConnAliases + describe "getRcvQueue" testGetRcvQueue + describe "deleteConn" do + testDeleteRcvConn + testDeleteSndConn + testDeleteDuplexConn + describe "upgradeRcvConnToDuplex" do + testUpgradeRcvConnToDuplex + describe "upgradeSndConnToDuplex" do + testUpgradeSndConnToDuplex + describe "set Queue status" do + describe "setRcvQueueStatus" do + testSetRcvQueueStatus + testSetRcvQueueStatusNoQueue + describe "setSndQueueStatus" do + testSetSndQueueStatus + testSetSndQueueStatusNoQueue + testSetQueueStatusDuplex + describe "Msg management" do + describe "create Msg" do + testCreateRcvMsg + testCreateSndMsg + testCreateRcvAndSndMsgs testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = do - it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do + it "compiled sqlite library should be threadsafe" $ \store -> do compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] compileOptions `shouldNotContain` [["THREADSAFE=0"]] testForeignKeysEnabled :: SpecWith SQLiteStore testForeignKeysEnabled = do - it "should throw error if foreign keys are enabled" $ \store -> do + it "foreign keys should be enabled" $ \store -> do let inconsistentQuery = [sql| INSERT INTO connections @@ -96,13 +109,13 @@ testForeignKeysEnabled = do rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "1234", connAlias = "conn1", - rcvPrivateKey = C.PrivateKey 1 2 3, + rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "2345", sndKey = Nothing, - decryptKey = C.PrivateKey 1 2 3, + decryptKey = C.safePrivateKey (1, 2, 3), verifyKey = Nothing, status = New } @@ -110,12 +123,12 @@ rcvQueue1 = sndQueue1 :: SndQueue sndQueue1 = SndQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "3456", connAlias = "conn1", - sndPrivateKey = C.PrivateKey 1 2 3, + sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, - signKey = C.PrivateKey 1 2 3, + signKey = C.safePrivateKey (1, 2, 3), status = New } @@ -131,6 +144,13 @@ testCreateRcvConn = do getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) +testCreateRcvConnDuplicate :: SpecWith SQLiteStore +testCreateRcvConnDuplicate = do + it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 + createRcvConn store rcvQueue1 + `throwsError` SEConnDuplicate + testCreateSndConn :: SpecWith SQLiteStore testCreateSndConn = do it "should create SndConnection and add RcvQueue" $ \store -> do @@ -143,118 +163,113 @@ testCreateSndConn = do getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) +testCreateSndConnDuplicate :: SpecWith SQLiteStore +testCreateSndConnDuplicate = do + it "should throw error on attempt to create duplicate SndConnection" $ \store -> do + _ <- runExceptT $ createSndConn store sndQueue1 + createSndConn store sndQueue1 + `throwsError` SEConnDuplicate + testGetAllConnAliases :: SpecWith SQLiteStore testGetAllConnAliases = do it "should get all conn aliases" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - createSndConn store sndQueue1 {connAlias = "conn2"} - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"} getAllConnAliases store `returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias] testGetRcvQueue :: SpecWith SQLiteStore testGetRcvQueue = do it "should get RcvQueue" $ \store -> do - let smpServer = SMPServer "smp.simplex.im" (Just "5223") (Just "1234") + let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash let recipientId = "1234" - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getRcvQueue store smpServer recipientId `returnsResult` rcvQueue1 testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn = do it "should create RcvConnection and delete it" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getConn store "conn1" `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = do it "should create SndConnection and delete it" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = do it "should create DuplexConnection and delete it" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = do it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 let anotherSndQueue = SndQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "2345", connAlias = "conn1", - sndPrivateKey = C.PrivateKey 1 2 3, + sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, - signKey = C.PrivateKey 1 2 3, + signKey = C.safePrivateKey (1, 2, 3), status = New } upgradeRcvConnToDuplex store "conn1" anotherSndQueue `throwsError` SEBadConnType CSnd - upgradeSndConnToDuplex store "conn1" rcvQueue1 - `returnsResult` () + _ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1 upgradeRcvConnToDuplex store "conn1" anotherSndQueue `throwsError` SEBadConnType CDuplex testUpgradeSndConnToDuplex :: SpecWith SQLiteStore testUpgradeSndConnToDuplex = do it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 let anotherRcvQueue = RcvQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "3456", connAlias = "conn1", - rcvPrivateKey = C.PrivateKey 1 2 3, + rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "4567", sndKey = Nothing, - decryptKey = C.PrivateKey 1 2 3, + decryptKey = C.safePrivateKey (1, 2, 3), verifyKey = Nothing, status = New } upgradeSndConnToDuplex store "conn1" anotherRcvQueue `throwsError` SEBadConnType CRcv - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 upgradeSndConnToDuplex store "conn1" anotherRcvQueue `throwsError` SEBadConnType CDuplex testSetRcvQueueStatus :: SpecWith SQLiteStore testSetRcvQueueStatus = do it "should update status of RcvQueue" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getConn store "conn1" `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) setRcvQueueStatus store rcvQueue1 Confirmed @@ -265,8 +280,7 @@ testSetRcvQueueStatus = do testSetSndQueueStatus :: SpecWith SQLiteStore testSetSndQueueStatus = do it "should update status of SndQueue" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) setSndQueueStatus store sndQueue1 Confirmed @@ -277,10 +291,8 @@ testSetSndQueueStatus = do testSetQueueStatusDuplex :: SpecWith SQLiteStore testSetQueueStatusDuplex = do it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) setRcvQueueStatus store rcvQueue1 Secured @@ -290,61 +302,88 @@ testSetQueueStatusDuplex = do setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn - SCDuplex - ( DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed} - ) + `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = do - it "should throw error on attempt to update status of nonexistent RcvQueue" $ \store -> do + xit "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do setRcvQueueStatus store rcvQueue1 Confirmed - `throwsError` SEInternal + `throwsError` SEConnNotFound testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore testSetSndQueueStatusNoQueue = do - it "should throw error on attempt to update status of nonexistent SndQueue" $ \store -> do + xit "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do setSndQueueStatus store sndQueue1 Confirmed - `throwsError` SEInternal + `throwsError` SEConnNotFound + +hw :: ByteString +hw = encodeUtf8 "Hello world!" + +ts :: UTCTime +ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) + +mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData +mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = + RcvMsgData + { internalId, + internalRcvId, + internalTs = ts, + senderMeta = (externalSndId, ts), + brokerMeta = (brokerId, ts), + msgBody = hw, + internalHash, + externalPrevSndHash = "hash_from_sender", + msgIntegrity = MsgOk + } + +testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation +testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do + updateRcvIds store rcvQueue + `returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) + createRcvMsg store rcvQueue rcvMsgData + `returnsResult` () testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = do - it "should create a RcvMsg and return InternalId" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + it "should reserve internal ids and create a RcvMsg" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 -- TODO getMsg to check message - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `returnsResult` InternalId 1 + testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" + testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" -testCreateRcvMsgNoQueue :: SpecWith SQLiteStore -testCreateRcvMsgNoQueue = do - it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEBadConn - createSndConn store sndQueue1 - `returnsResult` () - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEBadConnType CSnd +mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData +mkSndMsgData internalId internalSndId internalHash = + SndMsgData + { internalId, + internalSndId, + internalTs = ts, + msgBody = hw, + internalHash + } + +testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation +testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do + updateSndIds store sndQueue + `returnsResult` (internalId, internalSndId, expectedPrevHash) + createSndMsg store sndQueue sndMsgData + `returnsResult` () testCreateSndMsg :: SpecWith SQLiteStore testCreateSndMsg = do - it "should create a SndMsg and return InternalId" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do + _ <- runExceptT $ createSndConn store sndQueue1 -- TODO getMsg to check message - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `returnsResult` InternalId 1 + testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" + testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" -testCreateSndMsgNoQueue :: SpecWith SQLiteStore -testCreateSndMsgNoQueue = do - it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEBadConn - createRcvConn store rcvQueue1 - `returnsResult` () - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEBadConnType CRcv +testCreateRcvAndSndMsgs :: SpecWith SQLiteStore +testCreateRcvAndSndMsgs = do + it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 + testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" + testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" + testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" + testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" + testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" + testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" diff --git a/tests/ProtocolErrorTests.hs b/tests/ProtocolErrorTests.hs new file mode 100644 index 000000000..82e4afbf7 --- /dev/null +++ b/tests/ProtocolErrorTests.hs @@ -0,0 +1,18 @@ +module ProtocolErrorTests where + +import Simplex.Messaging.Agent.Transmission (AgentErrorType, agentErrorTypeP, serializeAgentError) +import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Protocol (ErrorType, errorTypeP, serializeErrorType) +import Test.Hspec +import Test.Hspec.QuickCheck (modifyMaxSuccess) +import Test.QuickCheck + +protocolErrorTests :: Spec +protocolErrorTests = modifyMaxSuccess (const 1000) $ do + describe "errors parsing / serializing" $ do + it "should parse SMP protocol errors" . property $ \err -> + parseAll errorTypeP (serializeErrorType err) + == Right (err :: ErrorType) + it "should parse SMP agent errors" . property $ \err -> + parseAll agentErrorTypeP (serializeAgentError err) + == Right (err :: AgentErrorType) diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 5f7175871..9bf0bf187 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -6,23 +6,19 @@ module SMPAgentClient where -import Control.Monad import Control.Monad.IO.Unlift import Crypto.Random import Network.Socket (HostName, ServiceName) -import SMPClient (testPort, withSmpServer, withSmpServerThreadOn) +import SMPClient (serverBracket, testPort, withSmpServer, withSmpServerThreadOn) import Simplex.Messaging.Agent (runSMPAgentBlocking) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig) import Simplex.Messaging.Transport -import System.Timeout (timeout) import Test.Hspec import UnliftIO.Concurrent import UnliftIO.Directory -import qualified UnliftIO.Exception as E import UnliftIO.IO -import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar) agentTestHost :: HostName agentTestHost = "localhost" @@ -125,12 +121,10 @@ cfg = } withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> (ThreadId -> m a) -> m a -withSmpAgentThreadOn (port', db') f = do - started <- newEmptyTMVarIO - E.bracket - (forkIOWithUnmask ($ runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'})) - (liftIO . killThread >=> const (removeFile db')) - \x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x +withSmpAgentThreadOn (port', db') = + serverBracket + (\started -> runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'}) + (removeFile db') withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a withSmpAgentOn (port', db') = withSmpAgentThreadOn (port', db') . const diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 56598673a..00e843119 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -1,23 +1,30 @@ {-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module SMPClient where +import Control.Monad (void) +import Control.Monad.Except (runExceptT) import Control.Monad.IO.Unlift import Crypto.Random +import Data.ByteString.Base64 (encode) +import qualified Data.ByteString.Char8 as B import Network.Socket +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Server (runSMPServerBlocking) import Simplex.Messaging.Server.Env.STM +import Simplex.Messaging.Server.StoreLog (openReadStoreLog) import Simplex.Messaging.Transport -import System.Timeout (timeout) import Test.Hspec import UnliftIO.Concurrent import qualified UnliftIO.Exception as E -import UnliftIO.IO -import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.Timeout (timeout) testHost :: HostName testHost = "localhost" @@ -25,13 +32,21 @@ testHost = "localhost" testPort :: ServiceName testPort = "5000" -testSMPClient :: MonadUnliftIO m => (Handle -> m a) -> m a -testSMPClient client = do - runTCPClient testHost testPort $ \h -> do - line <- liftIO $ getLn h - if line == "Welcome to SMP v0.2.0" - then client h - else error "not connected" +testKeyHashStr :: B.ByteString +testKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" + +testKeyHash :: Maybe C.KeyHash +testKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" + +testStoreLogFile :: FilePath +testStoreLogFile = "tests/tmp/smp-server-store.log" + +testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a +testSMPClient client = + runTCPClient testHost testPort $ \h -> + liftIO (runExceptT $ clientHandshake h testKeyHash) >>= \case + Right th -> client th + Left e -> error $ show e cfg :: ServerConfig cfg = @@ -39,16 +54,66 @@ cfg = { tcpPort = testPort, tbqSize = 1, queueIdBytes = 12, - msgIdBytes = 6 + msgIdBytes = 6, + storeLog = Nothing, + serverPrivateKey = + -- full RSA private key (only for tests) + "MIIFIwIBAAKCAQEArZyrri/NAwt5buvYjwu+B/MQeJUszDBpRgVqNddlI9kNwDXu\ + \kaJ8chEhrtaUgXeSWGooWwqjXEUQE6RVbCC6QVo9VEBSP4xFwVVd9Fj7OsgfcXXh\ + \AqWxfctDcBZQ5jTUiJpdBc+Vz2ZkumVNl0W+j9kWm9nfkMLQj8c0cVSDxz4OKpZb\ + \qFuj0uzHkis7e7wsrKSKWLPg3M5ZXPZM1m9qn7SfJzDRDfJifamxWI7uz9XK2+Dp\ + \NkUQlGQgFJEv1cKN88JAwIqZ1s+TAQMQiB+4QZ2aNfSqGEzRJN7FMCKRK7pM0A9A\ + \PCnijyuImvKFxTdk8Bx1q+XNJzsY6fBrLWJZ+QKBgQCySG4tzlcEm+tOVWRcwrWh\ + \6zsczGZp9mbf9c8itRx6dlldSYuDG1qnddL70wuAZF2AgS1JZgvcRZECoZRoWP5q\ + \Kq2wvpTIYjFPpC39lxgUoA/DXKVKZZdan+gwaVPAPT54my1CS32VrOiAY4gVJ3LJ\ + \Mn1/FqZXUFQA326pau3loQKCAQEAoljmJMp88EZoy3HlHUbOjl5UEhzzVsU1TnQi\ + \QmPm+aWRe2qelhjW4aTvSVE5mAUJsN6UWTeMf4uvM69Z9I5pfw2pEm8x4+GxRibY\ + \iiwF2QNaLxxmzEHm1zQQPTgb39o8mgklhzFPill0JsnL3f6IkVwjFJofWSmpqEGs\ + \dFSMRSXUTVXh1p/o7QZrhpwO/475iWKVS7o48N/0Xp513re3aXw+DRNuVnFEaBIe\ + \TLvWM9Czn16ndAu1HYiTBuMvtRbAWnGZxU8ewzF4wlWK5tdIL5PTJDd1VhZJAKtB\ + \npDvJpwxzKmjAhcTmjx0ckMIWtdVaOVm/2gWCXDty2FEdg7koQKBgQDOUUguJ/i7\ + \q0jldWYRnVkotKnpInPdcEaodrehfOqYEHnvro9xlS6OeAS4Vz5AdH45zQ/4J3bV\ + \2cH66tNr18ebM9nL//t5G69i89R9W7szyUxCI3LmAIdi3oSEbmz5GQBaw4l6h9Wi\ + \n4FmFQaAXZrjQfO2qJcAHvWRsMp2pmqAGwKBgQDXaza0DRsKWywWznsHcmHa0cx8\ + \I4jxqGaQmLO7wBJRP1NSFrywy1QfYrVX9CTLBK4V3F0PCgZ01Qv94751CzN43TgF\ + \ebd/O9r5NjNTnOXzdWqETbCffLGd6kLgCMwPQWpM9ySVjXHWCGZsRAnF2F6M1O32\ + \43StIifvwJQFqSM3ewKBgCaW6y7sRY90Ua7283RErezd9EyT22BWlDlACrPu3FNC\ + \LtBf1j43uxBWBQrMLsHe2GtTV0xt9m0MfwZsm2gSsXcm4Xi4DJgfN+Z7rIlyy9UY\ + \PCDSdZiU1qSr+NrffDrXlfiAM1cUmCdUX7eKjp/ltkUHNaOGfSn5Pdr3MkAiD/Hf\ + \AoGBAKIdKCuOwuYlwjS9J+IRGuSSM4o+OxQdwGmcJDTCpyWb5dEk68e7xKIna3zf\ + \jc+H+QdMXv1nkRK9bZgYheXczsXaNZUSTwpxaEldzVD3hNvsXSgJRy9fqHwA4PBq\ + \vqiBHoO3RNbqg+2rmTMfDuXreME3S955ZiPZm4Z+T8Hj52mPAoGAQm5QH/gLFtY5\ + \+znqU/0G8V6BKISCQMxbbmTQVcTgGySrP2gVd+e4MWvUttaZykhWqs8rpr7mgpIY\ + \hul7Swx0SHFN3WpXu8uj+B6MLpRcCbDHO65qU4kQLs+IaXXsuuTjMvJ5LwjkZVrQ\ + \TmKzSAw7iVWwEUZR/PeiEKazqrpp9VU=" } +withSmpServerStoreLogOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a +withSmpServerStoreLogOn port client = do + s <- liftIO $ openReadStoreLog testStoreLogFile + serverBracket + (\started -> runSMPServerBlocking started cfg {tcpPort = port, storeLog = Just s}) + (pure ()) + client + withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a -withSmpServerThreadOn port f = do +withSmpServerThreadOn port = + serverBracket + (\started -> runSMPServerBlocking started cfg {tcpPort = port}) + (pure ()) + +serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a +serverBracket process afterProcess f = do started <- newEmptyTMVarIO E.bracket - (forkIOWithUnmask ($ runSMPServerBlocking started cfg {tcpPort = port})) - (liftIO . killThread) - \x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x + (forkIOWithUnmask ($ process started)) + (\t -> killThread t >> afterProcess >> waitFor started "stop") + (\t -> waitFor started "start" >> f t) + where + waitFor started s = + 5_000_000 `timeout` atomically (takeTMVar started) >>= \case + Nothing -> error $ "server did not " <> s + _ -> pure () withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a withSmpServerOn port = withSmpServerThreadOn port . const @@ -56,33 +121,43 @@ withSmpServerOn port = withSmpServerThreadOn port . const withSmpServer :: (MonadUnliftIO m, MonadRandom m) => m a -> m a withSmpServer = withSmpServerOn testPort -runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a +runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (THandle -> m a) -> m a runSmpTest test = withSmpServer $ testSMPClient test -runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([Handle] -> m a) -> m a +runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([THandle] -> m a) -> m a runSmpTestN nClients test = withSmpServer $ run nClients [] where - run :: Int -> [Handle] -> m a + run :: Int -> [THandle] -> m a run 0 hs = test hs run n hs = testSMPClient $ \h -> run (n - 1) (h : hs) smpServerTest :: RawTransmission -> IO RawTransmission smpServerTest cmd = runSmpTest $ \h -> tPutRaw h cmd >> tGetRaw h -smpTest :: (Handle -> IO ()) -> Expectation +smpTest :: (THandle -> IO ()) -> Expectation smpTest test' = runSmpTest test' `shouldReturn` () -smpTestN :: Int -> ([Handle] -> IO ()) -> Expectation +smpTestN :: Int -> ([THandle] -> IO ()) -> Expectation smpTestN n test' = runSmpTestN n test' `shouldReturn` () -smpTest2 :: (Handle -> Handle -> IO ()) -> Expectation +smpTest2 :: (THandle -> THandle -> IO ()) -> Expectation smpTest2 test' = smpTestN 2 _test where _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" -smpTest3 :: (Handle -> Handle -> Handle -> IO ()) -> Expectation +smpTest3 :: (THandle -> THandle -> THandle -> IO ()) -> Expectation smpTest3 test' = smpTestN 3 _test where _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" + +tPutRaw :: THandle -> RawTransmission -> IO () +tPutRaw h (sig, corrId, queueId, command) = do + let t = B.intercalate " " [corrId, queueId, command] + void $ tPut h (C.Signature sig, t) + +tGetRaw :: THandle -> IO RawTransmission +tGetRaw h = do + ("", (CorrId corrId, qId, Right cmd)) <- tGet fromServer h + pure ("", corrId, encode qId, serializeCommand cmd) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 4a8c8766c..743d55671 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -4,22 +4,29 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} module ServerTests where +import Control.Concurrent (ThreadId, killThread) +import Control.Concurrent.STM +import Control.Exception (SomeException, try) +import Control.Monad.Except (forM_, runExceptT) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPClient import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol -import System.IO (Handle) +import Simplex.Messaging.Transport +import System.Directory (removeFile) +import System.TimeIt (timeItT) import System.Timeout import Test.HUnit import Test.Hspec rsaKeySize :: Int -rsaKeySize = 1024 `div` 8 +rsaKeySize = 2048 `div` 8 serverTests :: Spec serverTests = do @@ -30,18 +37,20 @@ serverTests = do describe "SMP messages" do describe "duplex communication over 2 SMP connections" testDuplex describe "switch subscription to another SMP queue" testSwitchSub + describe "Store log" testWithStoreLog + describe "Timing of AUTH error" testTiming pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command))) -sendRecv :: Handle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError +sendRecv :: THandle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> tGet fromServer h -signSendRecv :: Handle -> C.PrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError +signSendRecv :: THandle -> C.SafePrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError signSendRecv h pk (corrId, qId, cmd) = do - let t = B.intercalate "\r\n" [corrId, encode qId, cmd] - Right sig <- C.sign pk t - tPut h (sig, t) + let t = B.intercalate " " [corrId, encode qId, cmd] + Right sig <- runExceptT $ C.sign pk t + _ <- tPut h (sig, t) tGet fromServer h cmdSEND :: ByteString -> ByteString @@ -61,7 +70,7 @@ testCreateSecure = Resp "abcd" rId1 (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) (rId1, "") #== "creates queue" - Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5\r\nhello") + Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5 hello ") (ok1, OK) #== "accepts unsigned SEND" (sId1, sId) #== "same queue ID in response 1" @@ -72,10 +81,10 @@ testCreateSecure = (ok4, OK) #== "replies OK when message acknowledged if no more messages" Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, "ACK") - (err6, ERR PROHIBITED) #== "replies ERR when message acknowledged without messages" + (err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages" (sPub, sKey) <- C.generateKeyPair rsaKeySize - Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5\r\nhello") + Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5 hello ") (err1, ERR AUTH) #== "rejects signed SEND" (sId2, sId) #== "same queue ID in response 2" @@ -93,7 +102,7 @@ testCreateSecure = Resp "abcd" _ err4 <- signSendRecv h rKey ("abcd", rId, keyCmd) (err4, ERR AUTH) #== "rejects KEY if already secured" - Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11\r\nhello again") + Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11 hello again ") (ok3, OK) #== "accepts signed SEND" Resp "" _ (MSG _ _ msg) <- tGet fromServer h @@ -102,7 +111,7 @@ testCreateSecure = Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, "ACK") (ok5, OK) #== "replies OK when message acknowledged 2" - Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5\r\nhello") + Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5 hello ") (err5, ERR AUTH) #== "rejects unsigned SEND" testCreateDelete :: Spec @@ -117,10 +126,10 @@ testCreateDelete = Resp "bcda" _ ok1 <- signSendRecv rh rKey ("bcda", rId, "KEY " <> C.serializePubKey sPub) (ok1, OK) #== "secures queue" - Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello") + Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ") (ok2, OK) #== "accepts signed SEND" - Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7\r\nhello 2") + Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7 hello 2 ") (ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed" Resp "" _ (MSG _ _ msg1) <- tGet fromServer rh @@ -136,10 +145,10 @@ testCreateDelete = (ok3, OK) #== "suspends queue" (rId2, rId) #== "same queue ID in response 2" - Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5\r\nhello") + Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5 hello ") (err3, ERR AUTH) #== "rejects signed SEND" - Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5\r\nhello") + Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5 hello ") (err4, ERR AUTH) #== "reject unsigned SEND too" Resp "bcda" _ ok4 <- signSendRecv rh rKey ("bcda", rId, "OFF") @@ -158,10 +167,10 @@ testCreateDelete = (ok6, OK) #== "deletes queue" (rId3, rId) #== "same queue ID in response 3" - Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello") + Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ") (err7, ERR AUTH) #== "rejects signed SEND when deleted" - Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5\r\nhello") + Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5 hello ") (err8, ERR AUTH) #== "rejects unsigned SEND too when deleted" Resp "abcd" _ err11 <- signSendRecv rh rKey ("abcd", rId, "ACK") @@ -211,7 +220,7 @@ testDuplex = (aliceKey, C.serializePubKey asPub) #== "key received from Alice" Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, "KEY " <> aliceKey) - Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8\r\nhi alice") + Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8 hi alice ") Resp "" _ (MSG _ _ msg4) <- tGet fromServer alice Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, "ACK") @@ -229,7 +238,7 @@ testSwitchSub = smpTest3 \rh1 rh2 sh -> do (rPub, rKey) <- C.generateKeyPair rsaKeySize Resp "abcd" _ (IDS rId sId) <- signSendRecv rh1 rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) - Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5\r\ntest1") + Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5 test1 ") (ok1, OK) #== "sent test message 1" Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, cmdSEND "test2, no ACK") (ok2, OK) #== "sent test message 2" @@ -246,13 +255,13 @@ testSwitchSub = Resp "" _ end <- tGet fromServer rh1 (end, END) #== "unsubscribed the 1st TCP connection" - Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5\r\ntest3") + Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5 test3 ") Resp "" _ (MSG _ _ msg3) <- tGet fromServer rh2 (msg3, "test3") #== "delivered to the 2nd TCP connection" Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, "ACK") - (err, ERR PROHIBITED) #== "rejects ACK from the 1st TCP connection" + (err, ERR NO_MSG) #== "rejects ACK from the 1st TCP connection" Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, "ACK") (ok3, OK) #== "accepts ACK from the 2nd TCP connection" @@ -261,40 +270,128 @@ testSwitchSub = Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" +testWithStoreLog :: Spec +testWithStoreLog = + it "should store simplex queues to log and restore them after server restart" $ do + (sPub1, sKey1) <- C.generateKeyPair rsaKeySize + (sPub2, sKey2) <- C.generateKeyPair rsaKeySize + senderId1 <- newTVarIO "" + senderId2 <- newTVarIO "" + + withSmpServerStoreLogOn testPort . runTest $ \h -> do + (sId1, _, _) <- createAndSecureQueue h sPub1 + atomically $ writeTVar senderId1 sId1 + Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ") + Resp "" _ (MSG _ _ "hello") <- tGet fromServer h + + (sId2, rId2, rKey2) <- createAndSecureQueue h sPub2 + atomically $ writeTVar senderId2 sId2 + Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ") + Resp "" _ (MSG _ _ "hello too") <- tGet fromServer h + + Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, "DEL") + pure () + + logSize `shouldReturn` 5 + + withSmpServerThreadOn testPort . runTest $ \h -> do + sId1 <- readTVarIO senderId1 + -- fails if store log is disabled + Resp "bcda" _ (ERR AUTH) <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ") + pure () + + withSmpServerStoreLogOn testPort . runTest $ \h -> do + -- this queue is restored + sId1 <- readTVarIO senderId1 + Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ") + -- this queue is removed - not restored + sId2 <- readTVarIO senderId2 + Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ") + pure () + + logSize `shouldReturn` 1 + removeFile testStoreLogFile + where + createAndSecureQueue :: THandle -> SenderPublicKey -> IO (SenderId, RecipientId, C.SafePrivateKey) + createAndSecureQueue h sPub = do + (rPub, rKey) <- C.generateKeyPair rsaKeySize + Resp "abcd" "" (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) + let keyCmd = "KEY " <> C.serializePubKey sPub + Resp "dabc" rId' OK <- signSendRecv h rKey ("dabc", rId, keyCmd) + (rId', rId) #== "same queue ID" + pure (sId, rId, rKey) + + runTest :: (THandle -> IO ()) -> ThreadId -> Expectation + runTest test' server = do + testSMPClient test' `shouldReturn` () + killThread server + + logSize :: IO Int + logSize = + try (length . B.lines <$> B.readFile testStoreLogFile) >>= \case + Right l -> pure l + Left (_ :: SomeException) -> logSize + +testTiming :: Spec +testTiming = + it "should have similar time for auth error whether queue exists or not" $ + smpTest2 \rh sh -> do + (rPub, rKey) <- C.generateKeyPair rsaKeySize + Resp "abcd" "" (IDS rId sId) <- signSendRecv rh rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) + + (sPub, sKey) <- C.generateKeyPair rsaKeySize + let keyCmd = "KEY " <> C.serializePubKey sPub + Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, keyCmd) + + Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, "SEND 5 hello ") + + timeNoQueue <- timeRepeat 25 $ do + Resp "dabc" _ (ERR AUTH) <- signSendRecv sh sKey ("dabc", rId, "SEND 5 hello ") + return () + timeWrongKey <- timeRepeat 25 $ do + Resp "cdab" _ (ERR AUTH) <- signSendRecv sh rKey ("cdab", sId, "SEND 5 hello ") + return () + abs (timeNoQueue - timeWrongKey) / timeNoQueue < 0.15 `shouldBe` True + where + timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const + +samplePubKey :: ByteString +samplePubKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" + syntaxTests :: Spec syntaxTests = do - it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR SYNTAX 2") + it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX") describe "NEW" do - it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR SYNTAX 2") - it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR SYNTAX 2") - it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR SYNTAX 3") - it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR SYNTAX 4") + it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX") + it "many parameters" $ ("1234", "cdab", "", "NEW 1 " <> samplePubKey) >#> ("", "cdab", "", "ERR CMD SYNTAX") + it "no signature" $ ("", "dabc", "", "NEW " <> samplePubKey) >#> ("", "dabc", "", "ERR CMD NO_AUTH") + it "queue ID" $ ("1234", "abcd", "12345678", "NEW " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH") describe "KEY" do - it "valid syntax" $ ("1234", "bcda", "12345678", "KEY 3,4567,4567") >#> ("", "bcda", "12345678", "ERR AUTH") - it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR SYNTAX 3") - it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR SYNTAX 3") + it "valid syntax" $ ("1234", "bcda", "12345678", "KEY " <> samplePubKey) >#> ("", "bcda", "12345678", "ERR AUTH") + it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") + it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 " <> samplePubKey) >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "abcd", "12345678", "KEY " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "bcda", "", "KEY " <> samplePubKey) >#> ("", "bcda", "", "ERR CMD NO_AUTH") noParamsSyntaxTest "SUB" noParamsSyntaxTest "ACK" noParamsSyntaxTest "OFF" noParamsSyntaxTest "DEL" describe "SEND" do - it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5\r\nhello") >#> ("", "cdab", "12345678", "ERR AUTH") - it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11\r\nhello there") >#> ("", "dabc", "12345678", "ERR AUTH") - it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") - it "no queue ID" $ ("1234", "bcda", "", "SEND 5\r\nhello") >#> ("", "bcda", "", "ERR SYNTAX 5") - it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4\r\nhello") >#> ("", "abcd", "12345678", "ERR SIZE") + it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH") + it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH") + it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") + it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR CMD NO_QUEUE") + it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") + it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") describe "PING" do it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG") describe "broker response not allowed" do - it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR PROHIBITED") + it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR CMD PROHIBITED") where noParamsSyntaxTest :: ByteString -> Spec noParamsSyntaxTest cmd = describe (B.unpack cmd) do it "valid syntax" $ ("1234", "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH") - it "parameters" $ ("1234", "bcda", "12345678", cmd <> " 1") >#> ("", "bcda", "12345678", "ERR SYNTAX 2") - it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR SYNTAX 3") - it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR SYNTAX 3") + it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR CMD NO_AUTH") diff --git a/tests/Test.hs b/tests/Test.hs index 06f495d99..988b734f6 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,5 +1,6 @@ import AgentTests import MarkdownTests +import ProtocolErrorTests import ServerTests import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import Test.Hspec @@ -9,6 +10,7 @@ main = do createDirectoryIfMissing False "tests/tmp" hspec $ do describe "SimpleX markdown" markdownTests + describe "Protocol errors" protocolErrorTests describe "SMP server" serverTests describe "SMP client agent" agentTests removeDirectoryRecursive "tests/tmp"