introduction protocol (#156)

* commands to support introduction

* agent messages / envelopes to support introductions

* introductions and invitations table; insert record with random unique ID

* store class methods and types for introductions

* process INTRO and ACPT commands for connection introductions

* fix tests: add MonadFail constraint, remove OK response to JOIN

* process agent messages for introductions

* ICON notification when introduction is completed

* replace multiway if with case

* correction

* support random connection IDs

* save additional connection fields, refactor create connection funcs

* refactor

* refactor

* test duplex connection with random IDs

* store methods for introductions

* test introduction

* fix parsing of CON agent message

* test introduction with random connection IDs

* broadcast with random connection and broadcast IDs

* clean up sql
This commit is contained in:
Evgeny Poberezkin
2021-06-11 21:33:13 +01:00
committed by GitHub
parent 46c3589604
commit ab89963f45
11 changed files with 1156 additions and 413 deletions

View File

@@ -1,9 +1,8 @@
CREATE TABLE IF NOT EXISTS broadcasts (
broadcast_id BLOB NOT NULL,
PRIMARY KEY (broadcast_id)
CREATE TABLE broadcasts (
broadcast_id BLOB NOT NULL PRIMARY KEY
) WITHOUT ROWID;
CREATE TABLE IF NOT EXISTS broadcast_connections (
CREATE TABLE broadcast_connections (
broadcast_id BLOB NOT NULL REFERENCES broadcasts (broadcast_id) ON DELETE CASCADE,
conn_alias BLOB NOT NULL REFERENCES connections (conn_alias),
PRIMARY KEY (broadcast_id, conn_alias)

View File

@@ -0,0 +1,27 @@
CREATE TABLE conn_intros (
intro_id BLOB NOT NULL PRIMARY KEY,
to_conn BLOB NOT NULL REFERENCES connections (conn_alias) ON DELETE CASCADE,
to_info BLOB, -- info about "to" connection sent to "re" connection
to_status TEXT NOT NULL DEFAULT '', -- '', INV, CON
re_conn BLOB NOT NULL REFERENCES connections (conn_alias) ON DELETE CASCADE,
re_info BLOB NOT NULL, -- info about "re" connection sent to "to" connection
re_status TEXT NOT NULL DEFAULT '', -- '', INV, CON
queue_info BLOB
) WITHOUT ROWID;
CREATE TABLE conn_invitations (
inv_id BLOB NOT NULL PRIMARY KEY,
via_conn BLOB REFERENCES connections (conn_alias) ON DELETE SET NULL,
external_intro_id BLOB NOT NULL,
conn_info BLOB, -- info about another connection
queue_info BLOB, -- NULL if it's an initial introduction
conn_id BLOB REFERENCES connections (conn_alias) -- created connection
ON DELETE CASCADE
DEFERRABLE INITIALLY DEFERRED,
status TEXT DEFAULT '' -- '', 'ACPT', 'CON'
) WITHOUT ROWID;
ALTER TABLE connections
ADD via_inv BLOB REFERENCES conn_invitations (inv_id) ON DELETE RESTRICT;
ALTER TABLE connections
ADD conn_level INTEGER DEFAULT 0;

View File

@@ -61,7 +61,7 @@ import UnliftIO.STM
-- | Runs an SMP agent as a TCP service using passed configuration.
--
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
runSMPAgent :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
runSMPAgent t cfg = do
started <- newEmptyTMVarIO
runSMPAgentBlocking t started cfg
@@ -70,10 +70,10 @@ runSMPAgent t cfg = do
--
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
runSMPAgentBlocking :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
where
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
smpAgent :: forall c m'. (Transport c, MonadFail m', MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do
liftIO $ putLn h "Welcome to SMP v0.3.2 agent"
c <- getSMPAgentClient
@@ -97,7 +97,7 @@ logConnection c connected =
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
runSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
runSMPAgentClient :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
runSMPAgentClient c = do
db <- asks $ dbFile . config
s1 <- liftIO $ connectSQLiteStore db
@@ -128,7 +128,7 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a ->
logClient AgentClient {clientId} dir (ATransmission corrId entity cmd) = do
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, serializeEntity entity, B.takeWhile (/= ' ') $ serializeCommand cmd]
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
client :: forall m. (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
client c@AgentClient {rcvQ, sndQ} st = forever loop
where
loop :: m ()
@@ -147,6 +147,8 @@ withStore action = do
Right c -> return c
Left e -> throwError $ storeError e
where
-- TODO when parsing exception happens in store, the agent hangs;
-- changing SQLError to SomeException does not help
handleInternal :: (MonadError StoreError m') => SQLError -> m' a
handleInternal e = throwError . SEInternal $ bshow e
storeError :: StoreError -> AgentErrorType
@@ -173,23 +175,28 @@ unsupportedEntity c _ corrId entity _ =
processConnCommand ::
forall c m. (AgentMonad m, EntityCommand 'Conn_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> ACommand 'Client c -> m ()
processConnCommand c@AgentClient {sndQ} st corrId conn = \case
NEW -> createNewConnection conn
JOIN smpQueueInfo replyMode -> joinConnection conn smpQueueInfo replyMode
processConnCommand c@AgentClient {sndQ} st corrId conn@(Conn connId) = \case
NEW -> createNewConnection Nothing 0 >>= uncurry respond
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode Nothing 0 >> pure () -- >>= (`respond` OK)
INTRO reEntity reInfo -> makeIntroduction reEntity reInfo
ACPT inv eInfo -> acceptInvitation inv eInfo
SUB -> subscribeConnection conn
SUBALL -> subscribeAll
SEND msgBody -> sendMessage c st corrId conn msgBody
SEND msgBody -> sendClientMessage c st corrId conn msgBody
OFF -> suspendConnection conn
DEL -> deleteConnection conn
where
createNewConnection :: Entity 'Conn_ -> m ()
createNewConnection (Conn cId) = do
createNewConnection :: Maybe InvitationId -> Int -> m (Entity 'Conn_, ACommand 'Agent 'INV_)
createNewConnection viaInv connLevel = do
-- TODO create connection alias if not passed
-- make connAlias Maybe?
-- make connId Maybe?
srv <- getSMPServer
(rq, qInfo) <- newReceiveQueue c srv cId
withStore $ createRcvConn st rq
respond conn $ INV qInfo
(rq, qInfo) <- newReceiveQueue c srv
g <- asks idsDrg
let cData = ConnData {connId, viaInv, connLevel}
connId' <- withStore $ createRcvConn st g cData rq
addSubscription c rq connId'
pure (Conn connId', INV qInfo)
getSMPServer :: m SMPServer
getSMPServer =
@@ -200,16 +207,47 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
pure $ servers L.!! i
joinConnection :: Entity 'Conn_ -> SMPQueueInfo -> ReplyMode -> m ()
joinConnection (Conn cId) qInfo (ReplyMode replyMode) = do
-- TODO create connection alias if not passed
-- make connAlias Maybe?
(sq, senderKey, verifyKey) <- newSendQueue qInfo cId
withStore $ createSndConn st sq
joinConnection :: SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m (Entity 'Conn_)
joinConnection qInfo (ReplyMode replyMode) viaInv connLevel = do
(sq, senderKey, verifyKey) <- newSendQueue qInfo
g <- asks idsDrg
let cData = ConnData {connId, viaInv, connLevel}
connId' <- withStore $ createSndConn st g cData sq
connectToSendQueue c st sq senderKey verifyKey
when (replyMode == On) $ createReplyQueue cId sq
-- TODO this response is disabled to avoid two responses in terminal client (OK + CON),
-- respond conn OK
when (replyMode == On) $ createReplyQueue connId' sq
pure $ Conn connId'
makeIntroduction :: IntroEntity -> EntityInfo -> m ()
makeIntroduction (IE reEntity) reInfo = case reEntity of
Conn reConn ->
withStore ((,) <$> getConn st connId <*> getConn st reConn) >>= \case
(SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do
g <- asks idsDrg
introId <- withStore $ createIntro st g NewIntroduction {toConn = connId, reConn, reInfo}
sendControlMessage c sq $ A_INTRO (IE (Conn introId)) reInfo
respond conn OK
_ -> throwError $ CONN SIMPLEX
_ -> throwError $ CMD UNSUPPORTED
acceptInvitation :: IntroEntity -> EntityInfo -> m ()
acceptInvitation (IE invEntity) eInfo = case invEntity of
Conn invId -> do
withStore (getInvitation st invId) >>= \case
Invitation {viaConn, qInfo, externalIntroId, status = InvNew} ->
withStore (getConn st viaConn) >>= \case
SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of
Nothing -> do
(conn', INV qInfo') <- createNewConnection (Just invId) (connLevel + 1)
withStore $ addInvitationConn st invId $ fromConn conn'
sendControlMessage c sq $ A_INV (Conn externalIntroId) qInfo' eInfo
respond conn' OK
Just qInfo' -> do
conn' <- joinConnection qInfo' (ReplyMode On) (Just invId) (connLevel + 1)
withStore $ addInvitationConn st invId $ fromConn conn'
respond conn' OK
_ -> throwError $ CONN SIMPLEX
_ -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD UNSUPPORTED
subscribeConnection :: Entity 'Conn_ -> m ()
subscribeConnection conn'@(Conn cId) =
@@ -222,7 +260,7 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
-- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct
subscribeAll :: m ()
subscribeAll = withStore (getAllConnAliases st) >>= mapM_ (subscribeConnection . Conn)
subscribeAll = withStore (getAllConnIds st) >>= mapM_ (subscribeConnection . Conn)
suspendConnection :: Entity 'Conn_ -> m ()
suspendConnection (Conn cId) =
@@ -246,25 +284,30 @@ processConnCommand c@AgentClient {sndQ} st corrId conn = \case
removeSubscription c cId
delConn
createReplyQueue :: ByteString -> SndQueue -> m ()
createReplyQueue :: ConnId -> SndQueue -> m ()
createReplyQueue cId sq = do
srv <- getSMPServer
(rq, qInfo) <- newReceiveQueue c srv cId
(rq, qInfo) <- newReceiveQueue c srv
addSubscription c rq cId
withStore $ upgradeSndConnToDuplex st cId rq
senderTimestamp <- liftIO getCurrentTime
sendAgentMessage c sq . serializeSMPMessage $
SMPMessage
{ senderMsgId = 0,
senderTimestamp,
previousMsgHash = "",
agentMessage = REPLY qInfo
}
sendControlMessage c sq $ REPLY qInfo
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
respond ent resp = atomically . writeTBQueue sndQ $ ATransmission corrId ent resp
sendMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
sendMessage c st corrId (Conn cId) msgBody =
sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m ()
sendControlMessage c sq agentMessage = do
senderTimestamp <- liftIO getCurrentTime
sendAgentMessage c sq . serializeSMPMessage $
SMPMessage
{ senderMsgId = 0,
senderTimestamp,
previousMsgHash = "",
agentMessage
}
sendClientMessage :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Conn_ -> MsgBody -> m ()
sendClientMessage c st corrId (Conn cId) msgBody =
withStore (getConn st cId) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
SomeConn _ (SndConnection _ sq) -> sendMsg sq
@@ -273,7 +316,7 @@ sendMessage c st corrId (Conn cId) msgBody =
sendMsg :: SndQueue -> m ()
sendMsg sq = do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st sq
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st cId
let msgStr =
serializeSMPMessage
SMPMessage
@@ -284,7 +327,7 @@ sendMessage c st corrId (Conn cId) msgBody =
}
msgHash = C.sha256Hash msgStr
withStore $
createSndMsg st sq $
createSndMsg st cId $
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
sendAgentMessage c sq msgStr
atomically . writeTBQueue (sndQ c) $ ATransmission corrId (Conn cId) $ SENT (unId internalId)
@@ -292,15 +335,21 @@ sendMessage c st corrId (Conn cId) msgBody =
processBroadcastCommand ::
forall c m. (AgentMonad m, EntityCommand 'Broadcast_ c) => AgentClient -> SQLiteStore -> ACorrId -> Entity 'Broadcast_ -> ACommand 'Client c -> m ()
processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
NEW -> withStore (createBcast st bId) >> ok
NEW -> createNewBroadcast
ADD (Conn cId) -> withStore (addBcastConn st bId cId) >> ok
REM (Conn cId) -> withStore (removeBcastConn st bId cId) >> ok
LS -> withStore (getBcast st bId) >>= respond bcast . MS . map Conn
SEND msgBody -> withStore (getBcast st bId) >>= mapM_ (sendMsg msgBody) >> respond bcast (SENT 0)
DEL -> withStore (deleteBcast st bId) >> ok
where
sendMsg :: MsgBody -> ConnAlias -> m ()
sendMsg msgBody cId = sendMessage c st corrId (Conn cId) msgBody
createNewBroadcast :: m ()
createNewBroadcast = do
g <- asks idsDrg
bId' <- withStore $ createBcast st g bId
respond (Broadcast bId') OK
sendMsg :: MsgBody -> ConnId -> m ()
sendMsg msgBody cId = sendClientMessage c st corrId (Conn cId) msgBody
ok :: m ()
ok = respond bcast OK
@@ -308,7 +357,7 @@ processBroadcastCommand c st corrId bcast@(Broadcast bId) = \case
respond :: EntityCommand t c' => Entity t -> ACommand 'Agent c' -> m ()
respond ent resp = atomically . writeTBQueue (sndQ c) $ ATransmission corrId ent resp
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
subscriber :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
subscriber c@AgentClient {msgQ} st = forever $ do
-- TODO this will only process messages and notifications
t <- atomically $ readTBQueue msgQ
@@ -319,29 +368,33 @@ subscriber c@AgentClient {msgQ} st = forever $ do
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
withStore (getRcvConn st srv rId) >>= \case
SomeConn SCDuplex (DuplexConnection _ rq _) -> processSMP SCDuplex rq
SomeConn SCRcv (RcvConnection _ rq) -> processSMP SCRcv rq
SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq
SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq
_ -> atomically . writeTBQueue sndQ $ ATransmission "" (Conn "") (ERR $ CONN SIMPLEX)
where
processSMP :: SConnType c -> RcvQueue -> m ()
processSMP cType rq@RcvQueue {connAlias, status} =
processSMP :: SConnType c -> ConnData -> RcvQueue -> m ()
processSMP cType ConnData {connId} rq@RcvQueue {status} =
case cmd of
SMP.MSG srvMsgId srvTs msgBody -> do
-- TODO deduplicate with previously received
msg <- decryptAndVerify rq msgBody
let msgHash = C.sha256Hash msg
agentMsg <- liftEither $ parseSMPMessage msg
case agentMsg of
SMPConfirmation senderKey -> smpConfirmation senderKey
SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
case parseSMPMessage msg of
Left e -> notify $ ERR e
Right (SMPConfirmation senderKey) -> smpConfirmation senderKey
Right SMPMessage {agentMessage, senderMsgId, senderTimestamp, previousMsgHash} ->
case agentMessage of
HELLO verifyKey _ -> helloMsg verifyKey msgBody
REPLY qInfo -> replyMsg qInfo
A_MSG body -> agentClientMsg previousMsgHash (senderMsgId, senderTimestamp) (srvMsgId, srvTs) body msgHash
A_INTRO entity eInfo -> introMsg entity eInfo
A_INV conn qInfo cInfo -> invMsg conn qInfo cInfo
A_REQ conn qInfo cInfo -> reqMsg conn qInfo cInfo
A_CON conn -> conMsg conn
sendAck c rq
return ()
SMP.END -> do
removeSubscription c connAlias
removeSubscription c connId
logServer "<--" c srv rId "END"
notify END
_ -> do
@@ -349,7 +402,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
notify . ERR $ BROKER UNEXPECTED
where
notify :: EntityCommand 'Conn_ c => ACommand 'Agent c -> m ()
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connAlias) msg
notify msg = atomically . writeTBQueue sndQ $ ATransmission "" (Conn connId) msg
prohibited :: m ()
prohibited = notify . ERR $ AGENT A_PROHIBITED
@@ -376,7 +429,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
void $ verifyMessage (Just verifyKey) msgBody
withStore $ setRcvQueueActive st rq verifyKey
case cType of
SCDuplex -> notify CON
SCDuplex -> connected
_ -> pure ()
replyMsg :: SMPQueueInfo -> m ()
@@ -384,22 +437,87 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
logServer "<--" c srv rId "MSG <REPLY>"
case cType of
SCRcv -> do
(sq, senderKey, verifyKey) <- newSendQueue qInfo connAlias
withStore $ upgradeRcvConnToDuplex st connAlias sq
(sq, senderKey, verifyKey) <- newSendQueue qInfo
withStore $ upgradeRcvConnToDuplex st connId sq
connectToSendQueue c st sq senderKey verifyKey
notify CON
connected
_ -> prohibited
connected :: m ()
connected = do
withStore (getConnInvitation st connId) >>= \case
Just (Invitation {invId, externalIntroId}, DuplexConnection _ _ sq) -> do
withStore $ setInvitationStatus st invId InvCon
sendControlMessage c sq $ A_CON (Conn externalIntroId)
_ -> pure ()
notify CON
introMsg :: IntroEntity -> EntityInfo -> m ()
introMsg (IE entity) entityInfo = do
logServer "<--" c srv rId "MSG <INTRO>"
case (cType, entity) of
(SCDuplex, intro@Conn {}) -> createInv intro Nothing entityInfo
_ -> prohibited
invMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
invMsg (Conn introId) qInfo toInfo = do
logServer "<--" c srv rId "MSG <INV>"
case cType of
SCDuplex ->
withStore (getIntro st introId) >>= \case
Introduction {toConn, toStatus = IntroNew, reConn, reStatus = IntroNew}
| toConn /= connId -> prohibited
| otherwise ->
withStore (addIntroInvitation st introId toInfo qInfo >> getConn st reConn) >>= \case
SomeConn _ (DuplexConnection _ _ sq) -> do
sendControlMessage c sq $ A_REQ (Conn introId) qInfo toInfo
withStore $ setIntroReStatus st introId IntroInv
_ -> prohibited
_ -> prohibited
_ -> prohibited
reqMsg :: Entity 'Conn_ -> SMPQueueInfo -> EntityInfo -> m ()
reqMsg intro qInfo connInfo = do
logServer "<--" c srv rId "MSG <REQ>"
case cType of
SCDuplex -> createInv intro (Just qInfo) connInfo
_ -> prohibited
createInv :: Entity 'Conn_ -> Maybe SMPQueueInfo -> EntityInfo -> m ()
createInv (Conn externalIntroId) qInfo entityInfo = do
g <- asks idsDrg
let newInv = NewInvitation {viaConn = connId, externalIntroId, entityInfo, qInfo}
invId <- withStore $ createInvitation st g newInv
notify $ REQ (IE (Conn invId)) entityInfo
conMsg :: Entity 'Conn_ -> m ()
conMsg (Conn introId) = do
logServer "<--" c srv rId "MSG <CON>"
withStore (getIntro st introId) >>= \case
Introduction {toConn, toStatus, reConn, reStatus}
| toConn == connId && toStatus == IntroInv -> do
withStore $ setIntroToStatus st introId IntroCon
sendConMsg toConn reConn reStatus
| reConn == connId && reStatus == IntroInv -> do
withStore $ setIntroReStatus st introId IntroCon
sendConMsg toConn reConn toStatus
| otherwise -> prohibited
where
sendConMsg :: ConnId -> ConnId -> IntroStatus -> m ()
sendConMsg toConn reConn IntroCon =
atomically . writeTBQueue sndQ $ ATransmission "" (Conn toConn) $ ICON $ IE (Conn reConn)
sendConMsg _ _ _ = pure ()
agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
logServer "<--" c srv rId "MSG <MSG>"
case status of
Active -> do
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st rq
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- withStore $ updateRcvIds st connId
let msgIntegrity = checkMsgIntegrity prevExtSndId (fst senderMeta) prevRcvMsgHash receivedPrevMsgHash
withStore $
createRcvMsg st rq $
createRcvMsg st connId $
RcvMsgData
{ internalId,
internalRcvId,
@@ -438,8 +556,8 @@ connectToSendQueue c st sq senderKey verifyKey = do
withStore $ setSndQueueStatus st sq Active
newSendQueue ::
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> ConnAlias -> m (SndQueue, SenderPublicKey, VerificationKey)
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey)
newSendQueue (SMPQueueInfo smpServer senderId encryptKey) = do
size <- asks $ rsaKeySize . config
(senderKey, sndPrivateKey) <- liftIO $ C.generateKeyPair size
(verifyKey, signKey) <- liftIO $ C.generateKeyPair size
@@ -447,7 +565,6 @@ newSendQueue (SMPQueueInfo smpServer senderId encryptKey) connAlias = do
SndQueue
{ server = smpServer,
sndId = senderId,
connAlias,
sndPrivateKey,
encryptKey,
signKey,

View File

@@ -15,6 +15,7 @@ module Simplex.Messaging.Agent.Client
closeSMPServerClients,
newReceiveQueue,
subscribeQueue,
addSubscription,
sendConfirmation,
sendHello,
secureQueue,
@@ -61,8 +62,8 @@ data AgentClient = AgentClient
sndQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue SMPServerTransmission,
smpClients :: TVar (Map SMPServer SMPClient),
subscrSrvrs :: TVar (Map SMPServer (Set ConnAlias)),
subscrConns :: TVar (Map ConnAlias SMPServer),
subscrSrvrs :: TVar (Map SMPServer (Set ConnId)),
subscrConns :: TVar (Map ConnId SMPServer),
clientId :: Int
}
@@ -78,7 +79,7 @@ newAgentClient cc AgentConfig {tbqSize} = do
writeTVar cc clientId
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m, MonadFail m)
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
@@ -106,7 +107,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
removeSubs >>= mapM_ (mapM_ notifySub)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeSubs :: IO (Maybe (Set ConnAlias))
removeSubs :: IO (Maybe (Set ConnId))
removeSubs = atomically $ do
modifyTVar smpClients $ M.delete srv
cs <- M.lookup srv <$> readTVar (subscrSrvrs c)
@@ -117,7 +118,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
deleteKeys :: Ord k => Set k -> Map k a -> Map k a
deleteKeys ks m = S.foldr' M.delete m ks
notifySub :: ConnAlias -> IO ()
notifySub :: ConnId -> IO ()
notifySub connAlias = atomically . writeTBQueue (sndQ c) $ ATransmission "" (Conn connAlias) END
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
@@ -158,8 +159,8 @@ smpClientError = \case
SMPTransportError e -> BROKER $ TRANSPORT e
e -> INTERNAL $ show e
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> ConnAlias -> m (RcvQueue, SMPQueueInfo)
newReceiveQueue c srv connAlias = do
newReceiveQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueInfo)
newReceiveQueue c srv = do
size <- asks $ rsaKeySize . config
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateKeyPair size
logServer "-->" c srv "" "NEW"
@@ -170,7 +171,6 @@ newReceiveQueue c srv connAlias = do
RcvQueue
{ server = srv,
rcvId,
connAlias,
rcvPrivateKey,
sndId = Just sId,
sndKey = Nothing,
@@ -178,25 +178,24 @@ newReceiveQueue c srv connAlias = do
verifyKey = Nothing,
status = New
}
addSubscription c rq connAlias
return (rq, SMPQueueInfo srv sId encryptKey)
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnAlias -> m ()
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connAlias = do
withLogSMP c server rcvId "SUB" $ \smp ->
subscribeSMPQueue smp rcvPrivateKey rcvId
addSubscription c rq connAlias
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnAlias -> m ()
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c RcvQueue {server} connAlias = atomically $ do
modifyTVar (subscrConns c) $ M.insert connAlias server
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
where
addSub :: Maybe (Set ConnAlias) -> Set ConnAlias
addSub :: Maybe (Set ConnId) -> Set ConnId
addSub (Just cs) = S.insert connAlias cs
addSub _ = S.singleton connAlias
removeSubscription :: AgentMonad m => AgentClient -> ConnAlias -> m ()
removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m ()
removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically $ do
cs <- readTVar subscrConns
writeTVar subscrConns $ M.delete connAlias cs
@@ -204,7 +203,7 @@ removeSubscription AgentClient {subscrConns, subscrSrvrs} connAlias = atomically
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
(M.lookup connAlias cs)
where
delSub :: Set ConnAlias -> Maybe (Set ConnAlias)
delSub :: Set ConnId -> Maybe (Set ConnId)
delSub cs =
let cs' = S.delete connAlias cs
in if S.null cs' then Nothing else Just cs'

View File

@@ -33,6 +33,8 @@ module Simplex.Messaging.Agent.Protocol
Entity (..),
EntityTag (..),
AnEntity (..),
IntroEntity (..),
EntityInfo,
EntityCommand,
entityCommand,
ACommand (..),
@@ -53,7 +55,7 @@ module Simplex.Messaging.Agent.Protocol
ATransmission (..),
ATransmissionOrError (..),
ARawTransmission,
ConnAlias,
ConnId,
ReplyMode (..),
AckMode (..),
OnOff (..),
@@ -171,6 +173,13 @@ deriving instance Eq (Entity t)
deriving instance Show (Entity t)
instance TestEquality Entity where
testEquality (Conn c) (Conn c') = refl c c'
testEquality (OpenConn c) (OpenConn c') = refl c c'
testEquality (Broadcast c) (Broadcast c') = refl c c'
testEquality (AGroup c) (AGroup c') = refl c c'
testEquality _ _ = Nothing
entityId :: Entity t -> ByteString
entityId = \case
Conn bs -> bs
@@ -195,7 +204,11 @@ type family EntityCommand (t :: EntityTag) (c :: ACmdTag) :: Constraint where
EntityCommand Conn_ NEW_ = ()
EntityCommand Conn_ INV_ = ()
EntityCommand Conn_ JOIN_ = ()
EntityCommand Conn_ INTRO_ = ()
EntityCommand Conn_ REQ_ = ()
EntityCommand Conn_ ACPT_ = ()
EntityCommand Conn_ CON_ = ()
EntityCommand Conn_ ICON_ = ()
EntityCommand Conn_ SUB_ = ()
EntityCommand Conn_ SUBALL_ = ()
EntityCommand Conn_ END_ = ()
@@ -226,7 +239,11 @@ entityCommand = \case
NEW -> Just Dict
INV _ -> Just Dict
JOIN {} -> Just Dict
INTRO {} -> Just Dict
REQ {} -> Just Dict
ACPT {} -> Just Dict
CON -> Just Dict
ICON {} -> Just Dict
SUB -> Just Dict
SUBALL -> Just Dict
END -> Just Dict
@@ -258,7 +275,11 @@ data ACmdTag
= NEW_
| INV_
| JOIN_
| INTRO_
| REQ_
| ACPT_
| CON_
| ICON_
| SUB_
| SUBALL_
| END_
@@ -274,15 +295,31 @@ data ACmdTag
| OK_
| ERR_
type family Introduction (t :: EntityTag) :: Constraint where
Introduction Conn_ = ()
Introduction OpenConn_ = ()
Introduction AGroup_ = ()
Introduction t = (Int ~ Bool, TypeError (Text "Entity " :<>: ShowType t :<>: Text " cannot be INTRO'd to"))
data IntroEntity = forall t. Introduction t => IE (Entity t)
instance Eq IntroEntity where
IE e1 == IE e2 = isJust $ testEquality e1 e2
deriving instance Show IntroEntity
type EntityInfo = ByteString
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
data ACommand (p :: AParty) (c :: ACmdTag) where
NEW :: ACommand Client NEW_ -- response INV
INV :: SMPQueueInfo -> ACommand Agent INV_
JOIN :: SMPQueueInfo -> ReplyMode -> ACommand Client JOIN_ -- response OK
INTRO :: IntroEntity -> EntityInfo -> ACommand Client INTRO_
REQ :: IntroEntity -> EntityInfo -> ACommand Agent INTRO_
ACPT :: IntroEntity -> EntityInfo -> ACommand Client ACPT_
CON :: ACommand Agent CON_ -- notification that connection is established
-- TODO currently it automatically allows whoever sends the confirmation
-- CONF :: OtherPartyId -> ACommand Agent
-- LET :: OtherPartyId -> ACommand Client
ICON :: IntroEntity -> ACommand Agent ICON_
SUB :: ACommand Client SUB_
SUBALL :: ACommand Client SUBALL_ -- TODO should be moved to chat protocol - hack for subscribing to all
END :: ACommand Agent END_
@@ -318,6 +355,7 @@ instance TestEquality (ACommand p) where
testEquality c@INV {} c'@INV {} = refl c c'
testEquality c@JOIN {} c'@JOIN {} = refl c c'
testEquality CON CON = Just Refl
testEquality c@ICON {} c'@ICON {} = refl c c'
testEquality SUB SUB = Just Refl
testEquality SUBALL SUBALL = Just Refl
testEquality END END = Just Refl
@@ -334,7 +372,7 @@ instance TestEquality (ACommand p) where
testEquality c@ERR {} c'@ERR {} = refl c c'
testEquality _ _ = Nothing
refl :: Eq (f a) => f a -> f a -> Maybe (a :~: a)
refl :: Eq a => a -> a -> Maybe (t :~: t)
refl x x' = if x == x' then Just Refl else Nothing
-- | SMP message formats.
@@ -366,6 +404,14 @@ data AMessage where
REPLY :: SMPQueueInfo -> AMessage
-- | agent envelope for the client message
A_MSG :: MsgBody -> AMessage
-- | agent message for introduction
A_INTRO :: IntroEntity -> EntityInfo -> AMessage
-- | agent envelope for the sent invitation
A_INV :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
-- | agent envelope for the forwarded invitation
A_REQ :: Entity Conn_ -> SMPQueueInfo -> EntityInfo -> AMessage
-- | agent message for intro/group request
A_CON :: Entity Conn_ -> AMessage
deriving (Show)
-- | Parse SMP message.
@@ -408,12 +454,22 @@ agentMessageP =
"HELLO " *> hello
<|> "REPLY " *> reply
<|> "MSG " *> a_msg
<|> "INTRO " *> a_intro
<|> "INV " *> a_inv
<|> "REQ " *> a_req
<|> "CON " *> a_con
where
hello = HELLO <$> C.pubKeyP <*> ackMode
reply = REPLY <$> smpQueueInfoP
a_msg = do
a_msg = A_MSG <$> binaryBody
a_intro = A_INTRO <$> introEntityP <* A.space <*> binaryBody
a_inv = invP A_INV
a_req = invP A_REQ
a_con = A_CON <$> connEntityP
invP f = f <$> connEntityP <* A.space <*> smpQueueInfoP <* A.space <*> binaryBody
binaryBody = do
size :: Int <- A.decimal <* A.endOfLine
A_MSG <$> A.take size <* A.endOfLine
A.take size <* A.endOfLine
ackMode = AckMode <$> (" NO_ACK" $> Off <|> pure On)
-- | SMP queue information parser.
@@ -434,6 +490,13 @@ serializeAgentMessage = \case
HELLO verifyKey ackMode -> "HELLO " <> C.serializePubKey verifyKey <> if ackMode == AckMode Off then " NO_ACK" else ""
REPLY qInfo -> "REPLY " <> serializeSmpQueueInfo qInfo
A_MSG body -> "MSG " <> serializeMsg body <> "\n"
A_INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo <> "\n"
A_INV conn qInfo eInfo -> "INV " <> serializeInv conn qInfo eInfo
A_REQ conn qInfo eInfo -> "REQ " <> serializeInv conn qInfo eInfo
A_CON conn -> "CON " <> serializeEntity conn
where
serializeInv conn qInfo eInfo =
B.intercalate " " [serializeEntity conn, serializeSmpQueueInfo qInfo, serializeMsg eInfo] <> "\n"
-- | Serialize SMP queue information that is sent out-of-band.
serializeSmpQueueInfo :: SMPQueueInfo -> ByteString
@@ -457,7 +520,7 @@ instance IsString SMPServer where
fromString = parseString . parseAll $ smpServerP
-- | SMP agent connection alias.
type ConnAlias = ByteString
type ConnId = ByteString
-- | Connection modes.
data OnOff = On | Off deriving (Eq, Show, Read)
@@ -614,10 +677,19 @@ anEntityP =
<|> "B:" $> AE . Broadcast
<|> "G:" $> AE . AGroup
)
<*> A.takeTill (== ' ')
<*> A.takeTill wordEnd
entityConnP :: Parser (Entity Conn_)
entityConnP = "C:" *> (Conn <$> A.takeTill (== ' '))
connEntityP :: Parser (Entity Conn_)
connEntityP = "C:" *> (Conn <$> A.takeTill wordEnd)
introEntityP :: Parser IntroEntity
introEntityP =
($)
<$> ( "C:" $> IE . Conn
<|> "O:" $> IE . OpenConn
<|> "G:" $> IE . AGroup
)
<*> A.takeTill wordEnd
serializeEntity :: Entity t -> ByteString
serializeEntity = \case
@@ -632,6 +704,9 @@ commandP =
"NEW" $> ACmd SClient NEW
<|> "INV " *> invResp
<|> "JOIN " *> joinCmd
<|> "INTRO " *> introCmd
<|> "REQ " *> reqCmd
<|> "ACPT " *> acptCmd
<|> "SUB" $> ACmd SClient SUB
<|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all
<|> "END" $> ACmd SAgent END
@@ -645,16 +720,21 @@ commandP =
<|> "LS" $> ACmd SClient LS
<|> "MS " *> membersResp
<|> "ERR " *> agentError
<|> "ICON " *> iconMsg
<|> "CON" $> ACmd SAgent CON
<|> "OK" $> ACmd SAgent OK
where
invResp = ACmd SAgent . INV <$> smpQueueInfoP
joinCmd = ACmd SClient <$> (JOIN <$> smpQueueInfoP <*> replyMode)
introCmd = ACmd SClient <$> introP INTRO
reqCmd = ACmd SAgent <$> introP REQ
acptCmd = ACmd SClient <$> introP ACPT
sendCmd = ACmd SClient . SEND <$> A.takeByteString
sentResp = ACmd SAgent . SENT <$> A.decimal
addCmd = ACmd SClient . ADD <$> entityConnP
removeCmd = ACmd SClient . REM <$> entityConnP
membersResp = ACmd SAgent . MS <$> (entityConnP `A.sepBy'` A.char ' ')
addCmd = ACmd SClient . ADD <$> connEntityP
removeCmd = ACmd SClient . REM <$> connEntityP
membersResp = ACmd SAgent . MS <$> (connEntityP `A.sepBy'` A.char ' ')
iconMsg = ACmd SAgent . ICON <$> introEntityP
message = do
msgIntegrity <- msgIntegrityP <* A.space
recipientMeta <- "R=" *> partyMeta A.decimal
@@ -662,6 +742,7 @@ commandP =
senderMeta <- "S=" *> partyMeta A.decimal
msgBody <- A.takeByteString
return $ ACmd SAgent MSG {recipientMeta, brokerMeta, senderMeta, msgIntegrity, msgBody}
introP f = f <$> introEntityP <* A.space <*> A.takeByteString
replyMode = ReplyMode <$> (" NO_REPLY" $> Off <|> pure On)
partyMeta idParser = (,) <$> idParser <* "," <*> tsISO8601P <* A.space
agentError = ACmd SAgent . ERR <$> agentErrorTypeP
@@ -685,6 +766,9 @@ serializeCommand = \case
NEW -> "NEW"
INV qInfo -> "INV " <> serializeSmpQueueInfo qInfo
JOIN qInfo rMode -> "JOIN " <> serializeSmpQueueInfo qInfo <> replyMode rMode
INTRO (IE entity) eInfo -> "INTRO " <> serializeIntro entity eInfo
REQ (IE entity) eInfo -> "REQ " <> serializeIntro entity eInfo
ACPT (IE entity) eInfo -> "ACPT " <> serializeIntro entity eInfo
SUB -> "SUB"
SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all
END -> "END"
@@ -706,6 +790,7 @@ serializeCommand = \case
LS -> "LS"
MS cs -> "MS " <> B.intercalate " " (map serializeEntity cs)
CON -> "CON"
ICON (IE entity) -> "ICON " <> serializeEntity entity
ERR e -> "ERR " <> serializeAgentError e
OK -> "OK"
where
@@ -716,6 +801,9 @@ serializeCommand = \case
showTs :: UTCTime -> ByteString
showTs = B.pack . formatISO8601Millis
serializeIntro :: Entity t -> ByteString -> ByteString
serializeIntro entity eInfo = serializeEntity entity <> " " <> serializeMsg eInfo
-- | Serialize message integrity validation result.
serializeMsgIntegrity :: MsgIntegrity -> ByteString
serializeMsgIntegrity = \case
@@ -794,9 +882,10 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
hasEntityId :: AnEntity -> APartyCmd p -> Either AgentErrorType (APartyCmd p)
hasEntityId (AE entity) (APartyCmd cmd) =
APartyCmd <$> case cmd of
-- NEW and JOIN have optional entity
-- NEW, JOIN and ACPT have optional entity
NEW -> Right cmd
JOIN _ _ -> Right cmd
JOIN {} -> Right cmd
ACPT {} -> Right cmd
-- ERROR response does not always have entity
ERR _ -> Right cmd
-- other responses must have entity
@@ -818,6 +907,9 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
APartyCmd <$$> case cmd of
SEND body -> SEND <$$> getMsgBody body
MSG agentMsgId srvTS agentTS integrity body -> MSG agentMsgId srvTS agentTS integrity <$$> getMsgBody body
INTRO entity eInfo -> INTRO entity <$$> getMsgBody eInfo
REQ entity eInfo -> REQ entity <$$> getMsgBody eInfo
ACPT entity eInfo -> ACPT entity <$$> getMsgBody eInfo
_ -> pure $ Right cmd
-- TODO refactor with server

View File

@@ -3,16 +3,21 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
module Simplex.Messaging.Agent.Store where
import Control.Concurrent.STM (TVar)
import Control.Exception (Exception)
import Crypto.Random (ChaChaDRG)
import Data.ByteString.Char8 (ByteString)
import Data.Int (Int64)
import Data.Kind (Type)
import Data.Text (Text)
import Data.Time (UTCTime)
import Data.Type.Equality
import Simplex.Messaging.Agent.Protocol
@@ -30,33 +35,45 @@ import qualified Simplex.Messaging.Protocol as SMP
-- | Store class type. Defines store access methods for implementations.
class Monad m => MonadAgentStore s m where
-- Queue and Connection management
createRcvConn :: s -> RcvQueue -> m ()
createSndConn :: s -> SndQueue -> m ()
getConn :: s -> ConnAlias -> m SomeConn
getAllConnAliases :: s -> m [ConnAlias] -- TODO remove - hack for subscribing to all
createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
getConn :: s -> ConnId -> m SomeConn
getAllConnIds :: s -> m [ConnId] -- TODO remove - hack for subscribing to all
getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn
deleteConn :: s -> ConnAlias -> m ()
upgradeRcvConnToDuplex :: s -> ConnAlias -> SndQueue -> m ()
upgradeSndConnToDuplex :: s -> ConnAlias -> RcvQueue -> m ()
deleteConn :: s -> ConnId -> m ()
upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m ()
upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m ()
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
setRcvQueueActive :: s -> RcvQueue -> VerificationKey -> m ()
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
-- Msg management
updateRcvIds :: s -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
createRcvMsg :: s -> RcvQueue -> RcvMsgData -> m ()
updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
createRcvMsg :: s -> ConnId -> RcvMsgData -> m ()
updateSndIds :: s -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
createSndMsg :: s -> SndQueue -> SndMsgData -> m ()
updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
createSndMsg :: s -> ConnId -> SndMsgData -> m ()
getMsg :: s -> ConnAlias -> InternalId -> m Msg
getMsg :: s -> ConnId -> InternalId -> m Msg
-- Broadcasts
createBcast :: s -> BroadcastId -> m ()
addBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
removeBcastConn :: s -> BroadcastId -> ConnAlias -> m ()
createBcast :: s -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
addBcastConn :: s -> BroadcastId -> ConnId -> m ()
removeBcastConn :: s -> BroadcastId -> ConnId -> m ()
deleteBcast :: s -> BroadcastId -> m ()
getBcast :: s -> BroadcastId -> m [ConnAlias]
getBcast :: s -> BroadcastId -> m [ConnId]
-- Introductions
createIntro :: s -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
getIntro :: s -> IntroId -> m Introduction
addIntroInvitation :: s -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
setIntroToStatus :: s -> IntroId -> IntroStatus -> m ()
setIntroReStatus :: s -> IntroId -> IntroStatus -> m ()
createInvitation :: s -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
getInvitation :: s -> InvitationId -> m Invitation
addInvitationConn :: s -> InvitationId -> ConnId -> m ()
getConnInvitation :: s -> ConnId -> m (Maybe (Invitation, Connection CDuplex))
setInvitationStatus :: s -> InvitationId -> InvitationStatus -> m ()
-- * Queue types
@@ -64,7 +81,6 @@ class Monad m => MonadAgentStore s m where
data RcvQueue = RcvQueue
{ server :: SMPServer,
rcvId :: SMP.RecipientId,
connAlias :: ConnAlias,
rcvPrivateKey :: RecipientPrivateKey,
sndId :: Maybe SMP.SenderId,
sndKey :: Maybe SenderPublicKey,
@@ -78,7 +94,6 @@ data RcvQueue = RcvQueue
data SndQueue = SndQueue
{ server :: SMPServer,
sndId :: SMP.SenderId,
connAlias :: ConnAlias,
sndPrivateKey :: SenderPrivateKey,
encryptKey :: EncryptionKey,
signKey :: SignatureKey,
@@ -102,9 +117,9 @@ data ConnType = CRcv | CSnd | CDuplex deriving (Eq, Show)
-- - DuplexConnection is a connection that has both receive and send queues set up,
-- typically created by upgrading a receive or a send connection with a missing queue.
data Connection (d :: ConnType) where
RcvConnection :: ConnAlias -> RcvQueue -> Connection CRcv
SndConnection :: ConnAlias -> SndQueue -> Connection CSnd
DuplexConnection :: ConnAlias -> RcvQueue -> SndQueue -> Connection CDuplex
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
SndConnection :: ConnData -> SndQueue -> Connection CSnd
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
deriving instance Eq (Connection d)
@@ -141,6 +156,9 @@ instance Eq SomeConn where
deriving instance Show SomeConn
data ConnData = ConnData {connId :: ConnId, viaInv :: Maybe InvitationId, connLevel :: Int}
deriving (Eq, Show)
-- * Message integrity validation types
type MsgHash = ByteString
@@ -263,7 +281,7 @@ type DeliveredTs = UTCTime
-- | Base message data independent of direction.
data MsgBase = MsgBase
{ connAlias :: ConnAlias,
{ connAlias :: ConnId,
-- | Monotonically increasing id of a message per connection, internal to the agent.
-- Internal Id preserves ordering between both received and sent messages, and is needed
-- to track the order of the conversation (which can be different for the sender / receiver)
@@ -281,12 +299,87 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show)
type InternalTs = UTCTime
-- * Introduction types
data NewIntroduction = NewIntroduction
{ toConn :: ConnId,
reConn :: ConnId,
reInfo :: ByteString
}
data Introduction = Introduction
{ introId :: IntroId,
toConn :: ConnId,
toInfo :: Maybe ByteString,
toStatus :: IntroStatus,
reConn :: ConnId,
reInfo :: ByteString,
reStatus :: IntroStatus,
qInfo :: Maybe SMPQueueInfo
}
data IntroStatus = IntroNew | IntroInv | IntroCon
deriving (Eq)
serializeIntroStatus :: IntroStatus -> Text
serializeIntroStatus = \case
IntroNew -> ""
IntroInv -> "INV"
IntroCon -> "CON"
introStatusT :: Text -> Maybe IntroStatus
introStatusT = \case
"" -> Just IntroNew
"INV" -> Just IntroInv
"CON" -> Just IntroCon
_ -> Nothing
type IntroId = ByteString
data NewInvitation = NewInvitation
{ viaConn :: ConnId,
externalIntroId :: IntroId,
entityInfo :: EntityInfo,
qInfo :: Maybe SMPQueueInfo
}
data Invitation = Invitation
{ invId :: InvitationId,
viaConn :: ConnId,
externalIntroId :: IntroId,
entityInfo :: EntityInfo,
qInfo :: Maybe SMPQueueInfo,
connId :: Maybe ConnId,
status :: InvitationStatus
}
deriving (Show)
data InvitationStatus = InvNew | InvAcpt | InvCon
deriving (Eq, Show)
serializeInvStatus :: InvitationStatus -> Text
serializeInvStatus = \case
InvNew -> ""
InvAcpt -> "ACPT"
InvCon -> "CON"
invStatusT :: Text -> Maybe InvitationStatus
invStatusT = \case
"" -> Just InvNew
"ACPT" -> Just InvAcpt
"CON" -> Just InvCon
_ -> Nothing
type InvitationId = ByteString
-- * Store errors
-- | Agent store error.
data StoreError
= -- | IO exceptions in store actions.
SEInternal ByteString
| -- | failed to generate unique random ID
SEUniqueID
| -- | Connection alias not found (or both queues absent).
SEConnNotFound
| -- | Connection alias already used.
@@ -298,6 +391,10 @@ data StoreError
SEBcastNotFound
| -- | Broadcast ID already used.
SEBcastDuplicate
| -- | Introduction ID not found.
SEIntroNotFound
| -- | Invitation ID not found.
SEInvitationNotFound
| -- | Currently not used. The intention was to pass current expected queue status in methods,
-- as we always know what it should be at any stage of the protocol,
-- and in case it does not match use this error.

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
@@ -11,6 +12,7 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
@@ -22,14 +24,19 @@ module Simplex.Messaging.Agent.Store.SQLite
where
import Control.Concurrent (threadDelay)
import Control.Monad (unless, when)
import Control.Concurrent.STM (TVar, atomically, stateTVar)
import Control.Monad (join, unless, when)
import Control.Monad.Except (MonadError (throwError), MonadIO (liftIO))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
import Data.Bifunctor (first)
import Data.ByteString (ByteString)
import Data.ByteString.Base64 (encode)
import Data.Char (toLower)
import Data.Functor (($>))
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Text (isPrefixOf)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, field)
@@ -66,8 +73,8 @@ createSQLiteStore dbFilePath = do
let dbDir = takeDirectory dbFilePath
createDirectoryIfMissing False dbDir
store <- connectSQLiteStore dbFilePath
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
let threadsafeOption = find (isPrefixOf "THREADSAFE=") (concat compileOptions)
compileOptions <- DB.query_ (dbConn store) "pragma COMPILE_OPTIONS;" :: IO [[Text]]
let threadsafeOption = find (T.isPrefixOf "THREADSAFE=") (concat compileOptions)
case threadsafeOption of
Just "THREADSAFE=0" -> confirmOrExit "SQLite compiled with non-threadsafe code."
Nothing -> putStrLn "Warning: SQLite THREADSAFE compile option not found"
@@ -107,13 +114,16 @@ connectSQLiteStore dbFilePath = do
|]
pure SQLiteStore {dbFilePath, dbConn, dbNew}
checkConstraint :: StoreError -> IO () -> IO (Either StoreError ())
checkConstraint err action = first handleError <$> E.try action
where
handleError :: SQLError -> StoreError
handleError e
| DB.sqlError e == DB.ErrorConstraint = err
| otherwise = SEInternal $ bshow e
checkConstraint :: StoreError -> IO a -> IO (Either StoreError a)
checkConstraint err action = first (handleSQLError err) <$> E.try action
checkConstraint' :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a)
checkConstraint' err action = action `E.catch` (pure . Left . handleSQLError err)
handleSQLError :: StoreError -> SQLError -> StoreError
handleSQLError err e
| DB.sqlError e == DB.ErrorConstraint = err
| otherwise = SEInternal $ bshow e
withTransaction :: forall a. DB.Connection -> IO a -> IO a
withTransaction db a = loop 100 100_000
@@ -128,33 +138,43 @@ withTransaction db a = loop 100 100_000
else E.throwIO e
instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteStore m where
createRcvConn :: SQLiteStore -> RcvQueue -> m ()
createRcvConn SQLiteStore {dbConn} q@RcvQueue {server} =
liftIOEither $
checkConstraint SEConnDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertRcvQueue_ dbConn q
insertRcvConnection_ dbConn q
createRcvConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> RcvQueue -> m ConnId
createRcvConn SQLiteStore {dbConn} gVar cData q@RcvQueue {server} =
-- TODO if schema has to be restarted, this function can be refactored
-- to create connection first using createWithRandomId
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
getConnId_ dbConn gVar cData >>= traverse create
where
create :: ConnId -> IO ConnId
create connId = do
upsertServer_ dbConn server
insertRcvQueue_ dbConn connId q
insertRcvConnection_ dbConn cData {connId} q
pure connId
createSndConn :: SQLiteStore -> SndQueue -> m ()
createSndConn SQLiteStore {dbConn} q@SndQueue {server} =
liftIOEither $
checkConstraint SEConnDuplicate $
withTransaction dbConn $ do
upsertServer_ dbConn server
insertSndQueue_ dbConn q
insertSndConnection_ dbConn q
createSndConn :: SQLiteStore -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
createSndConn SQLiteStore {dbConn} gVar cData q@SndQueue {server} =
-- TODO if schema has to be restarted, this function can be refactored
-- to create connection first using createWithRandomId
liftIOEither . checkConstraint' SEConnDuplicate . withTransaction dbConn $
getConnId_ dbConn gVar cData >>= traverse create
where
create :: ConnId -> IO ConnId
create connId = do
upsertServer_ dbConn server
insertSndQueue_ dbConn connId q
insertSndConnection_ dbConn cData {connId} q
pure connId
getConn :: SQLiteStore -> ConnAlias -> m SomeConn
getConn SQLiteStore {dbConn} connAlias =
getConn :: SQLiteStore -> ConnId -> m SomeConn
getConn SQLiteStore {dbConn} connId =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias
getConn_ dbConn connId
getAllConnAliases :: SQLiteStore -> m [ConnAlias]
getAllConnAliases SQLiteStore {dbConn} =
getAllConnIds :: SQLiteStore -> m [ConnId]
getAllConnIds SQLiteStore {dbConn} =
liftIO $ do
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnAlias]]
r <- DB.query_ dbConn "SELECT conn_alias FROM connections;" :: IO [[ConnId]]
return (concat r)
getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn
@@ -169,37 +189,37 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|]
[":host" := host, ":port" := serializePort_ port, ":rcv_id" := rcvId]
>>= \case
[Only connAlias] -> getConn_ dbConn connAlias
[Only connId] -> getConn_ dbConn connId
_ -> pure $ Left SEConnNotFound
deleteConn :: SQLiteStore -> ConnAlias -> m ()
deleteConn SQLiteStore {dbConn} connAlias =
deleteConn :: SQLiteStore -> ConnId -> m ()
deleteConn SQLiteStore {dbConn} connId =
liftIO $
DB.executeNamed
dbConn
"DELETE FROM connections WHERE conn_alias = :conn_alias;"
[":conn_alias" := connAlias]
[":conn_alias" := connId]
upgradeRcvConnToDuplex :: SQLiteStore -> ConnAlias -> SndQueue -> m ()
upgradeRcvConnToDuplex SQLiteStore {dbConn} connAlias sq@SndQueue {server} =
upgradeRcvConnToDuplex :: SQLiteStore -> ConnId -> SndQueue -> m ()
upgradeRcvConnToDuplex SQLiteStore {dbConn} connId sq@SndQueue {server} =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
getConn_ dbConn connId >>= \case
Right (SomeConn _ RcvConnection {}) -> do
upsertServer_ dbConn server
insertSndQueue_ dbConn sq
updateConnWithSndQueue_ dbConn connAlias sq
insertSndQueue_ dbConn connId sq
updateConnWithSndQueue_ dbConn connId sq
pure $ Right ()
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
upgradeSndConnToDuplex :: SQLiteStore -> ConnAlias -> RcvQueue -> m ()
upgradeSndConnToDuplex SQLiteStore {dbConn} connAlias rq@RcvQueue {server} =
upgradeSndConnToDuplex :: SQLiteStore -> ConnId -> RcvQueue -> m ()
upgradeSndConnToDuplex SQLiteStore {dbConn} connId rq@RcvQueue {server} =
liftIOEither . withTransaction dbConn $
getConn_ dbConn connAlias >>= \case
getConn_ dbConn connId >>= \case
Right (SomeConn _ SndConnection {}) -> do
upsertServer_ dbConn server
insertRcvQueue_ dbConn rq
updateConnWithRcvQueue_ dbConn connAlias rq
insertRcvQueue_ dbConn connId rq
updateConnWithRcvQueue_ dbConn connId rq
pure $ Right ()
Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c
_ -> pure $ Left SEConnNotFound
@@ -248,83 +268,233 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|]
[":status" := status, ":host" := host, ":port" := serializePort_ port, ":snd_id" := sndId]
updateRcvIds :: SQLiteStore -> RcvQueue -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
updateRcvIds SQLiteStore {dbConn} RcvQueue {connAlias} =
updateRcvIds :: SQLiteStore -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
updateRcvIds SQLiteStore {dbConn} connId =
liftIO . withTransaction dbConn $ do
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connAlias
(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ dbConn connId
let internalId = InternalId $ unId lastInternalId + 1
internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1
updateLastIdsRcv_ dbConn connAlias internalId internalRcvId
updateLastIdsRcv_ dbConn connId internalId internalRcvId
pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash)
createRcvMsg :: SQLiteStore -> RcvQueue -> RcvMsgData -> m ()
createRcvMsg SQLiteStore {dbConn} RcvQueue {connAlias} rcvMsgData =
createRcvMsg :: SQLiteStore -> ConnId -> RcvMsgData -> m ()
createRcvMsg SQLiteStore {dbConn} connId rcvMsgData =
liftIO . withTransaction dbConn $ do
insertRcvMsgBase_ dbConn connAlias rcvMsgData
insertRcvMsgDetails_ dbConn connAlias rcvMsgData
updateHashRcv_ dbConn connAlias rcvMsgData
insertRcvMsgBase_ dbConn connId rcvMsgData
insertRcvMsgDetails_ dbConn connId rcvMsgData
updateHashRcv_ dbConn connId rcvMsgData
updateSndIds :: SQLiteStore -> SndQueue -> m (InternalId, InternalSndId, PrevSndMsgHash)
updateSndIds SQLiteStore {dbConn} SndQueue {connAlias} =
updateSndIds :: SQLiteStore -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
updateSndIds SQLiteStore {dbConn} connId =
liftIO . withTransaction dbConn $ do
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connAlias
(lastInternalId, lastInternalSndId, prevSndHash) <- retrieveLastIdsAndHashSnd_ dbConn connId
let internalId = InternalId $ unId lastInternalId + 1
internalSndId = InternalSndId $ unSndId lastInternalSndId + 1
updateLastIdsSnd_ dbConn connAlias internalId internalSndId
updateLastIdsSnd_ dbConn connId internalId internalSndId
pure (internalId, internalSndId, prevSndHash)
createSndMsg :: SQLiteStore -> SndQueue -> SndMsgData -> m ()
createSndMsg SQLiteStore {dbConn} SndQueue {connAlias} sndMsgData =
createSndMsg :: SQLiteStore -> ConnId -> SndMsgData -> m ()
createSndMsg SQLiteStore {dbConn} connId sndMsgData =
liftIO . withTransaction dbConn $ do
insertSndMsgBase_ dbConn connAlias sndMsgData
insertSndMsgDetails_ dbConn connAlias sndMsgData
updateHashSnd_ dbConn connAlias sndMsgData
insertSndMsgBase_ dbConn connId sndMsgData
insertSndMsgDetails_ dbConn connId sndMsgData
updateHashSnd_ dbConn connId sndMsgData
getMsg :: SQLiteStore -> ConnAlias -> InternalId -> m Msg
getMsg :: SQLiteStore -> ConnId -> InternalId -> m Msg
getMsg _st _connAlias _id = throwError SENotImplemented
createBcast :: SQLiteStore -> BroadcastId -> m ()
createBcast SQLiteStore {dbConn} bId =
liftIOEither $
checkConstraint SEBcastDuplicate $
DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
createBcast :: SQLiteStore -> TVar ChaChaDRG -> BroadcastId -> m BroadcastId
createBcast SQLiteStore {dbConn} gVar bcastId = liftIOEither $ case bcastId of
"" -> createWithRandomId gVar create
bId -> checkConstraint SEBcastDuplicate $ create bId $> bId
where
create bId = DB.execute dbConn "INSERT INTO broadcasts (broadcast_id) VALUES (?);" (Only bId)
addBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
addBcastConn SQLiteStore {dbConn} bId connAlias =
addBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
addBcastConn SQLiteStore {dbConn} bId connId =
liftIOEither . checkBroadcast dbConn bId $
getConn_ dbConn connAlias >>= \case
getConn_ dbConn connId >>= \case
Left _ -> pure $ Left SEConnNotFound
Right (SomeConn _ RcvConnection {}) -> pure . Left $ SEBadConnType CRcv
Right _ ->
checkConstraint SEConnDuplicate $
DB.execute
dbConn
"INSERT INTO broadcast_connections (broadcast_id, conn_alias) VALUES (?, ?);"
(bId, connAlias)
[sql|
INSERT INTO broadcast_connections
(broadcast_id, conn_alias) VALUES (?, ?);
|]
(bId, connId)
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnAlias -> m ()
removeBcastConn SQLiteStore {dbConn} bId connAlias =
removeBcastConn :: SQLiteStore -> BroadcastId -> ConnId -> m ()
removeBcastConn SQLiteStore {dbConn} bId connId =
liftIOEither . checkBroadcast dbConn bId $
bcastConnExists_ dbConn bId connAlias >>= \case
bcastConnExists_ dbConn bId connId >>= \case
False -> pure $ Left SEConnNotFound
_ ->
Right
<$> DB.execute
dbConn
"DELETE FROM broadcast_connections WHERE broadcast_id = ? AND conn_alias = ?;"
(bId, connAlias)
[sql|
DELETE FROM broadcast_connections
WHERE broadcast_id = ? AND conn_alias = ?;
|]
(bId, connId)
deleteBcast :: SQLiteStore -> BroadcastId -> m ()
deleteBcast SQLiteStore {dbConn} bId =
liftIOEither . checkBroadcast dbConn bId $
Right <$> DB.execute dbConn "DELETE FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
getBcast :: SQLiteStore -> BroadcastId -> m [ConnAlias]
getBcast :: SQLiteStore -> BroadcastId -> m [ConnId]
getBcast SQLiteStore {dbConn} bId =
liftIOEither . checkBroadcast dbConn bId $
Right . map fromOnly
<$> DB.query dbConn "SELECT conn_alias FROM broadcast_connections WHERE broadcast_id = ?;" (Only bId)
createIntro :: SQLiteStore -> TVar ChaChaDRG -> NewIntroduction -> m IntroId
createIntro SQLiteStore {dbConn} gVar NewIntroduction {toConn, reConn, reInfo} =
liftIOEither . createWithRandomId gVar $ \introId ->
DB.execute
dbConn
[sql|
INSERT INTO conn_intros
(intro_id, to_conn, re_conn, re_info) VALUES (?, ?, ?, ?);
|]
(introId, toConn, reConn, reInfo)
getIntro :: SQLiteStore -> IntroId -> m Introduction
getIntro SQLiteStore {dbConn} introId =
liftIOEither $
intro
<$> DB.query
dbConn
[sql|
SELECT to_conn, to_info, to_status, re_conn, re_info, re_status, queue_info
FROM conn_intros
WHERE intro_id = ?;
|]
(Only introId)
where
intro [(toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo)] =
Right $ Introduction {introId, toConn, toInfo, toStatus, reConn, reInfo, reStatus, qInfo}
intro _ = Left SEIntroNotFound
addIntroInvitation :: SQLiteStore -> IntroId -> EntityInfo -> SMPQueueInfo -> m ()
addIntroInvitation SQLiteStore {dbConn} introId toInfo qInfo =
liftIO $
DB.executeNamed
dbConn
[sql|
UPDATE conn_intros
SET to_info = :to_info,
queue_info = :queue_info,
to_status = :to_status
WHERE intro_id = :intro_id;
|]
[ ":to_info" := toInfo,
":queue_info" := Just qInfo,
":to_status" := IntroInv,
":intro_id" := introId
]
setIntroToStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
setIntroToStatus SQLiteStore {dbConn} introId toStatus =
liftIO $
DB.execute
dbConn
[sql|
UPDATE conn_intros
SET to_status = ?
WHERE intro_id = ?;
|]
(toStatus, introId)
setIntroReStatus :: SQLiteStore -> IntroId -> IntroStatus -> m ()
setIntroReStatus SQLiteStore {dbConn} introId reStatus =
liftIO $
DB.execute
dbConn
[sql|
UPDATE conn_intros
SET re_status = ?
WHERE intro_id = ?;
|]
(reStatus, introId)
createInvitation :: SQLiteStore -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
createInvitation SQLiteStore {dbConn} gVar NewInvitation {viaConn, externalIntroId, entityInfo, qInfo} =
liftIOEither . createWithRandomId gVar $ \invId ->
DB.execute
dbConn
[sql|
INSERT INTO conn_invitations
(inv_id, via_conn, external_intro_id, conn_info, queue_info) VALUES (?, ?, ?, ?, ?);
|]
(invId, viaConn, externalIntroId, entityInfo, qInfo)
getInvitation :: SQLiteStore -> InvitationId -> m Invitation
getInvitation SQLiteStore {dbConn} invId =
liftIOEither $
invitation
<$> DB.query
dbConn
[sql|
SELECT via_conn, external_intro_id, conn_info, queue_info, conn_id, status
FROM conn_invitations
WHERE inv_id = ?;
|]
(Only invId)
where
invitation [(viaConn, externalIntroId, entityInfo, qInfo, connId, status)] =
Right $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId, status}
invitation _ = Left SEInvitationNotFound
addInvitationConn :: SQLiteStore -> InvitationId -> ConnId -> m ()
addInvitationConn SQLiteStore {dbConn} invId connId =
liftIO $
DB.executeNamed
dbConn
[sql|
UPDATE conn_invitations
SET conn_id = :conn_id, status = :status
WHERE inv_id = :inv_id;
|]
[":conn_id" := connId, ":status" := InvAcpt, ":inv_id" := invId]
getConnInvitation :: SQLiteStore -> ConnId -> m (Maybe (Invitation, Connection 'CDuplex))
getConnInvitation SQLiteStore {dbConn} cId =
liftIO . withTransaction dbConn $
DB.query
dbConn
[sql|
SELECT inv_id, via_conn, external_intro_id, conn_info, queue_info, status
FROM conn_invitations
WHERE conn_id = ?;
|]
(Only cId)
>>= fmap join . traverse getViaConn . invitation
where
invitation [(invId, viaConn, externalIntroId, entityInfo, qInfo, status)] =
Just $ Invitation {invId, viaConn, externalIntroId, entityInfo, qInfo, connId = Just cId, status}
invitation _ = Nothing
getViaConn :: Invitation -> IO (Maybe (Invitation, Connection 'CDuplex))
getViaConn inv@Invitation {viaConn} = fmap (inv,) . duplexConn <$> getConn_ dbConn viaConn
duplexConn :: Either StoreError SomeConn -> Maybe (Connection 'CDuplex)
duplexConn (Right (SomeConn SCDuplex conn)) = Just conn
duplexConn _ = Nothing
setInvitationStatus :: SQLiteStore -> InvitationId -> InvitationStatus -> m ()
setInvitationStatus SQLiteStore {dbConn} invId status =
liftIO $
DB.execute
dbConn
[sql|
UPDATE conn_invitations
SET status = ? WHERE inv_id = ?;
|]
(status, invId)
-- * Auxiliary helpers
-- ? replace with ToField? - it's easy to forget to use this
@@ -337,7 +507,7 @@ deserializePort_ port = Just port
instance ToField QueueStatus where toField = toField . show
instance FromField QueueStatus where fromField = fromFieldToReadable_
instance FromField QueueStatus where fromField = fromTextField_ $ readMaybe . T.unpack
instance ToField InternalRcvId where toField (InternalRcvId x) = toField x
@@ -359,13 +529,24 @@ instance ToField MsgIntegrity where toField = toField . serializeMsgIntegrity
instance FromField MsgIntegrity where fromField = blobFieldParser msgIntegrityP
fromFieldToReadable_ :: forall a. (Read a, E.Typeable a) => Field -> Ok a
fromFieldToReadable_ = \case
instance ToField IntroStatus where toField = toField . serializeIntroStatus
instance FromField IntroStatus where fromField = fromTextField_ introStatusT
instance ToField InvitationStatus where toField = toField . serializeInvStatus
instance FromField InvitationStatus where fromField = fromTextField_ invStatusT
instance ToField SMPQueueInfo where toField = toField . serializeSmpQueueInfo
instance FromField SMPQueueInfo where fromField = blobFieldParser smpQueueInfoP
fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ fromText = \case
f@(Field (SQLText t) _) ->
let str = T.unpack t
in case readMaybe str of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid string: " <> str)
case fromText t of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
f -> returnError ConversionFailed f "expecting SQLText column type"
{- ORMOLU_DISABLE -}
@@ -397,8 +578,8 @@ upsertServer_ dbConn SMPServer {host, port, keyHash} = do
-- * createRcvConn helpers
insertRcvQueue_ :: DB.Connection -> RcvQueue -> IO ()
insertRcvQueue_ dbConn RcvQueue {..} = do
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
insertRcvQueue_ dbConn connId RcvQueue {..} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -411,7 +592,7 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
[ ":host" := host server,
":port" := port_,
":rcv_id" := rcvId,
":conn_alias" := connAlias,
":conn_alias" := connId,
":rcv_private_key" := rcvPrivateKey,
":snd_id" := sndId,
":snd_key" := sndKey,
@@ -420,26 +601,29 @@ insertRcvQueue_ dbConn RcvQueue {..} = do
":status" := status
]
insertRcvConnection_ :: DB.Connection -> RcvQueue -> IO ()
insertRcvConnection_ dbConn RcvQueue {server, rcvId, connAlias} = do
insertRcvConnection_ :: DB.Connection -> ConnData -> RcvQueue -> IO ()
insertRcvConnection_ dbConn ConnData {connId, viaInv, connLevel} RcvQueue {server, rcvId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
[sql|
INSERT INTO connections
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
VALUES
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL,
0, 0, 0, 0, x'', x'');
(:conn_alias,:rcv_host,:rcv_port,:rcv_id, NULL, NULL, NULL, :via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|]
[":conn_alias" := connAlias, ":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId]
[ ":conn_alias" := connId,
":rcv_host" := host server,
":rcv_port" := port_,
":rcv_id" := rcvId,
":via_inv" := viaInv,
":conn_level" := connLevel
]
-- * createSndConn helpers
insertSndQueue_ :: DB.Connection -> SndQueue -> IO ()
insertSndQueue_ dbConn SndQueue {..} = do
insertSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
insertSndQueue_ dbConn connId SndQueue {..} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -452,85 +636,96 @@ insertSndQueue_ dbConn SndQueue {..} = do
[ ":host" := host server,
":port" := port_,
":snd_id" := sndId,
":conn_alias" := connAlias,
":conn_alias" := connId,
":snd_private_key" := sndPrivateKey,
":encrypt_key" := encryptKey,
":sign_key" := signKey,
":status" := status
]
insertSndConnection_ :: DB.Connection -> SndQueue -> IO ()
insertSndConnection_ dbConn SndQueue {server, sndId, connAlias} = do
insertSndConnection_ :: DB.Connection -> ConnData -> SndQueue -> IO ()
insertSndConnection_ dbConn ConnData {connId, viaInv, connLevel} SndQueue {server, sndId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
[sql|
INSERT INTO connections
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id,
last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id,
last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
( conn_alias, rcv_host, rcv_port, rcv_id, snd_host, snd_port, snd_id, via_inv, conn_level, last_internal_msg_id, last_internal_rcv_msg_id, last_internal_snd_msg_id, last_external_snd_msg_id, last_rcv_msg_hash, last_snd_msg_hash)
VALUES
(:conn_alias, NULL, NULL, NULL,:snd_host,:snd_port,:snd_id,
0, 0, 0, 0, x'', x'');
(:conn_alias, NULL, NULL, NULL, :snd_host,:snd_port,:snd_id,:via_inv,:conn_level, 0, 0, 0, 0, x'', x'');
|]
[":conn_alias" := connAlias, ":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId]
[ ":conn_alias" := connId,
":snd_host" := host server,
":snd_port" := port_,
":snd_id" := sndId,
":via_inv" := viaInv,
":conn_level" := connLevel
]
-- * getConn helpers
getConn_ :: DB.Connection -> ConnAlias -> IO (Either StoreError SomeConn)
getConn_ dbConn connAlias = do
rQ <- retrieveRcvQueueByConnAlias_ dbConn connAlias
sQ <- retrieveSndQueueByConnAlias_ dbConn connAlias
pure $ case (rQ, sQ) of
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connAlias rcvQ sndQ)
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connAlias rcvQ)
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connAlias sndQ)
_ -> Left SEConnNotFound
getConn_ :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn)
getConn_ dbConn connId =
getConnData_ dbConn connId >>= \case
Nothing -> pure $ Left SEConnNotFound
Just connData -> do
rQ <- getRcvQueueByConnAlias_ dbConn connId
sQ <- getSndQueueByConnAlias_ dbConn connId
pure $ case (rQ, sQ) of
(Just rcvQ, Just sndQ) -> Right $ SomeConn SCDuplex (DuplexConnection connData rcvQ sndQ)
(Just rcvQ, Nothing) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
(Nothing, Just sndQ) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
_ -> Left SEConnNotFound
retrieveRcvQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe RcvQueue)
retrieveRcvQueueByConnAlias_ dbConn connAlias = do
r <-
DB.queryNamed
getConnData_ :: DB.Connection -> ConnId -> IO (Maybe ConnData)
getConnData_ dbConn connId =
connData
<$> DB.query dbConn "SELECT via_inv, conn_level FROM connections WHERE conn_alias = ?;" (Only connId)
where
connData [(viaInv, connLevel)] = Just ConnData {connId, viaInv, connLevel}
connData _ = Nothing
getRcvQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe RcvQueue)
getRcvQueueByConnAlias_ dbConn connId =
rcvQueue
<$> DB.query
dbConn
[sql|
SELECT
s.key_hash, q.host, q.port, q.rcv_id, q.conn_alias, q.rcv_private_key,
SELECT s.key_hash, q.host, q.port, q.rcv_id, q.rcv_private_key,
q.snd_id, q.snd_key, q.decrypt_key, q.verify_key, q.status
FROM rcv_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.conn_alias = :conn_alias;
WHERE q.conn_alias = ?;
|]
[":conn_alias" := connAlias]
case r of
[(keyHash, host, port, rcvId, cAlias, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] -> do
(Only connId)
where
rcvQueue [(keyHash, host, port, rcvId, rcvPrivateKey, sndId, sndKey, decryptKey, verifyKey, status)] =
let srv = SMPServer host (deserializePort_ port) keyHash
return . Just $ RcvQueue srv rcvId cAlias rcvPrivateKey sndId sndKey decryptKey verifyKey status
_ -> return Nothing
in Just $ RcvQueue srv rcvId rcvPrivateKey sndId sndKey decryptKey verifyKey status
rcvQueue _ = Nothing
retrieveSndQueueByConnAlias_ :: DB.Connection -> ConnAlias -> IO (Maybe SndQueue)
retrieveSndQueueByConnAlias_ dbConn connAlias = do
r <-
DB.queryNamed
getSndQueueByConnAlias_ :: DB.Connection -> ConnId -> IO (Maybe SndQueue)
getSndQueueByConnAlias_ dbConn connId =
sndQueue
<$> DB.query
dbConn
[sql|
SELECT
s.key_hash, q.host, q.port, q.snd_id, q.conn_alias,
q.snd_private_key, q.encrypt_key, q.sign_key, q.status
SELECT s.key_hash, q.host, q.port, q.snd_id, q.snd_private_key, q.encrypt_key, q.sign_key, q.status
FROM snd_queues q
INNER JOIN servers s ON q.host = s.host AND q.port = s.port
WHERE q.conn_alias = :conn_alias;
WHERE q.conn_alias = ?;
|]
[":conn_alias" := connAlias]
case r of
[(keyHash, host, port, sndId, cAlias, sndPrivateKey, encryptKey, signKey, status)] -> do
(Only connId)
where
sndQueue [(keyHash, host, port, sndId, sndPrivateKey, encryptKey, signKey, status)] =
let srv = SMPServer host (deserializePort_ port) keyHash
return . Just $ SndQueue srv sndId cAlias sndPrivateKey encryptKey signKey status
_ -> return Nothing
in Just $ SndQueue srv sndId sndPrivateKey encryptKey signKey status
sndQueue _ = Nothing
-- * upgradeRcvConnToDuplex helpers
updateConnWithSndQueue_ :: DB.Connection -> ConnAlias -> SndQueue -> IO ()
updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
updateConnWithSndQueue_ :: DB.Connection -> ConnId -> SndQueue -> IO ()
updateConnWithSndQueue_ dbConn connId SndQueue {server, sndId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -539,12 +734,12 @@ updateConnWithSndQueue_ dbConn connAlias SndQueue {server, sndId} = do
SET snd_host = :snd_host, snd_port = :snd_port, snd_id = :snd_id
WHERE conn_alias = :conn_alias;
|]
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connAlias]
[":snd_host" := host server, ":snd_port" := port_, ":snd_id" := sndId, ":conn_alias" := connId]
-- * upgradeSndConnToDuplex helpers
updateConnWithRcvQueue_ :: DB.Connection -> ConnAlias -> RcvQueue -> IO ()
updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
updateConnWithRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
updateConnWithRcvQueue_ dbConn connId RcvQueue {server, rcvId} = do
let port_ = serializePort_ $ port server
DB.executeNamed
dbConn
@@ -553,12 +748,12 @@ updateConnWithRcvQueue_ dbConn connAlias RcvQueue {server, rcvId} = do
SET rcv_host = :rcv_host, rcv_port = :rcv_port, rcv_id = :rcv_id
WHERE conn_alias = :conn_alias;
|]
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connAlias]
[":rcv_host" := host server, ":rcv_port" := port_, ":rcv_id" := rcvId, ":conn_alias" := connId]
-- * updateRcvIds helpers
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
retrieveLastIdsAndHashRcv_ dbConn connAlias = do
retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
retrieveLastIdsAndHashRcv_ dbConn connId = do
[(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <-
DB.queryNamed
dbConn
@@ -567,11 +762,11 @@ retrieveLastIdsAndHashRcv_ dbConn connAlias = do
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
[":conn_alias" := connId]
return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)
updateLastIdsRcv_ :: DB.Connection -> ConnAlias -> InternalId -> InternalRcvId -> IO ()
updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO ()
updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId =
DB.executeNamed
dbConn
[sql|
@@ -582,13 +777,13 @@ updateLastIdsRcv_ dbConn connAlias newInternalId newInternalRcvId =
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_rcv_msg_id" := newInternalRcvId,
":conn_alias" := connAlias
":conn_alias" := connId
]
-- * createRcvMsg helpers
insertRcvMsgBase_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
insertRcvMsgBase_ dbConn connId RcvMsgData {..} = do
DB.executeNamed
dbConn
[sql|
@@ -597,15 +792,15 @@ insertRcvMsgBase_ dbConn connAlias RcvMsgData {..} = do
VALUES
(:conn_alias,:internal_id,:internal_ts,:internal_rcv_id, NULL,:body);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_id" := internalId,
":internal_ts" := internalTs,
":internal_rcv_id" := internalRcvId,
":body" := decodeUtf8 msgBody
]
insertRcvMsgDetails_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
insertRcvMsgDetails_ dbConn connId RcvMsgData {..} =
DB.executeNamed
dbConn
[sql|
@@ -618,7 +813,7 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
:broker_id,:broker_ts,:rcv_status, NULL, NULL,
:internal_hash,:external_prev_snd_hash,:integrity);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_rcv_id" := internalRcvId,
":internal_id" := internalId,
":external_snd_id" := fst senderMeta,
@@ -631,8 +826,8 @@ insertRcvMsgDetails_ dbConn connAlias RcvMsgData {..} =
":integrity" := msgIntegrity
]
updateHashRcv_ :: DB.Connection -> ConnAlias -> RcvMsgData -> IO ()
updateHashRcv_ dbConn connAlias RcvMsgData {..} =
updateHashRcv_ :: DB.Connection -> ConnId -> RcvMsgData -> IO ()
updateHashRcv_ dbConn connId RcvMsgData {..} =
DB.executeNamed
dbConn
-- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved
@@ -645,14 +840,14 @@ updateHashRcv_ dbConn connAlias RcvMsgData {..} =
|]
[ ":last_external_snd_msg_id" := fst senderMeta,
":last_rcv_msg_hash" := internalHash,
":conn_alias" := connAlias,
":conn_alias" := connId,
":last_internal_rcv_msg_id" := internalRcvId
]
-- * updateSndIds helpers
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnAlias -> IO (InternalId, InternalSndId, PrevSndMsgHash)
retrieveLastIdsAndHashSnd_ dbConn connAlias = do
retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (InternalId, InternalSndId, PrevSndMsgHash)
retrieveLastIdsAndHashSnd_ dbConn connId = do
[(lastInternalId, lastInternalSndId, lastSndHash)] <-
DB.queryNamed
dbConn
@@ -661,11 +856,11 @@ retrieveLastIdsAndHashSnd_ dbConn connAlias = do
FROM connections
WHERE conn_alias = :conn_alias;
|]
[":conn_alias" := connAlias]
[":conn_alias" := connId]
return (lastInternalId, lastInternalSndId, lastSndHash)
updateLastIdsSnd_ :: DB.Connection -> ConnAlias -> InternalId -> InternalSndId -> IO ()
updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO ()
updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId =
DB.executeNamed
dbConn
[sql|
@@ -676,13 +871,13 @@ updateLastIdsSnd_ dbConn connAlias newInternalId newInternalSndId =
|]
[ ":last_internal_msg_id" := newInternalId,
":last_internal_snd_msg_id" := newInternalSndId,
":conn_alias" := connAlias
":conn_alias" := connId
]
-- * createSndMsg helpers
insertSndMsgBase_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
insertSndMsgBase_ dbConn connId SndMsgData {..} = do
DB.executeNamed
dbConn
[sql|
@@ -691,15 +886,15 @@ insertSndMsgBase_ dbConn connAlias SndMsgData {..} = do
VALUES
(:conn_alias,:internal_id,:internal_ts, NULL,:internal_snd_id,:body);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_id" := internalId,
":internal_ts" := internalTs,
":internal_snd_id" := internalSndId,
":body" := decodeUtf8 msgBody
]
insertSndMsgDetails_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
insertSndMsgDetails_ dbConn connId SndMsgData {..} =
DB.executeNamed
dbConn
[sql|
@@ -708,15 +903,15 @@ insertSndMsgDetails_ dbConn connAlias SndMsgData {..} =
VALUES
(:conn_alias,:internal_snd_id,:internal_id,:snd_status, NULL, NULL,:internal_hash);
|]
[ ":conn_alias" := connAlias,
[ ":conn_alias" := connId,
":internal_snd_id" := internalSndId,
":internal_id" := internalId,
":snd_status" := Created,
":internal_hash" := internalHash
]
updateHashSnd_ :: DB.Connection -> ConnAlias -> SndMsgData -> IO ()
updateHashSnd_ dbConn connAlias SndMsgData {..} =
updateHashSnd_ :: DB.Connection -> ConnId -> SndMsgData -> IO ()
updateHashSnd_ dbConn connId SndMsgData {..} =
DB.executeNamed
dbConn
-- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved
@@ -727,7 +922,7 @@ updateHashSnd_ dbConn connAlias SndMsgData {..} =
AND last_internal_snd_msg_id = :last_internal_snd_msg_id;
|]
[ ":last_snd_msg_hash" := internalHash,
":conn_alias" := connAlias,
":conn_alias" := connId,
":last_internal_snd_msg_id" := internalSndId
]
@@ -745,10 +940,10 @@ bcastExists_ dbConn bId = not . null <$> queryBcast
queryBcast :: IO [Only BroadcastId]
queryBcast = DB.query dbConn "SELECT broadcast_id FROM broadcasts WHERE broadcast_id = ?;" (Only bId)
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnAlias -> IO Bool
bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
bcastConnExists_ :: DB.Connection -> BroadcastId -> ConnId -> IO Bool
bcastConnExists_ dbConn bId connId = not . null <$> queryBcastConn
where
queryBcastConn :: IO [(BroadcastId, ConnAlias)]
queryBcastConn :: IO [(BroadcastId, ConnId)]
queryBcastConn =
DB.query
dbConn
@@ -757,4 +952,37 @@ bcastConnExists_ dbConn bId connAlias = not . null <$> queryBcastConn
FROM broadcast_connections
WHERE broadcast_id = ? AND conn_alias = ?;
|]
(bId, connAlias)
(bId, connId)
-- create record with a random ID
getConnId_ :: DB.Connection -> TVar ChaChaDRG -> ConnData -> IO (Either StoreError ConnId)
getConnId_ dbConn gVar ConnData {connId = ""} = getUniqueRandomId gVar $ getConnData_ dbConn
getConnId_ _ _ ConnData {connId} = pure $ Right connId
getUniqueRandomId :: TVar ChaChaDRG -> (ByteString -> IO (Maybe a)) -> IO (Either StoreError ByteString)
getUniqueRandomId gVar get = tryGet 3
where
tryGet :: Int -> IO (Either StoreError ByteString)
tryGet 0 = pure $ Left SEUniqueID
tryGet n = do
id' <- randomId gVar 12
get id' >>= \case
Nothing -> pure $ Right id'
Just _ -> tryGet (n - 1)
createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString)
createWithRandomId gVar create = tryCreate 3
where
tryCreate :: Int -> IO (Either StoreError ByteString)
tryCreate 0 = pure $ Left SEUniqueID
tryCreate n = do
id' <- randomId gVar 12
E.try (create id') >>= \case
Right _ -> pure $ Right id'
Left e
| DB.sqlError e == DB.ErrorConstraint -> tryCreate (n - 1)
| otherwise -> pure . Left . SEInternal $ bshow e
randomId :: TVar ChaChaDRG -> Int -> IO ByteString
randomId gVar n = encode <$> (atomically . stateTVar gVar $ randomBytesGenerate n)

View File

@@ -30,7 +30,7 @@ base64StringP = do
pure $ str <> pad
tsISO8601P :: Parser UTCTime
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill (== ' ')
tsISO8601P = maybe (fail "timestamp") pure . parseISO8601 . B.unpack =<< A.takeTill wordEnd
parse :: Parser a -> e -> (ByteString -> Either e a)
parse parser err = first (const err) . parseAll parser
@@ -42,14 +42,17 @@ parseRead :: Read a => Parser ByteString -> Parser a
parseRead = (>>= maybe (fail "cannot read") pure . readMaybe . B.unpack)
parseRead1 :: Read a => Parser a
parseRead1 = parseRead $ A.takeTill (== ' ')
parseRead1 = parseRead $ A.takeTill wordEnd
parseRead2 :: Read a => Parser a
parseRead2 = parseRead $ do
w1 <- A.takeTill (== ' ') <* A.char ' '
w2 <- A.takeTill (== ' ')
w1 <- A.takeTill wordEnd <* A.char ' '
w2 <- A.takeTill wordEnd
pure $ w1 <> " " <> w2
wordEnd :: Char -> Bool
wordEnd c = c == ' ' || c == '\n'
parseString :: (ByteString -> Either String a) -> (String -> a)
parseString p = either error id . p . B.pack

View File

@@ -5,6 +5,7 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PostfixOperators #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
@@ -17,6 +18,7 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import SMPAgentClient
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.Store (InvitationId)
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..))
import System.Timeout
@@ -29,10 +31,16 @@ agentTests (ATransport t) = do
describe "Establishing duplex connection" do
it "should connect via one server and one agent" $
smpAgentTest2_1_1 $ testDuplexConnection t
it "should connect via one server and one agent (random IDs)" $
smpAgentTest2_1_1 $ testDuplexConnRandomIds t
it "should connect via one server and 2 agents" $
smpAgentTest2_2_1 $ testDuplexConnection t
it "should connect via one server and 2 agents (random IDs)" $
smpAgentTest2_2_1 $ testDuplexConnRandomIds t
it "should connect via 2 servers and 2 agents" $
smpAgentTest2_2_2 $ testDuplexConnection t
it "should connect via 2 servers and 2 agents (random IDs)" $
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
describe "Connection subscriptions" do
it "should connect via one server and one agent" $
smpAgentTest3_1_1 $ testSubscription t
@@ -41,6 +49,13 @@ agentTests (ATransport t) = do
describe "Broadcast" do
it "should create broadcast and send messages" $
smpAgentTest3 $ testBroadcast t
it "should create broadcast and send messages (random IDs)" $
smpAgentTest3 $ testBroadcastRandomIds t
describe "Introduction" do
it "should send and accept introduction" $
smpAgentTest3 $ testIntroduction t
it "should send and accept introduction (random IDs)" $
smpAgentTest3 $ testIntroductionRandomIds t
type TestTransmission p = (ACorrId, ByteString, APartyCmd p)
@@ -54,9 +69,13 @@ testTE (ATransmissionOrError corrId entity cmdOrErr) =
Right cmd -> Right $ APartyCmd cmd
Left e -> Left e
-- | receive message to handle `h`
(<#:) :: Transport c => c -> IO (TestTransmissionOrError 'Agent)
(<#:) h = testTE <$> tGet SAgent h
-- | send transmission `t` to handle `h` and get response
(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (TestTransmissionOrError 'Agent)
h #: t = tPutRaw h t >> testTE <$> tGet SAgent h
h #: t = tPutRaw h t >> (h <#:)
-- | action and expected response
-- `h #:t #> r` is the test that sends `t` to `h` and validates that the response is `r`
@@ -75,11 +94,11 @@ correctTransmission (corrId, cAlias, cmdOrErr) = case cmdOrErr of
-- | receive message to handle `h` and validate that it is the expected one
(<#) :: Transport c => c -> TestTransmission' 'Agent c' -> Expectation
h <# (corrId, cAlias, cmd) = tGet SAgent h >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd))) . testTE
h <# (corrId, cAlias, cmd) = (h <#:) >>= (`shouldBe` (corrId, cAlias, Right (APartyCmd cmd)))
-- | receive message to handle `h` and validate it using predicate `p`
(<#=) :: Transport c => c -> (TestTransmission 'Agent -> Bool) -> Expectation
h <#= p = tGet SAgent h >>= (`shouldSatisfy` p . correctTransmission . testTE)
h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission)
-- | test that nothing is delivered to handle `h` during 10ms
(#:#) :: Transport c => c -> String -> Expectation
@@ -90,53 +109,75 @@ h #:# err = tryGet `shouldReturn` ()
Just _ -> error err
_ -> return ()
pattern Msg :: MsgBody -> APartyCmd 'Agent
pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk}
pattern Sent :: AgentMsgId -> APartyCmd 'Agent
pattern Sent msgId <- APartyCmd (SENT msgId)
pattern Inv :: SMPQueueInfo -> APartyCmd 'Agent
pattern Inv invitation <- APartyCmd (INV invitation)
pattern Msg :: MsgBody -> APartyCmd 'Agent
pattern Msg msgBody <- APartyCmd MSG {msgBody, msgIntegrity = MsgOk}
pattern Inv :: SMPQueueInfo -> Either AgentErrorType (APartyCmd 'Agent)
pattern Inv invitation <- Right (APartyCmd (INV invitation))
pattern Req :: InvitationId -> EntityInfo -> Either AgentErrorType (APartyCmd 'Agent)
pattern Req invId eInfo <- Right (APartyCmd (REQ (IE (Conn invId)) eInfo))
testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO ()
testDuplexConnection _ alice bob = do
("1", "C:bob", Right (Inv qInfo)) <- alice #: ("1", "C:bob", "NEW")
("1", "C:bob", Inv qInfo) <- alice #: ("1", "C:bob", "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON)
alice <# ("", "C:bob", CON)
alice #: ("2", "C:bob", "SEND :hello") =#> \case ("2", "C:bob", Sent 1) -> True; _ -> False
alice #: ("3", "C:bob", "SEND :how are you?") =#> \case ("3", "C:bob", Sent 2) -> True; _ -> False
alice #: ("2", "C:bob", "SEND :hello") #> ("2", "C:bob", SENT 1)
alice #: ("3", "C:bob", "SEND :how are you?") #> ("3", "C:bob", SENT 2)
bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
bob <#= \case ("", "C:alice", Msg "how are you?") -> True; _ -> False
bob #: ("14", "C:alice", "SEND 9\nhello too") =#> \case ("14", "C:alice", Sent 3) -> True; _ -> False
bob #: ("14", "C:alice", "SEND 9\nhello too") #> ("14", "C:alice", SENT 3)
alice <#= \case ("", "C:bob", Msg "hello too") -> True; _ -> False
bob #: ("15", "C:alice", "SEND 9\nmessage 1") =#> \case ("15", "C:alice", Sent 4) -> True; _ -> False
bob #: ("15", "C:alice", "SEND 9\nmessage 1") #> ("15", "C:alice", SENT 4)
alice <#= \case ("", "C:bob", Msg "message 1") -> True; _ -> False
alice #: ("5", "C:bob", "OFF") #> ("5", "C:bob", OK)
bob #: ("17", "C:alice", "SEND 9\nmessage 3") #> ("17", "C:alice", ERR (SMP AUTH))
alice #: ("6", "C:bob", "DEL") #> ("6", "C:bob", OK)
alice #:# "nothing else should be delivered to alice"
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
testDuplexConnRandomIds _ alice bob = do
("1", bobConn, Inv qInfo) <- alice #: ("1", "C:", "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
("", aliceConn, Right (APartyCmd CON)) <- bob #: ("11", "C:", "JOIN " <> qInfo')
alice <# ("", bobConn, CON)
alice #: ("2", bobConn, "SEND :hello") #> ("2", bobConn, SENT 1)
alice #: ("3", bobConn, "SEND :how are you?") #> ("3", bobConn, SENT 2)
bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False
bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False
bob #: ("14", aliceConn, "SEND 9\nhello too") #> ("14", aliceConn, SENT 3)
alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False
bob #: ("15", aliceConn, "SEND 9\nmessage 1") #> ("15", aliceConn, SENT 4)
alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False
alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK)
bob #: ("17", aliceConn, "SEND 9\nmessage 3") #> ("17", aliceConn, ERR (SMP AUTH))
alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK)
alice #:# "nothing else should be delivered to alice"
testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO ()
testSubscription _ alice1 alice2 bob = do
("1", "C:bob", Right (Inv qInfo)) <- alice1 #: ("1", "C:bob", "NEW")
("1", "C:bob", Inv qInfo) <- alice1 #: ("1", "C:bob", "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
bob #: ("11", "C:alice", "JOIN " <> qInfo') #> ("", "C:alice", CON)
bob #: ("12", "C:alice", "SEND 5\nhello") =#> \case ("12", "C:alice", Sent _) -> True; _ -> False
bob #: ("13", "C:alice", "SEND 11\nhello again") =#> \case ("13", "C:alice", Sent _) -> True; _ -> False
bob #: ("12", "C:alice", "SEND 5\nhello") #> ("12", "C:alice", SENT 1)
bob #: ("13", "C:alice", "SEND 11\nhello again") #> ("13", "C:alice", SENT 2)
alice1 <# ("", "C:bob", CON)
alice1 <#= \case ("", "C:bob", Msg "hello") -> True; _ -> False
alice1 <#= \case ("", "C:bob", Msg "hello again") -> True; _ -> False
alice2 #: ("21", "C:bob", "SUB") #> ("21", "C:bob", OK)
alice1 <# ("", "C:bob", END)
bob #: ("14", "C:alice", "SEND 2\nhi") =#> \case ("14", "C:alice", Sent _) -> True; _ -> False
bob #: ("14", "C:alice", "SEND 2\nhi") #> ("14", "C:alice", SENT 3)
alice2 <#= \case ("", "C:bob", Msg "hi") -> True; _ -> False
alice1 #:# "nothing else should be delivered to alice1"
testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO ()
testSubscrNotification _ (server, _) client = do
client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", Inv _) -> True; _ -> False
client #: ("1", "C:conn1", "NEW") =#> \case ("1", "C:conn1", APartyCmd INV {}) -> True; _ -> False
client #:# "nothing should be delivered to client before the server is killed"
killThread server
client <# ("", "C:conn1", END)
@@ -156,8 +197,8 @@ testBroadcast _ alice bob tom = do
alice #: ("e3", "B:team", "ADD C:unknown") #> ("e3", "B:team", ERR $ CONN NOT_FOUND)
alice #: ("e4", "B:team", "ADD C:bob") #> ("e4", "B:team", ERR $ CONN DUPLICATE)
-- send message
alice #: ("4", "B:team", "SEND 5\nhello") #> ("4", "C:bob", SENT 1)
alice <# ("4", "C:tom", SENT 1)
alice #: ("4", "B:team", "SEND 5\nhello") =#> \case ("4", c, Sent 1) -> c == "C:bob" || c == "C:tom"; _ -> False
alice <#= \case ("4", c, Sent 1) -> c == "C:bob" || c == "C:tom"; _ -> False
alice <# ("4", "B:team", SENT 0)
bob <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
tom <#= \case ("", "C:alice", Msg "hello") -> True; _ -> False
@@ -177,13 +218,104 @@ testBroadcast _ alice bob tom = do
-- commands with errors
alice #: ("e8", "B:team", "DEL") #> ("e8", "B:team", ERR $ BCAST B_NOT_FOUND)
alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND)
where
connect :: (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Right (Inv qInfo)) <- h1 #: ("c1", "C:" <> name2, "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') =#> \case ("", c1, APartyCmd CON) -> c1 == "C:" <> name1; _ -> False
h1 <#= \case ("", c2, APartyCmd CON) -> c2 == "C:" <> name2; _ -> False
testBroadcastRandomIds :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
testBroadcastRandomIds _ alice bob tom = do
-- establish connections
(aliceB, bobA) <- alice `connect'` bob
(aliceT, tomA) <- alice `connect'` tom
-- create and set up broadcast
("1", team, Right (APartyCmd OK)) <- alice #: ("1", "B:", "NEW")
alice #: ("2", team, "ADD " <> bobA) #> ("2", team, OK)
alice #: ("3", team, "ADD " <> tomA) #> ("3", team, OK)
-- commands with errors
alice #: ("e1", team, "NEW") #> ("e1", team, ERR $ BCAST B_DUPLICATE)
alice #: ("e2", "B:group", "ADD " <> bobA) #> ("e2", "B:group", ERR $ BCAST B_NOT_FOUND)
alice #: ("e3", team, "ADD C:unknown") #> ("e3", team, ERR $ CONN NOT_FOUND)
alice #: ("e4", team, "ADD " <> bobA) #> ("e4", team, ERR $ CONN DUPLICATE)
-- send message
alice #: ("4", team, "SEND 5\nhello") =#> \case ("4", c, Sent 1) -> c == bobA || c == tomA; _ -> False
alice <#= \case ("4", c, Sent 1) -> c == bobA || c == tomA; _ -> False
alice <# ("4", team, SENT 0)
bob <#= \case ("", c, Msg "hello") -> c == aliceB; _ -> False
tom <#= \case ("", c, Msg "hello") -> c == aliceT; _ -> False
-- remove one connection
alice #: ("5", team, "REM " <> tomA) #> ("5", team, OK)
alice #: ("6", team, "SEND 11\nhello again") #> ("6", bobA, SENT 2)
alice <# ("6", team, SENT 0)
bob <#= \case ("", c, Msg "hello again") -> c == aliceB; _ -> False
tom #:# "nothing delivered to tom"
-- commands with errors
alice #: ("e5", "B:group", "REM " <> bobA) #> ("e5", "B:group", ERR $ BCAST B_NOT_FOUND)
alice #: ("e6", team, "REM C:unknown") #> ("e6", team, ERR $ CONN NOT_FOUND)
alice #: ("e7", team, "REM " <> tomA) #> ("e7", team, ERR $ CONN NOT_FOUND)
-- delete broadcast
alice #: ("7", team, "DEL") #> ("7", team, OK)
alice #: ("8", team, "SEND 11\ntry sending") #> ("8", team, ERR $ BCAST B_NOT_FOUND)
-- commands with errors
alice #: ("e8", team, "DEL") #> ("e8", team, ERR $ BCAST B_NOT_FOUND)
alice #: ("e9", "B:group", "DEL") #> ("e9", "B:group", ERR $ BCAST B_NOT_FOUND)
testIntroduction :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
testIntroduction _ alice bob tom = do
-- establish connections
(alice, "alice") `connect` (bob, "bob")
(alice, "alice") `connect` (tom, "tom")
-- send introduction of tom to bob
alice #: ("1", "C:bob", "INTRO C:tom 8\nmeet tom") #> ("1", "C:bob", OK)
("", "C:alice", Req invId1 "meet tom") <- (bob <#:)
bob #: ("2", "C:tom_via_alice", "ACPT C:" <> invId1 <> " 7\nI'm bob") #> ("2", "C:tom_via_alice", OK)
("", "C:alice", Req invId2 "I'm bob") <- (tom <#:)
-- TODO info "tom here" is not used, either JOIN command also should have eInfo parameter
-- or this should be another command, not ACPT
tom #: ("3", "C:bob_via_alice", "ACPT C:" <> invId2 <> " 8\ntom here") #> ("3", "C:bob_via_alice", OK)
tom <# ("", "C:bob_via_alice", CON)
bob <# ("", "C:tom_via_alice", CON)
alice <# ("", "C:bob", ICON (IE (Conn "tom")))
-- they can message each other now
tom #: ("4", "C:bob_via_alice", "SEND :hello") #> ("4", "C:bob_via_alice", SENT 1)
bob <#= \case ("", "C:tom_via_alice", Msg "hello") -> True; _ -> False
bob #: ("5", "C:tom_via_alice", "SEND 9\nhello too") #> ("5", "C:tom_via_alice", SENT 2)
tom <#= \case ("", "C:bob_via_alice", Msg "hello too") -> True; _ -> False
testIntroductionRandomIds :: forall c. Transport c => TProxy c -> c -> c -> c -> IO ()
testIntroductionRandomIds _ alice bob tom = do
-- establish connections
(aliceB, bobA) <- alice `connect'` bob
(aliceT, tomA) <- alice `connect'` tom
-- send introduction of tom to bob
alice #: ("1", bobA, "INTRO " <> tomA <> " 8\nmeet tom") #> ("1", bobA, OK)
("", aliceB', Req invId1 "meet tom") <- (bob <#:)
aliceB' `shouldBe` aliceB
("2", tomB, Right (APartyCmd OK)) <- bob #: ("2", "C:", "ACPT C:" <> invId1 <> " 7\nI'm bob")
("", aliceT', Req invId2 "I'm bob") <- (tom <#:)
aliceT' `shouldBe` aliceT
-- TODO info "tom here" is not used, either JOIN command also should have eInfo parameter
-- or this should be another command, not ACPT
("3", bobT, Right (APartyCmd OK)) <- tom #: ("3", "C:", "ACPT C:" <> invId2 <> " 8\ntom here")
tom <# ("", bobT, CON)
bob <# ("", tomB, CON)
alice <# ("", bobA, ICON . IE . Conn $ B.drop 2 tomA)
-- they can message each other now
tom #: ("4", bobT, "SEND :hello") #> ("4", bobT, SENT 1)
bob <#= \case ("", c, Msg "hello") -> c == tomB; _ -> False
bob #: ("5", tomB, "SEND 9\nhello too") #> ("5", tomB, SENT 2)
tom <#= \case ("", c, Msg "hello too") -> c == bobT; _ -> False
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Inv qInfo) <- h1 #: ("c1", "C:" <> name2, "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
h2 #: ("c2", "C:" <> name1, "JOIN " <> qInfo') #> ("", "C:" <> name1, CON)
h1 <# ("", "C:" <> name2, CON)
connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString)
connect' h1 h2 = do
("c1", conn2, Inv qInfo) <- h1 #: ("c1", "C:", "NEW")
let qInfo' = serializeSmpQueueInfo qInfo
("", conn1, Right (APartyCmd CON)) <- h2 #: ("c2", "C:", "JOIN " <> qInfo')
h1 <# ("", conn2, CON)
pure (conn1, conn2)
samplePublicKey :: ByteString
samplePublicKey = "rsa:MIIBoDANBgkqhkiG9w0BAQEFAAOCAY0AMIIBiAKCAQEAtn1NI2tPoOGSGfad0aUg0tJ0kG2nzrIPGLiz8wb3dQSJC9xkRHyzHhEE8Kmy2cM4q7rNZIlLcm4M7oXOTe7SC4x59bLQG9bteZPKqXu9wk41hNamV25PWQ4zIcIRmZKETVGbwN7jFMpH7wxLdI1zzMArAPKXCDCJ5ctWh4OWDI6OR6AcCtEj+toCI6N6pjxxn5VigJtwiKhxYpoUJSdNM60wVEDCSUrZYBAuDH8pOxPfP+Tm4sokaFDTIG3QJFzOjC+/9nW4MUjAOFll9PCp9kaEFHJ/YmOYKMWNOCCPvLS6lxA83i0UaardkNLNoFS5paWfTlroxRwOC2T6PwO2ywKBgDjtXcSED61zK1seocQMyGRINnlWdhceD669kIHju/f6kAayvYKW3/lbJNXCmyinAccBosO08/0sUxvtuniIo18kfYJE0UmP1ReCjhMP+O+yOmwZJini/QelJk/Pez8IIDDWnY1qYQsN/q7ocjakOYrpGG7mig6JMFpDJtD6istR"
@@ -205,7 +337,7 @@ syntaxTests t = do
-- TODO: ERROR no connection alias in the response (it does not generate it yet if not provided)
-- TODO: add tests with defined connection alias
it "using same server as in invitation" $
("311", "C:", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:", "ERR SMP AUTH")
("311", "C:a", "JOIN smp::localhost:5000::1234::" <> samplePublicKey) >#> ("311", "C:a", "ERR SMP AUTH")
describe "invalid" do
-- TODO: JOIN is not merged yet - to be added
it "no parameters" $ ("321", "C:", "JOIN") >#> ("321", "C:", "ERR CMD SYNTAX")

View File

@@ -8,9 +8,11 @@
module AgentTests.SQLiteTests (storeTests) where
import Control.Concurrent.Async (concurrently_)
import Control.Concurrent.STM (newTVarIO)
import Control.Monad (replicateM_)
import Control.Monad.Except (ExceptT, runExceptT)
import qualified Crypto.PubKey.RSA as R
import Crypto.Random (drgNew)
import Data.ByteString.Char8 (ByteString)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
@@ -73,11 +75,13 @@ storeTests = do
describe "Queue and Connection management" do
describe "createRcvConn" do
testCreateRcvConn
testCreateRcvConnRandomId
testCreateRcvConnDuplicate
describe "createSndConn" do
testCreateSndConn
testCreateSndConnRandomID
testCreateSndConnDuplicate
describe "getAllConnAliases" testGetAllConnAliases
describe "getAllConnIds" testGetAllConnIds
describe "getRcvConn" testGetRcvConn
describe "deleteConn" do
testDeleteRcvConn
@@ -104,14 +108,16 @@ storeTests = do
testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore)
testConcurrentWrites =
it "should complete multiple concurrent write transactions w/t sqlite busy errors" $ \(s1, s2) -> do
_ <- runExceptT $ createRcvConn s1 rcvQueue1
concurrently_ (runTest s1) (runTest s2)
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn s1 g cData1 rcvQueue1
let ConnData {connId} = cData1
concurrently_ (runTest s1 connId) (runTest s2 connId)
where
runTest :: SQLiteStore -> IO (Either StoreError ())
runTest store = runExceptT . replicateM_ 100 $ do
(internalId, internalRcvId, _, _) <- updateRcvIds store rcvQueue1
runTest :: SQLiteStore -> ConnId -> IO (Either StoreError ())
runTest store connId = runExceptT . replicateM_ 100 $ do
(internalId, internalRcvId, _, _) <- updateRcvIds store connId
let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy"
createRcvMsg store rcvQueue1 rcvMsgData
createRcvMsg store connId rcvMsgData
testCompiledThreadsafe :: SpecWith SQLiteStore
testCompiledThreadsafe =
@@ -132,12 +138,14 @@ testForeignKeysEnabled =
DB.execute_ (dbConn store) inconsistentQuery
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
cData1 :: ConnData
cData1 = ConnData {connId = "conn1", viaInv = Nothing, connLevel = 1}
rcvQueue1 :: RcvQueue
rcvQueue1 =
RcvQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
rcvId = "1234",
connAlias = "conn1",
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
sndId = Just "2345",
sndKey = Nothing,
@@ -151,7 +159,6 @@ sndQueue1 =
SndQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
sndId = "3456",
connAlias = "conn1",
sndPrivateKey = C.safePrivateKey (1, 2, 3),
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
signKey = C.safePrivateKey (1, 2, 3),
@@ -161,64 +168,95 @@ sndQueue1 =
testCreateRcvConn :: SpecWith SQLiteStore
testCreateRcvConn =
it "should create RcvConnection and add SndQueue" $ \store -> do
createRcvConn store rcvQueue1
`returnsResult` ()
g <- newTVarIO =<< drgNew
createRcvConn store g cData1 rcvQueue1
`returnsResult` "conn1"
getConn store "conn1"
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
upgradeRcvConnToDuplex store "conn1" sndQueue1
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
testCreateRcvConnRandomId :: SpecWith SQLiteStore
testCreateRcvConnRandomId =
it "should create RcvConnection and add SndQueue with random ID" $ \store -> do
g <- newTVarIO =<< drgNew
Right connId <- runExceptT $ createRcvConn store g cData1 {connId = ""} rcvQueue1
getConn store connId
`returnsResult` SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1)
upgradeRcvConnToDuplex store connId sndQueue1
`returnsResult` ()
getConn store connId
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
testCreateRcvConnDuplicate =
it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
createRcvConn store g cData1 rcvQueue1
`throwsError` SEConnDuplicate
testCreateSndConn :: SpecWith SQLiteStore
testCreateSndConn =
it "should create SndConnection and add RcvQueue" $ \store -> do
createSndConn store sndQueue1
`returnsResult` ()
g <- newTVarIO =<< drgNew
createSndConn store g cData1 sndQueue1
`returnsResult` "conn1"
getConn store "conn1"
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
upgradeSndConnToDuplex store "conn1" rcvQueue1
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
testCreateSndConnRandomID :: SpecWith SQLiteStore
testCreateSndConnRandomID =
it "should create SndConnection and add RcvQueue with random ID" $ \store -> do
g <- newTVarIO =<< drgNew
Right connId <- runExceptT $ createSndConn store g cData1 {connId = ""} sndQueue1
getConn store connId
`returnsResult` SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1)
upgradeSndConnToDuplex store connId rcvQueue1
`returnsResult` ()
getConn store connId
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)
testCreateSndConnDuplicate :: SpecWith SQLiteStore
testCreateSndConnDuplicate =
it "should throw error on attempt to create duplicate SndConnection" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
createSndConn store sndQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
createSndConn store g cData1 sndQueue1
`throwsError` SEConnDuplicate
testGetAllConnAliases :: SpecWith SQLiteStore
testGetAllConnAliases =
testGetAllConnIds :: SpecWith SQLiteStore
testGetAllConnIds =
it "should get all conn aliases" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
_ <- runExceptT $ createSndConn store sndQueue1 {connAlias = "conn2"}
getAllConnAliases store
`returnsResult` ["conn1" :: ConnAlias, "conn2" :: ConnAlias]
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
_ <- runExceptT $ createSndConn store g cData1 {connId = "conn2"} sndQueue1
getAllConnIds store
`returnsResult` ["conn1" :: ConnId, "conn2" :: ConnId]
testGetRcvConn :: SpecWith SQLiteStore
testGetRcvConn =
it "should get connection using rcv queue id and server" $ \store -> do
let smpServer = SMPServer "smp.simplex.im" (Just "5223") testKeyHash
let recipientId = "1234"
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
getRcvConn store smpServer recipientId
`returnsResult` SomeConn SCRcv (RcvConnection (connAlias (rcvQueue1 :: RcvQueue)) rcvQueue1)
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
testDeleteRcvConn :: SpecWith SQLiteStore
testDeleteRcvConn =
it "should create RcvConnection and delete it" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
getConn store "conn1"
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
deleteConn store "conn1"
`returnsResult` ()
-- TODO check queues are deleted as well
@@ -228,9 +266,10 @@ testDeleteRcvConn =
testDeleteSndConn :: SpecWith SQLiteStore
testDeleteSndConn =
it "should create SndConnection and delete it" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
deleteConn store "conn1"
`returnsResult` ()
-- TODO check queues are deleted as well
@@ -240,10 +279,11 @@ testDeleteSndConn =
testDeleteDuplexConn :: SpecWith SQLiteStore
testDeleteDuplexConn =
it "should create DuplexConnection and delete it" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
deleteConn store "conn1"
`returnsResult` ()
-- TODO check queues are deleted as well
@@ -253,12 +293,12 @@ testDeleteDuplexConn =
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
testUpgradeRcvConnToDuplex =
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
let anotherSndQueue =
SndQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
sndId = "2345",
connAlias = "conn1",
sndPrivateKey = C.safePrivateKey (1, 2, 3),
encryptKey = C.PublicKey $ R.PublicKey 1 2 3,
signKey = C.safePrivateKey (1, 2, 3),
@@ -273,12 +313,12 @@ testUpgradeRcvConnToDuplex =
testUpgradeSndConnToDuplex :: SpecWith SQLiteStore
testUpgradeSndConnToDuplex =
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
let anotherRcvQueue =
RcvQueue
{ server = SMPServer "smp.simplex.im" (Just "5223") testKeyHash,
rcvId = "3456",
connAlias = "conn1",
rcvPrivateKey = C.safePrivateKey (1, 2, 3),
sndId = Just "4567",
sndKey = Nothing,
@@ -295,40 +335,43 @@ testUpgradeSndConnToDuplex =
testSetRcvQueueStatus :: SpecWith SQLiteStore
testSetRcvQueueStatus =
it "should update status of RcvQueue" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
getConn store "conn1"
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1)
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
setRcvQueueStatus store rcvQueue1 Confirmed
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn SCRcv (RcvConnection "conn1" rcvQueue1 {status = Confirmed})
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1 {status = Confirmed})
testSetSndQueueStatus :: SpecWith SQLiteStore
testSetSndQueueStatus =
it "should update status of SndQueue" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1)
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
setSndQueueStatus store sndQueue1 Confirmed
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn SCSnd (SndConnection "conn1" sndQueue1 {status = Confirmed})
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1 {status = Confirmed})
testSetQueueStatusDuplex :: SpecWith SQLiteStore
testSetQueueStatusDuplex =
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 sndQueue1)
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
setRcvQueueStatus store rcvQueue1 Secured
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1)
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1)
setSndQueueStatus store sndQueue1 Confirmed
`returnsResult` ()
getConn store "conn1"
`returnsResult` SomeConn SCDuplex (DuplexConnection "conn1" rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
testSetRcvQueueStatusNoQueue :: SpecWith SQLiteStore
testSetRcvQueueStatusNoQueue =
@@ -362,20 +405,22 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
msgIntegrity = MsgOk
}
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> RcvQueue -> RcvMsgData -> Expectation
testCreateRcvMsg' store expectedPrevSndId expectedPrevHash rcvQueue rcvMsgData@RcvMsgData {..} = do
updateRcvIds store rcvQueue
testCreateRcvMsg' :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation
testCreateRcvMsg' st expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do
updateRcvIds st connId
`returnsResult` (internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
createRcvMsg store rcvQueue rcvMsgData
createRcvMsg st connId rcvMsgData
`returnsResult` ()
testCreateRcvMsg :: SpecWith SQLiteStore
testCreateRcvMsg =
it "should reserve internal ids and create a RcvMsg" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
it "should reserve internal ids and create a RcvMsg" $ \st -> do
g <- newTVarIO =<< drgNew
let ConnData {connId} = cData1
_ <- runExceptT $ createRcvConn st g cData1 rcvQueue1
-- TODO getMsg to check message
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
testCreateRcvMsg' store 1 "hash_dummy" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
testCreateRcvMsg' st 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
testCreateRcvMsg' st 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
mkSndMsgData internalId internalSndId internalHash =
@@ -387,29 +432,33 @@ mkSndMsgData internalId internalSndId internalHash =
internalHash
}
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> SndQueue -> SndMsgData -> Expectation
testCreateSndMsg' store expectedPrevHash sndQueue sndMsgData@SndMsgData {..} = do
updateSndIds store sndQueue
testCreateSndMsg' :: SQLiteStore -> PrevSndMsgHash -> ConnId -> SndMsgData -> Expectation
testCreateSndMsg' store expectedPrevHash connId sndMsgData@SndMsgData {..} = do
updateSndIds store connId
`returnsResult` (internalId, internalSndId, expectedPrevHash)
createSndMsg store sndQueue sndMsgData
createSndMsg store connId sndMsgData
`returnsResult` ()
testCreateSndMsg :: SpecWith SQLiteStore
testCreateSndMsg =
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do
_ <- runExceptT $ createSndConn store sndQueue1
g <- newTVarIO =<< drgNew
let ConnData {connId} = cData1
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
-- TODO getMsg to check message
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
testCreateSndMsg' store "hash_dummy" sndQueue1 $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
testCreateSndMsg' store "hash_dummy" connId $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
testCreateRcvAndSndMsgs :: SpecWith SQLiteStore
testCreateRcvAndSndMsgs =
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do
_ <- runExceptT $ createRcvConn store rcvQueue1
g <- newTVarIO =<< drgNew
let ConnData {connId} = cData1
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
testCreateRcvMsg' store 0 "" rcvQueue1 $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
testCreateRcvMsg' store 1 "rcv_hash_1" rcvQueue1 $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
testCreateSndMsg' store "" sndQueue1 $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
testCreateRcvMsg' store 2 "rcv_hash_2" rcvQueue1 $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
testCreateSndMsg' store "snd_hash_1" sndQueue1 $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
testCreateSndMsg' store "snd_hash_2" sndQueue1 $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
testCreateRcvMsg' store 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
testCreateRcvMsg' store 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
testCreateSndMsg' store "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
testCreateRcvMsg' store 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
testCreateSndMsg' store "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
testCreateSndMsg' store "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"

View File

@@ -54,12 +54,12 @@ testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h
runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a
runSmpAgentTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
where
t = transport @c
runSmpAgentServerTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a
runSmpAgentServerTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a
runSmpAgentServerTest test =
withSmpServerThreadOn t testPort $
\server -> withSmpAgentThreadOn t (agentTestPort, testPort, testDB) $
@@ -70,7 +70,7 @@ runSmpAgentServerTest test =
smpAgentServerTest :: Transport c => ((ThreadId, ThreadId) -> c -> IO ()) -> Expectation
smpAgentServerTest test' = runSmpAgentServerTest test' `shouldReturn` ()
runSmpAgentTestN :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a
runSmpAgentTestN :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a
runSmpAgentTestN agents test = withSmpServer t $ run agents []
where
run :: [(ServiceName, ServiceName, String)] -> [c] -> m a
@@ -78,7 +78,7 @@ runSmpAgentTestN agents test = withSmpServer t $ run agents []
run (a@(p, _, _) : as) hs = withSmpAgentOn t a $ testSMPAgentClientOn p $ \h -> run as (h : hs)
t = transport @c
runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a
runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a
runSmpAgentTestN_1 nClients test = withSmpServer t . withSmpAgent t $ run nClients []
where
run :: Int -> [c] -> m a
@@ -156,17 +156,17 @@ cfg =
}
}
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn t (port', smpPort', db') =
let cfg' = cfg {tcpPort = port', dbFile = db', smpServers = L.fromList [SMPServer "localhost" (Just smpPort') testKeyHash]}
in serverBracket
(\started -> runSMPAgentBlocking t started cfg')
(removeFile db')
withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a
withSmpAgentOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a
withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const
withSmpAgent :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
withSmpAgent :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m) => ServiceName -> (c -> m a) -> m a