mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 20:45:52 +00:00
agent: verify msg integrity based on previous msg hash and id (#110)
Co-authored-by: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com>
This commit is contained in:
@@ -259,7 +259,7 @@ receiveFromAgent t ct c = forever . atomically $ do
|
||||
INV qInfo -> Invitation qInfo
|
||||
CON -> Connected contact
|
||||
END -> Disconnected contact
|
||||
MSG {m_body} -> ReceivedMessage contact m_body
|
||||
MSG {msgBody} -> ReceivedMessage contact msgBody
|
||||
SENT _ -> NoChatResponse
|
||||
OK -> Confirmation contact
|
||||
ERR (CONN e) -> ContactError e contact
|
||||
|
||||
@@ -177,10 +177,22 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
sendMsg sq = do
|
||||
senderTs <- liftIO getCurrentTime
|
||||
senderId <- withStore $ createSndMsg st connAlias msgBody senderTs
|
||||
sendAgentMessage c sq senderTs $ A_MSG msgBody
|
||||
respond $ SENT (unId senderId)
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
|
||||
let msgStr =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
{ senderMsgId = unSndId internalSndId,
|
||||
senderTimestamp = internalTs,
|
||||
previousMsgHash,
|
||||
agentMessage = A_MSG msgBody
|
||||
}
|
||||
msgHash = C.sha256Hash msgStr
|
||||
withStore $
|
||||
createSndMsg st sq $
|
||||
SndMsgData {internalId, internalSndId, internalTs, msgBody, msgHash}
|
||||
sendAgentMessage c sq msgStr
|
||||
respond $ SENT (unId internalId)
|
||||
|
||||
suspendConnection :: m ()
|
||||
suspendConnection =
|
||||
@@ -208,8 +220,14 @@ processCommand c@AgentClient {sndQ} st (corrId, connAlias, cmd) =
|
||||
sendReplyQInfo srv sq = do
|
||||
(rq, qInfo) <- newReceiveQueue c srv connAlias
|
||||
withStore $ upgradeSndConnToDuplex st connAlias rq
|
||||
senderTs <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq senderTs $ REPLY qInfo
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage = REPLY qInfo
|
||||
}
|
||||
|
||||
respond :: ACommand 'Agent -> m ()
|
||||
respond = respond' connAlias
|
||||
@@ -231,7 +249,9 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
case cmd of
|
||||
SMP.MSG srvMsgId srvTs msgBody -> do
|
||||
-- TODO deduplicate with previously received
|
||||
agentMsg <- liftEither . parseSMPMessage =<< decryptAndVerify rq msgBody
|
||||
msg <- decryptAndVerify rq msgBody
|
||||
let msgHash = C.sha256Hash msg
|
||||
agentMsg <- liftEither $ parseSMPMessage msg
|
||||
case agentMsg of
|
||||
SMPConfirmation senderKey -> do
|
||||
logServer "<--" c srv rId "MSG <KEY>"
|
||||
@@ -244,7 +264,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
secureQueue c rq senderKey
|
||||
withStore $ setRcvQueueStatus st rq Secured
|
||||
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
|
||||
SMPMessage {agentMessage, senderMsgId, senderTimestamp} ->
|
||||
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
|
||||
case agentMessage of
|
||||
HELLO verifyKey _ -> do
|
||||
logServer "<--" c srv rId "MSG <HELLO>"
|
||||
@@ -259,24 +279,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
withStore $ upgradeRcvConnToDuplex st connAlias sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
notify connAlias CON
|
||||
A_MSG body -> do
|
||||
-- TODO check message status
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
case status of
|
||||
Active -> do
|
||||
recipientTs <- liftIO getCurrentTime
|
||||
let m_sender = (senderMsgId, senderTimestamp)
|
||||
let m_broker = (srvMsgId, srvTs)
|
||||
recipientId <- withStore $ createRcvMsg st connAlias body recipientTs m_sender m_broker
|
||||
notify connAlias $
|
||||
MSG
|
||||
{ m_status = MsgOk,
|
||||
m_recipient = (unId recipientId, recipientTs),
|
||||
m_sender,
|
||||
m_broker,
|
||||
m_body = body
|
||||
}
|
||||
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
|
||||
A_MSG body -> agentClientMsg rq previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
|
||||
sendAck c rq
|
||||
return ()
|
||||
SMP.END -> do
|
||||
@@ -289,6 +292,44 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
where
|
||||
notify :: ConnAlias -> ACommand 'Agent -> m ()
|
||||
notify connAlias msg = atomically $ writeTBQueue sndQ ("", connAlias, msg)
|
||||
agentClientMsg :: RcvQueue -> PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
|
||||
agentClientMsg rq@RcvQueue {connAlias, status} receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
case status of
|
||||
Active -> do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
|
||||
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash
|
||||
withStore $
|
||||
createRcvMsg st rq $
|
||||
RcvMsgData
|
||||
{ internalId,
|
||||
internalRcvId,
|
||||
internalTs,
|
||||
senderMeta,
|
||||
brokerMeta,
|
||||
msgBody,
|
||||
msgHash,
|
||||
msgIntegrity
|
||||
}
|
||||
notify connAlias $
|
||||
MSG
|
||||
{ recipientMeta = (unId internalId, internalTs),
|
||||
senderMeta,
|
||||
brokerMeta,
|
||||
msgBody,
|
||||
msgIntegrity
|
||||
}
|
||||
_ -> notify connAlias . ERR $ AGENT A_PROHIBITED
|
||||
where
|
||||
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> MsgIntegrity
|
||||
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash
|
||||
| extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk
|
||||
| extSndId < prevExtSndId = MsgError $ MsgBadId extSndId
|
||||
| extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate
|
||||
| extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1)
|
||||
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
|
||||
| otherwise = MsgError MsgDuplicate -- this case is not possible
|
||||
|
||||
connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m ()
|
||||
connectToSendQueue c st sq senderKey verifyKey = do
|
||||
|
||||
@@ -226,7 +226,7 @@ sendConfirmation c sq@SndQueue {server, sndId} senderKey =
|
||||
liftSMP $ sendSMPMessage smp Nothing sndId msg
|
||||
where
|
||||
mkConfirmation :: SMPClient -> m MsgBody
|
||||
mkConfirmation smp = encryptAndSign smp sq $ SMPConfirmation senderKey
|
||||
mkConfirmation smp = encryptAndSign smp sq . serializeSMPMessage $ SMPConfirmation senderKey
|
||||
|
||||
sendHello :: forall m. AgentMonad m => AgentClient -> SndQueue -> VerificationKey -> m ()
|
||||
sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey =
|
||||
@@ -236,8 +236,14 @@ sendHello c sq@SndQueue {server, sndId, sndPrivateKey} verifyKey =
|
||||
where
|
||||
mkHello :: SMPClient -> AckMode -> m ByteString
|
||||
mkHello smp ackMode = do
|
||||
senderTs <- liftIO getCurrentTime
|
||||
mkAgentMessage smp sq senderTs $ HELLO verifyKey ackMode
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
encryptAndSign smp sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage = HELLO verifyKey ackMode
|
||||
}
|
||||
|
||||
send :: Int -> Int -> ByteString -> SMPClient -> ExceptT SMPClientError IO ()
|
||||
send 0 _ _ _ = throwE $ SMPServerError AUTH
|
||||
@@ -268,27 +274,17 @@ deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> SenderTimestamp -> AMessage -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} senderTs agentMsg =
|
||||
sendAgentMessage :: AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msg =
|
||||
withLogSMP_ c server sndId "SEND <message>" $ \smp -> do
|
||||
msg <- mkAgentMessage smp sq senderTs agentMsg
|
||||
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
||||
msg' <- encryptAndSign smp sq msg
|
||||
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg'
|
||||
|
||||
mkAgentMessage :: AgentMonad m => SMPClient -> SndQueue -> SenderTimestamp -> AMessage -> m ByteString
|
||||
mkAgentMessage smp sq senderTs agentMessage = do
|
||||
encryptAndSign smp sq $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp = senderTs,
|
||||
previousMsgHash = "1234", -- TODO hash of the previous message
|
||||
agentMessage
|
||||
}
|
||||
|
||||
encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> SMPMessage -> m ByteString
|
||||
encryptAndSign :: AgentMonad m => SMPClient -> SndQueue -> ByteString -> m ByteString
|
||||
encryptAndSign smp SndQueue {encryptKey, signKey} msg = do
|
||||
paddedSize <- asks $ (blockSize smp -) . reservedMsgSize
|
||||
liftError cryptoError $ do
|
||||
enc <- C.encrypt encryptKey paddedSize $ serializeSMPMessage msg
|
||||
enc <- C.encrypt encryptKey paddedSize msg
|
||||
C.Signature sig <- C.sign signKey enc
|
||||
pure $ sig <> enc
|
||||
|
||||
|
||||
@@ -43,8 +43,12 @@ class Monad m => MonadAgentStore s m where
|
||||
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
|
||||
|
||||
-- Msg management
|
||||
createRcvMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId
|
||||
createSndMsg :: s -> ConnAlias -> MsgBody -> InternalTs -> m InternalId
|
||||
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
|
||||
|
||||
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
|
||||
|
||||
getMsg :: s -> ConnAlias -> InternalId -> m Msg
|
||||
|
||||
-- * Queue types
|
||||
@@ -104,6 +108,11 @@ data SConnType :: ConnType -> Type where
|
||||
SCSnd :: SConnType CSnd
|
||||
SCDuplex :: SConnType CDuplex
|
||||
|
||||
connType :: SConnType c -> ConnType
|
||||
connType SCRcv = CRcv
|
||||
connType SCSnd = CSnd
|
||||
connType SCDuplex = CDuplex
|
||||
|
||||
deriving instance Eq (SConnType d)
|
||||
|
||||
deriving instance Show (SConnType d)
|
||||
@@ -125,6 +134,40 @@ instance Eq SomeConn where
|
||||
|
||||
deriving instance Show SomeConn
|
||||
|
||||
-- * Message integrity validation types
|
||||
|
||||
type MsgHash = ByteString
|
||||
|
||||
-- | Corresponds to `last_external_snd_msg_id` in `connections` table
|
||||
type PrevExternalSndId = Int64
|
||||
|
||||
-- | Corresponds to `last_rcv_msg_hash` in `connections` table
|
||||
type PrevRcvMsgHash = MsgHash
|
||||
|
||||
-- | Corresponds to `last_snd_msg_hash` in `connections` table
|
||||
type PrevSndMsgHash = MsgHash
|
||||
|
||||
-- * Message data containers - used on Msg creation to reduce number of parameters
|
||||
|
||||
data RcvMsgData = RcvMsgData
|
||||
{ internalId :: InternalId,
|
||||
internalRcvId :: InternalRcvId,
|
||||
internalTs :: InternalTs,
|
||||
senderMeta :: (ExternalSndId, ExternalSndTs),
|
||||
brokerMeta :: (BrokerId, BrokerTs),
|
||||
msgBody :: MsgBody,
|
||||
msgHash :: MsgHash,
|
||||
msgIntegrity :: MsgIntegrity
|
||||
}
|
||||
|
||||
data SndMsgData = SndMsgData
|
||||
{ internalId :: InternalId,
|
||||
internalSndId :: InternalSndId,
|
||||
internalTs :: InternalTs,
|
||||
msgBody :: MsgBody,
|
||||
msgHash :: MsgHash
|
||||
}
|
||||
|
||||
-- * Message types
|
||||
|
||||
-- | A message in either direction that is stored by the agent.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@@ -28,17 +29,17 @@ import Data.Maybe (fromMaybe)
|
||||
import Data.Text (isPrefixOf)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple (FromRow, NamedParam (..), SQLData (..), SQLError, field)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Schema (createSchema)
|
||||
import Simplex.Messaging.Agent.Transmission
|
||||
import Simplex.Messaging.Protocol (MsgBody)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (bshow, liftIOEither)
|
||||
import System.Exit (ExitCode (ExitFailure), exitWith)
|
||||
@@ -87,75 +88,160 @@ checkDuplicate action = liftIOEither $ first handleError <$> E.try action
|
||||
|
||||
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
|
||||
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
|
||||
createRcvConn SQLiteStore {dbConn} = checkDuplicate . createRcvQueueAndConn dbConn
|
||||
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
|
||||
checkDuplicate $
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn q
|
||||
insertRcvConnection_ dbConn q
|
||||
|
||||
createSndConn :: SQLiteStore -> SndQueue -> m ()
|
||||
createSndConn SQLiteStore {dbConn} = checkDuplicate . createSndQueueAndConn dbConn
|
||||
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
|
||||
checkDuplicate $
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn q
|
||||
insertSndConnection_ dbConn q
|
||||
|
||||
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
|
||||
getConn SQLiteStore {dbConn} connAlias = do
|
||||
queues <-
|
||||
liftIO $
|
||||
retrieveConnQueues dbConn connAlias
|
||||
case queues of
|
||||
(Just rcvQ, Just sndQ) -> return $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> return $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
|
||||
(Nothing, Just sndQ) -> return $ SomeConn SCSnd (SndConnection connAlias sndQ)
|
||||
_ -> throwError SEConnNotFound
|
||||
getConn SQLiteStore {dbConn} connAlias =
|
||||
liftIOEither . DB.withTransaction dbConn $
|
||||
getConn_ dbConn connAlias
|
||||
|
||||
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
|
||||
getAllConnAliases SQLiteStore {dbConn} =
|
||||
liftIO $
|
||||
retrieveAllConnAliases dbConn
|
||||
liftIO $ do
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
|
||||
return (concat r)
|
||||
|
||||
getRcvQueue :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m RcvQueue
|
||||
getRcvQueue SQLiteStore {dbConn} SMPServer {host, port} rcvId = do
|
||||
rcvQueue <-
|
||||
r <-
|
||||
liftIO $
|
||||
retrieveRcvQueue dbConn host port rcvId
|
||||
case rcvQueue of
|
||||
Just rcvQ -> return rcvQ
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
|
||||
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id;
|
||||
|]
|
||||
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
case r of
|
||||
[(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] ->
|
||||
let srv = SMPServer hst (deserializePort_ prt) keyHash
|
||||
in pure $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
_ -> throwError SEConnNotFound
|
||||
|
||||
deleteConn :: SQLiteStore -> ConnAlias -> m ()
|
||||
deleteConn SQLiteStore {dbConn} connAlias =
|
||||
liftIO $
|
||||
deleteConnCascade dbConn connAlias
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
|
||||
[":conn_alias" := connAlias]
|
||||
|
||||
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sndQueue =
|
||||
liftIOEither $
|
||||
updateRcvConnWithSndQueue dbConn connAlias sndQueue
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
|
||||
liftIOEither . DB.withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
Right (SomeConn SCRcv (RcvConnection _ _)) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn sq
|
||||
updateConnWithSndQueue_ dbConn connAlias sq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rcvQueue =
|
||||
liftIOEither $
|
||||
updateSndConnWithRcvQueue dbConn connAlias rcvQueue
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
|
||||
liftIOEither . DB.withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
Right (SomeConn SCSnd (SndConnection _ _)) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn rq
|
||||
updateConnWithRcvQueue_ dbConn connAlias rq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueStatus SQLiteStore {dbConn} rcvQueue status =
|
||||
setRcvQueueStatus SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} status =
|
||||
-- ? throw error if queue doesn't exist?
|
||||
liftIO $
|
||||
updateRcvQueueStatus dbConn rcvQueue status
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
|
||||
setRcvQueueActive :: SQLiteStore -> RcvQueue -> VerificationKey -> m ()
|
||||
setRcvQueueActive SQLiteStore {dbConn} rcvQueue verifyKey =
|
||||
setRcvQueueActive SQLiteStore {dbConn} RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey =
|
||||
-- ? throw error if queue doesn't exist?
|
||||
liftIO $
|
||||
updateRcvQueueActive dbConn rcvQueue verifyKey
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET verify_key = :verify_key, status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[ ":verify_key" := Just verifyKey,
|
||||
":status" := Active,
|
||||
":host" := host,
|
||||
":port" := serializePort_ port,
|
||||
":rcv_id" := rcvId
|
||||
]
|
||||
|
||||
setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m ()
|
||||
setSndQueueStatus SQLiteStore {dbConn} sndQueue status =
|
||||
setSndQueueStatus SQLiteStore {dbConn} SndQueue {sndId, server = SMPServer {host, port}} status =
|
||||
-- ? throw error if queue doesn't exist?
|
||||
liftIO $
|
||||
updateSndQueueStatus dbConn sndQueue status
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE snd_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
|
||||
|
||||
createRcvMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> m InternalId
|
||||
createRcvMsg SQLiteStore {dbConn} connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) =
|
||||
liftIOEither $
|
||||
insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs)
|
||||
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
||||
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
|
||||
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
createSndMsg :: SQLiteStore -> ConnAlias -> MsgBody -> InternalTs -> m InternalId
|
||||
createSndMsg SQLiteStore {dbConn} connAlias msgBody internalTs =
|
||||
liftIOEither $
|
||||
insertSndMsg dbConn connAlias msgBody internalTs
|
||||
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
|
||||
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
insertRcvMsgBase_ dbConn connAlias rcvMsgData
|
||||
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
|
||||
updateHashRcv_ dbConn connAlias rcvMsgData
|
||||
|
||||
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
||||
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
|
||||
pure (internalId, internalSndId, prevSndHash)
|
||||
|
||||
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
|
||||
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
|
||||
liftIO . DB.withTransaction dbConn $ do
|
||||
insertSndMsgBase_ dbConn connAlias sndMsgData
|
||||
insertSndMsgDetails_ dbConn connAlias sndMsgData
|
||||
updateHashSnd_ dbConn connAlias sndMsgData
|
||||
|
||||
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
|
||||
getMsg _st _connAlias _id = throwError SENotImplemented
|
||||
@@ -228,13 +314,6 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
createRcvQueueAndConn :: DB.Connection -> RcvQueue -> IO ()
|
||||
createRcvQueueAndConn dbConn rcvQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn (server (rcvQueue :: RcvQueue))
|
||||
insertRcvQueue_ dbConn rcvQueue
|
||||
insertRcvConnection_ dbConn rcvQueue
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -266,22 +345,16 @@ insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id)
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
|
||||
0, 0, 0);
|
||||
0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
createSndQueueAndConn :: DB.Connection -> SndQueue -> IO ()
|
||||
createSndQueueAndConn dbConn sndQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
upsertServer_ dbConn (server (sndQueue :: SndQueue))
|
||||
insertSndQueue_ dbConn sndQueue
|
||||
insertSndConnection_ dbConn sndQueue
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn SndQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -311,28 +384,25 @@ insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id)
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
|
||||
0, 0, 0);
|
||||
0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
retrieveConnQueues :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue)
|
||||
retrieveConnQueues dbConn connAlias =
|
||||
DB.withTransaction -- Avoid inconsistent state between queue reads
|
||||
dbConn
|
||||
$ retrieveConnQueues_ dbConn connAlias
|
||||
|
||||
-- Separate transactionless version of retrieveConnQueues to be reused in other functions that already wrap
|
||||
-- multiple statements in transaction - otherwise they'd be attempting to start a transaction within a transaction
|
||||
retrieveConnQueues_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue, Maybe SndQueue)
|
||||
retrieveConnQueues_ dbConn connAlias = do
|
||||
rcvQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
|
||||
sndQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
|
||||
return (rcvQ, sndQ)
|
||||
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
|
||||
getConn_ dbConn connAlias = do
|
||||
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
|
||||
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
|
||||
pure $ case (rQ, sQ) of
|
||||
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
|
||||
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
|
||||
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
|
||||
@@ -374,60 +444,8 @@ retrieveSndQueueByConnAlias_ dbConn connAlias = do
|
||||
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
|
||||
_ -> return Nothing
|
||||
|
||||
-- * getAllConnAliases helper
|
||||
|
||||
retrieveAllConnAliases :: DB.Connection -> IO [ConnAlias]
|
||||
retrieveAllConnAliases dbConn = do
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
|
||||
return (concat r)
|
||||
|
||||
-- * getRcvQueue helper
|
||||
|
||||
retrieveRcvQueue :: DB.Connection -> HostName -> Maybe ServiceName -> SMP.RecipientId -> IO (Maybe RcvQueue)
|
||||
retrieveRcvQueue dbConn host port rcvId = do
|
||||
r <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
|
||||
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.host = :host AND q.port = :port AND q.rcv_id = :rcv_id;
|
||||
|]
|
||||
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
case r of
|
||||
[(keyHash, hst, prt, rId, connAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
|
||||
let srv = SMPServer hst (deserializePort_ prt) keyHash
|
||||
return . Just $ RcvQueue srv rId connAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
_ -> return Nothing
|
||||
|
||||
-- * deleteConn helper
|
||||
|
||||
deleteConnCascade :: DB.Connection -> ConnAlias -> IO ()
|
||||
deleteConnCascade dbConn connAlias =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
|
||||
[":conn_alias" := connAlias]
|
||||
|
||||
-- * upgradeRcvConnToDuplex helpers
|
||||
|
||||
updateRcvConnWithSndQueue :: DB.Connection -> ConnAlias -> SndQueue -> IO (Either StoreError ())
|
||||
updateRcvConnWithSndQueue dbConn connAlias sndQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(Just _rcvQ, Nothing) -> do
|
||||
upsertServer_ dbConn (server (sndQueue :: SndQueue))
|
||||
insertSndQueue_ dbConn sndQueue
|
||||
updateConnWithSndQueue_ dbConn connAlias sndQueue
|
||||
return $ Right ()
|
||||
(Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd)
|
||||
(Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex)
|
||||
_ -> return $ Left SEConnNotFound
|
||||
|
||||
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
|
||||
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -442,20 +460,6 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
|
||||
-- * upgradeSndConnToDuplex helpers
|
||||
|
||||
updateSndConnWithRcvQueue :: DB.Connection -> ConnAlias -> RcvQueue -> IO (Either StoreError ())
|
||||
updateSndConnWithRcvQueue dbConn connAlias rcvQueue =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(Nothing, Just _sndQ) -> do
|
||||
upsertServer_ dbConn (server (rcvQueue :: RcvQueue))
|
||||
insertRcvQueue_ dbConn rcvQueue
|
||||
updateConnWithRcvQueue_ dbConn connAlias rcvQueue
|
||||
return $ Right ()
|
||||
(Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv)
|
||||
(Just _rcvQ, Just _sndQ) -> return $ Left (SEBadConnType CDuplex)
|
||||
_ -> return $ Left SEConnNotFound
|
||||
|
||||
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
|
||||
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
@@ -468,93 +472,40 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
|]
|
||||
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
|
||||
|
||||
-- * setRcvQueueStatus helper
|
||||
-- * updateRcvIds helpers
|
||||
|
||||
-- ? throw error if queue doesn't exist?
|
||||
updateRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO ()
|
||||
updateRcvQueueStatus dbConn RcvQueue {rcvId, server = SMPServer {host, port}} status =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
|
||||
-- * setRcvQueueActive helper
|
||||
|
||||
-- ? throw error if queue doesn't exist?
|
||||
updateRcvQueueActive :: DB.Connection -> RcvQueue -> VerificationKey -> IO ()
|
||||
updateRcvQueueActive dbConn RcvQueue {rcvId, server = SMPServer {host, port}} verifyKey =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE rcv_queues
|
||||
SET verify_key = :verify_key, status = :status
|
||||
WHERE host = :host AND port = :port AND rcv_id = :rcv_id;
|
||||
|]
|
||||
[ ":verify_key" := Just verifyKey,
|
||||
":status" := Active,
|
||||
":host" := host,
|
||||
":port" := serializePort_ port,
|
||||
":rcv_id" := rcvId
|
||||
]
|
||||
|
||||
-- * setSndQueueStatus helper
|
||||
|
||||
-- ? throw error if queue doesn't exist?
|
||||
updateSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO ()
|
||||
updateSndQueueStatus dbConn SndQueue {sndId, server = SMPServer {host, port}} status =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE snd_queues
|
||||
SET status = :status
|
||||
WHERE host = :host AND port = :port AND snd_id = :snd_id;
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
|
||||
|
||||
-- * createRcvMsg helpers
|
||||
|
||||
insertRcvMsg ::
|
||||
DB.Connection ->
|
||||
ConnAlias ->
|
||||
MsgBody ->
|
||||
InternalTs ->
|
||||
(ExternalSndId, ExternalSndTs) ->
|
||||
(BrokerId, BrokerTs) ->
|
||||
IO (Either StoreError InternalId)
|
||||
insertRcvMsg dbConn connAlias msgBody internalTs (externalSndId, externalSndTs) (brokerId, brokerTs) =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(Just _rcvQ, _) -> do
|
||||
(lastInternalId, lastInternalRcvId) <- retrieveLastInternalIdsRcv_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
let internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
||||
insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody
|
||||
insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs)
|
||||
updateLastInternalIdsRcv_ dbConn connAlias internalId internalRcvId
|
||||
return $ Right internalId
|
||||
(Nothing, Just _sndQ) -> return $ Left (SEBadConnType CSnd)
|
||||
_ -> return $ Left SEConnNotFound
|
||||
|
||||
retrieveLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId)
|
||||
retrieveLastInternalIdsRcv_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalRcvId)] <-
|
||||
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_rcv_msg_id
|
||||
SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
return (lastInternalId, lastInternalRcvId)
|
||||
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalRcvId -> MsgBody -> IO ()
|
||||
insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody = do
|
||||
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_internal_msg_id = :last_internal_msg_id,
|
||||
last_internal_rcv_msg_id = :last_internal_rcv_msg_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_rcv_msg_id" := newInternalRcvId,
|
||||
":conn_alias" := connAlias
|
||||
]
|
||||
|
||||
-- * createRcvMsg helpers
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -570,15 +521,8 @@ insertRcvMsgBase_ dbConn connAlias internalId internalTs internalRcvId msgBody =
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertRcvMsgDetails_ ::
|
||||
DB.Connection ->
|
||||
ConnAlias ->
|
||||
InternalRcvId ->
|
||||
InternalId ->
|
||||
(ExternalSndId, ExternalSndTs) ->
|
||||
(BrokerId, BrokerTs) ->
|
||||
IO ()
|
||||
insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, externalSndTs) (brokerId, brokerTs) =
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -592,60 +536,65 @@ insertRcvMsgDetails_ dbConn connAlias internalRcvId internalId (externalSndId, e
|
||||
[ ":conn_alias" := connAlias,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":internal_id" := internalId,
|
||||
":external_snd_id" := externalSndId,
|
||||
":external_snd_ts" := externalSndTs,
|
||||
":broker_id" := brokerId,
|
||||
":broker_ts" := brokerTs,
|
||||
":external_snd_id" := fst senderMeta,
|
||||
":external_snd_ts" := snd senderMeta,
|
||||
":broker_id" := fst brokerMeta,
|
||||
":broker_ts" := snd brokerMeta,
|
||||
":rcv_status" := Received
|
||||
]
|
||||
|
||||
updateLastInternalIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastInternalIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_external_snd_msg_id = :last_external_snd_msg_id,
|
||||
last_rcv_msg_hash = :last_rcv_msg_hash
|
||||
WHERE conn_alias = :conn_alias
|
||||
AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id;
|
||||
|]
|
||||
[ ":last_external_snd_msg_id" := fst senderMeta,
|
||||
":last_rcv_msg_hash" := msgHash,
|
||||
":conn_alias" := connAlias,
|
||||
":last_internal_rcv_msg_id" := internalRcvId
|
||||
]
|
||||
|
||||
-- * updateSndIds helpers
|
||||
|
||||
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
return (lastInternalId, lastInternalSndId, lastSndHash)
|
||||
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_internal_msg_id = :last_internal_msg_id, last_internal_rcv_msg_id = :last_internal_rcv_msg_id
|
||||
SET last_internal_msg_id = :last_internal_msg_id,
|
||||
last_internal_snd_msg_id = :last_internal_snd_msg_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_rcv_msg_id" := newInternalRcvId,
|
||||
":last_internal_snd_msg_id" := newInternalSndId,
|
||||
":conn_alias" := connAlias
|
||||
]
|
||||
|
||||
-- * createSndMsg helpers
|
||||
|
||||
insertSndMsg :: DB.Connection -> ConnAlias -> MsgBody -> InternalTs -> IO (Either StoreError InternalId)
|
||||
insertSndMsg dbConn connAlias msgBody internalTs =
|
||||
DB.withTransaction dbConn $ do
|
||||
queues <- retrieveConnQueues_ dbConn connAlias
|
||||
case queues of
|
||||
(_, Just _sndQ) -> do
|
||||
(lastInternalId, lastInternalSndId) <- retrieveLastInternalIdsSnd_ dbConn connAlias
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
let internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
||||
insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody
|
||||
insertSndMsgDetails_ dbConn connAlias internalSndId internalId
|
||||
updateLastInternalIdsSnd_ dbConn connAlias internalId internalSndId
|
||||
return $ Right internalId
|
||||
(Just _rcvQ, Nothing) -> return $ Left (SEBadConnType CRcv)
|
||||
_ -> return $ Left SEConnNotFound
|
||||
|
||||
retrieveLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId)
|
||||
retrieveLastInternalIdsSnd_ dbConn connAlias = do
|
||||
[(lastInternalId, lastInternalSndId)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT last_internal_msg_id, last_internal_snd_msg_id
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
return (lastInternalId, lastInternalSndId)
|
||||
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> InternalId -> InternalTs -> InternalSndId -> MsgBody -> IO ()
|
||||
insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody = do
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -661,8 +610,8 @@ insertSndMsgBase_ dbConn connAlias internalId internalTs internalSndId msgBody =
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> InternalSndId -> InternalId -> IO ()
|
||||
insertSndMsgDetails_ dbConn connAlias internalSndId internalId =
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -677,16 +626,18 @@ insertSndMsgDetails_ dbConn connAlias internalSndId internalId =
|
||||
":snd_status" := Created
|
||||
]
|
||||
|
||||
updateLastInternalIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastInternalIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
updateHashSnd_ dbConn connAlias SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
|
||||
[sql|
|
||||
UPDATE connections
|
||||
SET last_internal_msg_id = :last_internal_msg_id, last_internal_snd_msg_id = :last_internal_snd_msg_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
SET last_snd_msg_hash = :last_snd_msg_hash
|
||||
WHERE conn_alias = :conn_alias
|
||||
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_snd_msg_id" := newInternalSndId,
|
||||
":conn_alias" := connAlias
|
||||
[ ":last_snd_msg_hash" := msgHash,
|
||||
":conn_alias" := connAlias,
|
||||
":last_internal_snd_msg_id" := internalSndId
|
||||
]
|
||||
|
||||
@@ -90,6 +90,9 @@ connections =
|
||||
last_internal_msg_id INTEGER NOT NULL,
|
||||
last_internal_rcv_msg_id INTEGER NOT NULL,
|
||||
last_internal_snd_msg_id INTEGER NOT NULL,
|
||||
last_external_snd_msg_id INTEGER NOT NULL,
|
||||
last_rcv_msg_hash BLOB NOT NULL,
|
||||
last_snd_msg_hash BLOB NOT NULL,
|
||||
PRIMARY KEY (conn_alias),
|
||||
FOREIGN KEY (rcv_host, rcv_port, rcv_id) REFERENCES rcv_queues (host, port, rcv_id),
|
||||
FOREIGN KEY (snd_host, snd_port, snd_id) REFERENCES snd_queues (host, port, snd_id)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
module Simplex.Messaging.Agent.Transmission where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Applicative (optional, (<|>))
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Attoparsec.ByteString.Char8 (Parser)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
@@ -90,11 +90,11 @@ data ACommand (p :: AParty) where
|
||||
SEND :: MsgBody -> ACommand Client
|
||||
SENT :: AgentMsgId -> ACommand Agent
|
||||
MSG ::
|
||||
{ m_recipient :: (AgentMsgId, UTCTime),
|
||||
m_broker :: (MsgId, UTCTime),
|
||||
m_sender :: (AgentMsgId, UTCTime),
|
||||
m_status :: MsgStatus,
|
||||
m_body :: MsgBody
|
||||
{ recipientMeta :: (AgentMsgId, UTCTime),
|
||||
brokerMeta :: (MsgId, UTCTime),
|
||||
senderMeta :: (AgentMsgId, UTCTime),
|
||||
msgIntegrity :: MsgIntegrity,
|
||||
msgBody :: MsgBody
|
||||
} ->
|
||||
ACommand Agent
|
||||
-- ACK :: AgentMsgId -> ACommand Client
|
||||
@@ -142,7 +142,9 @@ parseSMPMessage = parse (smpMessageP <* A.endOfLine) $ AGENT A_MESSAGE
|
||||
SMPMessage
|
||||
<$> A.decimal <* A.space
|
||||
<*> tsISO8601P <* A.space
|
||||
<*> base64P <* A.endOfLine
|
||||
-- TODO previous message hash should become mandatory when we support HELLO and REPLY
|
||||
-- (for HELLO it would be the hash of SMPConfirmation)
|
||||
<*> (base64P <|> pure "") <* A.endOfLine
|
||||
<*> agentMessageP
|
||||
|
||||
serializeSMPMessage :: SMPMessage -> ByteString
|
||||
@@ -175,17 +177,11 @@ smpQueueInfoP =
|
||||
"smp::" *> (SMPQueueInfo <$> smpServerP <* "::" <*> base64P <* "::" <*> C.pubKeyP)
|
||||
|
||||
smpServerP :: Parser SMPServer
|
||||
smpServerP = SMPServer <$> server <*> port <*> kHash
|
||||
smpServerP = SMPServer <$> server <*> optional port <*> optional kHash
|
||||
where
|
||||
server = B.unpack <$> A.takeTill (A.inClass ":# ")
|
||||
port = fromChar ':' $ show <$> (A.decimal :: Parser Int)
|
||||
kHash = fromChar '#' C.keyHashP
|
||||
fromChar :: Char -> Parser a -> Parser (Maybe a)
|
||||
fromChar ch parser = do
|
||||
c <- A.peekChar
|
||||
if c == Just ch
|
||||
then A.char ch *> (Just <$> parser)
|
||||
else pure Nothing
|
||||
port = A.char ':' *> (B.unpack <$> A.takeWhile1 A.isDigit)
|
||||
kHash = A.char '#' *> C.keyHashP
|
||||
|
||||
parseAgentMessage :: ByteString -> Either AgentErrorType AMessage
|
||||
parseAgentMessage = parse agentMessageP $ AGENT A_MESSAGE
|
||||
@@ -241,10 +237,10 @@ type AgentMsgId = Int64
|
||||
|
||||
type SenderTimestamp = UTCTime
|
||||
|
||||
data MsgStatus = MsgOk | MsgError MsgErrorType
|
||||
data MsgIntegrity = MsgOk | MsgError MsgErrorType
|
||||
deriving (Eq, Show)
|
||||
|
||||
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash
|
||||
data MsgErrorType = MsgSkipped AgentMsgId AgentMsgId | MsgBadId AgentMsgId | MsgBadHash | MsgDuplicate
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | error type used in errors sent to agent clients
|
||||
@@ -319,22 +315,23 @@ commandP =
|
||||
sendCmd = ACmd SClient . SEND <$> A.takeByteString
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
message = do
|
||||
m_status <- status <* A.space
|
||||
m_recipient <- "R=" *> partyMeta A.decimal
|
||||
m_broker <- "B=" *> partyMeta base64P
|
||||
m_sender <- "S=" *> partyMeta A.decimal
|
||||
m_body <- A.takeByteString
|
||||
return $ ACmd SAgent MSG {m_recipient, m_broker, m_sender, m_status, m_body}
|
||||
msgIntegrity <- integrity <* A.space
|
||||
recipientMeta <- "R=" *> partyMeta A.decimal
|
||||
brokerMeta <- "B=" *> partyMeta base64P
|
||||
senderMeta <- "S=" *> partyMeta A.decimal
|
||||
msgBody <- A.takeByteString
|
||||
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
|
||||
replyMode =
|
||||
" NO_REPLY" $> ReplyOff
|
||||
<|> A.space *> (ReplyVia <$> smpServerP)
|
||||
<|> pure ReplyOn
|
||||
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
|
||||
status = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType)
|
||||
integrity = "OK" $> MsgOk <|> "ERR " *> (MsgError <$> msgErrorType)
|
||||
msgErrorType =
|
||||
"ID " *> (MsgBadId <$> A.decimal)
|
||||
<|> "IDS " *> (MsgSkipped <$> A.decimal <* A.space <*> A.decimal)
|
||||
<|> "HASH" $> MsgBadHash
|
||||
<|> "DUPLICATE" $> MsgDuplicate
|
||||
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
|
||||
|
||||
parseCommand :: ByteString -> Either AgentErrorType ACmd
|
||||
@@ -350,14 +347,14 @@ serializeCommand = \case
|
||||
END -> "END"
|
||||
SEND msgBody -> "SEND " <> serializeMsg msgBody
|
||||
SENT mId -> "SENT " <> bshow mId
|
||||
MSG {m_recipient = (rmId, rTs), m_broker = (bmId, bTs), m_sender = (smId, sTs), m_status, m_body} ->
|
||||
MSG {recipientMeta = (rmId, rTs), brokerMeta = (bmId, bTs), senderMeta = (smId, sTs), msgIntegrity, msgBody} ->
|
||||
B.unwords
|
||||
[ "MSG",
|
||||
msgStatus m_status,
|
||||
serializeMsgIntegrity msgIntegrity,
|
||||
"R=" <> bshow rmId <> "," <> showTs rTs,
|
||||
"B=" <> encode bmId <> "," <> showTs bTs,
|
||||
"S=" <> bshow smId <> "," <> showTs sTs,
|
||||
serializeMsg m_body
|
||||
serializeMsg msgBody
|
||||
]
|
||||
OFF -> "OFF"
|
||||
DEL -> "DEL"
|
||||
@@ -372,15 +369,16 @@ serializeCommand = \case
|
||||
ReplyOn -> ""
|
||||
showTs :: UTCTime -> ByteString
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
msgStatus :: MsgStatus -> ByteString
|
||||
msgStatus = \case
|
||||
serializeMsgIntegrity :: MsgIntegrity -> ByteString
|
||||
serializeMsgIntegrity = \case
|
||||
MsgOk -> "OK"
|
||||
MsgError e ->
|
||||
"ERR" <> case e of
|
||||
"ERR " <> case e of
|
||||
MsgSkipped fromMsgId toMsgId ->
|
||||
B.unwords ["NO_ID", bshow fromMsgId, bshow toMsgId]
|
||||
MsgBadId aMsgId -> "ID " <> bshow aMsgId
|
||||
MsgBadHash -> "HASH"
|
||||
MsgDuplicate -> "DUPLICATE"
|
||||
|
||||
agentErrorTypeP :: Parser AgentErrorType
|
||||
agentErrorTypeP =
|
||||
@@ -443,7 +441,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p))
|
||||
cmdWithMsgBody = \case
|
||||
SEND body -> SEND <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS status body -> MSG agentMsgId srvTS agentTS status <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
|
||||
cmd -> return $ Right cmd
|
||||
|
||||
-- TODO refactor with server
|
||||
|
||||
@@ -36,6 +36,7 @@ module Simplex.Messaging.Crypto
|
||||
encodePubKey,
|
||||
serializeKeyHash,
|
||||
getKeyHash,
|
||||
sha256Hash,
|
||||
privKeyP,
|
||||
pubKeyP,
|
||||
binaryPubKeyP,
|
||||
@@ -226,6 +227,9 @@ keyHashP = do
|
||||
getKeyHash :: ByteString -> KeyHash
|
||||
getKeyHash = KeyHash . hash
|
||||
|
||||
sha256Hash :: ByteString -> ByteString
|
||||
sha256Hash = BA.convert . (hash :: ByteString -> Digest SHA256)
|
||||
|
||||
serializeHeader :: Header -> ByteString
|
||||
serializeHeader Header {aesKey, ivBytes, authTag, msgSize} =
|
||||
unKey aesKey <> unIV ivBytes <> authTagToBS authTag <> (encodeWord32 . fromIntegral) msgSize
|
||||
|
||||
@@ -80,7 +80,7 @@ h #:# err = tryGet `shouldReturn` ()
|
||||
_ -> return ()
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent
|
||||
pattern Msg m_body <- MSG {m_body}
|
||||
pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk}
|
||||
|
||||
testDuplexConnection :: Handle -> Handle -> IO ()
|
||||
testDuplexConnection alice bob = do
|
||||
@@ -88,13 +88,13 @@ testDuplexConnection alice bob = do
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", CON)
|
||||
alice <# ("", "bob", CON)
|
||||
alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT _) -> True; _ -> False
|
||||
alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT _) -> True; _ -> False
|
||||
alice #: ("2", "bob", "SEND :hello") =#> \case ("2", "bob", SENT 1) -> True; _ -> False
|
||||
alice #: ("3", "bob", "SEND :how are you?") =#> \case ("3", "bob", SENT 2) -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False
|
||||
bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT _) -> True; _ -> False
|
||||
bob #: ("14", "alice", "SEND 9\nhello too") =#> \case ("14", "alice", SENT 3) -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False
|
||||
bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT _) -> True; _ -> False
|
||||
bob #: ("15", "alice", "SEND 9\nmessage 1") =#> \case ("15", "alice", SENT 4) -> True; _ -> False
|
||||
alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False
|
||||
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
|
||||
bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", ERR (SMP AUTH))
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
|
||||
module AgentTests.SQLiteTests (storeTests) where
|
||||
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import qualified Crypto.PubKey.RSA as R
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Time
|
||||
@@ -49,45 +52,50 @@ action `throwsError` e = runExceptT action `shouldReturn` Left e
|
||||
-- TODO add null port tests
|
||||
storeTests :: Spec
|
||||
storeTests = withStore do
|
||||
describe "compiled as threadsafe" testCompiledThreadsafe
|
||||
describe "foreign keys enabled" testForeignKeysEnabled
|
||||
describe "store setup" do
|
||||
testCompiledThreadsafe
|
||||
testForeignKeysEnabled
|
||||
describe "store methods" do
|
||||
describe "createRcvConn" do
|
||||
describe "unique" testCreateRcvConn
|
||||
describe "duplicate" testCreateRcvConnDuplicate
|
||||
describe "createSndConn" do
|
||||
describe "unique" testCreateSndConn
|
||||
describe "duplicate" testCreateSndConnDuplicate
|
||||
describe "getAllConnAliases" testGetAllConnAliases
|
||||
describe "getRcvQueue" testGetRcvQueue
|
||||
describe "deleteConn" do
|
||||
describe "RcvConnection" testDeleteRcvConn
|
||||
describe "SndConnection" testDeleteSndConn
|
||||
describe "DuplexConnection" testDeleteDuplexConn
|
||||
describe "upgradeRcvConnToDuplex" testUpgradeRcvConnToDuplex
|
||||
describe "upgradeSndConnToDuplex" testUpgradeSndConnToDuplex
|
||||
describe "set queue status" do
|
||||
describe "setRcvQueueStatus" testSetRcvQueueStatus
|
||||
describe "setSndQueueStatus" testSetSndQueueStatus
|
||||
describe "DuplexConnection" testSetQueueStatusDuplex
|
||||
xdescribe "RcvQueue does not exist" testSetRcvQueueStatusNoQueue
|
||||
xdescribe "SndQueue does not exist" testSetSndQueueStatusNoQueue
|
||||
describe "createRcvMsg" do
|
||||
describe "RcvQueue exists" testCreateRcvMsg
|
||||
describe "RcvQueue does not exist" testCreateRcvMsgNoQueue
|
||||
describe "createSndMsg" do
|
||||
describe "SndQueue exists" testCreateSndMsg
|
||||
describe "SndQueue does not exist" testCreateSndMsgNoQueue
|
||||
describe "Queue and Connection management" do
|
||||
describe "createRcvConn" do
|
||||
testCreateRcvConn
|
||||
testCreateRcvConnDuplicate
|
||||
describe "createSndConn" do
|
||||
testCreateSndConn
|
||||
testCreateSndConnDuplicate
|
||||
describe "getAllConnAliases" testGetAllConnAliases
|
||||
describe "getRcvQueue" testGetRcvQueue
|
||||
describe "deleteConn" do
|
||||
testDeleteRcvConn
|
||||
testDeleteSndConn
|
||||
testDeleteDuplexConn
|
||||
describe "upgradeRcvConnToDuplex" do
|
||||
testUpgradeRcvConnToDuplex
|
||||
describe "upgradeSndConnToDuplex" do
|
||||
testUpgradeSndConnToDuplex
|
||||
describe "set Queue status" do
|
||||
describe "setRcvQueueStatus" do
|
||||
testSetRcvQueueStatus
|
||||
testSetRcvQueueStatusNoQueue
|
||||
describe "setSndQueueStatus" do
|
||||
testSetSndQueueStatus
|
||||
testSetSndQueueStatusNoQueue
|
||||
testSetQueueStatusDuplex
|
||||
describe "Msg management" do
|
||||
describe "create Msg" do
|
||||
testCreateRcvMsg
|
||||
testCreateSndMsg
|
||||
testCreateRcvAndSndMsgs
|
||||
|
||||
testCompiledThreadsafe :: SpecWith SQLiteStore
|
||||
testCompiledThreadsafe = do
|
||||
it "should throw error if compiled sqlite library is not threadsafe" $ \store -> do
|
||||
it "compiled sqlite library should be threadsafe" $ \store -> do
|
||||
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
|
||||
compileOptions `shouldNotContain` [["THREADSAFE=0"]]
|
||||
|
||||
testForeignKeysEnabled :: SpecWith SQLiteStore
|
||||
testForeignKeysEnabled = do
|
||||
it "should throw error if foreign keys are enabled" $ \store -> do
|
||||
it "foreign keys should be enabled" $ \store -> do
|
||||
let inconsistentQuery =
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
@@ -139,8 +147,7 @@ testCreateRcvConn = do
|
||||
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateRcvConnDuplicate = do
|
||||
it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
createRcvConn store rcvQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
|
||||
@@ -159,18 +166,15 @@ testCreateSndConn = do
|
||||
testCreateSndConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateSndConnDuplicate = do
|
||||
it "should throw error on attempt to create duplicate SndConnection" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
createSndConn store sndQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
|
||||
testGetAllConnAliases :: SpecWith SQLiteStore
|
||||
testGetAllConnAliases = do
|
||||
it "should get all conn aliases" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
createSndConn store sndQueue1 {connAlias = "conn2"}
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"}
|
||||
getAllConnAliases store
|
||||
`returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias]
|
||||
|
||||
@@ -179,16 +183,14 @@ testGetRcvQueue = do
|
||||
it "should get RcvQueue" $ \store -> do
|
||||
let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash
|
||||
let recipientId = "1234"
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
getRcvQueue store smpServer recipientId
|
||||
`returnsResult` rcvQueue1
|
||||
|
||||
testDeleteRcvConn :: SpecWith SQLiteStore
|
||||
testDeleteRcvConn = do
|
||||
it "should create RcvConnection and delete it" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
deleteConn store "conn1"
|
||||
@@ -200,8 +202,7 @@ testDeleteRcvConn = do
|
||||
testDeleteSndConn :: SpecWith SQLiteStore
|
||||
testDeleteSndConn = do
|
||||
it "should create SndConnection and delete it" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
@@ -213,10 +214,8 @@ testDeleteSndConn = do
|
||||
testDeleteDuplexConn :: SpecWith SQLiteStore
|
||||
testDeleteDuplexConn = do
|
||||
it "should create DuplexConnection and delete it" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
@@ -228,8 +227,7 @@ testDeleteDuplexConn = do
|
||||
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeRcvConnToDuplex = do
|
||||
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
@@ -242,16 +240,14 @@ testUpgradeRcvConnToDuplex = do
|
||||
}
|
||||
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
|
||||
`throwsError` SEBadConnType CSnd
|
||||
upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
|
||||
`throwsError` SEBadConnType CDuplex
|
||||
|
||||
testUpgradeSndConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeSndConnToDuplex = do
|
||||
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
@@ -266,16 +262,14 @@ testUpgradeSndConnToDuplex = do
|
||||
}
|
||||
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
|
||||
`throwsError` SEBadConnType CRcv
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
|
||||
`throwsError` SEBadConnType CDuplex
|
||||
|
||||
testSetRcvQueueStatus :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatus = do
|
||||
it "should update status of RcvQueue" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Confirmed
|
||||
@@ -286,8 +280,7 @@ testSetRcvQueueStatus = do
|
||||
testSetSndQueueStatus :: SpecWith SQLiteStore
|
||||
testSetSndQueueStatus = do
|
||||
it "should update status of SndQueue" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
@@ -298,10 +291,8 @@ testSetSndQueueStatus = do
|
||||
testSetQueueStatusDuplex :: SpecWith SQLiteStore
|
||||
testSetQueueStatusDuplex = do
|
||||
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Secured
|
||||
@@ -311,61 +302,87 @@ testSetQueueStatusDuplex = do
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn
|
||||
SCDuplex
|
||||
( DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}
|
||||
)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
|
||||
|
||||
testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatusNoQueue = do
|
||||
it "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do
|
||||
xit "should throw error on attempt to update status of non-existent RcvQueue" $ \store -> do
|
||||
setRcvQueueStatus store rcvQueue1 Confirmed
|
||||
`throwsError` SEInternal ""
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
testSetSndQueueStatusNoQueue :: SpecWith SQLiteStore
|
||||
testSetSndQueueStatusNoQueue = do
|
||||
it "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do
|
||||
xit "should throw error on attempt to update status of non-existent SndQueue" $ \store -> do
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`throwsError` SEInternal ""
|
||||
`throwsError` SEConnNotFound
|
||||
|
||||
hw :: ByteString
|
||||
hw = encodeUtf8 "Hello world!"
|
||||
|
||||
ts :: UTCTime
|
||||
ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
|
||||
mkRcvMsgData :: InternalId -> InternalRcvId -> ExternalSndId -> BrokerId -> MsgHash -> RcvMsgData
|
||||
mkRcvMsgData internalId internalRcvId externalSndId brokerId msgHash =
|
||||
RcvMsgData
|
||||
{ internalId,
|
||||
internalRcvId,
|
||||
internalTs = ts,
|
||||
senderMeta = (externalSndId, ts),
|
||||
brokerMeta = (brokerId, ts),
|
||||
msgBody = hw,
|
||||
msgHash = msgHash,
|
||||
msgIntegrity = MsgOk
|
||||
}
|
||||
|
||||
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do
|
||||
updateRcvIds store rcvQueue
|
||||
`returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
|
||||
createRcvMsg store rcvQueue rcvMsgData
|
||||
`returnsResult` ()
|
||||
|
||||
testCreateRcvMsg :: SpecWith SQLiteStore
|
||||
testCreateRcvMsg = do
|
||||
it "should create a RcvMsg and return InternalId" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
it "should reserve internal ids and create a RcvMsg" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
-- TODO getMsg to check message
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
|
||||
`returnsResult` InternalId 1
|
||||
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
|
||||
testCreateRcvMsgNoQueue :: SpecWith SQLiteStore
|
||||
testCreateRcvMsgNoQueue = do
|
||||
it "should throw error on attempt to create a RcvMsg w/t a RcvQueue" $ \store -> do
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
|
||||
`throwsError` SEConnNotFound
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
createRcvMsg store "conn1" (encodeUtf8 "Hello world!") ts (1, ts) ("1", ts)
|
||||
`throwsError` SEBadConnType CSnd
|
||||
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
|
||||
mkSndMsgData internalId internalSndId msgHash =
|
||||
SndMsgData
|
||||
{ internalId,
|
||||
internalSndId,
|
||||
internalTs = ts,
|
||||
msgBody = hw,
|
||||
msgHash = msgHash
|
||||
}
|
||||
|
||||
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation
|
||||
testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do
|
||||
updateSndIds store sndQueue
|
||||
`returnsResult` (internalId, internalSndId, expectedPrevHash)
|
||||
createSndMsg store sndQueue sndMsgData
|
||||
`returnsResult` ()
|
||||
|
||||
testCreateSndMsg :: SpecWith SQLiteStore
|
||||
testCreateSndMsg = do
|
||||
it "should create a SndMsg and return InternalId" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
-- TODO getMsg to check message
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
|
||||
`returnsResult` InternalId 1
|
||||
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
|
||||
testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
|
||||
|
||||
testCreateSndMsgNoQueue :: SpecWith SQLiteStore
|
||||
testCreateSndMsgNoQueue = do
|
||||
it "should throw error on attempt to create a SndMsg w/t a SndQueue" $ \store -> do
|
||||
let ts = UTCTime (fromGregorian 2021 02 24) (secondsToDiffTime 0)
|
||||
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
|
||||
`throwsError` SEConnNotFound
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
createSndMsg store "conn1" (encodeUtf8 "Hello world!") ts
|
||||
`throwsError` SEBadConnType CRcv
|
||||
testCreateRcvAndSndMsgs :: SpecWith SQLiteStore
|
||||
testCreateRcvAndSndMsgs = do
|
||||
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
|
||||
Reference in New Issue
Block a user