diff --git a/apps/ntf-server/Main.hs b/apps/ntf-server/Main.hs index 478077e25..41625d6f2 100644 --- a/apps/ntf-server/Main.hs +++ b/apps/ntf-server/Main.hs @@ -2,6 +2,7 @@ module Main where +import Control.Logger.Simple import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig) import Simplex.Messaging.Notifications.Server (runNtfServer) import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..)) @@ -15,8 +16,13 @@ cfgPath = "/etc/opt/simplex-notifications" logPath :: FilePath logPath = "/var/opt/simplex-notifications" +logCfg :: LogConfig +logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} + main :: IO () -main = protocolServerCLI ntfServerCLIConfig runNtfServer +main = do + setLogLevel LogDebug -- TODO change to LogError in production + withGlobalLogging logCfg $ protocolServerCLI ntfServerCLIConfig runNtfServer ntfServerCLIConfig :: ServerCLIConfig NtfServerConfig ntfServerCLIConfig = diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 705b3d248..4ed89185a 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -63,7 +63,7 @@ import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (MonadRandom) -import Data.Bifunctor (first, second) +import Data.Bifunctor (bimap, first, second) import Data.ByteString.Char8 (ByteString) import Data.Composition ((.:), (.:.)) import Data.Functor (($>)) @@ -87,7 +87,7 @@ import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Ratchet as CR import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Client -import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfTknStatus (..)) +import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..)) import Simplex.Messaging.Parsers (parse) import Simplex.Messaging.Protocol (BrokerMsg, MsgBody) import qualified Simplex.Messaging.Protocol as SMP @@ -165,8 +165,8 @@ registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m () registerNtfToken c = withAgentEnv c . registerNtfToken' c -- | Verify device notifications token -verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m () -verifyNtfToken c = withAgentEnv c .: verifyNtfToken' c +verifyNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> ByteString -> C.CbNonce -> m () +verifyNtfToken c = withAgentEnv c .:. verifyNtfToken' c -- | Enable/disable periodic notifications enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m () @@ -523,60 +523,78 @@ setSMPServers' c servers = do registerNtfToken' :: forall m. AgentMonad m => AgentClient -> DeviceToken -> m () registerNtfToken' c deviceToken = withStore (`getDeviceNtfToken` deviceToken) >>= \case - Just tkn@NtfToken {ntfTokenId, ntfTknAction} -> case (ntfTokenId, ntfTknAction) of - (Nothing, Just (NTARegister ntfPubKey)) -> registerToken tkn ntfPubKey - -- TODO request verification code again in case there is registration, but no verification code in DB - probably after some timeout? - (Just tknId, Just (NTAVerify code)) -> - t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code - (Just tknId, Just (NTACron interval)) -> - t tkn (NTActive, Just NTACheck) $ agentNtfEnableCron c tknId tkn interval - (Just _tknId, Just NTACheck) -> pure () - (Just tknId, Just NTADelete) -> - t tkn (NTExpired, Nothing) $ agentNtfDeleteToken c tknId tkn - _ -> pure () + (Just tkn@NtfToken {ntfTokenId, ntfTknStatus, ntfTknAction}, prevTokens) -> do + mapM_ (deleteToken_ c) prevTokens + case (ntfTokenId, ntfTknAction) of + (Nothing, Just NTARegister) -> registerToken tkn + -- TODO minimal time before repeat registration + (Just _, Nothing) -> when (ntfTknStatus == NTRegistered) $ registerToken tkn + (Just tknId, Just (NTAVerify code)) -> + t tkn (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code + (Just tknId, Just (NTACron interval)) -> + t tkn (cronSuccess interval) $ agentNtfEnableCron c tknId tkn interval + (Just _tknId, Just NTACheck) -> pure () -- TODO + -- agentNtfCheckToken c tknId tkn >>= \case + (Just tknId, Just NTADelete) -> do + agentNtfDeleteToken c tknId tkn + withStore $ \st -> removeNtfToken st tkn + _ -> pure () _ -> getNtfServer c >>= \case Just ntfServer -> asks (cmdSignAlg . config) >>= \case C.SignAlg a -> do - (ntfPubKey, ntfPrivKey) <- liftIO $ C.generateSignatureKeyPair a - let tkn = newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey + tknKeys <- liftIO $ C.generateSignatureKeyPair a + dhKeys <- liftIO C.generateKeyPair' + let tkn = newNtfToken deviceToken ntfServer tknKeys dhKeys withStore $ \st -> createNtfToken st tkn - registerToken tkn ntfPubKey + registerToken tkn _ -> throwError $ CMD PROHIBITED where t tkn = withToken tkn Nothing - registerToken :: NtfToken -> C.APublicVerifyKey -> m () - registerToken tkn ntfPubKey = do - (pubDhKey, privDhKey) <- liftIO C.generateKeyPair' + registerToken :: NtfToken -> m () + registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do (tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey let dhSecret = C.dh' srvPubDhKey privDhKey withStore $ \st -> updateNtfTokenRegistration st tkn tknId dhSecret -verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> NtfRegCode -> m () -verifyNtfToken' c deviceToken code = +-- TODO decrypt verification code +verifyNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> ByteString -> C.CbNonce -> m () +verifyNtfToken' c deviceToken code nonce = withStore (`getDeviceNtfToken` deviceToken) >>= \case - Just tkn@NtfToken {ntfTokenId = Just tknId} -> - withToken tkn (Just (NTConfirmed, NTAVerify code)) (NTActive, Just NTACheck) $ - agentNtfVerifyToken c tknId tkn code + (Just tkn@NtfToken {ntfTokenId = Just tknId, ntfDhSecret = Just dhSecret}, _) -> do + code' <- liftEither . bimap cryptoError NtfRegCode $ C.cbDecrypt dhSecret nonce code + withToken tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $ + agentNtfVerifyToken c tknId tkn code' _ -> throwError $ CMD PROHIBITED enableNtfCron' :: AgentMonad m => AgentClient -> DeviceToken -> Word16 -> m () -enableNtfCron' c deviceToken interval = +enableNtfCron' c deviceToken interval = do + when (interval < 20) . throwError $ CMD PROHIBITED withStore (`getDeviceNtfToken` deviceToken) >>= \case - Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive} -> - withToken tkn (Just (NTActive, NTACron interval)) (NTActive, Just NTACheck) $ + (Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive}, _) -> + withToken tkn (Just (NTActive, NTACron interval)) (cronSuccess interval) $ agentNtfEnableCron c tknId tkn interval _ -> throwError $ CMD PROHIBITED +cronSuccess :: Word16 -> (NtfTknStatus, Maybe NtfTknAction) +cronSuccess interval + | interval == 0 = (NTActive, Just NTACheck) + | otherwise = (NTActive, Just $ NTACron interval) + deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m () deleteNtfToken' c deviceToken = withStore (`getDeviceNtfToken` deviceToken) >>= \case - Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus} -> - withToken tkn (Just (ntfTknStatus, NTADelete)) (NTExpired, Nothing) $ - agentNtfDeleteToken c tknId tkn + (Just tkn, _) -> deleteToken_ c tkn _ -> throwError $ CMD PROHIBITED +deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m () +deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do + forM_ ntfTokenId $ \tknId -> do + withStore $ \st -> updateNtfToken st tkn ntfTknStatus (Just NTADelete) + agentNtfDeleteToken c tknId tkn + withStore $ \st -> removeNtfToken st tkn + withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a withToken tkn from_ (toStatus, toAction_) f = do forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 8d51a3c9b..8108f36fd 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -25,6 +25,7 @@ module Simplex.Messaging.Agent.Client sendAgentMessage, agentNtfRegisterToken, agentNtfVerifyToken, + agentNtfCheckToken, agentNtfDeleteToken, agentNtfEnableCron, agentCbEncrypt, @@ -478,6 +479,10 @@ agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code = withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code +agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus +agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} = + withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId + agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m () agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} = withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 5a2c246ac..6931f53be 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -81,9 +81,10 @@ class Monad m => MonadAgentStore s m where createNtfToken :: s -> NtfToken -> m () -- TODO this should also return old tokens so that they are deleted from the server - getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken) -- return current token if it exists + getDeviceNtfToken :: s -> DeviceToken -> m (Maybe NtfToken, [NtfToken]) -- return current token if it exists updateNtfTokenRegistration :: s -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m () updateNtfToken :: s -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> m () + removeNtfToken :: s -> NtfToken -> m () -- * Queue types diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 52235f852..0400ef07e 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -33,12 +33,12 @@ import Control.Exception (bracket) import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Crypto.Random (ChaChaDRG, randomBytesGenerate) -import Data.Bifunctor (second) +import Data.Bifunctor (first, second) import Data.ByteString (ByteString) import qualified Data.ByteString.Base64.URL as U import Data.Char (toLower) import Data.Functor (($>)) -import Data.List (find, foldl') +import Data.List (find, foldl', partition) import qualified Data.Map.Strict as M import Data.Maybe (listToMaybe) import Data.Text (Text) @@ -567,35 +567,36 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk) createNtfToken :: SQLiteStore -> NtfToken -> m () - createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} = + createNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction} = liftIO . withTransaction st $ \db -> do upsertNtfServer_ db srv DB.execute db [sql| INSERT INTO ntf_tokens - (provider, device_token, ntf_host, ntf_port, tkn_id, tkn_priv_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + (provider, device_token, ntf_host, ntf_port, tkn_id, tkn_pub_key, tkn_priv_key, tkn_pub_dh_key, tkn_priv_dh_key, tkn_dh_secret, tkn_status, tkn_action) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |] - (provider, token, host, port, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) + (provider, token, host, port, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) - getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken) - getDeviceNtfToken st t@(DeviceToken provider token) = - liftIO . withTransaction st $ \db -> - fmap ntfToken . listToMaybe - <$> DB.query - db - [sql| - SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, - t.tkn_id, t.tkn_priv_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action - FROM ntf_tokens t - JOIN ntf_servers s USING (ntf_host, ntf_port) - WHERE t.provider = ? AND t.device_token = ? - |] - (provider, token) + getDeviceNtfToken :: SQLiteStore -> DeviceToken -> m (Maybe NtfToken, [NtfToken]) + getDeviceNtfToken st t = + liftIO . withTransaction st $ \db -> do + tokens <- + map ntfToken + <$> DB.query_ + db + [sql| + SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, + t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret, t.tkn_status, t.tkn_action + FROM ntf_tokens t + JOIN ntf_servers s USING (ntf_host, ntf_port) + |] + pure . first listToMaybe $ partition ((t ==) . deviceToken) tokens where - ntfToken (host, port, keyHash, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) = + ntfToken (host, port, keyHash, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction) = let ntfServer = ProtocolServer {host, port, keyHash} - in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPrivKey, ntfDhSecret, ntfTknStatus, ntfTknAction} + ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) + in NtfToken {deviceToken = t, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction} updateNtfTokenRegistration :: SQLiteStore -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> m () updateNtfTokenRegistration st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = @@ -623,6 +624,17 @@ instance (MonadUnliftIO m, MonadError StoreError m) => MonadAgentStore SQLiteSto |] (tknStatus, tknAction, updatedAt, provider, token, host, port) + removeNtfToken :: SQLiteStore -> NtfToken -> m () + removeNtfToken st NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} = + liftIO . withTransaction st $ \db -> + DB.execute + db + [sql| + DELETE FROM ntf_tokens + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (provider, token, host, port) + -- * Auxiliary helpers instance ToField QueueStatus where toField = toField . serializeQueueStatus diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs index 3587e05b1..4b93c0db9 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs @@ -23,7 +23,10 @@ CREATE TABLE ntf_tokens ( ntf_host TEXT NOT NULL, ntf_port TEXT NOT NULL, tkn_id BLOB, -- token ID assigned by notifications server - tkn_priv_key BLOB NOT NULL, -- private key to sign token commands + tkn_pub_key BLOB NOT NULL, -- client's public key to verify token commands (used by server, for repeat registraions) + tkn_priv_key BLOB NOT NULL, -- client's private key to sign token commands + tkn_pub_dh_key BLOB NOT NULL, -- client's public DH key (for repeat registraions) + tkn_priv_dh_key BLOB NOT NULL, -- client's private DH key (for repeat registraions) tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications tkn_status TEXT NOT NULL, tkn_action BLOB, diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 9f275e805..7caa6dedb 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -52,6 +52,7 @@ module Simplex.Messaging.Crypto CryptoPublicKey (..), CryptoPrivateKey (..), KeyPair, + ASignatureKeyPair, DhSecret (..), DhSecretX25519, ADhSecret (..), @@ -870,6 +871,14 @@ cbDecrypt secret (CbNonce nonce) packet newtype CbNonce = CbNonce {unCbNonce :: ByteString} deriving (Show) +instance StrEncoding CbNonce where + strEncode (CbNonce s) = strEncode s + strP = cbNonce <$> strP + +instance ToJSON CbNonce where + toJSON = strToJSON + toEncoding = strToJEncoding + cbNonce :: ByteString -> CbNonce cbNonce s | len == 24 = CbNonce s diff --git a/src/Simplex/Messaging/Notifications/Client.hs b/src/Simplex/Messaging/Notifications/Client.hs index 572831c8b..f7dc4fc52 100644 --- a/src/Simplex/Messaging/Notifications/Client.hs +++ b/src/Simplex/Messaging/Notifications/Client.hs @@ -33,6 +33,12 @@ ntfRegisterToken c pKey newTkn = ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO () ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId +ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO NtfTknStatus +ntfCheckToken c pKey tknId = + sendNtfCommand c (Just pKey) tknId TCHK >>= \case + NRTkn stat -> pure stat + _ -> throwE PCEUnexpectedResponse + ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO () ntfDeleteToken = okNtfCommand TDEL @@ -48,7 +54,7 @@ ntfCreateSubsciption c pKey newSub = ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus ntfCheckSubscription c pKey subId = sendNtfCommand c (Just pKey) subId SCHK >>= \case - NRStat stat -> pure stat + NRSub stat -> pure stat _ -> throwE PCEUnexpectedResponse ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO () @@ -65,7 +71,7 @@ okNtfCommand cmd c pKey entId = _ -> throwE PCEUnexpectedResponse data NtfTknAction - = NTARegister C.APublicVerifyKey -- public key to send to the server + = NTARegister | NTAVerify NtfRegCode -- code to verify token | NTACheck | NTACron Word16 @@ -74,14 +80,14 @@ data NtfTknAction instance Encoding NtfTknAction where smpEncode = \case - NTARegister key -> smpEncode ('R', key) + NTARegister -> "R" NTAVerify code -> smpEncode ('V', code) NTACheck -> "C" NTACron interval -> smpEncode ('I', interval) NTADelete -> "D" smpP = A.anyChar >>= \case - 'R' -> NTARegister <$> smpP + 'R' -> pure NTARegister 'V' -> NTAVerify <$> smpP 'C' -> pure NTACheck 'I' -> NTACron <$> smpP @@ -96,8 +102,12 @@ data NtfToken = NtfToken { deviceToken :: DeviceToken, ntfServer :: NtfServer, ntfTokenId :: Maybe NtfTokenId, + -- | key used by the ntf server to verify transmissions + ntfPubKey :: C.APublicVerifyKey, -- | key used by the ntf client to sign transmissions ntfPrivKey :: C.APrivateSignKey, + -- | client's DH keys (to repeat registration if necessary) + ntfDhKeys :: C.KeyPair 'C.X25519, -- | shared DH secret used to encrypt/decrypt notifications e2e ntfDhSecret :: Maybe C.DhSecretX25519, -- | token status @@ -107,14 +117,16 @@ data NtfToken = NtfToken } deriving (Show) -newNtfToken :: DeviceToken -> NtfServer -> C.APrivateSignKey -> C.APublicVerifyKey -> NtfToken -newNtfToken deviceToken ntfServer ntfPrivKey ntfPubKey = +newNtfToken :: DeviceToken -> NtfServer -> C.ASignatureKeyPair -> C.KeyPair 'C.X25519 -> NtfToken +newNtfToken deviceToken ntfServer (ntfPubKey, ntfPrivKey) ntfDhKeys = NtfToken { deviceToken, ntfServer, ntfTokenId = Nothing, + ntfPubKey, ntfPrivKey, + ntfDhKeys, ntfDhSecret = Nothing, ntfTknStatus = NTNew, - ntfTknAction = Just $ NTARegister ntfPubKey - } \ No newline at end of file + ntfTknAction = Just NTARegister + } diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 17c6ece56..d5985d213 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -16,6 +16,7 @@ import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Kind import Data.Maybe (isNothing) +import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Type.Equality import Data.Word (Word16) import Database.SQLite.Simple.FromField (FromField (..)) @@ -51,6 +52,7 @@ instance NtfEntityI 'Subscription where sNtfEntity = SSubscription data NtfCommandTag (e :: NtfEntity) where TNEW_ :: NtfCommandTag 'Token TVFY_ :: NtfCommandTag 'Token + TCHK_ :: NtfCommandTag 'Token TDEL_ :: NtfCommandTag 'Token TCRN_ :: NtfCommandTag 'Token SNEW_ :: NtfCommandTag 'Subscription @@ -66,6 +68,7 @@ instance NtfEntityI e => Encoding (NtfCommandTag e) where smpEncode = \case TNEW_ -> "TNEW" TVFY_ -> "TVFY" + TCHK_ -> "TCHK" TDEL_ -> "TDEL" TCRN_ -> "TCRN" SNEW_ -> "SNEW" @@ -82,6 +85,7 @@ instance ProtocolMsgTag NtfCmdTag where decodeTag = \case "TNEW" -> Just $ NCT SToken TNEW_ "TVFY" -> Just $ NCT SToken TVFY_ + "TCHK" -> Just $ NCT SToken TCHK_ "TDEL" -> Just $ NCT SToken TDEL_ "TCRN" -> Just $ NCT SToken TCRN_ "SNEW" -> Just $ NCT SSubscription SNEW_ @@ -147,6 +151,8 @@ data NtfCommand (e :: NtfEntity) where TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token -- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token TVFY :: NtfRegCode -> NtfCommand 'Token + -- | check token status + TCHK :: NtfCommand 'Token -- | delete token - all subscriptions will be removed and no more notifications will be sent TDEL :: NtfCommand 'Token -- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable @@ -171,6 +177,7 @@ instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where encodeProtocol = \case TNEW newTkn -> e (TNEW_, ' ', newTkn) TVFY code -> e (TVFY_, ' ', code) + TCHK -> e TCHK_ TDEL -> e TDEL_ TCRN int -> e (TCRN_, ' ', int) SNEW newSub -> e (SNEW_, ' ', newSub) @@ -209,6 +216,7 @@ instance ProtocolEncoding NtfCmd where NtfCmd SToken <$> case tag of TNEW_ -> TNEW <$> _smpP TVFY_ -> TVFY <$> _smpP + TCHK_ -> pure TCHK TDEL_ -> pure TDEL TCRN_ -> TCRN <$> _smpP NCT SSubscription tag -> @@ -224,7 +232,8 @@ data NtfResponseTag = NRId_ | NROk_ | NRErr_ - | NRStat_ + | NRTkn_ + | NRSub_ | NRPong_ deriving (Show) @@ -233,7 +242,8 @@ instance Encoding NtfResponseTag where NRId_ -> "ID" NROk_ -> "OK" NRErr_ -> "ERR" - NRStat_ -> "STAT" + NRTkn_ -> "TKN" + NRSub_ -> "SUB" NRPong_ -> "PONG" smpP = messageTagP @@ -242,7 +252,8 @@ instance ProtocolMsgTag NtfResponseTag where "ID" -> Just NRId_ "OK" -> Just NROk_ "ERR" -> Just NRErr_ - "STAT" -> Just NRStat_ + "TKN" -> Just NRTkn_ + "SUB" -> Just NRSub_ "PONG" -> Just NRPong_ _ -> Nothing @@ -250,7 +261,8 @@ data NtfResponse = NRId NtfEntityId C.PublicKeyX25519 | NROk | NRErr ErrorType - | NRStat NtfSubStatus + | NRTkn NtfTknStatus + | NRSub NtfSubStatus | NRPong instance ProtocolEncoding NtfResponse where @@ -259,7 +271,8 @@ instance ProtocolEncoding NtfResponse where NRId entId dhKey -> e (NRId_, ' ', entId, dhKey) NROk -> e NROk_ NRErr err -> e (NRErr_, ' ', err) - NRStat stat -> e (NRStat_, ' ', stat) + NRTkn stat -> e (NRTkn_, ' ', stat) + NRSub stat -> e (NRSub_, ' ', stat) NRPong -> e NRPong_ where e :: Encoding a => a -> ByteString @@ -269,7 +282,8 @@ instance ProtocolEncoding NtfResponse where NRId_ -> NRId <$> _smpP <*> smpP NROk_ -> pure NROk NRErr_ -> NRErr <$> _smpP - NRStat_ -> NRStat <$> _smpP + NRTkn_ -> NRTkn <$> _smpP + NRSub_ -> NRSub <$> _smpP NRPong_ -> pure NRPong checkCredentials (_, _, entId, _) cmd = case cmd of @@ -301,22 +315,22 @@ instance Encoding SMPQueueNtf where (smpServer, notifierId, notifierKey) <- smpP pure $ SMPQueueNtf smpServer notifierId notifierKey -data PushProvider = PPApple +data PushProvider = PPApns deriving (Eq, Ord, Show) instance Encoding PushProvider where smpEncode = \case - PPApple -> "A" + PPApns -> "A" smpP = A.anyChar >>= \case - 'A' -> pure PPApple + 'A' -> pure PPApns _ -> fail "bad PushProvider" instance TextEncoding PushProvider where textEncode = \case - PPApple -> "apple" + PPApns -> "apple" textDecode = \case - "apple" -> Just PPApple + "apple" -> Just PPApns _ -> Nothing instance FromField PushProvider where fromField = fromTextField_ textDecode @@ -363,7 +377,7 @@ instance Encoding NtfSubStatus where "ACTIVE" -> pure NSActive "END" -> pure NSEnd "SMP_AUTH" -> pure NSSMPAuth - _ -> fail "bad NtfError" + _ -> fail "bad NtfSubStatus" data NtfTknStatus = -- | Token created in DB @@ -380,26 +394,27 @@ data NtfTknStatus NTExpired deriving (Eq, Show) -instance TextEncoding NtfTknStatus where - textEncode = \case - NTNew -> "new" - NTRegistered -> "registered" - NTInvalid -> "invalid" - NTConfirmed -> "confirmed" - NTActive -> "active" - NTExpired -> "expired" - textDecode = \case - "new" -> Just NTNew - "registered" -> Just NTRegistered - "invalid" -> Just NTInvalid - "confirmed" -> Just NTConfirmed - "active" -> Just NTActive - "expired" -> Just NTExpired - _ -> Nothing +instance Encoding NtfTknStatus where + smpEncode = \case + NTNew -> "NEW" + NTRegistered -> "REGISTERED" + NTInvalid -> "INVALID" + NTConfirmed -> "CONFIRMED" + NTActive -> "ACTIVE" + NTExpired -> "EXPIRED" + smpP = + A.takeTill (== ' ') >>= \case + "NEW" -> pure NTNew + "REGISTERED" -> pure NTRegistered + "INVALID" -> pure NTInvalid + "CONFIRMED" -> pure NTConfirmed + "ACTIVE" -> pure NTActive + "EXPIRED" -> pure NTExpired + _ -> fail "bad NtfTknStatus" -instance FromField NtfTknStatus where fromField = fromTextField_ textDecode +instance FromField NtfTknStatus where fromField = fromTextField_ $ either (const Nothing) Just . smpDecode . encodeUtf8 -instance ToField NtfTknStatus where toField = toField . textEncode +instance ToField NtfTknStatus where toField = toField . decodeLatin1 . smpEncode checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e) checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 942b2db42..5d81b3715 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -10,11 +10,13 @@ module Simplex.Messaging.Notifications.Server where +import Control.Logger.Simple import Control.Monad.Except import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Reader import Crypto.Random (MonadRandom) import Data.ByteString.Char8 (ByteString) +import qualified Data.Text as T import Network.Socket (ServiceName) import Simplex.Messaging.Client.Agent import qualified Simplex.Messaging.Crypto as C @@ -26,9 +28,12 @@ import Simplex.Messaging.Notifications.Transport import Simplex.Messaging.Protocol (ErrorType (..), SignedTransmission, Transmission, encodeTransmission, tGet, tPut) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server +import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport (..), THandle (..), TProxy, Transport) import Simplex.Messaging.Transport.Server (runTransportServer) import Simplex.Messaging.Util +import UnliftIO (async, uninterruptibleCancel) +import UnliftIO.Concurrent (threadDelay) import UnliftIO.Exception import UnliftIO.STM @@ -106,7 +111,7 @@ ntfSubscriber NtfSubscriber {subQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} ntfPush :: MonadUnliftIO m => NtfPushServer -> m () ntfPush s@NtfPushServer {pushQ} = liftIO . forever . runExceptT $ do (tkn@NtfTknData {token = DeviceToken pp _, tknStatus}, ntf) <- atomically (readTBQueue pushQ) - liftIO $ putStrLn $ "sending push notification to " <> show pp + logDebug $ "sending push notification to " <> T.pack (show pp) status <- readTVarIO tknStatus case (status, ntf) of (_, PNVerification _) -> do @@ -116,13 +121,13 @@ ntfPush s@NtfPushServer {pushQ} = liftIO . forever . runExceptT $ do (NTActive, PNCheckMessages) -> do deliverNotification pp tkn ntf _ -> do - liftIO $ putStrLn "bad notification token status" + logError "bad notification token status" where deliverNotification :: PushProvider -> PushProviderClient deliverNotification pp tkn ntf = do deliver <- liftIO $ getPushClient s pp -- TODO retry later based on the error - deliver tkn ntf `catchError` \e -> liftIO (putStrLn $ "Push provider error (" <> show pp <> "): " <> show e) >> throwError e + deliver tkn ntf `catchError` \e -> logError (T.pack $ "Push provider error (" <> show pp <> "): " <> show e) >> throwError e runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m () runNtfClientTransport th@THandle {sessionId} = do @@ -138,14 +143,14 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m () receive th NtfServerClient {rcvQ, sndQ} = forever $ do - t@(_, _, (corrId, subId, cmdOrError)) <- tGet th - liftIO $ putStrLn "receive" + t@(_, _, (corrId, entId, cmdOrError)) <- tGet th + logDebug "received transmission" case cmdOrError of - Left e -> write sndQ (corrId, subId, NRErr e) + Left e -> write sndQ (corrId, entId, NRErr e) Right cmd -> verifyNtfTransmission t cmd >>= \case VRVerified req -> write rcvQ req - VRFailed -> write sndQ (corrId, subId, NRErr AUTH) + VRFailed -> write sndQ (corrId, entId, NRErr AUTH) where write q t = atomically $ writeTBQueue q t @@ -159,41 +164,31 @@ data VerificationResult = VRVerified NtfRequest | VRFailed verifyNtfTransmission :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => SignedTransmission NtfCmd -> NtfCmd -> m VerificationResult verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do + st <- asks store case cmd of - NtfCmd SToken (TNEW n@(NewNtfTkn _ k _)) -> - -- TODO check that token is not already in store, if it is - verify that the saved key is the same + NtfCmd SToken c@(TNEW n@(NewNtfTkn _ k _)) -> do + r_ <- atomically $ getNtfToken st entId pure $ if verifyCmdSignature sig_ signed k - then VRVerified (NtfReqNew corrId (ANE SToken n)) + then case r_ of + Just r@(NtfTkn NtfTknData {tknVerifyKey}) + | k == tknVerifyKey -> tknCmd r c + | otherwise -> VRFailed + _ -> VRVerified (NtfReqNew corrId (ANE SToken n)) else VRFailed NtfCmd SToken c -> do - st <- asks store - atomically (getNtfToken st entId) >>= \case - Just r@(NtfTkn NtfTknData {tknVerifyKey}) -> - pure $ - if verifyCmdSignature sig_ signed tknVerifyKey - then VRVerified (NtfReqCmd SToken r (corrId, entId, c)) - else VRFailed - _ -> pure VRFailed -- TODO dummy verification + r_ <- atomically $ getNtfToken st entId + pure $ case r_ of + Just r@(NtfTkn NtfTknData {tknVerifyKey}) + | verifyCmdSignature sig_ signed tknVerifyKey -> tknCmd r c + | otherwise -> VRFailed + _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed _ -> pure VRFailed - --- do --- st <- asks store --- case cmd of --- NCSubCreate tokenId smpQueue -> verifyCreateCmd verifyKey newSub <$> atomically (getNtfSubViaSMPQueue st smpQueue) --- _ -> verifySubCmd <$> atomically (getNtfSub st subId) --- where --- verifyCreateCmd k newSub sub_ --- | verifyCmdSignature sig_ signed k = case sub_ of --- Just sub -> if k == subVerifyKey sub then VRCommand sub else VRFail --- _ -> VRCreate newSub --- | otherwise = VRFail --- verifySubCmd = \case --- Just sub -> if verifyCmdSignature sig_ signed $ subVerifyKey sub then VRCommand sub else VRFail --- _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFail + where + tknCmd r c = VRVerified (NtfReqCmd SToken r (corrId, entId, c)) client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m () -client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {pushQ} = +client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {pushQ, intervalNotifiers} = forever $ atomically (readTBQueue rcvQ) >>= processCommand @@ -202,30 +197,68 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push processCommand :: NtfRequest -> m (Transmission NtfResponse) processCommand = \case NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do - liftIO $ putStrLn "TNEW" + logDebug "TNEW - new token" st <- asks store - (srvDhPubKey, srvDrivDhKey) <- liftIO C.generateKeyPair' - let dhSecret = C.dh' dhPubKey srvDrivDhKey + ks@(srvDhPubKey, srvDhPrivKey) <- liftIO C.generateKeyPair' + let dhSecret = C.dh' dhPubKey srvDhPrivKey tknId <- getId regCode <- getRegCode atomically $ do - tkn <- mkNtfTknData newTkn dhSecret regCode + tkn <- mkNtfTknData newTkn ks dhSecret regCode addNtfToken st tknId tkn writeTBQueue pushQ (tkn, PNVerification regCode) pure (corrId, "", NRId tknId srvDhPubKey) - NtfReqCmd SToken (NtfTkn NtfTknData {tknStatus, tknRegCode}) (corrId, tknId, cmd) -> do + NtfReqCmd SToken (NtfTkn tkn@NtfTknData {tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey)}) (corrId, tknId, cmd) -> do status <- readTVarIO tknStatus (corrId,tknId,) <$> case cmd of - TNEW _newTkn -> do - liftIO $ putStrLn "TNEW'" - pure NROk -- TODO when duplicate token sent + TNEW (NewNtfTkn _ _ dhPubKey) -> do + logDebug "TNEW - registered token" + let dhSecret = C.dh' dhPubKey srvDhPrivKey + -- it is required that DH secret is the same, to avoid failed verifications if notification is delaying + if tknDhSecret == dhSecret + then do + atomically $ writeTBQueue pushQ (tkn, PNVerification tknRegCode) + pure $ NRId tknId srvDhPubKey + else pure $ NRErr AUTH TVFY code -- this allows repeated verification for cases when client connection dropped before server response | (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do + logDebug "TVFY - token verified" atomically $ writeTVar tknStatus NTActive pure NROk - | otherwise -> pure $ NRErr AUTH - TDEL -> pure NROk - TCRN _int -> pure NROk + | otherwise -> do + logDebug "TVFY - incorrect code or token status" + pure $ NRErr AUTH + TCHK -> pure $ NRTkn status + TDEL -> do + logDebug "TDEL" + st <- asks store + atomically $ deleteNtfToken st tknId + pure NROk + TCRN 0 -> + logDebug "TCRN 0" + >> atomically (TM.lookupDelete tknId intervalNotifiers) + >>= mapM_ (uninterruptibleCancel . action) + >> pure NROk + TCRN int + | int < 20 -> pure $ NRErr QUOTA + | otherwise -> do + logDebug "TCRN" + atomically (TM.lookup tknId intervalNotifiers) >>= \case + Nothing -> runIntervalNotifier int + Just IntervalNotifier {interval, action} -> + unless (interval == int) $ do + uninterruptibleCancel action + runIntervalNotifier int + pure NROk + where + runIntervalNotifier interval = do + action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60 + let notifier = IntervalNotifier {action, token = tkn, interval} + atomically $ TM.insert tknId notifier intervalNotifiers + where + intervalNotifier delay = forever $ do + threadDelay delay + atomically $ writeTBQueue pushQ (tkn, PNCheckMessages) NtfReqNew corrId (ANE SSubscription _newSub) -> pure (corrId, "", NROk) NtfReqCmd SSubscription _sub (corrId, subId, cmd) -> (corrId,subId,) <$> case cmd of diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 67b83f1af..9c94c5ace 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -7,9 +7,11 @@ module Simplex.Messaging.Notifications.Server.Env where +import Control.Concurrent.Async (Async) import Control.Monad.IO.Unlift import Crypto.Random import Data.ByteString.Char8 (ByteString) +import Data.Word (Word16) import Data.X509.Validation (Fingerprint (..)) import Network.Socket import qualified Network.TLS as T @@ -59,7 +61,7 @@ newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsCo subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg pushServer <- atomically $ newNtfPushServer pushQSize apnsConfig -- TODO not creating APNS client on start to pass CI test, has to be replaced with mock APNS server - -- void . liftIO $ newPushClient pushServer PPApple + -- void . liftIO $ newPushClient pushServer PPApns tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile pure NtfEnv {config, subscriber, pushServer, store, idsDrg, tlsServerParams, serverIdentity = C.KeyHash fp} @@ -78,20 +80,28 @@ newNtfSubscriber qSize smpAgentCfg = do data NtfPushServer = NtfPushServer { pushQ :: TBQueue (NtfTknData, PushNotification), pushClients :: TMap PushProvider PushProviderClient, + intervalNotifiers :: TMap NtfTokenId IntervalNotifier, apnsConfig :: APNSPushClientConfig } +data IntervalNotifier = IntervalNotifier + { action :: Async (), + token :: NtfTknData, + interval :: Word16 + } + newNtfPushServer :: Natural -> APNSPushClientConfig -> STM NtfPushServer newNtfPushServer qSize apnsConfig = do pushQ <- newTBQueue qSize pushClients <- TM.empty - pure NtfPushServer {pushQ, pushClients, apnsConfig} + intervalNotifiers <- TM.empty + pure NtfPushServer {pushQ, pushClients, intervalNotifiers, apnsConfig} newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient newPushClient NtfPushServer {apnsConfig, pushClients} = \case - PPApple -> do + PPApns -> do c <- apnsPushProviderClient <$> createAPNSPushClient apnsConfig - atomically $ TM.insert PPApple c pushClients + atomically $ TM.insert PPApns c pushClients pure c getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs index 8a67f9fcf..0143fd20b 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs @@ -1,12 +1,16 @@ +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +{-# HLINT ignore "Use newtype instead of data" #-} module Simplex.Messaging.Notifications.Server.Push.APNS where +import Control.Logger.Simple import Control.Monad.Except import Crypto.Hash.Algorithms (SHA256 (..)) import qualified Crypto.PubKey.ECC.ECDSA as EC @@ -20,6 +24,7 @@ import Data.Aeson (FromJSON, ToJSON, (.=)) import qualified Data.Aeson as J import qualified Data.Aeson.Encoding as JE import Data.Bifunctor (first) +import qualified Data.ByteString.Base64 as B64 import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Builder (lazyByteString) import Data.ByteString.Char8 (ByteString) @@ -27,7 +32,6 @@ import qualified Data.ByteString.Lazy.Char8 as LB import qualified Data.CaseInsensitive as CI import Data.Int (Int64) import Data.Map.Strict (Map) -import Data.Maybe (fromMaybe) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8With) @@ -226,7 +230,7 @@ getApnsJWTToken APNSPushClient {apnsCfg = APNSPushClientConfig {appTeamId, token atomically $ writeTVar jwtToken t pure signedJWT' where - jwtTokenAge (JWTToken _ JWTClaims {iat}) = (iat -) . systemSeconds <$> getSystemTime + jwtTokenAge (JWTToken _ JWTClaims {iat}) = subtract iat . systemSeconds <$> getSystemTime mkApnsJWTToken :: Text -> JWTHeader -> EC.PrivateKey -> IO (JWTToken, SignedJWTToken) mkApnsJWTToken appTeamId jwtHeader privateKey = do @@ -256,15 +260,15 @@ apnsNotification :: NtfTknData -> C.CbNonce -> Int -> PushNotification -> Either apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case PNVerification (NtfRegCode code) -> encrypt code $ \code' -> - apn APNSBackground {contentAvailable = 1} . Just $ J.object ["verification" .= code'] + apn APNSBackground {contentAvailable = 1} . Just $ J.object ["verification" .= code', "nonce" .= nonce] PNMessage srv nId -> encrypt (strEncode srv <> "/" <> strEncode nId) $ \ntfQueue -> - apn apnMutableContent . Just $ J.object ["checkMessage" .= ntfQueue] + apn apnMutableContent . Just $ J.object ["checkMessage" .= ntfQueue, "nonce" .= nonce] PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing PNCheckMessages -> Right $ apn APNSBackground {contentAvailable = 1} . Just $ J.object ["checkMessages" .= True] where encrypt :: ByteString -> (Text -> APNSNotification) -> Either C.CryptoError APNSNotification - encrypt ntfData f = f . safeDecodeUtf8 . U.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen + encrypt ntfData f = f . safeDecodeUtf8 . B64.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen apn aps notificationData = APNSNotification {aps, notificationData} apnMutableContent = APNSMutableContent {mutableContent = 1, alert = APNSAlertText "Encrypted message or some other app event", category = Nothing} apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing} @@ -300,40 +304,39 @@ data PushProviderError type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProviderError IO () -newtype APNSErrorReponse = APNSErrorReponse {reason :: Text} +-- this is not a newtype on purpose to have a correct JSON encoding as a record +data APNSErrorReponse = APNSErrorReponse {reason :: Text} deriving (Generic, FromJSON) apnsPushProviderClient :: APNSPushClient -> PushProviderClient -apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken PPApple tknStr} pn = do +apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken PPApns tknStr} pn = do http2 <- liftHTTPS2 $ getApnsHTTP2Client c nonce <- atomically $ C.pseudoRandomCbNonce nonceDrg apnsNtf <- liftEither $ first PPCryptoError $ apnsNotification tkn nonce (paddedNtfLength apnsCfg) pn - liftIO $ putStrLn $ "APNS notification: " <> show apnsNtf req <- liftIO $ apnsRequest c tknStr apnsNtf - liftIO $ putStrLn $ "APNS request: " <> show req HTTP2Response {response, respBody} <- liftHTTPS2 $ sendRequest http2 req let status = H.responseStatus response - reason = fromMaybe "" $ J.decodeStrict' =<< respBody - liftIO $ putStrLn $ "APNS response: " <> show status <> " " <> T.unpack reason - result status reason + reason' = maybe "?" reason $ J.decodeStrict' respBody + logDebug $ "APNS response: " <> T.pack (show status) <> " " <> reason' + result status reason' where result :: Maybe Status -> Text -> ExceptT PushProviderError IO () - result status reason + result status reason' | status == Just N.ok200 = pure () | status == Just N.badRequest400 = - case reason of + case reason' of "BadDeviceToken" -> throwError PPTokenInvalid "DeviceTokenNotForTopic" -> throwError PPTokenInvalid "TopicDisallowed" -> throwError PPPermanentError - _ -> err status reason - | status == Just N.forbidden403 = case reason of + _ -> err status reason' + | status == Just N.forbidden403 = case reason' of "ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed "InvalidProviderToken" -> throwError PPPermanentError - _ -> err status reason + _ -> err status reason' | status == Just N.gone410 = throwError PPTokenInvalid | status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater -- Just tooManyRequests429 -> TODO TooManyRequests - too many requests for the same token - | otherwise = err status reason + | otherwise = err status reason' err :: Maybe Status -> Text -> ExceptT PushProviderError IO () err s r = throwError $ PPResponseError s r liftHTTPS2 a = ExceptT $ first PPConnection <$> a diff --git a/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs b/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs index 366ba26c2..9123a690f 100644 --- a/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs +++ b/src/Simplex/Messaging/Notifications/Server/Subscriptions.hs @@ -28,14 +28,15 @@ data NtfTknData = NtfTknData { token :: DeviceToken, tknStatus :: TVar NtfTknStatus, tknVerifyKey :: C.APublicVerifyKey, + tknDhKeys :: C.KeyPair 'C.X25519, tknDhSecret :: C.DhSecretX25519, tknRegCode :: NtfRegCode } -mkNtfTknData :: NewNtfEntity 'Token -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData -mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhSecret tknRegCode = do +mkNtfTknData :: NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData +mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do tknStatus <- newTVar NTRegistered - pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhSecret, tknRegCode} + pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode} -- data NtfSubscriptionsStore = NtfSubscriptionsStore @@ -64,8 +65,12 @@ getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st) addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () addNtfToken st tknId tkn@NtfTknData {token} = do - TM.insert tknId tkn (tokens st) - TM.insert token tknId (tokenIds st) + TM.insert tknId tkn $ tokens st + TM.insert token tknId $ tokenIds st + +deleteNtfToken :: NtfStore -> NtfTokenId -> STM () +deleteNtfToken st tknId = do + TM.lookupDelete tknId (tokens st) >>= mapM_ (\NtfTknData {token} -> TM.delete token $ tokenIds st) -- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e)) -- getNtfRec st ent entId = case ent of diff --git a/src/Simplex/Messaging/Transport/Client/HTTP2.hs b/src/Simplex/Messaging/Transport/Client/HTTP2.hs index 8e3a35741..2d019abd3 100644 --- a/src/Simplex/Messaging/Transport/Client/HTTP2.hs +++ b/src/Simplex/Messaging/Transport/Client/HTTP2.hs @@ -9,11 +9,13 @@ module Simplex.Messaging.Transport.Client.HTTP2 where import Control.Concurrent.Async import Control.Exception (IOException, catch, finally) import qualified Control.Exception as E +import Control.Logger.Simple (logDebug) import Control.Monad.Except import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Default (def) import Data.Maybe (isNothing) +import qualified Data.Text as T import qualified Data.X509.CertificateStore as XS import Foreign (mallocBytes) import Network.HPACK (BufferSize, HeaderTable) @@ -41,13 +43,12 @@ data HTTPS2Client = HTTPS2Client data HTTP2Response = HTTP2Response { response :: Response, - respBody :: Maybe ByteString, + respBody :: ByteString, respTrailers :: Maybe HeaderTable } data HTTP2SClientConfig = HTTP2SClientConfig { qSize :: Natural, - maxBody :: Int, connTimeout :: Int, tcpKeepAlive :: Maybe KeepAliveOpts, caStoreFile :: FilePath, @@ -59,8 +60,7 @@ defaultHTTP2SClientConfig :: HTTP2SClientConfig defaultHTTP2SClientConfig = HTTP2SClientConfig { qSize = 64, - maxBody = 500000, - connTimeout = 5000000, + connTimeout = 10000000, tcpKeepAlive = Nothing, caStoreFile = "/etc/ssl/cert.pem", suportedTLSParams = @@ -112,24 +112,14 @@ getHTTPS2Client host port config@HTTP2SClientConfig {tcpKeepAlive, connTimeout, (req, respVar) <- atomically $ readTBQueue reqQ sendReq req $ \r -> do let writeResp respBody respTrailers = atomically $ putTMVar respVar HTTP2Response {response = r, respBody, respTrailers} - case H.responseBodySize r of - Just sz -> - if sz <= maxBody config - then do - respBody <- getResponseBody r "" sz - respTrailers <- join <$> mapM (const $ H.getResponseTrailers r) respBody - writeResp respBody respTrailers - else writeResp Nothing Nothing - _ -> writeResp Nothing Nothing + respBody <- getResponseBody r "" + respTrailers <- H.getResponseTrailers r + writeResp respBody respTrailers - getResponseBody :: Response -> ByteString -> Int -> IO (Maybe ByteString) - getResponseBody r s sz = - H.getResponseBodyChunk r >>= \chunk -> do - if chunk == "" - then pure (if B.length s == sz then Just s else Nothing) - else do - let s' = s <> chunk - if B.length s' > sz then pure Nothing else getResponseBody r s' sz + getResponseBody :: Response -> ByteString -> IO ByteString + getResponseBody r s = + H.getResponseBodyChunk r >>= \chunk -> + if B.null chunk then pure s else getResponseBody r $ s <> chunk -- | Disconnects client from the server and terminates client threads. closeHTTPS2Client :: HTTPS2Client -> IO () diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index a37bcff19..a08e72623 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -15,7 +15,7 @@ import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..)) import Simplex.Messaging.Agent.Protocol -import Simplex.Messaging.Notifications.Protocol (DeviceToken (DeviceToken), PushProvider (PPApple)) +import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), PushProvider (..)) import Simplex.Messaging.Protocol (ErrorType (..), MsgBody) import Simplex.Messaging.Transport (ATransport (..)) import System.Timeout @@ -194,7 +194,7 @@ testNotificationToken :: IO () testNotificationToken = do alice <- getSMPAgentClient cfg initAgentServers Right () <- runExceptT $ do - registerNtfToken alice $ DeviceToken PPApple "abcd" + registerNtfToken alice $ DeviceToken PPApns "abcd" pure () exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()