diff --git a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs index 721a71946..5716dc2c6 100644 --- a/src/Simplex/Messaging/Server/MsgStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/MsgStore/Postgres.hs @@ -27,6 +27,7 @@ module Simplex.Messaging.Server.MsgStore.Postgres where import Control.Concurrent.STM +import qualified Control.Exception as E import Control.Monad import Control.Monad.Reader import Control.Monad.Trans.Except @@ -161,15 +162,16 @@ instance MsgStoreClass PostgresMsgStore where writeMsg :: PostgresMsgStore -> PostgresQueue -> Bool -> Message -> ExceptT ErrorType IO (Maybe (Message, Bool)) writeMsg ms q _ msg = - withDB' "writeMsg" (queueStore_ ms) $ \db -> do - let (msgQuota, ntf, body) = case msg of - Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body') - MessageQuota {} -> (True, False, B.empty) - toResult <$> - DB.query - db - "SELECT quota_written, was_empty FROM write_message(?,?,?,?,?,?,?)" - (recipientId' q, Binary (messageId msg), systemSeconds (messageTs msg), msgQuota, ntf, Binary body, quota) + uninterruptibleMask_ $ + withDB' "writeMsg" (queueStore_ ms) $ \db -> do + let (msgQuota, ntf, body) = case msg of + Message {msgFlags = MsgFlags ntf', msgBody = C.MaxLenBS body'} -> (False, ntf', body') + MessageQuota {} -> (True, False, B.empty) + toResult <$> + DB.query + db + "SELECT quota_written, was_empty FROM write_message(?,?,?,?,?,?,?)" + (recipientId' q, Binary (messageId msg), systemSeconds (messageTs msg), msgQuota, ntf, Binary body, quota) where toResult = \case ((msgQuota, wasEmpty) : _) -> if msgQuota then Nothing else Just (msg, wasEmpty) @@ -206,7 +208,7 @@ instance MsgStoreClass PostgresMsgStore where tryDeleteMsg_ _q _ _ = error "tryDeleteMsg_ not used" -- do isolateQueue :: PostgresMsgStore -> PostgresQueue -> Text -> DBStoreIO a -> ExceptT ErrorType IO a - isolateQueue ms _q op a = withDB' op (queueStore_ ms) $ runReaderT a . DBTransaction + isolateQueue ms _q op a = uninterruptibleMask_ $ withDB' op (queueStore_ ms) $ runReaderT a . DBTransaction unsafeRunStore _ _ _ = error "unsafeRunStore not used" @@ -216,15 +218,17 @@ instance MsgStoreClass PostgresMsgStore where tryDelMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message) tryDelMsg ms q msgId = - withDB' "tryDelMsg" (queueStore_ ms) $ \db -> - maybeFirstRow toMessage $ - DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_msg(?, ?)" (recipientId' q, Binary msgId) + uninterruptibleMask_ $ + withDB' "tryDelMsg" (queueStore_ ms) $ \db -> + maybeFirstRow toMessage $ + DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_msg(?, ?)" (recipientId' q, Binary msgId) tryDelPeekMsg :: PostgresMsgStore -> PostgresQueue -> MsgId -> ExceptT ErrorType IO (Maybe Message, Maybe Message) tryDelPeekMsg ms q msgId = - withDB' "tryDelPeekMsg" (queueStore_ ms) $ \db -> - toResult . map toMessage - <$> DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_peek_msg(?, ?)" (recipientId' q, Binary msgId) + uninterruptibleMask_ $ + withDB' "tryDelPeekMsg" (queueStore_ ms) $ \db -> + toResult . map toMessage + <$> DB.query db "SELECT r_msg_id, r_msg_ts, r_msg_quota, r_msg_ntf_flag, r_msg_body FROM try_del_peek_msg(?, ?)" (recipientId' q, Binary msgId) where toResult = \case [] -> (Nothing, Nothing) @@ -235,8 +239,13 @@ instance MsgStoreClass PostgresMsgStore where deleteExpiredMsgs :: PostgresMsgStore -> PostgresQueue -> Int64 -> ExceptT ErrorType IO Int deleteExpiredMsgs ms q old = - maybeFirstRow' 0 (fromIntegral @Int64 . fromOnly) $ withDB' "deleteExpiredMsgs" (queueStore_ ms) $ \db -> - DB.query db "SELECT delete_expired_msgs(?, ?)" (recipientId' q, old) + uninterruptibleMask_ $ + maybeFirstRow' 0 (fromIntegral @Int64 . fromOnly) $ withDB' "deleteExpiredMsgs" (queueStore_ ms) $ \db -> + DB.query db "SELECT delete_expired_msgs(?, ?)" (recipientId' q, old) + +uninterruptibleMask_ :: ExceptT ErrorType IO a -> ExceptT ErrorType IO a +uninterruptibleMask_ = ExceptT . E.uninterruptibleMask_ . runExceptT +{-# INLINE uninterruptibleMask_ #-} toMessage :: (Binary MsgId, Int64, Bool, Bool, Binary MsgBody) -> Message toMessage (Binary msgId, ts, msgQuota, ntf, Binary body)