From f577fcdacf15cac9f763555e4d97f6b5dd542bcf Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Fri, 8 Apr 2022 08:47:04 +0100 Subject: [PATCH] agent schema/methods/types/store methods for notifications tokens (#348) * agent schema/methods/types/store methods for notifications tokens * register notification token on the server * agent commands for notification tokens * refactor initial servers from AgentConfig * agent store functions for notification tokens * server STM store methods for tokens * fix protocol client for ntfs (use generic handshake), minimal server and agent tests * server command to verify ntf token --- apps/ntf-server/Main.hs | 1 + apps/smp-agent/Main.hs | 11 +- simplexmq.cabal | 2 + src/Simplex/Messaging/Agent.hs | 117 ++++++++++++++++-- src/Simplex/Messaging/Agent/Client.hs | 44 +++++-- src/Simplex/Messaging/Agent/Env/SQLite.hs | 29 +++-- src/Simplex/Messaging/Agent/Server.hs | 12 +- src/Simplex/Messaging/Agent/Store.hs | 8 ++ src/Simplex/Messaging/Agent/Store/SQLite.hs | 89 +++++++++++-- .../Migrations/M20220322_notifications.hs | 37 ++---- .../M20220404_ntf_subscriptions_draft.hs | 38 ++++++ src/Simplex/Messaging/Client.hs | 10 +- src/Simplex/Messaging/Encoding/String.hs | 8 +- src/Simplex/Messaging/Notifications/Client.hs | 93 +++++++++++--- .../Messaging/Notifications/Protocol.hs | 87 +++++++++++-- src/Simplex/Messaging/Notifications/Server.hs | 35 ++++-- .../Messaging/Notifications/Server/Env.hs | 18 ++- .../Notifications/Server/Subscriptions.hs | 17 +-- src/Simplex/Messaging/Parsers.hs | 10 ++ src/Simplex/Messaging/Protocol.hs | 6 +- tests/AgentTests/FunctionalAPITests.hs | 44 ++++--- tests/NtfClient.hs | 106 ++++++++++++++++ tests/NtfServerTests.hs | 40 ++++++ tests/SMPAgentClient.hs | 15 ++- tests/Test.hs | 2 + 25 files changed, 732 insertions(+), 147 deletions(-) create mode 100644 src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220404_ntf_subscriptions_draft.hs create mode 100644 tests/NtfClient.hs create mode 100644 tests/NtfServerTests.hs diff --git a/apps/ntf-server/Main.hs b/apps/ntf-server/Main.hs index 8426093e5..928b9e514 100644 --- a/apps/ntf-server/Main.hs +++ b/apps/ntf-server/Main.hs @@ -39,6 +39,7 @@ ntfServerCLIConfig = NtfServerConfig { transports, subIdBytes = 24, + regCodeBytes = 32, clientQSize = 16, subQSize = 64, pushQSize = 128, diff --git a/apps/smp-agent/Main.hs b/apps/smp-agent/Main.hs index 32e038f56..72ca7f073 100644 --- a/apps/smp-agent/Main.hs +++ b/apps/smp-agent/Main.hs @@ -11,7 +11,14 @@ import Simplex.Messaging.Agent.Server (runSMPAgent) import Simplex.Messaging.Transport (TLS, Transport (..)) cfg :: AgentConfig -cfg = defaultAgentConfig {initialSMPServers = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"]} +cfg = defaultAgentConfig + +servers :: InitialAgentServers +servers = + InitialAgentServers + { smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"], + ntf = [] + } logCfg :: LogConfig logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} @@ -20,4 +27,4 @@ main :: IO () main = do putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig) setLogLevel LogInfo -- LogError - withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg + withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers diff --git a/simplexmq.cabal b/simplexmq.cabal index 66089a3e3..6bd072f5b 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -298,6 +298,8 @@ test-suite smp-server-test CoreTests.EncodingTests CoreTests.ProtocolErrorTests CoreTests.VersionRangeTests + NtfClient, + NtfServerTests ServerTests SMPAgentClient SMPClient diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 821f5fe5a..812fb6e72 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -48,6 +48,11 @@ module Simplex.Messaging.Agent suspendConnection, deleteConnection, setSMPServers, + setNtfServers, + registerNtfToken, + verifyNtfToken, + enableNtfCron, + deleteNtfToken, logConnection, ) where @@ -69,6 +74,7 @@ import Data.Maybe (isJust) import qualified Data.Text as T import Data.Time.Clock import Data.Time.Clock.System (systemToUTCTime) +import Data.Word (Word16) import Database.SQLite.Simple (SQLError) import Simplex.Messaging.Agent.Client import Simplex.Messaging.Agent.Env.SQLite @@ -80,7 +86,8 @@ import Simplex.Messaging.Client (ServerTransmission) import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding -import Simplex.Messaging.Notifications.Protocol (DeviceToken) +import Simplex.Messaging.Notifications.Client +import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfTknStatus (..)) import Simplex.Messaging.Parsers (parse) import Simplex.Messaging.Protocol (BrokerMsg, MsgBody) import qualified Simplex.Messaging.Protocol as SMP @@ -93,11 +100,11 @@ import qualified UnliftIO.Exception as E import UnliftIO.STM -- | Creates an SMP agent client instance -getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m AgentClient -getSMPAgentClient cfg = newSMPAgentEnv cfg >>= runReaderT runAgent +getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> InitialAgentServers -> m AgentClient +getSMPAgentClient cfg initServers = newSMPAgentEnv cfg >>= runReaderT runAgent where runAgent = do - c <- getAgentClient + c <- getAgentClient initServers action <- async $ subscriber c `E.finally` disconnectAgentClient c pure c {smpSubscriber = action} @@ -150,10 +157,24 @@ deleteConnection c = withAgentEnv c . deleteConnection' c setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServer -> m () setSMPServers c = withAgentEnv c . setSMPServers' c +setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m () +setNtfServers c = withAgentEnv c . setNtfServers' c + -- | Register device notifications token registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m () registerNtfToken c = withAgentEnv c . registerNtfToken' c +-- | Verify device notifications token +verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m () +verifyNtfToken c = withAgentEnv c .: verifyNtfToken' c + +-- | Enable/disable periodic notifications +enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m () +enableNtfCron c = withAgentEnv c .: enableNtfCron' c + +deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m () +deleteNtfToken c = withAgentEnv c . deleteNtfToken' c + withAgentEnv :: AgentClient -> ReaderT Env m a -> m a withAgentEnv c = (`runReaderT` agentEnv c) @@ -161,8 +182,8 @@ withAgentEnv c = (`runReaderT` agentEnv c) -- withAgentClient c = withAgentLock c . withAgentEnv c -- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's. -getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient -getAgentClient = ask >>= atomically . newAgentClient +getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => InitialAgentServers -> m AgentClient +getAgentClient initServers = ask >>= atomically . newAgentClient initServers logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m () logConnection c connected = @@ -499,8 +520,73 @@ setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m () setSMPServers' c servers = do atomically $ writeTVar (smpServers c) servers -registerNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m () -registerNtfToken' c token = pure () +registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> m () +registerNtfToken' c deviceToken = + withStore (`getDeviceNtfToken` deviceToken) >>= \case + Just tkn@NtfToken {ntfTokenId, ntfTknAction} -> case (ntfTokenId, ntfTknAction) of + (Nothing, Just (NTARegister ntfPubKey)) -> registerToken tkn ntfPubKey + -- TODO request verification code again in case there is registration, but no verification code in DB - probably after some timeout? + (Just tknId, Just (NTAVerify code)) -> + t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code + (Just tknId, Just (NTACron interval)) -> + t tkn (NTActive, Just NTACheck) $ agentNtfEnableCron c tknId tkn interval + (Just _tknId, Just NTACheck) -> pure () + (Just tknId, Just NTADelete) -> + t tkn (NTExpired, Nothing) $ agentNtfDeleteToken c tknId tkn + _ -> pure () + _ -> + getNtfServer c >>= \case + Just ntfServer -> + asks (cmdSignAlg . config) >>= \case + C.SignAlg a -> do + (ntfPubKey, ntfPrivKey) <- liftIO $ C.generateSignatureKeyPair a + let tkn = newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey + withStore $ \st -> createNtfToken st tkn + registerToken tkn ntfPubKey + _ -> throwError $ CMD PROHIBITED + where + t tkn = withToken tkn Nothing + registerToken :: NtfToken -> C.APublicVerifyKey -> m () + registerToken tkn ntfPubKey = do + (pubDhKey, privDhKey) <- liftIO C.generateKeyPair' + (tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey + let dhSecret = C.dh' srvPubDhKey privDhKey + withStore $ \st -> updateNtfTokenRegistration st tkn tknId dhSecret + +verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m () +verifyNtfToken' c deviceToken code = + withStore (`getDeviceNtfToken` deviceToken) >>= \case + Just tkn@NtfToken {ntfTokenId = Just tknId} -> + withToken tkn (Just (NTConfirmed, NTAVerify code)) (NTActive, Just NTACheck) $ + agentNtfVerifyToken c tknId tkn code + _ -> throwError $ CMD PROHIBITED + +enableNtfCron' :: AgentMonad m => AgentClient -> DeviceToken -> Word16 -> m () +enableNtfCron' c deviceToken interval = + withStore (`getDeviceNtfToken` deviceToken) >>= \case + Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive} -> + withToken tkn (Just (NTActive, NTACron interval)) (NTActive, Just NTACheck) $ + agentNtfEnableCron c tknId tkn interval + _ -> throwError $ CMD PROHIBITED + +deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m () +deleteNtfToken' c deviceToken = + withStore (`getDeviceNtfToken` deviceToken) >>= \case + Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus} -> + withToken tkn (Just (ntfTknStatus, NTADelete)) (NTExpired, Nothing) $ + agentNtfDeleteToken c tknId tkn + _ -> throwError $ CMD PROHIBITED + +withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a +withToken tkn from_ (toStatus, toAction_) f = do + forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action) + res <- f + withStore $ \st -> updateNtfToken st tkn toStatus toAction_ + pure res + +setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m () +setNtfServers' c servers = do + atomically $ writeTVar (ntfServers c) servers getSMPServer :: AgentMonad m => AgentClient -> m SMPServer getSMPServer c = do @@ -509,8 +595,19 @@ getSMPServer c = do srv :| [] -> pure srv servers -> do gen <- asks randomServer - i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1) - pure $ servers L.!! i + atomically . stateTVar gen $ + first (servers L.!!) . randomR (0, L.length servers - 1) + +getNtfServer :: AgentMonad m => AgentClient -> m (Maybe NtfServer) +getNtfServer c = do + ntfServers <- readTVarIO $ ntfServers c + case ntfServers of + [] -> pure Nothing + [srv] -> pure $ Just srv + servers -> do + gen <- asks randomServer + atomically . stateTVar gen $ + first (Just . (servers !!)) . randomR (0, length servers - 1) subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m () subscriber c@AgentClient {msgQ} = forever $ do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 321fc82ab..fc37334db 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -8,7 +8,6 @@ {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeSynonymInstances #-} module Simplex.Messaging.Agent.Client ( AgentClient (..), @@ -24,6 +23,10 @@ module Simplex.Messaging.Agent.Client RetryInterval (..), secureQueue, sendAgentMessage, + agentNtfRegisterToken, + agentNtfVerifyToken, + agentNtfDeleteToken, + agentNtfEnableCron, agentCbEncrypt, agentCbDecrypt, cryptoError, @@ -51,6 +54,7 @@ import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (isNothing) import Data.Text.Encoding +import Data.Word (Word16) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval @@ -59,8 +63,8 @@ import Simplex.Messaging.Client import Simplex.Messaging.Client.Agent () import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding -import Simplex.Messaging.Notifications.Client (NtfClient) -import Simplex.Messaging.Notifications.Protocol (NtfResponse) +import Simplex.Messaging.Notifications.Client +import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) @@ -83,8 +87,9 @@ data AgentClient = AgentClient subQ :: TBQueue (ATransmission 'Agent), msgQ :: TBQueue (ServerTransmission BrokerMsg), smpServers :: TVar (NonEmpty SMPServer), + ntfServers :: TVar [NtfServer], smpClients :: TMap SMPServer SMPClientVar, - ntfClients :: TMap ProtocolServer NtfClientVar, + ntfClients :: TMap NtfServer NtfClientVar, subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue), subscrConns :: TMap ConnId SMPServer, @@ -99,13 +104,14 @@ data AgentClient = AgentClient lock :: TMVar () } -newAgentClient :: Env -> STM AgentClient -newAgentClient agentEnv = do +newAgentClient :: InitialAgentServers -> Env -> STM AgentClient +newAgentClient InitialAgentServers {smp, ntf} agentEnv = do let qSize = tbqSize $ config agentEnv rcvQ <- newTBQueue qSize subQ <- newTBQueue qSize msgQ <- newTBQueue qSize - smpServers <- newTVar $ initialSMPServers (config agentEnv) + smpServers <- newTVar smp + ntfServers <- newTVar ntf smpClients <- TM.empty ntfClients <- TM.empty subscrSrvrs <- TM.empty @@ -118,7 +124,7 @@ newAgentClient agentEnv = do asyncClients <- newTVar [] clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1) lock <- newTMVar () - return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock} + return AgentClient {rcvQ, subQ, msgQ, smpServers, ntfServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock} -- | Agent monad with MonadReader Env and MonadError AgentErrorType type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m) @@ -206,7 +212,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = notifySub :: ACommand 'Agent -> ConnId -> IO () notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd) -getNtfServerClient :: forall m. AgentMonad m => AgentClient -> ProtocolServer -> m NtfClient +getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient getNtfServerClient c@AgentClient {ntfClients} srv = atomically (getClientVar srv ntfClients) >>= either @@ -256,10 +262,10 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a tryConnectClient successAction retryAction = tryError connectClient >>= \r -> case r of - Right smp -> do + Right client -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv atomically $ putTMVar clientVar r - successAction smp + successAction client Left e -> do if e == BROKER NETWORK || e == BROKER TIMEOUT then retryAction @@ -464,6 +470,22 @@ sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg = msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg +agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519) +agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey = + withClient c ntfServer $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey) + +agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m () +agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code = + withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code + +agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m () +agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} = + withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId + +agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m () +agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval = + withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval + agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do cmNonce <- liftIO C.randomCbNonce diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index af400ff81..bd1207426 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -7,7 +7,9 @@ module Simplex.Messaging.Agent.Env.SQLite ( AgentConfig (..), + InitialAgentServers (..), defaultAgentConfig, + defaultReconnectInterval, Env (..), newSMPAgentEnv, ) @@ -25,13 +27,18 @@ import Simplex.Messaging.Agent.Store.SQLite import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Notifications.Client (NtfServer) import Simplex.Messaging.Transport (TLS, Transport (..)) import System.Random (StdGen, newStdGen) import UnliftIO.STM +data InitialAgentServers = InitialAgentServers + { smp :: NonEmpty SMPServer, + ntf :: [NtfServer] + } + data AgentConfig = AgentConfig { tcpPort :: ServiceName, - initialSMPServers :: NonEmpty SMPServer, cmdSignAlg :: C.SignAlg, connIdBytes :: Int, tbqSize :: Natural, @@ -47,11 +54,20 @@ data AgentConfig = AgentConfig certificateFile :: FilePath } +defaultReconnectInterval :: RetryInterval +defaultReconnectInterval = + RetryInterval + { initialInterval = second, + increaseAfter = 10 * second, + maxInterval = 10 * second + } + where + second = 1_000_000 + defaultAgentConfig :: AgentConfig defaultAgentConfig = AgentConfig { tcpPort = "5224", - initialSMPServers = undefined, -- TODO move it elsewhere? cmdSignAlg = C.SignAlg C.SEd448, connIdBytes = 12, tbqSize = 64, @@ -60,12 +76,7 @@ defaultAgentConfig = yesToMigrations = False, smpCfg = defaultClientConfig {defaultTransport = ("5223", transport @TLS)}, ntfCfg = defaultClientConfig {defaultTransport = ("443", transport @TLS)}, - reconnectInterval = - RetryInterval - { initialInterval = second, - increaseAfter = 10 * second, - maxInterval = 10 * second - }, + reconnectInterval = defaultReconnectInterval, helloTimeout = 2 * nominalDay, -- CA certificate private key is not needed for initialization -- ! we do not generate these @@ -73,8 +84,6 @@ defaultAgentConfig = privateKeyFile = "/etc/opt/simplex-agent/agent.key", certificateFile = "/etc/opt/simplex-agent/agent.crt" } - where - second = 1_000_000 data Env = Env { config :: AgentConfig, diff --git a/src/Simplex/Messaging/Agent/Server.hs b/src/Simplex/Messaging/Agent/Server.hs index b3c70c9ac..7193cd772 100644 --- a/src/Simplex/Messaging/Agent/Server.hs +++ b/src/Simplex/Messaging/Agent/Server.hs @@ -31,17 +31,17 @@ import UnliftIO.STM -- | Runs an SMP agent as a TCP service using passed configuration. -- -- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs -runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m () -runSMPAgent t cfg = do +runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> m () +runSMPAgent t cfg initServers = do started <- newEmptyTMVarIO - runSMPAgentBlocking t started cfg + runSMPAgentBlocking t started cfg initServers -- | Runs an SMP agent as a TCP service using passed configuration with signalling. -- -- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True) -- and when it is disconnected from the TCP socket once the server thread is killed (False). -runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m () -runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} = do +runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> InitialAgentServers -> m () +runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers = do runReaderT (smpAgent t) =<< newSMPAgentEnv cfg where smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' () @@ -50,7 +50,7 @@ runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertifica tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile runTransportServer started tcpPort tlsServerParams $ \(h :: c) -> do liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion - c <- getAgentClient + c <- getAgentClient initServers logConnection c True race_ (connectClient h c) (runAgentClient c) `E.finally` disconnectAgentClient c diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 4e20594a8..efcd87d3c 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -20,6 +20,8 @@ import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff, SkippedMsgKeys) +import Simplex.Messaging.Notifications.Client +import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfTknStatus, NtfTokenId) import Simplex.Messaging.Protocol ( MsgBody, MsgId, @@ -75,6 +77,12 @@ class Monad m => MonadAgentStore s m where getSkippedMsgKeys :: s -> ConnId -> m SkippedMsgKeys updateRatchet :: s -> ConnId -> RatchetX448 -> SkippedMsgDiff -> m () + -- Notification device token persistence + createNtfToken :: s -> NtfToken -> m () + getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken) -- return current token if it exists and mark any old tokens for deletion + updateNtfTokenRegistration :: s -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m () + updateNtfToken :: s -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m () + -- * Queue types -- | A receive queue. SMP queue through which the agent receives messages from a sender. diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 3f42fd139..52235f852 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -23,7 +23,6 @@ module Simplex.Messaging.Agent.Store.SQLite connectSQLiteStore, withConnection, withTransaction, - fromTextField_, firstRow, ) where @@ -41,14 +40,14 @@ import Data.Char (toLower) import Data.Functor (($>)) import Data.List (find, foldl') import qualified Data.Map.Strict as M +import Data.Maybe (listToMaybe) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) -import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, ToRow, field) +import Data.Time.Clock (getCurrentTime) +import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLError, ToRow, field) import qualified Database.SQLite.Simple as DB import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.Internal (Field (..)) -import Database.SQLite.Simple.Ok (Ok (Ok)) import Database.SQLite.Simple.QQ (sql) import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol @@ -59,7 +58,9 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (blobFieldParser) +import Simplex.Messaging.Notifications.Client (NtfServer, NtfTknAction, NtfToken (..)) +import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfTknStatus (..), NtfTokenId) +import Simplex.Messaging.Parsers (blobFieldParser, fromTextField_) import Simplex.Messaging.Protocol (MsgBody, ProtocolServer (..)) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Util (bshow, liftIOEither) @@ -565,6 +566,63 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto forM_ (M.assocs mks) $ \(msgN, mk) -> DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk) + createNtfToken :: SQLiteStore -> NtfToken -> m () + createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} = + liftIO . withTransaction st $ \db -> do + upsertNtfServer_ db srv + DB.execute + db + [sql| + INSERT INTO ntf_tokens + (provider, device_token, ntf_host, ntf_port, tkn_id, tkn_priv_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + |] + (provider, token, host, port, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) + + getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken) + getDeviceNtfToken st t@(DeviceToken provider token) = + liftIO . withTransaction st $ \db -> + fmap ntfToken . listToMaybe + <$> DB.query + db + [sql| + SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, + t.tkn_id, t.tkn_priv_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action + FROM ntf_tokens t + JOIN ntf_servers s USING (ntf_host, ntf_port) + WHERE t.provider = ? AND t.device_token = ? + |] + (provider, token) + where + ntfToken (host, port, keyHash, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) = + let ntfServer = ProtocolServer {host, port, keyHash} + in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} + + updateNtfTokenRegistration :: SQLiteStore -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m () + updateNtfTokenRegistration st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = + liftIO . withTransaction st $ \db -> do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_tokens + SET tkn_id = ?, tkn_dh_secret = ?, tkn_status = ?, tkn_action = ?, updated_at = ? + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) + + updateNtfToken :: SQLiteStore -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m () + updateNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction = + liftIO . withTransaction st $ \db -> do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_tokens + SET tkn_status = ?, tkn_action = ?, updated_at = ? + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (tknStatus, tknAction, updatedAt, provider, token, host, port) + -- * Auxiliary helpers instance ToField QueueStatus where toField = toField . serializeQueueStatus @@ -611,14 +669,6 @@ instance ToField (SConnectionMode c) where toField = toField . connMode instance FromField AConnectionMode where fromField = fromTextField_ $ fmap connMode' . connModeT -fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a -fromTextField_ fromText = \case - f@(Field (SQLText t) _) -> - case fromText t of - Just x -> Ok x - _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) - f -> returnError ConversionFailed f "expecting SQLText column type" - listToEither :: e -> [a] -> Either e a listToEither _ (x : _) = Right x listToEither e _ = Left e @@ -669,6 +719,19 @@ upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do |] [":host" := host, ":port" := port, ":key_hash" := keyHash] +upsertNtfServer_ :: DB.Connection -> NtfServer -> IO () +upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do + DB.executeNamed + db + [sql| + INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (:host,:port,:key_hash) + ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET + ntf_host=excluded.ntf_host, + ntf_port=excluded.ntf_port, + ntf_key_hash=excluded.ntf_key_hash; + |] + [":host" := host, ":port" := port, ":key_hash" := keyHash] + -- * createRcvConn helpers insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO () diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs index 8fd996810..3587e05b1 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs @@ -8,40 +8,29 @@ import Database.SQLite.Simple.QQ (sql) m20220322_notifications :: Query m20220322_notifications = [sql| -ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB; - -ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB; - -ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB; - -CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id); - CREATE TABLE ntf_servers ( ntf_host TEXT NOT NULL, ntf_port TEXT NOT NULL, ntf_key_hash BLOB NOT NULL, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), PRIMARY KEY (ntf_host, ntf_port) ) WITHOUT ROWID; -CREATE TABLE ntf_subscriptions ( +CREATE TABLE ntf_tokens ( + provider TEXT NOT NULL, -- apn + device_token TEXT NOT NULL, ntf_host TEXT NOT NULL, ntf_port TEXT NOT NULL, - ntf_sub_id BLOB NOT NULL, - ntf_sub_status TEXT NOT NULL, -- new, created, active, pending, error_auth - ntf_sub_action TEXT, -- if there is an action required on this subscription: create / check / token / delete - ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours - ntf_token TEXT NOT NULL, -- or BLOB? - smp_host TEXT NOT NULL, - smp_port TEXT NOT NULL, - smp_ntf_id BLOB NOT NULL, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL, -- this is to check subscription status periodically to know when it was last checked - PRIMARY KEY (ntf_host, ntf_port, ntf_sub_id), + tkn_id BLOB, -- token ID assigned by notifications server + tkn_priv_key BLOB NOT NULL, -- private key to sign token commands + tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications + tkn_status TEXT NOT NULL, + tkn_action BLOB, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), -- this is to check token status periodically to know when it was last checked + PRIMARY KEY (provider, device_token, ntf_host, ntf_port), FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers - ON DELETE RESTRICT ON UPDATE CASCADE, - FOREIGN KEY (smp_host, smp_port, smp_ntf_id) REFERENCES rcv_queues (host, port, ntf_id) ON DELETE RESTRICT ON UPDATE CASCADE ) WITHOUT ROWID; |] diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220404_ntf_subscriptions_draft.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220404_ntf_subscriptions_draft.hs new file mode 100644 index 000000000..5f58ca0fd --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220404_ntf_subscriptions_draft.hs @@ -0,0 +1,38 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220404_ntf_subscriptions_draft where + +import Database.SQLite.Simple (Query) +import Database.SQLite.Simple.QQ (sql) + +m20220404_ntf_subscriptions_draft :: Query +m20220404_ntf_subscriptions_draft = + [sql| +ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB; + +ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB; + +ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB; + +CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id); + +CREATE TABLE ntf_subscriptions ( + ntf_host TEXT NOT NULL, + ntf_port TEXT NOT NULL, + ntf_sub_id BLOB NOT NULL, + ntf_sub_status TEXT NOT NULL, -- new, created, active, pending, error_auth + ntf_sub_action TEXT, -- if there is an action required on this subscription: create / check / token / delete + ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours + ntf_token TEXT NOT NULL, -- or BLOB? + smp_host TEXT NOT NULL, + smp_port TEXT NOT NULL, + smp_ntf_id BLOB NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, -- this is to check subscription status periodically to know when it was last checked + PRIMARY KEY (ntf_host, ntf_port, ntf_sub_id), + FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers + ON DELETE RESTRICT ON UPDATE CASCADE, + FOREIGN KEY (smp_host, smp_port, smp_ntf_id) REFERENCES rcv_queues (host, port, ntf_id) + ON DELETE RESTRICT ON UPDATE CASCADE +) WITHOUT ROWID; +|] \ No newline at end of file diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 75ce01126..4174e8555 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -66,7 +66,7 @@ import Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError) -import Simplex.Messaging.Transport.Client (runTransportClient, smpClientHandshake) +import Simplex.Messaging.Transport.Client (runTransportClient) import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Transport.WebSockets (WS) import Simplex.Messaging.Util (bshow, liftError, raceAny_) @@ -132,11 +132,11 @@ type Response msg = Either ProtocolClientError msg -- as 'SMPServerTransmission' includes server information. getProtocolClient :: forall msg. Protocol msg => ProtocolServer -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg)) getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected = - (atomically mkSMPClient >>= runClient useTransport) + (atomically mkProtocolClient >>= runClient useTransport) `catch` \(e :: IOException) -> pure . Left $ PCEIOError e where - mkSMPClient :: STM (ProtocolClient msg) - mkSMPClient = do + mkProtocolClient :: STM (ProtocolClient msg) + mkProtocolClient = do connected <- newTVar False clientCorrId <- newTVar 0 sentCommands <- TM.empty @@ -177,7 +177,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc client :: forall c. Transport c => TProxy c -> ProtocolClient msg -> TMVar (Either ProtocolClientError (THandle c)) -> c -> IO () client _ c thVar h = - runExceptT (smpClientHandshake h $ keyHash protocolServer) >>= \case + runExceptT (protocolClientHandshake @msg h $ keyHash protocolServer) >>= \case Left e -> atomically . putTMVar thVar . Left $ PCETransportError e Right th@THandle {sessionId} -> do atomically $ do diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index 9e7d32c3f..52b26ef1d 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -2,7 +2,8 @@ {-# LANGUAGE OverloadedStrings #-} module Simplex.Messaging.Encoding.String - ( StrEncoding (..), + ( TextEncoding (..), + StrEncoding (..), Str (..), strP_, strToJSON, @@ -26,11 +27,16 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum) import qualified Data.List.NonEmpty as L +import Data.Text (Text) import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Word (Word16) import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Util ((<$?>)) +class TextEncoding a where + textEncode :: a -> Text + textDecode :: Text -> Maybe a + -- | Serializing human-readable and (where possible) URI-friendly strings for SMP and SMP agent protocols class StrEncoding a where {-# MINIMAL strEncode, (strDecode | strP) #-} diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs index 48fc89b1e..76667f9fb 100644 --- a/src/Simplex/Messaging/Notifications/Client.hs +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -7,42 +9,50 @@ module Simplex.Messaging.Notifications.Client where import Control.Monad.Except import Control.Monad.Trans.Except +import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Word (Word16) +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Parsers (blobFieldDecoder) +import Simplex.Messaging.Protocol (ProtocolServer) + +type NtfServer = ProtocolServer type NtfClient = ProtocolClient NtfResponse -registerNtfToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519) -registerNtfToken c pKey newTkn = +ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519) +ntfRegisterToken c pKey newTkn = sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case NRId tknId dhKey -> pure (tknId, dhKey) _ -> throwE PCEUnexpectedResponse -verifyNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegistrationCode -> ExceptT ProtocolClientError IO () -verifyNtfToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId +ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO () +ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId -deleteNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO () -deleteNtfToken = okNtfCommand TDEL +ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO () +ntfDeleteToken = okNtfCommand TDEL -enableNtfCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO () -enableNtfCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId +ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO () +ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId -createNtfSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519) -createNtfSubsciption c pKey newSub = +ntfCreateSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519) +ntfCreateSubsciption c pKey newSub = sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case NRId tknId dhKey -> pure (tknId, dhKey) _ -> throwE PCEUnexpectedResponse -checkNtfSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus -checkNtfSubscription c pKey subId = +ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus +ntfCheckSubscription c pKey subId = sendNtfCommand c (Just pKey) subId SCHK >>= \case NRStat stat -> pure stat _ -> throwE PCEUnexpectedResponse -deleteNfgSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO () -deleteNfgSubscription = okNtfCommand SDEL +ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO () +ntfDeleteSubscription = okNtfCommand SDEL -- | Send notification server command sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse @@ -53,3 +63,58 @@ okNtfCommand cmd c pKey entId = sendNtfCommand c (Just pKey) entId cmd >>= \case NROk -> return () _ -> throwE PCEUnexpectedResponse + +data NtfTknAction + = NTARegister C.APublicVerifyKey -- public key to send to the server + | NTAVerify NtfRegCode -- code to verify token + | NTACheck + | NTACron Word16 + | NTADelete + deriving (Show) + +instance Encoding NtfTknAction where + smpEncode = \case + NTARegister key -> smpEncode ('R', key) + NTAVerify code -> smpEncode ('V', code) + NTACheck -> "C" + NTACron interval -> smpEncode ('I', interval) + NTADelete -> "D" + smpP = + A.anyChar >>= \case + 'R' -> NTARegister <$> smpP + 'V' -> NTAVerify <$> smpP + 'C' -> pure NTACheck + 'I' -> NTACron <$> smpP + 'D' -> pure NTADelete + _ -> fail "bad NtfTknAction" + +instance FromField NtfTknAction where fromField = blobFieldDecoder smpDecode + +instance ToField NtfTknAction where toField = toField . smpEncode + +data NtfToken = NtfToken + { deviceToken :: DeviceToken, + ntfServer :: NtfServer, + ntfTokenId :: Maybe NtfTokenId, + -- | key used by the ntf client to sign transmissions + ntfPrivKey :: C.APrivateSignKey, + -- | shared DH secret used to encrypt/decrypt notifications e2e + ntfDhSecret :: Maybe C.DhSecretX25519, + -- | token status + ntfTknStatus :: NtfTknStatus, + -- | pending token action and the earliest time + ntfTknAction :: Maybe NtfTknAction + } + deriving (Show) + +newNtfToken :: DeviceToken -> NtfServer -> C.APrivateSignKey -> C.APublicVerifyKey -> NtfToken +newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey = + NtfToken + { deviceToken, + ntfServer, + ntfTokenId = Nothing, + ntfPrivKey, + ntfDhSecret = Nothing, + ntfTknStatus = NTNew, + ntfTknAction = Just $ NTARegister ntfPubKey + } \ No newline at end of file diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 1fac10e7e..17c6ece56 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -10,6 +10,7 @@ module Simplex.Messaging.Notifications.Protocol where +import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -17,8 +18,13 @@ import Data.Kind import Data.Maybe (isNothing) import Data.Type.Equality import Data.Word (Word16) +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Notifications.Transport (ntfClientHandshake) +import Simplex.Messaging.Parsers (fromTextField_) import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) import Simplex.Messaging.Util ((<$?>)) @@ -87,12 +93,31 @@ instance ProtocolMsgTag NtfCmdTag where instance NtfEntityI e => ProtocolMsgTag (NtfCommandTag e) where decodeTag s = decodeTag s >>= (\(NCT _ t) -> checkEntity' t) -type NtfRegistrationCode = ByteString +newtype NtfRegCode = NtfRegCode ByteString + deriving (Eq, Show) + +instance Encoding NtfRegCode where + smpEncode (NtfRegCode code) = smpEncode code + smpP = NtfRegCode <$> smpP + +instance StrEncoding NtfRegCode where + strEncode (NtfRegCode m) = strEncode m + strDecode s = NtfRegCode <$> strDecode s + strP = NtfRegCode <$> strP + +instance FromJSON NtfRegCode where + parseJSON = strParseJSON "NtfRegCode" + +instance ToJSON NtfRegCode where + toJSON = strToJSON + toEncoding = strToJEncoding data NewNtfEntity (e :: NtfEntity) where NewNtfTkn :: DeviceToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> NewNtfEntity 'Token NewNtfSub :: NtfTokenId -> SMPQueueNtf -> NewNtfEntity 'Subscription +deriving instance Show (NewNtfEntity e) + data ANewNtfEntity = forall e. NtfEntityI e => ANE (SNtfEntity e) (NewNtfEntity e) instance NtfEntityI e => Encoding (NewNtfEntity e) where @@ -111,6 +136,7 @@ instance Encoding ANewNtfEntity where instance Protocol NtfResponse where type ProtocolCommand NtfResponse = NtfCmd + protocolClientHandshake = ntfClientHandshake protocolPing = NtfCmd SSubscription PING protocolError = \case NRErr e -> Just e @@ -120,7 +146,7 @@ data NtfCommand (e :: NtfEntity) where -- | register new device token for notifications TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token -- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token - TVFY :: NtfRegistrationCode -> NtfCommand 'Token + TVFY :: NtfRegCode -> NtfCommand 'Token -- | delete token - all subscriptions will be removed and no more notifications will be sent TDEL :: NtfCommand 'Token -- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable @@ -134,8 +160,12 @@ data NtfCommand (e :: NtfEntity) where -- | keep-alive command PING :: NtfCommand 'Subscription +deriving instance Show (NtfCommand e) + data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e) +deriving instance Show NtfCmd + instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where type Tag (NtfCommand e) = NtfCommandTag e encodeProtocol = \case @@ -263,6 +293,7 @@ data SMPQueueNtf = SMPQueueNtf notifierId :: NotifierId, notifierKey :: NtfPrivateSignKey } + deriving (Show) instance Encoding SMPQueueNtf where smpEncode SMPQueueNtf {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey) @@ -270,17 +301,30 @@ instance Encoding SMPQueueNtf where (smpServer, notifierId, notifierKey) <- smpP pure $ SMPQueueNtf smpServer notifierId notifierKey -data PushPlatform = PPApple +data PushProvider = PPApple + deriving (Eq, Ord, Show) -instance Encoding PushPlatform where +instance Encoding PushProvider where smpEncode = \case PPApple -> "A" smpP = A.anyChar >>= \case 'A' -> pure PPApple - _ -> fail "bad PushPlatform" + _ -> fail "bad PushProvider" -data DeviceToken = DeviceToken PushPlatform ByteString +instance TextEncoding PushProvider where + textEncode = \case + PPApple -> "apple" + textDecode = \case + "apple" -> Just PPApple + _ -> Nothing + +instance FromField PushProvider where fromField = fromTextField_ textDecode + +instance ToField PushProvider where toField = toField . textEncode + +data DeviceToken = DeviceToken PushProvider ByteString + deriving (Eq, Ord, Show) instance Encoding DeviceToken where smpEncode (DeviceToken p t) = smpEncode (p, t) @@ -322,17 +366,40 @@ instance Encoding NtfSubStatus where _ -> fail "bad NtfError" data NtfTknStatus - = -- | state after registration (TNEW) + = -- | Token created in DB NTNew - | -- | if initial notification or verification failed (push provider error) + | -- | state after registration (TNEW) + NTRegistered + | -- | if initial notification failed (push provider error) or verification failed NTInvalid - | -- | if initial notification succeeded + | -- | Token confirmed via notification (accepted by push provider or verification code received by client) NTConfirmed | -- | after successful verification (TVFY) NTActive | -- | after it is no longer valid (push provider error) NTExpired - deriving (Eq) + deriving (Eq, Show) + +instance TextEncoding NtfTknStatus where + textEncode = \case + NTNew -> "new" + NTRegistered -> "registered" + NTInvalid -> "invalid" + NTConfirmed -> "confirmed" + NTActive -> "active" + NTExpired -> "expired" + textDecode = \case + "new" -> Just NTNew + "registered" -> Just NTRegistered + "invalid" -> Just NTInvalid + "confirmed" -> Just NTConfirmed + "active" -> Just NTActive + "expired" -> Just NTExpired + _ -> Nothing + +instance FromField NtfTknStatus where fromField = fromTextField_ textDecode + +instance ToField NtfTknStatus where toField = toField . textEncode checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e) checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 13d4671e6..916c64111 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -14,8 +14,8 @@ import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (MonadRandom) +import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) -import Data.Functor (($>)) import Network.Socket (ServiceName) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C @@ -106,7 +106,12 @@ ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} ntfPush :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m () ntfPush NtfPushServer {pushQ} = forever $ do atomically (readTBQueue pushQ) >>= \case - (NtfTknData {}, Notification {}) -> pure () + (NtfTknData {tknStatus}, notification) -> do + liftIO $ print $ J.encode notification + -- TODO status update should happen after the token status successfully sent + case notification of + PNVerification _ -> atomically $ writeTVar tknStatus NTConfirmed + _ -> pure () runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m () runNtfClientTransport th@THandle {sessionId} = do @@ -122,7 +127,7 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m () receive th NtfServerClient {rcvQ, sndQ} = forever $ do - t@(sig, signed, (corrId, subId, cmdOrError)) <- tGet th + t@(_, _, (corrId, subId, cmdOrError)) <- tGet th case cmdOrError of Left e -> write sndQ (corrId, subId, NRErr e) Right cmd -> @@ -144,7 +149,7 @@ verifyNtfTransmission :: verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do case cmd of NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) -> - -- TODO check that token is not already in store + -- TODO check that token is not already in store, if it is - verify that the saved key is the same pure $ if verifyCmdSignature sig_ signed k then VRVerified (NtfReqNew corrId (ANE SToken n)) @@ -189,16 +194,21 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} = (srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair' let dhSecret = C.dh' dhPubKey srvDrivDhKey tknId <- getId + regCode <- getRegCode atomically $ do - tkn <- mkNtfTknData newTkn dhSecret + tkn <- mkNtfTknData newTkn dhSecret regCode addNtfToken st tknId tkn - writeTBQueue pushQ (tkn, Notification) - -- pure (corrId, sId, NRSubId pubDhKey) + writeTBQueue pushQ (tkn, PNVerification regCode) pure (corrId, "", NRId tknId srvDhPubKey) - NtfReqCmd SToken tkn (corrId, tknId, cmd) -> + NtfReqCmd SToken (NtfTkn NtfTknData {tknStatus, tknRegCode}) (corrId, tknId, cmd) -> do + status <- readTVarIO tknStatus (corrId,tknId,) <$> case cmd of TNEW newTkn -> pure NROk -- TODO when duplicate token sent - TVFY code -> pure NROk + TVFY code -- this allows repeated verification for cases when client connection dropped before server response + | (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do + atomically $ writeTVar tknStatus NTActive + pure NROk + | otherwise -> pure $ NRErr AUTH TDEL -> pure NROk TCRN int -> pure NROk NtfReqNew corrId (ANE SSubscription newSub) -> pure (corrId, "", NROk) @@ -209,8 +219,11 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} = SDEL -> pure NROk PING -> pure NRPong getId :: m NtfEntityId - getId = do - n <- asks $ subIdBytes . config + getId = getRandomBytes =<< asks (subIdBytes . config) + getRegCode :: m NtfRegCode + getRegCode = NtfRegCode <$> (getRandomBytes =<< asks (regCodeBytes . config)) + getRandomBytes :: Int -> m ByteString + getRandomBytes n = do gVar <- asks idsDrg atomically (randomBytes n gVar) diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 8b2a416a1..e76c63375 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} @@ -8,8 +9,11 @@ module Simplex.Messaging.Notifications.Server.Env where import Control.Monad.IO.Unlift import Crypto.Random +import Data.Aeson (FromJSON, ToJSON) +import qualified Data.Aeson as J import Data.ByteString.Char8 (ByteString) import Data.X509.Validation (Fingerprint (..)) +import GHC.Generics import Network.Socket import qualified Network.TLS as T import Numeric.Natural @@ -17,6 +21,7 @@ import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Subscriptions +import Simplex.Messaging.Parsers (dropPrefix, taggedObjectJSON) import Simplex.Messaging.Protocol (CorrId, Transmission) import Simplex.Messaging.Transport (ATransport) import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams) @@ -25,6 +30,7 @@ import UnliftIO.STM data NtfServerConfig = NtfServerConfig { transports :: [(ServiceName, ATransport)], subIdBytes :: Int, + regCodeBytes :: Int, clientQSize :: Natural, subQSize :: Natural, pushQSize :: Natural, @@ -35,7 +41,15 @@ data NtfServerConfig = NtfServerConfig certificateFile :: FilePath } -data Notification = Notification +data PushNotification = PNVerification {code :: NtfRegCode} | PNPeriodic + deriving (Show, Generic) + +instance FromJSON PushNotification where + parseJSON = J.genericParseJSON . taggedObjectJSON $ dropPrefix "PN" + +instance ToJSON PushNotification where + toJSON = J.genericToJSON . taggedObjectJSON $ dropPrefix "PN" + toEncoding = J.genericToEncoding . taggedObjectJSON $ dropPrefix "PN" data NtfEnv = NtfEnv { config :: NtfServerConfig, @@ -70,7 +84,7 @@ newNtfSubscriber qSize smpAgentCfg = do pure NtfSubscriber {smpAgent, subQ} newtype NtfPushServer = NtfPushServer - { pushQ :: TBQueue (NtfTknData, Notification) + { pushQ :: TBQueue (NtfTknData, PushNotification) } newNtfPushServer :: Natural -> STM NtfPushServer diff --git a/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs b/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs index e5cbd26fc..fa5790d85 100644 --- a/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs +++ b/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs @@ -36,15 +36,16 @@ data NtfTknData = NtfTknData { token :: DeviceToken, tknStatus :: TVar NtfTknStatus, tknVerifyKey :: C.APublicVerifyKey, - tknDhSecret :: C.DhSecretX25519 + tknDhSecret :: C.DhSecretX25519, + tknRegCode :: NtfRegCode } -mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> STM NtfTknData -mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret = do - tknStatus <- newTVar NTNew - pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret} +mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData +mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret tknRegCode = do + tknStatus <- newTVar NTRegistered + pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret, tknRegCode} -data NtfSubscriptionsStore = NtfSubscriptionsStore +-- data NtfSubscriptionsStore = NtfSubscriptionsStore -- { subscriptions :: TMap NtfSubsciptionId NtfSubsciption, -- activeSubscriptions :: TMap (SMPServer, NotifierId) NtfSubsciptionId @@ -70,7 +71,9 @@ getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token)) getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st) addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () -addNtfToken st tknId tkn = pure () +addNtfToken st tknId tkn@NtfTknData {token} = do + TM.insert tknId tkn (tokens st) + TM.insert token tknId (tokenIds st) -- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e)) -- getNtfRec st ent entId = case ent of diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index ab4c1a4b5..ed0606fb2 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -13,6 +13,8 @@ import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Char (isAlphaNum, toLower) +import Data.Text (Text) +import qualified Data.Text as T import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) import Data.Typeable (Typeable) @@ -81,6 +83,14 @@ blobFieldDecoder dec = \case Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e) f -> returnError ConversionFailed f "expecting SQLBlob column type" +fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a +fromTextField_ fromText = \case + f@(Field (SQLText t) _) -> + case fromText t of + Just x -> Ok x + _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) + f -> returnError ConversionFailed f "expecting SQLText column type" + fstToLower :: String -> String fstToLower "" = "" fstToLower (h : t) = toLower h : t diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 0e94b5192..869855ee0 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -115,7 +116,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers -import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), tGetBlock, tPutBlock) +import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), smpClientHandshake, tGetBlock, tPutBlock) import Simplex.Messaging.Util (bshow, (<$?>)) import Simplex.Messaging.Version import Test.QuickCheck (Arbitrary (..)) @@ -184,6 +185,7 @@ data RawTransmission = RawTransmission entityId :: ByteString, command :: ByteString } + deriving (Show) -- | unparsed sent SMP transmission with signature, without session ID. type SignedRawTransmission = (Maybe C.ASignature, ByteString, ByteString, ByteString) @@ -554,11 +556,13 @@ transmissionP = do class (ProtocolEncoding msg, ProtocolEncoding (ProtocolCommand msg)) => Protocol msg where type ProtocolCommand msg = cmd | cmd -> msg + protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> ExceptT TransportError IO (THandle c) protocolPing :: ProtocolCommand msg protocolError :: msg -> Maybe ErrorType instance Protocol BrokerMsg where type ProtocolCommand BrokerMsg = Cmd + protocolClientHandshake = smpClientHandshake protocolPing = Cmd SSender PING protocolError = \case ERR e -> Just e diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 1ebe91f24..a37bcff19 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -9,11 +9,13 @@ module AgentTests.FunctionalAPITests (functionalAPITests) where import Control.Monad.Except (ExceptT, runExceptT) import Control.Monad.IO.Unlift +import NtfClient (withNtfServer) import SMPAgentClient import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..)) import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Notifications.Protocol (DeviceToken (DeviceToken), PushProvider (PPApple)) import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import Simplex.Messaging.Transport (ATransport (..)) import System.Timeout @@ -48,11 +50,14 @@ functionalAPITests t = do testAsyncServerOffline t it "should notify after HELLO timeout" $ withSmpServer t testAsyncHelloTimeout + describe "Notification server" $ do + it "should register device token" $ + withNtfServer t testNotificationToken testAgentClient :: IO () testAgentClient = do - alice <- getSMPAgentClient cfg - bob <- getSMPAgentClient cfg {dbFile = testDB2} + alice <- getSMPAgentClient cfg initAgentServers + bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers Right () <- runExceptT $ do (bobId, qInfo) <- createConnection alice SCMInvitation aliceId <- joinConnection bob qInfo "bob's connInfo" @@ -95,13 +100,13 @@ testAgentClient = do testAsyncInitiatingOffline :: IO () testAsyncInitiatingOffline = do - alice <- getSMPAgentClient cfg - bob <- getSMPAgentClient cfg {dbFile = testDB2} + alice <- getSMPAgentClient cfg initAgentServers + bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers Right () <- runExceptT $ do (bobId, cReq) <- createConnection alice SCMInvitation disconnectAgentClient alice aliceId <- joinConnection bob cReq "bob's connInfo" - alice' <- liftIO $ getSMPAgentClient cfg + alice' <- liftIO $ getSMPAgentClient cfg initAgentServers subscribeConnection alice' bobId ("", _, CONF confId "bob's connInfo") <- get alice' allowConnection alice' bobId confId "alice's connInfo" @@ -113,15 +118,15 @@ testAsyncInitiatingOffline = do testAsyncJoiningOfflineBeforeActivation :: IO () testAsyncJoiningOfflineBeforeActivation = do - alice <- getSMPAgentClient cfg - bob <- getSMPAgentClient cfg {dbFile = testDB2} + alice <- getSMPAgentClient cfg initAgentServers + bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers Right () <- runExceptT $ do (bobId, qInfo) <- createConnection alice SCMInvitation aliceId <- joinConnection bob qInfo "bob's connInfo" disconnectAgentClient bob ("", _, CONF confId "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" - bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} + bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers subscribeConnection bob' aliceId get alice ##> ("", bobId, CON) get bob' ##> ("", aliceId, INFO "alice's connInfo") @@ -131,18 +136,18 @@ testAsyncJoiningOfflineBeforeActivation = do testAsyncBothOffline :: IO () testAsyncBothOffline = do - alice <- getSMPAgentClient cfg - bob <- getSMPAgentClient cfg {dbFile = testDB2} + alice <- getSMPAgentClient cfg initAgentServers + bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers Right () <- runExceptT $ do (bobId, cReq) <- createConnection alice SCMInvitation disconnectAgentClient alice aliceId <- joinConnection bob cReq "bob's connInfo" disconnectAgentClient bob - alice' <- liftIO $ getSMPAgentClient cfg + alice' <- liftIO $ getSMPAgentClient cfg initAgentServers subscribeConnection alice' bobId ("", _, CONF confId "bob's connInfo") <- get alice' allowConnection alice' bobId confId "alice's connInfo" - bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} + bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers subscribeConnection bob' aliceId get alice' ##> ("", bobId, CON) get bob' ##> ("", aliceId, INFO "alice's connInfo") @@ -152,8 +157,8 @@ testAsyncBothOffline = do testAsyncServerOffline :: ATransport -> IO () testAsyncServerOffline t = do - alice <- getSMPAgentClient cfg - bob <- getSMPAgentClient cfg {dbFile = testDB2} + alice <- getSMPAgentClient cfg initAgentServers + bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers -- create connection and shutdown the server Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ createConnection alice SCMInvitation @@ -176,8 +181,8 @@ testAsyncServerOffline t = do testAsyncHelloTimeout :: IO () testAsyncHelloTimeout = do - alice <- getSMPAgentClient cfg - bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1} + alice <- getSMPAgentClient cfg initAgentServers + bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1} initAgentServers Right () <- runExceptT $ do (_, cReq) <- createConnection alice SCMInvitation disconnectAgentClient alice @@ -185,6 +190,13 @@ testAsyncHelloTimeout = do get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED) pure () +testNotificationToken :: IO () +testNotificationToken = do + alice <- getSMPAgentClient cfg initAgentServers + Right () <- runExceptT $ do + registerNtfToken alice $ DeviceToken PPApple "abcd" + pure () + exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetings alice bobId bob aliceId = do 5 <- sendMessage alice bobId "hello" diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs new file mode 100644 index 000000000..bb7f6dec6 --- /dev/null +++ b/tests/NtfClient.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module NtfClient where + +import Control.Monad.Except (runExceptT) +import Control.Monad.IO.Unlift +import Crypto.Random +import Data.ByteString.Char8 (ByteString) +import Network.Socket +import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Encoding +import Simplex.Messaging.Notifications.Server (runNtfServerBlocking) +import Simplex.Messaging.Notifications.Server.Env +import Simplex.Messaging.Notifications.Transport +import Simplex.Messaging.Protocol +import Simplex.Messaging.Transport +import Simplex.Messaging.Transport.Client +import Simplex.Messaging.Transport.KeepAlive +import UnliftIO.Concurrent +import qualified UnliftIO.Exception as E +import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar) +import UnliftIO.Timeout (timeout) + +testHost :: HostName +testHost = "localhost" + +testPort :: ServiceName +testPort = "6001" + +testKeyHash :: C.KeyHash +testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" + +testNtfClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a +testNtfClient client = + runTransportClient testHost testPort testKeyHash (Just defaultKeepAliveOpts) $ \h -> + liftIO (runExceptT $ ntfClientHandshake h testKeyHash) >>= \case + Right th -> client th + Left e -> error $ show e + +cfg :: NtfServerConfig +cfg = + NtfServerConfig + { transports = undefined, + subIdBytes = 24, + regCodeBytes = 32, + clientQSize = 1, + subQSize = 1, + pushQSize = 1, + smpAgentCfg = defaultSMPClientAgentConfig, + -- CA certificate private key is not needed for initialization + caCertificateFile = "tests/fixtures/ca.crt", + privateKeyFile = "tests/fixtures/server.key", + certificateFile = "tests/fixtures/server.crt" + } + +withNtfServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a +withNtfServerThreadOn t port' = + serverBracket + (\started -> runNtfServerBlocking started cfg {transports = [(port', t)]}) + (pure ()) + +serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a +serverBracket process afterProcess f = do + started <- newEmptyTMVarIO + E.bracket + (forkIOWithUnmask ($ process started)) + (\t -> killThread t >> afterProcess >> waitFor started "stop") + (\t -> waitFor started "start" >> f t) + where + waitFor started s = + 5_000_000 `timeout` atomically (takeTMVar started) >>= \case + Nothing -> error $ "server did not " <> s + _ -> pure () + +withNtfServerOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> m a -> m a +withNtfServerOn t port' = withNtfServerThreadOn t port' . const + +withNtfServer :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a +withNtfServer t = withNtfServerOn t testPort + +runNtfTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (THandle c -> m a) -> m a +runNtfTest test = withNtfServer (transport @c) $ testNtfClient test + +ntfServerTest :: + forall c smp. + (Transport c, Encoding smp) => + TProxy c -> + (Maybe C.ASignature, ByteString, ByteString, smp) -> + IO (Maybe C.ASignature, ByteString, ByteString, BrokerMsg) +ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h + where + tPut' h (sig, corrId, queueId, smp) = do + let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp) + Right () <- tPut h (sig, t') + pure () + tGet' h = do + (Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h + pure (Nothing, corrId, qId, cmd) \ No newline at end of file diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs new file mode 100644 index 000000000..93195e8ac --- /dev/null +++ b/tests/NtfServerTests.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module NtfServerTests where + +import Data.ByteString.Char8 (ByteString) +import NtfClient +import ServerTests (sampleDhPubKey, samplePubKey, sampleSig) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Encoding +import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Protocol +import Simplex.Messaging.Transport +import Test.Hspec + +ntfServerTests :: ATransport -> Spec +ntfServerTests t = do + describe "notifications server protocol syntax" $ ntfSyntaxTests t + +ntfSyntaxTests :: ATransport -> Spec +ntfSyntaxTests (ATransport t) = do + it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN) + describe "NEW" $ do + it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX) + it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX) + it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH) + it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH) + where + (>#>) :: + Encoding smp => + (Maybe C.ASignature, ByteString, ByteString, smp) -> + (Maybe C.ASignature, ByteString, ByteString, BrokerMsg) -> + Expectation + command >#> response = ntfServerTest t command `shouldReturn` response \ No newline at end of file diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index e232ffdbd..304ba9674 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -154,11 +154,17 @@ smpAgentTest1_1_1 test' = _test [h] = test' h _test _ = error "expected 1 handle" +initAgentServers :: InitialAgentServers +initAgentServers = + InitialAgentServers + { smp = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"], + ntf = ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"] + } + cfg :: AgentConfig cfg = defaultAgentConfig { tcpPort = agentTestPort, - initialSMPServers = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"], tbqSize = 1, dbFile = testDB, smpCfg = @@ -167,7 +173,7 @@ cfg = defaultTransport = (testPort, transport @TLS), tcpTimeout = 500_000 }, - reconnectInterval = (reconnectInterval defaultAgentConfig) {initialInterval = 50_000}, + reconnectInterval = defaultReconnectInterval {initialInterval = 50_000}, caCertificateFile = "tests/fixtures/ca.crt", privateKeyFile = "tests/fixtures/server.key", certificateFile = "tests/fixtures/server.crt" @@ -175,9 +181,10 @@ cfg = withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m () -> (ThreadId -> m a) -> m a withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess = - let cfg' = cfg {tcpPort = port', dbFile = db', initialSMPServers = L.fromList [SMPServer "localhost" smpPort' testKeyHash]} + let cfg' = cfg {tcpPort = port', dbFile = db'} + initServers' = initAgentServers {smp = L.fromList [SMPServer "localhost" smpPort' testKeyHash]} in serverBracket - (\started -> runSMPAgentBlocking t started cfg') + (\started -> runSMPAgentBlocking t started cfg' initServers') afterProcess withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a diff --git a/tests/Test.hs b/tests/Test.hs index 03eab361d..bdd67c0dc 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -4,6 +4,7 @@ import AgentTests (agentTests) import CoreTests.EncodingTests import CoreTests.ProtocolErrorTests import CoreTests.VersionRangeTests +import NtfServerTests (ntfServerTests) import ServerTests import Simplex.Messaging.Transport (TLS, Transport (..)) import Simplex.Messaging.Transport.WebSockets (WS) @@ -20,5 +21,6 @@ main = do describe "Version range" versionRangeTests describe "SMP server via TLS" $ serverTests (transport @TLS) describe "SMP server via WebSockets" $ serverTests (transport @WS) + describe "Ntf server via TLS" $ ntfServerTests (transport @TLS) describe "SMP client agent" $ agentTests (transport @TLS) removeDirectoryRecursive "tests/tmp"