mirror of
https://github.com/simplex-chat/simplexmq.git
synced 2026-03-31 22:46:27 +00:00
421 lines
16 KiB
Haskell
421 lines
16 KiB
Haskell
{-# LANGUAGE BangPatterns #-}
|
|
{-# LANGUAGE DerivingStrategies #-}
|
|
{-# LANGUAGE DuplicateRecordFields #-}
|
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
|
{-# LANGUAGE InstanceSigs #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE MultiWayIf #-}
|
|
{-# LANGUAGE NamedFieldPuns #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE QuasiQuotes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TupleSections #-}
|
|
{-# LANGUAGE TypeApplications #-}
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
|
|
module Simplex.Messaging.Server.MsgStore.Postgres
|
|
( PostgresMsgStore,
|
|
PostgresMsgStoreCfg (..),
|
|
PostgresQueue,
|
|
exportDbMessages,
|
|
getDbMessageStats,
|
|
getDbMessageCount,
|
|
deleteAllMessages,
|
|
batchInsertMessages,
|
|
updateQueueCounts,
|
|
)
|
|
where
|
|
|
|
import Control.Concurrent.STM
|
|
import qualified Control.Exception as E
|
|
import Control.Monad
|
|
import Control.Monad.Reader
|
|
import Control.Monad.Trans.Except
|
|
import qualified Data.ByteString as B
|
|
import qualified Data.ByteString.Builder as BB
|
|
import qualified Data.ByteString.Lazy as LB
|
|
import Data.Functor (($>))
|
|
import Data.IORef
|
|
import Data.Int (Int64)
|
|
import Data.List (intersperse)
|
|
import qualified Data.Map.Strict as M
|
|
import Data.Text (Text)
|
|
import Data.Time.Clock.System (SystemTime (..))
|
|
import Database.PostgreSQL.Simple (Binary (..), Only (..), (:.) (..))
|
|
import qualified Database.PostgreSQL.Simple as DB
|
|
import qualified Database.PostgreSQL.Simple.Copy as DB
|
|
import Database.PostgreSQL.Simple.SqlQQ (sql)
|
|
import Database.PostgreSQL.Simple.ToField (ToField (..))
|
|
import Simplex.Messaging.Agent.Store.Postgres.Common
|
|
import qualified Simplex.Messaging.Crypto as C
|
|
import Simplex.Messaging.Protocol
|
|
import Simplex.Messaging.Server.MsgStore
|
|
import Simplex.Messaging.Server.MsgStore.Types
|
|
import Simplex.Messaging.Server.QueueStore
|
|
import Simplex.Messaging.Server.QueueStore.Postgres
|
|
import Simplex.Messaging.Server.QueueStore.Types
|
|
import Simplex.Messaging.Server.StoreLog (foldLogLines)
|
|
import Simplex.Messaging.Encoding.String
|
|
import Simplex.Messaging.Util (maybeFirstRow, maybeFirstRow', (<$$>))
|
|
import System.IO (Handle, hFlush, stdout)
|
|
|
|
data PostgresMsgStore = PostgresMsgStore
|
|
{ config :: PostgresMsgStoreCfg,
|
|
queueStore_ :: PostgresQueueStore'
|
|
}
|
|
|
|
data PostgresMsgStoreCfg = PostgresMsgStoreCfg
|
|
{ queueStoreCfg :: PostgresStoreCfg,
|
|
quota :: Int
|
|
}
|
|
|
|
type PostgresQueueStore' = PostgresQueueStore PostgresQueue
|
|
|
|
data PostgresQueue = PostgresQueue
|
|
{ recipientId' :: RecipientId,
|
|
queueRec' :: TVar (Maybe QueueRec)
|
|
}
|
|
|
|
-- spec: spec/modules/Simplex/Messaging/Server/MsgStore/Postgres.md#msgqueue-is-unit-type
|
|
instance StoreQueueClass PostgresQueue where
|
|
recipientId = recipientId'
|
|
{-# INLINE recipientId #-}
|
|
queueRec = queueRec'
|
|
{-# INLINE queueRec #-}
|
|
withQueueLock PostgresQueue {} _ = id -- TODO [messages] maybe it's just transaction?
|
|
{-# INLINE withQueueLock #-}
|
|
|
|
newtype DBTransaction = DBTransaction {dbConn :: DB.Connection}
|
|
|
|
type DBStoreIO a = ReaderT DBTransaction IO a
|
|
|
|
instance MsgStoreClass PostgresMsgStore where
|
|
type StoreMonad PostgresMsgStore = ReaderT DBTransaction IO
|
|
type MsgQueue PostgresMsgStore = ()
|
|
type QueueStore PostgresMsgStore = PostgresQueueStore'
|
|
type StoreQueue PostgresMsgStore = PostgresQueue
|
|
type MsgStoreConfig PostgresMsgStore = PostgresMsgStoreCfg
|
|
|
|
newMsgStore :: PostgresMsgStoreCfg -> IO PostgresMsgStore
|
|
newMsgStore config = do
|
|
queueStore_ <- newQueueStore @PostgresQueue (queueStoreCfg config, False)
|
|
pure PostgresMsgStore {config, queueStore_}
|
|
|
|
closeMsgStore :: PostgresMsgStore -> IO ()
|
|
closeMsgStore = closeQueueStore @PostgresQueue . queueStore_
|
|
|
|
withActiveMsgQueues _ _ = error "withActiveMsgQueues not used"
|
|
|
|
unsafeWithAllMsgQueues _ _ _ = error "unsafeWithAllMsgQueues not used"
|
|
|
|
expireOldMessages :: Bool -> PostgresMsgStore -> Int64 -> Int64 -> IO MessageStats
|
|
expireOldMessages _tty ms now ttl =
|
|
maybeFirstRow' newMessageStats toMessageStats $ withConnection st $ \db ->
|
|
DB.query db "CALL expire_old_messages(?,?,?,0,0,0)" (oldQueue, oldMsg, batchSize)
|
|
where
|
|
st = dbStore $ queueStore_ ms
|
|
oldQueue = 0 :: Int64 -- expire all queues
|
|
oldMsg = now - ttl
|
|
batchSize = 10000 :: Int
|
|
toMessageStats (expiredMsgsCount, storedMsgsCount, storedQueues) =
|
|
MessageStats {expiredMsgsCount, storedMsgsCount, storedQueues}
|
|
|
|
foldRcvServiceMessages :: PostgresMsgStore -> ServiceId -> (a -> RecipientId -> Either ErrorType (Maybe (QueueRec, Message)) -> IO a) -> a -> IO (Either ErrorType a)
|
|
foldRcvServiceMessages ms serviceId f acc =
|
|
runExceptT $ withDB' "foldRcvServiceMessages" (queueStore_ ms) $ \db ->
|
|
DB.fold
|
|
db
|
|
[sql|
|
|
SELECT q.recipient_id, q.recipient_keys, q.rcv_dh_secret,
|
|
q.sender_id, q.sender_key, q.queue_mode,
|
|
q.notifier_id, q.notifier_key, q.rcv_ntf_dh_secret, q.ntf_service_id,
|
|
q.status, q.updated_at, q.link_id, q.rcv_service_id,
|
|
m.msg_id, m.msg_ts, m.msg_quota, m.msg_ntf_flag, m.msg_body
|
|
FROM msg_queues q
|
|
LEFT JOIN (
|
|
SELECT recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body,
|
|
ROW_NUMBER() OVER (PARTITION BY recipient_id ORDER BY message_id ASC) AS row_num
|
|
FROM messages
|
|
) m ON q.recipient_id = m.recipient_id AND m.row_num = 1
|
|
WHERE q.rcv_service_id = ? AND q.deleted_at IS NULL;
|
|
|]
|
|
(Only serviceId)
|
|
acc
|
|
f'
|
|
where
|
|
f' a (qRow :. mRow) =
|
|
let (rId, qr) = rowToQueueRec qRow
|
|
msg_ = toMaybeMessage mRow
|
|
in f a rId $ Right ((qr,) <$> msg_)
|
|
|
|
logQueueStates _ = error "logQueueStates not used"
|
|
|
|
logQueueState _ = error "logQueueState not used"
|
|
|
|
queueStore = queueStore_
|
|
{-# INLINE queueStore #-}
|
|
|
|
loadedQueueCounts :: PostgresMsgStore -> IO LoadedQueueCounts
|
|
loadedQueueCounts ms = do
|
|
loadedQueueCount <- M.size <$> readTVarIO queues
|
|
loadedNotifierCount <- M.size <$> readTVarIO notifiers
|
|
notifierLockCount <- M.size <$> readTVarIO notifierLocks
|
|
pure LoadedQueueCounts {loadedQueueCount, loadedNotifierCount, openJournalCount = 0, queueLockCount = 0, notifierLockCount}
|
|
where
|
|
PostgresQueueStore {queues, notifiers, notifierLocks} = queueStore_ ms
|
|
|
|
mkQueue :: PostgresMsgStore -> Bool -> RecipientId -> QueueRec -> IO PostgresQueue
|
|
mkQueue _ _keepLock rId qr = PostgresQueue rId <$> newTVarIO (Just qr)
|
|
{-# INLINE mkQueue #-}
|
|
|
|
getMsgQueue _ _ _ = pure ()
|
|
{-# INLINE getMsgQueue #-}
|
|
|
|
getPeekMsgQueue :: PostgresMsgStore -> PostgresQueue -> DBStoreIO (Maybe ((), Message))
|
|
getPeekMsgQueue _ q = ((),) <$$> tryPeekMsg_ q ()
|
|
|
|
withIdleMsgQueue :: Int64 -> PostgresMsgStore -> PostgresQueue -> (() -> DBStoreIO a) -> DBStoreIO (Maybe a, Int)
|
|
withIdleMsgQueue _ _ _ _ = error "withIdleMsgQueue not used"
|
|
|
|
deleteQueue :: PostgresMsgStore -> PostgresQueue -> IO (Either ErrorType QueueRec)
|
|
deleteQueue ms q = deleteStoreQueue (queueStore_ ms) q
|
|
{-# INLINE deleteQueue #-}
|
|
|
|
deleteQueueSize :: PostgresMsgStore -> PostgresQueue -> IO (Either ErrorType (QueueRec, Int))
|
|
deleteQueueSize ms q = runExceptT $ do
|
|
size <- getQueueSize ms q
|
|
qr <- ExceptT $ deleteStoreQueue (queueStore_ ms) q
|
|
pure (qr, size)
|
|
|
|
getQueueMessages_ _ _ _ = error "getQueueMessages_ not used"
|
|
|
|
writeMsg :: PostgresMsgStore -> PostgresQueue -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool))
|
|
writeMsg ms q _ msg =
|
|
uninterruptibleMask_ $
|
|
withDB' "writeMsg" (queueStore_ ms) $ \db -> do
|
|
let (msgQuota, ntf, body) = case msg of
|
|
Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body')
|
|
MessageQuota {} -> (True, False, B.empty)
|
|
toResult <$>
|
|
DB.query
|
|
db
|
|
"SELECT quota_written, was_empty FROM write_message(?,?,?,?,?,?,?)"
|
|
(recipientId' q, Binary (messageId msg), systemSeconds (messageTs msg), msgQuota, ntf, Binary body, quota)
|
|
where
|
|
toResult = \case
|
|
((msgQuota, wasEmpty) : _) -> if msgQuota then Nothing else Just (msg, wasEmpty)
|
|
[] -> Nothing
|
|
PostgresMsgStore {config = PostgresMsgStoreCfg {quota}} = ms
|
|
|
|
setOverQuota_ :: PostgresQueue -> IO () -- can ONLY be used while restoring messages, not while server running
|
|
setOverQuota_ _ = error "TODO setOverQuota_" -- TODO [messages]
|
|
|
|
getQueueSize_ :: () -> DBStoreIO Int
|
|
getQueueSize_ _ = error "getQueueSize_ not used"
|
|
|
|
getQueueSize :: PostgresMsgStore -> PostgresQueue -> ExceptT ErrorType IO Int
|
|
getQueueSize ms q =
|
|
withDB' "getQueueSize" (queueStore_ ms) $ \db ->
|
|
maybeFirstRow' 0 fromOnly $
|
|
DB.query db "SELECT msg_queue_size FROM msg_queues WHERE recipient_id = ? AND deleted_at IS NULL" (Only (recipientId' q))
|
|
|
|
tryPeekMsg_ :: PostgresQueue -> () -> DBStoreIO (Maybe Message)
|
|
tryPeekMsg_ q _ = do
|
|
db <- asks dbConn
|
|
liftIO $ maybeFirstRow toMessage $
|
|
DB.query
|
|
db
|
|
[sql|
|
|
SELECT msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
|
FROM messages
|
|
WHERE recipient_id = ?
|
|
ORDER BY message_id ASC LIMIT 1
|
|
|]
|
|
(Only (recipientId' q))
|
|
|
|
tryDeleteMsg_ :: PostgresQueue -> () -> Bool -> DBStoreIO ()
|
|
tryDeleteMsg_ _q _ _ = error "tryDeleteMsg_ not used" -- do
|
|
|
|
isolateQueue :: PostgresMsgStore -> PostgresQueue -> Text -> DBStoreIO a -> ExceptT ErrorType IO a
|
|
isolateQueue ms _q op a = uninterruptibleMask_ $ withDB' op (queueStore_ ms) $ runReaderT a . DBTransaction
|
|
|
|
unsafeRunStore _ _ _ = error "unsafeRunStore not used"
|
|
|
|
tryPeekMsg :: PostgresMsgStore -> PostgresQueue -> ExceptT ErrorType IO (Maybe Message)
|
|
tryPeekMsg ms q = isolateQueue ms q "tryPeekMsg" $ tryPeekMsg_ q ()
|
|
{-# INLINE tryPeekMsg #-}
|
|
|
|
tryDelMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message)
|
|
tryDelMsg ms q msgId =
|
|
uninterruptibleMask_ $
|
|
withDB' "tryDelMsg" (queueStore_ ms) $ \db ->
|
|
maybeFirstRow toMessage $
|
|
DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_msg(?, ?)" (recipientId' q, Binary msgId)
|
|
|
|
tryDelPeekMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message)
|
|
tryDelPeekMsg ms q msgId =
|
|
uninterruptibleMask_ $
|
|
withDB' "tryDelPeekMsg" (queueStore_ ms) $ \db ->
|
|
toResult . map toMessage
|
|
<$> DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_peek_msg(?, ?)" (recipientId' q, Binary msgId)
|
|
where
|
|
toResult = \case
|
|
[] -> (Nothing, Nothing)
|
|
[msg]
|
|
| messageId msg == msgId -> (Just msg, Nothing)
|
|
| otherwise -> (Nothing, Just msg)
|
|
deleted : next : _ -> (Just deleted, Just next)
|
|
|
|
deleteExpiredMsgs :: PostgresMsgStore -> PostgresQueue -> Int64 -> ExceptT ErrorType IO Int
|
|
deleteExpiredMsgs ms q old =
|
|
uninterruptibleMask_ $
|
|
maybeFirstRow' 0 (fromIntegral @Int64 . fromOnly) $ withDB' "deleteExpiredMsgs" (queueStore_ ms) $ \db ->
|
|
DB.query db "SELECT delete_expired_msgs(?, ?)" (recipientId' q, old)
|
|
|
|
uninterruptibleMask_ :: ExceptT ErrorType IO a -> ExceptT ErrorType IO a
|
|
uninterruptibleMask_ = ExceptT . E.uninterruptibleMask_ . runExceptT
|
|
{-# INLINE uninterruptibleMask_ #-}
|
|
|
|
toMaybeMessage :: (Maybe (Binary MsgId), Maybe Int64, Maybe Bool, Maybe Bool, Maybe (Binary MsgBody)) -> Maybe Message
|
|
toMaybeMessage = \case
|
|
(Just msgId, Just ts, Just msgQuota, Just ntf, Just body) -> Just $ toMessage (msgId, ts, msgQuota, ntf, body)
|
|
_ -> Nothing
|
|
|
|
toMessage :: (Binary MsgId, Int64, Bool, Bool, Binary MsgBody) -> Message
|
|
toMessage (Binary msgId, ts, msgQuota, ntf, Binary body)
|
|
| msgQuota = MessageQuota {msgId, msgTs}
|
|
| otherwise = Message {msgId, msgTs, msgFlags = MsgFlags ntf, msgBody = C.unsafeMaxLenBS body} -- TODO [messages] unsafeMaxLenBS?
|
|
where
|
|
msgTs = MkSystemTime ts 0
|
|
|
|
exportDbMessages :: Bool -> PostgresMsgStore -> Handle -> IO Int
|
|
exportDbMessages tty ms h = do
|
|
rows <- newIORef []
|
|
n <- withConnection st $ \db -> DB.foldWithOptions_ opts db query 0 $ \i r -> do
|
|
let i' = i + 1
|
|
if i' `mod` 1000 > 0
|
|
then modifyIORef rows (r :)
|
|
else do
|
|
readIORef rows >>= writeMessages . (r :)
|
|
writeIORef rows []
|
|
when tty $ putStr (progress i' <> "\r") >> hFlush stdout
|
|
pure i'
|
|
readIORef rows >>= \rs -> unless (null rs) $ writeMessages rs
|
|
when tty $ putStrLn $ progress n
|
|
pure n
|
|
where
|
|
st = dbStore $ queueStore_ ms
|
|
opts = DB.defaultFoldOptions {DB.fetchQuantity = DB.Fixed 1000}
|
|
query =
|
|
[sql|
|
|
SELECT recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body
|
|
FROM messages
|
|
ORDER BY recipient_id, message_id ASC
|
|
|]
|
|
writeMessages = BB.hPutBuilder h . encodeMessages . reverse
|
|
encodeMessages = mconcat . map (\(Only rId :. msg) -> BB.byteString (strEncode $ MLRv3 rId $ toMessage msg) <> BB.char8 '\n')
|
|
progress i = "Processed: " <> show i <> " records"
|
|
|
|
getDbMessageStats :: PostgresMsgStore -> IO MessageStats
|
|
getDbMessageStats ms =
|
|
maybeFirstRow' newMessageStats toMessageStats $ withConnection st $ \db ->
|
|
DB.query_
|
|
db
|
|
[sql|
|
|
SELECT
|
|
(SELECT COUNT (1) FROM msg_queues WHERE deleted_at IS NULL),
|
|
(SELECT COUNT (1) FROM messages m JOIN msg_queues q USING recipient_id WHERE deleted_at IS NULL)
|
|
|]
|
|
where
|
|
st = dbStore $ queueStore_ ms
|
|
toMessageStats (storedQueues, storedMsgsCount) =
|
|
MessageStats {storedQueues, storedMsgsCount, expiredMsgsCount = 0}
|
|
|
|
getDbMessageCount :: PostgresMsgStore -> IO Int64
|
|
getDbMessageCount ms =
|
|
maybeFirstRow' 0 fromOnly $
|
|
withConnection (dbStore $ queueStore_ ms) (`DB.query_` "SELECT COUNT(*) FROM messages")
|
|
|
|
deleteAllMessages :: PostgresMsgStore -> IO ()
|
|
deleteAllMessages ms =
|
|
withConnection (dbStore $ queueStore_ ms) $ \db -> do
|
|
void $ DB.execute_ db "TRUNCATE messages"
|
|
void $ DB.execute_
|
|
db
|
|
[sql|
|
|
UPDATE msg_queues
|
|
SET msg_queue_size = 0, msg_can_write = TRUE, msg_queue_expire = FALSE
|
|
WHERE msg_queue_size != 0 OR msg_can_write = FALSE OR msg_queue_expire = TRUE
|
|
|]
|
|
|
|
updateQueueCounts :: PostgresMsgStore -> IO ()
|
|
updateQueueCounts ms =
|
|
withConnection (dbStore $ queueStore_ ms) $ \db -> do
|
|
void $ DB.execute_
|
|
db
|
|
[sql|
|
|
CREATE TEMP TABLE queue_stats AS
|
|
SELECT recipient_id,
|
|
COUNT(*) AS size,
|
|
SUM(CASE WHEN msg_quota THEN 1 ELSE 0 END) AS quota_count
|
|
FROM messages
|
|
GROUP BY recipient_id
|
|
|]
|
|
void $ DB.execute_
|
|
db
|
|
[sql|
|
|
UPDATE msg_queues
|
|
SET msg_queue_size = 0, msg_can_write = TRUE, msg_queue_expire = FALSE
|
|
WHERE msg_queue_size != 0 OR msg_can_write = FALSE OR msg_queue_expire = TRUE
|
|
|]
|
|
void $ DB.execute_
|
|
db
|
|
[sql|
|
|
UPDATE msg_queues q
|
|
SET msg_queue_size = s.size,
|
|
msg_can_write = s.quota_count = 0,
|
|
msg_queue_expire = s.size > s.quota_count
|
|
FROM queue_stats s
|
|
WHERE q.recipient_id = s.recipient_id
|
|
|]
|
|
void $ DB.execute_ db "DROP TABLE queue_stats"
|
|
|
|
batchInsertMessages :: StoreQueueClass q => Bool -> FilePath -> PostgresQueueStore q -> IO Int64
|
|
batchInsertMessages tty f toStore = do
|
|
putStrLn "Importing messages..."
|
|
let st = dbStore toStore
|
|
(_, inserted) <-
|
|
withTransaction st $ \db -> do
|
|
DB.copy_
|
|
db
|
|
[sql|
|
|
COPY messages (recipient_id, msg_id, msg_ts, msg_quota, msg_ntf_flag, msg_body)
|
|
FROM STDIN WITH (FORMAT CSV)
|
|
|]
|
|
foldLogLines tty f (putMessage db) (0 :: Int, 0) >>= (DB.putCopyEnd db $>)
|
|
pure inserted
|
|
where
|
|
putMessage db (!i, !cnt) _eof s = do
|
|
let i' = i + 1
|
|
cnt' <- case strDecode s of
|
|
Right (MLRv3 rId msg) -> (cnt + 1) <$ DB.putCopyData db (messageRecToText rId msg)
|
|
Left e -> cnt <$ putStrLn ("Error parsing line " <> show i' <> ": " <> e)
|
|
pure (i', cnt')
|
|
|
|
messageRecToText :: RecipientId -> Message -> B.ByteString
|
|
messageRecToText rId msg =
|
|
LB.toStrict $ BB.toLazyByteString $ mconcat tabFields <> BB.char7 '\n'
|
|
where
|
|
tabFields = BB.char7 ',' `intersperse` fields
|
|
fields =
|
|
[ renderField (toField rId),
|
|
renderField (toField $ Binary (messageId msg)),
|
|
renderField (toField $ systemSeconds (messageTs msg)),
|
|
renderField (toField msgQuota),
|
|
renderField (toField ntf),
|
|
renderField (toField $ Binary body)
|
|
]
|
|
(msgQuota, ntf, body) = case msg of
|
|
Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body')
|
|
MessageQuota {} -> (True, False, B.empty)
|