agent schema/methods/types/store methods for notifications tokens (#348)

* agent schema/methods/types/store methods for notifications tokens

* register notification token on the server

* agent commands for notification tokens

* refactor initial servers from AgentConfig

* agent store functions for notification tokens

* server STM store methods for tokens

* fix protocol client for ntfs (use generic handshake), minimal server and agent tests

* server command to verify ntf token
This commit is contained in:
Evgeny Poberezkin
2022-04-08 08:47:04 +01:00
committed by GitHub
parent fb26916eea
commit f577fcdacf
25 changed files with 732 additions and 147 deletions
+1
View File
@@ -39,6 +39,7 @@ ntfServerCLIConfig =
NtfServerConfig
{ transports,
subIdBytes = 24,
regCodeBytes = 32,
clientQSize = 16,
subQSize = 64,
pushQSize = 128,
+9 -2
View File
@@ -11,7 +11,14 @@ import Simplex.Messaging.Agent.Server (runSMPAgent)
import Simplex.Messaging.Transport (TLS, Transport (..))
cfg :: AgentConfig
cfg = defaultAgentConfig {initialSMPServers = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"]}
cfg = defaultAgentConfig
servers :: InitialAgentServers
servers =
InitialAgentServers
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
ntf = []
}
logCfg :: LogConfig
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
@@ -20,4 +27,4 @@ main :: IO ()
main = do
putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig)
setLogLevel LogInfo -- LogError
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers
+2
View File
@@ -298,6 +298,8 @@ test-suite smp-server-test
CoreTests.EncodingTests
CoreTests.ProtocolErrorTests
CoreTests.VersionRangeTests
NtfClient,
NtfServerTests
ServerTests
SMPAgentClient
SMPClient
+107 -10
View File
@@ -48,6 +48,11 @@ module Simplex.Messaging.Agent
suspendConnection,
deleteConnection,
setSMPServers,
setNtfServers,
registerNtfToken,
verifyNtfToken,
enableNtfCron,
deleteNtfToken,
logConnection,
)
where
@@ -69,6 +74,7 @@ import Data.Maybe (isJust)
import qualified Data.Text as T
import Data.Time.Clock
import Data.Time.Clock.System (systemToUTCTime)
import Data.Word (Word16)
import Database.SQLite.Simple (SQLError)
import Simplex.Messaging.Agent.Client
import Simplex.Messaging.Agent.Env.SQLite
@@ -80,7 +86,8 @@ import Simplex.Messaging.Client (ServerTransmission)
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Ratchet as CR
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Protocol (DeviceToken)
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfTknStatus (..))
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, MsgBody)
import qualified Simplex.Messaging.Protocol as SMP
@@ -93,11 +100,11 @@ import qualified UnliftIO.Exception as E
import UnliftIO.STM
-- | Creates an SMP agent client instance
getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m AgentClient
getSMPAgentClient cfg = newSMPAgentEnv cfg >>= runReaderT runAgent
getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> InitialAgentServers -> m AgentClient
getSMPAgentClient cfg initServers = newSMPAgentEnv cfg >>= runReaderT runAgent
where
runAgent = do
c <- getAgentClient
c <- getAgentClient initServers
action <- async $ subscriber c `E.finally` disconnectAgentClient c
pure c {smpSubscriber = action}
@@ -150,10 +157,24 @@ deleteConnection c = withAgentEnv c . deleteConnection' c
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServer -> m ()
setSMPServers c = withAgentEnv c . setSMPServers' c
setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m ()
setNtfServers c = withAgentEnv c . setNtfServers' c
-- | Register device notifications token
registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
registerNtfToken c = withAgentEnv c . registerNtfToken' c
-- | Verify device notifications token
verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m ()
verifyNtfToken c = withAgentEnv c .: verifyNtfToken' c
-- | Enable/disable periodic notifications
enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
enableNtfCron c = withAgentEnv c .: enableNtfCron' c
deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
deleteNtfToken c = withAgentEnv c . deleteNtfToken' c
withAgentEnv :: AgentClient -> ReaderT Env m a -> m a
withAgentEnv c = (`runReaderT` agentEnv c)
@@ -161,8 +182,8 @@ withAgentEnv c = (`runReaderT` agentEnv c)
-- withAgentClient c = withAgentLock c . withAgentEnv c
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient
getAgentClient = ask >>= atomically . newAgentClient
getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => InitialAgentServers -> m AgentClient
getAgentClient initServers = ask >>= atomically . newAgentClient initServers
logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m ()
logConnection c connected =
@@ -499,8 +520,73 @@ setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
setSMPServers' c servers = do
atomically $ writeTVar (smpServers c) servers
registerNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
registerNtfToken' c token = pure ()
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> m ()
registerNtfToken' c deviceToken =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
Just tkn@NtfToken {ntfTokenId, ntfTknAction} -> case (ntfTokenId, ntfTknAction) of
(Nothing, Just (NTARegister ntfPubKey)) -> registerToken tkn ntfPubKey
-- TODO request verification code again in case there is registration, but no verification code in DB - probably after some timeout?
(Just tknId, Just (NTAVerify code)) ->
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
(Just tknId, Just (NTACron interval)) ->
t tkn (NTActive, Just NTACheck) $ agentNtfEnableCron c tknId tkn interval
(Just _tknId, Just NTACheck) -> pure ()
(Just tknId, Just NTADelete) ->
t tkn (NTExpired, Nothing) $ agentNtfDeleteToken c tknId tkn
_ -> pure ()
_ ->
getNtfServer c >>= \case
Just ntfServer ->
asks (cmdSignAlg . config) >>= \case
C.SignAlg a -> do
(ntfPubKey, ntfPrivKey) <- liftIO $ C.generateSignatureKeyPair a
let tkn = newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey
withStore $ \st -> createNtfToken st tkn
registerToken tkn ntfPubKey
_ -> throwError $ CMD PROHIBITED
where
t tkn = withToken tkn Nothing
registerToken :: NtfToken -> C.APublicVerifyKey -> m ()
registerToken tkn ntfPubKey = do
(pubDhKey, privDhKey) <- liftIO C.generateKeyPair'
(tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey
let dhSecret = C.dh' srvPubDhKey privDhKey
withStore $ \st -> updateNtfTokenRegistration st tkn tknId dhSecret
verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m ()
verifyNtfToken' c deviceToken code =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
Just tkn@NtfToken {ntfTokenId = Just tknId} ->
withToken tkn (Just (NTConfirmed, NTAVerify code)) (NTActive, Just NTACheck) $
agentNtfVerifyToken c tknId tkn code
_ -> throwError $ CMD PROHIBITED
enableNtfCron' :: AgentMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
enableNtfCron' c deviceToken interval =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive} ->
withToken tkn (Just (NTActive, NTACron interval)) (NTActive, Just NTACheck) $
agentNtfEnableCron c tknId tkn interval
_ -> throwError $ CMD PROHIBITED
deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
deleteNtfToken' c deviceToken =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus} ->
withToken tkn (Just (ntfTknStatus, NTADelete)) (NTExpired, Nothing) $
agentNtfDeleteToken c tknId tkn
_ -> throwError $ CMD PROHIBITED
withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a
withToken tkn from_ (toStatus, toAction_) f = do
forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action)
res <- f
withStore $ \st -> updateNtfToken st tkn toStatus toAction_
pure res
setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m ()
setNtfServers' c servers = do
atomically $ writeTVar (ntfServers c) servers
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
getSMPServer c = do
@@ -509,8 +595,19 @@ getSMPServer c = do
srv :| [] -> pure srv
servers -> do
gen <- asks randomServer
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
pure $ servers L.!! i
atomically . stateTVar gen $
first (servers L.!!) . randomR (0, L.length servers - 1)
getNtfServer :: AgentMonad m => AgentClient -> m (Maybe NtfServer)
getNtfServer c = do
ntfServers <- readTVarIO $ ntfServers c
case ntfServers of
[] -> pure Nothing
[srv] -> pure $ Just srv
servers -> do
gen <- asks randomServer
atomically . stateTVar gen $
first (Just . (servers !!)) . randomR (0, length servers - 1)
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
subscriber c@AgentClient {msgQ} = forever $ do
+33 -11
View File
@@ -8,7 +8,6 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Simplex.Messaging.Agent.Client
( AgentClient (..),
@@ -24,6 +23,10 @@ module Simplex.Messaging.Agent.Client
RetryInterval (..),
secureQueue,
sendAgentMessage,
agentNtfRegisterToken,
agentNtfVerifyToken,
agentNtfDeleteToken,
agentNtfEnableCron,
agentCbEncrypt,
agentCbDecrypt,
cryptoError,
@@ -51,6 +54,7 @@ import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (isNothing)
import Data.Text.Encoding
import Data.Word (Word16)
import Simplex.Messaging.Agent.Env.SQLite
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
@@ -59,8 +63,8 @@ import Simplex.Messaging.Client
import Simplex.Messaging.Client.Agent ()
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Client (NtfClient)
import Simplex.Messaging.Notifications.Protocol (NtfResponse)
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
@@ -83,8 +87,9 @@ data AgentClient = AgentClient
subQ :: TBQueue (ATransmission 'Agent),
msgQ :: TBQueue (ServerTransmission BrokerMsg),
smpServers :: TVar (NonEmpty SMPServer),
ntfServers :: TVar [NtfServer],
smpClients :: TMap SMPServer SMPClientVar,
ntfClients :: TMap ProtocolServer NtfClientVar,
ntfClients :: TMap NtfServer NtfClientVar,
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
subscrConns :: TMap ConnId SMPServer,
@@ -99,13 +104,14 @@ data AgentClient = AgentClient
lock :: TMVar ()
}
newAgentClient :: Env -> STM AgentClient
newAgentClient agentEnv = do
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
newAgentClient InitialAgentServers {smp, ntf} agentEnv = do
let qSize = tbqSize $ config agentEnv
rcvQ <- newTBQueue qSize
subQ <- newTBQueue qSize
msgQ <- newTBQueue qSize
smpServers <- newTVar $ initialSMPServers (config agentEnv)
smpServers <- newTVar smp
ntfServers <- newTVar ntf
smpClients <- TM.empty
ntfClients <- TM.empty
subscrSrvrs <- TM.empty
@@ -118,7 +124,7 @@ newAgentClient agentEnv = do
asyncClients <- newTVar []
clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1)
lock <- newTMVar ()
return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
return AgentClient {rcvQ, subQ, msgQ, smpServers, ntfServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
@@ -206,7 +212,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
notifySub :: ACommand 'Agent -> ConnId -> IO ()
notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> ProtocolServer -> m NtfClient
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
getNtfServerClient c@AgentClient {ntfClients} srv =
atomically (getClientVar srv ntfClients)
>>= either
@@ -256,10 +262,10 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
tryConnectClient successAction retryAction =
tryError connectClient >>= \r -> case r of
Right smp -> do
Right client -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar clientVar r
successAction smp
successAction client
Left e -> do
if e == BROKER NETWORK || e == BROKER TIMEOUT
then retryAction
@@ -464,6 +470,22 @@ sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg =
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
withClient c ntfServer $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do
cmNonce <- liftIO C.randomCbNonce
+19 -10
View File
@@ -7,7 +7,9 @@
module Simplex.Messaging.Agent.Env.SQLite
( AgentConfig (..),
InitialAgentServers (..),
defaultAgentConfig,
defaultReconnectInterval,
Env (..),
newSMPAgentEnv,
)
@@ -25,13 +27,18 @@ import Simplex.Messaging.Agent.Store.SQLite
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Client (NtfServer)
import Simplex.Messaging.Transport (TLS, Transport (..))
import System.Random (StdGen, newStdGen)
import UnliftIO.STM
data InitialAgentServers = InitialAgentServers
{ smp :: NonEmpty SMPServer,
ntf :: [NtfServer]
}
data AgentConfig = AgentConfig
{ tcpPort :: ServiceName,
initialSMPServers :: NonEmpty SMPServer,
cmdSignAlg :: C.SignAlg,
connIdBytes :: Int,
tbqSize :: Natural,
@@ -47,11 +54,20 @@ data AgentConfig = AgentConfig
certificateFile :: FilePath
}
defaultReconnectInterval :: RetryInterval
defaultReconnectInterval =
RetryInterval
{ initialInterval = second,
increaseAfter = 10 * second,
maxInterval = 10 * second
}
where
second = 1_000_000
defaultAgentConfig :: AgentConfig
defaultAgentConfig =
AgentConfig
{ tcpPort = "5224",
initialSMPServers = undefined, -- TODO move it elsewhere?
cmdSignAlg = C.SignAlg C.SEd448,
connIdBytes = 12,
tbqSize = 64,
@@ -60,12 +76,7 @@ defaultAgentConfig =
yesToMigrations = False,
smpCfg = defaultClientConfig {defaultTransport = ("5223", transport @TLS)},
ntfCfg = defaultClientConfig {defaultTransport = ("443", transport @TLS)},
reconnectInterval =
RetryInterval
{ initialInterval = second,
increaseAfter = 10 * second,
maxInterval = 10 * second
},
reconnectInterval = defaultReconnectInterval,
helloTimeout = 2 * nominalDay,
-- CA certificate private key is not needed for initialization
-- ! we do not generate these
@@ -73,8 +84,6 @@ defaultAgentConfig =
privateKeyFile = "/etc/opt/simplex-agent/agent.key",
certificateFile = "/etc/opt/simplex-agent/agent.crt"
}
where
second = 1_000_000
data Env = Env
{ config :: AgentConfig,
+6 -6
View File
@@ -31,17 +31,17 @@ import UnliftIO.STM
-- | Runs an SMP agent as a TCP service using passed configuration.
--
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
runSMPAgent t cfg = do
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> m ()
runSMPAgent t cfg initServers = do
started <- newEmptyTMVarIO
runSMPAgentBlocking t started cfg
runSMPAgentBlocking t started cfg initServers
-- | Runs an SMP agent as a TCP service using passed configuration with signalling.
--
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} = do
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> InitialAgentServers -> m ()
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers = do
runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
where
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
@@ -50,7 +50,7 @@ runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertifica
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
runTransportServer started tcpPort tlsServerParams $ \(h :: c) -> do
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
c <- getAgentClient
c <- getAgentClient initServers
logConnection c True
race_ (connectClient h c) (runAgentClient c)
`E.finally` disconnectAgentClient c
+8
View File
@@ -20,6 +20,8 @@ import Data.Type.Equality
import Simplex.Messaging.Agent.Protocol
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff, SkippedMsgKeys)
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfTknStatus, NtfTokenId)
import Simplex.Messaging.Protocol
( MsgBody,
MsgId,
@@ -75,6 +77,12 @@ class Monad m => MonadAgentStore s m where
getSkippedMsgKeys :: s -> ConnId -> m SkippedMsgKeys
updateRatchet :: s -> ConnId -> RatchetX448 -> SkippedMsgDiff -> m ()
-- Notification device token persistence
createNtfToken :: s -> NtfToken -> m ()
getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken) -- return current token if it exists and mark any old tokens for deletion
updateNtfTokenRegistration :: s -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m ()
updateNtfToken :: s -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m ()
-- * Queue types
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
+76 -13
View File
@@ -23,7 +23,6 @@ module Simplex.Messaging.Agent.Store.SQLite
connectSQLiteStore,
withConnection,
withTransaction,
fromTextField_,
firstRow,
)
where
@@ -41,14 +40,14 @@ import Data.Char (toLower)
import Data.Functor (($>))
import Data.List (find, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe (listToMaybe)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1)
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, ToRow, field)
import Data.Time.Clock (getCurrentTime)
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLError, ToRow, field)
import qualified Database.SQLite.Simple as DB
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.Internal (Field (..))
import Database.SQLite.Simple.Ok (Ok (Ok))
import Database.SQLite.Simple.QQ (sql)
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Messaging.Agent.Protocol
@@ -59,7 +58,9 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys)
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers (blobFieldParser)
import Simplex.Messaging.Notifications.Client (NtfServer, NtfTknAction, NtfToken (..))
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfTknStatus (..), NtfTokenId)
import Simplex.Messaging.Parsers (blobFieldParser, fromTextField_)
import Simplex.Messaging.Protocol (MsgBody, ProtocolServer (..))
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.Util (bshow, liftIOEither)
@@ -565,6 +566,63 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
forM_ (M.assocs mks) $ \(msgN, mk) ->
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
createNtfToken :: SQLiteStore -> NtfToken -> m ()
createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} =
liftIO . withTransaction st $ \db -> do
upsertNtfServer_ db srv
DB.execute
db
[sql|
INSERT INTO ntf_tokens
(provider, device_token, ntf_host, ntf_port, tkn_id, tkn_priv_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|]
(provider, token, host, port, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction)
getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken)
getDeviceNtfToken st t@(DeviceToken provider token) =
liftIO . withTransaction st $ \db ->
fmap ntfToken . listToMaybe
<$> DB.query
db
[sql|
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
t.tkn_id, t.tkn_priv_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action
FROM ntf_tokens t
JOIN ntf_servers s USING (ntf_host, ntf_port)
WHERE t.provider = ? AND t.device_token = ?
|]
(provider, token)
where
ntfToken (host, port, keyHash, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) =
let ntfServer = ProtocolServer {host, port, keyHash}
in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction}
updateNtfTokenRegistration :: SQLiteStore -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m ()
updateNtfTokenRegistration st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret =
liftIO . withTransaction st $ \db -> do
updatedAt <- getCurrentTime
DB.execute
db
[sql|
UPDATE ntf_tokens
SET tkn_id = ?, tkn_dh_secret = ?, tkn_status = ?, tkn_action = ?, updated_at = ?
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|]
(tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port)
updateNtfToken :: SQLiteStore -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m ()
updateNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction =
liftIO . withTransaction st $ \db -> do
updatedAt <- getCurrentTime
DB.execute
db
[sql|
UPDATE ntf_tokens
SET tkn_status = ?, tkn_action = ?, updated_at = ?
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|]
(tknStatus, tknAction, updatedAt, provider, token, host, port)
-- * Auxiliary helpers
instance ToField QueueStatus where toField = toField . serializeQueueStatus
@@ -611,14 +669,6 @@ instance ToField (SConnectionMode c) where toField = toField . connMode
instance FromField AConnectionMode where fromField = fromTextField_ $ fmap connMode' . connModeT
fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ fromText = \case
f@(Field (SQLText t) _) ->
case fromText t of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
f -> returnError ConversionFailed f "expecting SQLText column type"
listToEither :: e -> [a] -> Either e a
listToEither _ (x : _) = Right x
listToEither e _ = Left e
@@ -669,6 +719,19 @@ upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do
|]
[":host" := host, ":port" := port, ":key_hash" := keyHash]
upsertNtfServer_ :: DB.Connection -> NtfServer -> IO ()
upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
DB.executeNamed
db
[sql|
INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (:host,:port,:key_hash)
ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET
ntf_host=excluded.ntf_host,
ntf_port=excluded.ntf_port,
ntf_key_hash=excluded.ntf_key_hash;
|]
[":host" := host, ":port" := port, ":key_hash" := keyHash]
-- * createRcvConn helpers
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
@@ -8,40 +8,29 @@ import Database.SQLite.Simple.QQ (sql)
m20220322_notifications :: Query
m20220322_notifications =
[sql|
ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB;
ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB;
ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB;
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id);
CREATE TABLE ntf_servers (
ntf_host TEXT NOT NULL,
ntf_port TEXT NOT NULL,
ntf_key_hash BLOB NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
PRIMARY KEY (ntf_host, ntf_port)
) WITHOUT ROWID;
CREATE TABLE ntf_subscriptions (
CREATE TABLE ntf_tokens (
provider TEXT NOT NULL, -- apn
device_token TEXT NOT NULL,
ntf_host TEXT NOT NULL,
ntf_port TEXT NOT NULL,
ntf_sub_id BLOB NOT NULL,
ntf_sub_status TEXT NOT NULL, -- new, created, active, pending, error_auth
ntf_sub_action TEXT, -- if there is an action required on this subscription: create / check / token / delete
ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours
ntf_token TEXT NOT NULL, -- or BLOB?
smp_host TEXT NOT NULL,
smp_port TEXT NOT NULL,
smp_ntf_id BLOB NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL, -- this is to check subscription status periodically to know when it was last checked
PRIMARY KEY (ntf_host, ntf_port, ntf_sub_id),
tkn_id BLOB, -- token ID assigned by notifications server
tkn_priv_key BLOB NOT NULL, -- private key to sign token commands
tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications
tkn_status TEXT NOT NULL,
tkn_action BLOB,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now')), -- this is to check token status periodically to know when it was last checked
PRIMARY KEY (provider, device_token, ntf_host, ntf_port),
FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers
ON DELETE RESTRICT ON UPDATE CASCADE,
FOREIGN KEY (smp_host, smp_port, smp_ntf_id) REFERENCES rcv_queues (host, port, ntf_id)
ON DELETE RESTRICT ON UPDATE CASCADE
) WITHOUT ROWID;
|]
@@ -0,0 +1,38 @@
{-# LANGUAGE QuasiQuotes #-}
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220404_ntf_subscriptions_draft where
import Database.SQLite.Simple (Query)
import Database.SQLite.Simple.QQ (sql)
m20220404_ntf_subscriptions_draft :: Query
m20220404_ntf_subscriptions_draft =
[sql|
ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB;
ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB;
ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB;
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id);
CREATE TABLE ntf_subscriptions (
ntf_host TEXT NOT NULL,
ntf_port TEXT NOT NULL,
ntf_sub_id BLOB NOT NULL,
ntf_sub_status TEXT NOT NULL, -- new, created, active, pending, error_auth
ntf_sub_action TEXT, -- if there is an action required on this subscription: create / check / token / delete
ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours
ntf_token TEXT NOT NULL, -- or BLOB?
smp_host TEXT NOT NULL,
smp_port TEXT NOT NULL,
smp_ntf_id BLOB NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL, -- this is to check subscription status periodically to know when it was last checked
PRIMARY KEY (ntf_host, ntf_port, ntf_sub_id),
FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers
ON DELETE RESTRICT ON UPDATE CASCADE,
FOREIGN KEY (smp_host, smp_port, smp_ntf_id) REFERENCES rcv_queues (host, port, ntf_id)
ON DELETE RESTRICT ON UPDATE CASCADE
) WITHOUT ROWID;
|]
+5 -5
View File
@@ -66,7 +66,7 @@ import Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError)
import Simplex.Messaging.Transport.Client (runTransportClient, smpClientHandshake)
import Simplex.Messaging.Transport.Client (runTransportClient)
import Simplex.Messaging.Transport.KeepAlive
import Simplex.Messaging.Transport.WebSockets (WS)
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
@@ -132,11 +132,11 @@ type Response msg = Either ProtocolClientError msg
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall msg. Protocol msg => ProtocolServer -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected =
(atomically mkSMPClient >>= runClient useTransport)
(atomically mkProtocolClient >>= runClient useTransport)
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
where
mkSMPClient :: STM (ProtocolClient msg)
mkSMPClient = do
mkProtocolClient :: STM (ProtocolClient msg)
mkProtocolClient = do
connected <- newTVar False
clientCorrId <- newTVar 0
sentCommands <- TM.empty
@@ -177,7 +177,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc
client :: forall c. Transport c => TProxy c -> ProtocolClient msg -> TMVar (Either ProtocolClientError (THandle c)) -> c -> IO ()
client _ c thVar h =
runExceptT (smpClientHandshake h $ keyHash protocolServer) >>= \case
runExceptT (protocolClientHandshake @msg h $ keyHash protocolServer) >>= \case
Left e -> atomically . putTMVar thVar . Left $ PCETransportError e
Right th@THandle {sessionId} -> do
atomically $ do
+7 -1
View File
@@ -2,7 +2,8 @@
{-# LANGUAGE OverloadedStrings #-}
module Simplex.Messaging.Encoding.String
( StrEncoding (..),
( TextEncoding (..),
StrEncoding (..),
Str (..),
strP_,
strToJSON,
@@ -26,11 +27,16 @@ import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isAlphaNum)
import qualified Data.List.NonEmpty as L
import Data.Text (Text)
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
import Data.Word (Word16)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Util ((<$?>))
class TextEncoding a where
textEncode :: a -> Text
textDecode :: Text -> Maybe a
-- | Serializing human-readable and (where possible) URI-friendly strings for SMP and SMP agent protocols
class StrEncoding a where
{-# MINIMAL strEncode, (strDecode | strP) #-}
+79 -14
View File
@@ -1,5 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -7,42 +9,50 @@ module Simplex.Messaging.Notifications.Client where
import Control.Monad.Except
import Control.Monad.Trans.Except
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Word (Word16)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Messaging.Client
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Parsers (blobFieldDecoder)
import Simplex.Messaging.Protocol (ProtocolServer)
type NtfServer = ProtocolServer
type NtfClient = ProtocolClient NtfResponse
registerNtfToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
registerNtfToken c pKey newTkn =
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
ntfRegisterToken c pKey newTkn =
sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case
NRId tknId dhKey -> pure (tknId, dhKey)
_ -> throwE PCEUnexpectedResponse
verifyNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegistrationCode -> ExceptT ProtocolClientError IO ()
verifyNtfToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO ()
ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
deleteNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
deleteNtfToken = okNtfCommand TDEL
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
ntfDeleteToken = okNtfCommand TDEL
enableNtfCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
enableNtfCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
createNtfSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519)
createNtfSubsciption c pKey newSub =
ntfCreateSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519)
ntfCreateSubsciption c pKey newSub =
sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case
NRId tknId dhKey -> pure (tknId, dhKey)
_ -> throwE PCEUnexpectedResponse
checkNtfSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
checkNtfSubscription c pKey subId =
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
ntfCheckSubscription c pKey subId =
sendNtfCommand c (Just pKey) subId SCHK >>= \case
NRStat stat -> pure stat
_ -> throwE PCEUnexpectedResponse
deleteNfgSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
deleteNfgSubscription = okNtfCommand SDEL
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
ntfDeleteSubscription = okNtfCommand SDEL
-- | Send notification server command
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse
@@ -53,3 +63,58 @@ okNtfCommand cmd c pKey entId =
sendNtfCommand c (Just pKey) entId cmd >>= \case
NROk -> return ()
_ -> throwE PCEUnexpectedResponse
data NtfTknAction
= NTARegister C.APublicVerifyKey -- public key to send to the server
| NTAVerify NtfRegCode -- code to verify token
| NTACheck
| NTACron Word16
| NTADelete
deriving (Show)
instance Encoding NtfTknAction where
smpEncode = \case
NTARegister key -> smpEncode ('R', key)
NTAVerify code -> smpEncode ('V', code)
NTACheck -> "C"
NTACron interval -> smpEncode ('I', interval)
NTADelete -> "D"
smpP =
A.anyChar >>= \case
'R' -> NTARegister <$> smpP
'V' -> NTAVerify <$> smpP
'C' -> pure NTACheck
'I' -> NTACron <$> smpP
'D' -> pure NTADelete
_ -> fail "bad NtfTknAction"
instance FromField NtfTknAction where fromField = blobFieldDecoder smpDecode
instance ToField NtfTknAction where toField = toField . smpEncode
data NtfToken = NtfToken
{ deviceToken :: DeviceToken,
ntfServer :: NtfServer,
ntfTokenId :: Maybe NtfTokenId,
-- | key used by the ntf client to sign transmissions
ntfPrivKey :: C.APrivateSignKey,
-- | shared DH secret used to encrypt/decrypt notifications e2e
ntfDhSecret :: Maybe C.DhSecretX25519,
-- | token status
ntfTknStatus :: NtfTknStatus,
-- | pending token action and the earliest time
ntfTknAction :: Maybe NtfTknAction
}
deriving (Show)
newNtfToken :: DeviceToken -> NtfServer -> C.APrivateSignKey -> C.APublicVerifyKey -> NtfToken
newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey =
NtfToken
{ deviceToken,
ntfServer,
ntfTokenId = Nothing,
ntfPrivKey,
ntfDhSecret = Nothing,
ntfTknStatus = NTNew,
ntfTknAction = Just $ NTARegister ntfPubKey
}
+77 -10
View File
@@ -10,6 +10,7 @@
module Simplex.Messaging.Notifications.Protocol where
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
@@ -17,8 +18,13 @@ import Data.Kind
import Data.Maybe (isNothing)
import Data.Type.Equality
import Data.Word (Word16)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
import Simplex.Messaging.Parsers (fromTextField_)
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
import Simplex.Messaging.Util ((<$?>))
@@ -87,12 +93,31 @@ instance ProtocolMsgTag NtfCmdTag where
instance NtfEntityI e => ProtocolMsgTag (NtfCommandTag e) where
decodeTag s = decodeTag s >>= (\(NCT _ t) -> checkEntity' t)
type NtfRegistrationCode = ByteString
newtype NtfRegCode = NtfRegCode ByteString
deriving (Eq, Show)
instance Encoding NtfRegCode where
smpEncode (NtfRegCode code) = smpEncode code
smpP = NtfRegCode <$> smpP
instance StrEncoding NtfRegCode where
strEncode (NtfRegCode m) = strEncode m
strDecode s = NtfRegCode <$> strDecode s
strP = NtfRegCode <$> strP
instance FromJSON NtfRegCode where
parseJSON = strParseJSON "NtfRegCode"
instance ToJSON NtfRegCode where
toJSON = strToJSON
toEncoding = strToJEncoding
data NewNtfEntity (e :: NtfEntity) where
NewNtfTkn :: DeviceToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> NewNtfEntity 'Token
NewNtfSub :: NtfTokenId -> SMPQueueNtf -> NewNtfEntity 'Subscription
deriving instance Show (NewNtfEntity e)
data ANewNtfEntity = forall e. NtfEntityI e => ANE (SNtfEntity e) (NewNtfEntity e)
instance NtfEntityI e => Encoding (NewNtfEntity e) where
@@ -111,6 +136,7 @@ instance Encoding ANewNtfEntity where
instance Protocol NtfResponse where
type ProtocolCommand NtfResponse = NtfCmd
protocolClientHandshake = ntfClientHandshake
protocolPing = NtfCmd SSubscription PING
protocolError = \case
NRErr e -> Just e
@@ -120,7 +146,7 @@ data NtfCommand (e :: NtfEntity) where
-- | register new device token for notifications
TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token
-- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token
TVFY :: NtfRegistrationCode -> NtfCommand 'Token
TVFY :: NtfRegCode -> NtfCommand 'Token
-- | delete token - all subscriptions will be removed and no more notifications will be sent
TDEL :: NtfCommand 'Token
-- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable
@@ -134,8 +160,12 @@ data NtfCommand (e :: NtfEntity) where
-- | keep-alive command
PING :: NtfCommand 'Subscription
deriving instance Show (NtfCommand e)
data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
deriving instance Show NtfCmd
instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
type Tag (NtfCommand e) = NtfCommandTag e
encodeProtocol = \case
@@ -263,6 +293,7 @@ data SMPQueueNtf = SMPQueueNtf
notifierId :: NotifierId,
notifierKey :: NtfPrivateSignKey
}
deriving (Show)
instance Encoding SMPQueueNtf where
smpEncode SMPQueueNtf {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey)
@@ -270,17 +301,30 @@ instance Encoding SMPQueueNtf where
(smpServer, notifierId, notifierKey) <- smpP
pure $ SMPQueueNtf smpServer notifierId notifierKey
data PushPlatform = PPApple
data PushProvider = PPApple
deriving (Eq, Ord, Show)
instance Encoding PushPlatform where
instance Encoding PushProvider where
smpEncode = \case
PPApple -> "A"
smpP =
A.anyChar >>= \case
'A' -> pure PPApple
_ -> fail "bad PushPlatform"
_ -> fail "bad PushProvider"
data DeviceToken = DeviceToken PushPlatform ByteString
instance TextEncoding PushProvider where
textEncode = \case
PPApple -> "apple"
textDecode = \case
"apple" -> Just PPApple
_ -> Nothing
instance FromField PushProvider where fromField = fromTextField_ textDecode
instance ToField PushProvider where toField = toField . textEncode
data DeviceToken = DeviceToken PushProvider ByteString
deriving (Eq, Ord, Show)
instance Encoding DeviceToken where
smpEncode (DeviceToken p t) = smpEncode (p, t)
@@ -322,17 +366,40 @@ instance Encoding NtfSubStatus where
_ -> fail "bad NtfError"
data NtfTknStatus
= -- | state after registration (TNEW)
= -- | Token created in DB
NTNew
| -- | if initial notification or verification failed (push provider error)
| -- | state after registration (TNEW)
NTRegistered
| -- | if initial notification failed (push provider error) or verification failed
NTInvalid
| -- | if initial notification succeeded
| -- | Token confirmed via notification (accepted by push provider or verification code received by client)
NTConfirmed
| -- | after successful verification (TVFY)
NTActive
| -- | after it is no longer valid (push provider error)
NTExpired
deriving (Eq)
deriving (Eq, Show)
instance TextEncoding NtfTknStatus where
textEncode = \case
NTNew -> "new"
NTRegistered -> "registered"
NTInvalid -> "invalid"
NTConfirmed -> "confirmed"
NTActive -> "active"
NTExpired -> "expired"
textDecode = \case
"new" -> Just NTNew
"registered" -> Just NTRegistered
"invalid" -> Just NTInvalid
"confirmed" -> Just NTConfirmed
"active" -> Just NTActive
"expired" -> Just NTExpired
_ -> Nothing
instance FromField NtfTknStatus where fromField = fromTextField_ textDecode
instance ToField NtfTknStatus where toField = toField . textEncode
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
+24 -11
View File
@@ -14,8 +14,8 @@ import Control.Monad.Except
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader
import Crypto.Random (MonadRandom)
import qualified Data.Aeson as J
import Data.ByteString.Char8 (ByteString)
import Data.Functor (($>))
import Network.Socket (ServiceName)
import Simplex.Messaging.Client.Agent
import qualified Simplex.Messaging.Crypto as C
@@ -106,7 +106,12 @@ ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}}
ntfPush :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m ()
ntfPush NtfPushServer {pushQ} = forever $ do
atomically (readTBQueue pushQ) >>= \case
(NtfTknData {}, Notification {}) -> pure ()
(NtfTknData {tknStatus}, notification) -> do
liftIO $ print $ J.encode notification
-- TODO status update should happen after the token status successfully sent
case notification of
PNVerification _ -> atomically $ writeTVar tknStatus NTConfirmed
_ -> pure ()
runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m ()
runNtfClientTransport th@THandle {sessionId} = do
@@ -122,7 +127,7 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
receive th NtfServerClient {rcvQ, sndQ} = forever $ do
t@(sig, signed, (corrId, subId, cmdOrError)) <- tGet th
t@(_, _, (corrId, subId, cmdOrError)) <- tGet th
case cmdOrError of
Left e -> write sndQ (corrId, subId, NRErr e)
Right cmd ->
@@ -144,7 +149,7 @@ verifyNtfTransmission ::
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
case cmd of
NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) ->
-- TODO check that token is not already in store
-- TODO check that token is not already in store, if it is - verify that the saved key is the same
pure $
if verifyCmdSignature sig_ signed k
then VRVerified (NtfReqNew corrId (ANE SToken n))
@@ -189,16 +194,21 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} =
(srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair'
let dhSecret = C.dh' dhPubKey srvDrivDhKey
tknId <- getId
regCode <- getRegCode
atomically $ do
tkn <- mkNtfTknData newTkn dhSecret
tkn <- mkNtfTknData newTkn dhSecret regCode
addNtfToken st tknId tkn
writeTBQueue pushQ (tkn, Notification)
-- pure (corrId, sId, NRSubId pubDhKey)
writeTBQueue pushQ (tkn, PNVerification regCode)
pure (corrId, "", NRId tknId srvDhPubKey)
NtfReqCmd SToken tkn (corrId, tknId, cmd) ->
NtfReqCmd SToken (NtfTkn NtfTknData {tknStatus, tknRegCode}) (corrId, tknId, cmd) -> do
status <- readTVarIO tknStatus
(corrId,tknId,) <$> case cmd of
TNEW newTkn -> pure NROk -- TODO when duplicate token sent
TVFY code -> pure NROk
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
atomically $ writeTVar tknStatus NTActive
pure NROk
| otherwise -> pure $ NRErr AUTH
TDEL -> pure NROk
TCRN int -> pure NROk
NtfReqNew corrId (ANE SSubscription newSub) -> pure (corrId, "", NROk)
@@ -209,8 +219,11 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} =
SDEL -> pure NROk
PING -> pure NRPong
getId :: m NtfEntityId
getId = do
n <- asks $ subIdBytes . config
getId = getRandomBytes =<< asks (subIdBytes . config)
getRegCode :: m NtfRegCode
getRegCode = NtfRegCode <$> (getRandomBytes =<< asks (regCodeBytes . config))
getRandomBytes :: Int -> m ByteString
getRandomBytes n = do
gVar <- asks idsDrg
atomically (randomBytes n gVar)
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
@@ -8,8 +9,11 @@ module Simplex.Messaging.Notifications.Server.Env where
import Control.Monad.IO.Unlift
import Crypto.Random
import Data.Aeson (FromJSON, ToJSON)
import qualified Data.Aeson as J
import Data.ByteString.Char8 (ByteString)
import Data.X509.Validation (Fingerprint (..))
import GHC.Generics
import Network.Socket
import qualified Network.TLS as T
import Numeric.Natural
@@ -17,6 +21,7 @@ import Simplex.Messaging.Client.Agent
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Subscriptions
import Simplex.Messaging.Parsers (dropPrefix, taggedObjectJSON)
import Simplex.Messaging.Protocol (CorrId, Transmission)
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
@@ -25,6 +30,7 @@ import UnliftIO.STM
data NtfServerConfig = NtfServerConfig
{ transports :: [(ServiceName, ATransport)],
subIdBytes :: Int,
regCodeBytes :: Int,
clientQSize :: Natural,
subQSize :: Natural,
pushQSize :: Natural,
@@ -35,7 +41,15 @@ data NtfServerConfig = NtfServerConfig
certificateFile :: FilePath
}
data Notification = Notification
data PushNotification = PNVerification {code :: NtfRegCode} | PNPeriodic
deriving (Show, Generic)
instance FromJSON PushNotification where
parseJSON = J.genericParseJSON . taggedObjectJSON $ dropPrefix "PN"
instance ToJSON PushNotification where
toJSON = J.genericToJSON . taggedObjectJSON $ dropPrefix "PN"
toEncoding = J.genericToEncoding . taggedObjectJSON $ dropPrefix "PN"
data NtfEnv = NtfEnv
{ config :: NtfServerConfig,
@@ -70,7 +84,7 @@ newNtfSubscriber qSize smpAgentCfg = do
pure NtfSubscriber {smpAgent, subQ}
newtype NtfPushServer = NtfPushServer
{ pushQ :: TBQueue (NtfTknData, Notification)
{ pushQ :: TBQueue (NtfTknData, PushNotification)
}
newNtfPushServer :: Natural -> STM NtfPushServer
@@ -36,15 +36,16 @@ data NtfTknData = NtfTknData
{ token :: DeviceToken,
tknStatus :: TVar NtfTknStatus,
tknVerifyKey :: C.APublicVerifyKey,
tknDhSecret :: C.DhSecretX25519
tknDhSecret :: C.DhSecretX25519,
tknRegCode :: NtfRegCode
}
mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> STM NtfTknData
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret = do
tknStatus <- newTVar NTNew
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret}
mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret tknRegCode = do
tknStatus <- newTVar NTRegistered
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret, tknRegCode}
data NtfSubscriptionsStore = NtfSubscriptionsStore
-- data NtfSubscriptionsStore = NtfSubscriptionsStore
-- { subscriptions :: TMap NtfSubsciptionId NtfSubsciption,
-- activeSubscriptions :: TMap (SMPServer, NotifierId) NtfSubsciptionId
@@ -70,7 +71,9 @@ getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token))
getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st)
addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM ()
addNtfToken st tknId tkn = pure ()
addNtfToken st tknId tkn@NtfTknData {token} = do
TM.insert tknId tkn (tokens st)
TM.insert token tknId (tokenIds st)
-- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e))
-- getNtfRec st ent entId = case ent of
+10
View File
@@ -13,6 +13,8 @@ import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (isAlphaNum, toLower)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time.Clock (UTCTime)
import Data.Time.ISO8601 (parseISO8601)
import Data.Typeable (Typeable)
@@ -81,6 +83,14 @@ blobFieldDecoder dec = \case
Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e)
f -> returnError ConversionFailed f "expecting SQLBlob column type"
fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a
fromTextField_ fromText = \case
f@(Field (SQLText t) _) ->
case fromText t of
Just x -> Ok x
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
f -> returnError ConversionFailed f "expecting SQLText column type"
fstToLower :: String -> String
fstToLower "" = ""
fstToLower (h : t) = toLower h : t
+5 -1
View File
@@ -1,3 +1,4 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
@@ -115,7 +116,7 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), tGetBlock, tPutBlock)
import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), smpClientHandshake, tGetBlock, tPutBlock)
import Simplex.Messaging.Util (bshow, (<$?>))
import Simplex.Messaging.Version
import Test.QuickCheck (Arbitrary (..))
@@ -184,6 +185,7 @@ data RawTransmission = RawTransmission
entityId :: ByteString,
command :: ByteString
}
deriving (Show)
-- | unparsed sent SMP transmission with signature, without session ID.
type SignedRawTransmission = (Maybe C.ASignature, ByteString, ByteString, ByteString)
@@ -554,11 +556,13 @@ transmissionP = do
class (ProtocolEncoding msg, ProtocolEncoding (ProtocolCommand msg)) => Protocol msg where
type ProtocolCommand msg = cmd | cmd -> msg
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> ExceptT TransportError IO (THandle c)
protocolPing :: ProtocolCommand msg
protocolError :: msg -> Maybe ErrorType
instance Protocol BrokerMsg where
type ProtocolCommand BrokerMsg = Cmd
protocolClientHandshake = smpClientHandshake
protocolPing = Cmd SSender PING
protocolError = \case
ERR e -> Just e
+28 -16
View File
@@ -9,11 +9,13 @@ module AgentTests.FunctionalAPITests (functionalAPITests) where
import Control.Monad.Except (ExceptT, runExceptT)
import Control.Monad.IO.Unlift
import NtfClient (withNtfServer)
import SMPAgentClient
import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Notifications.Protocol (DeviceToken (DeviceToken), PushProvider (PPApple))
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
import Simplex.Messaging.Transport (ATransport (..))
import System.Timeout
@@ -48,11 +50,14 @@ functionalAPITests t = do
testAsyncServerOffline t
it "should notify after HELLO timeout" $
withSmpServer t testAsyncHelloTimeout
describe "Notification server" $ do
it "should register device token" $
withNtfServer t testNotificationToken
testAgentClient :: IO ()
testAgentClient = do
alice <- getSMPAgentClient cfg
bob <- getSMPAgentClient cfg {dbFile = testDB2}
alice <- getSMPAgentClient cfg initAgentServers
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, qInfo) <- createConnection alice SCMInvitation
aliceId <- joinConnection bob qInfo "bob's connInfo"
@@ -95,13 +100,13 @@ testAgentClient = do
testAsyncInitiatingOffline :: IO ()
testAsyncInitiatingOffline = do
alice <- getSMPAgentClient cfg
bob <- getSMPAgentClient cfg {dbFile = testDB2}
alice <- getSMPAgentClient cfg initAgentServers
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, cReq) <- createConnection alice SCMInvitation
disconnectAgentClient alice
aliceId <- joinConnection bob cReq "bob's connInfo"
alice' <- liftIO $ getSMPAgentClient cfg
alice' <- liftIO $ getSMPAgentClient cfg initAgentServers
subscribeConnection alice' bobId
("", _, CONF confId "bob's connInfo") <- get alice'
allowConnection alice' bobId confId "alice's connInfo"
@@ -113,15 +118,15 @@ testAsyncInitiatingOffline = do
testAsyncJoiningOfflineBeforeActivation :: IO ()
testAsyncJoiningOfflineBeforeActivation = do
alice <- getSMPAgentClient cfg
bob <- getSMPAgentClient cfg {dbFile = testDB2}
alice <- getSMPAgentClient cfg initAgentServers
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, qInfo) <- createConnection alice SCMInvitation
aliceId <- joinConnection bob qInfo "bob's connInfo"
disconnectAgentClient bob
("", _, CONF confId "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2}
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
subscribeConnection bob' aliceId
get alice ##> ("", bobId, CON)
get bob' ##> ("", aliceId, INFO "alice's connInfo")
@@ -131,18 +136,18 @@ testAsyncJoiningOfflineBeforeActivation = do
testAsyncBothOffline :: IO ()
testAsyncBothOffline = do
alice <- getSMPAgentClient cfg
bob <- getSMPAgentClient cfg {dbFile = testDB2}
alice <- getSMPAgentClient cfg initAgentServers
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
Right () <- runExceptT $ do
(bobId, cReq) <- createConnection alice SCMInvitation
disconnectAgentClient alice
aliceId <- joinConnection bob cReq "bob's connInfo"
disconnectAgentClient bob
alice' <- liftIO $ getSMPAgentClient cfg
alice' <- liftIO $ getSMPAgentClient cfg initAgentServers
subscribeConnection alice' bobId
("", _, CONF confId "bob's connInfo") <- get alice'
allowConnection alice' bobId confId "alice's connInfo"
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2}
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
subscribeConnection bob' aliceId
get alice' ##> ("", bobId, CON)
get bob' ##> ("", aliceId, INFO "alice's connInfo")
@@ -152,8 +157,8 @@ testAsyncBothOffline = do
testAsyncServerOffline :: ATransport -> IO ()
testAsyncServerOffline t = do
alice <- getSMPAgentClient cfg
bob <- getSMPAgentClient cfg {dbFile = testDB2}
alice <- getSMPAgentClient cfg initAgentServers
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
-- create connection and shutdown the server
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
runExceptT $ createConnection alice SCMInvitation
@@ -176,8 +181,8 @@ testAsyncServerOffline t = do
testAsyncHelloTimeout :: IO ()
testAsyncHelloTimeout = do
alice <- getSMPAgentClient cfg
bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1}
alice <- getSMPAgentClient cfg initAgentServers
bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1} initAgentServers
Right () <- runExceptT $ do
(_, cReq) <- createConnection alice SCMInvitation
disconnectAgentClient alice
@@ -185,6 +190,13 @@ testAsyncHelloTimeout = do
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
pure ()
testNotificationToken :: IO ()
testNotificationToken = do
alice <- getSMPAgentClient cfg initAgentServers
Right () <- runExceptT $ do
registerNtfToken alice $ DeviceToken PPApple "abcd"
pure ()
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings alice bobId bob aliceId = do
5 <- sendMessage alice bobId "hello"
+106
View File
@@ -0,0 +1,106 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module NtfClient where
import Control.Monad.Except (runExceptT)
import Control.Monad.IO.Unlift
import Crypto.Random
import Data.ByteString.Char8 (ByteString)
import Network.Socket
import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
import Simplex.Messaging.Notifications.Server.Env
import Simplex.Messaging.Notifications.Transport
import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client
import Simplex.Messaging.Transport.KeepAlive
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
import UnliftIO.Timeout (timeout)
testHost :: HostName
testHost = "localhost"
testPort :: ServiceName
testPort = "6001"
testKeyHash :: C.KeyHash
testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
testNtfClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a
testNtfClient client =
runTransportClient testHost testPort testKeyHash (Just defaultKeepAliveOpts) $ \h ->
liftIO (runExceptT $ ntfClientHandshake h testKeyHash) >>= \case
Right th -> client th
Left e -> error $ show e
cfg :: NtfServerConfig
cfg =
NtfServerConfig
{ transports = undefined,
subIdBytes = 24,
regCodeBytes = 32,
clientQSize = 1,
subQSize = 1,
pushQSize = 1,
smpAgentCfg = defaultSMPClientAgentConfig,
-- CA certificate private key is not needed for initialization
caCertificateFile = "tests/fixtures/ca.crt",
privateKeyFile = "tests/fixtures/server.key",
certificateFile = "tests/fixtures/server.crt"
}
withNtfServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
withNtfServerThreadOn t port' =
serverBracket
(\started -> runNtfServerBlocking started cfg {transports = [(port', t)]})
(pure ())
serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
serverBracket process afterProcess f = do
started <- newEmptyTMVarIO
E.bracket
(forkIOWithUnmask ($ process started))
(\t -> killThread t >> afterProcess >> waitFor started "stop")
(\t -> waitFor started "start" >> f t)
where
waitFor started s =
5_000_000 `timeout` atomically (takeTMVar started) >>= \case
Nothing -> error $ "server did not " <> s
_ -> pure ()
withNtfServerOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> m a -> m a
withNtfServerOn t port' = withNtfServerThreadOn t port' . const
withNtfServer :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
withNtfServer t = withNtfServerOn t testPort
runNtfTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (THandle c -> m a) -> m a
runNtfTest test = withNtfServer (transport @c) $ testNtfClient test
ntfServerTest ::
forall c smp.
(Transport c, Encoding smp) =>
TProxy c ->
(Maybe C.ASignature, ByteString, ByteString, smp) ->
IO (Maybe C.ASignature, ByteString, ByteString, BrokerMsg)
ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
where
tPut' h (sig, corrId, queueId, smp) = do
let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp)
Right () <- tPut h (sig, t')
pure ()
tGet' h = do
(Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h
pure (Nothing, corrId, qId, cmd)
+40
View File
@@ -0,0 +1,40 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module NtfServerTests where
import Data.ByteString.Char8 (ByteString)
import NtfClient
import ServerTests (sampleDhPubKey, samplePubKey, sampleSig)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Protocol
import Simplex.Messaging.Transport
import Test.Hspec
ntfServerTests :: ATransport -> Spec
ntfServerTests t = do
describe "notifications server protocol syntax" $ ntfSyntaxTests t
ntfSyntaxTests :: ATransport -> Spec
ntfSyntaxTests (ATransport t) = do
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
describe "NEW" $ do
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX)
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX)
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH)
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH)
where
(>#>) ::
Encoding smp =>
(Maybe C.ASignature, ByteString, ByteString, smp) ->
(Maybe C.ASignature, ByteString, ByteString, BrokerMsg) ->
Expectation
command >#> response = ntfServerTest t command `shouldReturn` response
+11 -4
View File
@@ -154,11 +154,17 @@ smpAgentTest1_1_1 test' =
_test [h] = test' h
_test _ = error "expected 1 handle"
initAgentServers :: InitialAgentServers
initAgentServers =
InitialAgentServers
{ smp = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"],
ntf = ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"]
}
cfg :: AgentConfig
cfg =
defaultAgentConfig
{ tcpPort = agentTestPort,
initialSMPServers = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"],
tbqSize = 1,
dbFile = testDB,
smpCfg =
@@ -167,7 +173,7 @@ cfg =
defaultTransport = (testPort, transport @TLS),
tcpTimeout = 500_000
},
reconnectInterval = (reconnectInterval defaultAgentConfig) {initialInterval = 50_000},
reconnectInterval = defaultReconnectInterval {initialInterval = 50_000},
caCertificateFile = "tests/fixtures/ca.crt",
privateKeyFile = "tests/fixtures/server.key",
certificateFile = "tests/fixtures/server.crt"
@@ -175,9 +181,10 @@ cfg =
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m () -> (ThreadId -> m a) -> m a
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
let cfg' = cfg {tcpPort = port', dbFile = db', initialSMPServers = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
let cfg' = cfg {tcpPort = port', dbFile = db'}
initServers' = initAgentServers {smp = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
in serverBracket
(\started -> runSMPAgentBlocking t started cfg')
(\started -> runSMPAgentBlocking t started cfg' initServers')
afterProcess
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
+2
View File
@@ -4,6 +4,7 @@ import AgentTests (agentTests)
import CoreTests.EncodingTests
import CoreTests.ProtocolErrorTests
import CoreTests.VersionRangeTests
import NtfServerTests (ntfServerTests)
import ServerTests
import Simplex.Messaging.Transport (TLS, Transport (..))
import Simplex.Messaging.Transport.WebSockets (WS)
@@ -20,5 +21,6 @@ main = do
describe "Version range" versionRangeTests
describe "SMP server via TLS" $ serverTests (transport @TLS)
describe "SMP server via WebSockets" $ serverTests (transport @WS)
describe "Ntf server via TLS" $ ntfServerTests (transport @TLS)
describe "SMP client agent" $ agentTests (transport @TLS)
removeDirectoryRecursive "tests/tmp"