mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-05-14 16:15:12 +00:00
interval notifications (TCRN command) (#352)
* notifications: periodic notifications * agent: allow repeat token registrations, delete old tokens from notification server (e.g., when database is moved to another device) * decrypt token verification code in the agent * check token status, send TCRN on registration if it was enabled * fix http2/apns response handling for error responses (also, APNS seems not to send content-length header?)
This commit is contained in:
committed by
GitHub
parent
45ddecc4b8
commit
9d8a9c4fe4
@@ -2,6 +2,7 @@
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig)
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServer)
|
||||
import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..))
|
||||
@@ -15,8 +16,13 @@ cfgPath = "/etc/opt/simplex-notifications"
|
||||
logPath :: FilePath
|
||||
logPath = "/var/opt/simplex-notifications"
|
||||
|
||||
logCfg :: LogConfig
|
||||
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
|
||||
main :: IO ()
|
||||
main = protocolServerCLI ntfServerCLIConfig runNtfServer
|
||||
main = do
|
||||
setLogLevel LogDebug -- TODO change to LogError in production
|
||||
withGlobalLogging logCfg $ protocolServerCLI ntfServerCLIConfig runNtfServer
|
||||
|
||||
ntfServerCLIConfig :: ServerCLIConfig NtfServerConfig
|
||||
ntfServerCLIConfig =
|
||||
|
||||
@@ -63,7 +63,7 @@ import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.Bifunctor (first, second)
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Composition ((.:), (.:.))
|
||||
import Data.Functor (($>))
|
||||
@@ -87,7 +87,7 @@ import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Crypto.Ratchet as CR
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfTknStatus (..))
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..))
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, MsgBody)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
@@ -165,8 +165,8 @@ registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken c = withAgentEnv c . registerNtfToken' c
|
||||
|
||||
-- | Verify device notifications token
|
||||
verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m ()
|
||||
verifyNtfToken c = withAgentEnv c .: verifyNtfToken' c
|
||||
verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> ByteString -> C.CbNonce -> m ()
|
||||
verifyNtfToken c = withAgentEnv c .:. verifyNtfToken' c
|
||||
|
||||
-- | Enable/disable periodic notifications
|
||||
enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
|
||||
@@ -523,60 +523,78 @@ setSMPServers' c servers = do
|
||||
registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> m ()
|
||||
registerNtfToken' c deviceToken =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId, ntfTknAction} -> case (ntfTokenId, ntfTknAction) of
|
||||
(Nothing, Just (NTARegister ntfPubKey)) -> registerToken tkn ntfPubKey
|
||||
-- TODO request verification code again in case there is registration, but no verification code in DB - probably after some timeout?
|
||||
(Just tknId, Just (NTAVerify code)) ->
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
|
||||
(Just tknId, Just (NTACron interval)) ->
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfEnableCron c tknId tkn interval
|
||||
(Just _tknId, Just NTACheck) -> pure ()
|
||||
(Just tknId, Just NTADelete) ->
|
||||
t tkn (NTExpired, Nothing) $ agentNtfDeleteToken c tknId tkn
|
||||
_ -> pure ()
|
||||
(Just tkn@NtfToken {ntfTokenId, ntfTknStatus, ntfTknAction}, prevTokens) -> do
|
||||
mapM_ (deleteToken_ c) prevTokens
|
||||
case (ntfTokenId, ntfTknAction) of
|
||||
(Nothing, Just NTARegister) -> registerToken tkn
|
||||
-- TODO minimal time before repeat registration
|
||||
(Just _, Nothing) -> when (ntfTknStatus == NTRegistered) $ registerToken tkn
|
||||
(Just tknId, Just (NTAVerify code)) ->
|
||||
t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code
|
||||
(Just tknId, Just (NTACron interval)) ->
|
||||
t tkn (cronSuccess interval) $ agentNtfEnableCron c tknId tkn interval
|
||||
(Just _tknId, Just NTACheck) -> pure () -- TODO
|
||||
-- agentNtfCheckToken c tknId tkn >>= \case
|
||||
(Just tknId, Just NTADelete) -> do
|
||||
agentNtfDeleteToken c tknId tkn
|
||||
withStore $ \st -> removeNtfToken st tkn
|
||||
_ -> pure ()
|
||||
_ ->
|
||||
getNtfServer c >>= \case
|
||||
Just ntfServer ->
|
||||
asks (cmdSignAlg . config) >>= \case
|
||||
C.SignAlg a -> do
|
||||
(ntfPubKey, ntfPrivKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
let tkn = newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey
|
||||
tknKeys <- liftIO $ C.generateSignatureKeyPair a
|
||||
dhKeys <- liftIO C.generateKeyPair'
|
||||
let tkn = newNtfToken deviceToken ntfServer tknKeys dhKeys
|
||||
withStore $ \st -> createNtfToken st tkn
|
||||
registerToken tkn ntfPubKey
|
||||
registerToken tkn
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
where
|
||||
t tkn = withToken tkn Nothing
|
||||
registerToken :: NtfToken -> C.APublicVerifyKey -> m ()
|
||||
registerToken tkn ntfPubKey = do
|
||||
(pubDhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
registerToken :: NtfToken -> m ()
|
||||
registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do
|
||||
(tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey
|
||||
let dhSecret = C.dh' srvPubDhKey privDhKey
|
||||
withStore $ \st -> updateNtfTokenRegistration st tkn tknId dhSecret
|
||||
|
||||
verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m ()
|
||||
verifyNtfToken' c deviceToken code =
|
||||
-- TODO decrypt verification code
|
||||
verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> ByteString -> C.CbNonce -> m ()
|
||||
verifyNtfToken' c deviceToken code nonce =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId} ->
|
||||
withToken tkn (Just (NTConfirmed, NTAVerify code)) (NTActive, Just NTACheck) $
|
||||
agentNtfVerifyToken c tknId tkn code
|
||||
(Just tkn@NtfToken {ntfTokenId = Just tknId, ntfDhSecret = Just dhSecret}, _) -> do
|
||||
code' <- liftEither . bimap cryptoError NtfRegCode $ C.cbDecrypt dhSecret nonce code
|
||||
withToken tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $
|
||||
agentNtfVerifyToken c tknId tkn code'
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
enableNtfCron' :: AgentMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
|
||||
enableNtfCron' c deviceToken interval =
|
||||
enableNtfCron' c deviceToken interval = do
|
||||
when (interval < 20) . throwError $ CMD PROHIBITED
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive} ->
|
||||
withToken tkn (Just (NTActive, NTACron interval)) (NTActive, Just NTACheck) $
|
||||
(Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive}, _) ->
|
||||
withToken tkn (Just (NTActive, NTACron interval)) (cronSuccess interval) $
|
||||
agentNtfEnableCron c tknId tkn interval
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
cronSuccess :: Word16 -> (NtfTknStatus, Maybe NtfTknAction)
|
||||
cronSuccess interval
|
||||
| interval == 0 = (NTActive, Just NTACheck)
|
||||
| otherwise = (NTActive, Just $ NTACron interval)
|
||||
|
||||
deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
|
||||
deleteNtfToken' c deviceToken =
|
||||
withStore (`getDeviceNtfToken` deviceToken) >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus} ->
|
||||
withToken tkn (Just (ntfTknStatus, NTADelete)) (NTExpired, Nothing) $
|
||||
agentNtfDeleteToken c tknId tkn
|
||||
(Just tkn, _) -> deleteToken_ c tkn
|
||||
_ -> throwError $ CMD PROHIBITED
|
||||
|
||||
deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m ()
|
||||
deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do
|
||||
forM_ ntfTokenId $ \tknId -> do
|
||||
withStore $ \st -> updateNtfToken st tkn ntfTknStatus (Just NTADelete)
|
||||
agentNtfDeleteToken c tknId tkn
|
||||
withStore $ \st -> removeNtfToken st tkn
|
||||
|
||||
withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a
|
||||
withToken tkn from_ (toStatus, toAction_) f = do
|
||||
forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action)
|
||||
|
||||
@@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client
|
||||
sendAgentMessage,
|
||||
agentNtfRegisterToken,
|
||||
agentNtfVerifyToken,
|
||||
agentNtfCheckToken,
|
||||
agentNtfDeleteToken,
|
||||
agentNtfEnableCron,
|
||||
agentCbEncrypt,
|
||||
@@ -478,6 +479,10 @@ agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken ->
|
||||
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
|
||||
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
|
||||
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
|
||||
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
|
||||
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
|
||||
@@ -81,9 +81,10 @@ class Monad m => MonadAgentStore s m where
|
||||
createNtfToken :: s -> NtfToken -> m ()
|
||||
|
||||
-- TODO this should also return old tokens so that they are deleted from the server
|
||||
getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken) -- return current token if it exists
|
||||
getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken, [NtfToken]) -- return current token if it exists
|
||||
updateNtfTokenRegistration :: s -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m ()
|
||||
updateNtfToken :: s -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m ()
|
||||
removeNtfToken :: s -> NtfToken -> m ()
|
||||
|
||||
-- * Queue types
|
||||
|
||||
|
||||
@@ -33,12 +33,12 @@ import Control.Exception (bracket)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Crypto.Random (ChaChaDRG, randomBytesGenerate)
|
||||
import Data.Bifunctor (second)
|
||||
import Data.Bifunctor (first, second)
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.Char (toLower)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (find, foldl')
|
||||
import Data.List (find, foldl', partition)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (listToMaybe)
|
||||
import Data.Text (Text)
|
||||
@@ -567,35 +567,36 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk)
|
||||
|
||||
createNtfToken :: SQLiteStore -> NtfToken -> m ()
|
||||
createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} =
|
||||
createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction} =
|
||||
liftIO . withTransaction st $ \db -> do
|
||||
upsertNtfServer_ db srv
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
INSERT INTO ntf_tokens
|
||||
(provider, device_token, ntf_host, ntf_port, tkn_id, tkn_priv_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
(provider, device_token, ntf_host, ntf_port, tkn_id, tkn_pub_key, tkn_priv_key, tkn_pub_dh_key, tkn_priv_dh_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
|]
|
||||
(provider, token, host, port, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction)
|
||||
(provider, token, host, port, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction)
|
||||
|
||||
getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken)
|
||||
getDeviceNtfToken st t@(DeviceToken provider token) =
|
||||
liftIO . withTransaction st $ \db ->
|
||||
fmap ntfToken . listToMaybe
|
||||
<$> DB.query
|
||||
db
|
||||
[sql|
|
||||
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
|
||||
t.tkn_id, t.tkn_priv_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action
|
||||
FROM ntf_tokens t
|
||||
JOIN ntf_servers s USING (ntf_host, ntf_port)
|
||||
WHERE t.provider = ? AND t.device_token = ?
|
||||
|]
|
||||
(provider, token)
|
||||
getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken, [NtfToken])
|
||||
getDeviceNtfToken st t =
|
||||
liftIO . withTransaction st $ \db -> do
|
||||
tokens <-
|
||||
map ntfToken
|
||||
<$> DB.query_
|
||||
db
|
||||
[sql|
|
||||
SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash,
|
||||
t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action
|
||||
FROM ntf_tokens t
|
||||
JOIN ntf_servers s USING (ntf_host, ntf_port)
|
||||
|]
|
||||
pure . first listToMaybe $ partition ((t ==) . deviceToken) tokens
|
||||
where
|
||||
ntfToken (host, port, keyHash, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) =
|
||||
ntfToken (host, port, keyHash, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) =
|
||||
let ntfServer = ProtocolServer {host, port, keyHash}
|
||||
in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction}
|
||||
ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey)
|
||||
in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction}
|
||||
|
||||
updateNtfTokenRegistration :: SQLiteStore -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m ()
|
||||
updateNtfTokenRegistration st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret =
|
||||
@@ -623,6 +624,17 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto
|
||||
|]
|
||||
(tknStatus, tknAction, updatedAt, provider, token, host, port)
|
||||
|
||||
removeNtfToken :: SQLiteStore -> NtfToken -> m ()
|
||||
removeNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} =
|
||||
liftIO . withTransaction st $ \db ->
|
||||
DB.execute
|
||||
db
|
||||
[sql|
|
||||
DELETE FROM ntf_tokens
|
||||
WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ?
|
||||
|]
|
||||
(provider, token, host, port)
|
||||
|
||||
-- * Auxiliary helpers
|
||||
|
||||
instance ToField QueueStatus where toField = toField . serializeQueueStatus
|
||||
|
||||
@@ -23,7 +23,10 @@ CREATE TABLE ntf_tokens (
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
tkn_id BLOB, -- token ID assigned by notifications server
|
||||
tkn_priv_key BLOB NOT NULL, -- private key to sign token commands
|
||||
tkn_pub_key BLOB NOT NULL, -- client's public key to verify token commands (used by server, for repeat registraions)
|
||||
tkn_priv_key BLOB NOT NULL, -- client's private key to sign token commands
|
||||
tkn_pub_dh_key BLOB NOT NULL, -- client's public DH key (for repeat registraions)
|
||||
tkn_priv_dh_key BLOB NOT NULL, -- client's private DH key (for repeat registraions)
|
||||
tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications
|
||||
tkn_status TEXT NOT NULL,
|
||||
tkn_action BLOB,
|
||||
|
||||
@@ -52,6 +52,7 @@ module Simplex.Messaging.Crypto
|
||||
CryptoPublicKey (..),
|
||||
CryptoPrivateKey (..),
|
||||
KeyPair,
|
||||
ASignatureKeyPair,
|
||||
DhSecret (..),
|
||||
DhSecretX25519,
|
||||
ADhSecret (..),
|
||||
@@ -870,6 +871,14 @@ cbDecrypt secret (CbNonce nonce) packet
|
||||
newtype CbNonce = CbNonce {unCbNonce :: ByteString}
|
||||
deriving (Show)
|
||||
|
||||
instance StrEncoding CbNonce where
|
||||
strEncode (CbNonce s) = strEncode s
|
||||
strP = cbNonce <$> strP
|
||||
|
||||
instance ToJSON CbNonce where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
cbNonce :: ByteString -> CbNonce
|
||||
cbNonce s
|
||||
| len == 24 = CbNonce s
|
||||
|
||||
@@ -33,6 +33,12 @@ ntfRegisterToken c pKey newTkn =
|
||||
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO ()
|
||||
ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
|
||||
|
||||
ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO NtfTknStatus
|
||||
ntfCheckToken c pKey tknId =
|
||||
sendNtfCommand c (Just pKey) tknId TCHK >>= \case
|
||||
NRTkn stat -> pure stat
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteToken = okNtfCommand TDEL
|
||||
|
||||
@@ -48,7 +54,7 @@ ntfCreateSubsciption c pKey newSub =
|
||||
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
|
||||
ntfCheckSubscription c pKey subId =
|
||||
sendNtfCommand c (Just pKey) subId SCHK >>= \case
|
||||
NRStat stat -> pure stat
|
||||
NRSub stat -> pure stat
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
|
||||
@@ -65,7 +71,7 @@ okNtfCommand cmd c pKey entId =
|
||||
_ -> throwE PCEUnexpectedResponse
|
||||
|
||||
data NtfTknAction
|
||||
= NTARegister C.APublicVerifyKey -- public key to send to the server
|
||||
= NTARegister
|
||||
| NTAVerify NtfRegCode -- code to verify token
|
||||
| NTACheck
|
||||
| NTACron Word16
|
||||
@@ -74,14 +80,14 @@ data NtfTknAction
|
||||
|
||||
instance Encoding NtfTknAction where
|
||||
smpEncode = \case
|
||||
NTARegister key -> smpEncode ('R', key)
|
||||
NTARegister -> "R"
|
||||
NTAVerify code -> smpEncode ('V', code)
|
||||
NTACheck -> "C"
|
||||
NTACron interval -> smpEncode ('I', interval)
|
||||
NTADelete -> "D"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'R' -> NTARegister <$> smpP
|
||||
'R' -> pure NTARegister
|
||||
'V' -> NTAVerify <$> smpP
|
||||
'C' -> pure NTACheck
|
||||
'I' -> NTACron <$> smpP
|
||||
@@ -96,8 +102,12 @@ data NtfToken = NtfToken
|
||||
{ deviceToken :: DeviceToken,
|
||||
ntfServer :: NtfServer,
|
||||
ntfTokenId :: Maybe NtfTokenId,
|
||||
-- | key used by the ntf server to verify transmissions
|
||||
ntfPubKey :: C.APublicVerifyKey,
|
||||
-- | key used by the ntf client to sign transmissions
|
||||
ntfPrivKey :: C.APrivateSignKey,
|
||||
-- | client's DH keys (to repeat registration if necessary)
|
||||
ntfDhKeys :: C.KeyPair 'C.X25519,
|
||||
-- | shared DH secret used to encrypt/decrypt notifications e2e
|
||||
ntfDhSecret :: Maybe C.DhSecretX25519,
|
||||
-- | token status
|
||||
@@ -107,14 +117,16 @@ data NtfToken = NtfToken
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newNtfToken :: DeviceToken -> NtfServer -> C.APrivateSignKey -> C.APublicVerifyKey -> NtfToken
|
||||
newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey =
|
||||
newNtfToken :: DeviceToken -> NtfServer -> C.ASignatureKeyPair -> C.KeyPair 'C.X25519 -> NtfToken
|
||||
newNtfToken deviceToken ntfServer (ntfPubKey, ntfPrivKey) ntfDhKeys =
|
||||
NtfToken
|
||||
{ deviceToken,
|
||||
ntfServer,
|
||||
ntfTokenId = Nothing,
|
||||
ntfPubKey,
|
||||
ntfPrivKey,
|
||||
ntfDhKeys,
|
||||
ntfDhSecret = Nothing,
|
||||
ntfTknStatus = NTNew,
|
||||
ntfTknAction = Just $ NTARegister ntfPubKey
|
||||
}
|
||||
ntfTknAction = Just NTARegister
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
@@ -51,6 +52,7 @@ instance NtfEntityI 'Subscription where sNtfEntity = SSubscription
|
||||
data NtfCommandTag (e :: NtfEntity) where
|
||||
TNEW_ :: NtfCommandTag 'Token
|
||||
TVFY_ :: NtfCommandTag 'Token
|
||||
TCHK_ :: NtfCommandTag 'Token
|
||||
TDEL_ :: NtfCommandTag 'Token
|
||||
TCRN_ :: NtfCommandTag 'Token
|
||||
SNEW_ :: NtfCommandTag 'Subscription
|
||||
@@ -66,6 +68,7 @@ instance NtfEntityI e => Encoding (NtfCommandTag e) where
|
||||
smpEncode = \case
|
||||
TNEW_ -> "TNEW"
|
||||
TVFY_ -> "TVFY"
|
||||
TCHK_ -> "TCHK"
|
||||
TDEL_ -> "TDEL"
|
||||
TCRN_ -> "TCRN"
|
||||
SNEW_ -> "SNEW"
|
||||
@@ -82,6 +85,7 @@ instance ProtocolMsgTag NtfCmdTag where
|
||||
decodeTag = \case
|
||||
"TNEW" -> Just $ NCT SToken TNEW_
|
||||
"TVFY" -> Just $ NCT SToken TVFY_
|
||||
"TCHK" -> Just $ NCT SToken TCHK_
|
||||
"TDEL" -> Just $ NCT SToken TDEL_
|
||||
"TCRN" -> Just $ NCT SToken TCRN_
|
||||
"SNEW" -> Just $ NCT SSubscription SNEW_
|
||||
@@ -147,6 +151,8 @@ data NtfCommand (e :: NtfEntity) where
|
||||
TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token
|
||||
-- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token
|
||||
TVFY :: NtfRegCode -> NtfCommand 'Token
|
||||
-- | check token status
|
||||
TCHK :: NtfCommand 'Token
|
||||
-- | delete token - all subscriptions will be removed and no more notifications will be sent
|
||||
TDEL :: NtfCommand 'Token
|
||||
-- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable
|
||||
@@ -171,6 +177,7 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
encodeProtocol = \case
|
||||
TNEW newTkn -> e (TNEW_, ' ', newTkn)
|
||||
TVFY code -> e (TVFY_, ' ', code)
|
||||
TCHK -> e TCHK_
|
||||
TDEL -> e TDEL_
|
||||
TCRN int -> e (TCRN_, ' ', int)
|
||||
SNEW newSub -> e (SNEW_, ' ', newSub)
|
||||
@@ -209,6 +216,7 @@ instance ProtocolEncoding NtfCmd where
|
||||
NtfCmd SToken <$> case tag of
|
||||
TNEW_ -> TNEW <$> _smpP
|
||||
TVFY_ -> TVFY <$> _smpP
|
||||
TCHK_ -> pure TCHK
|
||||
TDEL_ -> pure TDEL
|
||||
TCRN_ -> TCRN <$> _smpP
|
||||
NCT SSubscription tag ->
|
||||
@@ -224,7 +232,8 @@ data NtfResponseTag
|
||||
= NRId_
|
||||
| NROk_
|
||||
| NRErr_
|
||||
| NRStat_
|
||||
| NRTkn_
|
||||
| NRSub_
|
||||
| NRPong_
|
||||
deriving (Show)
|
||||
|
||||
@@ -233,7 +242,8 @@ instance Encoding NtfResponseTag where
|
||||
NRId_ -> "ID"
|
||||
NROk_ -> "OK"
|
||||
NRErr_ -> "ERR"
|
||||
NRStat_ -> "STAT"
|
||||
NRTkn_ -> "TKN"
|
||||
NRSub_ -> "SUB"
|
||||
NRPong_ -> "PONG"
|
||||
smpP = messageTagP
|
||||
|
||||
@@ -242,7 +252,8 @@ instance ProtocolMsgTag NtfResponseTag where
|
||||
"ID" -> Just NRId_
|
||||
"OK" -> Just NROk_
|
||||
"ERR" -> Just NRErr_
|
||||
"STAT" -> Just NRStat_
|
||||
"TKN" -> Just NRTkn_
|
||||
"SUB" -> Just NRSub_
|
||||
"PONG" -> Just NRPong_
|
||||
_ -> Nothing
|
||||
|
||||
@@ -250,7 +261,8 @@ data NtfResponse
|
||||
= NRId NtfEntityId C.PublicKeyX25519
|
||||
| NROk
|
||||
| NRErr ErrorType
|
||||
| NRStat NtfSubStatus
|
||||
| NRTkn NtfTknStatus
|
||||
| NRSub NtfSubStatus
|
||||
| NRPong
|
||||
|
||||
instance ProtocolEncoding NtfResponse where
|
||||
@@ -259,7 +271,8 @@ instance ProtocolEncoding NtfResponse where
|
||||
NRId entId dhKey -> e (NRId_, ' ', entId, dhKey)
|
||||
NROk -> e NROk_
|
||||
NRErr err -> e (NRErr_, ' ', err)
|
||||
NRStat stat -> e (NRStat_, ' ', stat)
|
||||
NRTkn stat -> e (NRTkn_, ' ', stat)
|
||||
NRSub stat -> e (NRSub_, ' ', stat)
|
||||
NRPong -> e NRPong_
|
||||
where
|
||||
e :: Encoding a => a -> ByteString
|
||||
@@ -269,7 +282,8 @@ instance ProtocolEncoding NtfResponse where
|
||||
NRId_ -> NRId <$> _smpP <*> smpP
|
||||
NROk_ -> pure NROk
|
||||
NRErr_ -> NRErr <$> _smpP
|
||||
NRStat_ -> NRStat <$> _smpP
|
||||
NRTkn_ -> NRTkn <$> _smpP
|
||||
NRSub_ -> NRSub <$> _smpP
|
||||
NRPong_ -> pure NRPong
|
||||
|
||||
checkCredentials (_, _, entId, _) cmd = case cmd of
|
||||
@@ -301,22 +315,22 @@ instance Encoding SMPQueueNtf where
|
||||
(smpServer, notifierId, notifierKey) <- smpP
|
||||
pure $ SMPQueueNtf smpServer notifierId notifierKey
|
||||
|
||||
data PushProvider = PPApple
|
||||
data PushProvider = PPApns
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding PushProvider where
|
||||
smpEncode = \case
|
||||
PPApple -> "A"
|
||||
PPApns -> "A"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'A' -> pure PPApple
|
||||
'A' -> pure PPApns
|
||||
_ -> fail "bad PushProvider"
|
||||
|
||||
instance TextEncoding PushProvider where
|
||||
textEncode = \case
|
||||
PPApple -> "apple"
|
||||
PPApns -> "apple"
|
||||
textDecode = \case
|
||||
"apple" -> Just PPApple
|
||||
"apple" -> Just PPApns
|
||||
_ -> Nothing
|
||||
|
||||
instance FromField PushProvider where fromField = fromTextField_ textDecode
|
||||
@@ -363,7 +377,7 @@ instance Encoding NtfSubStatus where
|
||||
"ACTIVE" -> pure NSActive
|
||||
"END" -> pure NSEnd
|
||||
"SMP_AUTH" -> pure NSSMPAuth
|
||||
_ -> fail "bad NtfError"
|
||||
_ -> fail "bad NtfSubStatus"
|
||||
|
||||
data NtfTknStatus
|
||||
= -- | Token created in DB
|
||||
@@ -380,26 +394,27 @@ data NtfTknStatus
|
||||
NTExpired
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance TextEncoding NtfTknStatus where
|
||||
textEncode = \case
|
||||
NTNew -> "new"
|
||||
NTRegistered -> "registered"
|
||||
NTInvalid -> "invalid"
|
||||
NTConfirmed -> "confirmed"
|
||||
NTActive -> "active"
|
||||
NTExpired -> "expired"
|
||||
textDecode = \case
|
||||
"new" -> Just NTNew
|
||||
"registered" -> Just NTRegistered
|
||||
"invalid" -> Just NTInvalid
|
||||
"confirmed" -> Just NTConfirmed
|
||||
"active" -> Just NTActive
|
||||
"expired" -> Just NTExpired
|
||||
_ -> Nothing
|
||||
instance Encoding NtfTknStatus where
|
||||
smpEncode = \case
|
||||
NTNew -> "NEW"
|
||||
NTRegistered -> "REGISTERED"
|
||||
NTInvalid -> "INVALID"
|
||||
NTConfirmed -> "CONFIRMED"
|
||||
NTActive -> "ACTIVE"
|
||||
NTExpired -> "EXPIRED"
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"NEW" -> pure NTNew
|
||||
"REGISTERED" -> pure NTRegistered
|
||||
"INVALID" -> pure NTInvalid
|
||||
"CONFIRMED" -> pure NTConfirmed
|
||||
"ACTIVE" -> pure NTActive
|
||||
"EXPIRED" -> pure NTExpired
|
||||
_ -> fail "bad NtfTknStatus"
|
||||
|
||||
instance FromField NtfTknStatus where fromField = fromTextField_ textDecode
|
||||
instance FromField NtfTknStatus where fromField = fromTextField_ $ either (const Nothing) Just . smpDecode . encodeUtf8
|
||||
|
||||
instance ToField NtfTknStatus where toField = toField . textEncode
|
||||
instance ToField NtfTknStatus where toField = toField . decodeLatin1 . smpEncode
|
||||
|
||||
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
|
||||
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
|
||||
|
||||
@@ -10,11 +10,13 @@
|
||||
|
||||
module Simplex.Messaging.Notifications.Server where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.Text as T
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
@@ -26,9 +28,12 @@ import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), SignedTransmission, Transmission, encodeTransmission, tGet, tPut)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport (..), THandle (..), TProxy, Transport)
|
||||
import Simplex.Messaging.Transport.Server (runTransportServer)
|
||||
import Simplex.Messaging.Util
|
||||
import UnliftIO (async, uninterruptibleCancel)
|
||||
import UnliftIO.Concurrent (threadDelay)
|
||||
import UnliftIO.Exception
|
||||
import UnliftIO.STM
|
||||
|
||||
@@ -106,7 +111,7 @@ ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}}
|
||||
ntfPush :: MonadUnliftIO m => NtfPushServer -> m ()
|
||||
ntfPush s@NtfPushServer {pushQ} = liftIO . forever . runExceptT $ do
|
||||
(tkn@NtfTknData {token = DeviceToken pp _, tknStatus}, ntf) <- atomically (readTBQueue pushQ)
|
||||
liftIO $ putStrLn $ "sending push notification to " <> show pp
|
||||
logDebug $ "sending push notification to " <> T.pack (show pp)
|
||||
status <- readTVarIO tknStatus
|
||||
case (status, ntf) of
|
||||
(_, PNVerification _) -> do
|
||||
@@ -116,13 +121,13 @@ ntfPush s@NtfPushServer {pushQ} = liftIO . forever . runExceptT $ do
|
||||
(NTActive, PNCheckMessages) -> do
|
||||
deliverNotification pp tkn ntf
|
||||
_ -> do
|
||||
liftIO $ putStrLn "bad notification token status"
|
||||
logError "bad notification token status"
|
||||
where
|
||||
deliverNotification :: PushProvider -> PushProviderClient
|
||||
deliverNotification pp tkn ntf = do
|
||||
deliver <- liftIO $ getPushClient s pp
|
||||
-- TODO retry later based on the error
|
||||
deliver tkn ntf `catchError` \e -> liftIO (putStrLn $ "Push provider error (" <> show pp <> "): " <> show e) >> throwError e
|
||||
deliver tkn ntf `catchError` \e -> logError (T.pack $ "Push provider error (" <> show pp <> "): " <> show e) >> throwError e
|
||||
|
||||
runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m ()
|
||||
runNtfClientTransport th@THandle {sessionId} = do
|
||||
@@ -138,14 +143,14 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
|
||||
receive th NtfServerClient {rcvQ, sndQ} = forever $ do
|
||||
t@(_, _, (corrId, subId, cmdOrError)) <- tGet th
|
||||
liftIO $ putStrLn "receive"
|
||||
t@(_, _, (corrId, entId, cmdOrError)) <- tGet th
|
||||
logDebug "received transmission"
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, subId, NRErr e)
|
||||
Left e -> write sndQ (corrId, entId, NRErr e)
|
||||
Right cmd ->
|
||||
verifyNtfTransmission t cmd >>= \case
|
||||
VRVerified req -> write rcvQ req
|
||||
VRFailed -> write sndQ (corrId, subId, NRErr AUTH)
|
||||
VRFailed -> write sndQ (corrId, entId, NRErr AUTH)
|
||||
where
|
||||
write q t = atomically $ writeTBQueue q t
|
||||
|
||||
@@ -159,41 +164,31 @@ data VerificationResult = VRVerified NtfRequest | VRFailed
|
||||
verifyNtfTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => SignedTransmission NtfCmd -> NtfCmd -> m VerificationResult
|
||||
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
|
||||
st <- asks store
|
||||
case cmd of
|
||||
NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) ->
|
||||
-- TODO check that token is not already in store, if it is - verify that the saved key is the same
|
||||
NtfCmd SToken c@(TNEW n@(NewNtfTkn _ k _)) -> do
|
||||
r_ <- atomically $ getNtfToken st entId
|
||||
pure $
|
||||
if verifyCmdSignature sig_ signed k
|
||||
then VRVerified (NtfReqNew corrId (ANE SToken n))
|
||||
then case r_ of
|
||||
Just r@(NtfTkn NtfTknData {tknVerifyKey})
|
||||
| k == tknVerifyKey -> tknCmd r c
|
||||
| otherwise -> VRFailed
|
||||
_ -> VRVerified (NtfReqNew corrId (ANE SToken n))
|
||||
else VRFailed
|
||||
NtfCmd SToken c -> do
|
||||
st <- asks store
|
||||
atomically (getNtfToken st entId) >>= \case
|
||||
Just r@(NtfTkn NtfTknData {tknVerifyKey}) ->
|
||||
pure $
|
||||
if verifyCmdSignature sig_ signed tknVerifyKey
|
||||
then VRVerified (NtfReqCmd SToken r (corrId, entId, c))
|
||||
else VRFailed
|
||||
_ -> pure VRFailed -- TODO dummy verification
|
||||
r_ <- atomically $ getNtfToken st entId
|
||||
pure $ case r_ of
|
||||
Just r@(NtfTkn NtfTknData {tknVerifyKey})
|
||||
| verifyCmdSignature sig_ signed tknVerifyKey -> tknCmd r c
|
||||
| otherwise -> VRFailed
|
||||
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
|
||||
_ -> pure VRFailed
|
||||
|
||||
-- do
|
||||
-- st <- asks store
|
||||
-- case cmd of
|
||||
-- NCSubCreate tokenId smpQueue -> verifyCreateCmd verifyKey newSub <$> atomically (getNtfSubViaSMPQueue st smpQueue)
|
||||
-- _ -> verifySubCmd <$> atomically (getNtfSub st subId)
|
||||
-- where
|
||||
-- verifyCreateCmd k newSub sub_
|
||||
-- | verifyCmdSignature sig_ signed k = case sub_ of
|
||||
-- Just sub -> if k == subVerifyKey sub then VRCommand sub else VRFail
|
||||
-- _ -> VRCreate newSub
|
||||
-- | otherwise = VRFail
|
||||
-- verifySubCmd = \case
|
||||
-- Just sub -> if verifyCmdSignature sig_ signed $ subVerifyKey sub then VRCommand sub else VRFail
|
||||
-- _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFail
|
||||
where
|
||||
tknCmd r c = VRVerified (NtfReqCmd SToken r (corrId, entId, c))
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m ()
|
||||
client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {pushQ} =
|
||||
client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {pushQ, intervalNotifiers} =
|
||||
forever $
|
||||
atomically (readTBQueue rcvQ)
|
||||
>>= processCommand
|
||||
@@ -202,30 +197,68 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push
|
||||
processCommand :: NtfRequest -> m (Transmission NtfResponse)
|
||||
processCommand = \case
|
||||
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do
|
||||
liftIO $ putStrLn "TNEW"
|
||||
logDebug "TNEW - new token"
|
||||
st <- asks store
|
||||
(srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair'
|
||||
let dhSecret = C.dh' dhPubKey srvDrivDhKey
|
||||
ks@(srvDhPubKey, srvDhPrivKey) <- liftIO C.generateKeyPair'
|
||||
let dhSecret = C.dh' dhPubKey srvDhPrivKey
|
||||
tknId <- getId
|
||||
regCode <- getRegCode
|
||||
atomically $ do
|
||||
tkn <- mkNtfTknData newTkn dhSecret regCode
|
||||
tkn <- mkNtfTknData newTkn ks dhSecret regCode
|
||||
addNtfToken st tknId tkn
|
||||
writeTBQueue pushQ (tkn, PNVerification regCode)
|
||||
pure (corrId, "", NRId tknId srvDhPubKey)
|
||||
NtfReqCmd SToken (NtfTkn NtfTknData {tknStatus, tknRegCode}) (corrId, tknId, cmd) -> do
|
||||
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey)}) (corrId, tknId, cmd) -> do
|
||||
status <- readTVarIO tknStatus
|
||||
(corrId,tknId,) <$> case cmd of
|
||||
TNEW _newTkn -> do
|
||||
liftIO $ putStrLn "TNEW'"
|
||||
pure NROk -- TODO when duplicate token sent
|
||||
TNEW (NewNtfTkn _ _ dhPubKey) -> do
|
||||
logDebug "TNEW - registered token"
|
||||
let dhSecret = C.dh' dhPubKey srvDhPrivKey
|
||||
-- it is required that DH secret is the same, to avoid failed verifications if notification is delaying
|
||||
if tknDhSecret == dhSecret
|
||||
then do
|
||||
atomically $ writeTBQueue pushQ (tkn, PNVerification tknRegCode)
|
||||
pure $ NRId tknId srvDhPubKey
|
||||
else pure $ NRErr AUTH
|
||||
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
|
||||
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
|
||||
logDebug "TVFY - token verified"
|
||||
atomically $ writeTVar tknStatus NTActive
|
||||
pure NROk
|
||||
| otherwise -> pure $ NRErr AUTH
|
||||
TDEL -> pure NROk
|
||||
TCRN _int -> pure NROk
|
||||
| otherwise -> do
|
||||
logDebug "TVFY - incorrect code or token status"
|
||||
pure $ NRErr AUTH
|
||||
TCHK -> pure $ NRTkn status
|
||||
TDEL -> do
|
||||
logDebug "TDEL"
|
||||
st <- asks store
|
||||
atomically $ deleteNtfToken st tknId
|
||||
pure NROk
|
||||
TCRN 0 ->
|
||||
logDebug "TCRN 0"
|
||||
>> atomically (TM.lookupDelete tknId intervalNotifiers)
|
||||
>>= mapM_ (uninterruptibleCancel . action)
|
||||
>> pure NROk
|
||||
TCRN int
|
||||
| int < 20 -> pure $ NRErr QUOTA
|
||||
| otherwise -> do
|
||||
logDebug "TCRN"
|
||||
atomically (TM.lookup tknId intervalNotifiers) >>= \case
|
||||
Nothing -> runIntervalNotifier int
|
||||
Just IntervalNotifier {interval, action} ->
|
||||
unless (interval == int) $ do
|
||||
uninterruptibleCancel action
|
||||
runIntervalNotifier int
|
||||
pure NROk
|
||||
where
|
||||
runIntervalNotifier interval = do
|
||||
action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60
|
||||
let notifier = IntervalNotifier {action, token = tkn, interval}
|
||||
atomically $ TM.insert tknId notifier intervalNotifiers
|
||||
where
|
||||
intervalNotifier delay = forever $ do
|
||||
threadDelay delay
|
||||
atomically $ writeTBQueue pushQ (tkn, PNCheckMessages)
|
||||
NtfReqNew corrId (ANE SSubscription _newSub) -> pure (corrId, "", NROk)
|
||||
NtfReqCmd SSubscription _sub (corrId, subId, cmd) ->
|
||||
(corrId,subId,) <$> case cmd of
|
||||
|
||||
@@ -7,9 +7,11 @@
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Env where
|
||||
|
||||
import Control.Concurrent.Async (Async)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Word (Word16)
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
@@ -59,7 +61,7 @@ newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsCo
|
||||
subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg
|
||||
pushServer <- atomically $ newNtfPushServer pushQSize apnsConfig
|
||||
-- TODO not creating APNS client on start to pass CI test, has to be replaced with mock APNS server
|
||||
-- void . liftIO $ newPushClient pushServer PPApple
|
||||
-- void . liftIO $ newPushClient pushServer PPApns
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
pure NtfEnv {config, subscriber, pushServer, store, idsDrg, tlsServerParams, serverIdentity = C.KeyHash fp}
|
||||
@@ -78,20 +80,28 @@ newNtfSubscriber qSize smpAgentCfg = do
|
||||
data NtfPushServer = NtfPushServer
|
||||
{ pushQ :: TBQueue (NtfTknData, PushNotification),
|
||||
pushClients :: TMap PushProvider PushProviderClient,
|
||||
intervalNotifiers :: TMap NtfTokenId IntervalNotifier,
|
||||
apnsConfig :: APNSPushClientConfig
|
||||
}
|
||||
|
||||
data IntervalNotifier = IntervalNotifier
|
||||
{ action :: Async (),
|
||||
token :: NtfTknData,
|
||||
interval :: Word16
|
||||
}
|
||||
|
||||
newNtfPushServer :: Natural -> APNSPushClientConfig -> STM NtfPushServer
|
||||
newNtfPushServer qSize apnsConfig = do
|
||||
pushQ <- newTBQueue qSize
|
||||
pushClients <- TM.empty
|
||||
pure NtfPushServer {pushQ, pushClients, apnsConfig}
|
||||
intervalNotifiers <- TM.empty
|
||||
pure NtfPushServer {pushQ, pushClients, intervalNotifiers, apnsConfig}
|
||||
|
||||
newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient
|
||||
newPushClient NtfPushServer {apnsConfig, pushClients} = \case
|
||||
PPApple -> do
|
||||
PPApns -> do
|
||||
c <- apnsPushProviderClient <$> createAPNSPushClient apnsConfig
|
||||
atomically $ TM.insert PPApple c pushClients
|
||||
atomically $ TM.insert PPApns c pushClients
|
||||
pure c
|
||||
|
||||
getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Push.APNS where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Crypto.Hash.Algorithms (SHA256 (..))
|
||||
import qualified Crypto.PubKey.ECC.ECDSA as EC
|
||||
@@ -20,6 +24,7 @@ import Data.Aeson (FromJSON, ToJSON, (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
@@ -27,7 +32,6 @@ import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import qualified Data.CaseInsensitive as CI
|
||||
import Data.Int (Int64)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8With)
|
||||
@@ -226,7 +230,7 @@ getApnsJWTToken APNSPushClient {apnsCfg = APNSPushClientConfig {appTeamId, token
|
||||
atomically $ writeTVar jwtToken t
|
||||
pure signedJWT'
|
||||
where
|
||||
jwtTokenAge (JWTToken _ JWTClaims {iat}) = (iat -) . systemSeconds <$> getSystemTime
|
||||
jwtTokenAge (JWTToken _ JWTClaims {iat}) = subtract iat . systemSeconds <$> getSystemTime
|
||||
|
||||
mkApnsJWTToken :: Text -> JWTHeader -> EC.PrivateKey -> IO (JWTToken, SignedJWTToken)
|
||||
mkApnsJWTToken appTeamId jwtHeader privateKey = do
|
||||
@@ -256,15 +260,15 @@ apnsNotification :: NtfTknData -> C.CbNonce -> Int -> PushNotification -> Either
|
||||
apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
|
||||
PNVerification (NtfRegCode code) ->
|
||||
encrypt code $ \code' ->
|
||||
apn APNSBackground {contentAvailable = 1} . Just $ J.object ["verification" .= code']
|
||||
apn APNSBackground {contentAvailable = 1} . Just $ J.object ["verification" .= code', "nonce" .= nonce]
|
||||
PNMessage srv nId ->
|
||||
encrypt (strEncode srv <> "/" <> strEncode nId) $ \ntfQueue ->
|
||||
apn apnMutableContent . Just $ J.object ["checkMessage" .= ntfQueue]
|
||||
apn apnMutableContent . Just $ J.object ["checkMessage" .= ntfQueue, "nonce" .= nonce]
|
||||
PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing
|
||||
PNCheckMessages -> Right $ apn APNSBackground {contentAvailable = 1} . Just $ J.object ["checkMessages" .= True]
|
||||
where
|
||||
encrypt :: ByteString -> (Text -> APNSNotification) -> Either C.CryptoError APNSNotification
|
||||
encrypt ntfData f = f . safeDecodeUtf8 . U.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen
|
||||
encrypt ntfData f = f . safeDecodeUtf8 . B64.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen
|
||||
apn aps notificationData = APNSNotification {aps, notificationData}
|
||||
apnMutableContent = APNSMutableContent {mutableContent = 1, alert = APNSAlertText "Encrypted message or some other app event", category = Nothing}
|
||||
apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
|
||||
@@ -300,40 +304,39 @@ data PushProviderError
|
||||
|
||||
type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProviderError IO ()
|
||||
|
||||
newtype APNSErrorReponse = APNSErrorReponse {reason :: Text}
|
||||
-- this is not a newtype on purpose to have a correct JSON encoding as a record
|
||||
data APNSErrorReponse = APNSErrorReponse {reason :: Text}
|
||||
deriving (Generic, FromJSON)
|
||||
|
||||
apnsPushProviderClient :: APNSPushClient -> PushProviderClient
|
||||
apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken PPApple tknStr} pn = do
|
||||
apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken PPApns tknStr} pn = do
|
||||
http2 <- liftHTTPS2 $ getApnsHTTP2Client c
|
||||
nonce <- atomically $ C.pseudoRandomCbNonce nonceDrg
|
||||
apnsNtf <- liftEither $ first PPCryptoError $ apnsNotification tkn nonce (paddedNtfLength apnsCfg) pn
|
||||
liftIO $ putStrLn $ "APNS notification: " <> show apnsNtf
|
||||
req <- liftIO $ apnsRequest c tknStr apnsNtf
|
||||
liftIO $ putStrLn $ "APNS request: " <> show req
|
||||
HTTP2Response {response, respBody} <- liftHTTPS2 $ sendRequest http2 req
|
||||
let status = H.responseStatus response
|
||||
reason = fromMaybe "" $ J.decodeStrict' =<< respBody
|
||||
liftIO $ putStrLn $ "APNS response: " <> show status <> " " <> T.unpack reason
|
||||
result status reason
|
||||
reason' = maybe "?" reason $ J.decodeStrict' respBody
|
||||
logDebug $ "APNS response: " <> T.pack (show status) <> " " <> reason'
|
||||
result status reason'
|
||||
where
|
||||
result :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
|
||||
result status reason
|
||||
result status reason'
|
||||
| status == Just N.ok200 = pure ()
|
||||
| status == Just N.badRequest400 =
|
||||
case reason of
|
||||
case reason' of
|
||||
"BadDeviceToken" -> throwError PPTokenInvalid
|
||||
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
|
||||
"TopicDisallowed" -> throwError PPPermanentError
|
||||
_ -> err status reason
|
||||
| status == Just N.forbidden403 = case reason of
|
||||
_ -> err status reason'
|
||||
| status == Just N.forbidden403 = case reason' of
|
||||
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
|
||||
"InvalidProviderToken" -> throwError PPPermanentError
|
||||
_ -> err status reason
|
||||
_ -> err status reason'
|
||||
| status == Just N.gone410 = throwError PPTokenInvalid
|
||||
| status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater
|
||||
-- Just tooManyRequests429 -> TODO TooManyRequests - too many requests for the same token
|
||||
| otherwise = err status reason
|
||||
| otherwise = err status reason'
|
||||
err :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
|
||||
err s r = throwError $ PPResponseError s r
|
||||
liftHTTPS2 a = ExceptT $ first PPConnection <$> a
|
||||
|
||||
@@ -28,14 +28,15 @@ data NtfTknData = NtfTknData
|
||||
{ token :: DeviceToken,
|
||||
tknStatus :: TVar NtfTknStatus,
|
||||
tknVerifyKey :: C.APublicVerifyKey,
|
||||
tknDhKeys :: C.KeyPair 'C.X25519,
|
||||
tknDhSecret :: C.DhSecretX25519,
|
||||
tknRegCode :: NtfRegCode
|
||||
}
|
||||
|
||||
mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
|
||||
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret tknRegCode = do
|
||||
mkNtfTknData :: NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
|
||||
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do
|
||||
tknStatus <- newTVar NTRegistered
|
||||
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret, tknRegCode}
|
||||
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode}
|
||||
|
||||
-- data NtfSubscriptionsStore = NtfSubscriptionsStore
|
||||
|
||||
@@ -64,8 +65,12 @@ getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st)
|
||||
|
||||
addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM ()
|
||||
addNtfToken st tknId tkn@NtfTknData {token} = do
|
||||
TM.insert tknId tkn (tokens st)
|
||||
TM.insert token tknId (tokenIds st)
|
||||
TM.insert tknId tkn $ tokens st
|
||||
TM.insert token tknId $ tokenIds st
|
||||
|
||||
deleteNtfToken :: NtfStore -> NtfTokenId -> STM ()
|
||||
deleteNtfToken st tknId = do
|
||||
TM.lookupDelete tknId (tokens st) >>= mapM_ (\NtfTknData {token} -> TM.delete token $ tokenIds st)
|
||||
|
||||
-- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e))
|
||||
-- getNtfRec st ent entId = case ent of
|
||||
|
||||
@@ -9,11 +9,13 @@ module Simplex.Messaging.Transport.Client.HTTP2 where
|
||||
import Control.Concurrent.Async
|
||||
import Control.Exception (IOException, catch, finally)
|
||||
import qualified Control.Exception as E
|
||||
import Control.Logger.Simple (logDebug)
|
||||
import Control.Monad.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Default (def)
|
||||
import Data.Maybe (isNothing)
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.X509.CertificateStore as XS
|
||||
import Foreign (mallocBytes)
|
||||
import Network.HPACK (BufferSize, HeaderTable)
|
||||
@@ -41,13 +43,12 @@ data HTTPS2Client = HTTPS2Client
|
||||
|
||||
data HTTP2Response = HTTP2Response
|
||||
{ response :: Response,
|
||||
respBody :: Maybe ByteString,
|
||||
respBody :: ByteString,
|
||||
respTrailers :: Maybe HeaderTable
|
||||
}
|
||||
|
||||
data HTTP2SClientConfig = HTTP2SClientConfig
|
||||
{ qSize :: Natural,
|
||||
maxBody :: Int,
|
||||
connTimeout :: Int,
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
caStoreFile :: FilePath,
|
||||
@@ -59,8 +60,7 @@ defaultHTTP2SClientConfig :: HTTP2SClientConfig
|
||||
defaultHTTP2SClientConfig =
|
||||
HTTP2SClientConfig
|
||||
{ qSize = 64,
|
||||
maxBody = 500000,
|
||||
connTimeout = 5000000,
|
||||
connTimeout = 10000000,
|
||||
tcpKeepAlive = Nothing,
|
||||
caStoreFile = "/etc/ssl/cert.pem",
|
||||
suportedTLSParams =
|
||||
@@ -112,24 +112,14 @@ getHTTPS2Client host port config@HTTP2SClientConfig {tcpKeepAlive, connTimeout,
|
||||
(req, respVar) <- atomically $ readTBQueue reqQ
|
||||
sendReq req $ \r -> do
|
||||
let writeResp respBody respTrailers = atomically $ putTMVar respVar HTTP2Response {response = r, respBody, respTrailers}
|
||||
case H.responseBodySize r of
|
||||
Just sz ->
|
||||
if sz <= maxBody config
|
||||
then do
|
||||
respBody <- getResponseBody r "" sz
|
||||
respTrailers <- join <$> mapM (const $ H.getResponseTrailers r) respBody
|
||||
writeResp respBody respTrailers
|
||||
else writeResp Nothing Nothing
|
||||
_ -> writeResp Nothing Nothing
|
||||
respBody <- getResponseBody r ""
|
||||
respTrailers <- H.getResponseTrailers r
|
||||
writeResp respBody respTrailers
|
||||
|
||||
getResponseBody :: Response -> ByteString -> Int -> IO (Maybe ByteString)
|
||||
getResponseBody r s sz =
|
||||
H.getResponseBodyChunk r >>= \chunk -> do
|
||||
if chunk == ""
|
||||
then pure (if B.length s == sz then Just s else Nothing)
|
||||
else do
|
||||
let s' = s <> chunk
|
||||
if B.length s' > sz then pure Nothing else getResponseBody r s' sz
|
||||
getResponseBody :: Response -> ByteString -> IO ByteString
|
||||
getResponseBody r s =
|
||||
H.getResponseBodyChunk r >>= \chunk ->
|
||||
if B.null chunk then pure s else getResponseBody r $ s <> chunk
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeHTTPS2Client :: HTTPS2Client -> IO ()
|
||||
|
||||
@@ -15,7 +15,7 @@ import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (DeviceToken), PushProvider (PPApple))
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), PushProvider (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import Simplex.Messaging.Transport (ATransport (..))
|
||||
import System.Timeout
|
||||
@@ -194,7 +194,7 @@ testNotificationToken :: IO ()
|
||||
testNotificationToken = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
registerNtfToken alice $ DeviceToken PPApple "abcd"
|
||||
registerNtfToken alice $ DeviceToken PPApns "abcd"
|
||||
pure ()
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
|
||||
Reference in New Issue
Block a user