mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-04-27 12:55:16 +00:00
agent schema/methods/types/store methods for notifications tokens (#348)
* agent schema/methods/types/store methods for notifications tokens * register notification token on the server * agent commands for notification tokens * refactor initial servers from AgentConfig * agent store functions for notification tokens * server STM store methods for tokens * fix protocol client for ntfs (use generic handshake), minimal server and agent tests * server command to verify ntf token
This commit is contained in:
committed by
GitHub
parent
fb26916eea
commit
f577fcdacf
@@ -39,6 +39,7 @@ ntfServerCLIConfig =
|
||||
NtfServerConfig
|
||||
{ transports,
|
||||
subIdBytes = 24,
|
||||
regCodeBytes = 32,
|
||||
clientQSize = 16,
|
||||
subQSize = 64,
|
||||
pushQSize = 128,
|
||||
|
||||
@@ -11,7 +11,14 @@ import Simplex.Messaging.Agent.Server (runSMPAgent)
|
||||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
|
||||
cfg :: AgentConfig
|
||||
cfg = defaultAgentConfig {initialSMPServers = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"]}
|
||||
cfg = defaultAgentConfig
|
||||
|
||||
servers :: InitialAgentServers
|
||||
servers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
|
||||
ntf = []
|
||||
}
|
||||
|
||||
logCfg :: LogConfig
|
||||
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
@@ -20,4 +27,4 @@ main :: IO ()
|
||||
main = do
|
||||
putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig)
|
||||
setLogLevel LogInfo -- LogError
|
||||
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg
|
||||
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers
|
||||
|
||||
@@ -298,6 +298,8 @@ test-suite smp-server-test
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
CoreTests.VersionRangeTests
|
||||
NtfClient,
|
||||
NtfServerTests
|
||||
ServerTests
|
||||
SMPAgentClient
|
||||
SMPClient
|
||||
|
||||
+107
-10
@@ -48,6 +48,11 @@ module Simplex.Messaging.Agent
|
||||
suspendConnection,
|
||||
deleteConnection,
|
||||
setSMPServers,
|
||||
setNtfServers,
|
||||
registerNtfToken,
|
||||
verifyNtfToken,
|
||||
enableNtfCron,
|
||||
deleteNtfToken,
|
||||
logConnection,
|
||||
)
|
||||
where
|
||||
@@ -69,6 +74,7 @@ import Data.Maybe (isJust)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time.Clock
|
||||
import Data.Time.Clock.System (systemToUTCTime)
|
||||
import Data.Word (Word16)
|
||||
import Database.SQLite.Simple (SQLError)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
@@ -80,7 +86,8 @@ import Simplex.Messaging.Client (ServerTransmission)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken)
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfTknStatus (..))
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, MsgBody)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
@@ -93,11 +100,11 @@ import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
-- | Creates an SMP agent client instance
|
||||
getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> m AgentClient
|
||||
getSMPAgentClient cfg = newSMPAgentEnv cfg >>= runReaderT runAgent
|
||||
getSMPAgentClient :: (MonadRandom m, MonadUnliftIO m) => AgentConfig -> InitialAgentServers -> m AgentClient
|
||||
getSMPAgentClient cfg initServers = newSMPAgentEnv cfg >>= runReaderT runAgent
|
||||
where
|
||||
runAgent = do
|
||||
c <- getAgentClient
|
||||
c <- getAgentClient initServers
|
||||
action <- async $ subscriber c `E.finally` disconnectAgentClient c
|
||||
pure c {smpSubscriber = action}
|
||||
|
||||
@@ -150,10 +157,24 @@ deleteConnection c = withAgentEnv c . deleteConnection' c
|
||||
setSMPServers :: AgentErrorMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
setSMPServers c = withAgentEnv c . setSMPServers' c
|
||||
|
||||
setNtfServers :: AgentErrorMonad m => AgentClient -> [NtfServer] -> m ()
|
||||
setNtfServers c = withAgentEnv c . setNtfServers' c
|
||||
|
||||
-- | Register device notifications token
|
||||
registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken c = withAgentEnv c . registerNtfToken' c
|
||||
|
||||
-- | Verify device notifications token
|
||||
verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m ()
|
||||
verifyNtfToken c = withAgentEnv c .: verifyNtfToken' c
|
||||
|
||||
-- | Enable/disable periodic notifications
|
||||
enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
|
||||
enableNtfCron c = withAgentEnv c .: enableNtfCron' c
|
||||
|
||||
deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
|
||||
deleteNtfToken c = withAgentEnv c . deleteNtfToken' c
|
||||
|
||||
withAgentEnv :: AgentClient -> ReaderT Env m a -> m a
|
||||
withAgentEnv c = (`runReaderT` agentEnv c)
|
||||
|
||||
@@ -161,8 +182,8 @@ withAgentEnv c = (`runReaderT` agentEnv c)
|
||||
-- withAgentClient c = withAgentLock c . withAgentEnv c
|
||||
|
||||
-- | Creates an SMP agent client instance that receives commands and sends responses via 'TBQueue's.
|
||||
getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => m AgentClient
|
||||
getAgentClient = ask >>= atomically . newAgentClient
|
||||
getAgentClient :: (MonadUnliftIO m, MonadReader Env m) => InitialAgentServers -> m AgentClient
|
||||
getAgentClient initServers = ask >>= atomically . newAgentClient initServers
|
||||
|
||||
logConnection :: MonadUnliftIO m => AgentClient -> Bool -> m ()
|
||||
logConnection c connected =
|
||||
@@ -499,8 +520,73 @@ setSMPServers' :: AgentMonad m => AgentClient -> NonEmpty SMPServer -> m ()
|
||||
setSMPServers' c servers = do
|
||||
atomically $ writeTVar (smpServers c) servers
|
||||
|
||||
registerNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken' c token = pure ()
|
||||
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken' c deviceToken =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId, ntfTknAction} -> case (ntfTokenId, ntfTknAction) of
|
||||
(Nothing, Just (NTARegister ntfPubKey)) -> registerToken tkn ntfPubKey
|
||||
-- TODO request verification code again in case there is registration, but no verification code in DB - probably after some timeout?
|
||||
(Just tknId, Just (NTAVerify code)) ->
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
|
||||
(Just tknId, Just (NTACron interval)) ->
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfEnableCron c tknId tkn interval
|
||||
(Just _tknId, Just NTACheck) -> pure ()
|
||||
(Just tknId, Just NTADelete) ->
|
||||
t tkn (NTExpired, Nothing) $ agentNtfDeleteToken c tknId tkn
|
||||
_ -> pure ()
|
||||
_ ->
|
||||
getNtfServer c >>= \case
|
||||
Just ntfServer ->
|
||||
asks (cmdSignAlg . config) >>= \case
|
||||
C.SignAlg a -> do
|
||||
(ntfPubKey, ntfPrivKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
let tkn = newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey
|
||||
withStore $ \st -> createNtfToken st tkn
|
||||
registerToken tkn ntfPubKey
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
t tkn = withToken tkn Nothing
|
||||
registerToken :: NtfToken -> C.APublicVerifyKey -> m ()
|
||||
registerToken tkn ntfPubKey = do
|
||||
(pubDhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
(tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey
|
||||
let dhSecret = C.dh' srvPubDhKey privDhKey
|
||||
withStore $ \st -> updateNtfTokenRegistration st tkn tknId dhSecret
|
||||
|
||||
verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m ()
|
||||
verifyNtfToken' c deviceToken code =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId} ->
|
||||
withToken tkn (Just (NTConfirmed, NTAVerify code)) (NTActive, Just NTACheck) $
|
||||
agentNtfVerifyToken c tknId tkn code
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
enableNtfCron' :: AgentMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
|
||||
enableNtfCron' c deviceToken interval =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive} ->
|
||||
withToken tkn (Just (NTActive, NTACron interval)) (NTActive, Just NTACheck) $
|
||||
agentNtfEnableCron c tknId tkn interval
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
|
||||
deleteNtfToken' c deviceToken =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus} ->
|
||||
withToken tkn (Just (ntfTknStatus, NTADelete)) (NTExpired, Nothing) $
|
||||
agentNtfDeleteToken c tknId tkn
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a
|
||||
withToken tkn from_ (toStatus, toAction_) f = do
|
||||
forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action)
|
||||
res <- f
|
||||
withStore $ \st -> updateNtfToken st tkn toStatus toAction_
|
||||
pure res
|
||||
|
||||
setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m ()
|
||||
setNtfServers' c servers = do
|
||||
atomically $ writeTVar (ntfServers c) servers
|
||||
|
||||
getSMPServer :: AgentMonad m => AgentClient -> m SMPServer
|
||||
getSMPServer c = do
|
||||
@@ -509,8 +595,19 @@ getSMPServer c = do
|
||||
srv :| [] -> pure srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
i <- atomically . stateTVar gen $ randomR (0, L.length servers - 1)
|
||||
pure $ servers L.!! i
|
||||
atomically . stateTVar gen $
|
||||
first (servers L.!!) . randomR (0, L.length servers - 1)
|
||||
|
||||
getNtfServer :: AgentMonad m => AgentClient -> m (Maybe NtfServer)
|
||||
getNtfServer c = do
|
||||
ntfServers <- readTVarIO $ ntfServers c
|
||||
case ntfServers of
|
||||
[] -> pure Nothing
|
||||
[srv] -> pure $ Just srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
atomically . stateTVar gen $
|
||||
first (Just . (servers !!)) . randomR (0, length servers - 1)
|
||||
|
||||
subscriber :: (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
subscriber c@AgentClient {msgQ} = forever $ do
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeSynonymInstances #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Client
|
||||
( AgentClient (..),
|
||||
@@ -24,6 +23,10 @@ module Simplex.Messaging.Agent.Client
|
||||
RetryInterval (..),
|
||||
secureQueue,
|
||||
sendAgentMessage,
|
||||
agentNtfRegisterToken,
|
||||
agentNtfVerifyToken,
|
||||
agentNtfDeleteToken,
|
||||
agentNtfEnableCron,
|
||||
agentCbEncrypt,
|
||||
agentCbDecrypt,
|
||||
cryptoError,
|
||||
@@ -51,6 +54,7 @@ import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Text.Encoding
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
@@ -59,8 +63,8 @@ import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Client (NtfClient)
|
||||
import Simplex.Messaging.Notifications.Protocol (NtfResponse)
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
@@ -83,8 +87,9 @@ data AgentClient = AgentClient
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
smpServers :: TVar (NonEmpty SMPServer),
|
||||
ntfServers :: TVar [NtfServer],
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
ntfClients :: TMap ProtocolServer NtfClientVar,
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
subscrConns :: TMap ConnId SMPServer,
|
||||
@@ -99,13 +104,14 @@ data AgentClient = AgentClient
|
||||
lock :: TMVar ()
|
||||
}
|
||||
|
||||
newAgentClient :: Env -> STM AgentClient
|
||||
newAgentClient agentEnv = do
|
||||
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
|
||||
newAgentClient InitialAgentServers {smp, ntf} agentEnv = do
|
||||
let qSize = tbqSize $ config agentEnv
|
||||
rcvQ <- newTBQueue qSize
|
||||
subQ <- newTBQueue qSize
|
||||
msgQ <- newTBQueue qSize
|
||||
smpServers <- newTVar $ initialSMPServers (config agentEnv)
|
||||
smpServers <- newTVar smp
|
||||
ntfServers <- newTVar ntf
|
||||
smpClients <- TM.empty
|
||||
ntfClients <- TM.empty
|
||||
subscrSrvrs <- TM.empty
|
||||
@@ -118,7 +124,7 @@ newAgentClient agentEnv = do
|
||||
asyncClients <- newTVar []
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1)
|
||||
lock <- newTMVar ()
|
||||
return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
|
||||
return AgentClient {rcvQ, subQ, msgQ, smpServers, ntfServers, smpClients, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
|
||||
|
||||
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
@@ -206,7 +212,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
notifySub :: ACommand 'Agent -> ConnId -> IO ()
|
||||
notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> ProtocolServer -> m NtfClient
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {ntfClients} srv =
|
||||
atomically (getClientVar srv ntfClients)
|
||||
>>= either
|
||||
@@ -256,10 +262,10 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
|
||||
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar clientVar r
|
||||
successAction smp
|
||||
successAction client
|
||||
Left e -> do
|
||||
if e == BROKER NETWORK || e == BROKER TIMEOUT
|
||||
then retryAction
|
||||
@@ -464,6 +470,22 @@ sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg =
|
||||
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
||||
liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
||||
|
||||
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
|
||||
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
|
||||
withClient c ntfServer $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
|
||||
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
|
||||
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
|
||||
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
|
||||
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
|
||||
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
|
||||
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
|
||||
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
|
||||
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do
|
||||
cmNonce <- liftIO C.randomCbNonce
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
|
||||
module Simplex.Messaging.Agent.Env.SQLite
|
||||
( AgentConfig (..),
|
||||
InitialAgentServers (..),
|
||||
defaultAgentConfig,
|
||||
defaultReconnectInterval,
|
||||
Env (..),
|
||||
newSMPAgentEnv,
|
||||
)
|
||||
@@ -25,13 +27,18 @@ import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Client (NtfServer)
|
||||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO.STM
|
||||
|
||||
data InitialAgentServers = InitialAgentServers
|
||||
{ smp :: NonEmpty SMPServer,
|
||||
ntf :: [NtfServer]
|
||||
}
|
||||
|
||||
data AgentConfig = AgentConfig
|
||||
{ tcpPort :: ServiceName,
|
||||
initialSMPServers :: NonEmpty SMPServer,
|
||||
cmdSignAlg :: C.SignAlg,
|
||||
connIdBytes :: Int,
|
||||
tbqSize :: Natural,
|
||||
@@ -47,11 +54,20 @@ data AgentConfig = AgentConfig
|
||||
certificateFile :: FilePath
|
||||
}
|
||||
|
||||
defaultReconnectInterval :: RetryInterval
|
||||
defaultReconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
increaseAfter = 10 * second,
|
||||
maxInterval = 10 * second
|
||||
}
|
||||
where
|
||||
second = 1_000_000
|
||||
|
||||
defaultAgentConfig :: AgentConfig
|
||||
defaultAgentConfig =
|
||||
AgentConfig
|
||||
{ tcpPort = "5224",
|
||||
initialSMPServers = undefined, -- TODO move it elsewhere?
|
||||
cmdSignAlg = C.SignAlg C.SEd448,
|
||||
connIdBytes = 12,
|
||||
tbqSize = 64,
|
||||
@@ -60,12 +76,7 @@ defaultAgentConfig =
|
||||
yesToMigrations = False,
|
||||
smpCfg = defaultClientConfig {defaultTransport = ("5223", transport @TLS)},
|
||||
ntfCfg = defaultClientConfig {defaultTransport = ("443", transport @TLS)},
|
||||
reconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
increaseAfter = 10 * second,
|
||||
maxInterval = 10 * second
|
||||
},
|
||||
reconnectInterval = defaultReconnectInterval,
|
||||
helloTimeout = 2 * nominalDay,
|
||||
-- CA certificate private key is not needed for initialization
|
||||
-- ! we do not generate these
|
||||
@@ -73,8 +84,6 @@ defaultAgentConfig =
|
||||
privateKeyFile = "/etc/opt/simplex-agent/agent.key",
|
||||
certificateFile = "/etc/opt/simplex-agent/agent.crt"
|
||||
}
|
||||
where
|
||||
second = 1_000_000
|
||||
|
||||
data Env = Env
|
||||
{ config :: AgentConfig,
|
||||
|
||||
@@ -31,17 +31,17 @@ import UnliftIO.STM
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent t cfg = do
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> m ()
|
||||
runSMPAgent t cfg initServers = do
|
||||
started <- newEmptyTMVarIO
|
||||
runSMPAgentBlocking t started cfg
|
||||
runSMPAgentBlocking t started cfg initServers
|
||||
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration with signalling.
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> InitialAgentServers -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers = do
|
||||
runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
|
||||
where
|
||||
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
@@ -50,7 +50,7 @@ runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertifica
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
runTransportServer started tcpPort tlsServerParams $ \(h :: c) -> do
|
||||
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
c <- getAgentClient
|
||||
c <- getAgentClient initServers
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runAgentClient c)
|
||||
`E.finally` disconnectAgentClient c
|
||||
|
||||
@@ -20,6 +20,8 @@ import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff, SkippedMsgKeys)
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfTknStatus, NtfTokenId)
|
||||
import Simplex.Messaging.Protocol
|
||||
( MsgBody,
|
||||
MsgId,
|
||||
@@ -75,6 +77,12 @@ class Monad m => MonadAgentStore s m where
|
||||
getSkippedMsgKeys :: s -> ConnId -> m SkippedMsgKeys
|
||||
updateRatchet :: s -> ConnId -> RatchetX448 -> SkippedMsgDiff -> m ()
|
||||
|
||||
-- Notification device token persistence
|
||||
createNtfToken :: s -> NtfToken -> m ()
|
||||
getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken) -- return current token if it exists and mark any old tokens for deletion
|
||||
updateNtfTokenRegistration :: s -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m ()
|
||||
updateNtfToken :: s -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m ()
|
||||
|
||||
-- * Queue types
|
||||
|
||||
-- | A receive queue. SMP queue through which the agent receives messages from a sender.
|
||||
|
||||
@@ -23,7 +23,6 @@ module Simplex.Messaging.Agent.Store.SQLite
|
||||
connectSQLiteStore,
|
||||
withConnection,
|
||||
withTransaction,
|
||||
fromTextField_,
|
||||
firstRow,
|
||||
)
|
||||
where
|
||||
@@ -41,14 +40,14 @@ import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (find, foldl')
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (listToMaybe)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeLatin1)
|
||||
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLData (..), SQLError, ToRow, field)
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
import Database.SQLite.Simple (FromRow, NamedParam (..), Only (..), SQLError, ToRow, field)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.Internal (Field (..))
|
||||
import Database.SQLite.Simple.Ok (Ok (Ok))
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
@@ -59,7 +58,9 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (blobFieldParser)
|
||||
import Simplex.Messaging.Notifications.Client (NtfServer, NtfTknAction, NtfToken (..))
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfTknStatus (..), NtfTokenId)
|
||||
import Simplex.Messaging.Parsers (blobFieldParser, fromTextField_)
|
||||
import Simplex.Messaging.Protocol (MsgBody, ProtocolServer (..))
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Util (bshow, liftIOEither)
|
||||
@@ -565,6 +566,63 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
forM_ (M.assocs mks) $ \(msgN, mk) ->
|
||||
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
|
||||
|
||||
createNtfToken :: SQLiteStore -> NtfToken -> m ()
|
||||
createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} =
|
||||
liftIO . withTransaction st $ \db -> do
|
||||
upsertNtfServer_ db srv
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO ntf_tokens
|
||||
(provider, device_token, ntf_host, ntf_port, tkn_id, tkn_priv_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
|]
|
||||
(provider, token, host, port, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction)
|
||||
|
||||
getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken)
|
||||
getDeviceNtfToken st t@(DeviceToken provider token) =
|
||||
liftIO . withTransaction st $ \db ->
|
||||
fmap ntfToken . listToMaybe
|
||||
<$> DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
|
||||
t.tkn_id, t.tkn_priv_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action
|
||||
FROM ntf_tokens t
|
||||
JOIN ntf_servers s USING (ntf_host, ntf_port)
|
||||
WHERE t.provider = ? AND t.device_token = ?
|
||||
|]
|
||||
(provider, token)
|
||||
where
|
||||
ntfToken (host, port, keyHash, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) =
|
||||
let ntfServer = ProtocolServer {host, port, keyHash}
|
||||
in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction}
|
||||
|
||||
updateNtfTokenRegistration :: SQLiteStore -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m ()
|
||||
updateNtfTokenRegistration st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret =
|
||||
liftIO . withTransaction st $ \db -> do
|
||||
updatedAt <- getCurrentTime
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
UPDATE ntf_tokens
|
||||
SET tkn_id = ?, tkn_dh_secret = ?, tkn_status = ?, tkn_action = ?, updated_at = ?
|
||||
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
||||
|]
|
||||
(tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port)
|
||||
|
||||
updateNtfToken :: SQLiteStore -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m ()
|
||||
updateNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction =
|
||||
liftIO . withTransaction st $ \db -> do
|
||||
updatedAt <- getCurrentTime
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
UPDATE ntf_tokens
|
||||
SET tkn_status = ?, tkn_action = ?, updated_at = ?
|
||||
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
||||
|]
|
||||
(tknStatus, tknAction, updatedAt, provider, token, host, port)
|
||||
|
||||
-- * Auxiliary helpers
|
||||
|
||||
instance ToField QueueStatus where toField = toField . serializeQueueStatus
|
||||
@@ -611,14 +669,6 @@ instance ToField (SConnectionMode c) where toField = toField . connMode
|
||||
|
||||
instance FromField AConnectionMode where fromField = fromTextField_ $ fmap connMode' . connModeT
|
||||
|
||||
fromTextField_ :: (E.Typeable a) => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ fromText = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
case fromText t of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
listToEither :: e -> [a] -> Either e a
|
||||
listToEither _ (x : _) = Right x
|
||||
listToEither e _ = Left e
|
||||
@@ -669,6 +719,19 @@ upsertServer_ dbConn ProtocolServer {host, port, keyHash} = do
|
||||
|]
|
||||
[":host" := host, ":port" := port, ":key_hash" := keyHash]
|
||||
|
||||
upsertNtfServer_ :: DB.Connection -> NtfServer -> IO ()
|
||||
upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do
|
||||
DB.executeNamed
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (:host,:port,:key_hash)
|
||||
ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET
|
||||
ntf_host=excluded.ntf_host,
|
||||
ntf_port=excluded.ntf_port,
|
||||
ntf_key_hash=excluded.ntf_key_hash;
|
||||
|]
|
||||
[":host" := host, ":port" := port, ":key_hash" := keyHash]
|
||||
|
||||
-- * createRcvConn helpers
|
||||
|
||||
insertRcvQueue_ :: DB.Connection -> ConnId -> RcvQueue -> IO ()
|
||||
|
||||
@@ -8,40 +8,29 @@ import Database.SQLite.Simple.QQ (sql)
|
||||
m20220322_notifications :: Query
|
||||
m20220322_notifications =
|
||||
[sql|
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB;
|
||||
|
||||
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id);
|
||||
|
||||
CREATE TABLE ntf_servers (
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_key_hash BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (ntf_host, ntf_port)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE ntf_subscriptions (
|
||||
CREATE TABLE ntf_tokens (
|
||||
provider TEXT NOT NULL, -- apn
|
||||
device_token TEXT NOT NULL,
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_sub_id BLOB NOT NULL,
|
||||
ntf_sub_status TEXT NOT NULL, -- new, created, active, pending, error_auth
|
||||
ntf_sub_action TEXT, -- if there is an action required on this subscription: create / check / token / delete
|
||||
ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours
|
||||
ntf_token TEXT NOT NULL, -- or BLOB?
|
||||
smp_host TEXT NOT NULL,
|
||||
smp_port TEXT NOT NULL,
|
||||
smp_ntf_id BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL, -- this is to check subscription status periodically to know when it was last checked
|
||||
PRIMARY KEY (ntf_host, ntf_port, ntf_sub_id),
|
||||
tkn_id BLOB, -- token ID assigned by notifications server
|
||||
tkn_priv_key BLOB NOT NULL, -- private key to sign token commands
|
||||
tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications
|
||||
tkn_status TEXT NOT NULL,
|
||||
tkn_action BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')), -- this is to check token status periodically to know when it was last checked
|
||||
PRIMARY KEY (provider, device_token, ntf_host, ntf_port),
|
||||
FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
FOREIGN KEY (smp_host, smp_port, smp_ntf_id) REFERENCES rcv_queues (host, port, ntf_id)
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
|]
|
||||
|
||||
+38
@@ -0,0 +1,38 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220404_ntf_subscriptions_draft where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20220404_ntf_subscriptions_draft :: Query
|
||||
m20220404_ntf_subscriptions_draft =
|
||||
[sql|
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB;
|
||||
|
||||
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id);
|
||||
|
||||
CREATE TABLE ntf_subscriptions (
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_sub_id BLOB NOT NULL,
|
||||
ntf_sub_status TEXT NOT NULL, -- new, created, active, pending, error_auth
|
||||
ntf_sub_action TEXT, -- if there is an action required on this subscription: create / check / token / delete
|
||||
ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours
|
||||
ntf_token TEXT NOT NULL, -- or BLOB?
|
||||
smp_host TEXT NOT NULL,
|
||||
smp_port TEXT NOT NULL,
|
||||
smp_ntf_id BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL, -- this is to check subscription status periodically to know when it was last checked
|
||||
PRIMARY KEY (ntf_host, ntf_port, ntf_sub_id),
|
||||
FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
FOREIGN KEY (smp_host, smp_port, smp_ntf_id) REFERENCES rcv_queues (host, port, ntf_id)
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
|]
|
||||
@@ -66,7 +66,7 @@ import Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport (..), THandle (..), TLS, TProxy, Transport (..), TransportError)
|
||||
import Simplex.Messaging.Transport.Client (runTransportClient, smpClientHandshake)
|
||||
import Simplex.Messaging.Transport.Client (runTransportClient)
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
|
||||
@@ -132,11 +132,11 @@ type Response msg = Either ProtocolClientError msg
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtocolServer -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected =
|
||||
(atomically mkSMPClient >>= runClient useTransport)
|
||||
(atomically mkProtocolClient >>= runClient useTransport)
|
||||
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
|
||||
where
|
||||
mkSMPClient :: STM (ProtocolClient msg)
|
||||
mkSMPClient = do
|
||||
mkProtocolClient :: STM (ProtocolClient msg)
|
||||
mkProtocolClient = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- TM.empty
|
||||
@@ -177,7 +177,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc
|
||||
|
||||
client :: forall c. Transport c => TProxy c -> ProtocolClient msg -> TMVar (Either ProtocolClientError (THandle c)) -> c -> IO ()
|
||||
client _ c thVar h =
|
||||
runExceptT (smpClientHandshake h $ keyHash protocolServer) >>= \case
|
||||
runExceptT (protocolClientHandshake @msg h $ keyHash protocolServer) >>= \case
|
||||
Left e -> atomically . putTMVar thVar . Left $ PCETransportError e
|
||||
Right th@THandle {sessionId} -> do
|
||||
atomically $ do
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Encoding.String
|
||||
( StrEncoding (..),
|
||||
( TextEncoding (..),
|
||||
StrEncoding (..),
|
||||
Str (..),
|
||||
strP_,
|
||||
strToJSON,
|
||||
@@ -26,11 +27,16 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Text (Text)
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
class TextEncoding a where
|
||||
textEncode :: a -> Text
|
||||
textDecode :: Text -> Maybe a
|
||||
|
||||
-- | Serializing human-readable and (where possible) URI-friendly strings for SMP and SMP agent protocols
|
||||
class StrEncoding a where
|
||||
{-# MINIMAL strEncode, (strDecode | strP) #-}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
@@ -7,42 +9,50 @@ module Simplex.Messaging.Notifications.Client where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Word (Word16)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Parsers (blobFieldDecoder)
|
||||
import Simplex.Messaging.Protocol (ProtocolServer)
|
||||
|
||||
type NtfServer = ProtocolServer
|
||||
|
||||
type NtfClient = ProtocolClient NtfResponse
|
||||
|
||||
registerNtfToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
registerNtfToken c pKey newTkn =
|
||||
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
ntfRegisterToken c pKey newTkn =
|
||||
sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case
|
||||
NRId tknId dhKey -> pure (tknId, dhKey)
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
verifyNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegistrationCode -> ExceptT ProtocolClientError IO ()
|
||||
verifyNtfToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
|
||||
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO ()
|
||||
ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
|
||||
|
||||
deleteNtfToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
|
||||
deleteNtfToken = okNtfCommand TDEL
|
||||
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteToken = okNtfCommand TDEL
|
||||
|
||||
enableNtfCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
|
||||
enableNtfCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
|
||||
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
|
||||
ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
|
||||
|
||||
createNtfSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519)
|
||||
createNtfSubsciption c pKey newSub =
|
||||
ntfCreateSubsciption :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO (NtfSubscriptionId, C.PublicKeyX25519)
|
||||
ntfCreateSubsciption c pKey newSub =
|
||||
sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case
|
||||
NRId tknId dhKey -> pure (tknId, dhKey)
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
checkNtfSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
|
||||
checkNtfSubscription c pKey subId =
|
||||
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
|
||||
ntfCheckSubscription c pKey subId =
|
||||
sendNtfCommand c (Just pKey) subId SCHK >>= \case
|
||||
NRStat stat -> pure stat
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
deleteNfgSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
|
||||
deleteNfgSubscription = okNtfCommand SDEL
|
||||
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteSubscription = okNtfCommand SDEL
|
||||
|
||||
-- | Send notification server command
|
||||
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse
|
||||
@@ -53,3 +63,58 @@ okNtfCommand cmd c pKey entId =
|
||||
sendNtfCommand c (Just pKey) entId cmd >>= \case
|
||||
NROk -> return ()
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
data NtfTknAction
|
||||
= NTARegister C.APublicVerifyKey -- public key to send to the server
|
||||
| NTAVerify NtfRegCode -- code to verify token
|
||||
| NTACheck
|
||||
| NTACron Word16
|
||||
| NTADelete
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfTknAction where
|
||||
smpEncode = \case
|
||||
NTARegister key -> smpEncode ('R', key)
|
||||
NTAVerify code -> smpEncode ('V', code)
|
||||
NTACheck -> "C"
|
||||
NTACron interval -> smpEncode ('I', interval)
|
||||
NTADelete -> "D"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'R' -> NTARegister <$> smpP
|
||||
'V' -> NTAVerify <$> smpP
|
||||
'C' -> pure NTACheck
|
||||
'I' -> NTACron <$> smpP
|
||||
'D' -> pure NTADelete
|
||||
_ -> fail "bad NtfTknAction"
|
||||
|
||||
instance FromField NtfTknAction where fromField = blobFieldDecoder smpDecode
|
||||
|
||||
instance ToField NtfTknAction where toField = toField . smpEncode
|
||||
|
||||
data NtfToken = NtfToken
|
||||
{ deviceToken :: DeviceToken,
|
||||
ntfServer :: NtfServer,
|
||||
ntfTokenId :: Maybe NtfTokenId,
|
||||
-- | key used by the ntf client to sign transmissions
|
||||
ntfPrivKey :: C.APrivateSignKey,
|
||||
-- | shared DH secret used to encrypt/decrypt notifications e2e
|
||||
ntfDhSecret :: Maybe C.DhSecretX25519,
|
||||
-- | token status
|
||||
ntfTknStatus :: NtfTknStatus,
|
||||
-- | pending token action and the earliest time
|
||||
ntfTknAction :: Maybe NtfTknAction
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newNtfToken :: DeviceToken -> NtfServer -> C.APrivateSignKey -> C.APublicVerifyKey -> NtfToken
|
||||
newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey =
|
||||
NtfToken
|
||||
{ deviceToken,
|
||||
ntfServer,
|
||||
ntfTokenId = Nothing,
|
||||
ntfPrivKey,
|
||||
ntfDhSecret = Nothing,
|
||||
ntfTknStatus = NTNew,
|
||||
ntfTknAction = Just $ NTARegister ntfPubKey
|
||||
}
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
module Simplex.Messaging.Notifications.Protocol where
|
||||
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
@@ -17,8 +18,13 @@ import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
|
||||
import Simplex.Messaging.Parsers (fromTextField_)
|
||||
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
@@ -87,12 +93,31 @@ instance ProtocolMsgTag NtfCmdTag where
|
||||
instance NtfEntityI e => ProtocolMsgTag (NtfCommandTag e) where
|
||||
decodeTag s = decodeTag s >>= (\(NCT _ t) -> checkEntity' t)
|
||||
|
||||
type NtfRegistrationCode = ByteString
|
||||
newtype NtfRegCode = NtfRegCode ByteString
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding NtfRegCode where
|
||||
smpEncode (NtfRegCode code) = smpEncode code
|
||||
smpP = NtfRegCode <$> smpP
|
||||
|
||||
instance StrEncoding NtfRegCode where
|
||||
strEncode (NtfRegCode m) = strEncode m
|
||||
strDecode s = NtfRegCode <$> strDecode s
|
||||
strP = NtfRegCode <$> strP
|
||||
|
||||
instance FromJSON NtfRegCode where
|
||||
parseJSON = strParseJSON "NtfRegCode"
|
||||
|
||||
instance ToJSON NtfRegCode where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
data NewNtfEntity (e :: NtfEntity) where
|
||||
NewNtfTkn :: DeviceToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> NewNtfEntity 'Token
|
||||
NewNtfSub :: NtfTokenId -> SMPQueueNtf -> NewNtfEntity 'Subscription
|
||||
|
||||
deriving instance Show (NewNtfEntity e)
|
||||
|
||||
data ANewNtfEntity = forall e. NtfEntityI e => ANE (SNtfEntity e) (NewNtfEntity e)
|
||||
|
||||
instance NtfEntityI e => Encoding (NewNtfEntity e) where
|
||||
@@ -111,6 +136,7 @@ instance Encoding ANewNtfEntity where
|
||||
|
||||
instance Protocol NtfResponse where
|
||||
type ProtocolCommand NtfResponse = NtfCmd
|
||||
protocolClientHandshake = ntfClientHandshake
|
||||
protocolPing = NtfCmd SSubscription PING
|
||||
protocolError = \case
|
||||
NRErr e -> Just e
|
||||
@@ -120,7 +146,7 @@ data NtfCommand (e :: NtfEntity) where
|
||||
-- | register new device token for notifications
|
||||
TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token
|
||||
-- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token
|
||||
TVFY :: NtfRegistrationCode -> NtfCommand 'Token
|
||||
TVFY :: NtfRegCode -> NtfCommand 'Token
|
||||
-- | delete token - all subscriptions will be removed and no more notifications will be sent
|
||||
TDEL :: NtfCommand 'Token
|
||||
-- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable
|
||||
@@ -134,8 +160,12 @@ data NtfCommand (e :: NtfEntity) where
|
||||
-- | keep-alive command
|
||||
PING :: NtfCommand 'Subscription
|
||||
|
||||
deriving instance Show (NtfCommand e)
|
||||
|
||||
data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
|
||||
|
||||
deriving instance Show NtfCmd
|
||||
|
||||
instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
type Tag (NtfCommand e) = NtfCommandTag e
|
||||
encodeProtocol = \case
|
||||
@@ -263,6 +293,7 @@ data SMPQueueNtf = SMPQueueNtf
|
||||
notifierId :: NotifierId,
|
||||
notifierKey :: NtfPrivateSignKey
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding SMPQueueNtf where
|
||||
smpEncode SMPQueueNtf {smpServer, notifierId, notifierKey} = smpEncode (smpServer, notifierId, notifierKey)
|
||||
@@ -270,17 +301,30 @@ instance Encoding SMPQueueNtf where
|
||||
(smpServer, notifierId, notifierKey) <- smpP
|
||||
pure $ SMPQueueNtf smpServer notifierId notifierKey
|
||||
|
||||
data PushPlatform = PPApple
|
||||
data PushProvider = PPApple
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding PushPlatform where
|
||||
instance Encoding PushProvider where
|
||||
smpEncode = \case
|
||||
PPApple -> "A"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'A' -> pure PPApple
|
||||
_ -> fail "bad PushPlatform"
|
||||
_ -> fail "bad PushProvider"
|
||||
|
||||
data DeviceToken = DeviceToken PushPlatform ByteString
|
||||
instance TextEncoding PushProvider where
|
||||
textEncode = \case
|
||||
PPApple -> "apple"
|
||||
textDecode = \case
|
||||
"apple" -> Just PPApple
|
||||
_ -> Nothing
|
||||
|
||||
instance FromField PushProvider where fromField = fromTextField_ textDecode
|
||||
|
||||
instance ToField PushProvider where toField = toField . textEncode
|
||||
|
||||
data DeviceToken = DeviceToken PushProvider ByteString
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding DeviceToken where
|
||||
smpEncode (DeviceToken p t) = smpEncode (p, t)
|
||||
@@ -322,17 +366,40 @@ instance Encoding NtfSubStatus where
|
||||
_ -> fail "bad NtfError"
|
||||
|
||||
data NtfTknStatus
|
||||
= -- | state after registration (TNEW)
|
||||
= -- | Token created in DB
|
||||
NTNew
|
||||
| -- | if initial notification or verification failed (push provider error)
|
||||
| -- | state after registration (TNEW)
|
||||
NTRegistered
|
||||
| -- | if initial notification failed (push provider error) or verification failed
|
||||
NTInvalid
|
||||
| -- | if initial notification succeeded
|
||||
| -- | Token confirmed via notification (accepted by push provider or verification code received by client)
|
||||
NTConfirmed
|
||||
| -- | after successful verification (TVFY)
|
||||
NTActive
|
||||
| -- | after it is no longer valid (push provider error)
|
||||
NTExpired
|
||||
deriving (Eq)
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance TextEncoding NtfTknStatus where
|
||||
textEncode = \case
|
||||
NTNew -> "new"
|
||||
NTRegistered -> "registered"
|
||||
NTInvalid -> "invalid"
|
||||
NTConfirmed -> "confirmed"
|
||||
NTActive -> "active"
|
||||
NTExpired -> "expired"
|
||||
textDecode = \case
|
||||
"new" -> Just NTNew
|
||||
"registered" -> Just NTRegistered
|
||||
"invalid" -> Just NTInvalid
|
||||
"confirmed" -> Just NTConfirmed
|
||||
"active" -> Just NTActive
|
||||
"expired" -> Just NTExpired
|
||||
_ -> Nothing
|
||||
|
||||
instance FromField NtfTknStatus where fromField = fromTextField_ textDecode
|
||||
|
||||
instance ToField NtfTknStatus where toField = toField . textEncode
|
||||
|
||||
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
|
||||
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
|
||||
|
||||
@@ -14,8 +14,8 @@ import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import qualified Data.Aeson as J
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Functor (($>))
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -106,7 +106,12 @@ ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}}
|
||||
ntfPush :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m ()
|
||||
ntfPush NtfPushServer {pushQ} = forever $ do
|
||||
atomically (readTBQueue pushQ) >>= \case
|
||||
(NtfTknData {}, Notification {}) -> pure ()
|
||||
(NtfTknData {tknStatus}, notification) -> do
|
||||
liftIO $ print $ J.encode notification
|
||||
-- TODO status update should happen after the token status successfully sent
|
||||
case notification of
|
||||
PNVerification _ -> atomically $ writeTVar tknStatus NTConfirmed
|
||||
_ -> pure ()
|
||||
|
||||
runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m ()
|
||||
runNtfClientTransport th@THandle {sessionId} = do
|
||||
@@ -122,7 +127,7 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
|
||||
receive th NtfServerClient {rcvQ, sndQ} = forever $ do
|
||||
t@(sig, signed, (corrId, subId, cmdOrError)) <- tGet th
|
||||
t@(_, _, (corrId, subId, cmdOrError)) <- tGet th
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, subId, NRErr e)
|
||||
Right cmd ->
|
||||
@@ -144,7 +149,7 @@ verifyNtfTransmission ::
|
||||
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
|
||||
case cmd of
|
||||
NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) ->
|
||||
-- TODO check that token is not already in store
|
||||
-- TODO check that token is not already in store, if it is - verify that the saved key is the same
|
||||
pure $
|
||||
if verifyCmdSignature sig_ signed k
|
||||
then VRVerified (NtfReqNew corrId (ANE SToken n))
|
||||
@@ -189,16 +194,21 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} =
|
||||
(srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair'
|
||||
let dhSecret = C.dh' dhPubKey srvDrivDhKey
|
||||
tknId <- getId
|
||||
regCode <- getRegCode
|
||||
atomically $ do
|
||||
tkn <- mkNtfTknData newTkn dhSecret
|
||||
tkn <- mkNtfTknData newTkn dhSecret regCode
|
||||
addNtfToken st tknId tkn
|
||||
writeTBQueue pushQ (tkn, Notification)
|
||||
-- pure (corrId, sId, NRSubId pubDhKey)
|
||||
writeTBQueue pushQ (tkn, PNVerification regCode)
|
||||
pure (corrId, "", NRId tknId srvDhPubKey)
|
||||
NtfReqCmd SToken tkn (corrId, tknId, cmd) ->
|
||||
NtfReqCmd SToken (NtfTkn NtfTknData {tknStatus, tknRegCode}) (corrId, tknId, cmd) -> do
|
||||
status <- readTVarIO tknStatus
|
||||
(corrId,tknId,) <$> case cmd of
|
||||
TNEW newTkn -> pure NROk -- TODO when duplicate token sent
|
||||
TVFY code -> pure NROk
|
||||
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
|
||||
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
|
||||
atomically $ writeTVar tknStatus NTActive
|
||||
pure NROk
|
||||
| otherwise -> pure $ NRErr AUTH
|
||||
TDEL -> pure NROk
|
||||
TCRN int -> pure NROk
|
||||
NtfReqNew corrId (ANE SSubscription newSub) -> pure (corrId, "", NROk)
|
||||
@@ -209,8 +219,11 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ} NtfPushServer {pushQ} =
|
||||
SDEL -> pure NROk
|
||||
PING -> pure NRPong
|
||||
getId :: m NtfEntityId
|
||||
getId = do
|
||||
n <- asks $ subIdBytes . config
|
||||
getId = getRandomBytes =<< asks (subIdBytes . config)
|
||||
getRegCode :: m NtfRegCode
|
||||
getRegCode = NtfRegCode <$> (getRandomBytes =<< asks (regCodeBytes . config))
|
||||
getRandomBytes :: Int -> m ByteString
|
||||
getRandomBytes n = do
|
||||
gVar <- asks idsDrg
|
||||
atomically (randomBytes n gVar)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
@@ -8,8 +9,11 @@ module Simplex.Messaging.Notifications.Server.Env where
|
||||
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.Aeson (FromJSON, ToJSON)
|
||||
import qualified Data.Aeson as J
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import GHC.Generics
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural
|
||||
@@ -17,6 +21,7 @@ import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Subscriptions
|
||||
import Simplex.Messaging.Parsers (dropPrefix, taggedObjectJSON)
|
||||
import Simplex.Messaging.Protocol (CorrId, Transmission)
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
@@ -25,6 +30,7 @@ import UnliftIO.STM
|
||||
data NtfServerConfig = NtfServerConfig
|
||||
{ transports :: [(ServiceName, ATransport)],
|
||||
subIdBytes :: Int,
|
||||
regCodeBytes :: Int,
|
||||
clientQSize :: Natural,
|
||||
subQSize :: Natural,
|
||||
pushQSize :: Natural,
|
||||
@@ -35,7 +41,15 @@ data NtfServerConfig = NtfServerConfig
|
||||
certificateFile :: FilePath
|
||||
}
|
||||
|
||||
data Notification = Notification
|
||||
data PushNotification = PNVerification {code :: NtfRegCode} | PNPeriodic
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance FromJSON PushNotification where
|
||||
parseJSON = J.genericParseJSON . taggedObjectJSON $ dropPrefix "PN"
|
||||
|
||||
instance ToJSON PushNotification where
|
||||
toJSON = J.genericToJSON . taggedObjectJSON $ dropPrefix "PN"
|
||||
toEncoding = J.genericToEncoding . taggedObjectJSON $ dropPrefix "PN"
|
||||
|
||||
data NtfEnv = NtfEnv
|
||||
{ config :: NtfServerConfig,
|
||||
@@ -70,7 +84,7 @@ newNtfSubscriber qSize smpAgentCfg = do
|
||||
pure NtfSubscriber {smpAgent, subQ}
|
||||
|
||||
newtype NtfPushServer = NtfPushServer
|
||||
{ pushQ :: TBQueue (NtfTknData, Notification)
|
||||
{ pushQ :: TBQueue (NtfTknData, PushNotification)
|
||||
}
|
||||
|
||||
newNtfPushServer :: Natural -> STM NtfPushServer
|
||||
|
||||
@@ -36,15 +36,16 @@ data NtfTknData = NtfTknData
|
||||
{ token :: DeviceToken,
|
||||
tknStatus :: TVar NtfTknStatus,
|
||||
tknVerifyKey :: C.APublicVerifyKey,
|
||||
tknDhSecret :: C.DhSecretX25519
|
||||
tknDhSecret :: C.DhSecretX25519,
|
||||
tknRegCode :: NtfRegCode
|
||||
}
|
||||
|
||||
mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> STM NtfTknData
|
||||
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret = do
|
||||
tknStatus <- newTVar NTNew
|
||||
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret}
|
||||
mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
|
||||
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret tknRegCode = do
|
||||
tknStatus <- newTVar NTRegistered
|
||||
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret, tknRegCode}
|
||||
|
||||
data NtfSubscriptionsStore = NtfSubscriptionsStore
|
||||
-- data NtfSubscriptionsStore = NtfSubscriptionsStore
|
||||
|
||||
-- { subscriptions :: TMap NtfSubsciptionId NtfSubsciption,
|
||||
-- activeSubscriptions :: TMap (SMPServer, NotifierId) NtfSubsciptionId
|
||||
@@ -70,7 +71,9 @@ getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token))
|
||||
getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st)
|
||||
|
||||
addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM ()
|
||||
addNtfToken st tknId tkn = pure ()
|
||||
addNtfToken st tknId tkn@NtfTknData {token} = do
|
||||
TM.insert tknId tkn (tokens st)
|
||||
TM.insert token tknId (tokenIds st)
|
||||
|
||||
-- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e))
|
||||
-- getNtfRec st ent entId = case ent of
|
||||
|
||||
@@ -13,6 +13,8 @@ import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum, toLower)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.ISO8601 (parseISO8601)
|
||||
import Data.Typeable (Typeable)
|
||||
@@ -81,6 +83,14 @@ blobFieldDecoder dec = \case
|
||||
Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e)
|
||||
f -> returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
|
||||
fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ fromText = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
case fromText t of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
fstToLower :: String -> String
|
||||
fstToLower "" = ""
|
||||
fstToLower (h : t) = toLower h : t
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
@@ -115,7 +116,7 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), tGetBlock, tPutBlock)
|
||||
import Simplex.Messaging.Transport (THandle (..), Transport, TransportError (..), smpClientHandshake, tGetBlock, tPutBlock)
|
||||
import Simplex.Messaging.Util (bshow, (<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
@@ -184,6 +185,7 @@ data RawTransmission = RawTransmission
|
||||
entityId :: ByteString,
|
||||
command :: ByteString
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
-- | unparsed sent SMP transmission with signature, without session ID.
|
||||
type SignedRawTransmission = (Maybe C.ASignature, ByteString, ByteString, ByteString)
|
||||
@@ -554,11 +556,13 @@ transmissionP = do
|
||||
|
||||
class (ProtocolEncoding msg, ProtocolEncoding (ProtocolCommand msg)) => Protocol msg where
|
||||
type ProtocolCommand msg = cmd | cmd -> msg
|
||||
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> ExceptT TransportError IO (THandle c)
|
||||
protocolPing :: ProtocolCommand msg
|
||||
protocolError :: msg -> Maybe ErrorType
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
type ProtocolCommand BrokerMsg = Cmd
|
||||
protocolClientHandshake = smpClientHandshake
|
||||
protocolPing = Cmd SSender PING
|
||||
protocolError = \case
|
||||
ERR e -> Just e
|
||||
|
||||
@@ -9,11 +9,13 @@ module AgentTests.FunctionalAPITests (functionalAPITests) where
|
||||
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import NtfClient (withNtfServer)
|
||||
import SMPAgentClient
|
||||
import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (DeviceToken), PushProvider (PPApple))
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import Simplex.Messaging.Transport (ATransport (..))
|
||||
import System.Timeout
|
||||
@@ -48,11 +50,14 @@ functionalAPITests t = do
|
||||
testAsyncServerOffline t
|
||||
it "should notify after HELLO timeout" $
|
||||
withSmpServer t testAsyncHelloTimeout
|
||||
describe "Notification server" $ do
|
||||
it "should register device token" $
|
||||
withNtfServer t testNotificationToken
|
||||
|
||||
testAgentClient :: IO ()
|
||||
testAgentClient = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
@@ -95,13 +100,13 @@ testAgentClient = do
|
||||
|
||||
testAsyncInitiatingOffline :: IO ()
|
||||
testAsyncInitiatingOffline = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
alice' <- liftIO $ getSMPAgentClient cfg
|
||||
alice' <- liftIO $ getSMPAgentClient cfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId "bob's connInfo") <- get alice'
|
||||
allowConnection alice' bobId confId "alice's connInfo"
|
||||
@@ -113,15 +118,15 @@ testAsyncInitiatingOffline = do
|
||||
|
||||
testAsyncJoiningOfflineBeforeActivation :: IO ()
|
||||
testAsyncJoiningOfflineBeforeActivation = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2}
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob' ##> ("", aliceId, INFO "alice's connInfo")
|
||||
@@ -131,18 +136,18 @@ testAsyncJoiningOfflineBeforeActivation = do
|
||||
|
||||
testAsyncBothOffline :: IO ()
|
||||
testAsyncBothOffline = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
alice' <- liftIO $ getSMPAgentClient cfg
|
||||
alice' <- liftIO $ getSMPAgentClient cfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId "bob's connInfo") <- get alice'
|
||||
allowConnection alice' bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2}
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
get alice' ##> ("", bobId, CON)
|
||||
get bob' ##> ("", aliceId, INFO "alice's connInfo")
|
||||
@@ -152,8 +157,8 @@ testAsyncBothOffline = do
|
||||
|
||||
testAsyncServerOffline :: ATransport -> IO ()
|
||||
testAsyncServerOffline t = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
-- create connection and shutdown the server
|
||||
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
|
||||
runExceptT $ createConnection alice SCMInvitation
|
||||
@@ -176,8 +181,8 @@ testAsyncServerOffline t = do
|
||||
|
||||
testAsyncHelloTimeout :: IO ()
|
||||
testAsyncHelloTimeout = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1}
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(_, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
@@ -185,6 +190,13 @@ testAsyncHelloTimeout = do
|
||||
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
|
||||
pure ()
|
||||
|
||||
testNotificationToken :: IO ()
|
||||
testNotificationToken = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
registerNtfToken alice $ DeviceToken PPApple "abcd"
|
||||
pure ()
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings alice bobId bob aliceId = do
|
||||
5 <- sendMessage alice bobId "hello"
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module NtfClient where
|
||||
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
|
||||
import UnliftIO.Timeout (timeout)
|
||||
|
||||
testHost :: HostName
|
||||
testHost = "localhost"
|
||||
|
||||
testPort :: ServiceName
|
||||
testPort = "6001"
|
||||
|
||||
testKeyHash :: C.KeyHash
|
||||
testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a
|
||||
testNtfClient client =
|
||||
runTransportClient testHost testPort testKeyHash (Just defaultKeepAliveOpts) $ \h ->
|
||||
liftIO (runExceptT $ ntfClientHandshake h testKeyHash) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
cfg :: NtfServerConfig
|
||||
cfg =
|
||||
NtfServerConfig
|
||||
{ transports = undefined,
|
||||
subIdBytes = 24,
|
||||
regCodeBytes = 32,
|
||||
clientQSize = 1,
|
||||
subQSize = 1,
|
||||
pushQSize = 1,
|
||||
smpAgentCfg = defaultSMPClientAgentConfig,
|
||||
-- CA certificate private key is not needed for initialization
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
}
|
||||
|
||||
withNtfServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withNtfServerThreadOn t port' =
|
||||
serverBracket
|
||||
(\started -> runNtfServerBlocking started cfg {transports = [(port', t)]})
|
||||
(pure ())
|
||||
|
||||
serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
|
||||
serverBracket process afterProcess f = do
|
||||
started <- newEmptyTMVarIO
|
||||
E.bracket
|
||||
(forkIOWithUnmask ($ process started))
|
||||
(\t -> killThread t >> afterProcess >> waitFor started "stop")
|
||||
(\t -> waitFor started "start" >> f t)
|
||||
where
|
||||
waitFor started s =
|
||||
5_000_000 `timeout` atomically (takeTMVar started) >>= \case
|
||||
Nothing -> error $ "server did not " <> s
|
||||
_ -> pure ()
|
||||
|
||||
withNtfServerOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> m a -> m a
|
||||
withNtfServerOn t port' = withNtfServerThreadOn t port' . const
|
||||
|
||||
withNtfServer :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withNtfServer t = withNtfServerOn t testPort
|
||||
|
||||
runNtfTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (THandle c -> m a) -> m a
|
||||
runNtfTest test = withNtfServer (transport @c) $ testNtfClient test
|
||||
|
||||
ntfServerTest ::
|
||||
forall c smp.
|
||||
(Transport c, Encoding smp) =>
|
||||
TProxy c ->
|
||||
(Maybe C.ASignature, ByteString, ByteString, smp) ->
|
||||
IO (Maybe C.ASignature, ByteString, ByteString, BrokerMsg)
|
||||
ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
||||
where
|
||||
tPut' h (sig, corrId, queueId, smp) = do
|
||||
let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp)
|
||||
Right () <- tPut h (sig, t')
|
||||
pure ()
|
||||
tGet' h = do
|
||||
(Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
@@ -0,0 +1,40 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module NtfServerTests where
|
||||
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import NtfClient
|
||||
import ServerTests (sampleDhPubKey, samplePubKey, sampleSig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Test.Hspec
|
||||
|
||||
ntfServerTests :: ATransport -> Spec
|
||||
ntfServerTests t = do
|
||||
describe "notifications server protocol syntax" $ ntfSyntaxTests t
|
||||
|
||||
ntfSyntaxTests :: ATransport -> Spec
|
||||
ntfSyntaxTests (ATransport t) = do
|
||||
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
|
||||
describe "NEW" $ do
|
||||
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX)
|
||||
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX)
|
||||
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH)
|
||||
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH)
|
||||
where
|
||||
(>#>) ::
|
||||
Encoding smp =>
|
||||
(Maybe C.ASignature, ByteString, ByteString, smp) ->
|
||||
(Maybe C.ASignature, ByteString, ByteString, BrokerMsg) ->
|
||||
Expectation
|
||||
command >#> response = ntfServerTest t command `shouldReturn` response
|
||||
+11
-4
@@ -154,11 +154,17 @@ smpAgentTest1_1_1 test' =
|
||||
_test [h] = test' h
|
||||
_test _ = error "expected 1 handle"
|
||||
|
||||
initAgentServers :: InitialAgentServers
|
||||
initAgentServers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"],
|
||||
ntf = ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"]
|
||||
}
|
||||
|
||||
cfg :: AgentConfig
|
||||
cfg =
|
||||
defaultAgentConfig
|
||||
{ tcpPort = agentTestPort,
|
||||
initialSMPServers = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"],
|
||||
tbqSize = 1,
|
||||
dbFile = testDB,
|
||||
smpCfg =
|
||||
@@ -167,7 +173,7 @@ cfg =
|
||||
defaultTransport = (testPort, transport @TLS),
|
||||
tcpTimeout = 500_000
|
||||
},
|
||||
reconnectInterval = (reconnectInterval defaultAgentConfig) {initialInterval = 50_000},
|
||||
reconnectInterval = defaultReconnectInterval {initialInterval = 50_000},
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
@@ -175,9 +181,10 @@ cfg =
|
||||
|
||||
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m () -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
|
||||
let cfg' = cfg {tcpPort = port', dbFile = db', initialSMPServers = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
|
||||
let cfg' = cfg {tcpPort = port', dbFile = db'}
|
||||
initServers' = initAgentServers {smp = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg')
|
||||
(\started -> runSMPAgentBlocking t started cfg' initServers')
|
||||
afterProcess
|
||||
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
|
||||
@@ -4,6 +4,7 @@ import AgentTests (agentTests)
|
||||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
import CoreTests.VersionRangeTests
|
||||
import NtfServerTests (ntfServerTests)
|
||||
import ServerTests
|
||||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
@@ -20,5 +21,6 @@ main = do
|
||||
describe "Version range" versionRangeTests
|
||||
describe "SMP server via TLS" $ serverTests (transport @TLS)
|
||||
describe "SMP server via WebSockets" $ serverTests (transport @WS)
|
||||
describe "Ntf server via TLS" $ ntfServerTests (transport @TLS)
|
||||
describe "SMP client agent" $ agentTests (transport @TLS)
|
||||
removeDirectoryRecursive "tests/tmp"
|
||||
|
||||
Reference in New Issue
Block a user