mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-24 19:05:24 +00:00
SMP agent: functional API (#159)
* SMP agent: functional API (WIP) * functional API for SMP agent, tests * fix ICON message parameter * use stateTVar
This commit is contained in:
committed by
GitHub
parent
bf5561c89c
commit
d5f324cb5c
@@ -18,10 +18,10 @@ CREATE TABLE conn_invitations (
|
||||
conn_id BLOB REFERENCES connections (conn_alias) -- created connection
|
||||
ON DELETE CASCADE
|
||||
DEFERRABLE INITIALLY DEFERRED,
|
||||
status TEXT DEFAULT '' -- '', 'ACPT', 'CON'
|
||||
status TEXT NOT NULL DEFAULT '' -- '', 'ACPT', 'CON'
|
||||
) WITHOUT ROWID;
|
||||
|
||||
ALTER TABLE connections
|
||||
ADD via_inv BLOB REFERENCES conn_invitations (inv_id) ON DELETE RESTRICT;
|
||||
ALTER TABLE connections
|
||||
ADD conn_level INTEGER DEFAULT 0;
|
||||
ADD conn_level INTEGER NOT NULL DEFAULT 0;
|
||||
|
||||
+271
-182
@@ -1,12 +1,15 @@
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
-- |
|
||||
-- Module : Simplex.Messaging.Agent
|
||||
@@ -21,10 +24,34 @@
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
|
||||
module Simplex.Messaging.Agent
|
||||
( runSMPAgent,
|
||||
( -- * SMP agent over TCP
|
||||
runSMPAgent,
|
||||
runSMPAgentBlocking,
|
||||
|
||||
-- * queue-based SMP agent
|
||||
getAgentClient,
|
||||
runAgentClient,
|
||||
|
||||
-- * SMP agent functional API
|
||||
AgentMonad,
|
||||
AgentErrorMonad,
|
||||
getSMPAgentClient,
|
||||
runSMPAgentClient,
|
||||
createConnection,
|
||||
joinConnection,
|
||||
sendIntroduction,
|
||||
acceptInvitation,
|
||||
subscribeConnection,
|
||||
sendMessage,
|
||||
suspendConnection,
|
||||
deleteConnection,
|
||||
createConnection',
|
||||
joinConnection',
|
||||
sendIntroduction',
|
||||
acceptInvitation',
|
||||
subscribeConnection',
|
||||
sendMessage',
|
||||
suspendConnection',
|
||||
deleteConnection',
|
||||
)
|
||||
where
|
||||
|
||||
@@ -34,8 +61,10 @@ import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.Bifunctor (second)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Text as T
|
||||
@@ -54,14 +83,14 @@ import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy, Transport (..), runTransportServer)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import System.Random (randomR)
|
||||
import UnliftIO.Async (race_)
|
||||
import UnliftIO.Async (Async, async, race_)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent t cfg = do
|
||||
started <- newEmptyTMVarIO
|
||||
runSMPAgentBlocking t started cfg
|
||||
@@ -70,23 +99,83 @@ runSMPAgent t cfg = do
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: (MonadFail m, MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort} = runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
|
||||
where
|
||||
smpAgent :: forall c m'. (Transport c, MonadFail m', MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
smpAgent _ = runTransportServer started tcpPort $ \(h :: c) -> do
|
||||
liftIO $ putLn h "Welcome to SMP v0.3.2 agent"
|
||||
c <- getSMPAgentClient
|
||||
c <- getAgentClient
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runSMPAgentClient c)
|
||||
`E.finally` (closeSMPServerClients c >> logConnection c False)
|
||||
race_ (connectClient h c) (runAgentClient c)
|
||||
`E.finally` disconnectServers c
|
||||
|
||||
-- | Creates an SMP agent instance that receives commands and sends responses via 'TBQueue's.
|
||||
getSMPAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient
|
||||
getSMPAgentClient = do
|
||||
n <- asks clientCounter
|
||||
cfg <- asks config
|
||||
atomically $ newAgentClient n cfg
|
||||
-- | Creates an SMP agent client instance
|
||||
getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m (Async (), AgentClient)
|
||||
getSMPAgentClient cfg = newSMPAgentEnv cfg >>= runReaderT runAgent
|
||||
where
|
||||
runAgent = do
|
||||
c <- getAgentClient
|
||||
st <- agentDB
|
||||
action <- async $ subscriber c st `E.finally` disconnectServers c
|
||||
pure (action, c)
|
||||
|
||||
disconnectServers :: MonadUnliftIO m => AgentClient -> m ()
|
||||
disconnectServers c = closeSMPServerClients c >> logConnection c False
|
||||
|
||||
-- |
|
||||
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
|
||||
|
||||
-- | Create SMP agent connection (NEW command) in Reader monad
|
||||
createConnection' :: AgentMonad m => AgentClient -> m (ConnId, SMPQueueInfo)
|
||||
createConnection' c = newConn c "" Nothing 0
|
||||
|
||||
-- | Create SMP agent connection (NEW command)
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> m (ConnId, SMPQueueInfo)
|
||||
createConnection c = createConnection' c `runReaderT` agentEnv c
|
||||
|
||||
-- | Join SMP agent connection (JOIN command) in Reader monad
|
||||
joinConnection' :: AgentMonad m => AgentClient -> SMPQueueInfo -> m ConnId
|
||||
joinConnection' c qInfo = joinConn c "" qInfo (ReplyMode On) Nothing 0
|
||||
|
||||
-- | Join SMP agent connection (JOIN command)
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> SMPQueueInfo -> m ConnId
|
||||
joinConnection c qInfo = joinConnection' c qInfo `runReaderT` agentEnv c
|
||||
|
||||
-- | Accept invitation (ACPT command) in Reader monad
|
||||
acceptInvitation' :: AgentMonad m => AgentClient -> InvitationId -> ConnInfo -> m ConnId
|
||||
acceptInvitation' c = acceptInv c ""
|
||||
|
||||
-- | Accept invitation (ACPT command)
|
||||
acceptInvitation :: AgentErrorMonad m => AgentClient -> InvitationId -> ConnInfo -> m ConnId
|
||||
acceptInvitation c invId cInfo = acceptInvitation c invId cInfo `runReaderT` agentEnv c
|
||||
|
||||
-- | Send introduction of the second connection the first (INTRO command)
|
||||
sendIntroduction :: AgentErrorMonad m => AgentClient -> ConnId -> ConnId -> ConnInfo -> m ()
|
||||
sendIntroduction c toConn reConn reInfo = sendIntroduction' c toConn reConn reInfo `runReaderT` agentEnv c
|
||||
|
||||
-- | Subscribe to receive connection messages (SUB command)
|
||||
subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection c connId = subscribeConnection' c connId `runReaderT` agentEnv c
|
||||
|
||||
-- | Send message to the connection (SEND command)
|
||||
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgBody -> m InternalId
|
||||
sendMessage c connId msgBody = sendMessage' c connId msgBody `runReaderT` agentEnv c
|
||||
|
||||
-- | Suspend SMP agent connection (OFF command)
|
||||
suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
suspendConnection c connId = suspendConnection' c connId `runReaderT` agentEnv c
|
||||
|
||||
-- | Delete SMP agent connection (DEL command)
|
||||
deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnection c connId = deleteConnection' c connId `runReaderT` agentEnv c
|
||||
|
||||
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
|
||||
getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient
|
||||
getAgentClient = do
|
||||
store <- agentDB
|
||||
env <- ask
|
||||
atomically $ newAgentClient store env
|
||||
|
||||
connectClient :: Transport c => MonadUnliftIO m => c -> AgentClient -> m ()
|
||||
connectClient h c = race_ (send h c) (receive h c)
|
||||
@@ -97,30 +186,31 @@ logConnection c connected =
|
||||
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
|
||||
|
||||
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
|
||||
runSMPAgentClient :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runSMPAgentClient c = do
|
||||
runAgentClient :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runAgentClient c = do
|
||||
st <- agentDB
|
||||
race_ (subscriber c st) (client c)
|
||||
|
||||
agentDB :: (MonadUnliftIO m, MonadReader Env m) => m SQLiteStore
|
||||
agentDB = do
|
||||
db <- asks $ dbFile . config
|
||||
s1 <- liftIO $ connectSQLiteStore db
|
||||
s2 <- liftIO $ connectSQLiteStore db
|
||||
race_ (subscriber c s1) (client c s2)
|
||||
liftIO $ connectSQLiteStore db
|
||||
|
||||
receive :: forall c m. (Transport c, MonadUnliftIO m) => c -> AgentClient -> m ()
|
||||
receive h c@AgentClient {rcvQ, sndQ} = forever loop
|
||||
receive h c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
(corrId, connId, cmdOrErr) <- tGet SClient h
|
||||
case cmdOrErr of
|
||||
Right cmd -> write rcvQ (corrId, connId, cmd)
|
||||
Left e -> write subQ (corrId, connId, ERR e)
|
||||
where
|
||||
loop :: m ()
|
||||
loop = do
|
||||
(corrId, connId, cmdOrErr) <- tGet SClient h
|
||||
case cmdOrErr of
|
||||
Right cmd -> write rcvQ (corrId, connId, cmd)
|
||||
Left e -> write sndQ (corrId, connId, ERR e)
|
||||
write :: TBQueue (ATransmission p) -> ATransmission p -> m ()
|
||||
write q t = do
|
||||
logClient c "-->" t
|
||||
atomically $ writeTBQueue q t
|
||||
|
||||
send :: (Transport c, MonadUnliftIO m) => c -> AgentClient -> m ()
|
||||
send h c@AgentClient {sndQ} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
send h c@AgentClient {subQ} = forever $ do
|
||||
t <- atomically $ readTBQueue subQ
|
||||
tPut h t
|
||||
logClient c "<--" t
|
||||
|
||||
@@ -128,15 +218,13 @@ logClient :: MonadUnliftIO m => AgentClient -> ByteString -> ATransmission a ->
|
||||
logClient AgentClient {clientId} dir (corrId, connId, cmd) = do
|
||||
logInfo . decodeUtf8 $ B.unwords [bshow clientId, dir, "A :", corrId, connId, B.takeWhile (/= ' ') $ serializeCommand cmd]
|
||||
|
||||
client :: forall m. (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
client c@AgentClient {rcvQ, sndQ} st = forever loop
|
||||
where
|
||||
loop :: m ()
|
||||
loop = do
|
||||
t@(corrId, connId, _) <- atomically $ readTBQueue rcvQ
|
||||
runExceptT (processCommand c st t) >>= \case
|
||||
Left e -> atomically $ writeTBQueue sndQ (corrId, connId, ERR e)
|
||||
Right _ -> pure ()
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
client c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
(corrId, connId, cmd) <- atomically $ readTBQueue rcvQ
|
||||
runExceptT (processCommand c (connId, cmd))
|
||||
>>= atomically . writeTBQueue subQ . \case
|
||||
Left e -> (corrId, connId, ERR e)
|
||||
Right (connId', resp) -> (corrId, connId', resp)
|
||||
|
||||
withStore ::
|
||||
AgentMonad m =>
|
||||
@@ -159,148 +247,148 @@ withStore action = do
|
||||
SEBadConnType CSnd -> CONN SIMPLEX
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
processCommand :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> ATransmission 'Client -> m ()
|
||||
processCommand c@AgentClient {sndQ} st (corrId, connId, cmd) = case cmd of
|
||||
NEW -> createNewConnection Nothing 0 >>= uncurry respond
|
||||
JOIN smpQueueInfo replyMode -> joinConnection smpQueueInfo replyMode Nothing 0 >> pure () -- >>= (`respond` OK)
|
||||
INTRO reConnId reInfo -> makeIntroduction reConnId reInfo
|
||||
ACPT invId connInfo -> acceptInvitation invId connInfo
|
||||
SUB -> subscribeConnection connId
|
||||
SUBALL -> subscribeAll
|
||||
SEND msgBody -> sendMessage msgBody
|
||||
OFF -> suspendConnection
|
||||
DEL -> deleteConnection
|
||||
-- | execute any SMP agent command
|
||||
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
|
||||
processCommand c (connId, cmd) = case cmd of
|
||||
NEW -> second INV <$> newConn c connId Nothing 0
|
||||
JOIN smpQueueInfo replyMode -> (,OK) <$> joinConn c connId smpQueueInfo replyMode Nothing 0
|
||||
INTRO reConnId reInfo -> sendIntroduction' c connId reConnId reInfo $> (connId, OK)
|
||||
ACPT invId connInfo -> (,OK) <$> acceptInv c connId invId connInfo
|
||||
SUB -> subscribeConnection' c connId $> (connId, OK)
|
||||
SEND msgBody -> (connId,) . SENT . unId <$> sendMessage' c connId msgBody
|
||||
OFF -> suspendConnection' c connId $> (connId, OK)
|
||||
DEL -> deleteConnection' c connId $> (connId, OK)
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Maybe InvitationId -> Int -> m (ConnId, SMPQueueInfo)
|
||||
newConn c connId viaInv connLevel = do
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createRcvConn st g cData rq
|
||||
addSubscription c rq connId'
|
||||
pure (connId', qInfo)
|
||||
where
|
||||
createNewConnection :: Maybe InvitationId -> Int -> m (ConnId, ACommand 'Agent)
|
||||
createNewConnection viaInv connLevel = do
|
||||
-- TODO create connection alias if not passed
|
||||
-- make connId Maybe?
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createRcvConn st g cData rq
|
||||
addSubscription c rq connId'
|
||||
pure (connId', INV qInfo)
|
||||
|
||||
getSMPServer :: m SMPServer
|
||||
getSMPServer =
|
||||
asks (smpServers . config) >>= \case
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
|
||||
pure $ servers L.!! i
|
||||
|
||||
joinConnection :: SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m ConnId
|
||||
joinConnection qInfo (ReplyMode replyMode) viaInv connLevel = do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createSndConn st g cData sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
when (replyMode == On) $ createReplyQueue connId' sq
|
||||
pure connId'
|
||||
|
||||
makeIntroduction :: IntroId -> ConnInfo -> m ()
|
||||
makeIntroduction reConn reInfo =
|
||||
withStore ((,) <$> getConn st connId <*> getConn st reConn) >>= \case
|
||||
(SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do
|
||||
g <- asks idsDrg
|
||||
introId <- withStore $ createIntro st g NewIntroduction {toConn = connId, reConn, reInfo}
|
||||
sendControlMessage c sq $ A_INTRO introId reInfo
|
||||
respond connId OK
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
|
||||
acceptInvitation :: InvitationId -> ConnInfo -> m ()
|
||||
acceptInvitation invId connInfo =
|
||||
withStore (getInvitation st invId) >>= \case
|
||||
Invitation {viaConn, qInfo, externalIntroId, status = InvNew} ->
|
||||
withStore (getConn st viaConn) >>= \case
|
||||
SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of
|
||||
Nothing -> do
|
||||
(connId', INV qInfo') <- createNewConnection (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId connId'
|
||||
sendControlMessage c sq $ A_INV externalIntroId qInfo' connInfo
|
||||
respond connId' OK
|
||||
Just qInfo' -> do
|
||||
connId' <- joinConnection qInfo' (ReplyMode On) (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId connId'
|
||||
respond connId' OK
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
subscribeConnection :: ConnId -> m ()
|
||||
subscribeConnection cId =
|
||||
withStore (getConn st cId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> subscribe rq
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
subscribe rq = subscribeQueue c rq cId >> respond cId OK
|
||||
|
||||
-- TODO remove - hack for subscribing to all; respond' and parameterization of subscribeConnection are byproduct
|
||||
subscribeAll :: m ()
|
||||
subscribeAll = withStore (getAllConnIds st) >>= mapM_ subscribeConnection
|
||||
|
||||
sendMessage :: MsgBody -> m ()
|
||||
sendMessage msgBody =
|
||||
withStore (getConn st connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg sq
|
||||
SomeConn _ (SndConnection _ sq) -> sendMsg sq
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
sendMsg :: SndQueue -> m ()
|
||||
sendMsg sq = do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st connId
|
||||
let msgStr =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
{ senderMsgId = unSndId internalSndId,
|
||||
senderTimestamp = internalTs,
|
||||
previousMsgHash,
|
||||
agentMessage = A_MSG msgBody
|
||||
}
|
||||
msgHash = C.sha256Hash msgStr
|
||||
withStore $
|
||||
createSndMsg st connId $
|
||||
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
|
||||
sendAgentMessage c sq msgStr
|
||||
atomically $ writeTBQueue sndQ (corrId, connId, SENT $ unId internalId)
|
||||
|
||||
suspendConnection :: m ()
|
||||
suspendConnection =
|
||||
withStore (getConn st connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> suspend rq
|
||||
SomeConn _ (RcvConnection _ rq) -> suspend rq
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
suspend rq = suspendQueue c rq >> respond connId OK
|
||||
|
||||
deleteConnection :: m ()
|
||||
deleteConnection =
|
||||
withStore (getConn st connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> delete rq
|
||||
SomeConn _ (RcvConnection _ rq) -> delete rq
|
||||
_ -> delConn
|
||||
where
|
||||
delConn = withStore (deleteConn st connId) >> respond connId OK
|
||||
delete rq = do
|
||||
deleteQueue c rq
|
||||
removeSubscription c connId
|
||||
delConn
|
||||
st = store c
|
||||
|
||||
joinConn :: forall m. AgentMonad m => AgentClient -> ConnId -> SMPQueueInfo -> ReplyMode -> Maybe InvitationId -> Int -> m ConnId
|
||||
joinConn c connId qInfo (ReplyMode replyMode) viaInv connLevel = do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
g <- asks idsDrg
|
||||
let cData = ConnData {connId, viaInv, connLevel}
|
||||
connId' <- withStore $ createSndConn st g cData sq
|
||||
connectToSendQueue c sq senderKey verifyKey
|
||||
when (replyMode == On) $ createReplyQueue connId' sq
|
||||
pure connId'
|
||||
where
|
||||
st = store c
|
||||
createReplyQueue :: ConnId -> SndQueue -> m ()
|
||||
createReplyQueue cId sq = do
|
||||
srv <- getSMPServer
|
||||
(rq, qInfo) <- newReceiveQueue c srv
|
||||
(rq, qInfo') <- newReceiveQueue c srv
|
||||
addSubscription c rq cId
|
||||
withStore $ upgradeSndConnToDuplex st cId rq
|
||||
sendControlMessage c sq $ REPLY qInfo
|
||||
sendControlMessage c sq $ REPLY qInfo'
|
||||
|
||||
respond :: ConnId -> ACommand 'Agent -> m ()
|
||||
respond cId resp = atomically . writeTBQueue sndQ $ (corrId, cId, resp)
|
||||
-- | Send introduction of the second connection the first (INTRO command) in Reader monad
|
||||
sendIntroduction' :: AgentMonad m => AgentClient -> ConnId -> ConnId -> ConnInfo -> m ()
|
||||
sendIntroduction' c toConn reConn reInfo =
|
||||
withStore ((,) <$> getConn st toConn <*> getConn st reConn) >>= \case
|
||||
(SomeConn _ (DuplexConnection _ _ sq), SomeConn _ DuplexConnection {}) -> do
|
||||
g <- asks idsDrg
|
||||
introId <- withStore $ createIntro st g NewIntroduction {toConn, reConn, reInfo}
|
||||
sendControlMessage c sq $ A_INTRO introId reInfo
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
st = store c
|
||||
|
||||
acceptInv :: AgentMonad m => AgentClient -> ConnId -> InvitationId -> ConnInfo -> m ConnId
|
||||
acceptInv c connId invId connInfo =
|
||||
withStore (getInvitation st invId) >>= \case
|
||||
Invitation {viaConn, qInfo, externalIntroId, status = InvNew} ->
|
||||
withStore (getConn st viaConn) >>= \case
|
||||
SomeConn _ (DuplexConnection ConnData {connLevel} _ sq) -> case qInfo of
|
||||
Nothing -> do
|
||||
(connId', qInfo') <- newConn c connId (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId connId'
|
||||
sendControlMessage c sq $ A_INV externalIntroId qInfo' connInfo
|
||||
pure connId'
|
||||
Just qInfo' -> do
|
||||
connId' <- joinConn c connId qInfo' (ReplyMode On) (Just invId) (connLevel + 1)
|
||||
withStore $ addInvitationConn st invId connId'
|
||||
pure connId'
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
st = store c
|
||||
|
||||
-- | Subscribe to receive connection messages (SUB command) in Reader monad
|
||||
subscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection' c connId =
|
||||
withStore (getConn (store c) connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> subscribeQueue c rq connId
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribeQueue c rq connId
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
|
||||
-- | Send message to the connection (SEND command) in Reader monad
|
||||
sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgBody -> m InternalId
|
||||
sendMessage' c connId msgBody =
|
||||
withStore (getConn st connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ _ sq) -> sendMsg_ sq
|
||||
SomeConn _ (SndConnection _ sq) -> sendMsg_ sq
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
where
|
||||
st = store c
|
||||
sendMsg_ :: SndQueue -> m InternalId
|
||||
sendMsg_ sq = do
|
||||
internalTs <- liftIO getCurrentTime
|
||||
(internalId, internalSndId, previousMsgHash) <- withStore $ updateSndIds st connId
|
||||
let msgStr =
|
||||
serializeSMPMessage
|
||||
SMPMessage
|
||||
{ senderMsgId = unSndId internalSndId,
|
||||
senderTimestamp = internalTs,
|
||||
previousMsgHash,
|
||||
agentMessage = A_MSG msgBody
|
||||
}
|
||||
msgHash = C.sha256Hash msgStr
|
||||
withStore $
|
||||
createSndMsg st connId $
|
||||
SndMsgData {internalId, internalSndId, internalTs, msgBody, internalHash = msgHash}
|
||||
sendAgentMessage c sq msgStr
|
||||
pure internalId
|
||||
|
||||
-- | Suspend SMP agent connection (OFF command) in Reader monad
|
||||
suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
suspendConnection' c connId =
|
||||
withStore (getConn (store c) connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> suspendQueue c rq
|
||||
SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq
|
||||
_ -> throwError $ CONN SIMPLEX
|
||||
|
||||
-- | Delete SMP agent connection (DEL command) in Reader monad
|
||||
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
deleteConnection' c connId =
|
||||
withStore (getConn st connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> delete rq
|
||||
SomeConn _ (RcvConnection _ rq) -> delete rq
|
||||
_ -> withStore (deleteConn st connId)
|
||||
where
|
||||
st = store c
|
||||
delete :: RcvQueue -> m ()
|
||||
delete rq = do
|
||||
deleteQueue c rq
|
||||
removeSubscription c connId
|
||||
withStore (deleteConn st connId)
|
||||
|
||||
getSMPServer :: AgentMonad m => m SMPServer
|
||||
getSMPServer =
|
||||
asks (smpServers . config) >>= \case
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
|
||||
pure $ servers L.!! i
|
||||
|
||||
sendControlMessage :: AgentMonad m => AgentClient -> SndQueue -> AMessage -> m ()
|
||||
sendControlMessage c sq agentMessage = do
|
||||
@@ -313,20 +401,19 @@ sendControlMessage c sq agentMessage = do
|
||||
agentMessage
|
||||
}
|
||||
|
||||
subscriber :: (MonadFail m, MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> SQLiteStore -> m ()
|
||||
subscriber c@AgentClient {msgQ} st = forever $ do
|
||||
-- TODO this will only process messages and notifications
|
||||
t <- atomically $ readTBQueue msgQ
|
||||
runExceptT (processSMPTransmission c st t) >>= \case
|
||||
Left e -> liftIO $ print e
|
||||
Right _ -> return ()
|
||||
|
||||
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SQLiteStore -> SMPServerTransmission -> m ()
|
||||
processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
processSMPTransmission c@AgentClient {subQ} st (srv, rId, cmd) = do
|
||||
withStore (getRcvConn st srv rId) >>= \case
|
||||
SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq
|
||||
SomeConn SCRcv (RcvConnection cData rq) -> processSMP SCRcv cData rq
|
||||
_ -> atomically $ writeTBQueue sndQ ("", "", ERR $ CONN NOT_FOUND)
|
||||
_ -> atomically $ writeTBQueue subQ ("", "", ERR $ CONN NOT_FOUND)
|
||||
where
|
||||
processSMP :: SConnType c -> ConnData -> RcvQueue -> m ()
|
||||
processSMP cType ConnData {connId} rq@RcvQueue {status} =
|
||||
@@ -358,7 +445,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
notify . ERR $ BROKER UNEXPECTED
|
||||
where
|
||||
notify :: ACommand 'Agent -> m ()
|
||||
notify msg = atomically $ writeTBQueue sndQ ("", connId, msg)
|
||||
notify msg = atomically $ writeTBQueue subQ ("", connId, msg)
|
||||
|
||||
prohibited :: m ()
|
||||
prohibited = notify . ERR $ AGENT A_PROHIBITED
|
||||
@@ -369,7 +456,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
case status of
|
||||
New -> do
|
||||
-- TODO currently it automatically allows whoever sends the confirmation
|
||||
-- Commands CONF and LET are not supported in v0.2
|
||||
-- TODO create invitation and send REQ
|
||||
withStore $ setRcvQueueStatus st rq Confirmed
|
||||
-- TODO update sender key in the store?
|
||||
secureQueue c rq senderKey
|
||||
@@ -395,7 +482,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
SCRcv -> do
|
||||
(sq, senderKey, verifyKey) <- newSendQueue qInfo
|
||||
withStore $ upgradeRcvConnToDuplex st connId sq
|
||||
connectToSendQueue c st sq senderKey verifyKey
|
||||
connectToSendQueue c sq senderKey verifyKey
|
||||
connected
|
||||
_ -> prohibited
|
||||
|
||||
@@ -460,7 +547,7 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
| otherwise -> prohibited
|
||||
where
|
||||
sendConMsg :: ConnId -> ConnId -> m ()
|
||||
sendConMsg toConn reConn = atomically $ writeTBQueue sndQ ("", toConn, ICON reConn)
|
||||
sendConMsg toConn reConn = atomically $ writeTBQueue subQ ("", toConn, ICON reConn)
|
||||
|
||||
agentClientMsg :: PrevRcvMsgHash -> (ExternalSndId, ExternalSndTs) -> (BrokerId, BrokerTs) -> MsgBody -> MsgHash -> m ()
|
||||
agentClientMsg receivedPrevMsgHash senderMeta brokerMeta msgBody msgHash = do
|
||||
@@ -502,12 +589,14 @@ processSMPTransmission c@AgentClient {sndQ} st (srv, rId, cmd) = do
|
||||
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
|
||||
| otherwise = MsgError MsgDuplicate -- this case is not possible
|
||||
|
||||
connectToSendQueue :: AgentMonad m => AgentClient -> SQLiteStore -> SndQueue -> SenderPublicKey -> VerificationKey -> m ()
|
||||
connectToSendQueue c st sq senderKey verifyKey = do
|
||||
connectToSendQueue :: AgentMonad m => AgentClient -> SndQueue -> SenderPublicKey -> VerificationKey -> m ()
|
||||
connectToSendQueue c sq senderKey verifyKey = do
|
||||
sendConfirmation c sq senderKey
|
||||
withStore $ setSndQueueStatus st sq Confirmed
|
||||
sendHello c sq verifyKey
|
||||
withStore $ setSndQueueStatus st sq Active
|
||||
where
|
||||
st = store c
|
||||
|
||||
newSendQueue ::
|
||||
(MonadUnliftIO m, MonadReader Env m) => SMPQueueInfo -> m (SndQueue, SenderPublicKey, VerificationKey)
|
||||
|
||||
@@ -31,6 +31,7 @@ module Simplex.Messaging.Agent.Client
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
@@ -48,6 +49,7 @@ import Data.Time.Clock
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore)
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgBody, QueueId, SenderPublicKey)
|
||||
@@ -59,27 +61,30 @@ import UnliftIO.STM
|
||||
|
||||
data AgentClient = AgentClient
|
||||
{ rcvQ :: TBQueue (ATransmission 'Client),
|
||||
sndQ :: TBQueue (ATransmission 'Agent),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
smpClients :: TVar (Map SMPServer SMPClient),
|
||||
subscrSrvrs :: TVar (Map SMPServer (Set ConnId)),
|
||||
subscrConns :: TVar (Map ConnId SMPServer),
|
||||
clientId :: Int
|
||||
clientId :: Int,
|
||||
store :: SQLiteStore,
|
||||
agentEnv :: Env
|
||||
}
|
||||
|
||||
newAgentClient :: TVar Int -> AgentConfig -> STM AgentClient
|
||||
newAgentClient cc AgentConfig {tbqSize} = do
|
||||
rcvQ <- newTBQueue tbqSize
|
||||
sndQ <- newTBQueue tbqSize
|
||||
msgQ <- newTBQueue tbqSize
|
||||
newAgentClient :: SQLiteStore -> Env -> STM AgentClient
|
||||
newAgentClient store agentEnv = do
|
||||
let qSize = tbqSize $ config agentEnv
|
||||
rcvQ <- newTBQueue qSize
|
||||
subQ <- newTBQueue qSize
|
||||
msgQ <- newTBQueue qSize
|
||||
smpClients <- newTVar M.empty
|
||||
subscrSrvrs <- newTVar M.empty
|
||||
subscrConns <- newTVar M.empty
|
||||
clientId <- (+ 1) <$> readTVar cc
|
||||
writeTVar cc clientId
|
||||
return AgentClient {rcvQ, sndQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId}
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1)
|
||||
return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, clientId, store, agentEnv}
|
||||
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m, MonadFail m)
|
||||
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
@@ -119,7 +124,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
deleteKeys ks m = S.foldr' M.delete m ks
|
||||
|
||||
notifySub :: ConnId -> IO ()
|
||||
notifySub connId = atomically $ writeTBQueue (sndQ c) ("", connId, END)
|
||||
notifySub connId = atomically $ writeTBQueue (subQ c) ("", connId, END)
|
||||
|
||||
closeSMPServerClients :: MonadUnliftIO m => AgentClient -> m ()
|
||||
closeSMPServerClients c = liftIO $ readTVarIO (smpClients c) >>= mapM_ closeSMPClient
|
||||
|
||||
@@ -161,7 +161,6 @@ data ACommand (p :: AParty) where
|
||||
CON :: ACommand Agent -- notification that connection is established
|
||||
ICON :: ConnId -> ACommand Agent
|
||||
SUB :: ACommand Client
|
||||
SUBALL :: ACommand Client -- TODO should be moved to chat protocol - hack for subscribing to all
|
||||
END :: ACommand Agent
|
||||
-- QST :: QueueDirection -> ACommand Client
|
||||
-- STAT :: QueueDirection -> Maybe QueueStatus -> Maybe SubMode -> ACommand Agent
|
||||
@@ -478,7 +477,6 @@ commandP =
|
||||
<|> "REQ " *> reqCmd
|
||||
<|> "ACPT " *> acptCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "SUBALL" $> ACmd SClient SUBALL -- TODO remove - hack for subscribing to all
|
||||
<|> "END" $> ACmd SAgent END
|
||||
<|> "SEND " *> sendCmd
|
||||
<|> "SENT " *> sentResp
|
||||
@@ -533,7 +531,6 @@ serializeCommand = \case
|
||||
REQ invId cInfo -> "REQ " <> invId <> " " <> serializeMsg cInfo
|
||||
ACPT invId cInfo -> "ACPT " <> invId <> " " <> serializeMsg cInfo
|
||||
SUB -> "SUB"
|
||||
SUBALL -> "SUBALL" -- TODO remove - hack for subscribing to all
|
||||
END -> "END"
|
||||
SEND msgBody -> "SEND " <> serializeMsg msgBody
|
||||
SENT mId -> "SENT " <> bshow mId
|
||||
@@ -549,7 +546,7 @@ serializeCommand = \case
|
||||
OFF -> "OFF"
|
||||
DEL -> "DEL"
|
||||
CON -> "CON"
|
||||
ICON introId -> "ICON " <> introId
|
||||
ICON connId -> "ICON " <> connId
|
||||
ERR e -> "ERR " <> serializeAgentError e
|
||||
OK -> "OK"
|
||||
where
|
||||
|
||||
@@ -53,9 +53,10 @@ import Control.Monad
|
||||
import Control.Monad.Trans.Class
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Protocol (SMPServer (..))
|
||||
@@ -64,7 +65,7 @@ import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport (ATransport (..), TCP, THandle (..), TProxy, Transport (..), TransportError, clientHandshake, runTransportClient)
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
|
||||
import System.Timeout
|
||||
import System.Timeout (timeout)
|
||||
|
||||
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
|
||||
--
|
||||
@@ -195,22 +196,27 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, smpPing} msgQ dis
|
||||
process :: SMPClient -> IO ()
|
||||
process SMPClient {rcvQ, sentCommands} = forever $ do
|
||||
(_, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
|
||||
cs <- readTVarIO sentCommands
|
||||
case M.lookup corrId cs of
|
||||
Nothing -> do
|
||||
case respOrErr of
|
||||
Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd)
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
_ -> return ()
|
||||
Just Request {queueId, responseVar} -> atomically $ do
|
||||
modifyTVar sentCommands $ M.delete corrId
|
||||
putTMVar responseVar $
|
||||
if queueId == qId
|
||||
then case respOrErr of
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (Cmd _ (ERR e)) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPUnexpectedResponse
|
||||
if B.null $ bs corrId
|
||||
then sendMsg qId respOrErr
|
||||
else do
|
||||
cs <- readTVarIO sentCommands
|
||||
case M.lookup corrId cs of
|
||||
Nothing -> sendMsg qId respOrErr
|
||||
Just Request {queueId, responseVar} -> atomically $ do
|
||||
modifyTVar sentCommands $ M.delete corrId
|
||||
putTMVar responseVar $
|
||||
if queueId == qId
|
||||
then case respOrErr of
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (Cmd _ (ERR e)) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPUnexpectedResponse
|
||||
|
||||
sendMsg :: QueueId -> Either ErrorType Cmd -> IO ()
|
||||
sendMsg qId = \case
|
||||
Right (Cmd SBroker cmd) -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd)
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
_ -> return ()
|
||||
|
||||
-- | Disconnects SMP client from the server and terminates client threads.
|
||||
closeSMPClient :: SMPClient -> IO ()
|
||||
|
||||
+57
-15
@@ -8,19 +8,28 @@
|
||||
{-# LANGUAGE PostfixOperators #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
|
||||
module AgentTests where
|
||||
|
||||
import AgentTests.SQLiteTests (storeTests)
|
||||
import Control.Concurrent
|
||||
import Control.Monad.Except (catchError, runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import SMPAgentClient
|
||||
import SMPClient (withSmpServer)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite (dbFile)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.Store (InternalId (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TProxy (..), Transport (..))
|
||||
import System.Timeout
|
||||
import Test.Hspec
|
||||
import UnliftIO.STM
|
||||
|
||||
agentTests :: ATransport -> Spec
|
||||
agentTests (ATransport t) = do
|
||||
@@ -39,6 +48,8 @@ agentTests (ATransport t) = do
|
||||
smpAgentTest2_2_2 $ testDuplexConnection t
|
||||
it "should connect via 2 servers and 2 agents (random IDs)" $
|
||||
smpAgentTest2_2_2 $ testDuplexConnRandomIds t
|
||||
it "should connect via one server using SMP agent clients" $
|
||||
withSmpServer (ATransport t) testAgentClient
|
||||
describe "Connection subscriptions" do
|
||||
it "should connect via one server and one agent" $
|
||||
smpAgentTest3_1_1 $ testSubscription t
|
||||
@@ -75,7 +86,7 @@ correctTransmission (corrId, cAlias, cmdOrErr) = case cmdOrErr of
|
||||
|
||||
-- | receive message to handle `h` and validate that it is the expected one
|
||||
(<#) :: Transport c => c -> ATransmission 'Agent -> Expectation
|
||||
h <# (corrId, cAlias, cmd) = (h <#:) >>= (`shouldBe` (corrId, cAlias, Right cmd))
|
||||
h <# (corrId, cAlias, cmd) = (h <#:) `shouldReturn` (corrId, cAlias, Right cmd)
|
||||
|
||||
-- | receive message to handle `h` and validate it using predicate `p`
|
||||
(<#=) :: Transport c => c -> (ATransmission 'Agent -> Bool) -> Expectation
|
||||
@@ -93,17 +104,12 @@ h #:# err = tryGet `shouldReturn` ()
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent
|
||||
pattern Msg msgBody <- MSG {msgBody, msgIntegrity = MsgOk}
|
||||
|
||||
-- pattern Inv :: SMPQueueInfo -> Either AgentErrorType (ACommand 'Agent)
|
||||
-- pattern Inv invitation <- Right (INV invitation)
|
||||
|
||||
-- pattern Req :: InvitationId -> ConnInfo -> Either AgentErrorType (ACommand 'Agent)
|
||||
-- pattern Req invId cInfo <- Right (REQ invId cInfo)
|
||||
|
||||
testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnection _ alice bob = do
|
||||
("1", "bob", Right (INV qInfo)) <- alice #: ("1", "bob", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
bob #: ("11", "alice", "JOIN " <> qInfo') #> ("", "alice", CON)
|
||||
bob #: ("11", "alice", "JOIN " <> qInfo') #> ("11", "alice", OK)
|
||||
bob <# ("", "alice", CON)
|
||||
alice <# ("", "bob", CON)
|
||||
alice #: ("2", "bob", "SEND :hello") #> ("2", "bob", SENT 1)
|
||||
alice #: ("3", "bob", "SEND :how are you?") #> ("3", "bob", SENT 2)
|
||||
@@ -118,11 +124,48 @@ testDuplexConnection _ alice bob = do
|
||||
alice #: ("6", "bob", "DEL") #> ("6", "bob", OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
testAgentClient :: IO ()
|
||||
testAgentClient = do
|
||||
(_, alice) <- getSMPAgentClient cfg
|
||||
(_, bob) <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice
|
||||
aliceId <- joinConnection bob qInfo
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, CON)
|
||||
InternalId 1 <- sendMessage alice bobId "hello"
|
||||
InternalId 2 <- sendMessage alice bobId "how are you?"
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False
|
||||
InternalId 3 <- sendMessage bob aliceId "hello too"
|
||||
InternalId 4 <- sendMessage bob aliceId "message 1"
|
||||
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
|
||||
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
|
||||
suspendConnection alice bobId
|
||||
InternalId 0 <- sendMessage bob aliceId "message 2" `catchError` \(SMP AUTH) -> pure $ InternalId 0
|
||||
deleteConnection alice bobId
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
pure ()
|
||||
where
|
||||
(##>) :: MonadIO m => m (ATransmission 'Agent) -> ATransmission 'Agent -> m ()
|
||||
a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t)
|
||||
(=##>) :: MonadIO m => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m ()
|
||||
a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p)
|
||||
noMessages :: AgentClient -> String -> Expectation
|
||||
noMessages c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` get c >>= \case
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
get c = atomically (readTBQueue $ subQ c)
|
||||
|
||||
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnRandomIds _ alice bob = do
|
||||
("1", bobConn, Right (INV qInfo)) <- alice #: ("1", "", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
("", aliceConn, Right CON) <- bob #: ("11", "", "JOIN " <> qInfo')
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN " <> qInfo')
|
||||
bob <# ("", aliceConn, CON)
|
||||
alice <# ("", bobConn, CON)
|
||||
alice #: ("2", bobConn, "SEND :hello") #> ("2", bobConn, SENT 1)
|
||||
alice #: ("3", bobConn, "SEND :how are you?") #> ("3", bobConn, SENT 2)
|
||||
@@ -139,12 +182,9 @@ testDuplexConnRandomIds _ alice bob = do
|
||||
|
||||
testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testSubscription _ alice1 alice2 bob = do
|
||||
("1", "bob", Right (INV qInfo)) <- alice1 #: ("1", "bob", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
bob #: ("11", "alice", "JOIN " <> qInfo') #> ("", "alice", CON)
|
||||
(alice1, "alice") `connect` (bob, "bob")
|
||||
bob #: ("12", "alice", "SEND 5\nhello") #> ("12", "alice", SENT 1)
|
||||
bob #: ("13", "alice", "SEND 11\nhello again") #> ("13", "alice", SENT 2)
|
||||
alice1 <# ("", "bob", CON)
|
||||
alice1 <#= \case ("", "bob", Msg "hello") -> True; _ -> False
|
||||
alice1 <#= \case ("", "bob", Msg "hello again") -> True; _ -> False
|
||||
alice2 #: ("21", "bob", "SUB") #> ("21", "bob", OK)
|
||||
@@ -210,14 +250,16 @@ connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
("c1", _, Right (INV qInfo)) <- h1 #: ("c1", name2, "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
h2 #: ("c2", name1, "JOIN " <> qInfo') #> ("", name1, CON)
|
||||
h2 #: ("c2", name1, "JOIN " <> qInfo') #> ("c2", name1, OK)
|
||||
h2 <# ("", name1, CON)
|
||||
h1 <# ("", name2, CON)
|
||||
|
||||
connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString)
|
||||
connect' h1 h2 = do
|
||||
("c1", conn2, Right (INV qInfo)) <- h1 #: ("c1", "", "NEW")
|
||||
let qInfo' = serializeSmpQueueInfo qInfo
|
||||
("", conn1, Right CON) <- h2 #: ("c2", "", "JOIN " <> qInfo')
|
||||
("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN " <> qInfo')
|
||||
h2 <# ("", conn1, CON)
|
||||
h1 <# ("", conn2, CON)
|
||||
pure (conn1, conn2)
|
||||
|
||||
|
||||
@@ -54,12 +54,12 @@ testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
|
||||
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
|
||||
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h
|
||||
|
||||
runSmpAgentTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a
|
||||
runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (c -> m a) -> m a
|
||||
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
|
||||
where
|
||||
t = transport @c
|
||||
|
||||
runSmpAgentServerTest :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a
|
||||
runSmpAgentServerTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => ((ThreadId, ThreadId) -> c -> m a) -> m a
|
||||
runSmpAgentServerTest test =
|
||||
withSmpServerThreadOn t testPort $
|
||||
\server -> withSmpAgentThreadOn t (agentTestPort, testPort, testDB) $
|
||||
@@ -70,7 +70,7 @@ runSmpAgentServerTest test =
|
||||
smpAgentServerTest :: Transport c => ((ThreadId, ThreadId) -> c -> IO ()) -> Expectation
|
||||
smpAgentServerTest test' = runSmpAgentServerTest test' `shouldReturn` ()
|
||||
|
||||
runSmpAgentTestN :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => [(ServiceName, ServiceName, String)] -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN agents test = withSmpServer t $ run agents []
|
||||
where
|
||||
run :: [(ServiceName, ServiceName, String)] -> [c] -> m a
|
||||
@@ -78,7 +78,7 @@ runSmpAgentTestN agents test = withSmpServer t $ run agents []
|
||||
run (a@(p, _, _) : as) hs = withSmpAgentOn t a $ testSMPAgentClientOn p $ \h -> run as (h : hs)
|
||||
t = transport @c
|
||||
|
||||
runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadFail m, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN_1 :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => Int -> ([c] -> m a) -> m a
|
||||
runSmpAgentTestN_1 nClients test = withSmpServer t . withSmpAgent t $ run nClients []
|
||||
where
|
||||
run :: Int -> [c] -> m a
|
||||
@@ -156,17 +156,17 @@ cfg =
|
||||
}
|
||||
}
|
||||
|
||||
withSmpAgentThreadOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn t (port', smpPort', db') =
|
||||
let cfg' = cfg {tcpPort = port', dbFile = db', smpServers = L.fromList [SMPServer "localhost" (Just smpPort') testKeyHash]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg')
|
||||
(removeFile db')
|
||||
|
||||
withSmpAgentOn :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a
|
||||
withSmpAgentOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m a -> m a
|
||||
withSmpAgentOn t (port', smpPort', db') = withSmpAgentThreadOn t (port', smpPort', db') . const
|
||||
|
||||
withSmpAgent :: (MonadFail m, MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withSmpAgent :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
||||
|
||||
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m) => ServiceName -> (c -> m a) -> m a
|
||||
|
||||
Reference in New Issue
Block a user