diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 39e351f99..a2aeb9d68 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -52,6 +52,7 @@ module Simplex.Messaging.Agent registerNtfToken, verifyNtfToken, enableNtfCron, + checkNtfToken, deleteNtfToken, logConnection, ) @@ -89,10 +90,10 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..)) import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, MsgBody) +import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM) +import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM, ($>>=)) import Simplex.Messaging.Version import System.Random (randomR) import UnliftIO.Async (async, race_) @@ -172,6 +173,9 @@ verifyNtfToken c = withAgentEnv c .:. verifyNtfToken' c enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m () enableNtfCron c = withAgentEnv c .: enableNtfCron' c +checkNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m NtfTknStatus +checkNtfToken c = withAgentEnv c . checkNtfToken' c + deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m () deleteNtfToken c = withAgentEnv c . deleteNtfToken' c @@ -552,7 +556,7 @@ registerNtfToken' c deviceToken = registerToken tkn _ -> throwError $ CMD PROHIBITED where - t tkn = withToken tkn Nothing + t tkn = withToken c tkn Nothing registerToken :: NtfToken -> m () registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do (tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey @@ -565,7 +569,7 @@ verifyNtfToken' c deviceToken code nonce = withStore (`getDeviceNtfToken` deviceToken) >>= \case (Just tkn@NtfToken {ntfTokenId = Just tknId, ntfDhSecret = Just dhSecret}, _) -> do code' <- liftEither . bimap cryptoError NtfRegCode $ C.cbDecrypt dhSecret nonce code - withToken tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $ + withToken c tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $ agentNtfVerifyToken c tknId tkn code' _ -> throwError $ CMD PROHIBITED @@ -574,7 +578,7 @@ enableNtfCron' c deviceToken interval = do when (interval < 20) . throwError $ CMD PROHIBITED withStore (`getDeviceNtfToken` deviceToken) >>= \case (Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive}, _) -> - withToken tkn (Just (NTActive, NTACron interval)) (cronSuccess interval) $ + withToken c tkn (Just (NTActive, NTACron interval)) (cronSuccess interval) $ agentNtfEnableCron c tknId tkn interval _ -> throwError $ CMD PROHIBITED @@ -583,6 +587,12 @@ cronSuccess interval | interval == 0 = (NTActive, Just NTACheck) | otherwise = (NTActive, Just $ NTACron interval) +checkNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m NtfTknStatus +checkNtfToken' c deviceToken = + withStore (`getDeviceNtfToken` deviceToken) >>= \case + (Just tkn@NtfToken {ntfTokenId = Just tknId}, _) -> agentNtfCheckToken c tknId tkn + _ -> throwError $ CMD PROHIBITED + deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m () deleteNtfToken' c deviceToken = withStore (`getDeviceNtfToken` deviceToken) >>= \case @@ -593,15 +603,23 @@ deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m () deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do forM_ ntfTokenId $ \tknId -> do withStore $ \st -> updateNtfToken st tkn ntfTknStatus (Just NTADelete) - agentNtfDeleteToken c tknId tkn + agentNtfDeleteToken c tknId tkn `catchError` \case + NTF AUTH -> pure () + e -> throwError e withStore $ \st -> removeNtfToken st tkn -withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a -withToken tkn from_ (toStatus, toAction_) f = do +withToken :: AgentMonad m => AgentClient -> NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a +withToken c tkn@NtfToken {deviceToken} from_ (toStatus, toAction_) f = do forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action) - res <- f - withStore $ \st -> updateNtfToken st tkn toStatus toAction_ - pure res + tryError f >>= \case + Right res -> do + withStore $ \st -> updateNtfToken st tkn toStatus toAction_ + pure res + Left e@(NTF AUTH) -> do + withStore $ \st -> removeNtfToken st tkn + registerNtfToken' c deviceToken + throwError e + Left e -> throwError e setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m () setNtfServers' c servers = do @@ -678,7 +696,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, sessId, rId, cmd) _ -> prohibited >> ack _ -> prohibited >> ack SMP.END -> - atomically (TM.lookup srv smpClients >>= fmap join . mapM tryReadTMVar >>= processEND) + atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND) >>= logServer "<--" c srv rId where processEND = \case diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 8108f36fd..e063ae9bc 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -8,6 +9,7 @@ {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Agent.Client ( AgentClient (..), @@ -66,7 +68,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol -import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey) +import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -132,10 +134,15 @@ type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorTy class ProtocolServerClient msg where getProtocolServerClient :: AgentMonad m => AgentClient -> ProtocolServer -> m (ProtocolClient msg) + protocolError :: ErrorType -> AgentErrorType -instance ProtocolServerClient BrokerMsg where getProtocolServerClient = getSMPServerClient +instance ProtocolServerClient BrokerMsg where + getProtocolServerClient = getSMPServerClient + protocolError = SMP -instance ProtocolServerClient NtfResponse where getProtocolServerClient = getNtfServerClient +instance ProtocolServerClient NtfResponse where + getProtocolServerClient = getNtfServerClient + protocolError = NTF getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient getSMPServerClient c@AgentClient {smpClients, msgQ} srv = @@ -148,7 +155,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = connectClient = do cfg <- asks $ smpCfg . config u <- askUnliftIO - liftEitherError protocolClientError (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u) + liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u) clientDisconnected :: UnliftIO m -> IO () clientDisconnected u = do @@ -207,7 +214,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv = e@PCEResponseTimeout -> throwError e e@PCENetworkError -> throwError e e -> do - liftIO $ notifySub (ERR $ protocolClientError e) connId + liftIO $ notifySub (ERR $ protocolClientError SMP e) connId atomically $ removePendingSubscription c srv connId notifySub :: ACommand 'Agent -> ConnId -> IO () @@ -223,7 +230,7 @@ getNtfServerClient c@AgentClient {ntfClients} srv = connectClient :: m NtfClient connectClient = do cfg <- asks $ ntfCfg . config - liftEitherError protocolClientError (getProtocolClient srv cfg Nothing clientDisconnected) + liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected) clientDisconnected :: IO () clientDisconnected = do @@ -322,18 +329,18 @@ withLogClient_ c srv qId cmdStr action = do logServer "<--" c srv qId "OK" return res -withClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a -withClient c srv action = withClient_ c srv $ liftClient . action +withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withClient c srv action = withClient_ c srv $ liftClient (protocolError @msg) . action -withLogClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a -withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient . action +withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a +withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient (protocolError @msg) . action -liftClient :: AgentMonad m => ExceptT ProtocolClientError IO a -> m a -liftClient = liftError protocolClientError +liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> ExceptT ProtocolClientError IO a -> m a +liftClient = liftError . protocolClientError -protocolClientError :: ProtocolClientError -> AgentErrorType -protocolClientError = \case - PCEProtocolError e -> SMP e +protocolClientError :: (ErrorType -> AgentErrorType) -> ProtocolClientError -> AgentErrorType +protocolClientError protocolError_ = \case + PCEProtocolError e -> protocolError_ e PCEResponseError e -> BROKER $ RESPONSE e PCEUnexpectedResponse -> BROKER UNEXPECTED PCEResponseTimeout -> BROKER TIMEOUT @@ -428,14 +435,14 @@ sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, withLogClient_ c server sndId "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg - liftClient $ sendSMPMessage smp Nothing sndId msg + liftClient SMP $ sendSMPMessage smp Nothing sndId msg sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database" sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m () sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo = withLogClient_ c smpServer senderId "SEND " $ \smp -> do msg <- mkInvitation - liftClient $ sendSMPMessage smp Nothing senderId msg + liftClient SMP $ sendSMPMessage smp Nothing senderId msg where mkInvitation :: m ByteString -- this is only encrypted with per-queue E2E, not with double ratchet @@ -469,7 +476,7 @@ sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg = withLogClient_ c server sndId "SEND " $ \smp -> do let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg - liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg + liftClient SMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519) agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey = diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index dbe47d908..8be63a060 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -686,6 +686,8 @@ data AgentErrorType CONN {connErr :: ConnectionErrorType} | -- | SMP protocol errors forwarded to agent clients SMP {smpErr :: ErrorType} + | -- | NTF protocol errors forwarded to agent clients + NTF {ntfErr :: ErrorType} | -- | SMP server errors BROKER {brokerErr :: BrokerErrorType} | -- | errors of other agents @@ -774,6 +776,7 @@ instance StrEncoding AgentErrorType where "CMD " *> (CMD <$> parseRead1) <|> "CONN " *> (CONN <$> parseRead1) <|> "SMP " *> (SMP <$> strP) + <|> "NTF " *> (NTF <$> strP) <|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> strP) <|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP) <|> "BROKER " *> (BROKER <$> parseRead1) @@ -783,6 +786,7 @@ instance StrEncoding AgentErrorType where CMD e -> "CMD " <> bshow e CONN e -> "CONN " <> bshow e SMP e -> "SMP " <> strEncode e + NTF e -> "NTF " <> strEncode e BROKER (RESPONSE e) -> "BROKER RESPONSE " <> strEncode e BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e BROKER e -> "BROKER " <> bshow e diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 7a62dc465..1f9d215cd 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -28,7 +28,7 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (tryE, whenM) +import Simplex.Messaging.Util (tryE, whenM, ($>>=)) import System.Timeout (timeout) import UnliftIO (async, forConcurrently_) import UnliftIO.Exception (Exception) @@ -295,7 +295,7 @@ removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMP removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s) getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateSignKey) -getSubKey subs srv s = fmap join . mapM (TM.lookup s) =<< TM.lookup srv subs +getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s hasSub :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM Bool hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index d5985d213..d3ad6d043 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -264,6 +264,7 @@ data NtfResponse | NRTkn NtfTknStatus | NRSub NtfSubStatus | NRPong + deriving (Show) instance ProtocolEncoding NtfResponse where type Tag NtfResponse = NtfResponseTag @@ -361,7 +362,7 @@ data NtfSubStatus NSEnd | -- | SMP AUTH error NSSMPAuth - deriving (Eq) + deriving (Eq, Show) instance Encoding NtfSubStatus where smpEncode = \case diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index babe98a70..f7100cf14 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -117,7 +117,7 @@ ntfPush s@NtfPushServer {pushQ} = liftIO . forever . runExceptT $ do (_, PNVerification _) -> do -- TODO check token status deliverNotification pp tkn ntf - atomically $ writeTVar tknStatus NTConfirmed + atomically $ modifyTVar tknStatus $ \status' -> if status' == NTActive then NTActive else NTConfirmed (NTActive, PNCheckMessages) -> do deliverNotification pp tkn ntf _ -> do @@ -166,26 +166,26 @@ verifyNtfTransmission :: verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do st <- asks store case cmd of - NtfCmd SToken c@(TNEW n@(NewNtfTkn _ k _)) -> do - r_ <- atomically $ getNtfToken st entId + NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _)) -> do + r_ <- atomically $ getNtfTokenRegistration st tkn pure $ if verifyCmdSignature sig_ signed k then case r_ of - Just r@(NtfTkn NtfTknData {tknVerifyKey}) - | k == tknVerifyKey -> tknCmd r c + Just t@NtfTknData {tknVerifyKey} + | k == tknVerifyKey -> verifiedTknCmd t c | otherwise -> VRFailed - _ -> VRVerified (NtfReqNew corrId (ANE SToken n)) + _ -> VRVerified (NtfReqNew corrId (ANE SToken tkn)) else VRFailed NtfCmd SToken c -> do - r_ <- atomically $ getNtfToken st entId - pure $ case r_ of - Just r@(NtfTkn NtfTknData {tknVerifyKey}) - | verifyCmdSignature sig_ signed tknVerifyKey -> tknCmd r c + t_ <- atomically $ getNtfToken st entId + pure $ case t_ of + Just t@NtfTknData {tknVerifyKey} + | verifyCmdSignature sig_ signed tknVerifyKey -> verifiedTknCmd t c | otherwise -> VRFailed _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed _ -> pure VRFailed where - tknCmd r c = VRVerified (NtfReqCmd SToken r (corrId, entId, c)) + verifiedTknCmd t c = VRVerified (NtfReqCmd SToken (NtfTkn t) (corrId, entId, c)) client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m () client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {pushQ, intervalNotifiers} = @@ -204,11 +204,11 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push tknId <- getId regCode <- getRegCode atomically $ do - tkn <- mkNtfTknData newTkn ks dhSecret regCode + tkn <- mkNtfTknData tknId newTkn ks dhSecret regCode addNtfToken st tknId tkn writeTBQueue pushQ (tkn, PNVerification regCode) pure (corrId, "", NRId tknId srvDhPubKey) - NtfReqCmd SToken (NtfTkn tkn@NtfTknData {tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey)}) (corrId, tknId, cmd) -> do + NtfReqCmd SToken (NtfTkn tkn@NtfTknData {ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey)}) (corrId, tknId, cmd) -> do status <- readTVarIO tknStatus (corrId,tknId,) <$> case cmd of TNEW (NewNtfTkn _ _ dhPubKey) -> do @@ -218,12 +218,15 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push if tknDhSecret == dhSecret then do atomically $ writeTBQueue pushQ (tkn, PNVerification tknRegCode) - pure $ NRId tknId srvDhPubKey + pure $ NRId ntfTknId srvDhPubKey else pure $ NRErr AUTH TVFY code -- this allows repeated verification for cases when client connection dropped before server response | (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do logDebug "TVFY - token verified" + st <- asks store atomically $ writeTVar tknStatus NTActive + tIds <- atomically $ removeInactiveTokenRegistrations st tkn + forM_ tIds cancelInvervalNotifications pure NROk | otherwise -> do logDebug "TVFY - incorrect code or token status" @@ -233,12 +236,12 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push logDebug "TDEL" st <- asks store atomically $ deleteNtfToken st tknId + cancelInvervalNotifications tknId pure NROk - TCRN 0 -> + TCRN 0 -> do logDebug "TCRN 0" - >> atomically (TM.lookupDelete tknId intervalNotifiers) - >>= mapM_ (uninterruptibleCancel . action) - >> pure NROk + cancelInvervalNotifications tknId + pure NROk TCRN int | int < 20 -> pure $ NRErr QUOTA | otherwise -> do @@ -274,6 +277,10 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push getRandomBytes n = do gVar <- asks idsDrg atomically (C.pseudoRandomBytes n gVar) + cancelInvervalNotifications :: NtfTokenId -> m () + cancelInvervalNotifications tknId = + atomically (TM.lookupDelete tknId intervalNotifiers) + >>= mapM_ (uninterruptibleCancel . action) -- NReqCreate corrId tokenId smpQueue -> pure (corrId, "", NROk) -- do diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index bc7040b67..e39e7d3e1 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -2,30 +2,38 @@ {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} module Simplex.Messaging.Notifications.Server.Store where import Control.Concurrent.STM +import Control.Monad +import Data.ByteString.Char8 (ByteString) +import qualified Data.Map.Strict as M import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util ((<$$>)) +import Simplex.Messaging.Util (whenM, ($>>=)) data NtfStore = NtfStore { tokens :: TMap NtfTokenId NtfTknData, - tokenIds :: TMap DeviceToken NtfTokenId + tokenRegistrations :: TMap DeviceToken (TMap ByteString NtfTokenId) } newNtfStore :: STM NtfStore newNtfStore = do tokens <- TM.empty - tokenIds <- TM.empty - pure NtfStore {tokens, tokenIds} + tokenRegistrations <- TM.empty + pure NtfStore {tokens, tokenRegistrations} data NtfTknData = NtfTknData - { token :: DeviceToken, + { ntfTknId :: NtfTokenId, + token :: DeviceToken, tknStatus :: TVar NtfTknStatus, tknVerifyKey :: C.APublicVerifyKey, tknDhKeys :: C.KeyPair 'C.X25519, @@ -33,10 +41,10 @@ data NtfTknData = NtfTknData tknRegCode :: NtfRegCode } -mkNtfTknData :: NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData -mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do +mkNtfTknData :: NtfTokenId -> NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData +mkNtfTknData ntfTknId (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do tknStatus <- newTVar NTRegistered - pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode} + pure NtfTknData {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode} -- data NtfSubscriptionsStore = NtfSubscriptionsStore @@ -58,19 +66,57 @@ data NtfEntityRec (e :: NtfEntity) where NtfTkn :: NtfTknData -> NtfEntityRec 'Token NtfSub :: NtfSubData -> NtfEntityRec 'Subscription -data ANtfEntityRec = forall e. NtfEntityI e => NER (SNtfEntity e) (NtfEntityRec e) - -getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token)) -getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st) +getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData) +getNtfToken st tknId = TM.lookup tknId (tokens st) addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () -addNtfToken st tknId tkn@NtfTknData {token} = do +addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do TM.insert tknId tkn $ tokens st - TM.insert token tknId $ tokenIds st + TM.lookup token regs >>= \case + Just tIds -> TM.insert regKey tknId tIds + _ -> do + tIds <- TM.singleton regKey tknId + TM.insert token tIds regs + where + regs = tokenRegistrations st + regKey = C.toPubKey C.pubKeyBytes tknVerifyKey + +getNtfTokenRegistration :: NtfStore -> NewNtfEntity 'Token -> STM (Maybe NtfTknData) +getNtfTokenRegistration st (NewNtfTkn token tknVerifyKey _) = + TM.lookup token (tokenRegistrations st) + $>>= TM.lookup regKey + $>>= (`TM.lookup` tokens st) + where + regKey = C.toPubKey C.pubKeyBytes tknVerifyKey + +removeInactiveTokenRegistrations :: NtfStore -> NtfTknData -> STM [NtfTokenId] +removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} = + TM.lookup token (tokenRegistrations st) + >>= maybe (pure []) removeRegs + where + removeRegs :: TMap ByteString NtfTokenId -> STM [NtfTokenId] + removeRegs tknRegs = do + tIds <- filter ((/= tId) . snd) . M.assocs <$> readTVar tknRegs + forM_ tIds $ \(regKey, tId') -> do + TM.delete regKey tknRegs + TM.delete tId' $ tokens st + pure $ map snd tIds deleteNtfToken :: NtfStore -> NtfTokenId -> STM () deleteNtfToken st tknId = do - TM.lookupDelete tknId (tokens st) >>= mapM_ (\NtfTknData {token} -> TM.delete token $ tokenIds st) + TM.lookupDelete tknId (tokens st) + >>= mapM_ + ( \NtfTknData {token, tknVerifyKey} -> + TM.lookup token regs + >>= mapM_ + ( \tIds -> do + TM.delete (regKey tknVerifyKey) tIds + whenM (TM.null tIds) $ TM.delete token regs + ) + ) + where + regs = tokenRegistrations st + regKey = C.toPubKey C.pubKeyBytes -- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e)) -- getNtfRec st ent entId = case ent of diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 6809899f8..c2ce106ed 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -98,7 +98,7 @@ smpServer started = do m () serverThread s subQ subs clientSubs unsub = forever $ do atomically updateSubscribers - >>= fmap join . mapM endPreviousSubscriptions + $>>= endPreviousSubscriptions >>= mapM_ unsub where updateSubscribers :: STM (Maybe (QueueId, Client)) @@ -110,8 +110,7 @@ smpServer started = do else do yes <- readTVar $ connected c' pure $ if yes then Just (qId, c') else Nothing - TM.lookupInsert qId clnt (subs s) - >>= fmap join . mapM clientToBeNotified + TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s) endPreviousSubscriptions (qId, c) = do void . forkIO . atomically $ diff --git a/src/Simplex/Messaging/Server/QueueStore/STM.hs b/src/Simplex/Messaging/Server/QueueStore/STM.hs index 401e7ee30..d04c8f14b 100644 --- a/src/Simplex/Messaging/Server/QueueStore/STM.hs +++ b/src/Simplex/Messaging/Server/QueueStore/STM.hs @@ -18,7 +18,7 @@ import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (ifM) +import Simplex.Messaging.Util (ifM, ($>>=)) import UnliftIO.STM data QueueStore = QueueStore @@ -51,9 +51,8 @@ instance MonadQueueStore QueueStore STM where where getVar = case party of SRecipient -> TM.lookup qId queues - SSender -> TM.lookup qId senders >>= get - SNotifier -> TM.lookup qId notifiers >>= get - get = fmap join . mapM (`TM.lookup` queues) + SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues) + SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues) secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec) secureQueue QueueStore {queues} rId sKey = @@ -91,4 +90,4 @@ toResult :: Maybe a -> Either ErrorType a toResult = maybe (Left AUTH) Right withQueue :: RecipientId -> TMap RecipientId (TVar QueueRec) -> (TVar QueueRec -> STM (Maybe a)) -> STM (Either ErrorType a) -withQueue rId queues f = toResult <$> (TM.lookup rId queues >>= fmap join . mapM f) +withQueue rId queues f = toResult <$> TM.lookup rId queues $>>= f diff --git a/src/Simplex/Messaging/TMap.hs b/src/Simplex/Messaging/TMap.hs index b5bc01253..761a41c93 100644 --- a/src/Simplex/Messaging/TMap.hs +++ b/src/Simplex/Messaging/TMap.hs @@ -2,6 +2,7 @@ module Simplex.Messaging.TMap ( TMap, empty, singleton, + Simplex.Messaging.TMap.null, Simplex.Messaging.TMap.lookup, member, insert, @@ -30,6 +31,10 @@ singleton :: k -> a -> STM (TMap k a) singleton k v = newTVar $ M.singleton k v {-# INLINE singleton #-} +null :: TMap k a -> STM Bool +null m = M.null <$> readTVar m +{-# INLINE null #-} + lookup :: Ord k => k -> TMap k a -> STM (Maybe a) lookup k m = M.lookup k <$> readTVar m {-# INLINE lookup #-} diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index ea53bc60e..0d741bb43 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -65,3 +65,6 @@ whenM b a = ifM b a $ pure () unlessM :: Monad m => m Bool -> m () -> m () unlessM b = ifM b $ pure () {-# INLINE unlessM #-} + +($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b) +f $>>= g = f >>= fmap join . mapM g diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 9b6a9e0f6..752c97760 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -3,6 +3,9 @@ module AgentTests.NotificationTests where +-- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging) + +import Control.Concurrent (threadDelay) import Control.Monad.Except import qualified Data.Aeson as J import qualified Data.Aeson.Types as JT @@ -11,21 +14,36 @@ import qualified Data.ByteString.Base64.URL as U import Data.ByteString.Char8 (ByteString) import Data.Text.Encoding (encodeUtf8) import NtfClient -import SMPAgentClient (agentCfg, initAgentServers) +import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2) import Simplex.Messaging.Agent +import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..)) import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS +import Simplex.Messaging.Protocol (ErrorType (AUTH)) import Simplex.Messaging.Transport (ATransport) +import Simplex.Messaging.Util (tryE) +import System.Directory (removeFile) import Test.Hspec import UnliftIO.STM notificationTests :: ATransport -> Spec -notificationTests t = do - describe "Managing notification tokens" $ - it "should register and verify notification token" $ - withAPNSMockServer $ \apns -> withNtfServer t $ testNotificationToken apns +notificationTests t = + after_ (removeFile testDB) $ + describe "Managing notification tokens" $ do + it "should register and verify notification token" $ + withAPNSMockServer $ \apns -> + withNtfServer t $ testNotificationToken apns + it "should allow repeated registration with the same credentials" $ \_ -> + withAPNSMockServer $ \apns -> + withNtfServer t $ testNtfTokenRepeatRegistration apns + it "should allow the second registration with different credentials and delete the first after verification" $ \_ -> + withAPNSMockServer $ \apns -> + withNtfServer t $ testNtfTokenSecondRegistration apns + it "should re-register token when notification server is restarted" $ \_ -> + withAPNSMockServer $ \apns -> + testNtfTokenServerRestart t apns testNotificationToken :: APNSMockServer -> IO () testNotificationToken APNSMockServer {apnsQ} = do @@ -40,10 +58,110 @@ testNotificationToken APNSMockServer {apnsQ} = do liftIO $ sendApnsResponse APNSRespOk verifyNtfToken a tkn verification nonce enableNtfCron a tkn 30 + NTActive <- checkNtfToken a tkn + deleteNtfToken a tkn + -- agent deleted this token + Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn pure () pure () - where - (.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString - v .-> key = do - J.Object o <- pure v - liftEither . bimap INTERNAL (U.decodeLenient . encodeUtf8) $ JT.parseEither (J..: key) o + +(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString +v .-> key = do + J.Object o <- pure v + liftEither . bimap INTERNAL (U.decodeLenient . encodeUtf8) $ JT.parseEither (J..: key) o + +-- logCfg :: LogConfig +-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} + +testNtfTokenRepeatRegistration :: APNSMockServer -> IO () +testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do + -- setLogLevel LogError -- LogDebug + -- withGlobalLogging logCfg $ do + a <- getSMPAgentClient agentCfg initAgentServers + Right () <- runExceptT $ do + let tkn = DeviceToken PPApns "abcd" + registerNtfToken a tkn + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + liftIO $ sendApnsResponse APNSRespOk + registerNtfToken a tkn + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- + atomically $ readTBQueue apnsQ + _ <- ntfData' .-> "verification" + _ <- C.cbNonce <$> ntfData' .-> "nonce" + liftIO $ sendApnsResponse' APNSRespOk + -- can still use the first verification code, it is the same after decryption + verifyNtfToken a tkn verification nonce + enableNtfCron a tkn 30 + NTActive <- checkNtfToken a tkn + pure () + pure () + +testNtfTokenSecondRegistration :: APNSMockServer -> IO () +testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do + -- setLogLevel LogError -- LogDebug + -- withGlobalLogging logCfg $ do + a <- getSMPAgentClient agentCfg initAgentServers + a' <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers + Right () <- runExceptT $ do + let tkn = DeviceToken PPApns "abcd" + registerNtfToken a tkn + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + liftIO $ sendApnsResponse APNSRespOk + verifyNtfToken a tkn verification nonce + + registerNtfToken a' tkn + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- + atomically $ readTBQueue apnsQ + verification' <- ntfData' .-> "verification" + nonce' <- C.cbNonce <$> ntfData' .-> "nonce" + liftIO $ sendApnsResponse' APNSRespOk + + -- at this point the first token is still active + NTActive <- checkNtfToken a tkn + -- and the second is not yet verified + NTConfirmed <- checkNtfToken a' tkn + -- now the second token registration is verified + verifyNtfToken a' tkn verification' nonce' + -- the first registration is removed + Left (NTF AUTH) <- tryE $ checkNtfToken a tkn + -- and the second is active + NTActive <- checkNtfToken a' tkn + enableNtfCron a' tkn 30 + pure () + pure () + +testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO () +testNtfTokenServerRestart t APNSMockServer {apnsQ} = do + a <- getSMPAgentClient agentCfg initAgentServers + let tkn = DeviceToken PPApns "abcd" + Right ntfData <- withNtfServer t . runExceptT $ do + registerNtfToken a tkn + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <- + atomically $ readTBQueue apnsQ + liftIO $ sendApnsResponse APNSRespOk + pure ntfData + -- the new agent is created as otherwise when running the tests in CI the old agent was keeping the connection to the server + threadDelay 1000000 + disconnectAgentClient a + a' <- getSMPAgentClient agentCfg initAgentServers + -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, + -- so that repeat verification happens without restarting the clients, when notification arrives + Right () <- withNtfServer t . runExceptT $ do + verification <- ntfData .-> "verification" + nonce <- C.cbNonce <$> ntfData .-> "nonce" + Left (NTF AUTH) <- tryE $ verifyNtfToken a' tkn verification nonce + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <- + atomically $ readTBQueue apnsQ + verification' <- ntfData' .-> "verification" + nonce' <- C.cbNonce <$> ntfData' .-> "nonce" + liftIO $ sendApnsResponse' APNSRespOk + verifyNtfToken a' tkn verification' nonce' + NTActive <- checkNtfToken a' tkn + enableNtfCron a' tkn 30 + pure ()