mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-30 14:16:00 +00:00
392 lines
15 KiB
Haskell
392 lines
15 KiB
Haskell
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE DuplicateRecordFields #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE InstanceSigs #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE OverloadedLists #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
|
|
|
module Simplex.Messaging.Agent.Client
|
|
( AgentClient (..),
|
|
newAgentClient,
|
|
AgentMonad,
|
|
withAgentLock,
|
|
closeAgentClient,
|
|
newRcvQueue,
|
|
subscribeQueue,
|
|
addSubscription,
|
|
sendConfirmation,
|
|
sendInvitation,
|
|
RetryInterval (..),
|
|
secureQueue,
|
|
sendAgentMessage,
|
|
agentCbEncrypt,
|
|
agentCbDecrypt,
|
|
cryptoError,
|
|
sendAck,
|
|
suspendQueue,
|
|
deleteQueue,
|
|
logServer,
|
|
removeSubscription,
|
|
)
|
|
where
|
|
|
|
import Control.Concurrent.Async (Async, async, uninterruptibleCancel)
|
|
import Control.Concurrent.STM (stateTVar)
|
|
import Control.Logger.Simple
|
|
import Control.Monad.Except
|
|
import Control.Monad.IO.Unlift
|
|
import Control.Monad.Reader
|
|
import Data.Bifunctor (first)
|
|
import Data.ByteString.Base64
|
|
import Data.ByteString.Char8 (ByteString)
|
|
import qualified Data.ByteString.Char8 as B
|
|
import Data.Map.Strict (Map)
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Maybe (isNothing)
|
|
import Data.Set (Set)
|
|
import qualified Data.Set as S
|
|
import Data.Text.Encoding
|
|
import Simplex.Messaging.Agent.Env.SQLite
|
|
import Simplex.Messaging.Agent.Protocol
|
|
import Simplex.Messaging.Agent.RetryInterval
|
|
import Simplex.Messaging.Agent.Store
|
|
import Simplex.Messaging.Client
|
|
import qualified Simplex.Messaging.Crypto as C
|
|
import Simplex.Messaging.Encoding
|
|
import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
|
|
import qualified Simplex.Messaging.Protocol as SMP
|
|
import Simplex.Messaging.Util (bshow, liftEitherError, liftError)
|
|
import Simplex.Messaging.Version
|
|
import UnliftIO.Exception (Exception, IOException)
|
|
import qualified UnliftIO.Exception as E
|
|
import UnliftIO.STM
|
|
|
|
data AgentClient = AgentClient
|
|
{ rcvQ :: TBQueue (ATransmission 'Client),
|
|
subQ :: TBQueue (ATransmission 'Agent),
|
|
msgQ :: TBQueue SMPServerTransmission,
|
|
smpClients :: TVar (Map SMPServer SMPClient),
|
|
subscrSrvrs :: TVar (Map SMPServer (Map ConnId RcvQueue)),
|
|
subscrConns :: TVar (Map ConnId SMPServer),
|
|
connMsgsQueued :: TVar (Map ConnId Bool),
|
|
smpQueueMsgQueues :: TVar (Map (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId)),
|
|
smpQueueMsgDeliveries :: TVar (Map (ConnId, SMPServer, SMP.SenderId) (Async ())),
|
|
reconnections :: TVar [Async ()],
|
|
clientId :: Int,
|
|
agentEnv :: Env,
|
|
smpSubscriber :: Async (),
|
|
lock :: TMVar ()
|
|
}
|
|
|
|
newAgentClient :: Env -> STM AgentClient
|
|
newAgentClient agentEnv = do
|
|
let qSize = tbqSize $ config agentEnv
|
|
rcvQ <- newTBQueue qSize
|
|
subQ <- newTBQueue qSize
|
|
msgQ <- newTBQueue qSize
|
|
smpClients <- newTVar M.empty
|
|
subscrSrvrs <- newTVar M.empty
|
|
subscrConns <- newTVar M.empty
|
|
connMsgsQueued <- newTVar M.empty
|
|
smpQueueMsgQueues <- newTVar M.empty
|
|
smpQueueMsgDeliveries <- newTVar M.empty
|
|
reconnections <- newTVar []
|
|
clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1)
|
|
lock <- newTMVar ()
|
|
return AgentClient {rcvQ, subQ, msgQ, smpClients, subscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, clientId, agentEnv, smpSubscriber = undefined, lock}
|
|
|
|
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
|
|
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
|
|
|
newtype InternalException e = InternalException {unInternalException :: e}
|
|
deriving (Eq, Show)
|
|
|
|
instance Exception e => Exception (InternalException e)
|
|
|
|
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
|
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
|
|
withRunInIO exceptToIO =
|
|
withExceptT unInternalException . ExceptT . E.try $
|
|
withRunInIO $ \run ->
|
|
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
|
|
|
|
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
|
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
|
readTVarIO smpClients
|
|
>>= maybe newSMPClient return . M.lookup srv
|
|
where
|
|
newSMPClient :: m SMPClient
|
|
newSMPClient = do
|
|
smp <- connectClient
|
|
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
|
atomically . modifyTVar smpClients $ M.insert srv smp
|
|
return smp
|
|
|
|
connectClient :: m SMPClient
|
|
connectClient = do
|
|
cfg <- asks $ smpCfg . config
|
|
u <- askUnliftIO
|
|
liftEitherError smpClientError (getSMPClient srv cfg msgQ $ clientDisconnected u)
|
|
`E.catch` internalError
|
|
where
|
|
internalError :: IOException -> m SMPClient
|
|
internalError = throwError . INTERNAL . show
|
|
|
|
clientDisconnected :: UnliftIO m -> IO ()
|
|
clientDisconnected u = do
|
|
removeClientSubs >>= (`forM_` serverDown u)
|
|
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
|
|
|
removeClientSubs :: IO (Maybe (Map ConnId RcvQueue))
|
|
removeClientSubs = atomically $ do
|
|
modifyTVar smpClients $ M.delete srv
|
|
cs <- M.lookup srv <$> readTVar (subscrSrvrs c)
|
|
modifyTVar (subscrSrvrs c) $ M.delete srv
|
|
modifyTVar (subscrConns c) $ maybe id (deleteKeys . M.keysSet) cs
|
|
return cs
|
|
where
|
|
deleteKeys :: Ord k => Set k -> Map k a -> Map k a
|
|
deleteKeys ks m = S.foldr' M.delete m ks
|
|
|
|
serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO ()
|
|
serverDown u cs = unless (M.null cs) $ do
|
|
mapM_ (notifySub DOWN) $ M.keysSet cs
|
|
a <- async . unliftIO u $ tryReconnectClient cs
|
|
atomically $ modifyTVar (reconnections c) (a :)
|
|
|
|
tryReconnectClient :: Map ConnId RcvQueue -> m ()
|
|
tryReconnectClient cs = do
|
|
ri <- asks $ reconnectInterval . config
|
|
withRetryInterval ri $ \loop ->
|
|
reconnectClient cs `catchError` const loop
|
|
|
|
reconnectClient :: Map ConnId RcvQueue -> m ()
|
|
reconnectClient cs = do
|
|
withAgentLock c . withSMP c srv $ \smp -> do
|
|
subs <- readTVarIO $ subscrConns c
|
|
forM_ (M.toList cs) $ \(connId, rq@RcvQueue {rcvPrivateKey, rcvId}) ->
|
|
when (isNothing $ M.lookup connId subs) $ do
|
|
subscribeSMPQueue smp rcvPrivateKey rcvId
|
|
`catchError` \case
|
|
SMPServerError e -> liftIO $ notifySub (ERR $ SMP e) connId
|
|
e -> throwError e
|
|
addSubscription c rq connId
|
|
liftIO $ notifySub UP connId
|
|
|
|
notifySub :: ACommand 'Agent -> ConnId -> IO ()
|
|
notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
|
|
|
closeAgentClient :: MonadUnliftIO m => AgentClient -> m ()
|
|
closeAgentClient c = liftIO $ do
|
|
closeSMPServerClients c
|
|
cancelActions $ reconnections c
|
|
cancelActions $ smpQueueMsgDeliveries c
|
|
|
|
closeSMPServerClients :: AgentClient -> IO ()
|
|
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ closeSMPClient
|
|
|
|
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
|
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
|
|
|
|
withAgentLock :: MonadUnliftIO m => AgentClient -> m a -> m a
|
|
withAgentLock AgentClient {lock} =
|
|
E.bracket_
|
|
(void . atomically $ takeTMVar lock)
|
|
(atomically $ putTMVar lock ())
|
|
|
|
withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a
|
|
withSMP_ c srv action =
|
|
(getSMPServerClient c srv >>= action) `catchError` logServerError
|
|
where
|
|
logServerError :: AgentErrorType -> m a
|
|
logServerError e = do
|
|
logServer "<--" c srv "" $ bshow e
|
|
throwError e
|
|
|
|
withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a
|
|
withLogSMP_ c srv qId cmdStr action = do
|
|
logServer "-->" c srv qId cmdStr
|
|
res <- withSMP_ c srv action
|
|
logServer "<--" c srv qId "OK"
|
|
return res
|
|
|
|
withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
|
withSMP c srv action = withSMP_ c srv $ liftSMP . action
|
|
|
|
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
|
withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action
|
|
|
|
liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a
|
|
liftSMP = liftError smpClientError
|
|
|
|
smpClientError :: SMPClientError -> AgentErrorType
|
|
smpClientError = \case
|
|
SMPServerError e -> SMP e
|
|
SMPResponseError e -> BROKER $ RESPONSE e
|
|
SMPUnexpectedResponse -> BROKER UNEXPECTED
|
|
SMPResponseTimeout -> BROKER TIMEOUT
|
|
SMPNetworkError -> BROKER NETWORK
|
|
SMPTransportError e -> BROKER $ TRANSPORT e
|
|
e -> INTERNAL $ show e
|
|
|
|
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueUri)
|
|
newRcvQueue c srv =
|
|
asks (cmdSignAlg . config) >>= \case
|
|
C.SignAlg a -> newRcvQueue_ a c srv
|
|
|
|
newRcvQueue_ ::
|
|
(C.SignatureAlgorithm a, C.AlgorithmI a, AgentMonad m) =>
|
|
C.SAlgorithm a ->
|
|
AgentClient ->
|
|
SMPServer ->
|
|
m (RcvQueue, SMPQueueUri)
|
|
newRcvQueue_ a c srv = do
|
|
(recipientKey, rcvPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
|
(dhKey, privDhKey) <- liftIO C.generateKeyPair'
|
|
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
|
logServer "-->" c srv "" "NEW"
|
|
QIK {rcvId, sndId, rcvPublicDhKey} <-
|
|
withSMP c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey
|
|
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
|
let rq =
|
|
RcvQueue
|
|
{ server = srv,
|
|
rcvId,
|
|
rcvPrivateKey,
|
|
rcvDhSecret = C.dh' rcvPublicDhKey privDhKey,
|
|
e2ePrivKey,
|
|
e2eDhSecret = Nothing,
|
|
sndId = Just sndId,
|
|
status = New
|
|
}
|
|
pure (rq, SMPQueueUri srv sndId SMP.smpClientVRange e2eDhKey)
|
|
|
|
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
|
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
|
withLogSMP c server rcvId "SUB" $ \smp ->
|
|
subscribeSMPQueue smp rcvPrivateKey rcvId
|
|
addSubscription c rq connId
|
|
|
|
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
|
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
|
modifyTVar (subscrConns c) $ M.insert connId server
|
|
modifyTVar (subscrSrvrs c) $ M.alter (Just . addSub) server
|
|
where
|
|
addSub :: Maybe (Map ConnId RcvQueue) -> Map ConnId RcvQueue
|
|
addSub (Just cs) = M.insert connId rq cs
|
|
addSub _ = M.singleton connId rq
|
|
|
|
removeSubscription :: AgentMonad m => AgentClient -> ConnId -> m ()
|
|
removeSubscription AgentClient {subscrConns, subscrSrvrs} connId = atomically $ do
|
|
cs <- readTVar subscrConns
|
|
writeTVar subscrConns $ M.delete connId cs
|
|
mapM_
|
|
(modifyTVar subscrSrvrs . M.alter (>>= delSub))
|
|
(M.lookup connId cs)
|
|
where
|
|
delSub :: Map ConnId RcvQueue -> Maybe (Map ConnId RcvQueue)
|
|
delSub cs =
|
|
let cs' = M.delete connId cs
|
|
in if M.null cs' then Nothing else Just cs'
|
|
|
|
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
|
logServer dir AgentClient {clientId} srv qId cmdStr =
|
|
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
|
|
|
showServer :: SMPServer -> ByteString
|
|
showServer srv = B.pack $ host srv <> maybe "" (":" <>) (port srv)
|
|
|
|
logSecret :: ByteString -> ByteString
|
|
logSecret bs = encode $ B.take 3 bs
|
|
|
|
-- TODO maybe package E2ERatchetParams into SMPConfirmation
|
|
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
|
sendConfirmation c SndQueue {server, sndId} encConfirmation =
|
|
withLogSMP_ c server sndId "SEND <CONF>" $ \smp ->
|
|
liftSMP $ sendSMPMessage smp Nothing sndId encConfirmation
|
|
|
|
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
|
sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo =
|
|
withLogSMP_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
|
msg <- mkInvitation
|
|
liftSMP $ sendSMPMessage smp Nothing senderId msg
|
|
where
|
|
mkInvitation :: m ByteString
|
|
-- this is only encrypted with per-queue E2E, not with double ratchet
|
|
mkInvitation = do
|
|
let agentEnvelope = AgentInvitation {agentVersion = smpAgentVersion, connReq, connInfo}
|
|
agentCbEncryptOnce dhPublicKey . smpEncode $
|
|
SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope
|
|
|
|
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
|
|
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
|
|
withLogSMP c server rcvId "KEY <key>" $ \smp ->
|
|
secureSMPQueue smp rcvPrivateKey rcvId senderKey
|
|
|
|
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
|
sendAck c RcvQueue {server, rcvId, rcvPrivateKey} =
|
|
withLogSMP c server rcvId "ACK" $ \smp ->
|
|
ackSMPMessage smp rcvPrivateKey rcvId
|
|
|
|
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
|
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
|
withLogSMP c server rcvId "OFF" $ \smp ->
|
|
suspendSMPQueue smp rcvPrivateKey rcvId
|
|
|
|
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
|
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
|
withLogSMP c server rcvId "DEL" $ \smp ->
|
|
deleteSMPQueue smp rcvPrivateKey rcvId
|
|
|
|
-- TODO this is just wrong
|
|
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
|
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg =
|
|
withLogSMP_ c server sndId "SEND <MSG>" $ \smp -> do
|
|
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
|
|
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
|
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
|
|
|
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
|
agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do
|
|
cmNonce <- liftIO C.randomCbNonce
|
|
cmEncBody <-
|
|
liftEither . first cryptoError $
|
|
C.cbEncrypt e2eDhSecret cmNonce msg SMP.e2eEncMessageLength
|
|
-- TODO per-queue client version
|
|
let cmHeader = SMP.PubHeader (maxVersion SMP.smpClientVRange) e2ePubKey
|
|
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
|
|
|
|
-- add encoding as AgentInvitation'?
|
|
agentCbEncryptOnce :: AgentMonad m => C.PublicKeyX25519 -> ByteString -> m ByteString
|
|
agentCbEncryptOnce dhRcvPubKey msg = do
|
|
(dhSndPubKey, dhSndPrivKey) <- liftIO C.generateKeyPair'
|
|
let e2eDhSecret = C.dh' dhRcvPubKey dhSndPrivKey
|
|
cmNonce <- liftIO C.randomCbNonce
|
|
cmEncBody <-
|
|
liftEither . first cryptoError $
|
|
C.cbEncrypt e2eDhSecret cmNonce msg SMP.e2eEncMessageLength
|
|
-- TODO per-queue client version
|
|
let cmHeader = SMP.PubHeader (maxVersion SMP.smpClientVRange) (Just dhSndPubKey)
|
|
pure $ smpEncode SMP.ClientMsgEnvelope {cmHeader, cmNonce, cmEncBody}
|
|
|
|
-- | NaCl crypto-box decrypt - both for messages received from the server
|
|
-- and per-queue E2E encrypted messages from the sender that were inside.
|
|
agentCbDecrypt :: AgentMonad m => C.DhSecretX25519 -> C.CbNonce -> ByteString -> m ByteString
|
|
agentCbDecrypt dhSecret nonce msg =
|
|
liftEither . first cryptoError $
|
|
C.cbDecrypt dhSecret nonce msg
|
|
|
|
cryptoError :: C.CryptoError -> AgentErrorType
|
|
cryptoError = \case
|
|
C.CryptoLargeMsgError -> CMD LARGE
|
|
C.CryptoHeaderError _ -> AGENT A_ENCRYPTION
|
|
C.AESDecryptError -> AGENT A_ENCRYPTION
|
|
C.CBDecryptError -> AGENT A_ENCRYPTION
|
|
e -> INTERNAL $ show e
|