From d31958855f7072340136a15dc508e082cfcfecfd Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 2 Apr 2022 16:14:19 +0100 Subject: [PATCH] ntf server implementation, updated ntf protocol, ntf client based on refactored protocol client, bare-bones SMP agent to manage ntf connections (to connect to ntf server) (#338) * process ntf server commands * when subscription is re-created and it was ENDed, resubscribe to SMP * SMPClientAgent draft * SMPClientAgent: remove double tracking of subscriptions * subscriber frame * PING error now throws error to restart SMPClient for more reliable re-connection (#342) * increase TCP timeout to 5 sec * add pragmas and vacuum db (#343) * vacuum in each connection to enable auto-vacuum (#344) * update protocol, token verification * refactor SMPClient to ProtocoClient, to use with notification server protocol * notification server client, managing notification clients in the agent * stub for push payload Co-authored-by: JRoberts <8711996+jr-simplex@users.noreply.github.com> --- .../diagrams/notifications/register-token.mmd | 30 ++ simplexmq.cabal | 2 + src/Simplex/Messaging/Agent.hs | 16 +- src/Simplex/Messaging/Agent/Client.hs | 264 ++++++++------- src/Simplex/Messaging/Agent/Env/SQLite.hs | 7 +- src/Simplex/Messaging/Agent/Protocol.hs | 7 +- src/Simplex/Messaging/Agent/Store/SQLite.hs | 32 +- src/Simplex/Messaging/Client.hs | 209 ++++++------ src/Simplex/Messaging/Client/Agent.hs | 301 +++++++++++++++++ src/Simplex/Messaging/Notifications/Client.hs | 55 +++ .../Messaging/Notifications/Protocol.hs | 312 ++++++++++++++---- src/Simplex/Messaging/Notifications/Server.hs | 218 ++++++++++-- .../Messaging/Notifications/Server/Env.hs | 59 +++- .../Messaging/Notifications/Server/Push.hs | 11 + .../Notifications/Server/Subscriptions.hs | 114 ++++++- src/Simplex/Messaging/Protocol.hs | 76 +++-- src/Simplex/Messaging/Server.hs | 2 +- src/Simplex/Messaging/TMap.hs | 7 + src/Simplex/Messaging/Transport/KeepAlive.hs | 2 +- tests/AgentTests/ConnectionRequestTests.hs | 4 +- tests/SMPAgentClient.hs | 4 +- 21 files changed, 1337 insertions(+), 395 deletions(-) create mode 100644 protocol/diagrams/notifications/register-token.mmd create mode 100644 src/Simplex/Messaging/Client/Agent.hs create mode 100644 src/Simplex/Messaging/Notifications/Client.hs create mode 100644 src/Simplex/Messaging/Notifications/Server/Push.hs diff --git a/protocol/diagrams/notifications/register-token.mmd b/protocol/diagrams/notifications/register-token.mmd new file mode 100644 index 000000000..bceef296f --- /dev/null +++ b/protocol/diagrams/notifications/register-token.mmd @@ -0,0 +1,30 @@ +sequenceDiagram + participant M as mobile app + participant C as chat core + participant A as agent + participant P as push server + participant APN as APN + + note over M, APN: get device token + M ->> APN: registerForRemoteNotifications() + APN ->> M: device token + + note over M, P: register device token with push server + M ->> C: /_ntf register + C ->> A: registerNtfToken() + A ->> P: TNEW + P ->> A: ID (tokenId) + A ->> C: registered + C ->> M: registered + + note over M, APN: verify device token + P ->> APN: E2E encrypted code
in background
notification + APN ->> M: deliver background notification with e2ee verification token + M ->> C: /_ntf verify + C ->> A: verifyNtfToken() + A ->> P: TVFY code + P ->> A: OK / ERR + A ->> C: verified + C ->> M: verified + + note over M, APN: now token ID can be used diff --git a/simplexmq.cabal b/simplexmq.cabal index c5687addf..3b9e8eb7e 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -43,10 +43,12 @@ library Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications Simplex.Messaging.Client + Simplex.Messaging.Client.Agent Simplex.Messaging.Crypto Simplex.Messaging.Crypto.Ratchet Simplex.Messaging.Encoding Simplex.Messaging.Encoding.String + Simplex.Messaging.Notifications.Client Simplex.Messaging.Notifications.Protocol Simplex.Messaging.Notifications.Server Simplex.Messaging.Notifications.Server.Env diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index f1ecad909..821f5fe5a 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -76,12 +76,13 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore) -import Simplex.Messaging.Client (SMPServerTransmission) +import Simplex.Messaging.Client (ServerTransmission) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding +import Simplex.Messaging.Notifications.Protocol (DeviceToken) import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (MsgBody) +import Simplex.Messaging.Protocol (BrokerMsg, MsgBody) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM) @@ -149,6 +150,10 @@ deleteConnection c = withAgentEnv c . deleteConnection' c setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServer -> m () setSMPServers c = withAgentEnv c . setSMPServers' c +-- | Register device notifications token +registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m () +registerNtfToken c = withAgentEnv c . registerNtfToken' c + withAgentEnv :: AgentClient -> ReaderT Env m a -> m a withAgentEnv c = (`runReaderT` agentEnv c) @@ -490,10 +495,13 @@ deleteConnection' c connId = withStore (`deleteConn` connId) -- | Change servers to be used for creating new queues, in Reader monad -setSMPServers' :: forall m. AgentMonad m => AgentClient -> NonEmpty SMPServer -> m () +setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m () setSMPServers' c servers = do atomically $ writeTVar (smpServers c) servers +registerNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m () +registerNtfToken' c token = pure () + getSMPServer :: AgentMonad m => AgentClient -> m SMPServer getSMPServer c = do smpServers <- readTVarIO $ smpServers c @@ -511,7 +519,7 @@ subscriber c@AgentClient {msgQ} = forever $ do Left e -> liftIO $ print e Right _ -> return () -processSMPTransmission :: forall m. AgentMonad m => AgentClient -> SMPServerTransmission -> m () +processSMPTransmission :: forall m. AgentMonad m => AgentClient -> ServerTransmission BrokerMsg -> m () processSMPTransmission c@AgentClient {subQ} (srv, rId, cmd) = do withStore (\st -> getRcvConn st srv rId) >>= \case SomeConn SCDuplex (DuplexConnection cData rq _) -> processSMP SCDuplex cData rq diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 96bcd4529..321fc82ab 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -2,14 +2,13 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# LANGUAGE TypeSynonymInstances #-} module Simplex.Messaging.Agent.Client ( AgentClient (..), @@ -57,9 +56,12 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store import Simplex.Messaging.Client +import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding -import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey) +import Simplex.Messaging.Notifications.Client (NtfClient) +import Simplex.Messaging.Notifications.Protocol (NtfResponse) +import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -67,18 +69,22 @@ import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, when import Simplex.Messaging.Version import System.Timeout (timeout) import UnliftIO (async, forConcurrently_) -import UnliftIO.Exception (Exception, IOException) import qualified UnliftIO.Exception as E import UnliftIO.STM +type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg)) + type SMPClientVar = TMVar (Either AgentErrorType SMPClient) +type NtfClientVar = TMVar (Either AgentErrorType NtfClient) + data AgentClient = AgentClient { rcvQ :: TBQueue (ATransmission 'Client), subQ :: TBQueue (ATransmission 'Agent), - msgQ :: TBQueue SMPServerTransmission, + msgQ :: TBQueue (ServerTransmission BrokerMsg), smpServers :: TVar (NonEmpty SMPServer), smpClients :: TMap SMPServer SMPClientVar, + ntfClients :: TMap ProtocolServer NtfClientVar, subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), subscrConns :: TMap ConnId SMPServer, @@ -101,6 +107,7 @@ newAgentClient agentEnv = do msgQ <- newTBQueue qSize smpServers <- newTVar $ initialSMPServers (config agentEnv) smpClients <- TM.empty + ntfClients <- TM.empty subscrSrvrs <- TM.empty pendingSubscrSrvrs <- TM.empty subscrConns <- TM.empty @@ -111,80 +118,30 @@ newAgentClient agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) lock <- newTMVar () - return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock} + return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock} -- | Agent monad with MonadReader Env and MonadError AgentErrorType type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) -newtype InternalException e = InternalException {unInternalException :: e} - deriving (Eq, Show) +class ProtocolServerClient msg where + getProtocolServerClient :: AgentMonad m => AgentClient -> ProtocolServer -> m (ProtocolClient msg) -instance Exception e => Exception (InternalException e) +instance ProtocolServerClient BrokerMsg where getProtocolServerClient = getSMPServerClient -instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where - withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b - withRunInIO exceptToIO = - withExceptT unInternalException . ExceptT . E.try $ - withRunInIO $ \run -> - exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT) +instance ProtocolServerClient NtfResponse where getProtocolServerClient = getNtfServerClient getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient getSMPServerClient c@AgentClient {smpClients, msgQ} srv = - atomically getClientVar >>= either newSMPClient waitForSMPClient + atomically (getClientVar srv smpClients) + >>= either + (newProtocolClient c srv smpClients connectClient reconnectClient) + (waitForProtocolClient smpCfg) where - getClientVar :: STM (Either SMPClientVar SMPClientVar) - getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients - - newClientVar :: STM SMPClientVar - newClientVar = do - smpVar <- newEmptyTMVar - TM.insert srv smpVar smpClients - pure smpVar - - waitForSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient - waitForSMPClient smpVar = do - SMPClientConfig {tcpTimeout} <- asks $ smpCfg . config - smpClient_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar smpVar) - liftEither $ case smpClient_ of - Just (Right smpClient) -> Right smpClient - Just (Left e) -> Left e - Nothing -> Left $ BROKER TIMEOUT - - newSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient - newSMPClient smpVar = tryConnectClient pure tryConnectAsync - where - tryConnectClient :: (SMPClient -> m a) -> m () -> m a - tryConnectClient successAction retryAction = - tryError connectClient >>= \r -> case r of - Right smp -> do - logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv - atomically $ putTMVar smpVar r - successAction smp - Left e -> do - if e == BROKER NETWORK || e == BROKER TIMEOUT - then retryAction - else atomically $ do - putTMVar smpVar (Left e) - TM.delete srv smpClients - throwError e - tryConnectAsync :: m () - tryConnectAsync = do - a <- async connectAsync - atomically $ modifyTVar' (asyncClients c) (a :) - connectAsync :: m () - connectAsync = do - ri <- asks $ reconnectInterval . config - withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop - connectClient :: m SMPClient connectClient = do cfg <- asks $ smpCfg . config u <- askUnliftIO - liftEitherError smpClientError (getSMPClient srv cfg msgQ $ clientDisconnected u) - `E.catch` internalError - where - internalError :: IOException -> m SMPClient - internalError = throwError . INTERNAL . show + liftEitherError protocolClientError (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u) clientDisconnected :: UnliftIO m -> IO () clientDisconnected u = do @@ -194,13 +151,14 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) removeClientAndSubs = atomically $ do TM.delete srv smpClients - cVar_ <- TM.lookupDelete srv $ subscrSrvrs c - forM cVar_ $ \cVar -> do - cs <- readTVar cVar - modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) - addPendingSubs cVar cs - pure cs + TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs where + updateSubs cVar = do + cs <- readTVar cVar + modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) + addPendingSubs cVar cs + pure cs + addPendingSubs cVar cs = do let ps = pendingSubscrSrvrs c TM.lookup srv ps >>= \case @@ -225,30 +183,100 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = reconnectClient :: m () reconnectClient = - withAgentLock c . withSMP c srv $ \smp -> do + withAgentLock c . withClient c srv $ \smp -> do cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c) forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) -> whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $ subscribe_ smp sub `catchError` handleError connId where - subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT SMPClientError IO () + subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT ProtocolClientError IO () subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do subscribeSMPQueue smp rcvPrivateKey rcvId addSubscription c rq connId liftIO $ notifySub UP connId - handleError :: ConnId -> SMPClientError -> ExceptT SMPClientError IO () + handleError :: ConnId -> ProtocolClientError -> ExceptT ProtocolClientError IO () handleError connId = \case - e@SMPResponseTimeout -> throwError e - e@SMPNetworkError -> throwError e + e@PCEResponseTimeout -> throwError e + e@PCENetworkError -> throwError e e -> do - liftIO $ notifySub (ERR $ smpClientError e) connId + liftIO $ notifySub (ERR $ protocolClientError e) connId atomically $ removePendingSubscription c srv connId notifySub :: ACommand 'Agent -> ConnId -> IO () notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd) -closeAgentClient :: MonadUnliftIO m => AgentClient -> m () +getNtfServerClient :: forall m. AgentMonad m => AgentClient -> ProtocolServer -> m NtfClient +getNtfServerClient c@AgentClient {ntfClients} srv = + atomically (getClientVar srv ntfClients) + >>= either + (newProtocolClient c srv ntfClients connectClient $ pure ()) + (waitForProtocolClient ntfCfg) + where + connectClient :: m NtfClient + connectClient = do + cfg <- asks $ ntfCfg . config + liftEitherError protocolClientError (getProtocolClient srv cfg Nothing clientDisconnected) + + clientDisconnected :: IO () + clientDisconnected = do + atomically $ TM.delete srv ntfClients + logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv + +getClientVar :: forall a. ProtocolServer -> TMap ProtocolServer (TMVar a) -> STM (Either (TMVar a) (TMVar a)) +getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients + where + newClientVar :: STM (TMVar a) + newClientVar = do + var <- newEmptyTMVar + TM.insert srv var clients + pure var + +waitForProtocolClient :: AgentMonad m => (AgentConfig -> ProtocolClientConfig) -> ClientVar msg -> m (ProtocolClient msg) +waitForProtocolClient clientConfig clientVar = do + ProtocolClientConfig {tcpTimeout} <- asks $ clientConfig . config + client_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar clientVar) + liftEither $ case client_ of + Just (Right smpClient) -> Right smpClient + Just (Left e) -> Left e + Nothing -> Left $ BROKER TIMEOUT + +newProtocolClient :: + forall msg m. + AgentMonad m => + AgentClient -> + ProtocolServer -> + TMap ProtocolServer (ClientVar msg) -> + m (ProtocolClient msg) -> + m () -> + ClientVar msg -> + m (ProtocolClient msg) +newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync + where + tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a + tryConnectClient successAction retryAction = + tryError connectClient >>= \r -> case r of + Right smp -> do + logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv + atomically $ putTMVar clientVar r + successAction smp + Left e -> do + if e == BROKER NETWORK || e == BROKER TIMEOUT + then retryAction + else atomically $ do + putTMVar clientVar (Left e) + TM.delete srv clients + throwError e + tryConnectAsync :: m () + tryConnectAsync = do + a <- async connectAsync + atomically $ modifyTVar' (asyncClients c) (a :) + connectAsync :: m () + connectAsync = do + ri <- asks $ reconnectInterval . config + withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop + +closeAgentClient :: MonadIO m => AgentClient -> m () closeAgentClient c = liftIO $ do closeSMPServerClients c cancelActions $ reconnections c @@ -260,7 +288,7 @@ closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeCli where closeClient smpVar = atomically (readTMVar smpVar) >>= \case - Right smp -> closeSMPClient smp `E.catch` \(_ :: E.SomeException) -> pure () + Right smp -> closeProtocolClient smp `E.catch` \(_ :: E.SomeException) -> pure () _ -> pure () cancelActions :: Foldable f => TVar (f (Async ())) -> IO () @@ -272,40 +300,40 @@ withAgentLock AgentClient {lock} = (void . atomically $ takeTMVar lock) (atomically $ putTMVar lock ()) -withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a -withSMP_ c srv action = - (getSMPServerClient c srv >>= action) `catchError` logServerError +withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> m a) -> m a +withClient_ c srv action = (getProtocolServerClient c srv >>= action) `catchError` logServerError where logServerError :: AgentErrorType -> m a logServerError e = do logServer "<--" c srv "" $ bshow e throwError e -withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a -withLogSMP_ c srv qId cmdStr action = do +withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a +withLogClient_ c srv qId cmdStr action = do logServer "-->" c srv qId cmdStr - res <- withSMP_ c srv action + res <- withClient_ c srv action logServer "<--" c srv qId "OK" return res -withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a -withSMP c srv action = withSMP_ c srv $ liftSMP . action +withClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withClient c srv action = withClient_ c srv $ liftClient . action -withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a -withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action +withLogClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient . action -liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a -liftSMP = liftError smpClientError +liftClient :: AgentMonad m => ExceptT ProtocolClientError IO a -> m a +liftClient = liftError protocolClientError -smpClientError :: SMPClientError -> AgentErrorType -smpClientError = \case - SMPServerError e -> SMP e - SMPResponseError e -> BROKER $ RESPONSE e - SMPUnexpectedResponse -> BROKER UNEXPECTED - SMPResponseTimeout -> BROKER TIMEOUT - SMPNetworkError -> BROKER NETWORK - SMPTransportError e -> BROKER $ TRANSPORT e - e -> INTERNAL $ show e +protocolClientError :: ProtocolClientError -> AgentErrorType +protocolClientError = \case + PCEProtocolError e -> SMP e + PCEResponseError e -> BROKER $ RESPONSE e + PCEUnexpectedResponse -> BROKER UNEXPECTED + PCEResponseTimeout -> BROKER TIMEOUT + PCENetworkError -> BROKER NETWORK + PCETransportError e -> BROKER $ TRANSPORT e + e@PCESignatureError {} -> INTERNAL $ show e + e@PCEIOError {} -> INTERNAL $ show e newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueUri) newRcvQueue c srv = @@ -324,7 +352,7 @@ newRcvQueue_ a c srv = do (e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair' logServer "-->" c srv "" "NEW" QIK {rcvId, sndId, rcvPublicDhKey} <- - withSMP c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey + withClient c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId] let rq = RcvQueue @@ -342,15 +370,15 @@ newRcvQueue_ a c srv = do subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do atomically $ addPendingSubscription c rq connId - withLogSMP c server rcvId "SUB" $ \smp -> do + withLogClient c server rcvId "SUB" $ \smp -> do liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case Left e -> do - atomically . when (e /= SMPNetworkError && e /= SMPResponseTimeout) $ + atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $ removePendingSubscription c server connId throwError e Right _ -> addSubscription c rq connId -addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m () +addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do TM.insert connId server $ subscrConns c addSubs_ (subscrSrvrs c) rq connId @@ -365,7 +393,7 @@ addSubs_ ss rq@RcvQueue {server} connId = Just m -> TM.insert connId rq m _ -> TM.singleton connId rq >>= \m -> TM.insert server m ss -removeSubscription :: MonadUnliftIO m => AgentClient -> ConnId -> m () +removeSubscription :: MonadIO m => AgentClient -> ConnId -> m () removeSubscription c@AgentClient {subscrConns} connId = atomically $ do server_ <- TM.lookupDelete connId subscrConns mapM_ (\server -> removeSubs_ (subscrSrvrs c) server connId) server_ @@ -377,12 +405,12 @@ removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> S removeSubs_ ss server connId = TM.lookup server ss >>= mapM_ (TM.delete connId) -logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () +logServer :: MonadIO m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m () logServer dir AgentClient {clientId} srv qId cmdStr = logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr] showServer :: SMPServer -> ByteString -showServer SMPServer {host, port} = +showServer ProtocolServer {host, port} = B.pack $ host <> if null port then "" else ':' : port logSecret :: ByteString -> ByteString @@ -390,17 +418,17 @@ logSecret bs = encode $ B.take 3 bs sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation = - withLogSMP_ c server sndId "SEND " $ \smp -> do + withLogClient_ c server sndId "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg - liftSMP $ sendSMPMessage smp Nothing sndId msg + liftClient $ sendSMPMessage smp Nothing sndId msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo = - withLogSMP_ c smpServer senderId "SEND " $ \smp -> do + withLogClient_ c smpServer senderId "SEND " $ \smp -> do msg <- mkInvitation - liftSMP $ sendSMPMessage smp Nothing senderId msg + liftClient $ sendSMPMessage smp Nothing senderId msg where mkInvitation :: m ByteString -- this is only encrypted with per-queue E2E, not with double ratchet @@ -411,30 +439,30 @@ sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) co secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m () secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey = - withLogSMP c server rcvId "KEY " $ \smp -> + withLogClient c server rcvId "KEY " $ \smp -> secureSMPQueue smp rcvPrivateKey rcvId senderKey sendAck :: AgentMonad m => AgentClient -> RcvQueue -> m () sendAck c RcvQueue {server, rcvId, rcvPrivateKey} = - withLogSMP c server rcvId "ACK" $ \smp -> + withLogClient c server rcvId "ACK" $ \smp -> ackSMPMessage smp rcvPrivateKey rcvId suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} = - withLogSMP c server rcvId "OFF" $ \smp -> + withLogClient c server rcvId "OFF" $ \smp -> suspendSMPQueue smp rcvPrivateKey rcvId deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m () deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} = - withLogSMP c server rcvId "DEL" $ \smp -> + withLogClient c server rcvId "DEL" $ \smp -> deleteSMPQueue smp rcvPrivateKey rcvId sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m () sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg = - withLogSMP_ c server sndId "SEND " $ \smp -> do + withLogClient_ c server sndId "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg - liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg + liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index a1f03cf6f..66f510ff6 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -23,6 +23,7 @@ import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Store.SQLite import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client +import Simplex.Messaging.Client.Agent (SMPClientAgentConfig, defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import System.Random (StdGen, newStdGen) import UnliftIO.STM @@ -36,7 +37,8 @@ data AgentConfig = AgentConfig dbFile :: FilePath, dbPoolSize :: Int, yesToMigrations :: Bool, - smpCfg :: SMPClientConfig, + smpCfg :: ProtocolClientConfig, + ntfCfg :: ProtocolClientConfig, reconnectInterval :: RetryInterval, helloTimeout :: NominalDiffTime, caCertificateFile :: FilePath, @@ -55,7 +57,8 @@ defaultAgentConfig = dbFile = "smp-agent.db", dbPoolSize = 4, yesToMigrations = False, - smpCfg = smpDefaultConfig, + smpCfg = defaultClientConfig, + ntfCfg = defaultClientConfig, reconnectInterval = RetryInterval { initialInterval = second, diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index fa1d46a8a..dbe47d908 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -7,6 +7,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -49,7 +50,8 @@ module Simplex.Messaging.Agent.Protocol AgentMessageType (..), APrivHeader (..), AMessage (..), - SMPServer (..), + SMPServer, + pattern SMPServer, SrvLoc (..), SMPQueueUri (..), SMPQueueInfo (..), @@ -131,9 +133,10 @@ import Simplex.Messaging.Protocol ( ErrorType, MsgBody, MsgId, - SMPServer (..), + SMPServer, SndPublicVerifyKey, SrvLoc (..), + pattern SMPServer, ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTransportError, transportErrorP) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 295e28620..89a3c0e69 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -9,6 +9,7 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -59,7 +60,7 @@ import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), Skipp import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldParser) -import Simplex.Messaging.Protocol (MsgBody) +import Simplex.Messaging.Protocol (MsgBody, ProtocolServer (..)) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftIOEither) import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) @@ -127,9 +128,26 @@ connectSQLiteStore dbFilePath poolSize = do connectDB :: FilePath -> IO DB.Connection connectDB path = do dbConn <- DB.open path - DB.execute_ dbConn "PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL;" + DB.execute_ dbConn "PRAGMA foreign_keys = ON;" + -- DB.execute_ dbConn "PRAGMA trusted_schema = OFF;" + DB.execute_ dbConn "PRAGMA secure_delete = ON;" + DB.execute_ dbConn "PRAGMA auto_vacuum = FULL;" + DB.execute_ dbConn "VACUUM;" + -- _printPragmas dbConn path pure dbConn +_printPragmas :: DB.Connection -> FilePath -> IO () +_printPragmas db path = do + foreign_keys <- DB.query_ db "PRAGMA foreign_keys;" :: IO [[Int]] + print $ path <> " foreign_keys: " <> show foreign_keys + -- when run via sqlite-simple query for trusted_schema seems to return empty list + trusted_schema <- DB.query_ db "PRAGMA trusted_schema;" :: IO [[Int]] + print $ path <> " trusted_schema: " <> show trusted_schema + secure_delete <- DB.query_ db "PRAGMA secure_delete;" :: IO [[Int]] + print $ path <> " secure_delete: " <> show secure_delete + auto_vacuum <- DB.query_ db "PRAGMA auto_vacuum;" :: IO [[Int]] + print $ path <> " auto_vacuum: " <> show auto_vacuum + checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a) checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err) @@ -190,7 +208,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto getConn_ db connId getRcvConn :: SQLiteStore -> SMPServer -> SMP.RecipientId -> m SomeConn - getRcvConn st SMPServer {host, port} rcvId = + getRcvConn st ProtocolServer {host, port} rcvId = liftIOEither . withTransaction st $ \db -> DB.queryNamed db @@ -235,7 +253,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto _ -> pure $ Left SEConnNotFound setRcvQueueStatus :: SQLiteStore -> RcvQueue -> QueueStatus -> m () - setRcvQueueStatus st RcvQueue {rcvId, server = SMPServer {host, port}} status = + setRcvQueueStatus st RcvQueue {rcvId, server = ProtocolServer {host, port}} status = -- ? throw error if queue does not exist? liftIO . withTransaction st $ \db -> DB.executeNamed @@ -248,7 +266,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto [":status" := status, ":host" := host, ":port" := port, ":rcv_id" := rcvId] setRcvQueueConfirmedE2E :: SQLiteStore -> RcvQueue -> C.DhSecretX25519 -> m () - setRcvQueueConfirmedE2E st RcvQueue {rcvId, server = SMPServer {host, port}} e2eDhSecret = + setRcvQueueConfirmedE2E st RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret = liftIO . withTransaction st $ \db -> DB.executeNamed db @@ -266,7 +284,7 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto ] setSndQueueStatus :: SQLiteStore -> SndQueue -> QueueStatus -> m () - setSndQueueStatus st SndQueue {sndId, server = SMPServer {host, port}} status = + setSndQueueStatus st SndQueue {sndId, server = ProtocolServer {host, port}} status = -- ? throw error if queue does not exist? liftIO . withTransaction st $ \db -> DB.executeNamed @@ -640,7 +658,7 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, -- * Server upsert helper upsertServer_ :: DB.Connection -> SMPServer -> IO () -upsertServer_ dbConn SMPServer {host, port, keyHash} = do +upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do DB.executeNamed dbConn [sql| diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index df28de2c2..db30d3085 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -23,9 +24,10 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md module Simplex.Messaging.Client ( -- * Connect (disconnect) client to (from) SMP server + ProtocolClient, SMPClient, - getSMPClient, - closeSMPClient, + getProtocolClient, + closeProtocolClient, -- * SMP protocol command functions createSMPQueue, @@ -37,13 +39,13 @@ module Simplex.Messaging.Client ackSMPMessage, suspendSMPQueue, deleteSMPQueue, - sendSMPCommand, + sendProtocolCommand, -- * Supporting types and client configuration - SMPClientError (..), - SMPClientConfig (..), - smpDefaultConfig, - SMPServerTransmission, + ProtocolClientError (..), + ProtocolClientConfig (..), + defaultClientConfig, + ServerTransmission, ) where @@ -60,7 +62,7 @@ import Data.Maybe (fromMaybe) import Network.Socket (ServiceName) import Numeric.Natural import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Protocol +import Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError) @@ -72,31 +74,30 @@ import System.Timeout (timeout) -- | 'SMPClient' is a handle used to send commands to a specific SMP server. -- --- The only exported selector is blockSize that is negotiated --- with the server during the TCP transport handshake. --- -- Use 'getSMPClient' to connect to an SMP server and create a client handle. -data SMPClient = SMPClient +data ProtocolClient msg = ProtocolClient { action :: Async (), connected :: TVar Bool, sessionId :: ByteString, - smpServer :: SMPServer, + protocolServer :: ProtocolServer, tcpTimeout :: Int, clientCorrId :: TVar Natural, - sentCommands :: TMap CorrId Request, + sentCommands :: TMap CorrId (Request msg), sndQ :: TBQueue SentRawTransmission, - rcvQ :: TBQueue (SignedTransmission BrokerMsg), - msgQ :: TBQueue SMPServerTransmission + rcvQ :: TBQueue (SignedTransmission msg), + msgQ :: Maybe (TBQueue (ServerTransmission msg)) } --- | Type synonym for transmission from some SPM server queue. -type SMPServerTransmission = (SMPServer, RecipientId, BrokerMsg) +type SMPClient = ProtocolClient SMP.BrokerMsg --- | SMP client configuration. -data SMPClientConfig = SMPClientConfig +-- | Type synonym for transmission from some SPM server queue. +type ServerTransmission msg = (ProtocolServer, QueueId, msg) + +-- | protocol client configuration. +data ProtocolClientConfig = ProtocolClientConfig { -- | size of TBQueue to use for server commands and responses qSize :: Natural, - -- | default SMP server port if port is not specified in SMPServer + -- | default server port if port is not specified in ProtocolServer defaultTransport :: (ServiceName, ATransport), -- | timeout of TCP commands (microseconds) tcpTimeout :: Int, @@ -106,34 +107,35 @@ data SMPClientConfig = SMPClientConfig smpPing :: Int } --- | Default SMP client configuration. -smpDefaultConfig :: SMPClientConfig -smpDefaultConfig = - SMPClientConfig +-- | Default protocol client configuration. +defaultClientConfig :: ProtocolClientConfig +defaultClientConfig = + ProtocolClientConfig { qSize = 64, defaultTransport = ("5223", transport @TLS), - tcpTimeout = 4_000_000, + tcpTimeout = 5_000_000, tcpKeepAlive = Just defaultKeepAliveOpts, - smpPing = 600_000_000 -- 10min + smpPing = 300_000_000 -- 5 min } -data Request = Request +data Request msg = Request { queueId :: QueueId, - responseVar :: TMVar Response + responseVar :: TMVar (Response msg) } -type Response = Either SMPClientError BrokerMsg +type Response msg = Either ProtocolClientError msg --- | Connects to 'SMPServer' using passed client configuration +-- | Connects to 'ProtocolServer' using passed client configuration -- and queue for messages and notifications. -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient) -getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected = - atomically mkSMPClient >>= runClient useTransport +getProtocolClient :: forall msg. Protocol msg => ProtocolServer -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg)) +getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected = + (atomically mkSMPClient >>= runClient useTransport) + `catch` \(e :: IOException) -> pure . Left $ PCEIOError e where - mkSMPClient :: STM SMPClient + mkSMPClient :: STM (ProtocolClient msg) mkSMPClient = do connected <- newTVar False clientCorrId <- newTVar 0 @@ -141,11 +143,11 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp sndQ <- newTBQueue qSize rcvQ <- newTBQueue qSize return - SMPClient + ProtocolClient { action = undefined, sessionId = undefined, connected, - smpServer, + protocolServer, tcpTimeout, clientCorrId, sentCommands, @@ -154,50 +156,51 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp msgQ } - runClient :: (ServiceName, ATransport) -> SMPClient -> IO (Either SMPClientError SMPClient) + runClient :: (ServiceName, ATransport) -> ProtocolClient msg -> IO (Either ProtocolClientError (ProtocolClient msg)) runClient (port', ATransport t) c = do thVar <- newEmptyTMVarIO action <- async $ - runTransportClient (host smpServer) port' (keyHash smpServer) tcpKeepAlive (client t c thVar) - `finally` atomically (putTMVar thVar $ Left SMPNetworkError) + runTransportClient (host protocolServer) port' (keyHash protocolServer) tcpKeepAlive (client t c thVar) + `finally` atomically (putTMVar thVar $ Left PCENetworkError) th_ <- tcpTimeout `timeout` atomically (takeTMVar thVar) pure $ case th_ of Just (Right THandle {sessionId}) -> Right c {action, sessionId} Just (Left e) -> Left e - Nothing -> Left SMPNetworkError + Nothing -> Left PCENetworkError useTransport :: (ServiceName, ATransport) - useTransport = case port smpServer of + useTransport = case port protocolServer of "" -> defaultTransport cfg "80" -> ("80", transport @WS) p -> (p, transport @TLS) - client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError (THandle c)) -> c -> IO () + client :: forall c. Transport c => TProxy c -> ProtocolClient msg -> TMVar (Either ProtocolClientError (THandle c)) -> c -> IO () client _ c thVar h = - runExceptT (smpClientHandshake h $ keyHash smpServer) >>= \case - Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e + runExceptT (smpClientHandshake h $ keyHash protocolServer) >>= \case + Left e -> atomically . putTMVar thVar . Left $ PCETransportError e Right th@THandle {sessionId} -> do atomically $ do writeTVar (connected c) True putTMVar thVar $ Right th - let c' = c {sessionId} :: SMPClient + let c' = c {sessionId} :: ProtocolClient msg + -- TODO remove ping if 0 is passed (or Nothing?) raceAny_ [send c' th, process c', receive c' th, ping c'] `finally` disconnected - send :: Transport c => SMPClient -> THandle c -> IO () - send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h + send :: Transport c => ProtocolClient msg -> THandle c -> IO () + send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h - receive :: Transport c => SMPClient -> THandle c -> IO () - receive SMPClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ + receive :: Transport c => ProtocolClient msg -> THandle c -> IO () + receive ProtocolClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ - ping :: SMPClient -> IO () + ping :: ProtocolClient msg -> IO () ping c = forever $ do threadDelay smpPing - runExceptT $ sendSMPCommand c Nothing "" PING + void . either throwIO pure =<< runExceptT (sendProtocolCommand c Nothing "" protocolPing) - process :: SMPClient -> IO () - process SMPClient {rcvQ, sentCommands} = forever $ do + process :: ProtocolClient msg -> IO () + process ProtocolClient {rcvQ, sentCommands} = forever $ do (_, _, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ if B.null $ bs corrId then sendMsg qId respOrErr @@ -209,45 +212,48 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp putTMVar responseVar $ if queueId == qId then case respOrErr of - Left e -> Left $ SMPResponseError e - Right (ERR e) -> Left $ SMPServerError e - Right r -> Right r - else Left SMPUnexpectedResponse + Left e -> Left $ PCEResponseError e + Right r -> case protocolError r of + Just e -> Left $ PCEProtocolError e + _ -> Right r + else Left PCEUnexpectedResponse - sendMsg :: QueueId -> Either ErrorType BrokerMsg -> IO () + sendMsg :: QueueId -> Either ErrorType msg -> IO () sendMsg qId = \case - Right cmd -> atomically $ writeTBQueue msgQ (smpServer, qId, cmd) + Right cmd -> atomically $ mapM_ (`writeTBQueue` (protocolServer, qId, cmd)) msgQ -- TODO send everything else to errQ and log in agent _ -> return () --- | Disconnects SMP client from the server and terminates client threads. -closeSMPClient :: SMPClient -> IO () -closeSMPClient = uninterruptibleCancel . action +-- | Disconnects client from the server and terminates client threads. +closeProtocolClient :: ProtocolClient msg -> IO () +closeProtocolClient = uninterruptibleCancel . action -- | SMP client error type. -data SMPClientError +data ProtocolClientError = -- | Correctly parsed SMP server ERR response. -- This error is forwarded to the agent client as `ERR SMP err`. - SMPServerError ErrorType + PCEProtocolError ErrorType | -- | Invalid server response that failed to parse. -- Forwarded to the agent client as `ERR BROKER RESPONSE`. - SMPResponseError ErrorType + PCEResponseError ErrorType | -- | Different response from what is expected to a certain SMP command, -- e.g. server should respond `IDS` or `ERR` to `NEW` command, -- other responses would result in this error. -- Forwarded to the agent client as `ERR BROKER UNEXPECTED`. - SMPUnexpectedResponse + PCEUnexpectedResponse | -- | Used for TCP connection and command response timeouts. -- Forwarded to the agent client as `ERR BROKER TIMEOUT`. - SMPResponseTimeout + PCEResponseTimeout | -- | Failure to establish TCP connection. -- Forwarded to the agent client as `ERR BROKER NETWORK`. - SMPNetworkError + PCENetworkError | -- | TCP transport handshake or some other transport error. -- Forwarded to the agent client as `ERR BROKER TRANSPORT e`. - SMPTransportError TransportError + PCETransportError TransportError | -- | Error when cryptographically "signing" the command. - SMPSignatureError C.CryptoError + PCESignatureError C.CryptoError + | -- | IO Error + PCEIOError IOException deriving (Eq, Show, Exception) -- | Create a new SMP queue. @@ -258,92 +264,95 @@ createSMPQueue :: RcvPrivateSignKey -> RcvPublicVerifyKey -> RcvPublicDhKey -> - ExceptT SMPClientError IO QueueIdsKeys + ExceptT ProtocolClientError IO QueueIdsKeys createSMPQueue c rpKey rKey dhKey = sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey) >>= \case IDS qik -> pure qik - _ -> throwE SMPUnexpectedResponse + _ -> throwE PCEUnexpectedResponse -- | Subscribe to the SMP queue. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue -subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO () -subscribeSMPQueue c@SMPClient {smpServer, msgQ} rpKey rId = +subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO () +subscribeSMPQueue c@ProtocolClient {protocolServer, msgQ} rpKey rId = sendSMPCommand c (Just rpKey) rId SUB >>= \case OK -> return () cmd@MSG {} -> - lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd) - _ -> throwE SMPUnexpectedResponse + lift . atomically $ mapM_ (`writeTBQueue` (protocolServer, rId, cmd)) msgQ + _ -> throwE PCEUnexpectedResponse -- | Subscribe to the SMP queue notifications. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications -subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO () +subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT ProtocolClientError IO () subscribeSMPQueueNotifications = okSMPCommand NSUB -- | Secure the SMP queue by adding a sender public key. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command -secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO () +secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT ProtocolClientError IO () secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId -- | Enable notifications for the queue for push notifications server. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command -enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT SMPClientError IO NotifierId +enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT ProtocolClientError IO NotifierId enableSMPQueueNotifications c rpKey rId notifierKey = sendSMPCommand c (Just rpKey) rId (NKEY notifierKey) >>= \case NID nId -> pure nId - _ -> throwE SMPUnexpectedResponse + _ -> throwE PCEUnexpectedResponse -- | Send SMP message. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message -sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT SMPClientError IO () +sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT ProtocolClientError IO () sendSMPMessage c spKey sId msg = sendSMPCommand c spKey sId (SEND msg) >>= \case OK -> pure () - _ -> throwE SMPUnexpectedResponse + _ -> throwE PCEUnexpectedResponse -- | Acknowledge message delivery (server deletes the message). -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery -ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () -ackSMPMessage c@SMPClient {smpServer, msgQ} rpKey rId = +ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO () +ackSMPMessage c@ProtocolClient {protocolServer, msgQ} rpKey rId = sendSMPCommand c (Just rpKey) rId ACK >>= \case OK -> return () cmd@MSG {} -> - lift . atomically $ writeTBQueue msgQ (smpServer, rId, cmd) - _ -> throwE SMPUnexpectedResponse + lift . atomically $ mapM_ (`writeTBQueue` (protocolServer, rId, cmd)) msgQ + _ -> throwE PCEUnexpectedResponse -- | Irreversibly suspend SMP queue. -- The existing messages from the queue will still be delivered. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue -suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () +suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO () suspendSMPQueue = okSMPCommand OFF -- | Irreversibly delete SMP queue and all messages in it. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue -deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO () +deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO () deleteSMPQueue = okSMPCommand DEL -okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO () +okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO () okSMPCommand cmd c pKey qId = sendSMPCommand c (Just pKey) qId cmd >>= \case OK -> return () - _ -> throwE SMPUnexpectedResponse + _ -> throwE PCEUnexpectedResponse -- | Send SMP command --- TODO sign all requests (SEND of SMP confirmation would be signed with the same key that is passed to the recipient) -sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg -sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeout} pKey qId cmd = do +sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg +sendSMPCommand c pKey qId = sendProtocolCommand c pKey qId . Cmd sParty + +-- | Send Protocol command +sendProtocolCommand :: forall msg. ProtocolEncoding (ProtocolCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtocolCommand msg -> ExceptT ProtocolClientError IO msg +sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeout} pKey qId cmd = do corrId <- lift_ getNextCorrId t <- signTransmission $ encodeTransmission sessionId (corrId, qId, cmd) ExceptT $ sendRecv corrId t where - lift_ :: STM a -> ExceptT SMPClientError IO a + lift_ :: STM a -> ExceptT ProtocolClientError IO a lift_ action = ExceptT $ Right <$> atomically action getNextCorrId :: STM CorrId @@ -351,20 +360,20 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeou i <- stateTVar clientCorrId $ \i -> (i, i + 1) pure . CorrId $ bshow i - signTransmission :: ByteString -> ExceptT SMPClientError IO SentRawTransmission + signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission signTransmission t = case pKey of Nothing -> return (Nothing, t) Just pk -> do - sig <- liftError SMPSignatureError $ C.sign pk t + sig <- liftError PCESignatureError $ C.sign pk t return (Just sig, t) -- two separate "atomically" needed to avoid blocking - sendRecv :: CorrId -> SentRawTransmission -> IO Response + sendRecv :: CorrId -> SentRawTransmission -> IO (Response msg) sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar where - withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a + withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a - send :: CorrId -> SentRawTransmission -> STM (TMVar Response) + send :: CorrId -> SentRawTransmission -> STM (TMVar (Response msg)) send corrId t = do r <- newEmptyTMVar TM.insert corrId (Request qId r) sentCommands diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs new file mode 100644 index 000000000..7a62dc465 --- /dev/null +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -0,0 +1,301 @@ +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Simplex.Messaging.Client.Agent where + +import Control.Concurrent (forkIO) +import Control.Concurrent.Async (Async, uninterruptibleCancel) +import Control.Logger.Simple +import Control.Monad.Except +import Control.Monad.IO.Unlift +import Control.Monad.Trans.Except +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import Data.Set (Set) +import Data.Text.Encoding +import Numeric.Natural +import Simplex.Messaging.Agent.RetryInterval +import Simplex.Messaging.Client +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer) +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util (tryE, whenM) +import System.Timeout (timeout) +import UnliftIO (async, forConcurrently_) +import UnliftIO.Exception (Exception) +import qualified UnliftIO.Exception as E +import UnliftIO.STM + +type SMPClientVar = TMVar (Either ProtocolClientError SMPClient) + +data SMPClientAgentEvent + = CAConnected SMPServer + | CADisconnected SMPServer (Set SMPSub) + | CAReconnected SMPServer + | CAResubscribed SMPServer SMPSub + | CASubError SMPServer SMPSub ProtocolClientError + +data SMPSubParty = SPRecipient | SPNotifier + deriving (Eq, Ord) + +type SMPSub = (SMPSubParty, QueueId) + +-- type SMPServerSub = (SMPServer, SMPSub) + +data SMPClientAgentConfig = SMPClientAgentConfig + { smpCfg :: ProtocolClientConfig, + reconnectInterval :: RetryInterval, + msgQSize :: Natural, + agentQSize :: Natural + } + +defaultSMPClientAgentConfig :: SMPClientAgentConfig +defaultSMPClientAgentConfig = + SMPClientAgentConfig + { smpCfg = defaultClientConfig, + reconnectInterval = + RetryInterval + { initialInterval = second, + increaseAfter = 10 * second, + maxInterval = 10 * second + }, + msgQSize = 64, + agentQSize = 64 + } + where + second = 1000000 + +data SMPClientAgent = SMPClientAgent + { agentCfg :: SMPClientAgentConfig, + msgQ :: TBQueue (ServerTransmission BrokerMsg), + agentQ :: TBQueue SMPClientAgentEvent, + smpClients :: TMap SMPServer SMPClientVar, + srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateSignKey), + pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateSignKey), + reconnections :: TVar [Async ()], + asyncClients :: TVar [Async ()] + } + +newtype InternalException e = InternalException {unInternalException :: e} + deriving (Eq, Show) + +instance Exception e => Exception (InternalException e) + +instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where + withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b + withRunInIO exceptToIO = + withExceptT unInternalException . ExceptT . E.try $ + withRunInIO $ \run -> + exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT) + +newSMPClientAgent :: SMPClientAgentConfig -> STM SMPClientAgent +newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} = do + msgQ <- newTBQueue msgQSize + agentQ <- newTBQueue agentQSize + smpClients <- TM.empty + srvSubs <- TM.empty + pendingSrvSubs <- TM.empty + reconnections <- newTVar [] + asyncClients <- newTVar [] + pure SMPClientAgent {agentCfg, msgQ, agentQ, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients} + +getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT ProtocolClientError IO SMPClient +getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = + atomically getClientVar >>= either newSMPClient waitForSMPClient + where + getClientVar :: STM (Either SMPClientVar SMPClientVar) + getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients + + newClientVar :: STM SMPClientVar + newClientVar = do + smpVar <- newEmptyTMVar + TM.insert srv smpVar smpClients + pure smpVar + + waitForSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient + waitForSMPClient smpVar = do + let ProtocolClientConfig {tcpTimeout} = smpCfg agentCfg + smpClient_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar smpVar) + liftEither $ case smpClient_ of + Just (Right smpClient) -> Right smpClient + Just (Left e) -> Left e + Nothing -> Left PCEResponseTimeout + + newSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient + newSMPClient smpVar = tryConnectClient pure tryConnectAsync + where + tryConnectClient :: (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO () -> ExceptT ProtocolClientError IO a + tryConnectClient successAction retryAction = + tryE connectClient >>= \r -> case r of + Right smp -> do + logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv + atomically $ putTMVar smpVar r + successAction smp + Left e -> do + if e == PCENetworkError || e == PCEResponseTimeout + then retryAction + else atomically $ do + putTMVar smpVar (Left e) + TM.delete srv smpClients + throwE e + tryConnectAsync :: ExceptT ProtocolClientError IO () + tryConnectAsync = do + a <- async connectAsync + atomically $ modifyTVar' (asyncClients ca) (a :) + connectAsync :: ExceptT ProtocolClientError IO () + connectAsync = + withRetryInterval (reconnectInterval agentCfg) $ \loop -> + void $ tryConnectClient (const reconnectClient) loop + + connectClient :: ExceptT ProtocolClientError IO SMPClient + connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected + + clientDisconnected :: IO () + clientDisconnected = do + removeClientAndSubs >>= (`forM_` serverDown) + logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv + + removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateSignKey)) + removeClientAndSubs = atomically $ do + TM.delete srv smpClients + TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs + where + updateSubs sVar = do + ss <- readTVar sVar + addPendingSubs sVar ss + pure ss + + addPendingSubs sVar ss = do + let ps = pendingSrvSubs ca + TM.lookup srv ps >>= \case + Just v -> TM.union ss v + _ -> TM.insert srv sVar ps + + serverDown :: Map SMPSub C.APrivateSignKey -> IO () + serverDown ss = unless (M.null ss) . void . runExceptT $ do + notify . CADisconnected srv $ M.keysSet ss + reconnectServer + + reconnectServer :: ExceptT ProtocolClientError IO () + reconnectServer = do + a <- async tryReconnectClient + atomically $ modifyTVar' (reconnections ca) (a :) + + tryReconnectClient :: ExceptT ProtocolClientError IO () + tryReconnectClient = do + withRetryInterval (reconnectInterval agentCfg) $ \loop -> + reconnectClient `catchE` const loop + + reconnectClient :: ExceptT ProtocolClientError IO () + reconnectClient = do + withSMP ca srv $ \smp -> do + notify $ CAReconnected srv + cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca) + forConcurrently_ (maybe [] M.assocs cs) $ \sub@(s, _) -> + whenM (atomically $ hasSub (srvSubs ca) srv s) $ + subscribe_ smp sub `catchE` handleError s + where + subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO () + subscribe_ smp sub@(s, _) = do + smpSubscribe smp sub + atomically $ addSubscription ca srv sub + notify $ CAResubscribed srv s + + handleError :: SMPSub -> ProtocolClientError -> ExceptT ProtocolClientError IO () + handleError s = \case + e@PCEResponseTimeout -> throwE e + e@PCENetworkError -> throwE e + e -> do + notify $ CASubError srv s e + atomically $ removePendingSubscription ca srv s + + notify :: SMPClientAgentEvent -> ExceptT ProtocolClientError IO () + notify evt = atomically $ writeTBQueue (agentQ ca) evt + +closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m () +closeSMPClientAgent c = liftIO $ do + closeSMPServerClients c + cancelActions $ reconnections c + cancelActions $ asyncClients c + +closeSMPServerClients :: SMPClientAgent -> IO () +closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient) + where + closeClient smpVar = + atomically (readTMVar smpVar) >>= \case + Right smp -> closeProtocolClient smp `E.catch` \(_ :: E.SomeException) -> pure () + _ -> pure () + +cancelActions :: Foldable f => TVar (f (Async ())) -> IO () +cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel + +withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO a +withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError + where + logSMPError :: ProtocolClientError -> ExceptT ProtocolClientError IO a + logSMPError e = do + liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e + throwE e + +subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO () +subscribeQueue ca srv sub = do + atomically $ addPendingSubscription ca srv sub + withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleError + where + subscribe_ smp = do + smpSubscribe smp sub + atomically $ addSubscription ca srv sub + + handleError e = do + atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $ + removePendingSubscription ca srv $ fst sub + throwE e + +showServer :: SMPServer -> ByteString +showServer ProtocolServer {host, port} = + B.pack $ host <> if null port then "" else ':' : port + +smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO () +smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId + where + subscribe_ = case party of + SPRecipient -> subscribeSMPQueue + SPNotifier -> subscribeSMPQueueNotifications + +addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM () +addSubscription ca srv sub = do + addSub_ (srvSubs ca) srv sub + removePendingSubscription ca srv $ fst sub + +addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM () +addPendingSubscription = addSub_ . pendingSrvSubs + +addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM () +addSub_ subs srv (s, key) = + TM.lookup srv subs >>= \case + Just m -> TM.insert s key m + _ -> TM.singleton s key >>= \v -> TM.insert srv v subs + +removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () +removeSubscription = removeSub_ . srvSubs + +removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () +removePendingSubscription = removeSub_ . pendingSrvSubs + +removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM () +removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s) + +getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateSignKey) +getSubKey subs srv s = fmap join . mapM (TM.lookup s) =<< TM.lookup srv subs + +hasSub :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM Bool +hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs new file mode 100644 index 000000000..48fc89b1e --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Notifications.Client where + +import Control.Monad.Except +import Control.Monad.Trans.Except +import Data.Word (Word16) +import Simplex.Messaging.Client +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Notifications.Protocol + +type NtfClient = ProtocolClient NtfResponse + +registerNtfToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519) +registerNtfToken c pKey newTkn = + sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case + NRId tknId dhKey -> pure (tknId, dhKey) + _ -> throwE PCEUnexpectedResponse + +verifyNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegistrationCode -> ExceptT ProtocolClientError IO () +verifyNtfToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId + +deleteNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO () +deleteNtfToken = okNtfCommand TDEL + +enableNtfCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO () +enableNtfCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId + +createNtfSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519) +createNtfSubsciption c pKey newSub = + sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case + NRId tknId dhKey -> pure (tknId, dhKey) + _ -> throwE PCEUnexpectedResponse + +checkNtfSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus +checkNtfSubscription c pKey subId = + sendNtfCommand c (Just pKey) subId SCHK >>= \case + NRStat stat -> pure stat + _ -> throwE PCEUnexpectedResponse + +deleteNfgSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO () +deleteNfgSubscription = okNtfCommand SDEL + +-- | Send notification server command +sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse +sendNtfCommand c pKey entId = sendProtocolCommand c pKey entId . NtfCmd sNtfEntity + +okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT ProtocolClientError IO () +okNtfCommand cmd c pKey entId = + sendNtfCommand c (Just pKey) entId cmd >>= \case + NROk -> return () + _ -> throwE PCEUnexpectedResponse diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index da66280fd..1fac10e7e 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -1,6 +1,11 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} module Simplex.Messaging.Notifications.Protocol where @@ -8,154 +13,333 @@ module Simplex.Messaging.Notifications.Protocol where import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Kind import Data.Maybe (isNothing) +import Data.Type.Equality +import Data.Word (Word16) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding -import Simplex.Messaging.Protocol +import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) +import Simplex.Messaging.Util ((<$?>)) -data NtfCommandTag - = NCCreate_ - | NCCheck_ - | NCToken_ - | NCDelete_ +data NtfEntity = Token | Subscription deriving (Show) -instance Encoding NtfCommandTag where +data SNtfEntity :: NtfEntity -> Type where + SToken :: SNtfEntity 'Token + SSubscription :: SNtfEntity 'Subscription + +instance TestEquality SNtfEntity where + testEquality SToken SToken = Just Refl + testEquality SSubscription SSubscription = Just Refl + testEquality _ _ = Nothing + +deriving instance Show (SNtfEntity e) + +class NtfEntityI (e :: NtfEntity) where sNtfEntity :: SNtfEntity e + +instance NtfEntityI 'Token where sNtfEntity = SToken + +instance NtfEntityI 'Subscription where sNtfEntity = SSubscription + +data NtfCommandTag (e :: NtfEntity) where + TNEW_ :: NtfCommandTag 'Token + TVFY_ :: NtfCommandTag 'Token + TDEL_ :: NtfCommandTag 'Token + TCRN_ :: NtfCommandTag 'Token + SNEW_ :: NtfCommandTag 'Subscription + SCHK_ :: NtfCommandTag 'Subscription + SDEL_ :: NtfCommandTag 'Subscription + PING_ :: NtfCommandTag 'Subscription + +deriving instance Show (NtfCommandTag e) + +data NtfCmdTag = forall e. NtfEntityI e => NCT (SNtfEntity e) (NtfCommandTag e) + +instance NtfEntityI e => Encoding (NtfCommandTag e) where smpEncode = \case - NCCreate_ -> "CREATE" - NCCheck_ -> "CHECK" - NCToken_ -> "TOKEN" - NCDelete_ -> "DELETE" + TNEW_ -> "TNEW" + TVFY_ -> "TVFY" + TDEL_ -> "TDEL" + TCRN_ -> "TCRN" + SNEW_ -> "SNEW" + SCHK_ -> "SCHK" + SDEL_ -> "SDEL" + PING_ -> "PING" smpP = messageTagP -instance ProtocolMsgTag NtfCommandTag where +instance Encoding NtfCmdTag where + smpEncode (NCT _ t) = smpEncode t + smpP = messageTagP + +instance ProtocolMsgTag NtfCmdTag where decodeTag = \case - "CREATE" -> Just NCCreate_ - "CHECK" -> Just NCCheck_ - "TOKEN" -> Just NCToken_ - "DELETE" -> Just NCDelete_ + "TNEW" -> Just $ NCT SToken TNEW_ + "TVFY" -> Just $ NCT SToken TVFY_ + "TDEL" -> Just $ NCT SToken TDEL_ + "TCRN" -> Just $ NCT SToken TCRN_ + "SNEW" -> Just $ NCT SSubscription SNEW_ + "SCHK" -> Just $ NCT SSubscription SCHK_ + "SDEL" -> Just $ NCT SSubscription SDEL_ + "PING" -> Just $ NCT SSubscription PING_ _ -> Nothing -data NtfCommand - = NCCreate DeviceToken SMPQueueNtfUri C.APublicVerifyKey C.PublicKeyX25519 - | NCCheck - | NCToken DeviceToken - | NCDelete +instance NtfEntityI e => ProtocolMsgTag (NtfCommandTag e) where + decodeTag s = decodeTag s >>= (\(NCT _ t) -> checkEntity' t) -instance Protocol NtfCommand where - type Tag NtfCommand = NtfCommandTag +type NtfRegistrationCode = ByteString + +data NewNtfEntity (e :: NtfEntity) where + NewNtfTkn :: DeviceToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> NewNtfEntity 'Token + NewNtfSub :: NtfTokenId -> SMPQueueNtf -> NewNtfEntity 'Subscription + +data ANewNtfEntity = forall e. NtfEntityI e => ANE (SNtfEntity e) (NewNtfEntity e) + +instance NtfEntityI e => Encoding (NewNtfEntity e) where + smpEncode = \case + NewNtfTkn tkn verifyKey dhPubKey -> smpEncode ('T', tkn, verifyKey, dhPubKey) + NewNtfSub tknId smpQueue -> smpEncode ('S', tknId, smpQueue) + smpP = (\(ANE _ c) -> checkEntity c) <$?> smpP + +instance Encoding ANewNtfEntity where + smpEncode (ANE _ e) = smpEncode e + smpP = + A.anyChar >>= \case + 'T' -> ANE SToken <$> (NewNtfTkn <$> smpP <*> smpP <*> smpP) + 'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP) + _ -> fail "bad ANewNtfEntity" + +instance Protocol NtfResponse where + type ProtocolCommand NtfResponse = NtfCmd + protocolPing = NtfCmd SSubscription PING + protocolError = \case + NRErr e -> Just e + _ -> Nothing + +data NtfCommand (e :: NtfEntity) where + -- | register new device token for notifications + TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token + -- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token + TVFY :: NtfRegistrationCode -> NtfCommand 'Token + -- | delete token - all subscriptions will be removed and no more notifications will be sent + TDEL :: NtfCommand 'Token + -- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable + TCRN :: Word16 -> NtfCommand 'Token + -- | create SMP subscription + SNEW :: NewNtfEntity 'Subscription -> NtfCommand 'Subscription + -- | check SMP subscription status (response is STAT) + SCHK :: NtfCommand 'Subscription + -- | delete SMP subscription + SDEL :: NtfCommand 'Subscription + -- | keep-alive command + PING :: NtfCommand 'Subscription + +data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e) + +instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where + type Tag (NtfCommand e) = NtfCommandTag e encodeProtocol = \case - NCCreate token smpQueue verifyKey dhKey -> e (NCCreate_, ' ', token, smpQueue, verifyKey, dhKey) - NCCheck -> e NCCheck_ - NCToken token -> e (NCToken_, ' ', token) - NCDelete -> e NCDelete_ + TNEW newTkn -> e (TNEW_, ' ', newTkn) + TVFY code -> e (TVFY_, ' ', code) + TDEL -> e TDEL_ + TCRN int -> e (TCRN_, ' ', int) + SNEW newSub -> e (SNEW_, ' ', newSub) + SCHK -> e SCHK_ + SDEL -> e SDEL_ + PING -> e PING_ where e :: Encoding a => a -> ByteString e = smpEncode - protocolP = \case - NCCreate_ -> NCCreate <$> _smpP <*> smpP <*> smpP <*> smpP - NCCheck_ -> pure NCCheck - NCToken_ -> NCToken <$> _smpP - NCDelete_ -> pure NCDelete + protocolP tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP (NCT (sNtfEntity @e) tag) - checkCredentials (sig, _, subId, _) cmd = case cmd of - -- CREATE must have signature but NOT subscription ID - NCCreate {} - | isNothing sig -> Left $ CMD NO_AUTH - | not (B.null subId) -> Left $ CMD HAS_AUTH - | otherwise -> Right cmd - -- other client commands must have both signature and subscription ID + checkCredentials (sig, _, entityId, _) cmd = case cmd of + -- TNEW and SNEW must have signature but NOT token/subscription IDs + TNEW {} -> sigNoEntity + SNEW {} -> sigNoEntity + PING + | isNothing sig && B.null entityId -> Right cmd + | otherwise -> Left $ CMD HAS_AUTH + -- other client commands must have both signature and entity ID _ - | isNothing sig || B.null subId -> Left $ CMD NO_AUTH + | isNothing sig || B.null entityId -> Left $ CMD NO_AUTH | otherwise -> Right cmd + where + sigNoEntity + | isNothing sig = Left $ CMD NO_AUTH + | not (B.null entityId) = Left $ CMD HAS_AUTH + | otherwise = Right cmd + +instance ProtocolEncoding NtfCmd where + type Tag NtfCmd = NtfCmdTag + encodeProtocol (NtfCmd _ c) = encodeProtocol c + + protocolP = \case + NCT SToken tag -> + NtfCmd SToken <$> case tag of + TNEW_ -> TNEW <$> _smpP + TVFY_ -> TVFY <$> _smpP + TDEL_ -> pure TDEL + TCRN_ -> TCRN <$> _smpP + NCT SSubscription tag -> + NtfCmd SSubscription <$> case tag of + SNEW_ -> SNEW <$> _smpP + SCHK_ -> pure SCHK + SDEL_ -> pure SDEL + PING_ -> pure PING + + checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c data NtfResponseTag - = NRSubId_ + = NRId_ | NROk_ | NRErr_ | NRStat_ + | NRPong_ deriving (Show) instance Encoding NtfResponseTag where smpEncode = \case - NRSubId_ -> "ID" + NRId_ -> "ID" NROk_ -> "OK" NRErr_ -> "ERR" NRStat_ -> "STAT" + NRPong_ -> "PONG" smpP = messageTagP instance ProtocolMsgTag NtfResponseTag where decodeTag = \case - "ID" -> Just NRSubId_ + "ID" -> Just NRId_ "OK" -> Just NROk_ "ERR" -> Just NRErr_ "STAT" -> Just NRStat_ + "PONG" -> Just NRPong_ _ -> Nothing data NtfResponse - = NRSubId C.PublicKeyX25519 + = NRId NtfEntityId C.PublicKeyX25519 | NROk | NRErr ErrorType - | NRStat NtfStatus + | NRStat NtfSubStatus + | NRPong -instance Protocol NtfResponse where +instance ProtocolEncoding NtfResponse where type Tag NtfResponse = NtfResponseTag encodeProtocol = \case - NRSubId dhKey -> e (NRSubId_, ' ', dhKey) + NRId entId dhKey -> e (NRId_, ' ', entId, dhKey) NROk -> e NROk_ NRErr err -> e (NRErr_, ' ', err) NRStat stat -> e (NRStat_, ' ', stat) + NRPong -> e NRPong_ where e :: Encoding a => a -> ByteString e = smpEncode protocolP = \case - NRSubId_ -> NRSubId <$> _smpP + NRId_ -> NRId <$> _smpP <*> smpP NROk_ -> pure NROk NRErr_ -> NRErr <$> _smpP NRStat_ -> NRStat <$> _smpP + NRPong_ -> pure NRPong - checkCredentials (_, _, subId, _) cmd = case cmd of - -- ERR response does not always have subscription ID + checkCredentials (_, _, entId, _) cmd = case cmd of + -- ID response must not have queue ID + NRId {} -> noEntity + -- ERR response does not always have entity ID NRErr _ -> Right cmd - -- other server responses must have subscription ID + -- PONG response must not have queue ID + NRPong -> noEntity + -- other server responses must have entity ID _ - | B.null subId -> Left $ CMD NO_ENTITY + | B.null entId -> Left $ CMD NO_ENTITY | otherwise -> Right cmd + where + noEntity + | B.null entId = Right cmd + | otherwise = Left $ CMD HAS_AUTH -data SMPQueueNtfUri = SMPQueueNtfUri - { smpServer :: SMPServer, +data SMPQueueNtf = SMPQueueNtf + { smpServer :: ProtocolServer, notifierId :: NotifierId, notifierKey :: NtfPrivateSignKey } -instance Encoding SMPQueueNtfUri where - smpEncode SMPQueueNtfUri {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey) +instance Encoding SMPQueueNtf where + smpEncode SMPQueueNtf {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey) smpP = do (smpServer, notifierId, notifierKey) <- smpP - pure $ SMPQueueNtfUri smpServer notifierId notifierKey + pure $ SMPQueueNtf smpServer notifierId notifierKey -newtype DeviceToken = DeviceToken ByteString +data PushPlatform = PPApple + +instance Encoding PushPlatform where + smpEncode = \case + PPApple -> "A" + smpP = + A.anyChar >>= \case + 'A' -> pure PPApple + _ -> fail "bad PushPlatform" + +data DeviceToken = DeviceToken PushPlatform ByteString instance Encoding DeviceToken where - smpEncode (DeviceToken t) = smpEncode t - smpP = DeviceToken <$> smpP + smpEncode (DeviceToken p t) = smpEncode (p, t) + smpP = DeviceToken <$> smpP <*> smpP -type NtfSubsciptionId = ByteString +type NtfEntityId = ByteString -data NtfStatus = NSPending | NSActive | NSEnd | NSSMPAuth +type NtfSubscriptionId = NtfEntityId -instance Encoding NtfStatus where +type NtfTokenId = NtfEntityId + +data NtfSubStatus + = -- | state after SNEW + NSNew + | -- | pending connection/subscription to SMP server + NSPending + | -- | connected and subscribed to SMP server + NSActive + | -- | NEND received (we currently do not support it) + NSEnd + | -- | SMP AUTH error + NSSMPAuth + deriving (Eq) + +instance Encoding NtfSubStatus where smpEncode = \case - NSPending -> "PENDING" + NSNew -> "NEW" + NSPending -> "PENDING" -- e.g. after SMP server disconnect/timeout while ntf server is retrying to connect NSActive -> "ACTIVE" NSEnd -> "END" NSSMPAuth -> "SMP_AUTH" smpP = A.takeTill (== ' ') >>= \case + "NEW" -> pure NSNew "PENDING" -> pure NSPending "ACTIVE" -> pure NSActive "END" -> pure NSEnd "SMP_AUTH" -> pure NSSMPAuth _ -> fail "bad NtfError" + +data NtfTknStatus + = -- | state after registration (TNEW) + NTNew + | -- | if initial notification or verification failed (push provider error) + NTInvalid + | -- | if initial notification succeeded + NTConfirmed + | -- | after successful verification (TVFY) + NTActive + | -- | after it is no longer valid (push provider error) + NTExpired + deriving (Eq) + +checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e) +checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of + Just Refl -> Right c + Nothing -> Left "bad command party" + +checkEntity' :: forall t p p'. (NtfEntityI p, NtfEntityI p') => t p' -> Maybe (t p) +checkEntity' c = case testEquality (sNtfEntity @p) (sNtfEntity @p') of + Just Refl -> Just c + _ -> Nothing diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index dfdaae801..13d4671e6 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} module Simplex.Messaging.Notifications.Server where @@ -12,13 +15,16 @@ import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (MonadRandom) import Data.ByteString.Char8 (ByteString) +import Data.Functor (($>)) import Network.Socket (ServiceName) +import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Env import Simplex.Messaging.Notifications.Server.Subscriptions import Simplex.Messaging.Notifications.Transport -import Simplex.Messaging.Protocol (ErrorType (..), Transmission, encodeTransmission, tGet, tPut) +import Simplex.Messaging.Protocol (ErrorType (..), SignedTransmission, Transmission, encodeTransmission, tGet, tPut) +import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server import Simplex.Messaging.Transport (ATransport (..), THandle (..), TProxy, Transport) import Simplex.Messaging.Transport.Server (runTransportServer) @@ -37,7 +43,10 @@ runNtfServerBlocking started cfg@NtfServerConfig {transports} = do runReaderT ntfServer env where ntfServer :: (MonadUnliftIO m', MonadReader NtfEnv m') => m' () - ntfServer = raceAny_ (map runServer transports) + ntfServer = do + s <- asks subscriber + ps <- asks pushServer + raceAny_ (ntfSubscriber s : ntfPush ps : map runServer transports) runServer :: (MonadUnliftIO m', MonadReader NtfEnv m') => (ServiceName, ATransport) -> m' () runServer (tcpPort, ATransport t) = do @@ -51,11 +60,61 @@ runNtfServerBlocking started cfg@NtfServerConfig {transports} = do Right th -> runNtfClientTransport th Left _ -> pure () +ntfSubscriber :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfSubscriber -> m () +ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = do + raceAny_ [subscribe, receiveSMP, receiveAgent] + where + subscribe :: m () + subscribe = forever $ do + atomically (readTBQueue subQ) >>= \case + NtfSub NtfSubData {smpQueue} -> do + let SMPQueueNtf {smpServer, notifierId, notifierKey} = smpQueue + liftIO (runExceptT $ subscribeQueue ca smpServer ((SPNotifier, notifierId), notifierKey)) >>= \case + Right _ -> pure () -- update subscription status + Left e -> pure () + + receiveSMP :: m () + receiveSMP = forever $ do + (srv, ntfId, msg) <- atomically $ readTBQueue msgQ + case msg of + SMP.NMSG -> do + -- check when the last NMSG was received from this queue + -- update timestamp + -- check what was the last hidden notification was sent (and whether to this queue) + -- decide whether it should be sent as hidden or visible + -- construct and possibly encrypt notification + -- send it + pure () + _ -> pure () + pure () + + receiveAgent = + forever $ + atomically (readTBQueue agentQ) >>= \case + CAConnected _ -> pure () + CADisconnected srv subs -> do + -- update subscription statuses + pure () + CAReconnected _ -> pure () + CAResubscribed srv sub -> do + -- update subscription status + pure () + CASubError srv sub err -> do + -- update subscription status + pure () + +ntfPush :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m () +ntfPush NtfPushServer {pushQ} = forever $ do + atomically (readTBQueue pushQ) >>= \case + (NtfTknData {}, Notification {}) -> pure () + runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m () runNtfClientTransport th@THandle {sessionId} = do - q <- asks $ tbqSize . config - c <- atomically $ newNtfServerClient q sessionId - raceAny_ [send th c, client c, receive th c] + qSize <- asks $ clientQSize . config + c <- atomically $ newNtfServerClient qSize sessionId + s <- asks subscriber + ps <- asks pushServer + raceAny_ [send th c, client c s ps, receive th c] `finally` clientDisconnected c clientDisconnected :: MonadUnliftIO m => NtfServerClient -> m () @@ -63,14 +122,13 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m () receive th NtfServerClient {rcvQ, sndQ} = forever $ do - (sig, signed, (corrId, queueId, cmdOrError)) <- tGet th + t@(sig, signed, (corrId, subId, cmdOrError)) <- tGet th case cmdOrError of - Left e -> write sndQ (corrId, queueId, NRErr e) - Right cmd -> do - verified <- verifyTransmission sig signed queueId cmd - if verified - then write rcvQ (corrId, queueId, cmd) - else write sndQ (corrId, queueId, NRErr AUTH) + Left e -> write sndQ (corrId, subId, NRErr e) + Right cmd -> + verifyNtfTransmission t cmd >>= \case + VRVerified req -> write rcvQ req + VRFailed -> write sndQ (corrId, subId, NRErr AUTH) where write q t = atomically $ writeTBQueue q t @@ -79,33 +137,121 @@ send h NtfServerClient {sndQ, sessionId} = forever $ do t <- atomically $ readTBQueue sndQ liftIO $ tPut h (Nothing, encodeTransmission sessionId t) -verifyTransmission :: - forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => Maybe C.ASignature -> ByteString -> NtfSubsciptionId -> NtfCommand -> m Bool -verifyTransmission sig_ signed subId cmd = do - case cmd of - NCCreate _ _ k _ -> pure $ verifyCmdSignature sig_ signed k - _ -> do - st <- asks store - verifySubCmd <$> atomically (getNtfSubscription st subId) - where - verifySubCmd = \case - Right sub -> verifyCmdSignature sig_ signed $ subVerifyKey sub - Left _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` False +data VerificationResult = VRVerified NtfRequest | VRFailed -client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> m () -client NtfServerClient {rcvQ, sndQ} = +verifyNtfTransmission :: + forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => SignedTransmission NtfCmd -> NtfCmd -> m VerificationResult +verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do + case cmd of + NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) -> + -- TODO check that token is not already in store + pure $ + if verifyCmdSignature sig_ signed k + then VRVerified (NtfReqNew corrId (ANE SToken n)) + else VRFailed + NtfCmd SToken c -> do + st <- asks store + atomically (getNtfToken st entId) >>= \case + Just r@(NtfTkn NtfTknData {tknVerifyKey}) -> + pure $ + if verifyCmdSignature sig_ signed tknVerifyKey + then VRVerified (NtfReqCmd SToken r (corrId, entId, c)) + else VRFailed + _ -> pure VRFailed -- TODO dummy verification + _ -> pure VRFailed + +-- do +-- st <- asks store +-- case cmd of +-- NCSubCreate tokenId smpQueue -> verifyCreateCmd verifyKey newSub <$> atomically (getNtfSubViaSMPQueue st smpQueue) +-- _ -> verifySubCmd <$> atomically (getNtfSub st subId) +-- where +-- verifyCreateCmd k newSub sub_ +-- | verifyCmdSignature sig_ signed k = case sub_ of +-- Just sub -> if k == subVerifyKey sub then VRCommand sub else VRFail +-- _ -> VRCreate newSub +-- | otherwise = VRFail +-- verifySubCmd = \case +-- Just sub -> if verifyCmdSignature sig_ signed $ subVerifyKey sub then VRCommand sub else VRFail +-- _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFail + +client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m () +client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} = forever $ atomically (readTBQueue rcvQ) >>= processCommand >>= atomically . writeTBQueue sndQ where - processCommand :: Transmission NtfCommand -> m (Transmission NtfResponse) - processCommand (corrId, subId, cmd) = case cmd of - NCCreate _token _smpQueue _verifyKey _dhKey -> do - pure (corrId, subId, NROk) - NCCheck -> do - pure (corrId, subId, NROk) - NCToken _token -> do - pure (corrId, subId, NROk) - NCDelete -> do - pure (corrId, subId, NROk) + processCommand :: NtfRequest -> m (Transmission NtfResponse) + processCommand = \case + NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do + st <- asks store + (srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair' + let dhSecret = C.dh' dhPubKey srvDrivDhKey + tknId <- getId + atomically $ do + tkn <- mkNtfTknData newTkn dhSecret + addNtfToken st tknId tkn + writeTBQueue pushQ (tkn, Notification) + -- pure (corrId, sId, NRSubId pubDhKey) + pure (corrId, "", NRId tknId srvDhPubKey) + NtfReqCmd SToken tkn (corrId, tknId, cmd) -> + (corrId,tknId,) <$> case cmd of + TNEW newTkn -> pure NROk -- TODO when duplicate token sent + TVFY code -> pure NROk + TDEL -> pure NROk + TCRN int -> pure NROk + NtfReqNew corrId (ANE SSubscription newSub) -> pure (corrId, "", NROk) + NtfReqCmd SSubscription sub (corrId, subId, cmd) -> + (corrId,subId,) <$> case cmd of + SNEW newSub -> pure NROk + SCHK -> pure NROk + SDEL -> pure NROk + PING -> pure NRPong + getId :: m NtfEntityId + getId = do + n <- asks $ subIdBytes . config + gVar <- asks idsDrg + atomically (randomBytes n gVar) + +-- NReqCreate corrId tokenId smpQueue -> pure (corrId, "", NROk) +-- do +-- st <- asks store +-- (pubDhKey, privDhKey) <- liftIO C.generateKeyPair' +-- let dhSecret = C.dh' dhPubKey privDhKey +-- sub <- atomically $ mkNtfSubsciption smpQueue token verifyKey dhSecret +-- addSubRetry 3 st sub >>= \case +-- Nothing -> pure (corrId, "", NRErr INTERNAL) +-- Just sId -> do +-- atomically $ writeTBQueue subQ sub +-- pure (corrId, sId, NRSubId pubDhKey) +-- where +-- addSubRetry :: Int -> NtfSubscriptionsStore -> NtfSubsciption -> m (Maybe NtfSubsciptionId) +-- addSubRetry 0 _ _ = pure Nothing +-- addSubRetry n st sub = do +-- sId <- getId +-- -- create QueueRec record with these ids and keys +-- atomically (addNtfSub st sId sub) >>= \case +-- Nothing -> addSubRetry (n - 1) st sub +-- _ -> pure $ Just sId +-- getId :: m NtfSubsciptionId +-- getId = do +-- n <- asks $ subIdBytes . config +-- gVar <- asks idsDrg +-- atomically (randomBytes n gVar) +-- NReqCommand sub@NtfSubsciption {tokenId, subStatus} (corrId, subId, cmd) -> +-- (corrId,subId,) <$> case cmd of +-- NCSubCreate tokenId smpQueue -> pure NROk +-- do +-- st <- asks store +-- (pubDhKey, privDhKey) <- liftIO C.generateKeyPair' +-- let dhSecret = C.dh' (dhPubKey newSub) privDhKey +-- atomically (updateNtfSub st sub newSub dhSecret) >>= \case +-- Nothing -> pure $ NRErr INTERNAL +-- _ -> atomically $ do +-- whenM ((== NSEnd) <$> readTVar status) $ writeTBQueue subQ sub +-- pure $ NRSubId pubDhKey +-- NCSubCheck -> NRStat <$> readTVarIO subStatus +-- NCSubDelete -> do +-- st <- asks store +-- atomically (deleteNtfSub st subId) $> NROk diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 3ce4d2b8d..fec2aafb6 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -1,4 +1,7 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} module Simplex.Messaging.Notifications.Server.Env where @@ -6,26 +9,27 @@ module Simplex.Messaging.Notifications.Server.Env where import Control.Monad.IO.Unlift import Crypto.Random import Data.ByteString.Char8 (ByteString) -import qualified Data.Map.Strict as M import Data.X509.Validation (Fingerprint (..)) import Network.Socket import qualified Network.TLS as T import Numeric.Natural import Simplex.Messaging.Agent.RetryInterval -import Simplex.Messaging.Client +import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Subscriptions -import Simplex.Messaging.Protocol (Transmission) +import Simplex.Messaging.Protocol (CorrId, Transmission) import Simplex.Messaging.Transport (ATransport) import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams) import UnliftIO.STM data NtfServerConfig = NtfServerConfig { transports :: [(ServiceName, ATransport)], - subscriptionIdBytes :: Int, - tbqSize :: Natural, - smpCfg :: SMPClientConfig, + subIdBytes :: Int, + clientQSize :: Natural, + subQSize :: Natural, + pushQSize :: Natural, + smpAgentCfg :: SMPClientAgentConfig, reconnectInterval :: RetryInterval, -- CA certificate private key is not needed for initialization caCertificateFile :: FilePath, @@ -33,26 +37,55 @@ data NtfServerConfig = NtfServerConfig certificateFile :: FilePath } +data Notification = Notification + data NtfEnv = NtfEnv { config :: NtfServerConfig, - serverIdentity :: C.KeyHash, - store :: NtfSubscriptions, + subscriber :: NtfSubscriber, + pushServer :: NtfPushServer, + store :: NtfStore, idsDrg :: TVar ChaChaDRG, + serverIdentity :: C.KeyHash, tlsServerParams :: T.ServerParams, serverIdentity :: C.KeyHash } newNtfServerEnv :: (MonadUnliftIO m, MonadRandom m) => NtfServerConfig -> m NtfEnv -newNtfServerEnv config@NtfServerConfig {caCertificateFile, certificateFile, privateKeyFile} = do +newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, caCertificateFile, certificateFile, privateKeyFile} = do idsDrg <- newTVarIO =<< drgNew - store <- newTVarIO M.empty + store <- atomically newNtfStore + subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg + pushServer <- atomically $ newNtfPushServer pushQSize tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile - let serverIdentity = C.KeyHash fp - pure NtfEnv {config, store, idsDrg, tlsServerParams, serverIdentity} + pure NtfEnv {config, subscriber, pushServer, store, idsDrg, tlsServerParams, serverIdentity = C.KeyHash fp} + +data NtfSubscriber = NtfSubscriber + { subQ :: TBQueue (NtfEntityRec 'Subscription), + smpAgent :: SMPClientAgent + } + +newNtfSubscriber :: Natural -> SMPClientAgentConfig -> STM NtfSubscriber +newNtfSubscriber qSize smpAgentCfg = do + smpAgent <- newSMPClientAgent smpAgentCfg + subQ <- newTBQueue qSize + pure NtfSubscriber {smpAgent, subQ} + +newtype NtfPushServer = NtfPushServer + { pushQ :: TBQueue (NtfTknData, Notification) + } + +newNtfPushServer :: Natural -> STM NtfPushServer +newNtfPushServer qSize = do + pushQ <- newTBQueue qSize + pure NtfPushServer {pushQ} + +data NtfRequest + = NtfReqNew CorrId ANewNtfEntity + | forall e. NtfEntityI e => NtfReqCmd (SNtfEntity e) (NtfEntityRec e) (Transmission (NtfCommand e)) data NtfServerClient = NtfServerClient - { rcvQ :: TBQueue (Transmission NtfCommand), + { rcvQ :: TBQueue NtfRequest, sndQ :: TBQueue (Transmission NtfResponse), sessionId :: ByteString, connected :: TVar Bool diff --git a/src/Simplex/Messaging/Notifications/Server/Push.hs b/src/Simplex/Messaging/Notifications/Server/Push.hs new file mode 100644 index 000000000..87b475f2a --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Server/Push.hs @@ -0,0 +1,11 @@ +module Simplex.Messaging.Notifications.Server.Push where + +import Control.Concurrent.STM +import Data.ByteString.Char8 (ByteString) +import Simplex.Messaging.Protocol (NotifierId, SMPServer) + +data NtfPushPayload = NPVerification ByteString | NPNotification SMPServer NotifierId | NPPing + +class PushProvider p where + newPushProvider :: STM p + requestBody :: p -> NtfPushPayload -> ByteString -- ? diff --git a/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs b/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs index 5658e336a..e5cbd26fc 100644 --- a/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs +++ b/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs @@ -1,25 +1,109 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE TupleSections #-} + module Simplex.Messaging.Notifications.Server.Subscriptions where import Control.Concurrent.STM +import Control.Monad +import Crypto.PubKey.Curve25519 (dhSecret) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Data.Set (Set) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Protocol (ErrorType (..), NotifierId, NtfPrivateSignKey, SMPServer) +import Simplex.Messaging.Protocol (ErrorType (..), NotifierId, NtfPrivateSignKey, ProtocolServer) +import Simplex.Messaging.TMap (TMap) +import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Util ((<$$>)) -type NtfSubscriptionsData = Map NtfSubsciptionId NtfSubsciptionRec - -type NtfSubscriptions = TVar NtfSubscriptionsData - -data NtfSubsciptionRec = NtfSubsciptionRec - { smpServer :: SMPServer, - notifierId :: NotifierId, - notifierKey :: NtfPrivateSignKey, - token :: DeviceToken, - status :: TVar NtfStatus, - subVerifyKey :: C.APublicVerifyKey, - subDHSecret :: C.DhSecretX25519 +data NtfStore = NtfStore + { tokens :: TMap NtfTokenId NtfTknData, + tokenIds :: TMap DeviceToken NtfTokenId } -getNtfSubscription :: NtfSubscriptions -> NtfSubsciptionId -> STM (Either ErrorType NtfSubsciptionRec) -getNtfSubscription st subId = maybe (Left AUTH) Right . M.lookup subId <$> readTVar st +newNtfStore :: STM NtfStore +newNtfStore = do + tokens <- TM.empty + tokenIds <- TM.empty + pure NtfStore {tokens, tokenIds} + +data NtfTknData = NtfTknData + { token :: DeviceToken, + tknStatus :: TVar NtfTknStatus, + tknVerifyKey :: C.APublicVerifyKey, + tknDhSecret :: C.DhSecretX25519 + } + +mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> STM NtfTknData +mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret = do + tknStatus <- newTVar NTNew + pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret} + +data NtfSubscriptionsStore = NtfSubscriptionsStore + +-- { subscriptions :: TMap NtfSubsciptionId NtfSubsciption, +-- activeSubscriptions :: TMap (SMPServer, NotifierId) NtfSubsciptionId +-- } +-- do +-- subscriptions <- newTVar M.empty +-- activeSubscriptions <- newTVar M.empty +-- pure NtfSubscriptionsStore {subscriptions, activeSubscriptions} + +data NtfSubData = NtfSubData + { smpQueue :: SMPQueueNtf, + tokenId :: NtfTokenId, + subStatus :: TVar NtfSubStatus + } + +data NtfEntityRec (e :: NtfEntity) where + NtfTkn :: NtfTknData -> NtfEntityRec 'Token + NtfSub :: NtfSubData -> NtfEntityRec 'Subscription + +data ANtfEntityRec = forall e. NtfEntityI e => NER (SNtfEntity e) (NtfEntityRec e) + +getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token)) +getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st) + +addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () +addNtfToken st tknId tkn = pure () + +-- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e)) +-- getNtfRec st ent entId = case ent of +-- SToken -> NtfTkn <$$> TM.lookup entId (tokens st) +-- SSubscription -> pure Nothing + +-- getNtfVerifyKey :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e, C.APublicVerifyKey)) +-- getNtfVerifyKey st ent entId = +-- getNtfRec st ent entId >>= \case +-- Just r@(NtfTkn NtfTknData {tknVerifyKey}) -> pure $ Just (r, tknVerifyKey) +-- Just r@(NtfSub NtfSubData {tokenId}) -> +-- getNtfRec st SToken tokenId >>= \case +-- Just (NtfTkn NtfTknData {tknVerifyKey}) -> pure $ Just (r, tknVerifyKey) +-- _ -> pure Nothing +-- _ -> pure Nothing + +-- mkNtfSubsciption :: SMPQueueNtf -> NtfTokenId -> STM NtfSubsciption +-- mkNtfSubsciption smpQueue tokenId = do +-- subStatus <- newTVar NSNew +-- pure NtfSubsciption {smpQueue, tokenId, subStatus} + +-- getNtfSub :: NtfSubscriptionsStore -> NtfSubsciptionId -> STM (Maybe NtfSubsciption) +-- getNtfSub st subId = pure Nothing -- maybe (pure $ Left AUTH) (fmap Right . readTVar) . M.lookup subId . subscriptions =<< readTVar st + +-- getNtfSubViaSMPQueue :: NtfSubscriptionsStore -> SMPQueueNtf -> STM (Maybe NtfSubsciption) +-- getNtfSubViaSMPQueue st smpQueue = pure Nothing + +-- -- replace keeping status +-- updateNtfSub :: NtfSubscriptionsStore -> NtfSubsciption -> SMPQueueNtf -> NtfTokenId -> C.DhSecretX25519 -> STM (Maybe ()) +-- updateNtfSub st sub smpQueue tokenId dhSecret = pure Nothing + +-- addNtfSub :: NtfSubscriptionsStore -> NtfSubsciptionId -> NtfSubsciption -> STM (Maybe ()) +-- addNtfSub st subId sub = pure Nothing + +-- deleteNtfSub :: NtfSubscriptionsStore -> NtfSubsciptionId -> STM () +-- deleteNtfSub st subId = pure () diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index dbe5a98a0..0e94b5192 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,23 +1,25 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-} -- | --- Module : Simplex.Messaging.Protocol +-- Module : Simplex.Messaging.ProtocolEncoding -- Copyright : (c) simplex.chat -- License : AGPL-3 -- @@ -37,7 +39,7 @@ module Simplex.Messaging.Protocol e2eEncMessageLength, -- * SMP protocol types - Protocol (..), + ProtocolEncoding (..), Command (..), Party (..), Cmd (..), @@ -55,7 +57,10 @@ module Simplex.Messaging.Protocol PubHeader (..), ClientMessage (..), PrivHeader (..), - SMPServer (..), + Protocol (..), + ProtocolServer (..), + SMPServer, + pattern SMPServer, SrvLoc (..), CorrId (..), QueueId, @@ -163,7 +168,7 @@ data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p) deriving instance Show Cmd -- | Parsed SMP transmission without signature, size and session ID. -type Transmission c = (CorrId, QueueId, c) +type Transmission c = (CorrId, EntityId, c) -- | signed parsed transmission, with original raw bytes and parsing error. type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c)) @@ -196,7 +201,9 @@ type SenderId = QueueId type NotifierId = QueueId -- | SMP queue ID on the server. -type QueueId = ByteString +type QueueId = EntityId + +type EntityId = ByteString -- | Parameterized type for SMP protocol commands from all clients. data Command (p :: Party) where @@ -266,7 +273,7 @@ class ProtocolMsgTag t where messageTagP :: ProtocolMsgTag t => Parser t messageTagP = - maybe (fail "bad command") pure . decodeTag + maybe (fail "bad message") pure . decodeTag =<< (A.takeTill (== ' ') <* optional A.space) instance PartyI p => Encoding (CommandTag p) where @@ -374,34 +381,39 @@ instance Encoding ClientMessage where smpEncode (ClientMessage h msg) = smpEncode h <> msg smpP = ClientMessage <$> smpP <*> A.takeByteString +type SMPServer = ProtocolServer + +pattern SMPServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer +pattern SMPServer host port keyHash = ProtocolServer host port keyHash + -- | SMP server location and transport key digest (hash). -data SMPServer = SMPServer +data ProtocolServer = ProtocolServer { host :: HostName, port :: ServiceName, keyHash :: C.KeyHash } deriving (Eq, Ord, Show) -instance IsString SMPServer where +instance IsString ProtocolServer where fromString = parseString strDecode -instance Encoding SMPServer where - smpEncode SMPServer {host, port, keyHash} = +instance Encoding ProtocolServer where + smpEncode ProtocolServer {host, port, keyHash} = smpEncode (host, port, keyHash) smpP = do (host, port, keyHash) <- smpP - pure SMPServer {host, port, keyHash} + pure ProtocolServer {host, port, keyHash} -instance StrEncoding SMPServer where - strEncode SMPServer {host, port, keyHash} = +instance StrEncoding ProtocolServer where + strEncode ProtocolServer {host, port, keyHash} = "smp://" <> strEncode keyHash <> "@" <> strEncode (SrvLoc host port) strP = do _ <- "smp://" keyHash <- strP <* A.char '@' SrvLoc host port <- strP - pure SMPServer {host, port, keyHash} + pure ProtocolServer {host, port, keyHash} -instance ToJSON SMPServer where +instance ToJSON ProtocolServer where toJSON = strToJSON toEncoding = strToJEncoding @@ -540,13 +552,25 @@ transmissionP = do command <- A.takeByteString pure RawTransmission {signature, signed, sessId, corrId, entityId, command} -class Protocol msg where +class (ProtocolEncoding msg, ProtocolEncoding (ProtocolCommand msg)) => Protocol msg where + type ProtocolCommand msg = cmd | cmd -> msg + protocolPing :: ProtocolCommand msg + protocolError :: msg -> Maybe ErrorType + +instance Protocol BrokerMsg where + type ProtocolCommand BrokerMsg = Cmd + protocolPing = Cmd SSender PING + protocolError = \case + ERR e -> Just e + _ -> Nothing + +class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where type Tag msg encodeProtocol :: msg -> ByteString protocolP :: Tag msg -> Parser msg checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg -instance PartyI p => Protocol (Command p) where +instance PartyI p => ProtocolEncoding (Command p) where type Tag (Command p) = CommandTag p encodeProtocol = \case NEW rKey dhKey -> e (NEW_, ' ', rKey, dhKey) @@ -584,7 +608,7 @@ instance PartyI p => Protocol (Command p) where | isNothing sig || B.null queueId -> Left $ CMD NO_AUTH | otherwise -> Right cmd -instance Protocol Cmd where +instance ProtocolEncoding Cmd where type Tag Cmd = CmdTag encodeProtocol (Cmd _ c) = encodeProtocol c @@ -606,7 +630,7 @@ instance Protocol Cmd where checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c -instance Protocol BrokerMsg where +instance ProtocolEncoding BrokerMsg where type Tag BrokerMsg = BrokerMsgTag encodeProtocol = \case IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh) @@ -632,7 +656,7 @@ instance Protocol BrokerMsg where PONG_ -> pure PONG checkCredentials (_, _, queueId, _) cmd = case cmd of - -- IDS response must not have queue ID + -- IDS response should not have queue ID IDS _ -> Right cmd -- ERR response does not always have queue ID ERR _ -> Right cmd @@ -649,7 +673,7 @@ _smpP :: Encoding a => Parser a _smpP = A.space *> smpP -- | Parse SMP protocol commands and broker messages -parseProtocol :: (Protocol msg, ProtocolMsgTag (Tag msg)) => ByteString -> Either ErrorType msg +parseProtocol :: ProtocolEncoding msg => ByteString -> Either ErrorType msg parseProtocol s = let (tag, params) = B.break (== ' ') s in case decodeTag tag of @@ -712,7 +736,7 @@ instance Encoding CommandError where tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t -encodeTransmission :: Protocol c => ByteString -> Transmission c -> ByteString +encodeTransmission :: ProtocolEncoding c => ByteString -> Transmission c -> ByteString encodeTransmission sessionId (CorrId corrId, queueId, command) = smpEncode (sessionId, corrId, queueId) <> encodeProtocol command @@ -721,11 +745,7 @@ tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmissi tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th -- | Receive client and server transmissions (determined by `cmd` type). -tGet :: - forall cmd c m. - (Protocol cmd, ProtocolMsgTag (Tag cmd), Transport c, MonadIO m) => - THandle c -> - m (SignedTransmission cmd) +tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (SignedTransmission cmd) tGet th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseValidate where decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index c76a2aaa3..c3d62bb02 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -23,7 +23,7 @@ -- and optional append only log of SMP queue records. -- -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md -module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, verifyCmdSignature, dummyVerifyCmd) where +module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking, verifyCmdSignature, dummyVerifyCmd, randomBytes) where import Control.Monad import Control.Monad.Except diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index a6584903e..b5bc01253 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -11,6 +11,7 @@ module Simplex.Messaging.TMap adjust, update, alter, + alterF, union, ) where @@ -65,6 +66,12 @@ alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM () alter f k m = modifyTVar' m $ M.alter f k {-# INLINE alter #-} +alterF :: Ord k => (Maybe a -> STM (Maybe a)) -> k -> TMap k a -> STM () +alterF f k m = do + mv <- M.alterF f k =<< readTVar m + writeTVar m $! mv +{-# INLINE alterF #-} + union :: Ord k => Map k a -> TMap k a -> STM () union m' m = modifyTVar' m $ M.union m' {-# INLINE union #-} diff --git a/src/Simplex/Messaging/Transport/KeepAlive.hs b/src/Simplex/Messaging/Transport/KeepAlive.hs index aa308492d..c949394a0 100644 --- a/src/Simplex/Messaging/Transport/KeepAlive.hs +++ b/src/Simplex/Messaging/Transport/KeepAlive.hs @@ -52,7 +52,7 @@ foreign import capi "netinet/tcp.h value TCP_KEEPINTVL" _TCP_KEEPINTVL :: CInt foreign import capi "netinet/tcp.h value TCP_KEEPCNT" _TCP_KEEPCNT :: CInt -#endif +#endif setSocketKeepAlive :: Socket -> KeepAliveOpts -> IO () setSocketKeepAlive sock KeepAliveOpts {keepCnt, keepIdle, keepIntvl} = do diff --git a/tests/AgentTests/ConnectionRequestTests.hs b/tests/AgentTests/ConnectionRequestTests.hs index cbab22793..1b1598413 100644 --- a/tests/AgentTests/ConnectionRequestTests.hs +++ b/tests/AgentTests/ConnectionRequestTests.hs @@ -11,7 +11,7 @@ import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Protocol (smpClientVRange) +import Simplex.Messaging.Protocol (ProtocolServer (..), smpClientVRange) import Simplex.Messaging.Version import Test.Hspec @@ -20,7 +20,7 @@ uri = "smp.simplex.im" srv :: SMPServer srv = - SMPServer + ProtocolServer { host = "smp.simplex.im", port = "5223", keyHash = C.KeyHash "\215m\248\251" diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 2d7b82375..e232ffdbd 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -24,7 +24,7 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Server (runSMPAgentBlocking) -import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig) +import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig) import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.KeepAlive @@ -162,7 +162,7 @@ cfg = tbqSize = 1, dbFile = testDB, smpCfg = - smpDefaultConfig + defaultClientConfig { qSize = 1, defaultTransport = (testPort, transport @TLS), tcpTimeout = 500_000