From 6be48397033583d3aa44457cf2abb7dbd97859dc Mon Sep 17 00:00:00 2001 From: Efim Poberezkin <8711996+efim-poberezkin@users.noreply.github.com> Date: Sun, 2 May 2021 00:38:32 +0400 Subject: [PATCH] agent: verify msg integrity based on previous msg hash and id (#110) Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> --- apps/dog-food/Main.hs | 2 +- src/Simplex/Messaging/Agent.hs | 93 +++- src/Simplex/Messaging/Agent/Client.hs | 34 +- src/Simplex/Messaging/Agent/Store.hs | 47 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 497 ++++++++---------- .../Messaging/Agent/Store/SQLite/Schema.hs | 3 + src/Simplex/Messaging/Agent/Transmission.hs | 62 ++- src/Simplex/Messaging/Crypto.hs | 4 + tests/AgentTests.hs | 10 +- tests/AgentTests/SQLiteTests.hs | 225 ++++---- 10 files changed, 515 insertions(+), 462 deletions(-) diff --git a/apps/dog-food/Main.hs b/apps/dog-food/Main.hs index bf874285d..3517746a5 100644 --- a/apps/dog-food/Main.hs +++ b/apps/dog-food/Main.hs @@ -259,7 +259,7 @@ receiveFromAgent t ct c = forever . atomically $ do INV qInfo -> Invitation qInfo CON -> Connected contact END -> Disconnected contact - MSG {m_body} -> ReceivedMessage contact m_body + MSG {msgBody} -> ReceivedMessage contact msgBody SENT _ -> NoChatResponse OK -> Confirmation contact ERR (CONN e) -> ContactError e contact diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 1bbdb87bd..ca544a6fd 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -177,10 +177,22 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = _ -> throwError $ CONN SIMPLEX where sendMsg sq = do - senderTs <- liftIO getCurrentTime - senderId <- withStore $ createSndMsg st connAlias msgBody senderTs - sendAgentMessage c sq senderTs $ A_MSG msgBody - respond $ SENT (unId senderId) + internalTs <- liftIO getCurrentTime + (internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq + let msgStr = + serializeSMPMessage + SMPMessage + { senderMsgId = unSndId internalSndId, + senderTimestamp = internalTs, + previousMsgHash, + agentMessage = A_MSG msgBody + } + msgHash = C.sha256Hash msgStr + withStore $ + createSndMsg st sq $ + SndMsgData {internalId, internalSndId, internalTs, msgBody, msgHash} + sendAgentMessage c sq msgStr + respond $ SENT (unId internalId) suspendConnection :: m () suspendConnection = @@ -208,8 +220,14 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) = sendReplyQInfo srv sq = do (rq, qInfo) <- newReceiveQueue c srv connAlias withStore $ upgradeSndConnToDuplex st connAlias rq - senderTs <- liftIO getCurrentTime - sendAgentMessage c sq senderTs $ REPLY qInfo + senderTimestamp <- liftIO getCurrentTime + sendAgentMessage c sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage = REPLY qInfo + } respond :: ACommand 'Agent -> m () respond = respond' connAlias @@ -231,7 +249,9 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do case cmd of SMP.MSG srvMsgId srvTs msgBody -> do -- TODO deduplicate with previously received - agentMsg <- liftEither . parseSMPMessage =<< decryptAndVerify rq msgBody + msg <- decryptAndVerify rq msgBody + let msgHash = C.sha256Hash msg + agentMsg <- liftEither $ parseSMPMessage msg case agentMsg of SMPConfirmation senderKey -> do logServer "<--" c srv rId "MSG " @@ -244,7 +264,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do secureQueue c rq senderKey withStore $ setRcvQueueStatus st rq Secured _ -> notify connAlias . ERR $ AGENT A_PROHIBITED - SMPMessage {agentMessage, senderMsgId, senderTimestamp} -> + SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} -> case agentMessage of HELLO verifyKey _ -> do logServer "<--" c srv rId "MSG " @@ -259,24 +279,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do withStore $ upgradeRcvConnToDuplex st connAlias sq connectToSendQueue c st sq senderKey verifyKey notify connAlias CON - A_MSG body -> do - -- TODO check message status - logServer "<--" c srv rId "MSG " - case status of - Active -> do - recipientTs <- liftIO getCurrentTime - let m_sender = (senderMsgId, senderTimestamp) - let m_broker = (srvMsgId, srvTs) - recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker - notify connAlias $ - MSG - { m_status = MsgOk, - m_recipient = (unId recipientId, recipientTs), - m_sender, - m_broker, - m_body = body - } - _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + A_MSG body -> agentClientMsg rq previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash sendAck c rq return () SMP.END -> do @@ -289,6 +292,44 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do where notify :: ConnAlias -> ACommand 'Agent -> m () notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg) + agentClientMsg :: RcvQueue -> PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m () + agentClientMsg rq@RcvQueue {connAlias, status} receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do + logServer "<--" c srv rId "MSG " + case status of + Active -> do + internalTs <- liftIO getCurrentTime + (internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq + let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash + withStore $ + createRcvMsg st rq $ + RcvMsgData + { internalId, + internalRcvId, + internalTs, + senderMeta, + brokerMeta, + msgBody, + msgHash, + msgIntegrity + } + notify connAlias $ + MSG + { recipientMeta = (unId internalId, internalTs), + senderMeta, + brokerMeta, + msgBody, + msgIntegrity + } + _ -> notify connAlias . ERR $ AGENT A_PROHIBITED + where + checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> MsgIntegrity + checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash + | extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk + | extSndId < prevExtSndId = MsgError $ MsgBadId extSndId + | extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate + | extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1) + | internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash + | otherwise = MsgError MsgDuplicate -- this case is not possible connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m () connectToSendQueue c st sq senderKey verifyKey = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index fd66d3652..30c510b28 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -226,7 +226,7 @@ sendConfirmation c sq@SndQueue {server, sndId} senderKey = liftSMP $ sendSMPMessage smp Nothing sndId msg where mkConfirmation :: SMPClient -> m MsgBody - mkConfirmation smp = encryptAndSign smp sq $ SMPConfirmation senderKey + mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m () sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = @@ -236,8 +236,14 @@ sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey = where mkHello :: SMPClient -> AckMode -> m ByteString mkHello smp ackMode = do - senderTs <- liftIO getCurrentTime - mkAgentMessage smp sq senderTs $ HELLO verifyKey ackMode + senderTimestamp <- liftIO getCurrentTime + encryptAndSign smp sq . serializeSMPMessage $ + SMPMessage + { senderMsgId = 0, + senderTimestamp, + previousMsgHash = "", + agentMessage = HELLO verifyKey ackMode + } send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO () send 0 _ _ _ = throwE $ SMPServerError AUTH @@ -268,27 +274,17 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} = withLogSMP c server rcvId "DEL" $ \smp -> deleteSMPQueue smp rcvPrivateKey rcvId -sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m () -sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} senderTs agentMsg = +sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () +sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msg = withLogSMP_ c server sndId "SEND " $ \smp -> do - msg <- mkAgentMessage smp sq senderTs agentMsg - liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg + msg' <- encryptAndSign smp sq msg + liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg' -mkAgentMessage :: AgentMonad m => SMPClient -> SndQueue -> SenderTimestamp -> AMessage -> m ByteString -mkAgentMessage smp sq senderTs agentMessage = do - encryptAndSign smp sq $ - SMPMessage - { senderMsgId = 0, - senderTimestamp = senderTs, - previousMsgHash = "1234", -- TODO hash of the previous message - agentMessage - } - -encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> SMPMessage -> m ByteString +encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> ByteString -> m ByteString encryptAndSign smp SndQueue {encryptKey, signKey} msg = do paddedSize <- asks $ (blockSize smp -) . reservedMsgSize liftError cryptoError $ do - enc <- C.encrypt encryptKey paddedSize $ serializeSMPMessage msg + enc <- C.encrypt encryptKey paddedSize msg C.Signature sig <- C.sign signKey enc pure $ sig <> enc diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 637b0fa81..51b9db518 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -43,8 +43,12 @@ class Monad m => MonadAgentStore s m where setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m () -- Msg management - createRcvMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId - createSndMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> m InternalId + updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m () + + updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) + createSndMsg :: s -> SndQueue -> SndMsgData -> m () + getMsg :: s -> ConnAlias -> InternalId -> m Msg -- * Queue types @@ -104,6 +108,11 @@ data SConnType :: ConnType -> Type where SCSnd :: SConnType CSnd SCDuplex :: SConnType CDuplex +connType :: SConnType c -> ConnType +connType SCRcv = CRcv +connType SCSnd = CSnd +connType SCDuplex = CDuplex + deriving instance Eq (SConnType d) deriving instance Show (SConnType d) @@ -125,6 +134,40 @@ instance Eq SomeConn where deriving instance Show SomeConn +-- * Message integrity validation types + +type MsgHash = ByteString + +-- | Corresponds to `last_external_snd_msg_id` in `connections` table +type PrevExternalSndId = Int64 + +-- | Corresponds to `last_rcv_msg_hash` in `connections` table +type PrevRcvMsgHash = MsgHash + +-- | Corresponds to `last_snd_msg_hash` in `connections` table +type PrevSndMsgHash = MsgHash + +-- * Message data containers - used on Msg creation to reduce number of parameters + +data RcvMsgData = RcvMsgData + { internalId :: InternalId, + internalRcvId :: InternalRcvId, + internalTs :: InternalTs, + senderMeta :: (ExternalSndId, ExternalSndTs), + brokerMeta :: (BrokerId, BrokerTs), + msgBody :: MsgBody, + msgHash :: MsgHash, + msgIntegrity :: MsgIntegrity + } + +data SndMsgData = SndMsgData + { internalId :: InternalId, + internalSndId :: InternalSndId, + internalTs :: InternalTs, + msgBody :: MsgBody, + msgHash :: MsgHash + } + -- * Message types -- | A message in either direction that is stored by the agent. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 452654537..c48110065 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -28,17 +29,17 @@ import Data.Maybe (fromMaybe) import Data.Text (isPrefixOf) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8) -import Database.SQLite.Simple as DB +import Database.SQLite.Simple (FromRow, NamedParam (..), SQLData (..), SQLError, field) +import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.QQ (sql) import Database.SQLite.Simple.ToField (ToField (..)) -import Network.Socket (HostName, ServiceName) +import Network.Socket (ServiceName) import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema) import Simplex.Messaging.Agent.Transmission -import Simplex.Messaging.Protocol (MsgBody) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftIOEither) import System.Exit (ExitCode (ExitFailure), exitWith) @@ -87,75 +88,160 @@ checkDuplicate action = liftIOEither $ first handleError <$> E.try action instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where createRcvConn :: SQLiteStore -> RcvQueue -> m () - createRcvConn SQLiteStore {dbConn} = checkDuplicate . createRcvQueueAndConn dbConn + createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} = + checkDuplicate $ + DB.withTransaction dbConn $ do + upsertServer_ dbConn server + insertRcvQueue_ dbConn q + insertRcvConnection_ dbConn q createSndConn :: SQLiteStore -> SndQueue -> m () - createSndConn SQLiteStore {dbConn} = checkDuplicate . createSndQueueAndConn dbConn + createSndConn SQLiteStore {dbConn} q@SndQueue {server} = + checkDuplicate $ + DB.withTransaction dbConn $ do + upsertServer_ dbConn server + insertSndQueue_ dbConn q + insertSndConnection_ dbConn q getConn :: SQLiteStore -> ConnAlias -> m SomeConn - getConn SQLiteStore {dbConn} connAlias = do - queues <- - liftIO $ - retrieveConnQueues dbConn connAlias - case queues of - (Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) - (Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ) - (Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ) - _ -> throwError SEConnNotFound + getConn SQLiteStore {dbConn} connAlias = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias getAllConnAliases :: SQLiteStore -> m [ConnAlias] getAllConnAliases SQLiteStore {dbConn} = - liftIO $ - retrieveAllConnAliases dbConn + liftIO $ do + r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] + return (concat r) getRcvQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m RcvQueue getRcvQueue SQLiteStore {dbConn} SMPServer {host, port} rcvId = do - rcvQueue <- + r <- liftIO $ - retrieveRcvQueue dbConn host port rcvId - case rcvQueue of - Just rcvQ -> return rcvQ + DB.queryNamed + dbConn + [sql| + SELECT + s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, + q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status + FROM rcv_queues q + INNER JOIN servers s ON q.host = s.host AND q.port = s.port + WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; + |] + [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] + case r of + [(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> + let srv = SMPServer hst (deserializePort_ prt) keyHash + in pure $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status _ -> throwError SEConnNotFound deleteConn :: SQLiteStore -> ConnAlias -> m () deleteConn SQLiteStore {dbConn} connAlias = liftIO $ - deleteConnCascade dbConn connAlias + DB.executeNamed + dbConn + "DELETE FROM connections WHERE conn_alias = :conn_alias;" + [":conn_alias" := connAlias] upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m () - upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sndQueue = - liftIOEither $ - updateRcvConnWithSndQueue dbConn connAlias sndQueue + upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias >>= \case + Right (SomeConn SCRcv (RcvConnection _ _)) -> do + upsertServer_ dbConn server + insertSndQueue_ dbConn sq + updateConnWithSndQueue_ dbConn connAlias sq + pure $ Right () + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m () - upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rcvQueue = - liftIOEither $ - updateSndConnWithRcvQueue dbConn connAlias rcvQueue + upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} = + liftIOEither . DB.withTransaction dbConn $ + getConn_ dbConn connAlias >>= \case + Right (SomeConn SCSnd (SndConnection _ _)) -> do + upsertServer_ dbConn server + insertRcvQueue_ dbConn rq + updateConnWithRcvQueue_ dbConn connAlias rq + pure $ Right () + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m () - setRcvQueueStatus SQLiteStore {dbConn} rcvQueue status = + setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status = + -- ? throw error if queue doesn't exist? liftIO $ - updateRcvQueueStatus dbConn rcvQueue status + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m () - setRcvQueueActive SQLiteStore {dbConn} rcvQueue verifyKey = + setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = + -- ? throw error if queue doesn't exist? liftIO $ - updateRcvQueueActive dbConn rcvQueue verifyKey + DB.executeNamed + dbConn + [sql| + UPDATE rcv_queues + SET verify_key = :verify_key, status = :status + WHERE host = :host AND port = :port AND rcv_id = :rcv_id; + |] + [ ":verify_key" := Just verifyKey, + ":status" := Active, + ":host" := host, + ":port" := serializePort_ port, + ":rcv_id" := rcvId + ] setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () - setSndQueueStatus SQLiteStore {dbConn} sndQueue status = + setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status = + -- ? throw error if queue doesn't exist? liftIO $ - updateSndQueueStatus dbConn sndQueue status + DB.executeNamed + dbConn + [sql| + UPDATE snd_queues + SET status = :status + WHERE host = :host AND port = :port AND snd_id = :snd_id; + |] + [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - createRcvMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId - createRcvMsg SQLiteStore {dbConn} connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) = - liftIOEither $ - insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) + updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) + updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} = + liftIO . DB.withTransaction dbConn $ do + (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias + let internalId = InternalId $ unId lastInternalId + 1 + internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 + updateLastIdsRcv_ dbConn connAlias internalId internalRcvId + pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) - createSndMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> m InternalId - createSndMsg SQLiteStore {dbConn} connAlias msgBody internalTs = - liftIOEither $ - insertSndMsg dbConn connAlias msgBody internalTs + createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m () + createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData = + liftIO . DB.withTransaction dbConn $ do + insertRcvMsgBase_ dbConn connAlias rcvMsgData + insertRcvMsgDetails_ dbConn connAlias rcvMsgData + updateHashRcv_ dbConn connAlias rcvMsgData + + updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash) + updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} = + liftIO . DB.withTransaction dbConn $ do + (lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias + let internalId = InternalId $ unId lastInternalId + 1 + internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 + updateLastIdsSnd_ dbConn connAlias internalId internalSndId + pure (internalId, internalSndId, prevSndHash) + + createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m () + createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData = + liftIO . DB.withTransaction dbConn $ do + insertSndMsgBase_ dbConn connAlias sndMsgData + insertSndMsgDetails_ dbConn connAlias sndMsgData + updateHashSnd_ dbConn connAlias sndMsgData getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg getMsg _st _connAlias _id = throwError SENotImplemented @@ -228,13 +314,6 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do -- * createRcvConn helpers -createRcvQueueAndConn :: DB.Connection -> RcvQueue -> IO () -createRcvQueueAndConn dbConn rcvQueue = - DB.withTransaction dbConn $ do - upsertServer_ dbConn (server (rcvQueue :: RcvQueue)) - insertRcvQueue_ dbConn rcvQueue - insertRcvConnection_ dbConn rcvQueue - insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO () insertRcvQueue_ dbConn RcvQueue {..} = do let port_ = serializePort_ $ port server @@ -266,22 +345,16 @@ insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do [sql| INSERT INTO connections ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id) + last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, + last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES (:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, - 0, 0, 0); + 0, 0, 0, 0, x'', x''); |] [":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId] -- * createSndConn helpers -createSndQueueAndConn :: DB.Connection -> SndQueue -> IO () -createSndQueueAndConn dbConn sndQueue = - DB.withTransaction dbConn $ do - upsertServer_ dbConn (server (sndQueue :: SndQueue)) - insertSndQueue_ dbConn sndQueue - insertSndConnection_ dbConn sndQueue - insertSndQueue_ :: DB.Connection -> SndQueue -> IO () insertSndQueue_ dbConn SndQueue {..} = do let port_ = serializePort_ $ port server @@ -311,28 +384,25 @@ insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do [sql| INSERT INTO connections ( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, - last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id) + last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, + last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash) VALUES (:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id, - 0, 0, 0); + 0, 0, 0, 0, x'', x''); |] [":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId] -- * getConn helpers -retrieveConnQueues :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue) -retrieveConnQueues dbConn connAlias = - DB.withTransaction -- Avoid inconsistent state between queue reads - dbConn - $ retrieveConnQueues_ dbConn connAlias - --- Separate transactionless version of retrieveConnQueues to be reused in other functions that already wrap --- multiple statements in transaction - otherwise they'd be attempting to start a transaction within a transaction -retrieveConnQueues_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue) -retrieveConnQueues_ dbConn connAlias = do - rcvQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias - sndQ <- retrieveSndQueueByConnAlias_ dbConn connAlias - return (rcvQ, sndQ) +getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn) +getConn_ dbConn connAlias = do + rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias + sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias + pure $ case (rQ, sQ) of + (Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ) + (Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ) + (Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ) + _ -> Left SEConnNotFound retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue) retrieveRcvQueueByConnAlias_ dbConn connAlias = do @@ -374,60 +444,8 @@ retrieveSndQueueByConnAlias_ dbConn connAlias = do return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status _ -> return Nothing --- * getAllConnAliases helper - -retrieveAllConnAliases :: DB.Connection -> IO [ConnAlias] -retrieveAllConnAliases dbConn = do - r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]] - return (concat r) - --- * getRcvQueue helper - -retrieveRcvQueue :: DB.Connection -> HostName -> Maybe ServiceName -> SMP.RecipientId -> IO (Maybe RcvQueue) -retrieveRcvQueue dbConn host port rcvId = do - r <- - DB.queryNamed - dbConn - [sql| - SELECT - s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key, - q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status - FROM rcv_queues q - INNER JOIN servers s ON q.host = s.host AND q.port = s.port - WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id; - |] - [":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] - case r of - [(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do - let srv = SMPServer hst (deserializePort_ prt) keyHash - return . Just $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status - _ -> return Nothing - --- * deleteConn helper - -deleteConnCascade :: DB.Connection -> ConnAlias -> IO () -deleteConnCascade dbConn connAlias = - DB.executeNamed - dbConn - "DELETE FROM connections WHERE conn_alias = :conn_alias;" - [":conn_alias" := connAlias] - -- * upgradeRcvConnToDuplex helpers -updateRcvConnWithSndQueue :: DB.Connection -> ConnAlias -> SndQueue -> IO (Either StoreError ()) -updateRcvConnWithSndQueue dbConn connAlias sndQueue = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Just _rcvQ, Nothing) -> do - upsertServer_ dbConn (server (sndQueue :: SndQueue)) - insertSndQueue_ dbConn sndQueue - updateConnWithSndQueue_ dbConn connAlias sndQueue - return $ Right () - (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEConnNotFound - updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO () updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do let port_ = serializePort_ $ port server @@ -442,20 +460,6 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do -- * upgradeSndConnToDuplex helpers -updateSndConnWithRcvQueue :: DB.Connection -> ConnAlias -> RcvQueue -> IO (Either StoreError ()) -updateSndConnWithRcvQueue dbConn connAlias rcvQueue = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Nothing, Just _sndQ) -> do - upsertServer_ dbConn (server (rcvQueue :: RcvQueue)) - insertRcvQueue_ dbConn rcvQueue - updateConnWithRcvQueue_ dbConn connAlias rcvQueue - return $ Right () - (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - (Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex) - _ -> return $ Left SEConnNotFound - updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO () updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do let port_ = serializePort_ $ port server @@ -468,93 +472,40 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do |] [":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias] --- * setRcvQueueStatus helper +-- * updateRcvIds helpers --- ? throw error if queue doesn't exist? -updateRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () -updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} status = - DB.executeNamed - dbConn - [sql| - UPDATE rcv_queues - SET status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; - |] - [":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId] - --- * setRcvQueueActive helper - --- ? throw error if queue doesn't exist? -updateRcvQueueActive :: DB.Connection -> RcvQueue -> VerificationKey -> IO () -updateRcvQueueActive dbConn RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey = - DB.executeNamed - dbConn - [sql| - UPDATE rcv_queues - SET verify_key = :verify_key, status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; - |] - [ ":verify_key" := Just verifyKey, - ":status" := Active, - ":host" := host, - ":port" := serializePort_ port, - ":rcv_id" := rcvId - ] - --- * setSndQueueStatus helper - --- ? throw error if queue doesn't exist? -updateSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO () -updateSndQueueStatus dbConn SndQueue {sndId, server = SMPServer {host, port}} status = - DB.executeNamed - dbConn - [sql| - UPDATE snd_queues - SET status = :status - WHERE host = :host AND port = :port AND snd_id = :snd_id; - |] - [":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId] - --- * createRcvMsg helpers - -insertRcvMsg :: - DB.Connection -> - ConnAlias -> - MsgBody -> - InternalTs -> - (ExternalSndId, ExternalSndTs) -> - (BrokerId, BrokerTs) -> - IO (Either StoreError InternalId) -insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (Just _rcvQ, _) -> do - (lastInternalId, lastInternalRcvId) <- retrieveLastInternalIdsRcv_ dbConn connAlias - let internalId = InternalId $ unId lastInternalId + 1 - let internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 - insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody - insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) - updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId - return $ Right internalId - (Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd) - _ -> return $ Left SEConnNotFound - -retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId) -retrieveLastInternalIdsRcv_ dbConn connAlias = do - [(lastInternalId, lastInternalRcvId)] <- +retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +retrieveLastIdsAndHashRcv_ dbConn connAlias = do + [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- DB.queryNamed dbConn [sql| - SELECT last_internal_msg_id, last_internal_rcv_msg_id + SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash FROM connections WHERE conn_alias = :conn_alias; |] [":conn_alias" := connAlias] - return (lastInternalId, lastInternalRcvId) + return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) -insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalRcvId -> MsgBody -> IO () -insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = do +updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () +updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = + DB.executeNamed + dbConn + [sql| + UPDATE connections + SET last_internal_msg_id = :last_internal_msg_id, + last_internal_rcv_msg_id = :last_internal_rcv_msg_id + WHERE conn_alias = :conn_alias; + |] + [ ":last_internal_msg_id" := newInternalId, + ":last_internal_rcv_msg_id" := newInternalRcvId, + ":conn_alias" := connAlias + ] + +-- * createRcvMsg helpers + +insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do DB.executeNamed dbConn [sql| @@ -570,15 +521,8 @@ insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = ":body" := decodeUtf8 msgBody ] -insertRcvMsgDetails_ :: - DB.Connection -> - ConnAlias -> - InternalRcvId -> - InternalId -> - (ExternalSndId, ExternalSndTs) -> - (BrokerId, BrokerTs) -> - IO () -insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) = +insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} = DB.executeNamed dbConn [sql| @@ -592,60 +536,65 @@ insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, e [ ":conn_alias" := connAlias, ":internal_rcv_id" := internalRcvId, ":internal_id" := internalId, - ":external_snd_id" := externalSndId, - ":external_snd_ts" := externalSndTs, - ":broker_id" := brokerId, - ":broker_ts" := brokerTs, + ":external_snd_id" := fst senderMeta, + ":external_snd_ts" := snd senderMeta, + ":broker_id" := fst brokerMeta, + ":broker_ts" := snd brokerMeta, ":rcv_status" := Received ] -updateLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO () -updateLastInternalIdsRcv_ dbConn connAlias newInternalId newInternalRcvId = +updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO () +updateHashRcv_ dbConn connAlias RcvMsgData {..} = + DB.executeNamed + dbConn + -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved + [sql| + UPDATE connections + SET last_external_snd_msg_id = :last_external_snd_msg_id, + last_rcv_msg_hash = :last_rcv_msg_hash + WHERE conn_alias = :conn_alias + AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; + |] + [ ":last_external_snd_msg_id" := fst senderMeta, + ":last_rcv_msg_hash" := msgHash, + ":conn_alias" := connAlias, + ":last_internal_rcv_msg_id" := internalRcvId + ] + +-- * updateSndIds helpers + +retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash) +retrieveLastIdsAndHashSnd_ dbConn connAlias = do + [(lastInternalId, lastInternalSndId, lastSndHash)] <- + DB.queryNamed + dbConn + [sql| + SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash + FROM connections + WHERE conn_alias = :conn_alias; + |] + [":conn_alias" := connAlias] + return (lastInternalId, lastInternalSndId, lastSndHash) + +updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () +updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId = DB.executeNamed dbConn [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, last_internal_rcv_msg_id = :last_internal_rcv_msg_id + SET last_internal_msg_id = :last_internal_msg_id, + last_internal_snd_msg_id = :last_internal_snd_msg_id WHERE conn_alias = :conn_alias; |] [ ":last_internal_msg_id" := newInternalId, - ":last_internal_rcv_msg_id" := newInternalRcvId, + ":last_internal_snd_msg_id" := newInternalSndId, ":conn_alias" := connAlias ] -- * createSndMsg helpers -insertSndMsg :: DB.Connection -> ConnAlias -> MsgBody -> InternalTs -> IO (Either StoreError InternalId) -insertSndMsg dbConn connAlias msgBody internalTs = - DB.withTransaction dbConn $ do - queues <- retrieveConnQueues_ dbConn connAlias - case queues of - (_, Just _sndQ) -> do - (lastInternalId, lastInternalSndId) <- retrieveLastInternalIdsSnd_ dbConn connAlias - let internalId = InternalId $ unId lastInternalId + 1 - let internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 - insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody - insertSndMsgDetails_ dbConn connAlias internalSndId internalId - updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId - return $ Right internalId - (Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv) - _ -> return $ Left SEConnNotFound - -retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId) -retrieveLastInternalIdsSnd_ dbConn connAlias = do - [(lastInternalId, lastInternalSndId)] <- - DB.queryNamed - dbConn - [sql| - SELECT last_internal_msg_id, last_internal_snd_msg_id - FROM connections - WHERE conn_alias = :conn_alias; - |] - [":conn_alias" := connAlias] - return (lastInternalId, lastInternalSndId) - -insertSndMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalSndId -> MsgBody -> IO () -insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = do +insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do DB.executeNamed dbConn [sql| @@ -661,8 +610,8 @@ insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = ":body" := decodeUtf8 msgBody ] -insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> InternalSndId -> InternalId -> IO () -insertSndMsgDetails_ dbConn connAlias internalSndId internalId = +insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +insertSndMsgDetails_ dbConn connAlias SndMsgData {..} = DB.executeNamed dbConn [sql| @@ -677,16 +626,18 @@ insertSndMsgDetails_ dbConn connAlias internalSndId internalId = ":snd_status" := Created ] -updateLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO () -updateLastInternalIdsSnd_ dbConn connAlias newInternalId newInternalSndId = +updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO () +updateHashSnd_ dbConn connAlias SndMsgData {..} = DB.executeNamed dbConn + -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved [sql| UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, last_internal_snd_msg_id = :last_internal_snd_msg_id - WHERE conn_alias = :conn_alias; + SET last_snd_msg_hash = :last_snd_msg_hash + WHERE conn_alias = :conn_alias + AND last_internal_snd_msg_id = :last_internal_snd_msg_id; |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_alias" := connAlias + [ ":last_snd_msg_hash" := msgHash, + ":conn_alias" := connAlias, + ":last_internal_snd_msg_id" := internalSndId ] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs index 12d886a4e..959a071e9 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Schema.hs @@ -90,6 +90,9 @@ connections = last_internal_msg_id INTEGER NOT NULL, last_internal_rcv_msg_id INTEGER NOT NULL, last_internal_snd_msg_id INTEGER NOT NULL, + last_external_snd_msg_id INTEGER NOT NULL, + last_rcv_msg_hash BLOB NOT NULL, + last_snd_msg_hash BLOB NOT NULL, PRIMARY KEY (conn_alias), FOREIGN KEY (rcv_host, rcv_port, rcv_id) REFERENCES rcv_queues (host, port, rcv_id), FOREIGN KEY (snd_host, snd_port, snd_id) REFERENCES snd_queues (host, port, snd_id) diff --git a/src/Simplex/Messaging/Agent/Transmission.hs b/src/Simplex/Messaging/Agent/Transmission.hs index 3bc9a7592..b29b919cf 100644 --- a/src/Simplex/Messaging/Agent/Transmission.hs +++ b/src/Simplex/Messaging/Agent/Transmission.hs @@ -13,7 +13,7 @@ module Simplex.Messaging.Agent.Transmission where -import Control.Applicative ((<|>)) +import Control.Applicative (optional, (<|>)) import Control.Monad.IO.Class import Data.Attoparsec.ByteString.Char8 (Parser) import qualified Data.Attoparsec.ByteString.Char8 as A @@ -90,11 +90,11 @@ data ACommand (p :: AParty) where SEND :: MsgBody -> ACommand Client SENT :: AgentMsgId -> ACommand Agent MSG :: - { m_recipient :: (AgentMsgId, UTCTime), - m_broker :: (MsgId, UTCTime), - m_sender :: (AgentMsgId, UTCTime), - m_status :: MsgStatus, - m_body :: MsgBody + { recipientMeta :: (AgentMsgId, UTCTime), + brokerMeta :: (MsgId, UTCTime), + senderMeta :: (AgentMsgId, UTCTime), + msgIntegrity :: MsgIntegrity, + msgBody :: MsgBody } -> ACommand Agent -- ACK :: AgentMsgId -> ACommand Client @@ -142,7 +142,9 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE SMPMessage <$> A.decimal <* A.space <*> tsISO8601P <* A.space - <*> base64P <* A.endOfLine + -- TODO previous message hash should become mandatory when we support HELLO and REPLY + -- (for HELLO it would be the hash of SMPConfirmation) + <*> (base64P <|> pure "") <* A.endOfLine <*> agentMessageP serializeSMPMessage :: SMPMessage -> ByteString @@ -175,17 +177,11 @@ smpQueueInfoP = "smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP) smpServerP :: Parser SMPServer -smpServerP = SMPServer <$> server <*> port <*> kHash +smpServerP = SMPServer <$> server <*> optional port <*> optional kHash where server = B.unpack <$> A.takeTill (A.inClass ":# ") - port = fromChar ':' $ show <$> (A.decimal :: Parser Int) - kHash = fromChar '#' C.keyHashP - fromChar :: Char -> Parser a -> Parser (Maybe a) - fromChar ch parser = do - c <- A.peekChar - if c == Just ch - then A.char ch *> (Just <$> parser) - else pure Nothing + port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit) + kHash = A.char '#' *> C.keyHashP parseAgentMessage :: ByteString -> Either AgentErrorType AMessage parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE @@ -241,10 +237,10 @@ type AgentMsgId = Int64 type SenderTimestamp = UTCTime -data MsgStatus = MsgOk | MsgError MsgErrorType +data MsgIntegrity = MsgOk | MsgError MsgErrorType deriving (Eq, Show) -data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash +data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash | MsgDuplicate deriving (Eq, Show) -- | error type used in errors sent to agent clients @@ -319,22 +315,23 @@ commandP = sendCmd = ACmd SClient . SEND <$> A.takeByteString sentResp = ACmd SAgent . SENT <$> A.decimal message = do - m_status <- status <* A.space - m_recipient <- "R=" *> partyMeta A.decimal - m_broker <- "B=" *> partyMeta base64P - m_sender <- "S=" *> partyMeta A.decimal - m_body <- A.takeByteString - return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body} + msgIntegrity <- integrity <* A.space + recipientMeta <- "R=" *> partyMeta A.decimal + brokerMeta <- "B=" *> partyMeta base64P + senderMeta <- "S=" *> partyMeta A.decimal + msgBody <- A.takeByteString + return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody} replyMode = " NO_REPLY" $> ReplyOff <|> A.space *> (ReplyVia <$> smpServerP) <|> pure ReplyOn partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space - status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) + integrity = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType) msgErrorType = "ID " *> (MsgBadId <$> A.decimal) <|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal) <|> "HASH" $> MsgBadHash + <|> "DUPLICATE" $> MsgDuplicate agentError = ACmd SAgent . ERR <$> agentErrorTypeP parseCommand :: ByteString -> Either AgentErrorType ACmd @@ -350,14 +347,14 @@ serializeCommand = \case END -> "END" SEND msgBody -> "SEND " <> serializeMsg msgBody SENT mId -> "SENT " <> bshow mId - MSG {m_recipient = (rmId, rTs), m_broker = (bmId, bTs), m_sender = (smId, sTs), m_status, m_body} -> + MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} -> B.unwords [ "MSG", - msgStatus m_status, + serializeMsgIntegrity msgIntegrity, "R=" <> bshow rmId <> "," <> showTs rTs, "B=" <> encode bmId <> "," <> showTs bTs, "S=" <> bshow smId <> "," <> showTs sTs, - serializeMsg m_body + serializeMsg msgBody ] OFF -> "OFF" DEL -> "DEL" @@ -372,15 +369,16 @@ serializeCommand = \case ReplyOn -> "" showTs :: UTCTime -> ByteString showTs = B.pack . formatISO8601Millis - msgStatus :: MsgStatus -> ByteString - msgStatus = \case + serializeMsgIntegrity :: MsgIntegrity -> ByteString + serializeMsgIntegrity = \case MsgOk -> "OK" MsgError e -> - "ERR" <> case e of + "ERR " <> case e of MsgSkipped fromMsgId toMsgId -> B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId] MsgBadId aMsgId -> "ID " <> bshow aMsgId MsgBadHash -> "HASH" + MsgDuplicate -> "DUPLICATE" agentErrorTypeP :: Parser AgentErrorType agentErrorTypeP = @@ -443,7 +441,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p)) cmdWithMsgBody = \case SEND body -> SEND <$$> getMsgBody body - MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body + MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body cmd -> return $ Right cmd -- TODO refactor with server diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index b5022e96f..19f37d78f 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -36,6 +36,7 @@ module Simplex.Messaging.Crypto encodePubKey, serializeKeyHash, getKeyHash, + sha256Hash, privKeyP, pubKeyP, binaryPubKeyP, @@ -226,6 +227,9 @@ keyHashP = do getKeyHash :: ByteString -> KeyHash getKeyHash = KeyHash . hash +sha256Hash :: ByteString -> ByteString +sha256Hash = BA.convert . (hash :: ByteString -> Digest SHA256) + serializeHeader :: Header -> ByteString serializeHeader Header {aesKey, ivBytes, authTag, msgSize} = unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index ad0c365d0..6a1754f13 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -80,7 +80,7 @@ h #:# err = tryGet `shouldReturn` () _ -> return () pattern Msg :: MsgBody -> ACommand 'Agent -pattern Msg m_body <- MSG {m_body} +pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk} testDuplexConnection :: Handle -> Handle -> IO () testDuplexConnection alice bob = do @@ -88,13 +88,13 @@ testDuplexConnection alice bob = do let qInfo' = serializeSmpQueueInfo qInfo bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", CON) alice <# ("", "bob", CON) - alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT _) -> True; _ -> False - alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT _) -> True; _ -> False + alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT 1) -> True; _ -> False + alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT 2) -> True; _ -> False bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False - bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT _) -> True; _ -> False + bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT 3) -> True; _ -> False alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False - bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT _) -> True; _ -> False + bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT 4) -> True; _ -> False alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False alice #: ("5", "bob", "OFF") #> ("5", "bob", OK) bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH)) diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 46132c832..450e568f3 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -1,12 +1,15 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} module AgentTests.SQLiteTests (storeTests) where import Control.Monad.Except (ExceptT, runExceptT) import qualified Crypto.PubKey.RSA as R +import Data.ByteString.Char8 (ByteString) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Time @@ -49,45 +52,50 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e -- TODO add null port tests storeTests :: Spec storeTests = withStore do - describe "compiled as threadsafe" testCompiledThreadsafe - describe "foreign keys enabled" testForeignKeysEnabled + describe "store setup" do + testCompiledThreadsafe + testForeignKeysEnabled describe "store methods" do - describe "createRcvConn" do - describe "unique" testCreateRcvConn - describe "duplicate" testCreateRcvConnDuplicate - describe "createSndConn" do - describe "unique" testCreateSndConn - describe "duplicate" testCreateSndConnDuplicate - describe "getAllConnAliases" testGetAllConnAliases - describe "getRcvQueue" testGetRcvQueue - describe "deleteConn" do - describe "RcvConnection" testDeleteRcvConn - describe "SndConnection" testDeleteSndConn - describe "DuplexConnection" testDeleteDuplexConn - describe "upgradeRcvConnToDuplex" testUpgradeRcvConnToDuplex - describe "upgradeSndConnToDuplex" testUpgradeSndConnToDuplex - describe "set queue status" do - describe "setRcvQueueStatus" testSetRcvQueueStatus - describe "setSndQueueStatus" testSetSndQueueStatus - describe "DuplexConnection" testSetQueueStatusDuplex - xdescribe "RcvQueue does not exist" testSetRcvQueueStatusNoQueue - xdescribe "SndQueue does not exist" testSetSndQueueStatusNoQueue - describe "createRcvMsg" do - describe "RcvQueue exists" testCreateRcvMsg - describe "RcvQueue does not exist" testCreateRcvMsgNoQueue - describe "createSndMsg" do - describe "SndQueue exists" testCreateSndMsg - describe "SndQueue does not exist" testCreateSndMsgNoQueue + describe "Queue and Connection management" do + describe "createRcvConn" do + testCreateRcvConn + testCreateRcvConnDuplicate + describe "createSndConn" do + testCreateSndConn + testCreateSndConnDuplicate + describe "getAllConnAliases" testGetAllConnAliases + describe "getRcvQueue" testGetRcvQueue + describe "deleteConn" do + testDeleteRcvConn + testDeleteSndConn + testDeleteDuplexConn + describe "upgradeRcvConnToDuplex" do + testUpgradeRcvConnToDuplex + describe "upgradeSndConnToDuplex" do + testUpgradeSndConnToDuplex + describe "set Queue status" do + describe "setRcvQueueStatus" do + testSetRcvQueueStatus + testSetRcvQueueStatusNoQueue + describe "setSndQueueStatus" do + testSetSndQueueStatus + testSetSndQueueStatusNoQueue + testSetQueueStatusDuplex + describe "Msg management" do + describe "create Msg" do + testCreateRcvMsg + testCreateSndMsg + testCreateRcvAndSndMsgs testCompiledThreadsafe :: SpecWith SQLiteStore testCompiledThreadsafe = do - it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do + it "compiled sqlite library should be threadsafe" $ \store -> do compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] compileOptions `shouldNotContain` [["THREADSAFE=0"]] testForeignKeysEnabled :: SpecWith SQLiteStore testForeignKeysEnabled = do - it "should throw error if foreign keys are enabled" $ \store -> do + it "foreign keys should be enabled" $ \store -> do let inconsistentQuery = [sql| INSERT INTO connections @@ -139,8 +147,7 @@ testCreateRcvConn = do testCreateRcvConnDuplicate :: SpecWith SQLiteStore testCreateRcvConnDuplicate = do it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 createRcvConn store rcvQueue1 `throwsError` SEConnDuplicate @@ -159,18 +166,15 @@ testCreateSndConn = do testCreateSndConnDuplicate :: SpecWith SQLiteStore testCreateSndConnDuplicate = do it "should throw error on attempt to create duplicate SndConnection" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 createSndConn store sndQueue1 `throwsError` SEConnDuplicate testGetAllConnAliases :: SpecWith SQLiteStore testGetAllConnAliases = do it "should get all conn aliases" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - createSndConn store sndQueue1 {connAlias = "conn2"} - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"} getAllConnAliases store `returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias] @@ -179,16 +183,14 @@ testGetRcvQueue = do it "should get RcvQueue" $ \store -> do let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash let recipientId = "1234" - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getRcvQueue store smpServer recipientId `returnsResult` rcvQueue1 testDeleteRcvConn :: SpecWith SQLiteStore testDeleteRcvConn = do it "should create RcvConnection and delete it" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getConn store "conn1" `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) deleteConn store "conn1" @@ -200,8 +202,7 @@ testDeleteRcvConn = do testDeleteSndConn :: SpecWith SQLiteStore testDeleteSndConn = do it "should create SndConnection and delete it" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) deleteConn store "conn1" @@ -213,10 +214,8 @@ testDeleteSndConn = do testDeleteDuplexConn :: SpecWith SQLiteStore testDeleteDuplexConn = do it "should create DuplexConnection and delete it" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) deleteConn store "conn1" @@ -228,8 +227,7 @@ testDeleteDuplexConn = do testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore testUpgradeRcvConnToDuplex = do it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 let anotherSndQueue = SndQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, @@ -242,16 +240,14 @@ testUpgradeRcvConnToDuplex = do } upgradeRcvConnToDuplex store "conn1" anotherSndQueue `throwsError` SEBadConnType CSnd - upgradeSndConnToDuplex store "conn1" rcvQueue1 - `returnsResult` () + _ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1 upgradeRcvConnToDuplex store "conn1" anotherSndQueue `throwsError` SEBadConnType CDuplex testUpgradeSndConnToDuplex :: SpecWith SQLiteStore testUpgradeSndConnToDuplex = do it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 let anotherRcvQueue = RcvQueue { server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash, @@ -266,16 +262,14 @@ testUpgradeSndConnToDuplex = do } upgradeSndConnToDuplex store "conn1" anotherRcvQueue `throwsError` SEBadConnType CRcv - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 upgradeSndConnToDuplex store "conn1" anotherRcvQueue `throwsError` SEBadConnType CDuplex testSetRcvQueueStatus :: SpecWith SQLiteStore testSetRcvQueueStatus = do it "should update status of RcvQueue" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 getConn store "conn1" `returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1) setRcvQueueStatus store rcvQueue1 Confirmed @@ -286,8 +280,7 @@ testSetRcvQueueStatus = do testSetSndQueueStatus :: SpecWith SQLiteStore testSetSndQueueStatus = do it "should update status of SndQueue" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + _ <- runExceptT $ createSndConn store sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1) setSndQueueStatus store sndQueue1 Confirmed @@ -298,10 +291,8 @@ testSetSndQueueStatus = do testSetQueueStatusDuplex :: SpecWith SQLiteStore testSetQueueStatusDuplex = do it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () - upgradeRcvConnToDuplex store "conn1" sndQueue1 - `returnsResult` () + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 getConn store "conn1" `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1) setRcvQueueStatus store rcvQueue1 Secured @@ -311,61 +302,87 @@ testSetQueueStatusDuplex = do setSndQueueStatus store sndQueue1 Confirmed `returnsResult` () getConn store "conn1" - `returnsResult` SomeConn - SCDuplex - ( DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed} - ) + `returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}) testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore testSetRcvQueueStatusNoQueue = do - it "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do + xit "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do setRcvQueueStatus store rcvQueue1 Confirmed - `throwsError` SEInternal "" + `throwsError` SEConnNotFound testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore testSetSndQueueStatusNoQueue = do - it "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do + xit "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do setSndQueueStatus store sndQueue1 Confirmed - `throwsError` SEInternal "" + `throwsError` SEConnNotFound + +hw :: ByteString +hw = encodeUtf8 "Hello world!" + +ts :: UTCTime +ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) + +mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData +mkRcvMsgData internalId internalRcvId externalSndId brokerId msgHash = + RcvMsgData + { internalId, + internalRcvId, + internalTs = ts, + senderMeta = (externalSndId, ts), + brokerMeta = (brokerId, ts), + msgBody = hw, + msgHash = msgHash, + msgIntegrity = MsgOk + } + +testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation +testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do + updateRcvIds store rcvQueue + `returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash) + createRcvMsg store rcvQueue rcvMsgData + `returnsResult` () testCreateRcvMsg :: SpecWith SQLiteStore testCreateRcvMsg = do - it "should create a RcvMsg and return InternalId" $ \store -> do - createRcvConn store rcvQueue1 - `returnsResult` () + it "should reserve internal ids and create a RcvMsg" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 -- TODO getMsg to check message - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `returnsResult` InternalId 1 + testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy" + testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy" -testCreateRcvMsgNoQueue :: SpecWith SQLiteStore -testCreateRcvMsgNoQueue = do - it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEConnNotFound - createSndConn store sndQueue1 - `returnsResult` () - createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts) - `throwsError` SEBadConnType CSnd +mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData +mkSndMsgData internalId internalSndId msgHash = + SndMsgData + { internalId, + internalSndId, + internalTs = ts, + msgBody = hw, + msgHash = msgHash + } + +testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation +testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do + updateSndIds store sndQueue + `returnsResult` (internalId, internalSndId, expectedPrevHash) + createSndMsg store sndQueue sndMsgData + `returnsResult` () testCreateSndMsg :: SpecWith SQLiteStore testCreateSndMsg = do - it "should create a SndMsg and return InternalId" $ \store -> do - createSndConn store sndQueue1 - `returnsResult` () + it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do + _ <- runExceptT $ createSndConn store sndQueue1 -- TODO getMsg to check message - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `returnsResult` InternalId 1 + testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" + testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" -testCreateSndMsgNoQueue :: SpecWith SQLiteStore -testCreateSndMsgNoQueue = do - it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do - let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0) - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEConnNotFound - createRcvConn store rcvQueue1 - `returnsResult` () - createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts - `throwsError` SEBadConnType CRcv +testCreateRcvAndSndMsgs :: SpecWith SQLiteStore +testCreateRcvAndSndMsgs = do + it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do + _ <- runExceptT $ createRcvConn store rcvQueue1 + _ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1 + testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1" + testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2" + testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1" + testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3" + testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2" + testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"