diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 715ea9a99..125a8e5d2 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -105,7 +105,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge receiveSMP :: m () receiveSMP = forever $ do - (_srv, _sessId, _ntfId, msg) <- atomically $ readTBQueue msgQ + (srv, _sessId, ntfId, msg) <- atomically $ readTBQueue msgQ case msg of SMP.NMSG -> do -- check when the last NMSG was received from this queue @@ -114,7 +114,11 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge -- decide whether it should be sent as hidden or visible -- construct and possibly encrypt notification -- send it - pure () + NtfPushServer {pushQ} <- asks pushServer + st <- asks store + atomically $ + findNtfSubscriptionToken st (SMPQueueNtf srv ntfId) + >>= mapM_ (\tkn -> writeTBQueue pushQ (tkn, PNMessage srv ntfId)) _ -> pure () pure () @@ -209,35 +213,42 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do else VRFailed NtfCmd SToken c -> do t_ <- atomically $ getNtfToken st entId - pure $ case t_ of - Just t@NtfTknData {tknVerifyKey} - | verifyCmdSignature sig_ signed tknVerifyKey -> verifiedTknCmd t c - | otherwise -> VRFailed - _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed - NtfCmd SSubscription c@(SNEW sub@(NewNtfSub tknId _ _)) -> do - -- TODO move active token check here to differentiate error - r_ <- atomically $ findNtfSubscription st sub - pure $ case r_ of - Just (NtfTknData {tknVerifyKey}, sub_) -> - if verifyCmdSignature sig_ signed tknVerifyKey - then case sub_ of - Just s@NtfSubData {tokenId} - | tknId == tokenId -> verifiedSubCmd s c - | otherwise -> VRFailed - _ -> VRVerified (NtfReqNew corrId (ANE SSubscription sub)) - else VRFailed - _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed + verifyToken t_ (`verifiedTknCmd` c) + NtfCmd SSubscription c@(SNEW sub@(NewNtfSub tknId smpQueue _)) -> do + s_ <- atomically $ findNtfSubscription st smpQueue + case s_ of + Nothing -> do + -- TODO move active token check here to differentiate error + t_ <- atomically $ getActiveNtfToken st tknId + verifyToken' t_ $ VRVerified (NtfReqNew corrId (ANE SSubscription sub)) + Just s@NtfSubData {tokenId = subTknId} -> + if subTknId == tknId + then do + -- TODO move active token check here to differentiate error + t_ <- atomically $ getActiveNtfToken st subTknId + verifyToken' t_ $ verifiedSubCmd s c + else pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed NtfCmd SSubscription c -> do - -- TODO move active token check here to differentiate error - r_ <- atomically $ getNtfSubscription st entId - pure $ case r_ of - Just (s, NtfTknData {tknVerifyKey}) - | verifyCmdSignature sig_ signed tknVerifyKey -> verifiedSubCmd s c - | otherwise -> VRFailed - _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed + s_ <- atomically $ getNtfSubscription st entId + case s_ of + Just s@NtfSubData {tokenId = subTknId} -> do + -- TODO move active token check here to differentiate error + t_ <- atomically $ getActiveNtfToken st subTknId + verifyToken' t_ $ verifiedSubCmd s c + _ -> pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed where verifiedTknCmd t c = VRVerified (NtfReqCmd SToken (NtfTkn t) (corrId, entId, c)) verifiedSubCmd s c = VRVerified (NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c)) + verifyToken :: Maybe NtfTknData -> (NtfTknData -> VerificationResult) -> m VerificationResult + verifyToken t_ positiveVerificationResult = + pure $ case t_ of + Just t@NtfTknData {tknVerifyKey} -> + if verifyCmdSignature sig_ signed tknVerifyKey + then positiveVerificationResult t + else VRFailed + _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed + verifyToken' :: Maybe NtfTknData -> VerificationResult -> m VerificationResult + verifyToken' t_ = verifyToken t_ . const client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m () client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ} NtfPushServer {pushQ, intervalNotifiers} = diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 37a0b4f37..3497cad93 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -20,14 +20,14 @@ import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Protocol (NtfPrivateSignKey) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (whenM, ($>>=), (<$$>)) +import Simplex.Messaging.Util (whenM, ($>>=)) data NtfStore = NtfStore { tokens :: TMap NtfTokenId NtfTknData, tokenRegistrations :: TMap DeviceToken (TMap ByteString NtfTokenId), subscriptions :: TMap NtfSubscriptionId NtfSubData, tokenSubscriptions :: TMap NtfTokenId (TVar (Set NtfSubscriptionId)), - subscriptionLookup :: TMap (NtfTokenId, SMPQueueNtf) NtfSubscriptionId + subscriptionLookup :: TMap SMPQueueNtf NtfSubscriptionId } newNtfStore :: STM NtfStore @@ -127,20 +127,19 @@ deleteNtfToken st tknId = do regs = tokenRegistrations st regKey = C.toPubKey C.pubKeyBytes -getNtfSubscription :: NtfStore -> NtfSubscriptionId -> STM (Maybe (NtfSubData, NtfTknData)) +getNtfSubscription :: NtfStore -> NtfSubscriptionId -> STM (Maybe NtfSubData) getNtfSubscription st subId = TM.lookup subId (subscriptions st) - $>>= \sub@NtfSubData {tokenId} -> - (sub,) <$$> getActiveNtfToken st tokenId -findNtfSubscription :: NtfStore -> NewNtfEntity 'Subscription -> STM (Maybe (NtfTknData, Maybe NtfSubData)) -findNtfSubscription st (NewNtfSub tknId smpQueue _) = do - getActiveNtfToken st tknId >>= mapM (\tkn -> (tkn,) <$> getSub) - where - getSub :: STM (Maybe NtfSubData) - getSub = - TM.lookup (tknId, smpQueue) (subscriptionLookup st) - $>>= (`TM.lookup` subscriptions st) +findNtfSubscription :: NtfStore -> SMPQueueNtf -> STM (Maybe NtfSubData) +findNtfSubscription st smpQueue = do + TM.lookup smpQueue (subscriptionLookup st) + $>>= \subId -> TM.lookup subId (subscriptions st) + +findNtfSubscriptionToken :: NtfStore -> SMPQueueNtf -> STM (Maybe NtfTknData) +findNtfSubscriptionToken st smpQueue = do + findNtfSubscription st smpQueue + $>>= \NtfSubData {tokenId} -> getActiveNtfToken st tokenId getActiveNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData) getActiveNtfToken st tknId = @@ -160,7 +159,7 @@ addNtfSubscription st subId sub@NtfSubData {smpQueue, tokenId} = insertSub ts = do modifyTVar' ts $ insert subId TM.insert subId sub $ subscriptions st - TM.insert (tokenId, smpQueue) subId (subscriptionLookup st) + TM.insert smpQueue subId (subscriptionLookup st) -- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e)) -- getNtfRec st ent entId = case ent of