mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-13 18:43:11 +00:00
Merge branch 'master' into sqlcipher
This commit is contained in:
@@ -52,6 +52,7 @@ library
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
Simplex.Messaging.Crypto
|
||||
@@ -81,6 +82,7 @@ library
|
||||
Simplex.Messaging.Server.Stats
|
||||
Simplex.Messaging.Server.StoreLog
|
||||
Simplex.Messaging.TMap
|
||||
Simplex.Messaging.TMap2
|
||||
Simplex.Messaging.Transport
|
||||
Simplex.Messaging.Transport.Client
|
||||
Simplex.Messaging.Transport.HTTP2
|
||||
|
||||
+281
-77
@@ -12,7 +12,6 @@
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
-- |
|
||||
-- Module : Simplex.Messaging.Agent
|
||||
@@ -39,6 +38,10 @@ module Simplex.Messaging.Agent
|
||||
disconnectAgentClient,
|
||||
resumeAgentClient,
|
||||
withAgentLock,
|
||||
createConnectionAsync,
|
||||
joinConnectionAsync,
|
||||
allowConnectionAsync,
|
||||
ackMessageAsync,
|
||||
createConnection,
|
||||
joinConnection,
|
||||
allowConnection,
|
||||
@@ -80,8 +83,9 @@ import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Composition ((.:), (.:.))
|
||||
import Data.Composition ((.:), (.:.), (.::))
|
||||
import Data.Functor (($>))
|
||||
import Data.List (deleteFirstsBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
@@ -141,13 +145,29 @@ resumeAgentClient c = atomically $ writeTVar (active c) True
|
||||
-- |
|
||||
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
|
||||
|
||||
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
|
||||
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode
|
||||
|
||||
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
|
||||
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs
|
||||
|
||||
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
|
||||
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c
|
||||
|
||||
-- | Acknowledge message (ACK command) asynchronously, no synchronous response
|
||||
ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
|
||||
ackMessageAsync c = withAgentEnv c .:. ackMessageAsync' c
|
||||
|
||||
-- | Create SMP agent connection (NEW command)
|
||||
createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
|
||||
createConnection c enableNtfs cMode = withAgentEnv c $ newConn c "" enableNtfs cMode
|
||||
createConnection c enableNtfs cMode = withAgentEnv c $ newConn c "" False enableNtfs cMode
|
||||
|
||||
-- | Join SMP agent connection (JOIN command)
|
||||
joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" enableNtfs
|
||||
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs
|
||||
|
||||
-- | Allow connection to continue after CONF notification (LET command)
|
||||
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
@@ -283,8 +303,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do
|
||||
-- | execute any SMP agent command
|
||||
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
|
||||
processCommand c (connId, cmd) = case cmd of
|
||||
NEW (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId True cMode
|
||||
JOIN (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId True cReq connInfo
|
||||
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode
|
||||
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo
|
||||
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
|
||||
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo
|
||||
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
|
||||
@@ -295,15 +315,59 @@ processCommand c (connId, cmd) = case cmd of
|
||||
DEL -> deleteConnection' c connId $> (connId, OK)
|
||||
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c connId enableNtfs cMode = do
|
||||
srv <- getSMPServer c
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
(rq, qUri) <- newRcvQueue c srv clientVRange
|
||||
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
|
||||
newConnAsync c corrId enableNtfs cMode = do
|
||||
g <- asks idsDrg
|
||||
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
|
||||
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
|
||||
connId' <- withStore c $ \db -> createRcvConn db g cData rq cMode
|
||||
let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
|
||||
connId <- withStore c $ \db -> createNewConn db g cData cMode
|
||||
enqueueCommand c corrId connId Nothing $ NEW enableNtfs (ACM cMode)
|
||||
pure connId
|
||||
|
||||
joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri (ConnReqUriData _ agentVRange _) _) cInfo = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
case agentVRange `compatibleVersion` aVRange of
|
||||
Just (Compatible connAgentVersion) -> do
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
|
||||
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
|
||||
enqueueCommand c corrId connId Nothing $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo
|
||||
pure connId
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo =
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
|
||||
allowConnectionAsync' c corrId connId confId ownConnInfo =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (RcvConnection _ RcvQueue {server}) ->
|
||||
enqueueCommand c corrId connId (Just server) $ LET confId ownConnInfo
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
|
||||
ackMessageAsync' c corrId connId msgId =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> enqueueAck rq
|
||||
SomeConn _ (RcvConnection _ rq) -> enqueueAck rq
|
||||
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
|
||||
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
|
||||
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
enqueueAck :: RcvQueue -> m ()
|
||||
enqueueAck RcvQueue {server} = do
|
||||
enqueueCommand c corrId connId (Just server) $ ACK msgId
|
||||
|
||||
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
|
||||
newConn c connId asyncMode enableNtfs cMode =
|
||||
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode
|
||||
|
||||
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c)
|
||||
newConnSrv c connId asyncMode enableNtfs cMode srv = do
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
(rq, qUri) <- newRcvQueue c srv clientVRange
|
||||
connId' <- setUpConn asyncMode rq
|
||||
addSubscription c rq connId'
|
||||
when enableNtfs $ do
|
||||
ns <- asks ntfSupervisor
|
||||
@@ -316,9 +380,26 @@ newConn c connId enableNtfs cMode = do
|
||||
(pk1, pk2, e2eRcvParams) <- liftIO $ CR.generateE2EParams CR.e2eEncryptVersion
|
||||
withStore' c $ \db -> createRatchetX3dhKeys db connId' pk1 pk2
|
||||
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams CR.e2eEncryptVRange)
|
||||
where
|
||||
setUpConn True rq = do
|
||||
withStore c $ \db -> updateNewConnRcv db connId rq
|
||||
pure connId
|
||||
setUpConn False rq = do
|
||||
g <- asks idsDrg
|
||||
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
|
||||
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
|
||||
withStore c $ \db -> createRcvConn db g cData rq cMode
|
||||
|
||||
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c connId enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = do
|
||||
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
|
||||
joinConn c connId asyncMode enableNtfs cReq cInfo = do
|
||||
srv <- case cReq of
|
||||
CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ ->
|
||||
getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)]
|
||||
_ -> getSMPServer c
|
||||
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
|
||||
|
||||
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServer -> m ConnId
|
||||
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case ( qUri `compatibleVersion` clientVRange,
|
||||
@@ -330,38 +411,47 @@ joinConn c connId enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUr
|
||||
(_, rcDHRs) <- liftIO C.generateKeyPair'
|
||||
let rc = CR.initSndRatchet rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
|
||||
sq <- newSndQueue qInfo
|
||||
g <- asks idsDrg
|
||||
let duplexHS = connAgentVersion /= 1
|
||||
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
|
||||
connId' <- withStore c $ \db -> runExceptT $ do
|
||||
connId' <- ExceptT $ createSndConn db g cData sq
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
connId' <- setUpConn asyncMode cData sq rc
|
||||
let cData' = (cData :: ConnData) {connId = connId'}
|
||||
tryError (confirmQueue aVersion c cData' sq cInfo $ Just e2eSndParams) >>= \case
|
||||
tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case
|
||||
Right _ -> do
|
||||
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
|
||||
pure connId'
|
||||
Left e -> do
|
||||
-- TODO recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
|
||||
withStore' c (`deleteConn` connId')
|
||||
unless asyncMode $ withStore' c (`deleteConn` connId')
|
||||
throwError e
|
||||
where
|
||||
setUpConn True _ sq rc =
|
||||
withStore c $ \db -> runExceptT $ do
|
||||
ExceptT $ updateNewConnSnd db connId sq
|
||||
liftIO $ createRatchet db connId rc
|
||||
pure connId
|
||||
setUpConn False cData sq rc = do
|
||||
g <- asks idsDrg
|
||||
withStore c $ \db -> runExceptT $ do
|
||||
connId' <- ExceptT $ createSndConn db g cData sq
|
||||
liftIO $ createRatchet db connId' rc
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConn c connId enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = do
|
||||
joinConnSrv c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo srv = do
|
||||
aVRange <- asks $ smpAgentVRange . config
|
||||
clientVRange <- asks $ smpClientVRange . config
|
||||
case ( qUri `compatibleVersion` clientVRange,
|
||||
agentVRange `compatibleVersion` aVRange
|
||||
) of
|
||||
(Just qInfo, Just vrsn) -> do
|
||||
(connId', cReq) <- newConn c connId enableNtfs SCMInvitation
|
||||
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation srv
|
||||
sendInvitation c qInfo vrsn cReq cInfo
|
||||
pure connId'
|
||||
_ -> throwError $ AGENT A_VERSION
|
||||
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
|
||||
throwError $ CMD PROHIBITED
|
||||
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} = do
|
||||
srv <- getSMPServer c
|
||||
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo
|
||||
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
|
||||
(rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion
|
||||
let qInfo = toVersionT qUri smpClientVersion
|
||||
addSubscription c rq connId
|
||||
@@ -391,7 +481,7 @@ acceptContact' c connId enableNtfs invId ownConnInfo = do
|
||||
withStore c (`getConn` contactConnId) >>= \case
|
||||
SomeConn _ ContactConnection {} -> do
|
||||
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
|
||||
joinConn c connId enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
|
||||
withStore' c (`unacceptInvitation` invId)
|
||||
throwError err
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
@@ -411,18 +501,21 @@ processConfirmation c rq@RcvQueue {e2ePrivKey, smpClientVersion = v} SMPConfirma
|
||||
-- | Subscribe to receive connection messages (SUB command) in Reader monad
|
||||
subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection' c connId =
|
||||
withStore c (`getConn` connId) >>= \case
|
||||
SomeConn _ (DuplexConnection cData rq sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
subscribe rq
|
||||
SomeConn _ (SndConnection cData sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
case status (sq :: SndQueue) of
|
||||
Confirmed -> pure ()
|
||||
Active -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ INTERNAL "unexpected queue status"
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (ContactConnection _ rq) -> subscribe rq
|
||||
withStore c (`getConn` connId) >>= \conn -> do
|
||||
resumeConnCmds c connId
|
||||
case conn of
|
||||
SomeConn _ (DuplexConnection cData rq sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
subscribe rq
|
||||
SomeConn _ (SndConnection cData sq) -> do
|
||||
resumeMsgDelivery c cData sq
|
||||
case status (sq :: SndQueue) of
|
||||
Confirmed -> pure ()
|
||||
Active -> throwError $ CONN SIMPLEX
|
||||
_ -> throwError $ INTERNAL "unexpected queue status"
|
||||
SomeConn _ (RcvConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (ContactConnection _ rq) -> subscribe rq
|
||||
SomeConn _ (NewConnection _) -> pure ()
|
||||
where
|
||||
subscribe :: RcvQueue -> m ()
|
||||
subscribe rq = do
|
||||
@@ -436,33 +529,34 @@ subscribeConnections' c connIds = do
|
||||
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConn)
|
||||
let (errs, cs) = M.mapEither id conns
|
||||
errs' = M.map (Left . storeError) errs
|
||||
(sndQs, rcvQs) = M.mapEither rcvOrSndQueue cs
|
||||
sndRs = M.map (sndSubResult . fst) sndQs
|
||||
srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs
|
||||
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
|
||||
srvRcvQs :: Map SMPServer (Map ConnId RcvQueue) = M.foldlWithKey' addRcvQueue M.empty rcvQs
|
||||
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
|
||||
mapM_ (resumeConnCmds c) $ M.keys cs
|
||||
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
|
||||
ns <- asks ntfSupervisor
|
||||
tkn <- readTVarIO (ntfTkn ns)
|
||||
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs
|
||||
let rs = M.unions $ errs' : sndRs : rcvRs
|
||||
let rs = M.unions $ errs' : subRs : rcvRs
|
||||
notifyResultError rs
|
||||
pure rs
|
||||
where
|
||||
rcvOrSndQueue :: SomeConn -> Either (SndQueue, ConnData) (RcvQueue, ConnData)
|
||||
rcvOrSndQueue = \case
|
||||
SomeConn _ (DuplexConnection cData rq _) -> Right (rq, cData)
|
||||
SomeConn _ (SndConnection cData sq) -> Left (sq, cData)
|
||||
SomeConn _ (RcvConnection cData rq) -> Right (rq, cData)
|
||||
SomeConn _ (ContactConnection cData rq) -> Right (rq, cData)
|
||||
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) RcvQueue
|
||||
rcvQueueOrResult = \case
|
||||
SomeConn _ (DuplexConnection _ rq _) -> Right rq
|
||||
SomeConn _ (SndConnection _ sq) -> Left $ sndSubResult sq
|
||||
SomeConn _ (RcvConnection _ rq) -> Right rq
|
||||
SomeConn _ (ContactConnection _ rq) -> Right rq
|
||||
SomeConn _ (NewConnection _) -> Left (Right ())
|
||||
sndSubResult :: SndQueue -> Either AgentErrorType ()
|
||||
sndSubResult sq = case status (sq :: SndQueue) of
|
||||
Confirmed -> Right ()
|
||||
Active -> Left $ CONN SIMPLEX
|
||||
_ -> Left $ INTERNAL "unexpected queue status"
|
||||
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
|
||||
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
|
||||
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribe (srv, qs) = snd <$> subscribeQueues c srv (M.map fst qs)
|
||||
addRcvQueue :: Map SMPServer (Map ConnId RcvQueue) -> ConnId -> RcvQueue -> Map SMPServer (Map ConnId RcvQueue)
|
||||
addRcvQueue m connId rq@RcvQueue {server} = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
|
||||
subscribe :: (SMPServer, Map ConnId RcvQueue) -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribe (srv, qs) = snd <$> subscribeQueues c srv qs
|
||||
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
|
||||
sendNtfCreate ns rcvRs =
|
||||
forM_ (concatMap M.assocs rcvRs) $ \case
|
||||
@@ -502,6 +596,7 @@ getConnectionMessage' c connId = do
|
||||
SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq
|
||||
SomeConn _ (ContactConnection _ rq) -> getQueueMessage c rq
|
||||
SomeConn _ SndConnection {} -> throwError $ CONN SIMPLEX
|
||||
SomeConn _ NewConnection {} -> throwError $ CMD PROHIBITED
|
||||
|
||||
getNotificationMessage' :: forall m. AgentMonad m => AgentClient -> C.CbNonce -> ByteString -> m (NotificationInfo, [SMPMsgMeta])
|
||||
getNotificationMessage' c nonce encNtfInfo = do
|
||||
@@ -541,11 +636,106 @@ sendMessage' c connId msgFlags msg =
|
||||
enqueueMsg :: ConnData -> SndQueue -> m AgentMsgId
|
||||
enqueueMsg cData sq = enqueueMessage c cData sq msgFlags $ A_MSG msg
|
||||
|
||||
-- / async command processing v v v
|
||||
|
||||
enqueueCommand :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> ACommand 'Client -> m ()
|
||||
enqueueCommand c corrId connId server aCommand = do
|
||||
resumeSrvCmds c server
|
||||
commandId <- withStore' c $ \db -> createCommand db corrId connId server aCommand
|
||||
queuePendingCommands c server [commandId]
|
||||
|
||||
resumeSrvCmds :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
|
||||
resumeSrvCmds c server =
|
||||
unlessM (cmdProcessExists c server) $
|
||||
async (runCommandProcessing c server)
|
||||
>>= \a -> atomically (TM.insert server a $ asyncCmdProcesses c)
|
||||
|
||||
resumeConnCmds :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
resumeConnCmds c connId =
|
||||
unlessM connQueued $
|
||||
withStore' c (`getPendingCommands` connId)
|
||||
>>= mapM_ (uncurry enqueueSrvCmds)
|
||||
where
|
||||
enqueueSrvCmds srv cmdIds = unlessM (cmdProcessExists c srv) $ do
|
||||
a <- async (runCommandProcessing c srv)
|
||||
atomically (TM.insert srv a $ asyncCmdProcesses c)
|
||||
queuePendingCommands c srv cmdIds
|
||||
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connCmdsQueued c)
|
||||
|
||||
cmdProcessExists :: AgentMonad m => AgentClient -> Maybe SMPServer -> m Bool
|
||||
cmdProcessExists c srv = atomically $ TM.member srv (asyncCmdProcesses c)
|
||||
|
||||
queuePendingCommands :: AgentMonad m => AgentClient -> Maybe SMPServer -> [AsyncCmdId] -> m ()
|
||||
queuePendingCommands c server cmdIds = atomically $ do
|
||||
q <- getPendingCommandQ c server
|
||||
mapM_ (writeTQueue q) cmdIds
|
||||
|
||||
getPendingCommandQ :: AgentClient -> Maybe SMPServer -> STM (TQueue AsyncCmdId)
|
||||
getPendingCommandQ c server = do
|
||||
maybe newMsgQueue pure =<< TM.lookup server (asyncCmdQueues c)
|
||||
where
|
||||
newMsgQueue = do
|
||||
cq <- newTQueue
|
||||
TM.insert server cq $ asyncCmdQueues c
|
||||
pure cq
|
||||
|
||||
runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
|
||||
runCommandProcessing c@AgentClient {subQ} server = do
|
||||
cq <- atomically $ getPendingCommandQ c server
|
||||
ri <- asks $ messageRetryInterval . config -- different retry interval?
|
||||
forever $ do
|
||||
atomically $ endAgentOperation c AOSndNetwork
|
||||
cmdId <- atomically $ readTQueue cq
|
||||
atomically $ beginAgentOperation c AOSndNetwork
|
||||
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
|
||||
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
|
||||
Right (corrId, connId, ACmd _ cmd) -> processCmd ri corrId connId cmdId cmd
|
||||
where
|
||||
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> ACommand p -> m ()
|
||||
processCmd ri corrId connId cmdId = \case
|
||||
NEW enableNtfs (ACM cMode) -> do
|
||||
usedSrvs <- newTVarIO ([] :: [SMPServer])
|
||||
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
|
||||
(_, cReq) <- newConnSrv c connId True enableNtfs cMode srv
|
||||
notify $ INV (ACR cMode cReq)
|
||||
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do
|
||||
let initUsed = [smpServer (queueAddress :: SMPQueueAddress)]
|
||||
usedSrvs <- newTVarIO initUsed
|
||||
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
|
||||
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
|
||||
notify OK
|
||||
LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
|
||||
ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK
|
||||
cmd -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd)
|
||||
where
|
||||
tryCommand action = withRetryInterval ri $ \loop ->
|
||||
tryError action >>= \case
|
||||
Left e
|
||||
| temporaryAgentError e || e == BROKER HOST -> retryCommand loop
|
||||
| otherwise -> notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
|
||||
Right () -> withStore' c (`deleteCommand` cmdId)
|
||||
retryCommand loop = do
|
||||
-- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent
|
||||
atomically $ endAgentOperation c AOSndNetwork
|
||||
atomically $ beginAgentOperation c AOSndNetwork
|
||||
loop
|
||||
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd)
|
||||
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m ()
|
||||
withNextSrv usedSrvs initUsed action = do
|
||||
used <- readTVarIO usedSrvs
|
||||
srv <- getNextSMPServer c used
|
||||
atomically $ do
|
||||
srvs <- readTVar $ smpServers c
|
||||
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
|
||||
writeTVar usedSrvs used'
|
||||
action srv
|
||||
-- ^ ^ ^ async command processing /
|
||||
|
||||
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
|
||||
enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do
|
||||
resumeMsgDelivery c cData sq
|
||||
msgId <- storeSentMsg
|
||||
queuePendingMsgs c connId sq [msgId]
|
||||
queuePendingMsgs c sq [msgId]
|
||||
pure $ unId msgId
|
||||
where
|
||||
storeSentMsg :: m InternalId
|
||||
@@ -565,28 +755,28 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage
|
||||
|
||||
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
|
||||
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
|
||||
let qKey = (connId, server, sndId)
|
||||
let qKey = (server, sndId)
|
||||
unlessM (queueDelivering qKey) $
|
||||
async (runSmpQueueMsgDelivery c cData sq)
|
||||
>>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c)
|
||||
unlessM connQueued $
|
||||
withStore' c (`getPendingMsgs` connId)
|
||||
>>= queuePendingMsgs c connId sq
|
||||
>>= queuePendingMsgs c sq
|
||||
where
|
||||
queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c)
|
||||
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c)
|
||||
|
||||
queuePendingMsgs :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> [InternalId] -> m ()
|
||||
queuePendingMsgs c connId sq msgIds = atomically $ do
|
||||
queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m ()
|
||||
queuePendingMsgs c sq msgIds = atomically $ do
|
||||
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
|
||||
-- s <- readTVar (msgDeliveryOp c)
|
||||
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
|
||||
q <- getPendingMsgQ c connId sq
|
||||
q <- getPendingMsgQ c sq
|
||||
mapM_ (writeTQueue q) msgIds
|
||||
|
||||
getPendingMsgQ :: AgentClient -> ConnId -> SndQueue -> STM (TQueue InternalId)
|
||||
getPendingMsgQ c connId SndQueue {server, sndId} = do
|
||||
let qKey = (connId, server, sndId)
|
||||
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId)
|
||||
getPendingMsgQ c SndQueue {server, sndId} = do
|
||||
let qKey = (server, sndId)
|
||||
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
|
||||
where
|
||||
newMsgQueue qKey = do
|
||||
@@ -596,7 +786,7 @@ getPendingMsgQ c connId SndQueue {server, sndId} = do
|
||||
|
||||
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
|
||||
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
|
||||
mq <- atomically $ getPendingMsgQ c connId sq
|
||||
mq <- atomically $ getPendingMsgQ c sq
|
||||
ri <- asks $ messageRetryInterval . config
|
||||
forever $ do
|
||||
atomically $ endAgentOperation c AOSndNetwork
|
||||
@@ -677,7 +867,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
|
||||
-- and this branch should never be reached as receive is created before the confirmation,
|
||||
-- so the condition is not necessary here, strictly speaking.
|
||||
_ -> unless (duplexHandshake == Just True) $ do
|
||||
qInfo <- createReplyQueue c cData sq
|
||||
srv <- getSMPServer c
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
|
||||
AM_A_MSG_ -> notify $ SENT mId
|
||||
_ -> pure ()
|
||||
@@ -703,6 +894,7 @@ ackMessage' c connId msgId = do
|
||||
SomeConn _ (RcvConnection _ rq) -> ack rq
|
||||
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
|
||||
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
|
||||
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
ack :: RcvQueue -> m ()
|
||||
ack rq = do
|
||||
@@ -721,6 +913,7 @@ suspendConnection' c connId =
|
||||
SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq
|
||||
SomeConn _ (ContactConnection _ rq) -> suspendQueue c rq
|
||||
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
|
||||
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
|
||||
|
||||
-- | Delete SMP agent connection (DEL command) in Reader monad
|
||||
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
@@ -730,6 +923,7 @@ deleteConnection' c connId =
|
||||
SomeConn _ (RcvConnection _ rq) -> delete rq
|
||||
SomeConn _ (ContactConnection _ rq) -> delete rq
|
||||
SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId)
|
||||
SomeConn _ (NewConnection _) -> withStore' c (`deleteConn` connId)
|
||||
where
|
||||
delete :: RcvQueue -> m ()
|
||||
delete rq = do
|
||||
@@ -748,6 +942,7 @@ getConnectionServers' c connId = connServers <$> withStore c (`getConn` connId)
|
||||
SomeConn _ (SndConnection _ SndQueue {server}) -> ConnectionStats {rcvServers = [], sndServers = [server]}
|
||||
SomeConn _ (DuplexConnection _ RcvQueue {server = s1} SndQueue {server = s2}) -> ConnectionStats {rcvServers = [s1], sndServers = [s2]}
|
||||
SomeConn _ (ContactConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []}
|
||||
SomeConn _ (NewConnection _) -> ConnectionStats {rcvServers = [], sndServers = []}
|
||||
|
||||
-- | Change servers to be used for creating new queues, in Reader monad
|
||||
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
@@ -984,14 +1179,23 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do
|
||||
suspendSendingAndDatabase c
|
||||
|
||||
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
|
||||
getSMPServer c = do
|
||||
smpServers <- readTVarIO $ smpServers c
|
||||
case smpServers of
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
atomically . stateTVar gen $
|
||||
first (servers L.!!) . randomR (0, L.length servers - 1)
|
||||
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
|
||||
|
||||
pickServer :: AgentMonad m => NonEmpty SMPServer -> m SMPServer
|
||||
pickServer = \case
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
|
||||
|
||||
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer
|
||||
getNextSMPServer c usedSrvs = do
|
||||
srvs <- readTVarIO $ smpServers c
|
||||
case L.nonEmpty $ deleteFirstsBy sameAddr (L.toList srvs) usedSrvs of
|
||||
Just srvs' -> pickServer srvs'
|
||||
_ -> pickServer srvs
|
||||
where
|
||||
sameAddr (SMPServer host port _) (SMPServer host' port' _) = host == host' && port == port'
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
subscriber c@AgentClient {msgQ} = forever $ do
|
||||
@@ -1229,8 +1433,8 @@ connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
|
||||
withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
|
||||
enqueueConfirmation c cData sq ownConnInfo Nothing
|
||||
|
||||
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2eEncryption = do
|
||||
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do
|
||||
aMessage <- mkAgentMessage agentVersion
|
||||
msg <- mkConfirmation aMessage
|
||||
sendConfirmation c sq msg
|
||||
@@ -1244,14 +1448,14 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2e
|
||||
mkAgentMessage :: Version -> m AgentMessage
|
||||
mkAgentMessage 1 = pure $ AgentConnInfo connInfo
|
||||
mkAgentMessage _ = do
|
||||
qInfo <- createReplyQueue c cData sq
|
||||
qInfo <- createReplyQueue c cData sq srv
|
||||
pure $ AgentConnInfoReply (qInfo :| []) connInfo
|
||||
|
||||
enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
|
||||
enqueueConfirmation c cData@ConnData {connId, connAgentVersion} sq connInfo e2eEncryption = do
|
||||
resumeMsgDelivery c cData sq
|
||||
msgId <- storeConfirmation
|
||||
queuePendingMsgs c connId sq [msgId]
|
||||
queuePendingMsgs c sq [msgId]
|
||||
where
|
||||
storeConfirmation :: m InternalId
|
||||
storeConfirmation = withStore c $ \db -> runExceptT $ do
|
||||
|
||||
@@ -90,7 +90,6 @@ import Data.Maybe (listToMaybe)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text.Encoding
|
||||
import Data.Tuple (swap)
|
||||
import Data.Word (Word16)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -131,6 +130,8 @@ import Simplex.Messaging.Protocol
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.TMap2 (TMap2)
|
||||
import qualified Simplex.Messaging.TMap2 as TM2
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
@@ -155,13 +156,15 @@ data AgentClient = AgentClient
|
||||
ntfServers :: TVar [NtfServer],
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
useNetworkConfig :: TVar NetworkConfig,
|
||||
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
subscrConns :: TVar (Set ConnId),
|
||||
activeSubscrConns :: TMap ConnId SMPServer,
|
||||
activeSubs :: TMap2 SMPServer ConnId RcvQueue,
|
||||
pendingSubs :: TMap2 SMPServer ConnId RcvQueue,
|
||||
connMsgsQueued :: TMap ConnId Bool,
|
||||
smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId),
|
||||
smpQueueMsgDeliveries :: TMap (ConnId, SMPServer, SMP.SenderId) (Async ()),
|
||||
smpQueueMsgQueues :: TMap (SMPServer, SMP.SenderId) (TQueue InternalId),
|
||||
smpQueueMsgDeliveries :: TMap (SMPServer, SMP.SenderId) (Async ()),
|
||||
connCmdsQueued :: TMap ConnId Bool,
|
||||
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
|
||||
asyncCmdProcesses :: TMap (Maybe SMPServer) (Async ()),
|
||||
ntfNetworkOp :: TVar AgentOpState,
|
||||
rcvNetworkOp :: TVar AgentOpState,
|
||||
msgDeliveryOp :: TVar AgentOpState,
|
||||
@@ -207,13 +210,15 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
ntfServers <- newTVar ntf
|
||||
ntfClients <- TM.empty
|
||||
useNetworkConfig <- newTVar netCfg
|
||||
subscrSrvrs <- TM.empty
|
||||
pendingSubscrSrvrs <- TM.empty
|
||||
subscrConns <- newTVar S.empty
|
||||
activeSubscrConns <- TM.empty
|
||||
activeSubs <- TM2.empty
|
||||
pendingSubs <- TM2.empty
|
||||
connMsgsQueued <- TM.empty
|
||||
smpQueueMsgQueues <- TM.empty
|
||||
smpQueueMsgDeliveries <- TM.empty
|
||||
connCmdsQueued <- TM.empty
|
||||
asyncCmdQueues <- TM.empty
|
||||
asyncCmdProcesses <- TM.empty
|
||||
ntfNetworkOp <- newTVar $ AgentOpState False 0
|
||||
rcvNetworkOp <- newTVar $ AgentOpState False 0
|
||||
msgDeliveryOp <- newTVar $ AgentOpState False 0
|
||||
@@ -225,7 +230,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
|
||||
asyncClients <- newTVar []
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
|
||||
lock <- newTMVar ()
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
|
||||
|
||||
agentStore :: AgentClient -> SQLiteStore
|
||||
agentStore AgentClient {agentEnv = Env {store}} = store
|
||||
@@ -264,26 +269,18 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
|
||||
TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
|
||||
addPendingSubs cVar cs = do
|
||||
let ps = pendingSubscrSrvrs c
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union cs v
|
||||
_ -> TM.insert srv cVar ps
|
||||
TM2.insert1 srv cVar $ pendingSubs c
|
||||
readTVar cVar
|
||||
|
||||
serverDown :: Map ConnId RcvQueue -> IO ()
|
||||
serverDown cs = unless (M.null cs) $
|
||||
whenM (readTVarIO active) $ do
|
||||
let conns = M.keys cs
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
unless (null conns) . notifySub "" $ DOWN srv conns
|
||||
serverDown cs = whenM (readTVarIO active) $ do
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
let conns = M.keys cs
|
||||
unless (null conns) $ do
|
||||
notifySub "" $ DOWN srv conns
|
||||
atomically $ mapM_ (releaseGetLock c) cs
|
||||
unliftIO u reconnectServer
|
||||
|
||||
@@ -301,7 +298,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withAgentLock c $
|
||||
atomically (TM.lookup srv (pendingSubscrSrvrs c) >>= mapM readTVar)
|
||||
atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar)
|
||||
>>= mapM_ resubscribe
|
||||
where
|
||||
resubscribe :: Map ConnId RcvQueue -> m ()
|
||||
@@ -407,12 +404,14 @@ closeAgentClient c = liftIO $ do
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
cancelActions $ smpQueueMsgDeliveries c
|
||||
clear subscrSrvrs
|
||||
clear pendingSubscrSrvrs
|
||||
cancelActions $ asyncCmdProcesses c
|
||||
atomically . TM2.clear $ activeSubs c
|
||||
atomically . TM2.clear $ pendingSubs c
|
||||
clear subscrConns
|
||||
clear activeSubscrConns
|
||||
clear connMsgsQueued
|
||||
clear smpQueueMsgQueues
|
||||
clear connCmdsQueued
|
||||
clear asyncCmdQueues
|
||||
clear getMsgLocks
|
||||
where
|
||||
clear :: Monoid m => (AgentClient -> TVar m) -> IO ()
|
||||
@@ -514,17 +513,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
addPendingSubscription c rq connId
|
||||
TM2.insert server connId rq $ pendingSubs c
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
|
||||
>>= either throwError pure
|
||||
|
||||
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult c rq@RcvQueue {server} connId r = do
|
||||
processSubResult c rq connId r = do
|
||||
case r of
|
||||
Left e ->
|
||||
atomically . unless (temporaryClientError e) $
|
||||
removePendingSubscription c server connId
|
||||
TM2.delete connId (pendingSubs c)
|
||||
_ -> addSubscription c rq connId
|
||||
pure r
|
||||
|
||||
@@ -544,9 +543,9 @@ temporaryAgentError = \case
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
|
||||
forM_ qs_ $ \q -> atomically $ do
|
||||
modifyTVar (subscrConns c) . S.insert $ fst q
|
||||
uncurry (addPendingSubscription c) $ swap q
|
||||
forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
TM2.insert server connId rq $ pendingSubs c
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
@@ -568,35 +567,18 @@ subscribeQueues c srv qs = do
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
||||
TM.insert connId server $ activeSubscrConns c
|
||||
modifyTVar (subscrConns c) $ S.insert connId
|
||||
addSubs_ (subscrSrvrs c) rq connId
|
||||
removePendingSubscription c server connId
|
||||
TM2.insert server connId rq $ activeSubs c
|
||||
TM2.delete connId $ pendingSubs c
|
||||
|
||||
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
|
||||
hasActiveSubscription c connId = TM.member connId (activeSubscrConns c)
|
||||
|
||||
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
|
||||
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
|
||||
|
||||
addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM ()
|
||||
addSubs_ ss rq@RcvQueue {server} connId =
|
||||
TM.lookup server ss >>= \case
|
||||
Just m -> TM.insert connId rq m
|
||||
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
|
||||
hasActiveSubscription c connId = TM2.member connId $ activeSubs c
|
||||
|
||||
removeSubscription :: AgentClient -> ConnId -> STM ()
|
||||
removeSubscription c connId = do
|
||||
modifyTVar (subscrConns c) $ S.delete connId
|
||||
server_ <- TM.lookupDelete connId $ activeSubscrConns c
|
||||
mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_
|
||||
|
||||
removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM ()
|
||||
removePendingSubscription = removeSubs_ . pendingSubscrSrvrs
|
||||
|
||||
removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM ()
|
||||
removeSubs_ ss server connId =
|
||||
TM.lookup server ss >>= mapM_ (TM.delete connId)
|
||||
TM2.delete connId $ activeSubs c
|
||||
TM2.delete connId $ pendingSubs c
|
||||
|
||||
getSubscriptions :: AgentClient -> STM (Set ConnId)
|
||||
getSubscriptions = readTVar . subscrConns
|
||||
|
||||
@@ -40,9 +40,13 @@ module Simplex.Messaging.Agent.Protocol
|
||||
-- * SMP agent protocol types
|
||||
ConnInfo,
|
||||
ACommand (..),
|
||||
ACommandTag (..),
|
||||
aCommandTag,
|
||||
ACmd (..),
|
||||
ACmdTag (..),
|
||||
AParty (..),
|
||||
SAParty (..),
|
||||
APartyI (..),
|
||||
MsgHash,
|
||||
MsgMeta (..),
|
||||
ConnectionStats (..),
|
||||
@@ -92,6 +96,8 @@ module Simplex.Messaging.Agent.Protocol
|
||||
serializeCommand,
|
||||
connMode,
|
||||
connMode',
|
||||
networkCommandP,
|
||||
dbCommandP,
|
||||
commandP,
|
||||
connModeT,
|
||||
serializeQueueStatus,
|
||||
@@ -117,7 +123,6 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:), (.:.))
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
@@ -211,6 +216,12 @@ instance TestEquality SAParty where
|
||||
testEquality SClient SClient = Just Refl
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
class APartyI (p :: AParty) where sAParty :: SAParty p
|
||||
|
||||
instance APartyI Agent where sAParty = SAgent
|
||||
|
||||
instance APartyI Client where sAParty = SClient
|
||||
|
||||
data ACmd = forall p. ACmd (SAParty p) (ACommand p)
|
||||
|
||||
deriving instance Show ACmd
|
||||
@@ -219,9 +230,9 @@ type ConnInfo = ByteString
|
||||
|
||||
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
|
||||
data ACommand (p :: AParty) where
|
||||
NEW :: AConnectionMode -> ACommand Client -- response INV
|
||||
NEW :: Bool -> AConnectionMode -> ACommand Client -- response INV
|
||||
INV :: AConnectionRequestUri -> ACommand Agent
|
||||
JOIN :: AConnectionRequestUri -> ConnInfo -> ACommand Client -- response OK
|
||||
JOIN :: Bool -> AConnectionRequestUri -> ConnInfo -> ACommand Client -- response OK
|
||||
CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake
|
||||
LET :: ConfirmationId -> ConnInfo -> ACommand Client -- ConnInfo is from client
|
||||
REQ :: InvitationId -> L.NonEmpty SMPServer -> ConnInfo -> ACommand Agent -- ConnInfo is from sender
|
||||
@@ -253,6 +264,75 @@ deriving instance Eq (ACommand p)
|
||||
|
||||
deriving instance Show (ACommand p)
|
||||
|
||||
data ACmdTag = forall p. APartyI p => ACmdTag (SAParty p) (ACommandTag p)
|
||||
|
||||
data ACommandTag (p :: AParty) where
|
||||
NEW_ :: ACommandTag Client
|
||||
INV_ :: ACommandTag Agent
|
||||
JOIN_ :: ACommandTag Client
|
||||
CONF_ :: ACommandTag Agent
|
||||
LET_ :: ACommandTag Client
|
||||
REQ_ :: ACommandTag Agent
|
||||
ACPT_ :: ACommandTag Client
|
||||
RJCT_ :: ACommandTag Client
|
||||
INFO_ :: ACommandTag Agent
|
||||
CON_ :: ACommandTag Agent
|
||||
SUB_ :: ACommandTag Client
|
||||
END_ :: ACommandTag Agent
|
||||
CONNECT_ :: ACommandTag Agent
|
||||
DISCONNECT_ :: ACommandTag Agent
|
||||
DOWN_ :: ACommandTag Agent
|
||||
UP_ :: ACommandTag Agent
|
||||
SEND_ :: ACommandTag Client
|
||||
MID_ :: ACommandTag Agent
|
||||
SENT_ :: ACommandTag Agent
|
||||
MERR_ :: ACommandTag Agent
|
||||
MSG_ :: ACommandTag Agent
|
||||
ACK_ :: ACommandTag Client
|
||||
OFF_ :: ACommandTag Client
|
||||
DEL_ :: ACommandTag Client
|
||||
CHK_ :: ACommandTag Client
|
||||
STAT_ :: ACommandTag Agent
|
||||
OK_ :: ACommandTag Agent
|
||||
ERR_ :: ACommandTag Agent
|
||||
SUSPENDED_ :: ACommandTag Agent
|
||||
|
||||
deriving instance Eq (ACommandTag p)
|
||||
|
||||
deriving instance Show (ACommandTag p)
|
||||
|
||||
aCommandTag :: ACommand p -> ACommandTag p
|
||||
aCommandTag = \case
|
||||
NEW {} -> NEW_
|
||||
INV _ -> INV_
|
||||
JOIN {} -> JOIN_
|
||||
CONF {} -> CONF_
|
||||
LET {} -> LET_
|
||||
REQ {} -> REQ_
|
||||
ACPT {} -> ACPT_
|
||||
RJCT _ -> RJCT_
|
||||
INFO _ -> INFO_
|
||||
CON -> CON_
|
||||
SUB -> SUB_
|
||||
END -> END_
|
||||
CONNECT {} -> CONNECT_
|
||||
DISCONNECT {} -> DISCONNECT_
|
||||
DOWN {} -> DOWN_
|
||||
UP {} -> UP_
|
||||
SEND {} -> SEND_
|
||||
MID _ -> MID_
|
||||
SENT _ -> SENT_
|
||||
MERR {} -> MERR_
|
||||
MSG {} -> MSG_
|
||||
ACK _ -> ACK_
|
||||
OFF -> OFF_
|
||||
DEL -> DEL_
|
||||
CHK -> CHK_
|
||||
STAT _ -> STAT_
|
||||
OK -> OK_
|
||||
ERR _ -> ERR_
|
||||
SUSPENDED -> SUSPENDED_
|
||||
|
||||
data ConnectionStats = ConnectionStats
|
||||
{ rcvServers :: [SMPServer],
|
||||
sndServers :: [SMPServer]
|
||||
@@ -920,58 +1000,129 @@ instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
|
||||
|
||||
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
|
||||
|
||||
-- | SMP agent command and response parser for commands passed via network (only parses binary length)
|
||||
networkCommandP :: Parser ACmd
|
||||
networkCommandP = commandP A.takeByteString
|
||||
|
||||
-- | SMP agent command and response parser for commands stored in db (fully parses binary bodies)
|
||||
dbCommandP :: Parser ACmd
|
||||
dbCommandP = commandP $ A.take =<< (A.decimal <* "\n")
|
||||
|
||||
instance Encoding ACmdTag where
|
||||
smpEncode (ACmdTag _ cmd) = smpEncode cmd
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"NEW" -> pure $ ACmdTag SClient NEW_
|
||||
"INV" -> pure $ ACmdTag SAgent INV_
|
||||
"JOIN" -> pure $ ACmdTag SClient JOIN_
|
||||
"CONF" -> pure $ ACmdTag SAgent CONF_
|
||||
"LET" -> pure $ ACmdTag SClient LET_
|
||||
"REQ" -> pure $ ACmdTag SAgent REQ_
|
||||
"ACPT" -> pure $ ACmdTag SClient ACPT_
|
||||
"RJCT" -> pure $ ACmdTag SClient RJCT_
|
||||
"INFO" -> pure $ ACmdTag SAgent INFO_
|
||||
"CON" -> pure $ ACmdTag SAgent CON_
|
||||
"SUB" -> pure $ ACmdTag SClient SUB_
|
||||
"END" -> pure $ ACmdTag SAgent END_
|
||||
"CONNECT" -> pure $ ACmdTag SAgent CONNECT_
|
||||
"DISCONNECT" -> pure $ ACmdTag SAgent DISCONNECT_
|
||||
"DOWN" -> pure $ ACmdTag SAgent DOWN_
|
||||
"UP" -> pure $ ACmdTag SAgent UP_
|
||||
"SEND" -> pure $ ACmdTag SClient SEND_
|
||||
"MID" -> pure $ ACmdTag SAgent MID_
|
||||
"SENT" -> pure $ ACmdTag SAgent SENT_
|
||||
"MERR" -> pure $ ACmdTag SAgent MERR_
|
||||
"MSG" -> pure $ ACmdTag SAgent MSG_
|
||||
"ACK" -> pure $ ACmdTag SClient ACK_
|
||||
"OFF" -> pure $ ACmdTag SClient OFF_
|
||||
"DEL" -> pure $ ACmdTag SClient DEL_
|
||||
"CHK" -> pure $ ACmdTag SClient CHK_
|
||||
"STAT" -> pure $ ACmdTag SAgent STAT_
|
||||
"OK" -> pure $ ACmdTag SAgent OK_
|
||||
"ERR" -> pure $ ACmdTag SAgent ERR_
|
||||
"SUSPENDED" -> pure $ ACmdTag SAgent SUSPENDED_
|
||||
_ -> fail "bad ACmdTag"
|
||||
|
||||
instance APartyI p => Encoding (ACommandTag p) where
|
||||
smpEncode = \case
|
||||
NEW_ -> "NEW"
|
||||
INV_ -> "INV"
|
||||
JOIN_ -> "JOIN"
|
||||
CONF_ -> "CONF"
|
||||
LET_ -> "LET"
|
||||
REQ_ -> "REQ"
|
||||
ACPT_ -> "ACPT"
|
||||
RJCT_ -> "RJCT"
|
||||
INFO_ -> "INFO"
|
||||
CON_ -> "CON"
|
||||
SUB_ -> "SUB"
|
||||
END_ -> "END"
|
||||
CONNECT_ -> "CONNECT"
|
||||
DISCONNECT_ -> "DISCONNECT"
|
||||
DOWN_ -> "DOWN"
|
||||
UP_ -> "UP"
|
||||
SEND_ -> "SEND"
|
||||
MID_ -> "MID"
|
||||
SENT_ -> "SENT"
|
||||
MERR_ -> "MERR"
|
||||
MSG_ -> "MSG"
|
||||
ACK_ -> "ACK"
|
||||
OFF_ -> "OFF"
|
||||
DEL_ -> "DEL"
|
||||
CHK_ -> "CHK"
|
||||
STAT_ -> "STAT"
|
||||
OK_ -> "OK"
|
||||
ERR_ -> "ERR"
|
||||
SUSPENDED_ -> "SUSPENDED"
|
||||
smpP = (\(ACmdTag _ t) -> checkParty t) <$?> smpP
|
||||
|
||||
checkParty :: forall t p p'. (APartyI p, APartyI p') => t p' -> Either String (t p)
|
||||
checkParty x = case testEquality (sAParty @p) (sAParty @p') of
|
||||
Just Refl -> Right x
|
||||
Nothing -> Left "bad party"
|
||||
|
||||
-- | SMP agent command and response parser
|
||||
commandP :: Parser ACmd
|
||||
commandP =
|
||||
"NEW " *> newCmd
|
||||
<|> "INV " *> invResp
|
||||
<|> "JOIN " *> joinCmd
|
||||
<|> "CONF " *> confMsg
|
||||
<|> "LET " *> letCmd
|
||||
<|> "REQ " *> reqMsg
|
||||
<|> "ACPT " *> acptCmd
|
||||
<|> "RJCT " *> rjctCmd
|
||||
<|> "INFO " *> infoCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "END" $> ACmd SAgent END
|
||||
<|> "CONNECT " *> connectResp
|
||||
<|> "DISCONNECT " *> disconnectResp
|
||||
<|> "DOWN " *> downResp
|
||||
<|> "UP " *> upResp
|
||||
<|> "SEND " *> sendCmd
|
||||
<|> "MID " *> msgIdResp
|
||||
<|> "SENT " *> sentResp
|
||||
<|> "MERR " *> msgErrResp
|
||||
<|> "MSG " *> message
|
||||
<|> "ACK " *> ackCmd
|
||||
<|> "OFF" $> ACmd SClient OFF
|
||||
<|> "DEL" $> ACmd SClient DEL
|
||||
<|> "CHK" $> ACmd SClient CHK
|
||||
<|> "STAT " *> statResp
|
||||
<|> "ERR " *> agentError
|
||||
<|> "CON" $> ACmd SAgent CON
|
||||
<|> "OK" $> ACmd SAgent OK
|
||||
commandP :: Parser ByteString -> Parser ACmd
|
||||
commandP binaryP =
|
||||
smpP
|
||||
>>= \case
|
||||
ACmdTag SClient cmd ->
|
||||
ACmd SClient <$> case cmd of
|
||||
NEW_ -> s (NEW <$> strP_ <*> strP)
|
||||
JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> binaryP)
|
||||
LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP)
|
||||
ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> binaryP)
|
||||
RJCT_ -> s (RJCT <$> A.takeByteString)
|
||||
SUB_ -> pure SUB
|
||||
SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP)
|
||||
ACK_ -> s (ACK <$> A.decimal)
|
||||
OFF_ -> pure OFF
|
||||
DEL_ -> pure DEL
|
||||
CHK_ -> pure CHK
|
||||
ACmdTag SAgent cmd ->
|
||||
ACmd SAgent <$> case cmd of
|
||||
INV_ -> s (INV <$> strP)
|
||||
CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP)
|
||||
REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP)
|
||||
INFO_ -> s (INFO <$> binaryP)
|
||||
CON_ -> pure CON
|
||||
END_ -> pure END
|
||||
CONNECT_ -> s (CONNECT <$> strP_ <*> strP)
|
||||
DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP)
|
||||
DOWN_ -> s (DOWN <$> strP_ <*> connections)
|
||||
UP_ -> s (UP <$> strP_ <*> connections)
|
||||
MID_ -> s (MID <$> A.decimal)
|
||||
SENT_ -> s (SENT <$> A.decimal)
|
||||
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
|
||||
MSG_ -> s (MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> binaryP)
|
||||
STAT_ -> s (STAT <$> strP)
|
||||
OK_ -> pure OK
|
||||
ERR_ -> s (ERR <$> strP)
|
||||
SUSPENDED_ -> pure SUSPENDED
|
||||
where
|
||||
newCmd = ACmd SClient . NEW <$> strP
|
||||
invResp = ACmd SAgent . INV <$> strP
|
||||
joinCmd = ACmd SClient .: JOIN <$> strP_ <*> A.takeByteString
|
||||
confMsg = ACmd SAgent .:. CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> A.takeByteString
|
||||
letCmd = ACmd SClient .: LET <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
|
||||
reqMsg = ACmd SAgent .:. REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> A.takeByteString
|
||||
acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
|
||||
rjctCmd = ACmd SClient . RJCT <$> A.takeByteString
|
||||
infoCmd = ACmd SAgent . INFO <$> A.takeByteString
|
||||
connectResp = ACmd SAgent .: CONNECT <$> strP_ <*> strP
|
||||
disconnectResp = ACmd SAgent .: DISCONNECT <$> strP_ <*> strP
|
||||
downResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
|
||||
upResp = ACmd SAgent .: UP <$> strP_ <*> connections
|
||||
sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString
|
||||
msgIdResp = ACmd SAgent . MID <$> A.decimal
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
msgErrResp = ACmd SAgent .: MERR <$> A.decimal <* A.space <*> strP
|
||||
message = ACmd SAgent .:. MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> A.takeByteString
|
||||
ackCmd = ACmd SClient . ACK <$> A.decimal
|
||||
statResp = ACmd SAgent . STAT <$> strP
|
||||
s :: Parser a -> Parser a
|
||||
s p = A.space *> p
|
||||
connections :: Parser [ConnId]
|
||||
connections = strP `A.sepBy'` A.char ','
|
||||
msgMetaP = do
|
||||
integrity <- strP
|
||||
@@ -980,17 +1131,16 @@ commandP =
|
||||
sndMsgId <- " S=" *> A.decimal
|
||||
pure MsgMeta {integrity, recipient, broker, sndMsgId}
|
||||
partyMeta idParser = (,) <$> idParser <* A.char ',' <*> tsISO8601P
|
||||
agentError = ACmd SAgent . ERR <$> strP
|
||||
|
||||
parseCommand :: ByteString -> Either AgentErrorType ACmd
|
||||
parseCommand = parse commandP $ CMD SYNTAX
|
||||
parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX
|
||||
|
||||
-- | Serialize SMP agent command.
|
||||
serializeCommand :: ACommand p -> ByteString
|
||||
serializeCommand = \case
|
||||
NEW cMode -> "NEW " <> strEncode cMode
|
||||
NEW ntfs cMode -> B.unwords ["NEW", strEncode ntfs, strEncode cMode]
|
||||
INV cReq -> "INV " <> strEncode cReq
|
||||
JOIN cReq cInfo -> B.unwords ["JOIN", strEncode cReq, serializeBinary cInfo]
|
||||
JOIN ntfs cReq cInfo -> B.unwords ["JOIN", strEncode ntfs, strEncode cReq, serializeBinary cInfo]
|
||||
CONF confId srvs cInfo -> B.unwords ["CONF", confId, strEncodeList srvs, serializeBinary cInfo]
|
||||
LET confId cInfo -> B.unwords ["LET", confId, serializeBinary cInfo]
|
||||
REQ invId srvs cInfo -> B.unwords ["REQ", invId, strEncode srvs, serializeBinary cInfo]
|
||||
@@ -1068,7 +1218,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
tConnId :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p)
|
||||
tConnId (_, connId, _) cmd = case cmd of
|
||||
-- NEW, JOIN and ACPT have optional connId
|
||||
NEW _ -> Right cmd
|
||||
NEW _ _ -> Right cmd
|
||||
JOIN {} -> Right cmd
|
||||
ACPT {} -> Right cmd
|
||||
-- ERROR response does not always have connId
|
||||
@@ -1086,7 +1236,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
||||
cmdWithMsgBody = \case
|
||||
SEND msgFlags body -> SEND msgFlags <$$> getBody body
|
||||
MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body
|
||||
JOIN qUri cInfo -> JOIN qUri <$$> getBody cInfo
|
||||
JOIN ntfs qUri cInfo -> JOIN ntfs qUri <$$> getBody cInfo
|
||||
CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo
|
||||
LET confId cInfo -> LET confId <$$> getBody cInfo
|
||||
REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo
|
||||
|
||||
@@ -92,7 +92,7 @@ data SndQueue = SndQueue
|
||||
-- * Connection types
|
||||
|
||||
-- | Type of a connection.
|
||||
data ConnType = CRcv | CSnd | CDuplex | CContact deriving (Eq, Show)
|
||||
data ConnType = CNew | CRcv | CSnd | CDuplex | CContact deriving (Eq, Show)
|
||||
|
||||
-- | Connection of a specific type.
|
||||
--
|
||||
@@ -105,6 +105,7 @@ data ConnType = CRcv | CSnd | CDuplex | CContact deriving (Eq, Show)
|
||||
-- - DuplexConnection is a connection that has both receive and send queues set up,
|
||||
-- typically created by upgrading a receive or a send connection with a missing queue.
|
||||
data Connection (d :: ConnType) where
|
||||
NewConnection :: ConnData -> Connection CNew
|
||||
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
|
||||
SndConnection :: ConnData -> SndQueue -> Connection CSnd
|
||||
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
|
||||
@@ -115,12 +116,14 @@ deriving instance Eq (Connection d)
|
||||
deriving instance Show (Connection d)
|
||||
|
||||
data SConnType :: ConnType -> Type where
|
||||
SCNew :: SConnType CNew
|
||||
SCRcv :: SConnType CRcv
|
||||
SCSnd :: SConnType CSnd
|
||||
SCDuplex :: SConnType CDuplex
|
||||
SCContact :: SConnType CContact
|
||||
|
||||
connType :: SConnType c -> ConnType
|
||||
connType SCNew = CNew
|
||||
connType SCRcv = CRcv
|
||||
connType SCSnd = CSnd
|
||||
connType SCDuplex = CDuplex
|
||||
@@ -272,6 +275,8 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show)
|
||||
|
||||
type InternalTs = UTCTime
|
||||
|
||||
type AsyncCmdId = Int64
|
||||
|
||||
-- * Store errors
|
||||
|
||||
-- | Agent store error.
|
||||
@@ -293,6 +298,8 @@ data StoreError
|
||||
SEInvitationNotFound
|
||||
| -- | Message not found
|
||||
SEMsgNotFound
|
||||
| -- | Command not found
|
||||
SECmdNotFound
|
||||
| -- | Currently not used. The intention was to pass current expected queue status in methods,
|
||||
-- as we always know what it should be at any stage of the protocol,
|
||||
-- and in case it does not match use this error.
|
||||
|
||||
@@ -26,6 +26,9 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
sqlString,
|
||||
|
||||
-- * Queues and connections
|
||||
createNewConn,
|
||||
updateNewConnRcv,
|
||||
updateNewConnSnd,
|
||||
createRcvConn,
|
||||
createSndConn,
|
||||
getConn,
|
||||
@@ -68,6 +71,11 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
getRatchet,
|
||||
getSkippedMsgKeys,
|
||||
updateRatchet,
|
||||
-- Async commands
|
||||
createCommand,
|
||||
getPendingCommands,
|
||||
getPendingCommand,
|
||||
deleteCommand,
|
||||
-- Notification device token persistence
|
||||
createNtfToken,
|
||||
getSavedNtfToken,
|
||||
@@ -107,8 +115,10 @@ import Data.Bifunctor (second)
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.Char (toLower)
|
||||
import Data.Function (on)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (foldl')
|
||||
import Data.Int (Int64)
|
||||
import Data.List (find, foldl', groupBy)
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromMaybe, listToMaybe)
|
||||
@@ -265,6 +275,37 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
|
||||
ConnData {connId = ""} -> createWithRandomId gVar create
|
||||
ConnData {connId} -> create connId $> Right connId
|
||||
|
||||
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
|
||||
|
||||
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ())
|
||||
updateNewConnRcv db connId rq@RcvQueue {server} =
|
||||
getConn db connId $>>= \case
|
||||
(SomeConn _ NewConnection {}) -> updateConn
|
||||
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
where
|
||||
updateConn :: IO (Either StoreError ())
|
||||
updateConn = do
|
||||
upsertServer_ db server
|
||||
insertRcvQueue_ db connId rq
|
||||
pure $ Right ()
|
||||
|
||||
updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ())
|
||||
updateNewConnSnd db connId sq@SndQueue {server} =
|
||||
getConn db connId $>>= \case
|
||||
(SomeConn _ NewConnection {}) -> updateConn
|
||||
(SomeConn _ SndConnection {}) -> updateConn -- to allow retries
|
||||
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
|
||||
where
|
||||
updateConn :: IO (Either StoreError ())
|
||||
updateConn = do
|
||||
upsertServer_ db server
|
||||
insertSndQueue_ db connId sq
|
||||
pure $ Right ()
|
||||
|
||||
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
|
||||
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
|
||||
createConn_ gVar cData $ \connId -> do
|
||||
@@ -661,6 +702,50 @@ updateRatchet db connId rc skipped = do
|
||||
forM_ (M.assocs mks) $ \(msgN, mk) ->
|
||||
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
|
||||
|
||||
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> ACommand 'Client -> IO AsyncCmdId
|
||||
createCommand db corrId connId srv cmd = do
|
||||
DB.execute
|
||||
db
|
||||
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command) VALUES (?,?,?,?,?,?)"
|
||||
(host_, port_, corrId, connId, aCommandTag cmd, cmd)
|
||||
insertedRowId db
|
||||
where
|
||||
(host_, port_) =
|
||||
case srv of
|
||||
Just (SMPServer host port _) -> (Just host, Just port)
|
||||
_ -> (Nothing, Nothing)
|
||||
|
||||
insertedRowId :: DB.Connection -> IO Int64
|
||||
insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()"
|
||||
|
||||
getPendingCommands :: DB.Connection -> ConnId -> IO [(Maybe SMPServer, [AsyncCmdId])]
|
||||
getPendingCommands db connId = do
|
||||
map (\ids -> (fst $ head ids, map snd ids)) . groupBy ((==) `on` fst) . map srvCmdId
|
||||
<$> DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT c.host, c.port, s.key_hash, c.command_id
|
||||
FROM commands c
|
||||
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
|
||||
WHERE conn_id = ?
|
||||
ORDER BY c.host, c.port, c.command_id ASC
|
||||
|]
|
||||
(Only connId)
|
||||
where
|
||||
srvCmdId (host, port, keyHash, cmdId) = (SMPServer <$> host <*> port <*> keyHash, cmdId)
|
||||
|
||||
getPendingCommand :: DB.Connection -> AsyncCmdId -> IO (Either StoreError (ACorrId, ConnId, ACmd))
|
||||
getPendingCommand db msgId = do
|
||||
firstRow id SECmdNotFound $
|
||||
DB.query
|
||||
db
|
||||
"SELECT corr_id, conn_id, command FROM commands WHERE command_id = ?"
|
||||
(Only msgId)
|
||||
|
||||
deleteCommand :: DB.Connection -> AsyncCmdId -> IO ()
|
||||
deleteCommand db cmdId =
|
||||
DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId)
|
||||
|
||||
createNtfToken :: DB.Connection -> NtfToken -> IO ()
|
||||
createNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do
|
||||
upsertNtfServer_ db srv
|
||||
@@ -1023,6 +1108,14 @@ instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1
|
||||
|
||||
instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
|
||||
|
||||
instance ToField (ACommand p) where toField = toField . serializeCommand
|
||||
|
||||
instance FromField ACmd where fromField = blobFieldParser dbCommandP
|
||||
|
||||
instance APartyI p => ToField (ACommandTag p) where toField = toField . smpEncode
|
||||
|
||||
instance FromField ACmdTag where fromField = blobFieldParser smpP
|
||||
|
||||
listToEither :: e -> [a] -> Either e a
|
||||
listToEither _ (x : _) = Right x
|
||||
listToEither e _ = Left e
|
||||
@@ -1130,6 +1223,7 @@ getConn dbConn connId =
|
||||
(Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
|
||||
(Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
|
||||
(Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ)
|
||||
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection connData)
|
||||
_ -> Left SEConnNotFound
|
||||
|
||||
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
|
||||
|
||||
@@ -36,6 +36,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
|
||||
@@ -50,7 +51,8 @@ schemaMigrations =
|
||||
("20220607_v2", m20220608_v2),
|
||||
("m20220625_v2_ntf_mode", m20220625_v2_ntf_mode),
|
||||
("m20220811_onion_hosts", m20220811_onion_hosts),
|
||||
("m20220817_connection_ntfs", m20220817_connection_ntfs)
|
||||
("m20220817_connection_ntfs", m20220817_connection_ntfs),
|
||||
("m20220905_commands", m20220905_commands)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20220905_commands :: Query
|
||||
m20220905_commands =
|
||||
[sql|
|
||||
CREATE TABLE commands (
|
||||
command_id INTEGER PRIMARY KEY,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
host TEXT,
|
||||
port TEXT,
|
||||
corr_id BLOB NOT NULL,
|
||||
command_tag BLOB NOT NULL,
|
||||
command BLOB NOT NULL,
|
||||
agent_version INTEGER NOT NULL DEFAULT 1,
|
||||
FOREIGN KEY (host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
|]
|
||||
@@ -194,3 +194,15 @@ CREATE TABLE ntf_subscriptions(
|
||||
FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE commands(
|
||||
command_id INTEGER PRIMARY KEY,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
host TEXT,
|
||||
port TEXT,
|
||||
corr_id BLOB NOT NULL,
|
||||
command_tag BLOB NOT NULL,
|
||||
command BLOB NOT NULL,
|
||||
agent_version INTEGER NOT NULL DEFAULT 1,
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
@@ -2,6 +2,7 @@ module Simplex.Messaging.TMap
|
||||
( TMap,
|
||||
empty,
|
||||
singleton,
|
||||
clear,
|
||||
Simplex.Messaging.TMap.null,
|
||||
Simplex.Messaging.TMap.lookup,
|
||||
member,
|
||||
@@ -31,6 +32,10 @@ singleton :: k -> a -> STM (TMap k a)
|
||||
singleton k v = newTVar $ M.singleton k v
|
||||
{-# INLINE singleton #-}
|
||||
|
||||
clear :: TMap k a -> STM ()
|
||||
clear m = writeTVar m M.empty
|
||||
{-# INLINE clear #-}
|
||||
|
||||
null :: TMap k a -> STM Bool
|
||||
null m = M.null <$> readTVar m
|
||||
{-# INLINE null #-}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.TMap2
|
||||
( TMap2,
|
||||
empty,
|
||||
clear,
|
||||
Simplex.Messaging.TMap2.lookup,
|
||||
lookup1,
|
||||
member,
|
||||
insert,
|
||||
insert1,
|
||||
delete,
|
||||
lookupDelete1,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (forM_, (>=>))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (whenM, ($>>=))
|
||||
|
||||
-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys.
|
||||
-- It allows direct access via k1 to a group of k2 values and via k2 to one value
|
||||
data TMap2 k1 k2 a = TMap2
|
||||
{ _m1 :: TMap k1 (TMap k2 a),
|
||||
_m2 :: TMap k2 k1
|
||||
}
|
||||
|
||||
empty :: STM (TMap2 k1 k2 a)
|
||||
empty = TMap2 <$> TM.empty <*> TM.empty
|
||||
|
||||
clear :: TMap2 k1 k2 a -> STM ()
|
||||
clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2
|
||||
|
||||
lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a)
|
||||
lookup k2 TMap2 {_m1, _m2} = do
|
||||
TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2
|
||||
|
||||
lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
|
||||
lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1
|
||||
{-# INLINE lookup1 #-}
|
||||
|
||||
member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool
|
||||
member k2 TMap2 {_m2} = TM.member k2 _m2
|
||||
{-# INLINE member #-}
|
||||
|
||||
insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM ()
|
||||
insert k1 k2 v TMap2 {_m1, _m2} =
|
||||
TM.lookup k2 _m2 >>= \case
|
||||
Just k1'
|
||||
| k1 == k1' -> _insert1
|
||||
| otherwise -> _delete1 k1' k2 _m1 >> _insert2
|
||||
_ -> _insert2
|
||||
where
|
||||
_insert1 =
|
||||
TM.lookup k1 _m1 >>= \case
|
||||
Just m -> TM.insert k2 v m
|
||||
_ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1
|
||||
_insert2 = TM.insert k2 k1 _m2 >> _insert1
|
||||
|
||||
insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM ()
|
||||
insert1 k1 m' TMap2 {_m1, _m2} =
|
||||
TM.lookup k1 _m1 >>= \case
|
||||
Just m -> readTVar m' >>= (`TM.union` m)
|
||||
_ -> TM.insert k1 m' _m1
|
||||
|
||||
delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM ()
|
||||
delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1)
|
||||
|
||||
_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM ()
|
||||
_delete1 k1 k2 m1 =
|
||||
TM.lookup k1 m1
|
||||
>>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1))
|
||||
|
||||
lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
|
||||
lookupDelete1 k1 TMap2 {_m1, _m2} = do
|
||||
m_ <- TM.lookupDelete k1 _m1
|
||||
forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet
|
||||
pure m_
|
||||
+21
-21
@@ -131,9 +131,9 @@ pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
|
||||
|
||||
testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnection _ alice bob = do
|
||||
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW INV")
|
||||
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice", "JOIN " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
|
||||
alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
bob <# ("", "alice", INFO "alice's connInfo")
|
||||
@@ -164,9 +164,9 @@ testDuplexConnection _ alice bob = do
|
||||
|
||||
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnRandomIds _ alice bob = do
|
||||
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW INV")
|
||||
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV")
|
||||
let cReq' = strEncode cReq
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN " <> cReq' <> " 14\nbob's connInfo")
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " 14\nbob's connInfo")
|
||||
("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
|
||||
bobConn' `shouldBe` bobConn
|
||||
alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False
|
||||
@@ -197,10 +197,10 @@ testDuplexConnRandomIds _ alice bob = do
|
||||
|
||||
testContactConnection :: Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testContactConnection _ alice bob tom = do
|
||||
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW CON")
|
||||
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON")
|
||||
let cReq' = strEncode cReq
|
||||
|
||||
bob #: ("11", "alice", "JOIN " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
|
||||
("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
|
||||
alice #: ("2", "bob", "ACPT " <> aInvId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:)
|
||||
@@ -213,7 +213,7 @@ testContactConnection _ alice bob tom = do
|
||||
bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False
|
||||
bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK)
|
||||
|
||||
tom #: ("21", "alice", "JOIN " <> cReq' <> " 14\ntom's connInfo") #> ("21", "alice", OK)
|
||||
tom #: ("21", "alice", "JOIN T " <> cReq' <> " 14\ntom's connInfo") #> ("21", "alice", OK)
|
||||
("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:)
|
||||
alice #: ("4", "tom", "ACPT " <> aInvId' <> " 16\nalice's connInfo") #> ("4", "tom", OK)
|
||||
("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:)
|
||||
@@ -228,10 +228,10 @@ testContactConnection _ alice bob tom = do
|
||||
|
||||
testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testContactConnRandomIds _ alice bob = do
|
||||
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW CON")
|
||||
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON")
|
||||
let cReq' = strEncode cReq
|
||||
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN " <> cReq' <> " 14\nbob's connInfo")
|
||||
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " 14\nbob's connInfo")
|
||||
("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
|
||||
aliceContact' `shouldBe` aliceContact
|
||||
|
||||
@@ -251,9 +251,9 @@ testContactConnRandomIds _ alice bob = do
|
||||
|
||||
testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testRejectContactRequest _ alice bob = do
|
||||
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW CON")
|
||||
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice", "JOIN " <> cReq' <> " 10\nbob's info") #> ("11", "alice", OK)
|
||||
bob #: ("11", "alice", "JOIN T " <> cReq' <> " 10\nbob's info") #> ("11", "alice", OK)
|
||||
("", "a_contact", Right (REQ aInvId _ "bob's info")) <- (alice <#:)
|
||||
-- RJCT must use correct contact connection
|
||||
alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND)
|
||||
@@ -282,7 +282,7 @@ testSubscription _ alice1 alice2 bob = do
|
||||
|
||||
testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO ()
|
||||
testSubscrNotification t (server, _) client = do
|
||||
client #: ("1", "conn1", "NEW INV") =#> \case ("1", "conn1", INV {}) -> True; _ -> False
|
||||
client #: ("1", "conn1", "NEW T INV") =#> \case ("1", "conn1", INV {}) -> True; _ -> False
|
||||
client #:# "nothing should be delivered to client before the server is killed"
|
||||
killThread server
|
||||
client <# ("", "", DOWN testSMPServer ["conn1"])
|
||||
@@ -392,9 +392,9 @@ testConcurrentMsgDelivery :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testConcurrentMsgDelivery _ alice bob = do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
|
||||
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW INV")
|
||||
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV")
|
||||
let cReq' = strEncode cReq
|
||||
bob #: ("11", "alice2", "JOIN " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice2", OK)
|
||||
bob #: ("11", "alice2", "JOIN T " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice2", OK)
|
||||
("", "bob2", Right (CONF _confId _ "bob's connInfo")) <- (alice <#:)
|
||||
-- below commands would be needed to accept bob's connection, but alice does not
|
||||
-- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
|
||||
@@ -431,9 +431,9 @@ testMsgDeliveryQuotaExceeded _ alice bob = do
|
||||
|
||||
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW INV")
|
||||
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV")
|
||||
let cReq' = strEncode cReq
|
||||
h2 #: ("c2", name1, "JOIN " <> cReq' <> " 5\ninfo2") #> ("c2", name1, OK)
|
||||
h2 #: ("c2", name1, "JOIN T " <> cReq' <> " 5\ninfo2") #> ("c2", name1, OK)
|
||||
("", _, Right (CONF connId _ "info2")) <- (h1 <#:)
|
||||
h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK)
|
||||
h2 <# ("", name1, INFO "info1")
|
||||
@@ -452,9 +452,9 @@ sendMessage (h1, name1) (h2, name2) msg = do
|
||||
|
||||
-- connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString)
|
||||
-- connect' h1 h2 = do
|
||||
-- ("c1", conn2, Right (INV cReq)) <- h1 #: ("c1", "", "NEW INV")
|
||||
-- ("c1", conn2, Right (INV cReq)) <- h1 #: ("c1", "", "NEW T INV")
|
||||
-- let cReq' = strEncode cReq
|
||||
-- ("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN " <> cReq' <> " 5\ninfo2")
|
||||
-- ("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN T " <> cReq' <> " 5\ninfo2")
|
||||
-- ("", _, Right (REQ connId _ "info2")) <- (h1 <#:)
|
||||
-- h1 #: ("c3", conn2, "ACPT " <> connId <> " 5\ninfo1") =#> \case ("c3", c, OK) -> c == conn2; _ -> False
|
||||
-- h2 <# ("", conn1, INFO "info1")
|
||||
@@ -471,17 +471,17 @@ syntaxTests t = do
|
||||
describe "NEW" $ do
|
||||
describe "valid" $ do
|
||||
-- TODO: add tests with defined connection id
|
||||
it "with correct parameter" $ ("211", "", "NEW INV") >#>= \case ("211", _, "INV" : _) -> True; _ -> False
|
||||
it "with correct parameter" $ ("211", "", "NEW T INV") >#>= \case ("211", _, "INV" : _) -> True; _ -> False
|
||||
describe "invalid" $ do
|
||||
-- TODO: add tests with defined connection id
|
||||
it "with incorrect parameter" $ ("222", "", "NEW hi") >#> ("222", "", "ERR CMD SYNTAX")
|
||||
it "with incorrect parameter" $ ("222", "", "NEW T hi") >#> ("222", "", "ERR CMD SYNTAX")
|
||||
|
||||
describe "JOIN" $ do
|
||||
describe "valid" $ do
|
||||
it "using same server as in invitation" $
|
||||
( "311",
|
||||
"a",
|
||||
"JOIN https://simpex.chat/invitation#/?smp=smp%3A%2F%2F"
|
||||
"JOIN T https://simpex.chat/invitation#/?smp=smp%3A%2F%2F"
|
||||
<> urlEncode True "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
<> "%40localhost%3A5001%2F3456-w%3D%3D%23"
|
||||
<> urlEncode True sampleDhKey
|
||||
|
||||
@@ -114,6 +114,11 @@ functionalAPITests t = do
|
||||
describe "Batching SMP commands" $ do
|
||||
it "should subscribe to multiple subscriptions with batching" $
|
||||
testBatchedSubscriptions t
|
||||
describe "Async agent commands" $ do
|
||||
it "should connect using async agent commands" $
|
||||
withSmpServer t testAsyncCommands
|
||||
it "should restore and complete async commands on restart" $
|
||||
testAsyncCommandsRestore t
|
||||
|
||||
testAgentClient :: IO ()
|
||||
testAgentClient = do
|
||||
@@ -560,6 +565,64 @@ testBatchedSubscriptions t = do
|
||||
killThread t1
|
||||
pure res
|
||||
|
||||
testAsyncCommands :: IO ()
|
||||
testAsyncCommands = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
bobId <- createConnectionAsync alice "1" True SCMInvitation
|
||||
("1", bobId', INV (ACR _ qInfo)) <- get alice
|
||||
liftIO $ bobId' `shouldBe` bobId
|
||||
aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo"
|
||||
("2", aliceId', OK) <- get bob
|
||||
liftIO $ aliceId' `shouldBe` aliceId
|
||||
("", _, CONF confId _ "bob's connInfo") <- get alice
|
||||
allowConnectionAsync alice "3" bobId confId "alice's connInfo"
|
||||
("3", _, OK) <- get alice
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4
|
||||
1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 1)
|
||||
2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 2)
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
ackMessageAsync bob "4" aliceId $ baseId + 1
|
||||
("4", _, OK) <- get bob
|
||||
get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False
|
||||
ackMessageAsync bob "5" aliceId $ baseId + 2
|
||||
("5", _, OK) <- get bob
|
||||
3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 3)
|
||||
4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 4)
|
||||
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
|
||||
ackMessageAsync alice "6" bobId $ baseId + 3
|
||||
("6", _, OK) <- get alice
|
||||
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
|
||||
ackMessageAsync alice "7" bobId $ baseId + 4
|
||||
("7", _, OK) <- get alice
|
||||
pure ()
|
||||
pure ()
|
||||
where
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
testAsyncCommandsRestore :: ATransport -> IO ()
|
||||
testAsyncCommandsRestore t = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation
|
||||
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
|
||||
disconnectAgentClient alice
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
withSmpServerStoreLogOn t testPort $ \_ -> do
|
||||
Right () <- runExceptT $ do
|
||||
subscribeConnection alice' bobId
|
||||
("1", _, INV _) <- get alice'
|
||||
pure ()
|
||||
pure ()
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings = exchangeGreetingsMsgId 4
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h
|
||||
where
|
||||
get h = do
|
||||
t@(_, _, cmdStr) <- tGetRaw h
|
||||
case parseAll commandP cmdStr of
|
||||
case parseAll networkCommandP cmdStr of
|
||||
Right (ACmd SAgent CONNECT {}) -> get h
|
||||
Right (ACmd SAgent DISCONNECT {}) -> get h
|
||||
_ -> pure t
|
||||
|
||||
Reference in New Issue
Block a user