Files
simplexmq/src/Simplex/Messaging/Agent.hs
2023-10-24 16:55:57 +04:00

2448 lines
131 KiB
Haskell

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-}
-- |
-- Module : Simplex.Messaging.Agent
-- Copyright : (c) simplex.chat
-- License : AGPL-3
--
-- Maintainer : chat@simplex.chat
-- Stability : experimental
-- Portability : non-portable
--
-- This module defines SMP protocol agent with SQLite persistence.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
module Simplex.Messaging.Agent
( -- * queue-based SMP agent
getAgentClient,
runAgentClient,
-- * SMP agent functional API
AgentClient (..),
AgentMonad,
AgentErrorMonad,
SubscriptionsInfo (..),
getSMPAgentClient,
disconnectAgentClient,
resumeAgentClient,
withConnLock,
withInvLock,
createUser,
deleteUser,
createConnectionAsync,
joinConnectionAsync,
allowConnectionAsync,
acceptContactAsync,
ackMessageAsync,
switchConnectionAsync,
deleteConnectionAsync,
deleteConnectionsAsync,
createConnection,
joinConnection,
allowConnection,
acceptContact,
rejectContact,
subscribeConnection,
subscribeConnections,
getConnectionMessage,
getNotificationMessage,
resubscribeConnection,
resubscribeConnections,
sendMessage,
ackMessage,
switchConnection,
abortConnectionSwitch,
synchronizeRatchet,
suspendConnection,
deleteConnection,
deleteConnections,
getConnectionServers,
getConnectionRatchetAdHash,
setProtocolServers,
testProtocolServer,
setNtfServers,
setNetworkConfig,
getNetworkConfig,
reconnectAllServers,
registerNtfToken,
verifyNtfToken,
checkNtfToken,
deleteNtfToken,
getNtfToken,
getNtfTokenData,
toggleConnectionNtfs,
xftpStartWorkers,
xftpReceiveFile,
xftpDeleteRcvFile,
xftpSendFile,
xftpDeleteSndFileInternal,
xftpDeleteSndFileRemote,
foregroundAgent,
suspendAgent,
execAgentStoreSQL,
getAgentMigrations,
debugAgentLocks,
getAgentStats,
resetAgentStats,
getAgentSubscriptions,
logConnection,
)
where
import Control.Logger.Simple (logError, logInfo, showText)
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader
import Crypto.Random (MonadRandom)
import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Composition ((.:), (.:.), (.::))
import Data.Foldable (foldl')
import Data.Functor (($>))
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time.Clock
import Data.Time.Clock.System (systemToUTCTime)
import Simplex.FileTransfer.Agent (closeXFTPAgent, deleteSndFileInternal, deleteSndFileRemote, startXFTPWorkers, toFSFilePath, xftpDeleteRcvFile', xftpReceiveFile', xftpSendFile')
import Simplex.FileTransfer.Description (ValidFileDescription)
import Simplex.FileTransfer.Protocol (FileParty (..))
import Simplex.FileTransfer.Util (removePath)
import Simplex.Messaging.Agent.Client
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Lock (withLock)
import Simplex.Messaging.Agent.NtfSubSupervisor
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Store
import Simplex.Messaging.Agent.Store.SQLite
import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client (ProtocolClient (..), ServerTransmission)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs)
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..), NtfTokenId)
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..))
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, EntityId, ErrorType (AUTH), MsgBody, MsgFlags (..), NtfServer, ProtoServerWithAuth, ProtocolTypeI (..), SMPMsgMeta, SProtocolType (..), SndPublicVerifyKey, SubscriptionMode (..), UserProtocol, XFTPServerWithAuth)
import qualified Simplex.Messaging.Protocol as SMP
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import UnliftIO.Async (async, race_)
import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay)
import UnliftIO.STM
-- import GHC.Conc (unsafeIOToSTM)
-- | Creates an SMP agent client instance
getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> InitialAgentServers -> SQLiteStore -> m AgentClient
getSMPAgentClient cfg initServers store =
liftIO (newSMPAgentEnv cfg store) >>= runReaderT runAgent
where
runAgent = do
c <- getAgentClient initServers
void $ raceAny_ [subscriber c, runNtfSupervisor c, cleanupManager c] `forkFinally` const (disconnectAgentClient c)
pure c
disconnectAgentClient :: MonadUnliftIO m => AgentClient -> m ()
disconnectAgentClient c@AgentClient {agentEnv = Env {ntfSupervisor = ns, xftpAgent = xa}} = do
closeAgentClient c
closeNtfSupervisor ns
closeXFTPAgent xa
logConnection c False
resumeAgentClient :: MonadIO m => AgentClient -> m ()
resumeAgentClient c = atomically $ writeTVar (active c) True
type AgentErrorMonad m = (MonadUnliftIO m, MonadError AgentErrorType m)
createUser :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> m UserId
createUser c = withAgentEnv c .: createUser' c
-- | Delete user record optionally deleting all user's connections on SMP servers
deleteUser :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> m ()
deleteUser c = withAgentEnv c .: deleteUser' c
-- | Create SMP agent connection (NEW command) asynchronously, synchronous response is new connection id
createConnectionAsync :: forall m c. (AgentErrorMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId
createConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .: newConnAsync c userId aCorrId enableNtfs
-- | Join SMP agent connection (JOIN command) asynchronously, synchronous response is new connection id
joinConnectionAsync :: AgentErrorMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConnectionAsync c userId aCorrId enableNtfs = withAgentEnv c .:. joinConnAsync c userId aCorrId enableNtfs
-- | Allow connection to continue after CONF notification (LET command), no synchronous response
allowConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
allowConnectionAsync c = withAgentEnv c .:: allowConnectionAsync' c
-- | Accept contact after REQ notification (ACPT command) asynchronously, synchronous response is new connection id
acceptContactAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> Bool -> ConfirmationId -> ConnInfo -> SubscriptionMode -> m ConnId
acceptContactAsync c aCorrId enableNtfs = withAgentEnv c .:. acceptContactAsync' c aCorrId enableNtfs
-- | Acknowledge message (ACK command) asynchronously, no synchronous response
ackMessageAsync :: forall m. AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m ()
ackMessageAsync c = withAgentEnv c .:: ackMessageAsync' c
-- | Switch connection to the new receive queue
switchConnectionAsync :: AgentErrorMonad m => AgentClient -> ACorrId -> ConnId -> m ConnectionStats
switchConnectionAsync c = withAgentEnv c .: switchConnectionAsync' c
-- | Delete SMP agent connection (DEL command) asynchronously, no synchronous response
deleteConnectionAsync :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
deleteConnectionAsync c = withAgentEnv c . deleteConnectionAsync' c
-- -- | Delete SMP agent connections using batch commands asynchronously, no synchronous response
deleteConnectionsAsync :: AgentErrorMonad m => AgentClient -> [ConnId] -> m ()
deleteConnectionsAsync c = withAgentEnv c . deleteConnectionsAsync' c
-- | Create SMP agent connection (NEW command)
createConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c)
createConnection c userId enableNtfs = withAgentEnv c .:. newConn c userId "" enableNtfs
-- | Join SMP agent connection (JOIN command)
joinConnection :: AgentErrorMonad m => AgentClient -> UserId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConnection c userId enableNtfs = withAgentEnv c .:. joinConn c userId "" enableNtfs
-- | Allow connection to continue after CONF notification (LET command)
allowConnection :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
allowConnection c = withAgentEnv c .:. allowConnection' c
-- | Accept contact after REQ notification (ACPT command)
acceptContact :: AgentErrorMonad m => AgentClient -> Bool -> ConfirmationId -> ConnInfo -> SubscriptionMode -> m ConnId
acceptContact c enableNtfs = withAgentEnv c .:. acceptContact' c "" enableNtfs
-- | Reject contact (RJCT command)
rejectContact :: AgentErrorMonad m => AgentClient -> ConnId -> ConfirmationId -> m ()
rejectContact c = withAgentEnv c .: rejectContact' c
-- | Subscribe to receive connection messages (SUB command)
subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
subscribeConnection c = withAgentEnv c . subscribeConnection' c
-- | Subscribe to receive connection messages from multiple connections, batching commands when possible
subscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
subscribeConnections c = withAgentEnv c . subscribeConnections' c
-- | Get connection message (GET command)
getConnectionMessage :: AgentErrorMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta)
getConnectionMessage c = withAgentEnv c . getConnectionMessage' c
-- | Get connection message for received notification
getNotificationMessage :: AgentErrorMonad m => AgentClient -> C.CbNonce -> ByteString -> m (NotificationInfo, [SMPMsgMeta])
getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c
resubscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
resubscribeConnection c = withAgentEnv c . resubscribeConnection' c
resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
resubscribeConnections c = withAgentEnv c . resubscribeConnections' c
-- | Send message to the connection (SEND command)
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage c = withAgentEnv c .:. sendMessage' c
ackMessage :: AgentErrorMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m ()
ackMessage c = withAgentEnv c .:. ackMessage' c
-- | Switch connection to the new receive queue
switchConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
switchConnection c = withAgentEnv c . switchConnection' c
-- | Abort switching connection to the new receive queue
abortConnectionSwitch :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
abortConnectionSwitch c = withAgentEnv c . abortConnectionSwitch' c
-- | Re-synchronize connection ratchet keys
synchronizeRatchet :: AgentErrorMonad m => AgentClient -> ConnId -> Bool -> m ConnectionStats
synchronizeRatchet c = withAgentEnv c .: synchronizeRatchet' c
-- | Suspend SMP agent connection (OFF command)
suspendConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
suspendConnection c = withAgentEnv c . suspendConnection' c
-- | Delete SMP agent connection (DEL command)
deleteConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
deleteConnection c = withAgentEnv c . deleteConnection' c
-- | Delete multiple connections, batching commands when possible
deleteConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
deleteConnections c = withAgentEnv c . deleteConnections' c
-- | get servers used for connection
getConnectionServers :: AgentErrorMonad m => AgentClient -> ConnId -> m ConnectionStats
getConnectionServers c = withAgentEnv c . getConnectionServers' c
-- | get connection ratchet associated data hash for verification (should match peer AD hash)
getConnectionRatchetAdHash :: AgentErrorMonad m => AgentClient -> ConnId -> m ByteString
getConnectionRatchetAdHash c = withAgentEnv c . getConnectionRatchetAdHash' c
-- | Change servers to be used for creating new queues
setProtocolServers :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentErrorMonad m) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> m ()
setProtocolServers c = withAgentEnv c .: setProtocolServers' c
-- | Test protocol server
testProtocolServer :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentErrorMonad m) => AgentClient -> UserId -> ProtoServerWithAuth p -> m (Maybe ProtocolTestFailure)
testProtocolServer c userId srv = withAgentEnv c $ case protocolTypeI @p of
SPSMP -> runSMPServerTest c userId srv
SPXFTP -> runXFTPServerTest c userId srv
setNtfServers :: MonadUnliftIO m => AgentClient -> [NtfServer] -> m ()
setNtfServers c = withAgentEnv c . setNtfServers' c
-- | set SOCKS5 proxy on/off and optionally set TCP timeout
setNetworkConfig :: MonadUnliftIO m => AgentClient -> NetworkConfig -> m ()
setNetworkConfig c cfg' = do
cfg <- atomically $ do
swapTVar (useNetworkConfig c) cfg'
when (cfg /= cfg') $ reconnectAllServers c
getNetworkConfig :: AgentErrorMonad m => AgentClient -> m NetworkConfig
getNetworkConfig = readTVarIO . useNetworkConfig
reconnectAllServers :: MonadUnliftIO m => AgentClient -> m ()
reconnectAllServers c = liftIO $ do
closeProtocolServerClients c smpClients
closeProtocolServerClients c ntfClients
-- | Register device notifications token
registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus
registerNtfToken c = withAgentEnv c .: registerNtfToken' c
-- | Verify device notifications token
verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> m ()
verifyNtfToken c = withAgentEnv c .:. verifyNtfToken' c
checkNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m NtfTknStatus
checkNtfToken c = withAgentEnv c . checkNtfToken' c
deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
deleteNtfToken c = withAgentEnv c . deleteNtfToken' c
getNtfToken :: AgentErrorMonad m => AgentClient -> m (DeviceToken, NtfTknStatus, NotificationsMode)
getNtfToken c = withAgentEnv c $ getNtfToken' c
getNtfTokenData :: AgentErrorMonad m => AgentClient -> m NtfToken
getNtfTokenData c = withAgentEnv c $ getNtfTokenData' c
-- | Set connection notifications on/off
toggleConnectionNtfs :: AgentErrorMonad m => AgentClient -> ConnId -> Bool -> m ()
toggleConnectionNtfs c = withAgentEnv c .: toggleConnectionNtfs' c
xftpStartWorkers :: AgentErrorMonad m => AgentClient -> Maybe FilePath -> m ()
xftpStartWorkers c = withAgentEnv c . startXFTPWorkers c
-- | Receive XFTP file
xftpReceiveFile :: AgentErrorMonad m => AgentClient -> UserId -> ValidFileDescription 'FRecipient -> Maybe CryptoFileArgs -> m RcvFileId
xftpReceiveFile c = withAgentEnv c .:. xftpReceiveFile' c
-- | Delete XFTP rcv file (deletes work files from file system and db records)
xftpDeleteRcvFile :: AgentErrorMonad m => AgentClient -> RcvFileId -> m ()
xftpDeleteRcvFile c = withAgentEnv c . xftpDeleteRcvFile' c
-- | Send XFTP file
xftpSendFile :: AgentErrorMonad m => AgentClient -> UserId -> CryptoFile -> Int -> m SndFileId
xftpSendFile c = withAgentEnv c .:. xftpSendFile' c
-- | Delete XFTP snd file internally (deletes work files from file system and db records)
xftpDeleteSndFileInternal :: AgentErrorMonad m => AgentClient -> SndFileId -> m ()
xftpDeleteSndFileInternal c = withAgentEnv c . deleteSndFileInternal c
-- | Delete XFTP snd file chunks on servers
xftpDeleteSndFileRemote :: AgentErrorMonad m => AgentClient -> UserId -> SndFileId -> ValidFileDescription 'FSender -> m ()
xftpDeleteSndFileRemote c = withAgentEnv c .:. deleteSndFileRemote c
-- | Activate operations
foregroundAgent :: MonadUnliftIO m => AgentClient -> m ()
foregroundAgent c = withAgentEnv c $ foregroundAgent' c
-- | Suspend operations with max delay to deliver pending messages
suspendAgent :: MonadUnliftIO m => AgentClient -> Int -> m ()
suspendAgent c = withAgentEnv c . suspendAgent' c
execAgentStoreSQL :: AgentErrorMonad m => AgentClient -> Text -> m [Text]
execAgentStoreSQL c = withAgentEnv c . execAgentStoreSQL' c
getAgentMigrations :: AgentErrorMonad m => AgentClient -> m [UpMigration]
getAgentMigrations c = withAgentEnv c $ getAgentMigrations' c
debugAgentLocks :: MonadUnliftIO m => AgentClient -> m AgentLocks
debugAgentLocks c = withAgentEnv c $ debugAgentLocks' c
getAgentStats :: MonadIO m => AgentClient -> m [(AgentStatsKey, Int)]
getAgentStats c = readTVarIO (agentStats c) >>= mapM (\(k, cnt) -> (k,) <$> readTVarIO cnt) . M.assocs
resetAgentStats :: MonadIO m => AgentClient -> m ()
resetAgentStats = atomically . TM.clear . agentStats
withAgentEnv :: AgentClient -> ReaderT Env m a -> m a
withAgentEnv c = (`runReaderT` agentEnv c)
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
getAgentClient :: AgentMonad' m => InitialAgentServers -> m AgentClient
getAgentClient initServers = ask >>= atomically . newAgentClient initServers
logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m ()
logConnection c connected =
let event = if connected then "connected to" else "disconnected from"
in logInfo $ T.unwords ["client", showText (clientId c), event, "Agent"]
-- | Runs an SMP agent instance that receives commands and sends responses via 'TBQueue's.
runAgentClient :: AgentMonad' m => AgentClient -> m ()
runAgentClient c = race_ (subscriber c) (client c)
client :: forall m. AgentMonad' m => AgentClient -> m ()
client c@AgentClient {rcvQ, subQ} = forever $ do
(corrId, entId, cmd) <- atomically $ readTBQueue rcvQ
runExceptT (processCommand c (entId, cmd))
>>= atomically . writeTBQueue subQ . \case
Left e -> (corrId, entId, APC SAEConn $ ERR e)
Right (entId', resp) -> (corrId, entId', resp)
-- | execute any SMP agent command
processCommand :: forall m. AgentMonad m => AgentClient -> (EntityId, APartyCmd 'Client) -> m (EntityId, APartyCmd 'Agent)
processCommand c (connId, APC e cmd) =
second (APC e) <$> case cmd of
NEW enableNtfs (ACM cMode) subMode -> second (INV . ACR cMode) <$> newConn c userId connId enableNtfs cMode Nothing subMode
JOIN enableNtfs (ACR _ cReq) subMode connInfo -> (,OK) <$> joinConn c userId connId enableNtfs cReq connInfo subMode
LET confId ownCInfo -> allowConnection' c connId confId ownCInfo $> (connId, OK)
ACPT invId ownCInfo -> (,OK) <$> acceptContact' c connId True invId ownCInfo SMSubscribe
RJCT invId -> rejectContact' c connId invId $> (connId, OK)
SUB -> subscribeConnection' c connId $> (connId, OK)
SEND msgFlags msgBody -> (connId,) . MID <$> sendMessage' c connId msgFlags msgBody
ACK msgId rcptInfo_ -> ackMessage' c connId msgId rcptInfo_ $> (connId, OK)
SWCH -> switchConnection' c connId $> (connId, OK)
OFF -> suspendConnection' c connId $> (connId, OK)
DEL -> deleteConnection' c connId $> (connId, OK)
CHK -> (connId,) . STAT <$> getConnectionServers' c connId
where
-- command interface does not support different users
userId :: UserId
userId = 1
createUser' :: AgentMonad m => AgentClient -> NonEmpty SMPServerWithAuth -> NonEmpty XFTPServerWithAuth -> m UserId
createUser' c smp xftp = do
userId <- withStore' c createUserRecord
atomically $ TM.insert userId smp $ smpServers c
atomically $ TM.insert userId xftp $ xftpServers c
pure userId
deleteUser' :: AgentMonad m => AgentClient -> UserId -> Bool -> m ()
deleteUser' c userId delSMPQueues = do
if delSMPQueues
then withStore c (`setUserDeleted` userId) >>= deleteConnectionsAsync_ delUser c
else withStore c (`deleteUserRecord` userId)
atomically $ TM.delete userId $ smpServers c
where
delUser =
whenM (withStore' c (`deleteUserWithoutConns` userId)) . atomically $
writeTBQueue (subQ c) ("", "", APC SAENone $ DEL_USER userId)
newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> SubscriptionMode -> m ConnId
newConnAsync c userId corrId enableNtfs cMode subMode = do
connId <- newConnNoQueues c userId "" enableNtfs cMode
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ NEW enableNtfs (ACM cMode) subMode
pure connId
newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId
newConnNoQueues c userId connId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
-- connection mode is determined by the accepting agent
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
withStore c $ \db -> createNewConn db g cData cMode
joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo subMode = do
withInvLock c (B.unpack . strEncode $ cReqUri) "joinConnAsync" $ do
aVRange <- asks $ smpAgentVRange . config
case crAgentVRange `compatibleVersion` aVRange of
Just (Compatible connAgentVersion) -> do
g <- asks idsDrg
let duplexHS = connAgentVersion /= 1
cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
connId <- withStore c $ \db -> createNewConn db g cData SCMInvitation
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn $ JOIN enableNtfs (ACR sConnectionMode cReqUri) subMode cInfo
pure connId
_ -> throwError $ AGENT A_VERSION
joinConnAsync _c _userId _corrId _enableNtfs (CRContactUri _) _subMode _cInfo =
throwError $ CMD PROHIBITED
allowConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> ConfirmationId -> ConnInfo -> m ()
allowConnectionAsync' c corrId connId confId ownConnInfo =
withStore c (`getConn` connId) >>= \case
SomeConn _ (RcvConnection _ RcvQueue {server}) ->
enqueueCommand c corrId connId (Just server) $ AClientCommand $ APC SAEConn $ LET confId ownConnInfo
_ -> throwError $ CMD PROHIBITED
acceptContactAsync' :: AgentMonad m => AgentClient -> ACorrId -> Bool -> InvitationId -> ConnInfo -> SubscriptionMode -> m ConnId
acceptContactAsync' c corrId enableNtfs invId ownConnInfo subMode = do
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ (ContactConnection ConnData {userId} _) -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConnAsync c userId corrId enableNtfs connReq ownConnInfo subMode `catchAgentError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
ackMessageAsync' :: forall m. AgentMonad m => AgentClient -> ACorrId -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m ()
ackMessageAsync' c corrId connId msgId rcptInfo_ = do
SomeConn cType _ <- withStore c (`getConn` connId)
case cType of
SCDuplex -> enqueueAck
SCRcv -> enqueueAck
SCSnd -> throwError $ CONN SIMPLEX
SCContact -> throwError $ CMD PROHIBITED
SCNew -> throwError $ CMD PROHIBITED
where
enqueueAck :: m ()
enqueueAck = do
let mId = InternalId msgId
RcvMsg {msgType} <- withStoreCtx "ackMessageAsync': getRcvMsg" c $ \db -> getRcvMsg db connId mId
when (isJust rcptInfo_ && msgType /= AM_A_MSG_) $ throwError $ CMD PROHIBITED
(RcvQueue {server}, _) <- withStoreCtx "ackMessageAsync': setMsgUserAck" c $ \db -> setMsgUserAck db connId mId
enqueueCommand c corrId connId (Just server) . AClientCommand $ APC SAEConn $ ACK msgId rcptInfo_
deleteConnectionAsync' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
deleteConnectionAsync' c connId = deleteConnectionsAsync' c [connId]
deleteConnectionsAsync' :: AgentMonad m => AgentClient -> [ConnId] -> m ()
deleteConnectionsAsync' = deleteConnectionsAsync_ $ pure ()
deleteConnectionsAsync_ :: forall m. AgentMonad m => m () -> AgentClient -> [ConnId] -> m ()
deleteConnectionsAsync_ onSuccess c connIds = case connIds of
[] -> onSuccess
_ -> do
(_, rqs, connIds') <- prepareDeleteConnections_ getConns c connIds
withStore' c $ forM_ connIds' . setConnDeleted
void . forkIO $
withLock (deleteLock c) "deleteConnectionsAsync" $
deleteConnQueues c True rqs >> onSuccess
-- | Add connection to the new receive queue
switchConnectionAsync' :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> m ConnectionStats
switchConnectionAsync' c corrId connId =
withConnLock c connId "switchConnectionAsync" $
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData rqs@(rq :| _rqs) sqs)
| isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED
| otherwise -> do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
enqueueCommand c corrId connId Nothing $ AClientCommand $ APC SAEConn SWCH
let rqs' = updatedQs rq1 rqs
pure . connectionStats $ DuplexConnection cData rqs' sqs
_ -> throwError $ CMD PROHIBITED
newConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> m (ConnId, ConnectionRequestUri c)
newConn c userId connId enableNtfs cMode clientData subMode =
getSMPServer c userId >>= newConnSrv c userId connId enableNtfs cMode clientData subMode
newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newConnSrv c userId connId enableNtfs cMode clientData subMode srv = do
connId' <- newConnNoQueues c userId connId enableNtfs cMode
newRcvConnSrv c userId connId' enableNtfs cMode clientData subMode srv
newRcvConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> Maybe CRClientData -> SubscriptionMode -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newRcvConnSrv c userId connId enableNtfs cMode clientData subMode srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
(rq, qUri) <- newRcvQueue c userId connId srv smpClientVRange subMode `catchAgentError` \e -> liftIO (print e) >> throwError e
void . withStore c $ \db -> updateNewConnRcv db connId rq
case subMode of
SMOnlyCreate -> pure ()
SMSubscribe -> addSubscription c rq
when enableNtfs $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
let crData = ConnReqUriData CRSSimplex smpAgentVRange [qUri] clientData
case cMode of
SCMContact -> pure (connId, CRContactUri crData)
SCMInvitation -> do
(pk1, pk2, e2eRcvParams) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
withStore' c $ \db -> createRatchetX3dhKeys db connId pk1 pk2
pure (connId, CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> m ConnId
joinConn c userId connId enableNtfs cReq cInfo subMode = do
srv <- case cReq of
CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _ ->
getNextServer c userId [qServer q]
_ -> getSMPServer c userId
joinConnSrv c userId connId enableNtfs cReq cInfo subMode srv
startJoinInvitation :: AgentMonad m => UserId -> ConnId -> Bool -> ConnectionRequestUri 'CMInvitation -> m (Compatible Version, ConnData, SndQueue, CR.Ratchet 'C.X448, CR.E2ERatchetParams 'C.X448)
startJoinInvitation userId connId enableNtfs (CRInvitationUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)} e2eRcvParamsUri) = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
case ( qUri `compatibleVersion` smpClientVRange,
e2eRcvParamsUri `compatibleVersion` e2eEncryptVRange,
crAgentVRange `compatibleVersion` smpAgentVRange
) of
(Just qInfo, Just (Compatible e2eRcvParams@(CR.E2ERatchetParams _ _ rcDHRr)), Just aVersion@(Compatible connAgentVersion)) -> do
(pk1, pk2, e2eSndParams) <- liftIO . CR.generateE2EParams $ version e2eRcvParams
(_, rcDHRs) <- liftIO C.generateKeyPair'
let rc = CR.initSndRatchet e2eEncryptVRange rcDHRr rcDHRs $ CR.x3dhSnd pk1 pk2 e2eRcvParams
q <- newSndQueue userId "" qInfo
let duplexHS = connAgentVersion /= 1
cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Just duplexHS, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk}
pure (aVersion, cData, q, rc, e2eSndParams)
_ -> throwError $ AGENT A_VERSION
joinConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ConnId
joinConnSrv c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv =
withInvLock c (B.unpack . strEncode $ inv) "joinConnSrv" $ do
(aVersion, cData@ConnData {connAgentVersion}, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
g <- asks idsDrg
connId' <- withStore c $ \db -> runExceptT $ do
connId' <- ExceptT $ createSndConn db g cData q
liftIO $ createRatchet db connId' rc
pure connId'
let sq = (q :: SndQueue) {connId = connId'}
cData' = (cData :: ConnData) {connId = connId'}
duplexHS = connAgentVersion /= 1
tryError (confirmQueue aVersion c cData' sq srv cInfo (Just e2eSndParams) subMode) >>= \case
Right _ -> do
unless duplexHS . void $ enqueueMessage c cData' sq SMP.noMsgFlags HELLO
pure connId'
Left e -> do
-- possible improvement: recovery for failure on network timeout, see rfcs/2022-04-20-smp-conf-timeout-recovery.md
withStore' c (`deleteConn` connId')
throwError e
joinConnSrv c userId connId enableNtfs (CRContactUri ConnReqUriData {crAgentVRange, crSmpQueues = (qUri :| _)}) cInfo subMode srv = do
aVRange <- asks $ smpAgentVRange . config
clientVRange <- asks $ smpClientVRange . config
case ( qUri `compatibleVersion` clientVRange,
crAgentVRange `compatibleVersion` aVRange
) of
(Just qInfo, Just vrsn) -> do
(connId', cReq) <- newConnSrv c userId connId enableNtfs SCMInvitation Nothing subMode srv
sendInvitation c userId qInfo vrsn cReq cInfo
pure connId'
_ -> throwError $ AGENT A_VERSION
joinConnSrvAsync :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> ConnectionRequestUri c -> ConnInfo -> SubscriptionMode -> SMPServerWithAuth -> m ()
joinConnSrvAsync c userId connId enableNtfs inv@CRInvitationUri {} cInfo subMode srv = do
(aVersion, cData, q, rc, e2eSndParams) <- startJoinInvitation userId connId enableNtfs inv
dbQueueId <- withStore c $ \db -> runExceptT $ do
liftIO $ createRatchet db connId rc
ExceptT $ updateNewConnSnd db connId q
let q' = (q :: SndQueue) {dbQueueId}
confirmQueueAsync aVersion c cData q' srv cInfo (Just e2eSndParams) subMode
joinConnSrvAsync _c _userId _connId _enableNtfs (CRContactUri _) _cInfo _subMode _srv = do
throwError $ CMD PROHIBITED
createReplyQueue :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> SubscriptionMode -> SMPServerWithAuth -> m SMPQueueInfo
createReplyQueue c ConnData {userId, connId, enableNtfs} SndQueue {smpClientVersion} subMode srv = do
(rq, qUri) <- newRcvQueue c userId connId srv (versionToRange smpClientVersion) subMode
let qInfo = toVersionT qUri smpClientVersion
case subMode of
SMOnlyCreate -> pure ()
SMSubscribe -> addSubscription c rq
void . withStore c $ \db -> upgradeSndConnToDuplex db connId rq
when enableNtfs $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
pure qInfo
-- | Approve confirmation (LET command) in Reader monad
allowConnection' :: AgentMonad m => AgentClient -> ConnId -> ConfirmationId -> ConnInfo -> m ()
allowConnection' c connId confId ownConnInfo = withConnLock c connId "allowConnection" $ do
withStore c (`getConn` connId) >>= \case
SomeConn _ (RcvConnection _ rq@RcvQueue {server, rcvId, e2ePrivKey, smpClientVersion = v}) -> do
senderKey <- withStore c $ \db -> runExceptT $ do
AcceptedConfirmation {ratchetState, senderConf = SMPConfirmation {senderKey, e2ePubKey, smpClientVersion = v'}} <- ExceptT $ acceptConfirmation db confId ownConnInfo
liftIO $ createRatchet db connId ratchetState
let dhSecret = C.dh' e2ePubKey e2ePrivKey
liftIO $ setRcvQueueConfirmedE2E db rq dhSecret $ min v v'
pure senderKey
enqueueCommand c "" connId (Just server) . AInternalCommand $ ICAllowSecure rcvId senderKey
_ -> throwError $ CMD PROHIBITED
-- | Accept contact (ACPT command) in Reader monad
acceptContact' :: AgentMonad m => AgentClient -> ConnId -> Bool -> InvitationId -> ConnInfo -> SubscriptionMode -> m ConnId
acceptContact' c connId enableNtfs invId ownConnInfo subMode = withConnLock c connId "acceptContact" $ do
Invitation {contactConnId, connReq} <- withStore c (`getInvitation` invId)
withStore c (`getConn` contactConnId) >>= \case
SomeConn _ (ContactConnection ConnData {userId} _) -> do
withStore' c $ \db -> acceptInvitation db invId ownConnInfo
joinConn c userId connId enableNtfs connReq ownConnInfo subMode `catchAgentError` \err -> do
withStore' c (`unacceptInvitation` invId)
throwError err
_ -> throwError $ CMD PROHIBITED
-- | Reject contact (RJCT command) in Reader monad
rejectContact' :: AgentMonad m => AgentClient -> ConnId -> InvitationId -> m ()
rejectContact' c contactConnId invId =
withStore c $ \db -> deleteInvitation db contactConnId invId
-- | Subscribe to receive connection messages (SUB command) in Reader monad
subscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
subscribeConnection' c connId = toConnResult connId =<< subscribeConnections' c [connId]
toConnResult :: AgentMonad m => ConnId -> Map ConnId (Either AgentErrorType ()) -> m ()
toConnResult connId rs = case M.lookup connId rs of
Just (Right ()) -> when (M.size rs > 1) $ logError $ T.pack $ "too many results " <> show (M.size rs)
Just (Left e) -> throwError e
_ -> throwError $ INTERNAL $ "no result for connection " <> B.unpack connId
type QCmdResult = (QueueStatus, Either AgentErrorType ())
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
subscribeConnections' _ [] = pure M.empty
subscribeConnections' c connIds = do
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConns` connIds)
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(subRs, rcvQs) = M.mapEither rcvQueueOrResult cs
mapM_ (mapM_ (\(cData, sqs) -> mapM_ (resumeMsgDelivery c cData) sqs) . sndQueue) cs
mapM_ (resumeConnCmds c) $ M.keys cs
rcvRs <- connResults <$> subscribeQueues c (concat $ M.elems rcvQs)
ns <- asks ntfSupervisor
tkn <- readTVarIO (ntfTkn ns)
when (instantNotifications tkn) . void . forkIO $ sendNtfCreate ns rcvRs conns
let rs = M.unions ([errs', subRs, rcvRs] :: [Map ConnId (Either AgentErrorType ())])
notifyResultError rs
pure rs
where
rcvQueueOrResult :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
rcvQueueOrResult (SomeConn _ conn) = case conn of
DuplexConnection _ rqs _ -> Right $ L.toList rqs
SndConnection _ sq -> Left $ sndSubResult sq
RcvConnection _ rq -> Right [rq]
ContactConnection _ rq -> Right [rq]
NewConnection _ -> Left (Right ())
sndSubResult :: SndQueue -> Either AgentErrorType ()
sndSubResult SndQueue {status} = case status of
Confirmed -> Right ()
Active -> Left $ CONN SIMPLEX
_ -> Left $ INTERNAL "unexpected queue status"
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
connResults = M.map snd . foldl' addResult M.empty
where
-- collects results by connection ID
addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
-- combines two results for one connection, by using only Active queues (if there is at least one Active queue)
combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
combineRes r' _ = Just r'
order :: QCmdResult -> Int
order (Active, Right _) = 1
order (Active, _) = 2
order (_, Right _) = 3
order _ = 4
sendNtfCreate :: NtfSupervisor -> Map ConnId (Either AgentErrorType ()) -> Map ConnId (Either StoreError SomeConn) -> m ()
sendNtfCreate ns rcvRs conns =
forM_ (M.assocs rcvRs) $ \case
(connId, Right _) -> forM_ (M.lookup connId conns) $ \case
Right (SomeConn _ conn) -> do
let cmd = if enableNtfs $ toConnData conn then NSCCreate else NSCDelete
atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd)
_ -> pure ()
_ -> pure ()
sndQueue :: SomeConn -> Maybe (ConnData, NonEmpty SndQueue)
sndQueue (SomeConn _ conn) = case conn of
DuplexConnection cData _ sqs -> Just (cData, sqs)
SndConnection cData sq -> Just (cData, [sq])
_ -> Nothing
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
notifyResultError rs = do
let actual = M.size rs
expected = length connIds
when (actual /= expected) . atomically $
writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
resubscribeConnection' c connId = toConnResult connId =<< resubscribeConnections' c [connId]
resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
resubscribeConnections' _ [] = pure M.empty
resubscribeConnections' c connIds = do
let r = M.fromList . zip connIds . repeat $ Right ()
connIds' <- filterM (fmap not . atomically . hasActiveSubscription c) connIds
-- union is left-biased, so results returned by subscribeConnections' take precedence
(`M.union` r) <$> subscribeConnections' c connIds'
getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta)
getConnectionMessage' c connId = do
whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection _ (rq :| _) _ -> getQueueMessage c rq
RcvConnection _ rq -> getQueueMessage c rq
ContactConnection _ rq -> getQueueMessage c rq
SndConnection _ _ -> throwError $ CONN SIMPLEX
NewConnection _ -> throwError $ CMD PROHIBITED
getNotificationMessage' :: forall m. AgentMonad m => AgentClient -> C.CbNonce -> ByteString -> m (NotificationInfo, [SMPMsgMeta])
getNotificationMessage' c nonce encNtfInfo = do
withStore' c getActiveNtfToken >>= \case
Just NtfToken {ntfDhSecret = Just dhSecret} -> do
ntfData <- agentCbDecrypt dhSecret nonce encNtfInfo
PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} <- liftEither (parse strP (INTERNAL "error parsing PNMessageData") ntfData)
(ntfConnId, rcvNtfDhSecret) <- withStore c (`getNtfRcvQueue` smpQueue)
ntfMsgMeta <- (eitherToMaybe . smpDecode <$> agentCbDecrypt rcvNtfDhSecret nmsgNonce encNMsgMeta) `catchAgentError` \_ -> pure Nothing
maxMsgs <- asks $ ntfMaxMessages . config
(NotificationInfo {ntfConnId, ntfTs, ntfMsgMeta},) <$> getNtfMessages ntfConnId maxMsgs ntfMsgMeta []
_ -> throwError $ CMD PROHIBITED
where
getNtfMessages ntfConnId maxMs nMeta ms
| length ms < maxMs =
getConnectionMessage' c ntfConnId >>= \case
Just m@SMP.SMPMsgMeta {msgId, msgTs, msgFlags} -> case nMeta of
Just SMP.NMsgMeta {msgId = msgId', msgTs = msgTs'}
| msgId == msgId' || msgTs > msgTs' -> pure $ reverse (m : ms)
| otherwise -> getMsg (m : ms)
_
| SMP.notification msgFlags -> pure $ reverse (m : ms)
| otherwise -> getMsg (m : ms)
_ -> pure $ reverse ms
| otherwise = pure $ reverse ms
where
getMsg = getNtfMessages ntfConnId maxMs nMeta
-- | Send message to the connection (SEND command) in Reader monad
sendMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage' c connId msgFlags msg = withConnLock c connId "sendMessage" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData _ sqs -> enqueueMsgs cData sqs
SndConnection cData sq -> enqueueMsgs cData [sq]
_ -> throwError $ CONN SIMPLEX
where
enqueueMsgs :: ConnData -> NonEmpty SndQueue -> m AgentMsgId
enqueueMsgs cData sqs = do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
enqueueMessages c cData sqs msgFlags $ A_MSG msg
-- / async command processing v v v
enqueueCommand :: AgentMonad m => AgentClient -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> m ()
enqueueCommand c corrId connId server aCommand = do
resumeSrvCmds c server
commandId <- withStore c $ \db -> createCommand db corrId connId server aCommand
queuePendingCommands c server [commandId]
resumeSrvCmds :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
resumeSrvCmds c server =
unlessM (cmdProcessExists c server) $
async (runCommandProcessing c server)
>>= \a -> atomically (TM.insert server a $ asyncCmdProcesses c)
resumeConnCmds :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
resumeConnCmds c connId =
unlessM connQueued $
withStore' c (`getPendingCommands` connId)
>>= mapM_ (uncurry enqueueConnCmds)
where
enqueueConnCmds srv cmdIds = do
resumeSrvCmds c srv
queuePendingCommands c srv cmdIds
connQueued = atomically $ isJust <$> TM.lookupInsert connId True (connCmdsQueued c)
cmdProcessExists :: AgentMonad' m => AgentClient -> Maybe SMPServer -> m Bool
cmdProcessExists c srv = atomically $ TM.member srv (asyncCmdProcesses c)
queuePendingCommands :: AgentMonad' m => AgentClient -> Maybe SMPServer -> [AsyncCmdId] -> m ()
queuePendingCommands c server cmdIds = atomically $ do
q <- getPendingCommandQ c server
mapM_ (writeTQueue q) cmdIds
getPendingCommandQ :: AgentClient -> Maybe SMPServer -> STM (TQueue AsyncCmdId)
getPendingCommandQ c server = do
maybe newMsgQueue pure =<< TM.lookup server (asyncCmdQueues c)
where
newMsgQueue = do
cq <- newTQueue
TM.insert server cq $ asyncCmdQueues c
pure cq
runCommandProcessing :: forall m. AgentMonad m => AgentClient -> Maybe SMPServer -> m ()
runCommandProcessing c@AgentClient {subQ} server_ = do
cq <- atomically $ getPendingCommandQ c server_
ri <- asks $ messageRetryInterval . config -- different retry interval?
forever $ do
atomically $ endAgentOperation c AOSndNetwork
atomically $ throwWhenInactive c
cmdId <- atomically $ readTQueue cq
atomically $ beginAgentOperation c AOSndNetwork
tryAgentError (withStore c $ \db -> getPendingCommand db cmdId) >>= \case
Left e -> atomically $ writeTBQueue subQ ("", "", APC SAEConn $ ERR e)
Right cmd -> processCmd (riFast ri) cmdId cmd
where
processCmd :: RetryInterval -> AsyncCmdId -> PendingCommand -> m ()
processCmd ri cmdId PendingCommand {corrId, userId, connId, command} = case command of
AClientCommand (APC _ cmd) -> case cmd of
NEW enableNtfs (ACM cMode) subMode -> noServer $ do
usedSrvs <- newTVarIO ([] :: [SMPServer])
tryCommand . withNextSrv c userId usedSrvs [] $ \srv -> do
(_, cReq) <- newRcvConnSrv c userId connId enableNtfs cMode Nothing subMode srv
notify $ INV (ACR cMode cReq)
JOIN enableNtfs (ACR _ cReq@(CRInvitationUri ConnReqUriData {crSmpQueues = q :| _} _)) subMode connInfo -> noServer $ do
let initUsed = [qServer q]
usedSrvs <- newTVarIO initUsed
tryCommand . withNextSrv c userId usedSrvs initUsed $ \srv -> do
joinConnSrvAsync c userId connId enableNtfs cReq connInfo subMode srv
notify OK
LET confId ownCInfo -> withServer' . tryCommand $ allowConnection' c connId confId ownCInfo >> notify OK
ACK msgId rcptInfo_ -> withServer' . tryCommand $ ackMessage' c connId msgId rcptInfo_ >> notify OK
SWCH ->
noServer . tryCommand . withConnLock c connId "switchConnection" $
withStore c (`getConn` connId) >>= \case
SomeConn _ conn@(DuplexConnection _ (replaced :| _rqs) _) ->
switchDuplexConnection c conn replaced >>= notify . SWITCH QDRcv SPStarted
_ -> throwError $ CMD PROHIBITED
DEL -> withServer' . tryCommand $ deleteConnection' c connId >> notify OK
_ -> notify $ ERR $ INTERNAL $ "unsupported async command " <> show (aCommandTag cmd)
AInternalCommand cmd -> case cmd of
ICAckDel rId srvMsgId msgId -> withServer $ \srv -> tryWithLock "ICAckDel" $ ack srv rId srvMsgId >> withStore' c (\db -> deleteMsg db connId msgId)
ICAck rId srvMsgId -> withServer $ \srv -> tryWithLock "ICAck" $ ack srv rId srvMsgId
ICAllowSecure _rId senderKey -> withServer' . tryWithLock "ICAllowSecure" $ do
(SomeConn _ conn, AcceptedConfirmation {senderConf, ownConnInfo}) <-
withStore c $ \db -> runExceptT $ (,) <$> ExceptT (getConn db connId) <*> ExceptT (getAcceptedConfirmation db connId)
case conn of
RcvConnection cData rq -> do
secure rq senderKey
mapM_ (connectReplyQueues c cData ownConnInfo) (L.nonEmpty $ smpReplyQueues senderConf)
_ -> throwError $ INTERNAL $ "incorrect connection type " <> show (internalCmdTag cmd)
ICDuplexSecure _rId senderKey -> withServer' . tryWithLock "ICDuplexSecure" . withDuplexConn $ \(DuplexConnection cData (rq :| _) (sq :| _)) -> do
secure rq senderKey
when (duplexHandshake cData == Just True) . void $
enqueueMessage c cData sq SMP.MsgFlags {notification = True} HELLO
-- ICDeleteConn is no longer used, but it can be present in old client databases
ICDeleteConn -> withStore' c (`deleteCommand` cmdId)
ICDeleteRcvQueue rId -> withServer $ \srv -> tryWithLock "ICDeleteRcvQueue" $ do
rq <- withStore c (\db -> getDeletedRcvQueue db connId srv rId)
deleteQueue c rq
withStore' c (`deleteConnRcvQueue` rq)
ICQSecure rId senderKey ->
withServer $ \srv -> tryWithLock "ICQSecure" . withDuplexConn $ \(DuplexConnection cData rqs sqs) ->
case find (sameQueue (srv, rId)) rqs of
Just rq'@RcvQueue {server, sndId, status, dbReplaceQueueId = Just replaceQId} ->
case find (\q -> replaceQId == q.dbQueueId) rqs of
Just rq1 -> when (status == Confirmed) $ do
secureQueue c rq' senderKey
withStore' c $ \db -> setRcvQueueStatus db rq' Secured
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QUSE [((server, sndId), True)]
rq1' <- withStore' c $ \db -> setRcvSwitchStatus db rq1 $ Just RSSendingQUSE
let rqs' = updatedQs rq1' rqs
conn' = DuplexConnection cData rqs' sqs
notify . SWITCH QDRcv SPSecured $ connectionStats conn'
_ -> internalErr "ICQSecure: no switching queue found"
_ -> internalErr "ICQSecure: queue address not found in connection"
ICQDelete rId -> do
withServer $ \srv -> tryWithLock "ICQDelete" . withDuplexConn $ \(DuplexConnection cData rqs sqs) -> do
case removeQ (srv, rId) rqs of
Nothing -> internalErr "ICQDelete: queue address not found in connection"
Just (rq'@RcvQueue {primary}, rq'' : rqs')
| primary -> internalErr "ICQDelete: cannot delete primary rcv queue"
| otherwise -> do
checkRQSwchStatus rq' RSReceivedMessage
tryError (deleteQueue c rq') >>= \case
Right () -> finalizeSwitch
Left e
| temporaryOrHostError e -> throwError e
| otherwise -> finalizeSwitch >> throwError e
where
finalizeSwitch = do
withStore' c $ \db -> deleteConnRcvQueue db rq'
when (enableNtfs cData) $ do
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
let conn' = DuplexConnection cData (rq'' :| rqs') sqs
notify $ SWITCH QDRcv SPCompleted $ connectionStats conn'
_ -> internalErr "ICQDelete: cannot delete the only queue in connection"
where
ack srv rId srvMsgId = do
rq <- withStore c $ \db -> getRcvQueue db connId srv rId
ackQueueMessage c rq srvMsgId
secure :: RcvQueue -> SMP.SndPublicVerifyKey -> m ()
secure rq senderKey = do
secureQueue c rq senderKey
withStore' c $ \db -> setRcvQueueStatus db rq Secured
where
withServer a = case server_ of
Just srv -> a srv
_ -> internalErr "command requires server"
withServer' = withServer . const
noServer a = case server_ of
Nothing -> a
_ -> internalErr "command requires no server"
withDuplexConn :: (Connection 'CDuplex -> m ()) -> m ()
withDuplexConn a =
withStore c (`getConn` connId) >>= \case
SomeConn _ conn@DuplexConnection {} -> a conn
_ -> internalErr "command requires duplex connection"
tryCommand action = withRetryInterval ri $ \_ loop ->
tryError action >>= \case
Left e
| temporaryOrHostError e -> retrySndOp c loop
| otherwise -> cmdError e
Right () -> withStore' c (`deleteCommand` cmdId)
tryWithLock name = tryCommand . withConnLock c connId name
internalErr s = cmdError $ INTERNAL $ s <> ": " <> show (agentCommandTag command)
cmdError e = notify (ERR e) >> withStore' c (`deleteCommand` cmdId)
notify :: forall e. AEntityI e => ACommand 'Agent e -> m ()
notify cmd = atomically $ writeTBQueue subQ (corrId, connId, APC (sAEntity @e) cmd)
-- ^ ^ ^ async command processing /
enqueueMessages :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessages c cData sqs msgFlags aMessage = do
when (ratchetSyncSendProhibited cData) $ throwError $ INTERNAL "enqueueMessages: ratchet is not synchronized"
enqueueMessages' c cData sqs msgFlags aMessage
enqueueMessages' :: AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessages' c cData (sq :| sqs) msgFlags aMessage = do
msgId <- enqueueMessage c cData sq msgFlags aMessage
mapM_ (enqueueSavedMessage c cData msgId) $
filter (\SndQueue {status} -> status == Secured || status == Active) sqs
pure msgId
enqueueMessage :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> MsgFlags -> AMessage -> m AgentMsgId
enqueueMessage c cData@ConnData {connId} sq msgFlags aMessage = do
resumeMsgDelivery c cData sq
aVRange <- asks $ smpAgentVRange . config
msgId <- storeSentMsg $ maxVersion aVRange
queuePendingMsgs c sq [msgId]
pure $ unId msgId
where
storeSentMsg :: Version -> m InternalId
storeSentMsg agentVersion = withStore c $ \db -> runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
let privHeader = APrivHeader (unSndId internalSndId) prevMsgHash
agentMsg = AgentMessage privHeader aMessage
agentMsgStr = smpEncode agentMsg
internalHash = C.sha256Hash agentMsgStr
encAgentMessage <- agentRatchetEncrypt db connId agentMsgStr e2eEncUserMsgLength
let msgBody = smpEncode $ AgentMsgEnvelope {agentVersion, encAgentMessage}
msgType = agentMessageType agentMsg
msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgFlags, msgBody, internalHash, prevMsgHash}
liftIO $ createSndMsg db connId msgData
liftIO $ createSndMsgDelivery db connId sq internalId
pure internalId
enqueueSavedMessage :: AgentMonad m => AgentClient -> ConnData -> AgentMsgId -> SndQueue -> m ()
enqueueSavedMessage c cData@ConnData {connId} msgId sq = do
resumeMsgDelivery c cData sq
let mId = InternalId msgId
queuePendingMsgs c sq [mId]
withStore' c $ \db -> createSndMsgDelivery db connId sq mId
resumeMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
resumeMsgDelivery c cData@ConnData {connId} sq@SndQueue {server, sndId} = do
let qKey = (server, sndId)
unlessM (queueDelivering qKey) $
async (runSmpQueueMsgDelivery c cData sq)
>>= \a -> atomically (TM.insert qKey a $ smpQueueMsgDeliveries c)
unlessM msgsQueued $
withStore' c (\db -> getPendingMsgs db connId sq)
>>= queuePendingMsgs c sq
where
queueDelivering qKey = atomically $ TM.member qKey (smpQueueMsgDeliveries c)
msgsQueued = atomically $ isJust <$> TM.lookupInsert (server, sndId) True (pendingMsgsQueued c)
queuePendingMsgs :: AgentMonad' m => AgentClient -> SndQueue -> [InternalId] -> m ()
queuePendingMsgs c sq msgIds = atomically $ do
modifyTVar' (msgDeliveryOp c) $ \s -> s {opsInProgress = opsInProgress s + length msgIds}
-- s <- readTVar (msgDeliveryOp c)
-- unsafeIOToSTM $ putStrLn $ "msgDeliveryOp: " <> show (opsInProgress s)
(mq, _) <- getPendingMsgQ c sq
mapM_ (writeTQueue mq) msgIds
getPendingMsgQ :: AgentClient -> SndQueue -> STM (TQueue InternalId, TMVar ())
getPendingMsgQ c SndQueue {server, sndId} = do
let qKey = (server, sndId)
maybe (newMsgQueue qKey) pure =<< TM.lookup qKey (smpQueueMsgQueues c)
where
newMsgQueue qKey = do
q <- (,) <$> newTQueue <*> newEmptyTMVar
TM.insert qKey q $ smpQueueMsgQueues c
pure q
runSmpQueueMsgDelivery :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> m ()
runSmpQueueMsgDelivery c@AgentClient {subQ} cData@ConnData {userId, connId, duplexHandshake} sq = do
(mq, qLock) <- atomically $ getPendingMsgQ c sq
ri <- asks $ messageRetryInterval . config
forever $ do
atomically $ endAgentOperation c AOSndNetwork
atomically $ throwWhenInactive c
atomically $ throwWhenNoDelivery c sq
msgId <- atomically $ readTQueue mq
atomically $ beginAgentOperation c AOSndNetwork
atomically $ endAgentOperation c AOMsgDelivery -- this operation begins in queuePendingMsgs
let mId = unId msgId
tryAgentError (withStore c $ \db -> getPendingMsgData db connId msgId) >>= \case
Left e -> notify $ MERR mId e
Right (rq_, PendingMsgData {msgType, msgBody, msgFlags, msgRetryState, internalTs}) -> do
let ri' = maybe id updateRetryInterval2 msgRetryState ri
withRetryLock2 ri' qLock $ \riState loop -> do
resp <- tryError $ case msgType of
AM_CONN_INFO -> sendConfirmation c sq msgBody
AM_CONN_INFO_REPLY -> sendConfirmation c sq msgBody
_ -> sendAgentMessage c sq msgFlags msgBody
case resp of
Left e -> do
let err = if msgType == AM_A_MSG_ then MERR mId e else ERR e
case e of
SMP SMP.QUOTA -> case msgType of
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
_ -> retrySndMsg RISlow
SMP SMP.AUTH -> case msgType of
AM_CONN_INFO -> connError msgId NOT_AVAILABLE
AM_CONN_INFO_REPLY -> connError msgId NOT_AVAILABLE
AM_RATCHET_INFO -> connError msgId NOT_AVAILABLE
AM_HELLO_
-- in duplexHandshake mode (v2) HELLO is only sent once, without retrying,
-- because the queue must be secured by the time the confirmation or the first HELLO is received
| duplexHandshake == Just True -> connErr
| otherwise ->
ifM (msgExpired helloTimeout) connErr (retrySndMsg RIFast)
where
connErr = case rq_ of
-- party initiating connection
Just _ -> connError msgId NOT_AVAILABLE
-- party joining connection
_ -> connError msgId NOT_ACCEPTED
AM_REPLY_ -> notifyDel msgId err
AM_A_MSG_ -> notifyDel msgId err
AM_A_RCVD_ -> notifyDel msgId err
AM_QCONT_ -> notifyDel msgId err
AM_QADD_ -> qError msgId "QADD: AUTH"
AM_QKEY_ -> qError msgId "QKEY: AUTH"
AM_QUSE_ -> qError msgId "QUSE: AUTH"
AM_QTEST_ -> qError msgId "QTEST: AUTH"
AM_EREADY_ -> notifyDel msgId err
_
-- for other operations BROKER HOST is treated as a permanent error (e.g., when connecting to the server),
-- the message sending would be retried
| temporaryOrHostError e -> do
let timeoutSel = if msgType == AM_HELLO_ then helloTimeout else messageTimeout
ifM (msgExpired timeoutSel) (notifyDel msgId err) (retrySndMsg RIFast)
| otherwise -> notifyDel msgId err
where
msgExpired timeoutSel = do
msgTimeout <- asks $ timeoutSel . config
currentTime <- liftIO getCurrentTime
pure $ diffUTCTime currentTime internalTs > msgTimeout
retrySndMsg riMode = do
withStore' c $ \db -> updatePendingMsgRIState db connId msgId riState
retrySndOp c $ loop riMode
Right () -> do
case msgType of
AM_CONN_INFO -> setConfirmed
AM_CONN_INFO_REPLY -> setConfirmed
AM_RATCHET_INFO -> pure ()
AM_REPLY_ -> pure ()
AM_HELLO_ -> do
withStore' c $ \db -> setSndQueueStatus db sq Active
case rq_ of
-- party initiating connection (in v1)
Just RcvQueue {status} ->
-- it is unclear why subscribeQueue was needed here,
-- message delivery can only be enabled for queues that were created in the current session or subscribed
-- subscribeQueue c rq connId
--
-- If initiating party were to send CON to the user without waiting for reply HELLO (to reduce handshake time),
-- it would lead to the non-deterministic internal ID of the first sent message, at to some other race conditions,
-- because it can be sent before HELLO is received
-- With `status == Active` condition, CON is sent here only by the accepting party, that previously received HELLO
when (status == Active) $ notify CON
-- Party joining connection sends REPLY after HELLO in v1,
-- it is an error to send REPLY in duplexHandshake mode (v2),
-- and this branch should never be reached as receive is created before the confirmation,
-- so the condition is not necessary here, strictly speaking.
_ -> unless (duplexHandshake == Just True) $ do
srv <- getSMPServer c userId
qInfo <- createReplyQueue c cData sq SMSubscribe srv
void . enqueueMessage c cData sq SMP.noMsgFlags $ REPLY [qInfo]
AM_A_MSG_ -> notify $ SENT mId
AM_A_RCVD_ -> pure ()
AM_QCONT_ -> pure ()
AM_QADD_ -> pure ()
AM_QKEY_ -> do
SomeConn _ conn <- withStore c (`getConn` connId)
notify . SWITCH QDSnd SPConfirmed $ connectionStats conn
AM_QUSE_ -> pure ()
AM_QTEST_ -> withConnLock c connId "runSmpQueueMsgDelivery AM_QTEST_" $ do
withStore' c $ \db -> setSndQueueStatus db sq Active
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData' rqs sqs -> do
-- remove old snd queue from connection once QTEST is sent to the new queue
let addr = qAddress sq
case findQ addr sqs of
-- this is the same queue where this loop delivers messages to but with updated state
Just SndQueue {dbReplaceQueueId = Just replacedId, primary} ->
-- second part of this condition is a sanity check because dbReplaceQueueId cannot point to the same queue, see switchConnection'
case removeQP (\sq' -> sq'.dbQueueId == replacedId && not (sameQueue addr sq')) sqs of
Nothing -> internalErr msgId "sent QTEST: queue not found in connection"
Just (sq', sq'' : sqs') -> do
checkSQSwchStatus sq' SSSendingQTEST
-- remove the delivery from the map to stop the thread when the delivery loop is complete
atomically $ TM.delete (qAddress sq') $ smpQueueMsgQueues c
withStore' c $ \db -> do
when primary $ setSndQueuePrimary db connId sq
deletePendingMsgs db connId sq'
deleteConnSndQueue db connId sq'
let sqs'' = sq'' :| sqs'
conn' = DuplexConnection cData' rqs sqs''
notify . SWITCH QDSnd SPCompleted $ connectionStats conn'
_ -> internalErr msgId "sent QTEST: there is only one queue in connection"
_ -> internalErr msgId "sent QTEST: queue not in connection or not replacing another queue"
_ -> internalErr msgId "QTEST sent not in duplex connection"
AM_EREADY_ -> pure ()
delMsgKeep (msgType == AM_A_MSG_) msgId
where
setConfirmed = do
withStore' c $ \db -> do
setSndQueueStatus db sq Confirmed
when (isJust rq_) $ removeConfirmations db connId
unless (duplexHandshake == Just True) . void $ enqueueMessage c cData sq SMP.noMsgFlags HELLO
where
delMsg :: InternalId -> m ()
delMsg = delMsgKeep False
delMsgKeep :: Bool -> InternalId -> m ()
delMsgKeep keepForReceipt msgId = withStore' c $ \db -> deleteSndMsgDelivery db connId sq msgId keepForReceipt
notify :: forall e. AEntityI e => ACommand 'Agent e -> m ()
notify cmd = atomically $ writeTBQueue subQ ("", connId, APC (sAEntity @e) cmd)
notifyDel :: AEntityI e => InternalId -> ACommand 'Agent e -> m ()
notifyDel msgId cmd = notify cmd >> delMsg msgId
connError msgId = notifyDel msgId . ERR . CONN
qError msgId = notifyDel msgId . ERR . AGENT . A_QUEUE
internalErr msgId = notifyDel msgId . ERR . INTERNAL
retrySndOp :: AgentMonad m => AgentClient -> m () -> m ()
retrySndOp c loop = do
-- end... is in a separate atomically because if begin... blocks, SUSPENDED won't be sent
atomically $ endAgentOperation c AOSndNetwork
atomically $ throwWhenInactive c
atomically $ beginAgentOperation c AOSndNetwork
loop
ackMessage' :: forall m. AgentMonad m => AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> m ()
ackMessage' c connId msgId rcptInfo_ = withConnLock c connId "ackMessage" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection {} -> ack >> sendRcpt conn >> del
RcvConnection {} -> ack >> del
SndConnection {} -> throwError $ CONN SIMPLEX
ContactConnection {} -> throwError $ CMD PROHIBITED
NewConnection _ -> throwError $ CMD PROHIBITED
where
ack :: m ()
ack = do
-- the stored message was delivered via a specific queue, the rest failed to decrypt and were already acknowledged
(rq, srvMsgId) <- withStoreCtx "ackMessage': setMsgUserAck" c $ \db -> setMsgUserAck db connId $ InternalId msgId
ackQueueMessage c rq srvMsgId
del :: m ()
del = withStoreCtx' "ackMessage': deleteMsg" c $ \db -> deleteMsg db connId $ InternalId msgId
sendRcpt :: Connection 'CDuplex -> m ()
sendRcpt (DuplexConnection cData _ sqs) = do
msg@RcvMsg {msgType, msgReceipt} <- withStoreCtx "ackMessage': getRcvMsg" c $ \db -> getRcvMsg db connId $ InternalId msgId
case rcptInfo_ of
Just rcptInfo -> do
unless (msgType == AM_A_MSG_) $ throwError (CMD PROHIBITED)
when (messageRcptsSupported cData) $ do
let RcvMsg {msgMeta = MsgMeta {sndMsgId}, internalHash} = msg
rcpt = A_RCVD [AMessageReceipt {agentMsgId = sndMsgId, msgHash = internalHash, rcptInfo}]
void $ enqueueMessages c cData sqs SMP.MsgFlags {notification = False} rcpt
Nothing -> case (msgType, msgReceipt) of
-- only remove sent message if receipt hash was Ok, both to debug and for future redundancy
(AM_A_RCVD_, Just MsgReceipt {agentMsgId = sndMsgId, msgRcptStatus = MROk}) ->
withStoreCtx' "ackMessage': deleteDeliveredSndMsg" c $ \db -> deleteDeliveredSndMsg db connId $ InternalId sndMsgId
_ -> pure ()
switchConnection' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
switchConnection' c connId =
withConnLock c connId "switchConnection" $
withStore c (`getConn` connId) >>= \case
SomeConn _ conn@(DuplexConnection cData rqs@(rq :| _rqs) _)
| isJust (switchingRQ rqs) -> throwError $ CMD PROHIBITED
| otherwise -> do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
rq' <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSwitchStarted
switchDuplexConnection c conn rq'
_ -> throwError $ CMD PROHIBITED
switchDuplexConnection :: AgentMonad m => AgentClient -> Connection 'CDuplex -> RcvQueue -> m ConnectionStats
switchDuplexConnection c (DuplexConnection cData@ConnData {connId, userId} rqs sqs) rq@RcvQueue {server, dbQueueId, sndId} = do
checkRQSwchStatus rq RSSwitchStarted
clientVRange <- asks $ smpClientVRange . config
-- try to get the server that is different from all queues, or at least from the primary rcv queue
srvAuth@(ProtoServerWithAuth srv _) <- getNextServer c userId $ map qServer (L.toList rqs) <> map qServer (L.toList sqs)
srv' <- if srv == server then getNextServer c userId [server] else pure srvAuth
(q, qUri) <- newRcvQueue c userId connId srv' clientVRange SMSubscribe
let rq' = (q :: RcvQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
void . withStore c $ \db -> addConnRcvQueue db connId rq'
addSubscription c rq'
void . enqueueMessages c cData sqs SMP.noMsgFlags $ QADD [(qUri, Just (server, sndId))]
rq1 <- withStore' c $ \db -> setRcvSwitchStatus db rq $ Just RSSendingQADD
let rqs' = updatedQs rq1 rqs <> [rq']
pure . connectionStats $ DuplexConnection cData rqs' sqs
abortConnectionSwitch' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
abortConnectionSwitch' c connId =
withConnLock c connId "abortConnectionSwitch" $
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData rqs sqs) -> case switchingRQ rqs of
Just rq
| canAbortRcvSwitch rq -> do
when (ratchetSyncSendProhibited cData) $ throwError $ CMD PROHIBITED
-- multiple queues to which the connections switches were possible when repeating switch was allowed
let (delRqs, keepRqs) = L.partition (\q -> Just rq.dbQueueId == q.dbReplaceQueueId) rqs
case L.nonEmpty keepRqs of
Just rqs' -> do
rq' <- withStore' c $ \db -> do
mapM_ (setRcvQueueDeleted db) delRqs
setRcvSwitchStatus db rq Nothing
forM_ delRqs $ \RcvQueue {server, rcvId} -> enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICDeleteRcvQueue rcvId
let rqs'' = updatedQs rq' rqs'
conn' = DuplexConnection cData rqs'' sqs
pure $ connectionStats conn'
_ -> throwError $ INTERNAL "won't delete all rcv queues in connection"
| otherwise -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD PROHIBITED
synchronizeRatchet' :: AgentMonad m => AgentClient -> ConnId -> Bool -> m ConnectionStats
synchronizeRatchet' c connId force = withConnLock c connId "synchronizeRatchet" $ do
withStore c (`getConn` connId) >>= \case
SomeConn _ (DuplexConnection cData rqs sqs)
| ratchetSyncAllowed cData || force -> do
-- check queues are not switching?
AgentConfig {e2eEncryptVRange} <- asks config
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
void $ enqueueRatchetKeyMsgs c cData sqs e2eParams
withStore' c $ \db -> do
setConnRatchetSync db connId RSStarted
setRatchetX3dhKeys db connId pk1 pk2 k1 k2
let cData' = cData {ratchetSyncState = RSStarted} :: ConnData
conn' = DuplexConnection cData' rqs sqs
pure $ connectionStats conn'
| otherwise -> throwError $ CMD PROHIBITED
_ -> throwError $ CMD PROHIBITED
ackQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> SMP.MsgId -> m ()
ackQueueMessage c rq srvMsgId =
sendAck c rq srvMsgId `catchAgentError` \case
SMP SMP.NO_MSG -> pure ()
e -> throwError e
-- | Suspend SMP agent connection (OFF command) in Reader monad
suspendConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
suspendConnection' c connId = withConnLock c connId "suspendConnection" $ do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection _ rqs _ -> mapM_ (suspendQueue c) rqs
RcvConnection _ rq -> suspendQueue c rq
ContactConnection _ rq -> suspendQueue c rq
SndConnection _ _ -> throwError $ CONN SIMPLEX
NewConnection _ -> throwError $ CMD PROHIBITED
-- | Delete SMP agent connection (DEL command) in Reader monad
-- unlike deleteConnectionAsync, this function does not mark connection as deleted in case of deletion failure
-- currently it is used only in tests
deleteConnection' :: forall m. AgentMonad m => AgentClient -> ConnId -> m ()
deleteConnection' c connId = toConnResult connId =<< deleteConnections' c [connId]
connRcvQueues :: Connection d -> [RcvQueue]
connRcvQueues = \case
DuplexConnection _ rqs _ -> L.toList rqs
RcvConnection _ rq -> [rq]
ContactConnection _ rq -> [rq]
SndConnection _ _ -> []
NewConnection _ -> []
disableConn :: AgentMonad m => AgentClient -> ConnId -> m ()
disableConn c connId = do
atomically $ removeSubscription c connId
ns <- asks ntfSupervisor
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCDelete)
-- Unlike deleteConnectionsAsync, this function does not mark connections as deleted in case of deletion failure.
deleteConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
deleteConnections' = deleteConnections_ getConns False
deleteDeletedConns :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
deleteDeletedConns = deleteConnections_ getDeletedConns True
prepareDeleteConnections_ ::
forall m.
AgentMonad m =>
(DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) ->
AgentClient ->
[ConnId] ->
m (Map ConnId (Either AgentErrorType ()), [RcvQueue], [ConnId])
prepareDeleteConnections_ getConnections c connIds = do
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (`getConnections` connIds)
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(delRs, rcvQs) = M.mapEither rcvQueues cs
rqs = concat $ M.elems rcvQs
connIds' = M.keys rcvQs
forM_ connIds' $ disableConn c
withStore' c $ forM_ (M.keys delRs) . deleteConn
pure (errs' <> delRs, rqs, connIds')
where
rcvQueues :: SomeConn -> Either (Either AgentErrorType ()) [RcvQueue]
rcvQueues (SomeConn _ conn) = case connRcvQueues conn of
[] -> Left $ Right ()
rqs -> Right rqs
deleteConnQueues :: forall m. AgentMonad m => AgentClient -> Bool -> [RcvQueue] -> m (Map ConnId (Either AgentErrorType ()))
deleteConnQueues c ntf rqs = do
rs <- connResults <$> (deleteQueueRecs =<< deleteQueues c rqs)
forM_ (M.assocs rs) $ \case
(connId, Right _) -> withStore' c (`deleteConn` connId) >> notify ("", connId, APC SAEConn DEL_CONN)
_ -> pure ()
pure rs
where
deleteQueueRecs :: [(RcvQueue, Either AgentErrorType ())] -> m [(RcvQueue, Either AgentErrorType ())]
deleteQueueRecs rs = do
maxErrs <- asks $ deleteErrorCount . config
forM rs $ \(rq, r) -> do
r' <- case r of
Right _ -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq Nothing $> r
Left e
| temporaryOrHostError e && deleteErrors rq + 1 < maxErrs -> withStore' c (`incRcvDeleteErrors` rq) $> r
| otherwise -> withStore' c (`deleteConnRcvQueue` rq) >> notifyRQ rq (Just e) $> Right ()
pure (rq, r')
notifyRQ rq e_ = notify ("", rq.connId, APC SAEConn $ DEL_RCVQ (qServer rq) (queueId rq) e_)
notify = when ntf . atomically . writeTBQueue (subQ c)
connResults :: [(RcvQueue, Either AgentErrorType ())] -> Map ConnId (Either AgentErrorType ())
connResults = M.map snd . foldl' addResult M.empty
where
-- collects results by connection ID
addResult :: Map ConnId QCmdResult -> (RcvQueue, Either AgentErrorType ()) -> Map ConnId QCmdResult
addResult rs (RcvQueue {connId, status}, r) = M.alter (combineRes (status, r)) connId rs
-- combines two results for one connection, by prioritizing errors in Active queues
combineRes :: QCmdResult -> Maybe QCmdResult -> Maybe QCmdResult
combineRes r' (Just r) = Just $ if order r <= order r' then r else r'
combineRes r' _ = Just r'
order :: QCmdResult -> Int
order (Active, Left _) = 1
order (_, Left _) = 2
order _ = 3
deleteConnections_ ::
forall m.
AgentMonad m =>
(DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn]) ->
Bool ->
AgentClient ->
[ConnId] ->
m (Map ConnId (Either AgentErrorType ()))
deleteConnections_ _ _ _ [] = pure M.empty
deleteConnections_ getConnections ntf c connIds = do
(rs, rqs, _) <- prepareDeleteConnections_ getConnections c connIds
rcvRs <- deleteConnQueues c ntf rqs
let rs' = M.union rs rcvRs
notifyResultError rs'
pure rs'
where
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
notifyResultError rs = do
let actual = M.size rs
expected = length connIds
when (actual /= expected) . atomically $
writeTBQueue (subQ c) ("", "", APC SAEConn $ ERR $ INTERNAL $ "deleteConnections result size: " <> show actual <> ", expected " <> show expected)
getConnectionServers' :: AgentMonad m => AgentClient -> ConnId -> m ConnectionStats
getConnectionServers' c connId = do
SomeConn _ conn <- withStore c (`getConn` connId)
pure $ connectionStats conn
getConnectionRatchetAdHash' :: AgentMonad m => AgentClient -> ConnId -> m ByteString
getConnectionRatchetAdHash' c connId = do
CR.Ratchet {rcAD = Str rcAD} <- withStore c (`getRatchet` connId)
pure $ C.sha256Hash rcAD
connectionStats :: Connection c -> ConnectionStats
connectionStats = \case
RcvConnection cData rq ->
(stats cData) {rcvQueuesInfo = [rcvQueueInfo rq]}
SndConnection cData sq ->
(stats cData) {sndQueuesInfo = [sndQueueInfo sq]}
DuplexConnection cData rqs sqs ->
(stats cData) {rcvQueuesInfo = map rcvQueueInfo $ L.toList rqs, sndQueuesInfo = map sndQueueInfo $ L.toList sqs}
ContactConnection cData rq ->
(stats cData) {rcvQueuesInfo = [rcvQueueInfo rq]}
NewConnection cData ->
stats cData
where
stats cData@ConnData {connAgentVersion, ratchetSyncState} =
ConnectionStats
{ connAgentVersion,
rcvQueuesInfo = [],
sndQueuesInfo = [],
ratchetSyncState,
ratchetSyncSupported = ratchetSyncSupported' cData
}
-- | Change servers to be used for creating new queues, in Reader monad
setProtocolServers' :: forall p m. (ProtocolTypeI p, UserProtocol p, AgentMonad m) => AgentClient -> UserId -> NonEmpty (ProtoServerWithAuth p) -> m ()
setProtocolServers' c userId srvs = atomically $ TM.insert userId srvs (userServers c)
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus
registerNtfToken' c suppliedDeviceToken suppliedNtfMode =
withStore' c getSavedNtfToken >>= \case
Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId, ntfTknStatus, ntfTknAction, ntfMode = savedNtfMode} -> do
status <- case (ntfTokenId, ntfTknAction) of
(Nothing, Just NTARegister) -> do
when (savedDeviceToken /= suppliedDeviceToken) $ withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken
registerToken tkn $> NTRegistered
-- possible improvement: add minimal time before repeat registration
(Just tknId, Nothing)
| savedDeviceToken == suppliedDeviceToken ->
when (ntfTknStatus == NTRegistered) (registerToken tkn) $> NTRegistered
| otherwise -> replaceToken tknId
(Just tknId, Just (NTAVerify code))
| savedDeviceToken == suppliedDeviceToken ->
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
| otherwise -> replaceToken tknId
(Just tknId, Just NTACheck)
| savedDeviceToken == suppliedDeviceToken -> do
ns <- asks ntfSupervisor
atomically $ nsUpdateToken ns tkn {ntfMode = suppliedNtfMode}
when (ntfTknStatus == NTActive) $ do
cron <- asks $ ntfCron . config
agentNtfEnableCron c tknId tkn cron
when (suppliedNtfMode == NMInstant) $ initializeNtfSubs c
when (suppliedNtfMode == NMPeriodic && savedNtfMode == NMInstant) $ deleteNtfSubs c NSCDelete
-- possible improvement: get updated token status from the server, or maybe TCRON could return the current status
pure ntfTknStatus
| otherwise -> replaceToken tknId
(Just tknId, Just NTADelete) -> do
agentNtfDeleteToken c tknId tkn
withStore' c (`removeNtfToken` tkn)
ns <- asks ntfSupervisor
atomically $ nsRemoveNtfToken ns
pure NTExpired
_ -> pure ntfTknStatus
withStore' c $ \db -> updateNtfMode db tkn suppliedNtfMode
pure status
where
replaceToken :: NtfTokenId -> m NtfTknStatus
replaceToken tknId = do
ns <- asks ntfSupervisor
tryReplace ns `catchAgentError` \e ->
if temporaryOrHostError e
then throwError e
else do
withStore' c $ \db -> removeNtfToken db tkn
atomically $ nsRemoveNtfToken ns
createToken
where
tryReplace ns = do
agentNtfReplaceToken c tknId tkn suppliedDeviceToken
withStore' c $ \db -> updateDeviceToken db tkn suppliedDeviceToken
atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode}
pure NTRegistered
_ -> createToken
where
t tkn = withToken c tkn Nothing
createToken :: m NtfTknStatus
createToken =
getNtfServer c >>= \case
Just ntfServer ->
asks (cmdSignAlg . config) >>= \case
C.SignAlg a -> do
tknKeys <- liftIO $ C.generateSignatureKeyPair a
dhKeys <- liftIO C.generateKeyPair'
let tkn = newNtfToken suppliedDeviceToken ntfServer tknKeys dhKeys suppliedNtfMode
withStore' c (`createNtfToken` tkn)
registerToken tkn
pure NTRegistered
_ -> throwError $ CMD PROHIBITED
registerToken :: NtfToken -> m ()
registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do
(tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey
let dhSecret = C.dh' srvPubDhKey privDhKey
withStore' c $ \db -> updateNtfTokenRegistration db tkn tknId dhSecret
ns <- asks ntfSupervisor
atomically $ nsUpdateToken ns tkn {deviceToken = suppliedDeviceToken, ntfTknStatus = NTRegistered, ntfMode = suppliedNtfMode}
verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> C.CbNonce -> ByteString -> m ()
verifyNtfToken' c deviceToken nonce code =
withStore' c getSavedNtfToken >>= \case
Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId = Just tknId, ntfDhSecret = Just dhSecret, ntfMode} -> do
when (deviceToken /= savedDeviceToken) . throwError $ CMD PROHIBITED
code' <- liftEither . bimap cryptoError NtfRegCode $ C.cbDecrypt dhSecret nonce code
toStatus <-
withToken c tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $
agentNtfVerifyToken c tknId tkn code'
when (toStatus == NTActive) $ do
cron <- asks $ ntfCron . config
agentNtfEnableCron c tknId tkn cron
when (ntfMode == NMInstant) $ initializeNtfSubs c
_ -> throwError $ CMD PROHIBITED
checkNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m NtfTknStatus
checkNtfToken' c deviceToken =
withStore' c getSavedNtfToken >>= \case
Just tkn@NtfToken {deviceToken = savedDeviceToken, ntfTokenId = Just tknId} -> do
when (deviceToken /= savedDeviceToken) . throwError $ CMD PROHIBITED
agentNtfCheckToken c tknId tkn
_ -> throwError $ CMD PROHIBITED
deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
deleteNtfToken' c deviceToken =
withStore' c getSavedNtfToken >>= \case
Just tkn@NtfToken {deviceToken = savedDeviceToken} -> do
when (deviceToken /= savedDeviceToken) . throwError $ CMD PROHIBITED
deleteToken_ c tkn
deleteNtfSubs c NSCSmpDelete
_ -> throwError $ CMD PROHIBITED
getNtfToken' :: AgentMonad m => AgentClient -> m (DeviceToken, NtfTknStatus, NotificationsMode)
getNtfToken' c =
withStore' c getSavedNtfToken >>= \case
Just NtfToken {deviceToken, ntfTknStatus, ntfMode} -> pure (deviceToken, ntfTknStatus, ntfMode)
_ -> throwError $ CMD PROHIBITED
getNtfTokenData' :: AgentMonad m => AgentClient -> m NtfToken
getNtfTokenData' c =
withStore' c getSavedNtfToken >>= \case
Just tkn -> pure tkn
_ -> throwError $ CMD PROHIBITED
-- | Set connection notifications, in Reader monad
toggleConnectionNtfs' :: forall m. AgentMonad m => AgentClient -> ConnId -> Bool -> m ()
toggleConnectionNtfs' c connId enable = do
SomeConn _ conn <- withStore c (`getConn` connId)
case conn of
DuplexConnection cData _ _ -> toggle cData
RcvConnection cData _ -> toggle cData
ContactConnection cData _ -> toggle cData
_ -> throwError $ CONN SIMPLEX
where
toggle :: ConnData -> m ()
toggle cData
| enableNtfs cData == enable = pure ()
| otherwise = do
withStore' c $ \db -> setConnectionNtfs db connId enable
ns <- asks ntfSupervisor
let cmd = if enable then NSCCreate else NSCDelete
atomically $ sendNtfSubCommand ns (connId, cmd)
deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m ()
deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do
ns <- asks ntfSupervisor
forM_ ntfTokenId $ \tknId -> do
let ntfTknAction = Just NTADelete
withStore' c $ \db -> updateNtfToken db tkn ntfTknStatus ntfTknAction
atomically $ nsUpdateToken ns tkn {ntfTknStatus, ntfTknAction}
agentNtfDeleteToken c tknId tkn `catchAgentError` \case
NTF AUTH -> pure ()
e -> throwError e
withStore' c $ \db -> removeNtfToken db tkn
atomically $ nsRemoveNtfToken ns
withToken :: AgentMonad m => AgentClient -> NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m NtfTknStatus
withToken c tkn@NtfToken {deviceToken, ntfMode} from_ (toStatus, toAction_) f = do
ns <- asks ntfSupervisor
forM_ from_ $ \(status, action) -> do
withStore' c $ \db -> updateNtfToken db tkn status (Just action)
atomically $ nsUpdateToken ns tkn {ntfTknStatus = status, ntfTknAction = Just action}
tryError f >>= \case
Right _ -> do
withStore' c $ \db -> updateNtfToken db tkn toStatus toAction_
let updatedToken = tkn {ntfTknStatus = toStatus, ntfTknAction = toAction_}
atomically $ nsUpdateToken ns updatedToken
pure toStatus
Left e@(NTF AUTH) -> do
withStore' c $ \db -> removeNtfToken db tkn
atomically $ nsRemoveNtfToken ns
void $ registerNtfToken' c deviceToken ntfMode
throwError e
Left e -> throwError e
initializeNtfSubs :: AgentMonad m => AgentClient -> m ()
initializeNtfSubs c = sendNtfConnCommands c NSCCreate
deleteNtfSubs :: AgentMonad m => AgentClient -> NtfSupervisorCommand -> m ()
deleteNtfSubs c deleteCmd = do
ns <- asks ntfSupervisor
void . atomically . flushTBQueue $ ntfSubQ ns
sendNtfConnCommands c deleteCmd
sendNtfConnCommands :: AgentMonad m => AgentClient -> NtfSupervisorCommand -> m ()
sendNtfConnCommands c cmd = do
ns <- asks ntfSupervisor
connIds <- atomically $ getSubscriptions c
forM_ connIds $ \connId -> do
withStore' c (`getConnData` connId) >>= \case
Just (ConnData {enableNtfs}, _) ->
when enableNtfs . atomically $ writeTBQueue (ntfSubQ ns) (connId, cmd)
_ ->
atomically $ writeTBQueue (subQ c) ("", connId, APC SAEConn $ ERR $ INTERNAL "no connection data")
setNtfServers' :: AgentMonad' m => AgentClient -> [NtfServer] -> m ()
setNtfServers' c = atomically . writeTVar (ntfServers c)
foregroundAgent' :: AgentMonad' m => AgentClient -> m ()
foregroundAgent' c = do
atomically $ writeTVar (agentState c) ASForeground
mapM_ activate $ reverse agentOperations
where
activate opSel = atomically $ modifyTVar' (opSel c) $ \s -> s {opSuspended = False}
suspendAgent' :: AgentMonad' m => AgentClient -> Int -> m ()
suspendAgent' c 0 = do
atomically $ writeTVar (agentState c) ASSuspended
mapM_ suspend agentOperations
where
suspend opSel = atomically $ modifyTVar' (opSel c) $ \s -> s {opSuspended = True}
suspendAgent' c@AgentClient {agentState = as} maxDelay = do
state <-
atomically $ do
writeTVar as ASSuspending
suspendOperation c AONtfNetwork $ pure ()
suspendOperation c AORcvNetwork $
suspendOperation c AOMsgDelivery $
suspendSendingAndDatabase c
readTVar as
when (state == ASSuspending) . void . forkIO $ do
threadDelay maxDelay
-- liftIO $ putStrLn "suspendAgent after timeout"
atomically . whenSuspending c $ do
-- unsafeIOToSTM $ putStrLn $ "in timeout: suspendSendingAndDatabase"
suspendSendingAndDatabase c
execAgentStoreSQL' :: AgentMonad m => AgentClient -> Text -> m [Text]
execAgentStoreSQL' c sql = withStore' c (`execSQL` sql)
getAgentMigrations' :: AgentMonad m => AgentClient -> m [UpMigration]
getAgentMigrations' c = map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn)
debugAgentLocks' :: AgentMonad' m => AgentClient -> m AgentLocks
debugAgentLocks' AgentClient {connLocks = cs, invLocks = is, reconnectLocks = rs, deleteLock = d} = do
connLocks <- getLocks cs
invLocks <- getLocks is
srvLocks <- getLocks rs
delLock <- atomically $ tryReadTMVar d
pure AgentLocks {connLocks, invLocks, srvLocks, delLock}
where
getLocks ls = atomically $ M.mapKeys (B.unpack . strEncode) . M.mapMaybe id <$> (mapM tryReadTMVar =<< readTVar ls)
getSMPServer :: AgentMonad m => AgentClient -> UserId -> m SMPServerWithAuth
getSMPServer c userId = withUserServers c userId pickServer
subscriber :: AgentMonad' m => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
t <- atomically $ readTBQueue msgQ
agentOperationBracket c AORcvNetwork waitUntilActive $
runExceptT (processSMPTransmission c t) >>= \case
Left e -> liftIO $ print e
Right _ -> return ()
cleanupManager :: forall m. AgentMonad' m => AgentClient -> m ()
cleanupManager c@AgentClient {subQ} = do
delay <- asks (initialCleanupDelay . config)
liftIO $ threadDelay' delay
int <- asks (cleanupInterval . config)
ttl <- asks $ storedMsgDataTTL . config
forever $ do
run ERR deleteConns
run ERR $ withStore' c (`deleteRcvMsgHashesExpired` ttl)
run ERR $ withStore' c (`deleteSndMsgsExpired` ttl)
run ERR $ withStore' c (`deleteRatchetKeyHashesExpired` ttl)
run RFERR deleteRcvFilesExpired
run RFERR deleteRcvFilesDeleted
run RFERR deleteRcvFilesTmpPaths
run SFERR deleteSndFilesExpired
run SFERR deleteSndFilesDeleted
run SFERR deleteSndFilesPrefixPaths
run SFERR deleteExpiredReplicasForDeletion
liftIO $ threadDelay' int
where
run :: forall e. AEntityI e => (AgentErrorType -> ACommand 'Agent e) -> ExceptT AgentErrorType m () -> m ()
run err a = do
void . runExceptT $ a `catchAgentError` (notify "" . err)
step <- asks $ cleanupStepInterval . config
liftIO $ threadDelay step
deleteConns =
withLock (deleteLock c) "cleanupManager" $ do
void $ withStore' c getDeletedConnIds >>= deleteDeletedConns c
withStore' c deleteUsersWithoutConns >>= mapM_ (notify "" . DEL_USER)
deleteRcvFilesExpired = do
rcvFilesTTL <- asks $ rcvFilesTTL . config
rcvExpired <- withStore' c (`getRcvFilesExpired` rcvFilesTTL)
forM_ rcvExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
removePath =<< toFSFilePath p
withStore' c (`deleteRcvFile'` dbId)
deleteRcvFilesDeleted = do
rcvDeleted <- withStore' c getCleanupRcvFilesDeleted
forM_ rcvDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
removePath =<< toFSFilePath p
withStore' c (`deleteRcvFile'` dbId)
deleteRcvFilesTmpPaths = do
rcvTmpPaths <- withStore' c getCleanupRcvFilesTmpPaths
forM_ rcvTmpPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . RFERR) $ do
removePath =<< toFSFilePath p
withStore' c (`updateRcvFileNoTmpPath` dbId)
deleteSndFilesExpired = do
sndFilesTTL <- asks $ sndFilesTTL . config
sndExpired <- withStore' c (`getSndFilesExpired` sndFilesTTL)
forM_ sndExpired $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
forM_ p $ removePath <=< toFSFilePath
withStore' c (`deleteSndFile'` dbId)
deleteSndFilesDeleted = do
sndDeleted <- withStore' c getCleanupSndFilesDeleted
forM_ sndDeleted $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
forM_ p $ removePath <=< toFSFilePath
withStore' c (`deleteSndFile'` dbId)
deleteSndFilesPrefixPaths = do
sndPrefixPaths <- withStore' c getCleanupSndFilesPrefixPaths
forM_ sndPrefixPaths $ \(dbId, entId, p) -> flip catchAgentError (notify entId . SFERR) $ do
removePath =<< toFSFilePath p
withStore' c (`updateSndFileNoPrefixPath` dbId)
deleteExpiredReplicasForDeletion = do
rcvFilesTTL <- asks $ rcvFilesTTL . config
withStore' c (`deleteDeletedSndChunkReplicasExpired` rcvFilesTTL)
notify :: forall e. AEntityI e => EntityId -> ACommand 'Agent e -> ExceptT AgentErrorType m ()
notify entId cmd = atomically $ writeTBQueue subQ ("", entId, APC (sAEntity @e) cmd)
-- | make sure to ACK or throw in each message processing branch
-- it cannot be finally, unfortunately, as sometimes it needs to be ACK+DEL
processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m ()
processSMPTransmission c@AgentClient {smpClients, subQ} (tSess@(_, srv, _), v, sessId, rId, cmd) = do
(rq, SomeConn _ conn) <- withStore c (\db -> getRcvConn db srv rId)
processSMP rq conn $ toConnData conn
where
processSMP :: forall c. RcvQueue -> Connection c -> ConnData -> m ()
processSMP
rq@RcvQueue {e2ePrivKey, e2eDhSecret, status}
conn
cData@ConnData {userId, connId, duplexHandshake, connAgentVersion, ratchetSyncState = rss} =
withConnLock c connId "processSMP" $ case cmd of
SMP.MSG msg@SMP.RcvMessage {msgId = srvMsgId} ->
handleNotifyAck $
decryptSMPMessage v rq msg >>= \case
SMP.ClientRcvMsgBody {msgTs = srvTs, msgFlags, msgBody} -> processClientMsg srvTs msgFlags msgBody
SMP.ClientRcvMsgQuota {} -> queueDrained >> ack
where
queueDrained = case conn of
DuplexConnection _ _ sqs -> void $ enqueueMessages c cData sqs SMP.noMsgFlags $ QCONT (sndAddress rq)
_ -> pure ()
processClientMsg srvTs msgFlags msgBody = do
clientMsg@SMP.ClientMsgEnvelope {cmHeader = SMP.PubHeader phVer e2ePubKey_} <-
parseMessage msgBody
clientVRange <- asks $ smpClientVRange . config
unless (phVer `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case (e2eDhSecret, e2ePubKey_) of
(Nothing, Just e2ePubKey) -> do
let e2eDh = C.dh' e2ePubKey e2ePrivKey
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHConfirmation senderKey, AgentConfirmation {e2eEncryption_, encConnInfo, agentVersion}) ->
smpConfirmation conn senderKey e2ePubKey e2eEncryption_ encConnInfo phVer agentVersion >> ack
(SMP.PHEmpty, AgentInvitation {connReq, connInfo}) ->
smpInvitation conn connReq connInfo >> ack
_ -> prohibited >> ack
(Just e2eDh, Nothing) -> do
decryptClientMessage e2eDh clientMsg >>= \case
(SMP.PHEmpty, AgentRatchetKey {agentVersion, e2eEncryption}) -> do
conn' <- updateConnVersion conn cData agentVersion
qDuplex conn' "AgentRatchetKey" $ newRatchetKey e2eEncryption
ack
(SMP.PHEmpty, AgentMsgEnvelope {agentVersion, encAgentMessage}) -> do
conn' <- updateConnVersion conn cData agentVersion
-- primary queue is set as Active in helloMsg, below is to set additional queues Active
let RcvQueue {primary, dbReplaceQueueId} = rq
unless (status == Active) . withStore' c $ \db -> setRcvQueueStatus db rq Active
case (conn', dbReplaceQueueId) of
(DuplexConnection _ rqs _, Just replacedId) -> do
when primary . withStore' c $ \db -> setRcvQueuePrimary db connId rq
case find (\q -> replacedId == q.dbQueueId) rqs of
Just rq'@RcvQueue {server, rcvId} -> do
checkRQSwchStatus rq' RSSendingQUSE
void $ withStore' c $ \db -> setRcvSwitchStatus db rq' $ Just RSReceivedMessage
enqueueCommand c "" connId (Just server) $ AInternalCommand $ ICQDelete rcvId
_ -> notify . ERR . AGENT $ A_QUEUE "replaced RcvQueue not found in connection"
_ -> pure ()
let encryptedMsgHash = C.sha256Hash encAgentMessage
tryError (agentClientMsg encryptedMsgHash) >>= \case
Right (Just (msgId, msgMeta, aMessage, rcPrev)) -> do
conn'' <- resetRatchetSync
case aMessage of
HELLO -> helloMsg conn'' >> ackDel msgId
REPLY cReq -> replyMsg conn'' cReq >> ackDel msgId
-- note that there is no ACK sent for A_MSG, it is sent with agent's user ACK command
A_MSG body -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
A_RCVD rcpts -> qDuplex conn'' "RCVD" $ messagesRcvd rcpts msgMeta
QCONT addr -> qDuplexAckDel conn'' "QCONT" $ continueSending addr
QADD qs -> qDuplexAckDel conn'' "QADD" $ qAddMsg qs
QKEY qs -> qDuplexAckDel conn'' "QKEY" $ qKeyMsg qs
QUSE qs -> qDuplexAckDel conn'' "QUSE" $ qUseMsg qs
-- no action needed for QTEST
-- any message in the new queue will mark it active and trigger deletion of the old queue
QTEST _ -> logServer "<--" c srv rId "MSG <QTEST>" >> ackDel msgId
EREADY _ -> qDuplexAckDel conn'' "EREADY" $ ereadyMsg rcPrev
where
qDuplexAckDel :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplexAckDel conn'' name a = qDuplex conn'' name a >> ackDel msgId
resetRatchetSync :: m (Connection c)
resetRatchetSync
| rss `notElem` ([RSOk, RSStarted] :: [RatchetSyncState]) = do
let cData'' = (toConnData conn') {ratchetSyncState = RSOk} :: ConnData
conn'' = updateConnection cData'' conn'
notify . RSYNC RSOk Nothing $ connectionStats conn''
withStore' c $ \db -> setConnRatchetSync db connId RSOk
pure conn''
| otherwise = pure conn'
Right _ -> prohibited >> ack
Left e@(AGENT A_DUPLICATE) -> do
withStoreCtx' "processSMP: getLastMsg" c (\db -> getLastMsg db connId srvMsgId) >>= \case
Just RcvMsg {internalId, msgMeta, msgBody = agentMsgBody, userAck}
| userAck -> ackDel internalId
| otherwise -> do
liftEither (parse smpP (AGENT A_MESSAGE) agentMsgBody) >>= \case
AgentMessage _ (A_MSG body) -> do
logServer "<--" c srv rId "MSG <MSG>"
notify $ MSG msgMeta msgFlags body
_ -> pure ()
_ -> checkDuplicateHash e encryptedMsgHash >> ack
Left (AGENT (A_CRYPTO e)) -> do
exists <- withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash
unless exists notifySync
ack
where
notifySync :: m ()
notifySync = qDuplex conn' "AGENT A_CRYPTO error" $ \connDuplex -> do
let rss' = cryptoErrToSyncState e
when (rss `elem` ([RSOk, RSAllowed, RSRequired] :: [RatchetSyncState])) $ do
let cData'' = (toConnData conn') {ratchetSyncState = rss'} :: ConnData
conn'' = updateConnection cData'' connDuplex
notify . RSYNC rss' (Just e) $ connectionStats conn''
withStore' c $ \db -> setConnRatchetSync db connId rss'
Left e -> checkDuplicateHash e encryptedMsgHash >> ack
where
checkDuplicateHash :: AgentErrorType -> ByteString -> m ()
checkDuplicateHash e encryptedMsgHash =
unlessM (withStore' c $ \db -> checkRcvMsgHashExists db connId encryptedMsgHash) $
throwError e
agentClientMsg :: ByteString -> m (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448))
agentClientMsg encryptedMsgHash = withStore c $ \db -> runExceptT $ do
rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY
agentMsgBody <- agentRatchetDecrypt' db connId rc encAgentMessage
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
let msgType = agentMessageType agentMsg
internalHash = C.sha256Hash agentMsgBody
internalTs <- liftIO getCurrentTime
(internalId, internalRcvId, prevExtSndId, prevRcvMsgHash) <- liftIO $ updateRcvIds db connId
let integrity = checkMsgIntegrity prevExtSndId sndMsgId prevRcvMsgHash prevMsgHash
recipient = (unId internalId, internalTs)
broker = (srvMsgId, systemToUTCTime srvTs)
msgMeta = MsgMeta {integrity, recipient, broker, sndMsgId}
rcvMsg = RcvMsgData {msgMeta, msgType, msgFlags, msgBody = agentMsgBody, internalRcvId, internalHash, externalPrevSndHash = prevMsgHash, encryptedMsgHash}
liftIO $ createRcvMsg db connId rq rcvMsg
pure $ Just (internalId, msgMeta, aMessage, rc)
_ -> pure Nothing
_ -> prohibited >> ack
_ -> prohibited >> ack
updateConnVersion :: Connection c -> ConnData -> Version -> m (Connection c)
updateConnVersion conn' cData' msgAgentVersion = do
aVRange <- asks $ smpAgentVRange . config
let msgAVRange = fromMaybe (versionToRange msgAgentVersion) $ safeVersionRange (minVersion aVRange) msgAgentVersion
case msgAVRange `compatibleVersion` aVRange of
Just (Compatible av)
| av > connAgentVersion -> do
withStore' c $ \db -> setConnAgentVersion db connId av
let cData'' = cData' {connAgentVersion = av} :: ConnData
pure $ updateConnection cData'' conn'
| otherwise -> pure conn'
Nothing -> pure conn'
ack :: m ()
ack = enqueueCmd $ ICAck rId srvMsgId
ackDel :: InternalId -> m ()
ackDel = enqueueCmd . ICAckDel rId srvMsgId
handleNotifyAck :: m () -> m ()
handleNotifyAck m = m `catchAgentError` \e -> notify (ERR e) >> ack
SMP.END ->
atomically (TM.lookup tSess smpClients $>>= tryReadTMVar >>= processEND)
>>= logServer "<--" c srv rId
where
processEND = \case
Just (Right clnt)
| sessId == sessionId clnt -> do
removeSubscription c connId
notify' END
pure "END"
| otherwise -> ignored
_ -> ignored
ignored = pure "END from disconnected client - ignored"
_ -> do
logServer "<--" c srv rId $ "unexpected: " <> bshow cmd
notify . ERR $ BROKER (B.unpack $ strEncode srv) UNEXPECTED
where
notify :: forall e. AEntityI e => ACommand 'Agent e -> m ()
notify = atomically . notify'
notify' :: forall e. AEntityI e => ACommand 'Agent e -> STM ()
notify' msg = writeTBQueue subQ ("", connId, APC (sAEntity @e) msg)
prohibited :: m ()
prohibited = notify . ERR $ AGENT A_PROHIBITED
enqueueCmd :: InternalCommand -> m ()
enqueueCmd = enqueueCommand c "" connId (Just srv) . AInternalCommand
decryptClientMessage :: C.DhSecretX25519 -> SMP.ClientMsgEnvelope -> m (SMP.PrivHeader, AgentMsgEnvelope)
decryptClientMessage e2eDh SMP.ClientMsgEnvelope {cmNonce, cmEncBody} = do
clientMsg <- agentCbDecrypt e2eDh cmNonce cmEncBody
SMP.ClientMessage privHeader clientBody <- parseMessage clientMsg
agentEnvelope <- parseMessage clientBody
-- Version check is removed here, because when connecting via v1 contact address the agent still sends v2 message,
-- to allow duplexHandshake mode, in case the receiving agent was updated to v2 after the address was created.
-- aVRange <- asks $ smpAgentVRange . config
-- if agentVersion agentEnvelope `isCompatible` aVRange
-- then pure (privHeader, agentEnvelope)
-- else throwError $ AGENT A_VERSION
pure (privHeader, agentEnvelope)
parseMessage :: Encoding a => ByteString -> m a
parseMessage = liftEither . parse smpP (AGENT A_MESSAGE)
smpConfirmation :: Connection c -> C.APublicVerifyKey -> C.PublicKeyX25519 -> Maybe (CR.E2ERatchetParams 'C.X448) -> ByteString -> Version -> Version -> m ()
smpConfirmation conn' senderKey e2ePubKey e2eEncryption encConnInfo smpClientVersion agentVersion = do
logServer "<--" c srv rId "MSG <CONF>"
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
unless
(agentVersion `isCompatible` smpAgentVRange && smpClientVersion `isCompatible` smpClientVRange)
(throwError $ AGENT A_VERSION)
case status of
New -> case (conn', e2eEncryption) of
-- party initiating connection
(RcvConnection {}, Just e2eSndParams@(CR.E2ERatchetParams e2eVersion _ _)) -> do
unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION)
(pk1, rcDHRs) <- withStore c (`getRatchetX3dhKeys` connId)
let rc = CR.initRcvRatchet e2eEncryptVRange rcDHRs $ CR.x3dhRcv pk1 rcDHRs e2eSndParams
(agentMsgBody_, rc', skipped) <- liftError cryptoError $ CR.rcDecrypt rc M.empty encConnInfo
case (agentMsgBody_, skipped) of
(Right agentMsgBody, CR.SMDNoChange) ->
parseMessage agentMsgBody >>= \case
AgentConnInfo connInfo ->
processConf connInfo SMPConfirmation {senderKey, e2ePubKey, connInfo, smpReplyQueues = [], smpClientVersion} False
AgentConnInfoReply smpQueues connInfo ->
processConf connInfo SMPConfirmation {senderKey, e2ePubKey, connInfo, smpReplyQueues = L.toList smpQueues, smpClientVersion} True
_ -> prohibited
where
processConf connInfo senderConf duplexHS = do
let newConfirmation = NewConfirmation {connId, senderConf, ratchetState = rc'}
g <- asks idsDrg
confId <- withStore c $ \db -> do
setHandshakeVersion db connId agentVersion duplexHS
createConfirmation db g newConfirmation
let srvs = map qServer $ smpReplyQueues senderConf
notify $ CONF confId srvs connInfo
_ -> prohibited
-- party accepting connection
(DuplexConnection _ (RcvQueue {smpClientVersion = v'} :| _) _, Nothing) -> do
withStore c (\db -> runExceptT $ agentRatchetDecrypt db connId encConnInfo) >>= parseMessage >>= \case
AgentConnInfo connInfo -> do
notify $ INFO connInfo
let dhSecret = C.dh' e2ePubKey e2ePrivKey
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq dhSecret $ min v' smpClientVersion
enqueueCmd $ ICDuplexSecure rId senderKey
_ -> prohibited
_ -> prohibited
_ -> prohibited
helloMsg :: Connection c -> m ()
helloMsg conn' = do
logServer "<--" c srv rId "MSG <HELLO>"
case status of
Active -> prohibited
_ ->
case conn' of
DuplexConnection _ _ (sq@SndQueue {status = sndStatus} :| _)
-- `sndStatus == Active` when HELLO was previously sent, and this is the reply HELLO
-- this branch is executed by the accepting party in duplexHandshake mode (v2)
-- and by the initiating party in v1
-- Also see comment where HELLO is sent.
| sndStatus == Active -> notify CON
| duplexHandshake == Just True -> enqueueDuplexHello sq
| otherwise -> pure ()
_ -> pure ()
where
enqueueDuplexHello :: SndQueue -> m ()
enqueueDuplexHello sq = do
let cData' = toConnData conn'
void $ enqueueMessage c cData' sq SMP.MsgFlags {notification = True} HELLO
replyMsg :: Connection c -> NonEmpty SMPQueueInfo -> m ()
replyMsg conn' smpQueues = do
logServer "<--" c srv rId "MSG <REPLY>"
case duplexHandshake of
Just True -> prohibited
_ -> case conn' of
RcvConnection {} -> do
AcceptedConfirmation {ownConnInfo} <- withStore c (`getAcceptedConfirmation` connId)
let cData' = toConnData conn'
connectReplyQueues c cData' ownConnInfo smpQueues `catchAgentError` (notify . ERR)
_ -> prohibited
continueSending :: (SMPServer, SMP.SenderId) -> Connection 'CDuplex -> m ()
continueSending addr (DuplexConnection _ _ sqs) =
case findQ addr sqs of
Just sq -> do
logServer "<--" c srv rId "MSG <QCONT>"
atomically $ do
(_, qLock) <- getPendingMsgQ c sq
void $ tryPutTMVar qLock ()
Nothing -> qError "QCONT: queue address not found"
messagesRcvd :: NonEmpty AMessageReceipt -> MsgMeta -> Connection 'CDuplex -> m ()
messagesRcvd rcpts msgMeta@MsgMeta {broker = (srvMsgId, _)} _ = do
logServer "<--" c srv rId "MSG <RCPT>"
rs <- forM rcpts $ \rcpt -> clientReceipt rcpt `catchAgentError` \e -> notify (ERR e) $> Nothing
case L.nonEmpty . catMaybes $ L.toList rs of
Just rs' -> notify $ RCVD msgMeta rs' -- client must ACK once processed
Nothing -> enqueueCmd $ ICAck rId srvMsgId
where
clientReceipt :: AMessageReceipt -> m (Maybe MsgReceipt)
clientReceipt AMessageReceipt {agentMsgId, msgHash} = do
let sndMsgId = InternalSndId agentMsgId
SndMsg {internalId = InternalId msgId, msgType, internalHash, msgReceipt} <- withStoreCtx "messagesRcvd: getSndMsgViaRcpt" c $ \db -> getSndMsgViaRcpt db connId sndMsgId
if msgType /= AM_A_MSG_
then notify (ERR $ AGENT A_PROHIBITED) $> Nothing -- unexpected message type for receipt
else case msgReceipt of
Just MsgReceipt {msgRcptStatus = MROk} -> pure Nothing -- already notified with MROk status
_ -> do
let msgRcptStatus = if msgHash == internalHash then MROk else MRBadMsgHash
rcpt = MsgReceipt {agentMsgId = msgId, msgRcptStatus}
withStore' c $ \db -> updateSndMsgRcpt db connId sndMsgId rcpt
pure $ Just rcpt
-- processed by queue sender
qAddMsg :: NonEmpty (SMPQueueUri, Maybe SndQAddr) -> Connection 'CDuplex -> m ()
qAddMsg ((_, Nothing) :| _) _ = qError "adding queue without switching is not supported"
qAddMsg ((qUri, Just addr) :| _) (DuplexConnection cData' rqs sqs) = do
when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized")
clientVRange <- asks $ smpClientVRange . config
case qUri `compatibleVersion` clientVRange of
Just qInfo@(Compatible sqInfo@SMPQueueInfo {queueAddress}) ->
case (findQ (qAddress sqInfo) sqs, findQ addr sqs) of
(Just _, _) -> qError "QADD: queue address is already used in connection"
(_, Just sq@SndQueue {dbQueueId}) -> do
let (delSqs, keepSqs) = L.partition (\q -> Just dbQueueId == q.dbReplaceQueueId) sqs
case L.nonEmpty keepSqs of
Just sqs' -> do
-- move inside case?
withStore' c $ \db -> mapM_ (deleteConnSndQueue db connId) delSqs
sq_@SndQueue {sndPublicKey, e2ePubKey} <- newSndQueue userId connId qInfo
let sq'' = (sq_ :: SndQueue) {primary = True, dbQueueId, dbReplaceQueueId = Just dbQueueId}
dbId <- withStore c $ \db -> addConnSndQueue db connId sq''
let sq2 = (sq'' :: SndQueue) {dbQueueId = dbId}
case (sndPublicKey, e2ePubKey) of
(Just sndPubKey, Just dhPublicKey) -> do
logServer "<--" c srv rId $ "MSG <QADD> " <> logSecret (senderId queueAddress)
let sqInfo' = (sqInfo :: SMPQueueInfo) {queueAddress = queueAddress {dhPublicKey}}
void . enqueueMessages c cData' sqs SMP.noMsgFlags $ QKEY [(sqInfo', sndPubKey)]
sq1 <- withStore' c $ \db -> setSndSwitchStatus db sq $ Just SSSendingQKEY
let sqs'' = updatedQs sq1 sqs' <> [sq2]
conn' = DuplexConnection cData' rqs sqs''
notify . SWITCH QDSnd SPStarted $ connectionStats conn'
_ -> qError "absent sender keys"
_ -> qError "QADD: won't delete all snd queues in connection"
_ -> qError "QADD: replaced queue address is not found in connection"
_ -> throwError $ AGENT A_VERSION
-- processed by queue recipient
qKeyMsg :: NonEmpty (SMPQueueInfo, SndPublicVerifyKey) -> Connection 'CDuplex -> m ()
qKeyMsg ((qInfo, senderKey) :| _) conn'@(DuplexConnection cData' rqs _) = do
when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized")
clientVRange <- asks $ smpClientVRange . config
unless (qInfo `isCompatible` clientVRange) . throwError $ AGENT A_VERSION
case findRQ (smpServer, senderId) rqs of
Just rq'@RcvQueue {rcvId, e2ePrivKey = dhPrivKey, smpClientVersion = cVer, status = status'}
| status' == New || status' == Confirmed -> do
checkRQSwchStatus rq RSSendingQADD
logServer "<--" c srv rId $ "MSG <QKEY> " <> logSecret senderId
let dhSecret = C.dh' dhPublicKey dhPrivKey
withStore' c $ \db -> setRcvQueueConfirmedE2E db rq' dhSecret $ min cVer cVer'
enqueueCommand c "" connId (Just smpServer) $ AInternalCommand $ ICQSecure rcvId senderKey
notify . SWITCH QDRcv SPConfirmed $ connectionStats conn'
| otherwise -> qError "QKEY: queue already secured"
_ -> qError "QKEY: queue address not found in connection"
where
SMPQueueInfo cVer' SMPQueueAddress {smpServer, senderId, dhPublicKey} = qInfo
-- processed by queue sender
-- mark queue as Secured and to start sending messages to it
qUseMsg :: NonEmpty ((SMPServer, SMP.SenderId), Bool) -> Connection 'CDuplex -> m ()
-- NOTE: does not yet support the change of the primary status during the rotation
qUseMsg ((addr, _primary) :| _) (DuplexConnection cData' rqs sqs) = do
when (ratchetSyncSendProhibited cData') $ throwError $ AGENT (A_QUEUE "ratchet is not synchronized")
case findQ addr sqs of
Just sq'@SndQueue {dbReplaceQueueId = Just replaceQId} -> do
case find (\q -> replaceQId == q.dbQueueId) sqs of
Just sq1 -> do
checkSQSwchStatus sq1 SSSendingQKEY
logServer "<--" c srv rId $ "MSG <QUSE> " <> logSecret (snd addr)
withStore' c $ \db -> setSndQueueStatus db sq' Secured
let sq'' = (sq' :: SndQueue) {status = Secured}
-- sending QTEST to the new queue only, the old one will be removed if sent successfully
void $ enqueueMessages c cData' [sq''] SMP.noMsgFlags $ QTEST [addr]
sq1' <- withStore' c $ \db -> setSndSwitchStatus db sq1 $ Just SSSendingQTEST
let sqs' = updatedQs sq1' sqs
conn' = DuplexConnection cData' rqs sqs'
notify . SWITCH QDSnd SPSecured $ connectionStats conn'
_ -> qError "QUSE: switching SndQueue not found in connection"
_ -> qError "QUSE: switched queue address not found in connection"
qError :: String -> m ()
qError = throwError . AGENT . A_QUEUE
ereadyMsg :: CR.RatchetX448 -> Connection 'CDuplex -> m ()
ereadyMsg rcPrev (DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) = do
let CR.Ratchet {rcSnd} = rcPrev
-- if ratchet was initialized as receiving, it means EREADY wasn't sent on key negotiation
when (isNothing rcSnd) . void $
enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} (EREADY lastExternalSndId)
smpInvitation :: Connection c -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
smpInvitation conn' connReq@(CRInvitationUri crData _) cInfo = do
logServer "<--" c srv rId "MSG <KEY>"
case conn' of
ContactConnection {} -> do
g <- asks idsDrg
let newInv = NewInvitation {contactConnId = connId, connReq, recipientConnInfo = cInfo}
invId <- withStore c $ \db -> createInvitation db g newInv
let srvs = L.map qServer $ crSmpQueues crData
notify $ REQ invId srvs cInfo
_ -> prohibited
qDuplex :: Connection c -> String -> (Connection 'CDuplex -> m ()) -> m ()
qDuplex conn' name action = case conn' of
DuplexConnection {} -> action conn'
_ -> qError $ name <> ": message must be sent to duplex connection"
newRatchetKey :: CR.E2ERatchetParams 'C.X448 -> Connection 'CDuplex -> m ()
newRatchetKey e2eOtherPartyParams@(CR.E2ERatchetParams e2eVersion k1Rcv k2Rcv) conn'@(DuplexConnection cData'@ConnData {lastExternalSndId} _ sqs) =
unlessM ratchetExists $ do
AgentConfig {e2eEncryptVRange} <- asks config
unless (e2eVersion `isCompatible` e2eEncryptVRange) (throwError $ AGENT A_VERSION)
keys <- getSendRatchetKeys
initRatchet e2eEncryptVRange keys
notifyAgreed
where
rkHashRcv = rkHash k1Rcv k2Rcv
rkHash k1 k2 = C.sha256Hash $ C.pubKeyBytes k1 <> C.pubKeyBytes k2
ratchetExists :: m Bool
ratchetExists = withStore' c $ \db -> do
exists <- checkRatchetKeyHashExists db connId rkHashRcv
unless exists $ addProcessedRatchetKeyHash db connId rkHashRcv
pure exists
getSendRatchetKeys :: m (C.PrivateKeyX448, C.PrivateKeyX448, C.PublicKeyX448, C.PublicKeyX448)
getSendRatchetKeys
| rss == RSStarted = withStore c (`getRatchetX3dhKeys'` connId)
| otherwise = do
(pk1, pk2, e2eParams@(CR.E2ERatchetParams _ k1 k2)) <- liftIO . CR.generateE2EParams $ version e2eOtherPartyParams
void $ enqueueRatchetKeyMsgs c cData' sqs e2eParams
pure (pk1, pk2, k1, k2)
notifyAgreed :: m ()
notifyAgreed = do
let cData'' = cData' {ratchetSyncState = RSAgreed} :: ConnData
conn'' = updateConnection cData'' conn'
notify . RSYNC RSAgreed Nothing $ connectionStats conn''
recreateRatchet :: CR.Ratchet 'C.X448 -> m ()
recreateRatchet rc = withStore' c $ \db -> do
setConnRatchetSync db connId RSAgreed
deleteRatchet db connId
createRatchet db connId rc
-- compare public keys `k1` in AgentRatchetKey messages sent by self and other party
-- to determine ratchet initilization ordering
initRatchet :: VersionRange -> (C.PrivateKeyX448, C.PrivateKeyX448, C.PublicKeyX448, C.PublicKeyX448) -> m ()
initRatchet e2eEncryptVRange (pk1, pk2, k1, k2)
| rkHash k1 k2 <= rkHashRcv = do
recreateRatchet $ CR.initRcvRatchet e2eEncryptVRange pk2 $ CR.x3dhRcv pk1 pk2 e2eOtherPartyParams
| otherwise = do
(_, rcDHRs) <- liftIO C.generateKeyPair'
recreateRatchet $ CR.initSndRatchet e2eEncryptVRange k2Rcv rcDHRs $ CR.x3dhSnd pk1 pk2 e2eOtherPartyParams
void . enqueueMessages' c cData' sqs SMP.MsgFlags {notification = True} $ EREADY lastExternalSndId
checkMsgIntegrity :: PrevExternalSndId -> ExternalSndId -> PrevRcvMsgHash -> ByteString -> MsgIntegrity
checkMsgIntegrity prevExtSndId extSndId internalPrevMsgHash receivedPrevMsgHash
| extSndId == prevExtSndId + 1 && internalPrevMsgHash == receivedPrevMsgHash = MsgOk
| extSndId < prevExtSndId = MsgError $ MsgBadId extSndId
| extSndId == prevExtSndId = MsgError MsgDuplicate -- ? deduplicate
| extSndId > prevExtSndId + 1 = MsgError $ MsgSkipped (prevExtSndId + 1) (extSndId - 1)
| internalPrevMsgHash /= receivedPrevMsgHash = MsgError MsgBadHash
| otherwise = MsgError MsgDuplicate -- this case is not possible
checkRQSwchStatus :: AgentMonad m => RcvQueue -> RcvSwitchStatus -> m ()
checkRQSwchStatus rq@RcvQueue {rcvSwchStatus} expected =
unless (rcvSwchStatus == Just expected) $ switchStatusError rq expected rcvSwchStatus
checkSQSwchStatus :: AgentMonad m => SndQueue -> SndSwitchStatus -> m ()
checkSQSwchStatus sq@SndQueue {sndSwchStatus} expected =
unless (sndSwchStatus == Just expected) $ switchStatusError sq expected sndSwchStatus
switchStatusError :: (SMPQueueRec q, AgentMonad m, Show a) => q -> a -> Maybe a -> m ()
switchStatusError q expected actual =
throwError . INTERNAL $
("unexpected switch status, queueId=" <> show (queueId q))
<> (", expected=" <> show expected)
<> (", actual=" <> show actual)
connectReplyQueues :: AgentMonad m => AgentClient -> ConnData -> ConnInfo -> NonEmpty SMPQueueInfo -> m ()
connectReplyQueues c cData@ConnData {userId, connId} ownConnInfo (qInfo :| _) = do
clientVRange <- asks $ smpClientVRange . config
case qInfo `proveCompatible` clientVRange of
Nothing -> throwError $ AGENT A_VERSION
Just qInfo' -> do
sq <- newSndQueue userId connId qInfo'
dbQueueId <- withStore c $ \db -> upgradeRcvConnToDuplex db connId sq
enqueueConfirmation c cData sq {dbQueueId} ownConnInfo Nothing
confirmQueueAsync :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m ()
confirmQueueAsync v c cData sq srv connInfo e2eEncryption_ subMode = do
resumeMsgDelivery c cData sq
msgId <- storeConfirmation c cData sq e2eEncryption_ =<< mkAgentConfirmation v c cData sq srv connInfo subMode
queuePendingMsgs c sq [msgId]
confirmQueue :: forall m. AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> SubscriptionMode -> m ()
confirmQueue v@(Compatible agentVersion) c cData@ConnData {connId} sq srv connInfo e2eEncryption_ subMode = do
msg <- mkConfirmation =<< mkAgentConfirmation v c cData sq srv connInfo subMode
sendConfirmation c sq msg
withStore' c $ \db -> setSndQueueStatus db sq Confirmed
where
mkConfirmation :: AgentMessage -> m MsgBody
mkConfirmation aMessage = withStore c $ \db -> runExceptT $ do
void . liftIO $ updateSndIds db connId
encConnInfo <- agentRatchetEncrypt db connId (smpEncode aMessage) e2eEncConnInfoLength
pure . smpEncode $ AgentConfirmation {agentVersion, e2eEncryption_, encConnInfo}
mkAgentConfirmation :: AgentMonad m => Compatible Version -> AgentClient -> ConnData -> SndQueue -> SMPServerWithAuth -> ConnInfo -> SubscriptionMode -> m AgentMessage
mkAgentConfirmation (Compatible agentVersion) c cData sq srv connInfo subMode
| agentVersion == 1 = pure $ AgentConnInfo connInfo
| otherwise = do
qInfo <- createReplyQueue c cData sq subMode srv
pure $ AgentConnInfoReply (qInfo :| []) connInfo
enqueueConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> ConnInfo -> Maybe (CR.E2ERatchetParams 'C.X448) -> m ()
enqueueConfirmation c cData sq connInfo e2eEncryption_ = do
resumeMsgDelivery c cData sq
msgId <- storeConfirmation c cData sq e2eEncryption_ $ AgentConnInfo connInfo
queuePendingMsgs c sq [msgId]
storeConfirmation :: AgentMonad m => AgentClient -> ConnData -> SndQueue -> Maybe (CR.E2ERatchetParams 'C.X448) -> AgentMessage -> m InternalId
storeConfirmation c ConnData {connId, connAgentVersion} sq e2eEncryption_ agentMsg = withStore c $ \db -> runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
let agentMsgStr = smpEncode agentMsg
internalHash = C.sha256Hash agentMsgStr
encConnInfo <- agentRatchetEncrypt db connId agentMsgStr e2eEncConnInfoLength
let msgBody = smpEncode $ AgentConfirmation {agentVersion = connAgentVersion, e2eEncryption_, encConnInfo}
msgType = agentMessageType agentMsg
msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash}
liftIO $ createSndMsg db connId msgData
liftIO $ createSndMsgDelivery db connId sq internalId
pure internalId
enqueueRatchetKeyMsgs :: forall m. AgentMonad m => AgentClient -> ConnData -> NonEmpty SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId
enqueueRatchetKeyMsgs c cData (sq :| sqs) e2eEncryption = do
msgId <- enqueueRatchetKey c cData sq e2eEncryption
mapM_ (enqueueSavedMessage c cData msgId) $
filter (\SndQueue {status} -> status == Secured || status == Active) sqs
pure msgId
enqueueRatchetKey :: forall m. AgentMonad m => AgentClient -> ConnData -> SndQueue -> CR.E2ERatchetParams 'C.X448 -> m AgentMsgId
enqueueRatchetKey c cData@ConnData {connId} sq e2eEncryption = do
resumeMsgDelivery c cData sq
aVRange <- asks $ smpAgentVRange . config
msgId <- storeRatchetKey $ maxVersion aVRange
queuePendingMsgs c sq [msgId]
pure $ unId msgId
where
storeRatchetKey :: Version -> m InternalId
storeRatchetKey agentVersion = withStore c $ \db -> runExceptT $ do
internalTs <- liftIO getCurrentTime
(internalId, internalSndId, prevMsgHash) <- liftIO $ updateSndIds db connId
let agentMsg = AgentRatchetInfo ""
agentMsgStr = smpEncode agentMsg
internalHash = C.sha256Hash agentMsgStr
let msgBody = smpEncode $ AgentRatchetKey {agentVersion, e2eEncryption, info = agentMsgStr}
msgType = agentMessageType agentMsg
msgData = SndMsgData {internalId, internalSndId, internalTs, msgType, msgBody, msgFlags = SMP.MsgFlags {notification = True}, internalHash, prevMsgHash}
liftIO $ createSndMsg db connId msgData
liftIO $ createSndMsgDelivery db connId sq internalId
pure internalId
-- encoded AgentMessage -> encoded EncAgentMessage
agentRatchetEncrypt :: DB.Connection -> ConnId -> ByteString -> Int -> ExceptT StoreError IO ByteString
agentRatchetEncrypt db connId msg paddedLen = do
rc <- ExceptT $ getRatchet db connId
(encMsg, rc') <- liftE (SEAgentError . cryptoError) $ CR.rcEncrypt rc paddedLen msg
liftIO $ updateRatchet db connId rc' CR.SMDNoChange
pure encMsg
-- encoded EncAgentMessage -> encoded AgentMessage
agentRatchetDecrypt :: DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO ByteString
agentRatchetDecrypt db connId encAgentMsg = do
rc <- ExceptT $ getRatchet db connId
agentRatchetDecrypt' db connId rc encAgentMsg
agentRatchetDecrypt' :: DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO ByteString
agentRatchetDecrypt' db connId rc encAgentMsg = do
skipped <- liftIO $ getSkippedMsgKeys db connId
(agentMsgBody_, rc', skippedDiff) <- liftE (SEAgentError . cryptoError) $ CR.rcDecrypt rc skipped encAgentMsg
liftIO $ updateRatchet db connId rc' skippedDiff
liftEither $ first (SEAgentError . cryptoError) agentMsgBody_
newSndQueue :: (MonadUnliftIO m, MonadReader Env m) => UserId -> ConnId -> Compatible SMPQueueInfo -> m SndQueue
newSndQueue userId connId (Compatible (SMPQueueInfo smpClientVersion SMPQueueAddress {smpServer, senderId, dhPublicKey = rcvE2ePubDhKey})) = do
C.SignAlg a <- asks $ cmdSignAlg . config
(sndPublicKey, sndPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
(e2ePubKey, e2ePrivKey) <- liftIO C.generateKeyPair'
pure
SndQueue
{ userId,
connId,
server = smpServer,
sndId = senderId,
sndPublicKey = Just sndPublicKey,
sndPrivateKey,
e2eDhSecret = C.dh' rcvE2ePubDhKey e2ePrivKey,
e2ePubKey = Just e2ePubKey,
status = New,
dbQueueId = 0,
primary = True,
dbReplaceQueueId = Nothing,
sndSwchStatus = Nothing,
smpClientVersion
}