diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index dae98d301..a67d03b9d 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -79,8 +79,6 @@ import UnliftIO.STM import GHC.Conc (listThreads) #endif -import qualified Data.ByteString.Base64 as B64 - runNtfServer :: NtfServerConfig -> IO () runNtfServer cfg = do started <- newEmptyTMVarIO diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs index 224aeaec7..8a8c475ac 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -33,11 +33,11 @@ import Data.Containers.ListUtils (nubOrd) import Data.Either (fromRight) import Data.Functor (($>)) import Data.Int (Int64) -import Data.List (foldl', intercalate) +import Data.List (findIndex, foldl') import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, mapMaybe) +import Data.Maybe (fromMaybe, mapMaybe) import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T @@ -587,9 +587,10 @@ toLastNtf (qRow :. (ts, nonce, Binary encMeta)) = importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> IO (Int64, Int64, Int64) importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do - (tCnt, tIds) <- importTokens - sCnt <- importSubscriptions tIds - nCnt <- importLastNtfs + (tIds, tCnt) <- importTokens + subLookup <- readTVarIO $ subscriptionLookup stmStore + sCnt <- importSubscriptions tIds subLookup + nCnt <- importLastNtfs tIds subLookup pure (tCnt, sCnt, nCnt) where importTokens = do @@ -597,59 +598,65 @@ importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do tokens <- filterTokens allTokens let skipped = length allTokens - length tokens when (skipped /= 0) $ putStrLn $ "Total skipped tokens " <> show skipped - tCnt <- withConnection s $ \db -> foldM (insertToken db) 0 tokens - void $ checkCount "token" (length tokens) tCnt + -- uncomment this line instead of the next to import tokens one by one. + -- tCnt <- withConnection s $ \db -> foldM (importTkn db) 0 tokens + tRows <- mapM (fmap ntfTknToRow . mkTknRec) tokens + tCnt <- withConnection s $ \db -> DB.executeMany db insertNtfTknQuery tRows let tokenIds = S.fromList $ map (\NtfTknData {ntfTknId} -> ntfTknId) tokens - pure (tCnt, tokenIds) + (tokenIds,) <$> checkCount "token" (length tokens) tCnt where filterTokens tokens = do let deviceTokens = foldl' (\m t -> M.alter (Just . (t :) . fromMaybe []) (tokenKey t) m) M.empty tokens tokenSubs <- readTVarIO (tokenSubscriptions stmStore) filterM (keepTokenRegistration deviceTokens tokenSubs) tokens tokenKey NtfTknData {token, tknVerifyKey} = strEncode token <> ":" <> C.toPubKey C.pubKeyBytes tknVerifyKey - keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, token, tknStatus} = + keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, tknStatus} = case M.lookup (tokenKey tkn) deviceTokens of Just ts - | length ts >= 2 -> + | length ts < 2 -> pure True + | otherwise -> readTVarIO tknStatus >>= \case NTConfirmed -> do - anyActive <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> (NTActive ==) <$> readTVarIO tknStatus') ts - noSubs <- S.null <$> maybe (pure S.empty) readTVarIO (M.lookup ntfTknId tokenSubs) - if anyActive - then ( - if noSubs - then False <$ putStrLn ("Skipped inactive token " <> enc ntfTknId <> " (no subscriptions)") - else pure True - ) + hasSubs <- maybe (pure False) (\v -> not . S.null <$> readTVarIO v) $ M.lookup ntfTknId tokenSubs + if hasSubs + then pure True else do - let noSubsStr = if noSubs then " no subscriptions" else " has subscriptions" - putStrLn $ "Error: more than one registration for token " <> enc ntfTknId <> " " <> show token <> noSubsStr - pure True + anyActive <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> (NTActive ==) <$> readTVarIO tknStatus') ts + if anyActive + then False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId) + else case findIndex (\NtfTknData {ntfTknId = tId} -> tId == ntfTknId) ts of + Just 0 -> pure True -- keeping the first token + Just _ -> False <$ putStrLn ("Skipped duplicate inactive token " <> enc ntfTknId <> " (no active token)") + Nothing -> True <$ putStrLn "Error: no device token in the list" _ -> pure True - | otherwise -> pure True Nothing -> True <$ putStrLn "Error: no device token in lookup map" - insertToken db !n tkn@NtfTknData {ntfTknId} = do - tknRow <- ntfTknToRow <$> mkTknRec tkn - (DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) -> - putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n - importSubscriptions tIds = do - allSubs <- M.elems <$> readTVarIO (subscriptions stmStore) - let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs - skipped = length allSubs - length subs - when (skipped /= 0) $ putStrLn $ "Skipped subscriptions (no tokens) " <> show skipped + -- importTkn db !n tkn@NtfTknData {ntfTknId} = do + -- tknRow <- ntfTknToRow <$> mkTknRec tkn + -- (DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) -> + -- putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n + importSubscriptions :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64 + importSubscriptions tIds subLookup = do + subs <- filterSubs . M.elems =<< readTVarIO (subscriptions stmStore) srvIds <- importServers subs putStrLn $ "Importing " <> show (length subs) <> " subscriptions..." - -- uncomment this line instead of the next 2 lines to import subs one by one. - (sCnt, missingTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs - -- sCnt <- foldM (importSubs srvIds) 0 $ toChunks 100000 subs - -- let missingTkns = M.empty - putStrLn $ "Imported " <> show sCnt <> " subscriptions" - unless (M.null missingTkns) $ do - putStrLn $ show (M.size missingTkns) <> " missing tokens:" - forM_ (M.assocs missingTkns) $ \(tId, sIds) -> - putStrLn $ "Token " <> enc tId <> " " <> show (length sIds) <> " subscriptions: " <> intercalate ", " (map enc sIds) + -- uncomment this line instead of the next to import subs one by one. + -- (sCnt, errTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs + sCnt <- foldM (importSubs srvIds) 0 $ toChunks 500000 subs checkCount "subscription" (length subs) sCnt where + filterSubs allSubs = do + let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs + skipped = length allSubs - length subs + when (skipped /= 0) $ putStrLn $ "Skipped " <> show skipped <> " subscriptions of missing tokens" + let (removedSubTokens, removeSubs, dupQueues) = foldl' addSubToken (S.empty, S.empty, S.empty) subs + unless (null removeSubs) $ putStrLn $ "Skipped " <> show (S.size removeSubs) <> " duplicate subscriptions of " <> show (S.size removedSubTokens) <> " tokens for " <> show (S.size dupQueues) <> " queues" + pure $ filter (\NtfSubData {ntfSubId} -> S.notMember ntfSubId removeSubs) subs + where + addSubToken acc@(!stIds, !sIds, !qs) NtfSubData {ntfSubId, smpQueue, tokenId} = + case M.lookup smpQueue subLookup of + Just sId | sId /= ntfSubId -> + (S.insert tokenId stIds, S.insert ntfSubId sIds, S.insert smpQueue qs) + _ -> acc importSubs srvIds !n subs = do rows <- mapM (ntfSubRow srvIds) subs cnt <- withConnection s $ \db -> DB.executeMany db insertNtfSubQuery $ L.toList rows @@ -657,19 +664,19 @@ importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" hFlush stdout pure n' - importSub db srvIds (!n, !missingTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do - subRow <- ntfSubRow srvIds sub - E.try (DB.execute db insertNtfSubQuery subRow) >>= \case - Right i -> do - let n' = n + i - when (n' `mod` 100000 == 0) $ do - putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" - hFlush stdout - pure (n', missingTkns) - Left (e :: E.SomeException) -> do - when (n `mod` 100000 == 0) $ putStrLn "" - putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e - pure (n, M.alter (Just . (sId :) . fromMaybe []) tokenId missingTkns) + -- importSub db srvIds (!n, !errTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do + -- subRow <- ntfSubRow srvIds sub + -- E.try (DB.execute db insertNtfSubQuery subRow) >>= \case + -- Right i -> do + -- let n' = n + i + -- when (n' `mod` 100000 == 0) $ do + -- putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" + -- hFlush stdout + -- pure (n', errTkns) + -- Left (e :: E.SomeException) -> do + -- when (n `mod` 100000 == 0) $ putStrLn "" + -- putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e + -- pure (n, M.alter (Just . maybe [sId] (sId :)) tokenId errTkns) ntfSubRow srvIds sub = case M.lookup srv srvIds of Just sId -> ntfSubToRow sId <$> mkSubRec sub Nothing -> E.throwIO $ userError $ "no matching server ID for server " <> show srv @@ -682,19 +689,32 @@ importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do where srvQuery = "INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash) VALUES (?, ?, ?) RETURNING smp_server_id" srvs = nubOrd $ map ntfSubServer subs - importLastNtfs = do - subLookup <- readTVarIO $ subscriptionLookup stmStore - ntfRows <- fmap concat . mapM (lastNtfRows subLookup) . M.assocs =<< readTVarIO (tokenLastNtfs stmStore) + importLastNtfs :: S.Set NtfTokenId -> M.Map SMPQueueNtf NtfSubscriptionId -> IO Int64 + importLastNtfs tIds subLookup = do + ntfs <- readTVarIO (tokenLastNtfs stmStore) + ntfRows <- filterLastNtfRows ntfs nCnt <- withConnection s $ \db -> DB.executeMany db lastNtfQuery ntfRows checkCount "last notification" (length ntfRows) nCnt where lastNtfQuery = "INSERT INTO last_notifications(token_id, subscription_id, sent_at, nmsg_nonce, nmsg_data) VALUES (?,?,?,?,?)" - lastNtfRows :: M.Map SMPQueueNtf NtfSubscriptionId -> (NtfTokenId, TVar (NonEmpty PNMessageData)) -> IO [(NtfTokenId, NtfSubscriptionId, SystemTime, C.CbNonce, Binary ByteString)] - lastNtfRows subLookup (tId, ntfs) = fmap catMaybes . mapM ntfRow . L.toList =<< readTVarIO ntfs + filterLastNtfRows ntfs = do + (skippedTkns, ntfCnt, (skippedQueues, ntfRows)) <- foldM lastNtfRows (S.empty, 0, (S.empty, [])) $ M.assocs ntfs + let skipped = ntfCnt - length ntfRows + when (skipped /= 0) $ putStrLn $ "Skipped last notifications " <> show skipped <> " for " <> show (S.size skippedTkns) <> " missing tokens and " <> show (S.size skippedQueues) <> " missing subscriptions with token present" + pure ntfRows + lastNtfRows (!stIds, !cnt, !acc) (tId, ntfVar) = do + ntfs <- L.toList <$> readTVarIO ntfVar + let cnt' = cnt + length ntfs + pure $ + if S.member tId tIds + then (stIds, cnt', foldl' ntfRow acc ntfs) + else (S.insert tId stIds, cnt', acc) where - ntfRow PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of - Just ntfSubId -> pure $ Just (tId, ntfSubId, ntfTs, nmsgNonce, Binary encNMsgMeta) - Nothing -> Nothing <$ putStrLn ("Error: no subscription " <> show smpQueue <> " for notification of token " <> enc tId) + ntfRow (!qs, !rows) PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of + Just ntfSubId -> + let row = (tId, ntfSubId, ntfTs, nmsgNonce, Binary encNMsgMeta) + in (qs, row : rows) + Nothing -> (S.insert smpQueue qs, rows) checkCount name expected inserted | fromIntegral expected == inserted = do putStrLn $ "Imported " <> show inserted <> " " <> name <> "s." @@ -711,12 +731,21 @@ exportNtfDbStore NtfPostgresStore {dbStoreLog = Nothing} _ = exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFile = (,,) <$> exportTokens <*> exportSubscriptions <*> exportLastNtfs where - exportTokens = - withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn -> + exportTokens = do + tCnt <- withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn -> logCreateToken sl (rowToNtfTkn tkn) $> (i + 1) - exportSubscriptions = - withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub -> - logCreateSubscription sl (toNtfSub sub) $> (i + 1) + putStrLn $ "Exported " <> show tCnt <> " tokens" + pure tCnt + exportSubscriptions = do + sCnt <- withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub -> do + let i' = i + 1 + logCreateSubscription sl (toNtfSub sub) + when (i' `mod` 500000 == 0) $ do + putStr $ "Exported " <> show i' <> " subscriptions" <> "\r" + hFlush stdout + pure i' + putStrLn $ "Exported " <> show sCnt <> " subscriptions" + pure sCnt where ntfSubQuery = [sql| diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index cefe1720b..ce69e5c11 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -570,7 +570,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag threadDelay 500000 suspendAgent alice 0 closeDBStore store - callCommand "sync" + threadDelay 500000 >> callCommand "sync" >> threadDelay 500000 putStrLn "before opening the database from another agent" -- aliceNtf client doesn't have subscription and is allowed to get notification message @@ -578,7 +578,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag (Just SMPMsgMeta {msgFlags = MsgFlags True}) :| _ <- getConnectionMessages aliceNtf [cId] pure () - callCommand "sync" + threadDelay 500000 >> callCommand "sync" >> threadDelay 500000 putStrLn "after closing the database in another agent" reopenDBStore store foregroundAgent alice diff --git a/tests/CLITests.hs b/tests/CLITests.hs index e5fb784ee..51d5d6c68 100644 --- a/tests/CLITests.hs +++ b/tests/CLITests.hs @@ -1,6 +1,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NamedFieldPuns #-} module CLITests where @@ -8,6 +9,7 @@ import AgentTests.FunctionalAPITests (runRight_) import Control.Logger.Simple import Control.Monad import qualified Crypto.PubKey.RSA as RSA +import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy as BL import qualified Data.HashMap.Strict as HM import Data.Ini (Ini (..), lookupValue, readIniFile, writeIniFile) @@ -41,8 +43,11 @@ import UnliftIO.Concurrent (threadDelay) import UnliftIO.Exception (bracket) #if defined(dbServerPostgres) -import NtfClient (ntfTestServerDBConnectInfo) +import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.Types (Query (..)) +import NtfClient (ntfTestServerDBConnectInfo, ntfTestServerDBConnstr, ntfTestStoreDBOpts) import SMPClient (postgressBracket) +import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Notifications.Server.Main #endif @@ -77,7 +82,7 @@ cliTests = do it "with store log, no password" $ smpServerTest True False it "static files" smpServerTestStatic #if defined(dbServerPostgres) - aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $ + around_ (postgressBracket ntfTestServerDBConnectInfo) $ before_ (createNtfSchema ntfTestServerDBConnectInfo ntfTestStoreDBOpts) $ describe "Ntf server CLI" $ do it "should initialize, start and delete the server (no store log)" $ ntfServerTest False it "should initialize, start and delete the server (with store log)" $ ntfServerTest True @@ -192,9 +197,15 @@ smpServerTestStatic = do in map (X.signedObject . X.getSigned) cc #if defined(dbServerPostgres) +createNtfSchema :: PSQL.ConnectInfo -> DBOpts -> IO () +createNtfSchema connInfo DBOpts {schema} = do + db <- PSQL.connect connInfo + void $ PSQL.execute_ db $ Query $ "CREATE SCHEMA " <> schema + PSQL.close db + ntfServerTest :: Bool -> IO () ntfServerTest storeLog = do - capture_ (withArgs (["init"] <> ["--disable-store-log" | not storeLog]) $ ntfServerCLI ntfCfgPath ntfLogPath) + capture_ (withArgs (["init", "--database=" <> B.unpack ntfTestServerDBConnstr] <> ["--disable-store-log" | not storeLog]) $ ntfServerCLI ntfCfgPath ntfLogPath) >>= (`shouldSatisfy` (("Server initialized, you can modify configuration in " <> ntfCfgPath <> "/ntf-server.ini") `isPrefixOf`)) Right ini <- readIniFile $ ntfCfgPath <> "/ntf-server.ini" lookupValue "STORE_LOG" "enable" ini `shouldBe` Right (if storeLog then "on" else "off")