diff --git a/.gitignore b/.gitignore index 83446fcf4..965b1e528 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ *.lock -*.cabal *.db *.db.bak *.session.sql diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index d2e7ae835..bb0685549 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -8,20 +8,10 @@ import Control.Logger.Simple import qualified Data.List.NonEmpty as L import Simplex.Messaging.Agent (runSMPAgent) import Simplex.Messaging.Agent.Env.SQLite -import Simplex.Messaging.Client (smpDefaultConfig) import Simplex.Messaging.Transport (TCP, Transport (..)) cfg :: AgentConfig -cfg = - AgentConfig - { tcpPort = "5224", - smpServers = L.fromList ["localhost:5223#bU0K+bRg24xWW//lS0umO1Zdw/SXqpJNtm1/RrPLViE="], - rsaKeySize = 2048 `div` 8, - connIdBytes = 12, - tbqSize = 16, - dbFile = "smp-agent.db", - smpCfg = smpDefaultConfig - } +cfg = defaultAgentConfig {smpServers = L.fromList ["localhost:5223#bU0K+bRg24xWW//lS0umO1Zdw/SXqpJNtm1/RrPLViE="]} logCfg :: LogConfig logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} diff --git a/apps/smp-server/Main.hs b/apps/smp-server/Main.hs index c29f4800e..b71bed5d4 100644 --- a/apps/smp-server/Main.hs +++ b/apps/smp-server/Main.hs @@ -29,19 +29,25 @@ import System.Directory (createDirectoryIfMissing, doesFileExist, removeFile) import System.Exit (exitFailure) import System.FilePath (combine) import System.IO (IOMode (..), hFlush, stdout) +import Text.Read (readEither) defaultServerPort :: ServiceName defaultServerPort = "5223" +defaultBlockSize :: Int +defaultBlockSize = 4096 + serverConfig :: ServerConfig serverConfig = ServerConfig { tbqSize = 16, + msgQueueQuota = 256, queueIdBytes = 12, msgIdBytes = 6, -- below parameters are set based on ini file /etc/opt/simplex/smp-server.ini transports = undefined, storeLog = undefined, + blockSize = undefined, serverPrivateKey = undefined } @@ -96,9 +102,9 @@ getConfig opts = do pure $ makeConfig ini pk storeLog makeConfig :: IniOpts -> C.FullPrivateKey -> Maybe (StoreLog 'ReadMode) -> ServerConfig -makeConfig IniOpts {serverPort, enableWebsockets} pk storeLog = +makeConfig IniOpts {serverPort, blockSize, enableWebsockets} pk storeLog = let transports = (serverPort, transport @TCP) : [("80", transport @WS) | enableWebsockets] - in serverConfig {serverPrivateKey = pk, storeLog, transports} + in serverConfig {serverPrivateKey = pk, storeLog, blockSize, transports} printConfig :: ServerConfig -> IO () printConfig ServerConfig {serverPrivateKey, storeLog} = do @@ -139,6 +145,7 @@ data IniOpts = IniOpts storeLogFile :: FilePath, serverKeyFile :: FilePath, serverPort :: ServiceName, + blockSize :: Int, enableWebsockets :: Bool } @@ -151,7 +158,8 @@ readIni = do serverKeyFile = opt defaultKeyFile "TRANSPORT" "key_file" ini serverPort = opt defaultServerPort "TRANSPORT" "port" ini enableWebsockets = (== Right "on") $ lookupValue "TRANSPORT" "websockets" ini - pure IniOpts {enableStoreLog, storeLogFile, serverKeyFile, serverPort, enableWebsockets} + blockSize <- liftEither . readEither $ opt (show defaultBlockSize) "TRANSPORT" "block_size" ini + pure IniOpts {enableStoreLog, storeLogFile, serverKeyFile, serverPort, blockSize, enableWebsockets} where opt :: String -> Text -> Text -> Ini -> String opt def section key ini = either (const def) T.unpack $ lookupValue section key ini @@ -177,6 +185,9 @@ createIni ServerOpts {enableStoreLog} = do <> "\n\ \# port: " <> defaultServerPort + <> "\n\ + \# block_size: " + <> show defaultBlockSize <> "\n\ \websockets: on\n" pure @@ -185,6 +196,7 @@ createIni ServerOpts {enableStoreLog} = do storeLogFile = defaultStoreLogFile, serverKeyFile = defaultKeyFile, serverPort = defaultServerPort, + blockSize = defaultBlockSize, enableWebsockets = True } @@ -222,7 +234,7 @@ confirm msg = do when (map toLower ok /= "y") exitFailure serverKeyHash :: C.FullPrivateKey -> B.ByteString -serverKeyHash = encode . C.unKeyHash . C.publicKeyHash . C.publicKey +serverKeyHash = encode . C.unKeyHash . C.publicKeyHash . C.publicKey' openStoreLog :: ServerOpts -> IniOpts -> IO (Maybe (StoreLog 'ReadMode)) openStoreLog ServerOpts {enableStoreLog = l} IniOpts {enableStoreLog = l', storeLogFile = f} diff --git a/migrations/20210101_initial.sql b/migrations/20210101_initial.sql index 050d4b4d1..75716865a 100644 --- a/migrations/20210101_initial.sql +++ b/migrations/20210101_initial.sql @@ -67,7 +67,7 @@ CREATE TABLE IF NOT EXISTS messages( internal_ts TEXT NOT NULL, internal_rcv_id INTEGER, internal_snd_id INTEGER, - body TEXT NOT NULL, + body TEXT NOT NULL, -- deprecated PRIMARY KEY (conn_alias, internal_id), FOREIGN KEY (conn_alias) REFERENCES connections (conn_alias) diff --git a/migrations/20210529_broadcasts.sql b/migrations/20210529_broadcasts.sql deleted file mode 100644 index 3095f0572..000000000 --- a/migrations/20210529_broadcasts.sql +++ /dev/null @@ -1,10 +0,0 @@ -CREATE TABLE IF NOT EXISTS broadcasts ( - broadcast_id BLOB NOT NULL, - PRIMARY KEY (broadcast_id) -) WITHOUT ROWID; - -CREATE TABLE IF NOT EXISTS broadcast_connections ( - broadcast_id BLOB NOT NULL REFERENCES broadcasts (broadcast_id) ON DELETE CASCADE, - conn_alias BLOB NOT NULL REFERENCES connections (conn_alias), - PRIMARY KEY (broadcast_id, conn_alias) -) WITHOUT ROWID; diff --git a/migrations/20210624_confirmations.sql b/migrations/20210624_confirmations.sql new file mode 100644 index 000000000..f7b1e8e85 --- /dev/null +++ b/migrations/20210624_confirmations.sql @@ -0,0 +1,9 @@ +CREATE TABLE conn_confirmations ( + confirmation_id BLOB NOT NULL PRIMARY KEY, + conn_alias BLOB NOT NULL REFERENCES connections ON DELETE CASCADE, + sender_key BLOB NOT NULL, + sender_conn_info BLOB NOT NULL, + accepted INTEGER NOT NULL, + own_conn_info BLOB, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +) WITHOUT ROWID; diff --git a/migrations/20210809_snd_messages.sql b/migrations/20210809_snd_messages.sql new file mode 100644 index 000000000..ef1624ec3 --- /dev/null +++ b/migrations/20210809_snd_messages.sql @@ -0,0 +1,3 @@ +ALTER TABLE messages ADD msg_body BLOB NOT NULL DEFAULT x''; -- this field replaces body TEXT +-- TODO possibly migrate the data from body if it is possible in migration +ALTER TABLE snd_messages ADD previous_msg_hash BLOB NOT NULL DEFAULT x''; diff --git a/package.yaml b/package.yaml index a12a0333a..88889931c 100644 --- a/package.yaml +++ b/package.yaml @@ -22,7 +22,7 @@ extra-source-files: - CHANGELOG.md dependencies: - - ansi-terminal == 0.10.* + - ansi-terminal >= 0.10 && < 0.12 - asn1-encoding == 0.9.* - asn1-types == 0.3.* - async == 2.2.* @@ -30,21 +30,22 @@ dependencies: - base >= 4.7 && < 5 - base64-bytestring >= 1.0 && < 1.3 - bytestring == 0.10.* - - constraints == 0.12.* + - composition == 1.0.* + - constraints >= 0.12 && < 0.14 - containers == 0.6.* - - cryptonite == 0.27.* + - cryptonite >= 0.27 && < 0.30 - direct-sqlite == 2.3.* - directory == 1.3.* - file-embed == 0.0.14.* - filepath == 1.4.* - - generic-random == 1.3.* + - generic-random >= 1.3 && < 1.5 - iso8601-time == 0.1.* - memory == 0.15.* - mtl == 2.2.* - network == 3.1.* - network-transport == 0.5.* - QuickCheck == 2.14.* - - random == 1.1.* + - random >= 1.1 && < 1.3 - simple-logger == 0.1.* - sqlite-simple == 0.4.* - stm == 2.5.* @@ -67,7 +68,7 @@ executables: dependencies: - cryptostore == 0.2.* - ini == 0.4.* - - optparse-applicative == 0.15.* + - optparse-applicative >= 0.15 && < 0.17 - simplexmq ghc-options: - -threaded @@ -89,7 +90,6 @@ tests: - hspec == 2.7.* - hspec-core == 2.7.* - HUnit == 1.6.* - - random == 1.1.* - QuickCheck == 2.14.* - timeit == 2.0.* diff --git a/protocol/agent-protocol.md b/protocol/agent-protocol.md index 94ee8d914..9d80beef0 100644 --- a/protocol/agent-protocol.md +++ b/protocol/agent-protocol.md @@ -16,9 +16,10 @@ - [Client commands and server responses](#client-commands-and-server-responses) - [NEW command and INV response](#new-command-and-inv-response) - [JOIN command](#join-command) - - [CON notification](#con-notification) + - [REQ notification and ACPT command](#req-notification-and-acpt-command) + - [INFO and CON notifications](#info-and-con-notifications) - [SUB command](#sub-command) - - [SEND command and SENT response](#send-command-and-sent-response) + - [SEND command and MID, SENT and MERR responses](#send-command-and-mid-sent-and-merr-responses) - [MSG notification](#msg-notification) - [END notification](#end-notification) - [OFF command](#off-command) @@ -73,18 +74,22 @@ SMP agent protocol has 3 main parts: The procedure of establishing a duplex connection is explained on the example of Alice and Bob creating a bi-directional connection comprised of two unidirectional (simplex) queues, using SMP agents (A and B) to facilitate it, and two different SMP servers (which could be the same server). It is shown on the diagram above and has these steps: 1. Alice requests the new connection from the SMP agent A using `NEW` command. -2. Agent A creates an SMP queue on the server (using [SMP protocol](./simplex-messaging.md)) and responds to Alice with the invitation that contains queue information and the encryption key Bob's agent B should use. The invitation format is described in [Connection invitation](#connection-invitation). +2. Agent A creates an SMP connection on the server (using [SMP protocol](./simplex-messaging.md)) and responds to Alice with the invitation that contains queue information and the encryption key Bob's agent B should use. The invitation format is described in [Connection invitation](#connection-invitation). 3. Alice sends the invitation to Bob via any secure channel they have (out-of-band message). 4. Bob sends `JOIN` command with the invitation as a parameter to agent B to accept the connection. -5. Establishing Alice's SMP queue (with SMP protocol commands): - - Agent B sends unauthenticated message to SMP queue with ephemeral key that will be used to authenticate commands to the queue, as described in SMP protocol. - - Agent A receives the KEY and secures the queue. +5. Establishing Alice's SMP connection (with SMP protocol commands): + - Agent B sends an "SMP confirmation" to the SMP queue specified in the invitation - SMP confirmation is an unauthenticated message with an ephemeral key that will be used to authenticate Bob's commands to the queue, as described in SMP protocol, and Bob's info. + - Agent A receives the SMP confirmation containing Bob's key and info. + - Agent A notifies Alice sending REQ notification with Bob's info. + - Alice accepts connection request with ACPT command. + - Agent A secures the queue. - Agent B tries sending authenticated SMP SEND command with agent `HELLO` message until it succeeds. Once it succeeds, Bob's agent "knows" the queue is secured. 6. Agent B creates a new SMP queue on the server. 7. Establish Bob's SMP queue: - Agent B sends `REPLY` message with the invitation to this 2nd queue to Alice's agent (via the 1st queue). - - Agent A having received this `REPLY` message sends unauthenticated message to SMP queue with Alice agent's ephemeral key that will be used to authenticate commands to the queue, as described in SMP protocol. - - Bob's agent receives the key and secures the queue. + - Agent A, having received this `REPLY` message, sends unauthenticated message to SMP queue with Alice agent's ephemeral key that will be used to authenticate Alice's commands to the queue, as described in SMP protocol, and Alice's info. + - Bob's agent receives the key and Alice's information and secures the queue. + - Bob's agent sends the notification `INFO` with Alice's information to Bob. - Alice's agent keeps sending `HELLO` message until it succeeds. 8. Agents A and B notify Alice and Bob that connection is established. - Once sending `HELLO` succeeds, Alice's agent sends to Alice `CON` notification that confirms that now both parties can communicate. @@ -193,13 +198,25 @@ cId = encoded cName = 1*(ALPHA / DIGIT / "_" / "-") agentCommand = (userCmd / agentMsg) CRLF -userCmd = newCmd / joinCmd / subscribeCmd / sendCmd / acknowledgeCmd / suspendCmd / deleteCmd -agentMsg = invitation / connected / unsubscribed / message / sent / received / ok / error +userCmd = newCmd / joinCmd / acceptCmd / subscribeCmd / sendCmd / acknowledgeCmd / suspendCmd / deleteCmd +agentMsg = invitation / connRequest / connInfo / connected / unsubscribed / connDown / connUp / messageId / sent / messageError / message / received / ok / error newCmd = %s"NEW" [SP %s"NO_ACK"] ; response is `invitation` or `error` +; NO_ACK parameter currently not supported invitation = %s"INV" SP ; `queueInfo` is the same as in out-of-band message, see SMP protocol +connRequest = %s"REQ" SP confirmationId SP msgBody +; msgBody here is any binary information identifying connection request + +confirmationId = 1*DIGIT + +acceptCmd = %s"ACPT" SP confirmationId SP msgBody +; msgBody here is any binary information identifying connecting party + +connInfo = %s"INFO" SP msgBody +; msgBody here is any binary information identifying connecting party + connected = %s"CON" subscribeCmd = %s"SUB" ; response is `ok` or `error` @@ -208,6 +225,12 @@ unsubscribed = %s"END" ; when another agent (or another client of the same agent) ; subscribes to the same SMP queue on the server +connDown = %s"DOWN" +; lost connection (e.g. because of Internet connectivity or server is down) + +connUp = %s"UP" +; restored connection + joinCmd = %s"JOIN" SP [SP %s"NO_REPLY"] [SP %s"NO_ACK"] ; `queueInfo` is the same as in out-of-band message, see SMP protocol ; response is `connected` or `error` @@ -225,18 +248,22 @@ binaryMsg = size CRLF msgBody CRLF ; the last CRLF is in addition to CRLF in the size = 1*DIGIT ; size in bytes msgBody = *OCTET ; any content of specified size - safe for binary +messageId = %s"MID" SP agentMsgId + sent = %s"SENT" SP agentMsgId +messageError = %s"MERR" SP agentMsgId SP + message = %s"MSG" SP msgIntegrity SP recipientMeta SP brokerMeta SP senderMeta SP binaryMsg recipientMeta = %s"R=" agentMsgId "," agentTimestamp ; receiving agent message metadata brokerMeta = %s"B=" brokerMsgId "," brokerTimestamp ; broker (server) message metadata senderMeta = %s"S=" agentMsgId "," agentTimestamp ; sending agent message metadata brokerMsgId = encoded brokerTimestamp = -msgIntegrity = ok / messageError +msgIntegrity = ok / msgIntegrityError -messageError = %s"ERR" SP messageErrorType -messageErrorType = skippedMsgErr / badMsgIdErr / badHashErr +msgIntegrityError = %s"ERR" SP msgIntegrityErrorType +msgIntegrityErrorType = skippedMsgErr / badMsgIdErr / badHashErr skippedMsgErr = %s"NO_ID" SP missingFromMsgId SP missingToMsgId badMsgIdErr = %s"ID" SP previousMsgId ; ID is lower than the previous @@ -247,7 +274,6 @@ missingToMsgId = agentMsgId previousMsgId = agentMsgId acknowledgeCmd = %s"ACK" SP agentMsgId ; ID assigned by receiving agent (in MSG "R") -; currently not implemented received = %s"RCVD" SP agentMsgId ; ID assigned by sending agent (in SENT response) ; currently not implemented @@ -261,27 +287,41 @@ error = %s"ERR" SP #### NEW command and INV response -`NEW` command is used to create a connection and an invitation to be sent out-of-band to another protocol user. It should be used by the client of the agent that initiates creating a duplex connection. +`NEW` command is used to create a connection and an invitation to be sent out-of-band to another protocol user (the joining party). It should be used by the client of the agent that initiates creating a duplex connection (the initiating party). -`INV` response is sent by the agent to the client. +`INV` response is sent by the agent to the client of the initiating party. #### JOIN command -It is used to create a connection and accept the invitation received out-of-band. It should be used by the client of the agent that accepts the connection. +It is used to create a connection and accept the invitation received out-of-band. It should be used by the client of the agent that accepts the connection (the joining party). -#### CON notification +#### REQ notification and ACPT command -It is sent by both agents managing duplex connection to their clients once the connection is established and ready to accept client messages. +When the joining party uses `JOIN` command, the initiating party will receive `REQ` notification with some numeric identifier and an additional binary information, that can be used to identify the joining party or for any other purpose. + +To continue with the connection the initiating party should use `ACPT` command. + +#### INFO and CON notifications + +After the initiating party proceeds with the connection using `ACPT` command, the joining party will receive `INFO` notification that can be used to identify the initiating party or for any other purpose. + +Once the connection is established and ready to accept client messages, both agents will send `CON` notification to their clients. #### SUB command This command can be used by the client to resume receiving messages from the connection that was created in another TCP/client session. Agent response to this command can be `OK` or `ERR` in case connection does not exist (or can only be used to send connections - e.g. when the reply queue was not created). -#### SEND command and SENT response +#### SEND command and MID, SENT and MERR responses -`SEND` command is used to the client to send messages +`SEND` command is used by the client to send messages. -`SENT` response is sent by the agent to confirm that the message was delivered to the SMP server. Message ID in this response is the sequential message number that includes both sent and received messages in the connection. +`MID` notification with the message ID (the sequential message number that includes both sent and received messages in the connection) is sent to the client to confirm that the message is accepted by the agent, before it is sent to the SMP server. + +`SENT` response is sent by the agent to confirm that the message was delivered to the SMP server. This notification contains the same message ID as `MID` notification. `SENT` notification, depending on network availability, can be sent at any time later, potentially in the next client session. + +In case of the failure to send the message for any other reason than network connection or message queue quota - e.g. authentication error (`ERR AUTH`) or syntax error (`ERR CMD error`), the agent will send to the client `MERR` notification with the message ID, and this message delivery will no longer be attempted. + +In case of client disconnecting from the agent, the pending messages will not be sent until the client re-connects to the agent and subscribes to the connection that has pending messages. #### MSG notification @@ -294,6 +334,12 @@ It is sent by the agent to the client when agent receives the message from the S It is sent by the agent to the client when agent receives SMP protocol `END` notification from SMP server. It indicates that another agent has subscribed to the same SMP queue on the server and the server terminated the subscription of the current agent. +#### DOWN and UP notifications + +These notifications are sent when server or network connection is, respectively, `DOWN` or back `UP`. + +All the subscriptions made in the current client session will be automatically resumed when `UP` notification is received. + #### OFF command It is used to suspend the receiving SMP queue - sender will no longer be able to send the messages to the connection, but the recipient can retrieve the remaining messages. Agent response to this command can be `OK` or `ERR`. This command is irreversible. diff --git a/protocol/diagrams/duplex-messaging/duplex-creating.mmd b/protocol/diagrams/duplex-messaging/duplex-creating.mmd index d31a639e0..d8df2d199 100644 --- a/protocol/diagrams/duplex-messaging/duplex-creating.mmd +++ b/protocol/diagrams/duplex-messaging/duplex-creating.mmd @@ -27,11 +27,14 @@ sequenceDiagram note over BA: status: NONE/NEW note over BA, AA: 5. establish Alice's SMP queue - BA ->> AS: SEND: KEY: sender's server key + BA ->> AS: SEND: Bob's info and sender server key (SMP confirmation) note over BA: status: NONE/CONFIRMED activate BA - AS ->> AA: MSG: KEY: sender's server key + AS ->> AA: MSG: Bob's info and
sender server key note over AA: status: CONFIRMED/NONE + AA ->> AS: ACK: confirm message + AA ->> A: REQ: connection request ID
and Bob's info + A ->> AA: ACPT: accept connection request,
send Alice's info AA ->> AS: KEY: secure queue note over AA: status: SECURED/NONE @@ -40,6 +43,7 @@ sequenceDiagram note over BA: status: NONE/ACTIVE AS ->> AA: MSG: HELLO: Alice's agent
knows Bob can send note over AA: status: ACTIVE/NONE + AA ->> AS: ACK: confirm message note over BA, BS: 6. create Bob's SMP queue BA ->> BS: NEW: create SMP queue @@ -51,12 +55,15 @@ sequenceDiagram note over BA: status: PENDING/ACTIVE AS ->> AA: MSG: REPLY: invitation
to connect note over AA: status: ACTIVE/NEW + AA ->> AS: ACK: confirm message - AA ->> BS: SEND: KEY: sender's server key + AA ->> BS: SEND: Alice's info and sender's server key note over AA: status: ACTIVE/CONFIRMED activate AA - BS ->> BA: MSG: KEY: sender's server key + BS ->> BA: MSG: Alice's info and
sender's server key note over BA: status: CONFIRMED/ACTIVE + BA ->> B: INFO: Alice's info + BA ->> BS: ACK: confirm message BA ->> BS: KEY: secure queue note over BA: status: SECURED/ACTIVE @@ -65,6 +72,7 @@ sequenceDiagram note over AA: status: ACTIVE/ACTIVE BS ->> BA: MSG: HELLO: Bob's agent
knows Alice can send note over BA: status: ACTIVE/ACTIVE + BA ->> BS: ACK: confirm message note over A, B: 8. notify users about connection success AA ->> A: CON: connected diff --git a/protocol/diagrams/duplex-messaging/duplex-creating.svg b/protocol/diagrams/duplex-messaging/duplex-creating.svg index c92600134..138935d3b 100644 --- a/protocol/diagrams/duplex-messaging/duplex-creating.svg +++ b/protocol/diagrams/duplex-messaging/duplex-creating.svg @@ -1 +1 @@ -AliceAlice'sagentAlice'sserverBob'sserverBob'sagentBobstatus (receive/send): NONE/NONE1. request connection from agentNEW: createduplex connection2. create Alice's SMP queueNEW: create SMP queueIDS: SMP queue IDsstatus: NEW/NONEINV: invitationto connectstatus: PENDING/NONE3. out-of-band invitationOOB: invitation to connect4. accept connectionJOIN:via invitation infostatus: NONE/NEW5. establish Alice's SMP queueSEND: KEY: sender's server keystatus: NONE/CONFIRMEDMSG: KEY: sender's server keystatus: CONFIRMED/NONEKEY: secure queuestatus: SECURED/NONESEND: HELLO: try sending until successfulstatus: NONE/ACTIVEMSG: HELLO: Alice's agentknows Bob can sendstatus: ACTIVE/NONE6. create Bob's SMP queueNEW: create SMP queueIDS: SMP queue IDsstatus: NEW/ACTIVE7. establish Bob's SMP queueSEND: REPLY: invitation to the connectstatus: PENDING/ACTIVEMSG: REPLY: invitationto connectstatus: ACTIVE/NEWSEND: KEY: sender's server keystatus: ACTIVE/CONFIRMEDMSG: KEY: sender's server keystatus: CONFIRMED/ACTIVEKEY: secure queuestatus: SECURED/ACTIVESEND: HELLO: try sending until successfulstatus: ACTIVE/ACTIVEMSG: HELLO: Bob's agentknows Alice can sendstatus: ACTIVE/ACTIVE8. notify users about connection successCON: connectedCON: connectedAliceAlice'sagentAlice'sserverBob'sserverBob'sagentBob \ No newline at end of file +AliceAlice'sagentAlice'sserverBob'sserverBob'sagentBobstatus (receive/send): NONE/NONE1. request connection from agentNEW: createduplex connection2. create Alice's SMP queueNEW: create SMP queueIDS: SMP queue IDsstatus: NEW/NONEINV: invitationto connectstatus: PENDING/NONE3. out-of-band invitationOOB: invitation to connect4. accept connectionJOIN:via invitation infostatus: NONE/NEW5. establish Alice's SMP queueSEND: Bob's info and sender server key (SMP confirmation)status: NONE/CONFIRMEDMSG: Bob's info andsender server keystatus: CONFIRMED/NONEACK: confirm messageREQ: connection request IDand Bob's infoACPT: accept connection request,send Alice's infoKEY: secure queuestatus: SECURED/NONESEND: HELLO: try sending until successfulstatus: NONE/ACTIVEMSG: HELLO: Alice's agentknows Bob can sendstatus: ACTIVE/NONEACK: confirm message6. create Bob's SMP queueNEW: create SMP queueIDS: SMP queue IDsstatus: NEW/ACTIVE7. establish Bob's SMP queueSEND: REPLY: invitation to the connectstatus: PENDING/ACTIVEMSG: REPLY: invitationto connectstatus: ACTIVE/NEWACK: confirm messageSEND: Alice's info and sender's server keystatus: ACTIVE/CONFIRMEDMSG: Alice's info andsender's server keystatus: CONFIRMED/ACTIVEINFO: Alice's infoACK: confirm messageKEY: secure queuestatus: SECURED/ACTIVESEND: HELLO: try sending until successfulstatus: ACTIVE/ACTIVEMSG: HELLO: Bob's agentknows Alice can sendstatus: ACTIVE/ACTIVEACK: confirm message8. notify users about connection successCON: connectedCON: connectedAliceAlice'sagentAlice'sserverBob'sserverBob'sagentBob \ No newline at end of file diff --git a/protocol/simplex-messaging.md b/protocol/simplex-messaging.md index a2eadd488..884ae98ec 100644 --- a/protocol/simplex-messaging.md +++ b/protocol/simplex-messaging.md @@ -410,7 +410,7 @@ secure = %s"KEY" SP senderKey senderKey = %s"rsa:" x509encoded ; the sender's RSA public key for this queue ``` -`senderKey` is received from the sender as part of the first message - see [Send Message Command](#send-message-command). +`senderKey` is received from the sender as part of the first message - see [Send Message](#send-message) command. Once the queue is secured only signed messages can be sent to it. @@ -535,7 +535,8 @@ No further messages should be delivered to unsubscribed transport connection. - transmission has no required signature or queue ID (`NO_AUTH`) - transmission has unexpected credentials (`HAS_AUTH`) - transmission has no required queue ID (`NO_QUEUE`) -- authentication error (`AUTH`) - incorrect signature, unknown (or suspended) queue, sender's ID is used in place of recipient's and vice versa, and some other cases (see [Send message command](#send-message-command)). +- authentication error (`AUTH`) - incorrect signature, unknown (or suspended) queue, sender's ID is used in place of recipient's and vice versa, and some other cases (see [Send message](#send-message) command). +- message queue quota exceeded error (`QUOTA`) - too many messages were sent to the message queue. Further messages can only be sent after the recipient retrieves the messages. - incorrect message body size (`SIZE`). - internal server error (`INTERNAL`). diff --git a/simplexmq.cabal b/simplexmq.cabal new file mode 100644 index 000000000..c34c18ef6 --- /dev/null +++ b/simplexmq.cabal @@ -0,0 +1,250 @@ +cabal-version: 1.12 + +-- This file has been generated from package.yaml by hpack version 0.34.4. +-- +-- see: https://github.com/sol/hpack +-- +-- hash: 5169db4a4922766c79f08cbdb91d4c765520372273ab432569eba25a253a8dbb + +name: simplexmq +version: 0.3.2 +synopsis: SimpleXMQ message broker +description: This package includes <./docs/Simplex-Messaging-Server.html server>, + <./docs/Simplex-Messaging-Client.html client> and + <./docs/Simplex-Messaging-Agent.html agent> for SMP protocols: + . + * + * + . + See built with SimpleXMQ broker. +category: Chat, Network, Web, System, Cryptography +homepage: https://github.com/simplex-chat/simplexmq#readme +author: simplex.chat +maintainer: chat@simplex.chat +copyright: 2020 simplex.chat +license: AGPL-3 +license-file: LICENSE +build-type: Simple +extra-source-files: + README.md + CHANGELOG.md + +library + exposed-modules: + Simplex.Messaging.Agent + Simplex.Messaging.Agent.Client + Simplex.Messaging.Agent.Env.SQLite + Simplex.Messaging.Agent.Protocol + Simplex.Messaging.Agent.RetryInterval + Simplex.Messaging.Agent.Store + Simplex.Messaging.Agent.Store.SQLite + Simplex.Messaging.Agent.Store.SQLite.Migrations + Simplex.Messaging.Client + Simplex.Messaging.Crypto + Simplex.Messaging.Parsers + Simplex.Messaging.Protocol + Simplex.Messaging.Server + Simplex.Messaging.Server.Env.STM + Simplex.Messaging.Server.MsgStore + Simplex.Messaging.Server.MsgStore.STM + Simplex.Messaging.Server.QueueStore + Simplex.Messaging.Server.QueueStore.STM + Simplex.Messaging.Server.StoreLog + Simplex.Messaging.Transport + Simplex.Messaging.Transport.WebSockets + Simplex.Messaging.Util + other-modules: + Paths_simplexmq + hs-source-dirs: + src + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns + build-depends: + QuickCheck ==2.14.* + , ansi-terminal >=0.10 && <0.12 + , asn1-encoding ==0.9.* + , asn1-types ==0.3.* + , async ==2.2.* + , attoparsec ==0.13.* + , base >=4.7 && <5 + , base64-bytestring >=1.0 && <1.3 + , bytestring ==0.10.* + , composition ==1.0.* + , constraints >=0.12 && <0.14 + , containers ==0.6.* + , cryptonite >=0.27 && <0.30 + , direct-sqlite ==2.3.* + , directory ==1.3.* + , file-embed ==0.0.14.* + , filepath ==1.4.* + , generic-random >=1.3 && <1.5 + , iso8601-time ==0.1.* + , memory ==0.15.* + , mtl ==2.2.* + , network ==3.1.* + , network-transport ==0.5.* + , random >=1.1 && <1.3 + , simple-logger ==0.1.* + , sqlite-simple ==0.4.* + , stm ==2.5.* + , template-haskell ==2.16.* + , text ==1.2.* + , time ==1.9.* + , transformers ==0.5.* + , unliftio ==0.2.* + , unliftio-core ==0.2.* + , websockets ==0.12.* + , x509 ==1.7.* + default-language: Haskell2010 + +executable smp-agent + main-is: Main.hs + other-modules: + Paths_simplexmq + hs-source-dirs: + apps/smp-agent + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded + build-depends: + QuickCheck ==2.14.* + , ansi-terminal >=0.10 && <0.12 + , asn1-encoding ==0.9.* + , asn1-types ==0.3.* + , async ==2.2.* + , attoparsec ==0.13.* + , base >=4.7 && <5 + , base64-bytestring >=1.0 && <1.3 + , bytestring ==0.10.* + , composition ==1.0.* + , constraints >=0.12 && <0.14 + , containers ==0.6.* + , cryptonite >=0.27 && <0.30 + , direct-sqlite ==2.3.* + , directory ==1.3.* + , file-embed ==0.0.14.* + , filepath ==1.4.* + , generic-random >=1.3 && <1.5 + , iso8601-time ==0.1.* + , memory ==0.15.* + , mtl ==2.2.* + , network ==3.1.* + , network-transport ==0.5.* + , random >=1.1 && <1.3 + , simple-logger ==0.1.* + , simplexmq + , sqlite-simple ==0.4.* + , stm ==2.5.* + , template-haskell ==2.16.* + , text ==1.2.* + , time ==1.9.* + , transformers ==0.5.* + , unliftio ==0.2.* + , unliftio-core ==0.2.* + , websockets ==0.12.* + , x509 ==1.7.* + default-language: Haskell2010 + +executable smp-server + main-is: Main.hs + other-modules: + Paths_simplexmq + hs-source-dirs: + apps/smp-server + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded + build-depends: + QuickCheck ==2.14.* + , ansi-terminal >=0.10 && <0.12 + , asn1-encoding ==0.9.* + , asn1-types ==0.3.* + , async ==2.2.* + , attoparsec ==0.13.* + , base >=4.7 && <5 + , base64-bytestring >=1.0 && <1.3 + , bytestring ==0.10.* + , composition ==1.0.* + , constraints >=0.12 && <0.14 + , containers ==0.6.* + , cryptonite >=0.27 && <0.30 + , cryptostore ==0.2.* + , direct-sqlite ==2.3.* + , directory ==1.3.* + , file-embed ==0.0.14.* + , filepath ==1.4.* + , generic-random >=1.3 && <1.5 + , ini ==0.4.* + , iso8601-time ==0.1.* + , memory ==0.15.* + , mtl ==2.2.* + , network ==3.1.* + , network-transport ==0.5.* + , optparse-applicative >=0.15 && <0.17 + , random >=1.1 && <1.3 + , simple-logger ==0.1.* + , simplexmq + , sqlite-simple ==0.4.* + , stm ==2.5.* + , template-haskell ==2.16.* + , text ==1.2.* + , time ==1.9.* + , transformers ==0.5.* + , unliftio ==0.2.* + , unliftio-core ==0.2.* + , websockets ==0.12.* + , x509 ==1.7.* + default-language: Haskell2010 + +test-suite smp-server-test + type: exitcode-stdio-1.0 + main-is: Test.hs + other-modules: + AgentTests + AgentTests.FunctionalAPITests + AgentTests.SQLiteTests + ProtocolErrorTests + ServerTests + SMPAgentClient + SMPClient + Paths_simplexmq + hs-source-dirs: + tests + ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns + build-depends: + HUnit ==1.6.* + , QuickCheck ==2.14.* + , ansi-terminal >=0.10 && <0.12 + , asn1-encoding ==0.9.* + , asn1-types ==0.3.* + , async ==2.2.* + , attoparsec ==0.13.* + , base >=4.7 && <5 + , base64-bytestring >=1.0 && <1.3 + , bytestring ==0.10.* + , composition ==1.0.* + , constraints >=0.12 && <0.14 + , containers ==0.6.* + , cryptonite >=0.27 && <0.30 + , direct-sqlite ==2.3.* + , directory ==1.3.* + , file-embed ==0.0.14.* + , filepath ==1.4.* + , generic-random >=1.3 && <1.5 + , hspec ==2.7.* + , hspec-core ==2.7.* + , iso8601-time ==0.1.* + , memory ==0.15.* + , mtl ==2.2.* + , network ==3.1.* + , network-transport ==0.5.* + , random >=1.1 && <1.3 + , simple-logger ==0.1.* + , simplexmq + , sqlite-simple ==0.4.* + , stm ==2.5.* + , template-haskell ==2.16.* + , text ==1.2.* + , time ==1.9.* + , timeit ==2.0.* + , transformers ==0.5.* + , unliftio ==0.2.* + , unliftio-core ==0.2.* + , websockets ==0.12.* + , x509 ==1.7.* + default-language: Haskell2010 diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 25bc15c91..6b003250c 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -1,12 +1,17 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} -- | -- Module : Simplex.Messaging.Agent @@ -21,10 +26,29 @@ -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md module Simplex.Messaging.Agent - ( runSMPAgent, + ( -- * SMP agent over TCP + runSMPAgent, runSMPAgentBlocking, + + -- * queue-based SMP agent + getAgentClient, + runAgentClient, + + -- * SMP agent functional API + AgentClient (..), + AgentMonad, + AgentErrorMonad, getSMPAgentClient, - runSMPAgentClient, + disconnectAgentClient, -- used in tests + withAgentLock, + createConnection, + joinConnection, + acceptConnection, + subscribeConnection, + sendMessage, + ackMessage, + suspendConnection, + deleteConnection, ) where @@ -34,10 +58,16 @@ import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (MonadRandom) +import Data.Bifunctor (second) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Composition ((.:), (.:.)) +import Data.Functor (($>)) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import Data.Maybe (isJust) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) import Data.Time.Clock @@ -45,16 +75,17 @@ import Database.SQLite.Simple (SQLError) import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store -import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore, connectSQLiteStore) +import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore) import Simplex.Messaging.Client (SMPServerTransmission) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (MsgBody, SenderPublicKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), runTransportServer) -import Simplex.Messaging.Util (bshow) +import Simplex.Messaging.Util (bshow, tryError) import System.Random (randomR) -import UnliftIO.Async (race_) +import UnliftIO.Async (Async, async, race_) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -76,17 +107,66 @@ runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReader smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' () smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do liftIO $ putLn h "Welcome to SMP v0.3.2 agent" - c <- getSMPAgentClient + c <- getAgentClient logConnection c True - race_ (connectClient h c) (runSMPAgentClient c) - `E.finally` (closeSMPServerClients c >> logConnection c False) + race_ (connectClient h c) (runAgentClient c) + `E.finally` disconnectAgentClient c --- | Creates an SMP agent instance that receives commands and sends responses via 'TBQueue's. -getSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient -getSMPAgentClient = do - n <- asks clientCounter - cfg <- asks config - atomically $ newAgentClient n cfg +-- | Creates an SMP agent client instance +getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m AgentClient +getSMPAgentClient cfg = newSMPAgentEnv cfg >>= runReaderT runAgent + where + runAgent = do + c <- getAgentClient + action <- async $ subscriber c `E.finally` disconnectAgentClient c + pure c {smpSubscriber = action} + +disconnectAgentClient :: MonadUnliftIO m => AgentClient -> m () +disconnectAgentClient c = closeAgentClient c >> logConnection c False + +-- | +type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m) + +-- | Create SMP agent connection (NEW command) +createConnection :: AgentErrorMonad m => AgentClient -> m (ConnId, SMPQueueInfo) +createConnection c = withAgentEnv c $ newConn c "" + +-- | Join SMP agent connection (JOIN command) +joinConnection :: AgentErrorMonad m => AgentClient -> SMPQueueInfo -> ConnInfo -> m ConnId +joinConnection c = withAgentEnv c .: joinConn c "" + +-- | Approve confirmation (LET command) +acceptConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () +acceptConnection c = withAgentEnv c .:. acceptConnection' c + +-- | Subscribe to receive connection messages (SUB command) +subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +subscribeConnection c = withAgentEnv c . subscribeConnection' c + +-- | Send message to the connection (SEND command) +sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgBody -> m AgentMsgId +sendMessage c = withAgentEnv c .: sendMessage' c + +ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> m () +ackMessage c = withAgentEnv c .: ackMessage' c + +-- | Suspend SMP agent connection (OFF command) +suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +suspendConnection c = withAgentEnv c . suspendConnection' c + +-- | Delete SMP agent connection (DEL command) +deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () +deleteConnection c = withAgentEnv c . deleteConnection' c + +withAgentEnv :: AgentClient -> ReaderT Env m a -> m a +withAgentEnv c = (`runReaderT` agentEnv c) + +-- withAgentClient :: AgentErrorMonad m => AgentClient -> ReaderT Env m a -> m a +-- withAgentClient c = withAgentLock c . withAgentEnv c + +-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. +getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient +getAgentClient = ask >>= atomically . newAgentClient connectClient :: Transport c => MonadUnliftIO m => c -> AgentClient -> m () connectClient h c = race_ (send h c) (receive h c) @@ -97,56 +177,51 @@ logConnection c connected = in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"] -- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's. -runSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () -runSMPAgentClient c = do - db <- asks $ dbFile . config - s1 <- liftIO $ connectSQLiteStore db - s2 <- liftIO $ connectSQLiteStore db - race_ (subscriber c s1) (client c s2) +runAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () +runAgentClient c = race_ (subscriber c) (client c) receive :: forall c m. (Transport c, MonadUnliftIO m) => c -> AgentClient -> m () -receive h c@AgentClient {rcvQ, sndQ} = forever loop +receive h c@AgentClient {rcvQ, subQ} = forever $ do + (corrId, connId, cmdOrErr) <- tGet SClient h + case cmdOrErr of + Right cmd -> write rcvQ (corrId, connId, cmd) + Left e -> write subQ (corrId, connId, ERR e) where - loop :: m () - loop = do - ATransmissionOrError corrId entity cmdOrErr <- tGet SClient h - case cmdOrErr of - Right cmd -> write rcvQ $ ATransmission corrId entity cmd - Left e -> write sndQ $ ATransmission corrId entity $ ERR e write :: TBQueue (ATransmission p) -> ATransmission p -> m () write q t = do logClient c "-->" t atomically $ writeTBQueue q t send :: (Transport c, MonadUnliftIO m) => c -> AgentClient -> m () -send h c@AgentClient {sndQ} = forever $ do - t <- atomically $ readTBQueue sndQ +send h c@AgentClient {subQ} = forever $ do + t <- atomically $ readTBQueue subQ tPut h t logClient c "<--" t logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a -> m () -logClient AgentClient {clientId} dir (ATransmission corrId entity cmd) = do - logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, serializeEntity entity, B.takeWhile (/= ' ') $ serializeCommand cmd] +logClient AgentClient {clientId} dir (corrId, connId, cmd) = do + logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd] -client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () -client c@AgentClient {rcvQ, sndQ} st = forever loop - where - loop :: m () - loop = do - t@(ATransmission corrId entity _) <- atomically $ readTBQueue rcvQ - runExceptT (processCommand c st t) >>= \case - Left e -> atomically . writeTBQueue sndQ $ ATransmission corrId entity (ERR e) - Right _ -> pure () +client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () +client c@AgentClient {rcvQ, subQ} = forever $ do + (corrId, connId, cmd) <- atomically $ readTBQueue rcvQ + withAgentLock c (runExceptT $ processCommand c (connId, cmd)) + >>= atomically . writeTBQueue subQ . \case + Left e -> (corrId, connId, ERR e) + Right (connId', resp) -> (corrId, connId', resp) withStore :: AgentMonad m => - (forall m'. (MonadUnliftIO m', MonadError StoreError m') => m' a) -> + (forall m'. (MonadUnliftIO m', MonadError StoreError m') => SQLiteStore -> m' a) -> m a withStore action = do - runExceptT (action `E.catch` handleInternal) >>= \case + st <- asks store + runExceptT (action st `E.catch` handleInternal) >>= \case Right c -> return c Left e -> throwError $ storeError e where + -- TODO when parsing exception happens in store, the agent hangs; + -- changing SQLError to SomeException does not help handleInternal :: (MonadError StoreError m') => SQLError -> m' a handleInternal e = throwError . SEInternal $ bshow e storeError :: StoreError -> AgentErrorType @@ -155,216 +230,319 @@ withStore action = do SEConnDuplicate -> CONN DUPLICATE SEBadConnType CRcv -> CONN SIMPLEX SEBadConnType CSnd -> CONN SIMPLEX - SEBcastNotFound -> BCAST B_NOT_FOUND - SEBcastDuplicate -> BCAST B_DUPLICATE e -> INTERNAL $ show e -processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m () -processCommand c st (ATransmission corrId entity cmd) = process c st corrId entity cmd +-- | execute any SMP agent command +processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent) +processCommand c (connId, cmd) = case cmd of + NEW -> second INV <$> newConn c connId + JOIN smpQueueInfo connInfo -> (,OK) <$> joinConn c connId smpQueueInfo connInfo + ACPT confId ownConnInfo -> acceptConnection' c connId confId ownConnInfo $> (connId, OK) + SUB -> subscribeConnection' c connId $> (connId, OK) + SEND msgBody -> (connId,) . MID <$> sendMessage' c connId msgBody + ACK msgId -> ackMessage' c connId msgId $> (connId, OK) + OFF -> suspendConnection' c connId $> (connId, OK) + DEL -> deleteConnection' c connId $> (connId, OK) + +newConn :: AgentMonad m => AgentClient -> ConnId -> m (ConnId, SMPQueueInfo) +newConn c connId = do + srv <- getSMPServer + (rq, qInfo) <- newRcvQueue c srv + g <- asks idsDrg + let cData = ConnData {connId} + connId' <- withStore $ \st -> createRcvConn st g cData rq + addSubscription c rq connId' + pure (connId', qInfo) + +joinConn :: AgentMonad m => AgentClient -> ConnId -> SMPQueueInfo -> ConnInfo -> m ConnId +joinConn c connId qInfo cInfo = do + (sq, senderKey, verifyKey) <- newSndQueue qInfo + g <- asks idsDrg + cfg <- asks config + let cData = ConnData {connId} + connId' <- withStore $ \st -> createSndConn st g cData sq + confirmQueue c sq senderKey cInfo + activateQueueJoining c connId' sq verifyKey $ retryInterval cfg + pure connId' + +activateQueueJoining :: forall m. AgentMonad m => AgentClient -> ConnId -> SndQueue -> VerificationKey -> RetryInterval -> m () +activateQueueJoining c connId sq verifyKey retryInterval = + activateQueue c connId sq verifyKey retryInterval createReplyQueue where - process = case entity of - Conn _ -> processConnCommand - Broadcast _ -> processBroadcastCommand - _ -> unsupportedEntity + createReplyQueue :: m () + createReplyQueue = do + srv <- getSMPServer + (rq, qInfo') <- newRcvQueue c srv + addSubscription c rq connId + withStore $ \st -> upgradeSndConnToDuplex st connId rq + sendControlMessage c sq $ REPLY qInfo' -unsupportedEntity :: AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity t -> ACommand 'Client c -> m () -unsupportedEntity c _ corrId entity _ = - atomically . writeTBQueue (sndQ c) . ATransmission corrId entity . ERR $ CMD UNSUPPORTED +-- | Approve confirmation (LET command) in Reader monad +acceptConnection' :: AgentMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m () +acceptConnection' c connId confId ownConnInfo = + withStore (`getConn` connId) >>= \case + SomeConn SCRcv (RcvConnection _ rq) -> do + AcceptedConfirmation {senderKey} <- withStore $ \st -> acceptConfirmation st confId ownConnInfo + processConfirmation c rq senderKey + _ -> throwError $ CMD PROHIBITED -processConnCommand :: - forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m () -processConnCommand c@AgentClient {sndQ} st corrId conn = \case - NEW -> createNewConnection conn - JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode - SUB -> subscribeConnection conn - SUBALL -> subscribeAll - SEND msgBody -> sendMessage c st corrId conn msgBody - OFF -> suspendConnection conn - DEL -> deleteConnection conn +processConfirmation :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m () +processConfirmation c rq sndKey = do + withStore $ \st -> setRcvQueueStatus st rq Confirmed + secureQueue c rq sndKey + withStore $ \st -> setRcvQueueStatus st rq Secured + +-- | Subscribe to receive connection messages (SUB command) in Reader monad +subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () +subscribeConnection' c connId = + withStore (`getConn` connId) >>= \case + SomeConn _ (DuplexConnection _ rq sq) -> do + resumeDelivery sq + case status (sq :: SndQueue) of + Confirmed -> withVerifyKey sq $ \verifyKey -> do + conf <- withStore (`getAcceptedConfirmation` connId) + secureQueue c rq $ senderKey (conf :: AcceptedConfirmation) + withStore $ \st -> setRcvQueueStatus st rq Secured + activateSecuredQueue rq sq verifyKey + Secured -> withVerifyKey sq $ activateSecuredQueue rq sq + Active -> subscribeQueue c rq connId + _ -> throwError $ INTERNAL "unexpected queue status" + SomeConn _ (SndConnection _ sq) -> do + resumeDelivery sq + case status (sq :: SndQueue) of + Confirmed -> withVerifyKey sq $ \verifyKey -> + activateQueueJoining c connId sq verifyKey =<< resumeInterval + Active -> throwError $ CONN SIMPLEX + _ -> throwError $ INTERNAL "unexpected queue status" + SomeConn _ (RcvConnection _ rq) -> subscribeQueue c rq connId where - createNewConnection :: Entity 'Conn_ -> m () - createNewConnection (Conn cId) = do - -- TODO create connection alias if not passed - -- make connAlias Maybe? - srv <- getSMPServer - (rq, qInfo) <- newReceiveQueue c srv cId - withStore $ createRcvConn st rq - respond conn $ INV qInfo + resumeDelivery :: SndQueue -> m () + resumeDelivery SndQueue {server} = do + wasDelivering <- resumeMsgDelivery c connId server + unless wasDelivering $ do + pending <- withStore (`getPendingMsgs` connId) + queuePendingMsgs c connId pending + withVerifyKey :: SndQueue -> (C.PublicKey -> m ()) -> m () + withVerifyKey sq action = + let err = throwError $ INTERNAL "missing signing key public counterpart" + in maybe err action . C.publicKey $ signKey sq + activateSecuredQueue :: RcvQueue -> SndQueue -> C.PublicKey -> m () + activateSecuredQueue rq sq verifyKey = do + activateQueueInitiating c connId sq verifyKey =<< resumeInterval + subscribeQueue c rq connId + resumeInterval :: m RetryInterval + resumeInterval = do + r <- asks $ retryInterval . config + pure r {initialInterval = 5_000_000} - getSMPServer :: m SMPServer - getSMPServer = - asks (smpServers . config) >>= \case - srv :| [] -> pure srv - servers -> do - gen <- asks randomServer - i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1) - pure $ servers L.!! i - - joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m () - joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do - -- TODO create connection alias if not passed - -- make connAlias Maybe? - (sq, senderKey, verifyKey) <- newSendQueue qInfo cId - withStore $ createSndConn st sq - connectToSendQueue c st sq senderKey verifyKey - when (replyMode == On) $ createReplyQueue cId sq - -- TODO this response is disabled to avoid two responses in terminal client (OK + CON), - -- respond conn OK - - subscribeConnection :: Entity 'Conn_ -> m () - subscribeConnection conn'@(Conn cId) = - withStore (getConn st cId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> subscribe rq - SomeConn _ (RcvConnection _ rq) -> subscribe rq - _ -> throwError $ CONN SIMPLEX - where - subscribe rq = subscribeQueue c rq cId >> respond conn' OK - - -- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct - subscribeAll :: m () - subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn) - - suspendConnection :: Entity 'Conn_ -> m () - suspendConnection (Conn cId) = - withStore (getConn st cId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> suspend rq - SomeConn _ (RcvConnection _ rq) -> suspend rq - _ -> throwError $ CONN SIMPLEX - where - suspend rq = suspendQueue c rq >> respond conn OK - - deleteConnection :: Entity 'Conn_ -> m () - deleteConnection (Conn cId) = - withStore (getConn st cId) >>= \case - SomeConn _ (DuplexConnection _ rq _) -> delete rq - SomeConn _ (RcvConnection _ rq) -> delete rq - _ -> delConn - where - delConn = withStore (deleteConn st cId) >> respond conn OK - delete rq = do - deleteQueue c rq - removeSubscription c cId - delConn - - createReplyQueue :: ByteString -> SndQueue -> m () - createReplyQueue cId sq = do - srv <- getSMPServer - (rq, qInfo) <- newReceiveQueue c srv cId - withStore $ upgradeSndConnToDuplex st cId rq - senderTimestamp <- liftIO getCurrentTime - sendAgentMessage c sq . serializeSMPMessage $ - SMPMessage - { senderMsgId = 0, - senderTimestamp, - previousMsgHash = "", - agentMessage = REPLY qInfo - } - - respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m () - respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp - -sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m () -sendMessage c st corrId (Conn cId) msgBody = - withStore (getConn st cId) >>= \case - SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq - SomeConn _ (SndConnection _ sq) -> sendMsg sq +-- | Send message to the connection (SEND command) in Reader monad +sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgBody -> m AgentMsgId +sendMessage' c connId msg = + withStore (`getConn` connId) >>= \case + SomeConn _ (DuplexConnection _ _ sq) -> enqueueMessage sq + SomeConn _ (SndConnection _ sq) -> enqueueMessage sq _ -> throwError $ CONN SIMPLEX where - sendMsg :: SndQueue -> m () - sendMsg sq = do - internalTs <- liftIO getCurrentTime - (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq - let msgStr = - serializeSMPMessage - SMPMessage - { senderMsgId = unSndId internalSndId, - senderTimestamp = internalTs, - previousMsgHash, - agentMessage = A_MSG msgBody - } - msgHash = C.sha256Hash msgStr - withStore $ - createSndMsg st sq $ - SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash} - sendAgentMessage c sq msgStr - atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId) + enqueueMessage :: SndQueue -> m AgentMsgId + enqueueMessage SndQueue {server} = do + msgId <- storeSentMsg + wasDelivering <- resumeMsgDelivery c connId server + pending <- + if wasDelivering + then pure [PendingMsg {connId, msgId}] + else withStore (`getPendingMsgs` connId) + queuePendingMsgs c connId pending + pure $ unId msgId + where + storeSentMsg :: m InternalId + storeSentMsg = do + internalTs <- liftIO getCurrentTime + withStore $ \st -> do + (internalId, internalSndId, previousMsgHash) <- updateSndIds st connId + let msgBody = + serializeSMPMessage + SMPMessage + { senderMsgId = unSndId internalSndId, + senderTimestamp = internalTs, + previousMsgHash, + agentMessage = A_MSG msg + } + internalHash = C.sha256Hash msgBody + msgData = SndMsgData {..} + createSndMsg st connId msgData + pure internalId -processBroadcastCommand :: - forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m () -processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case - NEW -> withStore (createBcast st bId) >> ok - ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok - REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok - LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn - SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0) - DEL -> withStore (deleteBcast st bId) >> ok +resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnId -> SMPServer -> m Bool +resumeMsgDelivery c connId srv = do + void $ resume srv (srvMsgDeliveries c) $ runSrvMsgDelivery c srv + resume connId (connMsgDeliveries c) $ runMsgDelivery c connId srv where - sendMsg :: MsgBody -> ConnAlias -> m () - sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody + resume :: Ord a => a -> TVar (Map a (Async ())) -> m () -> m Bool + resume key actionMap actionProcess = do + isDelivering <- isJust . M.lookup key <$> readTVarIO actionMap + unless isDelivering $ + async actionProcess + >>= atomically . modifyTVar actionMap . M.insert key + pure isDelivering - ok :: m () - ok = respond bcast OK +queuePendingMsgs :: AgentMonad m => AgentClient -> ConnId -> [PendingMsg] -> m () +queuePendingMsgs c connId pending = + atomically $ getPendingMsgQ connId (connMsgQueues c) >>= forM_ pending . writeTQueue - respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m () - respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp +getPendingMsgQ :: Ord a => a -> TVar (Map a (TQueue PendingMsg)) -> STM (TQueue PendingMsg) +getPendingMsgQ key queueMap = do + maybe newMsgQueue pure . M.lookup key =<< readTVar queueMap + where + newMsgQueue :: STM (TQueue PendingMsg) + newMsgQueue = do + mq <- newTQueue + modifyTVar queueMap $ M.insert key mq + pure mq -subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m () -subscriber c@AgentClient {msgQ} st = forever $ do - -- TODO this will only process messages and notifications +runMsgDelivery :: AgentMonad m => AgentClient -> ConnId -> SMPServer -> m () +runMsgDelivery c connId srv = do + mq <- atomically . getPendingMsgQ connId $ connMsgQueues c + smq <- atomically . getPendingMsgQ srv $ srvMsgQueues c + forever . atomically $ readTQueue mq >>= writeTQueue smq + +runSrvMsgDelivery :: forall m. AgentMonad m => AgentClient -> SMPServer -> m () +runSrvMsgDelivery c@AgentClient {subQ} srv = do + mq <- atomically . getPendingMsgQ srv $ srvMsgQueues c + ri <- asks $ reconnectInterval . config + forever $ do + PendingMsg {connId, msgId} <- atomically $ readTQueue mq + let mId = unId msgId + withStore (\st -> E.try $ getPendingMsgData st connId msgId) >>= \case + Left (e :: E.SomeException) -> + notify connId $ MERR mId (INTERNAL $ show e) + Right (sq, msgBody) -> do + withRetryInterval ri $ \loop -> do + tryError (sendAgentMessage c sq msgBody) >>= \case + Left e -> case e of + SMP SMP.QUOTA -> loop + SMP {} -> notify connId $ MERR mId e + CMD {} -> notify connId $ MERR mId e + _ -> loop + Right () -> do + notify connId $ SENT mId + withStore $ \st -> updateSndMsgStatus st connId msgId SndMsgSent + where + notify :: ConnId -> ACommand 'Agent -> m () + notify connId cmd = atomically $ writeTBQueue subQ ("", connId, cmd) + +ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> m () +ackMessage' c connId msgId = do + withStore (`getConn` connId) >>= \case + SomeConn _ (DuplexConnection _ rq _) -> ack rq + SomeConn _ (RcvConnection _ rq) -> ack rq + _ -> throwError $ CONN SIMPLEX + where + ack :: RcvQueue -> m () + ack rq = do + let mId = InternalId msgId + withStore $ \st -> checkRcvMsg st connId mId + sendAck c rq + withStore $ \st -> updateRcvMsgAck st connId mId + +-- | Suspend SMP agent connection (OFF command) in Reader monad +suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m () +suspendConnection' c connId = + withStore (`getConn` connId) >>= \case + SomeConn _ (DuplexConnection _ rq _) -> suspendQueue c rq + SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq + _ -> throwError $ CONN SIMPLEX + +-- | Delete SMP agent connection (DEL command) in Reader monad +deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m () +deleteConnection' c connId = + withStore (`getConn` connId) >>= \case + SomeConn _ (DuplexConnection _ rq _) -> delete rq + SomeConn _ (RcvConnection _ rq) -> delete rq + _ -> withStore (`deleteConn` connId) + where + delete :: RcvQueue -> m () + delete rq = do + deleteQueue c rq + removeSubscription c connId + withStore (`deleteConn` connId) + +getSMPServer :: AgentMonad m => m SMPServer +getSMPServer = + asks (smpServers . config) >>= \case + srv :| [] -> pure srv + servers -> do + gen <- asks randomServer + i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1) + pure $ servers L.!! i + +sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m () +sendControlMessage c sq agentMessage = do + senderTimestamp <- liftIO getCurrentTime + sendAgentMessage c sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage + } + +subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () +subscriber c@AgentClient {msgQ} = forever $ do t <- atomically $ readTBQueue msgQ - runExceptT (processSMPTransmission c st t) >>= \case + withAgentLock c (runExceptT $ processSMPTransmission c t) >>= \case Left e -> liftIO $ print e Right _ -> return () -processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m () -processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do - withStore (getRcvConn st srv rId) >>= \case - SomeConn SCDuplex (DuplexConnection _ rq _) -> processSMP SCDuplex rq - SomeConn SCRcv (RcvConnection _ rq) -> processSMP SCRcv rq - _ -> atomically . writeTBQueue sndQ $ ATransmission "" (Conn "") (ERR $ CONN SIMPLEX) +processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SMPServerTransmission -> m () +processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do + withStore (\st -> getRcvConn st srv rId) >>= \case + SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq + SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq + _ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND) where - processSMP :: SConnType c -> RcvQueue -> m () - processSMP cType rq@RcvQueue {connAlias, status} = + processSMP :: SConnType c -> ConnData -> RcvQueue -> m () + processSMP cType ConnData {connId} rq@RcvQueue {status} = case cmd of SMP.MSG srvMsgId srvTs msgBody -> do -- TODO deduplicate with previously received msg <- decryptAndVerify rq msgBody let msgHash = C.sha256Hash msg - agentMsg <- liftEither $ parseSMPMessage msg - case agentMsg of - SMPConfirmation senderKey -> smpConfirmation senderKey - SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> + case parseSMPMessage msg of + Left e -> notify $ ERR e + Right (SMPConfirmation senderKey cInfo) -> smpConfirmation senderKey cInfo >> sendAck c rq + Right SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> case agentMessage of - HELLO verifyKey _ -> helloMsg verifyKey msgBody - REPLY qInfo -> replyMsg qInfo + HELLO verifyKey _ -> helloMsg verifyKey msgBody >> sendAck c rq + REPLY qInfo -> replyMsg qInfo >> sendAck c rq A_MSG body -> agentClientMsg previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash - sendAck c rq - return () SMP.END -> do - removeSubscription c connAlias + removeSubscription c connId logServer "<--" c srv rId "END" notify END _ -> do logServer "<--" c srv rId $ "unexpected: " <> bshow cmd notify . ERR $ BROKER UNEXPECTED where - notify :: EntityCommand 'Conn_ c => ACommand 'Agent c -> m () - notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connAlias) msg + notify :: ACommand 'Agent -> m () + notify msg = atomically $ writeTBQueue subQ ("", connId, msg) prohibited :: m () prohibited = notify . ERR $ AGENT A_PROHIBITED - smpConfirmation :: SenderPublicKey -> m () - smpConfirmation senderKey = do + smpConfirmation :: SenderPublicKey -> ConnInfo -> m () + smpConfirmation senderKey cInfo = do logServer "<--" c srv rId "MSG " case status of - New -> do - -- TODO currently it automatically allows whoever sends the confirmation - -- Commands CONF and LET are not supported in v0.2 - withStore $ setRcvQueueStatus st rq Confirmed - -- TODO update sender key in the store? - secureQueue c rq senderKey - withStore $ setRcvQueueStatus st rq Secured + New -> case cType of + SCRcv -> do + g <- asks idsDrg + let newConfirmation = NewConfirmation {connId, senderKey, senderConnInfo = cInfo} + confId <- withStore $ \st -> createConfirmation st g newConfirmation + notify $ REQ confId cInfo + SCDuplex -> do + notify $ INFO cInfo + processConfirmation c rq senderKey + _ -> prohibited _ -> prohibited helloMsg :: SenderPublicKey -> ByteString -> m () @@ -374,9 +552,9 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do Active -> prohibited _ -> do void $ verifyMessage (Just verifyKey) msgBody - withStore $ setRcvQueueActive st rq verifyKey + withStore $ \st -> setRcvQueueActive st rq verifyKey case cType of - SCDuplex -> notify CON + SCDuplex -> notifyConnected c connId _ -> pure () replyMsg :: SMPQueueInfo -> m () @@ -384,42 +562,26 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do logServer "<--" c srv rId "MSG " case cType of SCRcv -> do - (sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias - withStore $ upgradeRcvConnToDuplex st connAlias sq - connectToSendQueue c st sq senderKey verifyKey - notify CON + AcceptedConfirmation {ownConnInfo} <- withStore (`getAcceptedConfirmation` connId) + (sq, senderKey, verifyKey) <- newSndQueue qInfo + withStore $ \st -> upgradeRcvConnToDuplex st connId sq + confirmQueue c sq senderKey ownConnInfo + withStore (`removeConfirmations` connId) + cfg <- asks config + activateQueueInitiating c connId sq verifyKey $ retryInterval cfg _ -> prohibited agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m () - agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do + agentClientMsg externalPrevSndHash sender broker msgBody internalHash = do logServer "<--" c srv rId "MSG " - case status of - Active -> do - internalTs <- liftIO getCurrentTime - (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq - let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash receivedPrevMsgHash - withStore $ - createRcvMsg st rq $ - RcvMsgData - { internalId, - internalRcvId, - internalTs, - senderMeta, - brokerMeta, - msgBody, - internalHash = msgHash, - externalPrevSndHash = receivedPrevMsgHash, - msgIntegrity - } - notify - MSG - { recipientMeta = (unId internalId, internalTs), - senderMeta, - brokerMeta, - msgBody, - msgIntegrity - } - _ -> prohibited + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore (`updateRcvIds` connId) + let integrity = checkMsgIntegrity prevExtSndId (fst sender) prevRcvMsgHash externalPrevSndHash + recipient = (unId internalId, internalTs) + msgMeta = MsgMeta {integrity, recipient, sender, broker} + rcvMsg = RcvMsgData {..} + withStore $ \st -> createRcvMsg st connId rcvMsg + notify $ MSG msgMeta msgBody checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash @@ -430,16 +592,39 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash | otherwise = MsgError MsgDuplicate -- this case is not possible -connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m () -connectToSendQueue c st sq senderKey verifyKey = do - sendConfirmation c sq senderKey - withStore $ setSndQueueStatus st sq Confirmed - sendHello c sq verifyKey - withStore $ setSndQueueStatus st sq Active +confirmQueue :: AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> ConnInfo -> m () +confirmQueue c sq senderKey cInfo = do + sendConfirmation c sq senderKey cInfo + withStore $ \st -> setSndQueueStatus st sq Confirmed -newSendQueue :: - (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey) -newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do +activateQueueInitiating :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> VerificationKey -> RetryInterval -> m () +activateQueueInitiating c connId sq verifyKey retryInterval = + activateQueue c connId sq verifyKey retryInterval $ notifyConnected c connId + +activateQueue :: forall m. AgentMonad m => AgentClient -> ConnId -> SndQueue -> VerificationKey -> RetryInterval -> m () -> m () +activateQueue c connId sq verifyKey retryInterval afterActivation = + getActivation c connId >>= \case + Nothing -> async runActivation >>= addActivation c connId + Just _ -> pure () + where + runActivation :: m () + runActivation = do + sendHello c sq verifyKey retryInterval + withStore $ \st -> setSndQueueStatus st sq Active + removeActivation c connId + removeVerificationKey + afterActivation + removeVerificationKey :: m () + removeVerificationKey = + let safeSignKey = C.removePublicKey $ signKey sq + in withStore $ \st -> updateSignKey st sq safeSignKey + +notifyConnected :: AgentMonad m => AgentClient -> ConnId -> m () +notifyConnected c connId = atomically $ writeTBQueue (subQ c) ("", connId, CON) + +newSndQueue :: + (MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey) +newSndQueue (SMPQueueInfo smpServer senderId encryptKey) = do size <- asks $ rsaKeySize . config (senderKey, sndPrivateKey) <- liftIO $ C.generateKeyPair size (verifyKey, signKey) <- liftIO $ C.generateKeyPair size @@ -447,7 +632,6 @@ newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do SndQueue { server = smpServer, sndId = senderId, - connAlias, sndPrivateKey, encryptKey, signKey, diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 2c6bea6f1..cc9cfb340 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -11,11 +11,13 @@ module Simplex.Messaging.Agent.Client ( AgentClient (..), newAgentClient, AgentMonad, - getSMPServerClient, - closeSMPServerClients, - newReceiveQueue, + withAgentLock, + closeAgentClient, + newRcvQueue, subscribeQueue, + addSubscription, sendConfirmation, + RetryInterval (..), sendHello, secureQueue, sendAgentMessage, @@ -27,9 +29,14 @@ module Simplex.Messaging.Agent.Client logServer, removeSubscription, cryptoError, + addActivation, + getActivation, + removeActivation, ) where +import Control.Concurrent.Async (Async, async, uninterruptibleCancel) +import Control.Concurrent.STM (stateTVar) import Control.Logger.Simple import Control.Monad.Except import Control.Monad.IO.Unlift @@ -40,44 +47,62 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Data.Maybe (isNothing) import Data.Set (Set) import qualified Data.Set as S import Data.Text.Encoding import Data.Time.Clock import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey) import Simplex.Messaging.Util (bshow, liftEitherError, liftError) -import UnliftIO.Concurrent import UnliftIO.Exception (IOException) import qualified UnliftIO.Exception as E import UnliftIO.STM data AgentClient = AgentClient { rcvQ :: TBQueue (ATransmission 'Client), - sndQ :: TBQueue (ATransmission 'Agent), + subQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue SMPServerTransmission, smpClients :: TVar (Map SMPServer SMPClient), - subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)), - subscrConns :: TVar (Map ConnAlias SMPServer), - clientId :: Int + subscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)), + subscrConns :: TVar (Map ConnId SMPServer), + activations :: TVar (Map ConnId (Async ())), -- activations of send queues in progress + connMsgQueues :: TVar (Map ConnId (TQueue PendingMsg)), + connMsgDeliveries :: TVar (Map ConnId (Async ())), + srvMsgQueues :: TVar (Map SMPServer (TQueue PendingMsg)), + srvMsgDeliveries :: TVar (Map SMPServer (Async ())), + reconnections :: TVar [Async ()], + clientId :: Int, + agentEnv :: Env, + smpSubscriber :: Async (), + lock :: TMVar () } -newAgentClient :: TVar Int -> AgentConfig -> STM AgentClient -newAgentClient cc AgentConfig {tbqSize} = do - rcvQ <- newTBQueue tbqSize - sndQ <- newTBQueue tbqSize - msgQ <- newTBQueue tbqSize +newAgentClient :: Env -> STM AgentClient +newAgentClient agentEnv = do + let qSize = tbqSize $ config agentEnv + rcvQ <- newTBQueue qSize + subQ <- newTBQueue qSize + msgQ <- newTBQueue qSize smpClients <- newTVar M.empty subscrSrvrs <- newTVar M.empty subscrConns <- newTVar M.empty - clientId <- (+ 1) <$> readTVar cc - writeTVar cc clientId - return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId} + activations <- newTVar M.empty + connMsgQueues <- newTVar M.empty + connMsgDeliveries <- newTVar M.empty + srvMsgQueues <- newTVar M.empty + srvMsgDeliveries <- newTVar M.empty + reconnections <- newTVar [] + clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) + lock <- newTMVar () + return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, activations, connMsgQueues, connMsgDeliveries, srvMsgQueues, srvMsgDeliveries, reconnections, clientId, agentEnv, smpSubscriber = undefined, lock} +-- | Agent monad with MonadReader Env and MonadError AgentErrorType type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient @@ -95,33 +120,76 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = connectClient :: m SMPClient connectClient = do cfg <- asks $ smpCfg . config - liftEitherError smpClientError (getSMPClient srv cfg msgQ clientDisconnected) + u <- askUnliftIO + liftEitherError smpClientError (getSMPClient srv cfg msgQ $ clientDisconnected u) `E.catch` internalError where internalError :: IOException -> m SMPClient internalError = throwError . INTERNAL . show - clientDisconnected :: IO () - clientDisconnected = do - removeSubs >>= mapM_ (mapM_ notifySub) + clientDisconnected :: UnliftIO m -> IO () + clientDisconnected u = do + removeClientSubs >>= (`forM_` serverDown u) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - removeSubs :: IO (Maybe (Set ConnAlias)) - removeSubs = atomically $ do + removeClientSubs :: IO (Maybe (Map ConnId RcvQueue)) + removeClientSubs = atomically $ do modifyTVar smpClients $ M.delete srv cs <- M.lookup srv <$> readTVar (subscrSrvrs c) modifyTVar (subscrSrvrs c) $ M.delete srv - modifyTVar (subscrConns c) $ maybe id deleteKeys cs + modifyTVar (subscrConns c) $ maybe id (deleteKeys . M.keysSet) cs return cs where deleteKeys :: Ord k => Set k -> Map k a -> Map k a deleteKeys ks m = S.foldr' M.delete m ks - notifySub :: ConnAlias -> IO () - notifySub connAlias = atomically . writeTBQueue (sndQ c) $ ATransmission "" (Conn connAlias) END + serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () + serverDown u cs = unless (M.null cs) $ do + mapM_ (notifySub DOWN) $ M.keysSet cs + a <- async . unliftIO u $ tryReconnectClient cs + atomically $ modifyTVar (reconnections c) (a :) -closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m () -closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient + tryReconnectClient :: Map ConnId RcvQueue -> m () + tryReconnectClient cs = do + ri <- asks $ reconnectInterval . config + withRetryInterval ri $ \loop -> + reconnectClient cs `catchError` const loop + + reconnectClient :: Map ConnId RcvQueue -> m () + reconnectClient cs = do + withAgentLock c . withSMP c srv $ \smp -> do + subs <- readTVarIO $ subscrConns c + forM_ (M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) -> + when (isNothing $ M.lookup connId subs) $ do + subscribeSMPQueue smp rcvPrivateKey rcvId + `catchError` \case + SMPServerError e -> liftIO $ notifySub (ERR $ SMP e) connId + e -> throwError e + addSubscription c rq connId + liftIO $ notifySub UP connId + + notifySub :: ACommand 'Agent -> ConnId -> IO () + notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd) + +closeAgentClient :: MonadUnliftIO m => AgentClient -> m () +closeAgentClient c = liftIO $ do + closeSMPServerClients c + cancelActions $ activations c + cancelActions $ reconnections c + cancelActions $ connMsgDeliveries c + cancelActions $ srvMsgDeliveries c + +closeSMPServerClients :: AgentClient -> IO () +closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ closeSMPClient + +cancelActions :: Foldable f => TVar (f (Async ())) -> IO () +cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel + +withAgentLock :: MonadUnliftIO m => AgentClient -> m a -> m a +withAgentLock AgentClient {lock} = + E.bracket_ + (void . atomically $ takeTMVar lock) + (atomically $ putTMVar lock ()) withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a withSMP_ c srv action = @@ -158,8 +226,8 @@ smpClientError = \case SMPTransportError e -> BROKER $ TRANSPORT e e -> INTERNAL $ show e -newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo) -newReceiveQueue c srv connAlias = do +newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueInfo) +newRcvQueue c srv = do size <- asks $ rsaKeySize . config (recipientKey, rcvPrivateKey) <- liftIO $ C.generateKeyPair size logServer "-->" c srv "" "NEW" @@ -170,44 +238,50 @@ newReceiveQueue c srv connAlias = do RcvQueue { server = srv, rcvId, - connAlias, rcvPrivateKey, sndId = Just sId, - sndKey = Nothing, decryptKey, verifyKey = Nothing, status = New } - addSubscription c rq connAlias return (rq, SMPQueueInfo srv sId encryptKey) -subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnAlias -> m () -subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connAlias = do +subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () +subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do withLogSMP c server rcvId "SUB" $ \smp -> subscribeSMPQueue smp rcvPrivateKey rcvId - addSubscription c rq connAlias + addSubscription c rq connId -addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnAlias -> m () -addSubscription c RcvQueue {server} connAlias = atomically $ do - modifyTVar (subscrConns c) $ M.insert connAlias server +addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () +addSubscription c rq@RcvQueue {server} connId = atomically $ do + modifyTVar (subscrConns c) $ M.insert connId server modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server where - addSub :: Maybe (Set ConnAlias) -> Set ConnAlias - addSub (Just cs) = S.insert connAlias cs - addSub _ = S.singleton connAlias + addSub :: Maybe (Map ConnId RcvQueue) -> Map ConnId RcvQueue + addSub (Just cs) = M.insert connId rq cs + addSub _ = M.singleton connId rq -removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m () -removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do +removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m () +removeSubscription AgentClient {subscrConns, subscrSrvrs} connId = atomically $ do cs <- readTVar subscrConns - writeTVar subscrConns $ M.delete connAlias cs + writeTVar subscrConns $ M.delete connId cs mapM_ (modifyTVar subscrSrvrs . M.alter (>>= delSub)) - (M.lookup connAlias cs) + (M.lookup connId cs) where - delSub :: Set ConnAlias -> Maybe (Set ConnAlias) + delSub :: Map ConnId RcvQueue -> Maybe (Map ConnId RcvQueue) delSub cs = - let cs' = S.delete connAlias cs - in if S.null cs' then Nothing else Just cs' + let cs' = M.delete connId cs + in if M.null cs' then Nothing else Just cs' + +addActivation :: MonadUnliftIO m => AgentClient -> ConnId -> Async () -> m () +addActivation c connId a = atomically . modifyTVar (activations c) $ M.insert connId a + +getActivation :: MonadUnliftIO m => AgentClient -> ConnId -> m (Maybe (Async ())) +getActivation c connId = M.lookup connId <$> readTVarIO (activations c) + +removeActivation :: MonadUnliftIO m => AgentClient -> ConnId -> m () +removeActivation c connId = atomically . modifyTVar (activations c) $ M.delete connId logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = @@ -219,20 +293,23 @@ showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv) logSecret :: ByteString -> ByteString logSecret bs = encode $ B.take 3 bs -sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> m () -sendConfirmation c sq@SndQueue {server, sndId} senderKey = +sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> ConnInfo -> m () +sendConfirmation c sq@SndQueue {server, sndId} senderKey cInfo = withLogSMP_ c server sndId "SEND " $ \smp -> do msg <- mkConfirmation smp liftSMP $ sendSMPMessage smp Nothing sndId msg where mkConfirmation :: SMPClient -> m MsgBody - mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey + mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey cInfo -sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () -sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = +sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> RetryInterval -> m () +sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey ri = withLogSMP_ c server sndId "SEND (retrying)" $ \smp -> do msg <- mkHello smp $ AckMode On - liftSMP $ send 8 100000 msg smp + liftSMP . withRetryInterval ri $ \loop -> + sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case + SMPServerError AUTH -> loop + e -> throwE e where mkHello :: SMPClient -> AckMode -> m ByteString mkHello smp ackMode = do @@ -245,15 +322,6 @@ sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = agentMessage = HELLO verifyKey ackMode } - send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () - send 0 _ _ _ = throwE $ SMPServerError AUTH - send retry delay msg smp = - sendSMPMessage smp (Just sndPrivateKey) sndId msg `catchE` \case - SMPServerError AUTH -> do - threadDelay delay - send (retry - 1) (delay * 3 `div` 2) msg smp - e -> throwE e - secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SenderPublicKey -> m () secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey = withLogSMP c server rcvId "KEY " $ \smp -> diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index 13445643a..6a063d4dd 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NumericUnderscores #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} module Simplex.Messaging.Agent.Env.SQLite where @@ -11,7 +12,9 @@ import Data.List.NonEmpty (NonEmpty) import Network.Socket import Numeric.Natural import Simplex.Messaging.Agent.Protocol (SMPServer) +import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store.SQLite +import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import System.Random (StdGen, newStdGen) import UnliftIO.STM @@ -23,11 +26,43 @@ data AgentConfig = AgentConfig connIdBytes :: Int, tbqSize :: Natural, dbFile :: FilePath, - smpCfg :: SMPClientConfig + dbPoolSize :: Int, + smpCfg :: SMPClientConfig, + retryInterval :: RetryInterval, + reconnectInterval :: RetryInterval } +minute :: Int +minute = 60_000_000 + +defaultAgentConfig :: AgentConfig +defaultAgentConfig = + AgentConfig + { tcpPort = "5224", + smpServers = undefined, + rsaKeySize = 2048 `div` 8, + connIdBytes = 12, + tbqSize = 16, + dbFile = "smp-agent.db", + dbPoolSize = 4, + smpCfg = smpDefaultConfig, + retryInterval = + RetryInterval + { initialInterval = 1_000_000, + increaseAfter = minute, + maxInterval = 10 * minute + }, + reconnectInterval = + RetryInterval + { initialInterval = 1_000_000, + increaseAfter = 10_000_000, + maxInterval = 10_000_000 + } + } + data Env = Env { config :: AgentConfig, + store :: SQLiteStore, idsDrg :: TVar ChaChaDRG, clientCounter :: TVar Int, reservedMsgSize :: Int, @@ -35,15 +70,15 @@ data Env = Env } newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env -newSMPAgentEnv config = do +newSMPAgentEnv cfg = do idsDrg <- newTVarIO =<< drgNew - _ <- liftIO $ createSQLiteStore $ dbFile config + store <- liftIO $ createSQLiteStore (dbFile cfg) (dbPoolSize cfg) Migrations.app clientCounter <- newTVarIO 0 randomServer <- newTVarIO =<< liftIO newStdGen - return Env {config, idsDrg, clientCounter, reservedMsgSize, randomServer} + return Env {config = cfg, store, idsDrg, clientCounter, reservedMsgSize, randomServer} where -- 1st rsaKeySize is used by the RSA signature in each command, -- 2nd - by encrypted message body header -- 3rd - by message signature -- smpCommandSize - is the estimated max size for SMP command, queueId, corrId - reservedMsgSize = 3 * rsaKeySize config + smpCommandSize (smpCfg config) + reservedMsgSize = 3 * rsaKeySize cfg + smpCommandSize (smpCfg cfg) diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 07f135440..518b4219d 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -10,9 +10,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} @@ -30,16 +28,12 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md module Simplex.Messaging.Agent.Protocol ( -- * SMP agent protocol types - Entity (..), - EntityTag (..), - AnEntity (..), - EntityCommand, - entityCommand, + ConnInfo, ACommand (..), - ACmdTag (..), AParty (..), - APartyCmd (..), SAParty (..), + MsgHash, + MsgMeta (..), SMPMessage (..), AMessage (..), SMPServer (..), @@ -47,14 +41,15 @@ module Simplex.Messaging.Agent.Protocol AgentErrorType (..), CommandErrorType (..), ConnectionErrorType (..), - BroadcastErrorType (..), BrokerErrorType (..), SMPAgentError (..), - ATransmission (..), - ATransmissionOrError (..), + ATransmission, + ATransmissionOrError, ARawTransmission, - ConnAlias, - ReplyMode (..), + ConnId, + ConfirmationId, + IntroId, + InvitationId, AckMode (..), OnOff (..), MsgIntegrity (..), @@ -69,14 +64,12 @@ module Simplex.Messaging.Agent.Protocol -- * Parse and serialize serializeCommand, - serializeEntity, serializeSMPMessage, serializeMsgIntegrity, serializeServer, serializeSmpQueueInfo, serializeAgentError, commandP, - anEntityP, parseSMPMessage, smpServerP, smpQueueInfoP, @@ -98,18 +91,15 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Constraint (Dict (..)) import Data.Functor (($>)) import Data.Int (Int64) -import Data.Kind (Constraint, Type) -import Data.Maybe (isJust) +import Data.Kind (Type) import Data.String (IsString (..)) import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 import Data.Type.Equality import Data.Typeable () import GHC.Generics (Generic) -import GHC.TypeLits (ErrorMessage (..), TypeError) import Generic.Random (genericArbitraryU) import Network.Socket (HostName, ServiceName) import qualified Simplex.Messaging.Crypto as C @@ -131,12 +121,10 @@ import UnliftIO.Exception type ARawTransmission = (ByteString, ByteString, ByteString) -- | Parsed SMP agent protocol transmission. -data ATransmission p = forall t c. EntityCommand t c => ATransmission ACorrId (Entity t) (ACommand p c) +type ATransmission p = (ACorrId, ConnId, ACommand p) -- | SMP agent protocol transmission or transmission error. -data ATransmissionOrError p = forall t c. EntityCommand t c => ATransmissionOrError ACorrId (Entity t) (Either AgentErrorType (ACommand p c)) - -deriving instance Show (ATransmissionOrError p) +type ATransmissionOrError p = (ACorrId, ConnId, Either AgentErrorType (ACommand p)) type ACorrId = ByteString @@ -158,190 +146,64 @@ instance TestEquality SAParty where testEquality SClient SClient = Just Refl testEquality _ _ = Nothing --- | SMP agent protocol entity types -data EntityTag = Conn_ | OpenConn_ | Broadcast_ | AGroup_ - -data Entity :: EntityTag -> Type where - Conn :: {fromConn :: ByteString} -> Entity Conn_ - OpenConn :: {fromOpenConn :: ByteString} -> Entity OpenConn_ - Broadcast :: {fromBroadcast :: ByteString} -> Entity Broadcast_ - AGroup :: {fromAGroup :: ByteString} -> Entity AGroup_ - -deriving instance Eq (Entity t) - -deriving instance Show (Entity t) - -entityId :: Entity t -> ByteString -entityId = \case - Conn bs -> bs - OpenConn bs -> bs - Broadcast bs -> bs - AGroup bs -> bs - -data AnEntity = forall t. AE (Entity t) - -data ACmd = forall (p :: AParty) (c :: ACmdTag). ACmd (SAParty p) (ACommand p c) +data ACmd = forall p. ACmd (SAParty p) (ACommand p) deriving instance Show ACmd -data APartyCmd (p :: AParty) = forall c. APartyCmd (ACommand p c) - -instance Eq (APartyCmd p) where - APartyCmd c1 == APartyCmd c2 = isJust $ testEquality c1 c2 - -deriving instance Show (APartyCmd p) - -type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where - EntityCommand Conn_ NEW_ = () - EntityCommand Conn_ INV_ = () - EntityCommand Conn_ JOIN_ = () - EntityCommand Conn_ CON_ = () - EntityCommand Conn_ SUB_ = () - EntityCommand Conn_ SUBALL_ = () - EntityCommand Conn_ END_ = () - EntityCommand Conn_ SEND_ = () - EntityCommand Conn_ SENT_ = () - EntityCommand Conn_ MSG_ = () - EntityCommand Conn_ OFF_ = () - EntityCommand Conn_ DEL_ = () - EntityCommand Conn_ OK_ = () - EntityCommand Conn_ ERR_ = () - EntityCommand Broadcast_ NEW_ = () - EntityCommand Broadcast_ ADD_ = () - EntityCommand Broadcast_ REM_ = () - EntityCommand Broadcast_ LS_ = () - EntityCommand Broadcast_ MS_ = () - EntityCommand Broadcast_ SEND_ = () - EntityCommand Broadcast_ SENT_ = () - EntityCommand Broadcast_ DEL_ = () - EntityCommand Broadcast_ OK_ = () - EntityCommand Broadcast_ ERR_ = () - EntityCommand _ ERR_ = () - EntityCommand t c = - (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " does not support command " :<>: ShowType c)) - -entityCommand :: Entity t -> ACommand p c -> Maybe (Dict (EntityCommand t c)) -entityCommand = \case - Conn _ -> \case - NEW -> Just Dict - INV _ -> Just Dict - JOIN {} -> Just Dict - CON -> Just Dict - SUB -> Just Dict - SUBALL -> Just Dict - END -> Just Dict - SEND _ -> Just Dict - SENT _ -> Just Dict - MSG {} -> Just Dict - OFF -> Just Dict - DEL -> Just Dict - OK -> Just Dict - ERR _ -> Just Dict - _ -> Nothing - Broadcast _ -> \case - NEW -> Just Dict - ADD _ -> Just Dict - REM _ -> Just Dict - LS -> Just Dict - MS _ -> Just Dict - SEND _ -> Just Dict - SENT _ -> Just Dict - DEL -> Just Dict - OK -> Just Dict - ERR _ -> Just Dict - _ -> Nothing - _ -> \case - ERR _ -> Just Dict - _ -> Nothing - -data ACmdTag - = NEW_ - | INV_ - | JOIN_ - | CON_ - | SUB_ - | SUBALL_ - | END_ - | SEND_ - | SENT_ - | MSG_ - | OFF_ - | DEL_ - | ADD_ - | REM_ - | LS_ - | MS_ - | OK_ - | ERR_ +type ConnInfo = ByteString -- | Parameterized type for SMP agent protocol commands and responses from all participants. -data ACommand (p :: AParty) (c :: ACmdTag) where - NEW :: ACommand Client NEW_ -- response INV - INV :: SMPQueueInfo -> ACommand Agent INV_ - JOIN :: SMPQueueInfo -> ReplyMode -> ACommand Client JOIN_ -- response OK - CON :: ACommand Agent CON_ -- notification that connection is established - -- TODO currently it automatically allows whoever sends the confirmation - -- CONF :: OtherPartyId -> ACommand Agent - -- LET :: OtherPartyId -> ACommand Client - SUB :: ACommand Client SUB_ - SUBALL :: ACommand Client SUBALL_ -- TODO should be moved to chat protocol - hack for subscribing to all - END :: ACommand Agent END_ +data ACommand (p :: AParty) where + NEW :: ACommand Client -- response INV + INV :: SMPQueueInfo -> ACommand Agent + JOIN :: SMPQueueInfo -> ConnInfo -> ACommand Client -- response OK + REQ :: ConfirmationId -> ConnInfo -> ACommand Agent -- ConnInfo is from sender + ACPT :: ConfirmationId -> ConnInfo -> ACommand Client -- ConnInfo is from client + INFO :: ConnInfo -> ACommand Agent + CON :: ACommand Agent -- notification that connection is established + SUB :: ACommand Client + END :: ACommand Agent + DOWN :: ACommand Agent + UP :: ACommand Agent -- QST :: QueueDirection -> ACommand Client -- STAT :: QueueDirection -> Maybe QueueStatus -> Maybe SubMode -> ACommand Agent - SEND :: MsgBody -> ACommand Client SEND_ - SENT :: AgentMsgId -> ACommand Agent SENT_ - MSG :: - { recipientMeta :: (AgentMsgId, UTCTime), - brokerMeta :: (MsgId, UTCTime), - senderMeta :: (AgentMsgId, UTCTime), - msgIntegrity :: MsgIntegrity, - msgBody :: MsgBody - } -> - ACommand Agent MSG_ - -- ACK :: AgentMsgId -> ACommand Client + SEND :: MsgBody -> ACommand Client + MID :: AgentMsgId -> ACommand Agent + SENT :: AgentMsgId -> ACommand Agent + MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent + MSG :: MsgMeta -> MsgBody -> ACommand Agent + ACK :: AgentMsgId -> ACommand Client -- RCVD :: AgentMsgId -> ACommand Agent - OFF :: ACommand Client MSG_ - DEL :: ACommand Client DEL_ - ADD :: Entity Conn_ -> ACommand Client ADD_ - REM :: Entity Conn_ -> ACommand Client REM_ - LS :: ACommand Client LS_ - MS :: [Entity Conn_] -> ACommand Agent MS_ - OK :: ACommand Agent OK_ - ERR :: AgentErrorType -> ACommand Agent ERR_ + OFF :: ACommand Client + DEL :: ACommand Client + OK :: ACommand Agent + ERR :: AgentErrorType -> ACommand Agent -deriving instance Eq (ACommand p c) +deriving instance Eq (ACommand p) -deriving instance Show (ACommand p c) +deriving instance Show (ACommand p) -instance TestEquality (ACommand p) where - testEquality NEW NEW = Just Refl - testEquality c@INV {} c'@INV {} = refl c c' - testEquality c@JOIN {} c'@JOIN {} = refl c c' - testEquality CON CON = Just Refl - testEquality SUB SUB = Just Refl - testEquality SUBALL SUBALL = Just Refl - testEquality END END = Just Refl - testEquality c@SEND {} c'@SEND {} = refl c c' - testEquality c@SENT {} c'@SENT {} = refl c c' - testEquality c@MSG {} c'@MSG {} = refl c c' - testEquality OFF OFF = Just Refl - testEquality DEL DEL = Just Refl - testEquality c@ADD {} c'@ADD {} = refl c c' - testEquality c@REM {} c'@REM {} = refl c c' - testEquality c@LS {} c'@LS {} = refl c c' - testEquality c@MS {} c'@MS {} = refl c c' - testEquality OK OK = Just Refl - testEquality c@ERR {} c'@ERR {} = refl c c' - testEquality _ _ = Nothing +type MsgHash = ByteString -refl :: Eq (f a) => f a -> f a -> Maybe (a :~: a) -refl x x' = if x == x' then Just Refl else Nothing +-- | Agent message metadata sent to the client +data MsgMeta = MsgMeta + { integrity :: MsgIntegrity, + recipient :: (AgentMsgId, UTCTime), + broker :: (MsgId, UTCTime), + sender :: (AgentMsgId, UTCTime) + } + deriving (Eq, Show) -- | SMP message formats. data SMPMessage = -- | SMP confirmation -- (see ) - SMPConfirmation SenderPublicKey + SMPConfirmation + { -- | sender's public key to use for authentication of sender's commands at the recepient's server + senderKey :: SenderPublicKey, + -- | sender's information to be associated with the connection, e.g. sender's profile information + connInfo :: ConnInfo + } | -- | Agent message header and envelope for client messages -- (see ) SMPMessage @@ -350,7 +212,7 @@ data SMPMessage -- | timestamp from the sending agent senderTimestamp :: SenderTimestamp, -- | digest of the previous message - previousMsgHash :: ByteString, + previousMsgHash :: MsgHash, -- | messages sent between agents once queue is secured agentMessage :: AMessage } @@ -373,12 +235,10 @@ parseSMPMessage :: ByteString -> Either AgentErrorType SMPMessage parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE where smpMessageP :: Parser SMPMessage - smpMessageP = - smpConfirmationP <* A.endOfLine - <|> A.endOfLine *> smpClientMessageP + smpMessageP = A.endOfLine *> smpClientMessageP <|> smpConfirmationP smpConfirmationP :: Parser SMPMessage - smpConfirmationP = SMPConfirmation <$> ("KEY " *> C.pubKeyP <* A.endOfLine) + smpConfirmationP = "KEY " *> (SMPConfirmation <$> C.pubKeyP <* A.endOfLine <* A.endOfLine <*> binaryBodyP <* A.endOfLine) smpClientMessageP :: Parser SMPMessage smpClientMessageP = @@ -393,7 +253,7 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE -- | Serialize SMP message. serializeSMPMessage :: SMPMessage -> ByteString serializeSMPMessage = \case - SMPConfirmation sKey -> smpMessage ("KEY " <> C.serializePubKey sKey) "" "" + SMPConfirmation sKey cInfo -> smpMessage ("KEY " <> C.serializePubKey sKey) "" (serializeBinary cInfo) <> "\n" SMPMessage {senderMsgId, senderTimestamp, previousMsgHash, agentMessage} -> let header = messageHeader senderMsgId senderTimestamp previousMsgHash body = serializeAgentMessage agentMessage @@ -411,9 +271,7 @@ agentMessageP = where hello = HELLO <$> C.pubKeyP <*> ackMode reply = REPLY <$> smpQueueInfoP - a_msg = do - size :: Int <- A.decimal <* A.endOfLine - A_MSG <$> A.take size <* A.endOfLine + a_msg = A_MSG <$> binaryBodyP <* A.endOfLine ackMode = AckMode <$> (" NO_ACK" $> Off <|> pure On) -- | SMP queue information parser. @@ -425,7 +283,7 @@ smpQueueInfoP = smpServerP :: Parser SMPServer smpServerP = SMPServer <$> server <*> optional port <*> optional kHash where - server = B.unpack <$> A.takeWhile1 (A.notInClass ":# ") + server = B.unpack <$> A.takeWhile1 (A.notInClass ":#,; ") port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit) kHash = C.KeyHash <$> (A.char '#' *> base64P) @@ -433,7 +291,7 @@ serializeAgentMessage :: AMessage -> ByteString serializeAgentMessage = \case HELLO verifyKey ackMode -> "HELLO " <> C.serializePubKey verifyKey <> if ackMode == AckMode Off then " NO_ACK" else "" REPLY qInfo -> "REPLY " <> serializeSmpQueueInfo qInfo - A_MSG body -> "MSG " <> serializeMsg body <> "\n" + A_MSG body -> "MSG " <> serializeBinary body <> "\n" -- | Serialize SMP queue information that is sent out-of-band. serializeSmpQueueInfo :: SMPQueueInfo -> ByteString @@ -457,7 +315,13 @@ instance IsString SMPServer where fromString = parseString . parseAll $ smpServerP -- | SMP agent connection alias. -type ConnAlias = ByteString +type ConnId = ByteString + +type ConfirmationId = ByteString + +type IntroId = ByteString + +type InvitationId = ByteString -- | Connection modes. data OnOff = On | Off deriving (Eq, Show, Read) @@ -471,9 +335,6 @@ newtype AckMode = AckMode OnOff deriving (Eq, Show) data SMPQueueInfo = SMPQueueInfo SMPServer SMP.SenderId EncryptionKey deriving (Eq, Show) --- | Connection reply mode (used in JOIN command). -newtype ReplyMode = ReplyMode OnOff deriving (Eq, Show) - -- | Public key used to E2E encrypt SMP messages. type EncryptionKey = C.PublicKey @@ -481,7 +342,7 @@ type EncryptionKey = C.PublicKey type DecryptionKey = C.SafePrivateKey -- | Private key used to sign SMP commands -type SignatureKey = C.SafePrivateKey +type SignatureKey = C.APrivateKey -- | Public key used by SMP server to authorize (verify) SMP commands. type VerificationKey = C.PublicKey @@ -520,8 +381,6 @@ data AgentErrorType CMD CommandErrorType | -- | connection errors CONN ConnectionErrorType - | -- | broadcast errors - BCAST BroadcastErrorType | -- | SMP protocol errors forwarded to agent clients SMP ErrorType | -- | SMP server errors @@ -536,14 +395,10 @@ data AgentErrorType data CommandErrorType = -- | command is prohibited in this context PROHIBITED - | -- | command is not supported by this entity - UNSUPPORTED | -- | command syntax is invalid SYNTAX - | -- | cannot parse entity - BAD_ENTITY | -- | entity ID is required with this command - NO_ENTITY + NO_CONN | -- | message size is not correct (no terminating space) SIZE | -- | message does not fit in SMP block @@ -560,14 +415,6 @@ data ConnectionErrorType SIMPLEX deriving (Eq, Generic, Read, Show, Exception) --- | Broadcast error -data BroadcastErrorType - = -- | broadcast ID is not in the database - B_NOT_FOUND - | -- | broadcast ID already exists - B_DUPLICATE - deriving (Eq, Generic, Read, Show, Exception) - -- | SMP server errors. data BrokerErrorType = -- | invalid server response (failed to parse) @@ -600,70 +447,53 @@ instance Arbitrary CommandErrorType where arbitrary = genericArbitraryU instance Arbitrary ConnectionErrorType where arbitrary = genericArbitraryU -instance Arbitrary BroadcastErrorType where arbitrary = genericArbitraryU - instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU -anEntityP :: Parser AnEntity -anEntityP = - ($) - <$> ( "C:" $> AE . Conn - <|> "O:" $> AE . OpenConn - <|> "B:" $> AE . Broadcast - <|> "G:" $> AE . AGroup - ) - <*> A.takeTill (== ' ') - -entityConnP :: Parser (Entity Conn_) -entityConnP = "C:" *> (Conn <$> A.takeTill (== ' ')) - -serializeEntity :: Entity t -> ByteString -serializeEntity = \case - Conn s -> "C:" <> s - OpenConn s -> "O:" <> s - Broadcast s -> "B:" <> s - AGroup s -> "G:" <> s - -- | SMP agent command and response parser commandP :: Parser ACmd commandP = "NEW" $> ACmd SClient NEW <|> "INV " *> invResp <|> "JOIN " *> joinCmd + <|> "REQ " *> reqCmd + <|> "ACPT " *> acptCmd + <|> "INFO " *> infoCmd <|> "SUB" $> ACmd SClient SUB - <|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all <|> "END" $> ACmd SAgent END + <|> "DOWN" $> ACmd SAgent DOWN + <|> "UP" $> ACmd SAgent UP <|> "SEND " *> sendCmd + <|> "MID " *> msgIdResp <|> "SENT " *> sentResp + <|> "MERR " *> msgErrResp <|> "MSG " *> message + <|> "ACK " *> ackCmd <|> "OFF" $> ACmd SClient OFF <|> "DEL" $> ACmd SClient DEL - <|> "ADD " *> addCmd - <|> "REM " *> removeCmd - <|> "LS" $> ACmd SClient LS - <|> "MS " *> membersResp <|> "ERR " *> agentError <|> "CON" $> ACmd SAgent CON <|> "OK" $> ACmd SAgent OK where invResp = ACmd SAgent . INV <$> smpQueueInfoP - joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode) + joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <* A.space <*> A.takeByteString) + reqCmd = ACmd SAgent <$> (REQ <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString) + acptCmd = ACmd SClient <$> (ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString) + infoCmd = ACmd SAgent . INFO <$> A.takeByteString sendCmd = ACmd SClient . SEND <$> A.takeByteString + msgIdResp = ACmd SAgent . MID <$> A.decimal sentResp = ACmd SAgent . SENT <$> A.decimal - addCmd = ACmd SClient . ADD <$> entityConnP - removeCmd = ACmd SClient . REM <$> entityConnP - membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ') - message = do - msgIntegrity <- msgIntegrityP <* A.space - recipientMeta <- "R=" *> partyMeta A.decimal - brokerMeta <- "B=" *> partyMeta base64P - senderMeta <- "S=" *> partyMeta A.decimal - msgBody <- A.takeByteString - return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody} - replyMode = ReplyMode <$> (" NO_REPLY" $> Off <|> pure On) - partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space + msgErrResp = ACmd SAgent <$> (MERR <$> A.decimal <* A.space <*> agentErrorTypeP) + message = ACmd SAgent <$> (MSG <$> msgMetaP <* A.space <*> A.takeByteString) + ackCmd = ACmd SClient . ACK <$> A.decimal + msgMetaP = do + integrity <- msgIntegrityP + recipient <- " R=" *> partyMeta A.decimal + broker <- " B=" *> partyMeta base64P + sender <- " S=" *> partyMeta A.decimal + pure MsgMeta {integrity, recipient, broker, sender} + partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P agentError = ACmd SAgent . ERR <$> agentErrorTypeP -- | Message integrity validation result parser. @@ -680,41 +510,41 @@ parseCommand :: ByteString -> Either AgentErrorType ACmd parseCommand = parse commandP $ CMD SYNTAX -- | Serialize SMP agent command. -serializeCommand :: ACommand p c -> ByteString +serializeCommand :: ACommand p -> ByteString serializeCommand = \case NEW -> "NEW" INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo - JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode + JOIN qInfo cInfo -> "JOIN " <> serializeSmpQueueInfo qInfo <> " " <> serializeBinary cInfo + REQ confId cInfo -> "REQ " <> confId <> " " <> serializeBinary cInfo + ACPT confId cInfo -> "ACPT " <> confId <> " " <> serializeBinary cInfo + INFO cInfo -> "INFO " <> serializeBinary cInfo SUB -> "SUB" - SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all END -> "END" - SEND msgBody -> "SEND " <> serializeMsg msgBody + DOWN -> "DOWN" + UP -> "UP" + SEND msgBody -> "SEND " <> serializeBinary msgBody + MID mId -> "MID " <> bshow mId SENT mId -> "SENT " <> bshow mId - MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} -> - B.unwords - [ "MSG", - serializeMsgIntegrity msgIntegrity, - "R=" <> bshow rmId <> "," <> showTs rTs, - "B=" <> encode bmId <> "," <> showTs bTs, - "S=" <> bshow smId <> "," <> showTs sTs, - serializeMsg msgBody - ] + MERR mId e -> "MERR " <> bshow mId <> " " <> serializeAgentError e + MSG msgMeta msgBody -> + "MSG " <> serializeMsgMeta msgMeta <> " " <> serializeBinary msgBody + ACK mId -> "ACK " <> bshow mId OFF -> "OFF" DEL -> "DEL" - ADD c -> "ADD " <> serializeEntity c - REM c -> "REM " <> serializeEntity c - LS -> "LS" - MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs) CON -> "CON" ERR e -> "ERR " <> serializeAgentError e OK -> "OK" where - replyMode :: ReplyMode -> ByteString - replyMode = \case - ReplyMode Off -> " NO_REPLY" - ReplyMode On -> "" showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis + serializeMsgMeta :: MsgMeta -> ByteString + serializeMsgMeta MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sender = (smId, sTs)} = + B.unwords + [ serializeMsgIntegrity integrity, + "R=" <> bshow rmId <> "," <> showTs rTs, + "B=" <> encode bmId <> "," <> showTs bTs, + "S=" <> bshow smId <> "," <> showTs sTs + ] -- | Serialize message integrity validation result. serializeMsgIntegrity :: MsgIntegrity -> ByteString @@ -732,7 +562,6 @@ serializeMsgIntegrity = \case agentErrorTypeP :: Parser AgentErrorType agentErrorTypeP = "SMP " *> (SMP <$> SMP.errorTypeP) - <|> "BCAST " *> (BCAST <$> bcastErrorP) <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> SMP.errorTypeP) <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) <|> "INTERNAL " *> (INTERNAL <$> parseRead A.takeByteString) @@ -742,21 +571,17 @@ agentErrorTypeP = serializeAgentError :: AgentErrorType -> ByteString serializeAgentError = \case SMP e -> "SMP " <> SMP.serializeErrorType e - BCAST e -> "BCAST " <> serializeBcastError e BROKER (RESPONSE e) -> "BROKER RESPONSE " <> SMP.serializeErrorType e BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e e -> bshow e -bcastErrorP :: Parser BroadcastErrorType -bcastErrorP = "NOT_FOUND" $> B_NOT_FOUND <|> "DUPLICATE" $> B_DUPLICATE +binaryBodyP :: Parser ByteString +binaryBodyP = do + size :: Int <- A.decimal <* A.endOfLine + A.take size -serializeBcastError :: BroadcastErrorType -> ByteString -serializeBcastError = \case - B_NOT_FOUND -> "NOT_FOUND" - B_DUPLICATE -> "DUPLICATE" - -serializeMsg :: ByteString -> ByteString -serializeMsg body = bshow (B.length body) <> "\n" <> body +serializeBinary :: ByteString -> ByteString +serializeBinary body = bshow (B.length body) <> "\n" <> body -- | Send raw (unparsed) SMP agent protocol transmission to TCP connection. tPutRaw :: Transport c => c -> ARawTransmission -> IO () @@ -771,59 +596,50 @@ tGetRaw h = (,,) <$> getLn h <*> getLn h <*> getLn h -- | Send SMP agent protocol command (or response) to TCP connection. tPut :: (Transport c, MonadIO m) => c -> ATransmission p -> m () -tPut h (ATransmission corrId ent cmd) = - liftIO $ tPutRaw h (corrId, serializeEntity ent, serializeCommand cmd) +tPut h (corrId, connAlias, command) = + liftIO $ tPutRaw h (corrId, connAlias, serializeCommand command) -- | Receive client and agent transmissions from TCP connection. tGet :: forall c m p. (Transport c, MonadIO m) => SAParty p -> c -> m (ATransmissionOrError p) tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody where tParseLoadBody :: ARawTransmission -> m (ATransmissionOrError p) - tParseLoadBody (corrId, entityStr, command) = - case parseAll anEntityP entityStr of - Left _ -> pure $ ATransmissionOrError @_ @_ @ERR_ corrId (Conn "") $ Left $ CMD BAD_ENTITY - Right entity -> do - let cmd = parseCommand command >>= fromParty >>= hasEntityId entity - makeTransmission corrId entity <$> either (pure . Left) cmdWithMsgBody cmd + tParseLoadBody t@(corrId, connId, command) = do + let cmd = parseCommand command >>= fromParty >>= tConnId t + fullCmd <- either (return . Left) cmdWithMsgBody cmd + return (corrId, connId, fullCmd) - fromParty :: ACmd -> Either AgentErrorType (APartyCmd p) + fromParty :: ACmd -> Either AgentErrorType (ACommand p) fromParty (ACmd (p :: p1) cmd) = case testEquality party p of - Just Refl -> Right $ APartyCmd cmd + Just Refl -> Right cmd _ -> Left $ CMD PROHIBITED - hasEntityId :: AnEntity -> APartyCmd p -> Either AgentErrorType (APartyCmd p) - hasEntityId (AE entity) (APartyCmd cmd) = - APartyCmd <$> case cmd of - -- NEW and JOIN have optional entity - NEW -> Right cmd - JOIN _ _ -> Right cmd - -- ERROR response does not always have entity - ERR _ -> Right cmd - -- other responses must have entity - _ - | B.null (entityId entity) -> Left $ CMD NO_ENTITY - | otherwise -> Right cmd + tConnId :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p) + tConnId (_, connId, _) cmd = case cmd of + -- NEW, JOIN and ACPT have optional connId + NEW -> Right cmd + JOIN {} -> Right cmd + -- ERROR response does not always have connId + ERR _ -> Right cmd + -- other responses must have connId + _ + | B.null connId -> Left $ CMD NO_CONN + | otherwise -> Right cmd - makeTransmission :: ACorrId -> AnEntity -> Either AgentErrorType (APartyCmd p) -> ATransmissionOrError p - makeTransmission corrId (AE entity) = \case - Left e -> err e - Right (APartyCmd cmd) -> case entityCommand entity cmd of - Just Dict -> ATransmissionOrError corrId entity $ Right cmd - _ -> err $ CMD UNSUPPORTED - where - err e = ATransmissionOrError @_ @_ @ERR_ corrId entity $ Left e - - cmdWithMsgBody :: APartyCmd p -> m (Either AgentErrorType (APartyCmd p)) - cmdWithMsgBody (APartyCmd cmd) = - APartyCmd <$$> case cmd of - SEND body -> SEND <$$> getMsgBody body - MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body - _ -> pure $ Right cmd + cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) + cmdWithMsgBody = \case + SEND body -> SEND <$$> getBody body + MSG msgMeta body -> MSG msgMeta <$$> getBody body + JOIN qInfo cInfo -> JOIN qInfo <$$> getBody cInfo + REQ confId cInfo -> REQ confId <$$> getBody cInfo + ACPT confId cInfo -> ACPT confId <$$> getBody cInfo + INFO cInfo -> INFO <$$> getBody cInfo + cmd -> pure $ Right cmd -- TODO refactor with server - getMsgBody :: MsgBody -> m (Either AgentErrorType MsgBody) - getMsgBody msgBody = - case B.unpack msgBody of + getBody :: ByteString -> m (Either AgentErrorType ByteString) + getBody binary = + case B.unpack binary of ':' : body -> return . Right $ B.pack body str -> case readMaybe str :: Maybe Int of Just size -> liftIO $ do diff --git a/src/Simplex/Messaging/Agent/RetryInterval.hs b/src/Simplex/Messaging/Agent/RetryInterval.hs new file mode 100644 index 000000000..048b9e09c --- /dev/null +++ b/src/Simplex/Messaging/Agent/RetryInterval.hs @@ -0,0 +1,28 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Agent.RetryInterval where + +import Control.Concurrent (threadDelay) +import Control.Monad.IO.Class (MonadIO, liftIO) + +data RetryInterval = RetryInterval + { initialInterval :: Int, + increaseAfter :: Int, + maxInterval :: Int + } + +withRetryInterval :: forall m. MonadIO m => RetryInterval -> (m () -> m ()) -> m () +withRetryInterval RetryInterval {initialInterval, increaseAfter, maxInterval} action = + callAction 0 initialInterval + where + callAction :: Int -> Int -> m () + callAction elapsedTime delay = action loop + where + loop = do + let newDelay = + if elapsedTime < increaseAfter || delay == maxInterval + then delay + else min (delay * 3 `div` 2) maxInterval + liftIO $ threadDelay delay + callAction (elapsedTime + delay) newDelay diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 6d3dc606f..fd8b3ced6 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -9,7 +9,9 @@ module Simplex.Messaging.Agent.Store where +import Control.Concurrent.STM (TVar) import Control.Exception (Exception) +import Crypto.Random (ChaChaDRG) import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Kind (Type) @@ -30,33 +32,36 @@ import qualified Simplex.Messaging.Protocol as SMP -- | Store class type. Defines store access methods for implementations. class Monad m => MonadAgentStore s m where -- Queue and Connection management - createRcvConn :: s -> RcvQueue -> m () - createSndConn :: s -> SndQueue -> m () - getConn :: s -> ConnAlias -> m SomeConn - getAllConnAliases :: s -> m [ConnAlias] -- TODO remove - hack for subscribing to all + createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId + createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId + getConn :: s -> ConnId -> m SomeConn + getAllConnIds :: s -> m [ConnId] -- TODO remove - hack for subscribing to all getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn - deleteConn :: s -> ConnAlias -> m () - upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m () - upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m () + deleteConn :: s -> ConnId -> m () + upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m () + upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m () setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m () setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m () setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m () + updateSignKey :: s -> SndQueue -> SignatureKey -> m () + + -- Confirmations + createConfirmation :: s -> TVar ChaChaDRG -> NewConfirmation -> m ConfirmationId + acceptConfirmation :: s -> ConfirmationId -> ConnInfo -> m AcceptedConfirmation + getAcceptedConfirmation :: s -> ConnId -> m AcceptedConfirmation + removeConfirmations :: s -> ConnId -> m () -- Msg management - updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) - createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m () - - updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) - createSndMsg :: s -> SndQueue -> SndMsgData -> m () - - getMsg :: s -> ConnAlias -> InternalId -> m Msg - - -- Broadcasts - createBcast :: s -> BroadcastId -> m () - addBcastConn :: s -> BroadcastId -> ConnAlias -> m () - removeBcastConn :: s -> BroadcastId -> ConnAlias -> m () - deleteBcast :: s -> BroadcastId -> m () - getBcast :: s -> BroadcastId -> m [ConnAlias] + updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + createRcvMsg :: s -> ConnId -> RcvMsgData -> m () + updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash) + createSndMsg :: s -> ConnId -> SndMsgData -> m () + updateSndMsgStatus :: s -> ConnId -> InternalId -> SndMsgStatus -> m () + getPendingMsgData :: s -> ConnId -> InternalId -> m (SndQueue, MsgBody) + getPendingMsgs :: s -> ConnId -> m [PendingMsg] + getMsg :: s -> ConnId -> InternalId -> m Msg + checkRcvMsg :: s -> ConnId -> InternalId -> m () + updateRcvMsgAck :: s -> ConnId -> InternalId -> m () -- * Queue types @@ -64,10 +69,8 @@ class Monad m => MonadAgentStore s m where data RcvQueue = RcvQueue { server :: SMPServer, rcvId :: SMP.RecipientId, - connAlias :: ConnAlias, rcvPrivateKey :: RecipientPrivateKey, sndId :: Maybe SMP.SenderId, - sndKey :: Maybe SenderPublicKey, decryptKey :: DecryptionKey, verifyKey :: Maybe VerificationKey, status :: QueueStatus @@ -78,7 +81,6 @@ data RcvQueue = RcvQueue data SndQueue = SndQueue { server :: SMPServer, sndId :: SMP.SenderId, - connAlias :: ConnAlias, sndPrivateKey :: SenderPrivateKey, encryptKey :: EncryptionKey, signKey :: SignatureKey, @@ -102,9 +104,9 @@ data ConnType = CRcv | CSnd | CDuplex deriving (Eq, Show) -- - DuplexConnection is a connection that has both receive and send queues set up, -- typically created by upgrading a receive or a send connection with a missing queue. data Connection (d :: ConnType) where - RcvConnection :: ConnAlias -> RcvQueue -> Connection CRcv - SndConnection :: ConnAlias -> SndQueue -> Connection CSnd - DuplexConnection :: ConnAlias -> RcvQueue -> SndQueue -> Connection CDuplex + RcvConnection :: ConnData -> RcvQueue -> Connection CRcv + SndConnection :: ConnData -> SndQueue -> Connection CSnd + DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex deriving instance Eq (Connection d) @@ -141,9 +143,26 @@ instance Eq SomeConn where deriving instance Show SomeConn --- * Message integrity validation types +newtype ConnData = ConnData {connId :: ConnId} + deriving (Eq, Show) -type MsgHash = ByteString +-- * Confirmation types + +data NewConfirmation = NewConfirmation + { connId :: ConnId, + senderKey :: SenderPublicKey, + senderConnInfo :: ConnInfo + } + +data AcceptedConfirmation = AcceptedConfirmation + { confirmationId :: ConfirmationId, + connId :: ConnId, + senderKey :: SenderPublicKey, + senderConnInfo :: ConnInfo, + ownConnInfo :: ConnInfo + } + +-- * Message integrity validation types -- | Corresponds to `last_external_snd_msg_id` in `connections` table type PrevExternalSndId = Int64 @@ -159,15 +178,11 @@ type PrevSndMsgHash = MsgHash -- * Message data containers - used on Msg creation to reduce number of parameters data RcvMsgData = RcvMsgData - { internalId :: InternalId, - internalRcvId :: InternalRcvId, - internalTs :: InternalTs, - senderMeta :: (ExternalSndId, ExternalSndTs), - brokerMeta :: (BrokerId, BrokerTs), + { msgMeta :: MsgMeta, msgBody :: MsgBody, + internalRcvId :: InternalRcvId, internalHash :: MsgHash, - externalPrevSndHash :: MsgHash, - msgIntegrity :: MsgIntegrity + externalPrevSndHash :: MsgHash } data SndMsgData = SndMsgData @@ -175,9 +190,16 @@ data SndMsgData = SndMsgData internalSndId :: InternalSndId, internalTs :: InternalTs, msgBody :: MsgBody, - internalHash :: MsgHash + internalHash :: MsgHash, + previousMsgHash :: MsgHash } +data PendingMsg = PendingMsg + { connId :: ConnId, + msgId :: InternalId + } + deriving (Show) + -- * Broadcast types type BroadcastId = ByteString @@ -252,9 +274,9 @@ data SndMsg = SndMsg newtype InternalSndId = InternalSndId {unSndId :: Int64} deriving (Eq, Show) data SndMsgStatus - = Created - | Sent - | Delivered + = SndMsgCreated + | SndMsgSent + | SndMsgDelivered deriving (Eq, Show) type SentTs = UTCTime @@ -263,7 +285,7 @@ type DeliveredTs = UTCTime -- | Base message data independent of direction. data MsgBase = MsgBase - { connAlias :: ConnAlias, + { connAlias :: ConnId, -- | Monotonically increasing id of a message per connection, internal to the agent. -- Internal Id preserves ordering between both received and sent messages, and is needed -- to track the order of the conversation (which can be different for the sender / receiver) @@ -287,6 +309,8 @@ type InternalTs = UTCTime data StoreError = -- | IO exceptions in store actions. SEInternal ByteString + | -- | failed to generate unique random ID + SEUniqueID | -- | Connection alias not found (or both queues absent). SEConnNotFound | -- | Connection alias already used. @@ -294,10 +318,10 @@ data StoreError | -- | Wrong connection type, e.g. "send" connection when "receive" or "duplex" is expected, or vice versa. -- 'upgradeRcvConnToDuplex' and 'upgradeSndConnToDuplex' do not allow duplex connections - they would also return this error. SEBadConnType ConnType - | -- | Broadcast ID not found. - SEBcastNotFound - | -- | Broadcast ID already used. - SEBcastDuplicate + | -- | Confirmation not found. + SEConfirmationNotFound + | -- | Message not found + SEMsgNotFound | -- | Currently not used. The intention was to pass current expected queue status in methods, -- as we always know what it should be at any stage of the protocol, -- and in case it does not match use this error. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 9dcf7edd1..5d63af27c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -11,6 +12,7 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -18,20 +20,25 @@ module Simplex.Messaging.Agent.Store.SQLite ( SQLiteStore (..), createSQLiteStore, connectSQLiteStore, + withConnection, + withTransaction, + fromTextField_, ) where import Control.Concurrent (threadDelay) -import Control.Monad (unless, when) -import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO)) +import Control.Concurrent.STM +import Control.Exception (bracket) +import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) -import Data.Bifunctor (first) +import Crypto.Random (ChaChaDRG, randomBytesGenerate) +import Data.ByteString (ByteString) +import Data.ByteString.Base64 (encode) import Data.Char (toLower) import Data.List (find) import Data.Maybe (fromMaybe) -import Data.Text (isPrefixOf) +import Data.Text (Text) import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8) import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, field) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField @@ -42,8 +49,10 @@ import Database.SQLite.Simple.ToField (ToField (..)) import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.SQLite.Migrations (Migration) import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Parsers (blobFieldParser) +import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftIOEither) import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) @@ -57,35 +66,40 @@ import qualified UnliftIO.Exception as E data SQLiteStore = SQLiteStore { dbFilePath :: FilePath, - dbConn :: DB.Connection, + dbConnPool :: TBQueue DB.Connection, dbNew :: Bool } -createSQLiteStore :: FilePath -> IO SQLiteStore -createSQLiteStore dbFilePath = do +createSQLiteStore :: FilePath -> Int -> [Migration] -> IO SQLiteStore +createSQLiteStore dbFilePath poolSize migrations = do let dbDir = takeDirectory dbFilePath createDirectoryIfMissing False dbDir - store <- connectSQLiteStore dbFilePath - compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] - let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions) + st <- connectSQLiteStore dbFilePath poolSize + checkThreadsafe st + migrateSchema st migrations + pure st + +checkThreadsafe :: SQLiteStore -> IO () +checkThreadsafe st = withConnection st $ \db -> do + compileOptions <- DB.query_ db "pragma COMPILE_OPTIONS;" :: IO [[Text]] + let threadsafeOption = find (T.isPrefixOf "THREADSAFE=") (concat compileOptions) case threadsafeOption of Just "THREADSAFE=0" -> confirmOrExit "SQLite compiled with non-threadsafe code." Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found" _ -> return () - migrateSchema store - pure store -migrateSchema :: SQLiteStore -> IO () -migrateSchema SQLiteStore {dbConn, dbFilePath, dbNew} = do - Migrations.initialize dbConn - Migrations.get dbConn Migrations.app >>= \case +migrateSchema :: SQLiteStore -> [Migration] -> IO () +migrateSchema st migrations = withConnection st $ \db -> do + Migrations.initialize db + Migrations.get db migrations >>= \case Left e -> confirmOrExit $ "Database error: " <> e Right [] -> pure () Right ms -> do - unless dbNew $ do + unless (dbNew st) $ do confirmOrExit "The app has a newer version than the database - it will be backed up and upgraded." - copyFile dbFilePath $ dbFilePath <> ".bak" - Migrations.run dbConn ms + let f = dbFilePath st + copyFile f (f <> ".bak") + Migrations.run db ms confirmOrExit :: String -> IO () confirmOrExit s = do @@ -95,73 +109,90 @@ confirmOrExit s = do ok <- getLine when (map toLower ok /= "y") exitFailure -connectSQLiteStore :: FilePath -> IO SQLiteStore -connectSQLiteStore dbFilePath = do +connectSQLiteStore :: FilePath -> Int -> IO SQLiteStore +connectSQLiteStore dbFilePath poolSize = do dbNew <- not <$> doesFileExist dbFilePath - dbConn <- DB.open dbFilePath - DB.execute_ - dbConn - [sql| - PRAGMA foreign_keys = ON; - PRAGMA journal_mode = WAL; - |] - pure SQLiteStore {dbFilePath, dbConn, dbNew} + dbConnPool <- newTBQueueIO $ toEnum poolSize + replicateM_ poolSize $ + connectDB dbFilePath >>= atomically . writeTBQueue dbConnPool + pure SQLiteStore {dbFilePath, dbConnPool, dbNew} -checkConstraint :: StoreError -> IO () -> IO (Either StoreError ()) -checkConstraint err action = first handleError <$> E.try action - where - handleError :: SQLError -> StoreError - handleError e - | DB.sqlError e == DB.ErrorConstraint = err - | otherwise = SEInternal $ bshow e +connectDB :: FilePath -> IO DB.Connection +connectDB path = do + dbConn <- DB.open path + DB.execute_ dbConn "PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL;" + pure dbConn -withTransaction :: forall a. DB.Connection -> IO a -> IO a -withTransaction db a = loop 100 100_000 +checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a) +checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err) + +handleSQLError :: StoreError -> SQLError -> StoreError +handleSQLError err e + | DB.sqlError e == DB.ErrorConstraint = err + | otherwise = SEInternal $ bshow e + +withConnection :: SQLiteStore -> (DB.Connection -> IO a) -> IO a +withConnection SQLiteStore {dbConnPool} = + bracket + (atomically $ readTBQueue dbConnPool) + (atomically . writeTBQueue dbConnPool) + +withTransaction :: forall a. SQLiteStore -> (DB.Connection -> IO a) -> IO a +withTransaction st action = withConnection st $ loop 100 100_000 where - loop :: Int -> Int -> IO a - loop t tLim = - DB.withImmediateTransaction db a `E.catch` \(e :: SQLError) -> + loop :: Int -> Int -> DB.Connection -> IO a + loop t tLim db = + DB.withImmediateTransaction db (action db) `E.catch` \(e :: SQLError) -> if tLim > t && DB.sqlError e == DB.ErrorBusy then do threadDelay t - loop (t * 9 `div` 8) (tLim - t) + loop (t * 9 `div` 8) (tLim - t) db else E.throwIO e instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where - createRcvConn :: SQLiteStore -> RcvQueue -> m () - createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} = - liftIOEither $ - checkConstraint SEConnDuplicate $ - withTransaction dbConn $ do - upsertServer_ dbConn server - insertRcvQueue_ dbConn q - insertRcvConnection_ dbConn q + createRcvConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId + createRcvConn st gVar cData q@RcvQueue {server} = + -- TODO if schema has to be restarted, this function can be refactored + -- to create connection first using createWithRandomId + liftIOEither . checkConstraint SEConnDuplicate . withTransaction st $ \db -> + getConnId_ db gVar cData >>= traverse (create db) + where + create :: DB.Connection -> ConnId -> IO ConnId + create db connId = do + upsertServer_ db server + insertRcvQueue_ db connId q + insertRcvConnection_ db cData {connId} q + pure connId - createSndConn :: SQLiteStore -> SndQueue -> m () - createSndConn SQLiteStore {dbConn} q@SndQueue {server} = - liftIOEither $ - checkConstraint SEConnDuplicate $ - withTransaction dbConn $ do - upsertServer_ dbConn server - insertSndQueue_ dbConn q - insertSndConnection_ dbConn q + createSndConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId + createSndConn st gVar cData q@SndQueue {server} = + -- TODO if schema has to be restarted, this function can be refactored + -- to create connection first using createWithRandomId + liftIOEither . checkConstraint SEConnDuplicate . withTransaction st $ \db -> + getConnId_ db gVar cData >>= traverse (create db) + where + create :: DB.Connection -> ConnId -> IO ConnId + create db connId = do + upsertServer_ db server + insertSndQueue_ db connId q + insertSndConnection_ db cData {connId} q + pure connId - getConn :: SQLiteStore -> ConnAlias -> m SomeConn - getConn SQLiteStore {dbConn} connAlias = - liftIOEither . withTransaction dbConn $ - getConn_ dbConn connAlias + getConn :: SQLiteStore -> ConnId -> m SomeConn + getConn st connId = + liftIOEither . withTransaction st $ \db -> + getConn_ db connId - getAllConnAliases :: SQLiteStore -> m [ConnAlias] - getAllConnAliases SQLiteStore {dbConn} = - liftIO $ do - r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] - return (concat r) + getAllConnIds :: SQLiteStore -> m [ConnId] + getAllConnIds st = + liftIO . withTransaction st $ \db -> + concat <$> (DB.query_ db "SELECT conn_alias FROM connections;" :: IO [[ConnId]]) getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn - getRcvConn SQLiteStore {dbConn} SMPServer {host, port} rcvId = - liftIOEither . withTransaction dbConn $ + getRcvConn st SMPServer {host, port} rcvId = + liftIOEither . withTransaction st $ \db -> DB.queryNamed - dbConn + db [sql| SELECT q.conn_alias FROM rcv_queues q @@ -169,47 +200,47 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto |] [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] >>= \case - [Only connAlias] -> getConn_ dbConn connAlias + [Only connId] -> getConn_ db connId _ -> pure $ Left SEConnNotFound - deleteConn :: SQLiteStore -> ConnAlias -> m () - deleteConn SQLiteStore {dbConn} connAlias = - liftIO $ + deleteConn :: SQLiteStore -> ConnId -> m () + deleteConn st connId = + liftIO . withTransaction st $ \db -> DB.executeNamed - dbConn + db "DELETE FROM connections WHERE conn_alias = :conn_alias;" - [":conn_alias" := connAlias] + [":conn_alias" := connId] - upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m () - upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} = - liftIOEither . withTransaction dbConn $ - getConn_ dbConn connAlias >>= \case + upgradeRcvConnToDuplex :: SQLiteStore -> ConnId -> SndQueue -> m () + upgradeRcvConnToDuplex st connId sq@SndQueue {server} = + liftIOEither . withTransaction st $ \db -> + getConn_ db connId >>= \case Right (SomeConn _ RcvConnection {}) -> do - upsertServer_ dbConn server - insertSndQueue_ dbConn sq - updateConnWithSndQueue_ dbConn connAlias sq + upsertServer_ db server + insertSndQueue_ db connId sq + updateConnWithSndQueue_ db connId sq pure $ Right () Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c _ -> pure $ Left SEConnNotFound - upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m () - upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} = - liftIOEither . withTransaction dbConn $ - getConn_ dbConn connAlias >>= \case + upgradeSndConnToDuplex :: SQLiteStore -> ConnId -> RcvQueue -> m () + upgradeSndConnToDuplex st connId rq@RcvQueue {server} = + liftIOEither . withTransaction st $ \db -> + getConn_ db connId >>= \case Right (SomeConn _ SndConnection {}) -> do - upsertServer_ dbConn server - insertRcvQueue_ dbConn rq - updateConnWithRcvQueue_ dbConn connAlias rq + upsertServer_ db server + insertRcvQueue_ db connId rq + updateConnWithRcvQueue_ db connId rq pure $ Right () Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c _ -> pure $ Left SEConnNotFound setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m () - setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status = + setRcvQueueStatus st RcvQueue {rcvId, server = SMPServer {host, port}} status = -- ? throw error if queue does not exist? - liftIO $ + liftIO . withTransaction st $ \db -> DB.executeNamed - dbConn + db [sql| UPDATE rcv_queues SET status = :status @@ -218,11 +249,11 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m () - setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = + setRcvQueueActive st RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = -- ? throw error if queue does not exist? - liftIO $ + liftIO . withTransaction st $ \db -> DB.executeNamed - dbConn + db [sql| UPDATE rcv_queues SET verify_key = :verify_key, status = :status @@ -236,11 +267,11 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto ] setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () - setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status = + setSndQueueStatus st SndQueue {sndId, server = SMPServer {host, port}} status = -- ? throw error if queue does not exist? - liftIO $ + liftIO . withTransaction st $ \db -> DB.executeNamed - dbConn + db [sql| UPDATE snd_queues SET status = :status @@ -248,82 +279,193 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto |] [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) - updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} = - liftIO . withTransaction dbConn $ do - (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias + updateSignKey :: SQLiteStore -> SndQueue -> SignatureKey -> m () + updateSignKey st SndQueue {sndId, server = SMPServer {host, port}} signatureKey = + liftIO . withTransaction st $ \db -> + DB.executeNamed + db + [sql| + UPDATE snd_queues + SET sign_key = :sign_key + WHERE host = :host AND port = :port AND snd_id = :snd_id; + |] + [":sign_key" := signatureKey, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] + + createConfirmation :: SQLiteStore -> TVar ChaChaDRG -> NewConfirmation -> m ConfirmationId + createConfirmation st gVar NewConfirmation {connId, senderKey, senderConnInfo} = + liftIOEither . withTransaction st $ \db -> + createWithRandomId gVar $ \confirmationId -> + DB.execute + db + [sql| + INSERT INTO conn_confirmations + (confirmation_id, conn_alias, sender_key, sender_conn_info, accepted) VALUES (?, ?, ?, ?, 0); + |] + (confirmationId, connId, senderKey, senderConnInfo) + + acceptConfirmation :: SQLiteStore -> ConfirmationId -> ConnInfo -> m AcceptedConfirmation + acceptConfirmation st confirmationId ownConnInfo = + liftIOEither . withTransaction st $ \db -> do + DB.executeNamed + db + [sql| + UPDATE conn_confirmations + SET accepted = 1, + own_conn_info = :own_conn_info + WHERE confirmation_id = :confirmation_id; + |] + [ ":own_conn_info" := ownConnInfo, + ":confirmation_id" := confirmationId + ] + confirmation + <$> DB.query + db + [sql| + SELECT conn_alias, sender_key, sender_conn_info + FROM conn_confirmations + WHERE confirmation_id = ?; + |] + (Only confirmationId) + where + confirmation [(connId, senderKey, senderConnInfo)] = + Right $ AcceptedConfirmation {confirmationId, connId, senderKey, senderConnInfo, ownConnInfo} + confirmation _ = Left SEConfirmationNotFound + + getAcceptedConfirmation :: SQLiteStore -> ConnId -> m AcceptedConfirmation + getAcceptedConfirmation st connId = + liftIOEither . withTransaction st $ \db -> + confirmation + <$> DB.query + db + [sql| + SELECT confirmation_id, sender_key, sender_conn_info, own_conn_info + FROM conn_confirmations + WHERE conn_alias = ? AND accepted = 1; + |] + (Only connId) + where + confirmation [(confirmationId, senderKey, senderConnInfo, ownConnInfo)] = + Right $ AcceptedConfirmation {confirmationId, connId, senderKey, senderConnInfo, ownConnInfo} + confirmation _ = Left SEConfirmationNotFound + + removeConfirmations :: SQLiteStore -> ConnId -> m () + removeConfirmations st connId = + liftIO . withTransaction st $ \db -> + DB.executeNamed + db + [sql| + DELETE FROM conn_confirmations + WHERE conn_alias = :conn_alias; + |] + [":conn_alias" := connId] + + updateRcvIds :: SQLiteStore -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + updateRcvIds st connId = + liftIO . withTransaction st $ \db -> do + (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ db connId let internalId = InternalId $ unId lastInternalId + 1 internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 - updateLastIdsRcv_ dbConn connAlias internalId internalRcvId + updateLastIdsRcv_ db connId internalId internalRcvId pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) - createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m () - createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData = - liftIO . withTransaction dbConn $ do - insertRcvMsgBase_ dbConn connAlias rcvMsgData - insertRcvMsgDetails_ dbConn connAlias rcvMsgData - updateHashRcv_ dbConn connAlias rcvMsgData + createRcvMsg :: SQLiteStore -> ConnId -> RcvMsgData -> m () + createRcvMsg st connId rcvMsgData = + liftIO . withTransaction st $ \db -> do + insertRcvMsgBase_ db connId rcvMsgData + insertRcvMsgDetails_ db connId rcvMsgData + updateHashRcv_ db connId rcvMsgData - updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) - updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} = - liftIO . withTransaction dbConn $ do - (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias + updateSndIds :: SQLiteStore -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash) + updateSndIds st connId = + liftIO . withTransaction st $ \db -> do + (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ db connId let internalId = InternalId $ unId lastInternalId + 1 internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 - updateLastIdsSnd_ dbConn connAlias internalId internalSndId + updateLastIdsSnd_ db connId internalId internalSndId pure (internalId, internalSndId, prevSndHash) - createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m () - createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData = - liftIO . withTransaction dbConn $ do - insertSndMsgBase_ dbConn connAlias sndMsgData - insertSndMsgDetails_ dbConn connAlias sndMsgData - updateHashSnd_ dbConn connAlias sndMsgData + createSndMsg :: SQLiteStore -> ConnId -> SndMsgData -> m () + createSndMsg st connId sndMsgData = + liftIO . withTransaction st $ \db -> do + insertSndMsgBase_ db connId sndMsgData + insertSndMsgDetails_ db connId sndMsgData + updateHashSnd_ db connId sndMsgData - getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg - getMsg _st _connAlias _id = throwError SENotImplemented + updateSndMsgStatus :: SQLiteStore -> ConnId -> InternalId -> SndMsgStatus -> m () + updateSndMsgStatus st connId msgId msgStatus = + liftIO . withTransaction st $ \db -> + DB.executeNamed + db + [sql| + UPDATE snd_messages + SET snd_status = :snd_status + WHERE conn_alias = :conn_alias AND internal_id = :internal_id + |] + [ ":conn_alias" := connId, + ":internal_id" := msgId, + ":snd_status" := msgStatus + ] - createBcast :: SQLiteStore -> BroadcastId -> m () - createBcast SQLiteStore {dbConn} bId = - liftIOEither $ - checkConstraint SEBcastDuplicate $ - DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId) + getPendingMsgData :: SQLiteStore -> ConnId -> InternalId -> m (SndQueue, MsgBody) + getPendingMsgData st connId msgId = + liftIOEither . withTransaction st $ \db -> runExceptT $ do + sq <- ExceptT $ sndQueue <$> getSndQueueByConnAlias_ db connId + msgBody <- + ExceptT $ + sndMsgData + <$> DB.query + db + [sql| + SELECT m.msg_body + FROM messages m + JOIN snd_messages s ON s.conn_alias = m.conn_alias AND s.internal_id = m.internal_id + WHERE m.conn_alias = ? AND m.internal_id = ? + |] + (connId, msgId) + pure (sq, msgBody) + where + sndMsgData :: [Only MsgBody] -> Either StoreError MsgBody + sndMsgData [Only msgBody] = Right msgBody + sndMsgData _ = Left SEMsgNotFound + sndQueue :: Maybe SndQueue -> Either StoreError SndQueue + sndQueue = maybe (Left SEConnNotFound) Right - addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m () - addBcastConn SQLiteStore {dbConn} bId connAlias = - liftIOEither . checkBroadcast dbConn bId $ - getConn_ dbConn connAlias >>= \case - Left _ -> pure $ Left SEConnNotFound - Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv - Right _ -> - checkConstraint SEConnDuplicate $ - DB.execute - dbConn - "INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);" - (bId, connAlias) + getPendingMsgs :: SQLiteStore -> ConnId -> m [PendingMsg] + getPendingMsgs st connId = + liftIO . withTransaction st $ \db -> + map (PendingMsg connId . fromOnly) + <$> DB.query db "SELECT internal_id FROM snd_messages WHERE conn_alias = ? AND snd_status = ?" (connId, SndMsgCreated) - removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m () - removeBcastConn SQLiteStore {dbConn} bId connAlias = - liftIOEither . checkBroadcast dbConn bId $ - bcastConnExists_ dbConn bId connAlias >>= \case - False -> pure $ Left SEConnNotFound - _ -> - Right - <$> DB.execute - dbConn - "DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;" - (bId, connAlias) + getMsg :: SQLiteStore -> ConnId -> InternalId -> m Msg + getMsg _st _connId _id = throwError SENotImplemented - deleteBcast :: SQLiteStore -> BroadcastId -> m () - deleteBcast SQLiteStore {dbConn} bId = - liftIOEither . checkBroadcast dbConn bId $ - Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId) + checkRcvMsg :: SQLiteStore -> ConnId -> InternalId -> m () + checkRcvMsg st connId msgId = + liftIOEither . withTransaction st $ \db -> + hasMsg + <$> DB.query + db + [sql| + SELECT conn_alias, internal_id + FROM rcv_messages + WHERE conn_alias = ? AND internal_id = ? + |] + (connId, msgId) + where + hasMsg :: [(ConnId, InternalId)] -> Either StoreError () + hasMsg r = if null r then Left SEMsgNotFound else Right () - getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias] - getBcast SQLiteStore {dbConn} bId = - liftIOEither . checkBroadcast dbConn bId $ - Right . map fromOnly - <$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId) + updateRcvMsgAck :: SQLiteStore -> ConnId -> InternalId -> m () + updateRcvMsgAck st connId msgId = + liftIO . withTransaction st $ \db -> do + DB.execute + db + [sql| + UPDATE rcv_messages + SET rcv_status = ?, ack_brocker_ts = datetime('now') + WHERE conn_alias = ? AND internal_id = ? + |] + (AcknowledgedToBroker, connId, msgId) -- * Auxiliary helpers @@ -337,7 +479,7 @@ deserializePort_ port = Just port instance ToField QueueStatus where toField = toField . show -instance FromField QueueStatus where fromField = fromFieldToReadable_ +instance FromField QueueStatus where fromField = fromTextField_ $ readMaybe . T.unpack instance ToField InternalRcvId where toField (InternalRcvId x) = toField x @@ -359,13 +501,16 @@ instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity instance FromField MsgIntegrity where fromField = blobFieldParser msgIntegrityP -fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a -fromFieldToReadable_ = \case +instance ToField SMPQueueInfo where toField = toField . serializeSmpQueueInfo + +instance FromField SMPQueueInfo where fromField = blobFieldParser smpQueueInfoP + +fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a +fromTextField_ fromText = \case f@(Field (SQLText t) _) -> - let str = T.unpack t - in case readMaybe str of - Just x -> Ok x - _ -> returnError ConversionFailed f ("invalid string: " <> str) + case fromText t of + Just x -> Ok x + _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) f -> returnError ConversionFailed f "expecting SQLText column type" {- ORMOLU_DISABLE -} @@ -397,49 +542,49 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do -- * createRcvConn helpers -insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO () -insertRcvQueue_ dbConn RcvQueue {..} = do +insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () +insertRcvQueue_ dbConn connId RcvQueue {..} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn [sql| INSERT INTO rcv_queues - ( host, port, rcv_id, conn_alias, rcv_private_key, snd_id, snd_key, decrypt_key, verify_key, status) + ( host, port, rcv_id, conn_alias, rcv_private_key, snd_id, decrypt_key, verify_key, status) VALUES - (:host,:port,:rcv_id,:conn_alias,:rcv_private_key,:snd_id,:snd_key,:decrypt_key,:verify_key,:status); + (:host,:port,:rcv_id,:conn_alias,:rcv_private_key,:snd_id,:decrypt_key,:verify_key,:status); |] [ ":host" := host server, ":port" := port_, ":rcv_id" := rcvId, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":rcv_private_key" := rcvPrivateKey, ":snd_id" := sndId, - ":snd_key" := sndKey, ":decrypt_key" := decryptKey, ":verify_key" := verifyKey, ":status" := status ] -insertRcvConnection_ :: DB.Connection -> RcvQueue -> IO () -insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do +insertRcvConnection_ :: DB.Connection -> ConnData -> RcvQueue -> IO () +insertRcvConnection_ dbConn ConnData {connId} RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn [sql| INSERT INTO connections - ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, - last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) + ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES - (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, - 0, 0, 0, 0, x'', x''); + (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, 0, 0, 0, 0, x'', x''); |] - [":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId] + [ ":conn_alias" := connId, + ":rcv_host" := host server, + ":rcv_port" := port_, + ":rcv_id" := rcvId + ] -- * createSndConn helpers -insertSndQueue_ :: DB.Connection -> SndQueue -> IO () -insertSndQueue_ dbConn SndQueue {..} = do +insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO () +insertSndQueue_ dbConn connId SndQueue {..} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -452,85 +597,94 @@ insertSndQueue_ dbConn SndQueue {..} = do [ ":host" := host server, ":port" := port_, ":snd_id" := sndId, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":snd_private_key" := sndPrivateKey, ":encrypt_key" := encryptKey, ":sign_key" := signKey, ":status" := status ] -insertSndConnection_ :: DB.Connection -> SndQueue -> IO () -insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do +insertSndConnection_ :: DB.Connection -> ConnData -> SndQueue -> IO () +insertSndConnection_ dbConn ConnData {connId} SndQueue {server, sndId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn [sql| INSERT INTO connections - ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, - last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) + ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES - (:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id, - 0, 0, 0, 0, x'', x''); + (:conn_alias, NULL, NULL, NULL, :snd_host,:snd_port,:snd_id, 0, 0, 0, 0, x'', x''); |] - [":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId] + [ ":conn_alias" := connId, + ":snd_host" := host server, + ":snd_port" := port_, + ":snd_id" := sndId + ] -- * getConn helpers -getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn) -getConn_ dbConn connAlias = do - rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias - sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias - pure $ case (rQ, sQ) of - (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) - (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ) - (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> Left SEConnNotFound +getConn_ :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getConn_ dbConn connId = + getConnData_ dbConn connId >>= \case + Nothing -> pure $ Left SEConnNotFound + Just connData -> do + rQ <- getRcvQueueByConnAlias_ dbConn connId + sQ <- getSndQueueByConnAlias_ dbConn connId + pure $ case (rQ, sQ) of + (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ) + (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ) + (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connData sndQ) + _ -> Left SEConnNotFound -retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue) -retrieveRcvQueueByConnAlias_ dbConn connAlias = do - r <- - DB.queryNamed +getConnData_ :: DB.Connection -> ConnId -> IO (Maybe ConnData) +getConnData_ dbConn connId' = + connData + <$> DB.query dbConn "SELECT conn_alias FROM connections WHERE conn_alias = ?;" (Only connId') + where + connData [Only connId] = Just ConnData {connId} + connData _ = Nothing + +getRcvQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue) +getRcvQueueByConnAlias_ dbConn connId = + rcvQueue + <$> DB.query dbConn [sql| - SELECT - s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, - q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status + SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key, + q.snd_id, q.decrypt_key, q.verify_key, q.status FROM rcv_queues q INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.conn_alias = :conn_alias; + WHERE q.conn_alias = ?; |] - [":conn_alias" := connAlias] - case r of - [(keyHash, host, port, rcvId, cAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do + (Only connId) + where + rcvQueue [(keyHash, host, port, rcvId, rcvPrivateKey, sndId, decryptKey, verifyKey, status)] = let srv = SMPServer host (deserializePort_ port) keyHash - return . Just $ RcvQueue srv rcvId cAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status - _ -> return Nothing + in Just $ RcvQueue srv rcvId rcvPrivateKey sndId decryptKey verifyKey status + rcvQueue _ = Nothing -retrieveSndQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe SndQueue) -retrieveSndQueueByConnAlias_ dbConn connAlias = do - r <- - DB.queryNamed +getSndQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue) +getSndQueueByConnAlias_ dbConn connId = + sndQueue + <$> DB.query dbConn [sql| - SELECT - s.key_hash, q.host, q.port, q.snd_id, q.conn_alias, - q.snd_private_key, q.encrypt_key, q.sign_key, q.status + SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_private_key, q.encrypt_key, q.sign_key, q.status FROM snd_queues q INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.conn_alias = :conn_alias; + WHERE q.conn_alias = ?; |] - [":conn_alias" := connAlias] - case r of - [(keyHash, host, port, sndId, cAlias, sndPrivateKey, encryptKey, signKey, status)] -> do + (Only connId) + where + sndQueue [(keyHash, host, port, sndId, sndPrivateKey, encryptKey, signKey, status)] = let srv = SMPServer host (deserializePort_ port) keyHash - return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status - _ -> return Nothing + in Just $ SndQueue srv sndId sndPrivateKey encryptKey signKey status + sndQueue _ = Nothing -- * upgradeRcvConnToDuplex helpers -updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () -updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do +updateConnWithSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO () +updateConnWithSndQueue_ dbConn connId SndQueue {server, sndId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -539,12 +693,12 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do SET snd_host = :snd_host, snd_port = :snd_port, snd_id = :snd_id WHERE conn_alias = :conn_alias; |] - [":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connAlias] + [":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connId] -- * upgradeSndConnToDuplex helpers -updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () -updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do +updateConnWithRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () +updateConnWithRcvQueue_ dbConn connId RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server DB.executeNamed dbConn @@ -553,12 +707,12 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do SET rcv_host = :rcv_host, rcv_port = :rcv_port, rcv_id = :rcv_id WHERE conn_alias = :conn_alias; |] - [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias] + [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connId] -- * updateRcvIds helpers -retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) -retrieveLastIdsAndHashRcv_ dbConn connAlias = do +retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +retrieveLastIdsAndHashRcv_ dbConn connId = do [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- DB.queryNamed dbConn @@ -567,11 +721,11 @@ retrieveLastIdsAndHashRcv_ dbConn connAlias = do FROM connections WHERE conn_alias = :conn_alias; |] - [":conn_alias" := connAlias] + [":conn_alias" := connId] return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) -updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () -updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = +updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO () +updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = DB.executeNamed dbConn [sql| @@ -582,30 +736,32 @@ updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = |] [ ":last_internal_msg_id" := newInternalId, ":last_internal_rcv_msg_id" := newInternalRcvId, - ":conn_alias" := connAlias + ":conn_alias" := connId ] -- * createRcvMsg helpers -insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () -insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do +insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgBody, internalRcvId} = do + let MsgMeta {recipient = (internalId, internalTs)} = msgMeta DB.executeNamed dbConn [sql| INSERT INTO messages - ( conn_alias, internal_id, internal_ts, internal_rcv_id, internal_snd_id, body) + ( conn_alias, internal_id, internal_ts, internal_rcv_id, internal_snd_id, body, msg_body) VALUES - (:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL,:body); + (:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL, '',:msg_body); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_id" := internalId, ":internal_ts" := internalTs, ":internal_rcv_id" := internalRcvId, - ":body" := decodeUtf8 msgBody + ":msg_body" := msgBody ] -insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () -insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = +insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +insertRcvMsgDetails_ dbConn connId RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash} = do + let MsgMeta {integrity, recipient, sender, broker} = msgMeta DB.executeNamed dbConn [sql| @@ -618,21 +774,21 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = :broker_id,:broker_ts,:rcv_status, NULL, NULL, :internal_hash,:external_prev_snd_hash,:integrity); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_rcv_id" := internalRcvId, - ":internal_id" := internalId, - ":external_snd_id" := fst senderMeta, - ":external_snd_ts" := snd senderMeta, - ":broker_id" := fst brokerMeta, - ":broker_ts" := snd brokerMeta, + ":internal_id" := fst recipient, + ":external_snd_id" := fst sender, + ":external_snd_ts" := snd sender, + ":broker_id" := fst broker, + ":broker_ts" := snd broker, ":rcv_status" := Received, ":internal_hash" := internalHash, ":external_prev_snd_hash" := externalPrevSndHash, - ":integrity" := msgIntegrity + ":integrity" := integrity ] -updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () -updateHashRcv_ dbConn connAlias RcvMsgData {..} = +updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +updateHashRcv_ dbConn connId RcvMsgData {msgMeta, internalHash, internalRcvId} = DB.executeNamed dbConn -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved @@ -643,16 +799,16 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} = WHERE conn_alias = :conn_alias AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; |] - [ ":last_external_snd_msg_id" := fst senderMeta, + [ ":last_external_snd_msg_id" := fst (sender msgMeta), ":last_rcv_msg_hash" := internalHash, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":last_internal_rcv_msg_id" := internalRcvId ] -- * updateSndIds helpers -retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash) -retrieveLastIdsAndHashSnd_ dbConn connAlias = do +retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash) +retrieveLastIdsAndHashSnd_ dbConn connId = do [(lastInternalId, lastInternalSndId, lastSndHash)] <- DB.queryNamed dbConn @@ -661,11 +817,11 @@ retrieveLastIdsAndHashSnd_ dbConn connAlias = do FROM connections WHERE conn_alias = :conn_alias; |] - [":conn_alias" := connAlias] + [":conn_alias" := connId] return (lastInternalId, lastInternalSndId, lastSndHash) -updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () -updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = +updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO () +updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = DB.executeNamed dbConn [sql| @@ -676,47 +832,48 @@ updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = |] [ ":last_internal_msg_id" := newInternalId, ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_alias" := connAlias + ":conn_alias" := connId ] -- * createSndMsg helpers -insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () -insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do +insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +insertSndMsgBase_ dbConn connId SndMsgData {..} = do DB.executeNamed dbConn [sql| INSERT INTO messages - ( conn_alias, internal_id, internal_ts, internal_rcv_id, internal_snd_id, body) + ( conn_alias, internal_id, internal_ts, internal_rcv_id, internal_snd_id, body, msg_body) VALUES - (:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id,:body); + (:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id, '',:msg_body); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_id" := internalId, ":internal_ts" := internalTs, ":internal_snd_id" := internalSndId, - ":body" := decodeUtf8 msgBody + ":msg_body" := msgBody ] -insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () -insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = +insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +insertSndMsgDetails_ dbConn connId SndMsgData {..} = DB.executeNamed dbConn [sql| INSERT INTO snd_messages - ( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts, internal_hash) + ( conn_alias, internal_snd_id, internal_id, snd_status, sent_ts, delivered_ts, internal_hash, previous_msg_hash) VALUES - (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash); + (:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash,:previous_msg_hash); |] - [ ":conn_alias" := connAlias, + [ ":conn_alias" := connId, ":internal_snd_id" := internalSndId, ":internal_id" := internalId, - ":snd_status" := Created, - ":internal_hash" := internalHash + ":snd_status" := SndMsgCreated, + ":internal_hash" := internalHash, + ":previous_msg_hash" := previousMsgHash ] -updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () -updateHashSnd_ dbConn connAlias SndMsgData {..} = +updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +updateHashSnd_ dbConn connId SndMsgData {..} = DB.executeNamed dbConn -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved @@ -727,34 +884,39 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} = AND last_internal_snd_msg_id = :last_internal_snd_msg_id; |] [ ":last_snd_msg_hash" := internalHash, - ":conn_alias" := connAlias, + ":conn_alias" := connId, ":last_internal_snd_msg_id" := internalSndId ] --- * Broadcast helpers +-- create record with a random ID -checkBroadcast :: DB.Connection -> BroadcastId -> IO (Either StoreError a) -> IO (Either StoreError a) -checkBroadcast dbConn bId action = - withTransaction dbConn $ do - ok <- bcastExists_ dbConn bId - if ok then action else pure $ Left SEBcastNotFound +getConnId_ :: DB.Connection -> TVar ChaChaDRG -> ConnData -> IO (Either StoreError ConnId) +getConnId_ dbConn gVar ConnData {connId = ""} = getUniqueRandomId gVar $ getConnData_ dbConn +getConnId_ _ _ ConnData {connId} = pure $ Right connId -bcastExists_ :: DB.Connection -> BroadcastId -> IO Bool -bcastExists_ dbConn bId = not . null <$> queryBcast +getUniqueRandomId :: TVar ChaChaDRG -> (ByteString -> IO (Maybe a)) -> IO (Either StoreError ByteString) +getUniqueRandomId gVar get = tryGet 3 where - queryBcast :: IO [Only BroadcastId] - queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId) + tryGet :: Int -> IO (Either StoreError ByteString) + tryGet 0 = pure $ Left SEUniqueID + tryGet n = do + id' <- randomId gVar 12 + get id' >>= \case + Nothing -> pure $ Right id' + Just _ -> tryGet (n - 1) -bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool -bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn +createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString) +createWithRandomId gVar create = tryCreate 3 where - queryBcastConn :: IO [(BroadcastId, ConnAlias)] - queryBcastConn = - DB.query - dbConn - [sql| - SELECT broadcast_id, conn_alias - FROM broadcast_connections - WHERE broadcast_id = ? AND conn_alias = ?; - |] - (bId, connAlias) + tryCreate :: Int -> IO (Either StoreError ByteString) + tryCreate 0 = pure $ Left SEUniqueID + tryCreate n = do + id' <- randomId gVar 12 + E.try (create id') >>= \case + Right _ -> pure $ Right id' + Left e + | DB.sqlError e == DB.ErrorConstraint -> tryCreate (n - 1) + | otherwise -> pure . Left . SEInternal $ bshow e + +randomId :: TVar ChaChaDRG -> Int -> IO ByteString +randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index b022ea5bf..4e6128493 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -7,7 +7,8 @@ {-# LANGUAGE TupleSections #-} module Simplex.Messaging.Agent.Store.SQLite.Migrations - ( app, + ( Migration (..), + app, initialize, get, run, diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 140b33f88..87b340aae 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -53,9 +53,10 @@ import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Except import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Data.Maybe +import Data.Maybe (fromMaybe) import Network.Socket (ServiceName) import Numeric.Natural import Simplex.Messaging.Agent.Protocol (SMPServer (..)) @@ -64,7 +65,7 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Transport (ATransport (..), TCP, THandle (..), TProxy, Transport (..), TransportError, clientHandshake, runTransportClient) import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (bshow, liftError, raceAny_) -import System.Timeout +import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- @@ -98,6 +99,9 @@ data SMPClientConfig = SMPClientConfig tcpTimeout :: Int, -- | period for SMP ping commands (microseconds) smpPing :: Int, + -- | SMP transport block size, Nothing - the block size will be set by the server. + -- Allowed sizes are 4, 8, 16, 32, 64 KiB (* 1024 bytes). + smpBlockSize :: Maybe Int, -- | estimated maximum size of SMP command excluding message body, -- determines the maximum allowed message size smpCommandSize :: Int @@ -111,6 +115,7 @@ smpDefaultConfig = defaultTransport = ("5223", transport @TCP), tcpTimeout = 4_000_000, smpPing = 30_000_000, + smpBlockSize = Just 8192, smpCommandSize = 256 } @@ -127,7 +132,7 @@ type Response = Either SMPClientError Cmd -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient) -getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing} msgQ disconnected = +getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing, smpBlockSize} msgQ disconnected = atomically mkSMPClient >>= runClient useTransport where mkSMPClient :: STM SMPClient @@ -172,7 +177,7 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing} msgQ dis client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError Int) -> c -> IO () client _ c thVar h = - runExceptT (clientHandshake h $ keyHash smpServer) >>= \case + runExceptT (clientHandshake h smpBlockSize $ keyHash smpServer) >>= \case Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e Right th -> do atomically $ do @@ -195,22 +200,27 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing} msgQ dis process :: SMPClient -> IO () process SMPClient {rcvQ, sentCommands} = forever $ do (_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ - cs <- readTVarIO sentCommands - case M.lookup corrId cs of - Nothing -> do - case respOrErr of - Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) - -- TODO send everything else to errQ and log in agent - _ -> return () - Just Request {queueId, responseVar} -> atomically $ do - modifyTVar sentCommands $ M.delete corrId - putTMVar responseVar $ - if queueId == qId - then case respOrErr of - Left e -> Left $ SMPResponseError e - Right (Cmd _ (ERR e)) -> Left $ SMPServerError e - Right r -> Right r - else Left SMPUnexpectedResponse + if B.null $ bs corrId + then sendMsg qId respOrErr + else do + cs <- readTVarIO sentCommands + case M.lookup corrId cs of + Nothing -> sendMsg qId respOrErr + Just Request {queueId, responseVar} -> atomically $ do + modifyTVar sentCommands $ M.delete corrId + putTMVar responseVar $ + if queueId == qId + then case respOrErr of + Left e -> Left $ SMPResponseError e + Right (Cmd _ (ERR e)) -> Left $ SMPServerError e + Right r -> Right r + else Left SMPUnexpectedResponse + + sendMsg :: QueueId -> Either ErrorType Cmd -> IO () + sendMsg qId = \case + Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) + -- TODO send everything else to errQ and log in agent + _ -> return () -- | Disconnects SMP client from the server and terminates client threads. closeSMPClient :: SMPClient -> IO () diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index c56161712..10e7ad5b5 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -20,18 +20,20 @@ -- . module Simplex.Messaging.Crypto ( -- * RSA keys - PrivateKey (rsaPrivateKey), - SafePrivateKey, -- constructor is not exported + PrivateKey (rsaPrivateKey, publicKey), + SafePrivateKey (..), -- constructor is not exported FullPrivateKey (..), + APrivateKey (..), PublicKey (..), SafeKeyPair, FullKeyPair, KeyHash (..), generateKeyPair, - publicKey, + publicKey', publicKeySize, validKeySize, safePrivateKey, + removePublicKey, -- * E2E hybrid encryption scheme encrypt, @@ -121,6 +123,9 @@ newtype SafePrivateKey = SafePrivateKey {unPrivateKey :: R.PrivateKey} deriving -- | A newtype of 'Crypto.PubKey.RSA.PrivateKey' (with PublicKey inside). newtype FullPrivateKey = FullPrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show) +-- | A newtype of 'Crypto.PubKey.RSA.PrivateKey' (PublicKey may be inside). +newtype APrivateKey = APrivateKey {unPrivateKey :: R.PrivateKey} deriving (Eq, Show) + -- | Type-class used for both private key types: SafePrivateKey and FullPrivateKey. class PrivateKey k where -- unwraps 'Crypto.PubKey.RSA.PrivateKey' @@ -132,16 +137,36 @@ class PrivateKey k where -- smart constructor removing public key from SafePrivateKey but keeping it in FullPrivateKey mkPrivateKey :: R.PrivateKey -> k + -- extracts public key from private key + publicKey :: k -> Maybe PublicKey + +-- | Remove public key exponent from APrivateKey. +removePublicKey :: APrivateKey -> APrivateKey +removePublicKey (APrivateKey R.PrivateKey {private_pub = k, private_d}) = + APrivateKey $ unPrivateKey (safePrivateKey (R.public_size k, R.public_n k, private_d) :: SafePrivateKey) + instance PrivateKey SafePrivateKey where rsaPrivateKey = unPrivateKey _privateKey = SafePrivateKey mkPrivateKey R.PrivateKey {private_pub = k, private_d} = safePrivateKey (R.public_size k, R.public_n k, private_d) + publicKey _ = Nothing instance PrivateKey FullPrivateKey where rsaPrivateKey = unPrivateKey _privateKey = FullPrivateKey mkPrivateKey = FullPrivateKey + publicKey = Just . PublicKey . R.private_pub . rsaPrivateKey + +instance PrivateKey APrivateKey where + rsaPrivateKey = unPrivateKey + _privateKey = APrivateKey + mkPrivateKey = APrivateKey + publicKey pk = + let k = R.private_pub $ rsaPrivateKey pk + in if R.public_e k == 0 + then Nothing + else Just $ PublicKey k instance IsString FullPrivateKey where fromString = parseString (decode >=> decodePrivKey) @@ -151,10 +176,14 @@ instance IsString PublicKey where instance ToField SafePrivateKey where toField = toField . encodePrivKey +instance ToField APrivateKey where toField = toField . encodePrivKey + instance ToField PublicKey where toField = toField . encodePubKey instance FromField SafePrivateKey where fromField = blobFieldParser binaryPrivKeyP +instance FromField APrivateKey where fromField = blobFieldParser binaryPrivKeyP + instance FromField PublicKey where fromField = blobFieldParser binaryPubKeyP -- | Tuple of RSA 'PublicKey' and 'PrivateKey'. @@ -217,8 +246,8 @@ generateKeyPair size = loop privateKeySize :: PrivateKey k => k -> Int privateKeySize = R.public_size . R.private_pub . rsaPrivateKey -publicKey :: FullPrivateKey -> PublicKey -publicKey = PublicKey . R.private_pub . rsaPrivateKey +publicKey' :: FullPrivateKey -> PublicKey +publicKey' = PublicKey . R.private_pub . rsaPrivateKey publicKeySize :: PublicKey -> Int publicKeySize = R.public_size . rsaPublicKey @@ -227,6 +256,7 @@ validKeySize :: Int -> Bool validKeySize = \case 128 -> True 256 -> True + 384 -> True 512 -> True _ -> False diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 2b9522e3e..5e7741c29 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -30,7 +30,7 @@ base64StringP = do pure $ str <> pad tsISO8601P :: Parser UTCTime -tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ') +tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd parse :: Parser a -> e -> (ByteString -> Either e a) parse parser err = first (const err) . parseAll parser @@ -42,14 +42,17 @@ parseRead :: Read a => Parser ByteString -> Parser a parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack) parseRead1 :: Read a => Parser a -parseRead1 = parseRead $ A.takeTill (== ' ') +parseRead1 = parseRead $ A.takeTill wordEnd parseRead2 :: Read a => Parser a parseRead2 = parseRead $ do - w1 <- A.takeTill (== ' ') <* A.char ' ' - w2 <- A.takeTill (== ' ') + w1 <- A.takeTill wordEnd <* A.char ' ' + w2 <- A.takeTill wordEnd pure $ w1 <> " " <> w2 +wordEnd :: Char -> Bool +wordEnd c = c == ' ' || c == '\n' + parseString :: (ByteString -> Either String a) -> (String -> a) parseString p = either error id . p . B.pack diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 3d438d540..85d6e8369 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -192,6 +192,8 @@ data ErrorType CMD CommandError | -- | command authorization error - bad signature or non-existing SMP queue AUTH + | -- | SMP queue capacity is exceeded on the server + QUOTA | -- | ACK command is sent without message to be acknowledged NO_MSG | -- | internal server error diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 211317cc6..8c475f9d1 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -90,7 +90,8 @@ runSMPServerBlocking started cfg@ServerConfig {transports} = do runClient :: (Transport c, MonadUnliftIO m, MonadReader Env m) => TProxy c -> c -> m () runClient _ h = do keyPair <- asks serverKeyPair - liftIO (runExceptT $ serverHandshake h keyPair) >>= \case + ServerConfig {blockSize} <- asks config + liftIO (runExceptT $ serverHandshake h blockSize keyPair) >>= \case Right th -> runClientTransport th Left _ -> pure () @@ -157,6 +158,7 @@ verifyTransmission (sig, t@(corrId, queueId, cmd)) = do cryptoVerify $ case sigLen of 128 -> dummyKey128 256 -> dummyKey256 + 384 -> dummyKey384 512 -> dummyKey512 _ -> dummyKey256 sigLen = B.length $ C.unSignature sig @@ -169,6 +171,9 @@ dummyKey128 = "MIIBIDANBgkqhkiG9w0BAQEFAAOCAQ0AMIIBCAKBgQC2oeA7s4roXN5K2N6022I1/ dummyKey256 :: C.PublicKey dummyKey256 = "MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAxwmTvaqmdTbkfUGNi8Yu0L/T4cxuOlQlx3zGZ9X9Qx0+oZjknWK+QHrdWTcpS+zH4Hi7fP6kanOQoQ90Hj6Ghl57VU1GEdUPywSw4i1/7t0Wv9uT9Q2ktHp2rqVo3xkC9IVIpL7EZAxdRviIN2OsOB3g4a/F1ZpjxcAaZeOMUugiAX1+GtkLuE0Xn4neYjCaOghLxQTdhybN70VtnkiQLx/X9NjkDIl/spYGm3tQFMyYKkP6IWoEpj0926hJ0fmlmhy8tAOhlZsb/baW5cgkEZ3E9jVVrySCgQzoLQgma610FIISRpRJbSyv26jU7MkMxiyuBiDaFOORkXFttoKbtQKBgEbDS9II2brsz+vfI7uP8atFcawkE52cx4M1UWQhqb1H3tBiRl+qO+dMq1pPQF2bW7dlZAWYzS4W/367bTAuALHBDGB8xi1P4Njhh9vaOgTvuqrHG9NJQ85BLy0qGw8rjIWSIXVmVpfrXFJ8po5l04UE258Ll2yocv3QRQmddQW9" +dummyKey384 :: C.PublicKey +dummyKey384 = "MIICITANBgkqhkiG9w0BAQEFAAOCAg4AMIICCQKCAYEAthExp77lSFBMB0RedjgKIU+oNH5lMGdMqDCG0E5Ly7X49rFpfDMMN08GDIgvzg9kcwV3ScbPcjUE19wmAShX9f9k3w38KM3wmIBKSiuCREQl0V3xAYp1SYwiAkMNSSwxuIkDEeSOR56WdEcZvqbB4lY9MQlUv70KriPDxZaqKCTKslUezXHQuYPQX6eMnGFK7hxz5Kl5MajV52d+5iXsa8CA+m/e1KVnbelCO+xhN89xG8ALt0CJ9k5Wwo3myLgXi4dmNankCmg8jkh+7y2ywkzxMwH1JydDtV/FLzkbZsbPR2w93TNrTq1RJOuqMyh0VtdBSpxNW/Ft988TkkX2BAWzx82INw7W6/QbHGNtHNB995R4sgeYy8QbEpNGBhQnfQh7yRWygLTVXWKApQzzfCeIoDDWUS7dMv/zXoasAnpDBj+6UhHv3BHrps7kBvRyZQ2d/nUuAqiGd43ljJ++n6vNyFLgZoiV7HLia/FOGMkdt7j92CNmFHxiT6Xl7kRHAoGBAPNoWny2O7LBxzAKMLmQVHBAiKp6RMx+7URvtQDHDHPaZ7F3MvtvmYWwGzund3cQFAaV1EkJoYeI3YRuj6xdXgMyMaP54On++btArb6jUtZuvlC98qE8dEEHQNh+7TsCiMU+ivbeKFxS9A/B7OVedoMnPoJWhatbA9zB/6L1GNPh" + dummyKey512 :: C.PublicKey dummyKey512 = "MIICoDANBgkqhkiG9w0BAQEFAAOCAo0AMIICiAKCAgEArkCY9DuverJ4mmzDektv9aZMFyeRV46WZK9NsOBKEc+1ncqMs+LhLti9asKNgUBRbNzmbOe0NYYftrUpwnATaenggkTFxxbJ4JGJuGYbsEdFWkXSvrbWGtM8YUmn5RkAGme12xQ89bSM4VoJAGnrYPHwmcQd+KYCPZvTUsxaxgrJTX65ejHN9BsAn8XtGViOtHTDJO9yUMD2WrJvd7wnNa+0ugEteDLzMU++xS98VC+uA1vfauUqi3yXVchdfrLdVUuM+JE0gUEXCgzjuHkaoHiaGNiGhdPYoAJJdOKQOIHAKdk7Th6OPhirPhc9XYNB4O8JDthKhNtfokvFIFlC4QBRzJhpLIENaEBDt08WmgpOnecZB/CuxkqqOrNa8j5K5jNrtXAI67W46VEC2jeQy/gZwb64Zit2A4D00xXzGbQTPGj4ehcEMhLx5LSCygViEf0w0tN3c3TEyUcgPzvECd2ZVpQLr9Z4a07Ebr+YSuxcHhjg4Rg1VyJyOTTvaCBGm5X2B3+tI4NUttmikIHOYpBnsLmHY2BgfH2KcrIsDyAhInXmTFr/L2+erFarUnlfATd2L8Ti43TNHDedO6k6jI5Gyi62yPwjqPLEIIK8l+pIeNfHJ3pPmjhHBfzFcQLMMMXffHWNK8kWklrQXK+4j4HiPcTBvlO1FEtG9nEIZhUCgYA4a6WtI2k5YNli1C89GY5rGUY7RP71T6RWri/D3Lz9T7GvU+FemAyYmsvCQwqijUOur0uLvwSP8VdxpSUcrjJJSWur2hrPWzWlu0XbNaeizxpFeKbQP+zSrWJ1z8RwfAeUjShxt8q1TuqGqY10wQyp3nyiTGvS+KwZVj5h5qx8NQ==" @@ -292,16 +297,19 @@ client clnt@Client {subscriptions, rcvQ, sndQ} Server {subscribedQ} = QueueActive -> do ms <- asks msgStore msg <- mkMessage + quota <- asks $ msgQueueQuota . config atomically $ do - q <- getMsgQueue ms (recipientId qr) - writeMsg q msg - return ok + q <- getMsgQueue ms (recipientId qr) quota + isFull q >>= \case + False -> writeMsg q msg $> ok + True -> pure $ err QUOTA deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m Transmission deliverMessage tryPeek rId = \case Sub {subThread = NoSub} -> do ms <- asks msgStore - q <- atomically $ getMsgQueue ms rId + quota <- asks $ msgQueueQuota . config + q <- atomically $ getMsgQueue ms rId quota atomically (tryPeek q) >>= \case Nothing -> forkSub q $> ok Just msg -> atomically setDelivered $> mkResp corrId rId (msgCmd msg) diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 61873af27..5c397096b 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -25,9 +25,11 @@ import UnliftIO.STM data ServerConfig = ServerConfig { transports :: [(ServiceName, ATransport)], tbqSize :: Natural, + msgQueueQuota :: Natural, queueIdBytes :: Int, msgIdBytes :: Int, storeLog :: Maybe (StoreLog 'ReadMode), + blockSize :: Int, serverPrivateKey :: C.FullPrivateKey -- serverId :: ByteString } @@ -86,7 +88,7 @@ newEnv config = do idsDrg <- drgNew >>= newTVarIO s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig) let pk = serverPrivateKey config - serverKeyPair = (C.publicKey pk, pk) + serverKeyPair = (C.publicKey' pk, pk) return Env {config, server, queueStore, msgStore, idsDrg, serverKeyPair, storeLog = s'} where restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode) diff --git a/src/Simplex/Messaging/Server/MsgStore.hs b/src/Simplex/Messaging/Server/MsgStore.hs index e2cb8791a..3d729af60 100644 --- a/src/Simplex/Messaging/Server/MsgStore.hs +++ b/src/Simplex/Messaging/Server/MsgStore.hs @@ -3,6 +3,7 @@ module Simplex.Messaging.Server.MsgStore where import Data.Time.Clock +import Numeric.Natural import Simplex.Messaging.Protocol (Encoded, MsgBody, RecipientId) data Message = Message @@ -12,10 +13,11 @@ data Message = Message } class MonadMsgStore s q m | s -> q where - getMsgQueue :: s -> RecipientId -> m q + getMsgQueue :: s -> RecipientId -> Natural -> m q delMsgQueue :: s -> RecipientId -> m () class MonadMsgQueue q m where + isFull :: q -> m Bool writeMsg :: q -> Message -> m () -- non blocking tryPeekMsg :: q -> m (Maybe Message) -- non blocking peekMsg :: q -> m Message -- blocking diff --git a/src/Simplex/Messaging/Server/MsgStore/STM.hs b/src/Simplex/Messaging/Server/MsgStore/STM.hs index f5b0e670f..6d0fb63a0 100644 --- a/src/Simplex/Messaging/Server/MsgStore/STM.hs +++ b/src/Simplex/Messaging/Server/MsgStore/STM.hs @@ -8,11 +8,12 @@ module Simplex.Messaging.Server.MsgStore.STM where import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Numeric.Natural import Simplex.Messaging.Protocol (RecipientId) import Simplex.Messaging.Server.MsgStore import UnliftIO.STM -newtype MsgQueue = MsgQueue {msgQueue :: TQueue Message} +newtype MsgQueue = MsgQueue {msgQueue :: TBQueue Message} newtype MsgStoreData = MsgStoreData {messages :: Map RecipientId MsgQueue} @@ -22,13 +23,13 @@ newMsgStore :: STM STMMsgStore newMsgStore = newTVar $ MsgStoreData M.empty instance MonadMsgStore STMMsgStore MsgQueue STM where - getMsgQueue :: STMMsgStore -> RecipientId -> STM MsgQueue - getMsgQueue store rId = do + getMsgQueue :: STMMsgStore -> RecipientId -> Natural -> STM MsgQueue + getMsgQueue store rId quota = do m <- messages <$> readTVar store maybe (newQ m) return $ M.lookup rId m where newQ m' = do - q <- MsgQueue <$> newTQueue + q <- MsgQueue <$> newTBQueue quota writeTVar store . MsgStoreData $ M.insert rId q m' return q @@ -37,15 +38,18 @@ instance MonadMsgStore STMMsgStore MsgQueue STM where modifyTVar store $ MsgStoreData . M.delete rId . messages instance MonadMsgQueue MsgQueue STM where + isFull :: MsgQueue -> STM Bool + isFull = isFullTBQueue . msgQueue + writeMsg :: MsgQueue -> Message -> STM () - writeMsg = writeTQueue . msgQueue + writeMsg = writeTBQueue . msgQueue tryPeekMsg :: MsgQueue -> STM (Maybe Message) - tryPeekMsg = tryPeekTQueue . msgQueue + tryPeekMsg = tryPeekTBQueue . msgQueue peekMsg :: MsgQueue -> STM Message - peekMsg = peekTQueue . msgQueue + peekMsg = peekTBQueue . msgQueue -- atomic delete (== read) last and peek next message if available tryDelPeekMsg :: MsgQueue -> STM (Maybe Message) - tryDelPeekMsg (MsgQueue q) = tryReadTQueue q >> tryPeekTQueue q + tryDelPeekMsg (MsgQueue q) = tryReadTBQueue q >> tryPeekTBQueue q diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index 18e05260b..d8869f436 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -63,6 +63,7 @@ import Data.ByteArray (xor) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) +import Data.Maybe(fromMaybe) import Data.Set (Set) import qualified Data.Set as S import Data.Word (Word32) @@ -340,21 +341,21 @@ makeNextIV SessionKey {baseIV, counter} = atomically $ do -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -- -- The numbers in function names refer to the steps in the document. -serverHandshake :: forall c. Transport c => c -> C.FullKeyPair -> ExceptT TransportError IO (THandle c) -serverHandshake c (k, pk) = do +serverHandshake :: forall c. Transport c => c -> Int -> C.FullKeyPair -> ExceptT TransportError IO (THandle c) +serverHandshake c srvBlockSize (k, pk) = do + checkValidBlockSize srvBlockSize liftIO sendHeaderAndPublicKey_1 encryptedKeys <- receiveEncryptedKeys_4 - -- TODO server currently ignores blockSize returned by the client - -- this is reserved for future support of streams - ClientHandshake {blockSize = _, sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys - th <- liftIO $ transportHandle c rcvKey sndKey transportBlockSize -- keys are swapped here + ClientHandshake {blockSize, sndKey, rcvKey} <- decryptParseKeys_5 encryptedKeys + checkValidBlockSize blockSize + th <- liftIO $ transportHandle c rcvKey sndKey blockSize -- keys are swapped here sendWelcome_6 th pure th where sendHeaderAndPublicKey_1 :: IO () sendHeaderAndPublicKey_1 = do let sKey = C.encodePubKey k - header = ServerHeader {blockSize = transportBlockSize, keySize = B.length sKey} + header = ServerHeader {blockSize = srvBlockSize, keySize = B.length sKey} cPut c $ binaryServerHeader header cPut c sKey receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString @@ -374,13 +375,14 @@ serverHandshake c (k, pk) = do -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a -- -- The numbers in function names refer to the steps in the document. -clientHandshake :: forall c. Transport c => c -> Maybe C.KeyHash -> ExceptT TransportError IO (THandle c) -clientHandshake c keyHash = do +clientHandshake :: forall c. Transport c => c -> Maybe Int -> Maybe C.KeyHash -> ExceptT TransportError IO (THandle c) +clientHandshake c blkSize_ keyHash = do + mapM_ checkValidBlockSize blkSize_ (k, blkSize) <- getHeaderAndPublicKey_1_2 - -- TODO currently client always uses the blkSize returned by the server - keys@ClientHandshake {sndKey, rcvKey} <- liftIO $ generateKeys_3 blkSize - sendEncryptedKeys_4 k keys - th <- liftIO $ transportHandle c sndKey rcvKey blkSize + let clientBlkSize = fromMaybe blkSize blkSize_ + chs@ClientHandshake {sndKey, rcvKey} <- liftIO $ generateKeys_3 clientBlkSize + sendEncryptedKeys_4 k chs + th <- liftIO $ transportHandle c sndKey rcvKey clientBlkSize getWelcome_6 th >>= checkVersion pure th where @@ -388,8 +390,7 @@ clientHandshake c keyHash = do getHeaderAndPublicKey_1_2 = do header <- liftIO (cGet c serverHeaderSize) ServerHeader {blockSize, keySize} <- liftEither $ parse serverHeaderP (TEHandshake HEADER) header - when (blockSize < transportBlockSize || blockSize > maxTransportBlockSize) $ - throwError $ TEHandshake HEADER + checkValidBlockSize blockSize s <- liftIO $ cGet c keySize maybe (pure ()) (validateKeyHash_2 s) keyHash key <- liftEither $ parseKey s @@ -408,8 +409,8 @@ clientHandshake c keyHash = do baseIV <- C.randomIV pure SessionKey {aesKey, baseIV, counter = undefined} sendEncryptedKeys_4 :: C.PublicKey -> ClientHandshake -> ExceptT TransportError IO () - sendEncryptedKeys_4 k keys = - liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeClientHandshake keys) + sendEncryptedKeys_4 k chs = + liftError (const $ TEHandshake ENCRYPT) (C.encryptOAEP k $ serializeClientHandshake chs) >>= liftIO . cPut c getWelcome_6 :: THandle c -> ExceptT TransportError IO SMPVersion getWelcome_6 th = ExceptT $ (>>= parseSMPVersion) <$> tGetEncrypted th @@ -420,17 +421,18 @@ clientHandshake c keyHash = do when (major smpVersion > major currentSMPVersion) . throwE $ TEHandshake MAJOR_VERSION +checkValidBlockSize :: Int -> ExceptT TransportError IO () +checkValidBlockSize blkSize = + when (blkSize `notElem` transportBlockSizes) . throwError $ TEHandshake HEADER + data ServerHeader = ServerHeader {blockSize :: Int, keySize :: Int} deriving (Eq, Show) binaryRsaTransport :: Int binaryRsaTransport = 0 -transportBlockSize :: Int -transportBlockSize = 4096 - -maxTransportBlockSize :: Int -maxTransportBlockSize = 65536 +transportBlockSizes :: [Int] +transportBlockSizes = map (* 1024) [4, 8, 16, 32, 64] serverHeaderSize :: Int serverHeaderSize = 8 diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 2800e521e..5bd05c4a9 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -50,3 +50,6 @@ liftError f = liftEitherError f . runExceptT liftEitherError :: (MonadIO m, MonadError e' m) => (e -> e') -> IO (Either e a) -> m a liftEitherError f a = liftIOEither (first f <$> a) + +tryError :: MonadError e m => m a -> m (Either e a) +tryError action = (Right <$> action) `catchError` (pure . Left) diff --git a/stack.yaml b/stack.yaml index ae97d2a94..70267dd80 100644 --- a/stack.yaml +++ b/stack.yaml @@ -17,7 +17,7 @@ # # resolver: ./custom-snapshot.yaml # resolver: https://example.com/snapshots/2018-01-01.yaml -resolver: lts-17.12 +resolver: lts-18.0 # User packages to be built. # Various formats can be used as shown in the example below. diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index a3d9d184f..545bcece4 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -2,84 +2,87 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE PostfixOperators #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} +{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} -module AgentTests where +module AgentTests (agentTests) where +import AgentTests.FunctionalAPITests (functionalAPITests) import AgentTests.SQLiteTests (storeTests) import Control.Concurrent import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import SMPAgentClient +import SMPClient (testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..)) +import System.Directory (removeFile) import System.Timeout import Test.Hspec agentTests :: ATransport -> Spec agentTests (ATransport t) = do + describe "Functional API" $ functionalAPITests (ATransport t) describe "SQLite store" storeTests describe "SMP agent protocol syntax" $ syntaxTests t describe "Establishing duplex connection" do it "should connect via one server and one agent" $ smpAgentTest2_1_1 $ testDuplexConnection t + it "should connect via one server and one agent (random IDs)" $ + smpAgentTest2_1_1 $ testDuplexConnRandomIds t it "should connect via one server and 2 agents" $ smpAgentTest2_2_1 $ testDuplexConnection t + it "should connect via one server and 2 agents (random IDs)" $ + smpAgentTest2_2_1 $ testDuplexConnRandomIds t it "should connect via 2 servers and 2 agents" $ smpAgentTest2_2_2 $ testDuplexConnection t + it "should connect via 2 servers and 2 agents (random IDs)" $ + smpAgentTest2_2_2 $ testDuplexConnRandomIds t describe "Connection subscriptions" do it "should connect via one server and one agent" $ smpAgentTest3_1_1 $ testSubscription t it "should send notifications to client when server disconnects" $ smpAgentServerTest $ testSubscrNotification t - describe "Broadcast" do - it "should create broadcast and send messages" $ - smpAgentTest3 $ testBroadcast t + describe "Message delivery" do + it "should deliver messages after losing server connection and re-connecting" $ + smpAgentTest2_2_2_needs_server $ testMsgDeliveryServerRestart t + it "should deliver pending messages after agent restarting" $ + smpAgentTest1_1_1 $ testMsgDeliveryAgentRestart t -type TestTransmission p = (ACorrId, ByteString, APartyCmd p) - -type TestTransmission' p c = (ACorrId, ByteString, ACommand p c) - -type TestTransmissionOrError p = (ACorrId, ByteString, Either AgentErrorType (APartyCmd p)) - -testTE :: ATransmissionOrError p -> TestTransmissionOrError p -testTE (ATransmissionOrError corrId entity cmdOrErr) = - (corrId,serializeEntity entity,) $ case cmdOrErr of - Right cmd -> Right $ APartyCmd cmd - Left e -> Left e +-- | receive message to handle `h` +(<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent) +(<#:) = tGet SAgent -- | send transmission `t` to handle `h` and get response -(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (TestTransmissionOrError 'Agent) -h #: t = tPutRaw h t >> testTE <$> tGet SAgent h +(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent) +h #: t = tPutRaw h t >> (<#:) h -- | action and expected response -- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r` -(#>) :: IO (TestTransmissionOrError 'Agent) -> TestTransmission' 'Agent c -> Expectation -action #> (corrId, cAlias, cmd) = action `shouldReturn` (corrId, cAlias, Right (APartyCmd cmd)) +(#>) :: IO (ATransmissionOrError 'Agent) -> ATransmission 'Agent -> Expectation +action #> (corrId, cAlias, cmd) = action `shouldReturn` (corrId, cAlias, Right cmd) -- | action and predicate for the response -- `h #:t =#> p` is the test that sends `t` to `h` and validates the response using `p` -(=#>) :: IO (TestTransmissionOrError 'Agent) -> (TestTransmission 'Agent -> Bool) -> Expectation +(=#>) :: IO (ATransmissionOrError 'Agent) -> (ATransmission 'Agent -> Bool) -> Expectation action =#> p = action >>= (`shouldSatisfy` p . correctTransmission) -correctTransmission :: TestTransmissionOrError p -> TestTransmission p +correctTransmission :: ATransmissionOrError a -> ATransmission a correctTransmission (corrId, cAlias, cmdOrErr) = case cmdOrErr of Right cmd -> (corrId, cAlias, cmd) Left e -> error $ show e -- | receive message to handle `h` and validate that it is the expected one -(<#) :: Transport c => c -> TestTransmission' 'Agent c' -> Expectation -h <# (corrId, cAlias, cmd) = tGet SAgent h >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd))) . testTE +(<#) :: Transport c => c -> ATransmission 'Agent -> Expectation +h <# (corrId, cAlias, cmd) = (h <#:) `shouldReturn` (corrId, cAlias, Right cmd) -- | receive message to handle `h` and validate it using predicate `p` -(<#=) :: Transport c => c -> (TestTransmission 'Agent -> Bool) -> Expectation -h <#= p = tGet SAgent h >>= (`shouldSatisfy` p . correctTransmission . testTE) +(<#=) :: Transport c => c -> (ATransmission 'Agent -> Bool) -> Expectation +h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission) -- | test that nothing is delivered to handle `h` during 10ms (#:#) :: Transport c => c -> String -> Expectation @@ -90,125 +93,207 @@ h #:# err = tryGet `shouldReturn` () Just _ -> error err _ -> return () -pattern Msg :: MsgBody -> APartyCmd 'Agent -pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk} - -pattern Sent :: AgentMsgId -> APartyCmd 'Agent -pattern Sent msgId <- APartyCmd (SENT msgId) - -pattern Inv :: SMPQueueInfo -> APartyCmd 'Agent -pattern Inv invitation <- APartyCmd (INV invitation) +pattern Msg :: MsgBody -> ACommand 'Agent +pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} msgBody testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO () testDuplexConnection _ alice bob = do - ("1", "C:bob", Right (Inv qInfo)) <- alice #: ("1", "C:bob", "NEW") + ("1", "bob", Right (INV qInfo)) <- alice #: ("1", "bob", "NEW") let qInfo' = serializeSmpQueueInfo qInfo - bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON) - alice <# ("", "C:bob", CON) - alice #: ("2", "C:bob", "SEND :hello") =#> \case ("2", "C:bob", Sent 1) -> True; _ -> False - alice #: ("3", "C:bob", "SEND :how are you?") =#> \case ("3", "C:bob", Sent 2) -> True; _ -> False - bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False - bob <#= \case ("", "C:alice", Msg "how are you?") -> True; _ -> False - bob #: ("14", "C:alice", "SEND 9\nhello too") =#> \case ("14", "C:alice", Sent 3) -> True; _ -> False - alice <#= \case ("", "C:bob", Msg "hello too") -> True; _ -> False - bob #: ("15", "C:alice", "SEND 9\nmessage 1") =#> \case ("15", "C:alice", Sent 4) -> True; _ -> False - alice <#= \case ("", "C:bob", Msg "message 1") -> True; _ -> False - alice #: ("5", "C:bob", "OFF") #> ("5", "C:bob", OK) - bob #: ("17", "C:alice", "SEND 9\nmessage 3") #> ("17", "C:alice", ERR (SMP AUTH)) - alice #: ("6", "C:bob", "DEL") #> ("6", "C:bob", OK) + bob #: ("11", "alice", "JOIN " <> qInfo' <> " 14\nbob's connInfo") #> ("11", "alice", OK) + ("", "bob", Right (REQ confId "bob's connInfo")) <- (alice <#:) + alice #: ("2", "bob", "ACPT " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK) + bob <# ("", "alice", INFO "alice's connInfo") + bob <# ("", "alice", CON) + alice <# ("", "bob", CON) + alice #: ("3", "bob", "SEND :hello") #> ("3", "bob", MID 1) + alice <# ("", "bob", SENT 1) + alice #: ("4", "bob", "SEND :how are you?") #> ("4", "bob", MID 2) + alice <# ("", "bob", SENT 2) + bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False + bob #: ("12", "alice", "ACK 1") #> ("12", "alice", OK) + bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False + bob #: ("13", "alice", "ACK 2") #> ("13", "alice", OK) + bob #: ("14", "alice", "SEND 9\nhello too") #> ("14", "alice", MID 3) + bob <# ("", "alice", SENT 3) + alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False + alice #: ("3a", "bob", "ACK 3") #> ("3a", "bob", OK) + bob #: ("15", "alice", "SEND 9\nmessage 1") #> ("15", "alice", MID 4) + bob <# ("", "alice", SENT 4) + alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False + alice #: ("4a", "bob", "ACK 4") #> ("4a", "bob", OK) + alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) + bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", MID 5) + bob <# ("", "alice", MERR 5 (SMP AUTH)) + alice #: ("6", "bob", "DEL") #> ("6", "bob", OK) + alice #:# "nothing else should be delivered to alice" + +testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO () +testDuplexConnRandomIds _ alice bob = do + ("1", bobConn, Right (INV qInfo)) <- alice #: ("1", "", "NEW") + let qInfo' = serializeSmpQueueInfo qInfo + ("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN " <> qInfo' <> " 14\nbob's connInfo") + ("", bobConn', Right (REQ confId "bob's connInfo")) <- (alice <#:) + bobConn' `shouldBe` bobConn + alice #: ("2", bobConn, "ACPT " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False + bob <# ("", aliceConn, INFO "alice's connInfo") + bob <# ("", aliceConn, CON) + alice <# ("", bobConn, CON) + alice #: ("2", bobConn, "SEND :hello") #> ("2", bobConn, MID 1) + alice <# ("", bobConn, SENT 1) + alice #: ("3", bobConn, "SEND :how are you?") #> ("3", bobConn, MID 2) + alice <# ("", bobConn, SENT 2) + bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False + bob #: ("12", aliceConn, "ACK 1") #> ("12", aliceConn, OK) + bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False + bob #: ("13", aliceConn, "ACK 2") #> ("13", aliceConn, OK) + bob #: ("14", aliceConn, "SEND 9\nhello too") #> ("14", aliceConn, MID 3) + bob <# ("", aliceConn, SENT 3) + alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False + alice #: ("3a", bobConn, "ACK 3") #> ("3a", bobConn, OK) + bob #: ("15", aliceConn, "SEND 9\nmessage 1") #> ("15", aliceConn, MID 4) + bob <# ("", aliceConn, SENT 4) + alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False + alice #: ("4a", bobConn, "ACK 4") #> ("4a", bobConn, OK) + alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK) + bob #: ("17", aliceConn, "SEND 9\nmessage 3") #> ("17", aliceConn, MID 5) + bob <# ("", aliceConn, MERR 5 (SMP AUTH)) + alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK) alice #:# "nothing else should be delivered to alice" testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO () testSubscription _ alice1 alice2 bob = do - ("1", "C:bob", Right (Inv qInfo)) <- alice1 #: ("1", "C:bob", "NEW") - let qInfo' = serializeSmpQueueInfo qInfo - bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON) - bob #: ("12", "C:alice", "SEND 5\nhello") =#> \case ("12", "C:alice", Sent _) -> True; _ -> False - bob #: ("13", "C:alice", "SEND 11\nhello again") =#> \case ("13", "C:alice", Sent _) -> True; _ -> False - alice1 <# ("", "C:bob", CON) - alice1 <#= \case ("", "C:bob", Msg "hello") -> True; _ -> False - alice1 <#= \case ("", "C:bob", Msg "hello again") -> True; _ -> False - alice2 #: ("21", "C:bob", "SUB") #> ("21", "C:bob", OK) - alice1 <# ("", "C:bob", END) - bob #: ("14", "C:alice", "SEND 2\nhi") =#> \case ("14", "C:alice", Sent _) -> True; _ -> False - alice2 <#= \case ("", "C:bob", Msg "hi") -> True; _ -> False + (alice1, "alice") `connect` (bob, "bob") + bob #: ("12", "alice", "SEND 5\nhello") #> ("12", "alice", MID 1) + bob <# ("", "alice", SENT 1) + bob #: ("13", "alice", "SEND 11\nhello again") #> ("13", "alice", MID 2) + bob <# ("", "alice", SENT 2) + alice1 <#= \case ("", "bob", Msg "hello") -> True; _ -> False + alice1 #: ("1", "bob", "ACK 1") #> ("1", "bob", OK) + alice1 <#= \case ("", "bob", Msg "hello again") -> True; _ -> False + alice1 #: ("2", "bob", "ACK 2") #> ("2", "bob", OK) + alice2 #: ("21", "bob", "SUB") #> ("21", "bob", OK) + alice1 <# ("", "bob", END) + bob #: ("14", "alice", "SEND 2\nhi") #> ("14", "alice", MID 3) + bob <# ("", "alice", SENT 3) + alice2 <#= \case ("", "bob", Msg "hi") -> True; _ -> False + alice2 #: ("22", "bob", "ACK 3") #> ("22", "bob", OK) alice1 #:# "nothing else should be delivered to alice1" testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO () -testSubscrNotification _ (server, _) client = do - client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", Inv _) -> True; _ -> False +testSubscrNotification t (server, _) client = do + client #: ("1", "conn1", "NEW") =#> \case ("1", "conn1", INV {}) -> True; _ -> False client #:# "nothing should be delivered to client before the server is killed" killThread server - client <# ("", "C:conn1", END) + client <# ("", "conn1", DOWN) + withSmpServer (ATransport t) $ + client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue -testBroadcast :: forall c. Transport c => TProxy c -> c -> c -> c -> IO () -testBroadcast _ alice bob tom = do - -- establish connections - (alice, "alice") `connect` (bob, "bob") - (alice, "alice") `connect` (tom, "tom") - -- create and set up broadcast - alice #: ("1", "B:team", "NEW") #> ("1", "B:team", OK) - alice #: ("2", "B:team", "ADD C:bob") #> ("2", "B:team", OK) - alice #: ("3", "B:team", "ADD C:tom") #> ("3", "B:team", OK) - -- commands with errors - alice #: ("e1", "B:team", "NEW") #> ("e1", "B:team", ERR $ BCAST B_DUPLICATE) - alice #: ("e2", "B:group", "ADD C:bob") #> ("e2", "B:group", ERR $ BCAST B_NOT_FOUND) - alice #: ("e3", "B:team", "ADD C:unknown") #> ("e3", "B:team", ERR $ CONN NOT_FOUND) - alice #: ("e4", "B:team", "ADD C:bob") #> ("e4", "B:team", ERR $ CONN DUPLICATE) - -- send message - alice #: ("4", "B:team", "SEND 5\nhello") #> ("4", "C:bob", SENT 1) - alice <# ("4", "C:tom", SENT 1) - alice <# ("4", "B:team", SENT 0) - bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False - tom <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False - -- remove one connection - alice #: ("5", "B:team", "REM C:tom") #> ("5", "B:team", OK) - alice #: ("6", "B:team", "SEND 11\nhello again") #> ("6", "C:bob", SENT 2) - alice <# ("6", "B:team", SENT 0) - bob <#= \case ("", "C:alice", Msg "hello again") -> True; _ -> False - tom #:# "nothing delivered to tom" - -- commands with errors - alice #: ("e5", "B:group", "REM C:bob") #> ("e5", "B:group", ERR $ BCAST B_NOT_FOUND) - alice #: ("e6", "B:team", "REM C:unknown") #> ("e6", "B:team", ERR $ CONN NOT_FOUND) - alice #: ("e7", "B:team", "REM C:tom") #> ("e7", "B:team", ERR $ CONN NOT_FOUND) - -- delete broadcast - alice #: ("7", "B:team", "DEL") #> ("7", "B:team", OK) - alice #: ("8", "B:team", "SEND 11\ntry sending") #> ("8", "B:team", ERR $ BCAST B_NOT_FOUND) - -- commands with errors - alice #: ("e8", "B:team", "DEL") #> ("e8", "B:team", ERR $ BCAST B_NOT_FOUND) - alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND) +testMsgDeliveryServerRestart :: Transport c => TProxy c -> c -> c -> IO () +testMsgDeliveryServerRestart t alice bob = do + withServer $ do + connect (alice, "alice") (bob, "bob") + bob #: ("1", "alice", "SEND 2\nhi") #> ("1", "alice", MID 1) + bob <# ("", "alice", SENT 1) + alice <#= \case ("", "bob", Msg "hi") -> True; _ -> False + alice #: ("11", "bob", "ACK 1") #> ("11", "bob", OK) + alice #:# "nothing else delivered before the server is killed" + + alice <# ("", "bob", DOWN) + bob #: ("2", "alice", "SEND 11\nhello again") #> ("2", "alice", MID 2) + bob #:# "nothing else delivered before the server is restarted" + alice #:# "nothing else delivered before the server is restarted" + + withServer $ do + bob <# ("", "alice", SENT 2) + alice <# ("", "bob", UP) + alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False + alice #: ("12", "bob", "ACK 2") #> ("12", "bob", OK) + + removeFile testStoreLogFile where - connect :: (c, ByteString) -> (c, ByteString) -> IO () - connect (h1, name1) (h2, name2) = do - ("c1", _, Right (Inv qInfo)) <- h1 #: ("c1", "C:" <> name2, "NEW") - let qInfo' = serializeSmpQueueInfo qInfo - h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') =#> \case ("", c1, APartyCmd CON) -> c1 == "C:" <> name1; _ -> False - h1 <#= \case ("", c2, APartyCmd CON) -> c2 == "C:" <> name2; _ -> False + withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` () + +testMsgDeliveryAgentRestart :: Transport c => TProxy c -> c -> IO () +testMsgDeliveryAgentRestart t bob = do + withAgent $ \alice -> do + withServer $ do + connect (bob, "bob") (alice, "alice") + alice #: ("1", "bob", "SEND 5\nhello") #> ("1", "bob", MID 1) + alice <# ("", "bob", SENT 1) + bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False + bob #: ("11", "alice", "ACK 1") #> ("11", "alice", OK) + bob #:# "nothing else delivered before the server is down" + + bob <# ("", "alice", DOWN) + alice #: ("2", "bob", "SEND 11\nhello again") #> ("2", "bob", MID 2) + alice #:# "nothing else delivered before the server is restarted" + bob #:# "nothing else delivered before the server is restarted" + + withAgent $ \alice -> do + withServer $ do + tPutRaw alice ("3", "bob", "SUB") + alice <#= \case + (corrId, "bob", cmd) -> + (corrId == "3" && cmd == OK) + || (corrId == "" && cmd == SENT 2) + _ -> False + bob <# ("", "alice", UP) + bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False + bob #: ("12", "alice", "ACK 2") #> ("12", "alice", OK) + + removeFile testStoreLogFile + removeFile testDB + where + withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` () + withAgent = withSmpAgentThreadOn_ (ATransport t) (agentTestPort, testPort, testDB) (pure ()) . const . testSMPAgentClientOn agentTestPort + +connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO () +connect (h1, name1) (h2, name2) = do + ("c1", _, Right (INV qInfo)) <- h1 #: ("c1", name2, "NEW") + let qInfo' = serializeSmpQueueInfo qInfo + h2 #: ("c2", name1, "JOIN " <> qInfo' <> " 5\ninfo2") #> ("c2", name1, OK) + ("", _, Right (REQ connId "info2")) <- (h1 <#:) + h1 #: ("c3", name2, "ACPT " <> connId <> " 5\ninfo1") #> ("c3", name2, OK) + h2 <# ("", name1, INFO "info1") + h2 <# ("", name1, CON) + h1 <# ("", name2, CON) + +-- connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString) +-- connect' h1 h2 = do +-- ("c1", conn2, Right (INV qInfo)) <- h1 #: ("c1", "", "NEW") +-- let qInfo' = serializeSmpQueueInfo qInfo +-- ("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN " <> qInfo' <> " 5\ninfo2") +-- ("", _, Right (REQ connId "info2")) <- (h1 <#:) +-- h1 #: ("c3", conn2, "ACPT " <> connId <> " 5\ninfo1") =#> \case ("c3", c, OK) -> c == conn2; _ -> False +-- h2 <# ("", conn1, INFO "info1") +-- h2 <# ("", conn1, CON) +-- h1 <# ("", conn2, CON) +-- pure (conn1, conn2) samplePublicKey :: ByteString samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR" syntaxTests :: forall c. Transport c => TProxy c -> Spec syntaxTests t = do - it "unknown command" $ ("1", "C:5678", "HELLO") >#> ("1", "C:5678", "ERR CMD SYNTAX") + it "unknown command" $ ("1", "5678", "HELLO") >#> ("1", "5678", "ERR CMD SYNTAX") describe "NEW" do describe "valid" do - -- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided) -- TODO: add tests with defined connection alias - xit "without parameters" $ ("211", "C:", "NEW") >#>= \case ("211", "C:", "INV" : _) -> True; _ -> False + it "without parameters" $ ("211", "", "NEW") >#>= \case ("211", _, "INV" : _) -> True; _ -> False describe "invalid" do -- TODO: add tests with defined connection alias - it "with parameters" $ ("222", "C:", "NEW hi") >#> ("222", "C:", "ERR CMD SYNTAX") + it "with parameters" $ ("222", "", "NEW hi") >#> ("222", "", "ERR CMD SYNTAX") describe "JOIN" do describe "valid" do -- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided) -- TODO: add tests with defined connection alias it "using same server as in invitation" $ - ("311", "C:", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:", "ERR SMP AUTH") + ("311", "a", "JOIN smp::localhost:5000::1234::" <> samplePublicKey <> " 14\nbob's connInfo") >#> ("311", "a", "ERR SMP AUTH") describe "invalid" do -- TODO: JOIN is not merged yet - to be added - it "no parameters" $ ("321", "C:", "JOIN") >#> ("321", "C:", "ERR CMD SYNTAX") + it "no parameters" $ ("321", "", "JOIN") >#> ("321", "", "ERR CMD SYNTAX") where -- simple test for one command with the expected response (>#>) :: ARawTransmission -> ARawTransmission -> Expectation diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs new file mode 100644 index 000000000..008e0c14b --- /dev/null +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -0,0 +1,163 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} + +module AgentTests.FunctionalAPITests (functionalAPITests) where + +import Control.Monad.Except (ExceptT, runExceptT) +import Control.Monad.IO.Unlift +import SMPAgentClient +import SMPClient (withSmpServer) +import Simplex.Messaging.Agent +import Simplex.Messaging.Agent.Env.SQLite (dbFile) +import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) +import Simplex.Messaging.Transport (ATransport (..)) +import System.Timeout +import Test.Hspec +import UnliftIO.STM + +(##>) :: MonadIO m => m (ATransmission 'Agent) -> ATransmission 'Agent -> m () +a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t) + +(=##>) :: MonadIO m => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m () +a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p) + +get :: MonadIO m => AgentClient -> m (ATransmission 'Agent) +get c = atomically (readTBQueue $ subQ c) + +pattern Msg :: MsgBody -> ACommand 'Agent +pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} msgBody + +functionalAPITests :: ATransport -> Spec +functionalAPITests t = do + describe "Establishing duplex connection" $ + it "should connect via one server using SMP agent clients" $ + withSmpServer t testAgentClient + describe "Establishing connection asynchronously" $ do + it "should connect with initiating client going offline" $ + withSmpServer t testAsyncInitiatingOffline + it "should connect with joining client going offline before its queue activation" $ + withSmpServer t testAsyncJoiningOfflineBeforeActivation + -- TODO a valid test case but not trivial to implement, probably requires some agent rework + xit "should connect with joining client going offline after its queue activation" $ + withSmpServer t testAsyncJoiningOfflineAfterActivation + it "should connect with both clients going offline" $ + withSmpServer t testAsyncBothOffline + +testAgentClient :: IO () +testAgentClient = do + alice <- getSMPAgentClient cfg + bob <- getSMPAgentClient cfg {dbFile = testDB2} + Right () <- runExceptT $ do + (bobId, qInfo) <- createConnection alice + aliceId <- joinConnection bob qInfo "bob's connInfo" + ("", _, REQ confId "bob's connInfo") <- get alice + acceptConnection alice bobId confId "alice's connInfo" + get alice ##> ("", bobId, CON) + get bob ##> ("", aliceId, INFO "alice's connInfo") + get bob ##> ("", aliceId, CON) + 1 <- sendMessage alice bobId "hello" + get alice ##> ("", bobId, SENT 1) + 2 <- sendMessage alice bobId "how are you?" + get alice ##> ("", bobId, SENT 2) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId 1 + get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False + ackMessage bob aliceId 2 + 3 <- sendMessage bob aliceId "hello too" + get bob ##> ("", aliceId, SENT 3) + 4 <- sendMessage bob aliceId "message 1" + get bob ##> ("", aliceId, SENT 4) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId 3 + get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False + ackMessage alice bobId 4 + suspendConnection alice bobId + 5 <- sendMessage bob aliceId "message 2" + get bob ##> ("", aliceId, MERR 5 (SMP AUTH)) + deleteConnection alice bobId + liftIO $ noMessages alice "nothing else should be delivered to alice" + pure () + where + noMessages :: AgentClient -> String -> Expectation + noMessages c err = tryGet `shouldReturn` () + where + tryGet = + 10000 `timeout` get c >>= \case + Just _ -> error err + _ -> return () + +testAsyncInitiatingOffline :: IO () +testAsyncInitiatingOffline = do + alice <- getSMPAgentClient cfg + bob <- getSMPAgentClient cfg {dbFile = testDB2} + Right () <- runExceptT $ do + (bobId, qInfo) <- createConnection alice + disconnectAgentClient alice + aliceId <- joinConnection bob qInfo "bob's connInfo" + alice' <- liftIO $ getSMPAgentClient cfg + subscribeConnection alice' bobId + ("", _, REQ confId "bob's connInfo") <- get alice' + acceptConnection alice' bobId confId "alice's connInfo" + get alice' ##> ("", bobId, CON) + get bob ##> ("", aliceId, INFO "alice's connInfo") + get bob ##> ("", aliceId, CON) + exchangeGreetings alice' bobId bob aliceId + pure () + +testAsyncJoiningOfflineBeforeActivation :: IO () +testAsyncJoiningOfflineBeforeActivation = do + alice <- getSMPAgentClient cfg + bob <- getSMPAgentClient cfg {dbFile = testDB2} + Right () <- runExceptT $ do + (bobId, qInfo) <- createConnection alice + aliceId <- joinConnection bob qInfo "bob's connInfo" + disconnectAgentClient bob + ("", _, REQ confId "bob's connInfo") <- get alice + acceptConnection alice bobId confId "alice's connInfo" + bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} + subscribeConnection bob' aliceId + get alice ##> ("", bobId, CON) + get bob' ##> ("", aliceId, INFO "alice's connInfo") + get bob' ##> ("", aliceId, CON) + exchangeGreetings alice bobId bob' aliceId + pure () + +testAsyncJoiningOfflineAfterActivation :: IO () +testAsyncJoiningOfflineAfterActivation = error "not implemented" + +testAsyncBothOffline :: IO () +testAsyncBothOffline = do + alice <- getSMPAgentClient cfg + bob <- getSMPAgentClient cfg {dbFile = testDB2} + Right () <- runExceptT $ do + (bobId, qInfo) <- createConnection alice + disconnectAgentClient alice + aliceId <- joinConnection bob qInfo "bob's connInfo" + disconnectAgentClient bob + alice' <- liftIO $ getSMPAgentClient cfg + subscribeConnection alice' bobId + ("", _, REQ confId "bob's connInfo") <- get alice' + acceptConnection alice' bobId confId "alice's connInfo" + bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} + subscribeConnection bob' aliceId + get alice' ##> ("", bobId, CON) + get bob' ##> ("", aliceId, INFO "alice's connInfo") + get bob' ##> ("", aliceId, CON) + exchangeGreetings alice' bobId bob' aliceId + pure () + +exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetings alice bobId bob aliceId = do + 1 <- sendMessage alice bobId "hello" + get alice ##> ("", bobId, SENT 1) + get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False + ackMessage bob aliceId 1 + 2 <- sendMessage bob aliceId "hello too" + get bob ##> ("", aliceId, SENT 2) + get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False + ackMessage alice bobId 2 diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 834720645..26d652ad9 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -1,5 +1,6 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} @@ -8,9 +9,11 @@ module AgentTests.SQLiteTests (storeTests) where import Control.Concurrent.Async (concurrently_) +import Control.Concurrent.STM import Control.Monad (replicateM_) import Control.Monad.Except (ExceptT, runExceptT) import qualified Crypto.PubKey.RSA as R +import Crypto.Random (drgNew) import Data.ByteString.Char8 (ByteString) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) @@ -22,8 +25,9 @@ import SMPClient (testKeyHash) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite +import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import qualified Simplex.Messaging.Crypto as C -import System.Random (Random (randomIO)) +import System.Random import Test.Hspec import UnliftIO.Directory (removeFile) @@ -39,7 +43,7 @@ withStore2 = before connect2 . after (removeStore . fst) connect2 :: IO (SQLiteStore, SQLiteStore) connect2 = do s1 <- createStore - s2 <- connectSQLiteStore $ dbFilePath s1 + s2 <- connectSQLiteStore (dbFilePath s1) 4 pure (s1, s2) createStore :: IO SQLiteStore @@ -47,12 +51,15 @@ createStore = do -- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous -- IO operations on multiple similarly named files; error seems to be environment specific r <- randomIO :: IO Word32 - createSQLiteStore $ testDB <> show r + createSQLiteStore (testDB <> show r) 4 Migrations.app removeStore :: SQLiteStore -> IO () removeStore store = do - DB.close $ dbConn store + close store removeFile $ dbFilePath store + where + close :: SQLiteStore -> IO () + close st = mapM_ DB.close =<< atomically (flushTBQueue $ dbConnPool st) returnsResult :: (Eq a, Eq e, Show a, Show e) => ExceptT e IO a -> a -> Expectation action `returnsResult` r = runExceptT action `shouldReturn` Right r @@ -73,11 +80,13 @@ storeTests = do describe "Queue and Connection management" do describe "createRcvConn" do testCreateRcvConn + testCreateRcvConnRandomId testCreateRcvConnDuplicate describe "createSndConn" do testCreateSndConn + testCreateSndConnRandomID testCreateSndConnDuplicate - describe "getAllConnAliases" testGetAllConnAliases + describe "getAllConnIds" testGetAllConnIds describe "getRcvConn" testGetRcvConn describe "deleteConn" do testDeleteRcvConn @@ -104,24 +113,29 @@ storeTests = do testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore) testConcurrentWrites = it "should complete multiple concurrent write transactions w/t sqlite busy errors" $ \(s1, s2) -> do - _ <- runExceptT $ createRcvConn s1 rcvQueue1 - concurrently_ (runTest s1) (runTest s2) + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn s1 g cData1 rcvQueue1 + let ConnData {connId} = cData1 + concurrently_ (runTest s1 connId) (runTest s2 connId) where - runTest :: SQLiteStore -> IO (Either StoreError ()) - runTest store = runExceptT . replicateM_ 100 $ do - (internalId, internalRcvId, _, _) <- updateRcvIds store rcvQueue1 + runTest :: SQLiteStore -> ConnId -> IO (Either StoreError ()) + runTest store connId = runExceptT . replicateM_ 100 $ do + (internalId, internalRcvId, _, _) <- updateRcvIds store connId let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy" - createRcvMsg store rcvQueue1 rcvMsgData + createRcvMsg store connId rcvMsgData testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = - it "compiled sqlite library should be threadsafe" $ \store -> do - compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] + it "compiled sqlite library should be threadsafe" . withStoreConnection $ \db -> do + compileOptions <- DB.query_ db "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] compileOptions `shouldNotContain` [["THREADSAFE=0"]] +withStoreConnection :: (DB.Connection -> IO a) -> SQLiteStore -> IO a +withStoreConnection = flip withConnection + testForeignKeysEnabled :: SpecWith SQLiteStore testForeignKeysEnabled = - it "foreign keys should be enabled" $ \store -> do + it "foreign keys should be enabled" . withStoreConnection $ \db -> do let inconsistentQuery = [sql| INSERT INTO connections @@ -129,18 +143,19 @@ testForeignKeysEnabled = VALUES ("conn1", "smp.simplex.im", "5223", "1234", "smp.simplex.im", "5223", "2345"); |] - DB.execute_ (dbConn store) inconsistentQuery + DB.execute_ db inconsistentQuery `shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint) +cData1 :: ConnData +cData1 = ConnData {connId = "conn1"} + rcvQueue1 :: RcvQueue rcvQueue1 = RcvQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "1234", - connAlias = "conn1", rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "2345", - sndKey = Nothing, decryptKey = C.safePrivateKey (1, 2, 3), verifyKey = Nothing, status = New @@ -151,74 +166,104 @@ sndQueue1 = SndQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "3456", - connAlias = "conn1", sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, - signKey = C.safePrivateKey (1, 2, 3), + signKey = C.APrivateKey $ C.unPrivateKey (C.safePrivateKey (1, 2, 3) :: C.SafePrivateKey), status = New } testCreateRcvConn :: SpecWith SQLiteStore testCreateRcvConn = it "should create RcvConnection and add SndQueue" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + g <- newTVarIO =<< drgNew + createRcvConn store g cData1 rcvQueue1 + `returnsResult` "conn1" getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) upgradeRcvConnToDuplex store "conn1" sndQueue1 `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) + +testCreateRcvConnRandomId :: SpecWith SQLiteStore +testCreateRcvConnRandomId = + it "should create RcvConnection and add SndQueue with random ID" $ \store -> do + g <- newTVarIO =<< drgNew + Right connId <- runExceptT $ createRcvConn store g cData1 {connId = ""} rcvQueue1 + getConn store connId + `returnsResult` SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1) + upgradeRcvConnToDuplex store connId sndQueue1 + `returnsResult` () + getConn store connId + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1) testCreateRcvConnDuplicate :: SpecWith SQLiteStore testCreateRcvConnDuplicate = it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 - createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 + createRcvConn store g cData1 rcvQueue1 `throwsError` SEConnDuplicate testCreateSndConn :: SpecWith SQLiteStore testCreateSndConn = it "should create SndConnection and add RcvQueue" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + g <- newTVarIO =<< drgNew + createSndConn store g cData1 sndQueue1 + `returnsResult` "conn1" getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1) upgradeSndConnToDuplex store "conn1" rcvQueue1 `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) + +testCreateSndConnRandomID :: SpecWith SQLiteStore +testCreateSndConnRandomID = + it "should create SndConnection and add RcvQueue with random ID" $ \store -> do + g <- newTVarIO =<< drgNew + Right connId <- runExceptT $ createSndConn store g cData1 {connId = ""} sndQueue1 + getConn store connId + `returnsResult` SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1) + upgradeSndConnToDuplex store connId rcvQueue1 + `returnsResult` () + getConn store connId + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1) testCreateSndConnDuplicate :: SpecWith SQLiteStore testCreateSndConnDuplicate = it "should throw error on attempt to create duplicate SndConnection" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 - createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 + createSndConn store g cData1 sndQueue1 `throwsError` SEConnDuplicate -testGetAllConnAliases :: SpecWith SQLiteStore -testGetAllConnAliases = +testGetAllConnIds :: SpecWith SQLiteStore +testGetAllConnIds = it "should get all conn aliases" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 - _ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"} - getAllConnAliases store - `returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias] + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 + _ <- runExceptT $ createSndConn store g cData1 {connId = "conn2"} sndQueue1 + getAllConnIds store + `returnsResult` ["conn1" :: ConnId, "conn2" :: ConnId] testGetRcvConn :: SpecWith SQLiteStore testGetRcvConn = it "should get connection using rcv queue id and server" $ \store -> do let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash let recipientId = "1234" - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 getRcvConn store smpServer recipientId - `returnsResult` SomeConn SCRcv (RcvConnection (connAlias (rcvQueue1 :: RcvQueue)) rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn = it "should create RcvConnection and delete it" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well @@ -228,9 +273,10 @@ testDeleteRcvConn = testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = it "should create SndConnection and delete it" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well @@ -240,10 +286,11 @@ testDeleteSndConn = testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = it "should create DuplexConnection and delete it" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) deleteConn store "conn1" `returnsResult` () -- TODO check queues are deleted as well @@ -253,15 +300,15 @@ testDeleteDuplexConn = testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 let anotherSndQueue = SndQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, sndId = "2345", - connAlias = "conn1", sndPrivateKey = C.safePrivateKey (1, 2, 3), encryptKey = C.PublicKey $ R.PublicKey 1 2 3, - signKey = C.safePrivateKey (1, 2, 3), + signKey = C.APrivateKey $ C.unPrivateKey (C.safePrivateKey (1, 2, 3) :: C.SafePrivateKey), status = New } upgradeRcvConnToDuplex store "conn1" anotherSndQueue @@ -273,15 +320,14 @@ testUpgradeRcvConnToDuplex = testUpgradeSndConnToDuplex :: SpecWith SQLiteStore testUpgradeSndConnToDuplex = it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 let anotherRcvQueue = RcvQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, rcvId = "3456", - connAlias = "conn1", rcvPrivateKey = C.safePrivateKey (1, 2, 3), sndId = Just "4567", - sndKey = Nothing, decryptKey = C.safePrivateKey (1, 2, 3), verifyKey = Nothing, status = New @@ -295,40 +341,43 @@ testUpgradeSndConnToDuplex = testSetRcvQueueStatus :: SpecWith SQLiteStore testSetRcvQueueStatus = it "should update status of RcvQueue" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1) setRcvQueueStatus store rcvQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1 {status = Confirmed}) + `returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1 {status = Confirmed}) testSetSndQueueStatus :: SpecWith SQLiteStore testSetSndQueueStatus = it "should update status of SndQueue" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1) setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1 {status = Confirmed}) + `returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1 {status = Confirmed}) testSetQueueStatusDuplex :: SpecWith SQLiteStore testSetQueueStatusDuplex = it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1) setRcvQueueStatus store rcvQueue1 Secured `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1) setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) + `returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = @@ -351,31 +400,36 @@ ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash = RcvMsgData - { internalId, - internalRcvId, - internalTs = ts, - senderMeta = (externalSndId, ts), - brokerMeta = (brokerId, ts), + { internalRcvId, + msgMeta = + MsgMeta + { integrity = MsgOk, + recipient = (unId internalId, ts), + sender = (externalSndId, ts), + broker = (brokerId, ts) + }, msgBody = hw, internalHash, - externalPrevSndHash = "hash_from_sender", - msgIntegrity = MsgOk + externalPrevSndHash = "hash_from_sender" } -testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation -testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do - updateRcvIds store rcvQueue - `returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) - createRcvMsg store rcvQueue rcvMsgData +testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation +testCreateRcvMsg' st expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do + let MsgMeta {recipient = (internalId, _)} = msgMeta + updateRcvIds st connId + `returnsResult` (InternalId internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) + createRcvMsg st connId rcvMsgData `returnsResult` () testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = - it "should reserve internal ids and create a RcvMsg" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + it "should reserve internal ids and create a RcvMsg" $ \st -> do + g <- newTVarIO =<< drgNew + let ConnData {connId} = cData1 + _ <- runExceptT $ createRcvConn st g cData1 rcvQueue1 -- TODO getMsg to check message - testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" - testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" + testCreateRcvMsg' st 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" + testCreateRcvMsg' st 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData mkSndMsgData internalId internalSndId internalHash = @@ -384,32 +438,37 @@ mkSndMsgData internalId internalSndId internalHash = internalSndId, internalTs = ts, msgBody = hw, - internalHash + internalHash, + previousMsgHash = internalHash } -testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation -testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do - updateSndIds store sndQueue +testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> ConnId -> SndMsgData -> Expectation +testCreateSndMsg' store expectedPrevHash connId sndMsgData@SndMsgData {..} = do + updateSndIds store connId `returnsResult` (internalId, internalSndId, expectedPrevHash) - createSndMsg store sndQueue sndMsgData + createSndMsg store connId sndMsgData `returnsResult` () testCreateSndMsg :: SpecWith SQLiteStore testCreateSndMsg = it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do - _ <- runExceptT $ createSndConn store sndQueue1 + g <- newTVarIO =<< drgNew + let ConnData {connId} = cData1 + _ <- runExceptT $ createSndConn store g cData1 sndQueue1 -- TODO getMsg to check message - testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" - testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" + testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" + testCreateSndMsg' store "hash_dummy" connId $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" testCreateRcvAndSndMsgs :: SpecWith SQLiteStore testCreateRcvAndSndMsgs = it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do - _ <- runExceptT $ createRcvConn store rcvQueue1 + g <- newTVarIO =<< drgNew + let ConnData {connId} = cData1 + _ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 - testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" - testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" - testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" - testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" - testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" - testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" + testCreateRcvMsg' store 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" + testCreateRcvMsg' store 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" + testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" + testCreateRcvMsg' store 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" + testCreateSndMsg' store "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" + testCreateSndMsg' store "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3" diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 918b276f0..f967368b1 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -24,6 +24,7 @@ import SMPClient import Simplex.Messaging.Agent (runSMPAgentBlocking) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig) import Simplex.Messaging.Transport import Test.Hspec @@ -95,11 +96,15 @@ smpAgentTestN_1 n test' = runSmpAgentTestN_1 n test' `shouldReturn` () smpAgentTest2_2_2 :: forall c. Transport c => (c -> c -> IO ()) -> Expectation smpAgentTest2_2_2 test' = withSmpServerOn (transport @c) testPort2 $ - smpAgentTestN - [ (agentTestPort, testPort, testDB), - (agentTestPort2, testPort2, testDB2) - ] - _test + smpAgentTest2_2_2_needs_server test' + +smpAgentTest2_2_2_needs_server :: forall c. Transport c => (c -> c -> IO ()) -> Expectation +smpAgentTest2_2_2_needs_server test' = + smpAgentTestN + [ (agentTestPort, testPort, testDB), + (agentTestPort2, testPort2, testDB2) + ] + _test where _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" @@ -139,13 +144,20 @@ smpAgentTest3_1_1 test' = smpAgentTestN_1 3 _test _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" +smpAgentTest1_1_1 :: forall c. Transport c => (c -> IO ()) -> Expectation +smpAgentTest1_1_1 test' = + smpAgentTestN + [(agentTestPort2, testPort2, testDB2)] + _test + where + _test [h] = test' h + _test _ = error "expected 1 handle" + cfg :: AgentConfig cfg = - AgentConfig + defaultAgentConfig { tcpPort = agentTestPort, smpServers = L.fromList ["localhost:5000#KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8="], - rsaKeySize = 2048 `div` 8, - connIdBytes = 12, tbqSize = 1, dbFile = testDB, smpCfg = @@ -153,15 +165,19 @@ cfg = { qSize = 1, defaultTransport = (testPort, transport @TCP), tcpTimeout = 500_000 - } + }, + retryInterval = (retryInterval defaultAgentConfig) {initialInterval = 50_000} } -withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a -withSmpAgentThreadOn t (port', smpPort', db') = +withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m () -> (ThreadId -> m a) -> m a +withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess = let cfg' = cfg {tcpPort = port', dbFile = db', smpServers = L.fromList [SMPServer "localhost" (Just smpPort') testKeyHash]} in serverBracket (\started -> runSMPAgentBlocking t started cfg') - (removeFile db') + afterProcess + +withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a +withSmpAgentThreadOn t a@(_, _, db') = withSmpAgentThreadOn_ t a $ removeFile db' withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 5b7dad9c2..58a5d5163 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -39,6 +39,9 @@ testPort2 = "5001" testKeyHashStr :: B.ByteString testKeyHashStr = "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" +testBlockSize :: Maybe Int +testBlockSize = Just 8192 + testKeyHash :: Maybe C.KeyHash testKeyHash = Just "KXNE1m2E1m0lm92WGKet9CL6+lO742Vy5G6nsrkvgs8=" @@ -48,7 +51,7 @@ testStoreLogFile = "tests/tmp/smp-server-store.log" testSMPClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a testSMPClient client = runTransportClient testHost testPort $ \h -> - liftIO (runExceptT $ clientHandshake h testKeyHash) >>= \case + liftIO (runExceptT $ clientHandshake h testBlockSize testKeyHash) >>= \case Right th -> client th Left e -> error $ show e @@ -57,9 +60,11 @@ cfg = ServerConfig { transports = undefined, tbqSize = 1, + msgQueueQuota = 4, queueIdBytes = 12, msgIdBytes = 6, storeLog = Nothing, + blockSize = 8192, serverPrivateKey = -- full RSA private key (only for tests) "MIIFIwIBAAKCAQEArZyrri/NAwt5buvYjwu+B/MQeJUszDBpRgVqNddlI9kNwDXu\ diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index bca83c073..a3d93093b 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -337,12 +337,19 @@ testTiming (ATransport t) = (testSameTiming rh sh) [ (128, 128, 100), (128, 256, 25), + (128, 384, 15), -- (128, 512, 15), (256, 128, 100), - (256, 256, 25) + (256, 256, 25), + (256, 384, 15), -- (256, 512, 15), + (384, 128, 100), + (384, 256, 25), + (384, 384, 15) + -- (384, 512, 15), -- (512, 128, 100), -- (512, 256, 25), + -- (512, 384, 15), -- (512, 512, 15) ] where diff --git a/tests/Test.hs b/tests/Test.hs index 64ee00d5c..b27b86d59 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,6 +1,6 @@ {-# LANGUAGE TypeApplications #-} -import AgentTests +import AgentTests (agentTests) import ProtocolErrorTests import ServerTests import Simplex.Messaging.Transport (TCP, Transport (..))