From 0fbf406800119313cec668b40c2ae9e7afb39992 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 5 Apr 2021 13:10:16 +0100 Subject: [PATCH 01/17] transport encryption (#65) * transport encryption (WIP - using fixed key, parsing/serialization works, SMP tests fail) * transport encryption * transport encryption: separate keys to receive and to send, counter-based IVs * docs: update transport encryption and handshake * transport encryption handshake (TODO: validate key hash, welcome block, move keys to system environment) * change KeyHash type to newtype of Digest SHA256 * transport encryption: validate public key hash * send and receive welcome block with SMP version * refactor: parsing SMPServer * remove unused function * verify that client version is compatible with server version (major version is not smaller) * update (fix) SMP server tests --- apps/dog-food/ChatOptions.hs | 4 +- apps/dog-food/Main.hs | 5 +- rfcs/2021-01-26-crypto.md | 65 ++++--- rfcs/2021-03-18-groups.md | 25 +++ src/Simplex/Messaging/Agent.hs | 6 +- src/Simplex/Messaging/Agent/Client.hs | 6 +- src/Simplex/Messaging/Agent/Transmission.hs | 28 +-- src/Simplex/Messaging/Client.hs | 28 +-- src/Simplex/Messaging/Crypto.hs | 123 ++++++++---- src/Simplex/Messaging/Parsers.hs | 5 +- src/Simplex/Messaging/Protocol.hs | 103 +++++----- src/Simplex/Messaging/Server.hs | 15 +- src/Simplex/Messaging/Server/Env/STM.hs | 28 ++- src/Simplex/Messaging/Transport.hs | 196 +++++++++++++++++++- tests/AgentTests.hs | 5 +- tests/AgentTests/SQLiteTests.hs | 11 +- tests/SMPClient.hs | 48 +++-- tests/ServerTests.hs | 50 ++--- 18 files changed, 532 insertions(+), 219 deletions(-) create mode 100644 rfcs/2021-03-18-groups.md diff --git a/apps/dog-food/ChatOptions.hs b/apps/dog-food/ChatOptions.hs index 0a7ff89f0..0940c34b8 100644 --- a/apps/dog-food/ChatOptions.hs +++ b/apps/dog-food/ChatOptions.hs @@ -2,10 +2,10 @@ module ChatOptions (getChatOpts, ChatOpts (..)) where -import qualified Data.Attoparsec.ByteString.Char8 as A import qualified Data.ByteString.Char8 as B import Options.Applicative import Simplex.Messaging.Agent.Transmission (SMPServer (..), smpServerP) +import Simplex.Messaging.Parsers (parseAll) import System.FilePath (combine) import System.Info (os) import Types @@ -58,7 +58,7 @@ chatOpts appDir = | otherwise = TermModeEditor parseSMPServer :: ReadM SMPServer -parseSMPServer = eitherReader $ A.parseOnly (smpServerP <* A.endOfInput) . B.pack +parseSMPServer = eitherReader $ parseAll smpServerP . B.pack parseTermMode :: ReadM TermMode parseTermMode = maybeReader $ \case diff --git a/apps/dog-food/Main.hs b/apps/dog-food/Main.hs index ac33b81e5..bc2c1c198 100644 --- a/apps/dog-food/Main.hs +++ b/apps/dog-food/Main.hs @@ -26,6 +26,7 @@ import Simplex.Messaging.Agent.Client (AgentClient (..)) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (smpDefaultConfig) +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util (bshow, raceAny_) import System.Directory (getAppUserDataDirectory) import System.Exit (exitFailure) @@ -124,7 +125,7 @@ main = do t <- getChatClient smpServer user ct <- newChatTerminal (tbqSize cfg) user termMode -- setLogLevel LogInfo -- LogError - -- withGlobalLogging logCfg $ + -- withGlobalLogging logCfg $ do env <- newSMPAgentEnv cfg {dbFile = dbFileName} dogFoodChat t ct env @@ -172,7 +173,7 @@ newChatClient qSize smpServer name = do receiveFromChatTerm :: ChatClient -> ChatTerminal -> IO () receiveFromChatTerm t ct = forever $ do atomically (readTBQueue $ inputQ ct) - >>= processOrError . A.parseOnly (chatCommandP <* A.endOfInput) + >>= processOrError . parseAll chatCommandP where processOrError = \case Left err -> atomically . writeTBQueue (outQ t) . ErrorInput $ B.pack err diff --git a/rfcs/2021-01-26-crypto.md b/rfcs/2021-01-26-crypto.md index e9ea8aa49..c7f5f2c9d 100644 --- a/rfcs/2021-01-26-crypto.md +++ b/rfcs/2021-01-26-crypto.md @@ -16,49 +16,46 @@ For initial implementation I propose approach to be as simple as possible as lon One of the consideration is to use [noise protocol framework](https://noiseprotocol.org/noise.html), this section describes ad hoc protocol though. -During TCP session both client and server should use symmetric AES 256 bit encryption using the session key that will be established during the handshake. +During TCP session both client and server should use symmetric AES 256 bit encryption using two session keys and two base IVs that will be agreed during the handshake. Both client and the server should maintain two 32-bit word counters, one for sent and one for the received messages. The IV for each message should be computed by xor-ing the sequential message counter, starting from 0, with the first 32 bits of agreed base IV. TODO - explain it in a more formal way, also document how 32-bit word is encoded - with the most or least significant byte first (currently encodeWord32 from Network.Transport.Internal is used) -To establish the session key, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) and additional server ID (256 bits) in advance in order to be able to establish connection. +To establish the session keys and base IVs, the server should have an asymmetric key pair generated during server deployment and unknown to the clients. The users should know the key hash (256 bits) in advance in order to be able to establish connection. -The handshake sequence could be the following: +The handshake sequence is the following: -1. Once the connection is established, the server sends its public key to the client -2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. -3. If the hash is the same, the client should generate a random symmetric AES key and IV that will be used as a session key both by the client and the server. -4. The client then should encrypt this symmetric key with the public key that the server sent and send back to the server the result and the server ID also shared with the client in advance: `rsa-encrypt(aes-key, iv, server-id)`. -5. The server should decrypt the received key, IV and server id with its private key. -6. The server should compare the `server-id` sent by the client and if it does not match its ID terminate the connection. -7. In case of successful decryption and matching server ID, the server should send encrypted welcome header. +1. Once the connection is established, the server sends its public 2048 bit key to the client. TODO currently the key will be sent as a line terminated with CRLF, using ad-hoc key serialization we use. +2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. TODO as the hash is optional in server syntax at the moment, hash comparison will be optional as well. Probably it should become required. +3. If the hash is the same, the client should generate random symmetric AES keys and base IVs that will be used as session keys/IVs by the client and the server. +4. The client then should encrypt these symmetric keys and base IVs with the public key that the server sent, and send to the server the result of the encryption: `rsa-encrypt(snd-aes-key, snd-base-iv, rcv-aes-key, rcv-base-iv)`. `snd-aes-key` and `snd-base-iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv-aes-key` and `rcv-base-iv` will be used by the client to decrypt **received** messages and by the server to encrypt them. +5. The server should decrypt the received keys and base IVs with its private key. +6. In case of successful decryption, the server should send encrypted welcome block (encrypted_welcome_block) that contains SMP protocol version. -```abnf -aes_welcome_header = aes_header_auth_tag aes_encrypted_header -welcome_header = smp_version ["," smp_mode] *SP ; decrypt(aes_encrypted_header) - 32 bytes -smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format - ; for example: v123.456.789-alpha.7 -smp_mode = smp_public / smp_authenticated -smp_public = %s"pub" ; public (default) - no auth to create and manage queues -smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues -aes_header_auth_tag = aes_auth_tag -aes_auth_tag = 16*16(OCTET) -``` - -No payload should follow this header, it is only used to confirm successful handshake and send the SMP protocol version that the server supports. - -All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES key and IV sent by the client during the handshake. +All the subsequent data both from the client and from the server should be sent encrypted using symmetric AES keys and base IVs (incremented by counters on both sides) sent by the client during the handshake. Each transport block sent by the client and the server has this syntax: ```abnf -transport_block = aes_header_auth_tag aes_encrypted_header aes_body_auth_tag aes_encrypted_body -aes_encrypted_header = 32*32(OCTET) -header = padded_body_size payload_size reserved ; decrypt(aes_encrypted_header) - 32 bytes +transport_block = aes_body_auth_tag aes_encrypted_body ; fixed at 8192 bits aes_encrypted_body = 1*OCTET -body = payload pad -padded_body_size = size ; body size in bytes -payload_size = size ; payload_size in bytes -size = 4*4(OCTET) -reserved = 24*24(OCTET) -aes_body_auth_tag = aes_auth_tag +aes_body_auth_tag = 16*16(OCTET) + +encrypted_welcome_block = transport_block +welcome_block = smp_version SP pad ; decrypt(encrypted_welcome_block) +smp_version = %s"v" 1*DIGIT "." 1*DIGIT "." 1*DIGIT ["-" 1*ALPHA "." 1*DIGIT] ; in semver format + ; for example: v123.456.789-alpha.7 +pad = 1*OCTET +``` + +## Possible future improvements/changes + +- server id (256 bits), so that only the users that have it can connect to the server. This ID will have to be passed to the server during the handshake +- block size agreed during handshake +- transport encryption protocol agreed during handshake +- welcome block containing SMP mode (smp_mode) + +```abnf +smp_mode = smp_public / smp_authenticated +smp_public = %s"pub" ; public (default) - no auth to create and manage queues +smp_authenticated = %s"auth" ; server authentication with AUTH command (TBD) is required to create and manage queues ``` ## Initial handshake diff --git a/rfcs/2021-03-18-groups.md b/rfcs/2021-03-18-groups.md new file mode 100644 index 000000000..6da1fba69 --- /dev/null +++ b/rfcs/2021-03-18-groups.md @@ -0,0 +1,25 @@ +# SMP agent groups + +## Problems + +- device/user profile synchronisation +- chat group communication + +Both problems would require message broadcast between a group of SMP agents. + +## Solution: basic symmetric groups via SMP agent protocol + +Additional commands and message envelopes to SMP agent protocol to provide an abstraction layer for device synchronisation and chat groups. + +The groups are fully symmetric, all agent who are members of the group have equal rights and can join and leave group at any time. + +All the information about the groups is stored only in agents, the commands are used to synchronise the group state between the agents. + +```abnf +group_command = create_group / add_to_group / remove_from_group / leave_group +group_response = group_created / added_to_group / removed_from_group +group_notification = added_to_group_by / removed_from_group_by / left_group +create_group = %s"GNEW " group_name ; cAlias must be empty +add_to_group = %s"GADD " group_name ; cAlias is the connection to add to the group +added_to_group = %s"GADDED " name ; cAlias is the connection added to the group +``` \ No newline at end of file diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index eb9330dd3..163166ca1 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -35,7 +35,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (CorrId (..), MsgBody, SenderPublicKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport (putLn, runTCPServer) -import Simplex.Messaging.Util (liftError) +import Simplex.Messaging.Util (bshow, liftError) import System.IO (Handle) import UnliftIO.Async (race_) import UnliftIO.Exception (SomeException) @@ -94,7 +94,7 @@ send h c@AgentClient {sndQ} = forever $ do logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m () logClient AgentClient {clientId} dir (CorrId corrId, cAlias, cmd) = do - logInfo . decodeUtf8 $ B.unwords [B.pack $ show clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd] + logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, cAlias, B.takeWhile (/= ' ') $ serializeCommand cmd] client :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () client c@AgentClient {rcvQ, sndQ} st = forever $ do @@ -278,7 +278,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do removeSubscription c connAlias logServer "<--" c srv rId "END" notify connAlias END - _ -> logServer "<--" c srv rId $ "unexpected:" <> (B.pack . show) cmd + _ -> logServer "<--" c srv rId $ "unexpected:" <> bshow cmd where notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 1226c36be..a8dba40fa 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -47,7 +47,7 @@ import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey) -import Simplex.Messaging.Util (liftError) +import Simplex.Messaging.Util (bshow, liftError) import UnliftIO.Concurrent import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E @@ -133,7 +133,7 @@ withSMP c srv action = logServerError :: AgentErrorType -> m a logServerError e = do - logServer "<--" c srv "" $ (B.pack . show) e + logServer "<--" c srv "" $ bshow e throwError e withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a @@ -196,7 +196,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = - logInfo . decodeUtf8 $ B.unwords ["A", "(" <> (B.pack . show) clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] + logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] showServer :: SMPServer -> ByteString showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 207809276..af4f7fba2 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -152,7 +152,7 @@ serializeSMPMessage = \case in smpMessage "" header body where messageHeader msgId ts prevMsgHash = - B.unwords [B.pack $ show msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash] + B.unwords [bshow msgId, B.pack $ formatISO8601Millis ts, encode prevMsgHash] smpMessage smpHeader aHeader aBody = B.intercalate "\n" [smpHeader, aHeader, aBody, ""] agentMessageP :: Parser AMessage @@ -173,11 +173,17 @@ smpQueueInfoP = "smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP) smpServerP :: Parser SMPServer -smpServerP = SMPServer <$> server <*> port <*> msgHash +smpServerP = SMPServer <$> server <*> port <*> kHash where server = B.unpack <$> A.takeTill (A.inClass ":# ") - port = A.char ':' *> (Just . show <$> (A.decimal :: Parser Int)) <|> pure Nothing - msgHash = A.char '#' *> (Just <$> base64P) <|> pure Nothing + port = fromChar ':' $ show <$> (A.decimal :: Parser Int) + kHash = fromChar '#' C.keyHashP + fromChar :: Char -> Parser a -> Parser (Maybe a) + fromChar ch parser = do + c <- A.peekChar + if c == Just ch + then A.char ch *> (Just <$> parser) + else pure Nothing parseAgentMessage :: ByteString -> Either AgentErrorType AMessage parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage @@ -194,17 +200,15 @@ serializeSmpQueueInfo (SMPQueueInfo srv qId ek) = serializeServer :: SMPServer -> ByteString serializeServer SMPServer {host, port, keyHash} = - B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack) keyHash + B.pack $ host <> maybe "" (':' :) port <> maybe "" (('#' :) . B.unpack . C.serializeKeyHash) keyHash data SMPServer = SMPServer { host :: HostName, port :: Maybe ServiceName, - keyHash :: Maybe KeyHash + keyHash :: Maybe C.KeyHash } deriving (Eq, Ord, Show) -type KeyHash = Encoded - type ConnAlias = ByteString type OtherPartyId = Encoded @@ -354,7 +358,7 @@ serializeCommand = \case OFF -> "OFF" DEL -> "DEL" CON -> "CON" - ERR e -> "ERR " <> B.pack (show e) + ERR e -> "ERR " <> bshow e OK -> "OK" where replyMode :: ReplyMode -> ByteString @@ -370,13 +374,13 @@ serializeCommand = \case MsgError e -> "ERR" <> case e of MsgSkipped fromMsgId toMsgId -> - B.unwords ["NO_ID", B.pack $ show fromMsgId, B.pack $ show toMsgId] - MsgBadId aMsgId -> "ID " <> B.pack (show aMsgId) + B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId] + MsgBadId aMsgId -> "ID " <> bshow aMsgId MsgBadHash -> "HASH" -- TODO - save function as in the server Transmission - re-use? serializeMsg :: ByteString -> ByteString -serializeMsg body = B.pack (show $ B.length body) <> "\n" <> body +serializeMsg body = bshow (B.length body) <> "\n" <> body tPutRaw :: Handle -> ARawTransmission -> IO () tPutRaw h (corrId, connAlias, command) = do diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index a4737e399..757b42305 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -35,7 +36,6 @@ import Control.Monad.Trans.Class import Control.Monad.Trans.Except import qualified Crypto.PubKey.RSA.Types as RSA import Data.ByteString.Char8 (ByteString) -import qualified Data.ByteString.Char8 as B import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe @@ -46,7 +46,7 @@ import Simplex.Messaging.Agent.Transmission (SMPServer (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Util (liftEitherError, raceAny_) +import Simplex.Messaging.Util (bshow, liftEitherError, raceAny_) import System.IO import System.IO.Error import System.Timeout @@ -91,7 +91,7 @@ data Request = Request getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient getSMPClient - smpServer@SMPServer {host, port} + smpServer@SMPServer {host, port, keyHash} SMPClientConfig {qSize, defaultPort, tcpTimeout, smpPing} msgQ disconnected = do @@ -103,7 +103,7 @@ getSMPClient `finally` atomically (putTMVar started False) tcpTimeout `timeout` atomically (takeTMVar started) >>= \case Just True -> return c {action} - _ -> throwIO err + _ -> throwIO err -- TODO report handshake error too, not only connection timeout where err :: IOException err = mkIOError TimeExpired "connection timeout" Nothing Nothing @@ -128,18 +128,24 @@ getSMPClient } client :: SMPClient -> TMVar Bool -> Handle -> IO () - client c started h = do - _ <- getLn h -- "Welcome to SMP" + client c started h = + runExceptT (clientHandshake h keyHash) >>= \case + Right th -> clientTransport c started th + -- TODO report error instead of True/False + Left _ -> atomically $ putTMVar started False + + clientTransport :: SMPClient -> TMVar Bool -> THandle -> IO () + clientTransport c started th = do atomically $ do - modifyTVar (connected c) (const True) + writeTVar (connected c) True putTMVar started True - raceAny_ [send c h, process c, receive c h, ping c] + raceAny_ [send c th, process c, receive c th, ping c] `finally` disconnected - send :: SMPClient -> Handle -> IO () + send :: SMPClient -> THandle -> IO () send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h - receive :: SMPClient -> Handle -> IO () + receive :: SMPClient -> THandle -> IO () receive SMPClient {rcvQ} h = forever $ tGet fromServer h >>= atomically . writeTBQueue rcvQ ping :: SMPClient -> IO () @@ -241,7 +247,7 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do getNextCorrId = do i <- (+ 1) <$> readTVar clientCorrId writeTVar clientCorrId i - return . CorrId . B.pack $ show i + return . CorrId $ bshow i signTransmission :: ByteString -> ExceptT SMPClientError IO SignedRawTransmission signTransmission t = case pKey of diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 719056eae..f2aeb216c 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -10,17 +10,33 @@ module Simplex.Messaging.Crypto PublicKey (..), Signature (..), CryptoError (..), + KeyPair, + Key (..), + IV (..), + KeyHash (..), generateKeyPair, + publicKeySize, sign, verify, encrypt, decrypt, + encryptOAEP, + decryptOAEP, + encryptAES, + decryptAES, serializePrivKey, serializePubKey, - parsePrivKey, - parsePubKey, + serializeKeyHash, privKeyP, pubKeyP, + keyHashP, + authTagSize, + authTagToBS, + bsToAuthTag, + randomAesKey, + randomIV, + aesKeyP, + ivP, ) where @@ -30,7 +46,7 @@ import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import qualified Crypto.Cipher.Types as AES import qualified Crypto.Error as CE -import Crypto.Hash.Algorithms (SHA256 (..)) +import Crypto.Hash (Digest, SHA256 (..), digestFromByteString) import Crypto.Number.Generate (generateMax) import Crypto.Number.Prime (findPrimeFrom) import Crypto.Number.Serialize (i2osp, os2ip) @@ -53,7 +69,7 @@ import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) import Network.Transport.Internal (decodeWord32, encodeWord32) -import Simplex.Messaging.Parsers (base64P) +import Simplex.Messaging.Parsers (base64P, parseAll) import Simplex.Messaging.Util (bshow, liftEitherError, (<$$>)) newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show) @@ -71,14 +87,14 @@ instance ToField PublicKey where toField = toField . serializePubKey instance FromField PrivateKey where fromField f@(Field (SQLBlob b) _) = - case parsePrivKey b of + case parseAll privKeyP b of Right k -> Ok k Left e -> returnError ConversionFailed f ("couldn't parse PrivateKey field: " ++ e) fromField f = returnError ConversionFailed f "expecting SQLBlob column type" instance FromField PublicKey where fromField f@(Field (SQLBlob b) _) = - case parsePubKey b of + case parseAll pubKeyP b of Right k -> Ok k Left e -> returnError ConversionFailed f ("couldn't parse PublicKey field: " ++ e) fromField f = returnError ConversionFailed f "expecting SQLBlob column type" @@ -124,6 +140,9 @@ generateKeyPair size = loop then loop else return (PublicKey pub, privateKey s n d) +publicKeySize :: PublicKey -> Int +publicKeySize = R.public_size . rsaPublicKey + data Header = Header { aesKey :: Key, ivBytes :: IV, @@ -135,52 +154,83 @@ newtype Key = Key {unKey :: ByteString} newtype IV = IV {unIV :: ByteString} +newtype KeyHash = KeyHash {unKeyHash :: Digest SHA256} deriving (Eq, Ord, Show) + +instance IsString KeyHash where + fromString = either error id . parseAll keyHashP . fromString + +instance ToField KeyHash where toField = toField . serializeKeyHash + +instance FromField KeyHash where + fromField f@(Field (SQLBlob b) _) = + case parseAll keyHashP b of + Right k -> Ok k + Left e -> returnError ConversionFailed f ("couldn't parse KeyHash field: " ++ e) + fromField f = returnError ConversionFailed f "expecting SQLBlob column type" + +serializeKeyHash :: KeyHash -> ByteString +serializeKeyHash = encode . BA.convert . unKeyHash + +keyHashP :: Parser KeyHash +keyHashP = do + bs <- base64P + case digestFromByteString bs of + Just d -> pure $ KeyHash d + _ -> fail "invalid digest" + serializeHeader :: Header -> ByteString serializeHeader Header {aesKey, ivBytes, authTag, msgSize} = unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize headerP :: Parser Header headerP = do - aesKey <- Key <$> A.take aesKeySize - ivBytes <- IV <$> A.take (ivSize @AES256) + aesKey <- aesKeyP + ivBytes <- ivP authTag <- bsToAuthTag <$> A.take authTagSize msgSize <- fromIntegral . decodeWord32 <$> A.take 4 return Header {aesKey, ivBytes, authTag, msgSize} +aesKeyP :: Parser Key +aesKeyP = Key <$> A.take aesKeySize + +ivP :: Parser IV +ivP = IV <$> A.take (ivSize @AES256) + parseHeader :: ByteString -> Either CryptoError Header -parseHeader = first CryptoHeaderError . A.parseOnly (headerP <* A.endOfInput) +parseHeader = first CryptoHeaderError . parseAll headerP encrypt :: PublicKey -> Int -> ByteString -> ExceptT CryptoError IO ByteString encrypt k paddedSize msg = do - aesKey <- Key <$> randomBytes aesKeySize - ivBytes <- IV <$> randomBytes (ivSize @AES256) + aesKey <- liftIO randomAesKey + ivBytes <- liftIO randomIV + (authTag, msg') <- encryptAES aesKey ivBytes paddedSize msg + let header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg} + encHeader <- encryptOAEP k $ serializeHeader header + return $ encHeader <> msg' + +decrypt :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString +decrypt pk msg'' = do + let (encHeader, msg') = B.splitAt (private_size pk) msg'' + header <- decryptOAEP pk encHeader + Header {aesKey, ivBytes, authTag, msgSize} <- except $ parseHeader header + msg <- decryptAES aesKey ivBytes msg' authTag + return $ B.take msgSize msg + +encryptAES :: Key -> IV -> Int -> ByteString -> ExceptT CryptoError IO (AES.AuthTag, ByteString) +encryptAES aesKey ivBytes paddedSize msg = do aead <- initAEAD @AES256 aesKey ivBytes msg' <- paddedMsg - let (authTag, msg'') = encryptAES aead msg' - header = Header {aesKey, ivBytes, authTag, msgSize = B.length msg} - encHeader <- encryptOAEP k $ serializeHeader header - return $ encHeader <> msg'' + return $ AES.aeadSimpleEncrypt aead B.empty msg' authTagSize where len = B.length msg paddedMsg | len >= paddedSize = throwE CryptoLargeMsgError | otherwise = return (msg <> B.replicate (paddedSize - len) '#') -decrypt :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString -decrypt pk msg'' = do - let (encHeader, msg') = B.splitAt (private_size pk) msg'' - header <- decryptOAEP pk encHeader - Header {aesKey, ivBytes, authTag, msgSize} <- ExceptT . return $ parseHeader header +decryptAES :: Key -> IV -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString +decryptAES aesKey ivBytes msg authTag = do aead <- initAEAD @AES256 aesKey ivBytes - msg <- decryptAES aead msg' authTag - return $ B.take msgSize msg - -encryptAES :: AES.AEAD AES256 -> ByteString -> (AES.AuthTag, ByteString) -encryptAES aead plaintext = AES.aeadSimpleEncrypt aead B.empty plaintext authTagSize - -decryptAES :: AES.AEAD AES256 -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString -decryptAES aead ciphertext authTag = - maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty ciphertext authTag + maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c) initAEAD (Key aesKey) (IV ivBytes) = do @@ -189,15 +239,18 @@ initAEAD (Key aesKey) (IV ivBytes) = do cipher <- AES.cipherInit aesKey AES.aeadInit AES.AEAD_GCM cipher iv +randomAesKey :: IO Key +randomAesKey = Key <$> getRandomBytes aesKeySize + +randomIV :: IO IV +randomIV = IV <$> getRandomBytes (ivSize @AES256) + ivSize :: forall c. AES.BlockCipher c => Int ivSize = AES.blockSize (undefined :: c) makeIV :: AES.BlockCipher c => ByteString -> ExceptT CryptoError IO (AES.IV c) makeIV bs = maybeError CryptoIVError $ AES.makeIV bs -randomBytes :: Int -> ExceptT CryptoError IO ByteString -randomBytes n = ExceptT $ Right <$> getRandomBytes n - maybeError :: CryptoError -> Maybe a -> ExceptT CryptoError IO a maybeError e = maybe (throwE e) return @@ -253,12 +306,6 @@ privKeyP = do (private_size, private_n, private_d) <- keyParser_ return PrivateKey {private_size, private_n, private_d} -parsePubKey :: ByteString -> Either String PublicKey -parsePubKey = A.parseOnly (pubKeyP <* A.endOfInput) - -parsePrivKey :: ByteString -> Either String PrivateKey -parsePrivKey = A.parseOnly (privKeyP <* A.endOfInput) - keyParser_ :: Parser (Int, Integer, Integer) keyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP where diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 461c38c5e..85cb2f3a6 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -20,4 +20,7 @@ tsISO8601P :: Parser UTCTime tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ') parse :: Parser a -> e -> (ByteString -> Either e a) -parse parser err = first (const err) . A.parseOnly (parser <* A.endOfInput) +parse parser err = first (const err) . parseAll parser + +parseAll :: Parser a -> (ByteString -> Either String a) +parseAll parser = A.parseOnly (parser <* A.endOfInput) diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 94f49e24e..ad73e751a 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -13,7 +13,7 @@ module Simplex.Messaging.Protocol where import Control.Applicative ((<|>)) import Control.Monad -import Control.Monad.IO.Class +import Control.Monad.Except import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Base64 @@ -28,8 +28,6 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Transport import Simplex.Messaging.Util -import System.IO -import Text.Read data Party = Broker | Recipient | Sender deriving (Show) @@ -108,7 +106,7 @@ type MsgId = Encoded type MsgBody = ByteString -data ErrorType = PROHIBITED | SYNTAX Int | SIZE | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) +data ErrorType = PROHIBITED | SYNTAX Int | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) errBadTransmission :: Int errBadTransmission = 1 @@ -128,6 +126,16 @@ errNoQueueId = 5 errMessageBody :: Int errMessageBody = 6 +transmissionP :: Parser RawTransmission +transmissionP = do + signature <- segment + corrId <- segment + queueId <- segment + command <- A.takeByteString + return (signature, corrId, queueId, command) + where + segment = A.takeTill (== ' ') <* " " + commandP :: Parser Cmd commandP = "NEW " *> newCmd @@ -148,60 +156,47 @@ commandP = newCmd = Cmd SRecipient . NEW <$> C.pubKeyP idsResp = Cmd SBroker <$> (IDS <$> (base64P <* A.space) <*> base64P) keyCmd = Cmd SRecipient . KEY <$> C.pubKeyP - sendCmd = Cmd SSender . SEND <$> A.takeWhile A.isDigit + sendCmd = do + size <- A.decimal <* A.space + Cmd SSender . SEND <$> A.take size <* A.space message = do msgId <- base64P <* A.space ts <- tsISO8601P <* A.space - Cmd SBroker . MSG msgId ts <$> A.takeWhile A.isDigit + size <- A.decimal <* A.space + Cmd SBroker . MSG msgId ts <$> A.take size <* A.space serverError = Cmd SBroker . ERR <$> errorType errorType = "PROHIBITED" $> PROHIBITED <|> "SYNTAX " *> (SYNTAX <$> A.decimal) - <|> "SIZE" $> SIZE <|> "AUTH" $> AUTH <|> "INTERNAL" $> INTERNAL +-- TODO ignore the end of block, no need to parse it parseCommand :: ByteString -> Either ErrorType Cmd -parseCommand = parse commandP $ SYNTAX errBadSMPCommand +parseCommand = parse (commandP <* " " <* A.takeByteString) $ SYNTAX errBadSMPCommand serializeCommand :: Cmd -> ByteString serializeCommand = \case Cmd SRecipient (NEW rKey) -> "NEW " <> C.serializePubKey rKey Cmd SRecipient (KEY sKey) -> "KEY " <> C.serializePubKey sKey - Cmd SRecipient cmd -> B.pack $ show cmd - Cmd SSender (SEND msgBody) -> "SEND" <> serializeMsg msgBody + Cmd SRecipient cmd -> bshow cmd + Cmd SSender (SEND msgBody) -> "SEND " <> serializeMsg msgBody Cmd SSender PING -> "PING" Cmd SBroker (MSG msgId ts msgBody) -> - B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts] <> serializeMsg msgBody + B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody] Cmd SBroker (IDS rId sId) -> B.unwords ["IDS", encode rId, encode sId] - Cmd SBroker (ERR err) -> "ERR " <> B.pack (show err) - Cmd SBroker resp -> B.pack $ show resp + Cmd SBroker (ERR err) -> "ERR " <> bshow err + Cmd SBroker resp -> bshow resp where - serializeMsg msgBody = " " <> B.pack (show $ B.length msgBody) <> "\r\n" <> msgBody + serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " " -tPutRaw :: Handle -> RawTransmission -> IO () -tPutRaw h (signature, corrId, queueId, command) = do - putLn h signature - putLn h corrId - putLn h queueId - putLn h command - -tGetRaw :: Handle -> IO RawTransmission -tGetRaw h = do - signature <- getLn h - corrId <- getLn h - queueId <- getLn h - command <- getLn h - return (signature, corrId, queueId, command) - -tPut :: Handle -> SignedRawTransmission -> IO () -tPut h (C.Signature sig, t) = do - putLn h $ encode sig - putLn h t +tPut :: THandle -> SignedRawTransmission -> IO (Either TransportError ()) +tPut th (C.Signature sig, t) = + tPutEncrypted th $ encode sig <> " " <> t <> " " serializeTransmission :: Transmission -> ByteString serializeTransmission (CorrId corrId, queueId, command) = - B.intercalate "\r\n" [corrId, encode queueId, serializeCommand command] + B.intercalate " " [corrId, encode queueId, serializeCommand command] fromClient :: Cmd -> Either ErrorType Cmd fromClient = \case @@ -213,22 +208,28 @@ fromServer = \case cmd@(Cmd SBroker _) -> Right cmd _ -> Left PROHIBITED +tGetParse :: THandle -> IO (Either TransportError RawTransmission) +tGetParse th = (>>= parse transmissionP TransportParsingError) <$> tGetEncrypted th + -- | get client and server transmissions -- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used -tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> Handle -> m SignedTransmissionOrError -tGet fromParty h = do - (signature, corrId, queueId, command) <- liftIO $ tGetRaw h - let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId) - either (const $ tError corrId) tParseLoadBody decodedTransmission +tGet :: forall m. MonadIO m => (Cmd -> Either ErrorType Cmd) -> THandle -> m SignedTransmissionOrError +tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate where + decodeParseValidate :: Either TransportError RawTransmission -> m SignedTransmissionOrError + decodeParseValidate = \case + Right (signature, corrId, queueId, command) -> + let decodedTransmission = liftM2 (,corrId,,command) (decode signature) (decode queueId) + in either (const $ tError corrId) tParseValidate decodedTransmission + Left _ -> tError "" + tError :: ByteString -> m SignedTransmissionOrError tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left $ SYNTAX errBadTransmission)) - tParseLoadBody :: RawTransmission -> m SignedTransmissionOrError - tParseLoadBody t@(sig, corrId, queueId, command) = do + tParseValidate :: RawTransmission -> m SignedTransmissionOrError + tParseValidate t@(sig, corrId, queueId, command) = do let cmd = parseCommand command >>= fromParty >>= tCredentials t - fullCmd <- either (return . Left) cmdWithMsgBody cmd - return (C.Signature sig, (CorrId corrId, queueId, fullCmd)) + return (C.Signature sig, (CorrId corrId, queueId, cmd)) tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd tCredentials (signature, _, queueId, _) cmd = case cmd of @@ -261,19 +262,3 @@ tGet fromParty h = do Cmd SRecipient _ | B.null signature || B.null queueId -> Left $ SYNTAX errNoCredentials | otherwise -> Right cmd - - cmdWithMsgBody :: Cmd -> m (Either ErrorType Cmd) - cmdWithMsgBody = \case - Cmd SSender (SEND sizeStr) -> - Cmd SSender . SEND <$$> getMsgBody sizeStr - Cmd SBroker (MSG msgId ts sizeStr) -> - Cmd SBroker . MSG msgId ts <$$> getMsgBody sizeStr - cmd -> return $ Right cmd - - getMsgBody :: MsgBody -> m (Either ErrorType MsgBody) - getMsgBody sizeStr = case readMaybe (B.unpack sizeStr) :: Maybe Int of - Just size -> liftIO $ do - body <- B.hGet h size - s <- getLn h - return $ if B.null s then Right body else Left SIZE - Nothing -> return $ Left INTERNAL diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index a76a916e9..130ff2190 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -15,6 +15,7 @@ module Simplex.Messaging.Server (runSMPServer, randomBytes) where import Control.Concurrent.STM (stateTVar) import Control.Monad +import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader import Crypto.Random @@ -59,11 +60,17 @@ runSMPServer cfg@ServerConfig {tcpPort} = do runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () runClient h = do - liftIO $ putLn h "Welcome to SMP v0.2.0" + keyPair <- asks serverKeyPair + liftIO (runExceptT $ serverHandshake h keyPair) >>= \case + Right th -> runClientTransport th + Left _ -> pure () + +runClientTransport :: (MonadUnliftIO m, MonadReader Env m) => THandle -> m () +runClientTransport th = do q <- asks $ tbqSize . config c <- atomically $ newClient q s <- asks server - raceAny_ [send h c, client c s, receive h c] + raceAny_ [send th c, client c s, receive th c] `finally` cancelSubscribers c cancelSubscribers :: MonadUnliftIO m => Client -> m () @@ -75,7 +82,7 @@ cancelSub = \case Sub {subThread = SubThread t} -> killThread t _ -> return () -receive :: (MonadUnliftIO m, MonadReader Env m) => Handle -> Client -> m () +receive :: (MonadUnliftIO m, MonadReader Env m) => THandle -> Client -> m () receive h Client {rcvQ} = forever $ do (signature, (corrId, queueId, cmdOrError)) <- tGet fromClient h t <- case cmdOrError of @@ -83,7 +90,7 @@ receive h Client {rcvQ} = forever $ do Right cmd -> verifyTransmission (signature, (corrId, queueId, cmd)) atomically $ writeTBQueue rcvQ t -send :: MonadUnliftIO m => Handle -> Client -> m () +send :: MonadUnliftIO m => THandle -> Client -> m () send h Client {sndQ} = forever $ do t <- atomically $ readTBQueue sndQ liftIO $ tPut h ("", serializeTransmission t) diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index c2d170e05..126f00238 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -5,11 +5,13 @@ module Simplex.Messaging.Server.Env.STM where import Control.Concurrent (ThreadId) import Control.Monad.IO.Unlift +import qualified Crypto.PubKey.RSA as R import Crypto.Random import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Network.Socket (ServiceName) import Numeric.Natural +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Server.MsgStore.STM import Simplex.Messaging.Server.QueueStore.STM @@ -27,7 +29,9 @@ data Env = Env server :: Server, queueStore :: QueueStore, msgStore :: STMMsgStore, - idsDrg :: TVar ChaChaDRG + idsDrg :: TVar ChaChaDRG, + serverKeyPair :: C.KeyPair + -- serverId :: ByteString } data Server = Server @@ -72,4 +76,24 @@ newEnv config = do queueStore <- atomically newQueueStore msgStore <- atomically newMsgStore idsDrg <- drgNew >>= newTVarIO - return Env {config, server, queueStore, msgStore, idsDrg} + -- TODO these keys should be set in the environment, not in the code + return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair} + where + serverKeyPair = + ( C.PublicKey + { rsaPublicKey = + R.PublicKey + { public_size = 256, + public_n = 24491401566218566997383105010202223087300892576089255259580984651333137614713737618097624532507176450266480395052797332730303098565954279378701980313049999952643146946493842983667770915603693980339519205455913124235423278419181501399080069195664300809453039371169996023512911587381435574254546266774756319955237750224266282550919563293672568339958353047135257914364920805066749904289452712976534358633568668875150094910205741579097517675339029147403213185924413178887675432745168542469043448659751499651038006514754218441022754807971535895895877162103157702709155894482782232155817331812261258282431796597840952464257, + public_e = 8750208418393523480444709183090020123776537336553019181250117771363000810675051423462439348759073000328325050011503730211252469588880505946970399702607609166796825215104414212088697348613726705621594590369250976359268097976909710311654938358716518878047036682173044667792903503207106314854901036618348367397 + } + }, + C.PrivateKey + { private_size = 256, + private_n = 24491401566218566997383105010202223087300892576089255259580984651333137614713737618097624532507176450266480395052797332730303098565954279378701980313049999952643146946493842983667770915603693980339519205455913124235423278419181501399080069195664300809453039371169996023512911587381435574254546266774756319955237750224266282550919563293672568339958353047135257914364920805066749904289452712976534358633568668875150094910205741579097517675339029147403213185924413178887675432745168542469043448659751499651038006514754218441022754807971535895895877162103157702709155894482782232155817331812261258282431796597840952464257, + private_d = 7597313014691047671352664508683652467940113991200105893460705315744177757772923044415828427601194535604492873282390112577565179730319668643740113323630387082584239892956534048712048059175569855278723311295064858148623887611800385925820852572241607131360121661598015161261779381845187797044113149447495567589968956065009916550602209418325870594974390014927949966324558614396231902374868077411836997835082564279358230227298823445650053370542685308691044175390251929540772677009245507450972026595993054141350350385685400540681305852935721245601287301749047921282924410369389293829570448007237832101875085500166095784749 + } + ) + +-- public key hash: +-- "8Cvd+AYVxLpSsB/glEhVxkKuEzMNBFdAL5yr7p9DGGk=" diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index d0e2aa0e9..994366b02 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -1,21 +1,37 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Transport where -import Control.Monad.IO.Class +import Control.Monad.Except import Control.Monad.IO.Unlift -import Control.Monad.Reader +import Control.Monad.Trans.Except (throwE) +import Crypto.Cipher.Types (AuthTag) +import Crypto.Hash (hash) +import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first) +import Data.ByteArray (xor) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Set (Set) import qualified Data.Set as S +import Data.Word (Word32) import GHC.IO.Exception (IOErrorType (..)) +import GHC.IO.Handle.Internals (ioe_EOF) import Network.Socket +import Network.Transport.Internal (encodeWord32) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Parsers (parse, parseAll) +import Simplex.Messaging.Util (bshow, liftError) import System.IO import System.IO.Error import UnliftIO.Concurrent @@ -24,6 +40,8 @@ import qualified UnliftIO.Exception as E import qualified UnliftIO.IO as IO import UnliftIO.STM +-- * TCP transport + runTCPServer :: MonadUnliftIO m => ServiceName -> (Handle -> m ()) -> m () runTCPServer port server = do clients <- newTVarIO S.empty @@ -96,3 +114,177 @@ getLn h = trim_cr <$> B.hGetLine h where trim_cr "" = "" trim_cr s = if B.last s == '\r' then B.init s else s + +-- * Encrypted transport + +data SMPVersion = SMPVersion Int Int Int Int + deriving (Eq, Ord) + +major :: SMPVersion -> (Int, Int) +major (SMPVersion a b _ _) = (a, b) + +currentSMPVersion :: SMPVersion +currentSMPVersion = SMPVersion 0 2 0 0 + +serializeSMPVersion :: SMPVersion -> ByteString +serializeSMPVersion (SMPVersion a b c d) = B.intercalate "." [bshow a, bshow b, bshow c, bshow d] + +smpVersionP :: Parser SMPVersion +smpVersionP = + let ver = A.decimal <* A.char '.' + in SMPVersion <$> ver <*> ver <*> ver <*> A.decimal + +data THandle = THandle + { handle :: Handle, + sndKey :: SessionKey, + rcvKey :: SessionKey, + blockSize :: Int + } + +data SessionKey = SessionKey + { aesKey :: C.Key, + baseIV :: C.IV, + counter :: TVar Word32 + } + +data HandshakeKeys = HandshakeKeys + { sndKey :: SessionKey, + rcvKey :: SessionKey + } + +data TransportError + = TransportCryptoError C.CryptoError + | TransportParsingError + | TransportHandshakeError String + deriving (Eq, Show, Exception) + +tPutEncrypted :: THandle -> ByteString -> IO (Either TransportError ()) +tPutEncrypted THandle {handle = h, sndKey, blockSize} block = + encryptBlock sndKey (blockSize - C.authTagSize) block >>= \case + Left e -> return . Left $ TransportCryptoError e + Right (authTag, msg) -> Right <$> B.hPut h (C.authTagToBS authTag <> msg) + +tGetEncrypted :: THandle -> IO (Either TransportError ByteString) +tGetEncrypted THandle {handle = h, rcvKey, blockSize} = + B.hGet h blockSize >>= decryptBlock rcvKey >>= \case + Left e -> pure . Left $ TransportCryptoError e + Right "" -> ioe_EOF + Right msg -> pure $ Right msg + +encryptBlock :: SessionKey -> Int -> ByteString -> IO (Either C.CryptoError (AuthTag, ByteString)) +encryptBlock k@SessionKey {aesKey} size block = do + ivBytes <- makeNextIV k + runExceptT $ C.encryptAES aesKey ivBytes size block + +decryptBlock :: SessionKey -> ByteString -> IO (Either C.CryptoError ByteString) +decryptBlock k@SessionKey {aesKey} block = do + let (authTag, msg') = B.splitAt C.authTagSize block + ivBytes <- makeNextIV k + runExceptT $ C.decryptAES aesKey ivBytes msg' (C.bsToAuthTag authTag) + +makeNextIV :: SessionKey -> IO C.IV +makeNextIV SessionKey {baseIV, counter} = atomically $ do + c <- readTVar counter + writeTVar counter $ c + 1 + pure $ iv c + where + (start, rest) = B.splitAt 4 $ C.unIV baseIV + iv c = C.IV $ (start `xor` encodeWord32 c) <> rest + +-- | implements server transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption +-- The numbers in function names refer to the steps in the document +serverHandshake :: Handle -> C.KeyPair -> ExceptT TransportError IO THandle +serverHandshake h (k, pk) = do + liftIO sendPublicKey_1 + encryptedKeys <- receiveEncryptedKeys_4 + HandshakeKeys {sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys + th <- liftIO $ transportHandle h rcvKey sndKey -- keys are swapped here + sendWelcome_6 th + pure th + where + sendPublicKey_1 :: IO () + sendPublicKey_1 = putLn h $ C.serializePubKey k + receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString + receiveEncryptedKeys_4 = + liftIO (B.hGet h $ C.publicKeySize k) >>= \case + "" -> throwE $ TransportHandshakeError "EOF" + ks -> pure ks + decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO HandshakeKeys + decryptParseKeys_5 encKeys = + liftError TransportCryptoError (C.decryptOAEP pk encKeys) + >>= liftEither . parseHandshakeKeys + sendWelcome_6 :: THandle -> ExceptT TransportError IO () + sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " " + +-- | implements client transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption +-- The numbers in function names refer to the steps in the document +clientHandshake :: Handle -> Maybe C.KeyHash -> ExceptT TransportError IO THandle +clientHandshake h keyHash = do + k <- getPublicKey_1_2 + keys@HandshakeKeys {sndKey, rcvKey} <- liftIO generateKeys_3 + sendEncryptedKeys_4 k keys + th <- liftIO $ transportHandle h sndKey rcvKey + getWelcome_6 th >>= checkVersion + pure th + where + getPublicKey_1_2 :: ExceptT TransportError IO C.PublicKey + getPublicKey_1_2 = do + s <- liftIO $ getLn h + maybe (pure ()) (validateKeyHash_2 s) keyHash + liftEither $ parseKey s + parseKey :: ByteString -> Either TransportError C.PublicKey + parseKey = first TransportHandshakeError . parseAll C.pubKeyP + validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO () + validateKeyHash_2 k (C.KeyHash kHash) + | hash k == kHash = pure () + | otherwise = throwE $ TransportHandshakeError "wrong key hash" + generateKeys_3 :: IO HandshakeKeys + generateKeys_3 = HandshakeKeys <$> generateKey <*> generateKey + generateKey :: IO SessionKey + generateKey = do + aesKey <- C.randomAesKey + baseIV <- C.randomIV + pure SessionKey {aesKey, baseIV, counter = undefined} + sendEncryptedKeys_4 :: C.PublicKey -> HandshakeKeys -> ExceptT TransportError IO () + sendEncryptedKeys_4 k keys = + liftError TransportCryptoError (C.encryptOAEP k $ serializeHandshakeKeys keys) + >>= liftIO . B.hPut h + getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion + getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th + parseSMPVersion :: ByteString -> Either TransportError SMPVersion + parseSMPVersion = first TransportHandshakeError . A.parseOnly (smpVersionP <* A.space) + checkVersion :: SMPVersion -> ExceptT TransportError IO () + checkVersion smpVersion = + when (major smpVersion > major currentSMPVersion) . throwE $ + TransportHandshakeError "SMP server version" + +serializeHandshakeKeys :: HandshakeKeys -> ByteString +serializeHandshakeKeys HandshakeKeys {sndKey, rcvKey} = + serializeKey sndKey <> serializeKey rcvKey + where + serializeKey :: SessionKey -> ByteString + serializeKey SessionKey {aesKey, baseIV} = C.unKey aesKey <> C.unIV baseIV + +handshakeKeysP :: Parser HandshakeKeys +handshakeKeysP = HandshakeKeys <$> keyP <*> keyP + where + keyP :: Parser SessionKey + keyP = do + aesKey <- C.aesKeyP + baseIV <- C.ivP + pure SessionKey {aesKey, baseIV, counter = undefined} + +parseHandshakeKeys :: ByteString -> Either TransportError HandshakeKeys +parseHandshakeKeys = parse handshakeKeysP $ TransportHandshakeError "parsing keys" + +transportHandle :: Handle -> SessionKey -> SessionKey -> IO THandle +transportHandle h sk rk = do + sndCounter <- newTVarIO 0 + rcvCounter <- newTVarIO 0 + pure + THandle + { handle = h, + sndKey = sk {counter = sndCounter}, + rcvKey = rk {counter = rcvCounter}, + blockSize = 8192 + } diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index c3b7b0a5c..21fa7c063 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -13,6 +13,7 @@ import Control.Concurrent import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPAgentClient +import SMPClient (teshKeyHashStr) import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import System.IO (Handle) @@ -139,8 +140,8 @@ syntaxTests = do -- TODO: add tests with defined connection alias xit "only server" $ ("211", "", "NEW localhost") >#>= \case ("211", "", "INV" : _) -> True; _ -> False it "with port" $ ("212", "", "NEW localhost:5000") >#>= \case ("212", "", "INV" : _) -> True; _ -> False - xit "with keyHash" $ ("213", "", "NEW localhost#1234") >#>= \case ("213", "", "INV" : _) -> True; _ -> False - it "with port and keyHash" $ ("214", "", "NEW localhost:5000#1234") >#>= \case ("214", "", "INV" : _) -> True; _ -> False + xit "with keyHash" $ ("213", "", "NEW localhost#" <> teshKeyHashStr) >#>= \case ("213", "", "INV" : _) -> True; _ -> False + it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> teshKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False describe "invalid" do -- TODO: add tests with defined connection alias it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR SYNTAX 11") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index fe6593c88..38b759112 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -13,6 +13,7 @@ import Data.Time import Data.Word (Word32) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.QQ (sql) +import SMPClient (teshKeyHash) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Transmission @@ -96,7 +97,7 @@ testForeignKeysEnabled = do rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, rcvId = "1234", connAlias = "conn1", rcvPrivateKey = C.PrivateKey 1 2 3, @@ -110,7 +111,7 @@ rcvQueue1 = sndQueue1 :: SndQueue sndQueue1 = SndQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, sndId = "3456", connAlias = "conn1", sndPrivateKey = C.PrivateKey 1 2 3, @@ -156,7 +157,7 @@ testGetAllConnAliases = do testGetRcvQueue :: SpecWith SQLiteStore testGetRcvQueue = do it "should get RcvQueue" $ \store -> do - let smpServer = SMPServer "smp.simplex.im" (Just "5223") (Just "1234") + let smpServer = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash let recipientId = "1234" createRcvConn store rcvQueue1 `returnsResult` () @@ -211,7 +212,7 @@ testUpgradeRcvConnToDuplex = do `returnsResult` () let anotherSndQueue = SndQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, sndId = "2345", connAlias = "conn1", sndPrivateKey = C.PrivateKey 1 2 3, @@ -233,7 +234,7 @@ testUpgradeSndConnToDuplex = do `returnsResult` () let anotherRcvQueue = RcvQueue - { server = SMPServer "smp.simplex.im" (Just "5223") (Just "1234"), + { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, rcvId = "3456", connAlias = "conn1", rcvPrivateKey = C.PrivateKey 1 2 3, diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 0cde0e740..4c36c2db9 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -1,13 +1,19 @@ {-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module SMPClient where +import Control.Monad (void) +import Control.Monad.Except (runExceptT) import Control.Monad.IO.Unlift import Crypto.Random +import Data.ByteString.Base64 (encode) +import qualified Data.ByteString.Char8 as B import Network.Socket +import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Server import Simplex.Messaging.Server.Env.STM @@ -15,7 +21,6 @@ import Simplex.Messaging.Transport import Test.Hspec import UnliftIO.Concurrent import qualified UnliftIO.Exception as E -import UnliftIO.IO testHost :: HostName testHost = "localhost" @@ -23,14 +28,19 @@ testHost = "localhost" testPort :: ServiceName testPort = "5000" -testSMPClient :: MonadUnliftIO m => (Handle -> m a) -> m a +teshKeyHashStr :: B.ByteString +teshKeyHashStr = "8Cvd+AYVxLpSsB/glEhVxkKuEzMNBFdAL5yr7p9DGGk=" + +teshKeyHash :: Maybe C.KeyHash +teshKeyHash = Just "8Cvd+AYVxLpSsB/glEhVxkKuEzMNBFdAL5yr7p9DGGk=" + +testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a testSMPClient client = do threadDelay 250_000 -- TODO hack: thread delay for SMP server to start - runTCPClient testHost testPort $ \h -> do - line <- liftIO $ getLn h - if line == "Welcome to SMP v0.2.0" - then client h - else error "not connected" + runTCPClient testHost testPort $ \h -> + liftIO (runExceptT $ clientHandshake h teshKeyHash) >>= \case + Right th -> client th + Left e -> error $ show e cfg :: ServerConfig cfg = @@ -53,33 +63,43 @@ withSmpServerOn port = withSmpServerThreadOn port . const withSmpServer :: (MonadUnliftIO m, MonadRandom m) => m a -> m a withSmpServer = withSmpServerOn testPort -runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (Handle -> m a) -> m a +runSmpTest :: (MonadUnliftIO m, MonadRandom m) => (THandle -> m a) -> m a runSmpTest test = withSmpServer $ testSMPClient test -runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([Handle] -> m a) -> m a +runSmpTestN :: forall m a. (MonadUnliftIO m, MonadRandom m) => Int -> ([THandle] -> m a) -> m a runSmpTestN nClients test = withSmpServer $ run nClients [] where - run :: Int -> [Handle] -> m a + run :: Int -> [THandle] -> m a run 0 hs = test hs run n hs = testSMPClient $ \h -> run (n - 1) (h : hs) smpServerTest :: RawTransmission -> IO RawTransmission smpServerTest cmd = runSmpTest $ \h -> tPutRaw h cmd >> tGetRaw h -smpTest :: (Handle -> IO ()) -> Expectation +smpTest :: (THandle -> IO ()) -> Expectation smpTest test' = runSmpTest test' `shouldReturn` () -smpTestN :: Int -> ([Handle] -> IO ()) -> Expectation +smpTestN :: Int -> ([THandle] -> IO ()) -> Expectation smpTestN n test' = runSmpTestN n test' `shouldReturn` () -smpTest2 :: (Handle -> Handle -> IO ()) -> Expectation +smpTest2 :: (THandle -> THandle -> IO ()) -> Expectation smpTest2 test' = smpTestN 2 _test where _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" -smpTest3 :: (Handle -> Handle -> Handle -> IO ()) -> Expectation +smpTest3 :: (THandle -> THandle -> THandle -> IO ()) -> Expectation smpTest3 test' = smpTestN 3 _test where _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" + +tPutRaw :: THandle -> RawTransmission -> IO () +tPutRaw h (sig, corrId, queueId, command) = do + let t = B.intercalate " " [corrId, queueId, command] + void $ tPut h (C.Signature sig, t) + +tGetRaw :: THandle -> IO RawTransmission +tGetRaw h = do + ("", (CorrId corrId, qId, Right cmd)) <- tGet fromServer h + pure ("", corrId, encode qId, serializeCommand cmd) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 4a8c8766c..341bdbdf3 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -13,7 +13,7 @@ import qualified Data.ByteString.Char8 as B import SMPClient import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol -import System.IO (Handle) +import Simplex.Messaging.Transport import System.Timeout import Test.HUnit import Test.Hspec @@ -34,14 +34,14 @@ serverTests = do pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command))) -sendRecv :: Handle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError +sendRecv :: THandle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> tGet fromServer h -signSendRecv :: Handle -> C.PrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError +signSendRecv :: THandle -> C.PrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError signSendRecv h pk (corrId, qId, cmd) = do - let t = B.intercalate "\r\n" [corrId, encode qId, cmd] + let t = B.intercalate " " [corrId, encode qId, cmd] Right sig <- C.sign pk t - tPut h (sig, t) + _ <- tPut h (sig, t) tGet fromServer h cmdSEND :: ByteString -> ByteString @@ -61,7 +61,7 @@ testCreateSecure = Resp "abcd" rId1 (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) (rId1, "") #== "creates queue" - Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5\r\nhello") + Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, "SEND 5 hello ") (ok1, OK) #== "accepts unsigned SEND" (sId1, sId) #== "same queue ID in response 1" @@ -75,7 +75,7 @@ testCreateSecure = (err6, ERR PROHIBITED) #== "replies ERR when message acknowledged without messages" (sPub, sKey) <- C.generateKeyPair rsaKeySize - Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5\r\nhello") + Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5 hello ") (err1, ERR AUTH) #== "rejects signed SEND" (sId2, sId) #== "same queue ID in response 2" @@ -93,7 +93,7 @@ testCreateSecure = Resp "abcd" _ err4 <- signSendRecv h rKey ("abcd", rId, keyCmd) (err4, ERR AUTH) #== "rejects KEY if already secured" - Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11\r\nhello again") + Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, "SEND 11 hello again ") (ok3, OK) #== "accepts signed SEND" Resp "" _ (MSG _ _ msg) <- tGet fromServer h @@ -102,7 +102,7 @@ testCreateSecure = Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, "ACK") (ok5, OK) #== "replies OK when message acknowledged 2" - Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5\r\nhello") + Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, "SEND 5 hello ") (err5, ERR AUTH) #== "rejects unsigned SEND" testCreateDelete :: Spec @@ -117,10 +117,10 @@ testCreateDelete = Resp "bcda" _ ok1 <- signSendRecv rh rKey ("bcda", rId, "KEY " <> C.serializePubKey sPub) (ok1, OK) #== "secures queue" - Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello") + Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ") (ok2, OK) #== "accepts signed SEND" - Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7\r\nhello 2") + Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, "SEND 7 hello 2 ") (ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed" Resp "" _ (MSG _ _ msg1) <- tGet fromServer rh @@ -136,10 +136,10 @@ testCreateDelete = (ok3, OK) #== "suspends queue" (rId2, rId) #== "same queue ID in response 2" - Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5\r\nhello") + Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, "SEND 5 hello ") (err3, ERR AUTH) #== "rejects signed SEND" - Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5\r\nhello") + Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, "SEND 5 hello ") (err4, ERR AUTH) #== "reject unsigned SEND too" Resp "bcda" _ ok4 <- signSendRecv rh rKey ("bcda", rId, "OFF") @@ -158,10 +158,10 @@ testCreateDelete = (ok6, OK) #== "deletes queue" (rId3, rId) #== "same queue ID in response 3" - Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5\r\nhello") + Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, "SEND 5 hello ") (err7, ERR AUTH) #== "rejects signed SEND when deleted" - Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5\r\nhello") + Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, "SEND 5 hello ") (err8, ERR AUTH) #== "rejects unsigned SEND too when deleted" Resp "abcd" _ err11 <- signSendRecv rh rKey ("abcd", rId, "ACK") @@ -211,7 +211,7 @@ testDuplex = (aliceKey, C.serializePubKey asPub) #== "key received from Alice" Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, "KEY " <> aliceKey) - Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8\r\nhi alice") + Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, "SEND 8 hi alice ") Resp "" _ (MSG _ _ msg4) <- tGet fromServer alice Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, "ACK") @@ -229,7 +229,7 @@ testSwitchSub = smpTest3 \rh1 rh2 sh -> do (rPub, rKey) <- C.generateKeyPair rsaKeySize Resp "abcd" _ (IDS rId sId) <- signSendRecv rh1 rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) - Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5\r\ntest1") + Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, "SEND 5 test1 ") (ok1, OK) #== "sent test message 1" Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, cmdSEND "test2, no ACK") (ok2, OK) #== "sent test message 2" @@ -246,7 +246,7 @@ testSwitchSub = Resp "" _ end <- tGet fromServer rh1 (end, END) #== "unsubscribed the 1st TCP connection" - Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5\r\ntest3") + Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, "SEND 5 test3 ") Resp "" _ (MSG _ _ msg3) <- tGet fromServer rh2 (msg3, "test3") #== "delivered to the 2nd TCP connection" @@ -280,13 +280,13 @@ syntaxTests = do noParamsSyntaxTest "OFF" noParamsSyntaxTest "DEL" describe "SEND" do - it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5\r\nhello") >#> ("", "cdab", "12345678", "ERR AUTH") - it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11\r\nhello there") >#> ("", "dabc", "12345678", "ERR AUTH") + it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH") + it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH") it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") - it "no queue ID" $ ("1234", "bcda", "", "SEND 5\r\nhello") >#> ("", "bcda", "", "ERR SYNTAX 5") - it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4\r\nhello") >#> ("", "abcd", "12345678", "ERR SIZE") + it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR SYNTAX 5") + it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") + it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") + it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") describe "PING" do it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG") describe "broker response not allowed" do @@ -295,6 +295,6 @@ syntaxTests = do noParamsSyntaxTest :: ByteString -> Spec noParamsSyntaxTest cmd = describe (B.unpack cmd) do it "valid syntax" $ ("1234", "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH") - it "parameters" $ ("1234", "bcda", "12345678", cmd <> " 1") >#> ("", "bcda", "12345678", "ERR SYNTAX 2") + it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR SYNTAX 2") it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR SYNTAX 3") it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR SYNTAX 3") From ad7329893679e3d2aff2a55be0e080c8cb1cafe9 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 7 Apr 2021 22:59:57 +0100 Subject: [PATCH 02/17] Read server keys from files or create if absent (#79) * move server keys to config * add server keys from files * create server keys if key files do not exist * validate loaded server key pair * refactor fromString functions * key files in /etc/opt/simplex --- apps/smp-server/Main.hs | 62 +++++++++++++++++++++++-- src/Simplex/Messaging/Crypto.hs | 50 +++++++++++++++----- src/Simplex/Messaging/Server.hs | 2 +- src/Simplex/Messaging/Server/Env/STM.hs | 31 ++----------- src/Simplex/Messaging/Transport.hs | 5 +- tests/SMPClient.hs | 6 ++- 6 files changed, 110 insertions(+), 46 deletions(-) diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index d43c97193..0c873b21c 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -1,7 +1,19 @@ +{-# LANGUAGE OverloadedStrings #-} + module Main where +import Control.Monad (when) +import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Data.ByteString.Char8 as B +import Data.Char (toLower) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.Env.STM +import System.Directory (createDirectoryIfMissing, doesFileExist) +import System.Exit (exitFailure) +import System.FilePath (combine) +import System.IO (hFlush, stdout) cfg :: ServerConfig cfg = @@ -9,10 +21,54 @@ cfg = { tcpPort = "5223", tbqSize = 16, queueIdBytes = 12, - msgIdBytes = 6 + msgIdBytes = 6, + -- keys are loaded from files server_key.pub and server_key in ~/.simplex directory + serverKeyPair = undefined } +newKeySize :: Int +newKeySize = 2048 `div` 8 + +cfgDir :: FilePath +cfgDir = "/etc/opt/simplex" + main :: IO () main = do - putStrLn $ "Listening on port " ++ tcpPort cfg - runSMPServer cfg + (k, pk) <- readCreateKeys + B.putStrLn $ "SMP transport key hash: " <> publicKeyHash k + putStrLn $ "Listening on port " <> tcpPort cfg + runSMPServer cfg {serverKeyPair = (k, pk)} + +readCreateKeys :: IO C.KeyPair +readCreateKeys = do + createDirectoryIfMissing True cfgDir + let kPath = combine cfgDir "server_key.pub" + pkPath = combine cfgDir "server_key" + -- `||` is here to avoid creating keys and crash if one of two files exists + hasKeys <- (||) <$> doesFileExist kPath <*> doesFileExist pkPath + (if hasKeys then readKeys else createKeys) kPath pkPath + where + createKeys :: FilePath -> FilePath -> IO C.KeyPair + createKeys kPath pkPath = do + confirm + (k, pk) <- C.generateKeyPair newKeySize + B.writeFile kPath $ C.serializePubKey k + B.writeFile pkPath $ C.serializePrivKey pk + pure (k, pk) + confirm :: IO () + confirm = do + putStr "Generate new server key pair (y/N): " + hFlush stdout + ok <- getLine + when (map toLower ok /= "y") exitFailure + readKeys :: FilePath -> FilePath -> IO C.KeyPair + readKeys kPath pkPath = do + ks <- (,) <$> readKey kPath C.pubKeyP <*> readKey pkPath C.privKeyP + if C.validKeyPair ks then pure ks else putStrLn "invalid key pair" >> exitFailure + readKey :: FilePath -> Parser a -> IO a + readKey path parser = + let parseError = fail . ((path <> ": ") <>) + in B.readFile path >>= either parseError pure . parseAll parser . head . B.lines + +publicKeyHash :: C.PublicKey -> B.ByteString +publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.serializePubKey diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index f2aeb216c..75f2109c0 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -1,5 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -15,6 +16,7 @@ module Simplex.Messaging.Crypto IV (..), KeyHash (..), generateKeyPair, + validKeyPair, publicKeySize, sign, verify, @@ -27,6 +29,7 @@ module Simplex.Messaging.Crypto serializePrivKey, serializePubKey, serializeKeyHash, + getKeyHash, privKeyP, pubKeyP, keyHashP, @@ -46,8 +49,9 @@ import Control.Monad.Trans.Except import Crypto.Cipher.AES (AES256) import qualified Crypto.Cipher.Types as AES import qualified Crypto.Error as CE -import Crypto.Hash (Digest, SHA256 (..), digestFromByteString) +import Crypto.Hash (Digest, SHA256 (..), digestFromByteString, hash) import Crypto.Number.Generate (generateMax) +import Crypto.Number.ModArithmetic (expFast) import Crypto.Number.Prime (findPrimeFrom) import Crypto.Number.Serialize (i2osp, os2ip) import qualified Crypto.PubKey.RSA as R @@ -81,6 +85,15 @@ data PrivateKey = PrivateKey } deriving (Eq, Show) +instance IsString PrivateKey where + fromString = parseString privKeyP + +instance IsString PublicKey where + fromString = parseString pubKeyP + +parseString :: Parser a -> (String -> a) +parseString parser = either error id . parseAll parser . fromString + instance ToField PrivateKey where toField = toField . serializePrivKey instance ToField PublicKey where toField = toField . serializePubKey @@ -140,6 +153,16 @@ generateKeyPair size = loop then loop else return (PublicKey pub, privateKey s n d) +validKeyPair :: KeyPair -> Bool +validKeyPair + ( PublicKey R.PublicKey {public_size, public_n = n, public_e = e}, + PrivateKey {private_size, private_n, private_d = d} + ) = + let m = 30577 + in public_size == private_size + && n == private_n + && m == expFast (expFast m d n) e n + publicKeySize :: PublicKey -> Int publicKeySize = R.public_size . rsaPublicKey @@ -157,7 +180,7 @@ newtype IV = IV {unIV :: ByteString} newtype KeyHash = KeyHash {unKeyHash :: Digest SHA256} deriving (Eq, Ord, Show) instance IsString KeyHash where - fromString = either error id . parseAll keyHashP . fromString + fromString = parseString keyHashP instance ToField KeyHash where toField = toField . serializeKeyHash @@ -178,6 +201,9 @@ keyHashP = do Just d -> pure $ KeyHash d _ -> fail "invalid digest" +getKeyHash :: ByteString -> KeyHash +getKeyHash = KeyHash . hash + serializeHeader :: Header -> ByteString serializeHeader Header {aesKey, ivBytes, authTag, msgSize} = unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize @@ -314,16 +340,16 @@ keyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP rsaPrivateKey :: PrivateKey -> R.PrivateKey rsaPrivateKey pk = R.PrivateKey - { R.private_pub = + { private_pub = R.PublicKey - { R.public_size = private_size pk, - R.public_n = private_n pk, - R.public_e = undefined + { public_size = private_size pk, + public_n = private_n pk, + public_e = undefined }, - R.private_d = private_d pk, - R.private_p = 0, - R.private_q = 0, - R.private_dP = undefined, - R.private_dQ = undefined, - R.private_qinv = undefined + private_d = private_d pk, + private_p = 0, + private_q = 0, + private_dP = undefined, + private_dQ = undefined, + private_qinv = undefined } diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 130ff2190..69e3a0eec 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -60,7 +60,7 @@ runSMPServer cfg@ServerConfig {tcpPort} = do runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () runClient h = do - keyPair <- asks serverKeyPair + keyPair <- asks $ serverKeyPair . config liftIO (runExceptT $ serverHandshake h keyPair) >>= \case Right th -> runClientTransport th Left _ -> pure () diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 126f00238..cbfec9f3c 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -5,7 +5,6 @@ module Simplex.Messaging.Server.Env.STM where import Control.Concurrent (ThreadId) import Control.Monad.IO.Unlift -import qualified Crypto.PubKey.RSA as R import Crypto.Random import Data.Map.Strict (Map) import qualified Data.Map.Strict as M @@ -21,7 +20,9 @@ data ServerConfig = ServerConfig { tcpPort :: ServiceName, tbqSize :: Natural, queueIdBytes :: Int, - msgIdBytes :: Int + msgIdBytes :: Int, + serverKeyPair :: C.KeyPair + -- serverId :: ByteString } data Env = Env @@ -29,9 +30,7 @@ data Env = Env server :: Server, queueStore :: QueueStore, msgStore :: STMMsgStore, - idsDrg :: TVar ChaChaDRG, - serverKeyPair :: C.KeyPair - -- serverId :: ByteString + idsDrg :: TVar ChaChaDRG } data Server = Server @@ -76,24 +75,4 @@ newEnv config = do queueStore <- atomically newQueueStore msgStore <- atomically newMsgStore idsDrg <- drgNew >>= newTVarIO - -- TODO these keys should be set in the environment, not in the code - return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair} - where - serverKeyPair = - ( C.PublicKey - { rsaPublicKey = - R.PublicKey - { public_size = 256, - public_n = 24491401566218566997383105010202223087300892576089255259580984651333137614713737618097624532507176450266480395052797332730303098565954279378701980313049999952643146946493842983667770915603693980339519205455913124235423278419181501399080069195664300809453039371169996023512911587381435574254546266774756319955237750224266282550919563293672568339958353047135257914364920805066749904289452712976534358633568668875150094910205741579097517675339029147403213185924413178887675432745168542469043448659751499651038006514754218441022754807971535895895877162103157702709155894482782232155817331812261258282431796597840952464257, - public_e = 8750208418393523480444709183090020123776537336553019181250117771363000810675051423462439348759073000328325050011503730211252469588880505946970399702607609166796825215104414212088697348613726705621594590369250976359268097976909710311654938358716518878047036682173044667792903503207106314854901036618348367397 - } - }, - C.PrivateKey - { private_size = 256, - private_n = 24491401566218566997383105010202223087300892576089255259580984651333137614713737618097624532507176450266480395052797332730303098565954279378701980313049999952643146946493842983667770915603693980339519205455913124235423278419181501399080069195664300809453039371169996023512911587381435574254546266774756319955237750224266282550919563293672568339958353047135257914364920805066749904289452712976534358633568668875150094910205741579097517675339029147403213185924413178887675432745168542469043448659751499651038006514754218441022754807971535895895877162103157702709155894482782232155817331812261258282431796597840952464257, - private_d = 7597313014691047671352664508683652467940113991200105893460705315744177757772923044415828427601194535604492873282390112577565179730319668643740113323630387082584239892956534048712048059175569855278723311295064858148623887611800385925820852572241607131360121661598015161261779381845187797044113149447495567589968956065009916550602209418325870594974390014927949966324558614396231902374868077411836997835082564279358230227298823445650053370542685308691044175390251929540772677009245507450972026595993054141350350385685400540681305852935721245601287301749047921282924410369389293829570448007237832101875085500166095784749 - } - ) - --- public key hash: --- "8Cvd+AYVxLpSsB/glEhVxkKuEzMNBFdAL5yr7p9DGGk=" + return Env {config, server, queueStore, msgStore, idsDrg} diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 994366b02..46f95e050 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -15,7 +15,6 @@ import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except (throwE) import Crypto.Cipher.Types (AuthTag) -import Crypto.Hash (hash) import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) @@ -235,8 +234,8 @@ clientHandshake h keyHash = do parseKey :: ByteString -> Either TransportError C.PublicKey parseKey = first TransportHandshakeError . parseAll C.pubKeyP validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO () - validateKeyHash_2 k (C.KeyHash kHash) - | hash k == kHash = pure () + validateKeyHash_2 k kHash + | C.getKeyHash k == kHash = pure () | otherwise = throwE $ TransportHandshakeError "wrong key hash" generateKeys_3 :: IO HandshakeKeys generateKeys_3 = HandshakeKeys <$> generateKey <*> generateKey diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 4c36c2db9..85902cf57 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -48,7 +48,11 @@ cfg = { tcpPort = testPort, tbqSize = 1, queueIdBytes = 12, - msgIdBytes = 6 + msgIdBytes = 6, + serverKeyPair = + ( "256,wgJfm+EgMI3MeGdZlNs+KEoMlO0bpvZ2sa7bK4zWGtWGWXoCq1m89gaMk+f+HZavNJbJmflqrviBAoCFtDrA5+xC4+mwGlU6mLWiWtpvxgRBtNBsuHg3l+oJv0giFNCxoscne3P6n4kaCQEbA1T6KdrsdvxcaqyqzbpI7SozLIzhy45gsVywJfzpu6GYHlYNizdBJtoX2r66v6jDQFX7/MVDG4Z84RRa8PzjzT0wXSY+nirwIy5uwD0V5jrwaB0S5re6UnL7aLp51zHLUHPI/C9okBIkjY9kyQg3mAYXOPxb0OlGf3ENWnVdPKG6WqYnC3SBMIEVd4rqqxoH4myTgQ==,DHXxHfufuxfbuReISV9tCNttWXm/EVXTTN//hHkW/1wPLppbpY6aOqW+SZWwGCodIdGvdPSmaY9W8kfftWQY9xCOOcpkrzZwYHppT995xBIoB30vXG01dyruebFr3HjurT+uUbRGnxNYGwZg3AjkcyQtMKmq1pANvOGsOUgeDiU=", + "256,wgJfm+EgMI3MeGdZlNs+KEoMlO0bpvZ2sa7bK4zWGtWGWXoCq1m89gaMk+f+HZavNJbJmflqrviBAoCFtDrA5+xC4+mwGlU6mLWiWtpvxgRBtNBsuHg3l+oJv0giFNCxoscne3P6n4kaCQEbA1T6KdrsdvxcaqyqzbpI7SozLIzhy45gsVywJfzpu6GYHlYNizdBJtoX2r66v6jDQFX7/MVDG4Z84RRa8PzjzT0wXSY+nirwIy5uwD0V5jrwaB0S5re6UnL7aLp51zHLUHPI/C9okBIkjY9kyQg3mAYXOPxb0OlGf3ENWnVdPKG6WqYnC3SBMIEVd4rqqxoH4myTgQ==,PC6r+lZm5vyVpOl6dS9SXv09iE1PZoav6yeUbqsK+FScwHiOMEOkTY2mUyTHZ99nA4l7grAo4RPS6UOQS07QtgD2siZyj6F6Z3qAiBGesiG3+tb59pQ/prhs+5Q7RBlRMulz5KEwFINUb4Wy9ft4oIL/JJT9iSnYtTuGGirUEjB6YGzLKQeTyhkWA0iN89C5Vx6drB/pHyu3Mu+uc0Rax0UPD47gsNmxPNWUM6xLlkpNAWnSOHcSJZ3SN4QDLLCeBfqkgDLYkE3vbwvz8drt+H2eLi8OzFErEdkkrXg/0VwNjfhpBTt8D4TX00I7XsVksh3b2BRHzLfHTbLGdExLLQ==" + ) } withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a From 00289391550f8229878103d0888b48643dbf9132 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Wed, 14 Apr 2021 21:20:08 +0100 Subject: [PATCH 03/17] standard X509/PKCS8 encoding for RSA keys (#98) * key encoding primitives (WIP) * use X509/PKCS8 to read/write server key files * make PrivateKey type class * clean up * remove separate public key file * specific import --- apps/smp-server/Main.hs | 56 ++--- package.yaml | 5 + src/Simplex/Messaging/Agent/Transmission.hs | 6 +- src/Simplex/Messaging/Client.hs | 4 +- src/Simplex/Messaging/Crypto.hs | 216 ++++++++++++-------- src/Simplex/Messaging/Parsers.hs | 7 +- src/Simplex/Messaging/Protocol.hs | 4 +- src/Simplex/Messaging/Server.hs | 2 +- src/Simplex/Messaging/Server/Env/STM.hs | 9 +- src/Simplex/Messaging/Transport.hs | 2 +- stack.yaml | 4 +- tests/AgentTests/SQLiteTests.hs | 16 +- tests/SMPClient.hs | 38 +++- tests/ServerTests.hs | 2 +- 14 files changed, 232 insertions(+), 139 deletions(-) diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 0c873b21c..47cecf056 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -1,13 +1,14 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Main where import Control.Monad (when) -import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Crypto.Store.PKCS8 as S import qualified Data.ByteString.Char8 as B import Data.Char (toLower) +import Data.X509 (PrivKey (PrivKeyRSA)) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.Env.STM import System.Directory (createDirectoryIfMissing, doesFileExist) @@ -22,8 +23,8 @@ cfg = tbqSize = 16, queueIdBytes = 12, msgIdBytes = 6, - -- keys are loaded from files server_key.pub and server_key in ~/.simplex directory - serverKeyPair = undefined + -- key is loaded from the file server_key in /etc/opt/simplex directory + serverPrivateKey = undefined } newKeySize :: Int @@ -34,41 +35,40 @@ cfgDir = "/etc/opt/simplex" main :: IO () main = do - (k, pk) <- readCreateKeys - B.putStrLn $ "SMP transport key hash: " <> publicKeyHash k + pk <- readCreateKey + B.putStrLn $ "SMP transport key hash: " <> publicKeyHash (C.publicKey pk) putStrLn $ "Listening on port " <> tcpPort cfg - runSMPServer cfg {serverKeyPair = (k, pk)} + runSMPServer cfg {serverPrivateKey = pk} -readCreateKeys :: IO C.KeyPair -readCreateKeys = do +readCreateKey :: IO C.FullPrivateKey +readCreateKey = do createDirectoryIfMissing True cfgDir - let kPath = combine cfgDir "server_key.pub" - pkPath = combine cfgDir "server_key" - -- `||` is here to avoid creating keys and crash if one of two files exists - hasKeys <- (||) <$> doesFileExist kPath <*> doesFileExist pkPath - (if hasKeys then readKeys else createKeys) kPath pkPath + let path = combine cfgDir "server_key" + hasKey <- doesFileExist path + (if hasKey then readKey else createKey) path where - createKeys :: FilePath -> FilePath -> IO C.KeyPair - createKeys kPath pkPath = do + createKey :: FilePath -> IO C.FullPrivateKey + createKey path = do confirm - (k, pk) <- C.generateKeyPair newKeySize - B.writeFile kPath $ C.serializePubKey k - B.writeFile pkPath $ C.serializePrivKey pk - pure (k, pk) + (_, pk) <- C.generateKeyPair newKeySize + S.writeKeyFile S.TraditionalFormat path [PrivKeyRSA $ C.rsaPrivateKey pk] + pure pk confirm :: IO () confirm = do putStr "Generate new server key pair (y/N): " hFlush stdout ok <- getLine when (map toLower ok /= "y") exitFailure - readKeys :: FilePath -> FilePath -> IO C.KeyPair - readKeys kPath pkPath = do - ks <- (,) <$> readKey kPath C.pubKeyP <*> readKey pkPath C.privKeyP - if C.validKeyPair ks then pure ks else putStrLn "invalid key pair" >> exitFailure - readKey :: FilePath -> Parser a -> IO a - readKey path parser = - let parseError = fail . ((path <> ": ") <>) - in B.readFile path >>= either parseError pure . parseAll parser . head . B.lines + readKey :: FilePath -> IO C.FullPrivateKey + readKey path = do + S.readKeyFile path >>= \case + [S.Unprotected (PrivKeyRSA pk)] -> pure $ C.FullPrivateKey pk + [_] -> errorExit "not RSA key" + [] -> errorExit "invalid key file format" + _ -> errorExit "more than one key" + where + errorExit :: String -> IO b + errorExit e = putStrLn (e <> ": " <> path) >> exitFailure publicKeyHash :: C.PublicKey -> B.ByteString publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.serializePubKey diff --git a/package.yaml b/package.yaml index 83000817e..414379b1b 100644 --- a/package.yaml +++ b/package.yaml @@ -12,6 +12,9 @@ extra-source-files: - README.md dependencies: + - ansi-terminal == 0.10.* + - asn1-encoding == 0.9.* + - asn1-types == 0.3.* - async == 2.2.* - attoparsec == 0.13.* - base >= 4.7 && < 5 @@ -35,6 +38,7 @@ dependencies: - transformers == 0.5.* - unliftio == 0.2.* - unliftio-core == 0.1.* + - x509 == 1.7.* library: source-dirs: src @@ -44,6 +48,7 @@ executables: source-dirs: apps/smp-server main: Main.hs dependencies: + - cryptostore == 0.2.* - simplex-messaging ghc-options: - -threaded diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index af4f7fba2..075fd7b33 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -21,7 +21,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Int (Int64) -import Data.Kind +import Data.Kind (Type) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 import Data.Type.Equality @@ -224,9 +224,9 @@ data ReplyMode = ReplyOff | ReplyOn | ReplyVia SMPServer deriving (Eq, Show) type EncryptionKey = C.PublicKey -type DecryptionKey = C.PrivateKey +type DecryptionKey = C.SafePrivateKey -type SignatureKey = C.PrivateKey +type SignatureKey = C.SafePrivateKey type VerificationKey = C.PublicKey diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 757b42305..ad7020e14 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -228,13 +228,13 @@ suspendSMPQueue = okSMPCommand $ Cmd SRecipient OFF deleteSMPQueue :: SMPClient -> RecipientPrivateKey -> QueueId -> ExceptT SMPClientError IO () deleteSMPQueue = okSMPCommand $ Cmd SRecipient DEL -okSMPCommand :: Cmd -> SMPClient -> C.PrivateKey -> QueueId -> ExceptT SMPClientError IO () +okSMPCommand :: Cmd -> SMPClient -> C.SafePrivateKey -> QueueId -> ExceptT SMPClientError IO () okSMPCommand cmd c pKey qId = sendSMPCommand c (Just pKey) qId cmd >>= \case Cmd _ OK -> return () _ -> throwE SMPUnexpectedResponse -sendSMPCommand :: SMPClient -> Maybe C.PrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd +sendSMPCommand :: SMPClient -> Maybe C.SafePrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do corrId <- lift_ getNextCorrId t <- signTransmission $ serializeTransmission (corrId, qId, cmd) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 75f2109c0..5b8b85b53 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -1,23 +1,28 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Crypto - ( PrivateKey (..), + ( PrivateKey (rsaPrivateKey), + SafePrivateKey, -- constructor is not exported + FullPrivateKey (..), PublicKey (..), Signature (..), CryptoError (..), - KeyPair, + SafeKeyPair, + FullKeyPair, Key (..), IV (..), KeyHash (..), generateKeyPair, - validKeyPair, + publicKey, publicKeySize, + safePrivateKey, sign, verify, encrypt, @@ -43,6 +48,7 @@ module Simplex.Messaging.Crypto ) where +import Control.Applicative ((<|>)) import Control.Exception (Exception) import Control.Monad.Except import Control.Monad.Trans.Except @@ -51,13 +57,15 @@ import qualified Crypto.Cipher.Types as AES import qualified Crypto.Error as CE import Crypto.Hash (Digest, SHA256 (..), digestFromByteString, hash) import Crypto.Number.Generate (generateMax) -import Crypto.Number.ModArithmetic (expFast) import Crypto.Number.Prime (findPrimeFrom) -import Crypto.Number.Serialize (i2osp, os2ip) +import Crypto.Number.Serialize (os2ip) import qualified Crypto.PubKey.RSA as R import qualified Crypto.PubKey.RSA.OAEP as OAEP import qualified Crypto.PubKey.RSA.PSS as PSS import Crypto.Random (getRandomBytes) +import Data.ASN1.BinaryEncoding +import Data.ASN1.Encoding +import Data.ASN1.Types import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (first) @@ -66,53 +74,71 @@ import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.ByteString.Internal (c2w, w2c) +import Data.ByteString.Lazy (fromStrict, toStrict) import Data.String +import Data.Typeable (Typeable) +import Data.X509 import Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) import Network.Transport.Internal (decodeWord32, encodeWord32) -import Simplex.Messaging.Parsers (base64P, parseAll) -import Simplex.Messaging.Util (bshow, liftEitherError, (<$$>)) +import Simplex.Messaging.Parsers (base64P, base64StringP, parseAll) +import Simplex.Messaging.Util (liftEitherError, (<$$>)) newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show) -data PrivateKey = PrivateKey - { private_size :: Int, - private_n :: Integer, - private_d :: Integer - } - deriving (Eq, Show) +newtype SafePrivateKey = SafePrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show) -instance IsString PrivateKey where - fromString = parseString privKeyP +newtype FullPrivateKey = FullPrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show) + +class PrivateKey k where + rsaPrivateKey :: k -> R.PrivateKey + _privateKey :: R.PrivateKey -> k + mkPrivateKey :: R.PrivateKey -> k + +instance PrivateKey SafePrivateKey where + rsaPrivateKey = unPrivateKey + _privateKey = SafePrivateKey + mkPrivateKey R.PrivateKey {private_pub = k, private_d} = + safePrivateKey (R.public_size k, R.public_n k, private_d) + +instance PrivateKey FullPrivateKey where + rsaPrivateKey = unPrivateKey + _privateKey = FullPrivateKey + mkPrivateKey = FullPrivateKey + +instance IsString FullPrivateKey where + fromString = parseString decodePrivKey instance IsString PublicKey where - fromString = parseString pubKeyP + fromString = parseString decodePubKey -parseString :: Parser a -> (String -> a) -parseString parser = either error id . parseAll parser . fromString +parseString :: (ByteString -> Either String a) -> (String -> a) +parseString parse = either error id . parse . B.pack -instance ToField PrivateKey where toField = toField . serializePrivKey +instance ToField SafePrivateKey where toField = toField . serializePrivKey instance ToField PublicKey where toField = toField . serializePubKey -instance FromField PrivateKey where - fromField f@(Field (SQLBlob b) _) = - case parseAll privKeyP b of - Right k -> Ok k - Left e -> returnError ConversionFailed f ("couldn't parse PrivateKey field: " ++ e) - fromField f = returnError ConversionFailed f "expecting SQLBlob column type" +instance FromField SafePrivateKey where fromField = keyFromField privKeyP -instance FromField PublicKey where - fromField f@(Field (SQLBlob b) _) = - case parseAll pubKeyP b of - Right k -> Ok k - Left e -> returnError ConversionFailed f ("couldn't parse PublicKey field: " ++ e) - fromField f = returnError ConversionFailed f "expecting SQLBlob column type" +instance FromField PublicKey where fromField = keyFromField pubKeyP -type KeyPair = (PublicKey, PrivateKey) +keyFromField :: Typeable k => Parser k -> FieldParser k +keyFromField p = \case + f@(Field (SQLBlob b) _) -> + case parseAll p b of + Right k -> Ok k + Left e -> returnError ConversionFailed f ("couldn't parse key field: " ++ e) + f -> returnError ConversionFailed f "expecting SQLBlob column type" + +type KeyPair k = (PublicKey, k) + +type SafeKeyPair = (PublicKey, SafePrivateKey) + +type FullKeyPair = (PublicKey, FullPrivateKey) newtype Signature = Signature {unSignature :: ByteString} deriving (Eq, Show) @@ -139,29 +165,23 @@ aesKeySize = 256 `div` 8 authTagSize :: Int authTagSize = 128 `div` 8 -generateKeyPair :: Int -> IO KeyPair +generateKeyPair :: PrivateKey k => Int -> IO (KeyPair k) generateKeyPair size = loop where publicExponent = findPrimeFrom . (+ 3) <$> generateMax pubExpRange - privateKey s n d = PrivateKey {private_size = s, private_n = n, private_d = d} loop = do - (pub, priv) <- R.generate size =<< publicExponent - let s = R.public_size pub - n = R.public_n pub - d = R.private_d priv - in if d * d < n - then loop - else return (PublicKey pub, privateKey s n d) + (k, pk) <- R.generate size =<< publicExponent + let n = R.public_n k + d = R.private_d pk + if d * d < n + then loop + else pure (PublicKey k, mkPrivateKey pk) -validKeyPair :: KeyPair -> Bool -validKeyPair - ( PublicKey R.PublicKey {public_size, public_n = n, public_e = e}, - PrivateKey {private_size, private_n, private_d = d} - ) = - let m = 30577 - in public_size == private_size - && n == private_n - && m == expFast (expFast m d n) e n +rsaPrivateSize :: PrivateKey k => k -> Int +rsaPrivateSize = R.public_size . R.private_pub . rsaPrivateKey + +publicKey :: FullPrivateKey -> PublicKey +publicKey = PublicKey . R.private_pub . rsaPrivateKey publicKeySize :: PublicKey -> Int publicKeySize = R.public_size . rsaPublicKey @@ -180,7 +200,7 @@ newtype IV = IV {unIV :: ByteString} newtype KeyHash = KeyHash {unKeyHash :: Digest SHA256} deriving (Eq, Ord, Show) instance IsString KeyHash where - fromString = parseString keyHashP + fromString = parseString $ parseAll keyHashP instance ToField KeyHash where toField = toField . serializeKeyHash @@ -234,9 +254,9 @@ encrypt k paddedSize msg = do encHeader <- encryptOAEP k $ serializeHeader header return $ encHeader <> msg' -decrypt :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString +decrypt :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString decrypt pk msg'' = do - let (encHeader, msg') = B.splitAt (private_size pk) msg'' + let (encHeader, msg') = B.splitAt (rsaPrivateSize pk) msg'' header <- decryptOAEP pk encHeader Header {aesKey, ivBytes, authTag, msgSize} <- except $ parseHeader header msg <- decryptAES aesKey ivBytes msg' authTag @@ -297,7 +317,7 @@ encryptOAEP (PublicKey k) aesKey = liftEitherError CryptoRSAError $ OAEP.encrypt oaepParams k aesKey -decryptOAEP :: PrivateKey -> ByteString -> ExceptT CryptoError IO ByteString +decryptOAEP :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString decryptOAEP pk encKey = liftEitherError CryptoRSAError $ OAEP.decryptSafer oaepParams (rsaPrivateKey pk) encKey @@ -305,51 +325,85 @@ decryptOAEP pk encKey = pssParams :: PSS.PSSParams SHA256 ByteString ByteString pssParams = PSS.defaultPSSParams SHA256 -sign :: PrivateKey -> ByteString -> IO (Either R.Error Signature) +sign :: PrivateKey k => k -> ByteString -> IO (Either R.Error Signature) sign pk msg = Signature <$$> PSS.signSafer pssParams (rsaPrivateKey pk) msg verify :: PublicKey -> Signature -> ByteString -> Bool verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig serializePubKey :: PublicKey -> ByteString -serializePubKey (PublicKey k) = serializeKey_ (R.public_size k, R.public_n k, R.public_e k) +serializePubKey k = "rsa:" <> encodePubKey k -serializePrivKey :: PrivateKey -> ByteString -serializePrivKey pk = serializeKey_ (private_size pk, private_n pk, private_d pk) - -serializeKey_ :: (Int, Integer, Integer) -> ByteString -serializeKey_ (size, n, ex) = bshow size <> "," <> encInt n <> "," <> encInt ex - where - encInt = encode . i2osp +serializePrivKey :: PrivateKey k => k -> ByteString +serializePrivKey pk = "rsa:" <> encodePrivKey pk pubKeyP :: Parser PublicKey -pubKeyP = do - (public_size, public_n, public_e) <- keyParser_ - return . PublicKey $ R.PublicKey {R.public_size, R.public_n, R.public_e} +pubKeyP = keyP decodePubKey <|> legacyPubKeyP -privKeyP :: Parser PrivateKey -privKeyP = do - (private_size, private_n, private_d) <- keyParser_ - return PrivateKey {private_size, private_n, private_d} +privKeyP :: PrivateKey k => Parser k +privKeyP = keyP decodePrivKey <|> legacyPrivKeyP -keyParser_ :: Parser (Int, Integer, Integer) -keyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP +keyP :: (ByteString -> Either String k) -> Parser k +keyP dec = either fail pure . dec =<< ("rsa:" *> base64StringP) + +legacyPubKeyP :: Parser PublicKey +legacyPubKeyP = do + (public_size, public_n, public_e) <- legacyKeyParser_ + return . PublicKey $ R.PublicKey {public_size, public_n, public_e} + +legacyPrivKeyP :: PrivateKey k => Parser k +legacyPrivKeyP = _privateKey . safeRsaPrivateKey <$> legacyKeyParser_ + +legacyKeyParser_ :: Parser (Int, Integer, Integer) +legacyKeyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP where intP = os2ip <$> base64P -rsaPrivateKey :: PrivateKey -> R.PrivateKey -rsaPrivateKey pk = +safePrivateKey :: (Int, Integer, Integer) -> SafePrivateKey +safePrivateKey = SafePrivateKey . safeRsaPrivateKey + +safeRsaPrivateKey :: (Int, Integer, Integer) -> R.PrivateKey +safeRsaPrivateKey (size, n, d) = R.PrivateKey { private_pub = R.PublicKey - { public_size = private_size pk, - public_n = private_n pk, - public_e = undefined + { public_size = size, + public_n = n, + public_e = 0 }, - private_d = private_d pk, + private_d = d, private_p = 0, private_q = 0, - private_dP = undefined, - private_dQ = undefined, - private_qinv = undefined + private_dP = 0, + private_dQ = 0, + private_qinv = 0 } + +encodePubKey :: PublicKey -> ByteString +encodePubKey = encodeKey . PubKeyRSA . rsaPublicKey + +encodePrivKey :: PrivateKey k => k -> ByteString +encodePrivKey = encodeKey . PrivKeyRSA . rsaPrivateKey + +encodeKey :: ASN1Object a => a -> ByteString +encodeKey k = encode . toStrict . encodeASN1 DER $ toASN1 k [] + +decodePubKey :: ByteString -> Either String PublicKey +decodePubKey s = + decodeKey s >>= \case + (PubKeyRSA k, []) -> Right $ PublicKey k + r -> keyError r + +decodePrivKey :: PrivateKey k => ByteString -> Either String k +decodePrivKey s = + decodeKey s >>= \case + (PrivKeyRSA pk, []) -> Right $ mkPrivateKey pk + r -> keyError r + +decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) +decodeKey s = fromASN1 =<< first show . decodeASN1 DER . fromStrict =<< decode s + +keyError :: (a, [ASN1]) -> Either String b +keyError = \case + (_, []) -> Left "not RSA key" + _ -> Left "more than one key" diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 85cb2f3a6..74a802435 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -11,10 +11,13 @@ import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) base64P :: Parser ByteString -base64P = do +base64P = either fail pure . decode =<< base64StringP + +base64StringP :: Parser ByteString +base64StringP = do str <- A.takeWhile1 (\c -> isAlphaNum c || c == '+' || c == '/') pad <- A.takeWhile (== '=') - either fail pure $ decode (str <> pad) + pure $ str <> pad tsISO8601P :: Parser UTCTime tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ') diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index ad73e751a..2a4ea433e 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -93,12 +93,12 @@ instance IsString CorrId where fromString = CorrId . fromString -- only used by Agent, kept here so its definition is close to respective public key -type RecipientPrivateKey = C.PrivateKey +type RecipientPrivateKey = C.SafePrivateKey type RecipientPublicKey = C.PublicKey -- only used by Agent, kept here so its definition is close to respective public key -type SenderPrivateKey = C.PrivateKey +type SenderPrivateKey = C.SafePrivateKey type SenderPublicKey = C.PublicKey diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 69e3a0eec..130ff2190 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -60,7 +60,7 @@ runSMPServer cfg@ServerConfig {tcpPort} = do runClient :: (MonadUnliftIO m, MonadReader Env m) => Handle -> m () runClient h = do - keyPair <- asks $ serverKeyPair . config + keyPair <- asks serverKeyPair liftIO (runExceptT $ serverHandshake h keyPair) >>= \case Right th -> runClientTransport th Left _ -> pure () diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index cbfec9f3c..4371fc95f 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -21,7 +21,7 @@ data ServerConfig = ServerConfig tbqSize :: Natural, queueIdBytes :: Int, msgIdBytes :: Int, - serverKeyPair :: C.KeyPair + serverPrivateKey :: C.FullPrivateKey -- serverId :: ByteString } @@ -30,7 +30,8 @@ data Env = Env server :: Server, queueStore :: QueueStore, msgStore :: STMMsgStore, - idsDrg :: TVar ChaChaDRG + idsDrg :: TVar ChaChaDRG, + serverKeyPair :: C.FullKeyPair } data Server = Server @@ -75,4 +76,6 @@ newEnv config = do queueStore <- atomically newQueueStore msgStore <- atomically newMsgStore idsDrg <- drgNew >>= newTVarIO - return Env {config, server, queueStore, msgStore, idsDrg} + let pk = serverPrivateKey config + serverKeyPair = (C.publicKey pk, pk) + return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair} diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 46f95e050..38ace62d6 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -192,7 +192,7 @@ makeNextIV SessionKey {baseIV, counter} = atomically $ do -- | implements server transport handshake as per /rfcs/2021-01-26-crypto.md#transport-encryption -- The numbers in function names refer to the steps in the document -serverHandshake :: Handle -> C.KeyPair -> ExceptT TransportError IO THandle +serverHandshake :: Handle -> C.FullKeyPair -> ExceptT TransportError IO THandle serverHandshake h (k, pk) = do liftIO sendPublicKey_1 encryptedKeys <- receiveEncryptedKeys_4 diff --git a/stack.yaml b/stack.yaml index 0926d5374..4dd8b8bbf 100644 --- a/stack.yaml +++ b/stack.yaml @@ -35,9 +35,11 @@ packages: # forks / in-progress versions pinned to a git hash. For example: # extra-deps: - - sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002 + - cryptostore-0.2.1.0@sha256:9896e2984f36a1c8790f057fd5ce3da4cbcaf8aa73eb2d9277916886978c5b19,3881 - direct-sqlite-2.3.26@sha256:04e835402f1508abca383182023e4e2b9b86297b8533afbd4e57d1a5652e0c23,3718 - simple-logger-0.1.0@sha256:be8ede4bd251a9cac776533bae7fb643369ebd826eb948a9a18df1a8dd252ff8,1079 + - sqlite-simple-0.4.18.0@sha256:3ceea56375c0a3590c814e411a4eb86943f8d31b93b110ca159c90689b6b39e5,3002 + - terminal-0.2.0.0@sha256:de6770ecaae3197c66ac1f0db5a80cf5a5b1d3b64a66a05b50f442de5ad39570,2977 # - network-run-0.2.4@sha256:7dbb06def522dab413bce4a46af476820bffdff2071974736b06f52f4ab57c96,885 # - git: https://github.com/commercialhaskell/stack.git # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 38b759112..885352cef 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -100,10 +100,10 @@ rcvQueue1 = { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, rcvId = "1234", connAlias = "conn1", - rcvPrivateKey = C.PrivateKey 1 2 3, + rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "2345", sndKey = Nothing, - decryptKey = C.PrivateKey 1 2 3, + decryptKey = C.safePrivateKey (1, 2, 3), verifyKey = Nothing, status = New } @@ -114,9 +114,9 @@ sndQueue1 = { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, sndId = "3456", connAlias = "conn1", - sndPrivateKey = C.PrivateKey 1 2 3, + sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, - signKey = C.PrivateKey 1 2 3, + signKey = C.safePrivateKey (1, 2, 3), status = New } @@ -215,9 +215,9 @@ testUpgradeRcvConnToDuplex = do { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, sndId = "2345", connAlias = "conn1", - sndPrivateKey = C.PrivateKey 1 2 3, + sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, - signKey = C.PrivateKey 1 2 3, + signKey = C.safePrivateKey (1, 2, 3), status = New } upgradeRcvConnToDuplex store "conn1" anotherSndQueue @@ -237,10 +237,10 @@ testUpgradeSndConnToDuplex = do { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, rcvId = "3456", connAlias = "conn1", - rcvPrivateKey = C.PrivateKey 1 2 3, + rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "4567", sndKey = Nothing, - decryptKey = C.PrivateKey 1 2 3, + decryptKey = C.safePrivateKey (1, 2, 3), verifyKey = Nothing, status = New } diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 85902cf57..50d814a89 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -29,10 +29,10 @@ testPort :: ServiceName testPort = "5000" teshKeyHashStr :: B.ByteString -teshKeyHashStr = "8Cvd+AYVxLpSsB/glEhVxkKuEzMNBFdAL5yr7p9DGGk=" +teshKeyHashStr = "p1xa/XuzchgqomEL6RX+Me+fX096w50V7nJPAA0wpDE=" teshKeyHash :: Maybe C.KeyHash -teshKeyHash = Just "8Cvd+AYVxLpSsB/glEhVxkKuEzMNBFdAL5yr7p9DGGk=" +teshKeyHash = Just "p1xa/XuzchgqomEL6RX+Me+fX096w50V7nJPAA0wpDE=" testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a testSMPClient client = do @@ -49,10 +49,36 @@ cfg = tbqSize = 1, queueIdBytes = 12, msgIdBytes = 6, - serverKeyPair = - ( "256,wgJfm+EgMI3MeGdZlNs+KEoMlO0bpvZ2sa7bK4zWGtWGWXoCq1m89gaMk+f+HZavNJbJmflqrviBAoCFtDrA5+xC4+mwGlU6mLWiWtpvxgRBtNBsuHg3l+oJv0giFNCxoscne3P6n4kaCQEbA1T6KdrsdvxcaqyqzbpI7SozLIzhy45gsVywJfzpu6GYHlYNizdBJtoX2r66v6jDQFX7/MVDG4Z84RRa8PzjzT0wXSY+nirwIy5uwD0V5jrwaB0S5re6UnL7aLp51zHLUHPI/C9okBIkjY9kyQg3mAYXOPxb0OlGf3ENWnVdPKG6WqYnC3SBMIEVd4rqqxoH4myTgQ==,DHXxHfufuxfbuReISV9tCNttWXm/EVXTTN//hHkW/1wPLppbpY6aOqW+SZWwGCodIdGvdPSmaY9W8kfftWQY9xCOOcpkrzZwYHppT995xBIoB30vXG01dyruebFr3HjurT+uUbRGnxNYGwZg3AjkcyQtMKmq1pANvOGsOUgeDiU=", - "256,wgJfm+EgMI3MeGdZlNs+KEoMlO0bpvZ2sa7bK4zWGtWGWXoCq1m89gaMk+f+HZavNJbJmflqrviBAoCFtDrA5+xC4+mwGlU6mLWiWtpvxgRBtNBsuHg3l+oJv0giFNCxoscne3P6n4kaCQEbA1T6KdrsdvxcaqyqzbpI7SozLIzhy45gsVywJfzpu6GYHlYNizdBJtoX2r66v6jDQFX7/MVDG4Z84RRa8PzjzT0wXSY+nirwIy5uwD0V5jrwaB0S5re6UnL7aLp51zHLUHPI/C9okBIkjY9kyQg3mAYXOPxb0OlGf3ENWnVdPKG6WqYnC3SBMIEVd4rqqxoH4myTgQ==,PC6r+lZm5vyVpOl6dS9SXv09iE1PZoav6yeUbqsK+FScwHiOMEOkTY2mUyTHZ99nA4l7grAo4RPS6UOQS07QtgD2siZyj6F6Z3qAiBGesiG3+tb59pQ/prhs+5Q7RBlRMulz5KEwFINUb4Wy9ft4oIL/JJT9iSnYtTuGGirUEjB6YGzLKQeTyhkWA0iN89C5Vx6drB/pHyu3Mu+uc0Rax0UPD47gsNmxPNWUM6xLlkpNAWnSOHcSJZ3SN4QDLLCeBfqkgDLYkE3vbwvz8drt+H2eLi8OzFErEdkkrXg/0VwNjfhpBTt8D4TX00I7XsVksh3b2BRHzLfHTbLGdExLLQ==" - ) + serverPrivateKey = + -- full RSA private key (only for tests) + "MIIFIwIBAAKCAQEArZyrri/NAwt5buvYjwu+B/MQeJUszDBpRgVqNddlI9kNwDXu\ + \kaJ8chEhrtaUgXeSWGooWwqjXEUQE6RVbCC6QVo9VEBSP4xFwVVd9Fj7OsgfcXXh\ + \AqWxfctDcBZQ5jTUiJpdBc+Vz2ZkumVNl0W+j9kWm9nfkMLQj8c0cVSDxz4OKpZb\ + \qFuj0uzHkis7e7wsrKSKWLPg3M5ZXPZM1m9qn7SfJzDRDfJifamxWI7uz9XK2+Dp\ + \NkUQlGQgFJEv1cKN88JAwIqZ1s+TAQMQiB+4QZ2aNfSqGEzRJN7FMCKRK7pM0A9A\ + \PCnijyuImvKFxTdk8Bx1q+XNJzsY6fBrLWJZ+QKBgQCySG4tzlcEm+tOVWRcwrWh\ + \6zsczGZp9mbf9c8itRx6dlldSYuDG1qnddL70wuAZF2AgS1JZgvcRZECoZRoWP5q\ + \Kq2wvpTIYjFPpC39lxgUoA/DXKVKZZdan+gwaVPAPT54my1CS32VrOiAY4gVJ3LJ\ + \Mn1/FqZXUFQA326pau3loQKCAQEAoljmJMp88EZoy3HlHUbOjl5UEhzzVsU1TnQi\ + \QmPm+aWRe2qelhjW4aTvSVE5mAUJsN6UWTeMf4uvM69Z9I5pfw2pEm8x4+GxRibY\ + \iiwF2QNaLxxmzEHm1zQQPTgb39o8mgklhzFPill0JsnL3f6IkVwjFJofWSmpqEGs\ + \dFSMRSXUTVXh1p/o7QZrhpwO/475iWKVS7o48N/0Xp513re3aXw+DRNuVnFEaBIe\ + \TLvWM9Czn16ndAu1HYiTBuMvtRbAWnGZxU8ewzF4wlWK5tdIL5PTJDd1VhZJAKtB\ + \npDvJpwxzKmjAhcTmjx0ckMIWtdVaOVm/2gWCXDty2FEdg7koQKBgQDOUUguJ/i7\ + \q0jldWYRnVkotKnpInPdcEaodrehfOqYEHnvro9xlS6OeAS4Vz5AdH45zQ/4J3bV\ + \2cH66tNr18ebM9nL//t5G69i89R9W7szyUxCI3LmAIdi3oSEbmz5GQBaw4l6h9Wi\ + \n4FmFQaAXZrjQfO2qJcAHvWRsMp2pmqAGwKBgQDXaza0DRsKWywWznsHcmHa0cx8\ + \I4jxqGaQmLO7wBJRP1NSFrywy1QfYrVX9CTLBK4V3F0PCgZ01Qv94751CzN43TgF\ + \ebd/O9r5NjNTnOXzdWqETbCffLGd6kLgCMwPQWpM9ySVjXHWCGZsRAnF2F6M1O32\ + \43StIifvwJQFqSM3ewKBgCaW6y7sRY90Ua7283RErezd9EyT22BWlDlACrPu3FNC\ + \LtBf1j43uxBWBQrMLsHe2GtTV0xt9m0MfwZsm2gSsXcm4Xi4DJgfN+Z7rIlyy9UY\ + \PCDSdZiU1qSr+NrffDrXlfiAM1cUmCdUX7eKjp/ltkUHNaOGfSn5Pdr3MkAiD/Hf\ + \AoGBAKIdKCuOwuYlwjS9J+IRGuSSM4o+OxQdwGmcJDTCpyWb5dEk68e7xKIna3zf\ + \jc+H+QdMXv1nkRK9bZgYheXczsXaNZUSTwpxaEldzVD3hNvsXSgJRy9fqHwA4PBq\ + \vqiBHoO3RNbqg+2rmTMfDuXreME3S955ZiPZm4Z+T8Hj52mPAoGAQm5QH/gLFtY5\ + \+znqU/0G8V6BKISCQMxbbmTQVcTgGySrP2gVd+e4MWvUttaZykhWqs8rpr7mgpIY\ + \hul7Swx0SHFN3WpXu8uj+B6MLpRcCbDHO65qU4kQLs+IaXXsuuTjMvJ5LwjkZVrQ\ + \TmKzSAw7iVWwEUZR/PeiEKazqrpp9VU=" } withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 341bdbdf3..24c6af342 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -37,7 +37,7 @@ pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker sendRecv :: THandle -> (ByteString, ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> tGet fromServer h -signSendRecv :: THandle -> C.PrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError +signSendRecv :: THandle -> C.SafePrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError signSendRecv h pk (corrId, qId, cmd) = do let t = B.intercalate " " [corrId, encode qId, cmd] Right sig <- C.sign pk t From 417066c462c5b5b6e20af4fddd1c57ba05d0169a Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 16 Apr 2021 18:48:13 +0100 Subject: [PATCH 04/17] change missing IDs message status syntax (#100) --- src/Simplex/Messaging/Agent/Transmission.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 075fd7b33..0c9c8bfad 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -164,8 +164,8 @@ agentMessageP = hello = HELLO <$> C.pubKeyP <*> ackMode reply = REPLY <$> smpQueueInfoP a_msg = do - size :: Int <- A.decimal - A_MSG <$> (A.endOfLine *> A.take size <* A.endOfLine) + size :: Int <- A.decimal <* A.endOfLine + A_MSG <$> A.take size <* A.endOfLine ackMode = " NO_ACK" $> AckMode Off <|> pure (AckMode On) smpQueueInfoP :: Parser SMPQueueInfo @@ -330,7 +330,7 @@ commandP = status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) msgErrorType = "ID " *> (MsgBadId <$> A.decimal) - <|> "NO_ID " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) + <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash parseCommand :: ByteString -> Either AgentErrorType ACmd From 5e3bc7ee6c78f8dbf9a92e99c67a105f12149b20 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 18 Apr 2021 18:37:54 +0100 Subject: [PATCH 05/17] improve error handling (#101) * inventory of error handling problems and types * Change SMP protocol errors syntax * connection errors in agent protocol (ERR CONN), STORE error -> AGENT error * include exception in SEInternal error * add MESSAGE errors, remove CRYPTO and SIZE errors * agent protocol SYNTAX and AGENT errors * BROKER errors * group all client command (and agent response) errors * BROKER TRANSPORT error * simplify Client * clean up * transport errors * simplify client * parse / serialize agent errors * differentiate crypto errors * update errors.md * make agent and SMP protocol errors consistent, simplify * update doc * test: parse / serialize protocol errors with QuickCheck * add String to internal error * exponential back-off when retrying to send HELLO * refactor Client.hs * replace fold with recursion in startTCPClient * fail test if server did not start, refactor * test: wait till TCP server stops * test: refactor waiting for server to stop * test: fail with error if server did not start/stop --- package.yaml | 3 + src/Simplex/Messaging/Agent.hs | 74 ++++++------ src/Simplex/Messaging/Agent/Client.hs | 51 +++++--- src/Simplex/Messaging/Agent/Store.hs | 11 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 12 +- src/Simplex/Messaging/Agent/Transmission.hs | 113 ++++++++++-------- src/Simplex/Messaging/Client.hs | 65 +++++----- src/Simplex/Messaging/Crypto.hs | 24 ++-- src/Simplex/Messaging/Parsers.hs | 15 +++ src/Simplex/Messaging/Protocol.hs | 80 +++++++------ src/Simplex/Messaging/Server.hs | 4 +- .../Messaging/Server/QueueStore/STM.hs | 2 +- src/Simplex/Messaging/Transport.hs | 82 +++++++++---- src/Simplex/Messaging/errors.md | 97 +++++++++++++++ tests/AgentTests.hs | 10 +- tests/AgentTests/SQLiteTests.hs | 26 ++-- tests/ProtocolErrorTests.hs | 18 +++ tests/SMPAgentClient.hs | 16 +-- tests/SMPClient.hs | 23 +++- tests/ServerTests.hs | 40 +++---- tests/Test.hs | 2 + 21 files changed, 494 insertions(+), 274 deletions(-) create mode 100644 src/Simplex/Messaging/errors.md create mode 100644 tests/ProtocolErrorTests.hs diff --git a/package.yaml b/package.yaml index f2fd486ec..722d647dc 100644 --- a/package.yaml +++ b/package.yaml @@ -24,11 +24,13 @@ dependencies: - cryptonite == 0.26.* - directory == 1.3.* - filepath == 1.4.* + - generic-random == 1.3.* - iso8601-time == 0.1.* - memory == 0.15.* - mtl - network == 3.1.* - network-transport == 0.5.* + - QuickCheck == 2.13.* - simple-logger == 0.1.* - sqlite-simple == 0.4.* - stm @@ -82,6 +84,7 @@ tests: - hspec-core == 2.7.* - HUnit == 1.6.* - random == 1.1.* + - QuickCheck == 2.13.* ghc-options: # - -haddock diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 6f952fd35..59ca26015 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -114,10 +114,15 @@ withStore :: withStore action = do runExceptT (action `E.catch` handleInternal) >>= \case Right c -> return c - Left _ -> throwError STORE + Left e -> throwError $ storeError e where handleInternal :: (MonadError StoreError m') => SomeException -> m' a - handleInternal _ = throwError SEInternal + handleInternal e = throwError . SEInternal $ bshow e + storeError :: StoreError -> AgentErrorType + storeError = \case + SEConnNotFound -> CONN UNKNOWN + SEConnDuplicate -> CONN DUPLICATE + e -> INTERNAL $ show e processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m () processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = @@ -156,9 +161,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st cAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> subscribe rq SomeConn _ (RcvConnection _ rq) -> subscribe rq - -- TODO possibly there should be a separate error type trying - -- TODO to send the message to the connection without RcvQueue - _ -> throwError PROHIBITED + _ -> throwError $ CONN SIMPLEX where subscribe rq = subscribeQueue c rq cAlias >> respond' cAlias OK @@ -171,9 +174,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq SomeConn _ (SndConnection _ sq) -> sendMsg sq - -- TODO possibly there should be a separate error type trying - -- TODO to send the message to the connection without SndQueue - _ -> throwError PROHIBITED -- NOT_READY ? + _ -> throwError $ CONN SIMPLEX where sendMsg sq = do senderTs <- liftIO getCurrentTime @@ -186,7 +187,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> suspend rq SomeConn _ (RcvConnection _ rq) -> suspend rq - _ -> throwError PROHIBITED + _ -> throwError $ CONN SIMPLEX where suspend rq = suspendQueue c rq >> respond OK @@ -195,13 +196,13 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = withStore (getConn st connAlias) >>= \case SomeConn _ (DuplexConnection _ rq _) -> delete rq SomeConn _ (RcvConnection _ rq) -> delete rq - _ -> throwError PROHIBITED + _ -> delConn where + delConn = withStore (deleteConn st connAlias) >> respond OK delete rq = do deleteQueue c rq removeSubscription c connAlias - withStore (deleteConn st connAlias) - respond OK + delConn sendReplyQInfo :: SMPServer -> SndQueue -> m () sendReplyQInfo srv sq = do @@ -242,17 +243,14 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do -- TODO update sender key in the store? secureQueue c rq senderKey withStore $ setRcvQueueStatus st rq Secured - sendAck c rq - s -> - -- TODO maybe send notification to the user - liftIO . putStrLn $ "unexpected SMP confirmation, queue status " <> show s + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED SMPMessage {agentMessage, senderMsgId, senderTimestamp} -> case agentMessage of HELLO _verifyKey _ -> do logServer "<--" c srv rId "MSG " - -- TODO send status update to the user? - withStore $ setRcvQueueStatus st rq Active - sendAck c rq + case status of + Active -> notify connAlias . ERR $ AGENT A_PROHIBITED + _ -> withStore $ setRcvQueueStatus st rq Active REPLY qInfo -> do logServer "<--" c srv rId "MSG " -- TODO move senderKey inside SndQueue @@ -260,29 +258,33 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do withStore $ upgradeRcvConnToDuplex st connAlias sq connectToSendQueue c st sq senderKey verifyKey notify connAlias CON - sendAck c rq A_MSG body -> do - logServer "<--" c srv rId "MSG " -- TODO check message status - recipientTs <- liftIO getCurrentTime - let m_sender = (senderMsgId, senderTimestamp) - let m_broker = (srvMsgId, srvTs) - recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker - notify connAlias $ - MSG - { m_status = MsgOk, - m_recipient = (unId recipientId, recipientTs), - m_sender, - m_broker, - m_body = body - } - sendAck c rq + logServer "<--" c srv rId "MSG " + case status of + Active -> do + recipientTs <- liftIO getCurrentTime + let m_sender = (senderMsgId, senderTimestamp) + let m_broker = (srvMsgId, srvTs) + recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker + notify connAlias $ + MSG + { m_status = MsgOk, + m_recipient = (unId recipientId, recipientTs), + m_sender, + m_broker, + m_body = body + } + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + sendAck c rq return () SMP.END -> do removeSubscription c connAlias logServer "<--" c srv rId "END" notify connAlias END - _ -> logServer "<--" c srv rId $ "unexpected:" <> bshow cmd + _ -> do + logServer "<--" c srv rId $ "unexpected: " <> bshow cmd + notify connAlias . ERR $ BROKER UNEXPECTED where notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) @@ -295,7 +297,7 @@ connectToSendQueue c st sq senderKey verifyKey = do withStore $ setSndQueueStatus st sq Active decryptMessage :: (MonadUnliftIO m, MonadError AgentErrorType m) => DecryptionKey -> ByteString -> m ByteString -decryptMessage decryptKey msg = liftError CRYPTO $ C.decrypt decryptKey msg +decryptMessage decryptKey msg = liftError cryptoError $ C.decrypt decryptKey msg newSendQueue :: (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index a8dba40fa..8c431f888 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -24,6 +24,7 @@ module Simplex.Messaging.Agent.Client deleteQueue, logServer, removeSubscription, + cryptoError, ) where @@ -47,7 +48,7 @@ import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey) -import Simplex.Messaging.Util (bshow, liftError) +import Simplex.Messaging.Util (bshow, liftEitherError, liftError) import UnliftIO.Concurrent import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E @@ -86,15 +87,17 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = newSMPClient = do smp <- connectClient logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - -- TODO how can agent know client lost the connection? atomically . modifyTVar smpClients $ M.insert srv smp return smp connectClient :: m SMPClient connectClient = do cfg <- asks $ smpCfg . config - liftIO (getSMPClient srv cfg msgQ clientDisconnected) - `E.catch` \(_ :: IOException) -> throwError (BROKER smpErrTCPConnection) + liftEitherError smpClientError (getSMPClient srv cfg msgQ clientDisconnected) + `E.catch` internalError + where + internalError :: IOException -> m SMPClient + internalError = throwError . INTERNAL . show clientDisconnected :: IO () clientDisconnected = do @@ -125,12 +128,6 @@ withSMP c srv action = runAction :: SMPClient -> m a runAction smp = liftError smpClientError $ action smp - smpClientError :: SMPClientError -> AgentErrorType - smpClientError = \case - SMPServerError e -> SMP e - -- TODO handle other errors - _ -> INTERNAL - logServerError :: AgentErrorType -> m a logServerError e = do logServer "<--" c srv "" $ bshow e @@ -143,6 +140,16 @@ withLogSMP c srv qId cmdStr action = do logServer "<--" c srv qId "OK" return res +smpClientError :: SMPClientError -> AgentErrorType +smpClientError = \case + SMPServerError e -> SMP e + SMPResponseError e -> BROKER $ RESPONSE e + SMPUnexpectedResponse -> BROKER UNEXPECTED + SMPResponseTimeout -> BROKER TIMEOUT + SMPNetworkError -> BROKER NETWORK + SMPTransportError e -> BROKER $ TRANSPORT e + e -> INTERNAL $ show e + newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo) newReceiveQueue c srv connAlias = do size <- asks $ rsaKeySize . config @@ -214,26 +221,26 @@ sendConfirmation c SndQueue {server, sndId, encryptKey} senderKey = do mkConfirmation = do let msg = serializeSMPMessage $ SMPConfirmation senderKey paddedSize <- asks paddedMsgSize - liftError CRYPTO $ C.encrypt encryptKey paddedSize msg + liftError cryptoError $ C.encrypt encryptKey paddedSize msg sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () sendHello c SndQueue {server, sndId, sndPrivateKey, encryptKey} verifyKey = do msg <- mkHello $ AckMode On withLogSMP c server sndId "SEND (retrying)" $ - send 20 msg + send 8 100000 msg where mkHello :: AckMode -> m ByteString mkHello ackMode = do senderTs <- liftIO getCurrentTime mkAgentMessage encryptKey senderTs $ HELLO verifyKey ackMode - send :: Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () - send 0 _ _ = throwE SMPResponseTimeout -- TODO different error - send retry msg smp = + send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () + send 0 _ _ _ = throwE $ SMPServerError AUTH + send retry delay msg smp = sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case SMPServerError AUTH -> do - threadDelay 100000 - send (retry - 1) msg smp + threadDelay delay + send (retry - 1) (delay * 3 `div` 2) msg smp e -> throwE e secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m () @@ -273,4 +280,12 @@ mkAgentMessage encKey senderTs agentMessage = do agentMessage } paddedSize <- asks paddedMsgSize - liftError CRYPTO $ C.encrypt encKey paddedSize msg + liftError cryptoError $ C.encrypt encKey paddedSize msg + +cryptoError :: C.CryptoError -> AgentErrorType +cryptoError = \case + C.CryptoLargeMsgError -> CMD LARGE + C.RSADecryptError _ -> AGENT A_ENCRYPTION + C.CryptoHeaderError _ -> AGENT A_ENCRYPTION + C.AESDecryptError -> AGENT A_ENCRYPTION + e -> INTERNAL $ show e diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 6708bb13d..7521b9fae 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -10,6 +10,7 @@ module Simplex.Messaging.Agent.Store where import Control.Exception (Exception) +import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Kind (Type) import Data.Time (UTCTime) @@ -219,13 +220,11 @@ type InternalTs = UTCTime -- * Store errors --- TODO revise data StoreError - = SEInternal - | SENotFound - | SEBadConn + = SEInternal ByteString + | SEConnNotFound + | SEConnDuplicate | SEBadConnType ConnType - | SEBadQueueStatus - | SEBadQueueDirection + | SEBadQueueStatus -- not used, planned to check strictly | SENotImplemented -- TODO remove deriving (Eq, Show, Exception) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 0b83bbd77..1173e8215 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -95,7 +95,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto (Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) (Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ) (Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> throwError SEBadConn + _ -> throwError SEConnNotFound getAllConnAliases :: SQLiteStore -> m [ConnAlias] getAllConnAliases SQLiteStore {dbConn} = @@ -109,7 +109,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto retrieveRcvQueue dbConn host port rcvId case rcvQueue of Just rcvQ -> return rcvQ - _ -> throwError SENotFound + _ -> throwError SEConnNotFound deleteConn :: SQLiteStore -> ConnAlias -> m () deleteConn SQLiteStore {dbConn} connAlias = @@ -415,7 +415,7 @@ updateRcvConnWithSndQueue dbConn connAlias sndQueue = return $ Right () (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do @@ -443,7 +443,7 @@ updateSndConnWithRcvQueue dbConn connAlias rcvQueue = return $ Right () (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do @@ -508,7 +508,7 @@ insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId return $ Right internalId (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId) retrieveLastInternalIdsRcv_ dbConn connAlias = do @@ -599,7 +599,7 @@ insertSndMsg dbConn connAlias msgBody internalTs = updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId return $ Right internalId (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - _ -> return $ Left SEBadConn + _ -> return $ Left SEConnNotFound retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId) retrieveLastInternalIdsSnd_ dbConn connAlias = do diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 0c9c8bfad..02706da30 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -26,8 +27,9 @@ import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) import Network.Socket -import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol @@ -37,12 +39,12 @@ import Simplex.Messaging.Protocol MsgBody, MsgId, SenderPublicKey, - errMessageBody, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport import Simplex.Messaging.Util import System.IO +import Test.QuickCheck (Arbitrary (..)) import Text.Read import UnliftIO.Exception @@ -125,7 +127,7 @@ data AMessage where deriving (Show) parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage -parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ SYNTAX errBadMessage +parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE where smpMessageP :: Parser SMPMessage smpMessageP = @@ -186,7 +188,7 @@ smpServerP = SMPServer <$> server <*> port <*> kHash else pure Nothing parseAgentMessage :: ByteString -> Either AgentErrorType AMessage -parseAgentMessage = parse agentMessageP $ SYNTAX errBadMessage +parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE serializeAgentMessage :: AMessage -> ByteString serializeAgentMessage = \case @@ -245,50 +247,53 @@ data MsgStatus = MsgOk | MsgError MsgErrorType data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash deriving (Eq, Show) +-- | error type used in errors sent to agent clients data AgentErrorType - = UNKNOWN - | PROHIBITED - | SYNTAX Int - | BROKER Natural - | SMP ErrorType - | CRYPTO C.CryptoError - | SIZE - | STORE - | INTERNAL -- etc. TODO SYNTAX Natural - deriving (Eq, Show, Exception) + = CMD CommandErrorType -- command errors + | CONN ConnectionErrorType -- connection state errors + | SMP ErrorType -- SMP protocol errors forwarded to agent clients + | BROKER BrokerErrorType -- SMP server errors + | AGENT SMPAgentError -- errors of other agents + | INTERNAL String -- agent implementation errors + deriving (Eq, Generic, Read, Show, Exception) -data AckStatus = AckOk | AckError AckErrorType - deriving (Show) +data CommandErrorType + = PROHIBITED -- command is prohibited + | SYNTAX -- command syntax is invalid + | NO_CONN -- connection alias is required with this command + | SIZE -- message size is not correct (no terminating space) + | LARGE -- message does not fit SMP block + deriving (Eq, Generic, Read, Show, Exception) -data AckErrorType = AckUnknown | AckProhibited | AckSyntax Int -- etc. - deriving (Show) +data ConnectionErrorType + = UNKNOWN -- connection alias not in database + | DUPLICATE -- connection alias already exists + | SIMPLEX -- connection is simplex, but operation requires another queue + deriving (Eq, Generic, Read, Show, Exception) -errBadEncoding :: Int -errBadEncoding = 10 +data BrokerErrorType + = RESPONSE ErrorType -- invalid server response (failed to parse) + | UNEXPECTED -- unexpected response + | NETWORK -- network error + | TRANSPORT TransportError -- handshake or other transport error + | TIMEOUT -- command response timeout + deriving (Eq, Generic, Read, Show, Exception) -errBadCommand :: Int -errBadCommand = 11 +data SMPAgentError + = A_MESSAGE -- possibly should include bytestring that failed to parse + | A_PROHIBITED -- possibly should include the prohibited SMP/agent message + | A_ENCRYPTION -- cannot RSA/AES-decrypt or parse decrypted header + deriving (Eq, Generic, Read, Show, Exception) -errBadInvitation :: Int -errBadInvitation = 12 +instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU -errNoConnAlias :: Int -errNoConnAlias = 13 +instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU -errBadMessage :: Int -errBadMessage = 14 +instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU -errBadServer :: Int -errBadServer = 15 +instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU -smpErrTCPConnection :: Natural -smpErrTCPConnection = 1 - -smpErrCorrelationId :: Natural -smpErrCorrelationId = 2 - -smpUnexpectedResponse :: Natural -smpUnexpectedResponse = 3 +instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU commandP :: Parser ACmd commandP = @@ -319,9 +324,6 @@ commandP = m_sender <- "S=" *> partyMeta A.decimal m_body <- A.takeByteString return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body} - -- TODO other error types - agentError = ACmd SAgent . ERR <$> ("SMP " *> smpErrorType) - smpErrorType = "AUTH" $> SMP SMP.AUTH replyMode = " NO_REPLY" $> ReplyOff <|> A.space *> (ReplyVia <$> smpServerP) @@ -332,9 +334,10 @@ commandP = "ID " *> (MsgBadId <$> A.decimal) <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash + agentError = ACmd SAgent . ERR <$> agentErrorTypeP parseCommand :: ByteString -> Either AgentErrorType ACmd -parseCommand = parse commandP $ SYNTAX errBadCommand +parseCommand = parse commandP $ CMD SYNTAX serializeCommand :: ACommand p -> ByteString serializeCommand = \case @@ -358,7 +361,7 @@ serializeCommand = \case OFF -> "OFF" DEL -> "DEL" CON -> "CON" - ERR e -> "ERR " <> bshow e + ERR e -> "ERR " <> serializeAgentError e OK -> "OK" where replyMode :: ReplyMode -> ByteString @@ -378,7 +381,21 @@ serializeCommand = \case MsgBadId aMsgId -> "ID " <> bshow aMsgId MsgBadHash -> "HASH" --- TODO - save function as in the server Transmission - re-use? +agentErrorTypeP :: Parser AgentErrorType +agentErrorTypeP = + "SMP " *> (SMP <$> SMP.errorTypeP) + <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP) + <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) + <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) + <|> parseRead2 + +serializeAgentError :: AgentErrorType -> ByteString +serializeAgentError = \case + SMP e -> "SMP " <> SMP.serializeErrorType e + BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e + BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e + e -> bshow e + serializeMsg :: ByteString -> ByteString serializeMsg body = bshow (B.length body) <> "\n" <> body @@ -408,7 +425,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody fromParty :: ACmd -> Either AgentErrorType (ACommand p) fromParty (ACmd (p :: p1) cmd) = case testEquality party p of Just Refl -> Right cmd - _ -> Left PROHIBITED + _ -> Left $ CMD PROHIBITED tConnAlias :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p) tConnAlias (_, connAlias, _) cmd = case cmd of @@ -419,7 +436,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody ERR _ -> Right cmd -- other responses must have connAlias _ - | B.null connAlias -> Left $ SYNTAX errNoConnAlias + | B.null connAlias -> Left $ CMD NO_CONN | otherwise -> Right cmd cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) @@ -437,5 +454,5 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody Just size -> liftIO $ do body <- B.hGet h size s <- getLn h - return $ if B.null s then Right body else Left SIZE - Nothing -> return . Left $ SYNTAX errMessageBody + return $ if B.null s then Right body else Left $ CMD SIZE + Nothing -> return . Left $ CMD SYNTAX diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index ad7020e14..1efd16443 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -34,12 +34,10 @@ import Control.Exception import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Except -import qualified Crypto.PubKey.RSA.Types as RSA import Data.ByteString.Char8 (ByteString) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe -import GHC.IO.Exception (IOErrorType (..)) import Network.Socket (ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.Transmission (SMPServer (..)) @@ -48,13 +46,13 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Transport import Simplex.Messaging.Util (bshow, liftEitherError, raceAny_) import System.IO -import System.IO.Error import System.Timeout data SMPClient = SMPClient { action :: Async (), connected :: TVar Bool, smpServer :: SMPServer, + tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TVar (Map CorrId Request), sndQ :: TBQueue SignedRawTransmission, @@ -78,7 +76,7 @@ smpDefaultConfig = SMPClientConfig { qSize = 16, defaultPort = "5223", - tcpTimeout = 2_000_000, + tcpTimeout = 4_000_000, smpPing = 30_000_000, blockSize = 8_192, -- 16_384, smpCommandSize = 256 @@ -86,28 +84,29 @@ smpDefaultConfig = data Request = Request { queueId :: QueueId, - responseVar :: TMVar (Either SMPClientError Cmd) + responseVar :: TMVar Response } -getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO SMPClient +type Response = Either SMPClientError Cmd + +getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient) getSMPClient smpServer@SMPServer {host, port, keyHash} SMPClientConfig {qSize, defaultPort, tcpTimeout, smpPing} msgQ disconnected = do c <- atomically mkSMPClient - started <- newEmptyTMVarIO + err <- newEmptyTMVarIO action <- async $ - runTCPClient host (fromMaybe defaultPort port) (client c started) - `finally` atomically (putTMVar started False) - tcpTimeout `timeout` atomically (takeTMVar started) >>= \case - Just True -> return c {action} - _ -> throwIO err -- TODO report handshake error too, not only connection timeout + runTCPClient host (fromMaybe defaultPort port) (client c err) + `finally` atomically (putTMVar err $ Just SMPNetworkError) + ok <- tcpTimeout `timeout` atomically (takeTMVar err) + pure $ case ok of + Just Nothing -> Right c {action} + Just (Just e) -> Left e + Nothing -> Left SMPNetworkError where - err :: IOException - err = mkIOError TimeExpired "connection timeout" Nothing Nothing - mkSMPClient :: STM SMPClient mkSMPClient = do connected <- newTVar False @@ -120,6 +119,7 @@ getSMPClient { action = undefined, connected, smpServer, + tcpTimeout, clientCorrId, sentCommands, sndQ, @@ -127,18 +127,17 @@ getSMPClient msgQ } - client :: SMPClient -> TMVar Bool -> Handle -> IO () - client c started h = + client :: SMPClient -> TMVar (Maybe SMPClientError) -> Handle -> IO () + client c err h = runExceptT (clientHandshake h keyHash) >>= \case - Right th -> clientTransport c started th - -- TODO report error instead of True/False - Left _ -> atomically $ putTMVar started False + Right th -> clientTransport c err th + Left e -> atomically . putTMVar err . Just $ SMPTransportError e - clientTransport :: SMPClient -> TMVar Bool -> THandle -> IO () - clientTransport c started th = do + clientTransport :: SMPClient -> TMVar (Maybe SMPClientError) -> THandle -> IO () + clientTransport c err th = do atomically $ do writeTVar (connected c) True - putTMVar started True + putTMVar err Nothing raceAny_ [send c th, process c, receive c th, ping c] `finally` disconnected @@ -171,7 +170,7 @@ getSMPClient Left e -> Left $ SMPResponseError e Right (Cmd _ (ERR e)) -> Left $ SMPServerError e Right r -> Right r - else Left SMPQueueIdError + else Left SMPUnexpectedResponse closeSMPClient :: SMPClient -> IO () closeSMPClient = uninterruptibleCancel . action @@ -179,11 +178,11 @@ closeSMPClient = uninterruptibleCancel . action data SMPClientError = SMPServerError ErrorType | SMPResponseError ErrorType - | SMPQueueIdError | SMPUnexpectedResponse | SMPResponseTimeout - | SMPCryptoError RSA.Error - | SMPClientError + | SMPNetworkError + | SMPTransportError TransportError + | SMPSignatureError C.CryptoError deriving (Eq, Show, Exception) createSMPQueue :: @@ -235,7 +234,7 @@ okSMPCommand cmd c pKey qId = _ -> throwE SMPUnexpectedResponse sendSMPCommand :: SMPClient -> Maybe C.SafePrivateKey -> QueueId -> Cmd -> ExceptT SMPClientError IO Cmd -sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do +sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId cmd = do corrId <- lift_ getNextCorrId t <- signTransmission $ serializeTransmission (corrId, qId, cmd) ExceptT $ sendRecv corrId t @@ -253,14 +252,16 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId} pKey qId cmd = do signTransmission t = case pKey of Nothing -> return ("", t) Just pk -> do - sig <- liftEitherError SMPCryptoError $ C.sign pk t + sig <- liftEitherError SMPSignatureError $ C.sign pk t return (sig, t) -- two separate "atomically" needed to avoid blocking - sendRecv :: CorrId -> SignedRawTransmission -> IO (Either SMPClientError Cmd) - sendRecv corrId t = atomically (send corrId t) >>= atomically . takeTMVar + sendRecv :: CorrId -> SignedRawTransmission -> IO Response + sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar + where + withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a - send :: CorrId -> SignedRawTransmission -> STM (TMVar (Either SMPClientError Cmd)) + send :: CorrId -> SignedRawTransmission -> STM (TMVar Response) send corrId t = do r <- newEmptyTMVar modifyTVar sentCommands . M.insert corrId $ Request qId r diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 5b8b85b53..766c46632 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -68,7 +68,7 @@ import Data.ASN1.Encoding import Data.ASN1.Types import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Bifunctor (first) +import Data.Bifunctor (bimap, first) import qualified Data.ByteArray as BA import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -85,7 +85,7 @@ import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) import Network.Transport.Internal (decodeWord32, encodeWord32) import Simplex.Messaging.Parsers (base64P, base64StringP, parseAll) -import Simplex.Messaging.Util (liftEitherError, (<$$>)) +import Simplex.Messaging.Util (liftEitherError) newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show) @@ -148,10 +148,12 @@ instance IsString Signature where newtype Verified = Verified ByteString deriving (Show) data CryptoError - = CryptoRSAError R.Error - | CryptoCipherError CE.CryptoError + = RSAEncryptError R.Error + | RSADecryptError R.Error + | RSASignError R.Error + | AESCipherError CE.CryptoError | CryptoIVError - | CryptoDecryptError + | AESDecryptError | CryptoLargeMsgError | CryptoHeaderError String deriving (Eq, Show, Exception) @@ -276,7 +278,7 @@ encryptAES aesKey ivBytes paddedSize msg = do decryptAES :: Key -> IV -> ByteString -> AES.AuthTag -> ExceptT CryptoError IO ByteString decryptAES aesKey ivBytes msg authTag = do aead <- initAEAD @AES256 aesKey ivBytes - maybeError CryptoDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag + maybeError AESDecryptError $ AES.aeadSimpleDecrypt aead B.empty msg authTag initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c) initAEAD (Key aesKey) (IV ivBytes) = do @@ -307,26 +309,26 @@ bsToAuthTag :: ByteString -> AES.AuthTag bsToAuthTag = AES.AuthTag . BA.pack . map c2w . B.unpack cryptoFailable :: CE.CryptoFailable a -> ExceptT CryptoError IO a -cryptoFailable = liftEither . first CryptoCipherError . CE.eitherCryptoError +cryptoFailable = liftEither . first AESCipherError . CE.eitherCryptoError oaepParams :: OAEP.OAEPParams SHA256 ByteString ByteString oaepParams = OAEP.defaultOAEPParams SHA256 encryptOAEP :: PublicKey -> ByteString -> ExceptT CryptoError IO ByteString encryptOAEP (PublicKey k) aesKey = - liftEitherError CryptoRSAError $ + liftEitherError RSAEncryptError $ OAEP.encrypt oaepParams k aesKey decryptOAEP :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString decryptOAEP pk encKey = - liftEitherError CryptoRSAError $ + liftEitherError RSADecryptError $ OAEP.decryptSafer oaepParams (rsaPrivateKey pk) encKey pssParams :: PSS.PSSParams SHA256 ByteString ByteString pssParams = PSS.defaultPSSParams SHA256 -sign :: PrivateKey k => k -> ByteString -> IO (Either R.Error Signature) -sign pk msg = Signature <$$> PSS.signSafer pssParams (rsaPrivateKey pk) msg +sign :: PrivateKey k => k -> ByteString -> IO (Either CryptoError Signature) +sign pk msg = bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg verify :: PublicKey -> Signature -> ByteString -> Bool verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 74a802435..4f96bae0b 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE OverloadedStrings #-} + module Simplex.Messaging.Parsers where import Data.Attoparsec.ByteString.Char8 (Parser) @@ -9,6 +11,7 @@ import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) +import Text.Read (readMaybe) base64P :: Parser ByteString base64P = either fail pure . decode =<< base64StringP @@ -27,3 +30,15 @@ parse parser err = first (const err) . parseAll parser parseAll :: Parser a -> (ByteString -> Either String a) parseAll parser = A.parseOnly (parser <* A.endOfInput) + +parseRead :: Read a => Parser ByteString -> Parser a +parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack) + +parseRead1 :: Read a => Parser a +parseRead1 = parseRead $ A.takeTill (== ' ') + +parseRead2 :: Read a => Parser a +parseRead2 = parseRead $ do + w1 <- A.takeTill (== ' ') <* A.char ' ' + w2 <- A.takeTill (== ' ') + pure $ w1 <> " " <> w2 diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 2a4ea433e..d7f139c18 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} @@ -24,10 +25,13 @@ import Data.Kind import Data.String import Data.Time.Clock import Data.Time.ISO8601 +import GHC.Generics (Generic) +import Generic.Random (genericArbitraryU) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers import Simplex.Messaging.Transport import Simplex.Messaging.Util +import Test.QuickCheck (Arbitrary (..)) data Party = Broker | Recipient | Sender deriving (Show) @@ -106,25 +110,26 @@ type MsgId = Encoded type MsgBody = ByteString -data ErrorType = PROHIBITED | SYNTAX Int | AUTH | INTERNAL | DUPLICATE deriving (Show, Eq) +data ErrorType + = BLOCK + | CMD CommandError + | AUTH + | NO_MSG + | INTERNAL + | DUPLICATE_ -- TODO remove, not part of SMP protocol + deriving (Eq, Generic, Read, Show) -errBadTransmission :: Int -errBadTransmission = 1 +data CommandError + = PROHIBITED + | SYNTAX + | NO_AUTH + | HAS_AUTH + | NO_QUEUE + deriving (Eq, Generic, Read, Show) -errBadSMPCommand :: Int -errBadSMPCommand = 2 +instance Arbitrary ErrorType where arbitrary = genericArbitraryU -errNoCredentials :: Int -errNoCredentials = 3 - -errHasCredentials :: Int -errHasCredentials = 4 - -errNoQueueId :: Int -errNoQueueId = 5 - -errMessageBody :: Int -errMessageBody = 6 +instance Arbitrary CommandError where arbitrary = genericArbitraryU transmissionP :: Parser RawTransmission transmissionP = do @@ -164,16 +169,11 @@ commandP = ts <- tsISO8601P <* A.space size <- A.decimal <* A.space Cmd SBroker . MSG msgId ts <$> A.take size <* A.space - serverError = Cmd SBroker . ERR <$> errorType - errorType = - "PROHIBITED" $> PROHIBITED - <|> "SYNTAX " *> (SYNTAX <$> A.decimal) - <|> "AUTH" $> AUTH - <|> "INTERNAL" $> INTERNAL + serverError = Cmd SBroker . ERR <$> errorTypeP -- TODO ignore the end of block, no need to parse it parseCommand :: ByteString -> Either ErrorType Cmd -parseCommand = parse (commandP <* " " <* A.takeByteString) $ SYNTAX errBadSMPCommand +parseCommand = parse (commandP <* " " <* A.takeByteString) $ CMD SYNTAX serializeCommand :: Cmd -> ByteString serializeCommand = \case @@ -185,11 +185,17 @@ serializeCommand = \case Cmd SBroker (MSG msgId ts msgBody) -> B.unwords ["MSG", encode msgId, B.pack $ formatISO8601Millis ts, serializeMsg msgBody] Cmd SBroker (IDS rId sId) -> B.unwords ["IDS", encode rId, encode sId] - Cmd SBroker (ERR err) -> "ERR " <> bshow err + Cmd SBroker (ERR err) -> "ERR " <> serializeErrorType err Cmd SBroker resp -> bshow resp where serializeMsg msgBody = bshow (B.length msgBody) <> " " <> msgBody <> " " +errorTypeP :: Parser ErrorType +errorTypeP = "CMD " *> (CMD <$> parseRead1) <|> parseRead1 + +serializeErrorType :: ErrorType -> ByteString +serializeErrorType = bshow + tPut :: THandle -> SignedRawTransmission -> IO (Either TransportError ()) tPut th (C.Signature sig, t) = tPutEncrypted th $ encode sig <> " " <> t <> " " @@ -200,16 +206,16 @@ serializeTransmission (CorrId corrId, queueId, command) = fromClient :: Cmd -> Either ErrorType Cmd fromClient = \case - Cmd SBroker _ -> Left PROHIBITED + Cmd SBroker _ -> Left $ CMD PROHIBITED cmd -> Right cmd fromServer :: Cmd -> Either ErrorType Cmd fromServer = \case cmd@(Cmd SBroker _) -> Right cmd - _ -> Left PROHIBITED + _ -> Left $ CMD PROHIBITED tGetParse :: THandle -> IO (Either TransportError RawTransmission) -tGetParse th = (>>= parse transmissionP TransportParsingError) <$> tGetEncrypted th +tGetParse th = (>>= parse transmissionP TEBadBlock) <$> tGetEncrypted th -- | get client and server transmissions -- `fromParty` is used to limit allowed senders - `fromClient` or `fromServer` should be used @@ -224,7 +230,7 @@ tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate Left _ -> tError "" tError :: ByteString -> m SignedTransmissionOrError - tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left $ SYNTAX errBadTransmission)) + tError corrId = return (C.Signature B.empty, (CorrId corrId, B.empty, Left BLOCK)) tParseValidate :: RawTransmission -> m SignedTransmissionOrError tParseValidate t@(sig, corrId, queueId, command) = do @@ -233,32 +239,32 @@ tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate tCredentials :: RawTransmission -> Cmd -> Either ErrorType Cmd tCredentials (signature, _, queueId, _) cmd = case cmd of - -- IDS response should not have queue ID + -- IDS response must not have queue ID Cmd SBroker (IDS _ _) -> Right cmd -- ERR response does not always have queue ID Cmd SBroker (ERR _) -> Right cmd - -- PONG response should not have queue ID + -- PONG response must not have queue ID Cmd SBroker PONG | B.null queueId -> Right cmd - | otherwise -> Left $ SYNTAX errHasCredentials + | otherwise -> Left $ CMD HAS_AUTH -- other responses must have queue ID Cmd SBroker _ - | B.null queueId -> Left $ SYNTAX errNoQueueId + | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd -- NEW must NOT have signature or queue ID Cmd SRecipient (NEW _) - | B.null signature -> Left $ SYNTAX errNoCredentials - | not (B.null queueId) -> Left $ SYNTAX errHasCredentials + | B.null signature -> Left $ CMD NO_AUTH + | not (B.null queueId) -> Left $ CMD HAS_AUTH | otherwise -> Right cmd -- SEND must have queue ID, signature is not always required Cmd SSender (SEND _) - | B.null queueId -> Left $ SYNTAX errNoQueueId + | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd -- PING must not have queue ID or signature Cmd SSender PING | B.null queueId && B.null signature -> Right cmd - | otherwise -> Left $ SYNTAX errHasCredentials + | otherwise -> Left $ CMD HAS_AUTH -- other client commands must have both signature and queue ID Cmd SRecipient _ - | B.null signature || B.null queueId -> Left $ SYNTAX errNoCredentials + | B.null signature || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index c78eac3c1..3b6d6d8a3 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -165,7 +165,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = addQueueRetry n = do ids <- getIds atomically (addQueue st rKey ids) >>= \case - Left DUPLICATE -> addQueueRetry $ n - 1 + Left DUPLICATE_ -> addQueueRetry $ n - 1 Left e -> return $ Left e Right _ -> return $ Right ids @@ -200,7 +200,7 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s)) >>= \case Just (Just s) -> deliverMessage tryDelPeekMsg queueId s - _ -> return $ err PROHIBITED + _ -> return $ err NO_MSG withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a) withSub rId f = readTVar subscriptions >>= mapM f . M.lookup rId diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index dca596e82..86caff78f 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -32,7 +32,7 @@ instance MonadQueueStore QueueStore STM where addQueue store rKey ids@(rId, sId) = do cs@QueueStoreData {queues, senders} <- readTVar store if M.member rId queues || M.member sId senders - then return $ Left DUPLICATE + then return $ Left DUPLICATE_ else do writeTVar store $ cs diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 4d5d63e32..d0468886e 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -1,6 +1,7 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BlockArguments #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -11,6 +12,7 @@ module Simplex.Messaging.Transport where +import Control.Applicative ((<|>)) import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Trans.Except (throwE) @@ -21,18 +23,22 @@ import Data.Bifunctor (first) import Data.ByteArray (xor) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Functor (($>)) import Data.Set (Set) import qualified Data.Set as S import Data.Word (Word32) +import GHC.Generics (Generic) import GHC.IO.Exception (IOErrorType (..)) import GHC.IO.Handle.Internals (ioe_EOF) +import Generic.Random (genericArbitraryU) import Network.Socket import Network.Transport.Internal (encodeWord32) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Parsers (parse, parseAll) +import Simplex.Messaging.Parsers (parse, parseAll, parseRead1) import Simplex.Messaging.Util (bshow, liftError) import System.IO import System.IO.Error +import Test.QuickCheck (Arbitrary (..)) import UnliftIO.Concurrent import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E @@ -50,7 +56,10 @@ runTCPServer started port server = do atomically . modifyTVar clients $ S.insert tid where closeServer :: TVar (Set ThreadId) -> Socket -> IO () - closeServer clients sock = readTVarIO clients >>= mapM_ killThread >> close sock + closeServer clients sock = do + readTVarIO clients >>= mapM_ killThread + close sock + void . atomically $ tryPutTMVar started False startTCPServer :: TMVar Bool -> ServiceName -> IO Socket startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted @@ -76,9 +85,7 @@ runTCPClient host port client = do client h `E.finally` IO.hClose h startTCPClient :: HostName -> ServiceName -> IO Handle -startTCPClient host port = - withSocketsDo $ - resolve >>= foldM tryOpen (Left err) >>= either E.throwIO return -- replace fold with recursion +startTCPClient host port = withSocketsDo $ resolve >>= tryOpen err where err :: IOException err = mkIOError NoSuchThing "no address" Nothing Nothing @@ -88,9 +95,10 @@ startTCPClient host port = let hints = defaultHints {addrSocketType = Stream} in getAddrInfo (Just hints) (Just host) (Just port) - tryOpen :: Exception e => Either e Handle -> AddrInfo -> IO (Either e Handle) - tryOpen (Left _) addr = E.try $ open addr - tryOpen h _ = return h + tryOpen :: IOException -> [AddrInfo] -> IO Handle + tryOpen e [] = E.throwIO e + tryOpen _ (addr : as) = + E.try (open addr) >>= either (`tryOpen` as) pure open :: AddrInfo -> IO Handle open addr = do @@ -153,21 +161,51 @@ data HandshakeKeys = HandshakeKeys } data TransportError - = TransportCryptoError C.CryptoError - | TransportParsingError - | TransportHandshakeError String - deriving (Eq, Show, Exception) + = TEBadBlock + | TEEncrypt + | TEDecrypt + | TEHandshake HandshakeError + deriving (Eq, Generic, Read, Show, Exception) + +data HandshakeError + = ENCRYPT + | DECRYPT + | VERSION + | RSA_KEY + | AES_KEYS + | BAD_HASH + | MAJOR_VERSION + | TERMINATED + deriving (Eq, Generic, Read, Show, Exception) + +instance Arbitrary TransportError where arbitrary = genericArbitraryU + +instance Arbitrary HandshakeError where arbitrary = genericArbitraryU + +transportErrorP :: Parser TransportError +transportErrorP = + "BLOCK" $> TEBadBlock + <|> "AES_ENCRYPT" $> TEEncrypt + <|> "AES_DECRYPT" $> TEDecrypt + <|> TEHandshake <$> parseRead1 + +serializeTransportError :: TransportError -> ByteString +serializeTransportError = \case + TEEncrypt -> "AES_ENCRYPT" + TEDecrypt -> "AES_DECRYPT" + TEBadBlock -> "BLOCK" + TEHandshake e -> bshow e tPutEncrypted :: THandle -> ByteString -> IO (Either TransportError ()) tPutEncrypted THandle {handle = h, sndKey, blockSize} block = encryptBlock sndKey (blockSize - C.authTagSize) block >>= \case - Left e -> return . Left $ TransportCryptoError e + Left _ -> pure $ Left TEEncrypt Right (authTag, msg) -> Right <$> B.hPut h (C.authTagToBS authTag <> msg) tGetEncrypted :: THandle -> IO (Either TransportError ByteString) tGetEncrypted THandle {handle = h, rcvKey, blockSize} = B.hGet h blockSize >>= decryptBlock rcvKey >>= \case - Left e -> pure . Left $ TransportCryptoError e + Left _ -> pure $ Left TEDecrypt Right "" -> ioe_EOF Right msg -> pure $ Right msg @@ -207,11 +245,11 @@ serverHandshake h (k, pk) = do receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString receiveEncryptedKeys_4 = liftIO (B.hGet h $ C.publicKeySize k) >>= \case - "" -> throwE $ TransportHandshakeError "EOF" + "" -> throwE $ TEHandshake TERMINATED ks -> pure ks decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO HandshakeKeys decryptParseKeys_5 encKeys = - liftError TransportCryptoError (C.decryptOAEP pk encKeys) + liftError (const $ TEHandshake DECRYPT) (C.decryptOAEP pk encKeys) >>= liftEither . parseHandshakeKeys sendWelcome_6 :: THandle -> ExceptT TransportError IO () sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " " @@ -233,11 +271,11 @@ clientHandshake h keyHash = do maybe (pure ()) (validateKeyHash_2 s) keyHash liftEither $ parseKey s parseKey :: ByteString -> Either TransportError C.PublicKey - parseKey = first TransportHandshakeError . parseAll C.pubKeyP + parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.pubKeyP validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO () validateKeyHash_2 k kHash | C.getKeyHash k == kHash = pure () - | otherwise = throwE $ TransportHandshakeError "wrong key hash" + | otherwise = throwE $ TEHandshake BAD_HASH generateKeys_3 :: IO HandshakeKeys generateKeys_3 = HandshakeKeys <$> generateKey <*> generateKey generateKey :: IO SessionKey @@ -247,16 +285,16 @@ clientHandshake h keyHash = do pure SessionKey {aesKey, baseIV, counter = undefined} sendEncryptedKeys_4 :: C.PublicKey -> HandshakeKeys -> ExceptT TransportError IO () sendEncryptedKeys_4 k keys = - liftError TransportCryptoError (C.encryptOAEP k $ serializeHandshakeKeys keys) + liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeHandshakeKeys keys) >>= liftIO . B.hPut h getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th parseSMPVersion :: ByteString -> Either TransportError SMPVersion - parseSMPVersion = first TransportHandshakeError . A.parseOnly (smpVersionP <* A.space) + parseSMPVersion = first (const $ TEHandshake VERSION) . A.parseOnly (smpVersionP <* A.space) checkVersion :: SMPVersion -> ExceptT TransportError IO () checkVersion smpVersion = when (major smpVersion > major currentSMPVersion) . throwE $ - TransportHandshakeError "SMP server version" + TEHandshake MAJOR_VERSION serializeHandshakeKeys :: HandshakeKeys -> ByteString serializeHandshakeKeys HandshakeKeys {sndKey, rcvKey} = @@ -275,7 +313,7 @@ handshakeKeysP = HandshakeKeys <$> keyP <*> keyP pure SessionKey {aesKey, baseIV, counter = undefined} parseHandshakeKeys :: ByteString -> Either TransportError HandshakeKeys -parseHandshakeKeys = parse handshakeKeysP $ TransportHandshakeError "parsing keys" +parseHandshakeKeys = parse handshakeKeysP $ TEHandshake AES_KEYS transportHandle :: Handle -> SessionKey -> SessionKey -> IO THandle transportHandle h sk rk = do diff --git a/src/Simplex/Messaging/errors.md b/src/Simplex/Messaging/errors.md new file mode 100644 index 000000000..6e83e6d61 --- /dev/null +++ b/src/Simplex/Messaging/errors.md @@ -0,0 +1,97 @@ +# Errors + +## Problems + +- using numbers and strings to indicate errors (in protocol and in code) - ErrorType, AgentErrorType, TransportError +- re-using the same type in multiple contexts (with some constructors not applicable to all contexts) - ErrorType + +## Error types + +### ErrorType (Protocol.hs) + +- BLOCK - incorrect block format or encoding +- CMD error - command is unknown or has invalid syntax, where `error` can be: + - PROHIBITED - server response sent from client or vice versa + - SYNTAX - error parsing command + - NO_AUTH - transmission has no required credentials (signature or queue ID) + - HAS_AUTH - transmission has not allowed credentials + - NO_QUEUE - transmission has not queue ID +- AUTH - command is not authorised (queue does not exist or signature verification failed). +- NO_MSG - acknowledging (ACK) the message without message +- INTERNAL - internal server error. +- DUPLICATE_ - it is used internally to signal that the queue ID is already used. This is NOT used in the protocol, instead INTERNAL is sent to the client. It has to be removed. + +### AgentErrorType (Agent/Transmission.hs) + +Some of these errors are not correctly serialized/parsed - see line 322 in Agent/Transmission.hs + +- CMD e - command or response error + - PROHIBITED - server response sent as client command (and vice versa) + - SYNTAX - command is unknown or has invalid syntax. + - NO_CONN - connection is required in the command (and absent) + - SIZE - incorrect message size of messages (when parsing SEND and MSG) + - LARGE -- message does not fit SMP block +- CONN e - connection errors + - UNKNOWN - connection alias not in database + - DUPLICATE - connection alias already exists + - SIMPLEX - connection is simplex, but operation requires another queue +- SMP ErrorType - forwarding SMP errors (SMPServerError) to the agent client +- BROKER e - SMP server errors + - RESPONSE ErrorType - invalid SMP server response + - UNEXPECTED - unexpected response + - NETWORK - network TCP connection error + - TRANSPORT TransportError -- handshake or other transport error + - TIMEOUT - command response timeout +- AGENT e - errors of other agents + - A_MESSAGE - SMP message failed to parse + - A_PROHIBITED - SMP message is prohibited with the current queue status + - A_ENCRYPTION - cannot RSA/AES-decrypt or parse decrypted header +- INTERNAL ByteString - agent implementation or dependency error + +### SMPClientError (Client.hs) + +- SMPServerError ErrorType - this is correctly parsed server ERR response. This error is forwarded to the agent client as `ERR SMP err` +- SMPResponseError ErrorType - this is invalid server response that failed to parse - forwarded to the client as `ERR BROKER RESPONSE`. +- SMPUnexpectedResponse - different response from what is expected to a given command, e.g. server should respond `IDS` or `ERR` to `NEW` command, other responses would result in this error - forwarded to the client as `ERR BROKER UNEXPECTED`. +- SMPResponseTimeout - used for TCP connection and command response timeouts -> `ERR BROKER TIMEOUT`. +- SMPNetworkError - fails to establish TCP connection -> `ERR BROKER NETWORK` +- SMPTransportError e - fails connection handshake or some other transport error -> `ERR BROKER TRANSPORT e` +- SMPSignatureError C.CryptoError - error when cryptographically "signing" the command. + +### StoreError (Agent/Store.hs) + +- SEInternal ByteString - signals exceptions in store actions. +- SEConnNotFound - connection alias not found (or both queues absent). +- SEConnDuplicate - connection alias already used. +- SEBadConnType ConnType - wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. `updateRcvConnWithSndQueue` and `updateSndConnWithRcvQueue` do not allow duplex connections - they would also return this error. +- SEBadQueueStatus - the intention was to pass current expected queue status in methods, as we always know what it should be at any stage of the protocol, and in case it does not match use this error. **Currently not used**. +- SENotImplemented - used in `getMsg` that is not implemented/used. + +### CryptoError (Crypto.hs) + +- RSAEncryptError R.Error - RSA encryption error +- RSADecryptError R.Error - RSA decryption error +- RSASignError R.Error - RSA signature error +- AESCipherError CE.CryptoError - AES initialization error +- CryptoIVError - IV generation error +- AESDecryptError - AES decryption error +- CryptoLargeMsgError - message does not fit in SMP block +- CryptoHeaderError String - failure parsing RSA-encrypted message header + +### TransportError (Transport.hs) + + - TEBadBlock - error parsing block + - TEEncrypt - block encryption error + - TEDecrypt - block decryption error + - TEHandshake HandshakeError + +### HandshakeError (Transport.hs) + + - ENCRYPT - encryption error + - DECRYPT - decryption error + - VERSION - error parsing protocol version + - RSA_KEY - error parsing RSA key + - AES_KEYS - error parsing AES keys + - BAD_HASH - not matching RSA key hash + - MAJOR_VERSION - lower agent version than protocol version + - TERMINATED - transport terminated diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index ea33e6f2f..ac8686439 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -132,7 +132,7 @@ samplePublicKey = "256,ppr3DCweAD3RTVFhU2j0u+DnYdqJl1qCdKLHIKsPl1xBzfmnzK0o9GEDl syntaxTests :: Spec syntaxTests = do - it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR SYNTAX 11") + it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX") describe "NEW" do describe "valid" do -- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided) @@ -143,9 +143,9 @@ syntaxTests = do it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> teshKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False describe "invalid" do -- TODO: add tests with defined connection alias - it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR SYNTAX 11") - it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR SYNTAX 11") - it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR SYNTAX 11") + it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR CMD SYNTAX") + it "many parameters" $ ("222", "", "NEW localhost:5000 hi") >#> ("222", "", "ERR CMD SYNTAX") + it "invalid server keyHash" $ ("223", "", "NEW localhost:5000#1") >#> ("223", "", "ERR CMD SYNTAX") describe "JOIN" do describe "valid" do @@ -155,4 +155,4 @@ syntaxTests = do ("311", "", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "", "ERR SMP AUTH") describe "invalid" do -- TODO: JOIN is not merged yet - to be added - it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR SYNTAX 11") + it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index ef9f59e14..8f33d5965 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -66,14 +66,14 @@ storeTests = withStore do describe "setRcvQueueStatus" testSetRcvQueueStatus describe "setSndQueueStatus" testSetSndQueueStatus describe "DuplexConnection" testSetQueueStatusDuplex - xdescribe "RcvQueue doesn't exist" testSetRcvQueueStatusNoQueue - xdescribe "SndQueue doesn't exist" testSetSndQueueStatusNoQueue + xdescribe "RcvQueue does not exist" testSetRcvQueueStatusNoQueue + xdescribe "SndQueue does not exist" testSetSndQueueStatusNoQueue describe "createRcvMsg" do describe "RcvQueue exists" testCreateRcvMsg - describe "RcvQueue doesn't exist" testCreateRcvMsgNoQueue + describe "RcvQueue does not exist" testCreateRcvMsgNoQueue describe "createSndMsg" do describe "SndQueue exists" testCreateSndMsg - describe "SndQueue doesn't exist" testCreateSndMsgNoQueue + describe "SndQueue does not exist" testCreateSndMsgNoQueue testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = do @@ -175,7 +175,7 @@ testDeleteRcvConn = do `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = do @@ -188,7 +188,7 @@ testDeleteSndConn = do `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = do @@ -203,7 +203,7 @@ testDeleteDuplexConn = do `returnsResult` () -- TODO check queues are deleted as well getConn store "conn1" - `throwsError` SEBadConn + `throwsError` SEConnNotFound testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = do @@ -298,15 +298,15 @@ testSetQueueStatusDuplex = do testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = do - it "should throw error on attempt to update status of nonexistent RcvQueue" $ \store -> do + it "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do setRcvQueueStatus store rcvQueue1 Confirmed - `throwsError` SEInternal + `throwsError` SEInternal "" testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore testSetSndQueueStatusNoQueue = do - it "should throw error on attempt to update status of nonexistent SndQueue" $ \store -> do + it "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do setSndQueueStatus store sndQueue1 Confirmed - `throwsError` SEInternal + `throwsError` SEInternal "" testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = do @@ -323,7 +323,7 @@ testCreateRcvMsgNoQueue = do it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEBadConn + `throwsError` SEConnNotFound createSndConn store sndQueue1 `returnsResult` () createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) @@ -344,7 +344,7 @@ testCreateSndMsgNoQueue = do it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEBadConn + `throwsError` SEConnNotFound createRcvConn store rcvQueue1 `returnsResult` () createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts diff --git a/tests/ProtocolErrorTests.hs b/tests/ProtocolErrorTests.hs new file mode 100644 index 000000000..82e4afbf7 --- /dev/null +++ b/tests/ProtocolErrorTests.hs @@ -0,0 +1,18 @@ +module ProtocolErrorTests where + +import Simplex.Messaging.Agent.Transmission (AgentErrorType, agentErrorTypeP, serializeAgentError) +import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Protocol (ErrorType, errorTypeP, serializeErrorType) +import Test.Hspec +import Test.Hspec.QuickCheck (modifyMaxSuccess) +import Test.QuickCheck + +protocolErrorTests :: Spec +protocolErrorTests = modifyMaxSuccess (const 1000) $ do + describe "errors parsing / serializing" $ do + it "should parse SMP protocol errors" . property $ \err -> + parseAll errorTypeP (serializeErrorType err) + == Right (err :: ErrorType) + it "should parse SMP agent errors" . property $ \err -> + parseAll agentErrorTypeP (serializeAgentError err) + == Right (err :: AgentErrorType) diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 5f7175871..9bf0bf187 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -6,23 +6,19 @@ module SMPAgentClient where -import Control.Monad import Control.Monad.IO.Unlift import Crypto.Random import Network.Socket (HostName, ServiceName) -import SMPClient (testPort, withSmpServer, withSmpServerThreadOn) +import SMPClient (serverBracket, testPort, withSmpServer, withSmpServerThreadOn) import Simplex.Messaging.Agent (runSMPAgentBlocking) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig) import Simplex.Messaging.Transport -import System.Timeout (timeout) import Test.Hspec import UnliftIO.Concurrent import UnliftIO.Directory -import qualified UnliftIO.Exception as E import UnliftIO.IO -import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar) agentTestHost :: HostName agentTestHost = "localhost" @@ -125,12 +121,10 @@ cfg = } withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> (ThreadId -> m a) -> m a -withSmpAgentThreadOn (port', db') f = do - started <- newEmptyTMVarIO - E.bracket - (forkIOWithUnmask ($ runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'})) - (liftIO . killThread >=> const (removeFile db')) - \x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x +withSmpAgentThreadOn (port', db') = + serverBracket + (\started -> runSMPAgentBlocking started cfg {tcpPort = port', dbFile = db'}) + (removeFile db') withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => (ServiceName, String) -> m a -> m a withSmpAgentOn (port', db') = withSmpAgentThreadOn (port', db') . const diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 0e24a53f7..c6e6af28e 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -18,11 +18,11 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server (runSMPServerBlocking) import Simplex.Messaging.Server.Env.STM import Simplex.Messaging.Transport -import System.Timeout (timeout) import Test.Hspec import UnliftIO.Concurrent import qualified UnliftIO.Exception as E -import UnliftIO.STM (atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.Timeout (timeout) testHost :: HostName testHost = "localhost" @@ -83,12 +83,23 @@ cfg = } withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a -withSmpServerThreadOn port f = do +withSmpServerThreadOn port = + serverBracket + (\started -> runSMPServerBlocking started cfg {tcpPort = port}) + (pure ()) + +serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a +serverBracket process afterProcess f = do started <- newEmptyTMVarIO E.bracket - (forkIOWithUnmask ($ runSMPServerBlocking started cfg {tcpPort = port})) - (liftIO . killThread) - \x -> liftIO (5_000_000 `timeout` atomically (takeTMVar started)) >> f x + (forkIOWithUnmask ($ process started)) + (\t -> killThread t >> afterProcess >> waitFor started "stop") + (\t -> waitFor started "start" >> f t) + where + waitFor started s = + 5_000_000 `timeout` atomically (takeTMVar started) >>= \case + Nothing -> error $ "server did not " <> s + _ -> pure () withSmpServerOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> m a -> m a withSmpServerOn port = withSmpServerThreadOn port . const diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 24c6af342..50d3759cf 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -72,7 +72,7 @@ testCreateSecure = (ok4, OK) #== "replies OK when message acknowledged if no more messages" Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, "ACK") - (err6, ERR PROHIBITED) #== "replies ERR when message acknowledged without messages" + (err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages" (sPub, sKey) <- C.generateKeyPair rsaKeySize Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, "SEND 5 hello ") @@ -252,7 +252,7 @@ testSwitchSub = (msg3, "test3") #== "delivered to the 2nd TCP connection" Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, "ACK") - (err, ERR PROHIBITED) #== "rejects ACK from the 1st TCP connection" + (err, ERR NO_MSG) #== "rejects ACK from the 1st TCP connection" Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, "ACK") (ok3, OK) #== "accepts ACK from the 2nd TCP connection" @@ -263,18 +263,18 @@ testSwitchSub = syntaxTests :: Spec syntaxTests = do - it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR SYNTAX 2") + it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX") describe "NEW" do - it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR SYNTAX 2") - it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR SYNTAX 2") - it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR SYNTAX 3") - it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR SYNTAX 4") + it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX") + it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR CMD SYNTAX") + it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR CMD NO_AUTH") + it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH") describe "KEY" do it "valid syntax" $ ("1234", "bcda", "12345678", "KEY 3,4567,4567") >#> ("", "bcda", "12345678", "ERR AUTH") - it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR SYNTAX 3") - it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR SYNTAX 3") + it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") + it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR CMD NO_AUTH") noParamsSyntaxTest "SUB" noParamsSyntaxTest "ACK" noParamsSyntaxTest "OFF" @@ -282,19 +282,19 @@ syntaxTests = do describe "SEND" do it "valid syntax 1" $ ("1234", "cdab", "12345678", "SEND 5 hello ") >#> ("", "cdab", "12345678", "ERR AUTH") it "valid syntax 2" $ ("1234", "dabc", "12345678", "SEND 11 hello there ") >#> ("", "dabc", "12345678", "ERR AUTH") - it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") - it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR SYNTAX 5") - it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR SYNTAX 2") - it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR SYNTAX 2") - it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR SYNTAX 2") + it "no parameters" $ ("1234", "abcd", "12345678", "SEND") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") + it "no queue ID" $ ("1234", "bcda", "", "SEND 5 hello ") >#> ("", "bcda", "", "ERR CMD NO_QUEUE") + it "bad message body 1" $ ("1234", "cdab", "12345678", "SEND 11 hello ") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") + it "bad message body 2" $ ("1234", "dabc", "12345678", "SEND hello ") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "bigger body" $ ("1234", "abcd", "12345678", "SEND 4 hello ") >#> ("", "abcd", "12345678", "ERR CMD SYNTAX") describe "PING" do it "valid syntax" $ ("", "abcd", "", "PING") >#> ("", "abcd", "", "PONG") describe "broker response not allowed" do - it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR PROHIBITED") + it "OK" $ ("1234", "bcda", "12345678", "OK") >#> ("", "bcda", "12345678", "ERR CMD PROHIBITED") where noParamsSyntaxTest :: ByteString -> Spec noParamsSyntaxTest cmd = describe (B.unpack cmd) do it "valid syntax" $ ("1234", "abcd", "12345678", cmd) >#> ("", "abcd", "12345678", "ERR AUTH") - it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR SYNTAX 2") - it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR SYNTAX 3") - it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR SYNTAX 3") + it "wrong terminator" $ ("1234", "bcda", "12345678", cmd <> "=") >#> ("", "bcda", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "cdab", "12345678", cmd) >#> ("", "cdab", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "dabc", "", cmd) >#> ("", "dabc", "", "ERR CMD NO_AUTH") diff --git a/tests/Test.hs b/tests/Test.hs index 06f495d99..988b734f6 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,5 +1,6 @@ import AgentTests import MarkdownTests +import ProtocolErrorTests import ServerTests import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive) import Test.Hspec @@ -9,6 +10,7 @@ main = do createDirectoryIfMissing False "tests/tmp" hspec $ do describe "SimpleX markdown" markdownTests + describe "Protocol errors" protocolErrorTests describe "SMP server" serverTests describe "SMP client agent" agentTests removeDirectoryRecursive "tests/tmp" From 40ad6db51a0209702f34953caff469410c2cbda1 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin <8711996+efim-poberezkin@users.noreply.github.com> Date: Mon, 19 Apr 2021 00:46:01 +0400 Subject: [PATCH 06/17] return error on creation of duplicate connection (#102) Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> --- src/Simplex/Messaging/Agent/Store/SQLite.hs | 20 +++++++++++------ tests/AgentTests/SQLiteTests.hs | 24 +++++++++++++++++++-- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 1173e8215..e2fd08e45 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} @@ -21,6 +22,7 @@ where import Control.Monad (when) import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO)) import Control.Monad.IO.Unlift (MonadUnliftIO) +import Data.Bifunctor (first) import Data.List (find) import Data.Maybe (fromMaybe) import Data.Text (isPrefixOf) @@ -38,7 +40,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema) import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Util (liftIOEither) +import Simplex.Messaging.Util (bshow, liftIOEither) import System.Exit (ExitCode (ExitFailure), exitWith) import System.FilePath (takeDirectory) import Text.Read (readMaybe) @@ -75,16 +77,20 @@ connectSQLiteStore dbFilePath = do liftIO $ DB.execute_ dbConn "PRAGMA foreign_keys = ON;" return SQLiteStore {dbFilePath, dbConn} +checkDuplicate :: (MonadUnliftIO m, MonadError StoreError m) => IO () -> m () +checkDuplicate action = liftIOEither $ first handleError <$> E.try action + where + handleError :: SQLError -> StoreError + handleError e + | DB.sqlError e == DB.ErrorConstraint = SEConnDuplicate + | otherwise = SEInternal $ bshow e + instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where createRcvConn :: SQLiteStore -> RcvQueue -> m () - createRcvConn SQLiteStore {dbConn} rcvQueue = - liftIO $ - createRcvQueueAndConn dbConn rcvQueue + createRcvConn SQLiteStore {dbConn} = checkDuplicate . createRcvQueueAndConn dbConn createSndConn :: SQLiteStore -> SndQueue -> m () - createSndConn SQLiteStore {dbConn} sndQueue = - liftIO $ - createSndQueueAndConn dbConn sndQueue + createSndConn SQLiteStore {dbConn} = checkDuplicate . createSndQueueAndConn dbConn getConn :: SQLiteStore -> ConnAlias -> m SomeConn getConn SQLiteStore {dbConn} connAlias = do diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 8f33d5965..a63073c13 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -52,8 +52,12 @@ storeTests = withStore do describe "compiled as threadsafe" testCompiledThreadsafe describe "foreign keys enabled" testForeignKeysEnabled describe "store methods" do - describe "createRcvConn" testCreateRcvConn - describe "createSndConn" testCreateSndConn + describe "createRcvConn" do + describe "unique" testCreateRcvConn + describe "duplicate" testCreateRcvConnDuplicate + describe "createSndConn" do + describe "unique" testCreateSndConn + describe "duplicate" testCreateSndConnDuplicate describe "getAllConnAliases" testGetAllConnAliases describe "getRcvQueue" testGetRcvQueue describe "deleteConn" do @@ -132,6 +136,14 @@ testCreateRcvConn = do getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) +testCreateRcvConnDuplicate :: SpecWith SQLiteStore +testCreateRcvConnDuplicate = do + it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do + createRcvConn store rcvQueue1 + `returnsResult` () + createRcvConn store rcvQueue1 + `throwsError` SEConnDuplicate + testCreateSndConn :: SpecWith SQLiteStore testCreateSndConn = do it "should create SndConnection and add RcvQueue" $ \store -> do @@ -144,6 +156,14 @@ testCreateSndConn = do getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) +testCreateSndConnDuplicate :: SpecWith SQLiteStore +testCreateSndConnDuplicate = do + it "should throw error on attempt to create duplicate SndConnection" $ \store -> do + createSndConn store sndQueue1 + `returnsResult` () + createSndConn store sndQueue1 + `throwsError` SEConnDuplicate + testGetAllConnAliases :: SpecWith SQLiteStore testGetAllConnAliases = do it "should get all conn aliases" $ \store -> do From 3187bc8140ab7d5e69871255a157f443c89b1ba7 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 19 Apr 2021 08:40:23 +0100 Subject: [PATCH 07/17] chat: add connection errors in chat, fix catch (#103) --- apps/dog-food/Main.hs | 6 ++++++ src/Simplex/Messaging/Agent.hs | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/apps/dog-food/Main.hs b/apps/dog-food/Main.hs index b7127be0d..f4100e9a1 100644 --- a/apps/dog-food/Main.hs +++ b/apps/dog-food/Main.hs @@ -90,6 +90,7 @@ data ChatResponse | ReceivedMessage Contact ByteString | Disconnected Contact | YesYes + | ContactError ConnectionErrorType Contact | ErrorInput ByteString | ChatError AgentErrorType | NoChatResponse @@ -110,6 +111,10 @@ serializeChatResponse = \case ReceivedMessage c t -> prependFirst (ttyFromContact c) $ msgPlain t Disconnected c -> ["disconnected from " <> ttyContact c <> " - try \"/chat " <> bPlain (toBs c) <> "\""] YesYes -> ["you got it!"] + ContactError e c -> case e of + UNKNOWN -> ["no contact " <> ttyContact c] + DUPLICATE -> ["contact " <> ttyContact c <> " already exists"] + SIMPLEX -> ["contact " <> ttyContact c <> " did not accept invitation yet"] ErrorInput t -> ["invalid input: " <> bPlain t] ChatError e -> ["chat error: " <> plain (show e)] NoChatResponse -> [""] @@ -256,6 +261,7 @@ receiveFromAgent t ct c = forever . atomically $ do MSG {m_body} -> ReceivedMessage contact m_body SENT _ -> NoChatResponse OK -> Confirmation contact + ERR (CONN e) -> ContactError e contact ERR e -> ChatError e where contact = Contact a diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 59ca26015..3a946529c 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -26,6 +26,7 @@ import qualified Data.ByteString.Char8 as B import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) import Data.Time.Clock +import Database.SQLite.Simple (SQLError) import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Store @@ -39,7 +40,6 @@ import Simplex.Messaging.Transport (putLn, runTCPServer) import Simplex.Messaging.Util (bshow, liftError) import System.IO (Handle) import UnliftIO.Async (race_) -import UnliftIO.Exception (SomeException) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -116,7 +116,7 @@ withStore action = do Right c -> return c Left e -> throwError $ storeError e where - handleInternal :: (MonadError StoreError m') => SomeException -> m' a + handleInternal :: (MonadError StoreError m') => SQLError -> m' a handleInternal e = throwError . SEInternal $ bshow e storeError :: StoreError -> AgentErrorType storeError = \case From cddff787196bb47edaa1b88f549eec73470039ff Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 24 Apr 2021 12:46:57 +0100 Subject: [PATCH 08/17] binary X509 encoding for RSA key send during transport handshake (#105) --- apps/smp-server/Main.hs | 2 +- rfcs/2021-01-26-crypto.md | 7 ++- src/Simplex/Messaging/Crypto.hs | 37 ++++++++++------ src/Simplex/Messaging/Transport.hs | 68 +++++++++++++++++++++++------- src/Simplex/Messaging/Util.hs | 6 +-- tests/SMPClient.hs | 4 +- 6 files changed, 89 insertions(+), 35 deletions(-) diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 47cecf056..ee7801166 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -71,4 +71,4 @@ readCreateKey = do errorExit e = putStrLn (e <> ": " <> path) >> exitFailure publicKeyHash :: C.PublicKey -> B.ByteString -publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.serializePubKey +publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.binaryEncodePubKey diff --git a/rfcs/2021-01-26-crypto.md b/rfcs/2021-01-26-crypto.md index c7f5f2c9d..578a86c7f 100644 --- a/rfcs/2021-01-26-crypto.md +++ b/rfcs/2021-01-26-crypto.md @@ -22,7 +22,7 @@ To establish the session keys and base IVs, the server should have an asymmetric The handshake sequence is the following: -1. Once the connection is established, the server sends its public 2048 bit key to the client. TODO currently the key will be sent as a line terminated with CRLF, using ad-hoc key serialization we use. +1. Once the connection is established, the server sends transport_header and its public RSA key encoded in X509 binary format to the client. 2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. TODO as the hash is optional in server syntax at the moment, hash comparison will be optional as well. Probably it should become required. 3. If the hash is the same, the client should generate random symmetric AES keys and base IVs that will be used as session keys/IVs by the client and the server. 4. The client then should encrypt these symmetric keys and base IVs with the public key that the server sent, and send to the server the result of the encryption: `rsa-encrypt(snd-aes-key, snd-base-iv, rcv-aes-key, rcv-base-iv)`. `snd-aes-key` and `snd-base-iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv-aes-key` and `rcv-base-iv` will be used by the client to decrypt **received** messages and by the server to encrypt them. @@ -34,6 +34,11 @@ All the subsequent data both from the client and from the server should be sent Each transport block sent by the client and the server has this syntax: ```abnf +transport_header = block_size protocol key_size +block_size = 4*4(OCTET) ; 4-byte block size sent by the server, currently the client rejects if > 65536 bytes +protocol = 2*2(OCTET) ; currently it is 0, that means binary RSA key +key_size = 2*2(OCTET) ; the encoded key size in bytes (binary encoded in X509 standard) + transport_block = aes_body_auth_tag aes_encrypted_body ; fixed at 8192 bits aes_encrypted_body = 1*OCTET aes_body_auth_tag = 16*16(OCTET) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 766c46632..ce92763e0 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -33,10 +33,12 @@ module Simplex.Messaging.Crypto decryptAES, serializePrivKey, serializePubKey, + binaryEncodePubKey, serializeKeyHash, getKeyHash, privKeyP, pubKeyP, + binaryPubKeyP, keyHashP, authTagSize, authTagToBS, @@ -179,8 +181,8 @@ generateKeyPair size = loop then loop else pure (PublicKey k, mkPrivateKey pk) -rsaPrivateSize :: PrivateKey k => k -> Int -rsaPrivateSize = R.public_size . R.private_pub . rsaPrivateKey +privateKeySize :: PrivateKey k => k -> Int +privateKeySize = R.public_size . R.private_pub . rsaPrivateKey publicKey :: FullPrivateKey -> PublicKey publicKey = PublicKey . R.private_pub . rsaPrivateKey @@ -258,7 +260,7 @@ encrypt k paddedSize msg = do decrypt :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO ByteString decrypt pk msg'' = do - let (encHeader, msg') = B.splitAt (rsaPrivateSize pk) msg'' + let (encHeader, msg') = B.splitAt (privateKeySize pk) msg'' header <- decryptOAEP pk encHeader Header {aesKey, ivBytes, authTag, msgSize} <- except $ parseHeader header msg <- decryptAES aesKey ivBytes msg' authTag @@ -342,6 +344,9 @@ serializePrivKey pk = "rsa:" <> encodePrivKey pk pubKeyP :: Parser PublicKey pubKeyP = keyP decodePubKey <|> legacyPubKeyP +binaryPubKeyP :: Parser PublicKey +binaryPubKeyP = either fail pure . binaryDecodePubKey =<< A.takeByteString + privKeyP :: PrivateKey k => Parser k privKeyP = keyP decodePrivKey <|> legacyPrivKeyP @@ -382,28 +387,34 @@ safeRsaPrivateKey (size, n, d) = } encodePubKey :: PublicKey -> ByteString -encodePubKey = encodeKey . PubKeyRSA . rsaPublicKey +encodePubKey = encode . binaryEncodePubKey + +binaryEncodePubKey :: PublicKey -> ByteString +binaryEncodePubKey = binaryEncodeKey . PubKeyRSA . rsaPublicKey encodePrivKey :: PrivateKey k => k -> ByteString -encodePrivKey = encodeKey . PrivKeyRSA . rsaPrivateKey +encodePrivKey = encode . binaryEncodeKey . PrivKeyRSA . rsaPrivateKey -encodeKey :: ASN1Object a => a -> ByteString -encodeKey k = encode . toStrict . encodeASN1 DER $ toASN1 k [] +binaryEncodeKey :: ASN1Object a => a -> ByteString +binaryEncodeKey k = toStrict . encodeASN1 DER $ toASN1 k [] decodePubKey :: ByteString -> Either String PublicKey -decodePubKey s = - decodeKey s >>= \case +decodePubKey = binaryDecodePubKey <=< decode + +binaryDecodePubKey :: ByteString -> Either String PublicKey +binaryDecodePubKey = + binaryDecodeKey >=> \case (PubKeyRSA k, []) -> Right $ PublicKey k r -> keyError r decodePrivKey :: PrivateKey k => ByteString -> Either String k -decodePrivKey s = - decodeKey s >>= \case +decodePrivKey = + decode >=> binaryDecodeKey >=> \case (PrivKeyRSA pk, []) -> Right $ mkPrivateKey pk r -> keyError r -decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) -decodeKey s = fromASN1 =<< first show . decodeASN1 DER . fromStrict =<< decode s +binaryDecodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) +binaryDecodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict keyError :: (a, [ASN1]) -> Either String b keyError = \case diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index d0468886e..7f18cc8ff 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -32,7 +32,7 @@ import GHC.IO.Exception (IOErrorType (..)) import GHC.IO.Handle.Internals (ioe_EOF) import Generic.Random (genericArbitraryU) import Network.Socket -import Network.Transport.Internal (encodeWord32) +import Network.Transport.Internal (decodeNum16, decodeNum32, encodeEnum16, encodeEnum32, encodeWord32) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Parsers (parse, parseAll, parseRead1) import Simplex.Messaging.Util (bshow, liftError) @@ -172,6 +172,7 @@ data HandshakeError | DECRYPT | VERSION | RSA_KEY + | HEADER | AES_KEYS | BAD_HASH | MAJOR_VERSION @@ -233,15 +234,18 @@ makeNextIV SessionKey {baseIV, counter} = atomically $ do -- The numbers in function names refer to the steps in the document serverHandshake :: Handle -> C.FullKeyPair -> ExceptT TransportError IO THandle serverHandshake h (k, pk) = do - liftIO sendPublicKey_1 + liftIO sendHeaderAndPublicKey_1 encryptedKeys <- receiveEncryptedKeys_4 HandshakeKeys {sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys - th <- liftIO $ transportHandle h rcvKey sndKey -- keys are swapped here + th <- liftIO $ transportHandle h rcvKey sndKey transportBlockSize -- keys are swapped here sendWelcome_6 th pure th where - sendPublicKey_1 :: IO () - sendPublicKey_1 = putLn h $ C.serializePubKey k + sendHeaderAndPublicKey_1 :: IO () + sendHeaderAndPublicKey_1 = do + let sKey = C.binaryEncodePubKey k + header = TransportHeader {blockSize = transportBlockSize, keySize = B.length sKey} + B.hPut h $ binaryTransportHeader header <> sKey receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString receiveEncryptedKeys_4 = liftIO (B.hGet h $ C.publicKeySize k) >>= \case @@ -258,20 +262,25 @@ serverHandshake h (k, pk) = do -- The numbers in function names refer to the steps in the document clientHandshake :: Handle -> Maybe C.KeyHash -> ExceptT TransportError IO THandle clientHandshake h keyHash = do - k <- getPublicKey_1_2 + (k, blkSize) <- getHeaderAndPublicKey_1_2 keys@HandshakeKeys {sndKey, rcvKey} <- liftIO generateKeys_3 sendEncryptedKeys_4 k keys - th <- liftIO $ transportHandle h sndKey rcvKey + th <- liftIO $ transportHandle h sndKey rcvKey blkSize getWelcome_6 th >>= checkVersion pure th where - getPublicKey_1_2 :: ExceptT TransportError IO C.PublicKey - getPublicKey_1_2 = do - s <- liftIO $ getLn h + getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (C.PublicKey, Int) + getHeaderAndPublicKey_1_2 = do + header <- liftIO (B.hGet h transportHeaderSize) + TransportHeader {blockSize, keySize} <- liftEither $ parse transportHeaderP (TEHandshake HEADER) header + when (blockSize < transportBlockSize || blockSize > maxTransportBlockSize) $ + throwError $ TEHandshake HEADER + s <- liftIO $ B.hGet h keySize maybe (pure ()) (validateKeyHash_2 s) keyHash - liftEither $ parseKey s + key <- liftEither $ parseKey s + pure (key, blockSize) parseKey :: ByteString -> Either TransportError C.PublicKey - parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.pubKeyP + parseKey = first (const $ TEHandshake RSA_KEY) . parseAll C.binaryPubKeyP validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO () validateKeyHash_2 k kHash | C.getKeyHash k == kHash = pure () @@ -296,6 +305,35 @@ clientHandshake h keyHash = do when (major smpVersion > major currentSMPVersion) . throwE $ TEHandshake MAJOR_VERSION +data TransportHeader = TransportHeader {blockSize :: Int, keySize :: Int} + deriving (Eq, Show) + +binaryRsaTransport :: Int +binaryRsaTransport = 0 + +transportBlockSize :: Int +transportBlockSize = 8192 + +maxTransportBlockSize :: Int +maxTransportBlockSize = 65536 + +transportHeaderSize :: Int +transportHeaderSize = 8 + +binaryTransportHeader :: TransportHeader -> ByteString +binaryTransportHeader TransportHeader {blockSize, keySize} = + encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> encodeEnum16 keySize + +transportHeaderP :: Parser TransportHeader +transportHeaderP = TransportHeader <$> int32 <* binaryRsaTransportP <*> int16 + where + int32 = decodeNum32 <$> A.take 4 + int16 = decodeNum16 <$> A.take 2 + binaryRsaTransportP = binaryRsa <$> int16 + binaryRsa :: Int -> Parser Int + binaryRsa 0 = pure 0 + binaryRsa _ = fail "unknown transport mode" + serializeHandshakeKeys :: HandshakeKeys -> ByteString serializeHandshakeKeys HandshakeKeys {sndKey, rcvKey} = serializeKey sndKey <> serializeKey rcvKey @@ -315,8 +353,8 @@ handshakeKeysP = HandshakeKeys <$> keyP <*> keyP parseHandshakeKeys :: ByteString -> Either TransportError HandshakeKeys parseHandshakeKeys = parse handshakeKeysP $ TEHandshake AES_KEYS -transportHandle :: Handle -> SessionKey -> SessionKey -> IO THandle -transportHandle h sk rk = do +transportHandle :: Handle -> SessionKey -> SessionKey -> Int -> IO THandle +transportHandle h sk rk blockSize = do sndCounter <- newTVarIO 0 rcvCounter <- newTVarIO 0 pure @@ -324,5 +362,5 @@ transportHandle h sk rk = do { handle = h, sndKey = sk {counter = sndCounter}, rcvKey = rk {counter = rcvCounter}, - blockSize = 8192 + blockSize } diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index e8397015d..b05e7ff45 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -39,11 +39,11 @@ infixl 4 <$$> bshow :: Show a => a -> ByteString bshow = B.pack . show -liftIOEither :: (MonadUnliftIO m, MonadError e m) => IO (Either e a) -> m a +liftIOEither :: (MonadIO m, MonadError e m) => IO (Either e a) -> m a liftIOEither a = liftIO a >>= liftEither -liftError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a +liftError :: (MonadIO m, MonadError e' m) => (e -> e') -> ExceptT e IO a -> m a liftError f = liftEitherError f . runExceptT -liftEitherError :: (MonadUnliftIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a +liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a liftEitherError f a = liftIOEither (first f <$> a) diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index c6e6af28e..360771df8 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -31,10 +31,10 @@ testPort :: ServiceName testPort = "5000" teshKeyHashStr :: B.ByteString -teshKeyHashStr = "p1xa/XuzchgqomEL6RX+Me+fX096w50V7nJPAA0wpDE=" +teshKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" teshKeyHash :: Maybe C.KeyHash -teshKeyHash = Just "p1xa/XuzchgqomEL6RX+Me+fX096w50V7nJPAA0wpDE=" +teshKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a testSMPClient client = From 5fec6c1755a1802f42a4f1df805e8c2534deed8f Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 26 Apr 2021 20:05:46 +0100 Subject: [PATCH 09/17] Sign and verify agent messages (#106) * sign and verify agent messages with key sent in HELLO (TODO: hardcoded block size - should use size from handshake; verify signature of HELLO message itself; possibly, different MSG status if signature was not verified (currently ignored) or failed to verify (currently fails with AGENT A_ENCRYPTION - alternatively, change it to AGENT A_SIGNATURE)) * remove hardcoded block size, make it 4096 bytes * verify signature of HELLO message before it is added to RcvQueue * refactor * update doc * rename functions --- rfcs/2021-01-26-crypto.md | 8 +- src/Simplex/Messaging/Agent.hs | 15 ++- src/Simplex/Messaging/Agent/Client.hs | 107 ++++++++++++-------- src/Simplex/Messaging/Agent/Env/SQLite.hs | 12 +-- src/Simplex/Messaging/Agent/Store.hs | 1 + src/Simplex/Messaging/Agent/Store/SQLite.hs | 24 +++++ src/Simplex/Messaging/Agent/Transmission.hs | 1 + src/Simplex/Messaging/Client.hs | 40 ++++---- src/Simplex/Messaging/Crypto.hs | 4 +- src/Simplex/Messaging/Transport.hs | 14 ++- src/Simplex/Messaging/errors.md | 1 + tests/ServerTests.hs | 3 +- 12 files changed, 144 insertions(+), 86 deletions(-) diff --git a/rfcs/2021-01-26-crypto.md b/rfcs/2021-01-26-crypto.md index 578a86c7f..280d42e0a 100644 --- a/rfcs/2021-01-26-crypto.md +++ b/rfcs/2021-01-26-crypto.md @@ -39,7 +39,8 @@ block_size = 4*4(OCTET) ; 4-byte block size sent by the server, currently the cl protocol = 2*2(OCTET) ; currently it is 0, that means binary RSA key key_size = 2*2(OCTET) ; the encoded key size in bytes (binary encoded in X509 standard) -transport_block = aes_body_auth_tag aes_encrypted_body ; fixed at 8192 bits +transport_block = aes_body_auth_tag aes_encrypted_body +; size is sent by server during handshake, usually 8192 bytes aes_encrypted_body = 1*OCTET aes_body_auth_tag = 16*16(OCTET) @@ -107,7 +108,9 @@ Symmetric keys are generated per message and encrypted with receiver's public ke The syntax of each encrypted message body is the following: ```abnf -encrypted_message_body = rsa_encrypted_header aes_encrypted_body +encrypted_message_body = rsa_signature encrypted_body +encrypted_body = rsa_encrypted_header aes_encrypted_body +rsa_signature = 256*256(OCTET) ; sign(encrypted_body) - assuming 2048 bit key size rsa_encrypted_header = 256*256(OCTET) ; encrypt(header) - assuming 2048 bit key size aes_encrypted_body = 1*OCTET ; encrypt(body) @@ -122,7 +125,6 @@ body = payload pad Future considerations: - Generation of symmetric keys per session and session rotation; -- Signature and verification of messages. ## E2E implementation diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 3a946529c..4c2bc5ded 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -37,7 +37,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (CorrId (..), MsgBody, SenderPublicKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport (putLn, runTCPServer) -import Simplex.Messaging.Util (bshow, liftError) +import Simplex.Messaging.Util (bshow) import System.IO (Handle) import UnliftIO.Async (race_) import qualified UnliftIO.Exception as E @@ -227,11 +227,11 @@ subscriber c@AgentClient {msgQ} st = forever $ do processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m () processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do - rq@RcvQueue {connAlias, decryptKey, status} <- withStore $ getRcvQueue st srv rId + rq@RcvQueue {connAlias, status} <- withStore $ getRcvQueue st srv rId case cmd of SMP.MSG srvMsgId srvTs msgBody -> do -- TODO deduplicate with previously received - agentMsg <- liftEither . parseSMPMessage =<< decryptMessage decryptKey msgBody + agentMsg <- liftEither . parseSMPMessage =<< decryptAndVerify rq msgBody case agentMsg of SMPConfirmation senderKey -> do logServer "<--" c srv rId "MSG " @@ -246,11 +246,13 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do _ -> notify connAlias . ERR $ AGENT A_PROHIBITED SMPMessage {agentMessage, senderMsgId, senderTimestamp} -> case agentMessage of - HELLO _verifyKey _ -> do + HELLO verifyKey _ -> do logServer "<--" c srv rId "MSG " case status of Active -> notify connAlias . ERR $ AGENT A_PROHIBITED - _ -> withStore $ setRcvQueueStatus st rq Active + _ -> do + void $ verifyMessage (Just verifyKey) msgBody + withStore $ setRcvQueueActive st rq verifyKey REPLY qInfo -> do logServer "<--" c srv rId "MSG " -- TODO move senderKey inside SndQueue @@ -296,9 +298,6 @@ connectToSendQueue c st sq senderKey verifyKey = do sendHello c sq verifyKey withStore $ setSndQueueStatus st sq Active -decryptMessage :: (MonadUnliftIO m, MonadError AgentErrorType m) => DecryptionKey -> ByteString -> m ByteString -decryptMessage decryptKey msg = liftError cryptoError $ C.decrypt decryptKey msg - newSendQueue :: (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey) newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 8c431f888..fd66d3652 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -19,6 +19,8 @@ module Simplex.Messaging.Agent.Client sendHello, secureQueue, sendAgentMessage, + decryptAndVerify, + verifyMessage, sendAck, suspendQueue, deleteQueue, @@ -121,25 +123,31 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m () closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient -withSMP :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a -withSMP c srv action = - (getSMPServerClient c srv >>= runAction) `catchError` logServerError +withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a +withSMP_ c srv action = + (getSMPServerClient c srv >>= action) `catchError` logServerError where - runAction :: SMPClient -> m a - runAction smp = liftError smpClientError $ action smp - logServerError :: AgentErrorType -> m a logServerError e = do logServer "<--" c srv "" $ bshow e throwError e -withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a -withLogSMP c srv qId cmdStr action = do +withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a +withLogSMP_ c srv qId cmdStr action = do logServer "-->" c srv qId cmdStr - res <- withSMP c srv action + res <- withSMP_ c srv action logServer "<--" c srv qId "OK" return res +withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a +withSMP c srv action = withSMP_ c srv $ liftSMP . action + +withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a +withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action + +liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a +liftSMP = liftError smpClientError + smpClientError :: SMPClientError -> AgentErrorType smpClientError = \case SMPServerError e -> SMP e @@ -212,27 +220,24 @@ logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> m () -sendConfirmation c SndQueue {server, sndId, encryptKey} senderKey = do - msg <- mkConfirmation - withLogSMP c server sndId "SEND " $ \smp -> - sendSMPMessage smp Nothing sndId msg +sendConfirmation c sq@SndQueue {server, sndId} senderKey = + withLogSMP_ c server sndId "SEND " $ \smp -> do + msg <- mkConfirmation smp + liftSMP $ sendSMPMessage smp Nothing sndId msg where - mkConfirmation :: m MsgBody - mkConfirmation = do - let msg = serializeSMPMessage $ SMPConfirmation senderKey - paddedSize <- asks paddedMsgSize - liftError cryptoError $ C.encrypt encryptKey paddedSize msg + mkConfirmation :: SMPClient -> m MsgBody + mkConfirmation smp = encryptAndSign smp sq $ SMPConfirmation senderKey sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () -sendHello c SndQueue {server, sndId, sndPrivateKey, encryptKey} verifyKey = do - msg <- mkHello $ AckMode On - withLogSMP c server sndId "SEND (retrying)" $ - send 8 100000 msg +sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = + withLogSMP_ c server sndId "SEND (retrying)" $ \smp -> do + msg <- mkHello smp $ AckMode On + liftSMP $ send 8 100000 msg smp where - mkHello :: AckMode -> m ByteString - mkHello ackMode = do + mkHello :: SMPClient -> AckMode -> m ByteString + mkHello smp ackMode = do senderTs <- liftIO getCurrentTime - mkAgentMessage encryptKey senderTs $ HELLO verifyKey ackMode + mkAgentMessage smp sq senderTs $ HELLO verifyKey ackMode send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () send 0 _ _ _ = throwE $ SMPServerError AUTH @@ -264,23 +269,43 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} = deleteSMPQueue smp rcvPrivateKey rcvId sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m () -sendAgentMessage c SndQueue {server, sndId, sndPrivateKey, encryptKey} senderTs agentMsg = do - msg <- mkAgentMessage encryptKey senderTs agentMsg - withLogSMP c server sndId "SEND " $ \smp -> - sendSMPMessage smp (Just sndPrivateKey) sndId msg +sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} senderTs agentMsg = + withLogSMP_ c server sndId "SEND " $ \smp -> do + msg <- mkAgentMessage smp sq senderTs agentMsg + liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg -mkAgentMessage :: AgentMonad m => EncryptionKey -> SenderTimestamp -> AMessage -> m ByteString -mkAgentMessage encKey senderTs agentMessage = do - let msg = - serializeSMPMessage - SMPMessage - { senderMsgId = 0, - senderTimestamp = senderTs, - previousMsgHash = "1234", -- TODO hash of the previous message - agentMessage - } - paddedSize <- asks paddedMsgSize - liftError cryptoError $ C.encrypt encKey paddedSize msg +mkAgentMessage :: AgentMonad m => SMPClient -> SndQueue -> SenderTimestamp -> AMessage -> m ByteString +mkAgentMessage smp sq senderTs agentMessage = do + encryptAndSign smp sq $ + SMPMessage + { senderMsgId = 0, + senderTimestamp = senderTs, + previousMsgHash = "1234", -- TODO hash of the previous message + agentMessage + } + +encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> SMPMessage -> m ByteString +encryptAndSign smp SndQueue {encryptKey, signKey} msg = do + paddedSize <- asks $ (blockSize smp -) . reservedMsgSize + liftError cryptoError $ do + enc <- C.encrypt encryptKey paddedSize $ serializeSMPMessage msg + C.Signature sig <- C.sign signKey enc + pure $ sig <> enc + +decryptAndVerify :: AgentMonad m => RcvQueue -> ByteString -> m ByteString +decryptAndVerify RcvQueue {decryptKey, verifyKey} msg = + verifyMessage verifyKey msg + >>= liftError cryptoError . C.decrypt decryptKey + +verifyMessage :: AgentMonad m => Maybe VerificationKey -> ByteString -> m ByteString +verifyMessage verifyKey msg = do + size <- asks $ rsaKeySize . config + let (sig, enc) = B.splitAt size msg + case verifyKey of + Nothing -> pure enc + Just k + | C.verify k (C.Signature sig) enc -> pure enc + | otherwise -> throwError $ AGENT A_SIGNATURE cryptoError :: C.CryptoError -> AgentErrorType cryptoError = \case diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 6ecd85ea2..b14bfb4a6 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -26,7 +26,7 @@ data Env = Env { config :: AgentConfig, idsDrg :: TVar ChaChaDRG, clientCounter :: TVar Int, - paddedMsgSize :: Int + reservedMsgSize :: Int } newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env @@ -34,10 +34,10 @@ newSMPAgentEnv config = do idsDrg <- drgNew >>= newTVarIO _ <- createSQLiteStore $ dbFile config clientCounter <- newTVarIO 0 - return Env {config, idsDrg, clientCounter, paddedMsgSize} + return Env {config, idsDrg, clientCounter, reservedMsgSize} where - -- one rsaKeySize is used by the RSA signature in each command, - -- another - by encrypted message body header + -- 1st rsaKeySize is used by the RSA signature in each command, + -- 2nd - by encrypted message body header + -- 3rd - by message signature -- smpCommandSize - is the estimated max size for SMP command, queueId, corrId - paddedMsgSize = blockSize smp - 2 * rsaKeySize config - smpCommandSize smp - smp = smpCfg config + reservedMsgSize = 3 * rsaKeySize config + smpCommandSize (smpCfg config) diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 7521b9fae..637b0fa81 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -39,6 +39,7 @@ class Monad m => MonadAgentStore s m where upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m () upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m () setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m () + setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m () setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m () -- Msg management diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index e2fd08e45..452654537 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -137,6 +137,11 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto liftIO $ updateRcvQueueStatus dbConn rcvQueue status + setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m () + setRcvQueueActive SQLiteStore {dbConn} rcvQueue verifyKey = + liftIO $ + updateRcvQueueActive dbConn rcvQueue verifyKey + setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () setSndQueueStatus SQLiteStore {dbConn} sndQueue status = liftIO $ @@ -477,6 +482,25 @@ updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} st |] [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] +-- * setRcvQueueActive helper + +-- ? throw error if queue doesn't exist? +updateRcvQueueActive :: DB.Connection -> RcvQueue -> VerificationKey -> IO () +updateRcvQueueActive dbConn RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET verify_key = :verify_key, status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [ ":verify_key" := Just verifyKey, + ":status" := Active, + ":host" := host, + ":port" := serializePort_ port, + ":rcv_id" := rcvId + ] + -- * setSndQueueStatus helper -- ? throw error if queue doesn't exist? diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 02706da30..3bc9a7592 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -283,6 +283,7 @@ data SMPAgentError = A_MESSAGE -- possibly should include bytestring that failed to parse | A_PROHIBITED -- possibly should include the prohibited SMP/agent message | A_ENCRYPTION -- cannot RSA/AES-decrypt or parse decrypted header + | A_SIGNATURE -- invalid RSA signature deriving (Eq, Generic, Read, Show, Exception) instance Arbitrary AgentErrorType where arbitrary = genericArbitraryU diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 1efd16443..8def2ab8a 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -9,7 +9,7 @@ {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Client - ( SMPClient, + ( SMPClient (blockSize), getSMPClient, closeSMPClient, createSMPQueue, @@ -44,7 +44,7 @@ import Simplex.Messaging.Agent.Transmission (SMPServer (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport -import Simplex.Messaging.Util (bshow, liftEitherError, raceAny_) +import Simplex.Messaging.Util (bshow, liftError, raceAny_) import System.IO import System.Timeout @@ -57,7 +57,8 @@ data SMPClient = SMPClient sentCommands :: TVar (Map CorrId Request), sndQ :: TBQueue SignedRawTransmission, rcvQ :: TBQueue SignedTransmissionOrError, - msgQ :: TBQueue SMPServerTransmission + msgQ :: TBQueue SMPServerTransmission, + blockSize :: Int } type SMPServerTransmission = (SMPServer, RecipientId, Command 'Broker) @@ -67,7 +68,6 @@ data SMPClientConfig = SMPClientConfig defaultPort :: ServiceName, tcpTimeout :: Int, smpPing :: Int, - blockSize :: Int, smpCommandSize :: Int } @@ -78,7 +78,6 @@ smpDefaultConfig = defaultPort = "5223", tcpTimeout = 4_000_000, smpPing = 30_000_000, - blockSize = 8_192, -- 16_384, smpCommandSize = 256 } @@ -96,15 +95,15 @@ getSMPClient msgQ disconnected = do c <- atomically mkSMPClient - err <- newEmptyTMVarIO + thVar <- newEmptyTMVarIO action <- async $ - runTCPClient host (fromMaybe defaultPort port) (client c err) - `finally` atomically (putTMVar err $ Just SMPNetworkError) - ok <- tcpTimeout `timeout` atomically (takeTMVar err) - pure $ case ok of - Just Nothing -> Right c {action} - Just (Just e) -> Left e + runTCPClient host (fromMaybe defaultPort port) (client c thVar) + `finally` atomically (putTMVar thVar $ Left SMPNetworkError) + tHandle <- tcpTimeout `timeout` atomically (takeTMVar thVar) + pure $ case tHandle of + Just (Right THandle {blockSize}) -> Right c {action, blockSize} + Just (Left e) -> Left e Nothing -> Left SMPNetworkError where mkSMPClient :: STM SMPClient @@ -117,6 +116,7 @@ getSMPClient return SMPClient { action = undefined, + blockSize = undefined, connected, smpServer, tcpTimeout, @@ -127,17 +127,17 @@ getSMPClient msgQ } - client :: SMPClient -> TMVar (Maybe SMPClientError) -> Handle -> IO () - client c err h = + client :: SMPClient -> TMVar (Either SMPClientError THandle) -> Handle -> IO () + client c thVar h = runExceptT (clientHandshake h keyHash) >>= \case - Right th -> clientTransport c err th - Left e -> atomically . putTMVar err . Just $ SMPTransportError e + Right th -> clientTransport c thVar th + Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e - clientTransport :: SMPClient -> TMVar (Maybe SMPClientError) -> THandle -> IO () - clientTransport c err th = do + clientTransport :: SMPClient -> TMVar (Either SMPClientError THandle) -> THandle -> IO () + clientTransport c thVar th = do atomically $ do writeTVar (connected c) True - putTMVar err Nothing + putTMVar thVar $ Right th raceAny_ [send c th, process c, receive c th, ping c] `finally` disconnected @@ -252,7 +252,7 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, tcpTimeout} pKey qId signTransmission t = case pKey of Nothing -> return ("", t) Just pk -> do - sig <- liftEitherError SMPSignatureError $ C.sign pk t + sig <- liftError SMPSignatureError $ C.sign pk t return (sig, t) -- two separate "atomically" needed to avoid blocking diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index ce92763e0..a6e5a3e07 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -329,8 +329,8 @@ decryptOAEP pk encKey = pssParams :: PSS.PSSParams SHA256 ByteString ByteString pssParams = PSS.defaultPSSParams SHA256 -sign :: PrivateKey k => k -> ByteString -> IO (Either CryptoError Signature) -sign pk msg = bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg +sign :: PrivateKey k => k -> ByteString -> ExceptT CryptoError IO Signature +sign pk msg = ExceptT $ bimap RSASignError Signature <$> PSS.signSafer pssParams (rsaPrivateKey pk) msg verify :: PublicKey -> Signature -> ByteString -> Bool verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 7f18cc8ff..5411e3da4 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -311,8 +311,11 @@ data TransportHeader = TransportHeader {blockSize :: Int, keySize :: Int} binaryRsaTransport :: Int binaryRsaTransport = 0 +binaryRsaTransportBS :: ByteString +binaryRsaTransportBS = encodeEnum16 binaryRsaTransport + transportBlockSize :: Int -transportBlockSize = 8192 +transportBlockSize = 4096 maxTransportBlockSize :: Int maxTransportBlockSize = 65536 @@ -322,7 +325,7 @@ transportHeaderSize = 8 binaryTransportHeader :: TransportHeader -> ByteString binaryTransportHeader TransportHeader {blockSize, keySize} = - encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> encodeEnum16 keySize + encodeEnum32 blockSize <> binaryRsaTransportBS <> encodeEnum16 keySize transportHeaderP :: Parser TransportHeader transportHeaderP = TransportHeader <$> int32 <* binaryRsaTransportP <*> int16 @@ -330,9 +333,10 @@ transportHeaderP = TransportHeader <$> int32 <* binaryRsaTransportP <*> int16 int32 = decodeNum32 <$> A.take 4 int16 = decodeNum16 <$> A.take 2 binaryRsaTransportP = binaryRsa <$> int16 - binaryRsa :: Int -> Parser Int - binaryRsa 0 = pure 0 - binaryRsa _ = fail "unknown transport mode" + binaryRsa :: Int -> Parser () + binaryRsa n + | n == binaryRsaTransport = pure () + | otherwise = fail "unknown transport mode" serializeHandshakeKeys :: HandshakeKeys -> ByteString serializeHandshakeKeys HandshakeKeys {sndKey, rcvKey} = diff --git a/src/Simplex/Messaging/errors.md b/src/Simplex/Messaging/errors.md index 6e83e6d61..6fba8bed4 100644 --- a/src/Simplex/Messaging/errors.md +++ b/src/Simplex/Messaging/errors.md @@ -46,6 +46,7 @@ Some of these errors are not correctly serialized/parsed - see line 322 in Agent - A_MESSAGE - SMP message failed to parse - A_PROHIBITED - SMP message is prohibited with the current queue status - A_ENCRYPTION - cannot RSA/AES-decrypt or parse decrypted header + - A_SIGNATURE - invalid RSA signature - INTERNAL ByteString - agent implementation or dependency error ### SMPClientError (Client.hs) diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 50d3759cf..fa4c70f9e 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -7,6 +7,7 @@ module ServerTests where +import Control.Monad.Except (runExceptT) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -40,7 +41,7 @@ sendRecv h (sgn, corrId, qId, cmd) = tPutRaw h (sgn, corrId, encode qId, cmd) >> signSendRecv :: THandle -> C.SafePrivateKey -> (ByteString, ByteString, ByteString) -> IO SignedTransmissionOrError signSendRecv h pk (corrId, qId, cmd) = do let t = B.intercalate " " [corrId, encode qId, cmd] - Right sig <- C.sign pk t + Right sig <- runExceptT $ C.sign pk t _ <- tPut h (sig, t) tGet fromServer h From 816703527a72ec19dadb5cccf53fdcca5eb68e6d Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 26 Apr 2021 20:18:20 +0100 Subject: [PATCH 10/17] set different default server (#107) * set different default server * remove comment --- apps/dog-food/ChatOptions.hs | 5 +++-- src/Simplex/Messaging/Agent.hs | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/dog-food/ChatOptions.hs b/apps/dog-food/ChatOptions.hs index 7b501b3ff..8d0a0560f 100644 --- a/apps/dog-food/ChatOptions.hs +++ b/apps/dog-food/ChatOptions.hs @@ -1,4 +1,5 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} module ChatOptions (getChatOpts, ChatOpts (..)) where @@ -30,8 +31,8 @@ chatOpts appDir = ( long "server" <> short 's' <> metavar "SERVER" - <> help "SMP server to use (smp.simplex.im:5223)" - <> value (SMPServer "smp.simplex.im" (Just "5223") Nothing) + <> help "SMP server to use (smp1.simplex.im:5223#pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA=)" + <> value (SMPServer "smp1.simplex.im" (Just "5223") (Just "pLdiGvm0jD1CMblnov6Edd/391OrYsShw+RgdfR0ChA=")) ) <*> option parseTermMode diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 4c2bc5ded..1bbdb87bd 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -255,7 +255,6 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do withStore $ setRcvQueueActive st rq verifyKey REPLY qInfo -> do logServer "<--" c srv rId "MSG " - -- TODO move senderKey inside SndQueue (sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias withStore $ upgradeRcvConnToDuplex st connAlias sq connectToSendQueue c st sq senderKey verifyKey From afc09a6ec4b73e27082d3a8eebe8249cb87884d6 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 26 Apr 2021 20:34:28 +0100 Subject: [PATCH 11/17] Store log (#108) * StoreLog (WIP) * add log records to map * revert Protocol change * revert Server change * fix parseLogRecord * optionally save/restore queues to/from store log * refactor * refactor delQueueAndMsgs * move store log to /var/opt/simplex * use ini file --- apps/dog-food/Main.hs | 1 + apps/smp-server/Main.hs | 116 +++++++++++++++-- package.yaml | 2 + src/Simplex/Messaging/Server.hs | 31 ++++- src/Simplex/Messaging/Server/Env/STM.hs | 22 +++- src/Simplex/Messaging/Server/QueueStore.hs | 2 +- src/Simplex/Messaging/Server/StoreLog.hs | 140 +++++++++++++++++++++ src/Simplex/Messaging/Transport.hs | 9 +- tests/SMPClient.hs | 2 + 9 files changed, 302 insertions(+), 23 deletions(-) create mode 100644 src/Simplex/Messaging/Server/StoreLog.hs diff --git a/apps/dog-food/Main.hs b/apps/dog-food/Main.hs index f4100e9a1..bf874285d 100644 --- a/apps/dog-food/Main.hs +++ b/apps/dog-food/Main.hs @@ -109,6 +109,7 @@ serializeChatResponse = \case Connected c -> [ttyContact c <> " connected"] Confirmation c -> [ttyContact c <> " ok"] ReceivedMessage c t -> prependFirst (ttyFromContact c) $ msgPlain t + -- TODO either add command to re-connect or update message below Disconnected c -> ["disconnected from " <> ttyContact c <> " - try \"/chat " <> bPlain (toBs c) <> "\""] YesYes -> ["you got it!"] ContactError e c -> case e of diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index ee7801166..3a4952ace 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -1,20 +1,28 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} module Main where -import Control.Monad (when) +import Control.Monad (unless, when) import qualified Crypto.Store.PKCS8 as S import qualified Data.ByteString.Char8 as B import Data.Char (toLower) +import Data.Functor (($>)) +import Data.Ini (lookupValue, readIniFile) +import qualified Data.Text as T import Data.X509 (PrivKey (PrivKeyRSA)) +import Options.Applicative import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Server (runSMPServer) import Simplex.Messaging.Server.Env.STM +import Simplex.Messaging.Server.StoreLog (StoreLog, openReadStoreLog) import System.Directory (createDirectoryIfMissing, doesFileExist) import System.Exit (exitFailure) import System.FilePath (combine) -import System.IO (hFlush, stdout) +import System.IO (IOMode (..), hFlush, stdout) cfg :: ServerConfig cfg = @@ -23,6 +31,7 @@ cfg = tbqSize = 16, queueIdBytes = 12, msgIdBytes = 6, + storeLog = Nothing, -- key is loaded from the file server_key in /etc/opt/simplex directory serverPrivateKey = undefined } @@ -33,12 +42,58 @@ newKeySize = 2048 `div` 8 cfgDir :: FilePath cfgDir = "/etc/opt/simplex" +logDir :: FilePath +logDir = "/var/opt/simplex" + +defaultStoreLogFile :: FilePath +defaultStoreLogFile = combine logDir "smp-server-store.log" + main :: IO () main = do + opts <- getServerOpts + putStrLn "SMP Server (-h for help)" + ini <- readCreateIni opts + storeLog <- openStoreLog ini pk <- readCreateKey - B.putStrLn $ "SMP transport key hash: " <> publicKeyHash (C.publicKey pk) - putStrLn $ "Listening on port " <> tcpPort cfg - runSMPServer cfg {serverPrivateKey = pk} + B.putStrLn $ "transport key hash: " <> publicKeyHash (C.publicKey pk) + putStrLn $ "listening on port " <> tcpPort cfg + runSMPServer cfg {serverPrivateKey = pk, storeLog} + +data IniOpts = IniOpts + { enableStoreLog :: Bool, + storeLogFile :: FilePath + } + +readCreateIni :: ServerOpts -> IO IniOpts +readCreateIni ServerOpts {configFile} = do + createDirectoryIfMissing True cfgDir + doesFileExist configFile >>= (`unless` createIni) + readIni + where + readIni :: IO IniOpts + readIni = do + ini <- either exitError pure =<< readIniFile configFile + let enableStoreLog = (== Right "on") $ lookupValue "STORE_LOG" "enable" ini + storeLogFile = either (const defaultStoreLogFile) T.unpack $ lookupValue "STORE_LOG" "file" ini + pure IniOpts {enableStoreLog, storeLogFile} + exitError e = do + putStrLn $ "error reading config file " <> configFile <> ": " <> e + exitFailure + createIni :: IO () + createIni = do + confirm $ "Save default ini file to " <> configFile + writeFile + configFile + "[STORE_LOG]\n\ + \# The server uses STM memory to store SMP queues and messages,\n\ + \# that will be lost on restart (e.g., as with redis).\n\ + \# This option enables saving SMP queues to append only log,\n\ + \# and restoring them when the server is started.\n\ + \# Log is compacted on start (deleted queues are removed).\n\ + \# The messages in the queues are not logged.\n\ + \\n\ + \# enable: on\n\ + \# file: /var/opt/simplex/smp-server-store.log\n" readCreateKey :: IO C.FullPrivateKey readCreateKey = do @@ -49,16 +104,10 @@ readCreateKey = do where createKey :: FilePath -> IO C.FullPrivateKey createKey path = do - confirm + confirm "Generate new server key pair" (_, pk) <- C.generateKeyPair newKeySize S.writeKeyFile S.TraditionalFormat path [PrivKeyRSA $ C.rsaPrivateKey pk] pure pk - confirm :: IO () - confirm = do - putStr "Generate new server key pair (y/N): " - hFlush stdout - ok <- getLine - when (map toLower ok /= "y") exitFailure readKey :: FilePath -> IO C.FullPrivateKey readKey path = do S.readKeyFile path >>= \case @@ -70,5 +119,48 @@ readCreateKey = do errorExit :: String -> IO b errorExit e = putStrLn (e <> ": " <> path) >> exitFailure +confirm :: String -> IO () +confirm msg = do + putStr $ msg <> " (y/N): " + hFlush stdout + ok <- getLine + when (map toLower ok /= "y") exitFailure + publicKeyHash :: C.PublicKey -> B.ByteString publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.binaryEncodePubKey + +openStoreLog :: IniOpts -> IO (Maybe (StoreLog 'ReadMode)) +openStoreLog IniOpts {enableStoreLog, storeLogFile = f} + | enableStoreLog = do + createDirectoryIfMissing True logDir + putStrLn ("store log: " <> f) + Just <$> openReadStoreLog f + | otherwise = putStrLn "store log disabled" $> Nothing + +newtype ServerOpts = ServerOpts + { configFile :: FilePath + } + +serverOpts :: Parser ServerOpts +serverOpts = + ServerOpts + <$> strOption + ( long "config" + <> short 'c' + <> metavar "INI_FILE" + <> help ("config file (" <> defaultIniFile <> ")") + <> value defaultIniFile + ) + where + defaultIniFile = combine cfgDir "smp-server.ini" + +getServerOpts :: IO ServerOpts +getServerOpts = execParser opts + where + opts = + info + (serverOpts <**> helper) + ( fullDesc + <> header "Simplex Messaging Protocol (SMP) Server" + <> progDesc "Start server with INI_FILE (created on first run)" + ) diff --git a/package.yaml b/package.yaml index 722d647dc..259d1bc4a 100644 --- a/package.yaml +++ b/package.yaml @@ -51,6 +51,8 @@ executables: main: Main.hs dependencies: - cryptostore == 0.2.* + - ini == 0.4.* + - optparse-applicative == 0.15.* - simplex-messaging ghc-options: - -threaded diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 3b6d6d8a3..1fb81bcff 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -31,6 +31,7 @@ import Simplex.Messaging.Server.MsgStore import Simplex.Messaging.Server.MsgStore.STM (MsgQueue) import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.STM (QueueStore) +import Simplex.Messaging.Server.StoreLog import Simplex.Messaging.Transport import Simplex.Messaging.Util import UnliftIO.Async @@ -147,8 +148,8 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = NEW rKey -> createQueue st rKey SUB -> subscribeQueue queueId ACK -> acknowledgeMsg - KEY sKey -> okResp <$> atomically (secureQueue st queueId sKey) - OFF -> okResp <$> atomically (suspendQueue st queueId) + KEY sKey -> secureQueue_ st sKey + OFF -> suspendQueue_ st DEL -> delQueueAndMsgs st where createQueue :: QueueStore -> RecipientPublicKey -> m Transmission @@ -158,7 +159,9 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = addSubscribe = addQueueRetry 3 >>= \case Left e -> return $ ERR e - Right (rId, sId) -> subscribeQueue rId $> IDS rId sId + Right (rId, sId) -> do + withLog (`logCreateById` rId) + subscribeQueue rId $> IDS rId sId addQueueRetry :: Int -> m (Either ErrorType (RecipientId, SenderId)) addQueueRetry 0 = return $ Left INTERNAL @@ -169,11 +172,27 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = Left e -> return $ Left e Right _ -> return $ Right ids + logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO () + logCreateById s rId = + atomically (getQueue st SRecipient rId) >>= \case + Right q -> logCreateQueue s q + _ -> pure () + getIds :: m (RecipientId, SenderId) getIds = do n <- asks $ queueIdBytes . config liftM2 (,) (randomId n) (randomId n) + secureQueue_ :: QueueStore -> SenderPublicKey -> m Transmission + secureQueue_ st sKey = do + withLog $ \s -> logSecureQueue s queueId sKey + okResp <$> atomically (secureQueue st queueId sKey) + + suspendQueue_ :: QueueStore -> m Transmission + suspendQueue_ st = do + withLog (`logDeleteQueue` queueId) + okResp <$> atomically (suspendQueue st queueId) + subscribeQueue :: RecipientId -> m Transmission subscribeQueue rId = atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId @@ -260,12 +279,18 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = delQueueAndMsgs :: QueueStore -> m Transmission delQueueAndMsgs st = do + withLog (`logDeleteQueue` queueId) ms <- asks msgStore atomically $ deleteQueue st queueId >>= \case Left e -> return $ err e Right _ -> delMsgQueue ms queueId $> ok + withLog :: (StoreLog 'WriteMode -> IO a) -> m () + withLog action = do + env <- ask + liftIO . mapM_ action $ storeLog (env :: Env) + ok :: Transmission ok = mkResp corrId queueId OK diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 4371fc95f..9a61e0243 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Server.Env.STM where @@ -13,7 +15,10 @@ import Numeric.Natural import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Server.MsgStore.STM +import Simplex.Messaging.Server.QueueStore (QueueRec (..)) import Simplex.Messaging.Server.QueueStore.STM +import Simplex.Messaging.Server.StoreLog +import System.IO (IOMode (..)) import UnliftIO.STM data ServerConfig = ServerConfig @@ -21,6 +26,7 @@ data ServerConfig = ServerConfig tbqSize :: Natural, queueIdBytes :: Int, msgIdBytes :: Int, + storeLog :: Maybe (StoreLog 'ReadMode), serverPrivateKey :: C.FullPrivateKey -- serverId :: ByteString } @@ -31,7 +37,8 @@ data Env = Env queueStore :: QueueStore, msgStore :: STMMsgStore, idsDrg :: TVar ChaChaDRG, - serverKeyPair :: C.FullKeyPair + serverKeyPair :: C.FullKeyPair, + storeLog :: Maybe (StoreLog 'WriteMode) } data Server = Server @@ -70,12 +77,21 @@ newSubscription = do delivered <- newEmptyTMVar return Sub {subThread = NoSub, delivered} -newEnv :: (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env +newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env newEnv config = do server <- atomically $ newServer (tbqSize config) queueStore <- atomically newQueueStore msgStore <- atomically newMsgStore idsDrg <- drgNew >>= newTVarIO + s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig) let pk = serverPrivateKey config serverKeyPair = (C.publicKey pk, pk) - return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair} + return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s'} + where + restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode) + restoreQueues queueStore s = do + (queues, s') <- liftIO $ readWriteStoreLog s + atomically $ modifyTVar queueStore $ \d -> d {queues, senders = M.foldr' addSender M.empty queues} + pure s' + addSender :: QueueRec -> Map SenderId RecipientId -> Map SenderId RecipientId + addSender q = M.insert (senderId q) (recipientId q) diff --git a/src/Simplex/Messaging/Server/QueueStore.hs b/src/Simplex/Messaging/Server/QueueStore.hs index fd6783106..79eb2daee 100644 --- a/src/Simplex/Messaging/Server/QueueStore.hs +++ b/src/Simplex/Messaging/Server/QueueStore.hs @@ -15,7 +15,7 @@ data QueueRec = QueueRec status :: QueueStatus } -data QueueStatus = QueueActive | QueueOff +data QueueStatus = QueueActive | QueueOff deriving (Eq) class MonadQueueStore s m where addQueue :: s -> RecipientPublicKey -> (RecipientId, SenderId) -> m (Either ErrorType ()) diff --git a/src/Simplex/Messaging/Server/StoreLog.hs b/src/Simplex/Messaging/Server/StoreLog.hs new file mode 100644 index 000000000..5841b23c5 --- /dev/null +++ b/src/Simplex/Messaging/Server/StoreLog.hs @@ -0,0 +1,140 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} + +module Simplex.Messaging.Server.StoreLog + ( StoreLog, -- constructors are not exported + openWriteStoreLog, + openReadStoreLog, + closeStoreLog, + logCreateQueue, + logSecureQueue, + logDeleteQueue, + readWriteStoreLog, + ) +where + +import Control.Applicative (optional, (<|>)) +import Control.Monad (unless) +import Data.Attoparsec.ByteString.Char8 (Parser) +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.Bifunctor (first, second) +import Data.ByteString.Base64 (encode) +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Either (partitionEithers) +import Data.Functor (($>)) +import Data.List (foldl') +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Parsers (base64P, parseAll) +import Simplex.Messaging.Protocol +import Simplex.Messaging.Server.QueueStore (QueueRec (..), QueueStatus (..)) +import Simplex.Messaging.Transport (trimCR) +import System.Directory (doesFileExist) +import System.IO + +-- | opaque container for file handle with a type-safe IOMode +-- constructors are not exported, openWriteStoreLog and openReadStoreLog should be used instead +data StoreLog (a :: IOMode) where + ReadStoreLog :: FilePath -> Handle -> StoreLog 'ReadMode + WriteStoreLog :: FilePath -> Handle -> StoreLog 'WriteMode + +data StoreLogRecord + = CreateQueue QueueRec + | SecureQueue QueueId SenderPublicKey + | DeleteQueue QueueId + +storeLogRecordP :: Parser StoreLogRecord +storeLogRecordP = + "CREATE " *> createQueueP + <|> "SECURE " *> secureQueueP + <|> "DELETE " *> (DeleteQueue <$> base64P) + where + createQueueP = CreateQueue <$> queueRecP + secureQueueP = SecureQueue <$> base64P <* A.space <*> C.pubKeyP + queueRecP = do + recipientId <- "rid=" *> base64P <* A.space + senderId <- "sid=" *> base64P <* A.space + recipientKey <- "rk=" *> C.pubKeyP <* A.space + senderKey <- "sk=" *> optional C.pubKeyP + pure QueueRec {recipientId, senderId, recipientKey, senderKey, status = QueueActive} + +serializeStoreLogRecord :: StoreLogRecord -> ByteString +serializeStoreLogRecord = \case + CreateQueue q -> "CREATE " <> serializeQueue q + SecureQueue rId sKey -> "SECURE " <> encode rId <> " " <> C.serializePubKey sKey + DeleteQueue rId -> "DELETE " <> encode rId + where + serializeQueue QueueRec {recipientId, senderId, recipientKey, senderKey} = + B.unwords + [ "rid=" <> encode recipientId, + "sid=" <> encode senderId, + "rk=" <> C.serializePubKey recipientKey, + "sk=" <> maybe "" C.serializePubKey senderKey + ] + +openWriteStoreLog :: FilePath -> IO (StoreLog 'WriteMode) +openWriteStoreLog f = WriteStoreLog f <$> openFile f WriteMode + +openReadStoreLog :: FilePath -> IO (StoreLog 'ReadMode) +openReadStoreLog f = do + doesFileExist f >>= (`unless` writeFile f "") + ReadStoreLog f <$> openFile f ReadMode + +closeStoreLog :: StoreLog a -> IO () +closeStoreLog = \case + WriteStoreLog _ h -> hClose h + ReadStoreLog _ h -> hClose h + +writeStoreLogRecord :: StoreLog 'WriteMode -> StoreLogRecord -> IO () +writeStoreLogRecord (WriteStoreLog _ h) r = do + B.hPutStrLn h $ serializeStoreLogRecord r + hFlush h + +logCreateQueue :: StoreLog 'WriteMode -> QueueRec -> IO () +logCreateQueue s = writeStoreLogRecord s . CreateQueue + +logSecureQueue :: StoreLog 'WriteMode -> QueueId -> SenderPublicKey -> IO () +logSecureQueue s qId sKey = writeStoreLogRecord s $ SecureQueue qId sKey + +logDeleteQueue :: StoreLog 'WriteMode -> QueueId -> IO () +logDeleteQueue s = writeStoreLogRecord s . DeleteQueue + +readWriteStoreLog :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec, StoreLog 'WriteMode) +readWriteStoreLog s@(ReadStoreLog f _) = do + qs <- readQueues s + closeStoreLog s + s' <- openWriteStoreLog f + writeQueues s' qs + pure (qs, s') + +writeQueues :: StoreLog 'WriteMode -> Map RecipientId QueueRec -> IO () +writeQueues s = mapM_ (writeStoreLogRecord s . CreateQueue) . M.filter active + where + active QueueRec {status} = status == QueueActive + +type LogParsingError = (String, ByteString) + +readQueues :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec) +readQueues (ReadStoreLog _ h) = LB.hGetContents h >>= returnResult . procStoreLog + where + procStoreLog :: LB.ByteString -> ([LogParsingError], Map RecipientId QueueRec) + procStoreLog = second (foldl' procLogRecord M.empty) . partitionEithers . map parseLogRecord . LB.lines + returnResult :: ([LogParsingError], Map RecipientId QueueRec) -> IO (Map RecipientId QueueRec) + returnResult (errs, res) = mapM_ printError errs $> res + parseLogRecord :: LB.ByteString -> Either LogParsingError StoreLogRecord + parseLogRecord = (\s -> first (,s) $ parseAll storeLogRecordP s) . trimCR . LB.toStrict + procLogRecord :: Map RecipientId QueueRec -> StoreLogRecord -> Map RecipientId QueueRec + procLogRecord m = \case + CreateQueue q -> M.insert (recipientId q) q m + SecureQueue qId sKey -> M.adjust (\q -> q {senderKey = Just sKey}) qId m + DeleteQueue qId -> M.delete qId m + printError :: LogParsingError -> IO () + printError (e, s) = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 5411e3da4..f05731aa4 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -118,10 +118,11 @@ putLn :: Handle -> ByteString -> IO () putLn h = B.hPut h . (<> "\r\n") getLn :: Handle -> IO ByteString -getLn h = trim_cr <$> B.hGetLine h - where - trim_cr "" = "" - trim_cr s = if B.last s == '\r' then B.init s else s +getLn h = trimCR <$> B.hGetLine h + +trimCR :: ByteString -> ByteString +trimCR "" = "" +trimCR s = if B.last s == '\r' then B.init s else s -- * Encrypted transport diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 360771df8..3a92400fd 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} @@ -50,6 +51,7 @@ cfg = tbqSize = 1, queueIdBytes = 12, msgIdBytes = 6, + storeLog = Nothing, serverPrivateKey = -- full RSA private key (only for tests) "MIIFIwIBAAKCAQEArZyrri/NAwt5buvYjwu+B/MQeJUszDBpRgVqNddlI9kNwDXu\ From 729cf10ad87034a60afaca6e726c2c6b6eacb4f2 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Thu, 29 Apr 2021 07:23:32 +0100 Subject: [PATCH 12/17] test: SMP server store log (#109) * test: SMP server store log * test: extend store log test: queue deletion, log compacted * test: check store log length in lines --- src/Simplex/Messaging/Server.hs | 11 +++--- tests/SMPClient.hs | 12 ++++++ tests/ServerTests.hs | 68 +++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 5 deletions(-) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 1fb81bcff..dd1f67e98 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -52,6 +52,7 @@ runSMPServerBlocking started cfg@ServerConfig {tcpPort} = do smpServer = do s <- asks server race_ (runTCPServer started tcpPort runClient) (serverThread s) + `finally` withLog closeStoreLog serverThread :: MonadUnliftIO m => Server -> m () serverThread Server {subscribedQ, subscribers} = forever . atomically $ do @@ -286,11 +287,6 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = Left e -> return $ err e Right _ -> delMsgQueue ms queueId $> ok - withLog :: (StoreLog 'WriteMode -> IO a) -> m () - withLog action = do - env <- ask - liftIO . mapM_ action $ storeLog (env :: Env) - ok :: Transmission ok = mkResp corrId queueId OK @@ -303,6 +299,11 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = msgCmd :: Message -> Command 'Broker msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody +withLog :: (MonadUnliftIO m, MonadReader Env m) => (StoreLog 'WriteMode -> IO a) -> m () +withLog action = do + env <- ask + liftIO . mapM_ action $ storeLog (env :: Env) + randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m Encoded randomId n = do gVar <- asks idsDrg diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 3a92400fd..c8eb3874a 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -18,6 +18,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Server (runSMPServerBlocking) import Simplex.Messaging.Server.Env.STM +import Simplex.Messaging.Server.StoreLog (openReadStoreLog) import Simplex.Messaging.Transport import Test.Hspec import UnliftIO.Concurrent @@ -37,6 +38,9 @@ teshKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" teshKeyHash :: Maybe C.KeyHash teshKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" +testStoreLogFile :: FilePath +testStoreLogFile = "tests/tmp/smp-server-store.log" + testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a testSMPClient client = runTCPClient testHost testPort $ \h -> @@ -84,6 +88,14 @@ cfg = \TmKzSAw7iVWwEUZR/PeiEKazqrpp9VU=" } +withSmpServerStoreLogOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a +withSmpServerStoreLogOn port client = do + s <- liftIO $ openReadStoreLog testStoreLogFile + serverBracket + (\started -> runSMPServerBlocking started cfg {tcpPort = port, storeLog = Just s}) + (pure ()) + client + withSmpServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ServiceName -> (ThreadId -> m a) -> m a withSmpServerThreadOn port = serverBracket diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index fa4c70f9e..d081feb05 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -4,9 +4,13 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} module ServerTests where +import Control.Concurrent (ThreadId, killThread) +import Control.Concurrent.STM +import Control.Exception (SomeException, try) import Control.Monad.Except (runExceptT) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -15,6 +19,7 @@ import SMPClient import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport +import System.Directory (removeFile) import System.Timeout import Test.HUnit import Test.Hspec @@ -31,6 +36,7 @@ serverTests = do describe "SMP messages" do describe "duplex communication over 2 SMP connections" testDuplex describe "switch subscription to another SMP queue" testSwitchSub + describe "Store log" testWithStoreLog pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command))) @@ -262,6 +268,68 @@ testSwitchSub = Nothing -> return () Just _ -> error "nothing else is delivered to the 1st TCP connection" +testWithStoreLog :: Spec +testWithStoreLog = + it "should store simplex queues to log and restore them after server restart" $ do + (sPub1, sKey1) <- C.generateKeyPair rsaKeySize + (sPub2, sKey2) <- C.generateKeyPair rsaKeySize + senderId1 <- newTVarIO "" + senderId2 <- newTVarIO "" + + withSmpServerStoreLogOn testPort . runTest $ \h -> do + (sId1, _, _) <- createAndSecureQueue h sPub1 + atomically $ writeTVar senderId1 sId1 + Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ") + Resp "" _ (MSG _ _ "hello") <- tGet fromServer h + + (sId2, rId2, rKey2) <- createAndSecureQueue h sPub2 + atomically $ writeTVar senderId2 sId2 + Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ") + Resp "" _ (MSG _ _ "hello too") <- tGet fromServer h + + Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, "DEL") + pure () + + logSize `shouldReturn` 5 + + withSmpServerThreadOn testPort . runTest $ \h -> do + sId1 <- readTVarIO senderId1 + -- fails if store log is disabled + Resp "bcda" _ (ERR AUTH) <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ") + pure () + + withSmpServerStoreLogOn testPort . runTest $ \h -> do + -- this queue is restored + sId1 <- readTVarIO senderId1 + Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, "SEND 5 hello ") + -- this queue is removed - not restored + sId2 <- readTVarIO senderId2 + Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, "SEND 9 hello too ") + pure () + + logSize `shouldReturn` 1 + removeFile testStoreLogFile + where + createAndSecureQueue :: THandle -> SenderPublicKey -> IO (SenderId, RecipientId, C.SafePrivateKey) + createAndSecureQueue h sPub = do + (rPub, rKey) <- C.generateKeyPair rsaKeySize + Resp "abcd" "" (IDS rId sId) <- signSendRecv h rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) + let keyCmd = "KEY " <> C.serializePubKey sPub + Resp "dabc" rId' OK <- signSendRecv h rKey ("dabc", rId, keyCmd) + (rId', rId) #== "same queue ID" + pure (sId, rId, rKey) + + runTest :: (THandle -> IO ()) -> ThreadId -> Expectation + runTest test' server = do + testSMPClient test' `shouldReturn` () + killThread server + + logSize :: IO Int + logSize = + try (length . B.lines <$> B.readFile testStoreLogFile) >>= \case + Right l -> pure l + Left (_ :: SomeException) -> logSize + syntaxTests :: Spec syntaxTests = do it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX") From 6ceeb2c9db42e1c5383e12b4572511939ee91436 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 30 Apr 2021 09:13:18 +0100 Subject: [PATCH 13/17] save keys as binary to db, remove legacy encoding (#114) * save keys as binary to db, remove legacy encoding * import list --- apps/smp-server/Main.hs | 2 +- src/Simplex/Messaging/Crypto.hs | 73 +++++++++++------------------- src/Simplex/Messaging/Parsers.hs | 3 +- src/Simplex/Messaging/Transport.hs | 2 +- src/Simplex/Messaging/Util.hs | 5 +- tests/AgentTests.hs | 8 ++-- tests/AgentTests/SQLiteTests.hs | 12 ++--- tests/SMPClient.hs | 10 ++-- tests/ServerTests.hs | 17 ++++--- 9 files changed, 59 insertions(+), 73 deletions(-) diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index 3a4952ace..2de39e47c 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -127,7 +127,7 @@ confirm msg = do when (map toLower ok /= "y") exitFailure publicKeyHash :: C.PublicKey -> B.ByteString -publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.binaryEncodePubKey +publicKeyHash = C.serializeKeyHash . C.getKeyHash . C.encodePubKey openStoreLog :: IniOpts -> IO (Maybe (StoreLog 'ReadMode)) openStoreLog IniOpts {enableStoreLog, storeLogFile = f} diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index a6e5a3e07..b5022e96f 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -33,7 +33,7 @@ module Simplex.Messaging.Crypto decryptAES, serializePrivKey, serializePubKey, - binaryEncodePubKey, + encodePubKey, serializeKeyHash, getKeyHash, privKeyP, @@ -50,7 +50,6 @@ module Simplex.Messaging.Crypto ) where -import Control.Applicative ((<|>)) import Control.Exception (Exception) import Control.Monad.Except import Control.Monad.Trans.Except @@ -60,7 +59,6 @@ import qualified Crypto.Error as CE import Crypto.Hash (Digest, SHA256 (..), digestFromByteString, hash) import Crypto.Number.Generate (generateMax) import Crypto.Number.Prime (findPrimeFrom) -import Crypto.Number.Serialize (os2ip) import qualified Crypto.PubKey.RSA as R import qualified Crypto.PubKey.RSA.OAEP as OAEP import qualified Crypto.PubKey.RSA.PSS as PSS @@ -72,7 +70,7 @@ import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Bifunctor (bimap, first) import qualified Data.ByteArray as BA -import Data.ByteString.Base64 +import Data.ByteString.Base64 (decode, encode) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.ByteString.Internal (c2w, w2c) @@ -86,8 +84,8 @@ import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) import Network.Transport.Internal (decodeWord32, encodeWord32) -import Simplex.Messaging.Parsers (base64P, base64StringP, parseAll) -import Simplex.Messaging.Util (liftEitherError) +import Simplex.Messaging.Parsers (base64P, parseAll) +import Simplex.Messaging.Util (liftEitherError, (<$?>)) newtype PublicKey = PublicKey {rsaPublicKey :: R.PublicKey} deriving (Eq, Show) @@ -112,21 +110,21 @@ instance PrivateKey FullPrivateKey where mkPrivateKey = FullPrivateKey instance IsString FullPrivateKey where - fromString = parseString decodePrivKey + fromString = parseString (decode >=> decodePrivKey) instance IsString PublicKey where - fromString = parseString decodePubKey + fromString = parseString (decode >=> decodePubKey) parseString :: (ByteString -> Either String a) -> (String -> a) parseString parse = either error id . parse . B.pack -instance ToField SafePrivateKey where toField = toField . serializePrivKey +instance ToField SafePrivateKey where toField = toField . encodePrivKey -instance ToField PublicKey where toField = toField . serializePubKey +instance ToField PublicKey where toField = toField . encodePubKey -instance FromField SafePrivateKey where fromField = keyFromField privKeyP +instance FromField SafePrivateKey where fromField = keyFromField binaryPrivKeyP -instance FromField PublicKey where fromField = keyFromField pubKeyP +instance FromField PublicKey where fromField = keyFromField binaryPubKeyP keyFromField :: Typeable k => Parser k -> FieldParser k keyFromField p = \case @@ -336,35 +334,22 @@ verify :: PublicKey -> Signature -> ByteString -> Bool verify (PublicKey k) (Signature sig) msg = PSS.verify pssParams k msg sig serializePubKey :: PublicKey -> ByteString -serializePubKey k = "rsa:" <> encodePubKey k +serializePubKey = ("rsa:" <>) . encode . encodePubKey serializePrivKey :: PrivateKey k => k -> ByteString -serializePrivKey pk = "rsa:" <> encodePrivKey pk +serializePrivKey = ("rsa:" <>) . encode . encodePrivKey pubKeyP :: Parser PublicKey -pubKeyP = keyP decodePubKey <|> legacyPubKeyP +pubKeyP = decodePubKey <$?> ("rsa:" *> base64P) binaryPubKeyP :: Parser PublicKey -binaryPubKeyP = either fail pure . binaryDecodePubKey =<< A.takeByteString +binaryPubKeyP = decodePubKey <$?> A.takeByteString privKeyP :: PrivateKey k => Parser k -privKeyP = keyP decodePrivKey <|> legacyPrivKeyP +privKeyP = decodePrivKey <$?> ("rsa:" *> base64P) -keyP :: (ByteString -> Either String k) -> Parser k -keyP dec = either fail pure . dec =<< ("rsa:" *> base64StringP) - -legacyPubKeyP :: Parser PublicKey -legacyPubKeyP = do - (public_size, public_n, public_e) <- legacyKeyParser_ - return . PublicKey $ R.PublicKey {public_size, public_n, public_e} - -legacyPrivKeyP :: PrivateKey k => Parser k -legacyPrivKeyP = _privateKey . safeRsaPrivateKey <$> legacyKeyParser_ - -legacyKeyParser_ :: Parser (Int, Integer, Integer) -legacyKeyParser_ = (,,) <$> (A.decimal <* ",") <*> (intP <* ",") <*> intP - where - intP = os2ip <$> base64P +binaryPrivKeyP :: PrivateKey k => Parser k +binaryPrivKeyP = decodePrivKey <$?> A.takeByteString safePrivateKey :: (Int, Integer, Integer) -> SafePrivateKey safePrivateKey = SafePrivateKey . safeRsaPrivateKey @@ -387,34 +372,28 @@ safeRsaPrivateKey (size, n, d) = } encodePubKey :: PublicKey -> ByteString -encodePubKey = encode . binaryEncodePubKey - -binaryEncodePubKey :: PublicKey -> ByteString -binaryEncodePubKey = binaryEncodeKey . PubKeyRSA . rsaPublicKey +encodePubKey = encodeKey . PubKeyRSA . rsaPublicKey encodePrivKey :: PrivateKey k => k -> ByteString -encodePrivKey = encode . binaryEncodeKey . PrivKeyRSA . rsaPrivateKey +encodePrivKey = encodeKey . PrivKeyRSA . rsaPrivateKey -binaryEncodeKey :: ASN1Object a => a -> ByteString -binaryEncodeKey k = toStrict . encodeASN1 DER $ toASN1 k [] +encodeKey :: ASN1Object a => a -> ByteString +encodeKey k = toStrict . encodeASN1 DER $ toASN1 k [] decodePubKey :: ByteString -> Either String PublicKey -decodePubKey = binaryDecodePubKey <=< decode - -binaryDecodePubKey :: ByteString -> Either String PublicKey -binaryDecodePubKey = - binaryDecodeKey >=> \case +decodePubKey = + decodeKey >=> \case (PubKeyRSA k, []) -> Right $ PublicKey k r -> keyError r decodePrivKey :: PrivateKey k => ByteString -> Either String k decodePrivKey = - decode >=> binaryDecodeKey >=> \case + decodeKey >=> \case (PrivKeyRSA pk, []) -> Right $ mkPrivateKey pk r -> keyError r -binaryDecodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) -binaryDecodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict +decodeKey :: ASN1Object a => ByteString -> Either String (a, [ASN1]) +decodeKey = fromASN1 <=< first show . decodeASN1 DER . fromStrict keyError :: (a, [ASN1]) -> Either String b keyError = \case diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 4f96bae0b..25e2f32bb 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -11,10 +11,11 @@ import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) +import Simplex.Messaging.Util ((<$?>)) import Text.Read (readMaybe) base64P :: Parser ByteString -base64P = either fail pure . decode =<< base64StringP +base64P = decode <$?> base64StringP base64StringP :: Parser ByteString base64StringP = do diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index f05731aa4..d70139500 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -244,7 +244,7 @@ serverHandshake h (k, pk) = do where sendHeaderAndPublicKey_1 :: IO () sendHeaderAndPublicKey_1 = do - let sKey = C.binaryEncodePubKey k + let sKey = C.encodePubKey k header = TransportHeader {blockSize = transportBlockSize, keySize = B.length sKey} B.hPut h $ binaryTransportHeader header <> sKey receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index b05e7ff45..2800e521e 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -31,11 +31,14 @@ raceAny_ = r [] r as (m : ms) = withAsync m $ \a -> r (a : as) ms r as [] = void $ waitAnyCancel as -infixl 4 <$$> +infixl 4 <$$>, <$?> (<$$>) :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b) (<$$>) = fmap . fmap +(<$?>) :: MonadFail m => (a -> Either String b) -> m a -> m b +f <$?> m = m >>= either fail pure . f + bshow :: Show a => a -> ByteString bshow = B.pack . show diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index ac8686439..ad0c365d0 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -13,7 +13,7 @@ import Control.Concurrent import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPAgentClient -import SMPClient (teshKeyHashStr) +import SMPClient (testKeyHashStr) import Simplex.Messaging.Agent.Transmission import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import System.IO (Handle) @@ -128,7 +128,7 @@ testSubscrNotification (server, _) client = do client <# ("", "conn1", END) samplePublicKey :: ByteString -samplePublicKey = "256,ppr3DCweAD3RTVFhU2j0u+DnYdqJl1qCdKLHIKsPl1xBzfmnzK0o9GEDlaIClbK39KzPJMljcpnYb2KlSoZ51AhwF5PH2CS+FStc3QzajiqfdOQPet23Hd9YC6pqyTQ7idntqgPrE7yKJF44lUhKlq8QS9KQcbK7W6t7F9uQFw44ceWd2eVf81UV04kQdKWJvC5Sz6jtSZNEfs9mVI8H0wi1amUvS6+7EDJbxikhcCRnFShFO9dUKRYXj6L2JVqXqO5cZgY9BScyneWIg6mhhsTcdDbITM6COlL+pF1f3TjDN+slyV+IzE+ap/9NkpsrCcI8KwwDpqEDmUUV/JQfmQ==,gj2UAiWzSj7iun0iXvI5iz5WEjaqngmB3SzQ5+iarixbaG15LFDtYs3pijG3eGfB1wIFgoP4D2z97vIWn8olT4uCTUClf29zGDDve07h/B3QG/4i0IDnio7MX3AbE8O6PKouqy/GLTfT4WxFUn423g80rpsVYd5oj+SCL2eaxIc=" +samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" syntaxTests :: Spec syntaxTests = do @@ -139,8 +139,8 @@ syntaxTests = do -- TODO: add tests with defined connection alias xit "only server" $ ("211", "", "NEW localhost") >#>= \case ("211", "", "INV" : _) -> True; _ -> False it "with port" $ ("212", "", "NEW localhost:5000") >#>= \case ("212", "", "INV" : _) -> True; _ -> False - xit "with keyHash" $ ("213", "", "NEW localhost#" <> teshKeyHashStr) >#>= \case ("213", "", "INV" : _) -> True; _ -> False - it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> teshKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False + xit "with keyHash" $ ("213", "", "NEW localhost#" <> testKeyHashStr) >#>= \case ("213", "", "INV" : _) -> True; _ -> False + it "with port and keyHash" $ ("214", "", "NEW localhost:5000#" <> testKeyHashStr) >#>= \case ("214", "", "INV" : _) -> True; _ -> False describe "invalid" do -- TODO: add tests with defined connection alias it "no parameters" $ ("221", "", "NEW") >#> ("221", "", "ERR CMD SYNTAX") diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index a63073c13..46132c832 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -13,7 +13,7 @@ import Data.Time import Data.Word (Word32) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.QQ (sql) -import SMPClient (teshKeyHash) +import SMPClient (testKeyHash) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Transmission @@ -101,7 +101,7 @@ testForeignKeysEnabled = do rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue - { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "1234", connAlias = "conn1", rcvPrivateKey = C.safePrivateKey (1, 2, 3), @@ -115,7 +115,7 @@ rcvQueue1 = sndQueue1 :: SndQueue sndQueue1 = SndQueue - { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "3456", connAlias = "conn1", sndPrivateKey = C.safePrivateKey (1, 2, 3), @@ -177,7 +177,7 @@ testGetAllConnAliases = do testGetRcvQueue :: SpecWith SQLiteStore testGetRcvQueue = do it "should get RcvQueue" $ \store -> do - let smpServer = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash + let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash let recipientId = "1234" createRcvConn store rcvQueue1 `returnsResult` () @@ -232,7 +232,7 @@ testUpgradeRcvConnToDuplex = do `returnsResult` () let anotherSndQueue = SndQueue - { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "2345", connAlias = "conn1", sndPrivateKey = C.safePrivateKey (1, 2, 3), @@ -254,7 +254,7 @@ testUpgradeSndConnToDuplex = do `returnsResult` () let anotherRcvQueue = RcvQueue - { server = SMPServer "smp.simplex.im" (Just "5223") teshKeyHash, + { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "3456", connAlias = "conn1", rcvPrivateKey = C.safePrivateKey (1, 2, 3), diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index c8eb3874a..00e843119 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -32,11 +32,11 @@ testHost = "localhost" testPort :: ServiceName testPort = "5000" -teshKeyHashStr :: B.ByteString -teshKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" +testKeyHashStr :: B.ByteString +testKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" -teshKeyHash :: Maybe C.KeyHash -teshKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" +testKeyHash :: Maybe C.KeyHash +testKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" testStoreLogFile :: FilePath testStoreLogFile = "tests/tmp/smp-server-store.log" @@ -44,7 +44,7 @@ testStoreLogFile = "tests/tmp/smp-server-store.log" testSMPClient :: MonadUnliftIO m => (THandle -> m a) -> m a testSMPClient client = runTCPClient testHost testPort $ \h -> - liftIO (runExceptT $ clientHandshake h teshKeyHash) >>= \case + liftIO (runExceptT $ clientHandshake h testKeyHash) >>= \case Right th -> client th Left e -> error $ show e diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index d081feb05..b1ead469c 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -330,20 +330,23 @@ testWithStoreLog = Right l -> pure l Left (_ :: SomeException) -> logSize +samplePubKey :: ByteString +samplePubKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" + syntaxTests :: Spec syntaxTests = do it "unknown command" $ ("", "abcd", "1234", "HELLO") >#> ("", "abcd", "1234", "ERR CMD SYNTAX") describe "NEW" do it "no parameters" $ ("1234", "bcda", "", "NEW") >#> ("", "bcda", "", "ERR CMD SYNTAX") - it "many parameters" $ ("1234", "cdab", "", "NEW 1 2") >#> ("", "cdab", "", "ERR CMD SYNTAX") - it "no signature" $ ("", "dabc", "", "NEW 3,1234,1234") >#> ("", "dabc", "", "ERR CMD NO_AUTH") - it "queue ID" $ ("1234", "abcd", "12345678", "NEW 3,1234,1234") >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH") + it "many parameters" $ ("1234", "cdab", "", "NEW 1 " <> samplePubKey) >#> ("", "cdab", "", "ERR CMD SYNTAX") + it "no signature" $ ("", "dabc", "", "NEW " <> samplePubKey) >#> ("", "dabc", "", "ERR CMD NO_AUTH") + it "queue ID" $ ("1234", "abcd", "12345678", "NEW " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD HAS_AUTH") describe "KEY" do - it "valid syntax" $ ("1234", "bcda", "12345678", "KEY 3,4567,4567") >#> ("", "bcda", "12345678", "ERR AUTH") + it "valid syntax" $ ("1234", "bcda", "12345678", "KEY " <> samplePubKey) >#> ("", "bcda", "12345678", "ERR AUTH") it "no parameters" $ ("1234", "cdab", "12345678", "KEY") >#> ("", "cdab", "12345678", "ERR CMD SYNTAX") - it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 2") >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") - it "no signature" $ ("", "abcd", "12345678", "KEY 3,4567,4567") >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH") - it "no queue ID" $ ("1234", "bcda", "", "KEY 3,4567,4567") >#> ("", "bcda", "", "ERR CMD NO_AUTH") + it "many parameters" $ ("1234", "dabc", "12345678", "KEY 1 " <> samplePubKey) >#> ("", "dabc", "12345678", "ERR CMD SYNTAX") + it "no signature" $ ("", "abcd", "12345678", "KEY " <> samplePubKey) >#> ("", "abcd", "12345678", "ERR CMD NO_AUTH") + it "no queue ID" $ ("1234", "bcda", "", "KEY " <> samplePubKey) >#> ("", "bcda", "", "ERR CMD NO_AUTH") noParamsSyntaxTest "SUB" noParamsSyntaxTest "ACK" noParamsSyntaxTest "OFF" From 6be48397033583d3aa44457cf2abb7dbd97859dc Mon Sep 17 00:00:00 2001 From: Efim Poberezkin <8711996+efim-poberezkin@users.noreply.github.com> Date: Sun, 2 May 2021 00:38:32 +0400 Subject: [PATCH 14/17] agent: verify msg integrity based on previous msg hash and id (#110) Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> --- apps/dog-food/Main.hs | 2 +- src/Simplex/Messaging/Agent.hs | 93 +++- src/Simplex/Messaging/Agent/Client.hs | 34 +- src/Simplex/Messaging/Agent/Store.hs | 47 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 497 ++++++++---------- .../Messaging/Agent/Store/SQLite/Schema.hs | 3 + src/Simplex/Messaging/Agent/Transmission.hs | 62 ++- src/Simplex/Messaging/Crypto.hs | 4 + tests/AgentTests.hs | 10 +- tests/AgentTests/SQLiteTests.hs | 225 ++++---- 10 files changed, 515 insertions(+), 462 deletions(-) diff --git a/apps/dog-food/Main.hs b/apps/dog-food/Main.hs index bf874285d..3517746a5 100644 --- a/apps/dog-food/Main.hs +++ b/apps/dog-food/Main.hs @@ -259,7 +259,7 @@ receiveFromAgent t ct c = forever . atomically $ do INV qInfo -> Invitation qInfo CON -> Connected contact END -> Disconnected contact - MSG {m_body} -> ReceivedMessage contact m_body + MSG {msgBody} -> ReceivedMessage contact msgBody SENT _ -> NoChatResponse OK -> Confirmation contact ERR (CONN e) -> ContactError e contact diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 1bbdb87bd..ca544a6fd 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -177,10 +177,22 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = _ -> throwError $ CONN SIMPLEX where sendMsg sq = do - senderTs <- liftIO getCurrentTime - senderId <- withStore $ createSndMsg st connAlias msgBody senderTs - sendAgentMessage c sq senderTs $ A_MSG msgBody - respond $ SENT (unId senderId) + internalTs <- liftIO getCurrentTime + (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq + let msgStr = + serializeSMPMessage + SMPMessage + { senderMsgId = unSndId internalSndId, + senderTimestamp = internalTs, + previousMsgHash, + agentMessage = A_MSG msgBody + } + msgHash = C.sha256Hash msgStr + withStore $ + createSndMsg st sq $ + SndMsgData {internalId, internalSndId, internalTs, msgBody, msgHash} + sendAgentMessage c sq msgStr + respond $ SENT (unId internalId) suspendConnection :: m () suspendConnection = @@ -208,8 +220,14 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = sendReplyQInfo srv sq = do (rq, qInfo) <- newReceiveQueue c srv connAlias withStore $ upgradeSndConnToDuplex st connAlias rq - senderTs <- liftIO getCurrentTime - sendAgentMessage c sq senderTs $ REPLY qInfo + senderTimestamp <- liftIO getCurrentTime + sendAgentMessage c sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage = REPLY qInfo + } respond :: ACommand 'Agent -> m () respond = respond' connAlias @@ -231,7 +249,9 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do case cmd of SMP.MSG srvMsgId srvTs msgBody -> do -- TODO deduplicate with previously received - agentMsg <- liftEither . parseSMPMessage =<< decryptAndVerify rq msgBody + msg <- decryptAndVerify rq msgBody + let msgHash = C.sha256Hash msg + agentMsg <- liftEither $ parseSMPMessage msg case agentMsg of SMPConfirmation senderKey -> do logServer "<--" c srv rId "MSG " @@ -244,7 +264,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do secureQueue c rq senderKey withStore $ setRcvQueueStatus st rq Secured _ -> notify connAlias . ERR $ AGENT A_PROHIBITED - SMPMessage {agentMessage, senderMsgId, senderTimestamp} -> + SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> case agentMessage of HELLO verifyKey _ -> do logServer "<--" c srv rId "MSG " @@ -259,24 +279,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do withStore $ upgradeRcvConnToDuplex st connAlias sq connectToSendQueue c st sq senderKey verifyKey notify connAlias CON - A_MSG body -> do - -- TODO check message status - logServer "<--" c srv rId "MSG " - case status of - Active -> do - recipientTs <- liftIO getCurrentTime - let m_sender = (senderMsgId, senderTimestamp) - let m_broker = (srvMsgId, srvTs) - recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker - notify connAlias $ - MSG - { m_status = MsgOk, - m_recipient = (unId recipientId, recipientTs), - m_sender, - m_broker, - m_body = body - } - _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + A_MSG body -> agentClientMsg rq previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash sendAck c rq return () SMP.END -> do @@ -289,6 +292,44 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do where notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) + agentClientMsg :: RcvQueue -> PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m () + agentClientMsg rq@RcvQueue {connAlias, status} receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do + logServer "<--" c srv rId "MSG " + case status of + Active -> do + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq + let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash + withStore $ + createRcvMsg st rq $ + RcvMsgData + { internalId, + internalRcvId, + internalTs, + senderMeta, + brokerMeta, + msgBody, + msgHash, + msgIntegrity + } + notify connAlias $ + MSG + { recipientMeta = (unId internalId, internalTs), + senderMeta, + brokerMeta, + msgBody, + msgIntegrity + } + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + where + checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> MsgIntegrity + checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash + | extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk + | extSndId < prevExtSndId = MsgError $ MsgBadId extSndId + | extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate + | extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1) + | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash + | otherwise = MsgError MsgDuplicate -- this case is not possible connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m () connectToSendQueue c st sq senderKey verifyKey = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index fd66d3652..30c510b28 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -226,7 +226,7 @@ sendConfirmation c sq@SndQueue {server, sndId} senderKey = liftSMP $ sendSMPMessage smp Nothing sndId msg where mkConfirmation :: SMPClient -> m MsgBody - mkConfirmation smp = encryptAndSign smp sq $ SMPConfirmation senderKey + mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = @@ -236,8 +236,14 @@ sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = where mkHello :: SMPClient -> AckMode -> m ByteString mkHello smp ackMode = do - senderTs <- liftIO getCurrentTime - mkAgentMessage smp sq senderTs $ HELLO verifyKey ackMode + senderTimestamp <- liftIO getCurrentTime + encryptAndSign smp sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage = HELLO verifyKey ackMode + } send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () send 0 _ _ _ = throwE $ SMPServerError AUTH @@ -268,27 +274,17 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} = withLogSMP c server rcvId "DEL" $ \smp -> deleteSMPQueue smp rcvPrivateKey rcvId -sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m () -sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} senderTs agentMsg = +sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () +sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msg = withLogSMP_ c server sndId "SEND " $ \smp -> do - msg <- mkAgentMessage smp sq senderTs agentMsg - liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg + msg' <- encryptAndSign smp sq msg + liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg' -mkAgentMessage :: AgentMonad m => SMPClient -> SndQueue -> SenderTimestamp -> AMessage -> m ByteString -mkAgentMessage smp sq senderTs agentMessage = do - encryptAndSign smp sq $ - SMPMessage - { senderMsgId = 0, - senderTimestamp = senderTs, - previousMsgHash = "1234", -- TODO hash of the previous message - agentMessage - } - -encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> SMPMessage -> m ByteString +encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> ByteString -> m ByteString encryptAndSign smp SndQueue {encryptKey, signKey} msg = do paddedSize <- asks $ (blockSize smp -) . reservedMsgSize liftError cryptoError $ do - enc <- C.encrypt encryptKey paddedSize $ serializeSMPMessage msg + enc <- C.encrypt encryptKey paddedSize msg C.Signature sig <- C.sign signKey enc pure $ sig <> enc diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 637b0fa81..51b9db518 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -43,8 +43,12 @@ class Monad m => MonadAgentStore s m where setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m () -- Msg management - createRcvMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId - createSndMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> m InternalId + updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m () + + updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) + createSndMsg :: s -> SndQueue -> SndMsgData -> m () + getMsg :: s -> ConnAlias -> InternalId -> m Msg -- * Queue types @@ -104,6 +108,11 @@ data SConnType :: ConnType -> Type where SCSnd :: SConnType CSnd SCDuplex :: SConnType CDuplex +connType :: SConnType c -> ConnType +connType SCRcv = CRcv +connType SCSnd = CSnd +connType SCDuplex = CDuplex + deriving instance Eq (SConnType d) deriving instance Show (SConnType d) @@ -125,6 +134,40 @@ instance Eq SomeConn where deriving instance Show SomeConn +-- * Message integrity validation types + +type MsgHash = ByteString + +-- | Corresponds to `last_external_snd_msg_id` in `connections` table +type PrevExternalSndId = Int64 + +-- | Corresponds to `last_rcv_msg_hash` in `connections` table +type PrevRcvMsgHash = MsgHash + +-- | Corresponds to `last_snd_msg_hash` in `connections` table +type PrevSndMsgHash = MsgHash + +-- * Message data containers - used on Msg creation to reduce number of parameters + +data RcvMsgData = RcvMsgData + { internalId :: InternalId, + internalRcvId :: InternalRcvId, + internalTs :: InternalTs, + senderMeta :: (ExternalSndId, ExternalSndTs), + brokerMeta :: (BrokerId, BrokerTs), + msgBody :: MsgBody, + msgHash :: MsgHash, + msgIntegrity :: MsgIntegrity + } + +data SndMsgData = SndMsgData + { internalId :: InternalId, + internalSndId :: InternalSndId, + internalTs :: InternalTs, + msgBody :: MsgBody, + msgHash :: MsgHash + } + -- * Message types -- | A message in either direction that is stored by the agent. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 452654537..c48110065 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -28,17 +29,17 @@ import Data.Maybe (fromMaybe) import Data.Text (isPrefixOf) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) -import Database.SQLite.Simple as DB +import Database.SQLite.Simple (FromRow, NamedParam (..), SQLData (..), SQLError, field) +import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.QQ (sql) import Database.SQLite.Simple.ToField (ToField (..)) -import Network.Socket (HostName, ServiceName) +import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema) import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftIOEither) import System.Exit (ExitCode (ExitFailure), exitWith) @@ -87,75 +88,160 @@ checkDuplicate action = liftIOEither $ first handleError <$> E.try action instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where createRcvConn :: SQLiteStore -> RcvQueue -> m () - createRcvConn SQLiteStore {dbConn} = checkDuplicate . createRcvQueueAndConn dbConn + createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} = + checkDuplicate $ + DB.withTransaction dbConn $ do + upsertServer_ dbConn server + insertRcvQueue_ dbConn q + insertRcvConnection_ dbConn q createSndConn :: SQLiteStore -> SndQueue -> m () - createSndConn SQLiteStore {dbConn} = checkDuplicate . createSndQueueAndConn dbConn + createSndConn SQLiteStore {dbConn} q@SndQueue {server} = + checkDuplicate $ + DB.withTransaction dbConn $ do + upsertServer_ dbConn server + insertSndQueue_ dbConn q + insertSndConnection_ dbConn q getConn :: SQLiteStore -> ConnAlias -> m SomeConn - getConn SQLiteStore {dbConn} connAlias = do - queues <- - liftIO $ - retrieveConnQueues dbConn connAlias - case queues of - (Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) - (Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ) - (Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> throwError SEConnNotFound + getConn SQLiteStore {dbConn} connAlias = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias getAllConnAliases :: SQLiteStore -> m [ConnAlias] getAllConnAliases SQLiteStore {dbConn} = - liftIO $ - retrieveAllConnAliases dbConn + liftIO $ do + r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] + return (concat r) getRcvQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m RcvQueue getRcvQueue SQLiteStore {dbConn} SMPServer {host, port} rcvId = do - rcvQueue <- + r <- liftIO $ - retrieveRcvQueue dbConn host port rcvId - case rcvQueue of - Just rcvQ -> return rcvQ + DB.queryNamed + dbConn + [sql| + SELECT + s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, + q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status + FROM rcv_queues q + INNER JOIN servers s ON q.host = s.host AND q.port = s.port + WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; + |] + [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] + case r of + [(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> + let srv = SMPServer hst (deserializePort_ prt) keyHash + in pure $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status _ -> throwError SEConnNotFound deleteConn :: SQLiteStore -> ConnAlias -> m () deleteConn SQLiteStore {dbConn} connAlias = liftIO $ - deleteConnCascade dbConn connAlias + DB.executeNamed + dbConn + "DELETE FROM connections WHERE conn_alias = :conn_alias;" + [":conn_alias" := connAlias] upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m () - upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sndQueue = - liftIOEither $ - updateRcvConnWithSndQueue dbConn connAlias sndQueue + upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias >>= \case + Right (SomeConn SCRcv (RcvConnection _ _)) -> do + upsertServer_ dbConn server + insertSndQueue_ dbConn sq + updateConnWithSndQueue_ dbConn connAlias sq + pure $ Right () + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m () - upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rcvQueue = - liftIOEither $ - updateSndConnWithRcvQueue dbConn connAlias rcvQueue + upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias >>= \case + Right (SomeConn SCSnd (SndConnection _ _)) -> do + upsertServer_ dbConn server + insertRcvQueue_ dbConn rq + updateConnWithRcvQueue_ dbConn connAlias rq + pure $ Right () + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m () - setRcvQueueStatus SQLiteStore {dbConn} rcvQueue status = + setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status = + -- ? throw error if queue doesn't exist? liftIO $ - updateRcvQueueStatus dbConn rcvQueue status + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m () - setRcvQueueActive SQLiteStore {dbConn} rcvQueue verifyKey = + setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = + -- ? throw error if queue doesn't exist? liftIO $ - updateRcvQueueActive dbConn rcvQueue verifyKey + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET verify_key = :verify_key, status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [ ":verify_key" := Just verifyKey, + ":status" := Active, + ":host" := host, + ":port" := serializePort_ port, + ":rcv_id" := rcvId + ] setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () - setSndQueueStatus SQLiteStore {dbConn} sndQueue status = + setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status = + -- ? throw error if queue doesn't exist? liftIO $ - updateSndQueueStatus dbConn sndQueue status + DB.executeNamed + dbConn + [sql| + UPDATE snd_queues + SET status = :status + WHERE host = :host AND port = :port AND snd_id = :snd_id; + |] + [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - createRcvMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId - createRcvMsg SQLiteStore {dbConn} connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) = - liftIOEither $ - insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) + updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} = + liftIO . DB.withTransaction dbConn $ do + (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias + let internalId = InternalId $ unId lastInternalId + 1 + internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 + updateLastIdsRcv_ dbConn connAlias internalId internalRcvId + pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) - createSndMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> m InternalId - createSndMsg SQLiteStore {dbConn} connAlias msgBody internalTs = - liftIOEither $ - insertSndMsg dbConn connAlias msgBody internalTs + createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m () + createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData = + liftIO . DB.withTransaction dbConn $ do + insertRcvMsgBase_ dbConn connAlias rcvMsgData + insertRcvMsgDetails_ dbConn connAlias rcvMsgData + updateHashRcv_ dbConn connAlias rcvMsgData + + updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) + updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} = + liftIO . DB.withTransaction dbConn $ do + (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias + let internalId = InternalId $ unId lastInternalId + 1 + internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 + updateLastIdsSnd_ dbConn connAlias internalId internalSndId + pure (internalId, internalSndId, prevSndHash) + + createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m () + createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData = + liftIO . DB.withTransaction dbConn $ do + insertSndMsgBase_ dbConn connAlias sndMsgData + insertSndMsgDetails_ dbConn connAlias sndMsgData + updateHashSnd_ dbConn connAlias sndMsgData getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg getMsg _st _connAlias _id = throwError SENotImplemented @@ -228,13 +314,6 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do -- * createRcvConn helpers -createRcvQueueAndConn :: DB.Connection -> RcvQueue -> IO () -createRcvQueueAndConn dbConn rcvQueue = - DB.withTransaction dbConn $ do - upsertServer_ dbConn (server (rcvQueue :: RcvQueue)) - insertRcvQueue_ dbConn rcvQueue - insertRcvConnection_ dbConn rcvQueue - insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO () insertRcvQueue_ dbConn RcvQueue {..} = do let port_ = serializePort_ $ port server @@ -266,22 +345,16 @@ insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do [sql| INSERT INTO connections ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id) + last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, + last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, - 0, 0, 0); + 0, 0, 0, 0, x'', x''); |] [":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId] -- * createSndConn helpers -createSndQueueAndConn :: DB.Connection -> SndQueue -> IO () -createSndQueueAndConn dbConn sndQueue = - DB.withTransaction dbConn $ do - upsertServer_ dbConn (server (sndQueue :: SndQueue)) - insertSndQueue_ dbConn sndQueue - insertSndConnection_ dbConn sndQueue - insertSndQueue_ :: DB.Connection -> SndQueue -> IO () insertSndQueue_ dbConn SndQueue {..} = do let port_ = serializePort_ $ port server @@ -311,28 +384,25 @@ insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do [sql| INSERT INTO connections ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id) + last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, + last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES (:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id, - 0, 0, 0); + 0, 0, 0, 0, x'', x''); |] [":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId] -- * getConn helpers -retrieveConnQueues :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue) -retrieveConnQueues dbConn connAlias = - DB.withTransaction -- Avoid inconsistent state between queue reads - dbConn - $ retrieveConnQueues_ dbConn connAlias - --- Separate transactionless version of retrieveConnQueues to be reused in other functions that already wrap --- multiple statements in transaction - otherwise they'd be attempting to start a transaction within a transaction -retrieveConnQueues_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue) -retrieveConnQueues_ dbConn connAlias = do - rcvQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias - sndQ <- retrieveSndQueueByConnAlias_ dbConn connAlias - return (rcvQ, sndQ) +getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn) +getConn_ dbConn connAlias = do + rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias + sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias + pure $ case (rQ, sQ) of + (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) + (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ) + (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ) + _ -> Left SEConnNotFound retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue) retrieveRcvQueueByConnAlias_ dbConn connAlias = do @@ -374,60 +444,8 @@ retrieveSndQueueByConnAlias_ dbConn connAlias = do return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status _ -> return Nothing --- * getAllConnAliases helper - -retrieveAllConnAliases :: DB.Connection -> IO [ConnAlias] -retrieveAllConnAliases dbConn = do - r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] - return (concat r) - --- * getRcvQueue helper - -retrieveRcvQueue :: DB.Connection -> HostName -> Maybe ServiceName -> SMP.RecipientId -> IO (Maybe RcvQueue) -retrieveRcvQueue dbConn host port rcvId = do - r <- - DB.queryNamed - dbConn - [sql| - SELECT - s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, - q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status - FROM rcv_queues q - INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; - |] - [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] - case r of - [(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do - let srv = SMPServer hst (deserializePort_ prt) keyHash - return . Just $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status - _ -> return Nothing - --- * deleteConn helper - -deleteConnCascade :: DB.Connection -> ConnAlias -> IO () -deleteConnCascade dbConn connAlias = - DB.executeNamed - dbConn - "DELETE FROM connections WHERE conn_alias = :conn_alias;" - [":conn_alias" := connAlias] - -- * upgradeRcvConnToDuplex helpers -updateRcvConnWithSndQueue :: DB.Connection -> ConnAlias -> SndQueue -> IO (Either StoreError ()) -updateRcvConnWithSndQueue dbConn connAlias sndQueue = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Just _rcvQ, Nothing) -> do - upsertServer_ dbConn (server (sndQueue :: SndQueue)) - insertSndQueue_ dbConn sndQueue - updateConnWithSndQueue_ dbConn connAlias sndQueue - return $ Right () - (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEConnNotFound - updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do let port_ = serializePort_ $ port server @@ -442,20 +460,6 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do -- * upgradeSndConnToDuplex helpers -updateSndConnWithRcvQueue :: DB.Connection -> ConnAlias -> RcvQueue -> IO (Either StoreError ()) -updateSndConnWithRcvQueue dbConn connAlias rcvQueue = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Nothing, Just _sndQ) -> do - upsertServer_ dbConn (server (rcvQueue :: RcvQueue)) - insertRcvQueue_ dbConn rcvQueue - updateConnWithRcvQueue_ dbConn connAlias rcvQueue - return $ Right () - (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEConnNotFound - updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server @@ -468,93 +472,40 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do |] [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias] --- * setRcvQueueStatus helper +-- * updateRcvIds helpers --- ? throw error if queue doesn't exist? -updateRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () -updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} status = - DB.executeNamed - dbConn - [sql| - UPDATE rcv_queues - SET status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; - |] - [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] - --- * setRcvQueueActive helper - --- ? throw error if queue doesn't exist? -updateRcvQueueActive :: DB.Connection -> RcvQueue -> VerificationKey -> IO () -updateRcvQueueActive dbConn RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = - DB.executeNamed - dbConn - [sql| - UPDATE rcv_queues - SET verify_key = :verify_key, status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; - |] - [ ":verify_key" := Just verifyKey, - ":status" := Active, - ":host" := host, - ":port" := serializePort_ port, - ":rcv_id" := rcvId - ] - --- * setSndQueueStatus helper - --- ? throw error if queue doesn't exist? -updateSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO () -updateSndQueueStatus dbConn SndQueue {sndId, server = SMPServer {host, port}} status = - DB.executeNamed - dbConn - [sql| - UPDATE snd_queues - SET status = :status - WHERE host = :host AND port = :port AND snd_id = :snd_id; - |] - [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - --- * createRcvMsg helpers - -insertRcvMsg :: - DB.Connection -> - ConnAlias -> - MsgBody -> - InternalTs -> - (ExternalSndId, ExternalSndTs) -> - (BrokerId, BrokerTs) -> - IO (Either StoreError InternalId) -insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Just _rcvQ, _) -> do - (lastInternalId, lastInternalRcvId) <- retrieveLastInternalIdsRcv_ dbConn connAlias - let internalId = InternalId $ unId lastInternalId + 1 - let internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 - insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody - insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) - updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId - return $ Right internalId - (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - _ -> return $ Left SEConnNotFound - -retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId) -retrieveLastInternalIdsRcv_ dbConn connAlias = do - [(lastInternalId, lastInternalRcvId)] <- +retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +retrieveLastIdsAndHashRcv_ dbConn connAlias = do + [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- DB.queryNamed dbConn [sql| - SELECT last_internal_msg_id, last_internal_rcv_msg_id + SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash FROM connections WHERE conn_alias = :conn_alias; |] [":conn_alias" := connAlias] - return (lastInternalId, lastInternalRcvId) + return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) -insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalRcvId -> MsgBody -> IO () -insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = do +updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () +updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = + DB.executeNamed + dbConn + [sql| + UPDATE connections + SET last_internal_msg_id = :last_internal_msg_id, + last_internal_rcv_msg_id = :last_internal_rcv_msg_id + WHERE conn_alias = :conn_alias; + |] + [ ":last_internal_msg_id" := newInternalId, + ":last_internal_rcv_msg_id" := newInternalRcvId, + ":conn_alias" := connAlias + ] + +-- * createRcvMsg helpers + +insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do DB.executeNamed dbConn [sql| @@ -570,15 +521,8 @@ insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = ":body" := decodeUtf8 msgBody ] -insertRcvMsgDetails_ :: - DB.Connection -> - ConnAlias -> - InternalRcvId -> - InternalId -> - (ExternalSndId, ExternalSndTs) -> - (BrokerId, BrokerTs) -> - IO () -insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) = +insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = DB.executeNamed dbConn [sql| @@ -592,60 +536,65 @@ insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, e [ ":conn_alias" := connAlias, ":internal_rcv_id" := internalRcvId, ":internal_id" := internalId, - ":external_snd_id" := externalSndId, - ":external_snd_ts" := externalSndTs, - ":broker_id" := brokerId, - ":broker_ts" := brokerTs, + ":external_snd_id" := fst senderMeta, + ":external_snd_ts" := snd senderMeta, + ":broker_id" := fst brokerMeta, + ":broker_ts" := snd brokerMeta, ":rcv_status" := Received ] -updateLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () -updateLastInternalIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = +updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +updateHashRcv_ dbConn connAlias RcvMsgData {..} = + DB.executeNamed + dbConn + -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved + [sql| + UPDATE connections + SET last_external_snd_msg_id = :last_external_snd_msg_id, + last_rcv_msg_hash = :last_rcv_msg_hash + WHERE conn_alias = :conn_alias + AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; + |] + [ ":last_external_snd_msg_id" := fst senderMeta, + ":last_rcv_msg_hash" := msgHash, + ":conn_alias" := connAlias, + ":last_internal_rcv_msg_id" := internalRcvId + ] + +-- * updateSndIds helpers + +retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash) +retrieveLastIdsAndHashSnd_ dbConn connAlias = do + [(lastInternalId, lastInternalSndId, lastSndHash)] <- + DB.queryNamed + dbConn + [sql| + SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash + FROM connections + WHERE conn_alias = :conn_alias; + |] + [":conn_alias" := connAlias] + return (lastInternalId, lastInternalSndId, lastSndHash) + +updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () +updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = DB.executeNamed dbConn [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, last_internal_rcv_msg_id = :last_internal_rcv_msg_id + SET last_internal_msg_id = :last_internal_msg_id, + last_internal_snd_msg_id = :last_internal_snd_msg_id WHERE conn_alias = :conn_alias; |] [ ":last_internal_msg_id" := newInternalId, - ":last_internal_rcv_msg_id" := newInternalRcvId, + ":last_internal_snd_msg_id" := newInternalSndId, ":conn_alias" := connAlias ] -- * createSndMsg helpers -insertSndMsg :: DB.Connection -> ConnAlias -> MsgBody -> InternalTs -> IO (Either StoreError InternalId) -insertSndMsg dbConn connAlias msgBody internalTs = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (_, Just _sndQ) -> do - (lastInternalId, lastInternalSndId) <- retrieveLastInternalIdsSnd_ dbConn connAlias - let internalId = InternalId $ unId lastInternalId + 1 - let internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 - insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody - insertSndMsgDetails_ dbConn connAlias internalSndId internalId - updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId - return $ Right internalId - (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - _ -> return $ Left SEConnNotFound - -retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId) -retrieveLastInternalIdsSnd_ dbConn connAlias = do - [(lastInternalId, lastInternalSndId)] <- - DB.queryNamed - dbConn - [sql| - SELECT last_internal_msg_id, last_internal_snd_msg_id - FROM connections - WHERE conn_alias = :conn_alias; - |] - [":conn_alias" := connAlias] - return (lastInternalId, lastInternalSndId) - -insertSndMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalSndId -> MsgBody -> IO () -insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = do +insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do DB.executeNamed dbConn [sql| @@ -661,8 +610,8 @@ insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = ":body" := decodeUtf8 msgBody ] -insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> InternalSndId -> InternalId -> IO () -insertSndMsgDetails_ dbConn connAlias internalSndId internalId = +insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = DB.executeNamed dbConn [sql| @@ -677,16 +626,18 @@ insertSndMsgDetails_ dbConn connAlias internalSndId internalId = ":snd_status" := Created ] -updateLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () -updateLastInternalIdsSnd_ dbConn connAlias newInternalId newInternalSndId = +updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +updateHashSnd_ dbConn connAlias SndMsgData {..} = DB.executeNamed dbConn + -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, last_internal_snd_msg_id = :last_internal_snd_msg_id - WHERE conn_alias = :conn_alias; + SET last_snd_msg_hash = :last_snd_msg_hash + WHERE conn_alias = :conn_alias + AND last_internal_snd_msg_id = :last_internal_snd_msg_id; |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_alias" := connAlias + [ ":last_snd_msg_hash" := msgHash, + ":conn_alias" := connAlias, + ":last_internal_snd_msg_id" := internalSndId ] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs index 12d886a4e..959a071e9 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs @@ -90,6 +90,9 @@ connections = last_internal_msg_id INTEGER NOT NULL, last_internal_rcv_msg_id INTEGER NOT NULL, last_internal_snd_msg_id INTEGER NOT NULL, + last_external_snd_msg_id INTEGER NOT NULL, + last_rcv_msg_hash BLOB NOT NULL, + last_snd_msg_hash BLOB NOT NULL, PRIMARY KEY (conn_alias), FOREIGN KEY (rcv_host, rcv_port, rcv_id) REFERENCES rcv_queues (host, port, rcv_id), FOREIGN KEY (snd_host, snd_port, snd_id) REFERENCES snd_queues (host, port, snd_id) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 3bc9a7592..b29b919cf 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -13,7 +13,7 @@ module Simplex.Messaging.Agent.Transmission where -import Control.Applicative ((<|>)) +import Control.Applicative (optional, (<|>)) import Control.Monad.IO.Class import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A @@ -90,11 +90,11 @@ data ACommand (p :: AParty) where SEND :: MsgBody -> ACommand Client SENT :: AgentMsgId -> ACommand Agent MSG :: - { m_recipient :: (AgentMsgId, UTCTime), - m_broker :: (MsgId, UTCTime), - m_sender :: (AgentMsgId, UTCTime), - m_status :: MsgStatus, - m_body :: MsgBody + { recipientMeta :: (AgentMsgId, UTCTime), + brokerMeta :: (MsgId, UTCTime), + senderMeta :: (AgentMsgId, UTCTime), + msgIntegrity :: MsgIntegrity, + msgBody :: MsgBody } -> ACommand Agent -- ACK :: AgentMsgId -> ACommand Client @@ -142,7 +142,9 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE SMPMessage <$> A.decimal <* A.space <*> tsISO8601P <* A.space - <*> base64P <* A.endOfLine + -- TODO previous message hash should become mandatory when we support HELLO and REPLY + -- (for HELLO it would be the hash of SMPConfirmation) + <*> (base64P <|> pure "") <* A.endOfLine <*> agentMessageP serializeSMPMessage :: SMPMessage -> ByteString @@ -175,17 +177,11 @@ smpQueueInfoP = "smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP) smpServerP :: Parser SMPServer -smpServerP = SMPServer <$> server <*> port <*> kHash +smpServerP = SMPServer <$> server <*> optional port <*> optional kHash where server = B.unpack <$> A.takeTill (A.inClass ":# ") - port = fromChar ':' $ show <$> (A.decimal :: Parser Int) - kHash = fromChar '#' C.keyHashP - fromChar :: Char -> Parser a -> Parser (Maybe a) - fromChar ch parser = do - c <- A.peekChar - if c == Just ch - then A.char ch *> (Just <$> parser) - else pure Nothing + port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit) + kHash = A.char '#' *> C.keyHashP parseAgentMessage :: ByteString -> Either AgentErrorType AMessage parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE @@ -241,10 +237,10 @@ type AgentMsgId = Int64 type SenderTimestamp = UTCTime -data MsgStatus = MsgOk | MsgError MsgErrorType +data MsgIntegrity = MsgOk | MsgError MsgErrorType deriving (Eq, Show) -data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash +data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash | MsgDuplicate deriving (Eq, Show) -- | error type used in errors sent to agent clients @@ -319,22 +315,23 @@ commandP = sendCmd = ACmd SClient . SEND <$> A.takeByteString sentResp = ACmd SAgent . SENT <$> A.decimal message = do - m_status <- status <* A.space - m_recipient <- "R=" *> partyMeta A.decimal - m_broker <- "B=" *> partyMeta base64P - m_sender <- "S=" *> partyMeta A.decimal - m_body <- A.takeByteString - return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body} + msgIntegrity <- integrity <* A.space + recipientMeta <- "R=" *> partyMeta A.decimal + brokerMeta <- "B=" *> partyMeta base64P + senderMeta <- "S=" *> partyMeta A.decimal + msgBody <- A.takeByteString + return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody} replyMode = " NO_REPLY" $> ReplyOff <|> A.space *> (ReplyVia <$> smpServerP) <|> pure ReplyOn partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space - status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) + integrity = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) msgErrorType = "ID " *> (MsgBadId <$> A.decimal) <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash + <|> "DUPLICATE" $> MsgDuplicate agentError = ACmd SAgent . ERR <$> agentErrorTypeP parseCommand :: ByteString -> Either AgentErrorType ACmd @@ -350,14 +347,14 @@ serializeCommand = \case END -> "END" SEND msgBody -> "SEND " <> serializeMsg msgBody SENT mId -> "SENT " <> bshow mId - MSG {m_recipient = (rmId, rTs), m_broker = (bmId, bTs), m_sender = (smId, sTs), m_status, m_body} -> + MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} -> B.unwords [ "MSG", - msgStatus m_status, + serializeMsgIntegrity msgIntegrity, "R=" <> bshow rmId <> "," <> showTs rTs, "B=" <> encode bmId <> "," <> showTs bTs, "S=" <> bshow smId <> "," <> showTs sTs, - serializeMsg m_body + serializeMsg msgBody ] OFF -> "OFF" DEL -> "DEL" @@ -372,15 +369,16 @@ serializeCommand = \case ReplyOn -> "" showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis - msgStatus :: MsgStatus -> ByteString - msgStatus = \case + serializeMsgIntegrity :: MsgIntegrity -> ByteString + serializeMsgIntegrity = \case MsgOk -> "OK" MsgError e -> - "ERR" <> case e of + "ERR " <> case e of MsgSkipped fromMsgId toMsgId -> B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId] MsgBadId aMsgId -> "ID " <> bshow aMsgId MsgBadHash -> "HASH" + MsgDuplicate -> "DUPLICATE" agentErrorTypeP :: Parser AgentErrorType agentErrorTypeP = @@ -443,7 +441,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) cmdWithMsgBody = \case SEND body -> SEND <$$> getMsgBody body - MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body + MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body cmd -> return $ Right cmd -- TODO refactor with server diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index b5022e96f..19f37d78f 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -36,6 +36,7 @@ module Simplex.Messaging.Crypto encodePubKey, serializeKeyHash, getKeyHash, + sha256Hash, privKeyP, pubKeyP, binaryPubKeyP, @@ -226,6 +227,9 @@ keyHashP = do getKeyHash :: ByteString -> KeyHash getKeyHash = KeyHash . hash +sha256Hash :: ByteString -> ByteString +sha256Hash = BA.convert . (hash :: ByteString -> Digest SHA256) + serializeHeader :: Header -> ByteString serializeHeader Header {aesKey, ivBytes, authTag, msgSize} = unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index ad0c365d0..6a1754f13 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -80,7 +80,7 @@ h #:# err = tryGet `shouldReturn` () _ -> return () pattern Msg :: MsgBody -> ACommand 'Agent -pattern Msg m_body <- MSG {m_body} +pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk} testDuplexConnection :: Handle -> Handle -> IO () testDuplexConnection alice bob = do @@ -88,13 +88,13 @@ testDuplexConnection alice bob = do let qInfo' = serializeSmpQueueInfo qInfo bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", CON) alice <# ("", "bob", CON) - alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT _) -> True; _ -> False - alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT _) -> True; _ -> False + alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT 1) -> True; _ -> False + alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT 2) -> True; _ -> False bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False - bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT _) -> True; _ -> False + bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT 3) -> True; _ -> False alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False - bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT _) -> True; _ -> False + bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT 4) -> True; _ -> False alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH)) diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 46132c832..450e568f3 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -1,12 +1,15 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} module AgentTests.SQLiteTests (storeTests) where import Control.Monad.Except (ExceptT, runExceptT) import qualified Crypto.PubKey.RSA as R +import Data.ByteString.Char8 (ByteString) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Time @@ -49,45 +52,50 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e -- TODO add null port tests storeTests :: Spec storeTests = withStore do - describe "compiled as threadsafe" testCompiledThreadsafe - describe "foreign keys enabled" testForeignKeysEnabled + describe "store setup" do + testCompiledThreadsafe + testForeignKeysEnabled describe "store methods" do - describe "createRcvConn" do - describe "unique" testCreateRcvConn - describe "duplicate" testCreateRcvConnDuplicate - describe "createSndConn" do - describe "unique" testCreateSndConn - describe "duplicate" testCreateSndConnDuplicate - describe "getAllConnAliases" testGetAllConnAliases - describe "getRcvQueue" testGetRcvQueue - describe "deleteConn" do - describe "RcvConnection" testDeleteRcvConn - describe "SndConnection" testDeleteSndConn - describe "DuplexConnection" testDeleteDuplexConn - describe "upgradeRcvConnToDuplex" testUpgradeRcvConnToDuplex - describe "upgradeSndConnToDuplex" testUpgradeSndConnToDuplex - describe "set queue status" do - describe "setRcvQueueStatus" testSetRcvQueueStatus - describe "setSndQueueStatus" testSetSndQueueStatus - describe "DuplexConnection" testSetQueueStatusDuplex - xdescribe "RcvQueue does not exist" testSetRcvQueueStatusNoQueue - xdescribe "SndQueue does not exist" testSetSndQueueStatusNoQueue - describe "createRcvMsg" do - describe "RcvQueue exists" testCreateRcvMsg - describe "RcvQueue does not exist" testCreateRcvMsgNoQueue - describe "createSndMsg" do - describe "SndQueue exists" testCreateSndMsg - describe "SndQueue does not exist" testCreateSndMsgNoQueue + describe "Queue and Connection management" do + describe "createRcvConn" do + testCreateRcvConn + testCreateRcvConnDuplicate + describe "createSndConn" do + testCreateSndConn + testCreateSndConnDuplicate + describe "getAllConnAliases" testGetAllConnAliases + describe "getRcvQueue" testGetRcvQueue + describe "deleteConn" do + testDeleteRcvConn + testDeleteSndConn + testDeleteDuplexConn + describe "upgradeRcvConnToDuplex" do + testUpgradeRcvConnToDuplex + describe "upgradeSndConnToDuplex" do + testUpgradeSndConnToDuplex + describe "set Queue status" do + describe "setRcvQueueStatus" do + testSetRcvQueueStatus + testSetRcvQueueStatusNoQueue + describe "setSndQueueStatus" do + testSetSndQueueStatus + testSetSndQueueStatusNoQueue + testSetQueueStatusDuplex + describe "Msg management" do + describe "create Msg" do + testCreateRcvMsg + testCreateSndMsg + testCreateRcvAndSndMsgs testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = do - it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do + it "compiled sqlite library should be threadsafe" $ \store -> do compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] compileOptions `shouldNotContain` [["THREADSAFE=0"]] testForeignKeysEnabled :: SpecWith SQLiteStore testForeignKeysEnabled = do - it "should throw error if foreign keys are enabled" $ \store -> do + it "foreign keys should be enabled" $ \store -> do let inconsistentQuery = [sql| INSERT INTO connections @@ -139,8 +147,7 @@ testCreateRcvConn = do testCreateRcvConnDuplicate :: SpecWith SQLiteStore testCreateRcvConnDuplicate = do it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 createRcvConn store rcvQueue1 `throwsError` SEConnDuplicate @@ -159,18 +166,15 @@ testCreateSndConn = do testCreateSndConnDuplicate :: SpecWith SQLiteStore testCreateSndConnDuplicate = do it "should throw error on attempt to create duplicate SndConnection" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 createSndConn store sndQueue1 `throwsError` SEConnDuplicate testGetAllConnAliases :: SpecWith SQLiteStore testGetAllConnAliases = do it "should get all conn aliases" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - createSndConn store sndQueue1 {connAlias = "conn2"} - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"} getAllConnAliases store `returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias] @@ -179,16 +183,14 @@ testGetRcvQueue = do it "should get RcvQueue" $ \store -> do let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash let recipientId = "1234" - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getRcvQueue store smpServer recipientId `returnsResult` rcvQueue1 testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn = do it "should create RcvConnection and delete it" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getConn store "conn1" `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) deleteConn store "conn1" @@ -200,8 +202,7 @@ testDeleteRcvConn = do testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = do it "should create SndConnection and delete it" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) deleteConn store "conn1" @@ -213,10 +214,8 @@ testDeleteSndConn = do testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = do it "should create DuplexConnection and delete it" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) deleteConn store "conn1" @@ -228,8 +227,7 @@ testDeleteDuplexConn = do testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = do it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 let anotherSndQueue = SndQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, @@ -242,16 +240,14 @@ testUpgradeRcvConnToDuplex = do } upgradeRcvConnToDuplex store "conn1" anotherSndQueue `throwsError` SEBadConnType CSnd - upgradeSndConnToDuplex store "conn1" rcvQueue1 - `returnsResult` () + _ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1 upgradeRcvConnToDuplex store "conn1" anotherSndQueue `throwsError` SEBadConnType CDuplex testUpgradeSndConnToDuplex :: SpecWith SQLiteStore testUpgradeSndConnToDuplex = do it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 let anotherRcvQueue = RcvQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, @@ -266,16 +262,14 @@ testUpgradeSndConnToDuplex = do } upgradeSndConnToDuplex store "conn1" anotherRcvQueue `throwsError` SEBadConnType CRcv - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 upgradeSndConnToDuplex store "conn1" anotherRcvQueue `throwsError` SEBadConnType CDuplex testSetRcvQueueStatus :: SpecWith SQLiteStore testSetRcvQueueStatus = do it "should update status of RcvQueue" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getConn store "conn1" `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) setRcvQueueStatus store rcvQueue1 Confirmed @@ -286,8 +280,7 @@ testSetRcvQueueStatus = do testSetSndQueueStatus :: SpecWith SQLiteStore testSetSndQueueStatus = do it "should update status of SndQueue" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) setSndQueueStatus store sndQueue1 Confirmed @@ -298,10 +291,8 @@ testSetSndQueueStatus = do testSetQueueStatusDuplex :: SpecWith SQLiteStore testSetQueueStatusDuplex = do it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) setRcvQueueStatus store rcvQueue1 Secured @@ -311,61 +302,87 @@ testSetQueueStatusDuplex = do setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn - SCDuplex - ( DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed} - ) + `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = do - it "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do + xit "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do setRcvQueueStatus store rcvQueue1 Confirmed - `throwsError` SEInternal "" + `throwsError` SEConnNotFound testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore testSetSndQueueStatusNoQueue = do - it "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do + xit "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do setSndQueueStatus store sndQueue1 Confirmed - `throwsError` SEInternal "" + `throwsError` SEConnNotFound + +hw :: ByteString +hw = encodeUtf8 "Hello world!" + +ts :: UTCTime +ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) + +mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData +mkRcvMsgData internalId internalRcvId externalSndId brokerId msgHash = + RcvMsgData + { internalId, + internalRcvId, + internalTs = ts, + senderMeta = (externalSndId, ts), + brokerMeta = (brokerId, ts), + msgBody = hw, + msgHash = msgHash, + msgIntegrity = MsgOk + } + +testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation +testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do + updateRcvIds store rcvQueue + `returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) + createRcvMsg store rcvQueue rcvMsgData + `returnsResult` () testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = do - it "should create a RcvMsg and return InternalId" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + it "should reserve internal ids and create a RcvMsg" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 -- TODO getMsg to check message - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `returnsResult` InternalId 1 + testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" + testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" -testCreateRcvMsgNoQueue :: SpecWith SQLiteStore -testCreateRcvMsgNoQueue = do - it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEConnNotFound - createSndConn store sndQueue1 - `returnsResult` () - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEBadConnType CSnd +mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData +mkSndMsgData internalId internalSndId msgHash = + SndMsgData + { internalId, + internalSndId, + internalTs = ts, + msgBody = hw, + msgHash = msgHash + } + +testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation +testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do + updateSndIds store sndQueue + `returnsResult` (internalId, internalSndId, expectedPrevHash) + createSndMsg store sndQueue sndMsgData + `returnsResult` () testCreateSndMsg :: SpecWith SQLiteStore testCreateSndMsg = do - it "should create a SndMsg and return InternalId" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do + _ <- runExceptT $ createSndConn store sndQueue1 -- TODO getMsg to check message - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `returnsResult` InternalId 1 + testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" + testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" -testCreateSndMsgNoQueue :: SpecWith SQLiteStore -testCreateSndMsgNoQueue = do - it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEConnNotFound - createRcvConn store rcvQueue1 - `returnsResult` () - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEBadConnType CRcv +testCreateRcvAndSndMsgs :: SpecWith SQLiteStore +testCreateRcvAndSndMsgs = do + it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 + testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" + testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" + testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" + testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" + testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" + testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" From 829c198e5f2a5690172708a42dc82f6125e98407 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 1 May 2021 21:58:35 +0100 Subject: [PATCH 15/17] reserve block size (32 bit) and protocol (16 bit) in client handshake block (#116) * reserve block size (32 bit) and protocol (16 bit) in client handshake block * update function names * fix abnf --- rfcs/2021-01-26-crypto.md | 14 +++-- src/Simplex/Messaging/Transport.hs | 87 ++++++++++++++++-------------- 2 files changed, 58 insertions(+), 43 deletions(-) diff --git a/rfcs/2021-01-26-crypto.md b/rfcs/2021-01-26-crypto.md index 280d42e0a..6e397da18 100644 --- a/rfcs/2021-01-26-crypto.md +++ b/rfcs/2021-01-26-crypto.md @@ -22,10 +22,10 @@ To establish the session keys and base IVs, the server should have an asymmetric The handshake sequence is the following: -1. Once the connection is established, the server sends transport_header and its public RSA key encoded in X509 binary format to the client. +1. Once the connection is established, the server sends server_header and its public RSA key encoded in X509 binary format to the client. 2. The client compares the hash of the received key with the hash it already has (e.g. received as part of connection invitation or server in NEW command). If the hash does not match, the client must terminate the connection. TODO as the hash is optional in server syntax at the moment, hash comparison will be optional as well. Probably it should become required. 3. If the hash is the same, the client should generate random symmetric AES keys and base IVs that will be used as session keys/IVs by the client and the server. -4. The client then should encrypt these symmetric keys and base IVs with the public key that the server sent, and send to the server the result of the encryption: `rsa-encrypt(snd-aes-key, snd-base-iv, rcv-aes-key, rcv-base-iv)`. `snd-aes-key` and `snd-base-iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv-aes-key` and `rcv-base-iv` will be used by the client to decrypt **received** messages and by the server to encrypt them. +4. The client then should construct client_handshake block and send it to the server encrypted with the server public key: `rsa-encrypt(client_handshake)`. `snd_aes_key` and `snd_base_iv` will be used by the client to encrypt **sent** messages and by the server to decrypt them, `rcv_aes_key` and `rcv_base_iv` will be used by the client to decrypt **received** messages and by the server to encrypt them. 5. The server should decrypt the received keys and base IVs with its private key. 6. In case of successful decryption, the server should send encrypted welcome block (encrypted_welcome_block) that contains SMP protocol version. @@ -34,11 +34,19 @@ All the subsequent data both from the client and from the server should be sent Each transport block sent by the client and the server has this syntax: ```abnf -transport_header = block_size protocol key_size +server_header = block_size protocol key_size block_size = 4*4(OCTET) ; 4-byte block size sent by the server, currently the client rejects if > 65536 bytes protocol = 2*2(OCTET) ; currently it is 0, that means binary RSA key key_size = 2*2(OCTET) ; the encoded key size in bytes (binary encoded in X509 standard) +client_handshake = client_block_size client_protocol snd_aes_key snd_base_iv rcv_aes_key rcv_base_iv +client_block_size = 4*4(OCTET) ; 4-byte block size sent by the client, currently it is ignored by the server - reserved +client_protocol = 2*2(OCTET) ; currently it is 0 - reserved +snd_aes_key = 32*32(OCTET) +snd_base_iv = 16*16(OCTET) +rcv_aes_key = 32*32(OCTET) +rcv_base_iv = 16*16(OCTET) + transport_block = aes_body_auth_tag aes_encrypted_body ; size is sent by server during handshake, usually 8192 bytes aes_encrypted_body = 1*OCTET diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index d70139500..731f3d4d2 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -156,8 +156,9 @@ data SessionKey = SessionKey counter :: TVar Word32 } -data HandshakeKeys = HandshakeKeys - { sndKey :: SessionKey, +data ClientHandshake = ClientHandshake + { blockSize :: Int, + sndKey :: SessionKey, rcvKey :: SessionKey } @@ -237,7 +238,9 @@ serverHandshake :: Handle -> C.FullKeyPair -> ExceptT TransportError IO THandle serverHandshake h (k, pk) = do liftIO sendHeaderAndPublicKey_1 encryptedKeys <- receiveEncryptedKeys_4 - HandshakeKeys {sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys + -- TODO server currently ignores blockSize returned by the client + -- this is reserved for future support of streams + ClientHandshake {blockSize = _, sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys th <- liftIO $ transportHandle h rcvKey sndKey transportBlockSize -- keys are swapped here sendWelcome_6 th pure th @@ -245,17 +248,17 @@ serverHandshake h (k, pk) = do sendHeaderAndPublicKey_1 :: IO () sendHeaderAndPublicKey_1 = do let sKey = C.encodePubKey k - header = TransportHeader {blockSize = transportBlockSize, keySize = B.length sKey} - B.hPut h $ binaryTransportHeader header <> sKey + header = ServerHeader {blockSize = transportBlockSize, keySize = B.length sKey} + B.hPut h $ binaryServerHeader header <> sKey receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString receiveEncryptedKeys_4 = liftIO (B.hGet h $ C.publicKeySize k) >>= \case "" -> throwE $ TEHandshake TERMINATED ks -> pure ks - decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO HandshakeKeys + decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO ClientHandshake decryptParseKeys_5 encKeys = liftError (const $ TEHandshake DECRYPT) (C.decryptOAEP pk encKeys) - >>= liftEither . parseHandshakeKeys + >>= liftEither . parseClientHandshake sendWelcome_6 :: THandle -> ExceptT TransportError IO () sendWelcome_6 th = ExceptT . tPutEncrypted th $ serializeSMPVersion currentSMPVersion <> " " @@ -264,7 +267,8 @@ serverHandshake h (k, pk) = do clientHandshake :: Handle -> Maybe C.KeyHash -> ExceptT TransportError IO THandle clientHandshake h keyHash = do (k, blkSize) <- getHeaderAndPublicKey_1_2 - keys@HandshakeKeys {sndKey, rcvKey} <- liftIO generateKeys_3 + -- TODO currently client always uses the blkSize returned by the server + keys@ClientHandshake {sndKey, rcvKey} <- liftIO $ generateKeys_3 blkSize sendEncryptedKeys_4 k keys th <- liftIO $ transportHandle h sndKey rcvKey blkSize getWelcome_6 th >>= checkVersion @@ -272,8 +276,8 @@ clientHandshake h keyHash = do where getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (C.PublicKey, Int) getHeaderAndPublicKey_1_2 = do - header <- liftIO (B.hGet h transportHeaderSize) - TransportHeader {blockSize, keySize} <- liftEither $ parse transportHeaderP (TEHandshake HEADER) header + header <- liftIO (B.hGet h serverHeaderSize) + ServerHeader {blockSize, keySize} <- liftEither $ parse serverHeaderP (TEHandshake HEADER) header when (blockSize < transportBlockSize || blockSize > maxTransportBlockSize) $ throwError $ TEHandshake HEADER s <- liftIO $ B.hGet h keySize @@ -286,16 +290,16 @@ clientHandshake h keyHash = do validateKeyHash_2 k kHash | C.getKeyHash k == kHash = pure () | otherwise = throwE $ TEHandshake BAD_HASH - generateKeys_3 :: IO HandshakeKeys - generateKeys_3 = HandshakeKeys <$> generateKey <*> generateKey + generateKeys_3 :: Int -> IO ClientHandshake + generateKeys_3 blkSize = ClientHandshake blkSize <$> generateKey <*> generateKey generateKey :: IO SessionKey generateKey = do aesKey <- C.randomAesKey baseIV <- C.randomIV pure SessionKey {aesKey, baseIV, counter = undefined} - sendEncryptedKeys_4 :: C.PublicKey -> HandshakeKeys -> ExceptT TransportError IO () + sendEncryptedKeys_4 :: C.PublicKey -> ClientHandshake -> ExceptT TransportError IO () sendEncryptedKeys_4 k keys = - liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeHandshakeKeys keys) + liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeClientHandshake keys) >>= liftIO . B.hPut h getWelcome_6 :: THandle -> ExceptT TransportError IO SMPVersion getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th @@ -306,48 +310,37 @@ clientHandshake h keyHash = do when (major smpVersion > major currentSMPVersion) . throwE $ TEHandshake MAJOR_VERSION -data TransportHeader = TransportHeader {blockSize :: Int, keySize :: Int} +data ServerHeader = ServerHeader {blockSize :: Int, keySize :: Int} deriving (Eq, Show) binaryRsaTransport :: Int binaryRsaTransport = 0 -binaryRsaTransportBS :: ByteString -binaryRsaTransportBS = encodeEnum16 binaryRsaTransport - transportBlockSize :: Int transportBlockSize = 4096 maxTransportBlockSize :: Int maxTransportBlockSize = 65536 -transportHeaderSize :: Int -transportHeaderSize = 8 +serverHeaderSize :: Int +serverHeaderSize = 8 -binaryTransportHeader :: TransportHeader -> ByteString -binaryTransportHeader TransportHeader {blockSize, keySize} = - encodeEnum32 blockSize <> binaryRsaTransportBS <> encodeEnum16 keySize +binaryServerHeader :: ServerHeader -> ByteString +binaryServerHeader ServerHeader {blockSize, keySize} = + encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> encodeEnum16 keySize -transportHeaderP :: Parser TransportHeader -transportHeaderP = TransportHeader <$> int32 <* binaryRsaTransportP <*> int16 - where - int32 = decodeNum32 <$> A.take 4 - int16 = decodeNum16 <$> A.take 2 - binaryRsaTransportP = binaryRsa <$> int16 - binaryRsa :: Int -> Parser () - binaryRsa n - | n == binaryRsaTransport = pure () - | otherwise = fail "unknown transport mode" +serverHeaderP :: Parser ServerHeader +serverHeaderP = ServerHeader <$> int32 <* binaryRsaTransportP <*> int16 -serializeHandshakeKeys :: HandshakeKeys -> ByteString -serializeHandshakeKeys HandshakeKeys {sndKey, rcvKey} = - serializeKey sndKey <> serializeKey rcvKey +serializeClientHandshake :: ClientHandshake -> ByteString +serializeClientHandshake ClientHandshake {blockSize, sndKey, rcvKey} = + encodeEnum32 blockSize <> encodeEnum16 binaryRsaTransport <> serializeKey sndKey <> serializeKey rcvKey where serializeKey :: SessionKey -> ByteString serializeKey SessionKey {aesKey, baseIV} = C.unKey aesKey <> C.unIV baseIV -handshakeKeysP :: Parser HandshakeKeys -handshakeKeysP = HandshakeKeys <$> keyP <*> keyP +clientHandshakeP :: Parser ClientHandshake +clientHandshakeP = ClientHandshake <$> int32 <* binaryRsaTransportP <*> keyP <*> keyP where keyP :: Parser SessionKey keyP = do @@ -355,8 +348,22 @@ handshakeKeysP = HandshakeKeys <$> keyP <*> keyP baseIV <- C.ivP pure SessionKey {aesKey, baseIV, counter = undefined} -parseHandshakeKeys :: ByteString -> Either TransportError HandshakeKeys -parseHandshakeKeys = parse handshakeKeysP $ TEHandshake AES_KEYS +int32 :: Parser Int +int32 = decodeNum32 <$> A.take 4 + +int16 :: Parser Int +int16 = decodeNum16 <$> A.take 2 + +binaryRsaTransportP :: Parser () +binaryRsaTransportP = binaryRsa =<< int16 + where + binaryRsa :: Int -> Parser () + binaryRsa n + | n == binaryRsaTransport = pure () + | otherwise = fail "unknown transport mode" + +parseClientHandshake :: ByteString -> Either TransportError ClientHandshake +parseClientHandshake = parse clientHandshakeP $ TEHandshake AES_KEYS transportHandle :: Handle -> SessionKey -> SessionKey -> Int -> IO THandle transportHandle h sk rk blockSize = do From 633b3a4bda3e241c4e26ab63e06da8f9ed827a2a Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 1 May 2021 22:07:25 +0100 Subject: [PATCH 16/17] mitigate timing attack to determine if queue exists (#117) * mitigate timing attack to determine if queue exists * remove timing for authenticated SEND command Co-authored-by: Efim Poberezkin <8711996+efim-poberezkin@users.noreply.github.com> --- package.yaml | 1 + src/Simplex/Messaging/Protocol.hs | 2 +- src/Simplex/Messaging/Server.hs | 26 ++++++++++++++------------ tests/ServerTests.hs | 29 +++++++++++++++++++++++++++-- 4 files changed, 43 insertions(+), 15 deletions(-) diff --git a/package.yaml b/package.yaml index 259d1bc4a..d89e592ca 100644 --- a/package.yaml +++ b/package.yaml @@ -87,6 +87,7 @@ tests: - HUnit == 1.6.* - random == 1.1.* - QuickCheck == 2.13.* + - timeit == 2.0.* ghc-options: # - -haddock diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index d7f139c18..30d20ec5a 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -251,7 +251,7 @@ tGet fromParty th = liftIO (tGetParse th) >>= decodeParseValidate Cmd SBroker _ | B.null queueId -> Left $ CMD NO_QUEUE | otherwise -> Right cmd - -- NEW must NOT have signature or queue ID + -- NEW must have signature but NOT queue ID Cmd SRecipient (NEW _) | B.null signature -> Left $ CMD NO_AUTH | not (B.null queueId) -> Left $ CMD HAS_AUTH diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index dd1f67e98..d0d736673 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -108,25 +108,27 @@ verifyTransmission (sig, t@(corrId, queueId, cmd)) = do (corrId,queueId,) <$> case cmd of Cmd SBroker _ -> return $ smpErr INTERNAL -- it can only be client command, because `fromClient` was used Cmd SRecipient (NEW k) -> return $ verifySignature k - Cmd SRecipient _ -> withQueueRec SRecipient $ verifySignature . recipientKey - Cmd SSender (SEND _) -> withQueueRec SSender $ verifySend sig . senderKey + Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey + Cmd SSender (SEND _) -> verifyCmd SSender $ verifySend sig . senderKey Cmd SSender PING -> return cmd where - withQueueRec :: SParty (p :: Party) -> (QueueRec -> Cmd) -> m Cmd - withQueueRec party f = do + verifyCmd :: SParty p -> (QueueRec -> Cmd) -> m Cmd + verifyCmd party f = do + (aKey, _) <- asks serverKeyPair -- any public key can be used to mitigate timing attack st <- asks queueStore - qr <- atomically $ getQueue st party queueId - return $ either smpErr f qr + q <- atomically $ getQueue st party queueId + pure $ either (const $ fakeVerify aKey) f q + fakeVerify :: C.PublicKey -> Cmd + fakeVerify aKey = if verify aKey then authErr else authErr verifySend :: C.Signature -> Maybe SenderPublicKey -> Cmd verifySend "" = maybe cmd (const authErr) verifySend _ = maybe authErr verifySignature verifySignature :: C.PublicKey -> Cmd - verifySignature key = - if C.verify key sig (serializeTransmission t) - then cmd - else authErr - - smpErr e = Cmd SBroker $ ERR e + verifySignature key = if verify key then cmd else authErr + verify :: C.PublicKey -> Bool + verify key = C.verify key sig (serializeTransmission t) + smpErr :: ErrorType -> Cmd + smpErr = Cmd SBroker . ERR authErr = smpErr AUTH client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m () diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index b1ead469c..743d55671 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -11,7 +11,7 @@ module ServerTests where import Control.Concurrent (ThreadId, killThread) import Control.Concurrent.STM import Control.Exception (SomeException, try) -import Control.Monad.Except (runExceptT) +import Control.Monad.Except (forM_, runExceptT) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -20,12 +20,13 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol import Simplex.Messaging.Transport import System.Directory (removeFile) +import System.TimeIt (timeItT) import System.Timeout import Test.HUnit import Test.Hspec rsaKeySize :: Int -rsaKeySize = 1024 `div` 8 +rsaKeySize = 2048 `div` 8 serverTests :: Spec serverTests = do @@ -37,6 +38,7 @@ serverTests = do describe "duplex communication over 2 SMP connections" testDuplex describe "switch subscription to another SMP queue" testSwitchSub describe "Store log" testWithStoreLog + describe "Timing of AUTH error" testTiming pattern Resp :: CorrId -> QueueId -> Command 'Broker -> SignedTransmissionOrError pattern Resp corrId queueId command <- ("", (corrId, queueId, Right (Cmd SBroker command))) @@ -330,6 +332,29 @@ testWithStoreLog = Right l -> pure l Left (_ :: SomeException) -> logSize +testTiming :: Spec +testTiming = + it "should have similar time for auth error whether queue exists or not" $ + smpTest2 \rh sh -> do + (rPub, rKey) <- C.generateKeyPair rsaKeySize + Resp "abcd" "" (IDS rId sId) <- signSendRecv rh rKey ("abcd", "", "NEW " <> C.serializePubKey rPub) + + (sPub, sKey) <- C.generateKeyPair rsaKeySize + let keyCmd = "KEY " <> C.serializePubKey sPub + Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, keyCmd) + + Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, "SEND 5 hello ") + + timeNoQueue <- timeRepeat 25 $ do + Resp "dabc" _ (ERR AUTH) <- signSendRecv sh sKey ("dabc", rId, "SEND 5 hello ") + return () + timeWrongKey <- timeRepeat 25 $ do + Resp "cdab" _ (ERR AUTH) <- signSendRecv sh rKey ("cdab", sId, "SEND 5 hello ") + return () + abs (timeNoQueue - timeWrongKey) / timeNoQueue < 0.15 `shouldBe` True + where + timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const + samplePubKey :: ByteString samplePubKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" From b7902ee4c877207c5e389911b09925be355864b1 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin <8711996+efim-poberezkin@users.noreply.github.com> Date: Sun, 2 May 2021 10:46:18 +0400 Subject: [PATCH 17/17] agent sqlite: store msg hashes and integrity (#118, #119, #120) Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> --- src/Simplex/Messaging/Agent.hs | 5 +-- src/Simplex/Messaging/Agent/Store.hs | 15 ++++++--- src/Simplex/Messaging/Agent/Store/SQLite.hs | 33 ++++++++++++++----- .../Messaging/Agent/Store/SQLite/Schema.hs | 4 +++ src/Simplex/Messaging/Agent/Transmission.hs | 30 +++++++++-------- src/Simplex/Messaging/Crypto.hs | 4 +-- tests/AgentTests/SQLiteTests.hs | 9 ++--- 7 files changed, 67 insertions(+), 33 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index ca544a6fd..e5b214033 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -190,7 +190,7 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = msgHash = C.sha256Hash msgStr withStore $ createSndMsg st sq $ - SndMsgData {internalId, internalSndId, internalTs, msgBody, msgHash} + SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash} sendAgentMessage c sq msgStr respond $ SENT (unId internalId) @@ -309,7 +309,8 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do senderMeta, brokerMeta, msgBody, - msgHash, + internalHash = msgHash, + externalPrevSndHash = receivedPrevMsgHash, msgIntegrity } notify connAlias $ diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 51b9db518..778b6b6be 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -147,6 +147,7 @@ type PrevRcvMsgHash = MsgHash -- | Corresponds to `last_snd_msg_hash` in `connections` table type PrevSndMsgHash = MsgHash +-- ? merge/replace these with RcvMsg and SndMsg -- * Message data containers - used on Msg creation to reduce number of parameters data RcvMsgData = RcvMsgData @@ -156,7 +157,8 @@ data RcvMsgData = RcvMsgData senderMeta :: (ExternalSndId, ExternalSndTs), brokerMeta :: (BrokerId, BrokerTs), msgBody :: MsgBody, - msgHash :: MsgHash, + internalHash :: MsgHash, + externalPrevSndHash :: MsgHash, msgIntegrity :: MsgIntegrity } @@ -165,7 +167,7 @@ data SndMsgData = SndMsgData internalSndId :: InternalSndId, internalTs :: InternalTs, msgBody :: MsgBody, - msgHash :: MsgHash + internalHash :: MsgHash } -- * Message types @@ -194,7 +196,10 @@ data RcvMsg = RcvMsg -- | Timestamp of acknowledgement to sender, corresponds to `AcknowledgedToSender` status. -- Do not mix up with `externalSndTs` - timestamp created at sender before sending, -- which in its turn corresponds to `internalTs` in sending agent. - ackSenderTs :: AckSenderTs + ackSenderTs :: AckSenderTs, + -- | Hash of previous message as received from sender - stored for integrity forensics. + externalPrevSndHash :: MsgHash, + msgIntegrity :: MsgIntegrity } deriving (Eq, Show) @@ -254,7 +259,9 @@ data MsgBase = MsgBase -- due to a possibility of implementation errors in different agents. internalId :: InternalId, internalTs :: InternalTs, - msgBody :: MsgBody + msgBody :: MsgBody, + -- | Hash of the message as computed by agent. + internalHash :: MsgHash } deriving (Eq, Show) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index c48110065..ae9473e7d 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -40,6 +40,7 @@ import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema) import Simplex.Messaging.Agent.Transmission +import Simplex.Messaging.Parsers (parseAll) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftIOEither) import System.Exit (ExitCode (ExitFailure), exitWith) @@ -276,6 +277,16 @@ instance ToField RcvMsgStatus where toField = toField . show instance ToField SndMsgStatus where toField = toField . show +instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity + +instance FromField MsgIntegrity where + fromField = \case + f@(Field (SQLBlob b) _) -> + case parseAll msgIntegrityP b of + Right k -> Ok k + Left e -> returnError ConversionFailed f ("can't parse msg integrity field: " ++ e) + f -> returnError ConversionFailed f "expecting SQLBlob column type" + fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a fromFieldToReadable_ = \case f@(Field (SQLText t) _) -> @@ -528,10 +539,12 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = [sql| INSERT INTO rcv_messages ( conn_alias, internal_rcv_id, internal_id, external_snd_id, external_snd_ts, - broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts) + broker_id, broker_ts, rcv_status, ack_brocker_ts, ack_sender_ts, + internal_hash, external_prev_snd_hash, integrity) VALUES (:conn_alias,:internal_rcv_id,:internal_id,:external_snd_id,:external_snd_ts, - :broker_id,:broker_ts,:rcv_status, NULL, NULL); + :broker_id,:broker_ts,:rcv_status, NULL, NULL, + :internal_hash,:external_prev_snd_hash,:integrity); |] [ ":conn_alias" := connAlias, ":internal_rcv_id" := internalRcvId, @@ -540,7 +553,10 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = ":external_snd_ts" := snd senderMeta, ":broker_id" := fst brokerMeta, ":broker_ts" := snd brokerMeta, - ":rcv_status" := Received + ":rcv_status" := Received, + ":internal_hash" := internalHash, + ":external_prev_snd_hash" := externalPrevSndHash, + ":integrity" := msgIntegrity ] updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () @@ -556,7 +572,7 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} = AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; |] [ ":last_external_snd_msg_id" := fst senderMeta, - ":last_rcv_msg_hash" := msgHash, + ":last_rcv_msg_hash" := internalHash, ":conn_alias" := connAlias, ":last_internal_rcv_msg_id" := internalRcvId ] @@ -616,14 +632,15 @@ insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = dbConn [sql| INSERT INTO snd_messages - ( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts) + ( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts, internal_hash) VALUES - (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL); + (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash); |] [ ":conn_alias" := connAlias, ":internal_snd_id" := internalSndId, ":internal_id" := internalId, - ":snd_status" := Created + ":snd_status" := Created, + ":internal_hash" := internalHash ] updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () @@ -637,7 +654,7 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} = WHERE conn_alias = :conn_alias AND last_internal_snd_msg_id = :last_internal_snd_msg_id; |] - [ ":last_snd_msg_hash" := msgHash, + [ ":last_snd_msg_hash" := internalHash, ":conn_alias" := connAlias, ":last_internal_snd_msg_id" := internalSndId ] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs index 959a071e9..99ba67ccf 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs @@ -138,6 +138,9 @@ rcvMessages = rcv_status TEXT NOT NULL, ack_brocker_ts TEXT, ack_sender_ts TEXT, + internal_hash BLOB NOT NULL, + external_prev_snd_hash BLOB NOT NULL, + integrity BLOB NOT NULL, PRIMARY KEY (conn_alias, internal_rcv_id), FOREIGN KEY (conn_alias, internal_id) REFERENCES messages (conn_alias, internal_id) @@ -155,6 +158,7 @@ sndMessages = snd_status TEXT NOT NULL, sent_ts TEXT, delivered_ts TEXT, + internal_hash BLOB NOT NULL, PRIMARY KEY (conn_alias, internal_snd_id), FOREIGN KEY (conn_alias, internal_id) REFERENCES messages (conn_alias, internal_id) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index b29b919cf..fbaff38b6 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -315,7 +315,7 @@ commandP = sendCmd = ACmd SClient . SEND <$> A.takeByteString sentResp = ACmd SAgent . SENT <$> A.decimal message = do - msgIntegrity <- integrity <* A.space + msgIntegrity <- msgIntegrityP <* A.space recipientMeta <- "R=" *> partyMeta A.decimal brokerMeta <- "B=" *> partyMeta base64P senderMeta <- "S=" *> partyMeta A.decimal @@ -326,13 +326,16 @@ commandP = <|> A.space *> (ReplyVia <$> smpServerP) <|> pure ReplyOn partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space - integrity = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) + agentError = ACmd SAgent . ERR <$> agentErrorTypeP + +msgIntegrityP :: Parser MsgIntegrity +msgIntegrityP = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) + where msgErrorType = "ID " *> (MsgBadId <$> A.decimal) <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash <|> "DUPLICATE" $> MsgDuplicate - agentError = ACmd SAgent . ERR <$> agentErrorTypeP parseCommand :: ByteString -> Either AgentErrorType ACmd parseCommand = parse commandP $ CMD SYNTAX @@ -369,16 +372,17 @@ serializeCommand = \case ReplyOn -> "" showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis - serializeMsgIntegrity :: MsgIntegrity -> ByteString - serializeMsgIntegrity = \case - MsgOk -> "OK" - MsgError e -> - "ERR " <> case e of - MsgSkipped fromMsgId toMsgId -> - B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId] - MsgBadId aMsgId -> "ID " <> bshow aMsgId - MsgBadHash -> "HASH" - MsgDuplicate -> "DUPLICATE" + +serializeMsgIntegrity :: MsgIntegrity -> ByteString +serializeMsgIntegrity = \case + MsgOk -> "OK" + MsgError e -> + "ERR " <> case e of + MsgSkipped fromMsgId toMsgId -> + B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId] + MsgBadId aMsgId -> "ID " <> bshow aMsgId + MsgBadHash -> "HASH" + MsgDuplicate -> "DUPLICATE" agentErrorTypeP :: Parser AgentErrorType agentErrorTypeP = diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 19f37d78f..808699fe7 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -79,8 +79,8 @@ import Data.ByteString.Lazy (fromStrict, toStrict) import Data.String import Data.Typeable (Typeable) import Data.X509 -import Database.SQLite.Simple as DB -import Database.SQLite.Simple.FromField +import Database.SQLite.Simple (ResultError (..), SQLData (..)) +import Database.SQLite.Simple.FromField (FieldParser, FromField (..), returnError) import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.ToField (ToField (..)) diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 450e568f3..5442705f0 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -323,7 +323,7 @@ ts :: UTCTime ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData -mkRcvMsgData internalId internalRcvId externalSndId brokerId msgHash = +mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = RcvMsgData { internalId, internalRcvId, @@ -331,7 +331,8 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId msgHash = senderMeta = (externalSndId, ts), brokerMeta = (brokerId, ts), msgBody = hw, - msgHash = msgHash, + internalHash, + externalPrevSndHash = "hash_from_sender", msgIntegrity = MsgOk } @@ -351,13 +352,13 @@ testCreateRcvMsg = do testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData -mkSndMsgData internalId internalSndId msgHash = +mkSndMsgData internalId internalSndId internalHash = SndMsgData { internalId, internalSndId, internalTs = ts, msgBody = hw, - msgHash = msgHash + internalHash } testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation