Merge branch 'master' into sqlcipher

This commit is contained in:
Evgeny Poberezkin
2022-09-14 18:22:46 +01:00
14 changed files with 845 additions and 219 deletions
+2
View File
@@ -52,6 +52,7 @@ library
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
Simplex.Messaging.Client
Simplex.Messaging.Client.Agent
Simplex.Messaging.Crypto
@@ -81,6 +82,7 @@ library
Simplex.Messaging.Server.Stats
Simplex.Messaging.Server.StoreLog
Simplex.Messaging.TMap
Simplex.Messaging.TMap2
Simplex.Messaging.Transport
Simplex.Messaging.Transport.Client
Simplex.Messaging.Transport.HTTP2
+281 -77
View File
@@ -12,7 +12,6 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
-- |
-- Module : Simplex.Messaging.Agent
@@ -39,6 +38,10 @@ module Simplex.Messaging.Agent
disconnectAgentClient,
resumeAgentClient,
withAgentLock,
createConnectionAsync,
joinConnectionAsync,
allowConnectionAsync,
ackMessageAsync,
createConnection,
joinConnection,
allowConnection,
@@ -80,8 +83,9 @@ import Control.Monad.Reader
import Crypto.Random (MonadRandom)
import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Char8 (ByteString)
import Data.Composition ((.:), (.:.))
import Data.Composition ((.:), (.:.), (.::))
import Data.Functor (($>))
import Data.List (deleteFirstsBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
@@ -141,13 +145,29 @@ resumeAgentClient c = atomically $ writeTVar (active c) True
-- |
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
createConnectionAsync c corrId enableNtfs cMode = withAgentEnv c $ newConnAsync c corrId enableNtfs cMode
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnectionAsync c corrId enableNtfs = withAgentEnv c .: joinConnAsync c corrId enableNtfs
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c
-- | Acknowledge message (ACK command) asynchronously, no synchronous response
ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
ackMessageAsync c = withAgentEnv c .:. ackMessageAsync' c
-- | Create SMP agent connection (NEW command)
createConnection :: AgentErrorMonad m => AgentClient -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
createConnection c enableNtfs cMode = withAgentEnv c $ newConn c "" enableNtfs cMode
createConnection c enableNtfs cMode = withAgentEnv c $ newConn c "" False enableNtfs cMode
-- | Join SMP agent connection (JOIN command)
joinConnection :: AgentErrorMonad m => AgentClient -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" enableNtfs
joinConnection c enableNtfs = withAgentEnv c .: joinConn c "" False enableNtfs
-- | Allow connection to continue after CONF notification (LET command)
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
@@ -283,8 +303,8 @@ client c@AgentClient {rcvQ, subQ} = forever $ do
-- | execute any SMP agent command
processCommand :: forall m. AgentMonad m => AgentClient -> (ConnId, ACommand 'Client) -> m (ConnId, ACommand 'Agent)
processCommand c (connId, cmd) = case cmd of
NEW (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId True cMode
JOIN (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId True cReq connInfo
NEW enableNtfs (ACM cMode) -> second (INV . ACR cMode) <$> newConn c connId False enableNtfs cMode
JOIN enableNtfs (ACR _ cReq) connInfo -> (,OK) <$> joinConn c connId False enableNtfs cReq connInfo
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
@@ -295,15 +315,59 @@ processCommand c (connId, cmd) = case cmd of
DEL -> deleteConnection' c connId $> (connId, OK)
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
newConn c connId enableNtfs cMode = do
srv <- getSMPServer c
clientVRange <- asks $ smpClientVRange . config
(rq, qUri) <- newRcvQueue c srv clientVRange
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
newConnAsync c corrId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
connId' <- withStore c $ \db -> createRcvConn db g cData rq cMode
let cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
connId <- withStore c $ \db -> createNewConn db g cData cMode
enqueueCommand c corrId connId Nothing $ NEW enableNtfs (ACM cMode)
pure connId
joinConnAsync :: AgentMonad m => AgentClient -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnAsync c corrId enableNtfs cReqUri@(CRInvitationUri (ConnReqUriData _ agentVRange _) _) cInfo = do
aVRange <- asks $ smpAgentVRange . config
case agentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ JOIN enableNtfs (ACR sConnectionMode cReqUri) cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
joinConnAsync _c _corrId _enableNtfs (CRContactUri _) _cInfo =
throwError $ CMD PROHIBITED
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
allowConnectionAsync' c corrId connId confId ownConnInfo =
withStore c (`getConn` connId) >>= \case
SomeConn _ (RcvConnection _ RcvQueue {server}) ->
enqueueCommand c corrId connId (Just server) $ LET confId ownConnInfo
_ -> throwError $ CMD PROHIBITED
ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> m ()
ackMessageAsync' c corrId connId msgId =
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection _ rq _) -> enqueueAck rq
SomeConn _ (RcvConnection _ rq) -> enqueueAck rq
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
where
enqueueAck :: RcvQueue -> m ()
enqueueAck RcvQueue {server} = do
enqueueCommand c corrId connId (Just server) $ ACK msgId
newConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> m (ConnId, ConnectionRequestUri c)
newConn c connId asyncMode enableNtfs cMode =
getSMPServer c >>= newConnSrv c connId asyncMode enableNtfs cMode
newConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> SConnectionMode c -> SMPServer -> m (ConnId, ConnectionRequestUri c)
newConnSrv c connId asyncMode enableNtfs cMode srv = do
clientVRange <- asks $ smpClientVRange . config
(rq, qUri) <- newRcvQueue c srv clientVRange
connId' <- setUpConn asyncMode rq
addSubscription c rq connId'
when enableNtfs $ do
ns <- asks ntfSupervisor
@@ -316,9 +380,26 @@ newConn c connId enableNtfs cMode = do
(pk1, pk2, e2eRcvParams) <- liftIO $ CR.generateE2EParams CR.e2eEncryptVersion
withStore' c $ \db -> createRatchetX3dhKeys db connId' pk1 pk2
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams CR.e2eEncryptVRange)
where
setUpConn True rq = do
withStore c $ \db -> updateNewConnRcv db connId rq
pure connId
setUpConn False rq = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing} -- connection mode is determined by the accepting agent
withStore c $ \db -> createRcvConn db g cData rq cMode
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c connId enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo = do
joinConn :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c connId asyncMode enableNtfs cReq cInfo = do
srv <- case cReq of
CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _ ->
getNextSMPServer c [smpServer (queueAddress :: SMPQueueAddress)]
_ -> getSMPServer c
joinConnSrv c connId asyncMode enableNtfs cReq cInfo srv
joinConnSrv :: AgentMonad m => AgentClient -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> SMPServer -> m ConnId
joinConnSrv c connId asyncMode enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUri :| _)) e2eRcvParamsUri) cInfo srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
@@ -330,38 +411,47 @@ joinConn c connId enableNtfs (CRInvitationUri (ConnReqUriData _ agentVRange (qUr
(_, rcDHRs) <- liftIO C.generateKeyPair'
let rc = CR.initSndRatchet rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
sq <- newSndQueue qInfo
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS}
connId' <- withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData sq
liftIO $ createRatchet db connId' rc
pure connId'
connId' <- setUpConn asyncMode cData sq rc
let cData' = (cData :: ConnData) {connId = connId'}
tryError (confirmQueue aVersion c cData' sq cInfo $ Just e2eSndParams) >>= \case
tryError (confirmQueue aVersion c cData' sq srv cInfo $ Just e2eSndParams) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
pure connId'
Left e -> do
-- TODO recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
withStore' c (`deleteConn` connId')
unless asyncMode $ withStore' c (`deleteConn` connId')
throwError e
where
setUpConn True _ sq rc =
withStore c $ \db -> runExceptT $ do
ExceptT $ updateNewConnSnd db connId sq
liftIO $ createRatchet db connId rc
pure connId
setUpConn False cData sq rc = do
g <- asks idsDrg
withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData sq
liftIO $ createRatchet db connId' rc
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConn c connId enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo = do
joinConnSrv c connId False enableNtfs (CRContactUri (ConnReqUriData _ agentVRange (qUri :| _))) cInfo srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
agentVRange `compatibleVersion` aVRange
) of
(Just qInfo, Just vrsn) -> do
(connId', cReq) <- newConn c connId enableNtfs SCMInvitation
(connId', cReq) <- newConnSrv c connId False enableNtfs SCMInvitation srv
sendInvitation c qInfo vrsn cReq cInfo
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConnSrv _c _connId True _enableNtfs (CRContactUri _) _cInfo _srv = do
throwError $ CMD PROHIBITED
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} = do
srv <- getSMPServer c
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SMPServer -> m SMPQueueInfo
createReplyQueue c ConnData {connId, enableNtfs} SndQueue {smpClientVersion} srv = do
(rq, qUri) <- newRcvQueue c srv $ versionToRange smpClientVersion
let qInfo = toVersionT qUri smpClientVersion
addSubscription c rq connId
@@ -391,7 +481,7 @@ acceptContact' c connId enableNtfs invId ownConnInfo = do
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ ContactConnection {} -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConn c connId enableNtfs connReq ownConnInfo `catchError` \err -> do
joinConn c connId False enableNtfs connReq ownConnInfo `catchError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
@@ -411,18 +501,21 @@ processConfirmation c rq@RcvQueue {e2ePrivKey, smpClientVersion = v} SMPConfirma
-- | Subscribe to receive connection messages (SUB command) in Reader monad
subscribeConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
subscribeConnection' c connId =
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData rq sq) -> do
resumeMsgDelivery c cData sq
subscribe rq
SomeConn _ (SndConnection cData sq) -> do
resumeMsgDelivery c cData sq
case status (sq :: SndQueue) of
Confirmed -> pure ()
Active -> throwError $ CONN SIMPLEX
_ -> throwError $ INTERNAL "unexpected queue status"
SomeConn _ (RcvConnection _ rq) -> subscribe rq
SomeConn _ (ContactConnection _ rq) -> subscribe rq
withStore c (`getConn` connId) >>= \conn -> do
resumeConnCmds c connId
case conn of
SomeConn _ (DuplexConnection cData rq sq) -> do
resumeMsgDelivery c cData sq
subscribe rq
SomeConn _ (SndConnection cData sq) -> do
resumeMsgDelivery c cData sq
case status (sq :: SndQueue) of
Confirmed -> pure ()
Active -> throwError $ CONN SIMPLEX
_ -> throwError $ INTERNAL "unexpected queue status"
SomeConn _ (RcvConnection _ rq) -> subscribe rq
SomeConn _ (ContactConnection _ rq) -> subscribe rq
SomeConn _ (NewConnection _) -> pure ()
where
subscribe :: RcvQueue -> m ()
subscribe rq = do
@@ -436,33 +529,34 @@ subscribeConnections' c connIds = do
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConn)
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(sndQs, rcvQs) = M.mapEither rcvOrSndQueue cs
sndRs = M.map (sndSubResult . fst) sndQs
srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
srvRcvQs :: Map SMPServer (Map ConnId RcvQueue) = M.foldlWithKey' addRcvQueue M.empty rcvQs
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
mapM_ (resumeConnCmds c) $ M.keys cs
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
ns <- asks ntfSupervisor
tkn <- readTVarIO (ntfTkn ns)
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs
let rs = M.unions $ errs' : sndRs : rcvRs
let rs = M.unions $ errs' : subRs : rcvRs
notifyResultError rs
pure rs
where
rcvOrSndQueue :: SomeConn -> Either (SndQueue, ConnData) (RcvQueue, ConnData)
rcvOrSndQueue = \case
SomeConn _ (DuplexConnection cData rq _) -> Right (rq, cData)
SomeConn _ (SndConnection cData sq) -> Left (sq, cData)
SomeConn _ (RcvConnection cData rq) -> Right (rq, cData)
SomeConn _ (ContactConnection cData rq) -> Right (rq, cData)
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) RcvQueue
rcvQueueOrResult = \case
SomeConn _ (DuplexConnection _ rq _) -> Right rq
SomeConn _ (SndConnection _ sq) -> Left $ sndSubResult sq
SomeConn _ (RcvConnection _ rq) -> Right rq
SomeConn _ (ContactConnection _ rq) -> Right rq
SomeConn _ (NewConnection _) -> Left (Right ())
sndSubResult :: SndQueue -> Either AgentErrorType ()
sndSubResult sq = case status (sq :: SndQueue) of
Confirmed -> Right ()
Active -> Left $ CONN SIMPLEX
_ -> Left $ INTERNAL "unexpected queue status"
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
subscribe (srv, qs) = snd <$> subscribeQueues c srv (M.map fst qs)
addRcvQueue :: Map SMPServer (Map ConnId RcvQueue) -> ConnId -> RcvQueue -> Map SMPServer (Map ConnId RcvQueue)
addRcvQueue m connId rq@RcvQueue {server} = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
subscribe :: (SMPServer, Map ConnId RcvQueue) -> m (Map ConnId (Either AgentErrorType ()))
subscribe (srv, qs) = snd <$> subscribeQueues c srv qs
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
sendNtfCreate ns rcvRs =
forM_ (concatMap M.assocs rcvRs) $ \case
@@ -502,6 +596,7 @@ getConnectionMessage' c connId = do
SomeConn _ (RcvConnection _ rq) -> getQueueMessage c rq
SomeConn _ (ContactConnection _ rq) -> getQueueMessage c rq
SomeConn _ SndConnection {} -> throwError $ CONN SIMPLEX
SomeConn _ NewConnection {} -> throwError $ CMD PROHIBITED
getNotificationMessage' :: forall m. AgentMonad m => AgentClient -> C.CbNonce -> ByteString -> m (NotificationInfo, [SMPMsgMeta])
getNotificationMessage' c nonce encNtfInfo = do
@@ -541,11 +636,106 @@ sendMessage' c connId msgFlags msg =
enqueueMsg :: ConnData -> SndQueue -> m AgentMsgId
enqueueMsg cData sq = enqueueMessage c cData sq msgFlags $ A_MSG msg
-- / async command processing v v v
enqueueCommand :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> ACommand 'Client -> m ()
enqueueCommand c corrId connId server aCommand = do
resumeSrvCmds c server
commandId <- withStore' c $ \db -> createCommand db corrId connId server aCommand
queuePendingCommands c server [commandId]
resumeSrvCmds :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
resumeSrvCmds c server =
unlessM (cmdProcessExists c server) $
async (runCommandProcessing c server)
>>= \a -> atomically (TM.insert server a $ asyncCmdProcesses c)
resumeConnCmds :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
resumeConnCmds c connId =
unlessM connQueued $
withStore' c (`getPendingCommands` connId)
>>= mapM_ (uncurry enqueueSrvCmds)
where
enqueueSrvCmds srv cmdIds = unlessM (cmdProcessExists c srv) $ do
a <- async (runCommandProcessing c srv)
atomically (TM.insert srv a $ asyncCmdProcesses c)
queuePendingCommands c srv cmdIds
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connCmdsQueued c)
cmdProcessExists :: AgentMonad m => AgentClient -> Maybe SMPServer -> m Bool
cmdProcessExists c srv = atomically $ TM.member srv (asyncCmdProcesses c)
queuePendingCommands :: AgentMonad m => AgentClient -> Maybe SMPServer -> [AsyncCmdId] -> m ()
queuePendingCommands c server cmdIds = atomically $ do
q <- getPendingCommandQ c server
mapM_ (writeTQueue q) cmdIds
getPendingCommandQ :: AgentClient -> Maybe SMPServer -> STM (TQueue AsyncCmdId)
getPendingCommandQ c server = do
maybe newMsgQueue pure =<< TM.lookup server (asyncCmdQueues c)
where
newMsgQueue = do
cq <- newTQueue
TM.insert server cq $ asyncCmdQueues c
pure cq
runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
runCommandProcessing c@AgentClient {subQ} server = do
cq <- atomically $ getPendingCommandQ c server
ri <- asks $ messageRetryInterval . config -- different retry interval?
forever $ do
atomically $ endAgentOperation c AOSndNetwork
cmdId <- atomically $ readTQueue cq
atomically $ beginAgentOperation c AOSndNetwork
E.try (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
Left (e :: E.SomeException) -> atomically $ writeTBQueue subQ ("", "", ERR . INTERNAL $ show e)
Right (corrId, connId, ACmd _ cmd) -> processCmd ri corrId connId cmdId cmd
where
processCmd :: RetryInterval -> ACorrId -> ConnId -> AsyncCmdId -> ACommand p -> m ()
processCmd ri corrId connId cmdId = \case
NEW enableNtfs (ACM cMode) -> do
usedSrvs <- newTVarIO ([] :: [SMPServer])
tryCommand . withNextSrv usedSrvs [] $ \srv -> do
(_, cReq) <- newConnSrv c connId True enableNtfs cMode srv
notify $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = SMPQueueUri {queueAddress} :| _} _)) connInfo -> do
let initUsed = [smpServer (queueAddress :: SMPQueueAddress)]
usedSrvs <- newTVarIO initUsed
tryCommand . withNextSrv usedSrvs initUsed $ \srv -> do
void $ joinConnSrv c connId True enableNtfs cReq connInfo srv
notify OK
LET confId ownCInfo -> tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
ACK msgId -> tryCommand $ ackMessage' c connId msgId >> notify OK
cmd -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd)
where
tryCommand action = withRetryInterval ri $ \loop ->
tryError action >>= \case
Left e
| temporaryAgentError e || e == BROKER HOST -> retryCommand loop
| otherwise -> notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
Right () -> withStore' c (`deleteCommand` cmdId)
retryCommand loop = do
-- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent
atomically $ endAgentOperation c AOSndNetwork
atomically $ beginAgentOperation c AOSndNetwork
loop
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, cmd)
withNextSrv :: TVar [SMPServer] -> [SMPServer] -> (SMPServer -> m ()) -> m ()
withNextSrv usedSrvs initUsed action = do
used <- readTVarIO usedSrvs
srv <- getNextSMPServer c used
atomically $ do
srvs <- readTVar $ smpServers c
let used' = if length used + 1 >= L.length srvs then initUsed else srv : used
writeTVar usedSrvs used'
action srv
-- ^ ^ ^ async command processing /
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage = do
resumeMsgDelivery c cData sq
msgId <- storeSentMsg
queuePendingMsgs c connId sq [msgId]
queuePendingMsgs c sq [msgId]
pure $ unId msgId
where
storeSentMsg :: m InternalId
@@ -565,28 +755,28 @@ enqueueMessage c cData@ConnData {connId, connAgentVersion} sq msgFlags aMessage
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
let qKey = (connId, server, sndId)
let qKey = (server, sndId)
unlessM (queueDelivering qKey) $
async (runSmpQueueMsgDelivery c cData sq)
>>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c)
unlessM connQueued $
withStore' c (`getPendingMsgs` connId)
>>= queuePendingMsgs c connId sq
>>= queuePendingMsgs c sq
where
queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c)
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connMsgsQueued c)
queuePendingMsgs :: AgentMonad m => AgentClient -> ConnId -> SndQueue -> [InternalId] -> m ()
queuePendingMsgs c connId sq msgIds = atomically $ do
queuePendingMsgs :: AgentMonad m => AgentClient -> SndQueue -> [InternalId] -> m ()
queuePendingMsgs c sq msgIds = atomically $ do
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
-- s <- readTVar (msgDeliveryOp c)
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
q <- getPendingMsgQ c connId sq
q <- getPendingMsgQ c sq
mapM_ (writeTQueue q) msgIds
getPendingMsgQ :: AgentClient -> ConnId -> SndQueue -> STM (TQueue InternalId)
getPendingMsgQ c connId SndQueue {server, sndId} = do
let qKey = (connId, server, sndId)
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId)
getPendingMsgQ c SndQueue {server, sndId} = do
let qKey = (server, sndId)
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
where
newMsgQueue qKey = do
@@ -596,7 +786,7 @@ getPendingMsgQ c connId SndQueue {server, sndId} = do
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandshake} sq = do
mq <- atomically $ getPendingMsgQ c connId sq
mq <- atomically $ getPendingMsgQ c sq
ri <- asks $ messageRetryInterval . config
forever $ do
atomically $ endAgentOperation c AOSndNetwork
@@ -677,7 +867,8 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {connId, duplexHandsh
-- and this branch should never be reached as receive is created before the confirmation,
-- so the condition is not necessary here, strictly speaking.
_ -> unless (duplexHandshake == Just True) $ do
qInfo <- createReplyQueue c cData sq
srv <- getSMPServer c
qInfo <- createReplyQueue c cData sq srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
_ -> pure ()
@@ -703,6 +894,7 @@ ackMessage' c connId msgId = do
SomeConn _ (RcvConnection _ rq) -> ack rq
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
SomeConn _ (ContactConnection _ _) -> throwError $ CMD PROHIBITED
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
where
ack :: RcvQueue -> m ()
ack rq = do
@@ -721,6 +913,7 @@ suspendConnection' c connId =
SomeConn _ (RcvConnection _ rq) -> suspendQueue c rq
SomeConn _ (ContactConnection _ rq) -> suspendQueue c rq
SomeConn _ (SndConnection _ _) -> throwError $ CONN SIMPLEX
SomeConn _ (NewConnection _) -> throwError $ CMD PROHIBITED
-- | Delete SMP agent connection (DEL command) in Reader monad
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
@@ -730,6 +923,7 @@ deleteConnection' c connId =
SomeConn _ (RcvConnection _ rq) -> delete rq
SomeConn _ (ContactConnection _ rq) -> delete rq
SomeConn _ (SndConnection _ _) -> withStore' c (`deleteConn` connId)
SomeConn _ (NewConnection _) -> withStore' c (`deleteConn` connId)
where
delete :: RcvQueue -> m ()
delete rq = do
@@ -748,6 +942,7 @@ getConnectionServers' c connId = connServers <$> withStore c (`getConn` connId)
SomeConn _ (SndConnection _ SndQueue {server}) -> ConnectionStats {rcvServers = [], sndServers = [server]}
SomeConn _ (DuplexConnection _ RcvQueue {server = s1} SndQueue {server = s2}) -> ConnectionStats {rcvServers = [s1], sndServers = [s2]}
SomeConn _ (ContactConnection _ RcvQueue {server}) -> ConnectionStats {rcvServers = [server], sndServers = []}
SomeConn _ (NewConnection _) -> ConnectionStats {rcvServers = [], sndServers = []}
-- | Change servers to be used for creating new queues, in Reader monad
setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
@@ -984,14 +1179,23 @@ suspendAgent' c@AgentClient {agentState = as} maxDelay = do
suspendSendingAndDatabase c
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
getSMPServer c = do
smpServers <- readTVarIO $ smpServers c
case smpServers of
srv :| [] -> pure srv
servers -> do
gen <- asks randomServer
atomically . stateTVar gen $
first (servers L.!!) . randomR (0, L.length servers - 1)
getSMPServer c = readTVarIO (smpServers c) >>= pickServer
pickServer :: AgentMonad m => NonEmpty SMPServer -> m SMPServer
pickServer = \case
srv :| [] -> pure srv
servers -> do
gen <- asks randomServer
atomically $ (servers L.!!) <$> stateTVar gen (randomR (0, L.length servers - 1))
getNextSMPServer :: AgentMonad m => AgentClient -> [SMPServer] -> m SMPServer
getNextSMPServer c usedSrvs = do
srvs <- readTVarIO $ smpServers c
case L.nonEmpty $ deleteFirstsBy sameAddr (L.toList srvs) usedSrvs of
Just srvs' -> pickServer srvs'
_ -> pickServer srvs
where
sameAddr (SMPServer host port _) (SMPServer host' port' _) = host == host' && port == port'
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
@@ -1229,8 +1433,8 @@ connectReplyQueues c cData@ConnData {connId} ownConnInfo (qInfo :| _) = do
withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq ownConnInfo Nothing
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2eEncryption = do
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServer -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption = do
aMessage <- mkAgentMessage agentVersion
msg <- mkConfirmation aMessage
sendConfirmation c sq msg
@@ -1244,14 +1448,14 @@ confirmQueue (Compatible agentVersion) c cData@ConnData {connId} sq connInfo e2e
mkAgentMessage :: Version -> m AgentMessage
mkAgentMessage 1 = pure $ AgentConnInfo connInfo
mkAgentMessage _ = do
qInfo <- createReplyQueue c cData sq
qInfo <- createReplyQueue c cData sq srv
pure $ AgentConnInfoReply (qInfo :| []) connInfo
enqueueConfirmation :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
enqueueConfirmation c cData@ConnData {connId, connAgentVersion} sq connInfo e2eEncryption = do
resumeMsgDelivery c cData sq
msgId <- storeConfirmation
queuePendingMsgs c connId sq [msgId]
queuePendingMsgs c sq [msgId]
where
storeConfirmation :: m InternalId
storeConfirmation = withStore c $ \db -> runExceptT $ do
+40 -58
View File
@@ -90,7 +90,6 @@ import Data.Maybe (listToMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text.Encoding
import Data.Tuple (swap)
import Data.Word (Word16)
import qualified Database.SQLite.Simple as DB
import Simplex.Messaging.Agent.Env.SQLite
@@ -131,6 +130,8 @@ import Simplex.Messaging.Protocol
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.TMap2 (TMap2)
import qualified Simplex.Messaging.TMap2 as TM2
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
@@ -155,13 +156,15 @@ data AgentClient = AgentClient
ntfServers :: TVar [NtfServer],
ntfClients :: TMap NtfServer NtfClientVar,
useNetworkConfig :: TVar NetworkConfig,
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
subscrConns :: TVar (Set ConnId),
activeSubscrConns :: TMap ConnId SMPServer,
activeSubs :: TMap2 SMPServer ConnId RcvQueue,
pendingSubs :: TMap2 SMPServer ConnId RcvQueue,
connMsgsQueued :: TMap ConnId Bool,
smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId),
smpQueueMsgDeliveries :: TMap (ConnId, SMPServer, SMP.SenderId) (Async ()),
smpQueueMsgQueues :: TMap (SMPServer, SMP.SenderId) (TQueue InternalId),
smpQueueMsgDeliveries :: TMap (SMPServer, SMP.SenderId) (Async ()),
connCmdsQueued :: TMap ConnId Bool,
asyncCmdQueues :: TMap (Maybe SMPServer) (TQueue AsyncCmdId),
asyncCmdProcesses :: TMap (Maybe SMPServer) (Async ()),
ntfNetworkOp :: TVar AgentOpState,
rcvNetworkOp :: TVar AgentOpState,
msgDeliveryOp :: TVar AgentOpState,
@@ -207,13 +210,15 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
ntfServers <- newTVar ntf
ntfClients <- TM.empty
useNetworkConfig <- newTVar netCfg
subscrSrvrs <- TM.empty
pendingSubscrSrvrs <- TM.empty
subscrConns <- newTVar S.empty
activeSubscrConns <- TM.empty
activeSubs <- TM2.empty
pendingSubs <- TM2.empty
connMsgsQueued <- TM.empty
smpQueueMsgQueues <- TM.empty
smpQueueMsgDeliveries <- TM.empty
connCmdsQueued <- TM.empty
asyncCmdQueues <- TM.empty
asyncCmdProcesses <- TM.empty
ntfNetworkOp <- newTVar $ AgentOpState False 0
rcvNetworkOp <- newTVar $ AgentOpState False 0
msgDeliveryOp <- newTVar $ AgentOpState False 0
@@ -225,7 +230,7 @@ newAgentClient InitialAgentServers {smp, ntf, netCfg} agentEnv = do
asyncClients <- newTVar []
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
lock <- newTMVar ()
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrSrvrs, pendingSubscrSrvrs, subscrConns, activeSubscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, useNetworkConfig, subscrConns, activeSubs, pendingSubs, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, connCmdsQueued, asyncCmdQueues, asyncCmdProcesses, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
agentStore :: AgentClient -> SQLiteStore
agentStore AgentClient {agentEnv = Env {store}} = store
@@ -264,26 +269,18 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
TM2.lookupDelete1 srv (activeSubs c) >>= mapM updateSubs
where
updateSubs cVar = do
cs <- readTVar cVar
modifyTVar' (activeSubscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
addPendingSubs cVar cs = do
let ps = pendingSubscrSrvrs c
TM.lookup srv ps >>= \case
Just v -> TM.union cs v
_ -> TM.insert srv cVar ps
TM2.insert1 srv cVar $ pendingSubs c
readTVar cVar
serverDown :: Map ConnId RcvQueue -> IO ()
serverDown cs = unless (M.null cs) $
whenM (readTVarIO active) $ do
let conns = M.keys cs
notifySub "" $ hostEvent DISCONNECT client
unless (null conns) . notifySub "" $ DOWN srv conns
serverDown cs = whenM (readTVarIO active) $ do
notifySub "" $ hostEvent DISCONNECT client
let conns = M.keys cs
unless (null conns) $ do
notifySub "" $ DOWN srv conns
atomically $ mapM_ (releaseGetLock c) cs
unliftIO u reconnectServer
@@ -301,7 +298,7 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
reconnectClient :: m ()
reconnectClient =
withAgentLock c $
atomically (TM.lookup srv (pendingSubscrSrvrs c) >>= mapM readTVar)
atomically (TM2.lookup1 srv (pendingSubs c) >>= mapM readTVar)
>>= mapM_ resubscribe
where
resubscribe :: Map ConnId RcvQueue -> m ()
@@ -407,12 +404,14 @@ closeAgentClient c = liftIO $ do
cancelActions $ reconnections c
cancelActions $ asyncClients c
cancelActions $ smpQueueMsgDeliveries c
clear subscrSrvrs
clear pendingSubscrSrvrs
cancelActions $ asyncCmdProcesses c
atomically . TM2.clear $ activeSubs c
atomically . TM2.clear $ pendingSubs c
clear subscrConns
clear activeSubscrConns
clear connMsgsQueued
clear smpQueueMsgQueues
clear connCmdsQueued
clear asyncCmdQueues
clear getMsgLocks
where
clear :: Monoid m => (AgentClient -> TVar m) -> IO ()
@@ -514,17 +513,17 @@ subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
addPendingSubscription c rq connId
TM2.insert server connId rq $ pendingSubs c
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
>>= either throwError pure
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult c rq@RcvQueue {server} connId r = do
processSubResult c rq connId r = do
case r of
Left e ->
atomically . unless (temporaryClientError e) $
removePendingSubscription c server connId
TM2.delete connId (pendingSubs c)
_ -> addSubscription c rq connId
pure r
@@ -544,9 +543,9 @@ temporaryAgentError = \case
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ \q -> atomically $ do
modifyTVar (subscrConns c) . S.insert $ fst q
uncurry (addPendingSubscription c) $ swap q
forM_ qs_ $ \(connId, rq@RcvQueue {server}) -> atomically $ do
modifyTVar (subscrConns c) $ S.insert connId
TM2.insert server connId rq $ pendingSubs c
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
@@ -568,35 +567,18 @@ subscribeQueues c srv qs = do
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c rq@RcvQueue {server} connId = atomically $ do
TM.insert connId server $ activeSubscrConns c
modifyTVar (subscrConns c) $ S.insert connId
addSubs_ (subscrSrvrs c) rq connId
removePendingSubscription c server connId
TM2.insert server connId rq $ activeSubs c
TM2.delete connId $ pendingSubs c
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
hasActiveSubscription c connId = TM.member connId (activeSubscrConns c)
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
addSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> RcvQueue -> ConnId -> STM ()
addSubs_ ss rq@RcvQueue {server} connId =
TM.lookup server ss >>= \case
Just m -> TM.insert connId rq m
_ -> TM.singleton connId rq >>= \m -> TM.insert server m ss
hasActiveSubscription c connId = TM2.member connId $ activeSubs c
removeSubscription :: AgentClient -> ConnId -> STM ()
removeSubscription c connId = do
modifyTVar (subscrConns c) $ S.delete connId
server_ <- TM.lookupDelete connId $ activeSubscrConns c
mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_
removePendingSubscription :: AgentClient -> SMPServer -> ConnId -> STM ()
removePendingSubscription = removeSubs_ . pendingSubscrSrvrs
removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> STM ()
removeSubs_ ss server connId =
TM.lookup server ss >>= mapM_ (TM.delete connId)
TM2.delete connId $ activeSubs c
TM2.delete connId $ pendingSubs c
getSubscriptions :: AgentClient -> STM (Set ConnId)
getSubscriptions = readTVar . subscrConns
+209 -59
View File
@@ -40,9 +40,13 @@ module Simplex.Messaging.Agent.Protocol
-- * SMP agent protocol types
ConnInfo,
ACommand (..),
ACommandTag (..),
aCommandTag,
ACmd (..),
ACmdTag (..),
AParty (..),
SAParty (..),
APartyI (..),
MsgHash,
MsgMeta (..),
ConnectionStats (..),
@@ -92,6 +96,8 @@ module Simplex.Messaging.Agent.Protocol
serializeCommand,
connMode,
connMode',
networkCommandP,
dbCommandP,
commandP,
connModeT,
serializeQueueStatus,
@@ -117,7 +123,6 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.))
import Data.Functor (($>))
import Data.Int (Int64)
import Data.Kind (Type)
@@ -211,6 +216,12 @@ instance TestEquality SAParty where
testEquality SClient SClient = Just Refl
testEquality _ _ = Nothing
class APartyI (p :: AParty) where sAParty :: SAParty p
instance APartyI Agent where sAParty = SAgent
instance APartyI Client where sAParty = SClient
data ACmd = forall p. ACmd (SAParty p) (ACommand p)
deriving instance Show ACmd
@@ -219,9 +230,9 @@ type ConnInfo = ByteString
-- | Parameterized type for SMP agent protocol commands and responses from all participants.
data ACommand (p :: AParty) where
NEW :: AConnectionMode -> ACommand Client -- response INV
NEW :: Bool -> AConnectionMode -> ACommand Client -- response INV
INV :: AConnectionRequestUri -> ACommand Agent
JOIN :: AConnectionRequestUri -> ConnInfo -> ACommand Client -- response OK
JOIN :: Bool -> AConnectionRequestUri -> ConnInfo -> ACommand Client -- response OK
CONF :: ConfirmationId -> [SMPServer] -> ConnInfo -> ACommand Agent -- ConnInfo is from sender, [SMPServer] will be empty only in v1 handshake
LET :: ConfirmationId -> ConnInfo -> ACommand Client -- ConnInfo is from client
REQ :: InvitationId -> L.NonEmpty SMPServer -> ConnInfo -> ACommand Agent -- ConnInfo is from sender
@@ -253,6 +264,75 @@ deriving instance Eq (ACommand p)
deriving instance Show (ACommand p)
data ACmdTag = forall p. APartyI p => ACmdTag (SAParty p) (ACommandTag p)
data ACommandTag (p :: AParty) where
NEW_ :: ACommandTag Client
INV_ :: ACommandTag Agent
JOIN_ :: ACommandTag Client
CONF_ :: ACommandTag Agent
LET_ :: ACommandTag Client
REQ_ :: ACommandTag Agent
ACPT_ :: ACommandTag Client
RJCT_ :: ACommandTag Client
INFO_ :: ACommandTag Agent
CON_ :: ACommandTag Agent
SUB_ :: ACommandTag Client
END_ :: ACommandTag Agent
CONNECT_ :: ACommandTag Agent
DISCONNECT_ :: ACommandTag Agent
DOWN_ :: ACommandTag Agent
UP_ :: ACommandTag Agent
SEND_ :: ACommandTag Client
MID_ :: ACommandTag Agent
SENT_ :: ACommandTag Agent
MERR_ :: ACommandTag Agent
MSG_ :: ACommandTag Agent
ACK_ :: ACommandTag Client
OFF_ :: ACommandTag Client
DEL_ :: ACommandTag Client
CHK_ :: ACommandTag Client
STAT_ :: ACommandTag Agent
OK_ :: ACommandTag Agent
ERR_ :: ACommandTag Agent
SUSPENDED_ :: ACommandTag Agent
deriving instance Eq (ACommandTag p)
deriving instance Show (ACommandTag p)
aCommandTag :: ACommand p -> ACommandTag p
aCommandTag = \case
NEW {} -> NEW_
INV _ -> INV_
JOIN {} -> JOIN_
CONF {} -> CONF_
LET {} -> LET_
REQ {} -> REQ_
ACPT {} -> ACPT_
RJCT _ -> RJCT_
INFO _ -> INFO_
CON -> CON_
SUB -> SUB_
END -> END_
CONNECT {} -> CONNECT_
DISCONNECT {} -> DISCONNECT_
DOWN {} -> DOWN_
UP {} -> UP_
SEND {} -> SEND_
MID _ -> MID_
SENT _ -> SENT_
MERR {} -> MERR_
MSG {} -> MSG_
ACK _ -> ACK_
OFF -> OFF_
DEL -> DEL_
CHK -> CHK_
STAT _ -> STAT_
OK -> OK_
ERR _ -> ERR_
SUSPENDED -> SUSPENDED_
data ConnectionStats = ConnectionStats
{ rcvServers :: [SMPServer],
sndServers :: [SMPServer]
@@ -920,58 +1000,129 @@ instance Arbitrary BrokerErrorType where arbitrary = genericArbitraryU
instance Arbitrary SMPAgentError where arbitrary = genericArbitraryU
-- | SMP agent command and response parser for commands passed via network (only parses binary length)
networkCommandP :: Parser ACmd
networkCommandP = commandP A.takeByteString
-- | SMP agent command and response parser for commands stored in db (fully parses binary bodies)
dbCommandP :: Parser ACmd
dbCommandP = commandP $ A.take =<< (A.decimal <* "\n")
instance Encoding ACmdTag where
smpEncode (ACmdTag _ cmd) = smpEncode cmd
smpP =
A.takeTill (== ' ') >>= \case
"NEW" -> pure $ ACmdTag SClient NEW_
"INV" -> pure $ ACmdTag SAgent INV_
"JOIN" -> pure $ ACmdTag SClient JOIN_
"CONF" -> pure $ ACmdTag SAgent CONF_
"LET" -> pure $ ACmdTag SClient LET_
"REQ" -> pure $ ACmdTag SAgent REQ_
"ACPT" -> pure $ ACmdTag SClient ACPT_
"RJCT" -> pure $ ACmdTag SClient RJCT_
"INFO" -> pure $ ACmdTag SAgent INFO_
"CON" -> pure $ ACmdTag SAgent CON_
"SUB" -> pure $ ACmdTag SClient SUB_
"END" -> pure $ ACmdTag SAgent END_
"CONNECT" -> pure $ ACmdTag SAgent CONNECT_
"DISCONNECT" -> pure $ ACmdTag SAgent DISCONNECT_
"DOWN" -> pure $ ACmdTag SAgent DOWN_
"UP" -> pure $ ACmdTag SAgent UP_
"SEND" -> pure $ ACmdTag SClient SEND_
"MID" -> pure $ ACmdTag SAgent MID_
"SENT" -> pure $ ACmdTag SAgent SENT_
"MERR" -> pure $ ACmdTag SAgent MERR_
"MSG" -> pure $ ACmdTag SAgent MSG_
"ACK" -> pure $ ACmdTag SClient ACK_
"OFF" -> pure $ ACmdTag SClient OFF_
"DEL" -> pure $ ACmdTag SClient DEL_
"CHK" -> pure $ ACmdTag SClient CHK_
"STAT" -> pure $ ACmdTag SAgent STAT_
"OK" -> pure $ ACmdTag SAgent OK_
"ERR" -> pure $ ACmdTag SAgent ERR_
"SUSPENDED" -> pure $ ACmdTag SAgent SUSPENDED_
_ -> fail "bad ACmdTag"
instance APartyI p => Encoding (ACommandTag p) where
smpEncode = \case
NEW_ -> "NEW"
INV_ -> "INV"
JOIN_ -> "JOIN"
CONF_ -> "CONF"
LET_ -> "LET"
REQ_ -> "REQ"
ACPT_ -> "ACPT"
RJCT_ -> "RJCT"
INFO_ -> "INFO"
CON_ -> "CON"
SUB_ -> "SUB"
END_ -> "END"
CONNECT_ -> "CONNECT"
DISCONNECT_ -> "DISCONNECT"
DOWN_ -> "DOWN"
UP_ -> "UP"
SEND_ -> "SEND"
MID_ -> "MID"
SENT_ -> "SENT"
MERR_ -> "MERR"
MSG_ -> "MSG"
ACK_ -> "ACK"
OFF_ -> "OFF"
DEL_ -> "DEL"
CHK_ -> "CHK"
STAT_ -> "STAT"
OK_ -> "OK"
ERR_ -> "ERR"
SUSPENDED_ -> "SUSPENDED"
smpP = (\(ACmdTag _ t) -> checkParty t) <$?> smpP
checkParty :: forall t p p'. (APartyI p, APartyI p') => t p' -> Either String (t p)
checkParty x = case testEquality (sAParty @p) (sAParty @p') of
Just Refl -> Right x
Nothing -> Left "bad party"
-- | SMP agent command and response parser
commandP :: Parser ACmd
commandP =
"NEW " *> newCmd
<|> "INV " *> invResp
<|> "JOIN " *> joinCmd
<|> "CONF " *> confMsg
<|> "LET " *> letCmd
<|> "REQ " *> reqMsg
<|> "ACPT " *> acptCmd
<|> "RJCT " *> rjctCmd
<|> "INFO " *> infoCmd
<|> "SUB" $> ACmd SClient SUB
<|> "END" $> ACmd SAgent END
<|> "CONNECT " *> connectResp
<|> "DISCONNECT " *> disconnectResp
<|> "DOWN " *> downResp
<|> "UP " *> upResp
<|> "SEND " *> sendCmd
<|> "MID " *> msgIdResp
<|> "SENT " *> sentResp
<|> "MERR " *> msgErrResp
<|> "MSG " *> message
<|> "ACK " *> ackCmd
<|> "OFF" $> ACmd SClient OFF
<|> "DEL" $> ACmd SClient DEL
<|> "CHK" $> ACmd SClient CHK
<|> "STAT " *> statResp
<|> "ERR " *> agentError
<|> "CON" $> ACmd SAgent CON
<|> "OK" $> ACmd SAgent OK
commandP :: Parser ByteString -> Parser ACmd
commandP binaryP =
smpP
>>= \case
ACmdTag SClient cmd ->
ACmd SClient <$> case cmd of
NEW_ -> s (NEW <$> strP_ <*> strP)
JOIN_ -> s (JOIN <$> strP_ <*> strP_ <*> binaryP)
LET_ -> s (LET <$> A.takeTill (== ' ') <* A.space <*> binaryP)
ACPT_ -> s (ACPT <$> A.takeTill (== ' ') <* A.space <*> binaryP)
RJCT_ -> s (RJCT <$> A.takeByteString)
SUB_ -> pure SUB
SEND_ -> s (SEND <$> smpP <* A.space <*> binaryP)
ACK_ -> s (ACK <$> A.decimal)
OFF_ -> pure OFF
DEL_ -> pure DEL
CHK_ -> pure CHK
ACmdTag SAgent cmd ->
ACmd SAgent <$> case cmd of
INV_ -> s (INV <$> strP)
CONF_ -> s (CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> binaryP)
REQ_ -> s (REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> binaryP)
INFO_ -> s (INFO <$> binaryP)
CON_ -> pure CON
END_ -> pure END
CONNECT_ -> s (CONNECT <$> strP_ <*> strP)
DISCONNECT_ -> s (DISCONNECT <$> strP_ <*> strP)
DOWN_ -> s (DOWN <$> strP_ <*> connections)
UP_ -> s (UP <$> strP_ <*> connections)
MID_ -> s (MID <$> A.decimal)
SENT_ -> s (SENT <$> A.decimal)
MERR_ -> s (MERR <$> A.decimal <* A.space <*> strP)
MSG_ -> s (MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> binaryP)
STAT_ -> s (STAT <$> strP)
OK_ -> pure OK
ERR_ -> s (ERR <$> strP)
SUSPENDED_ -> pure SUSPENDED
where
newCmd = ACmd SClient . NEW <$> strP
invResp = ACmd SAgent . INV <$> strP
joinCmd = ACmd SClient .: JOIN <$> strP_ <*> A.takeByteString
confMsg = ACmd SAgent .:. CONF <$> A.takeTill (== ' ') <* A.space <*> strListP <* A.space <*> A.takeByteString
letCmd = ACmd SClient .: LET <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
reqMsg = ACmd SAgent .:. REQ <$> A.takeTill (== ' ') <* A.space <*> strP_ <*> A.takeByteString
acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
rjctCmd = ACmd SClient . RJCT <$> A.takeByteString
infoCmd = ACmd SAgent . INFO <$> A.takeByteString
connectResp = ACmd SAgent .: CONNECT <$> strP_ <*> strP
disconnectResp = ACmd SAgent .: DISCONNECT <$> strP_ <*> strP
downResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
upResp = ACmd SAgent .: UP <$> strP_ <*> connections
sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString
msgIdResp = ACmd SAgent . MID <$> A.decimal
sentResp = ACmd SAgent . SENT <$> A.decimal
msgErrResp = ACmd SAgent .: MERR <$> A.decimal <* A.space <*> strP
message = ACmd SAgent .:. MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> A.takeByteString
ackCmd = ACmd SClient . ACK <$> A.decimal
statResp = ACmd SAgent . STAT <$> strP
s :: Parser a -> Parser a
s p = A.space *> p
connections :: Parser [ConnId]
connections = strP `A.sepBy'` A.char ','
msgMetaP = do
integrity <- strP
@@ -980,17 +1131,16 @@ commandP =
sndMsgId <- " S=" *> A.decimal
pure MsgMeta {integrity, recipient, broker, sndMsgId}
partyMeta idParser = (,) <$> idParser <* A.char ',' <*> tsISO8601P
agentError = ACmd SAgent . ERR <$> strP
parseCommand :: ByteString -> Either AgentErrorType ACmd
parseCommand = parse commandP $ CMD SYNTAX
parseCommand = parse (commandP A.takeByteString) $ CMD SYNTAX
-- | Serialize SMP agent command.
serializeCommand :: ACommand p -> ByteString
serializeCommand = \case
NEW cMode -> "NEW " <> strEncode cMode
NEW ntfs cMode -> B.unwords ["NEW", strEncode ntfs, strEncode cMode]
INV cReq -> "INV " <> strEncode cReq
JOIN cReq cInfo -> B.unwords ["JOIN", strEncode cReq, serializeBinary cInfo]
JOIN ntfs cReq cInfo -> B.unwords ["JOIN", strEncode ntfs, strEncode cReq, serializeBinary cInfo]
CONF confId srvs cInfo -> B.unwords ["CONF", confId, strEncodeList srvs, serializeBinary cInfo]
LET confId cInfo -> B.unwords ["LET", confId, serializeBinary cInfo]
REQ invId srvs cInfo -> B.unwords ["REQ", invId, strEncode srvs, serializeBinary cInfo]
@@ -1068,7 +1218,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
tConnId :: ARawTransmission -> ACommand p -> Either AgentErrorType (ACommand p)
tConnId (_, connId, _) cmd = case cmd of
-- NEW, JOIN and ACPT have optional connId
NEW _ -> Right cmd
NEW _ _ -> Right cmd
JOIN {} -> Right cmd
ACPT {} -> Right cmd
-- ERROR response does not always have connId
@@ -1086,7 +1236,7 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
cmdWithMsgBody = \case
SEND msgFlags body -> SEND msgFlags <$$> getBody body
MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body
JOIN qUri cInfo -> JOIN qUri <$$> getBody cInfo
JOIN ntfs qUri cInfo -> JOIN ntfs qUri <$$> getBody cInfo
CONF confId srvs cInfo -> CONF confId srvs <$$> getBody cInfo
LET confId cInfo -> LET confId <$$> getBody cInfo
REQ invId srvs cInfo -> REQ invId srvs <$$> getBody cInfo
+8 -1
View File
@@ -92,7 +92,7 @@ data SndQueue = SndQueue
-- * Connection types
-- | Type of a connection.
data ConnType = CRcv | CSnd | CDuplex | CContact deriving (Eq, Show)
data ConnType = CNew | CRcv | CSnd | CDuplex | CContact deriving (Eq, Show)
-- | Connection of a specific type.
--
@@ -105,6 +105,7 @@ data ConnType = CRcv | CSnd | CDuplex | CContact deriving (Eq, Show)
-- - DuplexConnection is a connection that has both receive and send queues set up,
-- typically created by upgrading a receive or a send connection with a missing queue.
data Connection (d :: ConnType) where
NewConnection :: ConnData -> Connection CNew
RcvConnection :: ConnData -> RcvQueue -> Connection CRcv
SndConnection :: ConnData -> SndQueue -> Connection CSnd
DuplexConnection :: ConnData -> RcvQueue -> SndQueue -> Connection CDuplex
@@ -115,12 +116,14 @@ deriving instance Eq (Connection d)
deriving instance Show (Connection d)
data SConnType :: ConnType -> Type where
SCNew :: SConnType CNew
SCRcv :: SConnType CRcv
SCSnd :: SConnType CSnd
SCDuplex :: SConnType CDuplex
SCContact :: SConnType CContact
connType :: SConnType c -> ConnType
connType SCNew = CNew
connType SCRcv = CRcv
connType SCSnd = CSnd
connType SCDuplex = CDuplex
@@ -272,6 +275,8 @@ newtype InternalId = InternalId {unId :: Int64} deriving (Eq, Show)
type InternalTs = UTCTime
type AsyncCmdId = Int64
-- * Store errors
-- | Agent store error.
@@ -293,6 +298,8 @@ data StoreError
SEInvitationNotFound
| -- | Message not found
SEMsgNotFound
| -- | Command not found
SECmdNotFound
| -- | Currently not used. The intention was to pass current expected queue status in methods,
-- as we always know what it should be at any stage of the protocol,
-- and in case it does not match use this error.
+95 -1
View File
@@ -26,6 +26,9 @@ module Simplex.Messaging.Agent.Store.SQLite
sqlString,
-- * Queues and connections
createNewConn,
updateNewConnRcv,
updateNewConnSnd,
createRcvConn,
createSndConn,
getConn,
@@ -68,6 +71,11 @@ module Simplex.Messaging.Agent.Store.SQLite
getRatchet,
getSkippedMsgKeys,
updateRatchet,
-- Async commands
createCommand,
getPendingCommands,
getPendingCommand,
deleteCommand,
-- Notification device token persistence
createNtfToken,
getSavedNtfToken,
@@ -107,8 +115,10 @@ import Data.Bifunctor (second)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Base64.URL as U
import Data.Char (toLower)
import Data.Function (on)
import Data.Functor (($>))
import Data.List (foldl')
import Data.Int (Int64)
import Data.List (find, foldl', groupBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, listToMaybe)
@@ -265,6 +275,37 @@ createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of
ConnData {connId = ""} -> createWithRandomId gVar create
ConnData {connId} -> create connId $> Right connId
createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId)
createNewConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} cMode =
createConn_ gVar cData $ \connId -> do
DB.execute db "INSERT INTO connections (conn_id, conn_mode, smp_agent_version, enable_ntfs, duplex_handshake) VALUES (?, ?, ?, ?, ?)" (connId, cMode, connAgentVersion, enableNtfs, duplexHandshake)
updateNewConnRcv :: DB.Connection -> ConnId -> RcvQueue -> IO (Either StoreError ())
updateNewConnRcv db connId rq@RcvQueue {server} =
getConn db connId $>>= \case
(SomeConn _ NewConnection {}) -> updateConn
(SomeConn _ RcvConnection {}) -> updateConn -- to allow retries
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
where
updateConn :: IO (Either StoreError ())
updateConn = do
upsertServer_ db server
insertRcvQueue_ db connId rq
pure $ Right ()
updateNewConnSnd :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError ())
updateNewConnSnd db connId sq@SndQueue {server} =
getConn db connId $>>= \case
(SomeConn _ NewConnection {}) -> updateConn
(SomeConn _ SndConnection {}) -> updateConn -- to allow retries
(SomeConn c _) -> pure . Left . SEBadConnType $ connType c
where
updateConn :: IO (Either StoreError ())
updateConn = do
upsertServer_ db server
insertSndQueue_ db connId sq
pure $ Right ()
createRcvConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> IO (Either StoreError ConnId)
createRcvConn db gVar cData@ConnData {connAgentVersion, enableNtfs, duplexHandshake} q@RcvQueue {server} cMode =
createConn_ gVar cData $ \connId -> do
@@ -661,6 +702,50 @@ updateRatchet db connId rc skipped = do
forM_ (M.assocs mks) $ \(msgN, mk) ->
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> ACommand 'Client -> IO AsyncCmdId
createCommand db corrId connId srv cmd = do
DB.execute
db
"INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command) VALUES (?,?,?,?,?,?)"
(host_, port_, corrId, connId, aCommandTag cmd, cmd)
insertedRowId db
where
(host_, port_) =
case srv of
Just (SMPServer host port _) -> (Just host, Just port)
_ -> (Nothing, Nothing)
insertedRowId :: DB.Connection -> IO Int64
insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()"
getPendingCommands :: DB.Connection -> ConnId -> IO [(Maybe SMPServer, [AsyncCmdId])]
getPendingCommands db connId = do
map (\ids -> (fst $ head ids, map snd ids)) . groupBy ((==) `on` fst) . map srvCmdId
<$> DB.query
db
[sql|
SELECT c.host, c.port, s.key_hash, c.command_id
FROM commands c
LEFT JOIN servers s ON s.host = c.host AND s.port = c.port
WHERE conn_id = ?
ORDER BY c.host, c.port, c.command_id ASC
|]
(Only connId)
where
srvCmdId (host, port, keyHash, cmdId) = (SMPServer <$> host <*> port <*> keyHash, cmdId)
getPendingCommand :: DB.Connection -> AsyncCmdId -> IO (Either StoreError (ACorrId, ConnId, ACmd))
getPendingCommand db msgId = do
firstRow id SECmdNotFound $
DB.query
db
"SELECT corr_id, conn_id, command FROM commands WHERE command_id = ?"
(Only msgId)
deleteCommand :: DB.Connection -> AsyncCmdId -> IO ()
deleteCommand db cmdId =
DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId)
createNtfToken :: DB.Connection -> NtfToken -> IO ()
createNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do
upsertNtfServer_ db srv
@@ -1023,6 +1108,14 @@ instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1
instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
instance ToField (ACommand p) where toField = toField . serializeCommand
instance FromField ACmd where fromField = blobFieldParser dbCommandP
instance APartyI p => ToField (ACommandTag p) where toField = toField . smpEncode
instance FromField ACmdTag where fromField = blobFieldParser smpP
listToEither :: e -> [a] -> Either e a
listToEither _ (x : _) = Right x
listToEither e _ = Left e
@@ -1130,6 +1223,7 @@ getConn dbConn connId =
(Just rcvQ, Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection connData rcvQ)
(Nothing, Just sndQ, CMInvitation) -> Right $ SomeConn SCSnd (SndConnection connData sndQ)
(Just rcvQ, Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection connData rcvQ)
(Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection connData)
_ -> Left SEConnNotFound
getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode))
@@ -36,6 +36,7 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Transport.Client (TransportHost)
@@ -50,7 +51,8 @@ schemaMigrations =
("20220607_v2", m20220608_v2),
("m20220625_v2_ntf_mode", m20220625_v2_ntf_mode),
("m20220811_onion_hosts", m20220811_onion_hosts),
("m20220817_connection_ntfs", m20220817_connection_ntfs)
("m20220817_connection_ntfs", m20220817_connection_ntfs),
("m20220905_commands", m20220905_commands)
]
-- | The list of migrations in ascending order by date
@@ -0,0 +1,23 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20220905_commands :: Query
m20220905_commands =
[sql|
CREATE TABLE commands (
command_id INTEGER PRIMARY KEY,
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
host TEXT,
port TEXT,
corr_id BLOB NOT NULL,
command_tag BLOB NOT NULL,
command BLOB NOT NULL,
agent_version INTEGER NOT NULL DEFAULT 1,
FOREIGN KEY (host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE
);
|]
@@ -194,3 +194,15 @@ CREATE TABLE ntf_subscriptions(
FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers
ON DELETE RESTRICT ON UPDATE CASCADE
) WITHOUT ROWID;
CREATE TABLE commands(
command_id INTEGER PRIMARY KEY,
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
host TEXT,
port TEXT,
corr_id BLOB NOT NULL,
command_tag BLOB NOT NULL,
command BLOB NOT NULL,
agent_version INTEGER NOT NULL DEFAULT 1,
FOREIGN KEY(host, port) REFERENCES servers
ON DELETE RESTRICT ON UPDATE CASCADE
);
+5
View File
@@ -2,6 +2,7 @@ module Simplex.Messaging.TMap
( TMap,
empty,
singleton,
clear,
Simplex.Messaging.TMap.null,
Simplex.Messaging.TMap.lookup,
member,
@@ -31,6 +32,10 @@ singleton :: k -> a -> STM (TMap k a)
singleton k v = newTVar $ M.singleton k v
{-# INLINE singleton #-}
clear :: TMap k a -> STM ()
clear m = writeTVar m M.empty
{-# INLINE clear #-}
null :: TMap k a -> STM Bool
null m = M.null <$> readTVar m
{-# INLINE null #-}
+82
View File
@@ -0,0 +1,82 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
module Simplex.Messaging.TMap2
( TMap2,
empty,
clear,
Simplex.Messaging.TMap2.lookup,
lookup1,
member,
insert,
insert1,
delete,
lookupDelete1,
)
where
import Control.Concurrent.STM
import Control.Monad (forM_, (>=>))
import qualified Data.Map.Strict as M
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (whenM, ($>>=))
-- | this type is designed for k2 being unique in the whole data, and k1 grouping multiple values with k2 keys.
-- It allows direct access via k1 to a group of k2 values and via k2 to one value
data TMap2 k1 k2 a = TMap2
{ _m1 :: TMap k1 (TMap k2 a),
_m2 :: TMap k2 k1
}
empty :: STM (TMap2 k1 k2 a)
empty = TMap2 <$> TM.empty <*> TM.empty
clear :: TMap2 k1 k2 a -> STM ()
clear TMap2 {_m1, _m2} = TM.clear _m1 >> TM.clear _m2
lookup :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM (Maybe a)
lookup k2 TMap2 {_m1, _m2} = do
TM.lookup k2 _m2 $>>= (`TM.lookup` _m1) $>>= TM.lookup k2
lookup1 :: Ord k1 => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
lookup1 k1 TMap2 {_m1} = TM.lookup k1 _m1
{-# INLINE lookup1 #-}
member :: Ord k2 => k2 -> TMap2 k1 k2 a -> STM Bool
member k2 TMap2 {_m2} = TM.member k2 _m2
{-# INLINE member #-}
insert :: (Ord k1, Ord k2) => k1 -> k2 -> a -> TMap2 k1 k2 a -> STM ()
insert k1 k2 v TMap2 {_m1, _m2} =
TM.lookup k2 _m2 >>= \case
Just k1'
| k1 == k1' -> _insert1
| otherwise -> _delete1 k1' k2 _m1 >> _insert2
_ -> _insert2
where
_insert1 =
TM.lookup k1 _m1 >>= \case
Just m -> TM.insert k2 v m
_ -> TM.singleton k2 v >>= \m -> TM.insert k1 m _m1
_insert2 = TM.insert k2 k1 _m2 >> _insert1
insert1 :: (Ord k1, Ord k2) => k1 -> TMap k2 a -> TMap2 k1 k2 a -> STM ()
insert1 k1 m' TMap2 {_m1, _m2} =
TM.lookup k1 _m1 >>= \case
Just m -> readTVar m' >>= (`TM.union` m)
_ -> TM.insert k1 m' _m1
delete :: (Ord k1, Ord k2) => k2 -> TMap2 k1 k2 a -> STM ()
delete k2 TMap2 {_m1, _m2} = TM.lookupDelete k2 _m2 >>= mapM_ (\k1 -> _delete1 k1 k2 _m1)
_delete1 :: (Ord k1, Ord k2) => k1 -> k2 -> TMap k1 (TMap k2 a) -> STM ()
_delete1 k1 k2 m1 =
TM.lookup k1 m1
>>= mapM_ (\m -> TM.delete k2 m >> whenM (TM.null m) (TM.delete k1 m1))
lookupDelete1 :: (Ord k1, Ord k2) => k1 -> TMap2 k1 k2 a -> STM (Maybe (TMap k2 a))
lookupDelete1 k1 TMap2 {_m1, _m2} = do
m_ <- TM.lookupDelete k1 _m1
forM_ m_ $ readTVar >=> modifyTVar' _m2 . flip M.withoutKeys . M.keysSet
pure m_
+21 -21
View File
@@ -131,9 +131,9 @@ pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO ()
testDuplexConnection _ alice bob = do
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW INV")
("1", "bob", Right (INV cReq)) <- alice #: ("1", "bob", "NEW T INV")
let cReq' = strEncode cReq
bob #: ("11", "alice", "JOIN " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
bob #: ("11", "alice", "JOIN T " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
("", "bob", Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
alice #: ("2", "bob", "LET " <> confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
bob <# ("", "alice", INFO "alice's connInfo")
@@ -164,9 +164,9 @@ testDuplexConnection _ alice bob = do
testDuplexConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
testDuplexConnRandomIds _ alice bob = do
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW INV")
("1", bobConn, Right (INV cReq)) <- alice #: ("1", "", "NEW T INV")
let cReq' = strEncode cReq
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN " <> cReq' <> " 14\nbob's connInfo")
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " 14\nbob's connInfo")
("", bobConn', Right (CONF confId _ "bob's connInfo")) <- (alice <#:)
bobConn' `shouldBe` bobConn
alice #: ("2", bobConn, "LET " <> confId <> " 16\nalice's connInfo") =#> \case ("2", c, OK) -> c == bobConn; _ -> False
@@ -197,10 +197,10 @@ testDuplexConnRandomIds _ alice bob = do
testContactConnection :: Transport c => TProxy c -> c -> c -> c -> IO ()
testContactConnection _ alice bob tom = do
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW CON")
("1", "alice_contact", Right (INV cReq)) <- alice #: ("1", "alice_contact", "NEW T CON")
let cReq' = strEncode cReq
bob #: ("11", "alice", "JOIN " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
bob #: ("11", "alice", "JOIN T " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice", OK)
("", "alice_contact", Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
alice #: ("2", "bob", "ACPT " <> aInvId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
("", "alice", Right (CONF bConfId _ "alice's connInfo")) <- (bob <#:)
@@ -213,7 +213,7 @@ testContactConnection _ alice bob tom = do
bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False
bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK)
tom #: ("21", "alice", "JOIN " <> cReq' <> " 14\ntom's connInfo") #> ("21", "alice", OK)
tom #: ("21", "alice", "JOIN T " <> cReq' <> " 14\ntom's connInfo") #> ("21", "alice", OK)
("", "alice_contact", Right (REQ aInvId' _ "tom's connInfo")) <- (alice <#:)
alice #: ("4", "tom", "ACPT " <> aInvId' <> " 16\nalice's connInfo") #> ("4", "tom", OK)
("", "alice", Right (CONF tConfId _ "alice's connInfo")) <- (tom <#:)
@@ -228,10 +228,10 @@ testContactConnection _ alice bob tom = do
testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
testContactConnRandomIds _ alice bob = do
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW CON")
("1", aliceContact, Right (INV cReq)) <- alice #: ("1", "", "NEW T CON")
let cReq' = strEncode cReq
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN " <> cReq' <> " 14\nbob's connInfo")
("11", aliceConn, Right OK) <- bob #: ("11", "", "JOIN T " <> cReq' <> " 14\nbob's connInfo")
("", aliceContact', Right (REQ aInvId _ "bob's connInfo")) <- (alice <#:)
aliceContact' `shouldBe` aliceContact
@@ -251,9 +251,9 @@ testContactConnRandomIds _ alice bob = do
testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO ()
testRejectContactRequest _ alice bob = do
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW CON")
("1", "a_contact", Right (INV cReq)) <- alice #: ("1", "a_contact", "NEW T CON")
let cReq' = strEncode cReq
bob #: ("11", "alice", "JOIN " <> cReq' <> " 10\nbob's info") #> ("11", "alice", OK)
bob #: ("11", "alice", "JOIN T " <> cReq' <> " 10\nbob's info") #> ("11", "alice", OK)
("", "a_contact", Right (REQ aInvId _ "bob's info")) <- (alice <#:)
-- RJCT must use correct contact connection
alice #: ("2a", "bob", "RJCT " <> aInvId) #> ("2a", "bob", ERR $ CONN NOT_FOUND)
@@ -282,7 +282,7 @@ testSubscription _ alice1 alice2 bob = do
testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO ()
testSubscrNotification t (server, _) client = do
client #: ("1", "conn1", "NEW INV") =#> \case ("1", "conn1", INV {}) -> True; _ -> False
client #: ("1", "conn1", "NEW T INV") =#> \case ("1", "conn1", INV {}) -> True; _ -> False
client #:# "nothing should be delivered to client before the server is killed"
killThread server
client <# ("", "", DOWN testSMPServer ["conn1"])
@@ -392,9 +392,9 @@ testConcurrentMsgDelivery :: Transport c => TProxy c -> c -> c -> IO ()
testConcurrentMsgDelivery _ alice bob = do
connect (alice, "alice") (bob, "bob")
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW INV")
("1", "bob2", Right (INV cReq)) <- alice #: ("1", "bob2", "NEW T INV")
let cReq' = strEncode cReq
bob #: ("11", "alice2", "JOIN " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice2", OK)
bob #: ("11", "alice2", "JOIN T " <> cReq' <> " 14\nbob's connInfo") #> ("11", "alice2", OK)
("", "bob2", Right (CONF _confId _ "bob's connInfo")) <- (alice <#:)
-- below commands would be needed to accept bob's connection, but alice does not
-- alice #: ("2", "bob", "LET " <> _confId <> " 16\nalice's connInfo") #> ("2", "bob", OK)
@@ -431,9 +431,9 @@ testMsgDeliveryQuotaExceeded _ alice bob = do
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
connect (h1, name1) (h2, name2) = do
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW INV")
("c1", _, Right (INV cReq)) <- h1 #: ("c1", name2, "NEW T INV")
let cReq' = strEncode cReq
h2 #: ("c2", name1, "JOIN " <> cReq' <> " 5\ninfo2") #> ("c2", name1, OK)
h2 #: ("c2", name1, "JOIN T " <> cReq' <> " 5\ninfo2") #> ("c2", name1, OK)
("", _, Right (CONF connId _ "info2")) <- (h1 <#:)
h1 #: ("c3", name2, "LET " <> connId <> " 5\ninfo1") #> ("c3", name2, OK)
h2 <# ("", name1, INFO "info1")
@@ -452,9 +452,9 @@ sendMessage (h1, name1) (h2, name2) msg = do
-- connect' :: forall c. Transport c => c -> c -> IO (ByteString, ByteString)
-- connect' h1 h2 = do
-- ("c1", conn2, Right (INV cReq)) <- h1 #: ("c1", "", "NEW INV")
-- ("c1", conn2, Right (INV cReq)) <- h1 #: ("c1", "", "NEW T INV")
-- let cReq' = strEncode cReq
-- ("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN " <> cReq' <> " 5\ninfo2")
-- ("c2", conn1, Right OK) <- h2 #: ("c2", "", "JOIN T " <> cReq' <> " 5\ninfo2")
-- ("", _, Right (REQ connId _ "info2")) <- (h1 <#:)
-- h1 #: ("c3", conn2, "ACPT " <> connId <> " 5\ninfo1") =#> \case ("c3", c, OK) -> c == conn2; _ -> False
-- h2 <# ("", conn1, INFO "info1")
@@ -471,17 +471,17 @@ syntaxTests t = do
describe "NEW" $ do
describe "valid" $ do
-- TODO: add tests with defined connection id
it "with correct parameter" $ ("211", "", "NEW INV") >#>= \case ("211", _, "INV" : _) -> True; _ -> False
it "with correct parameter" $ ("211", "", "NEW T INV") >#>= \case ("211", _, "INV" : _) -> True; _ -> False
describe "invalid" $ do
-- TODO: add tests with defined connection id
it "with incorrect parameter" $ ("222", "", "NEW hi") >#> ("222", "", "ERR CMD SYNTAX")
it "with incorrect parameter" $ ("222", "", "NEW T hi") >#> ("222", "", "ERR CMD SYNTAX")
describe "JOIN" $ do
describe "valid" $ do
it "using same server as in invitation" $
( "311",
"a",
"JOIN https://simpex.chat/invitation#/?smp=smp%3A%2F%2F"
"JOIN T https://simpex.chat/invitation#/?smp=smp%3A%2F%2F"
<> urlEncode True "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
<> "%40localhost%3A5001%2F3456-w%3D%3D%23"
<> urlEncode True sampleDhKey
+63
View File
@@ -114,6 +114,11 @@ functionalAPITests t = do
describe "Batching SMP commands" $ do
it "should subscribe to multiple subscriptions with batching" $
testBatchedSubscriptions t
describe "Async agent commands" $ do
it "should connect using async agent commands" $
withSmpServer t testAsyncCommands
it "should restore and complete async commands on restart" $
testAsyncCommandsRestore t
testAgentClient :: IO ()
testAgentClient = do
@@ -560,6 +565,64 @@ testBatchedSubscriptions t = do
killThread t1
pure res
testAsyncCommands :: IO ()
testAsyncCommands = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
Right () <- runExceptT $ do
bobId <- createConnectionAsync alice "1" True SCMInvitation
("1", bobId', INV (ACR _ qInfo)) <- get alice
liftIO $ bobId' `shouldBe` bobId
aliceId <- joinConnectionAsync bob "2" True qInfo "bob's connInfo"
("2", aliceId', OK) <- get bob
liftIO $ aliceId' `shouldBe` aliceId
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnectionAsync alice "3" bobId confId "alice's connInfo"
("3", _, OK) <- get alice
get alice ##> ("", bobId, CON)
get bob ##> ("", aliceId, INFO "alice's connInfo")
get bob ##> ("", aliceId, CON)
-- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4
1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello"
get alice ##> ("", bobId, SENT $ baseId + 1)
2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
get alice ##> ("", bobId, SENT $ baseId + 2)
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
ackMessageAsync bob "4" aliceId $ baseId + 1
("4", _, OK) <- get bob
get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False
ackMessageAsync bob "5" aliceId $ baseId + 2
("5", _, OK) <- get bob
3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too"
get bob ##> ("", aliceId, SENT $ baseId + 3)
4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1"
get bob ##> ("", aliceId, SENT $ baseId + 4)
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
ackMessageAsync alice "6" bobId $ baseId + 3
("6", _, OK) <- get alice
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
ackMessageAsync alice "7" bobId $ baseId + 4
("7", _, OK) <- get alice
pure ()
pure ()
where
baseId = 3
msgId = subtract baseId
testAsyncCommandsRestore :: ATransport -> IO ()
testAsyncCommandsRestore t = do
alice <- getSMPAgentClient agentCfg initAgentServers
Right bobId <- runExceptT $ createConnectionAsync alice "1" True SCMInvitation
liftIO $ noMessages alice "alice doesn't receive INV because server is down"
disconnectAgentClient alice
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
withSmpServerStoreLogOn t testPort $ \_ -> do
Right () <- runExceptT $ do
subscribeConnection alice' bobId
("1", _, INV _) <- get alice'
pure ()
pure ()
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings = exchangeGreetingsMsgId 4
+1 -1
View File
@@ -62,7 +62,7 @@ smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h
where
get h = do
t@(_, _, cmdStr) <- tGetRaw h
case parseAll commandP cmdStr of
case parseAll networkCommandP cmdStr of
Right (ACmd SAgent CONNECT {}) -> get h
Right (ACmd SAgent DISCONNECT {}) -> get h
_ -> pure t