mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 20:45:52 +00:00
introduction protocol (#156)
* commands to support introduction * agent messages / envelopes to support introductions * introductions and invitations table; insert record with random unique ID * store class methods and types for introductions * process INTRO and ACPT commands for connection introductions * fix tests: add MonadFail constraint, remove OK response to JOIN * process agent messages for introductions * ICON notification when introduction is completed * replace multiway if with case * correction * support random connection IDs * save additional connection fields, refactor create connection funcs * refactor * refactor * test duplex connection with random IDs * store methods for introductions * test introduction * fix parsing of CON agent message * test introduction with random connection IDs * broadcast with random connection and broadcast IDs * clean up sql
This commit is contained in:
committed by
GitHub
parent
46c3589604
commit
ab89963f45
@@ -1,9 +1,8 @@
|
||||
CREATE TABLE IF NOT EXISTS broadcasts (
|
||||
broadcast_id BLOB NOT NULL,
|
||||
PRIMARY KEY (broadcast_id)
|
||||
CREATE TABLE broadcasts (
|
||||
broadcast_id BLOB NOT NULL PRIMARY KEY
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS broadcast_connections (
|
||||
CREATE TABLE broadcast_connections (
|
||||
broadcast_id BLOB NOT NULL REFERENCES broadcasts (broadcast_id) ON DELETE CASCADE,
|
||||
conn_alias BLOB NOT NULL REFERENCES connections (conn_alias),
|
||||
PRIMARY KEY (broadcast_id, conn_alias)
|
||||
|
||||
27
migrations/20210602_introductions.sql
Normal file
27
migrations/20210602_introductions.sql
Normal file
@@ -0,0 +1,27 @@
|
||||
CREATE TABLE conn_intros (
|
||||
intro_id BLOB NOT NULL PRIMARY KEY,
|
||||
to_conn BLOB NOT NULL REFERENCES connections (conn_alias) ON DELETE CASCADE,
|
||||
to_info BLOB, -- info about "to" connection sent to "re" connection
|
||||
to_status TEXT NOT NULL DEFAULT '', -- '', INV, CON
|
||||
re_conn BLOB NOT NULL REFERENCES connections (conn_alias) ON DELETE CASCADE,
|
||||
re_info BLOB NOT NULL, -- info about "re" connection sent to "to" connection
|
||||
re_status TEXT NOT NULL DEFAULT '', -- '', INV, CON
|
||||
queue_info BLOB
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE conn_invitations (
|
||||
inv_id BLOB NOT NULL PRIMARY KEY,
|
||||
via_conn BLOB REFERENCES connections (conn_alias) ON DELETE SET NULL,
|
||||
external_intro_id BLOB NOT NULL,
|
||||
conn_info BLOB, -- info about another connection
|
||||
queue_info BLOB, -- NULL if it's an initial introduction
|
||||
conn_id BLOB REFERENCES connections (conn_alias) -- created connection
|
||||
ON DELETE CASCADE
|
||||
DEFERRABLE INITIALLY DEFERRED,
|
||||
status TEXT DEFAULT '' -- '', 'ACPT', 'CON'
|
||||
) WITHOUT ROWID;
|
||||
|
||||
ALTER TABLE connections
|
||||
ADD via_inv BLOB REFERENCES conn_invitations (inv_id) ON DELETE RESTRICT;
|
||||
ALTER TABLE connections
|
||||
ADD conn_level INTEGER DEFAULT 0;
|
||||
@@ -61,7 +61,7 @@ import UnliftIO.STM
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent t cfg = do
|
||||
started <- newEmptyTMVarIO
|
||||
runSMPAgentBlocking t started cfg
|
||||
@@ -70,10 +70,10 @@ runSMPAgent t cfg = do
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
|
||||
where
|
||||
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent :: forall c m'. (Transport c, MonadFail m', MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do
|
||||
liftIO $ putLn h "Welcome to SMP v0.3.2 agent"
|
||||
c <- getSMPAgentClient
|
||||
@@ -97,7 +97,7 @@ logConnection c connected =
|
||||
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
|
||||
|
||||
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
|
||||
runSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runSMPAgentClient :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runSMPAgentClient c = do
|
||||
db <- asks $ dbFile . config
|
||||
s1 <- liftIO $ connectSQLiteStore db
|
||||
@@ -128,7 +128,7 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a ->
|
||||
logClient AgentClient {clientId} dir (ATransmission corrId entity cmd) = do
|
||||
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, serializeEntity entity, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
client :: forall m. (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
client c@AgentClient {rcvQ, sndQ} st = forever loop
|
||||
where
|
||||
loop :: m ()
|
||||
@@ -147,6 +147,8 @@ withStore action = do
|
||||
Right c -> return c
|
||||
Left e -> throwError $ storeError e
|
||||
where
|
||||
-- TODO when parsing exception happens in store, the agent hangs;
|
||||
-- changing SQLError to SomeException does not help
|
||||
handleInternal :: (MonadError StoreError m') => SQLError -> m' a
|
||||
handleInternal e = throwError . SEInternal $ bshow e
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
@@ -173,23 +175,28 @@ unsupportedEntity c _ corrId entity _ =
|
||||
|
||||
processConnCommand ::
|
||||
forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m ()
|
||||
processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
NEW -> createNewConnection conn
|
||||
JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode
|
||||
processConnCommand c@AgentClient {sndQ} st corrId conn@(Conn connId) = \case
|
||||
NEW -> createNewConnection Nothing 0 >>= uncurry respond
|
||||
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode Nothing 0 >> pure () -- >>= (`respond` OK)
|
||||
INTRO reEntity reInfo -> makeIntroduction reEntity reInfo
|
||||
ACPT inv eInfo -> acceptInvitation inv eInfo
|
||||
SUB -> subscribeConnection conn
|
||||
SUBALL -> subscribeAll
|
||||
SEND msgBody -> sendMessage c st corrId conn msgBody
|
||||
SEND msgBody -> sendClientMessage c st corrId conn msgBody
|
||||
OFF -> suspendConnection conn
|
||||
DEL -> deleteConnection conn
|
||||
where
|
||||
createNewConnection :: Entity 'Conn_ -> m ()
|
||||
createNewConnection (Conn cId) = do
|
||||
createNewConnection :: Maybe InvitationId -> Int -> m (Entity 'Conn_, ACommand 'Agent 'INV_)
|
||||
createNewConnection viaInv connLevel = do
|
||||
-- TODO create connection alias if not passed
|
||||
-- make connAlias Maybe?
|
||||
-- make connId Maybe?
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv cId
|
||||
withStore $ createRcvConn st rq
|
||||
respond conn $ INV qInfo
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createRcvConn st g cData rq
|
||||
addSubscription c rq connId'
|
||||
pure (Conn connId', INV qInfo)
|
||||
|
||||
getSMPServer :: m SMPServer
|
||||
getSMPServer =
|
||||
@@ -200,16 +207,47 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
|
||||
pure $ servers L.!! i
|
||||
|
||||
joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m ()
|
||||
joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do
|
||||
-- TODO create connection alias if not passed
|
||||
-- make connAlias Maybe?
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo cId
|
||||
withStore $ createSndConn st sq
|
||||
joinConnection :: SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m (Entity 'Conn_)
|
||||
joinConnection qInfo (ReplyMode replyMode) viaInv connLevel = do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createSndConn st g cData sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
when (replyMode == On) $ createReplyQueue cId sq
|
||||
-- TODO this response is disabled to avoid two responses in terminal client (OK + CON),
|
||||
-- respond conn OK
|
||||
when (replyMode == On) $ createReplyQueue connId' sq
|
||||
pure $ Conn connId'
|
||||
|
||||
makeIntroduction :: IntroEntity -> EntityInfo -> m ()
|
||||
makeIntroduction (IE reEntity) reInfo = case reEntity of
|
||||
Conn reConn ->
|
||||
withStore ((,) <$> getConn st connId <*> getConn st reConn) >>= \case
|
||||
(SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do
|
||||
g <- asks idsDrg
|
||||
introId <- withStore $ createIntro st g NewIntroduction {toConn = connId, reConn, reInfo}
|
||||
sendControlMessage c sq $ A_INTRO (IE (Conn introId)) reInfo
|
||||
respond conn OK
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ CMD UNSUPPORTED
|
||||
|
||||
acceptInvitation :: IntroEntity -> EntityInfo -> m ()
|
||||
acceptInvitation (IE invEntity) eInfo = case invEntity of
|
||||
Conn invId -> do
|
||||
withStore (getInvitation st invId) >>= \case
|
||||
Invitation {viaConn, qInfo, externalIntroId, status = InvNew} ->
|
||||
withStore (getConn st viaConn) >>= \case
|
||||
SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of
|
||||
Nothing -> do
|
||||
(conn', INV qInfo') <- createNewConnection (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId $ fromConn conn'
|
||||
sendControlMessage c sq $ A_INV (Conn externalIntroId) qInfo' eInfo
|
||||
respond conn' OK
|
||||
Just qInfo' -> do
|
||||
conn' <- joinConnection qInfo' (ReplyMode On) (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId $ fromConn conn'
|
||||
respond conn' OK
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
_ -> throwError $ CMD UNSUPPORTED
|
||||
|
||||
subscribeConnection :: Entity 'Conn_ -> m ()
|
||||
subscribeConnection conn'@(Conn cId) =
|
||||
@@ -222,7 +260,7 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
|
||||
-- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct
|
||||
subscribeAll :: m ()
|
||||
subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn)
|
||||
subscribeAll = withStore (getAllConnIds st) >>= mapM_ (subscribeConnection . Conn)
|
||||
|
||||
suspendConnection :: Entity 'Conn_ -> m ()
|
||||
suspendConnection (Conn cId) =
|
||||
@@ -246,25 +284,30 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
|
||||
removeSubscription c cId
|
||||
delConn
|
||||
|
||||
createReplyQueue :: ByteString -> SndQueue -> m ()
|
||||
createReplyQueue :: ConnId -> SndQueue -> m ()
|
||||
createReplyQueue cId sq = do
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv cId
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
addSubscription c rq cId
|
||||
withStore $ upgradeSndConnToDuplex st cId rq
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage = REPLY qInfo
|
||||
}
|
||||
sendControlMessage c sq $ REPLY qInfo
|
||||
|
||||
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
|
||||
respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp
|
||||
|
||||
sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
|
||||
sendMessage c st corrId (Conn cId) msgBody =
|
||||
sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m ()
|
||||
sendControlMessage c sq agentMessage = do
|
||||
senderTimestamp <- liftIO getCurrentTime
|
||||
sendAgentMessage c sq . serializeSMPMessage $
|
||||
SMPMessage
|
||||
{ senderMsgId = 0,
|
||||
senderTimestamp,
|
||||
previousMsgHash = "",
|
||||
agentMessage
|
||||
}
|
||||
|
||||
sendClientMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
|
||||
sendClientMessage c st corrId (Conn cId) msgBody =
|
||||
withStore (getConn st cId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
|
||||
SomeConn _ (SndConnection _ sq) -> sendMsg sq
|
||||
@@ -273,7 +316,7 @@ sendMessage c st corrId (Conn cId) msgBody =
|
||||
sendMsg :: SndQueue -> m ()
|
||||
sendMsg sq = do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st cId
|
||||
let msgStr =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
@@ -284,7 +327,7 @@ sendMessage c st corrId (Conn cId) msgBody =
|
||||
}
|
||||
msgHash = C.sha256Hash msgStr
|
||||
withStore $
|
||||
createSndMsg st sq $
|
||||
createSndMsg st cId $
|
||||
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
|
||||
sendAgentMessage c sq msgStr
|
||||
atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId)
|
||||
@@ -292,15 +335,21 @@ sendMessage c st corrId (Conn cId) msgBody =
|
||||
processBroadcastCommand ::
|
||||
forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m ()
|
||||
processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
|
||||
NEW -> withStore (createBcast st bId) >> ok
|
||||
NEW -> createNewBroadcast
|
||||
ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok
|
||||
REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok
|
||||
LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn
|
||||
SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0)
|
||||
DEL -> withStore (deleteBcast st bId) >> ok
|
||||
where
|
||||
sendMsg :: MsgBody -> ConnAlias -> m ()
|
||||
sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody
|
||||
createNewBroadcast :: m ()
|
||||
createNewBroadcast = do
|
||||
g <- asks idsDrg
|
||||
bId' <- withStore $ createBcast st g bId
|
||||
respond (Broadcast bId') OK
|
||||
|
||||
sendMsg :: MsgBody -> ConnId -> m ()
|
||||
sendMsg msgBody cId = sendClientMessage c st corrId (Conn cId) msgBody
|
||||
|
||||
ok :: m ()
|
||||
ok = respond bcast OK
|
||||
@@ -308,7 +357,7 @@ processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
|
||||
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
|
||||
respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
subscriber :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
subscriber c@AgentClient {msgQ} st = forever $ do
|
||||
-- TODO this will only process messages and notifications
|
||||
t <- atomically $ readTBQueue msgQ
|
||||
@@ -319,29 +368,33 @@ subscriber c@AgentClient {msgQ} st = forever $ do
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
|
||||
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
withStore (getRcvConn st srv rId) >>= \case
|
||||
SomeConn SCDuplex (DuplexConnection _ rq _) -> processSMP SCDuplex rq
|
||||
SomeConn SCRcv (RcvConnection _ rq) -> processSMP SCRcv rq
|
||||
SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq
|
||||
SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq
|
||||
_ -> atomically . writeTBQueue sndQ $ ATransmission "" (Conn "") (ERR $ CONN SIMPLEX)
|
||||
where
|
||||
processSMP :: SConnType c -> RcvQueue -> m ()
|
||||
processSMP cType rq@RcvQueue {connAlias, status} =
|
||||
processSMP :: SConnType c -> ConnData -> RcvQueue -> m ()
|
||||
processSMP cType ConnData {connId} rq@RcvQueue {status} =
|
||||
case cmd of
|
||||
SMP.MSG srvMsgId srvTs msgBody -> do
|
||||
-- TODO deduplicate with previously received
|
||||
msg <- decryptAndVerify rq msgBody
|
||||
let msgHash = C.sha256Hash msg
|
||||
agentMsg <- liftEither $ parseSMPMessage msg
|
||||
case agentMsg of
|
||||
SMPConfirmation senderKey -> smpConfirmation senderKey
|
||||
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
|
||||
case parseSMPMessage msg of
|
||||
Left e -> notify $ ERR e
|
||||
Right (SMPConfirmation senderKey) -> smpConfirmation senderKey
|
||||
Right SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
|
||||
case agentMessage of
|
||||
HELLO verifyKey _ -> helloMsg verifyKey msgBody
|
||||
REPLY qInfo -> replyMsg qInfo
|
||||
A_MSG body -> agentClientMsg previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
|
||||
A_INTRO entity eInfo -> introMsg entity eInfo
|
||||
A_INV conn qInfo cInfo -> invMsg conn qInfo cInfo
|
||||
A_REQ conn qInfo cInfo -> reqMsg conn qInfo cInfo
|
||||
A_CON conn -> conMsg conn
|
||||
sendAck c rq
|
||||
return ()
|
||||
SMP.END -> do
|
||||
removeSubscription c connAlias
|
||||
removeSubscription c connId
|
||||
logServer "<--" c srv rId "END"
|
||||
notify END
|
||||
_ -> do
|
||||
@@ -349,7 +402,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
notify . ERR $ BROKER UNEXPECTED
|
||||
where
|
||||
notify :: EntityCommand 'Conn_ c => ACommand 'Agent c -> m ()
|
||||
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connAlias) msg
|
||||
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connId) msg
|
||||
|
||||
prohibited :: m ()
|
||||
prohibited = notify . ERR $ AGENT A_PROHIBITED
|
||||
@@ -376,7 +429,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
void $ verifyMessage (Just verifyKey) msgBody
|
||||
withStore $ setRcvQueueActive st rq verifyKey
|
||||
case cType of
|
||||
SCDuplex -> notify CON
|
||||
SCDuplex -> connected
|
||||
_ -> pure ()
|
||||
|
||||
replyMsg :: SMPQueueInfo -> m ()
|
||||
@@ -384,22 +437,87 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
logServer "<--" c srv rId "MSG <REPLY>"
|
||||
case cType of
|
||||
SCRcv -> do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias
|
||||
withStore $ upgradeRcvConnToDuplex st connAlias sq
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
withStore $ upgradeRcvConnToDuplex st connId sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
notify CON
|
||||
connected
|
||||
_ -> prohibited
|
||||
|
||||
connected :: m ()
|
||||
connected = do
|
||||
withStore (getConnInvitation st connId) >>= \case
|
||||
Just (Invitation {invId, externalIntroId}, DuplexConnection _ _ sq) -> do
|
||||
withStore $ setInvitationStatus st invId InvCon
|
||||
sendControlMessage c sq $ A_CON (Conn externalIntroId)
|
||||
_ -> pure ()
|
||||
notify CON
|
||||
|
||||
introMsg :: IntroEntity -> EntityInfo -> m ()
|
||||
introMsg (IE entity) entityInfo = do
|
||||
logServer "<--" c srv rId "MSG <INTRO>"
|
||||
case (cType, entity) of
|
||||
(SCDuplex, intro@Conn {}) -> createInv intro Nothing entityInfo
|
||||
_ -> prohibited
|
||||
|
||||
invMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
|
||||
invMsg (Conn introId) qInfo toInfo = do
|
||||
logServer "<--" c srv rId "MSG <INV>"
|
||||
case cType of
|
||||
SCDuplex ->
|
||||
withStore (getIntro st introId) >>= \case
|
||||
Introduction {toConn, toStatus = IntroNew, reConn, reStatus = IntroNew}
|
||||
| toConn /= connId -> prohibited
|
||||
| otherwise ->
|
||||
withStore (addIntroInvitation st introId toInfo qInfo >> getConn st reConn) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> do
|
||||
sendControlMessage c sq $ A_REQ (Conn introId) qInfo toInfo
|
||||
withStore $ setIntroReStatus st introId IntroInv
|
||||
_ -> prohibited
|
||||
_ -> prohibited
|
||||
_ -> prohibited
|
||||
|
||||
reqMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
|
||||
reqMsg intro qInfo connInfo = do
|
||||
logServer "<--" c srv rId "MSG <REQ>"
|
||||
case cType of
|
||||
SCDuplex -> createInv intro (Just qInfo) connInfo
|
||||
_ -> prohibited
|
||||
|
||||
createInv :: Entity 'Conn_ -> Maybe SMPQueueInfo -> EntityInfo -> m ()
|
||||
createInv (Conn externalIntroId) qInfo entityInfo = do
|
||||
g <- asks idsDrg
|
||||
let newInv = NewInvitation {viaConn = connId, externalIntroId, entityInfo, qInfo}
|
||||
invId <- withStore $ createInvitation st g newInv
|
||||
notify $ REQ (IE (Conn invId)) entityInfo
|
||||
|
||||
conMsg :: Entity 'Conn_ -> m ()
|
||||
conMsg (Conn introId) = do
|
||||
logServer "<--" c srv rId "MSG <CON>"
|
||||
withStore (getIntro st introId) >>= \case
|
||||
Introduction {toConn, toStatus, reConn, reStatus}
|
||||
| toConn == connId && toStatus == IntroInv -> do
|
||||
withStore $ setIntroToStatus st introId IntroCon
|
||||
sendConMsg toConn reConn reStatus
|
||||
| reConn == connId && reStatus == IntroInv -> do
|
||||
withStore $ setIntroReStatus st introId IntroCon
|
||||
sendConMsg toConn reConn toStatus
|
||||
| otherwise -> prohibited
|
||||
where
|
||||
sendConMsg :: ConnId -> ConnId -> IntroStatus -> m ()
|
||||
sendConMsg toConn reConn IntroCon =
|
||||
atomically . writeTBQueue sndQ $ ATransmission "" (Conn toConn) $ ICON $ IE (Conn reConn)
|
||||
sendConMsg _ _ _ = pure ()
|
||||
|
||||
agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
|
||||
agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
|
||||
logServer "<--" c srv rId "MSG <MSG>"
|
||||
case status of
|
||||
Active -> do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
|
||||
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st connId
|
||||
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash receivedPrevMsgHash
|
||||
withStore $
|
||||
createRcvMsg st rq $
|
||||
createRcvMsg st connId $
|
||||
RcvMsgData
|
||||
{ internalId,
|
||||
internalRcvId,
|
||||
@@ -438,8 +556,8 @@ connectToSendQueue c st sq senderKey verifyKey = do
|
||||
withStore $ setSndQueueStatus st sq Active
|
||||
|
||||
newSendQueue ::
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey)
|
||||
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey)
|
||||
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do
|
||||
size <- asks $ rsaKeySize . config
|
||||
(senderKey, sndPrivateKey) <- liftIO $ C.generateKeyPair size
|
||||
(verifyKey, signKey) <- liftIO $ C.generateKeyPair size
|
||||
@@ -447,7 +565,6 @@ newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
|
||||
SndQueue
|
||||
{ server = smpServer,
|
||||
sndId = senderId,
|
||||
connAlias,
|
||||
sndPrivateKey,
|
||||
encryptKey,
|
||||
signKey,
|
||||
|
||||
@@ -15,6 +15,7 @@ module Simplex.Messaging.Agent.Client
|
||||
closeSMPServerClients,
|
||||
newReceiveQueue,
|
||||
subscribeQueue,
|
||||
addSubscription,
|
||||
sendConfirmation,
|
||||
sendHello,
|
||||
secureQueue,
|
||||
@@ -61,8 +62,8 @@ data AgentClient = AgentClient
|
||||
sndQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
smpClients :: TVar (Map SMPServer SMPClient),
|
||||
subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)),
|
||||
subscrConns :: TVar (Map ConnAlias SMPServer),
|
||||
subscrSrvrs :: TVar (Map SMPServer (Set ConnId)),
|
||||
subscrConns :: TVar (Map ConnId SMPServer),
|
||||
clientId :: Int
|
||||
}
|
||||
|
||||
@@ -78,7 +79,7 @@ newAgentClient cc AgentConfig {tbqSize} = do
|
||||
writeTVar cc clientId
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
|
||||
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m, MonadFail m)
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
@@ -106,7 +107,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
removeSubs >>= mapM_ (mapM_ notifySub)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeSubs :: IO (Maybe (Set ConnAlias))
|
||||
removeSubs :: IO (Maybe (Set ConnId))
|
||||
removeSubs = atomically $ do
|
||||
modifyTVar smpClients $ M.delete srv
|
||||
cs <- M.lookup srv <$> readTVar (subscrSrvrs c)
|
||||
@@ -117,7 +118,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
deleteKeys :: Ord k => Set k -> Map k a -> Map k a
|
||||
deleteKeys ks m = S.foldr' M.delete m ks
|
||||
|
||||
notifySub :: ConnAlias -> IO ()
|
||||
notifySub :: ConnId -> IO ()
|
||||
notifySub connAlias = atomically . writeTBQueue (sndQ c) $ ATransmission "" (Conn connAlias) END
|
||||
|
||||
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
|
||||
@@ -158,8 +159,8 @@ smpClientError = \case
|
||||
SMPTransportError e -> BROKER $ TRANSPORT e
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv connAlias = do
|
||||
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueInfo)
|
||||
newReceiveQueue c srv = do
|
||||
size <- asks $ rsaKeySize . config
|
||||
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateKeyPair size
|
||||
logServer "-->" c srv "" "NEW"
|
||||
@@ -170,7 +171,6 @@ newReceiveQueue c srv connAlias = do
|
||||
RcvQueue
|
||||
{ server = srv,
|
||||
rcvId,
|
||||
connAlias,
|
||||
rcvPrivateKey,
|
||||
sndId = Just sId,
|
||||
sndKey = Nothing,
|
||||
@@ -178,25 +178,24 @@ newReceiveQueue c srv connAlias = do
|
||||
verifyKey = Nothing,
|
||||
status = New
|
||||
}
|
||||
addSubscription c rq connAlias
|
||||
return (rq, SMPQueueInfo srv sId encryptKey)
|
||||
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnAlias -> m ()
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connAlias = do
|
||||
withLogSMP c server rcvId "SUB" $ \smp ->
|
||||
subscribeSMPQueue smp rcvPrivateKey rcvId
|
||||
addSubscription c rq connAlias
|
||||
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnAlias -> m ()
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c RcvQueue {server} connAlias = atomically $ do
|
||||
modifyTVar (subscrConns c) $ M.insert connAlias server
|
||||
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
|
||||
where
|
||||
addSub :: Maybe (Set ConnAlias) -> Set ConnAlias
|
||||
addSub :: Maybe (Set ConnId) -> Set ConnId
|
||||
addSub (Just cs) = S.insert connAlias cs
|
||||
addSub _ = S.singleton connAlias
|
||||
|
||||
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
|
||||
removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do
|
||||
cs <- readTVar subscrConns
|
||||
writeTVar subscrConns $ M.delete connAlias cs
|
||||
@@ -204,7 +203,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically
|
||||
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
|
||||
(M.lookup connAlias cs)
|
||||
where
|
||||
delSub :: Set ConnAlias -> Maybe (Set ConnAlias)
|
||||
delSub :: Set ConnId -> Maybe (Set ConnId)
|
||||
delSub cs =
|
||||
let cs' = S.delete connAlias cs
|
||||
in if S.null cs' then Nothing else Just cs'
|
||||
|
||||
@@ -33,6 +33,8 @@ module Simplex.Messaging.Agent.Protocol
|
||||
Entity (..),
|
||||
EntityTag (..),
|
||||
AnEntity (..),
|
||||
IntroEntity (..),
|
||||
EntityInfo,
|
||||
EntityCommand,
|
||||
entityCommand,
|
||||
ACommand (..),
|
||||
@@ -53,7 +55,7 @@ module Simplex.Messaging.Agent.Protocol
|
||||
ATransmission (..),
|
||||
ATransmissionOrError (..),
|
||||
ARawTransmission,
|
||||
ConnAlias,
|
||||
ConnId,
|
||||
ReplyMode (..),
|
||||
AckMode (..),
|
||||
OnOff (..),
|
||||
@@ -171,6 +173,13 @@ deriving instance Eq (Entity t)
|
||||
|
||||
deriving instance Show (Entity t)
|
||||
|
||||
instance TestEquality Entity where
|
||||
testEquality (Conn c) (Conn c') = refl c c'
|
||||
testEquality (OpenConn c) (OpenConn c') = refl c c'
|
||||
testEquality (Broadcast c) (Broadcast c') = refl c c'
|
||||
testEquality (AGroup c) (AGroup c') = refl c c'
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
entityId :: Entity t -> ByteString
|
||||
entityId = \case
|
||||
Conn bs -> bs
|
||||
@@ -195,7 +204,11 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where
|
||||
EntityCommand Conn_ NEW_ = ()
|
||||
EntityCommand Conn_ INV_ = ()
|
||||
EntityCommand Conn_ JOIN_ = ()
|
||||
EntityCommand Conn_ INTRO_ = ()
|
||||
EntityCommand Conn_ REQ_ = ()
|
||||
EntityCommand Conn_ ACPT_ = ()
|
||||
EntityCommand Conn_ CON_ = ()
|
||||
EntityCommand Conn_ ICON_ = ()
|
||||
EntityCommand Conn_ SUB_ = ()
|
||||
EntityCommand Conn_ SUBALL_ = ()
|
||||
EntityCommand Conn_ END_ = ()
|
||||
@@ -226,7 +239,11 @@ entityCommand = \case
|
||||
NEW -> Just Dict
|
||||
INV _ -> Just Dict
|
||||
JOIN {} -> Just Dict
|
||||
INTRO {} -> Just Dict
|
||||
REQ {} -> Just Dict
|
||||
ACPT {} -> Just Dict
|
||||
CON -> Just Dict
|
||||
ICON {} -> Just Dict
|
||||
SUB -> Just Dict
|
||||
SUBALL -> Just Dict
|
||||
END -> Just Dict
|
||||
@@ -258,7 +275,11 @@ data ACmdTag
|
||||
= NEW_
|
||||
| INV_
|
||||
| JOIN_
|
||||
| INTRO_
|
||||
| REQ_
|
||||
| ACPT_
|
||||
| CON_
|
||||
| ICON_
|
||||
| SUB_
|
||||
| SUBALL_
|
||||
| END_
|
||||
@@ -274,15 +295,31 @@ data ACmdTag
|
||||
| OK_
|
||||
| ERR_
|
||||
|
||||
type family Introduction (t :: EntityTag) :: Constraint where
|
||||
Introduction Conn_ = ()
|
||||
Introduction OpenConn_ = ()
|
||||
Introduction AGroup_ = ()
|
||||
Introduction t = (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " cannot be INTRO'd to"))
|
||||
|
||||
data IntroEntity = forall t. Introduction t => IE (Entity t)
|
||||
|
||||
instance Eq IntroEntity where
|
||||
IE e1 == IE e2 = isJust $ testEquality e1 e2
|
||||
|
||||
deriving instance Show IntroEntity
|
||||
|
||||
type EntityInfo = ByteString
|
||||
|
||||
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
|
||||
data ACommand (p :: AParty) (c :: ACmdTag) where
|
||||
NEW :: ACommand Client NEW_ -- response INV
|
||||
INV :: SMPQueueInfo -> ACommand Agent INV_
|
||||
JOIN :: SMPQueueInfo -> ReplyMode -> ACommand Client JOIN_ -- response OK
|
||||
INTRO :: IntroEntity -> EntityInfo -> ACommand Client INTRO_
|
||||
REQ :: IntroEntity -> EntityInfo -> ACommand Agent INTRO_
|
||||
ACPT :: IntroEntity -> EntityInfo -> ACommand Client ACPT_
|
||||
CON :: ACommand Agent CON_ -- notification that connection is established
|
||||
-- TODO currently it automatically allows whoever sends the confirmation
|
||||
-- CONF :: OtherPartyId -> ACommand Agent
|
||||
-- LET :: OtherPartyId -> ACommand Client
|
||||
ICON :: IntroEntity -> ACommand Agent ICON_
|
||||
SUB :: ACommand Client SUB_
|
||||
SUBALL :: ACommand Client SUBALL_ -- TODO should be moved to chat protocol - hack for subscribing to all
|
||||
END :: ACommand Agent END_
|
||||
@@ -318,6 +355,7 @@ instance TestEquality (ACommand p) where
|
||||
testEquality c@INV {} c'@INV {} = refl c c'
|
||||
testEquality c@JOIN {} c'@JOIN {} = refl c c'
|
||||
testEquality CON CON = Just Refl
|
||||
testEquality c@ICON {} c'@ICON {} = refl c c'
|
||||
testEquality SUB SUB = Just Refl
|
||||
testEquality SUBALL SUBALL = Just Refl
|
||||
testEquality END END = Just Refl
|
||||
@@ -334,7 +372,7 @@ instance TestEquality (ACommand p) where
|
||||
testEquality c@ERR {} c'@ERR {} = refl c c'
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
refl :: Eq (f a) => f a -> f a -> Maybe (a :~: a)
|
||||
refl :: Eq a => a -> a -> Maybe (t :~: t)
|
||||
refl x x' = if x == x' then Just Refl else Nothing
|
||||
|
||||
-- | SMP message formats.
|
||||
@@ -366,6 +404,14 @@ data AMessage where
|
||||
REPLY :: SMPQueueInfo -> AMessage
|
||||
-- | agent envelope for the client message
|
||||
A_MSG :: MsgBody -> AMessage
|
||||
-- | agent message for introduction
|
||||
A_INTRO :: IntroEntity -> EntityInfo -> AMessage
|
||||
-- | agent envelope for the sent invitation
|
||||
A_INV :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
|
||||
-- | agent envelope for the forwarded invitation
|
||||
A_REQ :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
|
||||
-- | agent message for intro/group request
|
||||
A_CON :: Entity Conn_ -> AMessage
|
||||
deriving (Show)
|
||||
|
||||
-- | Parse SMP message.
|
||||
@@ -408,12 +454,22 @@ agentMessageP =
|
||||
"HELLO " *> hello
|
||||
<|> "REPLY " *> reply
|
||||
<|> "MSG " *> a_msg
|
||||
<|> "INTRO " *> a_intro
|
||||
<|> "INV " *> a_inv
|
||||
<|> "REQ " *> a_req
|
||||
<|> "CON " *> a_con
|
||||
where
|
||||
hello = HELLO <$> C.pubKeyP <*> ackMode
|
||||
reply = REPLY <$> smpQueueInfoP
|
||||
a_msg = do
|
||||
a_msg = A_MSG <$> binaryBody
|
||||
a_intro = A_INTRO <$> introEntityP <* A.space <*> binaryBody
|
||||
a_inv = invP A_INV
|
||||
a_req = invP A_REQ
|
||||
a_con = A_CON <$> connEntityP
|
||||
invP f = f <$> connEntityP <* A.space <*> smpQueueInfoP <* A.space <*> binaryBody
|
||||
binaryBody = do
|
||||
size :: Int <- A.decimal <* A.endOfLine
|
||||
A_MSG <$> A.take size <* A.endOfLine
|
||||
A.take size <* A.endOfLine
|
||||
ackMode = AckMode <$> (" NO_ACK" $> Off <|> pure On)
|
||||
|
||||
-- | SMP queue information parser.
|
||||
@@ -434,6 +490,13 @@ serializeAgentMessage = \case
|
||||
HELLO verifyKey ackMode -> "HELLO " <> C.serializePubKey verifyKey <> if ackMode == AckMode Off then " NO_ACK" else ""
|
||||
REPLY qInfo -> "REPLY " <> serializeSmpQueueInfo qInfo
|
||||
A_MSG body -> "MSG " <> serializeMsg body <> "\n"
|
||||
A_INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo <> "\n"
|
||||
A_INV conn qInfo eInfo -> "INV " <> serializeInv conn qInfo eInfo
|
||||
A_REQ conn qInfo eInfo -> "REQ " <> serializeInv conn qInfo eInfo
|
||||
A_CON conn -> "CON " <> serializeEntity conn
|
||||
where
|
||||
serializeInv conn qInfo eInfo =
|
||||
B.intercalate " " [serializeEntity conn, serializeSmpQueueInfo qInfo, serializeMsg eInfo] <> "\n"
|
||||
|
||||
-- | Serialize SMP queue information that is sent out-of-band.
|
||||
serializeSmpQueueInfo :: SMPQueueInfo -> ByteString
|
||||
@@ -457,7 +520,7 @@ instance IsString SMPServer where
|
||||
fromString = parseString . parseAll $ smpServerP
|
||||
|
||||
-- | SMP agent connection alias.
|
||||
type ConnAlias = ByteString
|
||||
type ConnId = ByteString
|
||||
|
||||
-- | Connection modes.
|
||||
data OnOff = On | Off deriving (Eq, Show, Read)
|
||||
@@ -614,10 +677,19 @@ anEntityP =
|
||||
<|> "B:" $> AE . Broadcast
|
||||
<|> "G:" $> AE . AGroup
|
||||
)
|
||||
<*> A.takeTill (== ' ')
|
||||
<*> A.takeTill wordEnd
|
||||
|
||||
entityConnP :: Parser (Entity Conn_)
|
||||
entityConnP = "C:" *> (Conn <$> A.takeTill (== ' '))
|
||||
connEntityP :: Parser (Entity Conn_)
|
||||
connEntityP = "C:" *> (Conn <$> A.takeTill wordEnd)
|
||||
|
||||
introEntityP :: Parser IntroEntity
|
||||
introEntityP =
|
||||
($)
|
||||
<$> ( "C:" $> IE . Conn
|
||||
<|> "O:" $> IE . OpenConn
|
||||
<|> "G:" $> IE . AGroup
|
||||
)
|
||||
<*> A.takeTill wordEnd
|
||||
|
||||
serializeEntity :: Entity t -> ByteString
|
||||
serializeEntity = \case
|
||||
@@ -632,6 +704,9 @@ commandP =
|
||||
"NEW" $> ACmd SClient NEW
|
||||
<|> "INV " *> invResp
|
||||
<|> "JOIN " *> joinCmd
|
||||
<|> "INTRO " *> introCmd
|
||||
<|> "REQ " *> reqCmd
|
||||
<|> "ACPT " *> acptCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all
|
||||
<|> "END" $> ACmd SAgent END
|
||||
@@ -645,16 +720,21 @@ commandP =
|
||||
<|> "LS" $> ACmd SClient LS
|
||||
<|> "MS " *> membersResp
|
||||
<|> "ERR " *> agentError
|
||||
<|> "ICON " *> iconMsg
|
||||
<|> "CON" $> ACmd SAgent CON
|
||||
<|> "OK" $> ACmd SAgent OK
|
||||
where
|
||||
invResp = ACmd SAgent . INV <$> smpQueueInfoP
|
||||
joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode)
|
||||
introCmd = ACmd SClient <$> introP INTRO
|
||||
reqCmd = ACmd SAgent <$> introP REQ
|
||||
acptCmd = ACmd SClient <$> introP ACPT
|
||||
sendCmd = ACmd SClient . SEND <$> A.takeByteString
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
addCmd = ACmd SClient . ADD <$> entityConnP
|
||||
removeCmd = ACmd SClient . REM <$> entityConnP
|
||||
membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ')
|
||||
addCmd = ACmd SClient . ADD <$> connEntityP
|
||||
removeCmd = ACmd SClient . REM <$> connEntityP
|
||||
membersResp = ACmd SAgent . MS <$> (connEntityP `A.sepBy'` A.char ' ')
|
||||
iconMsg = ACmd SAgent . ICON <$> introEntityP
|
||||
message = do
|
||||
msgIntegrity <- msgIntegrityP <* A.space
|
||||
recipientMeta <- "R=" *> partyMeta A.decimal
|
||||
@@ -662,6 +742,7 @@ commandP =
|
||||
senderMeta <- "S=" *> partyMeta A.decimal
|
||||
msgBody <- A.takeByteString
|
||||
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
|
||||
introP f = f <$> introEntityP <* A.space <*> A.takeByteString
|
||||
replyMode = ReplyMode <$> (" NO_REPLY" $> Off <|> pure On)
|
||||
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
|
||||
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
|
||||
@@ -685,6 +766,9 @@ serializeCommand = \case
|
||||
NEW -> "NEW"
|
||||
INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo
|
||||
JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode
|
||||
INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo
|
||||
REQ (IE entity) eInfo -> "REQ " <> serializeIntro entity eInfo
|
||||
ACPT (IE entity) eInfo -> "ACPT " <> serializeIntro entity eInfo
|
||||
SUB -> "SUB"
|
||||
SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all
|
||||
END -> "END"
|
||||
@@ -706,6 +790,7 @@ serializeCommand = \case
|
||||
LS -> "LS"
|
||||
MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs)
|
||||
CON -> "CON"
|
||||
ICON (IE entity) -> "ICON " <> serializeEntity entity
|
||||
ERR e -> "ERR " <> serializeAgentError e
|
||||
OK -> "OK"
|
||||
where
|
||||
@@ -716,6 +801,9 @@ serializeCommand = \case
|
||||
showTs :: UTCTime -> ByteString
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
|
||||
serializeIntro :: Entity t -> ByteString -> ByteString
|
||||
serializeIntro entity eInfo = serializeEntity entity <> " " <> serializeMsg eInfo
|
||||
|
||||
-- | Serialize message integrity validation result.
|
||||
serializeMsgIntegrity :: MsgIntegrity -> ByteString
|
||||
serializeMsgIntegrity = \case
|
||||
@@ -794,9 +882,10 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
hasEntityId :: AnEntity -> APartyCmd p -> Either AgentErrorType (APartyCmd p)
|
||||
hasEntityId (AE entity) (APartyCmd cmd) =
|
||||
APartyCmd <$> case cmd of
|
||||
-- NEW and JOIN have optional entity
|
||||
-- NEW, JOIN and ACPT have optional entity
|
||||
NEW -> Right cmd
|
||||
JOIN _ _ -> Right cmd
|
||||
JOIN {} -> Right cmd
|
||||
ACPT {} -> Right cmd
|
||||
-- ERROR response does not always have entity
|
||||
ERR _ -> Right cmd
|
||||
-- other responses must have entity
|
||||
@@ -818,6 +907,9 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
APartyCmd <$$> case cmd of
|
||||
SEND body -> SEND <$$> getMsgBody body
|
||||
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
|
||||
INTRO entity eInfo -> INTRO entity <$$> getMsgBody eInfo
|
||||
REQ entity eInfo -> REQ entity <$$> getMsgBody eInfo
|
||||
ACPT entity eInfo -> ACPT entity <$$> getMsgBody eInfo
|
||||
_ -> pure $ Right cmd
|
||||
|
||||
-- TODO refactor with server
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store where
|
||||
|
||||
import Control.Concurrent.STM (TVar)
|
||||
import Control.Exception (Exception)
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
import Data.Text (Text)
|
||||
import Data.Time (UTCTime)
|
||||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
@@ -30,33 +35,45 @@ import qualified Simplex.Messaging.Protocol as SMP
|
||||
-- | Store class type. Defines store access methods for implementations.
|
||||
class Monad m => MonadAgentStore s m where
|
||||
-- Queue and Connection management
|
||||
createRcvConn :: s -> RcvQueue -> m ()
|
||||
createSndConn :: s -> SndQueue -> m ()
|
||||
getConn :: s -> ConnAlias -> m SomeConn
|
||||
getAllConnAliases :: s -> m [ConnAlias] -- TODO remove - hack for subscribing to all
|
||||
createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
|
||||
createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
|
||||
getConn :: s -> ConnId -> m SomeConn
|
||||
getAllConnIds :: s -> m [ConnId] -- TODO remove - hack for subscribing to all
|
||||
getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn
|
||||
deleteConn :: s -> ConnAlias -> m ()
|
||||
upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m ()
|
||||
deleteConn :: s -> ConnId -> m ()
|
||||
upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m ()
|
||||
upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m ()
|
||||
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m ()
|
||||
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
|
||||
|
||||
-- Msg management
|
||||
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
|
||||
updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> ConnId -> RcvMsgData -> m ()
|
||||
|
||||
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
|
||||
updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> ConnId -> SndMsgData -> m ()
|
||||
|
||||
getMsg :: s -> ConnAlias -> InternalId -> m Msg
|
||||
getMsg :: s -> ConnId -> InternalId -> m Msg
|
||||
|
||||
-- Broadcasts
|
||||
createBcast :: s -> BroadcastId -> m ()
|
||||
addBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
|
||||
removeBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
|
||||
createBcast :: s -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
|
||||
addBcastConn :: s -> BroadcastId -> ConnId -> m ()
|
||||
removeBcastConn :: s -> BroadcastId -> ConnId -> m ()
|
||||
deleteBcast :: s -> BroadcastId -> m ()
|
||||
getBcast :: s -> BroadcastId -> m [ConnAlias]
|
||||
getBcast :: s -> BroadcastId -> m [ConnId]
|
||||
|
||||
-- Introductions
|
||||
createIntro :: s -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
|
||||
getIntro :: s -> IntroId -> m Introduction
|
||||
addIntroInvitation :: s -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
|
||||
setIntroToStatus :: s -> IntroId -> IntroStatus -> m ()
|
||||
setIntroReStatus :: s -> IntroId -> IntroStatus -> m ()
|
||||
createInvitation :: s -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
|
||||
getInvitation :: s -> InvitationId -> m Invitation
|
||||
addInvitationConn :: s -> InvitationId -> ConnId -> m ()
|
||||
getConnInvitation :: s -> ConnId -> m (Maybe (Invitation, Connection CDuplex))
|
||||
setInvitationStatus :: s -> InvitationId -> InvitationStatus -> m ()
|
||||
|
||||
-- * Queue types
|
||||
|
||||
@@ -64,7 +81,6 @@ class Monad m => MonadAgentStore s m where
|
||||
data RcvQueue = RcvQueue
|
||||
{ server :: SMPServer,
|
||||
rcvId :: SMP.RecipientId,
|
||||
connAlias :: ConnAlias,
|
||||
rcvPrivateKey :: RecipientPrivateKey,
|
||||
sndId :: Maybe SMP.SenderId,
|
||||
sndKey :: Maybe SenderPublicKey,
|
||||
@@ -78,7 +94,6 @@ data RcvQueue = RcvQueue
|
||||
data SndQueue = SndQueue
|
||||
{ server :: SMPServer,
|
||||
sndId :: SMP.SenderId,
|
||||
connAlias :: ConnAlias,
|
||||
sndPrivateKey :: SenderPrivateKey,
|
||||
encryptKey :: EncryptionKey,
|
||||
signKey :: SignatureKey,
|
||||
@@ -102,9 +117,9 @@ data ConnType = CRcv | CSnd | CDuplex deriving (Eq, Show)
|
||||
-- - DuplexConnection is a connection that has both receive and send queues set up,
|
||||
-- typically created by upgrading a receive or a send connection with a missing queue.
|
||||
data Connection (d :: ConnType) where
|
||||
RcvConnection :: ConnAlias -> RcvQueue -> Connection CRcv
|
||||
SndConnection :: ConnAlias -> SndQueue -> Connection CSnd
|
||||
DuplexConnection :: ConnAlias -> RcvQueue -> SndQueue -> Connection CDuplex
|
||||
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
|
||||
SndConnection :: ConnData -> SndQueue -> Connection CSnd
|
||||
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
|
||||
|
||||
deriving instance Eq (Connection d)
|
||||
|
||||
@@ -141,6 +156,9 @@ instance Eq SomeConn where
|
||||
|
||||
deriving instance Show SomeConn
|
||||
|
||||
data ConnData = ConnData {connId :: ConnId, viaInv :: Maybe InvitationId, connLevel :: Int}
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- * Message integrity validation types
|
||||
|
||||
type MsgHash = ByteString
|
||||
@@ -263,7 +281,7 @@ type DeliveredTs = UTCTime
|
||||
|
||||
-- | Base message data independent of direction.
|
||||
data MsgBase = MsgBase
|
||||
{ connAlias :: ConnAlias,
|
||||
{ connAlias :: ConnId,
|
||||
-- | Monotonically increasing id of a message per connection, internal to the agent.
|
||||
-- Internal Id preserves ordering between both received and sent messages, and is needed
|
||||
-- to track the order of the conversation (which can be different for the sender / receiver)
|
||||
@@ -281,12 +299,87 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show)
|
||||
|
||||
type InternalTs = UTCTime
|
||||
|
||||
-- * Introduction types
|
||||
|
||||
data NewIntroduction = NewIntroduction
|
||||
{ toConn :: ConnId,
|
||||
reConn :: ConnId,
|
||||
reInfo :: ByteString
|
||||
}
|
||||
|
||||
data Introduction = Introduction
|
||||
{ introId :: IntroId,
|
||||
toConn :: ConnId,
|
||||
toInfo :: Maybe ByteString,
|
||||
toStatus :: IntroStatus,
|
||||
reConn :: ConnId,
|
||||
reInfo :: ByteString,
|
||||
reStatus :: IntroStatus,
|
||||
qInfo :: Maybe SMPQueueInfo
|
||||
}
|
||||
|
||||
data IntroStatus = IntroNew | IntroInv | IntroCon
|
||||
deriving (Eq)
|
||||
|
||||
serializeIntroStatus :: IntroStatus -> Text
|
||||
serializeIntroStatus = \case
|
||||
IntroNew -> ""
|
||||
IntroInv -> "INV"
|
||||
IntroCon -> "CON"
|
||||
|
||||
introStatusT :: Text -> Maybe IntroStatus
|
||||
introStatusT = \case
|
||||
"" -> Just IntroNew
|
||||
"INV" -> Just IntroInv
|
||||
"CON" -> Just IntroCon
|
||||
_ -> Nothing
|
||||
|
||||
type IntroId = ByteString
|
||||
|
||||
data NewInvitation = NewInvitation
|
||||
{ viaConn :: ConnId,
|
||||
externalIntroId :: IntroId,
|
||||
entityInfo :: EntityInfo,
|
||||
qInfo :: Maybe SMPQueueInfo
|
||||
}
|
||||
|
||||
data Invitation = Invitation
|
||||
{ invId :: InvitationId,
|
||||
viaConn :: ConnId,
|
||||
externalIntroId :: IntroId,
|
||||
entityInfo :: EntityInfo,
|
||||
qInfo :: Maybe SMPQueueInfo,
|
||||
connId :: Maybe ConnId,
|
||||
status :: InvitationStatus
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data InvitationStatus = InvNew | InvAcpt | InvCon
|
||||
deriving (Eq, Show)
|
||||
|
||||
serializeInvStatus :: InvitationStatus -> Text
|
||||
serializeInvStatus = \case
|
||||
InvNew -> ""
|
||||
InvAcpt -> "ACPT"
|
||||
InvCon -> "CON"
|
||||
|
||||
invStatusT :: Text -> Maybe InvitationStatus
|
||||
invStatusT = \case
|
||||
"" -> Just InvNew
|
||||
"ACPT" -> Just InvAcpt
|
||||
"CON" -> Just InvCon
|
||||
_ -> Nothing
|
||||
|
||||
type InvitationId = ByteString
|
||||
|
||||
-- * Store errors
|
||||
|
||||
-- | Agent store error.
|
||||
data StoreError
|
||||
= -- | IO exceptions in store actions.
|
||||
SEInternal ByteString
|
||||
| -- | failed to generate unique random ID
|
||||
SEUniqueID
|
||||
| -- | Connection alias not found (or both queues absent).
|
||||
SEConnNotFound
|
||||
| -- | Connection alias already used.
|
||||
@@ -298,6 +391,10 @@ data StoreError
|
||||
SEBcastNotFound
|
||||
| -- | Broadcast ID already used.
|
||||
SEBcastDuplicate
|
||||
| -- | Introduction ID not found.
|
||||
SEIntroNotFound
|
||||
| -- | Invitation ID not found.
|
||||
SEInvitationNotFound
|
||||
| -- | Currently not used. The intention was to pass current expected queue status in methods,
|
||||
-- as we always know what it should be at any stage of the protocol,
|
||||
-- and in case it does not match use this error.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
@@ -11,6 +12,7 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
@@ -22,14 +24,19 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Monad (unless, when)
|
||||
import Control.Concurrent.STM (TVar, atomically, stateTVar)
|
||||
import Control.Monad (join, unless, when)
|
||||
import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO))
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.ByteString.Base64 (encode)
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (find)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Text (isPrefixOf)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, field)
|
||||
@@ -66,8 +73,8 @@ createSQLiteStore dbFilePath = do
|
||||
let dbDir = takeDirectory dbFilePath
|
||||
createDirectoryIfMissing False dbDir
|
||||
store <- connectSQLiteStore dbFilePath
|
||||
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
|
||||
let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions)
|
||||
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[Text]]
|
||||
let threadsafeOption = find (T.isPrefixOf "THREADSAFE=") (concat compileOptions)
|
||||
case threadsafeOption of
|
||||
Just "THREADSAFE=0" -> confirmOrExit "SQLite compiled with non-threadsafe code."
|
||||
Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found"
|
||||
@@ -107,13 +114,16 @@ connectSQLiteStore dbFilePath = do
|
||||
|]
|
||||
pure SQLiteStore {dbFilePath, dbConn, dbNew}
|
||||
|
||||
checkConstraint :: StoreError -> IO () -> IO (Either StoreError ())
|
||||
checkConstraint err action = first handleError <$> E.try action
|
||||
where
|
||||
handleError :: SQLError -> StoreError
|
||||
handleError e
|
||||
| DB.sqlError e == DB.ErrorConstraint = err
|
||||
| otherwise = SEInternal $ bshow e
|
||||
checkConstraint :: StoreError -> IO a -> IO (Either StoreError a)
|
||||
checkConstraint err action = first (handleSQLError err) <$> E.try action
|
||||
|
||||
checkConstraint' :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a)
|
||||
checkConstraint' err action = action `E.catch` (pure . Left . handleSQLError err)
|
||||
|
||||
handleSQLError :: StoreError -> SQLError -> StoreError
|
||||
handleSQLError err e
|
||||
| DB.sqlError e == DB.ErrorConstraint = err
|
||||
| otherwise = SEInternal $ bshow e
|
||||
|
||||
withTransaction :: forall a. DB.Connection -> IO a -> IO a
|
||||
withTransaction db a = loop 100 100_000
|
||||
@@ -128,33 +138,43 @@ withTransaction db a = loop 100 100_000
|
||||
else E.throwIO e
|
||||
|
||||
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
|
||||
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
|
||||
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
|
||||
liftIOEither $
|
||||
checkConstraint SEConnDuplicate $
|
||||
withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn q
|
||||
insertRcvConnection_ dbConn q
|
||||
createRcvConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
|
||||
createRcvConn SQLiteStore {dbConn} gVar cData q@RcvQueue {server} =
|
||||
-- TODO if schema has to be restarted, this function can be refactored
|
||||
-- to create connection first using createWithRandomId
|
||||
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
|
||||
getConnId_ dbConn gVar cData >>= traverse create
|
||||
where
|
||||
create :: ConnId -> IO ConnId
|
||||
create connId = do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn connId q
|
||||
insertRcvConnection_ dbConn cData {connId} q
|
||||
pure connId
|
||||
|
||||
createSndConn :: SQLiteStore -> SndQueue -> m ()
|
||||
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
|
||||
liftIOEither $
|
||||
checkConstraint SEConnDuplicate $
|
||||
withTransaction dbConn $ do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn q
|
||||
insertSndConnection_ dbConn q
|
||||
createSndConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
|
||||
createSndConn SQLiteStore {dbConn} gVar cData q@SndQueue {server} =
|
||||
-- TODO if schema has to be restarted, this function can be refactored
|
||||
-- to create connection first using createWithRandomId
|
||||
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
|
||||
getConnId_ dbConn gVar cData >>= traverse create
|
||||
where
|
||||
create :: ConnId -> IO ConnId
|
||||
create connId = do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn connId q
|
||||
insertSndConnection_ dbConn cData {connId} q
|
||||
pure connId
|
||||
|
||||
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
|
||||
getConn SQLiteStore {dbConn} connAlias =
|
||||
getConn :: SQLiteStore -> ConnId -> m SomeConn
|
||||
getConn SQLiteStore {dbConn} connId =
|
||||
liftIOEither . withTransaction dbConn $
|
||||
getConn_ dbConn connAlias
|
||||
getConn_ dbConn connId
|
||||
|
||||
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
|
||||
getAllConnAliases SQLiteStore {dbConn} =
|
||||
getAllConnIds :: SQLiteStore -> m [ConnId]
|
||||
getAllConnIds SQLiteStore {dbConn} =
|
||||
liftIO $ do
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
|
||||
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnId]]
|
||||
return (concat r)
|
||||
|
||||
getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn
|
||||
@@ -169,37 +189,37 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
|]
|
||||
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
|
||||
>>= \case
|
||||
[Only connAlias] -> getConn_ dbConn connAlias
|
||||
[Only connId] -> getConn_ dbConn connId
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
deleteConn :: SQLiteStore -> ConnAlias -> m ()
|
||||
deleteConn SQLiteStore {dbConn} connAlias =
|
||||
deleteConn :: SQLiteStore -> ConnId -> m ()
|
||||
deleteConn SQLiteStore {dbConn} connId =
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
|
||||
[":conn_alias" := connAlias]
|
||||
[":conn_alias" := connId]
|
||||
|
||||
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
|
||||
upgradeRcvConnToDuplex :: SQLiteStore -> ConnId -> SndQueue -> m ()
|
||||
upgradeRcvConnToDuplex SQLiteStore {dbConn} connId sq@SndQueue {server} =
|
||||
liftIOEither . withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
getConn_ dbConn connId >>= \case
|
||||
Right (SomeConn _ RcvConnection {}) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertSndQueue_ dbConn sq
|
||||
updateConnWithSndQueue_ dbConn connAlias sq
|
||||
insertSndQueue_ dbConn connId sq
|
||||
updateConnWithSndQueue_ dbConn connId sq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
|
||||
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
|
||||
upgradeSndConnToDuplex :: SQLiteStore -> ConnId -> RcvQueue -> m ()
|
||||
upgradeSndConnToDuplex SQLiteStore {dbConn} connId rq@RcvQueue {server} =
|
||||
liftIOEither . withTransaction dbConn $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
getConn_ dbConn connId >>= \case
|
||||
Right (SomeConn _ SndConnection {}) -> do
|
||||
upsertServer_ dbConn server
|
||||
insertRcvQueue_ dbConn rq
|
||||
updateConnWithRcvQueue_ dbConn connAlias rq
|
||||
insertRcvQueue_ dbConn connId rq
|
||||
updateConnWithRcvQueue_ dbConn connId rq
|
||||
pure $ Right ()
|
||||
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
_ -> pure $ Left SEConnNotFound
|
||||
@@ -248,83 +268,233 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
|]
|
||||
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
|
||||
|
||||
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
|
||||
updateRcvIds :: SQLiteStore -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
updateRcvIds SQLiteStore {dbConn} connId =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
|
||||
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connId
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
|
||||
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
|
||||
updateLastIdsRcv_ dbConn connId internalId internalRcvId
|
||||
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
|
||||
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
|
||||
createRcvMsg :: SQLiteStore -> ConnId -> RcvMsgData -> m ()
|
||||
createRcvMsg SQLiteStore {dbConn} connId rcvMsgData =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
insertRcvMsgBase_ dbConn connAlias rcvMsgData
|
||||
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
|
||||
updateHashRcv_ dbConn connAlias rcvMsgData
|
||||
insertRcvMsgBase_ dbConn connId rcvMsgData
|
||||
insertRcvMsgDetails_ dbConn connId rcvMsgData
|
||||
updateHashRcv_ dbConn connId rcvMsgData
|
||||
|
||||
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
|
||||
updateSndIds :: SQLiteStore -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
updateSndIds SQLiteStore {dbConn} connId =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
|
||||
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connId
|
||||
let internalId = InternalId $ unId lastInternalId + 1
|
||||
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
|
||||
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
|
||||
updateLastIdsSnd_ dbConn connId internalId internalSndId
|
||||
pure (internalId, internalSndId, prevSndHash)
|
||||
|
||||
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
|
||||
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
|
||||
createSndMsg :: SQLiteStore -> ConnId -> SndMsgData -> m ()
|
||||
createSndMsg SQLiteStore {dbConn} connId sndMsgData =
|
||||
liftIO . withTransaction dbConn $ do
|
||||
insertSndMsgBase_ dbConn connAlias sndMsgData
|
||||
insertSndMsgDetails_ dbConn connAlias sndMsgData
|
||||
updateHashSnd_ dbConn connAlias sndMsgData
|
||||
insertSndMsgBase_ dbConn connId sndMsgData
|
||||
insertSndMsgDetails_ dbConn connId sndMsgData
|
||||
updateHashSnd_ dbConn connId sndMsgData
|
||||
|
||||
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
|
||||
getMsg :: SQLiteStore -> ConnId -> InternalId -> m Msg
|
||||
getMsg _st _connAlias _id = throwError SENotImplemented
|
||||
|
||||
createBcast :: SQLiteStore -> BroadcastId -> m ()
|
||||
createBcast SQLiteStore {dbConn} bId =
|
||||
liftIOEither $
|
||||
checkConstraint SEBcastDuplicate $
|
||||
DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
|
||||
createBcast :: SQLiteStore -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
|
||||
createBcast SQLiteStore {dbConn} gVar bcastId = liftIOEither $ case bcastId of
|
||||
"" -> createWithRandomId gVar create
|
||||
bId -> checkConstraint SEBcastDuplicate $ create bId $> bId
|
||||
where
|
||||
create bId = DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
|
||||
|
||||
addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
|
||||
addBcastConn SQLiteStore {dbConn} bId connAlias =
|
||||
addBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
|
||||
addBcastConn SQLiteStore {dbConn} bId connId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
getConn_ dbConn connAlias >>= \case
|
||||
getConn_ dbConn connId >>= \case
|
||||
Left _ -> pure $ Left SEConnNotFound
|
||||
Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv
|
||||
Right _ ->
|
||||
checkConstraint SEConnDuplicate $
|
||||
DB.execute
|
||||
dbConn
|
||||
"INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);"
|
||||
(bId, connAlias)
|
||||
[sql|
|
||||
INSERT INTO broadcast_connections
|
||||
(broadcast_id, conn_alias) VALUES (?, ?);
|
||||
|]
|
||||
(bId, connId)
|
||||
|
||||
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
|
||||
removeBcastConn SQLiteStore {dbConn} bId connAlias =
|
||||
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
|
||||
removeBcastConn SQLiteStore {dbConn} bId connId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
bcastConnExists_ dbConn bId connAlias >>= \case
|
||||
bcastConnExists_ dbConn bId connId >>= \case
|
||||
False -> pure $ Left SEConnNotFound
|
||||
_ ->
|
||||
Right
|
||||
<$> DB.execute
|
||||
dbConn
|
||||
"DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;"
|
||||
(bId, connAlias)
|
||||
[sql|
|
||||
DELETE FROM broadcast_connections
|
||||
WHERE broadcast_id = ? AND conn_alias = ?;
|
||||
|]
|
||||
(bId, connId)
|
||||
|
||||
deleteBcast :: SQLiteStore -> BroadcastId -> m ()
|
||||
deleteBcast SQLiteStore {dbConn} bId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
|
||||
|
||||
getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias]
|
||||
getBcast :: SQLiteStore -> BroadcastId -> m [ConnId]
|
||||
getBcast SQLiteStore {dbConn} bId =
|
||||
liftIOEither . checkBroadcast dbConn bId $
|
||||
Right . map fromOnly
|
||||
<$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId)
|
||||
|
||||
createIntro :: SQLiteStore -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
|
||||
createIntro SQLiteStore {dbConn} gVar NewIntroduction {toConn, reConn, reInfo} =
|
||||
liftIOEither . createWithRandomId gVar $ \introId ->
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO conn_intros
|
||||
(intro_id, to_conn, re_conn, re_info) VALUES (?, ?, ?, ?);
|
||||
|]
|
||||
(introId, toConn, reConn, reInfo)
|
||||
|
||||
getIntro :: SQLiteStore -> IntroId -> m Introduction
|
||||
getIntro SQLiteStore {dbConn} introId =
|
||||
liftIOEither $
|
||||
intro
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT to_conn, to_info, to_status, re_conn, re_info, re_status, queue_info
|
||||
FROM conn_intros
|
||||
WHERE intro_id = ?;
|
||||
|]
|
||||
(Only introId)
|
||||
where
|
||||
intro [(toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo)] =
|
||||
Right $ Introduction {introId, toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo}
|
||||
intro _ = Left SEIntroNotFound
|
||||
|
||||
addIntroInvitation :: SQLiteStore -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
|
||||
addIntroInvitation SQLiteStore {dbConn} introId toInfo qInfo =
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_intros
|
||||
SET to_info = :to_info,
|
||||
queue_info = :queue_info,
|
||||
to_status = :to_status
|
||||
WHERE intro_id = :intro_id;
|
||||
|]
|
||||
[ ":to_info" := toInfo,
|
||||
":queue_info" := Just qInfo,
|
||||
":to_status" := IntroInv,
|
||||
":intro_id" := introId
|
||||
]
|
||||
|
||||
setIntroToStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
|
||||
setIntroToStatus SQLiteStore {dbConn} introId toStatus =
|
||||
liftIO $
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_intros
|
||||
SET to_status = ?
|
||||
WHERE intro_id = ?;
|
||||
|]
|
||||
(toStatus, introId)
|
||||
|
||||
setIntroReStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
|
||||
setIntroReStatus SQLiteStore {dbConn} introId reStatus =
|
||||
liftIO $
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_intros
|
||||
SET re_status = ?
|
||||
WHERE intro_id = ?;
|
||||
|]
|
||||
(reStatus, introId)
|
||||
|
||||
createInvitation :: SQLiteStore -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
|
||||
createInvitation SQLiteStore {dbConn} gVar NewInvitation {viaConn, externalIntroId, entityInfo, qInfo} =
|
||||
liftIOEither . createWithRandomId gVar $ \invId ->
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO conn_invitations
|
||||
(inv_id, via_conn, external_intro_id, conn_info, queue_info) VALUES (?, ?, ?, ?, ?);
|
||||
|]
|
||||
(invId, viaConn, externalIntroId, entityInfo, qInfo)
|
||||
|
||||
getInvitation :: SQLiteStore -> InvitationId -> m Invitation
|
||||
getInvitation SQLiteStore {dbConn} invId =
|
||||
liftIOEither $
|
||||
invitation
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT via_conn, external_intro_id, conn_info, queue_info, conn_id, status
|
||||
FROM conn_invitations
|
||||
WHERE inv_id = ?;
|
||||
|]
|
||||
(Only invId)
|
||||
where
|
||||
invitation [(viaConn, externalIntroId, entityInfo, qInfo, connId, status)] =
|
||||
Right $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId, status}
|
||||
invitation _ = Left SEInvitationNotFound
|
||||
|
||||
addInvitationConn :: SQLiteStore -> InvitationId -> ConnId -> m ()
|
||||
addInvitationConn SQLiteStore {dbConn} invId connId =
|
||||
liftIO $
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_invitations
|
||||
SET conn_id = :conn_id, status = :status
|
||||
WHERE inv_id = :inv_id;
|
||||
|]
|
||||
[":conn_id" := connId, ":status" := InvAcpt, ":inv_id" := invId]
|
||||
|
||||
getConnInvitation :: SQLiteStore -> ConnId -> m (Maybe (Invitation, Connection 'CDuplex))
|
||||
getConnInvitation SQLiteStore {dbConn} cId =
|
||||
liftIO . withTransaction dbConn $
|
||||
DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT inv_id, via_conn, external_intro_id, conn_info, queue_info, status
|
||||
FROM conn_invitations
|
||||
WHERE conn_id = ?;
|
||||
|]
|
||||
(Only cId)
|
||||
>>= fmap join . traverse getViaConn . invitation
|
||||
where
|
||||
invitation [(invId, viaConn, externalIntroId, entityInfo, qInfo, status)] =
|
||||
Just $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId = Just cId, status}
|
||||
invitation _ = Nothing
|
||||
getViaConn :: Invitation -> IO (Maybe (Invitation, Connection 'CDuplex))
|
||||
getViaConn inv@Invitation {viaConn} = fmap (inv,) . duplexConn <$> getConn_ dbConn viaConn
|
||||
duplexConn :: Either StoreError SomeConn -> Maybe (Connection 'CDuplex)
|
||||
duplexConn (Right (SomeConn SCDuplex conn)) = Just conn
|
||||
duplexConn _ = Nothing
|
||||
|
||||
setInvitationStatus :: SQLiteStore -> InvitationId -> InvitationStatus -> m ()
|
||||
setInvitationStatus SQLiteStore {dbConn} invId status =
|
||||
liftIO $
|
||||
DB.execute
|
||||
dbConn
|
||||
[sql|
|
||||
UPDATE conn_invitations
|
||||
SET status = ? WHERE inv_id = ?;
|
||||
|]
|
||||
(status, invId)
|
||||
|
||||
-- * Auxiliary helpers
|
||||
|
||||
-- ? replace with ToField? - it's easy to forget to use this
|
||||
@@ -337,7 +507,7 @@ deserializePort_ port = Just port
|
||||
|
||||
instance ToField QueueStatus where toField = toField . show
|
||||
|
||||
instance FromField QueueStatus where fromField = fromFieldToReadable_
|
||||
instance FromField QueueStatus where fromField = fromTextField_ $ readMaybe . T.unpack
|
||||
|
||||
instance ToField InternalRcvId where toField (InternalRcvId x) = toField x
|
||||
|
||||
@@ -359,13 +529,24 @@ instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity
|
||||
|
||||
instance FromField MsgIntegrity where fromField = blobFieldParser msgIntegrityP
|
||||
|
||||
fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a
|
||||
fromFieldToReadable_ = \case
|
||||
instance ToField IntroStatus where toField = toField . serializeIntroStatus
|
||||
|
||||
instance FromField IntroStatus where fromField = fromTextField_ introStatusT
|
||||
|
||||
instance ToField InvitationStatus where toField = toField . serializeInvStatus
|
||||
|
||||
instance FromField InvitationStatus where fromField = fromTextField_ invStatusT
|
||||
|
||||
instance ToField SMPQueueInfo where toField = toField . serializeSmpQueueInfo
|
||||
|
||||
instance FromField SMPQueueInfo where fromField = blobFieldParser smpQueueInfoP
|
||||
|
||||
fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ fromText = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
let str = T.unpack t
|
||||
in case readMaybe str of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid string: " <> str)
|
||||
case fromText t of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
{- ORMOLU_DISABLE -}
|
||||
@@ -397,8 +578,8 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
insertRcvQueue_ dbConn connId RcvQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -411,7 +592,7 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
[ ":host" := host server,
|
||||
":port" := port_,
|
||||
":rcv_id" := rcvId,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":rcv_private_key" := rcvPrivateKey,
|
||||
":snd_id" := sndId,
|
||||
":snd_key" := sndKey,
|
||||
@@ -420,26 +601,29 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
|
||||
":status" := status
|
||||
]
|
||||
|
||||
insertRcvConnection_ :: DB.Connection -> RcvQueue -> IO ()
|
||||
insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
|
||||
insertRcvConnection_ :: DB.Connection -> ConnData -> RcvQueue -> IO ()
|
||||
insertRcvConnection_ dbConn ConnData {connId, viaInv, connLevel} RcvQueue {server, rcvId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
|
||||
0, 0, 0, 0, x'', x'');
|
||||
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, :via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
|
||||
[ ":conn_alias" := connId,
|
||||
":rcv_host" := host server,
|
||||
":rcv_port" := port_,
|
||||
":rcv_id" := rcvId,
|
||||
":via_inv" := viaInv,
|
||||
":conn_level" := connLevel
|
||||
]
|
||||
|
||||
-- * createSndConn helpers
|
||||
|
||||
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn SndQueue {..} = do
|
||||
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
insertSndQueue_ dbConn connId SndQueue {..} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -452,85 +636,96 @@ insertSndQueue_ dbConn SndQueue {..} = do
|
||||
[ ":host" := host server,
|
||||
":port" := port_,
|
||||
":snd_id" := sndId,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":snd_private_key" := sndPrivateKey,
|
||||
":encrypt_key" := encryptKey,
|
||||
":sign_key" := signKey,
|
||||
":status" := status
|
||||
]
|
||||
|
||||
insertSndConnection_ :: DB.Connection -> SndQueue -> IO ()
|
||||
insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
|
||||
insertSndConnection_ :: DB.Connection -> ConnData -> SndQueue -> IO ()
|
||||
insertSndConnection_ dbConn ConnData {connId, viaInv, connLevel} SndQueue {server, sndId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
INSERT INTO connections
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
|
||||
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
|
||||
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
|
||||
VALUES
|
||||
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
|
||||
0, 0, 0, 0, x'', x'');
|
||||
(:conn_alias, NULL, NULL, NULL, :snd_host,:snd_port,:snd_id,:via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|
||||
|]
|
||||
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
|
||||
[ ":conn_alias" := connId,
|
||||
":snd_host" := host server,
|
||||
":snd_port" := port_,
|
||||
":snd_id" := sndId,
|
||||
":via_inv" := viaInv,
|
||||
":conn_level" := connLevel
|
||||
]
|
||||
|
||||
-- * getConn helpers
|
||||
|
||||
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
|
||||
getConn_ dbConn connAlias = do
|
||||
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
|
||||
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
|
||||
pure $ case (rQ, sQ) of
|
||||
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
|
||||
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
|
||||
_ -> Left SEConnNotFound
|
||||
getConn_ :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
|
||||
getConn_ dbConn connId =
|
||||
getConnData_ dbConn connId >>= \case
|
||||
Nothing -> pure $ Left SEConnNotFound
|
||||
Just connData -> do
|
||||
rQ <- getRcvQueueByConnAlias_ dbConn connId
|
||||
sQ <- getSndQueueByConnAlias_ dbConn connId
|
||||
pure $ case (rQ, sQ) of
|
||||
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
|
||||
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
|
||||
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
|
||||
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
|
||||
r <-
|
||||
DB.queryNamed
|
||||
getConnData_ :: DB.Connection -> ConnId -> IO (Maybe ConnData)
|
||||
getConnData_ dbConn connId =
|
||||
connData
|
||||
<$> DB.query dbConn "SELECT via_inv, conn_level FROM connections WHERE conn_alias = ?;" (Only connId)
|
||||
where
|
||||
connData [(viaInv, connLevel)] = Just ConnData {connId, viaInv, connLevel}
|
||||
connData _ = Nothing
|
||||
|
||||
getRcvQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
|
||||
getRcvQueueByConnAlias_ dbConn connId =
|
||||
rcvQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
|
||||
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key,
|
||||
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
|
||||
FROM rcv_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.conn_alias = :conn_alias;
|
||||
WHERE q.conn_alias = ?;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
case r of
|
||||
[(keyHash, host, port, rcvId, cAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
|
||||
(Only connId)
|
||||
where
|
||||
rcvQueue [(keyHash, host, port, rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] =
|
||||
let srv = SMPServer host (deserializePort_ port) keyHash
|
||||
return . Just $ RcvQueue srv rcvId cAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
_ -> return Nothing
|
||||
in Just $ RcvQueue srv rcvId rcvPrivateKey sndId sndKey decryptKey verifyKey status
|
||||
rcvQueue _ = Nothing
|
||||
|
||||
retrieveSndQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe SndQueue)
|
||||
retrieveSndQueueByConnAlias_ dbConn connAlias = do
|
||||
r <-
|
||||
DB.queryNamed
|
||||
getSndQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
|
||||
getSndQueueByConnAlias_ dbConn connId =
|
||||
sndQueue
|
||||
<$> DB.query
|
||||
dbConn
|
||||
[sql|
|
||||
SELECT
|
||||
s.key_hash, q.host, q.port, q.snd_id, q.conn_alias,
|
||||
q.snd_private_key, q.encrypt_key, q.sign_key, q.status
|
||||
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_private_key, q.encrypt_key, q.sign_key, q.status
|
||||
FROM snd_queues q
|
||||
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
|
||||
WHERE q.conn_alias = :conn_alias;
|
||||
WHERE q.conn_alias = ?;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
case r of
|
||||
[(keyHash, host, port, sndId, cAlias, sndPrivateKey, encryptKey, signKey, status)] -> do
|
||||
(Only connId)
|
||||
where
|
||||
sndQueue [(keyHash, host, port, sndId, sndPrivateKey, encryptKey, signKey, status)] =
|
||||
let srv = SMPServer host (deserializePort_ port) keyHash
|
||||
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
|
||||
_ -> return Nothing
|
||||
in Just $ SndQueue srv sndId sndPrivateKey encryptKey signKey status
|
||||
sndQueue _ = Nothing
|
||||
|
||||
-- * upgradeRcvConnToDuplex helpers
|
||||
|
||||
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
|
||||
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
updateConnWithSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
|
||||
updateConnWithSndQueue_ dbConn connId SndQueue {server, sndId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -539,12 +734,12 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
|
||||
SET snd_host = :snd_host, snd_port = :snd_port, snd_id = :snd_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connAlias]
|
||||
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connId]
|
||||
|
||||
-- * upgradeSndConnToDuplex helpers
|
||||
|
||||
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
|
||||
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
updateConnWithRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
updateConnWithRcvQueue_ dbConn connId RcvQueue {server, rcvId} = do
|
||||
let port_ = serializePort_ $ port server
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
@@ -553,12 +748,12 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
|
||||
SET rcv_host = :rcv_host, rcv_port = :rcv_port, rcv_id = :rcv_id
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
|
||||
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connId]
|
||||
|
||||
-- * updateRcvIds helpers
|
||||
|
||||
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
|
||||
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
retrieveLastIdsAndHashRcv_ dbConn connId = do
|
||||
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
@@ -567,11 +762,11 @@ retrieveLastIdsAndHashRcv_ dbConn connAlias = do
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
[":conn_alias" := connId]
|
||||
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
|
||||
|
||||
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO ()
|
||||
updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -582,13 +777,13 @@ updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_rcv_msg_id" := newInternalRcvId,
|
||||
":conn_alias" := connAlias
|
||||
":conn_alias" := connId
|
||||
]
|
||||
|
||||
-- * createRcvMsg helpers
|
||||
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
|
||||
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
insertRcvMsgBase_ dbConn connId RcvMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -597,15 +792,15 @@ insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
|
||||
VALUES
|
||||
(:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL,:body);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_id" := internalId,
|
||||
":internal_ts" := internalTs,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
insertRcvMsgDetails_ dbConn connId RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -618,7 +813,7 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
:broker_id,:broker_ts,:rcv_status, NULL, NULL,
|
||||
:internal_hash,:external_prev_snd_hash,:integrity);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_rcv_id" := internalRcvId,
|
||||
":internal_id" := internalId,
|
||||
":external_snd_id" := fst senderMeta,
|
||||
@@ -631,8 +826,8 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
|
||||
":integrity" := msgIntegrity
|
||||
]
|
||||
|
||||
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
|
||||
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|
||||
updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
|
||||
updateHashRcv_ dbConn connId RcvMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
|
||||
@@ -645,14 +840,14 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|
||||
|]
|
||||
[ ":last_external_snd_msg_id" := fst senderMeta,
|
||||
":last_rcv_msg_hash" := internalHash,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":last_internal_rcv_msg_id" := internalRcvId
|
||||
]
|
||||
|
||||
-- * updateSndIds helpers
|
||||
|
||||
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
|
||||
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
retrieveLastIdsAndHashSnd_ dbConn connId = do
|
||||
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
|
||||
DB.queryNamed
|
||||
dbConn
|
||||
@@ -661,11 +856,11 @@ retrieveLastIdsAndHashSnd_ dbConn connAlias = do
|
||||
FROM connections
|
||||
WHERE conn_alias = :conn_alias;
|
||||
|]
|
||||
[":conn_alias" := connAlias]
|
||||
[":conn_alias" := connId]
|
||||
return (lastInternalId, lastInternalSndId, lastSndHash)
|
||||
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
|
||||
updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -676,13 +871,13 @@ updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|
||||
|]
|
||||
[ ":last_internal_msg_id" := newInternalId,
|
||||
":last_internal_snd_msg_id" := newInternalSndId,
|
||||
":conn_alias" := connAlias
|
||||
":conn_alias" := connId
|
||||
]
|
||||
|
||||
-- * createSndMsg helpers
|
||||
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
|
||||
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
insertSndMsgBase_ dbConn connId SndMsgData {..} = do
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -691,15 +886,15 @@ insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
|
||||
VALUES
|
||||
(:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id,:body);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_id" := internalId,
|
||||
":internal_ts" := internalTs,
|
||||
":internal_snd_id" := internalSndId,
|
||||
":body" := decodeUtf8 msgBody
|
||||
]
|
||||
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
|
||||
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
[sql|
|
||||
@@ -708,15 +903,15 @@ insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
|
||||
VALUES
|
||||
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash);
|
||||
|]
|
||||
[ ":conn_alias" := connAlias,
|
||||
[ ":conn_alias" := connId,
|
||||
":internal_snd_id" := internalSndId,
|
||||
":internal_id" := internalId,
|
||||
":snd_status" := Created,
|
||||
":internal_hash" := internalHash
|
||||
]
|
||||
|
||||
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
|
||||
updateHashSnd_ dbConn connAlias SndMsgData {..} =
|
||||
updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
|
||||
updateHashSnd_ dbConn connId SndMsgData {..} =
|
||||
DB.executeNamed
|
||||
dbConn
|
||||
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
|
||||
@@ -727,7 +922,7 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} =
|
||||
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|
||||
|]
|
||||
[ ":last_snd_msg_hash" := internalHash,
|
||||
":conn_alias" := connAlias,
|
||||
":conn_alias" := connId,
|
||||
":last_internal_snd_msg_id" := internalSndId
|
||||
]
|
||||
|
||||
@@ -745,10 +940,10 @@ bcastExists_ dbConn bId = not . null <$> queryBcast
|
||||
queryBcast :: IO [Only BroadcastId]
|
||||
queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
|
||||
|
||||
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool
|
||||
bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
|
||||
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnId -> IO Bool
|
||||
bcastConnExists_ dbConn bId connId = not . null <$> queryBcastConn
|
||||
where
|
||||
queryBcastConn :: IO [(BroadcastId, ConnAlias)]
|
||||
queryBcastConn :: IO [(BroadcastId, ConnId)]
|
||||
queryBcastConn =
|
||||
DB.query
|
||||
dbConn
|
||||
@@ -757,4 +952,37 @@ bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
|
||||
FROM broadcast_connections
|
||||
WHERE broadcast_id = ? AND conn_alias = ?;
|
||||
|]
|
||||
(bId, connAlias)
|
||||
(bId, connId)
|
||||
|
||||
-- create record with a random ID
|
||||
|
||||
getConnId_ :: DB.Connection -> TVar ChaChaDRG -> ConnData -> IO (Either StoreError ConnId)
|
||||
getConnId_ dbConn gVar ConnData {connId = ""} = getUniqueRandomId gVar $ getConnData_ dbConn
|
||||
getConnId_ _ _ ConnData {connId} = pure $ Right connId
|
||||
|
||||
getUniqueRandomId :: TVar ChaChaDRG -> (ByteString -> IO (Maybe a)) -> IO (Either StoreError ByteString)
|
||||
getUniqueRandomId gVar get = tryGet 3
|
||||
where
|
||||
tryGet :: Int -> IO (Either StoreError ByteString)
|
||||
tryGet 0 = pure $ Left SEUniqueID
|
||||
tryGet n = do
|
||||
id' <- randomId gVar 12
|
||||
get id' >>= \case
|
||||
Nothing -> pure $ Right id'
|
||||
Just _ -> tryGet (n - 1)
|
||||
|
||||
createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString)
|
||||
createWithRandomId gVar create = tryCreate 3
|
||||
where
|
||||
tryCreate :: Int -> IO (Either StoreError ByteString)
|
||||
tryCreate 0 = pure $ Left SEUniqueID
|
||||
tryCreate n = do
|
||||
id' <- randomId gVar 12
|
||||
E.try (create id') >>= \case
|
||||
Right _ -> pure $ Right id'
|
||||
Left e
|
||||
| DB.sqlError e == DB.ErrorConstraint -> tryCreate (n - 1)
|
||||
| otherwise -> pure . Left . SEInternal $ bshow e
|
||||
|
||||
randomId :: TVar ChaChaDRG -> Int -> IO ByteString
|
||||
randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n)
|
||||
|
||||
@@ -30,7 +30,7 @@ base64StringP = do
|
||||
pure $ str <> pad
|
||||
|
||||
tsISO8601P :: Parser UTCTime
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ')
|
||||
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
|
||||
|
||||
parse :: Parser a -> e -> (ByteString -> Either e a)
|
||||
parse parser err = first (const err) . parseAll parser
|
||||
@@ -42,14 +42,17 @@ parseRead :: Read a => Parser ByteString -> Parser a
|
||||
parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack)
|
||||
|
||||
parseRead1 :: Read a => Parser a
|
||||
parseRead1 = parseRead $ A.takeTill (== ' ')
|
||||
parseRead1 = parseRead $ A.takeTill wordEnd
|
||||
|
||||
parseRead2 :: Read a => Parser a
|
||||
parseRead2 = parseRead $ do
|
||||
w1 <- A.takeTill (== ' ') <* A.char ' '
|
||||
w2 <- A.takeTill (== ' ')
|
||||
w1 <- A.takeTill wordEnd <* A.char ' '
|
||||
w2 <- A.takeTill wordEnd
|
||||
pure $ w1 <> " " <> w2
|
||||
|
||||
wordEnd :: Char -> Bool
|
||||
wordEnd c = c == ' ' || c == '\n'
|
||||
|
||||
parseString :: (ByteString -> Either String a) -> (String -> a)
|
||||
parseString p = either error id . p . B.pack
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PostfixOperators #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
@@ -17,6 +18,7 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import SMPAgentClient
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store (InvitationId)
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..))
|
||||
import System.Timeout
|
||||
@@ -29,10 +31,16 @@ agentTests (ATransport t) = do
|
||||
describe "Establishing duplex connection" do
|
||||
it "should connect via one server and one agent" $
|
||||
smpAgentTest2_1_1 $ testDuplexConnection t
|
||||
it "should connect via one server and one agent (random IDs)" $
|
||||
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via one server and 2 agents" $
|
||||
smpAgentTest2_2_1 $ testDuplexConnection t
|
||||
it "should connect via one server and 2 agents (random IDs)" $
|
||||
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
|
||||
it "should connect via 2 servers and 2 agents" $
|
||||
smpAgentTest2_2_2 $ testDuplexConnection t
|
||||
it "should connect via 2 servers and 2 agents (random IDs)" $
|
||||
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
|
||||
describe "Connection subscriptions" do
|
||||
it "should connect via one server and one agent" $
|
||||
smpAgentTest3_1_1 $ testSubscription t
|
||||
@@ -41,6 +49,13 @@ agentTests (ATransport t) = do
|
||||
describe "Broadcast" do
|
||||
it "should create broadcast and send messages" $
|
||||
smpAgentTest3 $ testBroadcast t
|
||||
it "should create broadcast and send messages (random IDs)" $
|
||||
smpAgentTest3 $ testBroadcastRandomIds t
|
||||
describe "Introduction" do
|
||||
it "should send and accept introduction" $
|
||||
smpAgentTest3 $ testIntroduction t
|
||||
it "should send and accept introduction (random IDs)" $
|
||||
smpAgentTest3 $ testIntroductionRandomIds t
|
||||
|
||||
type TestTransmission p = (ACorrId, ByteString, APartyCmd p)
|
||||
|
||||
@@ -54,9 +69,13 @@ testTE (ATransmissionOrError corrId entity cmdOrErr) =
|
||||
Right cmd -> Right $ APartyCmd cmd
|
||||
Left e -> Left e
|
||||
|
||||
-- | receive message to handle `h`
|
||||
(<#:) :: Transport c => c -> IO (TestTransmissionOrError 'Agent)
|
||||
(<#:) h = testTE <$> tGet SAgent h
|
||||
|
||||
-- | send transmission `t` to handle `h` and get response
|
||||
(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (TestTransmissionOrError 'Agent)
|
||||
h #: t = tPutRaw h t >> testTE <$> tGet SAgent h
|
||||
h #: t = tPutRaw h t >> (h <#:)
|
||||
|
||||
-- | action and expected response
|
||||
-- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r`
|
||||
@@ -75,11 +94,11 @@ correctTransmission (corrId, cAlias, cmdOrErr) = case cmdOrErr of
|
||||
|
||||
-- | receive message to handle `h` and validate that it is the expected one
|
||||
(<#) :: Transport c => c -> TestTransmission' 'Agent c' -> Expectation
|
||||
h <# (corrId, cAlias, cmd) = tGet SAgent h >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd))) . testTE
|
||||
h <# (corrId, cAlias, cmd) = (h <#:) >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd)))
|
||||
|
||||
-- | receive message to handle `h` and validate it using predicate `p`
|
||||
(<#=) :: Transport c => c -> (TestTransmission 'Agent -> Bool) -> Expectation
|
||||
h <#= p = tGet SAgent h >>= (`shouldSatisfy` p . correctTransmission . testTE)
|
||||
h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission)
|
||||
|
||||
-- | test that nothing is delivered to handle `h` during 10ms
|
||||
(#:#) :: Transport c => c -> String -> Expectation
|
||||
@@ -90,53 +109,75 @@ h #:# err = tryGet `shouldReturn` ()
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
|
||||
pattern Msg :: MsgBody -> APartyCmd 'Agent
|
||||
pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk}
|
||||
|
||||
pattern Sent :: AgentMsgId -> APartyCmd 'Agent
|
||||
pattern Sent msgId <- APartyCmd (SENT msgId)
|
||||
|
||||
pattern Inv :: SMPQueueInfo -> APartyCmd 'Agent
|
||||
pattern Inv invitation <- APartyCmd (INV invitation)
|
||||
pattern Msg :: MsgBody -> APartyCmd 'Agent
|
||||
pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk}
|
||||
|
||||
pattern Inv :: SMPQueueInfo -> Either AgentErrorType (APartyCmd 'Agent)
|
||||
pattern Inv invitation <- Right (APartyCmd (INV invitation))
|
||||
|
||||
pattern Req :: InvitationId -> EntityInfo -> Either AgentErrorType (APartyCmd 'Agent)
|
||||
pattern Req invId eInfo <- Right (APartyCmd (REQ (IE (Conn invId)) eInfo))
|
||||
|
||||
testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnection _ alice bob = do
|
||||
("1", "C:bob", Right (Inv qInfo)) <- alice #: ("1", "C:bob", "NEW")
|
||||
("1", "C:bob", Inv qInfo) <- alice #: ("1", "C:bob", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON)
|
||||
alice <# ("", "C:bob", CON)
|
||||
alice #: ("2", "C:bob", "SEND :hello") =#> \case ("2", "C:bob", Sent 1) -> True; _ -> False
|
||||
alice #: ("3", "C:bob", "SEND :how are you?") =#> \case ("3", "C:bob", Sent 2) -> True; _ -> False
|
||||
alice #: ("2", "C:bob", "SEND :hello") #> ("2", "C:bob", SENT 1)
|
||||
alice #: ("3", "C:bob", "SEND :how are you?") #> ("3", "C:bob", SENT 2)
|
||||
bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
|
||||
bob <#= \case ("", "C:alice", Msg "how are you?") -> True; _ -> False
|
||||
bob #: ("14", "C:alice", "SEND 9\nhello too") =#> \case ("14", "C:alice", Sent 3) -> True; _ -> False
|
||||
bob #: ("14", "C:alice", "SEND 9\nhello too") #> ("14", "C:alice", SENT 3)
|
||||
alice <#= \case ("", "C:bob", Msg "hello too") -> True; _ -> False
|
||||
bob #: ("15", "C:alice", "SEND 9\nmessage 1") =#> \case ("15", "C:alice", Sent 4) -> True; _ -> False
|
||||
bob #: ("15", "C:alice", "SEND 9\nmessage 1") #> ("15", "C:alice", SENT 4)
|
||||
alice <#= \case ("", "C:bob", Msg "message 1") -> True; _ -> False
|
||||
alice #: ("5", "C:bob", "OFF") #> ("5", "C:bob", OK)
|
||||
bob #: ("17", "C:alice", "SEND 9\nmessage 3") #> ("17", "C:alice", ERR (SMP AUTH))
|
||||
alice #: ("6", "C:bob", "DEL") #> ("6", "C:bob", OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnRandomIds _ alice bob = do
|
||||
("1", bobConn, Inv qInfo) <- alice #: ("1", "C:", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
("", aliceConn, Right (APartyCmd CON)) <- bob #: ("11", "C:", "JOIN " <> qInfo')
|
||||
alice <# ("", bobConn, CON)
|
||||
alice #: ("2", bobConn, "SEND :hello") #> ("2", bobConn, SENT 1)
|
||||
alice #: ("3", bobConn, "SEND :how are you?") #> ("3", bobConn, SENT 2)
|
||||
bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False
|
||||
bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False
|
||||
bob #: ("14", aliceConn, "SEND 9\nhello too") #> ("14", aliceConn, SENT 3)
|
||||
alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False
|
||||
bob #: ("15", aliceConn, "SEND 9\nmessage 1") #> ("15", aliceConn, SENT 4)
|
||||
alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False
|
||||
alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK)
|
||||
bob #: ("17", aliceConn, "SEND 9\nmessage 3") #> ("17", aliceConn, ERR (SMP AUTH))
|
||||
alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testSubscription _ alice1 alice2 bob = do
|
||||
("1", "C:bob", Right (Inv qInfo)) <- alice1 #: ("1", "C:bob", "NEW")
|
||||
("1", "C:bob", Inv qInfo) <- alice1 #: ("1", "C:bob", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON)
|
||||
bob #: ("12", "C:alice", "SEND 5\nhello") =#> \case ("12", "C:alice", Sent _) -> True; _ -> False
|
||||
bob #: ("13", "C:alice", "SEND 11\nhello again") =#> \case ("13", "C:alice", Sent _) -> True; _ -> False
|
||||
bob #: ("12", "C:alice", "SEND 5\nhello") #> ("12", "C:alice", SENT 1)
|
||||
bob #: ("13", "C:alice", "SEND 11\nhello again") #> ("13", "C:alice", SENT 2)
|
||||
alice1 <# ("", "C:bob", CON)
|
||||
alice1 <#= \case ("", "C:bob", Msg "hello") -> True; _ -> False
|
||||
alice1 <#= \case ("", "C:bob", Msg "hello again") -> True; _ -> False
|
||||
alice2 #: ("21", "C:bob", "SUB") #> ("21", "C:bob", OK)
|
||||
alice1 <# ("", "C:bob", END)
|
||||
bob #: ("14", "C:alice", "SEND 2\nhi") =#> \case ("14", "C:alice", Sent _) -> True; _ -> False
|
||||
bob #: ("14", "C:alice", "SEND 2\nhi") #> ("14", "C:alice", SENT 3)
|
||||
alice2 <#= \case ("", "C:bob", Msg "hi") -> True; _ -> False
|
||||
alice1 #:# "nothing else should be delivered to alice1"
|
||||
|
||||
testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO ()
|
||||
testSubscrNotification _ (server, _) client = do
|
||||
client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", Inv _) -> True; _ -> False
|
||||
client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", APartyCmd INV {}) -> True; _ -> False
|
||||
client #:# "nothing should be delivered to client before the server is killed"
|
||||
killThread server
|
||||
client <# ("", "C:conn1", END)
|
||||
@@ -156,8 +197,8 @@ testBroadcast _ alice bob tom = do
|
||||
alice #: ("e3", "B:team", "ADD C:unknown") #> ("e3", "B:team", ERR $ CONN NOT_FOUND)
|
||||
alice #: ("e4", "B:team", "ADD C:bob") #> ("e4", "B:team", ERR $ CONN DUPLICATE)
|
||||
-- send message
|
||||
alice #: ("4", "B:team", "SEND 5\nhello") #> ("4", "C:bob", SENT 1)
|
||||
alice <# ("4", "C:tom", SENT 1)
|
||||
alice #: ("4", "B:team", "SEND 5\nhello") =#> \case ("4", c, Sent 1) -> c == "C:bob" || c == "C:tom"; _ -> False
|
||||
alice <#= \case ("4", c, Sent 1) -> c == "C:bob" || c == "C:tom"; _ -> False
|
||||
alice <# ("4", "B:team", SENT 0)
|
||||
bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
|
||||
tom <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
|
||||
@@ -177,13 +218,104 @@ testBroadcast _ alice bob tom = do
|
||||
-- commands with errors
|
||||
alice #: ("e8", "B:team", "DEL") #> ("e8", "B:team", ERR $ BCAST B_NOT_FOUND)
|
||||
alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND)
|
||||
where
|
||||
connect :: (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
("c1", _, Right (Inv qInfo)) <- h1 #: ("c1", "C:" <> name2, "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') =#> \case ("", c1, APartyCmd CON) -> c1 == "C:" <> name1; _ -> False
|
||||
h1 <#= \case ("", c2, APartyCmd CON) -> c2 == "C:" <> name2; _ -> False
|
||||
|
||||
testBroadcastRandomIds :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testBroadcastRandomIds _ alice bob tom = do
|
||||
-- establish connections
|
||||
(aliceB, bobA) <- alice `connect'` bob
|
||||
(aliceT, tomA) <- alice `connect'` tom
|
||||
-- create and set up broadcast
|
||||
("1", team, Right (APartyCmd OK)) <- alice #: ("1", "B:", "NEW")
|
||||
alice #: ("2", team, "ADD " <> bobA) #> ("2", team, OK)
|
||||
alice #: ("3", team, "ADD " <> tomA) #> ("3", team, OK)
|
||||
-- commands with errors
|
||||
alice #: ("e1", team, "NEW") #> ("e1", team, ERR $ BCAST B_DUPLICATE)
|
||||
alice #: ("e2", "B:group", "ADD " <> bobA) #> ("e2", "B:group", ERR $ BCAST B_NOT_FOUND)
|
||||
alice #: ("e3", team, "ADD C:unknown") #> ("e3", team, ERR $ CONN NOT_FOUND)
|
||||
alice #: ("e4", team, "ADD " <> bobA) #> ("e4", team, ERR $ CONN DUPLICATE)
|
||||
-- send message
|
||||
alice #: ("4", team, "SEND 5\nhello") =#> \case ("4", c, Sent 1) -> c == bobA || c == tomA; _ -> False
|
||||
alice <#= \case ("4", c, Sent 1) -> c == bobA || c == tomA; _ -> False
|
||||
alice <# ("4", team, SENT 0)
|
||||
bob <#= \case ("", c, Msg "hello") -> c == aliceB; _ -> False
|
||||
tom <#= \case ("", c, Msg "hello") -> c == aliceT; _ -> False
|
||||
-- remove one connection
|
||||
alice #: ("5", team, "REM " <> tomA) #> ("5", team, OK)
|
||||
alice #: ("6", team, "SEND 11\nhello again") #> ("6", bobA, SENT 2)
|
||||
alice <# ("6", team, SENT 0)
|
||||
bob <#= \case ("", c, Msg "hello again") -> c == aliceB; _ -> False
|
||||
tom #:# "nothing delivered to tom"
|
||||
-- commands with errors
|
||||
alice #: ("e5", "B:group", "REM " <> bobA) #> ("e5", "B:group", ERR $ BCAST B_NOT_FOUND)
|
||||
alice #: ("e6", team, "REM C:unknown") #> ("e6", team, ERR $ CONN NOT_FOUND)
|
||||
alice #: ("e7", team, "REM " <> tomA) #> ("e7", team, ERR $ CONN NOT_FOUND)
|
||||
-- delete broadcast
|
||||
alice #: ("7", team, "DEL") #> ("7", team, OK)
|
||||
alice #: ("8", team, "SEND 11\ntry sending") #> ("8", team, ERR $ BCAST B_NOT_FOUND)
|
||||
-- commands with errors
|
||||
alice #: ("e8", team, "DEL") #> ("e8", team, ERR $ BCAST B_NOT_FOUND)
|
||||
alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND)
|
||||
|
||||
testIntroduction :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testIntroduction _ alice bob tom = do
|
||||
-- establish connections
|
||||
(alice, "alice") `connect` (bob, "bob")
|
||||
(alice, "alice") `connect` (tom, "tom")
|
||||
-- send introduction of tom to bob
|
||||
alice #: ("1", "C:bob", "INTRO C:tom 8\nmeet tom") #> ("1", "C:bob", OK)
|
||||
("", "C:alice", Req invId1 "meet tom") <- (bob <#:)
|
||||
bob #: ("2", "C:tom_via_alice", "ACPT C:" <> invId1 <> " 7\nI'm bob") #> ("2", "C:tom_via_alice", OK)
|
||||
("", "C:alice", Req invId2 "I'm bob") <- (tom <#:)
|
||||
-- TODO info "tom here" is not used, either JOIN command also should have eInfo parameter
|
||||
-- or this should be another command, not ACPT
|
||||
tom #: ("3", "C:bob_via_alice", "ACPT C:" <> invId2 <> " 8\ntom here") #> ("3", "C:bob_via_alice", OK)
|
||||
tom <# ("", "C:bob_via_alice", CON)
|
||||
bob <# ("", "C:tom_via_alice", CON)
|
||||
alice <# ("", "C:bob", ICON (IE (Conn "tom")))
|
||||
-- they can message each other now
|
||||
tom #: ("4", "C:bob_via_alice", "SEND :hello") #> ("4", "C:bob_via_alice", SENT 1)
|
||||
bob <#= \case ("", "C:tom_via_alice", Msg "hello") -> True; _ -> False
|
||||
bob #: ("5", "C:tom_via_alice", "SEND 9\nhello too") #> ("5", "C:tom_via_alice", SENT 2)
|
||||
tom <#= \case ("", "C:bob_via_alice", Msg "hello too") -> True; _ -> False
|
||||
|
||||
testIntroductionRandomIds :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testIntroductionRandomIds _ alice bob tom = do
|
||||
-- establish connections
|
||||
(aliceB, bobA) <- alice `connect'` bob
|
||||
(aliceT, tomA) <- alice `connect'` tom
|
||||
-- send introduction of tom to bob
|
||||
alice #: ("1", bobA, "INTRO " <> tomA <> " 8\nmeet tom") #> ("1", bobA, OK)
|
||||
("", aliceB', Req invId1 "meet tom") <- (bob <#:)
|
||||
aliceB' `shouldBe` aliceB
|
||||
("2", tomB, Right (APartyCmd OK)) <- bob #: ("2", "C:", "ACPT C:" <> invId1 <> " 7\nI'm bob")
|
||||
("", aliceT', Req invId2 "I'm bob") <- (tom <#:)
|
||||
aliceT' `shouldBe` aliceT
|
||||
-- TODO info "tom here" is not used, either JOIN command also should have eInfo parameter
|
||||
-- or this should be another command, not ACPT
|
||||
("3", bobT, Right (APartyCmd OK)) <- tom #: ("3", "C:", "ACPT C:" <> invId2 <> " 8\ntom here")
|
||||
tom <# ("", bobT, CON)
|
||||
bob <# ("", tomB, CON)
|
||||
alice <# ("", bobA, ICON . IE . Conn $ B.drop 2 tomA)
|
||||
-- they can message each other now
|
||||
tom #: ("4", bobT, "SEND :hello") #> ("4", bobT, SENT 1)
|
||||
bob <#= \case ("", c, Msg "hello") -> c == tomB; _ -> False
|
||||
bob #: ("5", tomB, "SEND 9\nhello too") #> ("5", tomB, SENT 2)
|
||||
tom <#= \case ("", c, Msg "hello too") -> c == bobT; _ -> False
|
||||
|
||||
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
("c1", _, Inv qInfo) <- h1 #: ("c1", "C:" <> name2, "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') #> ("", "C:" <> name1, CON)
|
||||
h1 <# ("", "C:" <> name2, CON)
|
||||
|
||||
connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString)
|
||||
connect' h1 h2 = do
|
||||
("c1", conn2, Inv qInfo) <- h1 #: ("c1", "C:", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
("", conn1, Right (APartyCmd CON)) <- h2 #: ("c2", "C:", "JOIN " <> qInfo')
|
||||
h1 <# ("", conn2, CON)
|
||||
pure (conn1, conn2)
|
||||
|
||||
samplePublicKey :: ByteString
|
||||
samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"
|
||||
@@ -205,7 +337,7 @@ syntaxTests t = do
|
||||
-- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided)
|
||||
-- TODO: add tests with defined connection alias
|
||||
it "using same server as in invitation" $
|
||||
("311", "C:", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:", "ERR SMP AUTH")
|
||||
("311", "C:a", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:a", "ERR SMP AUTH")
|
||||
describe "invalid" do
|
||||
-- TODO: JOIN is not merged yet - to be added
|
||||
it "no parameters" $ ("321", "C:", "JOIN") >#> ("321", "C:", "ERR CMD SYNTAX")
|
||||
|
||||
@@ -8,9 +8,11 @@
|
||||
module AgentTests.SQLiteTests (storeTests) where
|
||||
|
||||
import Control.Concurrent.Async (concurrently_)
|
||||
import Control.Concurrent.STM (newTVarIO)
|
||||
import Control.Monad (replicateM_)
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import qualified Crypto.PubKey.RSA as R
|
||||
import Crypto.Random (drgNew)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
@@ -73,11 +75,13 @@ storeTests = do
|
||||
describe "Queue and Connection management" do
|
||||
describe "createRcvConn" do
|
||||
testCreateRcvConn
|
||||
testCreateRcvConnRandomId
|
||||
testCreateRcvConnDuplicate
|
||||
describe "createSndConn" do
|
||||
testCreateSndConn
|
||||
testCreateSndConnRandomID
|
||||
testCreateSndConnDuplicate
|
||||
describe "getAllConnAliases" testGetAllConnAliases
|
||||
describe "getAllConnIds" testGetAllConnIds
|
||||
describe "getRcvConn" testGetRcvConn
|
||||
describe "deleteConn" do
|
||||
testDeleteRcvConn
|
||||
@@ -104,14 +108,16 @@ storeTests = do
|
||||
testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore)
|
||||
testConcurrentWrites =
|
||||
it "should complete multiple concurrent write transactions w/t sqlite busy errors" $ \(s1, s2) -> do
|
||||
_ <- runExceptT $ createRcvConn s1 rcvQueue1
|
||||
concurrently_ (runTest s1) (runTest s2)
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn s1 g cData1 rcvQueue1
|
||||
let ConnData {connId} = cData1
|
||||
concurrently_ (runTest s1 connId) (runTest s2 connId)
|
||||
where
|
||||
runTest :: SQLiteStore -> IO (Either StoreError ())
|
||||
runTest store = runExceptT . replicateM_ 100 $ do
|
||||
(internalId, internalRcvId, _, _) <- updateRcvIds store rcvQueue1
|
||||
runTest :: SQLiteStore -> ConnId -> IO (Either StoreError ())
|
||||
runTest store connId = runExceptT . replicateM_ 100 $ do
|
||||
(internalId, internalRcvId, _, _) <- updateRcvIds store connId
|
||||
let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy"
|
||||
createRcvMsg store rcvQueue1 rcvMsgData
|
||||
createRcvMsg store connId rcvMsgData
|
||||
|
||||
testCompiledThreadsafe :: SpecWith SQLiteStore
|
||||
testCompiledThreadsafe =
|
||||
@@ -132,12 +138,14 @@ testForeignKeysEnabled =
|
||||
DB.execute_ (dbConn store) inconsistentQuery
|
||||
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
|
||||
|
||||
cData1 :: ConnData
|
||||
cData1 = ConnData {connId = "conn1", viaInv = Nothing, connLevel = 1}
|
||||
|
||||
rcvQueue1 :: RcvQueue
|
||||
rcvQueue1 =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
rcvId = "1234",
|
||||
connAlias = "conn1",
|
||||
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
sndId = Just "2345",
|
||||
sndKey = Nothing,
|
||||
@@ -151,7 +159,6 @@ sndQueue1 =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
sndId = "3456",
|
||||
connAlias = "conn1",
|
||||
sndPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
|
||||
signKey = C.safePrivateKey (1, 2, 3),
|
||||
@@ -161,64 +168,95 @@ sndQueue1 =
|
||||
testCreateRcvConn :: SpecWith SQLiteStore
|
||||
testCreateRcvConn =
|
||||
it "should create RcvConnection and add SndQueue" $ \store -> do
|
||||
createRcvConn store rcvQueue1
|
||||
`returnsResult` ()
|
||||
g <- newTVarIO =<< drgNew
|
||||
createRcvConn store g cData1 rcvQueue1
|
||||
`returnsResult` "conn1"
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
|
||||
testCreateRcvConnRandomId :: SpecWith SQLiteStore
|
||||
testCreateRcvConnRandomId =
|
||||
it "should create RcvConnection and add SndQueue with random ID" $ \store -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
Right connId <- runExceptT $ createRcvConn store g cData1 {connId = ""} rcvQueue1
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1)
|
||||
upgradeRcvConnToDuplex store connId sndQueue1
|
||||
`returnsResult` ()
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)
|
||||
|
||||
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateRcvConnDuplicate =
|
||||
it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
createRcvConn store g cData1 rcvQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
|
||||
testCreateSndConn :: SpecWith SQLiteStore
|
||||
testCreateSndConn =
|
||||
it "should create SndConnection and add RcvQueue" $ \store -> do
|
||||
createSndConn store sndQueue1
|
||||
`returnsResult` ()
|
||||
g <- newTVarIO =<< drgNew
|
||||
createSndConn store g cData1 sndQueue1
|
||||
`returnsResult` "conn1"
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
|
||||
upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
|
||||
testCreateSndConnRandomID :: SpecWith SQLiteStore
|
||||
testCreateSndConnRandomID =
|
||||
it "should create SndConnection and add RcvQueue with random ID" $ \store -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
Right connId <- runExceptT $ createSndConn store g cData1 {connId = ""} sndQueue1
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1)
|
||||
upgradeSndConnToDuplex store connId rcvQueue1
|
||||
`returnsResult` ()
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)
|
||||
|
||||
testCreateSndConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateSndConnDuplicate =
|
||||
it "should throw error on attempt to create duplicate SndConnection" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
createSndConn store sndQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
createSndConn store g cData1 sndQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
|
||||
testGetAllConnAliases :: SpecWith SQLiteStore
|
||||
testGetAllConnAliases =
|
||||
testGetAllConnIds :: SpecWith SQLiteStore
|
||||
testGetAllConnIds =
|
||||
it "should get all conn aliases" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
_ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"}
|
||||
getAllConnAliases store
|
||||
`returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias]
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
_ <- runExceptT $ createSndConn store g cData1 {connId = "conn2"} sndQueue1
|
||||
getAllConnIds store
|
||||
`returnsResult` ["conn1" :: ConnId, "conn2" :: ConnId]
|
||||
|
||||
testGetRcvConn :: SpecWith SQLiteStore
|
||||
testGetRcvConn =
|
||||
it "should get connection using rcv queue id and server" $ \store -> do
|
||||
let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash
|
||||
let recipientId = "1234"
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
getRcvConn store smpServer recipientId
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection (connAlias (rcvQueue1 :: RcvQueue)) rcvQueue1)
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
|
||||
testDeleteRcvConn :: SpecWith SQLiteStore
|
||||
testDeleteRcvConn =
|
||||
it "should create RcvConnection and delete it" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
-- TODO check queues are deleted as well
|
||||
@@ -228,9 +266,10 @@ testDeleteRcvConn =
|
||||
testDeleteSndConn :: SpecWith SQLiteStore
|
||||
testDeleteSndConn =
|
||||
it "should create SndConnection and delete it" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
-- TODO check queues are deleted as well
|
||||
@@ -240,10 +279,11 @@ testDeleteSndConn =
|
||||
testDeleteDuplexConn :: SpecWith SQLiteStore
|
||||
testDeleteDuplexConn =
|
||||
it "should create DuplexConnection and delete it" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
-- TODO check queues are deleted as well
|
||||
@@ -253,12 +293,12 @@ testDeleteDuplexConn =
|
||||
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeRcvConnToDuplex =
|
||||
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
sndId = "2345",
|
||||
connAlias = "conn1",
|
||||
sndPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
|
||||
signKey = C.safePrivateKey (1, 2, 3),
|
||||
@@ -273,12 +313,12 @@ testUpgradeRcvConnToDuplex =
|
||||
testUpgradeSndConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeSndConnToDuplex =
|
||||
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
|
||||
rcvId = "3456",
|
||||
connAlias = "conn1",
|
||||
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
|
||||
sndId = Just "4567",
|
||||
sndKey = Nothing,
|
||||
@@ -295,40 +335,43 @@ testUpgradeSndConnToDuplex =
|
||||
testSetRcvQueueStatus :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatus =
|
||||
it "should update status of RcvQueue" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1 {status = Confirmed})
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1 {status = Confirmed})
|
||||
|
||||
testSetSndQueueStatus :: SpecWith SQLiteStore
|
||||
testSetSndQueueStatus =
|
||||
it "should update status of SndQueue" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1 {status = Confirmed})
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1 {status = Confirmed})
|
||||
|
||||
testSetQueueStatusDuplex :: SpecWith SQLiteStore
|
||||
testSetQueueStatusDuplex =
|
||||
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Secured
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1)
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1)
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
|
||||
|
||||
testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatusNoQueue =
|
||||
@@ -362,20 +405,22 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
|
||||
msgIntegrity = MsgOk
|
||||
}
|
||||
|
||||
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do
|
||||
updateRcvIds store rcvQueue
|
||||
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg' st expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do
|
||||
updateRcvIds st connId
|
||||
`returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
|
||||
createRcvMsg store rcvQueue rcvMsgData
|
||||
createRcvMsg st connId rcvMsgData
|
||||
`returnsResult` ()
|
||||
|
||||
testCreateRcvMsg :: SpecWith SQLiteStore
|
||||
testCreateRcvMsg =
|
||||
it "should reserve internal ids and create a RcvMsg" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
it "should reserve internal ids and create a RcvMsg" $ \st -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
let ConnData {connId} = cData1
|
||||
_ <- runExceptT $ createRcvConn st g cData1 rcvQueue1
|
||||
-- TODO getMsg to check message
|
||||
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
testCreateRcvMsg' st 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg' st 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
|
||||
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
|
||||
mkSndMsgData internalId internalSndId internalHash =
|
||||
@@ -387,29 +432,33 @@ mkSndMsgData internalId internalSndId internalHash =
|
||||
internalHash
|
||||
}
|
||||
|
||||
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation
|
||||
testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do
|
||||
updateSndIds store sndQueue
|
||||
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> ConnId -> SndMsgData -> Expectation
|
||||
testCreateSndMsg' store expectedPrevHash connId sndMsgData@SndMsgData {..} = do
|
||||
updateSndIds store connId
|
||||
`returnsResult` (internalId, internalSndId, expectedPrevHash)
|
||||
createSndMsg store sndQueue sndMsgData
|
||||
createSndMsg store connId sndMsgData
|
||||
`returnsResult` ()
|
||||
|
||||
testCreateSndMsg :: SpecWith SQLiteStore
|
||||
testCreateSndMsg =
|
||||
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do
|
||||
_ <- runExceptT $ createSndConn store sndQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
let ConnData {connId} = cData1
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
-- TODO getMsg to check message
|
||||
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
|
||||
testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
|
||||
testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
|
||||
testCreateSndMsg' store "hash_dummy" connId $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
|
||||
|
||||
testCreateRcvAndSndMsgs :: SpecWith SQLiteStore
|
||||
testCreateRcvAndSndMsgs =
|
||||
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do
|
||||
_ <- runExceptT $ createRcvConn store rcvQueue1
|
||||
g <- newTVarIO =<< drgNew
|
||||
let ConnData {connId} = cData1
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
testCreateRcvMsg' store 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg' store 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg' store 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg' store "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg' store "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
|
||||
@@ -54,12 +54,12 @@ testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
|
||||
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
|
||||
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h
|
||||
|
||||
runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a
|
||||
runSmpAgentTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a
|
||||
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
|
||||
where
|
||||
t = transport @c
|
||||
|
||||
runSmpAgentServerTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a
|
||||
runSmpAgentServerTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a
|
||||
runSmpAgentServerTest test =
|
||||
withSmpServerThreadOn t testPort $
|
||||
\server -> withSmpAgentThreadOn t (agentTestPort, testPort, testDB) $
|
||||
@@ -70,7 +70,7 @@ runSmpAgentServerTest test =
|
||||
smpAgentServerTest :: Transport c => ((ThreadId, ThreadId) -> c -> IO ()) -> Expectation
|
||||
smpAgentServerTest test' = runSmpAgentServerTest test' `shouldReturn` ()
|
||||
|
||||
runSmpAgentTestN :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN agents test = withSmpServer t $ run agents []
|
||||
where
|
||||
run :: [(ServiceName, ServiceName, String)] -> [c] -> m a
|
||||
@@ -78,7 +78,7 @@ runSmpAgentTestN agents test = withSmpServer t $ run agents []
|
||||
run (a@(p, _, _) : as) hs = withSmpAgentOn t a $ testSMPAgentClientOn p $ \h -> run as (h : hs)
|
||||
t = transport @c
|
||||
|
||||
runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN_1 nClients test = withSmpServer t . withSmpAgent t $ run nClients []
|
||||
where
|
||||
run :: Int -> [c] -> m a
|
||||
@@ -156,17 +156,17 @@ cfg =
|
||||
}
|
||||
}
|
||||
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn t (port', smpPort', db') =
|
||||
let cfg' = cfg {tcpPort = port', dbFile = db', smpServers = L.fromList [SMPServer "localhost" (Just smpPort') testKeyHash]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg')
|
||||
(removeFile db')
|
||||
|
||||
withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a
|
||||
withSmpAgentOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a
|
||||
withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const
|
||||
|
||||
withSmpAgent :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withSmpAgent :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
||||
|
||||
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m) => ServiceName -> (c -> m a) -> m a
|
||||
|
||||
Reference in New Issue
Block a user