diff --git a/simplexmq.cabal b/simplexmq.cabal index 2590afc96..43eb6b31c 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -216,15 +216,6 @@ library Simplex.FileTransfer.Server.Stats Simplex.FileTransfer.Server.Store Simplex.FileTransfer.Server.StoreLog - Simplex.Messaging.Notifications.Server - Simplex.Messaging.Notifications.Server.Control - Simplex.Messaging.Notifications.Server.Env - Simplex.Messaging.Notifications.Server.Main - Simplex.Messaging.Notifications.Server.Push.APNS - Simplex.Messaging.Notifications.Server.Push.APNS.Internal - Simplex.Messaging.Notifications.Server.Stats - Simplex.Messaging.Notifications.Server.Store - Simplex.Messaging.Notifications.Server.StoreLog Simplex.Messaging.Server Simplex.Messaging.Server.CLI Simplex.Messaging.Server.Control @@ -257,6 +248,18 @@ library if flag(server_postgres) exposed-modules: + Simplex.Messaging.Notifications.Server + Simplex.Messaging.Notifications.Server.Control + Simplex.Messaging.Notifications.Server.Env + Simplex.Messaging.Notifications.Server.Main + Simplex.Messaging.Notifications.Server.Push.APNS + Simplex.Messaging.Notifications.Server.Push.APNS.Internal + Simplex.Messaging.Notifications.Server.Stats + Simplex.Messaging.Notifications.Server.Store + Simplex.Messaging.Notifications.Server.Store.Migrations + Simplex.Messaging.Notifications.Server.Store.Postgres + Simplex.Messaging.Notifications.Server.Store.Types + Simplex.Messaging.Notifications.Server.StoreLog Simplex.Messaging.Server.QueueStore.Postgres Simplex.Messaging.Server.QueueStore.Postgres.Migrations other-modules: @@ -340,6 +343,8 @@ library , sqlcipher-simple ==0.4.* if flag(server_postgres) cpp-options: -DdbServerPostgres + build-depends: + hex-text ==0.1.* if impl(ghc >= 9.6.2) build-depends: bytestring ==0.11.* @@ -352,6 +357,10 @@ library executable ntf-server if flag(client_library) buildable: False + if flag(server_postgres) + cpp-options: -DdbServerPostgres + else + buildable: False main-is: Main.hs other-modules: Paths_simplexmq @@ -444,7 +453,6 @@ test-suite simplexmq-test AgentTests.EqInstances AgentTests.FunctionalAPITests AgentTests.MigrationTests - AgentTests.NotificationTests AgentTests.ServerChoice AgentTests.ShortLinkTests CLITests @@ -460,8 +468,6 @@ test-suite simplexmq-test CoreTests.UtilTests CoreTests.VersionRangeTests FileDescriptionTests - NtfClient - NtfServerTests RemoteControl ServerTests SMPAgentClient @@ -484,6 +490,9 @@ test-suite simplexmq-test AgentTests.SQLiteTests if flag(server_postgres) other-modules: + AgentTests.NotificationTests + NtfClient + NtfServerTests ServerTests.SchemaDump hs-source-dirs: tests @@ -537,6 +546,8 @@ test-suite simplexmq-test , warp-tls , yaml default-language: Haskell2010 + if flag(server_postgres) + cpp-options: -DdbServerPostgres if flag(client_postgres) cpp-options: -DdbPostgres else @@ -550,5 +561,3 @@ test-suite simplexmq-test if flag(client_postgres) || flag(server_postgres) build-depends: postgresql-simple ==0.7.* - if flag(server_postgres) - cpp-options: -DdbServerPostgres diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 1a7a67806..dab0a4040 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -412,14 +412,22 @@ removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM () removeSubscription = removeSub_ . srvSubs {-# INLINE removeSubscription #-} +removePendingSub :: SMPClientAgent -> SMPServer -> SMPSub -> STM () +removePendingSub = removeSub_ . pendingSrvSubs +{-# INLINE removePendingSub #-} + removeSub_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSub -> STM () removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s) +removeSubscriptions :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM () +removeSubscriptions = removeSubs_ . srvSubs +{-# INLINE removeSubscriptions #-} + removePendingSubs :: SMPClientAgent -> SMPServer -> SMPSubParty -> [QueueId] -> STM () removePendingSubs = removeSubs_ . pendingSrvSubs {-# INLINE removePendingSubs #-} -removeSubs_ :: TMap SMPServer (TMap SMPSub C.APrivateAuthKey) -> SMPServer -> SMPSubParty -> [QueueId] -> STM () +removeSubs_ :: TMap SMPServer (TMap SMPSub s) -> SMPServer -> SMPSubParty -> [QueueId] -> STM () removeSubs_ subs srv party qs = TM.lookup srv subs >>= mapM_ (`modifyTVar'` (`M.withoutKeys` ss)) where ss = S.fromList $ map (party,) qs diff --git a/src/Simplex/Messaging/Encoding.hs b/src/Simplex/Messaging/Encoding.hs index 15718e297..ef0033dfb 100644 --- a/src/Simplex/Messaging/Encoding.hs +++ b/src/Simplex/Messaging/Encoding.hs @@ -143,7 +143,7 @@ instance Encoding Large where instance Encoding SystemTime where smpEncode = smpEncode . systemSeconds {-# INLINE smpEncode #-} - smpP = MkSystemTime <$> smpP <*> pure 0 + smpP = (`MkSystemTime` 0) <$> smpP {-# INLINE smpP #-} _smpP :: Encoding a => Parser a diff --git a/src/Simplex/Messaging/Encoding/String.hs b/src/Simplex/Messaging/Encoding/String.hs index 97e7d087b..c963ec99a 100644 --- a/src/Simplex/Messaging/Encoding/String.hs +++ b/src/Simplex/Messaging/Encoding/String.hs @@ -140,7 +140,7 @@ instance StrEncoding Int64 where instance StrEncoding SystemTime where strEncode = strEncode . systemSeconds - strP = MkSystemTime <$> strP <*> pure 0 + strP = (`MkSystemTime` 0) <$> strP instance StrEncoding UTCTime where strEncode = B.pack . iso8601Show diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index 2167814c1..b23bd4e91 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -517,8 +517,11 @@ instance Encoding NtfSubStatus where instance StrEncoding NtfSubStatus where strEncode = smpEncode + {-# INLINE strEncode #-} strP = smpP + {-# INLINE strP #-} +-- TODO [ntfdb] check what happens in agent when token in not yet registered data NtfTknStatus = -- | Token created in DB NTNew @@ -534,6 +537,17 @@ data NtfTknStatus NTExpired deriving (Eq, Show) +allowNtfSubCommands :: NtfTknStatus -> Bool +allowNtfSubCommands = \case + NTNew -> False + NTRegistered -> False + -- TODO [ntfdb] we could have separate statuses to show whether it became invalid + -- after verification (allow commands) or before (do not allow) + NTInvalid _ -> True + NTConfirmed -> False + NTActive -> True + NTExpired -> True + instance Encoding NtfTknStatus where smpEncode = \case NTNew -> "NEW" diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 84aebf9db..c1f1aa9ab 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -20,10 +20,8 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Data.Bifunctor (first) -import qualified Data.ByteString.Builder as BLD import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import qualified Data.ByteString.Lazy.Char8 as LB import Data.Either (partitionEithers) import Data.Functor (($>)) import Data.IORef @@ -33,7 +31,8 @@ import Data.List (intercalate, partition, sort) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes) +import Data.Maybe (mapMaybe) +import qualified Data.Set as S import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime) @@ -51,14 +50,16 @@ import Simplex.Messaging.Notifications.Server.Control import Simplex.Messaging.Notifications.Server.Env import Simplex.Messaging.Notifications.Server.Push.APNS (PushNotification (..), PushProviderError (..)) import Simplex.Messaging.Notifications.Server.Stats -import Simplex.Messaging.Notifications.Server.Store -import Simplex.Messaging.Notifications.Server.StoreLog +import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore, TokenNtfMessageRecord (..), stmStoreTokenLastNtf) +import Simplex.Messaging.Notifications.Server.Store.Postgres +import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Transport import Simplex.Messaging.Protocol (EntityId (..), ErrorType (..), ProtocolServer (host), SMPServer, SignedTransmission, Transmission, pattern NoEntity, pattern SMPServer, encodeTransmission, tGet, tPut) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server import Simplex.Messaging.Server.Control (CPClientRole (..)) -import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate) +import Simplex.Messaging.Server.Env.STM (StartOptions (..)) +import Simplex.Messaging.Server.QueueStore (getSystemDate) import Simplex.Messaging.Server.Stats (PeriodStats (..), PeriodStatCounts (..), periodStatCounts, updatePeriodStats) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM @@ -66,7 +67,7 @@ import Simplex.Messaging.Transport (ATransport (..), THandle (..), THandleAuth ( import Simplex.Messaging.Transport.Buffer (trimCR) import Simplex.Messaging.Transport.Server (AddHTTP, runTransportServer, runLocalTCPServer) import Simplex.Messaging.Util -import System.Exit (exitFailure) +import System.Exit (exitFailure, exitSuccess) import System.IO (BufferMode (..), hClose, hPrint, hPutStrLn, hSetBuffering, hSetNewlineMode, universalNewlineMode) import System.Mem.Weak (deRefWeak) import UnliftIO (IOMode (..), UnliftIO, askUnliftIO, async, uninterruptibleCancel, unliftIO, withFile) @@ -78,6 +79,8 @@ import UnliftIO.STM import GHC.Conc (listThreads) #endif +import qualified Data.ByteString.Base64 as B64 + runNtfServer :: NtfServerConfig -> IO () runNtfServer cfg = do started <- newEmptyTMVarIO @@ -89,11 +92,14 @@ runNtfServerBlocking started cfg = runReaderT (ntfServer cfg started) =<< newNtf type M a = ReaderT NtfEnv IO a ntfServer :: NtfServerConfig -> TMVar Bool -> M () -ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do - restoreServerLastNtfs +ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg, startOptions} started = do restoreServerStats s <- asks subscriber ps <- asks pushServer + when (maintenance startOptions) $ do + liftIO $ putStrLn "Server started in 'maintenance' mode, exiting" + stopServer + liftIO $ exitSuccess resubscribe s raceAny_ (ntfSubscriber s : ntfPush ps : map runServer transports <> serverStatsThread_ cfg <> controlPortThread_ cfg) `finally` stopServer where @@ -124,7 +130,7 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do logInfo "Server stopped" saveServer :: M () - saveServer = withNtfLog closeStoreLog >> saveServerLastNtfs >> saveServerStats + saveServer = asks store >>= liftIO . closeNtfDbStore >> saveServerStats serverStatsThread_ :: NtfServerConfig -> [M ()] serverStatsThread_ NtfServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} = @@ -330,10 +336,23 @@ ntfServer cfg@NtfServerConfig {transports, transportConfig = tCfg} started = do resubscribe :: NtfSubscriber -> M () resubscribe NtfSubscriber {newSubQ} = do logInfo "Preparing SMP resubscriptions..." - subs <- readTVarIO =<< asks (subscriptions . store) - subs' <- filterM (fmap ntfShouldSubscribe . readTVarIO . subStatus) $ M.elems subs - atomically . writeTBQueue newSubQ $ map NtfSub subs' - logInfo $ "SMP resubscriptions queued (" <> tshow (length subs') <> " subscriptions)" + st <- asks store + batchSize <- asks $ subsBatchSize . config + liftIO $ do + srvs <- getUsedSMPServers st + count <- foldM (subscribeSrvSubs st batchSize) (0 :: Int) srvs + logInfo $ "SMP resubscriptions queued (" <> tshow count <> " subscriptions)" + where + subscribeSrvSubs st batchSize !count srv = do + (n, subs_) <- + foldNtfSubscriptions st srv batchSize (0, []) $ \(!i, subs) sub -> + if length subs == batchSize + then write (L.fromList subs) $> (i + 1, []) + else pure (i + 1, sub : subs) + mapM_ write $ L.nonEmpty subs_ + pure $ count + n + where + write subs = atomically $ writeTBQueue newSubQ (srv, subs) ntfSubscriber :: NtfSubscriber -> M () ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = do @@ -341,44 +360,44 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge where subscribe :: M () subscribe = forever $ do - subs <- atomically (readTBQueue newSubQ) - let ss = L.groupAllWith server subs - batchSize <- asks $ subsBatchSize . config - forM_ ss $ \serverSubs -> do - let srv = server $ L.head serverSubs - batches = toChunks batchSize $ L.toList serverSubs - SMPSubscriber {newSubQ = subscriberSubQ} <- getSMPSubscriber srv - mapM_ (atomically . writeTQueue subscriberSubQ) batches - - server :: NtfEntityRec 'Subscription -> SMPServer - server (NtfSub sub) = ntfSubServer sub + (srv, subs) <- atomically $ readTBQueue newSubQ + -- TODO [ntfdb] as we now group by server before putting subs to queue, + -- maybe this "subscribe" thread can be removed completely, + -- and the caller would directly write to SMPSubscriber queues + SMPSubscriber {subscriberSubQ} <- getSMPSubscriber srv + atomically $ writeTQueue subscriberSubQ subs + -- TODO [ntfdb] this does not guarantee that only one subscriber per server is created + -- there should be TMVar in the map + -- This does not need changing if single newSubQ remains, but if it is removed, it need to change getSMPSubscriber :: SMPServer -> M SMPSubscriber getSMPSubscriber smpServer = liftIO (TM.lookupIO smpServer smpSubscribers) >>= maybe createSMPSubscriber pure where createSMPSubscriber = do - sub@SMPSubscriber {subThreadId} <- liftIO newSMPSubscriber + sub@SMPSubscriber {subThreadId} <- liftIO $ newSMPSubscriber smpServer atomically $ TM.insert smpServer sub smpSubscribers tId <- mkWeakThreadId =<< forkIO (runSMPSubscriber sub) atomically . writeTVar subThreadId $ Just tId pure sub runSMPSubscriber :: SMPSubscriber -> M () - runSMPSubscriber SMPSubscriber {newSubQ = subscriberSubQ} = + runSMPSubscriber SMPSubscriber {smpServer, subscriberSubQ} = do + st <- asks store forever $ do + -- TODO [ntfdb] possibly, the subscriptions can be batched here and sent every say 5 seconds + -- this should be analysed once we have prometheus stats subs <- atomically $ readTQueue subscriberSubQ - let subs' = L.map (\(NtfSub sub) -> sub) subs - srv = server $ L.head subs - logSubStatus srv "subscribing" $ length subs - mapM_ (\NtfSubData {smpQueue} -> updateSubStatus smpQueue NSPending) subs' - liftIO $ subscribeQueues srv subs' + -- TODO [ntfdb] validate/partition that SMP server matches and log internal error if not + updated <- liftIO $ batchUpdateSubStatus st subs NSPending + logSubStatus smpServer "subscribing" (L.length subs) updated + liftIO $ subscribeQueues smpServer subs -- \| Subscribe to queues. The list of results can have a different order. - subscribeQueues :: SMPServer -> NonEmpty NtfSubData -> IO () + subscribeQueues :: SMPServer -> NonEmpty NtfSubRec -> IO () subscribeQueues srv subs = subscribeQueuesNtfs ca srv (L.map sub subs) where - sub NtfSubData {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey) + sub NtfSubRec {smpQueue = SMPQueueNtf {notifierId}, notifierKey} = (notifierId, notifierKey) receiveSMP :: M () receiveSMP = forever $ do @@ -395,91 +414,83 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge NtfPushServer {pushQ} <- asks pushServer stats <- asks serverStats liftIO $ updatePeriodStats (activeSubs stats) ntfId - tkn_ <- atomically (findNtfSubscriptionToken st smpQueue) - forM_ tkn_ $ \tkn -> do - let newNtf = PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} - lastNtfs <- liftIO $ addTokenLastNtf st (ntfTknId tkn) newNtf - atomically (writeTBQueue pushQ (tkn, PNMessage lastNtfs)) + let newNtf = PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} + ntfs_ <- liftIO $ addTokenLastNtf st newNtf + forM_ ntfs_ $ \(tkn, lastNtfs) -> atomically $ writeTBQueue pushQ (tkn, PNMessage lastNtfs) + -- TODO [ntfdb] track queued notifications separately? incNtfStat ntfReceived - Right SMP.END -> - whenM (atomically $ activeClientSession' ca sessionId srv) $ - updateSubStatus smpQueue NSEnd - Right SMP.DELD -> updateSubStatus smpQueue NSDeleted + Right SMP.END -> do + whenM (atomically $ activeClientSession' ca sessionId srv) $ do + st <- asks store + void $ liftIO $ updateSrvSubStatus st smpQueue NSEnd + Right SMP.DELD -> do + st <- asks store + void $ liftIO $ updateSrvSubStatus st smpQueue NSDeleted Right (SMP.ERR e) -> logError $ "SMP server error: " <> tshow e Right _ -> logError "SMP server unexpected response" Left e -> logError $ "SMP client error: " <> tshow e - receiveAgent = + receiveAgent = do + st <- asks store forever $ atomically (readTBQueue agentQ) >>= \case CAConnected srv -> logInfo $ "SMP server reconnected " <> showServer' srv CADisconnected srv subs -> do - logSubStatus srv "disconnected" $ length subs - forM_ subs $ \(_, ntfId) -> do - let smpQueue = SMPQueueNtf srv ntfId - updateSubStatus smpQueue NSInactive - CASubscribed srv _ subs -> do - forM_ subs $ \ntfId -> updateSubStatus (SMPQueueNtf srv ntfId) NSActive - logSubStatus srv "subscribed" $ length subs - CASubError srv _ errs -> - forM errs (\(ntfId, err) -> handleSubError (SMPQueueNtf srv ntfId) err) - >>= logSubErrors srv . catMaybes . L.toList + forM_ (L.nonEmpty $ map snd $ S.toList subs) $ \nIds -> do + updated <- liftIO $ batchUpdateSrvSubStatus st srv nIds NSInactive + logSubStatus srv "disconnected" (L.length nIds) updated + CASubscribed srv _ nIds -> do + updated <- liftIO $ batchUpdateSrvSubStatus st srv nIds NSActive + logSubStatus srv "subscribed" (L.length nIds) updated + CASubError srv _ errs -> do + forM_ (L.nonEmpty $ mapMaybe (\(nId, err) -> (nId,) <$> subErrorStatus err) $ L.toList errs) $ \subStatuses -> do + updated <- liftIO $ batchUpdateSrvSubStatuses st srv subStatuses + logSubErrors srv subStatuses updated - logSubStatus srv event n = - when (n > 0) . logInfo $ - "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subscriptions)" + logSubStatus :: SMPServer -> T.Text -> Int -> Int64 -> M () + logSubStatus srv event n updated = + logInfo $ "SMP server " <> event <> " " <> showServer' srv <> " (" <> tshow n <> " subs, " <> tshow updated <> " subs updated)" - logSubErrors :: SMPServer -> [NtfSubStatus] -> M () - logSubErrors srv errs = forM_ (L.group $ sort errs) $ \errs' -> do - logError $ "SMP subscription errors on server " <> showServer' srv <> ": " <> tshow (L.head errs') <> " (" <> tshow (length errs') <> " errors)" + logSubErrors :: SMPServer -> NonEmpty (SMP.NotifierId, NtfSubStatus) -> Int64 -> M () + logSubErrors srv subs updated = forM_ (L.group $ L.sort $ L.map snd subs) $ \ss -> do + logError $ "SMP server subscription errors " <> showServer' srv <> ": " <> tshow (L.head ss) <> " (" <> tshow (length ss) <> " errors, " <> tshow updated <> " subs updated)" showServer' = decodeLatin1 . strEncode . host - handleSubError :: SMPQueueNtf -> SMPClientError -> M (Maybe NtfSubStatus) - handleSubError smpQueue = \case - PCEProtocolError AUTH -> updateSubStatus smpQueue NSAuth $> Just NSAuth + subErrorStatus :: SMPClientError -> Maybe NtfSubStatus + subErrorStatus = \case + PCEProtocolError AUTH -> Just NSAuth PCEProtocolError e -> updateErr "SMP error " e PCEResponseError e -> updateErr "ResponseError " e PCEUnexpectedResponse r -> updateErr "UnexpectedResponse " r PCETransportError e -> updateErr "TransportError " e PCECryptoError e -> updateErr "CryptoError " e - PCEIncompatibleHost -> let e = NSErr "IncompatibleHost" in updateSubStatus smpQueue e $> Just e - PCEResponseTimeout -> pure Nothing - PCENetworkError -> pure Nothing - PCEIOError _ -> pure Nothing + PCEIncompatibleHost -> Just $ NSErr "IncompatibleHost" + PCEResponseTimeout -> Nothing + PCENetworkError -> Nothing + PCEIOError _ -> Nothing where - updateErr :: Show e => ByteString -> e -> M (Maybe NtfSubStatus) - updateErr errType e = updateSubStatus smpQueue (NSErr $ errType <> bshow e) $> Just (NSErr errType) - - updateSubStatus smpQueue status = do - st <- asks store - atomically (findNtfSubscription st smpQueue) >>= mapM_ update - where - update NtfSubData {ntfSubId, subStatus} = do - old <- atomically $ stateTVar subStatus (,status) - when (old /= status) $ withNtfLog $ \sl -> logSubscriptionStatus sl ntfSubId status + -- Note on moving to PostgreSQL: the idea of logging errors without e is removed here + updateErr :: Show e => ByteString -> e -> Maybe NtfSubStatus + updateErr errType e = Just $ NSErr $ errType <> bshow e ntfPush :: NtfPushServer -> M () ntfPush s@NtfPushServer {pushQ} = forever $ do - (tkn@NtfTknData {ntfTknId, token = t@(DeviceToken pp _), tknStatus}, ntf) <- atomically (readTBQueue pushQ) + (tkn@NtfTknRec {ntfTknId, token = t@(DeviceToken pp _), tknStatus}, ntf) <- atomically (readTBQueue pushQ) liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp) - status <- readTVarIO tknStatus case ntf of PNVerification _ -> deliverNotification pp tkn ntf >>= \case Right _ -> do - status_ <- atomically $ stateTVar tknStatus $ \case - NTActive -> (Nothing, NTActive) - NTConfirmed -> (Nothing, NTConfirmed) - _ -> (Just NTConfirmed, NTConfirmed) - forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status' + st <- asks store + void $ liftIO $ setTknStatusConfirmed st tkn incNtfStatT t ntfVrfDelivered Left _ -> incNtfStatT t ntfVrfFailed - PNCheckMessages -> checkActiveTkn status $ do + PNCheckMessages -> checkActiveTkn tknStatus $ do deliverNotification pp tkn ntf >>= incNtfStatT t . (\case Left _ -> ntfCronFailed; Right () -> ntfCronDelivered) - PNMessage {} -> checkActiveTkn status $ do + PNMessage {} -> checkActiveTkn tknStatus $ do stats <- asks serverStats liftIO $ updatePeriodStats (activeTokens stats) ntfTknId deliverNotification pp tkn ntf @@ -489,8 +500,8 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do checkActiveTkn status action | status == NTActive = action | otherwise = liftIO $ logError "bad notification token status" - deliverNotification :: PushProvider -> NtfTknData -> PushNotification -> M (Either PushProviderError ()) - deliverNotification pp tkn@NtfTknData {ntfTknId} ntf = do + deliverNotification :: PushProvider -> NtfTknRec -> PushNotification -> M (Either PushProviderError ()) + deliverNotification pp tkn@NtfTknRec {ntfTknId} ntf = do deliver <- liftIO $ getPushClient s pp liftIO (runExceptT $ deliver tkn ntf) >>= \case Right _ -> pure $ Right () @@ -499,7 +510,10 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do PPRetryLater -> retryDeliver PPCryptoError _ -> err e PPResponseError {} -> err e - PPTokenInvalid r -> updateTknStatus tkn (NTInvalid $ Just r) >> err e + PPTokenInvalid r -> do + st <- asks store + void $ liftIO $ updateTknStatus st tkn $ NTInvalid $ Just r + err e PPPermanentError -> err e where retryDeliver :: M (Either PushProviderError ()) @@ -508,15 +522,13 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do liftIO (runExceptT $ deliver tkn ntf) >>= \case Right _ -> pure $ Right () Left e -> case e of - PPTokenInvalid r -> updateTknStatus tkn (NTInvalid $ Just r) >> err e + PPTokenInvalid r -> do + st <- asks store + void $ liftIO $ updateTknStatus st tkn $ NTInvalid $ Just r + err e _ -> err e err e = logError ("Push provider error (" <> tshow pp <> ", " <> tshow ntfTknId <> "): " <> tshow e) $> Left e -updateTknStatus :: NtfTknData -> NtfTknStatus -> M () -updateTknStatus NtfTknData {ntfTknId, tknStatus} status = do - old <- atomically $ stateTVar tknStatus (,status) - when (old /= status) $ withNtfLog $ \sl -> logTokenStatus sl ntfTknId status - runNtfClientTransport :: Transport c => THandleNTF c 'TServer -> M () runNtfClientTransport th@THandle {params} = do qSize <- asks $ clientQSize . config @@ -563,160 +575,144 @@ send h@THandle {params} NtfServerClient {sndQ, sndActiveAt} = forever $ do void . liftIO $ tPut h $ L.map (\t -> Right (Nothing, encodeTransmission params t)) ts atomically . (writeTVar sndActiveAt $!) =<< liftIO getSystemTime -data VerificationResult = VRVerified (Maybe NtfTknData, NtfRequest) | VRFailed +data VerificationResult = VRVerified NtfRequest | VRFailed verifyNtfTransmission :: Maybe (THandleAuth 'TServer, C.CbNonce) -> SignedTransmission ErrorType NtfCmd -> NtfCmd -> M VerificationResult verifyNtfTransmission auth_ (tAuth, authorized, (corrId, entId, _)) cmd = do st <- asks store case cmd of + -- TODO [ntfdb] this looks suspicious, as if it can prevent repeated registrations NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _)) -> do - r_ <- atomically $ getNtfTokenRegistration st tkn + r_ <- liftIO $ getNtfTokenRegistration st tkn pure $ if verifyCmdAuthorization auth_ tAuth authorized k then case r_ of - Just t@NtfTknData {tknVerifyKey} - | k == tknVerifyKey -> verifiedTknCmd t c + Right t@NtfTknRec {tknVerifyKey} + -- keys will be the same because of condition in `getNtfTokenRegistration` + | k == tknVerifyKey -> VRVerified $ tknCmd t c | otherwise -> VRFailed - Nothing -> VRVerified (Nothing, NtfReqNew corrId (ANE SToken tkn)) + Left _ -> VRVerified (NtfReqNew corrId (ANE SToken tkn)) else VRFailed NtfCmd SToken c -> do - t_ <- liftIO $ getNtfTokenIO st entId - verifyToken t_ (`verifiedTknCmd` c) - NtfCmd SSubscription c@(SNEW sub@(NewNtfSub tknId smpQueue _)) -> do - s_ <- atomically $ findNtfSubscription st smpQueue - case s_ of - Nothing -> do - t_ <- atomically $ getActiveNtfToken st tknId - verifyToken' t_ $ VRVerified (t_, NtfReqNew corrId (ANE SSubscription sub)) - Just s@NtfSubData {tokenId = subTknId} -> - if subTknId == tknId - then do - t_ <- atomically $ getActiveNtfToken st subTknId - verifyToken' t_ $ verifiedSubCmd t_ s c - else pure $ maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed - NtfCmd SSubscription PING -> pure $ VRVerified (Nothing, NtfReqPing corrId entId) - NtfCmd SSubscription c -> do - s_ <- liftIO $ getNtfSubscriptionIO st entId - case s_ of - Just s@NtfSubData {tokenId = subTknId} -> do - t_ <- atomically $ getActiveNtfToken st subTknId - verifyToken' t_ $ verifiedSubCmd t_ s c - _ -> pure $ maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed + t_ <- liftIO $ getNtfToken st entId + verifyToken_' t_ (`tknCmd` c) + NtfCmd SSubscription c@(SNEW sub@(NewNtfSub tknId smpQueue _)) -> + liftIO $ verify <$> findNtfSubscription st tknId smpQueue + where + verify = \case + Right (t, s_) -> verifyToken t $ case s_ of + Nothing -> NtfReqNew corrId (ANE SSubscription sub) + Just s -> subCmd s c + -- TODO [ntfdb] it should simply return error if it is not AUTH + Left _ -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed + NtfCmd SSubscription PING -> pure $ VRVerified $ NtfReqPing corrId entId + NtfCmd SSubscription c -> liftIO $ verify <$> getNtfSubscription st entId + where + verify = \case + Right (t, s) -> verifyToken t $ subCmd s c + -- TODO [ntfdb] it should simply return error if it is not AUTH + Left _ -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed where - verifiedTknCmd t c = VRVerified (Just t, NtfReqCmd SToken (NtfTkn t) (corrId, entId, c)) - verifiedSubCmd t_ s c = VRVerified (t_, NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c)) - verifyToken :: Maybe NtfTknData -> (NtfTknData -> VerificationResult) -> M VerificationResult - verifyToken t_ positiveVerificationResult = - pure $ case t_ of - Just t@NtfTknData {tknVerifyKey} -> - if verifyCmdAuthorization auth_ tAuth authorized tknVerifyKey - then positiveVerificationResult t - else VRFailed - _ -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed - verifyToken' :: Maybe NtfTknData -> VerificationResult -> M VerificationResult - verifyToken' t_ = verifyToken t_ . const + tknCmd t c = NtfReqCmd SToken (NtfTkn t) (corrId, entId, c) + subCmd s c = NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c) + verifyToken_' :: Either ErrorType NtfTknRec -> (NtfTknRec -> NtfRequest) -> M VerificationResult + verifyToken_' t_ result = pure $ case t_ of + Right t -> verifyToken t $ result t + -- TODO [ntfdb] it should simply return error if it is not AUTH + Left _ -> maybe False (dummyVerifyCmd auth_ authorized) tAuth `seq` VRFailed + verifyToken :: NtfTknRec -> NtfRequest -> VerificationResult + verifyToken NtfTknRec {tknVerifyKey} r + | verifyCmdAuthorization auth_ tAuth authorized tknVerifyKey = VRVerified r + | otherwise = VRFailed client :: NtfServerClient -> NtfSubscriber -> NtfPushServer -> M () client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPushServer {pushQ, intervalNotifiers} = - forever $ do - ts <- liftIO getSystemDate + forever $ atomically (readTBQueue rcvQ) - >>= mapM (\(tkn_, req) -> updateTokenDate ts tkn_ >> processCommand req) + >>= mapM processCommand >>= atomically . writeTBQueue sndQ where - updateTokenDate :: RoundedSystemTime -> Maybe NtfTknData -> M () - updateTokenDate ts' = mapM_ $ \NtfTknData {ntfTknId, tknUpdatedAt} -> do - let t' = Just ts' - t <- atomically $ swapTVar tknUpdatedAt t' - unless (t' == t) $ withNtfLog $ \s -> logUpdateTokenTime s ntfTknId ts' processCommand :: NtfRequest -> M (Transmission NtfResponse) processCommand = \case - NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn token _ dhPubKey)) -> do + NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn token _ dhPubKey)) -> (corrId,NoEntity,) <$> do logDebug "TNEW - new token" - st <- asks store - ks@(srvDhPubKey, srvDhPrivKey) <- atomically . C.generateKeyPair =<< asks random + (srvDhPubKey, srvDhPrivKey) <- atomically . C.generateKeyPair =<< asks random let dhSecret = C.dh' dhPubKey srvDhPrivKey tknId <- getId regCode <- getRegCode ts <- liftIO $ getSystemDate - tkn <- liftIO $ mkNtfTknData tknId newTkn ks dhSecret regCode ts - atomically $ addNtfToken st tknId tkn - atomically $ writeTBQueue pushQ (tkn, PNVerification regCode) - incNtfStatT token ntfVrfQueued - withNtfLog (`logCreateToken` tkn) - incNtfStatT token tknCreated - pure (corrId, NoEntity, NRTknId tknId srvDhPubKey) - NtfReqCmd SToken (NtfTkn tkn@NtfTknData {token, ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey), tknCronInterval}) (corrId, tknId, cmd) -> do - status <- readTVarIO tknStatus + let tkn = mkNtfTknRec tknId newTkn srvDhPrivKey dhSecret regCode ts + withNtfStore (`addNtfToken` tkn) $ \_ -> do + atomically $ writeTBQueue pushQ (tkn, PNVerification regCode) + incNtfStatT token ntfVrfQueued + incNtfStatT token tknCreated + pure $ NRTknId tknId srvDhPubKey + NtfReqCmd SToken (NtfTkn tkn@NtfTknRec {token, ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhPrivKey}) (corrId, tknId, cmd) -> do (corrId,tknId,) <$> case cmd of TNEW (NewNtfTkn _ _ dhPubKey) -> do logDebug "TNEW - registered token" - let dhSecret = C.dh' dhPubKey srvDhPrivKey + let dhSecret = C.dh' dhPubKey tknDhPrivKey -- it is required that DH secret is the same, to avoid failed verifications if notification is delaying if tknDhSecret == dhSecret then do atomically $ writeTBQueue pushQ (tkn, PNVerification tknRegCode) incNtfStatT token ntfVrfQueued - pure $ NRTknId ntfTknId srvDhPubKey + pure $ NRTknId ntfTknId $ C.publicKey tknDhPrivKey else pure $ NRErr AUTH TVFY code -- this allows repeated verification for cases when client connection dropped before server response - | (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do + | (tknStatus == NTRegistered || tknStatus == NTConfirmed || tknStatus == NTActive) && tknRegCode == code -> do logDebug "TVFY - token verified" - st <- asks store - updateTknStatus tkn NTActive - tIds <- atomically $ removeInactiveTokenRegistrations st tkn - forM_ tIds cancelInvervalNotifications - incNtfStatT token tknVerified - pure NROk + withNtfStore (`setTokenActive` tkn) $ \tIds -> do + -- TODO [ntfdb] this will be unnecessary if all cron notifications move to one thread + forM_ tIds cancelInvervalNotifications + incNtfStatT token tknVerified + pure NROk | otherwise -> do logDebug "TVFY - incorrect code or token status" + liftIO $ print tkn + let NtfRegCode c = code + liftIO $ print $ B64.encode c pure $ NRErr AUTH TCHK -> do logDebug "TCHK" - pure $ NRTkn status + pure $ NRTkn tknStatus TRPL token' -> do logDebug "TRPL - replace token" - st <- asks store regCode <- getRegCode - atomically $ do - removeTokenRegistration st tkn - writeTVar tknStatus NTRegistered - let tkn' = tkn {token = token', tknRegCode = regCode} - addNtfToken st tknId tkn' - writeTBQueue pushQ (tkn', PNVerification regCode) - incNtfStatT token ntfVrfQueued - withNtfLog $ \s -> logUpdateToken s tknId token' regCode - incNtfStatT token tknReplaced - pure NROk + let tkn' = tkn {token = token', tknStatus = NTRegistered, tknRegCode = regCode} + withNtfStore (`replaceNtfToken` tkn') $ \_ -> do + atomically $ writeTBQueue pushQ (tkn', PNVerification regCode) + incNtfStatT token ntfVrfQueued + incNtfStatT token tknReplaced + pure NROk TDEL -> do logDebug "TDEL" - st <- asks store - qs <- atomically $ deleteNtfToken st tknId - forM_ qs $ \SMPQueueNtf {smpServer, notifierId} -> - atomically $ removeSubscription ca smpServer (SPNotifier, notifierId) - cancelInvervalNotifications tknId - withNtfLog (`logDeleteToken` tknId) - incNtfStatT token tknDeleted - pure NROk + withNtfStore (`deleteNtfToken` tknId) $ \ss -> do + forM_ ss $ \(smpServer, nIds) -> do + atomically $ removeSubscriptions ca smpServer SPNotifier nIds + atomically $ removePendingSubs ca smpServer SPNotifier nIds + cancelInvervalNotifications tknId + incNtfStatT token tknDeleted + pure NROk TCRN 0 -> do logDebug "TCRN 0" - atomically $ writeTVar tknCronInterval 0 - cancelInvervalNotifications tknId - withNtfLog $ \s -> logTokenCron s tknId 0 - pure NROk + withNtfStore (\st -> updateTknCronInterval st ntfTknId 0) $ \_ -> do + -- TODO [ntfdb] move cron intervals to one thread + cancelInvervalNotifications tknId + pure NROk TCRN int | int < 20 -> pure $ NRErr QUOTA | otherwise -> do logDebug "TCRN" - atomically $ writeTVar tknCronInterval int - liftIO (TM.lookupIO tknId intervalNotifiers) >>= \case - Nothing -> runIntervalNotifier int - Just IntervalNotifier {interval, action} -> - unless (interval == int) $ do - uninterruptibleCancel action - runIntervalNotifier int - withNtfLog $ \s -> logTokenCron s tknId int - pure NROk + withNtfStore (\st -> updateTknCronInterval st ntfTknId int) $ \_ -> do + -- TODO [ntfdb] move cron intervals to one thread + liftIO (TM.lookupIO tknId intervalNotifiers) >>= \case + Nothing -> runIntervalNotifier int + Just IntervalNotifier {interval, action} -> + unless (interval == int) $ do + uninterruptibleCancel action + runIntervalNotifier int + pure NROk where runIntervalNotifier interval = do action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60 @@ -726,20 +722,20 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu intervalNotifier delay = forever $ do liftIO $ threadDelay' delay atomically $ writeTBQueue pushQ (tkn, PNCheckMessages) - NtfReqNew corrId (ANE SSubscription newSub) -> do + NtfReqNew corrId (ANE SSubscription newSub@(NewNtfSub _ (SMPQueueNtf srv _) _)) -> do logDebug "SNEW - new subscription" - st <- asks store subId <- getId - sub <- atomically $ mkNtfSubData subId newSub + let sub = mkNtfSubRec subId newSub resp <- - atomically (addNtfSubscription st subId sub) >>= \case - Just _ -> atomically (writeTBQueue newSubQ [NtfSub sub]) $> NRSubId subId - _ -> pure $ NRErr AUTH - withNtfLog (`logCreateSubscription` sub) - incNtfStat subCreated + withNtfStore (`addNtfSubscription` sub) $ \case + True -> do + atomically $ writeTBQueue newSubQ (srv, [sub]) + incNtfStat subCreated + pure $ NRSubId subId + -- TODO [ntfdb] we must allow repeated inserts that don't change credentials + False -> pure $ NRErr AUTH pure (corrId, NoEntity, resp) - NtfReqCmd SSubscription (NtfSub NtfSubData {smpQueue = SMPQueueNtf {smpServer, notifierId}, notifierKey = registeredNKey, subStatus}) (corrId, subId, cmd) -> do - status <- readTVarIO subStatus + NtfReqCmd SSubscription (NtfSub NtfSubRec {smpQueue = SMPQueueNtf {smpServer, notifierId}, notifierKey = registeredNKey, subStatus}) (corrId, subId, cmd) -> do (corrId,subId,) <$> case cmd of SNEW (NewNtfSub _ _ notifierKey) -> do logDebug "SNEW - existing subscription" @@ -750,15 +746,14 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu else NRErr AUTH SCHK -> do logDebug "SCHK" - pure $ NRSub status + pure $ NRSub subStatus SDEL -> do logDebug "SDEL" - st <- asks store - atomically $ deleteNtfSubscription st subId - atomically $ removeSubscription ca smpServer (SPNotifier, notifierId) - withNtfLog (`logDeleteSubscription` subId) - incNtfStat subDeleted - pure NROk + withNtfStore (`deleteNtfSubscription` subId) $ \_ -> do + atomically $ removeSubscription ca smpServer (SPNotifier, notifierId) + atomically $ removePendingSub ca smpServer (SPNotifier, notifierId) + incNtfStat subDeleted + pure NROk PING -> pure NRPong NtfReqPing corrId entId -> pure (corrId, entId, NRPong) getId :: M NtfEntityId @@ -772,8 +767,12 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu atomically (TM.lookupDelete tknId intervalNotifiers) >>= mapM_ (uninterruptibleCancel . action) -withNtfLog :: (StoreLog 'WriteMode -> IO a) -> M () -withNtfLog action = liftIO . mapM_ action =<< asks storeLog +withNtfStore :: (NtfPostgresStore -> IO (Either ErrorType a)) -> (a -> M NtfResponse) -> M NtfResponse +withNtfStore stAction continue = do + st <- asks store + liftIO (stAction st) >>= \case + Left e -> pure $ NRErr e + Right a -> continue a incNtfStatT :: DeviceToken -> (NtfServerStats -> IORef Int) -> M () incNtfStatT (DeviceToken PPApnsNull _) _ = pure () @@ -784,43 +783,24 @@ incNtfStat statSel = do stats <- asks serverStats liftIO $ atomicModifyIORef'_ (statSel stats) (+ 1) -saveServerLastNtfs :: M () -saveServerLastNtfs = asks (storeLastNtfsFile . config) >>= mapM_ saveLastNtfs +restoreServerLastNtfs :: NtfSTMStore -> FilePath -> IO () +restoreServerLastNtfs st f = + whenM (doesFileExist f) $ do + logInfo $ "restoring last notifications from file " <> T.pack f + runExceptT (liftIO (B.readFile f) >>= mapM restoreNtf . B.lines) >>= \case + Left e -> do + logError . T.pack $ "error restoring last notifications: " <> e + exitFailure + Right _ -> do + renameFile f $ f <> ".bak" + logInfo "last notifications restored" where - saveLastNtfs f = do - logInfo $ "saving last notifications to file " <> T.pack f - NtfStore {tokenLastNtfs} <- asks store - liftIO . withFile f WriteMode $ \h -> - readTVarIO tokenLastNtfs >>= mapM_ (saveTokenLastNtfs h) . M.assocs - logInfo "notifications saved" + restoreNtf s = do + TNMRv1 tknId ntf <- liftEither . first (ntfErr "parsing") $ strDecode s + liftIO $ stmStoreTokenLastNtf st tknId ntf where - -- reverse on save, to save notifications in order, will become reversed again when restoring. - saveTokenLastNtfs h (tknId, v) = BLD.hPutBuilder h . encodeLastNtfs tknId . L.reverse =<< readTVarIO v - encodeLastNtfs tknId = mconcat . L.toList . L.map (\ntf -> BLD.byteString (strEncode $ TNMRv1 tknId ntf) <> BLD.char8 '\n') - -restoreServerLastNtfs :: M () -restoreServerLastNtfs = - asks (storeLastNtfsFile . config) >>= mapM_ restoreLastNtfs - where - restoreLastNtfs f = - whenM (doesFileExist f) $ do - logInfo $ "restoring last notifications from file " <> T.pack f - st <- asks store - runExceptT (liftIO (LB.readFile f) >>= mapM (restoreNtf st) . LB.lines) >>= \case - Left e -> do - logError . T.pack $ "error restoring last notifications: " <> e - liftIO exitFailure - Right _ -> do - renameFile f $ f <> ".bak" - logInfo "last notifications restored" - where - restoreNtf st s' = do - TNMRv1 tknId ntf <- liftEither . first (ntfErr "parsing") $ strDecode s - liftIO $ storeTokenLastNtf st tknId ntf - where - s = LB.toStrict s' - ntfErr :: Show e => String -> e -> String - ntfErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s) + ntfErr :: Show e => String -> e -> String + ntfErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s) saveServerStats :: M () saveServerStats = diff --git a/src/Simplex/Messaging/Notifications/Server/Env.hs b/src/Simplex/Messaging/Notifications/Server/Env.hs index 3859a3df1..46f3e9f2d 100644 --- a/src/Simplex/Messaging/Notifications/Server/Env.hs +++ b/src/Simplex/Messaging/Notifications/Server/Env.hs @@ -9,7 +9,6 @@ module Simplex.Messaging.Notifications.Server.Env where import Control.Concurrent (ThreadId) import Control.Concurrent.Async (Async) -import Control.Logger.Simple import Crypto.Random import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) @@ -25,16 +24,17 @@ import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Stats -import Simplex.Messaging.Notifications.Server.Store -import Simplex.Messaging.Notifications.Server.StoreLog +import Simplex.Messaging.Notifications.Server.Store.Postgres +import Simplex.Messaging.Notifications.Server.Store.Types import Simplex.Messaging.Notifications.Transport (NTFVersion, VersionRangeNTF) import Simplex.Messaging.Protocol (BasicAuth, CorrId, SMPServer, Transmission) +import Simplex.Messaging.Server.Env.STM (StartOptions) import Simplex.Messaging.Server.Expiration +import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport (ATransport, THandleParams, TransportPeer (..)) import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials, TransportServerConfig, loadFingerprint, loadServerCredential) -import System.IO (IOMode (..)) import System.Mem.Weak (Weak) import UnliftIO.STM @@ -52,8 +52,7 @@ data NtfServerConfig = NtfServerConfig apnsConfig :: APNSPushClientConfig, subsBatchSize :: Int, inactiveClientExpiration :: Maybe ExpirationConfig, - storeLogFile :: Maybe FilePath, - storeLastNtfsFile :: Maybe FilePath, + dbStoreConfig :: PostgresStoreCfg, ntfCredentials :: ServerCredentials, -- stats config - see SMP server config logStatsInterval :: Maybe Int64, @@ -61,7 +60,8 @@ data NtfServerConfig = NtfServerConfig serverStatsLogFile :: FilePath, serverStatsBackupFile :: Maybe FilePath, ntfServerVRange :: VersionRangeNTF, - transportConfig :: TransportServerConfig + transportConfig :: TransportServerConfig, + startOptions :: StartOptions } defaultInactiveClientExpiration :: ExpirationConfig @@ -75,8 +75,7 @@ data NtfEnv = NtfEnv { config :: NtfServerConfig, subscriber :: NtfSubscriber, pushServer :: NtfPushServer, - store :: NtfStore, - storeLog :: Maybe (StoreLog 'WriteMode), + store :: NtfPostgresStore, random :: TVar ChaChaDRG, tlsServerCreds :: T.Credential, serverIdentity :: C.KeyHash, @@ -84,22 +83,23 @@ data NtfEnv = NtfEnv } newNtfServerEnv :: NtfServerConfig -> IO NtfEnv -newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, ntfCredentials} = do +newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, dbStoreConfig, ntfCredentials} = do random <- C.newRandom - store <- newNtfStore - logInfo "restoring subscriptions..." - storeLog <- mapM (`readWriteNtfStore` store) storeLogFile - logInfo "restored subscriptions" + store <- newNtfDbStore dbStoreConfig + -- TODO [ntfdb] this should happen with compacting on start + -- logInfo "restoring subscriptions..." + -- storeLog <- mapM (`readWriteNtfStore` store) storeLogFile + -- logInfo "restored subscriptions" subscriber <- newNtfSubscriber subQSize smpAgentCfg random pushServer <- newNtfPushServer pushQSize apnsConfig tlsServerCreds <- loadServerCredential ntfCredentials Fingerprint fp <- loadFingerprint ntfCredentials serverStats <- newNtfServerStats =<< getCurrentTime - pure NtfEnv {config, subscriber, pushServer, store, storeLog, random, tlsServerCreds, serverIdentity = C.KeyHash fp, serverStats} + pure NtfEnv {config, subscriber, pushServer, store, random, tlsServerCreds, serverIdentity = C.KeyHash fp, serverStats} data NtfSubscriber = NtfSubscriber { smpSubscribers :: TMap SMPServer SMPSubscriber, - newSubQ :: TBQueue [NtfEntityRec 'Subscription], + newSubQ :: TBQueue (SMPServer, NonEmpty NtfSubRec), -- should match SMPServer smpAgent :: SMPClientAgent } @@ -111,18 +111,19 @@ newNtfSubscriber qSize smpAgentCfg random = do pure NtfSubscriber {smpSubscribers, newSubQ, smpAgent} data SMPSubscriber = SMPSubscriber - { newSubQ :: TQueue (NonEmpty (NtfEntityRec 'Subscription)), + { smpServer :: SMPServer, + subscriberSubQ :: TQueue (NonEmpty NtfSubRec), subThreadId :: TVar (Maybe (Weak ThreadId)) } -newSMPSubscriber :: IO SMPSubscriber -newSMPSubscriber = do - newSubQ <- newTQueueIO +newSMPSubscriber :: SMPServer -> IO SMPSubscriber +newSMPSubscriber smpServer = do + subscriberSubQ <- newTQueueIO subThreadId <- newTVarIO Nothing - pure SMPSubscriber {newSubQ, subThreadId} + pure SMPSubscriber {smpServer, subscriberSubQ, subThreadId} data NtfPushServer = NtfPushServer - { pushQ :: TBQueue (NtfTknData, PushNotification), + { pushQ :: TBQueue (NtfTknRec, PushNotification), pushClients :: TMap PushProvider PushProviderClient, intervalNotifiers :: TMap NtfTokenId IntervalNotifier, apnsConfig :: APNSPushClientConfig @@ -130,7 +131,7 @@ data NtfPushServer = NtfPushServer data IntervalNotifier = IntervalNotifier { action :: Async (), - token :: NtfTknData, + token :: NtfTknRec, interval :: Word16 } @@ -159,7 +160,7 @@ data NtfRequest | NtfReqPing CorrId NtfEntityId data NtfServerClient = NtfServerClient - { rcvQ :: TBQueue (NonEmpty (Maybe NtfTknData, NtfRequest)), + { rcvQ :: TBQueue (NonEmpty NtfRequest), sndQ :: TBQueue (NonEmpty (Transmission NtfResponse)), ntfThParams :: THandleParams NTFVersion 'TServer, connected :: TVar Bool, diff --git a/src/Simplex/Messaging/Notifications/Server/Main.hs b/src/Simplex/Messaging/Notifications/Server/Main.hs index 5418ec4e9..aa0e036ba 100644 --- a/src/Simplex/Messaging/Notifications/Server/Main.hs +++ b/src/Simplex/Messaging/Notifications/Server/Main.hs @@ -11,29 +11,42 @@ module Simplex.Messaging.Notifications.Server.Main where import Control.Monad ((<$!>)) +import qualified Data.ByteString.Char8 as B import Data.Functor (($>)) import Data.Ini (lookupValue, readIniFile) +import Data.Int (Int64) import Data.Maybe (fromMaybe) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.IO as T -import Network.Socket (HostName) +import Network.Socket (HostName, ServiceName) import Options.Applicative +import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) +import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Client (HostMode (..), NetworkConfig (..), ProtocolClientConfig (..), SocksMode (..), defaultNetworkConfig, textToHostMode) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Notifications.Server (runNtfServer) +import Simplex.Messaging.Notifications.Server (runNtfServer, restoreServerLastNtfs) import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..), defaultInactiveClientExpiration) import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientConfig) +import Simplex.Messaging.Notifications.Server.Store (newNtfSTMStore) +import Simplex.Messaging.Notifications.Server.Store.Postgres (exportNtfDbStore, importNtfSTMStore, newNtfDbStore) +import Simplex.Messaging.Notifications.Server.StoreLog (readWriteNtfSTMStore) import Simplex.Messaging.Notifications.Transport (supportedServerNTFVRange) import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), pattern NtfServer) import Simplex.Messaging.Server.CLI +import Simplex.Messaging.Server.Env.STM (StartOptions) import Simplex.Messaging.Server.Expiration -import Simplex.Messaging.Transport (simplexMQVersion) +import Simplex.Messaging.Server.Main.Init (iniDbOpts) +import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) +import Simplex.Messaging.Server.StoreLog (closeStoreLog) +import Simplex.Messaging.Transport (ATransport, simplexMQVersion) import Simplex.Messaging.Transport.Client (TransportHost (..)) -import Simplex.Messaging.Transport.Server (ServerCredentials (..), TransportServerConfig (..), defaultTransportServerConfig) -import Simplex.Messaging.Util (tshow) -import System.Directory (createDirectoryIfMissing, doesFileExist) +import Simplex.Messaging.Transport.Server (AddHTTP, ServerCredentials (..), TransportServerConfig (..), defaultTransportServerConfig) +import Simplex.Messaging.Util (ifM, tshow) +import System.Directory (createDirectoryIfMissing, doesFileExist, renameFile) +import System.Exit (exitFailure) import System.FilePath (combine) import System.IO (BufferMode (..), hSetBuffering, stderr, stdout) import Text.Read (readMaybe) @@ -45,14 +58,8 @@ ntfServerCLI cfgPath logPath = doesFileExist iniFile >>= \case True -> exitError $ "Error: server is already initialized (" <> iniFile <> " exists).\nRun `" <> executableName <> " start`." _ -> initializeServer opts - OnlineCert certOpts -> - doesFileExist iniFile >>= \case - True -> genOnline cfgPath certOpts - _ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `" <> executableName <> " init`." - Start -> - doesFileExist iniFile >>= \case - True -> readIniFile iniFile >>= either exitError runServer - _ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `" <> executableName <> " init`." + OnlineCert certOpts -> withIniFile $ \_ -> genOnline cfgPath certOpts + Start opts -> withIniFile $ runServer opts Delete -> do confirmOrExit "WARNING: deleting the server will make all queues inaccessible, because the server identity (certificate fingerprint) will change.\nTHIS CANNOT BE UNDONE!" @@ -60,13 +67,75 @@ ntfServerCLI cfgPath logPath = deleteDirIfExists cfgPath deleteDirIfExists logPath putStrLn "Deleted configuration and log files" + Database cmd dbOpts@DBOpts {connstr, schema} -> withIniFile $ \ini -> do + schemaExists <- checkSchemaExists connstr schema + storeLogExists <- doesFileExist storeLogFilePath + lastNtfsExists <- doesFileExist defaultLastNtfsFile + case cmd of + SCImport + | schemaExists && (storeLogExists || lastNtfsExists) -> exitConfigureNtfStore connstr schema + | schemaExists -> do + putStrLn $ "Schema " <> B.unpack schema <> " already exists in PostrgreSQL database: " <> B.unpack connstr + exitFailure + | not storeLogExists -> do + putStrLn $ storeLogFilePath <> " file does not exist." + exitFailure + | not lastNtfsExists -> do + putStrLn $ defaultLastNtfsFile <> " file does not exist." + exitFailure + | otherwise -> do + storeLogFile <- getRequiredStoreLogFile ini + confirmOrExit + ("WARNING: store log file " <> storeLogFile <> " will be compacted and imported to PostrgreSQL database: " <> B.unpack connstr <> ", schema: " <> B.unpack schema) + "Notification server store not imported" + stmStore <- newNtfSTMStore + sl <- readWriteNtfSTMStore True storeLogFile stmStore + closeStoreLog sl + restoreServerLastNtfs stmStore defaultLastNtfsFile + let storeCfg = PostgresStoreCfg {dbOpts = dbOpts {createSchema = True}, dbStoreLogPath = Nothing, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} + ps <- newNtfDbStore storeCfg + (tCnt, sCnt, nCnt) <- importNtfSTMStore ps stmStore + renameFile storeLogFile $ storeLogFile <> ".bak" + putStrLn $ "Import completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show nCnt <> " last token notifications." + putStrLn "Configure database options in INI file." + SCExport + | schemaExists && storeLogExists -> exitConfigureNtfStore connstr schema + | not schemaExists -> do + putStrLn $ "Schema " <> B.unpack schema <> " does not exist in PostrgreSQL database: " <> B.unpack connstr + exitFailure + | storeLogExists -> do + putStrLn $ storeLogFilePath <> " file already exists." + exitFailure + | lastNtfsExists -> do + putStrLn $ defaultLastNtfsFile <> " file already exists." + exitFailure + | otherwise -> do + confirmOrExit + ("WARNING: PostrgreSQL database schema " <> B.unpack schema <> " (database: " <> B.unpack connstr <> ") will be exported to store log file " <> storeLogFilePath) + "Notification server store not imported" + let storeCfg = PostgresStoreCfg {dbOpts, dbStoreLogPath = Just storeLogFilePath, confirmMigrations = MCConsole, deletedTTL = iniDeletedTTL ini} + st <- newNtfDbStore storeCfg + (tCnt, sCnt, nCnt) <- exportNtfDbStore st defaultLastNtfsFile + putStrLn $ "Export completed: " <> show tCnt <> " tokens, " <> show sCnt <> " subscriptions, " <> show nCnt <> " last token notifications." where + withIniFile a = + doesFileExist iniFile >>= \case + True -> readIniFile iniFile >>= either exitError a + _ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `" <> executableName <> " init`." + getRequiredStoreLogFile ini = do + case enableStoreLog' ini $> storeLogFilePath of + Just storeLogFile -> do + ifM + (doesFileExist storeLogFile) + (pure storeLogFile) + (putStrLn ("Store log file " <> storeLogFile <> " not found") >> exitFailure) + Nothing -> putStrLn "Store log disabled, see `[STORE_LOG] enable`" >> exitFailure iniFile = combine cfgPath "ntf-server.ini" serverVersion = "SMP notifications server v" <> simplexMQVersion defaultServerPort = "443" executableName = "ntf-server" storeLogFilePath = combine logPath "ntf-server-store.log" - initializeServer InitOptions {enableStoreLog, signAlgorithm, ip, fqdn} = do + initializeServer InitOptions {enableStoreLog, dbOptions, signAlgorithm, ip, fqdn} = do clearDirIfExists cfgPath clearDirIfExists logPath createDirectoryIfMissing True cfgPath @@ -88,6 +157,10 @@ ntfServerCLI cfgPath logPath = \# and restoring it when the server is started.\n\ \# Log is compacted on start (deleted objects are removed).\n" <> ("enable: " <> onOff enableStoreLog <> "\n\n") + <> "# Database connection settings for PostgreSQL database.\n" + <> iniDbOpts dbOptions defaultNtfDBOpts + <> "Time to retain deleted entities in the database, days.\n" + <> ("# db_deleted_ttl: " <> tshow defaultDeletedTTL <> "\n\n") <> "# Last notifications are optionally saved and restored when the server restarts,\n\ \# they are preserved in the .bak file until the next restart.\n" <> ("restore_last_notifications: " <> onOff enableStoreLog <> "\n\n") @@ -125,26 +198,29 @@ ntfServerCLI cfgPath logPath = \disconnect: off\n" <> ("# ttl: " <> tshow (ttl defaultInactiveClientExpiration) <> "\n") <> ("# check_interval: " <> tshow (checkInterval defaultInactiveClientExpiration) <> "\n") - runServer ini = do + enableStoreLog' = settingIsOn "STORE_LOG" "enable" + runServer startOptions ini = do hSetBuffering stdout LineBuffering hSetBuffering stderr LineBuffering fp <- checkSavedFingerprint cfgPath defaultX509Config let host = either (const "") T.unpack $ lookupValue "TRANSPORT" "host" ini port = T.unpack $ strictIni "TRANSPORT" "port" ini - cfg@NtfServerConfig {transports, storeLogFile} = serverConfig + cfg@NtfServerConfig {transports} = serverConfig srv = ProtoServerWithAuth (NtfServer [THDomainName host] (if port == "443" then "" else port) (C.KeyHash fp)) Nothing printServiceInfo serverVersion srv - printServerConfig transports storeLogFile + printNtfServerConfig transports dbStoreConfig runNtfServer cfg where - enableStoreLog = settingIsOn "STORE_LOG" "enable" ini logStats = settingIsOn "STORE_LOG" "log_stats" ini c = combine cfgPath . ($ defaultX509Config) - restoreLastNtfsFile path = case iniOnOff "STORE_LOG" "restore_last_notifications" ini of - Just True -> Just path - Just False -> Nothing - -- if the setting is not set, it is enabled when store log is enabled - _ -> enableStoreLog $> path + dbStoreLogPath = enableStoreLog' ini $> storeLogFilePath + dbStoreConfig = + PostgresStoreCfg + { dbOpts = iniDBOptions ini defaultNtfDBOpts, + dbStoreLogPath, + confirmMigrations = MCYesUp, + deletedTTL = iniDeletedTTL ini + } serverConfig = NtfServerConfig { transports = iniTransports ini, @@ -180,8 +256,7 @@ ntfServerCLI cfgPath logPath = { ttl = readStrictIni "INACTIVE_CLIENTS" "ttl" ini, checkInterval = readStrictIni "INACTIVE_CLIENTS" "check_interval" ini }, - storeLogFile = enableStoreLog $> storeLogFilePath, - storeLastNtfsFile = restoreLastNtfsFile $ combine logPath "ntf-server-last-notifications.log", + dbStoreConfig, ntfCredentials = ServerCredentials { caCertificateFile = Just $ c caCrtFile, @@ -196,32 +271,67 @@ ntfServerCLI cfgPath logPath = transportConfig = defaultTransportServerConfig { logTLSErrors = fromMaybe False $ iniOnOff "TRANSPORT" "log_tls_errors" ini - } + }, + startOptions } + iniDeletedTTL ini = readIniDefault (86400 * defaultDeletedTTL) "STORE_LOG" "db_deleted_ttl" ini + defaultLastNtfsFile = combine logPath "ntf-server-last-notifications.log" + exitConfigureNtfStore connstr schema = do + putStrLn $ "Error: both " <> storeLogFilePath <> " file and " <> B.unpack schema <> " schema are present (database: " <> B.unpack connstr <> ")." + putStrLn "Configure notification server storage." + exitFailure + +printNtfServerConfig :: [(ServiceName, ATransport, AddHTTP)] -> PostgresStoreCfg -> IO () +printNtfServerConfig transports PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}, dbStoreLogPath} = do + B.putStrLn $ "PostgreSQL database: " <> connstr <> ", schema: " <> schema + printServerConfig "NTF" transports dbStoreLogPath data CliCommand = Init InitOptions | OnlineCert CertOptions - | Start + | Start StartOptions | Delete + | Database StoreCmd DBOpts + +data StoreCmd = SCImport | SCExport data InitOptions = InitOptions { enableStoreLog :: Bool, + dbOptions :: DBOpts, signAlgorithm :: SignAlgorithm, ip :: HostName, fqdn :: Maybe HostName } deriving (Show) +defaultNtfDBOpts :: DBOpts +defaultNtfDBOpts = + DBOpts + { connstr = "postgresql://ntf@/ntf_server_store", + schema = "ntf_server", + poolSize = 10, + createSchema = False + } + +-- time to retain deleted tokens and subscriptions in the database (days), for debugging +defaultDeletedTTL :: Int64 +defaultDeletedTTL = 21 + cliCommandP :: FilePath -> FilePath -> FilePath -> Parser CliCommand cliCommandP cfgPath logPath iniFile = hsubparser ( command "init" (info (Init <$> initP) (progDesc $ "Initialize server - creates " <> cfgPath <> " and " <> logPath <> " directories and configuration files")) <> command "cert" (info (OnlineCert <$> certOptionsP) (progDesc $ "Generate new online TLS server credentials (configuration: " <> iniFile <> ")")) - <> command "start" (info (pure Start) (progDesc $ "Start server (configuration: " <> iniFile <> ")")) + <> command "start" (info (Start <$> startOptionsP) (progDesc $ "Start server (configuration: " <> iniFile <> ")")) <> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files")) + <> command "database" (info (Database <$> databaseCmdP <*> dbOptsP defaultNtfDBOpts) (progDesc "Import/export notifications server store to/from PostgreSQL database")) ) where + databaseCmdP = + hsubparser + ( command "import" (info (pure SCImport) (progDesc $ "Import store logs into a new PostgreSQL database schema")) + <> command "export" (info (pure SCExport) (progDesc $ "Export PostgreSQL database schema to store logs")) + ) initP :: Parser InitOptions initP = do enableStoreLog <- @@ -234,6 +344,7 @@ cliCommandP cfgPath logPath iniFile = <> short 'l' <> help "Enable store log for persistence (DEPRECATED, enabled by default)" ) + dbOptions <- dbOptsP defaultNtfDBOpts signAlgorithm <- option (maybeReader readMaybe) @@ -261,4 +372,4 @@ cliCommandP cfgPath logPath iniFile = <> showDefault <> metavar "FQDN" ) - pure InitOptions {enableStoreLog, signAlgorithm, ip, fqdn} + pure InitOptions {enableStoreLog, dbOptions, signAlgorithm, ip, fqdn} diff --git a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs index ec9cd272c..7439e4fea 100644 --- a/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs +++ b/src/Simplex/Messaging/Notifications/Server/Push/APNS.hs @@ -50,7 +50,7 @@ import Network.Socket (HostName, ServiceName) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Push.APNS.Internal -import Simplex.Messaging.Notifications.Server.Store (NtfTknData (..)) +import Simplex.Messaging.Notifications.Server.Store.Types (NtfTknRec (..)) import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.Transport.HTTP2 (HTTP2Body (..)) import Simplex.Messaging.Transport.HTTP2.Client @@ -263,8 +263,8 @@ disconnectApnsHTTP2Client APNSPushClient {https2Client} = ntfCategoryCheckMessage :: Text ntfCategoryCheckMessage = "NTF_CAT_CHECK_MESSAGE" -apnsNotification :: NtfTknData -> C.CbNonce -> Int -> PushNotification -> Either C.CryptoError APNSNotification -apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case +apnsNotification :: NtfTknRec -> C.CbNonce -> Int -> PushNotification -> Either C.CryptoError APNSNotification +apnsNotification NtfTknRec {tknDhSecret} nonce paddedLen = \case PNVerification (NtfRegCode code) -> encrypt code $ \code' -> apn APNSBackground {contentAvailable = 1} . Just $ J.object ["nonce" .= nonce, "verification" .= code'] @@ -313,7 +313,7 @@ data PushProviderError | PPPermanentError deriving (Show, Exception) -type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProviderError IO () +type PushProviderClient = NtfTknRec -> PushNotification -> ExceptT PushProviderError IO () -- this is not a newtype on purpose to have a correct JSON encoding as a record data APNSErrorResponse = APNSErrorResponse {reason :: Text} @@ -321,7 +321,7 @@ data APNSErrorResponse = APNSErrorResponse {reason :: Text} $(JQ.deriveFromJSON defaultJSON ''APNSErrorResponse) apnsPushProviderClient :: APNSPushClient -> PushProviderClient -apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken _ tknStr} pn = do +apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknRec {token = DeviceToken _ tknStr} pn = do http2 <- liftHTTPS2 $ getApnsHTTP2Client c nonce <- atomically $ C.randomCbNonce nonceDrg apnsNtf <- liftEither $ first PPCryptoError $ apnsNotification tkn nonce (paddedNtfLength apnsCfg) pn diff --git a/src/Simplex/Messaging/Notifications/Server/Store.hs b/src/Simplex/Messaging/Notifications/Server/Store.hs index 259a933b6..4b8a4e230 100644 --- a/src/Simplex/Messaging/Notifications/Server/Store.hs +++ b/src/Simplex/Messaging/Notifications/Server/Store.hs @@ -30,7 +30,7 @@ import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (whenM, ($>>=)) -data NtfStore = NtfStore +data NtfSTMStore = NtfSTMStore { tokens :: TMap NtfTokenId NtfTknData, -- multiple registrations exist to protect from malicious registrations if token is compromised tokenRegistrations :: TMap DeviceToken (TMap ByteString NtfTokenId), @@ -40,15 +40,15 @@ data NtfStore = NtfStore tokenLastNtfs :: TMap NtfTokenId (TVar (NonEmpty PNMessageData)) } -newNtfStore :: IO NtfStore -newNtfStore = do +newNtfSTMStore :: IO NtfSTMStore +newNtfSTMStore = do tokens <- TM.emptyIO tokenRegistrations <- TM.emptyIO subscriptions <- TM.emptyIO tokenSubscriptions <- TM.emptyIO subscriptionLookup <- TM.emptyIO tokenLastNtfs <- TM.emptyIO - pure NtfStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup, tokenLastNtfs} + pure NtfSTMStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup, tokenLastNtfs} data NtfTknData = NtfTknData { ntfTknId :: NtfTokenId, @@ -80,18 +80,11 @@ data NtfSubData = NtfSubData ntfSubServer :: NtfSubData -> SMPServer ntfSubServer NtfSubData {smpQueue = SMPQueueNtf {smpServer}} = smpServer -data NtfEntityRec (e :: NtfEntity) where - NtfTkn :: NtfTknData -> NtfEntityRec 'Token - NtfSub :: NtfSubData -> NtfEntityRec 'Subscription +stmGetNtfTokenIO :: NtfSTMStore -> NtfTokenId -> IO (Maybe NtfTknData) +stmGetNtfTokenIO st tknId = TM.lookupIO tknId (tokens st) -getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData) -getNtfToken st tknId = TM.lookup tknId (tokens st) - -getNtfTokenIO :: NtfStore -> NtfTokenId -> IO (Maybe NtfTknData) -getNtfTokenIO st tknId = TM.lookupIO tknId (tokens st) - -addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM () -addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do +stmAddNtfToken :: NtfSTMStore -> NtfTokenId -> NtfTknData -> STM () +stmAddNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do TM.insert tknId tkn $ tokens st TM.lookup token regs >>= \case Just tIds -> TM.insert regKey tknId tIds @@ -102,16 +95,8 @@ addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do regs = tokenRegistrations st regKey = C.toPubKey C.pubKeyBytes tknVerifyKey -getNtfTokenRegistration :: NtfStore -> NewNtfEntity 'Token -> STM (Maybe NtfTknData) -getNtfTokenRegistration st (NewNtfTkn token tknVerifyKey _) = - TM.lookup token (tokenRegistrations st) - $>>= TM.lookup regKey - $>>= (`TM.lookup` tokens st) - where - regKey = C.toPubKey C.pubKeyBytes tknVerifyKey - -removeInactiveTokenRegistrations :: NtfStore -> NtfTknData -> STM [NtfTokenId] -removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} = +stmRemoveInactiveTokenRegistrations :: NtfSTMStore -> NtfTknData -> STM [NtfTokenId] +stmRemoveInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} = TM.lookup token (tokenRegistrations st) >>= maybe (pure []) removeRegs where @@ -125,8 +110,8 @@ removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} = void $ deleteTokenSubs st tId' pure $ map snd tIds -removeTokenRegistration :: NtfStore -> NtfTknData -> STM () -removeTokenRegistration st NtfTknData {ntfTknId = tId, token, tknVerifyKey} = +stmRemoveTokenRegistration :: NtfSTMStore -> NtfTknData -> STM () +stmRemoveTokenRegistration st NtfTknData {ntfTknId = tId, token, tknVerifyKey} = TM.lookup token (tokenRegistrations st) >>= mapM_ removeReg where removeReg regs = @@ -134,8 +119,8 @@ removeTokenRegistration st NtfTknData {ntfTknId = tId, token, tknVerifyKey} = >>= mapM_ (\tId' -> when (tId == tId') $ TM.delete k regs) k = C.toPubKey C.pubKeyBytes tknVerifyKey -deleteNtfToken :: NtfStore -> NtfTokenId -> STM [SMPQueueNtf] -deleteNtfToken st tknId = do +stmDeleteNtfToken :: NtfSTMStore -> NtfTokenId -> STM [SMPQueueNtf] +stmDeleteNtfToken st tknId = do void $ TM.lookupDelete tknId (tokens st) $>>= \NtfTknData {token, tknVerifyKey} -> TM.lookup token regs $>>= \tIds -> @@ -147,7 +132,7 @@ deleteNtfToken st tknId = do regs = tokenRegistrations st regKey = C.toPubKey C.pubKeyBytes -deleteTokenSubs :: NtfStore -> NtfTokenId -> STM [SMPQueueNtf] +deleteTokenSubs :: NtfSTMStore -> NtfTokenId -> STM [SMPQueueNtf] deleteTokenSubs st tknId = do qs <- TM.lookupDelete tknId (tokenSubscriptions st) @@ -159,32 +144,11 @@ deleteTokenSubs st tknId = do $>>= \NtfSubData {smpQueue} -> TM.delete smpQueue (subscriptionLookup st) $> Just smpQueue -getNtfSubscriptionIO :: NtfStore -> NtfSubscriptionId -> IO (Maybe NtfSubData) -getNtfSubscriptionIO st subId = TM.lookupIO subId (subscriptions st) +stmGetNtfSubscriptionIO :: NtfSTMStore -> NtfSubscriptionId -> IO (Maybe NtfSubData) +stmGetNtfSubscriptionIO st subId = TM.lookupIO subId (subscriptions st) -findNtfSubscription :: NtfStore -> SMPQueueNtf -> STM (Maybe NtfSubData) -findNtfSubscription st smpQueue = do - TM.lookup smpQueue (subscriptionLookup st) - $>>= \subId -> TM.lookup subId (subscriptions st) - -findNtfSubscriptionToken :: NtfStore -> SMPQueueNtf -> STM (Maybe NtfTknData) -findNtfSubscriptionToken st smpQueue = do - findNtfSubscription st smpQueue - $>>= \NtfSubData {tokenId} -> getActiveNtfToken st tokenId - -getActiveNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData) -getActiveNtfToken st tknId = - getNtfToken st tknId $>>= \tkn@NtfTknData {tknStatus} -> do - tStatus <- readTVar tknStatus - pure $ if tStatus == NTActive then Just tkn else Nothing - -mkNtfSubData :: NtfSubscriptionId -> NewNtfEntity 'Subscription -> STM NtfSubData -mkNtfSubData ntfSubId (NewNtfSub tokenId smpQueue notifierKey) = do - subStatus <- newTVar NSNew - pure NtfSubData {ntfSubId, smpQueue, tokenId, subStatus, notifierKey} - -addNtfSubscription :: NtfStore -> NtfSubscriptionId -> NtfSubData -> STM (Maybe ()) -addNtfSubscription st subId sub@NtfSubData {smpQueue, tokenId} = +stmAddNtfSubscription :: NtfSTMStore -> NtfSubscriptionId -> NtfSubData -> STM (Maybe ()) +stmAddNtfSubscription st subId sub@NtfSubData {smpQueue, tokenId} = TM.lookup tokenId (tokenSubscriptions st) >>= maybe newTokenSub pure >>= insertSub where newTokenSub = do @@ -198,8 +162,8 @@ addNtfSubscription st subId sub@NtfSubData {smpQueue, tokenId} = -- return Nothing if subscription existed before pure $ Just () -deleteNtfSubscription :: NtfStore -> NtfSubscriptionId -> STM () -deleteNtfSubscription st subId = do +stmDeleteNtfSubscription :: NtfSTMStore -> NtfSubscriptionId -> STM () +stmDeleteNtfSubscription st subId = do TM.lookupDelete subId (subscriptions st) >>= mapM_ ( \NtfSubData {smpQueue, tokenId} -> do @@ -208,32 +172,10 @@ deleteNtfSubscription st subId = do forM_ ts_ $ \ts -> modifyTVar' ts $ S.delete subId ) -addTokenLastNtf :: NtfStore -> NtfTokenId -> PNMessageData -> IO (NonEmpty PNMessageData) -addTokenLastNtf st tknId newNtf = - TM.lookupIO tknId (tokenLastNtfs st) >>= maybe (atomically maybeNewTokenLastNtfs) (atomically . addNtf) - where - maybeNewTokenLastNtfs = - TM.lookup tknId (tokenLastNtfs st) >>= maybe newTokenLastNtfs addNtf - newTokenLastNtfs = do - v <- newTVar [newNtf] - TM.insert tknId v $ tokenLastNtfs st - pure [newNtf] - addNtf v = - stateTVar v $ \ntfs -> let !ntfs' = rebuildList ntfs in (ntfs', ntfs') - where - rebuildList :: NonEmpty PNMessageData -> NonEmpty PNMessageData - rebuildList = foldr keepPrevNtf [newNtf] - where - PNMessageData {smpQueue = newNtfQ} = newNtf - keepPrevNtf ntf@PNMessageData {smpQueue} ntfs - | smpQueue /= newNtfQ && length ntfs < maxNtfs = ntf <| ntfs - | otherwise = ntfs - maxNtfs = 6 - -- This function is expected to be called after store log is read, -- as it checks for token existence when adding last notification. -storeTokenLastNtf :: NtfStore -> NtfTokenId -> PNMessageData -> IO () -storeTokenLastNtf (NtfStore {tokens, tokenLastNtfs}) tknId ntf = do +stmStoreTokenLastNtf :: NtfSTMStore -> NtfTokenId -> PNMessageData -> IO () +stmStoreTokenLastNtf (NtfSTMStore {tokens, tokenLastNtfs}) tknId ntf = do TM.lookupIO tknId tokenLastNtfs >>= atomically . maybe newTokenLastNtfs (`modifyTVar'` (ntf <|)) where newTokenLastNtfs = TM.lookup tknId tokenLastNtfs >>= maybe insertForExistingToken (`modifyTVar'` (ntf <|)) diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs new file mode 100644 index 000000000..a9de42668 --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Server/Store/Migrations.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Notifications.Server.Store.Migrations where + +import Data.List (sortOn) +import Data.Text (Text) +import qualified Data.Text as T +import Simplex.Messaging.Agent.Store.Shared +import Text.RawString.QQ (r) + +ntfServerSchemaMigrations :: [(String, Text, Maybe Text)] +ntfServerSchemaMigrations = + [ ("20250417_initial", m20250417_initial, Nothing) + ] + +-- | The list of migrations in ascending order by date +ntfServerMigrations :: [Migration] +ntfServerMigrations = sortOn name $ map migration ntfServerSchemaMigrations + where + migration (name, up, down) = Migration {name, up, down = down} + +m20250417_initial :: Text +m20250417_initial = + T.pack + [r| +CREATE TABLE tokens( + token_id BYTEA NOT NULL, + push_provider TEXT NOT NULL, + push_provider_token BYTEA NOT NULL, + status TEXT NOT NULL, + verify_key BYTEA NOT NULL, + dh_priv_key BYTEA NOT NULL, + dh_secret BYTEA NOT NULL, + reg_code BYTEA NOT NULL, + cron_interval BIGINT NOT NULL, + cron_sent_at BIGINT, + updated_at BIGINT, + PRIMARY KEY (token_id) +); + +CREATE UNIQUE INDEX idx_tokens_push_provider_token ON tokens(push_provider, push_provider_token, verify_key); +CREATE INDEX idx_tokens_cron_sent_at ON tokens((cron_sent_at + cron_interval)); + +CREATE TABLE smp_servers( + smp_server_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + smp_host TEXT NOT NULL, + smp_port TEXT NOT NULL, + smp_keyhash BYTEA NOT NULL +); + +CREATE UNIQUE INDEX idx_smp_servers ON smp_servers(smp_host, smp_port, smp_keyhash); + +CREATE TABLE subscriptions( + subscription_id BYTEA NOT NULL, + token_id BYTEA NOT NULL REFERENCES tokens ON DELETE CASCADE ON UPDATE RESTRICT, + smp_server_id BIGINT REFERENCES smp_servers ON DELETE RESTRICT ON UPDATE RESTRICT, + smp_notifier_id BYTEA NOT NULL, + smp_notifier_key BYTEA NOT NULL, + status TEXT NOT NULL, + PRIMARY KEY (subscription_id) +); + +CREATE UNIQUE INDEX idx_subscriptions_smp_server_id_notifier_id ON subscriptions(smp_server_id, smp_notifier_id); +CREATE INDEX idx_subscriptions_smp_server_id_status ON subscriptions(smp_server_id, status); +CREATE INDEX idx_subscriptions_token_id ON subscriptions(token_id); + +CREATE TABLE last_notifications( + token_ntf_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + token_id BYTEA NOT NULL REFERENCES tokens ON DELETE CASCADE ON UPDATE RESTRICT, + subscription_id BYTEA NOT NULL REFERENCES subscriptions ON DELETE CASCADE ON UPDATE RESTRICT, + sent_at BIGINT NOT NULL, + nmsg_nonce BYTEA NOT NULL, + nmsg_data BYTEA NOT NULL +); + +CREATE INDEX idx_last_notifications_token_id_sent_at ON last_notifications(token_id, sent_at); +CREATE INDEX idx_last_notifications_subscription_id ON last_notifications(subscription_id); + +CREATE UNIQUE INDEX idx_last_notifications_token_subscription ON last_notifications(token_id, subscription_id); + |] diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs new file mode 100644 index 000000000..afe60e5b5 --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Server/Store/Postgres.hs @@ -0,0 +1,809 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Simplex.Messaging.Notifications.Server.Store.Postgres where + +import Control.Concurrent.STM +import qualified Control.Exception as E +import Control.Logger.Simple +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class +import Control.Monad.Trans.Except +import Data.Bitraversable (bimapM) +import qualified Data.ByteString.Base64.URL as B64 +import Data.ByteString.Char8 (ByteString) +import qualified Data.ByteString.Char8 as B +import Data.Containers.ListUtils (nubOrd) +import Data.Either (fromRight) +import Data.Functor (($>)) +import Data.Int (Int64) +import Data.List (foldl', intercalate) +import Data.List.NonEmpty (NonEmpty (..)) +import qualified Data.List.NonEmpty as L +import qualified Data.Map.Strict as M +import Data.Maybe (catMaybes, fromMaybe, mapMaybe) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding (decodeLatin1, encodeUtf8) +import Data.Time.Clock.System (SystemTime (..)) +import Data.Word (Word16) +import Database.PostgreSQL.Simple (Binary (..), In (..), Only (..), Query, ToRow, (:.) (..)) +import qualified Database.PostgreSQL.Simple as DB +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +import Network.Socket (ServiceName) +import Simplex.Messaging.Agent.Store.AgentStore () +import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) +import Simplex.Messaging.Agent.Store.Postgres.Common +import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder, fromTextField_) +import Simplex.Messaging.Encoding +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Notifications.Protocol +import Simplex.Messaging.Notifications.Server.Store (NtfSTMStore (..), NtfSubData (..), NtfTknData (..), TokenNtfMessageRecord (..), ntfSubServer) +import Simplex.Messaging.Notifications.Server.Store.Migrations +import Simplex.Messaging.Notifications.Server.Store.Types +import Simplex.Messaging.Notifications.Server.StoreLog +import Simplex.Messaging.Parsers (parseAll) +import Simplex.Messaging.Protocol (EntityId (..), EncNMsgMeta, ErrorType (..), NotifierId, NtfPrivateAuthKey, NtfPublicAuthKey, SMPServer, pattern SMPServer) +import Simplex.Messaging.Server.QueueStore (RoundedSystemTime, getSystemDate) +import Simplex.Messaging.Server.QueueStore.Postgres (handleDuplicate, withLog_) +import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) +import Simplex.Messaging.Server.StoreLog (openWriteStoreLog) +import Simplex.Messaging.Transport.Client (TransportHost) +import Simplex.Messaging.Util (anyM, firstRow, maybeFirstRow, toChunks, tshow) +import System.Exit (exitFailure) +import System.IO (IOMode (..), hFlush, stdout, withFile) +import Text.Hex (decodeHex) + +#if !defined(dbPostgres) +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Util (eitherToMaybe) +#endif + +data NtfPostgresStore = NtfPostgresStore + { dbStore :: DBStore, + dbStoreLog :: Maybe (StoreLog 'WriteMode), + deletedTTL :: Int64 + } + +mkNtfTknRec :: NtfTokenId -> NewNtfEntity 'Token -> C.PrivateKeyX25519 -> C.DhSecretX25519 -> NtfRegCode -> RoundedSystemTime -> NtfTknRec +mkNtfTknRec ntfTknId (NewNtfTkn token tknVerifyKey _) tknDhPrivKey tknDhSecret tknRegCode ts = + NtfTknRec {ntfTknId, token, tknStatus = NTRegistered, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval = 0, tknUpdatedAt = Just ts} + +ntfSubServer' :: NtfSubRec -> SMPServer +ntfSubServer' NtfSubRec {smpQueue = SMPQueueNtf {smpServer}} = smpServer + +data NtfEntityRec (e :: NtfEntity) where + NtfTkn :: NtfTknRec -> NtfEntityRec 'Token + NtfSub :: NtfSubRec -> NtfEntityRec 'Subscription + +newNtfDbStore :: PostgresStoreCfg -> IO NtfPostgresStore +newNtfDbStore PostgresStoreCfg {dbOpts, dbStoreLogPath, confirmMigrations, deletedTTL} = do + dbStore <- either err pure =<< createDBStore dbOpts ntfServerMigrations confirmMigrations + dbStoreLog <- mapM (openWriteStoreLog True) dbStoreLogPath + pure NtfPostgresStore {dbStore, dbStoreLog, deletedTTL} + where + err e = do + logError $ "STORE: newNtfStore, error opening PostgreSQL database, " <> tshow e + exitFailure + +closeNtfDbStore :: NtfPostgresStore -> IO () +closeNtfDbStore NtfPostgresStore {dbStore, dbStoreLog} = do + closeDBStore dbStore + mapM_ closeStoreLog dbStoreLog + +addNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) +addNtfToken st tkn = + withDB "addNtfToken" st $ \db -> + E.try (DB.execute db insertNtfTknQuery $ ntfTknToRow tkn) + >>= bimapM handleDuplicate (\_ -> withLog "addNtfToken" st (`logCreateToken` tkn)) + +insertNtfTknQuery :: Query +insertNtfTknQuery = + [sql| + INSERT INTO tokens + (token_id, push_provider, push_provider_token, status, verify_key, dh_priv_key, dh_secret, reg_code, cron_interval, updated_at) + VALUES (?,?,?,?,?,?,?,?,?,?) + |] + +replaceNtfToken :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) +replaceNtfToken st NtfTknRec {ntfTknId, token = token@(DeviceToken pp ppToken), tknStatus, tknRegCode = code@(NtfRegCode regCode)} = + withDB "replaceNtfToken" st $ \db -> runExceptT $ do + ExceptT $ assertUpdated <$> + DB.execute + db + [sql| + UPDATE tokens + SET push_provider = ?, push_provider_token = ?, status = ?, reg_code = ? + WHERE token_id = ? + |] + (pp, Binary ppToken, tknStatus, Binary regCode, ntfTknId) + withLog "replaceNtfToken" st $ \sl -> logUpdateToken sl ntfTknId token code + +ntfTknToRow :: NtfTknRec -> NtfTknRow +ntfTknToRow NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} = + let DeviceToken pp ppToken = token + NtfRegCode regCode = tknRegCode + in (ntfTknId, pp, Binary ppToken, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, Binary regCode, tknCronInterval, tknUpdatedAt) + +getNtfToken :: NtfPostgresStore -> NtfTokenId -> IO (Either ErrorType NtfTknRec) +getNtfToken st tknId = getNtfToken_ st " WHERE token_id = ?" (Only tknId) + +getNtfTokenRegistration :: NtfPostgresStore -> NewNtfEntity 'Token -> IO (Either ErrorType NtfTknRec) +getNtfTokenRegistration st (NewNtfTkn (DeviceToken pp ppToken) tknVerifyKey _) = + getNtfToken_ st " WHERE push_provider = ? AND push_provider_token = ? AND verify_key = ?" (pp, Binary ppToken, tknVerifyKey) + +getNtfToken_ :: ToRow q => NtfPostgresStore -> Query -> q -> IO (Either ErrorType NtfTknRec) +getNtfToken_ st cond params = + withDB "getNtfToken" st $ \db -> runExceptT $ do + tkn <- ExceptT $ firstRow rowToNtfTkn AUTH $ DB.query db (ntfTknQuery <> cond) params + liftIO $ updateTokenDate st db tkn + pure tkn + +updateTokenDate :: NtfPostgresStore -> DB.Connection -> NtfTknRec -> IO () +updateTokenDate st db NtfTknRec {ntfTknId, tknUpdatedAt} = do + ts <- getSystemDate + when (maybe True (ts /=) tknUpdatedAt) $ do + void $ DB.execute db "UPDATE tokens SET updated_at = ? WHERE token_id = ?" (ts, ntfTknId) + withLog "updateTokenDate" st $ \sl -> logUpdateTokenTime sl ntfTknId ts + +type NtfTknRow = (NtfTokenId, PushProvider, Binary ByteString, NtfTknStatus, NtfPublicAuthKey, C.PrivateKeyX25519, C.DhSecretX25519, Binary ByteString, Word16, Maybe RoundedSystemTime) + +ntfTknQuery :: Query +ntfTknQuery = + [sql| + SELECT token_id, push_provider, push_provider_token, status, verify_key, dh_priv_key, dh_secret, reg_code, cron_interval, updated_at + FROM tokens + |] + +rowToNtfTkn :: NtfTknRow -> NtfTknRec +rowToNtfTkn (ntfTknId, pp, Binary ppToken, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, Binary regCode, tknCronInterval, tknUpdatedAt) = + let token = DeviceToken pp ppToken + tknRegCode = NtfRegCode regCode + in NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} + +deleteNtfToken :: NtfPostgresStore -> NtfTokenId -> IO (Either ErrorType [(SMPServer, [NotifierId])]) +deleteNtfToken st tknId = + withDB "deleteNtfToken" st $ \db -> runExceptT $ do + -- This SELECT obtains exclusive lock on token row and prevents any inserts + -- into other tables for this token ID until the deletion completes. + _ <- ExceptT $ firstRow (fromOnly @Int) AUTH $ + DB.query db "SELECT 1 FROM tokens WHERE token_id = ? FOR UPDATE" (Only tknId) + subs <- + liftIO $ map toServerSubs <$> + DB.query + db + [sql| + SELECT p.smp_host, p.smp_port, p.smp_keyhash, + string_agg(s.smp_notifier_id :: TEXT, ',') AS notifier_ids + FROM smp_servers p + JOIN subscriptions s ON s.smp_server_id = p.smp_server_id + WHERE s.token_id = ? + GROUP BY p.smp_host, p.smp_port, p.smp_keyhash; + |] + (Only tknId) + liftIO $ void $ DB.execute db "DELETE FROM tokens WHERE token_id = ?" (Only tknId) + withLog "deleteNtfToken" st (`logDeleteToken` tknId) + pure subs + where + toServerSubs :: SMPServerRow :. Only Text -> (SMPServer, [NotifierId]) + toServerSubs (srv :. Only nIdsStr) = (rowToSrv srv, parseByteaString nIdsStr) + parseByteaString :: Text -> [NotifierId] + parseByteaString s = mapMaybe (fmap EntityId . decodeHex . T.drop 2) $ T.splitOn "," s -- drop 2 to remove "\\x" + +type SMPServerRow = (NonEmpty TransportHost, ServiceName, C.KeyHash) + +type SMPQueueNtfRow = (NonEmpty TransportHost, ServiceName, C.KeyHash, NotifierId) + +rowToSrv :: SMPServerRow -> SMPServer +rowToSrv (host, port, kh) = SMPServer host port kh + +srvToRow :: SMPServer -> SMPServerRow +srvToRow (SMPServer host port kh) = (host, port, kh) + +smpQueueToRow :: SMPQueueNtf -> SMPQueueNtfRow +smpQueueToRow (SMPQueueNtf (SMPServer host port kh) nId) = (host, port, kh, nId) + +rowToSMPQueue :: SMPQueueNtfRow -> SMPQueueNtf +rowToSMPQueue (host, port, kh, nId) = SMPQueueNtf (SMPServer host port kh) nId + +updateTknCronInterval :: NtfPostgresStore -> NtfTokenId -> Word16 -> IO (Either ErrorType ()) +updateTknCronInterval st tknId cronInt = + withDB "updateTknCronInterval" st $ \db -> runExceptT $ do + ExceptT $ assertUpdated <$> + DB.execute db "UPDATE tokens SET cron_interval = ? WHERE token_id = ?" (cronInt, tknId) + withLog "updateTknCronInterval" st $ \sl -> logTokenCron sl tknId 0 + +-- Reads servers that have subscriptions that need subscribing. +-- It is executed on server start, and it is supposed to crash on database error +getUsedSMPServers :: NtfPostgresStore -> IO [SMPServer] +getUsedSMPServers st = + withTransaction (dbStore st) $ \db -> + map rowToSrv <$> + DB.query + db + [sql| + SELECT p.smp_host, p.smp_port, p.smp_keyhash + FROM smp_servers p + WHERE EXISTS ( + SELECT 1 FROM subscriptions s + WHERE s.smp_server_id = p.smp_server_id + AND s.status IN ? + ) + |] + (Only (In [NSNew, NSPending, NSActive, NSInactive])) + +foldNtfSubscriptions :: NtfPostgresStore -> SMPServer -> Int -> s -> (s -> NtfSubRec -> IO s) -> IO s +foldNtfSubscriptions st srv fetchCount state action = + withConnection (dbStore st) $ \db -> + DB.foldWithOptions opts db query params state $ \s -> action s . toNtfSub + where + query = + [sql| + SELECT s.subscription_id, s.token_id, s.smp_notifier_id, s.status, s.smp_notifier_key + FROM subscriptions s + JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + WHERE p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? + AND s.status IN ? + |] + params = srvToRow srv :. Only (In [NSNew, NSPending, NSActive, NSInactive]) + opts = DB.defaultFoldOptions {DB.fetchQuantity = DB.Fixed fetchCount} + toNtfSub (ntfSubId, tokenId, nId, subStatus, notifierKey) = + NtfSubRec {ntfSubId, tokenId, smpQueue = SMPQueueNtf srv nId, subStatus, notifierKey} + +findNtfSubscription :: NtfPostgresStore -> NtfTokenId -> SMPQueueNtf -> IO (Either ErrorType (NtfTknRec, Maybe NtfSubRec)) +findNtfSubscription st tknId q@(SMPQueueNtf srv nId) = + withDB "findNtfSubscription" st $ \db -> runExceptT $ do + r@(tkn@NtfTknRec {tknStatus}, _) <- + ExceptT $ firstRow (rowToNtfTknMaybeSub q) AUTH $ + DB.query + db + [sql| + SELECT t.token_id, t.push_provider, t.push_provider_token, t.status, t.verify_key, t.dh_priv_key, t.dh_secret, t.reg_code, t.cron_interval, t.updated_at, + s.subscription_id, s.smp_notifier_key, s.status + FROM tokens t + LEFT JOIN subscriptions s ON s.token_id = t.token_id AND s.smp_notifier_id = ? + LEFT JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? + WHERE t.token_id = ? + |] + (Only nId :. srvToRow srv :. Only tknId) + liftIO $ updateTokenDate st db tkn + unless (allowNtfSubCommands tknStatus) $ throwE AUTH + pure r + +getNtfSubscription :: NtfPostgresStore -> NtfSubscriptionId -> IO (Either ErrorType (NtfTknRec, NtfSubRec)) +getNtfSubscription st subId = + withDB "getNtfSubscription" st $ \db -> runExceptT $ do + r@(tkn@NtfTknRec {tknStatus}, _) <- + ExceptT $ firstRow rowToNtfTknSub AUTH $ + DB.query + db + [sql| + SELECT t.token_id, t.push_provider, t.push_provider_token, t.status, t.verify_key, t.dh_priv_key, t.dh_secret, t.reg_code, t.cron_interval, t.updated_at, + s.subscription_id, s.smp_notifier_key, s.status, + p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id + FROM subscriptions s + JOIN tokens t ON t.token_id = s.token_id + JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + WHERE s.subscription_id = ? + |] + (Only subId) + liftIO $ updateTokenDate st db tkn + unless (allowNtfSubCommands tknStatus) $ throwE AUTH + pure r + +type NtfSubRow = (NtfSubscriptionId, NtfPrivateAuthKey, NtfSubStatus) + +type MaybeNtfSubRow = (Maybe NtfSubscriptionId, Maybe NtfPrivateAuthKey, Maybe NtfSubStatus) + +rowToNtfTknSub :: NtfTknRow :. NtfSubRow :. SMPQueueNtfRow -> (NtfTknRec, NtfSubRec) +rowToNtfTknSub (tknRow :. (ntfSubId, notifierKey, subStatus) :. qRow) = + let tkn@NtfTknRec {ntfTknId = tokenId} = rowToNtfTkn tknRow + smpQueue = rowToSMPQueue qRow + in (tkn, NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus}) + +rowToNtfTknMaybeSub :: SMPQueueNtf -> NtfTknRow :. MaybeNtfSubRow -> (NtfTknRec, Maybe NtfSubRec) +rowToNtfTknMaybeSub smpQueue (tknRow :. subRow) = + let tkn@NtfTknRec {ntfTknId = tokenId} = rowToNtfTkn tknRow + sub_ = case subRow of + (Just ntfSubId, Just notifierKey, Just subStatus) -> + Just NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus} + _ -> Nothing + in (tkn, sub_) + +mkNtfSubRec :: NtfSubscriptionId -> NewNtfEntity 'Subscription -> NtfSubRec +mkNtfSubRec ntfSubId (NewNtfSub tokenId smpQueue notifierKey) = + NtfSubRec {ntfSubId, tokenId, smpQueue, subStatus = NSNew, notifierKey} + +updateTknStatus :: NtfPostgresStore -> NtfTknRec -> NtfTknStatus -> IO (Either ErrorType ()) +updateTknStatus st tkn status = + withDB' "updateTknStatus" st $ \db -> updateTknStatus_ st db tkn status + +updateTknStatus_ :: NtfPostgresStore -> DB.Connection -> NtfTknRec -> NtfTknStatus -> IO () +updateTknStatus_ st db NtfTknRec {ntfTknId} status = do + updated <- DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ?" (status, ntfTknId, status) + when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId status + +-- unless it was already active +setTknStatusConfirmed :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType ()) +setTknStatusConfirmed st NtfTknRec {ntfTknId} = + withDB' "updateTknStatus" st $ \db -> do + updated <- DB.execute db "UPDATE tokens SET status = ? WHERE token_id = ? AND status != ? AND status != ?" (NTConfirmed, ntfTknId, NTConfirmed, NTActive) + when (updated > 0) $ withLog "updateTknStatus" st $ \sl -> logTokenStatus sl ntfTknId NTConfirmed + +setTokenActive :: NtfPostgresStore -> NtfTknRec -> IO (Either ErrorType [NtfTokenId]) +setTokenActive st tkn@NtfTknRec {ntfTknId, token = DeviceToken pp ppToken} = + withDB' "setTokenActive" st $ \db -> do + updateTknStatus_ st db tkn NTActive + -- this removes other instances of the same token, e.g. because of repeated token registration attempts + tknIds <- + liftIO $ map fromOnly <$> + DB.query + db + [sql| + DELETE FROM tokens + WHERE push_provider = ? AND push_provider_token = ? AND token_id != ? + RETURNING token_id + |] + (pp, Binary ppToken, ntfTknId) + withLog "deleteNtfToken" st $ \sl -> mapM_ (logDeleteToken sl) tknIds + pure tknIds + +addNtfSubscription :: NtfPostgresStore -> NtfSubRec -> IO (Either ErrorType Bool) +addNtfSubscription st sub = + withDB "addNtfSubscription" st $ \db -> runExceptT $ do + srvId :: Int64 <- ExceptT $ upsertServer db $ ntfSubServer' sub + n <- liftIO $ DB.execute db insertNtfSubQuery $ ntfSubToRow srvId sub + withLog "addNtfSubscription" st (`logCreateSubscription` sub) + pure $ n > 0 + where + -- It is possible to combine these two statements into one with CTEs, + -- to reduce roundtrips in case of `insert`, but it would be making 2 queries in all cases. + -- With 2 statements it will succeed on the first `select` in most cases. + upsertServer db srv = getServer >>= maybe insertServer (pure . Right) + where + getServer = + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT smp_server_id + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + |] + (srvToRow srv) + insertServer = + firstRow fromOnly (STORE "error inserting SMP server when adding subscription") $ + DB.query + db + [sql| + INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash) VALUES (?, ?, ?) + ON CONFLICT (smp_host, smp_port, smp_keyhash) + DO UPDATE SET smp_host = EXCLUDED.smp_host + RETURNING smp_server_id + |] + (srvToRow srv) + +insertNtfSubQuery :: Query +insertNtfSubQuery = + [sql| + INSERT INTO subscriptions (token_id, smp_server_id, smp_notifier_id, subscription_id, smp_notifier_key, status) + VALUES (?,?,?,?,?,?) + |] + +ntfSubToRow :: Int64 -> NtfSubRec -> (NtfTokenId, Int64, NotifierId) :. NtfSubRow +ntfSubToRow srvId NtfSubRec {ntfSubId, tokenId, smpQueue = SMPQueueNtf _ nId, notifierKey, subStatus} = + (tokenId, srvId, nId) :. (ntfSubId, notifierKey, subStatus) + +deleteNtfSubscription :: NtfPostgresStore -> NtfSubscriptionId -> IO (Either ErrorType ()) +deleteNtfSubscription st subId = + withDB "deleteNtfSubscription" st $ \db -> runExceptT $ do + ExceptT $ assertUpdated <$> + DB.execute db "DELETE FROM subscriptions WHERE subscription_id = ?" (Only subId) + withLog "deleteNtfSubscription" st (`logDeleteSubscription` subId) + +updateSrvSubStatus :: NtfPostgresStore -> SMPQueueNtf -> NtfSubStatus -> IO (Either ErrorType ()) +updateSrvSubStatus st q status = + withDB' "updateSrvSubStatus" st $ \db -> do + subId_ :: Maybe NtfSubscriptionId <- + maybeFirstRow fromOnly $ + DB.query + db + [sql| + UPDATE subscriptions s + SET status = ? + FROM smp_servers p + WHERE p.smp_server_id = s.smp_server_id + AND p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? + AND s.status != ? + RETURNING s.subscription_id + |] + (Only status :. smpQueueToRow q :. Only status) + forM_ subId_ $ \subId -> + withLog "updateSrvSubStatus" st $ \sl -> logSubscriptionStatus sl subId status + +batchUpdateSrvSubStatus :: NtfPostgresStore -> SMPServer -> NonEmpty NotifierId -> NtfSubStatus -> IO Int64 +batchUpdateSrvSubStatus st srv nIds status = + batchUpdateStatus_ st srv $ \srvId -> + -- without executeMany + -- L.toList $ L.map (status,srvId,,status) nIds + L.toList $ L.map (status,srvId,) nIds + +batchUpdateSrvSubStatuses :: NtfPostgresStore -> SMPServer -> NonEmpty (NotifierId, NtfSubStatus) -> IO Int64 +batchUpdateSrvSubStatuses st srv subs = + batchUpdateStatus_ st srv $ \srvId -> + -- without executeMany + -- L.toList $ L.map (\(nId, status) -> (status, srvId, nId, status)) subs + L.toList $ L.map (\(nId, status) -> (status, srvId, nId)) subs + +-- without executeMany +-- batchUpdateStatus_ :: NtfPostgresStore -> SMPServer -> (Int64 -> [(NtfSubStatus, Int64, NotifierId, NtfSubStatus)]) -> IO Int64 +batchUpdateStatus_ :: NtfPostgresStore -> SMPServer -> (Int64 -> [(NtfSubStatus, Int64, NotifierId)]) -> IO Int64 +batchUpdateStatus_ st srv mkParams = + fmap (fromRight (-1)) $ withDB "batchUpdateStatus_" st $ \db -> runExceptT $ do + srvId <- ExceptT $ getSMPServerId db + let params = mkParams srvId + subs <- + liftIO $ + DB.returning + db + [sql| + UPDATE subscriptions s + SET status = upd.status + FROM (VALUES(?, ?, ?)) AS upd(status, smp_server_id, smp_notifier_id) + WHERE s.smp_server_id = upd.smp_server_id + AND s.smp_notifier_id = (upd.smp_notifier_id :: BYTEA) + AND s.status != upd.status + RETURNING s.subscription_id, s.status + |] + params + -- TODO [ntfdb] below is equivalent without using executeMany. + -- executeMany "works", and logs updates. + -- We do not have tests that validate correct subscription status, + -- and the potential problem is BYTEA conversation - VALUES are inserted as TEXT in this case for some reason. + -- subs <- + -- liftIO $ fmap catMaybes $ forM params $ + -- maybeFirstRow id . DB.query db "UPDATE subscriptions SET status = ? WHERE smp_server_id = ? AND smp_notifier_id = ? AND status != ? RETURNING subscription_id, status" + -- logWarn $ "batchUpdateStatus_: " <> tshow (length subs) + withLog "batchUpdateStatus_" st $ forM_ subs . uncurry . logSubscriptionStatus + pure $ fromIntegral $ length subs + where + getSMPServerId db = + firstRow fromOnly AUTH $ + DB.query + db + [sql| + SELECT smp_server_id + FROM smp_servers + WHERE smp_host = ? AND smp_port = ? AND smp_keyhash = ? + |] + (srvToRow srv) + +batchUpdateSubStatus :: NtfPostgresStore -> NonEmpty NtfSubRec -> NtfSubStatus -> IO Int64 +batchUpdateSubStatus st subs status = + fmap (fromRight (-1)) $ withDB' "batchUpdateSubStatus" st $ \db -> do + let params = L.toList $ L.map (\NtfSubRec {ntfSubId} -> (status, ntfSubId)) subs + subIds <- + DB.returning + db + [sql| + UPDATE subscriptions s + SET status = upd.status + FROM (VALUES(?, ?)) AS upd(status, subscription_id) + WHERE s.subscription_id = (upd.subscription_id :: BYTEA) + AND s.status != upd.status + RETURNING s.subscription_id + |] + params + -- TODO [ntfdb] below is equivalent without using executeMany - see comment above. + -- let params = L.toList $ L.map (\NtfSubRec {ntfSubId} -> (status, ntfSubId, status)) subs + -- subIds <- + -- fmap catMaybes $ forM params $ + -- maybeFirstRow id . DB.query db "UPDATE subscriptions SET status = ? WHERE subscription_id = ? AND status != ? RETURNING subscription_id" + -- logWarn $ "batchUpdateSubStatus: " <> tshow (length subIds) + withLog "batchUpdateSubStatus" st $ \sl -> + forM_ subIds $ \(Only subId) -> logSubscriptionStatus sl subId status + pure $ fromIntegral $ length subIds + +addTokenLastNtf :: NtfPostgresStore -> PNMessageData -> IO (Either ErrorType (NtfTknRec, NonEmpty PNMessageData)) +addTokenLastNtf st newNtf = + withDB "addTokenLastNtf" st $ \db -> runExceptT $ do + (tkn@NtfTknRec {ntfTknId = tId, tknStatus}, sId) <- + ExceptT $ firstRow toTokenSubId AUTH $ + DB.query + db + [sql| + SELECT t.token_id, t.push_provider, t.push_provider_token, t.status, t.verify_key, t.dh_priv_key, t.dh_secret, t.reg_code, t.cron_interval, t.updated_at, + s.subscription_id + FROM tokens t + JOIN subscriptions s ON s.token_id = t.token_id + JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + WHERE p.smp_host = ? AND p.smp_port = ? AND p.smp_keyhash = ? AND s.smp_notifier_id = ? + FOR UPDATE OF t, s + |] + (smpQueueToRow q) + unless (tknStatus == NTActive) $ throwE AUTH + lastNtfs_ <- + liftIO $ map toLastNtf <$> + DB.query + db + [sql| + WITH new AS ( + INSERT INTO last_notifications(token_id, subscription_id, sent_at, nmsg_nonce, nmsg_data) + VALUES (?,?,?,?,?) + ON CONFLICT (token_id, subscription_id) + DO UPDATE SET + sent_at = EXCLUDED.sent_at, + nmsg_nonce = EXCLUDED.nmsg_nonce, + nmsg_data = EXCLUDED.nmsg_data + ), + last AS ( + SELECT token_ntf_id, subscription_id, sent_at, nmsg_nonce, nmsg_data + FROM last_notifications + WHERE token_id = ? + ORDER BY sent_at DESC + LIMIT ? + ), + delete AS ( + DELETE FROM last_notifications + WHERE token_id = ? + AND sent_at < (SELECT min(sent_at) FROM last) + ) + SELECT p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id, + l.sent_at, l.nmsg_nonce, l.nmsg_data + FROM last l + JOIN subscriptions s ON s.subscription_id = l.subscription_id + JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + ORDER BY token_ntf_id DESC + |] + (tId, sId, ntfTs, nmsgNonce, Binary encNMsgMeta, tId, maxNtfs, tId) + let lastNtfs = fromMaybe (newNtf :| []) (L.nonEmpty lastNtfs_) + pure (tkn, lastNtfs) + where + maxNtfs = 6 :: Int + PNMessageData {smpQueue = q, ntfTs, nmsgNonce, encNMsgMeta} = newNtf + toTokenSubId :: NtfTknRow :. Only NtfSubscriptionId -> (NtfTknRec, NtfSubscriptionId) + toTokenSubId (tknRow :. Only sId) = (rowToNtfTkn tknRow, sId) + +toLastNtf :: SMPQueueNtfRow :. (SystemTime, C.CbNonce, Binary EncNMsgMeta) -> PNMessageData +toLastNtf (qRow :. (ts, nonce, Binary encMeta)) = + PNMessageData {smpQueue = rowToSMPQueue qRow, ntfTs = ts, nmsgNonce = nonce, encNMsgMeta = encMeta} + +importNtfSTMStore :: NtfPostgresStore -> NtfSTMStore -> IO (Int64, Int64, Int64) +importNtfSTMStore NtfPostgresStore {dbStore = s} stmStore = do + (tCnt, tIds) <- importTokens + sCnt <- importSubscriptions tIds + nCnt <- importLastNtfs + pure (tCnt, sCnt, nCnt) + where + importTokens = do + allTokens <- M.elems <$> readTVarIO (tokens stmStore) + tokens <- filterTokens allTokens + let skipped = length allTokens - length tokens + when (skipped /= 0) $ putStrLn $ "Total skipped tokens " <> show skipped + tCnt <- withConnection s $ \db -> foldM (insertToken db) 0 tokens + void $ checkCount "token" (length tokens) tCnt + let tokenIds = S.fromList $ map (\NtfTknData {ntfTknId} -> ntfTknId) tokens + pure (tCnt, tokenIds) + where + filterTokens tokens = do + let deviceTokens = foldl' (\m t -> M.alter (Just . (t :) . fromMaybe []) (tokenKey t) m) M.empty tokens + tokenSubs <- readTVarIO (tokenSubscriptions stmStore) + filterM (keepTokenRegistration deviceTokens tokenSubs) tokens + tokenKey NtfTknData {token, tknVerifyKey} = strEncode token <> ":" <> C.toPubKey C.pubKeyBytes tknVerifyKey + keepTokenRegistration deviceTokens tokenSubs tkn@NtfTknData {ntfTknId, token, tknStatus} = + case M.lookup (tokenKey tkn) deviceTokens of + Just ts + | length ts >= 2 -> + readTVarIO tknStatus >>= \case + NTConfirmed -> do + anyActive <- anyM $ map (\NtfTknData {tknStatus = tknStatus'} -> (NTActive ==) <$> readTVarIO tknStatus') ts + noSubs <- S.null <$> maybe (pure S.empty) readTVarIO (M.lookup ntfTknId tokenSubs) + if anyActive + then ( + if noSubs + then False <$ putStrLn ("Skipped inactive token " <> enc ntfTknId <> " (no subscriptions)") + else pure True + ) + else do + let noSubsStr = if noSubs then " no subscriptions" else " has subscriptions" + putStrLn $ "Error: more than one registration for token " <> enc ntfTknId <> " " <> show token <> noSubsStr + pure True + _ -> pure True + | otherwise -> pure True + Nothing -> True <$ putStrLn "Error: no device token in lookup map" + insertToken db !n tkn@NtfTknData {ntfTknId} = do + tknRow <- ntfTknToRow <$> mkTknRec tkn + (DB.execute db insertNtfTknQuery tknRow >>= pure . (n + )) `E.catch` \(e :: E.SomeException) -> + putStrLn ("Error inserting token " <> enc ntfTknId <> " " <> show e) $> n + importSubscriptions tIds = do + allSubs <- M.elems <$> readTVarIO (subscriptions stmStore) + let subs = filter (\NtfSubData {tokenId} -> S.member tokenId tIds) allSubs + skipped = length allSubs - length subs + when (skipped /= 0) $ putStrLn $ "Skipped subscriptions (no tokens) " <> show skipped + srvIds <- importServers subs + putStrLn $ "Importing " <> show (length subs) <> " subscriptions..." + -- uncomment this line instead of the next 2 lines to import subs one by one. + (sCnt, missingTkns) <- withConnection s $ \db -> foldM (importSub db srvIds) (0, M.empty) subs + -- sCnt <- foldM (importSubs srvIds) 0 $ toChunks 100000 subs + -- let missingTkns = M.empty + putStrLn $ "Imported " <> show sCnt <> " subscriptions" + unless (M.null missingTkns) $ do + putStrLn $ show (M.size missingTkns) <> " missing tokens:" + forM_ (M.assocs missingTkns) $ \(tId, sIds) -> + putStrLn $ "Token " <> enc tId <> " " <> show (length sIds) <> " subscriptions: " <> intercalate ", " (map enc sIds) + checkCount "subscription" (length subs) sCnt + where + importSubs srvIds !n subs = do + rows <- mapM (ntfSubRow srvIds) subs + cnt <- withConnection s $ \db -> DB.executeMany db insertNtfSubQuery $ L.toList rows + let n' = n + cnt + putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" + hFlush stdout + pure n' + importSub db srvIds (!n, !missingTkns) sub@NtfSubData {ntfSubId = sId, tokenId} = do + subRow <- ntfSubRow srvIds sub + E.try (DB.execute db insertNtfSubQuery subRow) >>= \case + Right i -> do + let n' = n + i + when (n' `mod` 100000 == 0) $ do + putStr $ "Imported " <> show n' <> " subscriptions" <> "\r" + hFlush stdout + pure (n', missingTkns) + Left (e :: E.SomeException) -> do + when (n `mod` 100000 == 0) $ putStrLn "" + putStrLn $ "Error inserting subscription " <> enc sId <> " for token " <> enc tokenId <> " " <> show e + pure (n, M.alter (Just . (sId :) . fromMaybe []) tokenId missingTkns) + ntfSubRow srvIds sub = case M.lookup srv srvIds of + Just sId -> ntfSubToRow sId <$> mkSubRec sub + Nothing -> E.throwIO $ userError $ "no matching server ID for server " <> show srv + where + srv = ntfSubServer sub + importServers subs = do + sIds <- withConnection s $ \db -> map fromOnly <$> DB.returning db srvQuery (map srvToRow srvs) + void $ checkCount "server" (length srvs) (length sIds) + pure $ M.fromList $ zip srvs sIds + where + srvQuery = "INSERT INTO smp_servers (smp_host, smp_port, smp_keyhash) VALUES (?, ?, ?) RETURNING smp_server_id" + srvs = nubOrd $ map ntfSubServer subs + importLastNtfs = do + subLookup <- readTVarIO $ subscriptionLookup stmStore + ntfRows <- fmap concat . mapM (lastNtfRows subLookup) . M.assocs =<< readTVarIO (tokenLastNtfs stmStore) + nCnt <- withConnection s $ \db -> DB.executeMany db lastNtfQuery ntfRows + checkCount "last notification" (length ntfRows) nCnt + where + lastNtfQuery = "INSERT INTO last_notifications(token_id, subscription_id, sent_at, nmsg_nonce, nmsg_data) VALUES (?,?,?,?,?)" + lastNtfRows :: M.Map SMPQueueNtf NtfSubscriptionId -> (NtfTokenId, TVar (NonEmpty PNMessageData)) -> IO [(NtfTokenId, NtfSubscriptionId, SystemTime, C.CbNonce, Binary ByteString)] + lastNtfRows subLookup (tId, ntfs) = fmap catMaybes . mapM ntfRow . L.toList =<< readTVarIO ntfs + where + ntfRow PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} = case M.lookup smpQueue subLookup of + Just ntfSubId -> pure $ Just (tId, ntfSubId, ntfTs, nmsgNonce, Binary encNMsgMeta) + Nothing -> Nothing <$ putStrLn ("Error: no subscription " <> show smpQueue <> " for notification of token " <> enc tId) + checkCount name expected inserted + | fromIntegral expected == inserted = do + putStrLn $ "Imported " <> show inserted <> " " <> name <> "s." + pure inserted + | otherwise = do + putStrLn $ "Incorrect " <> name <> " count: expected " <> show expected <> ", imported " <> show inserted + putStrLn "Import aborted, fix data and repeat" + exitFailure + enc = B.unpack . B64.encode . unEntityId + +exportNtfDbStore :: NtfPostgresStore -> FilePath -> IO (Int, Int, Int) +exportNtfDbStore NtfPostgresStore {dbStoreLog = Nothing} _ = + putStrLn "Internal error: export requires store log" >> exitFailure +exportNtfDbStore NtfPostgresStore {dbStore = s, dbStoreLog = Just sl} lastNtfsFile = + (,,) <$> exportTokens <*> exportSubscriptions <*> exportLastNtfs + where + exportTokens = + withConnection s $ \db -> DB.fold_ db ntfTknQuery 0 $ \ !i tkn -> + logCreateToken sl (rowToNtfTkn tkn) $> (i + 1) + exportSubscriptions = + withConnection s $ \db -> DB.fold_ db ntfSubQuery 0 $ \ !i sub -> + logCreateSubscription sl (toNtfSub sub) $> (i + 1) + where + ntfSubQuery = + [sql| + SELECT s.token_id, s.subscription_id, s.smp_notifier_key, s.status, + p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id + FROM subscriptions s + JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + |] + toNtfSub :: Only NtfTokenId :. NtfSubRow :. SMPQueueNtfRow -> NtfSubRec + toNtfSub (Only tokenId :. (ntfSubId, notifierKey, subStatus) :. qRow) = + let smpQueue = rowToSMPQueue qRow + in NtfSubRec {ntfSubId, tokenId, smpQueue, notifierKey, subStatus} + exportLastNtfs = + withFile lastNtfsFile WriteMode $ \h -> + withConnection s $ \db -> DB.fold_ db lastNtfsQuery 0 $ \ !i (Only tknId :. ntfRow) -> + B.hPutStr h (encodeLastNtf tknId $ toLastNtf ntfRow) $> (i + 1) + where + -- Note that the order here is ascending, to be compatible with how it is imported + lastNtfsQuery = + [sql| + SELECT s.token_id, p.smp_host, p.smp_port, p.smp_keyhash, s.smp_notifier_id, + n.sent_at, n.nmsg_nonce, n.nmsg_data + FROM last_notifications n + JOIN subscriptions s ON s.subscription_id = n.subscription_id + JOIN smp_servers p ON p.smp_server_id = s.smp_server_id + ORDER BY token_ntf_id ASC + |] + encodeLastNtf tknId ntf = strEncode (TNMRv1 tknId ntf) `B.snoc` '\n' + +withDB' :: String -> NtfPostgresStore -> (DB.Connection -> IO a) -> IO (Either ErrorType a) +withDB' op st action = withDB op st $ fmap Right . action + +withDB :: forall a. String -> NtfPostgresStore -> (DB.Connection -> IO (Either ErrorType a)) -> IO (Either ErrorType a) +withDB op st action = + E.uninterruptibleMask_ $ E.try (withTransaction (dbStore st) action) >>= either logErr pure + where + logErr :: E.SomeException -> IO (Either ErrorType a) + logErr e = logError ("STORE: " <> T.pack err) $> Left (STORE err) + where + err = op <> ", withDB, " <> show e + +withLog :: MonadIO m => String -> NtfPostgresStore -> (StoreLog 'WriteMode -> IO ()) -> m () +withLog op NtfPostgresStore {dbStoreLog} = withLog_ op dbStoreLog +{-# INLINE withLog #-} + +assertUpdated :: Int64 -> Either ErrorType () +assertUpdated 0 = Left AUTH +assertUpdated _ = Right () + +-- TODO [ntfdb] change instance and maybe field type to not round to a second, for more reliable sorting of the most recent notifications +instance FromField SystemTime where fromField f = fmap (`MkSystemTime` 0) . fromField f + +instance ToField SystemTime where toField = toField . systemSeconds + +instance FromField NtfSubStatus where fromField = fromTextField_ $ either (const Nothing) Just . smpDecode . encodeUtf8 + +instance ToField NtfSubStatus where toField = toField . decodeLatin1 . smpEncode + +#if !defined(dbPostgres) +instance FromField PushProvider where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 + +instance ToField PushProvider where toField = toField . decodeLatin1 . strEncode + +instance FromField NtfTknStatus where fromField = fromTextField_ $ either (const Nothing) Just . smpDecode . encodeUtf8 + +instance ToField NtfTknStatus where toField = toField . decodeLatin1 . smpEncode + +instance FromField (C.PrivateKey 'C.X25519) where fromField = blobFieldDecoder C.decodePrivKey + +instance ToField (C.PrivateKey 'C.X25519) where toField = toField . Binary . C.encodePrivKey + +instance FromField C.APrivateAuthKey where fromField = blobFieldDecoder C.decodePrivKey + +instance ToField C.APrivateAuthKey where toField = toField . Binary . C.encodePrivKey + +instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 + +instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1 . strEncode + +instance FromField C.KeyHash where fromField = blobFieldDecoder $ parseAll strP + +instance ToField C.KeyHash where toField = toField . Binary . strEncode + +instance FromField C.CbNonce where fromField = blobFieldDecoder $ parseAll smpP + +instance ToField C.CbNonce where toField = toField . Binary . smpEncode +#endif diff --git a/src/Simplex/Messaging/Notifications/Server/Store/Types.hs b/src/Simplex/Messaging/Notifications/Server/Store/Types.hs new file mode 100644 index 000000000..802906386 --- /dev/null +++ b/src/Simplex/Messaging/Notifications/Server/Store/Types.hs @@ -0,0 +1,109 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Notifications.Server.Store.Types where + +import Control.Applicative (optional) +import Control.Concurrent.STM +import qualified Data.ByteString.Char8 as B +import Data.Word (Word16) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode, NtfSubStatus, NtfSubscriptionId, NtfTokenId, NtfTknStatus, SMPQueueNtf) +import Simplex.Messaging.Notifications.Server.Store (NtfSubData (..), NtfTknData (..)) +import Simplex.Messaging.Protocol (NtfPrivateAuthKey, NtfPublicAuthKey) +import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) + +data NtfTknRec = NtfTknRec + { ntfTknId :: NtfTokenId, + token :: DeviceToken, + tknStatus :: NtfTknStatus, + tknVerifyKey :: NtfPublicAuthKey, + tknDhPrivKey :: C.PrivateKeyX25519, + tknDhSecret :: C.DhSecretX25519, + tknRegCode :: NtfRegCode, + tknCronInterval :: Word16, + tknUpdatedAt :: Maybe RoundedSystemTime + } + deriving (Show) + +mkTknData :: NtfTknRec -> IO NtfTknData +mkTknData NtfTknRec {ntfTknId, token, tknStatus = status, tknVerifyKey, tknDhPrivKey = pk, tknDhSecret, tknRegCode, tknCronInterval = cronInt, tknUpdatedAt = updatedAt} = do + tknStatus <- newTVarIO status + tknCronInterval <- newTVarIO cronInt + tknUpdatedAt <- newTVarIO updatedAt + let tknDhKeys = (C.publicKey pk, pk) + pure NtfTknData {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} + +mkTknRec :: NtfTknData -> IO NtfTknRec +mkTknRec NtfTknData {ntfTknId, token, tknStatus = status, tknVerifyKey, tknDhKeys = (_, tknDhPrivKey), tknDhSecret, tknRegCode, tknCronInterval = cronInt, tknUpdatedAt = updatedAt} = do + tknStatus <- readTVarIO status + tknCronInterval <- readTVarIO cronInt + tknUpdatedAt <- readTVarIO updatedAt + pure NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} + +instance StrEncoding NtfTknRec where + strEncode NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey = pk, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} = + B.unwords + [ "tknId=" <> strEncode ntfTknId, + "token=" <> strEncode token, + "tokenStatus=" <> strEncode tknStatus, + "verifyKey=" <> strEncode tknVerifyKey, + "dhKeys=" <> strEncode (C.publicKey pk, pk), + "dhSecret=" <> strEncode tknDhSecret, + "regCode=" <> strEncode tknRegCode, + "cron=" <> strEncode tknCronInterval + ] + <> maybe "" updatedAtStr tknUpdatedAt + where + updatedAtStr t = " updatedAt=" <> strEncode t + strP = do + ntfTknId <- "tknId=" *> strP_ + token <- "token=" *> strP_ + tknStatus <- "tokenStatus=" *> strP_ + tknVerifyKey <- "verifyKey=" *> strP_ + (_ :: C.PublicKeyX25519, tknDhPrivKey) <- "dhKeys=" *> strP_ + tknDhSecret <- "dhSecret=" *> strP_ + tknRegCode <- "regCode=" *> strP_ + tknCronInterval <- "cron=" *> strP + tknUpdatedAt <- optional $ " updatedAt=" *> strP + pure NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhPrivKey, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} + +data NtfSubRec = NtfSubRec + { ntfSubId :: NtfSubscriptionId, + smpQueue :: SMPQueueNtf, + notifierKey :: NtfPrivateAuthKey, + tokenId :: NtfTokenId, + subStatus :: NtfSubStatus + } + deriving (Show) + +mkSubData :: NtfSubRec -> IO NtfSubData +mkSubData NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do + subStatus <- newTVarIO status + pure NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} + +mkSubRec :: NtfSubData -> IO NtfSubRec +mkSubRec NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do + subStatus <- readTVarIO status + pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} + +instance StrEncoding NtfSubRec where + strEncode NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} = + B.unwords + [ "subId=" <> strEncode ntfSubId, + "smpQueue=" <> strEncode smpQueue, + "notifierKey=" <> strEncode notifierKey, + "tknId=" <> strEncode tokenId, + "subStatus=" <> strEncode subStatus + ] + strP = do + ntfSubId <- "subId=" *> strP_ + smpQueue <- "smpQueue=" *> strP_ + notifierKey <- "notifierKey=" *> strP_ + tokenId <- "tknId=" *> strP_ + subStatus <- "subStatus=" *> strP + pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} diff --git a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs index fa0ae373c..87c09826e 100644 --- a/src/Simplex/Messaging/Notifications/Server/StoreLog.hs +++ b/src/Simplex/Messaging/Notifications/Server/StoreLog.hs @@ -10,7 +10,7 @@ module Simplex.Messaging.Notifications.Server.StoreLog ( StoreLog, NtfStoreLogRecord (..), - readWriteNtfStore, + readWriteNtfSTMStore, logCreateToken, logTokenStatus, logUpdateToken, @@ -24,23 +24,19 @@ module Simplex.Messaging.Notifications.Server.StoreLog ) where -import Control.Applicative (optional) import Control.Concurrent.STM -import Control.Logger.Simple import Control.Monad import qualified Data.Attoparsec.ByteString.Char8 as A +import qualified Data.ByteString.Base64.URL as B64 import qualified Data.ByteString.Char8 as B -import qualified Data.ByteString.Lazy.Char8 as LB -import qualified Data.Text as T import Data.Word (Word16) -import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Server.Store -import Simplex.Messaging.Protocol (NtfPrivateAuthKey) +import Simplex.Messaging.Notifications.Server.Store.Types +import Simplex.Messaging.Protocol (EntityId (..)) import Simplex.Messaging.Server.QueueStore (RoundedSystemTime) import Simplex.Messaging.Server.StoreLog -import Simplex.Messaging.Util (safeDecodeUtf8) import System.IO data NtfStoreLogRecord @@ -55,52 +51,6 @@ data NtfStoreLogRecord | DeleteSubscription NtfSubscriptionId deriving (Show) -data NtfTknRec = NtfTknRec - { ntfTknId :: NtfTokenId, - token :: DeviceToken, - tknStatus :: NtfTknStatus, - tknVerifyKey :: C.APublicAuthKey, - tknDhKeys :: C.KeyPair 'C.X25519, - tknDhSecret :: C.DhSecretX25519, - tknRegCode :: NtfRegCode, - tknCronInterval :: Word16, - tknUpdatedAt :: Maybe RoundedSystemTime - } - deriving (Show) - -mkTknData :: NtfTknRec -> IO NtfTknData -mkTknData NtfTknRec {ntfTknId, token, tknStatus = status, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval = cronInt, tknUpdatedAt = updatedAt} = do - tknStatus <- newTVarIO status - tknCronInterval <- newTVarIO cronInt - tknUpdatedAt <- newTVarIO updatedAt - pure NtfTknData {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} - -mkTknRec :: NtfTknData -> IO NtfTknRec -mkTknRec NtfTknData {ntfTknId, token, tknStatus = status, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval = cronInt, tknUpdatedAt = updatedAt} = do - tknStatus <- readTVarIO status - tknCronInterval <- readTVarIO cronInt - tknUpdatedAt <- readTVarIO updatedAt - pure NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} - -data NtfSubRec = NtfSubRec - { ntfSubId :: NtfSubscriptionId, - smpQueue :: SMPQueueNtf, - notifierKey :: NtfPrivateAuthKey, - tokenId :: NtfTokenId, - subStatus :: NtfSubStatus - } - deriving (Show) - -mkSubData :: NtfSubRec -> IO NtfSubData -mkSubData NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do - subStatus <- newTVarIO status - pure NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} - -mkSubRec :: NtfSubData -> STM NtfSubRec -mkSubRec NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do - subStatus <- readTVar status - pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} - instance StrEncoding NtfStoreLogRecord where strEncode = \case CreateToken tknRec -> strEncode (Str "TCREATE", tknRec) @@ -125,56 +75,12 @@ instance StrEncoding NtfStoreLogRecord where "SDELETE " *> (DeleteSubscription <$> strP) ] -instance StrEncoding NtfTknRec where - strEncode NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} = - B.unwords - [ "tknId=" <> strEncode ntfTknId, - "token=" <> strEncode token, - "tokenStatus=" <> strEncode tknStatus, - "verifyKey=" <> strEncode tknVerifyKey, - "dhKeys=" <> strEncode tknDhKeys, - "dhSecret=" <> strEncode tknDhSecret, - "regCode=" <> strEncode tknRegCode, - "cron=" <> strEncode tknCronInterval - ] - <> maybe "" updatedAtStr tknUpdatedAt - where - updatedAtStr t = " updatedAt=" <> strEncode t - strP = do - ntfTknId <- "tknId=" *> strP_ - token <- "token=" *> strP_ - tknStatus <- "tokenStatus=" *> strP_ - tknVerifyKey <- "verifyKey=" *> strP_ - tknDhKeys <- "dhKeys=" *> strP_ - tknDhSecret <- "dhSecret=" *> strP_ - tknRegCode <- "regCode=" *> strP_ - tknCronInterval <- "cron=" *> strP - tknUpdatedAt <- optional $ " updatedAt=" *> strP - pure NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval, tknUpdatedAt} - -instance StrEncoding NtfSubRec where - strEncode NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} = - B.unwords - [ "subId=" <> strEncode ntfSubId, - "smpQueue=" <> strEncode smpQueue, - "notifierKey=" <> strEncode notifierKey, - "tknId=" <> strEncode tokenId, - "subStatus=" <> strEncode subStatus - ] - strP = do - ntfSubId <- "subId=" *> strP_ - smpQueue <- "smpQueue=" *> strP_ - notifierKey <- "notifierKey=" *> strP_ - tokenId <- "tknId=" *> strP_ - subStatus <- "subStatus=" *> strP - pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} - logNtfStoreRecord :: StoreLog 'WriteMode -> NtfStoreLogRecord -> IO () logNtfStoreRecord = writeStoreLogRecord {-# INLINE logNtfStoreRecord #-} -logCreateToken :: StoreLog 'WriteMode -> NtfTknData -> IO () -logCreateToken s tkn = logNtfStoreRecord s . CreateToken =<< mkTknRec tkn +logCreateToken :: StoreLog 'WriteMode -> NtfTknRec -> IO () +logCreateToken s = logNtfStoreRecord s . CreateToken logTokenStatus :: StoreLog 'WriteMode -> NtfTokenId -> NtfTknStatus -> IO () logTokenStatus s tknId tknStatus = logNtfStoreRecord s $ TokenStatus tknId tknStatus @@ -191,8 +97,8 @@ logDeleteToken s tknId = logNtfStoreRecord s $ DeleteToken tknId logUpdateTokenTime :: StoreLog 'WriteMode -> NtfTokenId -> RoundedSystemTime -> IO () logUpdateTokenTime s tknId t = logNtfStoreRecord s $ UpdateTokenTime tknId t -logCreateSubscription :: StoreLog 'WriteMode -> NtfSubData -> IO () -logCreateSubscription s sub = logNtfStoreRecord s . CreateSubscription =<< atomically (mkSubRec sub) +logCreateSubscription :: StoreLog 'WriteMode -> NtfSubRec -> IO () +logCreateSubscription s = logNtfStoreRecord s . CreateSubscription logSubscriptionStatus :: StoreLog 'WriteMode -> NtfSubscriptionId -> NtfSubStatus -> IO () logSubscriptionStatus s subId subStatus = logNtfStoreRecord s $ SubscriptionStatus subId subStatus @@ -200,49 +106,54 @@ logSubscriptionStatus s subId subStatus = logNtfStoreRecord s $ SubscriptionStat logDeleteSubscription :: StoreLog 'WriteMode -> NtfSubscriptionId -> IO () logDeleteSubscription s subId = logNtfStoreRecord s $ DeleteSubscription subId -readWriteNtfStore :: FilePath -> NtfStore -> IO (StoreLog 'WriteMode) -readWriteNtfStore = readWriteStoreLog readNtfStore writeNtfStore +readWriteNtfSTMStore :: Bool -> FilePath -> NtfSTMStore -> IO (StoreLog 'WriteMode) +readWriteNtfSTMStore tty = readWriteStoreLog (readNtfStore tty) writeNtfStore -readNtfStore :: FilePath -> NtfStore -> IO () -readNtfStore f st = mapM_ (addNtfLogRecord . LB.toStrict) . LB.lines =<< LB.readFile f +readNtfStore :: Bool -> FilePath -> NtfSTMStore -> IO () +readNtfStore tty f st = readLogLines tty f $ \_ -> processLine where - addNtfLogRecord s = case strDecode s of - Left e -> logError $ "Log parsing error (" <> T.pack e <> "): " <> safeDecodeUtf8 (B.take 100 s) - Right lr -> case lr of - CreateToken r@NtfTknRec {ntfTknId} -> do - tkn <- mkTknData r - atomically $ addNtfToken st ntfTknId tkn - TokenStatus tknId status -> do - tkn_ <- getNtfTokenIO st tknId - forM_ tkn_ $ \tkn@NtfTknData {tknStatus} -> do - atomically $ writeTVar tknStatus status - when (status == NTActive) $ void $ atomically $ removeInactiveTokenRegistrations st tkn - UpdateToken tknId token' tknRegCode -> do - getNtfTokenIO st tknId - >>= mapM_ - ( \tkn@NtfTknData {tknStatus} -> do - atomically $ removeTokenRegistration st tkn - atomically $ writeTVar tknStatus NTRegistered - atomically $ addNtfToken st tknId tkn {token = token', tknRegCode} - ) - TokenCron tknId cronInt -> - getNtfTokenIO st tknId - >>= mapM_ (\NtfTknData {tknCronInterval} -> atomically $ writeTVar tknCronInterval cronInt) - DeleteToken tknId -> - atomically $ void $ deleteNtfToken st tknId - UpdateTokenTime tknId t -> - getNtfTokenIO st tknId - >>= mapM_ (\NtfTknData {tknUpdatedAt} -> atomically $ writeTVar tknUpdatedAt $ Just t) - CreateSubscription r@NtfSubRec {ntfSubId} -> do - sub <- mkSubData r - void $ atomically $ addNtfSubscription st ntfSubId sub - SubscriptionStatus subId status -> do - getNtfSubscriptionIO st subId - >>= mapM_ (\NtfSubData {subStatus} -> atomically $ writeTVar subStatus status) - DeleteSubscription subId -> - atomically $ deleteNtfSubscription st subId + processLine s = either printError procNtfLogRecord (strDecode s) + where + printError e = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> B.take 100 s + procNtfLogRecord = \case + CreateToken r@NtfTknRec {ntfTknId} -> do + tkn <- mkTknData r + atomically $ stmAddNtfToken st ntfTknId tkn + TokenStatus tknId status -> do + tkn_ <- stmGetNtfTokenIO st tknId + forM_ tkn_ $ \tkn@NtfTknData {tknStatus} -> do + atomically $ writeTVar tknStatus status + when (status == NTActive) $ void $ atomically $ stmRemoveInactiveTokenRegistrations st tkn + UpdateToken tknId token' tknRegCode -> do + stmGetNtfTokenIO st tknId + >>= mapM_ + ( \tkn@NtfTknData {tknStatus} -> do + atomically $ stmRemoveTokenRegistration st tkn + atomically $ writeTVar tknStatus NTRegistered + atomically $ stmAddNtfToken st tknId tkn {token = token', tknRegCode} + ) + TokenCron tknId cronInt -> + stmGetNtfTokenIO st tknId + >>= mapM_ (\NtfTknData {tknCronInterval} -> atomically $ writeTVar tknCronInterval cronInt) + DeleteToken tknId -> + atomically $ void $ stmDeleteNtfToken st tknId + UpdateTokenTime tknId t -> + stmGetNtfTokenIO st tknId + >>= mapM_ (\NtfTknData {tknUpdatedAt} -> atomically $ writeTVar tknUpdatedAt $ Just t) + CreateSubscription r@NtfSubRec {tokenId, ntfSubId} -> do + sub <- mkSubData r + atomically (stmAddNtfSubscription st ntfSubId sub) >>= \case + Just () -> pure () + Nothing -> B.putStrLn $ "Warning: no token " <> enc tokenId <> ", subscription " <> enc ntfSubId + where + enc = B64.encode . unEntityId + SubscriptionStatus subId status -> do + stmGetNtfSubscriptionIO st subId + >>= mapM_ (\NtfSubData {subStatus} -> atomically $ writeTVar subStatus status) + DeleteSubscription subId -> + atomically $ stmDeleteNtfSubscription st subId -writeNtfStore :: StoreLog 'WriteMode -> NtfStore -> IO () -writeNtfStore s NtfStore {tokens, subscriptions} = do - mapM_ (logCreateToken s) =<< readTVarIO tokens - mapM_ (logCreateSubscription s) =<< readTVarIO subscriptions +writeNtfStore :: StoreLog 'WriteMode -> NtfSTMStore -> IO () +writeNtfStore s NtfSTMStore {tokens, subscriptions} = do + mapM_ (logCreateToken s <=< mkTknRec) =<< readTVarIO tokens + mapM_ (logCreateSubscription s <=< mkSubRec) =<< readTVarIO subscriptions diff --git a/src/Simplex/Messaging/Server/CLI.hs b/src/Simplex/Messaging/Server/CLI.hs index 8592aa228..a678825af 100644 --- a/src/Simplex/Messaging/Server/CLI.hs +++ b/src/Simplex/Messaging/Server/CLI.hs @@ -28,9 +28,10 @@ import Data.X509.Validation (Fingerprint (..)) import Network.Socket (HostName, ServiceName) import Options.Applicative import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (ProtoServerWithAuth (..), ProtocolServer (..), ProtocolTypeI) -import Simplex.Messaging.Server.Env.STM (AServerStoreCfg (..), ServerStoreCfg (..), StorePaths (..)) +import Simplex.Messaging.Server.Env.STM (AServerStoreCfg (..), ServerStoreCfg (..), StartOptions (..), StorePaths (..)) import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) import Simplex.Messaging.Transport (ATransport (..), TLS, Transport (..)) import Simplex.Messaging.Transport.Server (AddHTTP, loadFileFingerprint) @@ -174,6 +175,70 @@ certOptionsP = do ) pure CertOptions {signAlgorithm_, commonName_} +dbOptsP :: DBOpts -> Parser DBOpts +dbOptsP DBOpts {connstr = defDBConnStr, schema = defDBSchema, poolSize = defDBPoolSize} = do + connstr <- + strOption + ( long "database" + <> short 'd' + <> metavar "DB_CONN" + <> help "Database connection string" + <> value defDBConnStr + <> showDefault + ) + schema <- + strOption + ( long "schema" + <> metavar "DB_SCHEMA" + <> help "Database schema" + <> value defDBSchema + <> showDefault + ) + poolSize <- + option + auto + ( long "pool-size" + <> metavar "POOL_SIZE" + <> help "Database pool size" + <> value defDBPoolSize + <> showDefault + ) + pure DBOpts {connstr, schema, poolSize, createSchema = False} + +startOptionsP :: Parser StartOptions +startOptionsP = do + maintenance <- + switch + ( long "maintenance" + <> short 'm' + <> help "Do not start the server, only perform start and stop tasks" + ) + compactLog <- + switch + ( long "compact-log" + <> help "Compact store log (always enabled with `memory` storage for queues)" + ) + skipWarnings <- + switch + ( long "skip-warnings" + <> help "Start the server with non-critical start warnings" + ) + confirmMigrations <- + option + parseConfirmMigrations + ( long "confirm-migrations" + <> metavar "CONFIRM_MIGRATIONS" + <> help "Confirm PostgreSQL database migration: up, down (default is manual confirmation)" + <> value MCConsole + ) + pure StartOptions {maintenance, compactLog, skipWarnings, confirmMigrations} + where + parseConfirmMigrations :: ReadM MigrationConfirmation + parseConfirmMigrations = eitherReader $ \case + "up" -> Right MCYesUp + "down" -> Right MCYesUpDown + _ -> Left "invalid migration confirmation, pass 'up' or 'down'" + genOnline :: FilePath -> CertOptions -> IO () genOnline cfgPath CertOptions {signAlgorithm_, commonName_} = do (signAlgorithm, commonName) <- @@ -294,18 +359,27 @@ iniTransports ini = webPort = T.unpack <$> eitherToMaybe (lookupValue "WEB" "https" ini) ports = map T.unpack . T.splitOn "," -printServerConfig :: [(ServiceName, ATransport, AddHTTP)] -> Maybe FilePath -> IO () -printServerConfig transports logFile = do +iniDBOptions :: Ini -> DBOpts -> DBOpts +iniDBOptions ini _default@DBOpts {connstr, schema, poolSize} = + DBOpts + { connstr = either (const connstr) encodeUtf8 $ lookupValue "STORE_LOG" "db_connection" ini, + schema = either (const schema) encodeUtf8 $ lookupValue "STORE_LOG" "db_schema" ini, + poolSize = readIniDefault poolSize "STORE_LOG" "db_pool_size" ini, + createSchema = False + } + +printServerConfig :: String -> [(ServiceName, ATransport, AddHTTP)] -> Maybe FilePath -> IO () +printServerConfig protocol transports logFile = do putStrLn $ case logFile of Just f -> "Store log: " <> f _ -> "Store log disabled." - printServerTransports transports + printServerTransports protocol transports -printServerTransports :: [(ServiceName, ATransport, AddHTTP)] -> IO () -printServerTransports ts = do +printServerTransports :: String -> [(ServiceName, ATransport, AddHTTP)] -> IO () +printServerTransports protocol ts = do forM_ ts $ \(p, ATransport t, addHTTP) -> do let descr = p <> " (" <> transportName t <> ")..." - putStrLn $ "Serving SMP protocol on port " <> descr + putStrLn $ "Serving " <> protocol <> " protocol on port " <> descr when addHTTP $ putStrLn $ "Serving static site on port " <> descr unless (any (\(p, _, _) -> p == "443") ts) $ putStrLn @@ -314,11 +388,11 @@ printServerTransports ts = do printSMPServerConfig :: [(ServiceName, ATransport, AddHTTP)] -> AServerStoreCfg -> IO () printSMPServerConfig transports (ASSCfg _ _ cfg) = case cfg of - SSCMemory sp_ -> printServerConfig transports $ (\StorePaths {storeLogFile} -> storeLogFile) <$> sp_ - SSCMemoryJournal {storeLogFile} -> printServerConfig transports $ Just storeLogFile + SSCMemory sp_ -> printServerConfig "SMP" transports $ (\StorePaths {storeLogFile} -> storeLogFile) <$> sp_ + SSCMemoryJournal {storeLogFile} -> printServerConfig "SMP" transports $ Just storeLogFile SSCDatabaseJournal {storeCfg = PostgresStoreCfg {dbOpts = DBOpts {connstr, schema}}} -> do B.putStrLn $ "PostgreSQL database: " <> connstr <> ", schema: " <> schema - printServerTransports transports + printServerTransports "SMP" transports deleteDirIfExists :: FilePath -> IO () deleteDirIfExists path = whenM (doesDirectoryExist path) $ removeDirectoryRecursive path diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 5c3e9b6bf..5f3be4a98 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -370,6 +370,7 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp logInfo $ "restoring queues from file " <> T.pack f sl <- readWriteQueueStore False mkQ f st setStoreLog st sl +#if defined(dbServerPostgres) compactDbStoreLog = \case Just f -> do logInfo $ "compacting queues in file " <> T.pack f @@ -381,6 +382,7 @@ newEnv config@ServerConfig {smpCredentials, httpCredentials, serverStoreCfg, smp Nothing -> do logError "Error: `--compact-log` used without `db_store_log` INI option" exitFailure +#endif getCredentials protocol creds = do files <- missingCreds unless (null files) $ do diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 19e80bada..844d8d86d 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -247,13 +247,6 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = where iniStoreQueues = fromRight "memory" $ lookupValue "STORE_LOG" "store_queues" ini iniStoreMessage = fromRight "memory" $ lookupValue "STORE_LOG" "store_messages" ini - iniDBOptions ini = - DBOpts - { connstr = either (const defaultDBConnStr) encodeUtf8 $ lookupValue "STORE_LOG" "db_connection" ini, - schema = either (const defaultDBSchema) encodeUtf8 $ lookupValue "STORE_LOG" "db_schema" ini, - poolSize = readIniDefault defaultDBPoolSize "STORE_LOG" "db_pool_size" ini, - createSchema = False - } iniDeletedTTL ini = readIniDefault (86400 * defaultDeletedTTL) "STORE_LOG" "db_deleted_ttl" ini defaultStaticPath = combine logPath "www" enableStoreLog' = settingIsOn "STORE_LOG" "enable" @@ -411,7 +404,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = ASSCfg SQSMemory SMSJournal $ SSCMemoryJournal {storeLogFile = storeLogFilePath, storeMsgsPath = storeMsgsJournalDir} ASType SQSPostgres SMSJournal -> let dbStoreLogPath = enableDbStoreLog' ini $> storeLogFilePath - storeCfg = PostgresStoreCfg {dbOpts = iniDBOptions ini, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini} + storeCfg = PostgresStoreCfg {dbOpts = iniDBOptions ini defaultDBOpts, dbStoreLogPath, confirmMigrations = MCYesUp, deletedTTL = iniDeletedTTL ini} in ASSCfg SQSPostgres SMSJournal $ SSCDatabaseJournal {storeCfg, storeMsgsPath' = storeMsgsJournalDir}, storeNtfsFile = restoreMessagesFile storeNtfsFilePath, -- allow creating new queues by default @@ -512,7 +505,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = unless (storeLogExists) $ putStrLn $ "store_queues is `memory`, " <> storeLogFilePath <> " file will be created." #if defined(dbServerPostgres) SQSPostgres -> do - let DBOpts {connstr, schema} = iniDBOptions ini + let DBOpts {connstr, schema} = iniDBOptions ini defaultDBOpts schemaExists <- checkSchemaExists connstr schema case enableDbStoreLog' ini of Just () @@ -669,7 +662,7 @@ cliCommandP cfgPath logPath iniFile = <> command "start" (info (Start <$> startOptionsP) (progDesc $ "Start server (configuration: " <> iniFile <> ")")) <> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files")) <> command "journal" (info (Journal <$> journalCmdP) (progDesc "Import/export messages to/from journal storage")) - <> command "database" (info (Database <$> databaseCmdP <*> dbOptsP) (progDesc "Import/export queues to/from PostgreSQL database storage")) + <> command "database" (info (Database <$> databaseCmdP <*> dbOptsP defaultDBOpts) (progDesc "Import/export queues to/from PostgreSQL database storage")) ) where initP :: Parser InitOptions @@ -684,7 +677,7 @@ cliCommandP cfgPath logPath iniFile = <> short 'l' <> help "Enable store log for persistence (DEPRECATED, enabled by default)" ) - dbOptions <- dbOptsP + dbOptions <- dbOptsP defaultDBOpts logStats <- switch ( long "daily-stats" @@ -815,32 +808,6 @@ cliCommandP cfgPath logPath iniFile = disableWeb, scripted } - startOptionsP = do - maintenance <- - switch - ( long "maintenance" - <> short 'm' - <> help "Do not start the server, only perform start and stop tasks" - ) - compactLog <- - switch - ( long "compact-log" - <> help "Compact store log (always enabled with `memory` storage for queues)" - ) - skipWarnings <- - switch - ( long "skip-warnings" - <> help "Start the server with non-critical start warnings" - ) - confirmMigrations <- - option - parseConfirmMigrations - ( long "confirm-migrations" - <> metavar "CONFIRM_MIGRATIONS" - <> help "Confirm PostgreSQL database migration: up, down (default is manual confirmation)" - <> value MCConsole - ) - pure StartOptions {maintenance, compactLog, skipWarnings, confirmMigrations} journalCmdP = storeCmdP "message log file" "journal storage" databaseCmdP = storeCmdP "queue store log file" "PostgreSQL database schema" storeCmdP src dest = @@ -849,39 +816,6 @@ cliCommandP cfgPath logPath iniFile = <> command "export" (info (pure SCExport) (progDesc $ "Export " <> dest <> " to " <> src)) <> command "delete" (info (pure SCDelete) (progDesc $ "Delete " <> dest)) ) - dbOptsP = do - connstr <- - strOption - ( long "database" - <> short 'd' - <> metavar "DB_CONN" - <> help "Database connection string" - <> value defaultDBConnStr - <> showDefault - ) - schema <- - strOption - ( long "schema" - <> metavar "DB_SCHEMA" - <> help "Database schema" - <> value defaultDBSchema - <> showDefault - ) - poolSize <- - option - auto - ( long "pool-size" - <> metavar "POOL_SIZE" - <> help "Database pool size" - <> value defaultDBPoolSize - <> showDefault - ) - pure DBOpts {connstr, schema, poolSize, createSchema = False} - parseConfirmMigrations :: ReadM MigrationConfirmation - parseConfirmMigrations = eitherReader $ \case - "up" -> Right MCYesUp - "down" -> Right MCYesUpDown - _ -> Left "invalid migration confirmation, pass 'up' or 'down'" parseBasicAuth :: ReadM ServerPassword parseBasicAuth = eitherReader $ fmap ServerPassword . strDecode . B.pack entityP :: String -> String -> String -> Parser (Maybe Entity, Maybe Text) diff --git a/src/Simplex/Messaging/Server/Main/Init.hs b/src/Simplex/Messaging/Server/Main/Init.hs index 4c218c5cc..7b1b320be 100644 --- a/src/Simplex/Messaging/Server/Main/Init.hs +++ b/src/Simplex/Messaging/Server/Main/Init.hs @@ -4,11 +4,9 @@ module Simplex.Messaging.Server.Main.Init where -import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import qualified Data.List.NonEmpty as L import Data.Maybe (fromMaybe, isNothing) -import Numeric.Natural (Natural) import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1) @@ -27,14 +25,14 @@ import System.FilePath (()) defaultControlPort :: Int defaultControlPort = 5224 -defaultDBConnStr :: ByteString -defaultDBConnStr = "postgresql://smp@/smp_server_store" - -defaultDBSchema :: ByteString -defaultDBSchema = "smp_server" - -defaultDBPoolSize :: Natural -defaultDBPoolSize = 10 +defaultDBOpts :: DBOpts +defaultDBOpts = + DBOpts + { connstr = "postgresql://smp@/smp_server_store", + schema = "smp_server", + poolSize = 10, + createSchema = False + } -- time to retain deleted queues in the database (days), for debugging defaultDeletedTTL :: Int64 @@ -77,13 +75,11 @@ iniFileContent cfgPath logPath opts host basicAuth controlPortPwds = \# `database`- PostgreSQL databass (requires `store_messages: journal`).\n\ \store_queues: memory\n\n\ \# Database connection settings for PostgreSQL database (`store_queues: database`).\n" - <> (optDisabled' (connstr == defaultDBConnStr) <> "db_connection: " <> safeDecodeUtf8 connstr <> "\n") - <> (optDisabled' (schema == defaultDBSchema) <> "db_schema: " <> safeDecodeUtf8 schema <> "\n") - <> (optDisabled' (poolSize == defaultDBPoolSize) <> "db_pool_size: " <> tshow poolSize <> "\n\n") + <> iniDbOpts dbOptions defaultDBOpts <> "# Write database changes to store log file\n\ \# db_store_log: off\n\n\ \# Time to retain deleted queues in the database, days.\n" - <> ("db_deleted_ttl: " <> tshow defaultDeletedTTL <> "\n\n") + <> ("# db_deleted_ttl: " <> tshow defaultDeletedTTL <> "\n\n") <> "# Message storage mode: `memory` or `journal`.\n\ \store_messages: memory\n\n\ \# When store_messages is `memory`, undelivered messages are optionally saved and restored\n\ @@ -164,7 +160,6 @@ iniFileContent cfgPath logPath opts host basicAuth controlPortPwds = <> (webDisabled <> "key: " <> T.pack httpsKeyFile <> "\n") where InitOptions {enableStoreLog, dbOptions, socksProxy, ownDomains, controlPort, webStaticPath, disableWeb, logStats} = opts - DBOpts {connstr, schema, poolSize} = dbOptions defaultServerPorts = "5223,443" defaultStaticPath = logPath "www" httpsCertFile = cfgPath "web.crt" @@ -221,6 +216,12 @@ informationIniContent InitOptions {sourceCode, serverInfo} = <> "\n" <> countryStr optName (country =<< entity) +iniDbOpts :: DBOpts -> DBOpts -> Text +iniDbOpts DBOpts {connstr, schema, poolSize} DBOpts {connstr = defConnstr, schema = defSchema, poolSize = defPoolSize} = + (optDisabled' (connstr == defConnstr) <> "db_connection: " <> safeDecodeUtf8 connstr <> "\n") + <> (optDisabled' (schema == defSchema) <> "db_schema: " <> safeDecodeUtf8 schema <> "\n") + <> (optDisabled' (poolSize == defPoolSize) <> "db_pool_size: " <> tshow poolSize <> "\n\n") + optDisabled :: Maybe a -> Text optDisabled = optDisabled' . isNothing {-# INLINE optDisabled #-} diff --git a/src/Simplex/Messaging/Server/NtfStore.hs b/src/Simplex/Messaging/Server/NtfStore.hs index 7895f64e9..383fe014e 100644 --- a/src/Simplex/Messaging/Server/NtfStore.hs +++ b/src/Simplex/Messaging/Server/NtfStore.hs @@ -28,7 +28,7 @@ data MsgNtf = MsgNtf storeNtf :: NtfStore -> NotifierId -> MsgNtf -> IO () storeNtf (NtfStore ns) nId ntf = do TM.lookupIO nId ns >>= atomically . maybe newNtfs (`modifyTVar'` (ntf :)) - -- TODO coalesce messages here once the client is updated to process multiple messages + -- TODO [ntfdb] coalesce messages here once the client is updated to process multiple messages -- for single notification. -- when (isJust prevNtf) $ incStat $ msgNtfReplaced stats where diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index a4625b2f7..38158313d 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -23,6 +23,8 @@ module Simplex.Messaging.Server.QueueStore.Postgres PostgresStoreCfg (..), batchInsertQueues, foldQueueRecs, + handleDuplicate, + withLog_, ) where @@ -56,6 +58,7 @@ import Database.PostgreSQL.Simple.SqlQQ (sql) import GHC.IO (catchAny) import Simplex.Messaging.Agent.Client (withLockMap) import Simplex.Messaging.Agent.Lock (Lock) +import Simplex.Messaging.Agent.Store.AgentStore () import Simplex.Messaging.Agent.Store.Postgres (createDBStore, closeDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) @@ -530,8 +533,12 @@ withDB op st action = err = op <> ", withDB, " <> show e withLog :: MonadIO m => String -> PostgresQueueStore q -> (StoreLog 'WriteMode -> IO ()) -> m () -withLog op PostgresQueueStore {dbStoreLog} action = - forM_ dbStoreLog $ \sl -> liftIO $ action sl `catchAny` \e -> +withLog op PostgresQueueStore {dbStoreLog} = withLog_ op dbStoreLog +{-# INLINE withLog #-} + +withLog_ :: MonadIO m => String -> Maybe (StoreLog 'WriteMode) -> (StoreLog 'WriteMode -> IO ()) -> m () +withLog_ op sl_ action = + forM_ sl_ $ \sl -> liftIO $ action sl `catchAny` \e -> logWarn $ "STORE: " <> T.pack (op <> ", withLog, " <> show e) handleDuplicate :: SqlError -> IO ErrorType @@ -541,15 +548,15 @@ handleDuplicate e = case constraintViolation e of -- The orphan instances below are copy-pasted, but here they are defined specifically for PostgreSQL -instance ToField EntityId where toField (EntityId s) = toField $ Binary s - -deriving newtype instance FromField EntityId - instance ToField (NonEmpty C.APublicAuthKey) where toField = toField . Binary . smpEncode instance FromField (NonEmpty C.APublicAuthKey) where fromField = blobFieldDecoder smpDecode #if !defined(dbPostgres) +instance ToField EntityId where toField (EntityId s) = toField $ Binary s + +deriving newtype instance FromField EntityId + instance FromField QueueMode where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8 instance ToField QueueMode where toField = toField . decodeLatin1 . smpEncode diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 0b261672a..07f806b56 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -12,12 +12,12 @@ import AgentTests.ConnectionRequestTests import AgentTests.DoubleRatchetTests (doubleRatchetTests) import AgentTests.FunctionalAPITests (functionalAPITests) import AgentTests.MigrationTests (migrationTests) -import AgentTests.NotificationTests (notificationTests) import AgentTests.ServerChoice (serverChoiceTests) import AgentTests.ShortLinkTests (shortLinkTests) import Simplex.Messaging.Server.Env.STM (AStoreType (..)) import Simplex.Messaging.Transport (ATransport (..)) import Test.Hspec + #if defined(dbPostgres) import Fixtures import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) @@ -25,6 +25,12 @@ import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) import AgentTests.SQLiteTests (storeTests) #endif +#if defined(dbServerPostgres) +import AgentTests.NotificationTests (notificationTests) +import SMPClient (postgressBracket) +import NtfClient (ntfTestServerDBConnectInfo) +#endif + agentCoreTests :: Spec agentCoreTests = do describe "Migration tests" migrationTests @@ -41,7 +47,10 @@ agentTests ps = do #endif describe "Functional API" $ functionalAPITests ps describe "Chosen servers" serverChoiceTests - describe "Notification tests" $ notificationTests ps +#if defined(dbServerPostgres) + around_ (postgressBracket ntfTestServerDBConnectInfo) $ + describe "Notification tests" $ notificationTests ps +#endif #if !defined(dbPostgres) describe "SQLite store" storeTests #endif diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index ea0ebd29b..64196cf3f 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -59,7 +59,7 @@ import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.IO as TIO import NtfClient import SMPAgentClient (agentCfg, initAgentServers, initAgentServers2, testDB, testDB2, testNtfServer, testNtfServer2) -import SMPClient (cfgMS, cfgJ2QS, cfgVPrev, serverStoreConfig, testPort, testPort2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, xit'') +import SMPClient (cfgMS, cfgJ2QS, cfgVPrev, ntfTestPort, ntfTestPort2, serverStoreConfig, testPort, testPort2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn, xit'') import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMessage) import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) @@ -120,10 +120,12 @@ notificationTests ps@(t, _) = do it "should keep working with active token until replaced" $ withAPNSMockServer $ \apns -> testNtfTokenChangeServers t apns - xit'' "should re-register token in NTInvalid status after register attempt" $ + -- TODO [ntfdb] modify database in the test + xit "should re-register token in NTInvalid status after register attempt" $ withAPNSMockServer $ \apns -> testNtfTokenReRegisterInvalid t apns - xit'' "should re-register token in NTInvalid status after checking token" $ + -- TODO [ntfdb] modify database in the test + xit "should re-register token in NTInvalid status after checking token" $ withAPNSMockServer $ \apns -> testNtfTokenReRegisterInvalidOnCheck t apns describe "notification server tests" $ do @@ -163,12 +165,12 @@ notificationTests ps@(t, _) = do it "should keep sending notifications for old token" $ withSmpServer ps $ withAPNSMockServer $ \apns -> - withNtfServerOn t ntfTestPort $ + withNtfServer t $ testNotificationsOldToken apns it "should update server from new token" $ withSmpServer ps $ withAPNSMockServer $ \apns -> - withNtfServerOn t ntfTestPort2 . withNtfServerThreadOn t ntfTestPort $ \ntf -> + withNtfServerOn t ntfTestPort2 ntfTestDBCfg2 . withNtfServerThreadOn t ntfTestPort ntfTestDBCfg $ \ntf -> testNotificationsNewToken apns ntf testNtfMatrix :: HasCallStack => (ATransport, AStoreType) -> (APNSMockServer -> AgentMsgId -> AgentClient -> AgentClient -> IO ()) -> Spec @@ -278,7 +280,7 @@ testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO () testNtfTokenServerRestart t apns = do let tkn = DeviceToken PPApnsTest "abcd" ntfData <- withAgent 1 agentCfg initAgentServers testDB $ \a -> - withNtfServerStoreLog t $ \_ -> runRight $ do + withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -288,7 +290,7 @@ testNtfTokenServerRestart t apns = do withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, -- so that repeat verification happens without restarting the clients, when notification arrives - withNtfServerStoreLog t $ \_ -> runRight_ $ do + withNtfServer t $ runRight_ $ do verification <- ntfData .-> "verification" nonce <- C.cbNonce <$> ntfData .-> "nonce" verifyNtfToken a' tkn nonce verification @@ -299,7 +301,7 @@ testNtfTokenServerRestartReverify :: ATransport -> APNSMockServer -> IO () testNtfTokenServerRestartReverify t apns = do let tkn = DeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a -> do - ntfData <- withNtfServerStoreLog t $ \_ -> runRight $ do + ntfData <- withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -309,11 +311,11 @@ testNtfTokenServerRestartReverify t apns = do nonce <- C.cbNonce <$> ntfData .-> "nonce" Left (BROKER _ NETWORK) <- tryE $ verifyNtfToken a tkn nonce verification pure () - threadDelay 1000000 + threadDelay 1500000 withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, -- so that repeat verification happens without restarting the clients, when notification arrives - withNtfServerStoreLog t $ \_ -> runRight_ $ do + withNtfServer t $ runRight_ $ do NTActive <- registerNtfToken a' tkn NMPeriodic NTActive <- checkNtfToken a' tkn pure () @@ -322,7 +324,7 @@ testNtfTokenServerRestartReverifyTimeout :: ATransport -> APNSMockServer -> IO ( testNtfTokenServerRestartReverifyTimeout t apns = do let tkn = DeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a@AgentClient {agentEnv = Env {store}} -> do - (nonce, verification) <- withNtfServerStoreLog t $ \_ -> runRight $ do + (nonce, verification) <- withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -344,11 +346,11 @@ testNtfTokenServerRestartReverifyTimeout t apns = do (NTConfirmed, Just (NTAVerify code), PPApnsTest, "abcd" :: ByteString) Just NtfToken {ntfTknStatus = NTConfirmed, ntfTknAction = Just (NTAVerify _)} <- withTransaction store getSavedNtfToken pure () - threadDelay 1000000 + threadDelay 1500000 withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token, -- so that repeat verification happens without restarting the clients, when notification arrives - withNtfServerStoreLog t $ \_ -> runRight_ $ do + withNtfServer t $ runRight_ $ do NTActive <- registerNtfToken a' tkn NMPeriodic NTActive <- checkNtfToken a' tkn pure () @@ -357,7 +359,7 @@ testNtfTokenServerRestartReregister :: ATransport -> APNSMockServer -> IO () testNtfTokenServerRestartReregister t apns = do let tkn = DeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a -> - withNtfServerStoreLog t $ \_ -> runRight $ do + withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just _}} <- getMockNotification apns tkn @@ -367,7 +369,7 @@ testNtfTokenServerRestartReregister t apns = do withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, and client might have lost verification notification. -- so that repeat registration happens when client is restarted. - withNtfServerStoreLog t $ \_ -> runRight_ $ do + withNtfServer t $ runRight_ $ do NTRegistered <- registerNtfToken a' tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -381,7 +383,7 @@ testNtfTokenServerRestartReregisterTimeout :: ATransport -> APNSMockServer -> IO testNtfTokenServerRestartReregisterTimeout t apns = do let tkn = DeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers testDB $ \a@AgentClient {agentEnv = Env {store}} -> do - withNtfServerStoreLog t $ \_ -> runRight $ do + withNtfServer t $ runRight $ do NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just _}} <- getMockNotification apns tkn @@ -402,7 +404,7 @@ testNtfTokenServerRestartReregisterTimeout t apns = do withAgent 2 agentCfg initAgentServers testDB $ \a' -> -- server stopped before token is verified, and client might have lost verification notification. -- so that repeat registration happens when client is restarted. - withNtfServerStoreLog t $ \_ -> runRight_ $ do + withNtfServer t $ runRight_ $ do NTRegistered <- registerNtfToken a' tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn @@ -422,8 +424,8 @@ testNtfTokenMultipleServers :: ATransport -> APNSMockServer -> IO () testNtfTokenMultipleServers t apns = do let tkn = DeviceToken PPApnsTest "abcd" withAgent 1 agentCfg initAgentServers2 testDB $ \a -> - withNtfServerThreadOn t ntfTestPort $ \ntf -> - withNtfServerThreadOn t ntfTestPort2 $ \ntf2 -> runRight_ $ do + withNtfServerThreadOn t ntfTestPort ntfTestDBCfg $ \ntf -> + withNtfServerThreadOn t ntfTestPort2 ntfTestDBCfg2 $ \ntf2 -> runRight_ $ do -- register a new token, the agent picks a server and stores its choice NTRegistered <- registerNtfToken a tkn NMPeriodic APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- @@ -444,7 +446,7 @@ testNtfTokenMultipleServers t apns = do testNtfTokenChangeServers :: ATransport -> APNSMockServer -> IO () testNtfTokenChangeServers t apns = - withNtfServerThreadOn t ntfTestPort $ \ntf -> do + withNtfServerThreadOn t ntfTestPort ntfTestDBCfg $ \ntf -> do tkn1 <- withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight $ do tkn <- registerTestToken a "abcd" NMInstant apns NTActive <- checkNtfToken a tkn @@ -467,14 +469,14 @@ testNtfTokenChangeServers t apns = Left BROKER {brokerErr = NETWORK} <- tryError $ registerTestToken a "qwer" NMInstant apns -- ok, it's down for now getTestNtfTokenPort a >>= \port2 -> liftIO $ port2 `shouldBe` ntfTestPort2 -- but the token got updated killThread ntf - withNtfServerOn t ntfTestPort2 $ runRight_ $ do + withNtfServerOn t ntfTestPort2 ntfTestDBCfg2 $ runRight_ $ do liftIO $ threadDelay 1000000 -- for notification server to reconnect tkn <- registerTestToken a "qwer" NMInstant apns checkNtfToken a tkn >>= \r -> liftIO $ r `shouldBe` NTActive testNtfTokenReRegisterInvalid :: ATransport -> APNSMockServer -> IO () testNtfTokenReRegisterInvalid t apns = do - tkn <- withNtfServerStoreLog t $ \_ -> do + tkn <- withNtfServer t $ do withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight $ do tkn <- registerTestToken a "abcd" NMInstant apns NTActive <- checkNtfToken a tkn @@ -482,13 +484,13 @@ testNtfTokenReRegisterInvalid t apns = do threadDelay 250000 -- start server to compact - withNtfServerStoreLog t $ \_ -> pure () + withNtfServer t $ pure () threadDelay 250000 replaceSubstringInFile ntfTestStoreLogFile "tokenStatus=ACTIVE" "tokenStatus=INVALID" threadDelay 250000 - withNtfServerStoreLog t $ \_ -> do + withNtfServer t $ do withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight_ $ do NTInvalid Nothing <- registerNtfToken a tkn NMInstant tkn1 <- registerTestToken a "abcd" NMInstant apns @@ -503,7 +505,7 @@ replaceSubstringInFile filePath oldText newText = do testNtfTokenReRegisterInvalidOnCheck :: ATransport -> APNSMockServer -> IO () testNtfTokenReRegisterInvalidOnCheck t apns = do - tkn <- withNtfServerStoreLog t $ \_ -> do + tkn <- withNtfServer t $ do withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight $ do tkn <- registerTestToken a "abcd" NMInstant apns NTActive <- checkNtfToken a tkn @@ -511,13 +513,13 @@ testNtfTokenReRegisterInvalidOnCheck t apns = do threadDelay 250000 -- start server to compact - withNtfServerStoreLog t $ \_ -> pure () + withNtfServer t $ pure () threadDelay 250000 replaceSubstringInFile ntfTestStoreLogFile "tokenStatus=ACTIVE" "tokenStatus=INVALID" threadDelay 250000 - withNtfServerStoreLog t $ \_ -> do + withNtfServer t $ do withAgent 1 agentCfg initAgentServers testDB $ \a -> runRight_ $ do NTInvalid Nothing <- checkNtfToken a tkn tkn1 <- registerTestToken a "abcd" NMInstant apns @@ -526,7 +528,7 @@ testNtfTokenReRegisterInvalidOnCheck t apns = do testRunNTFServerTests :: ATransport -> NtfServer -> IO (Maybe ProtocolTestFailure) testRunNTFServerTests t srv = - withNtfServerOn t ntfTestPort $ + withNtfServer t $ withAgent 1 agentCfg initAgentServers testDB $ \a -> testProtocolServer a 1 $ ProtoServerWithAuth srv Nothing @@ -567,7 +569,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag threadDelay 500000 suspendAgent alice 0 closeDBStore store - threadDelay 1000000 + threadDelay 1500000 putStrLn "before opening the database from another agent" -- aliceNtf client doesn't have subscription and is allowed to get notification message @@ -575,7 +577,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag (Just SMPMsgMeta {msgFlags = MsgFlags True}) :| _ <- getConnectionMessages aliceNtf [cId] pure () - threadDelay 1000000 + threadDelay 1500000 putStrLn "after closing the database in another agent" reopenDBStore store foregroundAgent alice @@ -753,7 +755,7 @@ testChangeToken apns = withAgent 1 agentCfg initAgentServers testDB2 $ \bob -> d testNotificationsStoreLog :: (ATransport, AStoreType) -> APNSMockServer -> IO () testNotificationsStoreLog ps@(t, _) apns = withAgentClients2 $ \alice bob -> do withSmpServerStoreMsgLogOn ps testPort $ \_ -> do - (aliceId, bobId) <- withNtfServerStoreLog t $ \threadId -> runRight $ do + (aliceId, bobId) <- withNtfServer t $ runRight $ do (aliceId, bobId) <- makeConnection alice bob _ <- registerTestToken alice "abcd" NMInstant apns liftIO $ threadDelay 250000 @@ -762,19 +764,17 @@ testNotificationsStoreLog ps@(t, _) apns = withAgentClients2 $ \alice bob -> do void $ messageNotificationData alice apns get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False ackMessage alice bobId 2 Nothing - liftIO $ killThread threadId pure (aliceId, bobId) liftIO $ threadDelay 250000 - withNtfServerStoreLog t $ \threadId -> runRight_ $ do + withNtfServer t $ runRight_ $ do liftIO $ threadDelay 250000 3 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello again" get bob ##> ("", aliceId, SENT 3) void $ messageNotificationData alice apns get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False ackMessage alice bobId 3 Nothing - liftIO $ killThread threadId runRight_ $ do 4 <- sendMessage bob aliceId (SMP.MsgFlags True) "message 4" @@ -784,7 +784,7 @@ testNotificationsStoreLog ps@(t, _) apns = withAgentClients2 $ \alice bob -> do noNotifications apns withSmpServerStoreMsgLogOn ps testPort $ \_ -> - withNtfServerStoreLog t $ \_ -> runRight_ $ do + withNtfServer t $ runRight_ $ do void $ messageNotificationData alice apns testNotificationsSMPRestart :: (ATransport, AStoreType) -> APNSMockServer -> IO () diff --git a/tests/CLITests.hs b/tests/CLITests.hs index 7ba2316ca..e5fb784ee 100644 --- a/tests/CLITests.hs +++ b/tests/CLITests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} @@ -19,7 +20,6 @@ import qualified Network.HTTP.Client as H1 import qualified Network.HTTP2.Client as H2 import Simplex.FileTransfer.Server.Main (xftpServerCLI) import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Notifications.Server.Main import Simplex.Messaging.Server.Main (smpServerCLI, smpServerCLI_) import Simplex.Messaging.Transport (TLS (..), defaultSupportedParams, defaultSupportedParamsHTTPS, simplexMQVersion, supportedClientSMPRelayVRange) import Simplex.Messaging.Transport.Client (TransportClientConfig (..), defaultTransportClientConfig, runTLSTransportClient, smpClientHandshake) @@ -40,6 +40,12 @@ import UnliftIO.Async (async, cancel) import UnliftIO.Concurrent (threadDelay) import UnliftIO.Exception (bracket) +#if defined(dbServerPostgres) +import NtfClient (ntfTestServerDBConnectInfo) +import SMPClient (postgressBracket) +import Simplex.Messaging.Notifications.Server.Main +#endif + cfgPath :: FilePath cfgPath = "tests/tmp/cli/etc/opt/simplex" @@ -70,9 +76,12 @@ cliTests = do it "no store log, no password" $ smpServerTest False False it "with store log, no password" $ smpServerTest True False it "static files" smpServerTestStatic - describe "Ntf server CLI" $ do - it "should initialize, start and delete the server (no store log)" $ ntfServerTest False - it "should initialize, start and delete the server (with store log)" $ ntfServerTest True +#if defined(dbServerPostgres) + aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $ + describe "Ntf server CLI" $ do + it "should initialize, start and delete the server (no store log)" $ ntfServerTest False + it "should initialize, start and delete the server (with store log)" $ ntfServerTest True +#endif describe "XFTP server CLI" $ do it "should initialize, start and delete the server (no store log)" $ xftpServerTest False it "should initialize, start and delete the server (with store log)" $ xftpServerTest True @@ -182,6 +191,7 @@ smpServerTestStatic = do let X.CertificateChain cc = tlsServerCerts tls in map (X.signedObject . X.getSigned) cc +#if defined(dbServerPostgres) ntfServerTest :: Bool -> IO () ntfServerTest storeLog = do capture_ (withArgs (["init"] <> ["--disable-store-log" | not storeLog]) $ ntfServerCLI ntfCfgPath ntfLogPath) @@ -195,10 +205,11 @@ ntfServerTest storeLog = do r <- lines <$> capture_ (withArgs ["start"] $ (100000 `timeout` ntfServerCLI ntfCfgPath ntfLogPath) `catchAll_` pure (Just ())) r `shouldContain` ["SMP notifications server v" <> simplexMQVersion] r `shouldContain` (if storeLog then ["Store log: " <> ntfLogPath <> "/ntf-server-store.log"] else ["Store log disabled."]) - r `shouldContain` ["Serving SMP protocol on port 443 (TLS)..."] + r `shouldContain` ["Serving NTF protocol on port 443 (TLS)..."] capture_ (withStdin "Y" . withArgs ["delete"] $ ntfServerCLI ntfCfgPath ntfLogPath) >>= (`shouldSatisfy` ("WARNING: deleting the server will make all queues inaccessible" `isPrefixOf`)) doesFileExist (cfgPath <> "/ca.key") `shouldReturn` False +#endif xftpServerTest :: Bool -> IO () xftpServerTest storeLog = do diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 190815832..e7a7c2ba5 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -28,12 +28,15 @@ import qualified Data.ByteString.Char8 as B import Data.List.NonEmpty (NonEmpty) import qualified Data.Map.Strict as M import Data.Text (Text) +import Database.PostgreSQL.Simple (ConnectInfo (..), defaultConnectInfo) import GHC.Generics (Generic) import Network.HTTP.Types (Status) import qualified Network.HTTP.Types as N import qualified Network.HTTP2.Server as H import Network.Socket -import SMPClient (prevRange, serverBracket) +import SMPClient (ntfTestPort, prevRange, serverBracket) +import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C @@ -45,6 +48,8 @@ import Simplex.Messaging.Notifications.Server.Push.APNS import Simplex.Messaging.Notifications.Server.Push.APNS.Internal import Simplex.Messaging.Notifications.Transport import Simplex.Messaging.Protocol +import Simplex.Messaging.Server.Env.STM (StartOptions (..)) +import Simplex.Messaging.Server.QueueStore.Postgres.Config (PostgresStoreCfg (..)) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client @@ -60,12 +65,6 @@ import UnliftIO.STM testHost :: NonEmpty TransportHost testHost = "localhost" -ntfTestPort :: ServiceName -ntfTestPort = "6001" - -ntfTestPort2 :: ServiceName -ntfTestPort2 = "6002" - apnsTestPort :: ServiceName apnsTestPort = "6010" @@ -75,9 +74,46 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" ntfTestStoreLogFile :: FilePath ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log" +ntfTestStoreLogFile2 :: FilePath +ntfTestStoreLogFile2 = "tests/tmp/ntf-server-store.log.2" + ntfTestStoreLastNtfsFile :: FilePath ntfTestStoreLastNtfsFile = "tests/tmp/ntf-server-last-notifications.log" +ntfTestStoreDBOpts :: DBOpts +ntfTestStoreDBOpts = + DBOpts + { connstr = ntfTestServerDBConnstr, + schema = "ntf_server", + poolSize = 3, + createSchema = True + } + +ntfTestStoreDBOpts2 :: DBOpts +ntfTestStoreDBOpts2 = ntfTestStoreDBOpts {schema = "smp_server2"} + +ntfTestServerDBConnstr :: ByteString +ntfTestServerDBConnstr = "postgresql://ntf_test_server_user@/ntf_test_server_db" + +ntfTestServerDBConnectInfo :: ConnectInfo +ntfTestServerDBConnectInfo = + defaultConnectInfo { + connectUser = "ntf_test_server_user", + connectDatabase = "ntf_test_server_db" + } + +ntfTestDBCfg :: PostgresStoreCfg +ntfTestDBCfg = + PostgresStoreCfg + { dbOpts = ntfTestStoreDBOpts, + dbStoreLogPath = Just ntfTestStoreLogFile, + confirmMigrations = MCYesUp, + deletedTTL = 86400 + } + +ntfTestDBCfg2 :: PostgresStoreCfg +ntfTestDBCfg2 = ntfTestDBCfg {dbOpts = ntfTestStoreDBOpts2, dbStoreLogPath = Just ntfTestStoreLogFile2} + testNtfClient :: Transport c => (THandleNTF c 'TClient -> IO a) -> IO a testNtfClient client = do Right host <- pure $ chooseTransportHost defaultNetworkConfig testHost @@ -106,8 +142,7 @@ ntfServerCfg = }, subsBatchSize = 900, inactiveClientExpiration = Just defaultInactiveClientExpiration, - storeLogFile = Nothing, - storeLastNtfsFile = Nothing, + dbStoreConfig = ntfTestDBCfg, ntfCredentials = ServerCredentials { caCertificateFile = Just "tests/fixtures/ca.crt", @@ -120,7 +155,8 @@ ntfServerCfg = serverStatsLogFile = "tests/ntf-server-stats.daily.log", serverStatsBackupFile = Nothing, ntfServerVRange = supportedServerNTFVRange, - transportConfig = defaultTransportServerConfig + transportConfig = defaultTransportServerConfig, + startOptions = StartOptions {maintenance = False, compactLog = False, skipWarnings = False, confirmMigrations = MCYesUp} } ntfServerCfgVPrev :: NtfServerConfig @@ -134,11 +170,9 @@ ntfServerCfgVPrev = smpCfg' = smpCfg smpAgentCfg' serverVRange' = serverVRange smpCfg' -withNtfServerStoreLog :: ATransport -> (ThreadId -> IO a) -> IO a -withNtfServerStoreLog t = withNtfServerCfg ntfServerCfg {storeLogFile = Just ntfTestStoreLogFile, storeLastNtfsFile = Just ntfTestStoreLastNtfsFile, transports = [(ntfTestPort, t, False)]} - -withNtfServerThreadOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a -withNtfServerThreadOn t port' = withNtfServerCfg ntfServerCfg {transports = [(port', t, False)]} +withNtfServerThreadOn :: HasCallStack => ATransport -> ServiceName -> PostgresStoreCfg -> (HasCallStack => ThreadId -> IO a) -> IO a +withNtfServerThreadOn t port' dbStoreConfig = + withNtfServerCfg ntfServerCfg {transports = [(port', t, False)], dbStoreConfig} withNtfServerCfg :: HasCallStack => NtfServerConfig -> (ThreadId -> IO a) -> IO a withNtfServerCfg cfg@NtfServerConfig {transports} = @@ -149,11 +183,11 @@ withNtfServerCfg cfg@NtfServerConfig {transports} = (\started -> runNtfServerBlocking started cfg) (pure ()) -withNtfServerOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => IO a) -> IO a -withNtfServerOn t port' = withNtfServerThreadOn t port' . const +withNtfServerOn :: HasCallStack => ATransport -> ServiceName -> PostgresStoreCfg -> (HasCallStack => IO a) -> IO a +withNtfServerOn t port' dbStoreConfig = withNtfServerThreadOn t port' dbStoreConfig . const withNtfServer :: HasCallStack => ATransport -> (HasCallStack => IO a) -> IO a -withNtfServer t = withNtfServerOn t ntfTestPort +withNtfServer t = withNtfServerOn t ntfTestPort ntfTestDBCfg runNtfTest :: forall c a. Transport c => (THandleNTF c 'TClient -> IO a) -> IO a runNtfTest test = withNtfServer (transport @c) $ testNtfClient test diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index b2e868cc2..8af15aa59 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -13,6 +13,7 @@ module NtfServerTests where import Control.Concurrent (threadDelay) +import Control.Monad (void) import qualified Data.Aeson as J import qualified Data.Aeson.Types as JT import Data.Bifunctor (first) @@ -113,9 +114,20 @@ testNotificationSubscription (ATransport t) createQueue = APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}} <- getMockNotification apns tkn let dhSecret = C.dh' ntfDh dhPriv - Right verification = ntfData .-> "verification" - Right nonce = C.cbNonce <$> ntfData .-> "nonce" - Right code = NtfRegCode <$> C.cbDecrypt dhSecret nonce verification + decryptCode nd = + let Right verification = nd .-> "verification" + Right nonce = C.cbNonce <$> nd .-> "nonce" + Right pt = C.cbDecrypt dhSecret nonce verification + in NtfRegCode pt + let code = decryptCode ntfData + -- test repeated request - should return the same token ID + RespNtf "1a" NoEntity (NRTknId tId1 ntfDh1) <- signSendRecvNtf nh tknKey ("1a", NoEntity, TNEW $ NewNtfTkn tkn tknPub dhPub) + tId1 `shouldBe` tId + ntfDh1 `shouldBe` ntfDh + APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData1}} <- + getMockNotification apns tkn + let code1 = decryptCode ntfData1 + code `shouldBe` code1 RespNtf "2" _ NROk <- signSendRecvNtf nh tknKey ("2", tId, TVFY code) RespNtf "2a" _ (NRTkn NTActive) <- signSendRecvNtf nh tknKey ("2a", tId, TCHK) -- ntf server subscribes to queue notifications diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 2903de05c..9fce42669 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -15,8 +15,7 @@ import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import NtfClient (ntfTestPort) -import SMPClient (proxyVRangeV8, testPort) +import SMPClient (proxyVRangeV8, ntfTestPort, testPort) import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 07dc60723..97470703b 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -45,7 +45,12 @@ import UnliftIO.Timeout (timeout) import Util #if defined(dbServerPostgres) -import Database.PostgreSQL.Simple (ConnectInfo (..), defaultConnectInfo) +import Database.PostgreSQL.Simple (defaultConnectInfo) +#endif + +#if defined(dbPostgres) || defined(dbServerPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..)) +import Simplex.Messaging.Agent.Store.Postgres.Util (createDBAndUserIfNotExists, dropDatabaseAndUser) #endif testHost :: NonEmpty TransportHost @@ -60,6 +65,12 @@ testPort = "5001" testPort2 :: ServiceName testPort2 = "5002" +ntfTestPort :: ServiceName +ntfTestPort = "6001" + +ntfTestPort2 :: ServiceName +ntfTestPort2 = "6002" + testKeyHash :: C.KeyHash testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" @@ -381,3 +392,11 @@ smpTest4 _ msType test' = smpTestN msType 4 _test unexpected :: (HasCallStack, Show a) => a -> Expectation unexpected r = expectationFailure $ "unexpected response " <> show r + +#if defined(dbPostgres) || defined(dbServerPostgres) +postgressBracket :: ConnectInfo -> IO a -> IO a +postgressBracket connInfo = + E.bracket_ + (dropDatabaseAndUser connInfo >> createDBAndUserIfNotExists connInfo) + (dropDatabaseAndUser connInfo) +#endif diff --git a/tests/Test.hs b/tests/Test.hs index 9ebdec8f7..653538faf 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -21,7 +21,6 @@ import CoreTests.VersionRangeTests import FileDescriptionTests (fileDescriptionTests) import GHC.IO.Exception (IOException (..)) import qualified GHC.IO.Exception as IOException -import NtfServerTests (ntfServerTests) import RemoteControl (remoteControlTests) import SMPProxyTests (smpProxyTests) import ServerTests @@ -43,13 +42,14 @@ import AgentTests.SchemaDump (schemaDumpTest) #endif #if defined(dbServerPostgres) +import NtfServerTests (ntfServerTests) +import NtfClient (ntfTestServerDBConnectInfo) import SMPClient (testServerDBConnectInfo) import ServerTests.SchemaDump #endif #if defined(dbPostgres) || defined(dbServerPostgres) -import Database.PostgreSQL.Simple (ConnectInfo (..)) -import Simplex.Messaging.Agent.Store.Postgres.Util (createDBAndUserIfNotExists, dropDatabaseAndUser) +import SMPClient (postgressBracket) #endif logCfg :: LogConfig @@ -57,6 +57,7 @@ logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} main :: IO () main = do + -- TODO [ntfdb] running wiht LogWarn level shows potential issue "Queue count differs" setLogLevel LogError -- LogInfo withGlobalLogging logCfg $ do setEnv "APNS_KEY_ID" "H82WD9K9AQ" @@ -95,7 +96,7 @@ main = do describe "Server schema dump" serverSchemaDumpTest aroundAll_ (postgressBracket testServerDBConnectInfo) $ describe "SMP server via TLS, postgres+jornal message store" $ - before (pure (transport @TLS, ASType SQSPostgres SMSJournal)) serverTests + before (pure (transport @TLS, ASType SQSPostgres SMSJournal)) serverTests #endif describe "SMP server via TLS, jornal message store" $ do describe "SMP syntax" $ serverSyntaxTests (transport @TLS) @@ -105,8 +106,9 @@ main = do -- xdescribe "SMP server via WebSockets" $ do -- describe "SMP syntax" $ serverSyntaxTests (transport @WS) -- before (pure (transport @WS, ASType SQSMemory SMSJournal)) serverTests - describe "Notifications server" $ ntfServerTests (transport @TLS) #if defined(dbServerPostgres) + aroundAll_ (postgressBracket ntfTestServerDBConnectInfo) $ do + describe "Notifications server" $ ntfServerTests (transport @TLS) aroundAll_ (postgressBracket testServerDBConnectInfo) $ do describe "SMP client agent, postgres+jornal message store" $ agentTests (transport @TLS, ASType SQSPostgres SMSJournal) describe "SMP proxy, postgres+jornal message store" $ @@ -132,11 +134,3 @@ eventuallyRemove path retries = case retries of _ -> E.throwIO ioe where action = removeDirectoryRecursive path - -#if defined(dbPostgres) || defined(dbServerPostgres) -postgressBracket :: ConnectInfo -> IO a -> IO a -postgressBracket connInfo = - E.bracket_ - (dropDatabaseAndUser connInfo >> createDBAndUserIfNotExists connInfo) - (dropDatabaseAndUser connInfo) -#endif diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index f7e880083..bfb601465 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -454,7 +454,7 @@ testXFTPAgentSendRestore = withGlobalLogging logCfgNoLogs $ do pure rfd1 -- prefix path should be removed after sending file - threadDelay 200000 + threadDelay 500000 doesDirectoryExist prefixPath `shouldReturn` False doesFileExist encPath `shouldReturn` False